%%% @doc Default CONNECT-IP proxy handler (RFC 9484).
%%%
%%% Responsibilities:
%%% <ul>
%%% <li>Allocate addresses from the configured `address_pool' in
%%% response to `ADDRESS_REQUEST' capsules (round-robin).</li>
%%% <li>Emit the initial `ROUTE_ADVERTISEMENT' combining the
%%% configured static `routes' with any `resolved_addresses'
%%% populated by the listener's DNS step.</li>
%%% <li>BCP-38 source-address filtering on inbound packets (via
%%% `masque_ip:is_public/1').</li>
%%% <li>Hand each accepted IP packet to the user-supplied
%%% `forward_fun' (default: drop).</li>
%%% </ul>
%%%
%%% Phase 2 replaces this handler with `masque_ip_tun_proxy_handler'
%%% which owns a TUN device instead of a forwarder-fun.
-module(masque_ip_proxy_handler).
-behaviour(masque_handler).
-export([accept/1, init/2,
handle_ip_packet/2,
handle_address_request/2,
handle_address_assign/2,
handle_route_advertisement/2,
terminate/2]).
%% Public helpers reused by downstream consumers (TUN/router) so they
%% emit the same drop counter and lifecycle events as the default
%% handler.
-export([emit_drop/2, emit_drop/3]).
-include("masque_ip.hrl").
-record(state, {
opts :: map(),
resolved = [] :: [inet:ip_address()],
%% Address pool: either a single prefix or a list of prefixes.
pools = [] :: [#ip_route{}],
%% Already-assigned addresses (tagged with the IP version).
assigned = [] :: [ip_assignment_tuple()],
%% Negotiated URI scope: target / ipproto from the request line.
%% `'*'' on either axis means "any" and skips the per-packet check.
target = '*' :: masque_uri_ip:ip_target(),
ipproto = '*' :: masque_uri_ip:ip_ipproto()
}).
-type ip_assignment_tuple() ::
{4, inet:ip4_address(), 0..32}
| {6, inet:ip6_address(), 0..128}.
%%====================================================================
%% accept — SSRF gate using the resolved-addresses list
%%====================================================================
accept(Req) ->
Opts = maps:get(handler_opts, Req, #{}),
Allow = maps:get(allow_private, Opts, false),
Addrs = maps:get(resolved_addresses, Req, []),
case Allow orelse Addrs =:= [] orelse
lists:all(fun masque_ip:is_public/1, Addrs) of
true -> accept;
false -> {reject, forbidden}
end.
%%====================================================================
%% init — publish the initial ROUTE_ADVERTISEMENT
%%====================================================================
init(Req, Opts) ->
Resolved = maps:get(resolved_addresses, Req, []),
Pools = normalize_pools(maps:get(address_pool, Opts, [])),
StaticRoutes = maps:get(routes, Opts, []),
ResolvedRoutes = [route_for(A) || A <- Resolved],
Routes = lists:usort(StaticRoutes ++ ResolvedRoutes),
Target = maps:get(ip_target, Req, '*'),
IPProto = maps:get(ip_ipproto, Req, '*'),
S = #state{opts = Opts, resolved = Resolved, pools = Pools,
target = Target, ipproto = IPProto},
case Routes of
[] ->
{ok, S};
_ ->
masque_metrics:ip_advertise_inc(),
invoke_lifecycle(Opts, route_advertised, #{routes => Routes}, Opts),
{ok, S, [{advertise, Routes}]}
end.
%%====================================================================
%% ADDRESS_REQUEST — round-robin allocator from the pool
%%====================================================================
handle_address_request(Requests, #state{} = S) ->
{Entries, S1} = allocate_or_reject(Requests, S),
{ok, S1, [{assign, Entries}]}.
emit_assigned(#ip_assignment{version = V, address = A,
prefix_len = Pfx} = E,
#state{opts = Opts}) ->
masque_metrics:ip_assign_inc(),
invoke_lifecycle(Opts, address_assigned,
#{version => V, address => A, prefix_len => Pfx,
entry => E},
Opts).
allocate_or_reject(Requests, #state{pools = []} = S) ->
%% No pool configured — reject everything per RFC 9484 §5.2.
{masque_ip:reject_requests(Requests), S};
allocate_or_reject(Requests, #state{} = S) ->
lists:mapfoldl(fun allocate_one/2, S, Requests).
allocate_one(#ip_prefix_request{request_id = Id, version = V,
prefix_len = ReqPfx},
#state{opts = Opts} = S) ->
%% RFC 9484 §4.6: the proxy MAY answer with the same prefix
%% length the client asked for, or with a more specific (longer)
%% one. The `min_assignable_prefix' opt sets the widest prefix
%% the proxy is willing to give out per IP family.
Pfx = effective_prefix(V, ReqPfx, Opts),
case next_free(V, Pfx, S) of
{ok, Addr, S1} ->
Entry = #ip_assignment{request_id = Id, version = V,
address = Addr, prefix_len = Pfx},
register_with_registry(V, Addr, Pfx, S1),
emit_assigned(Entry, S1),
{Entry, S1};
none ->
%% Pool exhausted — single-entry rejection.
Req = #ip_prefix_request{request_id = Id, version = V,
address = zero_addr(V),
prefix_len = max_prefix(V)},
[Reject] = masque_ip:reject_requests([Req]),
{Reject, S}
end.
effective_prefix(V, ReqPfx, Opts) ->
Min = min_assignable(V, Opts),
Max = max_prefix(V),
%% Clamp into [Min, Max]. RFC 9484 says we may return a longer
%% (= numerically larger) prefix, never a wider one than Min.
Cand = case is_integer(ReqPfx) andalso ReqPfx >= 0
andalso ReqPfx =< Max of
true -> ReqPfx;
false -> Max
end,
erlang:max(Cand, Min).
min_assignable(V, Opts) ->
Default = max_prefix(V),
case maps:get(min_assignable_prefix, Opts, undefined) of
undefined -> Default;
Map when is_map(Map) -> maps:get(V, Map, Default);
N when is_integer(N) -> N
end.
register_with_registry(V, Addr, Pfx, #state{opts = Opts}) ->
%% The handler runs inside the session's process, so `self()' is
%% the session pid. Context id 0 is the IP datagram context per
%% RFC 9484 §6 (matches `MASQUE_CONTEXT_ID_IP'). Both can be
%% overridden via opts for embedded uses.
Pid = maps:get(session_pid, Opts, self()),
Ctx = maps:get(ip_context_id, Opts, ?MASQUE_CONTEXT_ID_IP),
_ = masque_ip_session_registry:register(V, Addr, Pfx, Pid, Ctx),
ok.
next_free(V, Pfx, #state{pools = Pools, assigned = Assigned} = S) ->
case pick_pool(V, Pools) of
undefined -> none;
Pool ->
case iter_pool(V, Pfx, Pool, Assigned) of
{ok, Addr} ->
{ok, Addr,
S#state{assigned = [{V, Addr, Pfx} | Assigned]}};
exhausted -> none
end
end.
pick_pool(_V, []) -> undefined;
pick_pool(V, [#ip_route{version = V} = P | _]) -> P;
pick_pool(V, [_ | Rest]) -> pick_pool(V, Rest).
iter_pool(V, Pfx,
#ip_route{start_addr = StartAddr, end_addr = EndAddr},
Assigned) ->
Max = max_prefix(V),
Stride = 1 bsl (Max - Pfx),
StartInt = align_up(addr_to_int(V, StartAddr), Stride),
EndInt = addr_to_int(V, EndAddr),
iter_range_strided(V, StartInt, EndInt, Pfx, Stride, Assigned).
iter_range_strided(_V, Cur, End, _Pfx, _Stride, _Assigned)
when Cur > End ->
exhausted;
iter_range_strided(V, Cur, End, Pfx, Stride, Assigned) ->
%% A candidate range covers [Cur, Cur + Stride - 1] in int space.
Last = Cur + Stride - 1,
case Last > End of
true -> exhausted;
false ->
case overlaps_assigned(V, Cur, Last, Assigned) of
true ->
iter_range_strided(V, Cur + Stride, End, Pfx, Stride,
Assigned);
false ->
{ok, int_to_addr(V, Cur)}
end
end.
overlaps_assigned(V, S, E, Assigned) ->
Max = max_prefix(V),
lists:any(
fun({V0, A, P0}) when V0 =:= V ->
AStart = addr_to_int(V, A),
AEnd = AStart + (1 bsl (Max - P0)) - 1,
max(S, AStart) =< min(E, AEnd);
(_) -> false
end, Assigned).
align_up(N, Stride) when Stride > 0 ->
Mask = Stride - 1,
(N + Mask) band (bnot Mask).
addr_to_int(4, {A,B,C,D}) ->
(A bsl 24) bor (B bsl 16) bor (C bsl 8) bor D;
addr_to_int(6, {A,B,C,D,E,F,G,H}) ->
(A bsl 112) bor (B bsl 96) bor (C bsl 80) bor (D bsl 64)
bor (E bsl 48) bor (F bsl 32) bor (G bsl 16) bor H.
int_to_addr(4, N) ->
{(N bsr 24) band 16#FF, (N bsr 16) band 16#FF,
(N bsr 8) band 16#FF, N band 16#FF};
int_to_addr(6, N) ->
{(N bsr 112) band 16#FFFF, (N bsr 96) band 16#FFFF,
(N bsr 80) band 16#FFFF, (N bsr 64) band 16#FFFF,
(N bsr 48) band 16#FFFF, (N bsr 32) band 16#FFFF,
(N bsr 16) band 16#FFFF, N band 16#FFFF}.
%% Turn an `address_pool' option (a prefix, a route, or a list of
%% these) into a list of `#ip_route{}` ranges we can allocate from.
normalize_pools(Pool) when is_list(Pool) ->
[normalize_pool(P) || P <- Pool];
normalize_pools(Pool) ->
[normalize_pool(Pool)].
normalize_pool(#ip_route{} = R) -> R;
normalize_pool({4, {A,B,C,D}, Pfx}) when Pfx >= 0, Pfx =< 32 ->
{Start, End} = prefix_range_v4({A,B,C,D}, Pfx),
#ip_route{version = 4, start_addr = Start, end_addr = End,
ip_protocol = 0};
normalize_pool({6, Addr, Pfx}) when Pfx >= 0, Pfx =< 128,
tuple_size(Addr) =:= 8 ->
{Start, End} = prefix_range_v6(Addr, Pfx),
#ip_route{version = 6, start_addr = Start, end_addr = End,
ip_protocol = 0}.
prefix_range_v4(Addr, Pfx) ->
N = ip_int(Addr, 32),
Mask = (1 bsl 32) - (1 bsl (32 - Pfx)),
Start = int_ip(N band Mask, 32, 4),
End = int_ip((N band Mask) bor ((1 bsl (32 - Pfx)) - 1), 32, 4),
{Start, End}.
prefix_range_v6(Addr, Pfx) ->
N = ip_int(Addr, 128),
Mask = (1 bsl 128) - (1 bsl (128 - Pfx)),
Start = int_ip(N band Mask, 128, 6),
End = int_ip((N band Mask) bor ((1 bsl (128 - Pfx)) - 1), 128, 6),
{Start, End}.
ip_int({A,B,C,D}, 32) ->
(A bsl 24) bor (B bsl 16) bor (C bsl 8) bor D;
ip_int({A,B,C,D,E,F,G,H}, 128) ->
(A bsl 112) bor (B bsl 96) bor (C bsl 80) bor (D bsl 64)
bor (E bsl 48) bor (F bsl 32) bor (G bsl 16) bor H.
int_ip(N, 32, 4) ->
{(N bsr 24) band 16#FF, (N bsr 16) band 16#FF,
(N bsr 8) band 16#FF, N band 16#FF};
int_ip(N, 128, 6) ->
{(N bsr 112) band 16#FFFF, (N bsr 96) band 16#FFFF,
(N bsr 80) band 16#FFFF, (N bsr 64) band 16#FFFF,
(N bsr 48) band 16#FFFF, (N bsr 32) band 16#FFFF,
(N bsr 16) band 16#FFFF, N band 16#FFFF}.
zero_addr(4) -> {0,0,0,0};
zero_addr(6) -> {0,0,0,0,0,0,0,0}.
max_prefix(4) -> 32;
max_prefix(6) -> 128.
%%====================================================================
%% Data-plane: forward_fun
%%====================================================================
handle_ip_packet(Packet, #state{opts = Opts} = S) ->
case accept_inbound(Packet, S) of
ok ->
forward(Packet, S);
{drop, Reason} ->
emit_drop(Reason, drop_detail(Packet), Opts),
{ok, S}
end.
%% RFC 9484 §5: the proxy MUST drop packets that fail BCP-38 source
%% filtering or fall outside the negotiated `target' / `ipproto'
%% scope. Returns the first failing axis so the drop counter and the
%% lifecycle hook can attribute the cause.
accept_inbound(Packet, #state{target = Target, ipproto = IPProto} = S) ->
case src_filter_passes(Packet, S) of
false ->
{drop, bcp38};
true ->
case masque_ip_packet:scope_check(Packet, Target, IPProto) of
ok -> ok;
{error, Reason} -> {drop, Reason}
end
end.
forward(Packet, #state{opts = Opts} = S) ->
case maps:find(forward_fun, Opts) of
{ok, Fun} when is_function(Fun, 2) ->
case Fun(Packet, S) of
%% Backward-compat shapes -----------------------------
{reply, RepPkt, S2} ->
{ok, S2, [{send_ip_packet, RepPkt}]};
{drop, S2} ->
emit_drop(forward_drop, drop_detail(Packet), Opts),
{ok, S2};
{forward, S2} ->
{ok, S2};
ok ->
{ok, S};
{error, _} ->
{ok, S};
%% New action-list shape ------------------------------
%% Lets a forward_fun emit multiple effects in one call
%% (e.g. ICMP error + drop). Recognised actions match
%% the existing IP-server-session interpreter:
%% {send_ip_packet, binary()}
%% {icmp_error, {Kind, Spec, Invoking}}
%% {drop, atom()} % telemetry only
{actions, Actions, S2} when is_list(Actions) ->
{Wire, _} = process_forward_actions(Actions, Packet, Opts),
{ok, S2, Wire}
end;
error ->
{ok, S}
end.
process_forward_actions(Actions, Packet, Opts) ->
lists:foldl(
fun({drop, Reason}, {Wire, Drops}) ->
emit_drop(Reason, drop_detail(Packet), Opts),
{Wire, [Reason | Drops]};
(Action, {Wire, Drops}) ->
{Wire ++ [Action], Drops}
end, {[], []}, Actions).
%%====================================================================
%% Drop emit / lifecycle hook
%%====================================================================
%% @doc Bump the drop counter and invoke `lifecycle_fun' if configured
%% in handler opts. Public so a TUN/router consumer that runs its own
%% data path can drive the same telemetry without re-implementing it.
-spec emit_drop(atom(), map()) -> ok.
emit_drop(Reason, Detail) ->
masque_metrics:ip_drop_inc(Reason),
invoke_lifecycle(Detail, packet_dropped, Detail#{reason => Reason}, #{}).
-spec emit_drop(atom(), map(), map()) -> ok.
emit_drop(Reason, Detail, Opts) ->
masque_metrics:ip_drop_inc(Reason),
invoke_lifecycle(Opts, packet_dropped, Detail#{reason => Reason}, Opts).
drop_detail(Packet) when is_binary(Packet) ->
#{packet_size => byte_size(Packet)}.
invoke_lifecycle(_Carrier, Event, Detail, Opts) ->
case maps:find(lifecycle_fun, Opts) of
{ok, Fun} when is_function(Fun, 2) ->
try Fun(Event, Detail) of
_ -> ok
catch
_:_ -> ok
end;
_ ->
ok
end.
%% BCP-38-style source check: reject packets whose source address
%% doesn't match one the proxy actually assigned to this client.
%% (`allow_private' skips the check, matching the accept/1 gate.)
src_filter_passes(_Packet, #state{opts = Opts, assigned = []}) ->
maps:get(allow_private, Opts, true);
src_filter_passes(<<4:4, _/bitstring>> = Packet, #state{assigned = Assigned}) ->
case Packet of
<<_:12/binary, SA:8, SB:8, SC:8, SD:8, _/binary>> ->
lists:any(fun({4, {A,B,C,D}, _}) ->
{A,B,C,D} =:= {SA,SB,SC,SD};
(_) -> false
end, Assigned);
_ -> false
end;
src_filter_passes(<<6:4, _/bitstring>> = Packet, #state{assigned = Assigned}) ->
case Packet of
<<_:8/binary, SA:16, SB:16, SC:16, SD:16,
SE:16, SF:16, SG:16, SH:16, _/binary>> ->
lists:any(fun({6, {A,B,C,D,E,F,G,H}, _}) ->
{A,B,C,D,E,F,G,H} =:=
{SA,SB,SC,SD,SE,SF,SG,SH};
(_) -> false
end, Assigned);
_ -> false
end;
src_filter_passes(_, _) ->
false.
%%====================================================================
%% Peer-initiated control-plane (bidirectional per §8.2)
%%====================================================================
handle_address_assign(_Entries, S) -> {ok, S}.
handle_route_advertisement(_Entries, S) -> {ok, S}.
terminate(_Reason, #state{assigned = Assigned, opts = Opts}) ->
[release_one(Entry, Opts) || Entry <- Assigned],
ok.
release_one({V, Addr, Pfx}, Opts) ->
_ = masque_ip_session_registry:release(V, Addr, Pfx),
invoke_lifecycle(Opts, address_released,
#{version => V, address => Addr, prefix_len => Pfx},
Opts),
ok.
%%====================================================================
%% Helpers
%%====================================================================
route_for({_,_,_,_} = A) ->
#ip_route{version = 4, start_addr = A, end_addr = A, ip_protocol = 0};
route_for({_,_,_,_,_,_,_,_} = A) ->
#ip_route{version = 6, start_addr = A, end_addr = A, ip_protocol = 0}.