Skip to main content

src/masque_udp_bind_proxy_handler.erl

%%% @doc Default Connect-UDP-Bind handler. Owns the per-session
%%% upstream `gen_udp' socket, computes the `Proxy-Public-Address'
%%% list for the response, gates inbound packets via
%%% `peer_filter_fun', and emits actions for the bind session to
%%% put on the wire.
%%%
%%% This module ships a `masque_handler'-shaped callback set
%%% (`init/2', `handle_info/2', `terminate/2') plus the
%%% `handle_bind_packet/3' entry the bind session calls when a
%%% datagram arrives from the client. `handle_bind_packet/3' is not
%%% on the existing `masque_handler' behaviour because no other
%%% protocol uses it; the bind sessions invoke it directly.
%%%
%%% Configurable via `handler_opts':
%%%
%%% <ul>
%%%   <li>`bind_address :: inet:ip_address() | any' - which
%%%       interface to bind to. Default `any'.</li>
%%%   <li>`bind_port :: inet:port_number()' - default `0'
%%%       (kernel-assigned ephemeral).</li>
%%%   <li>`bind_socket_opts :: [gen_udp:option()]' - merged on top of
%%%       `[binary, {active, true}]'.</li>
%%%   <li>`public_addresses :: [{ip_address(), port()}]' -
%%%       list emitted on `Proxy-Public-Address'. Required if the
%%%       socket is bound to a wildcard address; otherwise sockname
%%%       is the fallback.</li>
%%%   <li>`public_address_fun :: fun((sockname()) -> [{ip, port}])'
%%%       - alternative to the static list. Takes precedence when
%%%       set.</li>
%%%   <li>`peer_filter_fun :: fun((ip(), port()) -> ok | {drop, atom()})'
%%%       - per-packet egress policy. Default rejects RFC 1918,
%%%       link-local, and multicast unless `allow_private => true';
%%%       loopback is allowed by default for testability.</li>
%%%   <li>`scrub_fun :: fun((Packet, State) -> {pass, Packet, State} |
%%%                                            {drop, Reason, State})'
%%%       - data-plane policy hook for DDoS scrubbing or other
%%%       per-packet filtering. Default identity.</li>
%%% </ul>
-module(masque_udp_bind_proxy_handler).

-export([init/2, handle_bind_packet/3, handle_info/2, terminate/2]).

-include("masque_udp_bind.hrl").

-record(state, {
    socket               :: gen_udp:socket(),
    public_addresses     :: [{inet:ip_address(), inet:port_number()}],
    advertised_families  :: [4 | 6],
    peer_filter_fun      :: fun((inet:ip_address(),
                                 inet:port_number()) ->
                                ok | {drop, atom()}),
    scrub_fun            :: fun((binary(), term()) ->
                                {pass, binary(), term()}
                              | {drop, atom(), term()}),
    user_state           :: term()
}).

-type opts() :: map().
-type ip_port() :: {inet:ip_address(), inet:port_number()}.

%%====================================================================
%% Lifecycle
%%====================================================================

-spec init(masque_handler:req(), opts()) ->
    {ok, #state{}, [term()]} | {stop, term()}.
init(_Req, Opts) ->
    BindAddr = maps:get(bind_address, Opts, any),
    BindPort = maps:get(bind_port, Opts, 0),
    SocketOpts = [binary, {active, true},
                  {ip, BindAddr}
                  | maps:get(bind_socket_opts, Opts, [])],
    case gen_udp:open(BindPort, SocketOpts) of
        {ok, Socket} ->
            case resolve_public_addresses(Socket, Opts) of
                {ok, Addresses} ->
                    Families = families(Addresses),
                    State = #state{
                        socket = Socket,
                        public_addresses = Addresses,
                        advertised_families = Families,
                        peer_filter_fun =
                            maps:get(peer_filter_fun, Opts,
                                     fun default_peer_filter/2),
                        scrub_fun =
                            maps:get(scrub_fun, Opts,
                                     fun default_scrub/2),
                        user_state = maps:get(user_state, Opts, undefined)
                    },
                    Headers = response_headers(Addresses),
                    {ok, State, [{response_headers, Headers}]};
                {error, Reason} ->
                    _ = gen_udp:close(Socket),
                    {stop, Reason}
            end;
        {error, Reason} ->
            {stop, {udp_open, Reason}}
    end.

-spec terminate(term(), #state{}) -> ok.
terminate(_Reason, #state{socket = S}) ->
    _ = gen_udp:close(S),
    ok.

%%====================================================================
%% Inbound from the client (already decoded by the session)
%%====================================================================

%% @doc The bind session calls this with the decoded peer tuple and
%% the inner UDP payload. Returns either an `ok' continuation or a
%% drop with a reason that the session can attribute via metrics.
-spec handle_bind_packet(ip_port(), binary(), #state{}) ->
    {ok, #state{}} | {drop, atom(), #state{}}.
handle_bind_packet({IP, Port}, Payload, #state{} = S0)
  when is_binary(Payload) ->
    case (S0#state.peer_filter_fun)(IP, Port) of
        ok ->
            scrub_then_send(IP, Port, Payload, S0);
        {drop, Reason} ->
            {drop, Reason, S0}
    end.

scrub_then_send(IP, Port, Payload, #state{} = S0) ->
    case (S0#state.scrub_fun)(Payload, S0#state.user_state) of
        {pass, ScrubbedPayload, US} ->
            S1 = S0#state{user_state = US},
            send_to_peer(IP, Port, ScrubbedPayload, S1);
        {drop, Reason, US} ->
            {drop, Reason, S0#state{user_state = US}}
    end.

send_to_peer(IP, Port, Payload, #state{socket = Sock} = S) ->
    case gen_udp:send(Sock, IP, Port, Payload) of
        ok                -> {ok, S};
        {error, _Reason}  -> {drop, socket_error, S}
    end.

%%====================================================================
%% {udp, ...} from kernel: outbound to the client
%%====================================================================

-spec handle_info(term(), #state{}) ->
    {ok, #state{}} | {ok, #state{}, [term()]} | {stop, term(), #state{}}.
handle_info({udp, Socket, FromIP, FromPort, Bytes},
            #state{socket = Socket} = S) ->
    %% Drop frames whose source family is not advertised. This
    %% defends against weird kernel behaviour and against the proxy's
    %% own bind socket receiving packets from a family the client
    %% isn't expecting.
    case lists:member(family_of(FromIP), S#state.advertised_families) of
        false ->
            {ok, S};
        true ->
            {ok, S, [{send_bind_packet,
                      {FromIP, FromPort}, Bytes}]}
    end;
handle_info({udp_passive, Socket}, #state{socket = Socket} = S) ->
    _ = inet:setopts(Socket, [{active, true}]),
    {ok, S};
handle_info({udp_error, Socket, Reason}, #state{socket = Socket} = S) ->
    {stop, {bind_socket_error, Reason}, S};
handle_info({udp_closed, Socket}, #state{socket = Socket} = S) ->
    {stop, bind_socket_closed, S};
handle_info(_Other, S) ->
    {ok, S}.

%%====================================================================
%% Public-address resolution
%%====================================================================

resolve_public_addresses(Socket, Opts) ->
    case {maps:find(public_address_fun, Opts),
          maps:find(public_addresses, Opts)} of
        {{ok, Fun}, _} when is_function(Fun, 1) ->
            case inet:sockname(Socket) of
                {ok, Sn} ->
                    case Fun(Sn) of
                        []      -> {error, no_public_addresses};
                        Addrs when is_list(Addrs) -> {ok, Addrs}
                    end;
                {error, Reason} ->
                    {error, {sockname, Reason}}
            end;
        {error, {ok, []}} ->
            {error, no_public_addresses};
        {error, {ok, Addrs}} when is_list(Addrs) ->
            {ok, Addrs};
        {error, error} ->
            sockname_fallback(Socket)
    end.

%% sockname is the fallback only when bound to a specific interface.
%% If the socket is bound to a wildcard (0.0.0.0 / ::), refuse - that
%% address is not usable as a public address.
sockname_fallback(Socket) ->
    case inet:sockname(Socket) of
        {ok, {{0,0,0,0}, _}} ->
            {error, no_public_addresses};
        {ok, {{0,0,0,0,0,0,0,0}, _}} ->
            {error, no_public_addresses};
        {ok, {Addr, Port}} ->
            {ok, [{Addr, Port}]};
        {error, Reason} ->
            {error, {sockname, Reason}}
    end.

response_headers(Addresses) ->
    [masque_uri_udp_bind:format_bind_header(),
     {?MASQUE_HF_PROXY_PUBLIC_ADDRESS,
      masque_uri_udp_bind:format_proxy_public_address(Addresses)}].

families(Addresses) ->
    lists:usort([family_of(IP) || {IP, _} <- Addresses]).

family_of({_, _, _, _})             -> 4;
family_of({_, _, _, _, _, _, _, _}) -> 6.

%%====================================================================
%% Default policies
%%====================================================================

%% Reject RFC 1918, link-local, multicast. Loopback is allowed
%% (CT scaffolding runs the upstream peer on 127.0.0.1).
default_peer_filter(IP, _Port) ->
    case is_private(IP) of
        true  -> {drop, peer_filter};
        false -> ok
    end.

is_private({127, _, _, _})              -> false;     %% loopback ok
is_private({10, _, _, _})               -> true;
is_private({172, B, _, _}) when B >= 16, B =< 31 -> true;
is_private({192, 168, _, _})            -> true;
is_private({169, 254, _, _})            -> true;       %% link-local
is_private({A, _, _, _}) when A >= 224, A =< 239 -> true;  %% multicast
is_private({0, 0, 0, 0, 0, 0, 0, 1})    -> false;      %% v6 loopback
is_private({16#FE80, _, _, _, _, _, _, _}) -> true;    %% v6 link-local
is_private({16#FF00, _, _, _, _, _, _, _}) -> true;    %% v6 multicast
is_private({A, _, _, _, _, _, _, _})
  when A >= 16#FC00, A =< 16#FDFF -> true;            %% v6 ULA
is_private(_) -> false.

default_scrub(Packet, State) ->
    {pass, Packet, State}.