Skip to main content

src/masque_tcp_proxy_handler.erl

%%% @doc Built-in MASQUE handler that bridges a CONNECT-TCP tunnel to
%%% a real TCP connection on the server.
%%%
%%% For every accepted tunnel the handler resolves the target host,
%%% opens a `gen_tcp' connection, and relays bytes both ways:
%%%
%%% <ul>
%%%   <li>Client-to-target: `handle_data/2' writes bytes to the TCP socket.</li>
%%%   <li>Target-to-client: `{tcp, Socket, Bytes}' messages arrive on the
%%%       session and are emitted as `{send_data, Bytes}' actions.</li>
%%%   <li>Target FIN: `{tcp_closed, Socket}' closes the tunnel.</li>
%%% </ul>
%%%
%%% Accepts the same policy hooks as the UDP proxy (`allow', `resolver',
%%% `family'), plus `connect_timeout' (default 5000 ms).
-module(masque_tcp_proxy_handler).
-behaviour(masque_handler).

-export([accept/1, init/2, handle_data/2, handle_eof/1,
         handle_info/2, terminate/2]).

-record(state, {
    socket    :: gen_tcp:socket(),
    eof_timer :: reference() | undefined
}).

%%====================================================================
%% Behaviour callbacks
%%====================================================================

-spec accept(masque_handler:req()) -> masque_handler:accept_result().
accept(#{target_host := Host, target_port := Port} = Req) ->
    Opts = maps:get(handler_opts, Req, #{}),
    AllowFun = maps:get(allow, Opts, fun(_) -> true end),
    case AllowFun({Host, Port}) of
        true  -> accept;
        false -> {reject, forbidden}
    end.

-spec init(masque_handler:req(), term()) -> {ok, #state{}} | {stop, term()}.
init(#{target_host := Host, target_port := Port}, Opts) ->
    ResolverFun = maps:get(resolver, Opts, fun default_resolver/1),
    Family = pick_family(maps:get(family, Opts, auto), Host),
    ConnTimeout = maps:get(connect_timeout, Opts, 5000),
    AllowPrivate = maps:get(allow_private, Opts, false),
    case resolve(ResolverFun, Host) of
        {ok, IP} ->
            case AllowPrivate orelse masque_ip:is_public(IP) of
                false ->
                    {stop, {resolution_failed, private_address}};
                true ->
                    TcpOpts = [binary, {active, true}, Family
                               | maps:get(socket_opts, Opts, [])],
                    case gen_tcp:connect(IP, Port, TcpOpts, ConnTimeout) of
                        {ok, Socket} ->
                            {ok, #state{socket = Socket}};
                        {error, Reason} ->
                            {stop, {resolution_failed, {tcp_connect, Reason}}}
                    end
            end;
        {error, Reason} ->
            {stop, {resolution_failed, {resolve, Reason}}}
    end.

-spec handle_data(binary(), #state{}) -> {ok, #state{}} | {stop, term(), #state{}}.
handle_data(Data, #state{socket = S} = State) ->
    case gen_tcp:send(S, Data) of
        ok ->
            {ok, State};
        {error, closed} ->
            {stop, target_closed, State};
        {error, Reason} ->
            {stop, {target_error, Reason}, State}
    end.

-spec handle_eof(#state{}) -> {ok, #state{}} | {stop, term(), #state{}}.
handle_eof(#state{socket = S} = State) ->
    _ = gen_tcp:shutdown(S, write),
    TRef = erlang:send_after(30000, self(), eof_timeout),
    {ok, State#state{eof_timer = TRef}}.

-spec handle_info(term(), #state{}) -> {ok, #state{}} | {ok, #state{}, [term()]} | {stop, term(), #state{}}.
handle_info({tcp, Socket, Bytes}, #state{socket = Socket} = State) ->
    {ok, State, [{send_data, Bytes}]};
handle_info({tcp_closed, Socket}, #state{socket = Socket,
                                          eof_timer = TRef} = State) ->
    _ = case TRef of
            undefined -> ok;
            _         -> erlang:cancel_timer(TRef)
        end,
    {stop, target_closed, State};
handle_info(eof_timeout, State) ->
    {stop, eof_timeout, State};
handle_info({tcp_error, Socket, Reason}, #state{socket = Socket} = State) ->
    {stop, {target_error, Reason}, State};
handle_info(_Other, State) ->
    {ok, State}.

-spec terminate(term(), #state{}) -> ok.
terminate(_Reason, #state{socket = S}) ->
    _ = gen_tcp:close(S),
    ok.

%%====================================================================
%% Helpers
%%====================================================================

default_resolver(Host) when is_binary(Host) ->
    default_resolver(binary_to_list(Host));
default_resolver(Host) when is_list(Host) ->
    case inet_parse:address(Host) of
        {ok, IP} -> {ok, IP};
        _        ->
            case inet:getaddr(Host, inet) of
                {ok, IP}   -> {ok, IP};
                {error, _} -> inet:getaddr(Host, inet6)
            end
    end.

pick_family(inet, _Host)  -> inet;
pick_family(inet6, _Host) -> inet6;
pick_family(auto, Host) ->
    HostStr = if is_binary(Host) -> binary_to_list(Host); true -> Host end,
    case inet:parse_address(HostStr) of
        {ok, {_, _, _, _, _, _, _, _}} -> inet6;
        _ -> inet
    end.

resolve(ResolverFun, Host) ->
    ResolverFun(Host).