%%% @doc RFC 9484 §5 — CONNECT-IP capsule codec.
%%%
%%% Encodes and decodes the three CONNECT-IP capsule types on top of
%%% `masque_capsule' (RFC 9297 framing):
%%%
%%% <ul>
%%% <li>`ADDRESS_ASSIGN' (0x01) — Request IDs may be 0 (unprompted).</li>
%%% <li>`ADDRESS_REQUEST' (0x02) — at least one entry, Request IDs
%%% MUST be nonzero and unique per sender.</li>
%%% <li>`ROUTE_ADVERTISEMENT' (0x03) — lexicographic ordering by
%%% (Version, IP Protocol, Start Address); ranges disjoint
%%% within equal (Version, Protocol); protocol-0 ranges MUST NOT
%%% overlap nonzero-protocol ranges for the same Version.</li>
%%% </ul>
%%%
%%% Every validation failure returns `{error, Reason}' — callers map
%%% this into an H3_MESSAGE_ERROR stream reset per RFC 9297 §3.3.
-module(masque_ip_capsule).
-export([encode_address_assign/1,
encode_address_request/1,
encode_route_advertisement/1]).
-export([decode_address_assign/1,
decode_address_request/1,
decode_route_advertisement/1]).
-export([encode/2, decode_body/2]).
-include("masque_ip.hrl").
-type address_entry() :: #ip_assignment{}.
-type request_entry() :: #ip_prefix_request{}.
-type route_entry() :: #ip_route{}.
-type decode_error() ::
truncated
| malformed_varint
| bad_ip_version
| bad_prefix_length
| empty_address_request
| duplicate_request_id
| zero_request_id_in_request
| unordered_routes
| overlapping_routes
| proto_zero_overlap.
-export_type([address_entry/0, request_entry/0, route_entry/0,
decode_error/0]).
%%====================================================================
%% Capsule-level encode (body -> framed capsule)
%%====================================================================
%% @doc Encode a typed capsule onto the RFC 9297 capsule frame.
%% Dispatches on the capsule type atom.
-spec encode(address_assign | address_request | route_advertisement,
[address_entry() | request_entry() | route_entry()]) ->
iodata().
encode(address_assign, Entries) ->
Body = encode_address_assign(Entries),
masque_capsule:encode(?MASQUE_CAPSULE_ADDRESS_ASSIGN, Body);
encode(address_request, Entries) ->
Body = encode_address_request(Entries),
masque_capsule:encode(?MASQUE_CAPSULE_ADDRESS_REQUEST, Body);
encode(route_advertisement, Entries) ->
Body = encode_route_advertisement(Entries),
masque_capsule:encode(?MASQUE_CAPSULE_ROUTE_ADVERTISEMENT, Body).
%% @doc Decode the body bytes of a CONNECT-IP capsule into typed
%% entries. The capsule type is determined by the caller from the
%% capsule frame (e.g. via `masque_capsule:decode/1').
-spec decode_body(non_neg_integer(), binary()) ->
{ok, [address_entry() | request_entry() | route_entry()]}
| {error, decode_error()}.
decode_body(?MASQUE_CAPSULE_ADDRESS_ASSIGN, Body) ->
decode_address_assign(Body);
decode_body(?MASQUE_CAPSULE_ADDRESS_REQUEST, Body) ->
decode_address_request(Body);
decode_body(?MASQUE_CAPSULE_ROUTE_ADVERTISEMENT, Body) ->
decode_route_advertisement(Body);
decode_body(_Type, _Body) ->
%% Unknown capsule types are handled (dropped) by the caller.
{error, {unknown_capsule_type, _Type}}.
%%====================================================================
%% ADDRESS_ASSIGN (0x01) — request_id >= 0 allowed
%%====================================================================
-spec encode_address_assign([address_entry()]) -> binary().
encode_address_assign(Entries) ->
iolist_to_binary([encode_addr_entry(E) || E <- Entries]).
-spec decode_address_assign(binary()) ->
{ok, [address_entry()]} | {error, decode_error()}.
decode_address_assign(Body) ->
try
Entries = decode_addr_entries(Body, fun make_assignment/4, []),
{ok, Entries}
catch
throw:Err -> {error, Err}
end.
make_assignment(ReqId, Ver, Addr, Pfx) ->
#ip_assignment{request_id = ReqId, version = Ver,
address = Addr, prefix_len = Pfx}.
%%====================================================================
%% ADDRESS_REQUEST (0x02) — >=1 entry, all request_id > 0 and unique
%%====================================================================
-spec encode_address_request([request_entry()]) -> binary().
encode_address_request([]) ->
erlang:error(empty_address_request);
encode_address_request(Entries) ->
%% The `#ip_prefix_request{}` record's `request_id` field is typed
%% `pos_integer()`, so the "contains 0" case is unreachable via
%% the public API. Uniqueness still needs a runtime check.
Ids = [R#ip_prefix_request.request_id || R <- Entries],
case length(Ids) =:= length(lists:usort(Ids)) of
true -> ok;
false -> erlang:error(duplicate_request_id)
end,
iolist_to_binary([encode_req_entry(E) || E <- Entries]).
-spec decode_address_request(binary()) ->
{ok, [request_entry()]} | {error, decode_error()}.
decode_address_request(Body) ->
try
Entries = decode_addr_entries(Body, fun make_request/4, []),
case Entries of
[] -> {error, empty_address_request};
_ ->
Ids = [R#ip_prefix_request.request_id || R <- Entries],
case length(Ids) =:= length(lists:usort(Ids)) of
true -> {ok, Entries};
false -> {error, duplicate_request_id}
end
end
catch
throw:Err -> {error, Err}
end.
make_request(0, _Ver, _Addr, _Pfx) ->
throw(zero_request_id_in_request);
make_request(ReqId, Ver, Addr, Pfx) ->
#ip_prefix_request{request_id = ReqId, version = Ver,
address = Addr, prefix_len = Pfx}.
%%====================================================================
%% ROUTE_ADVERTISEMENT (0x03)
%%====================================================================
-spec encode_route_advertisement([route_entry()]) -> binary().
encode_route_advertisement(Entries) ->
ok = validate_routes(Entries),
iolist_to_binary([encode_route_entry(E) || E <- Entries]).
-spec decode_route_advertisement(binary()) ->
{ok, [route_entry()]} | {error, decode_error()}.
decode_route_advertisement(Body) ->
try
Entries = decode_route_entries(Body, []),
case validate_routes_result(Entries) of
ok -> {ok, Entries};
{error, E} -> {error, E}
end
catch
throw:Err -> {error, Err}
end.
%%====================================================================
%% Internal — per-entry encode (ADDRESS_ASSIGN / ADDRESS_REQUEST)
%%====================================================================
encode_addr_entry(#ip_assignment{request_id = Id, version = V,
address = A, prefix_len = P}) ->
encode_addr_tuple(Id, V, A, P);
encode_addr_entry(#ip_prefix_request{request_id = Id, version = V,
address = A, prefix_len = P}) ->
encode_addr_tuple(Id, V, A, P).
encode_req_entry(#ip_prefix_request{request_id = Id, version = V,
address = A, prefix_len = P}) ->
encode_addr_tuple(Id, V, A, P).
encode_addr_tuple(Id, 4, {A,B,C,D}, Pfx) when Pfx >= 0, Pfx =< 32 ->
[quic_varint:encode(Id), <<4:8, A:8, B:8, C:8, D:8, Pfx:8>>];
encode_addr_tuple(Id, 6, Addr, Pfx) when Pfx >= 0, Pfx =< 128,
tuple_size(Addr) =:= 8 ->
[quic_varint:encode(Id), <<6:8>>, v6_bin(Addr), <<Pfx:8>>];
encode_addr_tuple(_Id, _V, _A, _P) ->
erlang:error(bad_ip_version).
v6_bin({A,B,C,D,E,F,G,H}) ->
<<A:16, B:16, C:16, D:16, E:16, F:16, G:16, H:16>>.
%%====================================================================
%% Internal — per-entry decode (ADDRESS_ASSIGN / ADDRESS_REQUEST)
%%====================================================================
decode_addr_entries(<<>>, _Mk, Acc) ->
lists:reverse(Acc);
decode_addr_entries(Bin, Mk, Acc) ->
{ReqId, Rest1} = decode_varint(Bin),
case Rest1 of
<<4:8, A:8, B:8, C:8, D:8, Pfx:8, Rest2/binary>> when Pfx =< 32 ->
case prefix_host_bits_zero(4, {A,B,C,D}, Pfx) of
true ->
decode_addr_entries(Rest2, Mk,
[Mk(ReqId, 4, {A,B,C,D}, Pfx) | Acc]);
false ->
throw(non_canonical_prefix)
end;
<<4:8, _:8, _:8, _:8, _:8, Pfx:8, _/binary>> when Pfx > 32 ->
throw(bad_prefix_length);
<<4:8, _/binary>> ->
throw(truncated);
<<6:8, V6:128/binary-unit:1, Rest3/binary>> ->
<<A:16, B:16, C:16, D:16, E:16, F:16, G:16, H:16>> = V6,
case Rest3 of
<<Pfx:8, Rest4/binary>> when Pfx =< 128 ->
case prefix_host_bits_zero(6,
{A,B,C,D,E,F,G,H}, Pfx) of
true ->
decode_addr_entries(
Rest4, Mk,
[Mk(ReqId, 6,
{A,B,C,D,E,F,G,H}, Pfx) | Acc]);
false ->
throw(non_canonical_prefix)
end;
<<Pfx:8, _/binary>> when Pfx > 128 ->
throw(bad_prefix_length);
_ ->
throw(truncated)
end;
<<Ver:8, _/binary>> when Ver =/= 4, Ver =/= 6 ->
throw(bad_ip_version);
_ ->
throw(truncated)
end.
%%====================================================================
%% Internal — per-entry encode/decode (ROUTE_ADVERTISEMENT)
%%====================================================================
encode_route_entry(#ip_route{version = 4, start_addr = {A,B,C,D},
end_addr = {E,F,G,H}, ip_protocol = P})
when P >= 0, P =< 255 ->
<<4:8, A:8, B:8, C:8, D:8, E:8, F:8, G:8, H:8, P:8>>;
encode_route_entry(#ip_route{version = 6, start_addr = S, end_addr = E,
ip_protocol = P})
when P >= 0, P =< 255, tuple_size(S) =:= 8, tuple_size(E) =:= 8 ->
<<6:8, (v6_bin(S))/binary, (v6_bin(E))/binary, P:8>>;
encode_route_entry(_) ->
erlang:error(bad_ip_version).
decode_route_entries(<<>>, Acc) -> lists:reverse(Acc);
decode_route_entries(<<4:8, A:8, B:8, C:8, D:8,
E:8, F:8, G:8, H:8, P:8, Rest/binary>>, Acc) ->
Route = #ip_route{version = 4,
start_addr = {A,B,C,D},
end_addr = {E,F,G,H},
ip_protocol = P},
decode_route_entries(Rest, [Route | Acc]);
decode_route_entries(<<6:8, S:16/binary, E:16/binary, P:8, Rest/binary>>,
Acc) ->
<<SA:16, SB:16, SC:16, SD:16, SE:16, SF:16, SG:16, SH:16>> = S,
<<EA:16, EB:16, EC:16, ED:16, EE:16, EF:16, EG:16, EH:16>> = E,
Route = #ip_route{version = 6,
start_addr = {SA,SB,SC,SD,SE,SF,SG,SH},
end_addr = {EA,EB,EC,ED,EE,EF,EG,EH},
ip_protocol = P},
decode_route_entries(Rest, [Route | Acc]);
decode_route_entries(<<Ver:8, _/binary>>, _Acc) when Ver =/= 4, Ver =/= 6 ->
throw(bad_ip_version);
decode_route_entries(_, _Acc) ->
throw(truncated).
%%====================================================================
%% Internal — ROUTE_ADVERTISEMENT validation
%%====================================================================
validate_routes(Entries) ->
case validate_routes_result(Entries) of
ok -> ok;
{error, E} -> erlang:error(E)
end.
validate_routes_result([]) -> ok;
validate_routes_result(Entries) ->
case check_each_range(Entries) of
ok ->
case check_sort_and_disjoint(Entries) of
ok -> check_proto_zero_overlap(Entries);
{error, _} = Err -> Err
end;
{error, _} = Err -> Err
end.
%% RFC 9484 §4.7.2: every route advertises a non-empty range, i.e.
%% start_addr =< end_addr.
check_each_range([]) -> ok;
check_each_range([#ip_route{start_addr = S, end_addr = E} | Rest])
when S =< E ->
check_each_range(Rest);
check_each_range(_) ->
{error, route_range_reversed}.
%% Verify the list is sorted by (Version, Protocol, Start) with
%% strict disjointness within equal (Version, Protocol) buckets.
check_sort_and_disjoint([_] = _L) -> ok;
check_sort_and_disjoint([A, B | Rest]) ->
case compare_route(A, B) of
lt -> check_sort_and_disjoint([B | Rest]);
_ -> {error, unordered_routes}
end;
check_sort_and_disjoint([]) -> ok.
%% Strict "A < B" order across different (V, P) buckets is pure
%% lexicographic; within the same (V, P) bucket it additionally
%% requires E1 < S2 (disjoint ranges).
compare_route(#ip_route{version = V1}, #ip_route{version = V2}) when V1 < V2 ->
lt;
compare_route(#ip_route{version = V1}, #ip_route{version = V2}) when V1 > V2 ->
gt;
compare_route(#ip_route{ip_protocol = P1}, #ip_route{ip_protocol = P2})
when P1 < P2 -> lt;
compare_route(#ip_route{ip_protocol = P1}, #ip_route{ip_protocol = P2})
when P1 > P2 -> gt;
compare_route(#ip_route{end_addr = E1}, #ip_route{start_addr = S2}) ->
%% Same (V, P) bucket — ranges must be strictly disjoint and
%% ordered by start address, which is equivalent to E1 < S2.
case E1 < S2 of
true -> lt;
false -> overlap
end.
%% For each Version, protocol-0 ranges must not overlap any nonzero-
%% protocol range for the same Version (RFC 9484 §4.7.3).
check_proto_zero_overlap(Entries) ->
ByVer = lists:foldl(fun(#ip_route{version = V} = R, Acc) ->
maps:update_with(V,
fun(L) -> [R | L] end,
[R], Acc)
end, #{}, Entries),
maps:fold(fun(_V, Rs, ok) -> check_version_zero_overlap(Rs);
(_V, _Rs, Err) -> Err
end, ok, ByVer).
check_version_zero_overlap(Rs) ->
{Zeros, NonZeros} = lists:partition(
fun(#ip_route{ip_protocol = P}) -> P =:= 0 end, Rs),
case Zeros of
[] -> ok;
_ -> check_zero_vs_nonzero(Zeros, NonZeros)
end.
check_zero_vs_nonzero(_Zeros, []) -> ok;
check_zero_vs_nonzero(Zeros, [NZ | Rest]) ->
case lists:any(fun(Z) -> ranges_overlap(Z, NZ) end, Zeros) of
true -> {error, proto_zero_overlap};
false -> check_zero_vs_nonzero(Zeros, Rest)
end.
ranges_overlap(#ip_route{start_addr = S1, end_addr = E1},
#ip_route{start_addr = S2, end_addr = E2}) ->
%% Inclusive ranges overlap iff max(S) =< min(E).
max(S1, S2) =< min(E1, E2).
%% RFC 9484 §4.6: ADDRESS_ASSIGN/REQUEST prefixes must be canonical
%% (host bits zero).
prefix_host_bits_zero(_V, _IP, 0) -> true;
prefix_host_bits_zero(4, {A,B,C,D}, Pfx) when Pfx =< 32 ->
N = (A bsl 24) bor (B bsl 16) bor (C bsl 8) bor D,
HostBits = 32 - Pfx,
(N band ((1 bsl HostBits) - 1)) =:= 0;
prefix_host_bits_zero(6, {A,B,C,D,E,F,G,H}, Pfx) when Pfx =< 128 ->
N = (A bsl 112) bor (B bsl 96) bor (C bsl 80) bor (D bsl 64)
bor (E bsl 48) bor (F bsl 32) bor (G bsl 16) bor H,
HostBits = 128 - Pfx,
(N band ((1 bsl HostBits) - 1)) =:= 0.
%%====================================================================
%% Internal — varint wrapper
%%====================================================================
decode_varint(Bin) ->
try
quic_varint:decode(Bin)
catch
error:_ -> throw(malformed_varint)
end.