Skip to main content

src/masque_ip_session_registry.erl

%%% @doc Cross-session address registry for CONNECT-IP.
%%%
%%% Stores `{Version, StartAddr, EndAddr, Prefix} -> {SessionPid,
%%% ContextId}' assignments so an out-of-band consumer (e.g. a TUN
%%% device owner) can answer "which session serves destination X?"
%%% without reaching into per-session state.
%%%
%%% Storage is a single ETS `ordered_set' keyed by
%%% `{Version, StartIntAddr}'. Lookup does interval inclusion: find
%%% the entry with the largest start address `=<' the probe and
%%% verify the end address `>=' the probe. Because registrations are
%%% rejected when their range overlaps an existing entry, at most one
%%% interval can cover any given address.
%%%
%%% The gen_server itself is only on the write path (register /
%%% release / monitor handling). Lookups go directly through ETS so
%%% read traffic does not serialise on the server.
%%%
%%% All write APIs are tolerant of the registry not being started -
%%% they become no-ops when `whereis(?MODULE) =:= undefined'. This
%%% keeps test environments that don't boot the application from
%%% needing extra setup, and keeps the proxy handler simple (it
%%% always calls `register/5' / `release/3' without a feature flag).
-module(masque_ip_session_registry).
-behaviour(gen_server).

-export([start_link/0]).
-export([register/5, release/3, release_pid/1, lookup/1, all/0]).

-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
         terminate/2, code_change/3]).

-define(TABLE, ?MODULE).

-record(state, {monitors = #{} :: #{reference() => [tuple()]}}).

%%====================================================================
%% Public API
%%====================================================================

start_link() ->
    gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).

%% @doc Register a host address or prefix as served by `SessionPid'.
%% Returns `{error, conflict}' if the proposed range overlaps any
%% existing registration; otherwise `ok'. Idempotent for an exact
%% repeat by the same pid.
-spec register(4 | 6,
               inet:ip4_address() | inet:ip6_address(),
               non_neg_integer(),
               pid(),
               non_neg_integer()) -> ok | {error, conflict}.
register(V, Addr, Pfx, Pid, ContextId)
  when (V =:= 4 orelse V =:= 6),
       is_integer(Pfx), is_pid(Pid), is_integer(ContextId) ->
    case whereis(?MODULE) of
        undefined -> ok;
        _Server   ->
            gen_server:call(?MODULE,
                            {register, V, Addr, Pfx, Pid, ContextId})
    end.

%% @doc Release a previously-registered range. No-op if the range was
%% not registered.
-spec release(4 | 6,
              inet:ip4_address() | inet:ip6_address(),
              non_neg_integer()) -> ok.
release(V, Addr, Pfx) when V =:= 4; V =:= 6 ->
    case whereis(?MODULE) of
        undefined -> ok;
        _Server   -> gen_server:call(?MODULE, {release, V, Addr, Pfx})
    end.

%% @doc Release every range owned by `Pid'. Used by the registry's
%% own `'DOWN'' handler; exposed for tests and explicit cleanup.
-spec release_pid(pid()) -> ok.
release_pid(Pid) when is_pid(Pid) ->
    case whereis(?MODULE) of
        undefined -> ok;
        _Server   -> gen_server:call(?MODULE, {release_pid, Pid})
    end.

%% @doc Look up the session serving `IP'. Reads ETS directly.
-spec lookup(inet:ip_address()) ->
    {ok, pid(), non_neg_integer()} | not_found.
lookup(IP) ->
    try
        case normalise(IP) of
            {V, Int} -> lookup_int(V, Int);
            error    -> not_found
        end
    catch
        error:badarg -> not_found      %% table absent
    end.

%% @doc Snapshot of every registration. Intended for diagnostics.
-spec all() -> [{4 | 6, non_neg_integer(), non_neg_integer(),
                 non_neg_integer(), pid(), non_neg_integer()}].
all() ->
    try
        [{V, Start, End, Pfx, Pid, Ctx}
         || {{V, Start}, End, Pfx, Pid, Ctx, _MRef} <- ets:tab2list(?TABLE)]
    catch
        error:badarg -> []
    end.

%%====================================================================
%% gen_server callbacks
%%====================================================================

init([]) ->
    _ = ets:new(?TABLE,
                [ordered_set, named_table, public,
                 {read_concurrency, true}, {write_concurrency, true}]),
    {ok, #state{}}.

handle_call({register, V, Addr, Pfx, Pid, Ctx}, _From, S) ->
    case to_int(V, Addr) of
        error ->
            {reply, {error, bad_address}, S};
        Start ->
            End = range_end(Start, V, Pfx),
            Key = {V, Start},
            case overlap_check(V, Start, End) of
                ok ->
                    MRef = erlang:monitor(process, Pid),
                    true = ets:insert(?TABLE,
                                      {Key, End, Pfx, Pid, Ctx, MRef}),
                    Mon = maps:get(MRef, S#state.monitors, []),
                    Mons = (S#state.monitors)#{
                        MRef => [{V, Start, Pfx} | Mon]},
                    {reply, ok, S#state{monitors = Mons}};
                conflict ->
                    {reply, {error, conflict}, S}
            end
    end;
handle_call({release, V, Addr, Pfx}, _From, S) ->
    case to_int(V, Addr) of
        error ->
            {reply, ok, S};
        Start ->
            S2 = release_key({V, Start}, Pfx, S),
            {reply, ok, S2}
    end;
handle_call({release_pid, Pid}, _From, S) ->
    S2 = release_by_pid(Pid, S),
    {reply, ok, S2};
handle_call(_Other, _From, S) ->
    {reply, {error, bad_call}, S}.

handle_cast(_Msg, S) -> {noreply, S}.

handle_info({'DOWN', MRef, process, _Pid, _Reason},
            #state{monitors = Mons} = S) ->
    case maps:take(MRef, Mons) of
        {Ranges, Mons1} ->
            lists:foreach(fun({V, Start, _Pfx}) ->
                                  _ = ets:delete(?TABLE, {V, Start}),
                                  masque_metrics:ip_release_inc()
                          end, Ranges),
            {noreply, S#state{monitors = Mons1}};
        error ->
            {noreply, S}
    end;
handle_info(_Other, S) -> {noreply, S}.

terminate(_Reason, _S) -> ok.

code_change(_OldVsn, S, _Extra) -> {ok, S}.

%%====================================================================
%% Internal
%%====================================================================

normalise({_,_,_,_} = A) -> {4, to_int(4, A)};
normalise({_,_,_,_,_,_,_,_} = A) -> {6, to_int(6, A)};
normalise(_) -> error.

to_int(4, {A,B,C,D}) when A >= 0, A =< 255, B >= 0, B =< 255,
                          C >= 0, C =< 255, D >= 0, D =< 255 ->
    (A bsl 24) bor (B bsl 16) bor (C bsl 8) bor D;
to_int(6, {A,B,C,D,E,F,G,H}) ->
    (A bsl 112) bor (B bsl 96) bor (C bsl 80) bor (D bsl 64)
    bor (E bsl 48) bor (F bsl 32) bor (G bsl 16) bor H;
to_int(_, _) -> error.

range_end(Start, 4, Pfx) when Pfx >= 0, Pfx =< 32 ->
    Start bor ((1 bsl (32 - Pfx)) - 1);
range_end(Start, 6, Pfx) when Pfx >= 0, Pfx =< 128 ->
    Start bor ((1 bsl (128 - Pfx)) - 1).

%% Check whether a proposed range conflicts with any existing entry.
%% Returns `ok', `{idempotent, MRef}' (exact same key already there
%% pointing at... anything; we currently treat any exact-key match as
%% conflict to keep the contract strict), or `conflict'.
overlap_check(V, Start, End) ->
    Key = {V, Start},
    case ets:lookup(?TABLE, Key) of
        [{Key, _End, _Pfx, _Pid, _Ctx, _MRef}] ->
            conflict;
        [] ->
            case ets:prev(?TABLE, Key) of
                '$end_of_table' ->
                    overlap_check_next(V, Key, End);
                {V, _} = Prev ->
                    case ets:lookup(?TABLE, Prev) of
                        [{Prev, PrevEnd, _, _, _, _}] when PrevEnd >= Start ->
                            conflict;
                        _ ->
                            overlap_check_next(V, Key, End)
                    end;
                _OtherVer ->
                    overlap_check_next(V, Key, End)
            end
    end.

overlap_check_next(V, Key, End) ->
    case ets:next(?TABLE, Key) of
        '$end_of_table' -> ok;
        {V, NextStart} when NextStart =< End -> conflict;
        _ -> ok
    end.

lookup_int(V, Int) ->
    Probe = {V, Int + 1},
    case ets:prev(?TABLE, Probe) of
        '$end_of_table' -> not_found;
        {V, _} = Key ->
            case ets:lookup(?TABLE, Key) of
                [{Key, End, _Pfx, Pid, Ctx, _MRef}] when End >= Int ->
                    {ok, Pid, Ctx};
                _ ->
                    not_found
            end;
        _OtherVer ->
            not_found
    end.

release_key({V, Start} = Key, Pfx, #state{monitors = Mons} = S) ->
    case ets:lookup(?TABLE, Key) of
        [{Key, _End, Pfx, _Pid, _Ctx, MRef}] ->
            true = ets:delete(?TABLE, Key),
            masque_metrics:ip_release_inc(),
            Mons1 = case maps:find(MRef, Mons) of
                error ->
                    Mons;
                {ok, Ranges} ->
                    Ranges1 = [R || R <- Ranges, R =/= {V, Start, Pfx}],
                    case Ranges1 of
                        [] ->
                            erlang:demonitor(MRef, [flush]),
                            maps:remove(MRef, Mons);
                        _ ->
                            Mons#{MRef => Ranges1}
                    end
            end,
            S#state{monitors = Mons1};
        _ ->
            S
    end.

release_by_pid(Pid, #state{monitors = Mons} = S) ->
    {Mons1, Released} = maps:fold(
        fun(MRef, Ranges, {Acc, Drop}) ->
            case ranges_for_pid(Pid, Ranges) of
                [] ->
                    {Acc#{MRef => Ranges}, Drop};
                MatchedAll when MatchedAll =:= Ranges ->
                    erlang:demonitor(MRef, [flush]),
                    {Acc, Drop ++ MatchedAll};
                _Some ->
                    {Acc#{MRef => Ranges}, Drop}
            end
        end, {#{}, []}, Mons),
    lists:foreach(fun({V, Start, _Pfx}) ->
                          _ = ets:delete(?TABLE, {V, Start}),
                          masque_metrics:ip_release_inc()
                  end, Released),
    S#state{monitors = Mons1}.

ranges_for_pid(Pid, Ranges) ->
    [R || R <- Ranges,
          {V, Start, _Pfx0} <- [R],
          case ets:lookup(?TABLE, {V, Start}) of
              [{_, _End, _Pfx1, P, _Ctx, _MRef}] when P =:= Pid -> true;
              _ -> false
          end].