Skip to main content

src/masque_server_session.erl

%%% @doc Per-tunnel server-side session.
%%%
%%% One of these is spawned by the listener's handler fun after a
%%% CONNECT-UDP request passes validation. It receives routed
%%% datagrams from the connection router (`masque_server_connection')
%%% and invokes the configured user handler module to produce reply
%%% datagrams or close the session.
-module(masque_server_session).
-behaviour(gen_server).

-export([start_link/1]).

-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
         terminate/2, code_change/3]).

-include("masque.hrl").

-record(state, {
    conn       :: pid(),
    stream_id  :: non_neg_integer(),
    router     :: pid(),
    handler    :: module(),
    h_state    :: term(),
    req        :: map(),
    cap_buf = <<>> :: binary(),
    max_cap    :: pos_integer(),
    cap_fin_seen = false :: boolean(),
    %% Actions from handler init, applied after finalize
    pending_actions :: [term()] | undefined,
    start_time :: integer() | undefined
}).

%%====================================================================
%% API
%%====================================================================

-spec start_link(map()) -> {ok, pid()} | ignore | {error, term()}.
start_link(Args) ->
    gen_server:start_link(?MODULE, Args, []).

%%====================================================================
%% gen_server
%%====================================================================

init(#{conn := Conn, stream_id := StreamId, router := Router,
       handler := Handler, handler_opts := HOpts, req := Req}) ->
    process_flag(trap_exit, true),
    %% Monitor router so we stop if it dies during or after init.
    erlang:monitor(process, Router),
    %% RFC 9298 §3: a 2xx response means the tunnel is set up and the
    %% proxy is ready to forward UDP. So the user's `init/2' (which
    %% for the built-in proxy opens the gen_udp socket and validates
    %% the target) MUST run to completion before we commit to 200.
    %%
    %% The 200 response and stream claiming are deferred to finalize/0,
    %% called by the router after registration, so the router stays
    %% responsive during handler init.
    MaxCap = maps:get(max_capsule_size, HOpts,
                      ?MASQUE_DEFAULT_MAX_CAPSULE_SIZE),
    case init_handler(Handler, Req, HOpts) of
        {ok, HState, Actions} ->
            State = #state{conn = Conn, stream_id = StreamId,
                           router = Router, handler = Handler,
                           h_state = HState, req = Req,
                           max_cap = MaxCap,
                           pending_actions = Actions},
            {ok, State};
        {stop, Reason} ->
            {stop, Reason}
    end.

response_headers() ->
    [{<<"capsule-protocol">>, <<"?1">>}].

claim_stream(#state{conn = Conn, stream_id = StreamId,
                    cap_buf = Buf} = S) ->
    case quic_h3:set_stream_handler(Conn, StreamId, self()) of
        ok ->
            {ok, S};
        {ok, Chunks} ->
            More = iolist_to_binary([D || {D, _Fin} <- Chunks]),
            {ok, S#state{cap_buf = <<Buf/binary, More/binary>>}};
        {error, _} = Err ->
            Err
    end.

handle_call(finalize, _From,
            #state{pending_actions = Actions, conn = Conn,
                   stream_id = StreamId} = S)
  when Actions =/= undefined ->
    case quic_h3:send_response(Conn, StreamId, 200,
                                response_headers()) of
        ok ->
            case claim_stream(S#state{pending_actions = undefined}) of
                {ok, State} ->
                    masque_metrics:tunnel_opened(
                        #{protocol => udp, transport => h3}),
                    {reply, ok,
                     run_init_actions(Actions,
                         State#state{start_time =
                             erlang:monotonic_time(millisecond)})};
                {error, _} ->
                    {reply, {error, stream_dead}, S}
            end;
        {error, _} ->
            {reply, {error, stream_dead}, S}
    end;
handle_call(_Req, _From, S) ->
    {reply, {error, unknown_call}, S}.

handle_cast(connection_closed, S) ->
    {stop, connection_closed, S};
handle_cast(_Msg, S) ->
    {noreply, S}.

handle_info({masque_datagram_in, StreamId, Payload},
            #state{stream_id = StreamId} = S) ->
    masque_metrics:bytes_in(byte_size(Payload),
                            #{protocol => udp, transport => h3}),
    case masque_datagram:decode(Payload) of
        {ok, {?MASQUE_CONTEXT_ID_UDP, UdpBytes}}
          when byte_size(UdpBytes) =< ?MASQUE_MAX_UDP_PAYLOAD ->
            dispatch(handle_packet, [UdpBytes], S);
        {ok, {?MASQUE_CONTEXT_ID_UDP, _}} ->
            %% Oversized UDP payload — RFC 9298 §5 says drop.
            {noreply, S};
        {ok, {_Ctx, _Bytes}} ->
            %% Unknown context-id: RFC 9298 §5 says silently drop.
            {noreply, S};
        {error, _} ->
            {noreply, S}
    end;
handle_info({masque_stream_data, StreamId, Data, Fin},
            #state{stream_id = StreamId} = S) ->
    handle_stream_bytes(Data, Fin, S);
handle_info({quic_h3, _Conn, {data, StreamId, Data, Fin}},
            #state{stream_id = StreamId} = S) ->
    handle_stream_bytes(Data, Fin, S);
handle_info({masque_stream_reset, StreamId, _ErrorCode},
            #state{stream_id = StreamId} = S) ->
    {stop, peer_reset, S};
handle_info({'DOWN', _MRef, process, _Pid, _Reason}, S) ->
    %% Router died - clean up
    {stop, router_gone, S};
handle_info(Msg, S) ->
    dispatch(handle_info, [Msg], S).

terminate(normal, #state{conn = Conn, stream_id = StreamId,
                          router = Router,
                          handler = Handler, h_state = HState} = S) ->
    emit_tunnel_closed(S),
    _ = (try quic_h3:send_data(Conn, StreamId, <<>>, true) catch _:_ -> ok end),
    _ = (try masque_server_connection:unregister_session(Router, StreamId) catch _:_ -> ok end),
    try_callback(Handler, terminate, [normal, HState]),
    ok;
terminate(Reason, #state{router = Router, stream_id = StreamId,
                          handler = Handler, h_state = HState} = S)
  when Reason =:= connection_closed;
       Reason =:= router_gone;
       Reason =:= peer_reset ->
    emit_tunnel_closed(S),
    _ = (try masque_server_connection:unregister_session(Router, StreamId) catch _:_ -> ok end),
    try_callback(Handler, terminate, [Reason, HState]),
    ok;
terminate(Reason, #state{router = Router, stream_id = StreamId,
                          handler = Handler, h_state = HState} = S)
  when Reason =:= truncated_capsule;
       Reason =:= capsule_buffer_overflow ->
    emit_tunnel_closed(S),
    _ = (try masque_server_connection:unregister_session(Router, StreamId) catch _:_ -> ok end),
    try_callback(Handler, terminate, [Reason, HState]),
    ok;
terminate(Reason, #state{conn = Conn, stream_id = StreamId,
                          router = Router,
                          handler = Handler, h_state = HState} = S) ->
    emit_tunnel_closed(S),
    _ = (try quic_h3:cancel(Conn, StreamId, ?MASQUE_H3_MESSAGE_ERROR) catch _:_ -> ok end),
    _ = (try masque_server_connection:unregister_session(Router, StreamId) catch _:_ -> ok end),
    try_callback(Handler, terminate, [Reason, HState]),
    ok.

code_change(_OldVsn, S, _Extra) ->
    {ok, S}.

emit_tunnel_closed(#state{start_time = undefined}) -> ok;
emit_tunnel_closed(#state{start_time = T}) ->
    Duration = erlang:monotonic_time(millisecond) - T,
    masque_metrics:tunnel_closed(Duration,
                                 #{protocol => udp, transport => h3}).

%%====================================================================
%% Handler dispatch
%%====================================================================

init_handler(Handler, Req, HOpts) ->
    case exported(Handler, init, 2) of
        true ->
            case safe_apply(Handler, init, [Req, HOpts]) of
                {ok, HState}           -> {ok, HState, []};
                {ok, HState, Actions}  -> {ok, HState, Actions};
                {stop, Reason}         -> {stop, Reason};
                Other                  -> {stop, {bad_init, Other}}
            end;
        false ->
            {ok, undefined, []}
    end.

dispatch(CB, Extra, #state{handler = Handler, h_state = HS} = S) ->
    case exported(Handler, CB, length(Extra) + 1) of
        true ->
            case safe_apply(Handler, CB, Extra ++ [HS]) of
                {ok, HS2}          -> {noreply, S#state{h_state = HS2}};
                {ok, HS2, Actions} -> apply_actions_noreply(
                                        Actions, S#state{h_state = HS2});
                {stop, Reason, HS2} -> {stop, Reason, S#state{h_state = HS2}};
                _                  -> {noreply, S}
            end;
        false ->
            {noreply, S}
    end.

%% `erlang:function_exported/3` returns `false' for modules that have
%% not yet been loaded in this VM - which is the common case for a
%% user-supplied handler module encountered for the first time. Force
%% a load attempt before asking.
exported(Mod, Fun, Arity) ->
    _ = code:ensure_loaded(Mod),
    erlang:function_exported(Mod, Fun, Arity).

apply_actions_noreply(Actions, State) ->
    case do_actions(Actions, State) of
        {ok, S2}          -> {noreply, S2};
        {stop, Reason, S2} -> {stop, Reason, S2}
    end.

run_init_actions([], S) -> S;
run_init_actions(Actions, S) ->
    case do_actions(Actions, S) of
        {ok, S2}          -> S2;
        {stop, Reason, _} -> exit(Reason)
    end.

do_actions([], S) -> {ok, S};
do_actions([{send, Data} | Rest], S) ->
    do_actions([{send, ?MASQUE_CONTEXT_ID_UDP, Data} | Rest], S);
do_actions([{send, Ctx, Data} | Rest], S) ->
    %% Silent drop on oversize - RFC 9298 §5 (HTTP Datagrams are
    %% unreliable; application can resend if it cares).
    PayloadSize = iolist_size(Data),
    Max = quic_h3:max_datagram_size(S#state.conn, S#state.stream_id),
    Overhead = ctx_overhead(Ctx),
    TooBigForUDP = Ctx =:= ?MASQUE_CONTEXT_ID_UDP
                   andalso PayloadSize > ?MASQUE_MAX_UDP_PAYLOAD,
    TooBigForQUIC = Max > 0 andalso (PayloadSize + Overhead) > Max,
    case TooBigForUDP orelse TooBigForQUIC of
        true ->
            do_actions(Rest, S);
        false ->
            Enc = masque_datagram:encode(Ctx, Data),
            _ = quic_h3:send_datagram(S#state.conn, S#state.stream_id, Enc),
            masque_metrics:bytes_out(PayloadSize,
                                     #{protocol => udp, transport => h3}),
            do_actions(Rest, S)
    end;
do_actions([{send_capsule, Type, Value} | Rest], S) ->
    Enc = masque_capsule:encode(Type, Value),
    EncBin = iolist_to_binary(Enc),
    _ = quic_h3:send_data(S#state.conn, S#state.stream_id, EncBin, false),
    masque_metrics:bytes_out(byte_size(EncBin),
                             #{protocol => udp, transport => h3}),
    do_actions(Rest, S);
do_actions([close_session | _Rest], S) ->
    {stop, normal, S};
do_actions([{close_session, _Code, _Msg} | _Rest], S) ->
    {stop, normal, S};
do_actions([_Unknown | Rest], S) ->
    do_actions(Rest, S).

handle_stream_bytes(Data, Fin, #state{cap_buf = Buf,
                                      max_cap = Max} = S) ->
    New = <<Buf/binary, Data/binary>>,
    case byte_size(New) > Max of
        true ->
            reset_and_stop(capsule_buffer_overflow, S);
        false ->
            drain_capsules(New, Fin, S#state{cap_fin_seen = Fin})
    end.

%% Pull every complete capsule out of `Buf` and dispatch it to the
%% handler module. Malformed capsules and a FIN arriving mid-capsule
%% abort the HTTP/3 stream with H3_MESSAGE_ERROR per RFC 9297 §3.3 /
%% RFC 9114 §4.1.2.
drain_capsules(Buf, Fin, S) ->
    case masque_capsule:decode(Buf) of
        {ok, {Type, Value, Rest}} ->
            case dispatch(handle_capsule, [Type, Value], S) of
                {noreply, S2} ->
                    drain_capsules(Rest, Fin, S2#state{cap_buf = <<>>});
                {stop, _, _} = Stop ->
                    Stop
            end;
        {more, _} when Fin, Buf =/= <<>> ->
            %% Stream closed mid-capsule — truncated.
            reset_and_stop(truncated_capsule, S);
        {more, _} ->
            {noreply, S#state{cap_buf = Buf}};
        {error, _Reason} ->
            reset_and_stop(malformed_capsule, S)
    end.

reset_and_stop(Reason, #state{conn = Conn, stream_id = StreamId} = S) ->
    _ = (try quic_h3:cancel(Conn, StreamId, ?MASQUE_H3_MESSAGE_ERROR) catch _:_ -> ok end),
    {stop, Reason, S}.

safe_apply(M, F, A) ->
    try apply(M, F, A)
    catch
        Class:Reason:Stack ->
            error_logger:error_msg(
                "masque handler ~p:~p/~p failed: ~p:~p~n~p~n",
                [M, F, length(A), Class, Reason, Stack]),
            {stop, {handler_crash, Reason}}
    end.

try_callback(Mod, Fun, Args) ->
    Arity = length(Args),
    case erlang:function_exported(Mod, Fun, Arity) of
        true -> (try apply(Mod, Fun, Args) catch _:_ -> ok end);
        false -> ok
    end.

ctx_overhead(V) when V < 64        -> 1;
ctx_overhead(V) when V < 16384     -> 2;
ctx_overhead(V) when V < 1073741824 -> 4;
ctx_overhead(_)                    -> 8.