%% vim: ts=4 sw=4 et
-module(simple_bridge_websocket).
-export([
attempt_hijacking/2,
%% These three are used by cowboy or yaws and should basically never be
%% called except from within the simple_bridge app
call_init/2,
keepalive_timeout/2,
schedule_keepalive_msg/1,
%% Exported for code-reloading
websocket_loop/8
]).
%-compile(export_all).
-include("simple_bridge.hrl").
-define(otherwise, true).
-define(WS_MAGIC, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").
-define(WS_VERSION, "13").
-define(WS_MASKED, 1).
-define(WS_UNMASKED, 0).
-define(WS_EXTENDED_PAYLOAD_16BIT, 126).
-define(WS_EXTENDED_PAYLOAD_64BIT, 127).
-define(WS_CONTINUATION, 0).
-define(WS_TEXT, 1).
-define(WS_BINARY, 2).
-define(WS_CLOSE, 8).
-define(WS_PING, 9).
-define(WS_PONG, 10).
-define(IS_INVALID_OPCODE(Op),
not(Op=:=?WS_CONTINUATION orelse
Op=:=?WS_TEXT orelse
Op=:=?WS_BINARY orelse
Op=:=?WS_CLOSE orelse
Op=:=?WS_PING orelse
Op=:=?WS_PONG)).
-record(frame, {fin=1, rsv=0, opcode, masked=0, payload_len=0, mask_key, data = <<>>}).
-record(partial_data, {data = <<>>, message_frames=[]}).
-spec attempt_hijacking(bridge(), Handler :: atom()) -> spared | {hijacked, closed} | {hijacked, bridge()}.
attempt_hijacking(Bridge, Handler) ->
ProtocolVersion = sbw:protocol_version(Bridge),
UpgradeHeader = sbw:header_lower(upgrade, Bridge),
WSVersionHead = sbw:header("Sec-WebSocket-Version", Bridge),
ConnectionHeaderHasUpgrade = does_connection_header_have_upgrade(Bridge),
if
ProtocolVersion >= {1,1},
UpgradeHeader =:= "websocket",
ConnectionHeaderHasUpgrade,
WSVersionHead =/= undefined ->
WSVersions = re:split(WSVersionHead, "[, ]+]", [{return, list}]),
HijackedBridge = case lists:member(?WS_VERSION, WSVersions) of
true ->
hijack(Bridge, Handler);
false ->
hijack_request_fail(Bridge)
end,
{hijacked, HijackedBridge};
?otherwise ->
spared %% Spared from being hijacked
end.
does_connection_header_have_upgrade(Bridge) ->
case sbw:header_lower(connection, Bridge) of
undefined ->
false;
ConnectionHeader ->
case re:run(ConnectionHeader, "upgrade") of
nomatch -> false;
{match, _} -> true
end
end.
call_init(Handler, Bridge) ->
case erlang:function_exported(Handler, ws_init, 1) of
true ->
case Handler:ws_init(Bridge) of
ok -> undefined;
{ok, State} -> State
end;
false -> undefined
end.
keepalive_timeout(infinity, _) -> infinity;
keepalive_timeout(KAInterval, KATimeout) when is_integer(KAInterval), is_integer(KATimeout)->
%% For a timeout at least a message every X milliseconds, so we add
%% Interval+Timeout, then trigger a ping every Interval seconds, that way,
%% we get timeout milliseconds to either get a new message (pong or
%% otherwise), and if we don't hear back from the ping within Timeout
%% milliseconds, it will be the whole timeframe and the server will kill
%% the connection.
KAInterval + KATimeout.
schedule_keepalive_msg(infinity) ->
ok;
schedule_keepalive_msg(KAInterval) ->
timer:send_after(KAInterval, simple_bridge_send_ping).
cancel_pong_timer(undefined) ->
ok;
cancel_pong_timer(TRef) ->
{ok, cancel} = timer:cancel(TRef).
hijack_request_fail(Bridge) ->
Bridge2 = sbw:set_status_code(400, Bridge),
Bridge3 = sbw:set_header("Sec-Websocket-Version", ?WS_VERSION, Bridge2),
Bridge4 = sbw:set_response_data(["Invalid Websocket Upgrade Request. Please use Websocket version ",?WS_VERSION], Bridge3),
Bridge4.
prepare_response_key(WSKey) ->
FullString = WSKey ++ ?WS_MAGIC,
Sha = ?HASH(FullString),
base64:encode(Sha).
hijack(Bridge, Handler) ->
WSKey = sbw:header("Sec-Websocket-Key", Bridge),
ResponseKey = prepare_response_key(WSKey),
Socket = sbw:socket(Bridge),
inet:setopts(Socket, [{buffer,65535}]),
send_handshake_response(Socket, ResponseKey),
inet:setopts(Socket, [{active, once}]),
State = call_init(Handler, Bridge),
Backend = simple_bridge_util:get_env(backend),
{KAInterval, KATimeout} = simple_bridge_util:get_websocket_keepalive_interval_timeout(Backend),
schedule_keepalive_msg(KAInterval),
websocket_loop_init(Socket, Bridge, Handler, KAInterval, KATimeout, undefined, State, #partial_data{}).
send_handshake_response(Socket, ResponseKey) ->
Handshake = [
<<"HTTP/1.1 101 Switching Protocols\r\n">>,
<<"Upgrade: websocket\r\n">>,
<<"Connection: Upgrade\r\n">>,
<<"Sec-WebSocket-Accept: ">>,ResponseKey,<<"\r\n">>,
<<"\r\n">>
],
gen_tcp:send(Socket, Handshake).
websocket_loop_init(Socket, Bridge, Handler, KAInterval, KATimeout, PongTimer, State, PartialData) ->
try websocket_loop(Socket, Bridge, Handler, KAInterval, KATimeout, PongTimer, State, PartialData)
catch
exit:{websocket, ReasonCode, _Reason} ->
send(Socket, {close, ReasonCode}),
gen_tcp:close(Socket),
%% cascade the error up
exit(normal)
end.
websocket_loop(Socket, Bridge, Handler, KAInterval, KATimeout, PongTimer, State, PartialData) ->
receive
{tcp, Socket, Data} ->
cancel_pong_timer(PongTimer),
AttemptPacket = <<(PartialData#partial_data.data)/binary, Data/binary>>,
Frames = parse_frame(AttemptPacket),
PendingFrames = PartialData#partial_data.message_frames,
case process_frames(Frames, Socket, Bridge, Handler, State, PendingFrames) of
{PendingFrames2, RemainderData, NewState} ->
inet:setopts(Socket, [{active, once}]),
?MODULE:websocket_loop(Socket, Bridge, Handler, KAInterval, KATimeout, undefined, NewState, #partial_data{data=RemainderData, message_frames=PendingFrames2});
closed -> closed
end;
{tcp_closed, _Socket} ->
closed;
simple_bridge_pong_timeout ->
%% If this message is received, it means no TCP message was
%% received in the expected timeframe, so we kill the
%% connection. Any TCP message received would have cancelled
%% the timer.
send(Socket, {close, 1006}),
gen_tcp:close(Socket),
closed;
simple_bridge_send_ping ->
Reply = {ping, <<"simple bridge websocket">>},
schedule_keepalive_msg(KAInterval),
send(Socket, Reply),
{ok, NewPongTimer} = timer:send_after(KATimeout, simple_bridge_pong_timeout),
?MODULE:websocket_loop(Socket, Bridge, Handler, KAInterval, KATimeout, NewPongTimer, State, PartialData);
Msg ->
{Reply, NewState} = call_info(Handler, Bridge, Msg, State),
send(Socket, Reply),
?MODULE:websocket_loop(Socket, Bridge, Handler, KAInterval, KATimeout, PongTimer, NewState, PartialData)
end.
call_info(Handler, Bridge, Msg, State) ->
case erlang:function_exported(Handler, ws_info, 3) of
true ->
HandlerReturn = Handler:ws_info(Msg, Bridge, State),
{_Reply, _NewState} = extract_reply_state(State, HandlerReturn);
false ->
{noreply, State}
end.
extract_reply_state(State, InfoMsgReturn) ->
case InfoMsgReturn of
noreply -> {noreply, State};
{noreply, NewState} -> {noreply, NewState};
{reply, Reply} -> {{reply, Reply}, State};
{reply, Reply, NewState} -> {{reply, Reply}, NewState};
close -> {close, 1000};
{close, CloseReason} -> {close, CloseReason}
end.
close_with_purpose(ReasonCode, Reason) ->
exit({websocket, ReasonCode, Reason}).
send(_, noreply) ->
do_nothing;
send(Socket, {ping, Data}) ->
send_frame(Socket, #frame{opcode=?WS_PING, data=Data});
send(Socket, {pong, Data}) ->
send_frame(Socket, #frame{opcode=?WS_PONG, data=Data});
send(Socket, close) ->
send(Socket, {close, 1000});
send(Socket, {close, ReasonCode}) ->
ReasonBody = <<ReasonCode:16>>,
send_frame(Socket, #frame{opcode=?WS_CLOSE, data=ReasonBody});
send(Socket, {reply, {text, Data}}) ->
send_frame(Socket, #frame{opcode=?WS_TEXT, data=Data});
send(Socket, {reply, {binary, Data}}) ->
send_frame(Socket, #frame{opcode=?WS_BINARY, data=Data});
send(Socket, {reply, Fragments}) when is_list(Fragments) ->
send_fragments(Socket, Fragments).
%%send(Socket, {stream, {binary, Fun}}) ->
%% send_frame(#frame{opcode=?WS_BINARY, fin=0, data=Fu
send_fragments(_Socket, []) ->
do_nothing;
%% If it's a list of one fragment, just send a self-contained frame
send_fragments(Socket, [H]) ->
send(Socket, H);
send_fragments(Socket, [{binary,Data}|T]) ->
send_frame(Socket, #frame{fin=0, opcode=?WS_BINARY, data=Data}),
send_fragments_rest(Socket, T);
send_fragments(Socket, [{text,Data}|T]) ->
send_frame(Socket, #frame{fin=0, opcode=?WS_TEXT, data=Data}),
send_fragments_rest(Socket, T).
send_fragments_rest(Socket, [{_,Data}]) ->
send_frame(Socket, #frame{fin=1, opcode=?WS_CONTINUATION, data=Data});
send_fragments_rest(Socket, [{_,Data}|T]) ->
send_frame(Socket, #frame{fin=0, opcode=?WS_CONTINUATION, data=Data}),
send_fragments_rest(Socket, T).
send_frame(Socket, F) ->
BinFrame = encode_frame(F),
gen_tcp:send(Socket, BinFrame).
encode_frame(#frame{
fin=Fin, rsv=RSV, opcode=Opcode,
%% commented because server should not mask. We can always implement if
%% masking becomes necessary
%% masked=Masked, mask_key=Mask,
data=Data}) ->
BinData = iolist_to_binary(Data),
{PayloadLen, ExtLen, ExtBitSize} = case byte_size(BinData) of
L when L < 126 -> {L, 0, 0};
L when L < 65536 -> {126, L, 16};
L -> {127, L, 64}
end,
Masked = 0,
<<Fin:1, RSV:3, Opcode:4, Masked:1, PayloadLen:7, ExtLen:ExtBitSize, BinData/binary>>.
%% This goes through each Frame in "Frames", and processes it, responding to
%% control processes, dispatching handlers if necessary, and if we're in the
%% middle of a series of fragments, store up those fragments and append them to
%% PendingFrames, or if a series of fragments is completed, then dispatch those
%% frames to the handler module and discard them.
%% Done processing frames
process_frames([], _Socket, _Bridge, _Handler, State, PendingFrames) ->
{PendingFrames, <<>>, State};
%% Done processing frames, and we have some left-over binary data
process_frames([Bin], _Socket, _Bridge, _Handler, State, PendingFrames) when is_binary(Bin) ->
{PendingFrames, Bin, State};
%% Handling erroneous Frams:
%% Control frames with payload > 126
process_frames([#frame{opcode=Ctl, payload_len=PayloadLen} | _], _Socket, _Bridge, _Handler, _State, _Pending)
when (Ctl=:=?WS_PING orelse Ctl=:=?WS_PONG orelse Ctl=:=?WS_CLOSE) andalso PayloadLen >= ?WS_EXTENDED_PAYLOAD_16BIT ->
close_with_purpose(1002, {control_frame_payload_to_large, PayloadLen});
%% RSV bits set to anything but zero
process_frames([#frame{rsv=RSV} | _], _Socket, _Bridge, _Handler, _State, _Pending)
when RSV =/= 0 ->
close_with_purpose(1002, {invalid_rsv, RSV});
%% Invalid Opcode
process_frames([#frame{opcode=Op} | _], _Socket, _Bridge, _Handler, _State, _Pending)
when ?IS_INVALID_OPCODE(Op) ->
close_with_purpose(1002, {invalid_opcode, Op});
%% Single Text or Binary Frame (PendingFrames must be empty)
process_frames([_F = #frame{opcode=Opcode, fin=1, data=Data} |Rest], Socket, Bridge, Handler, State, [])
when Opcode=:=?WS_BINARY; Opcode=:=?WS_TEXT ->
Type = type(Opcode),
%% Side-effects, look out!
close_on_invalid_utf8_text(Type, Data),
HandlerReturn = Handler:ws_message({Type, Data}, Bridge, State),
{Reply, NewState} = extract_reply_state(State, HandlerReturn),
send(Socket, Reply),
process_frames(Rest, Socket, Bridge, Handler, NewState, []);
%% First Text or Binary Fragment (PendingFrames must be empty)
process_frames([F = #frame{opcode=Opcode, fin=0}|Rest], Socket, Bridge, Handler, State, [])
when Opcode=:=?WS_BINARY; Opcode=:=?WS_TEXT ->
process_frames(Rest, Socket, Bridge, Handler, State, [F]);
%% Continuation Frame
process_frames([F = #frame{opcode=?WS_CONTINUATION, fin=0}|Rest], Socket, Bridge, Handler, State, PendingFrames=[_|_]) ->
process_frames(Rest, Socket, Bridge, Handler, State, PendingFrames ++ [F]);
%% Last fragment of a fragmented message
process_frames([F = #frame{opcode=?WS_CONTINUATION, fin=1}|Rest], Socket, Bridge, Handler, State, PendingFrames=[_|_]) ->
ReorderedFrames = PendingFrames ++ [F],
Type = type((hd(ReorderedFrames))#frame.opcode),
Msg = defragment_data(ReorderedFrames),
%% Side-effects, look out!
close_on_invalid_utf8_text(Type, Msg),
HandlerReturn = Handler:ws_message({Type, Msg}, Bridge, State),
{Reply, NewState} = extract_reply_state(State, HandlerReturn),
send(Socket, Reply),
process_frames(Rest, Socket, Bridge, Handler, NewState, []);
process_frames([_F = #frame{opcode=?WS_PONG}|Rest], Socket, Bridge, Handler, State, PendingFrames) ->
%% do nothing?
process_frames(Rest, Socket, Bridge, Handler, State, PendingFrames);
process_frames([#frame{opcode=?WS_PING, data=Data, fin=1}|Rest], Socket, Bridge, Handler, State, PendingFrames) ->
send(Socket, {pong, Data}),
process_frames(Rest, Socket, Bridge, Handler, State, PendingFrames);
process_frames([_F = #frame{opcode=?WS_CLOSE, data=Data}|_Rest], Socket, Bridge, Handler, State, _PendingFrames) ->
StatusCode = case Data of
<<_:8>> ->
1002;
<<ReasonCode:16,Text/binary>> ->
case is_valid_close_code(ReasonCode) of
true ->
case is_utf8(Text) of
true -> ReasonCode;
false -> 1007
end;
false -> 1002
end;
_ -> 1000
end,
send(Socket, {close, StatusCode}),
Handler:ws_terminate(closed, Bridge, State),
inet:close(Socket),
closed;
%% None of the above caught it. something must be wrong, so let's just die
process_frames([F|_], _Socket, _Bridge, _Handler, _State, _PendingFrames) ->
close_with_purpose(1002, {unknown_error_processing_frame, F}).
defragment_data(Frames) ->
iolist_to_binary([F#frame.data || F <- Frames]).
is_valid_close_code(Code) ->
(Code >= 3000 andalso Code < 5000)
orelse lists:member(Code, [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011]).
type(?WS_BINARY) -> binary;
type(?WS_TEXT) -> text.
-define(DO_FRAMES(Mask),
if
byte_size(Data) >= PayloadLen ->
do_frames(Fin,RSV,Op,PayloadLen,Mask,Data);
?otherwise ->
[Raw]
end).
%% @doc Takes some binary data, assumed to be the beginning of a frame, and
%% tries to parse it into one or more frames. It will return either a list of
%% frames, with the last element being either any remaining binary data for
%% incomplete frames or an incomplete frame itself. Any completed frames will
%% be processed and discarded, with the remainder being passed back into
%% websocket_loop's partial data.
%%
%% This will be left to crash if an unmasked frame is sent from the client. My
%% apologies for making these lines extremely long. I wanted to model a single
%% frame packet in pattern matching horizontally so it's obvious what's going
%% on.
-spec parse_frame(binary()) -> [#frame{} | binary()].
parse_frame(<<>>) -> [<<>>];
%FRAME: FIN RSV OPCODE MASKED PAYLOAD_LEN EXT_PAYLOAD_LEN MASK _KEY PAYLOAD
parse_frame(Raw = <<Fin:1, RSV:3, Op:4, ?WS_MASKED:1, ?WS_EXTENDED_PAYLOAD_16BIT:7, PayloadLen:16, Mask:32, Data/binary>>) ->
?DO_FRAMES(Mask);
parse_frame(Raw = <<Fin:1, RSV:3, Op:4, ?WS_MASKED:1, ?WS_EXTENDED_PAYLOAD_64BIT:7, PayloadLen:64, Mask:32, Data/binary>>) ->
?DO_FRAMES(Mask);
parse_frame(Raw = <<Fin:1, RSV:3, Op:4, ?WS_MASKED:1, PayloadLen:7, Mask:32, Data/binary>>) ->
?DO_FRAMES(Mask);
parse_frame(<<_:8, ?WS_UNMASKED:1, _/binary>>) ->
close_with_purpose(1002, unmasked_packet_received_from_client);
parse_frame(Data) ->
[Data]. %% not enough data to parse the frame, so just return the data, and we'll try after we get more data
do_frames(Fin, RSV, Op, PayloadLen, Mask, Data) ->
F = #frame{fin=Fin, rsv=RSV, opcode=Op, masked=1, payload_len=PayloadLen, mask_key=Mask, data = <<>>},
append_frame_data_and_parse_remainder(F, Data).
append_frame_data_and_parse_remainder(F = #frame{payload_len=PayloadLen, mask_key=Mask, data=CurrentPayload}, Data) ->
FullData = <<CurrentPayload/binary,Data/binary>>,
Length = byte_size(FullData),
if
Length =:= PayloadLen ->
Unmasked = apply_mask(Mask, FullData),
[F#frame{data=Unmasked}, <<>>];
Length > PayloadLen ->
%% Since the length of the received data is longer than the
%% required payload length, break off the part we want for our
%% frame
FrameData = binary:part(FullData, 0, PayloadLen),
Unmasked = apply_mask(Mask, FrameData),
%% Then let's break off the remainder of the binary, we're going to
%% try parsing this, too
RemainingData = binary:part(Data, PayloadLen, Length-PayloadLen),
[F#frame{data=Unmasked} | parse_frame(RemainingData)]
end.
apply_mask(Mask, Data) ->
apply_mask(Mask, Data, <<>>).
apply_mask(_, <<>>, Acc) ->
Acc;
apply_mask(M, <<D:32,Rest/binary>>, Acc) ->
Masked = D bxor M,
apply_mask(M, Rest, <<Acc/binary,Masked:32>>);
apply_mask(FullM, <<D:8>>, Acc ) ->
<<M:8,_:24>> = <<FullM:32>>,
Masked = D bxor M,
<<Acc/binary,Masked:8>>;
apply_mask(FullM, <<D:16>>, Acc) ->
<<M:16,_:16>> = <<FullM:32>>,
Masked = D bxor M,
<<Acc/binary,Masked:16>>;
apply_mask(FullM, <<D:24>>, Acc) ->
<<M:24,_:8>> = <<FullM:32>>,
Masked = D bxor M,
<<Acc/binary,Masked:24>>.
close_on_invalid_utf8_text(binary, _) -> do_nothing;
close_on_invalid_utf8_text(text, Data) ->
case is_utf8(Data) of
true -> do_nothing;
false -> close_with_purpose(1007, {invalid_utf8, Data})
end.
%% UTF8-Validation Below:
%% Copyright 2011-2013 Loïc Hoguin <essen@ninenines.eu>
%% Borrowed from Cowboy:
%% https://github.com/extend/cowboy/blob/0d5a12c3ecd3bd093c33e9a8126f1d129719b9ea/src/cowboy_websocket.erl#L491
-spec is_utf8(binary()) -> boolean().
is_utf8(<<>>) ->
true;
is_utf8(<< _/utf8, Rest/binary >>) ->
is_utf8(Rest);
%% 2 bytes. Codepages C0 and C1 are invalid; fail early.
is_utf8(<< 2#1100000:7, _/bits >>) ->
false;
%%is_utf8(<< 2#110:3, _:5 >>) ->
%% true;
%%%% 3 bytes.
%%is_utf8(<< 2#1110:4, _:4 >>) ->
%% true;
%%is_utf8(<< 2#1110:4, _:4, 2#10:2, _:6 >>) ->
%% true;
%%%% 4 bytes. Codepage F4 may have invalid values greater than 0x10FFFF.
is_utf8(<< 2#11110100:8, 2#10:2, High:6, _/bits >>) when High >= 2#10000 ->
false;
%%is_utf8(<< 2#11110:5, _:3 >>) ->
%% true;
%%is_utf8(<< 2#11110:5, _:3, 2#10:2, _:6 >>) ->
%% true;
%%is_utf8(<< 2#11110:5, _:3, 2#10:2, _:6, 2#10:2, _:6 >>) ->
%% true;
%% Invalid.
is_utf8(_) ->
false.