Skip to main content

src/nquic_protocol_migration.erl

-module(nquic_protocol_migration).
-moduledoc """
Connection migration and PMTUD for the QUIC protocol state.

Pure functions over `#conn_state{}` covering RFC 9000 Section 9
(connection migration: peer-address update, path validation outcome,
preferred-address handoff) and RFC 8899 (Packetisation Layer Path MTU
Discovery: probe acknowledgement, loss, and black-hole reaction).
Extracted from `nquic_protocol` as part of REVIEW_PLAN.md Phase 4.4.

External side effects are limited to delegating into `nquic_path`,
`nquic_pmtud`, `nquic_loss`, `nquic_socket`, and `nquic_protocol_cid`.
PATH_CHALLENGE frames produced during preferred-address migration are
queued through `nquic_protocol_send_queues:queue_app_frame/2`.
""".

-include("nquic_conn.hrl").
-include("nquic_path.hrl").
-include("nquic_transport.hrl").
-export([
    apply_peer_update/2,
    check_migration_allowed/1,
    complete_migration/1,
    handle_preferred_address/2,
    revert_migration/1,
    select_preferred_peer/2
]).
-export([
    enable_pmtud/1,
    maybe_detect_black_hole/1,
    pmtud_on_black_hole/1,
    pmtud_on_probe_acked/1,
    pmtud_on_probe_lost/1
]).

-define(BLACK_HOLE_PTO_THRESHOLD, 3).

%%%-----------------------------------------------------------------------------
%% PATH VALIDATION / CONNECTION MIGRATION (RFC 9000 SECTION 9)
%%%-----------------------------------------------------------------------------
-spec apply_peer_update(nquic_socket:sockaddr(), nquic_protocol:state()) -> nquic_protocol:state().
apply_peer_update(Source, #conn_state{path = Path0} = State) ->
    PS = Path0#conn_path_mgmt.path_state,
    case nquic_path:detect_peer_change(PS, Source) of
        unchanged ->
            State;
        {changed, NewPS} ->
            State#conn_state{
                peer = NewPS#path_state.peer,
                path = Path0#conn_path_mgmt{path_state = NewPS}
            }
    end.

-spec check_migration_allowed(nquic_protocol:state()) -> ok | {error, nquic_error:any_reason()}.
check_migration_allowed(#conn_state{
    remote_params = RemoteParams,
    path = #conn_path_mgmt{peer_cids = PeerCids},
    dcid = DCID
}) ->
    DisableFlag =
        case RemoteParams of
            undefined -> false;
            #transport_params{disable_active_migration = V} -> V
        end,
    Available = maps:filter(
        fun(_Seq, #{cid := CID}) -> CID =/= DCID end,
        PeerCids
    ),
    maybe
        false ?= DisableFlag,
        true ?= map_size(Available) > 0,
        ok
    else
        true -> {error, migration_disabled};
        false -> {error, no_available_cids}
    end.

-spec complete_migration(nquic_protocol:state()) ->
    {ok, nquic_protocol:state()} | {error, no_available_cids}.
complete_migration(#conn_state{path = Path0, loss_state = LossState} = State) ->
    PS = Path0#conn_path_mgmt.path_state,
    case nquic_protocol_cid:rotate_dcid(State) of
        {error, no_available_cids} ->
            {error, no_available_cids};
        {ok, State0} ->
            State1 =
                case nquic_path:is_nat_rebinding(PS) of
                    true ->
                        State0;
                    false ->
                        Algorithm = nquic_loss:get_cc_algorithm(LossState),
                        PacerCfg = nquic_loss:pacer_config(LossState),
                        State0#conn_state{
                            loss_state = nquic_loss:init(Algorithm, PacerCfg)
                        }
                end,
            Path1 = State1#conn_state.path,
            NewPath = Path1#conn_path_mgmt{
                address_validated = true,
                anti_amp_bytes_sent = 0,
                anti_amp_bytes_received = 0
            },
            {ok, State1#conn_state{path = NewPath}}
    end.

-doc """
Handle preferred address migration (RFC 9000 Section 9.6).
Installs the CID from the preferred_address transport parameter at
sequence 1, selects the matching address family, and initiates path
validation by queuing a PATH_CHALLENGE frame.
Returns the updated state and a path validation timeout action.
""".
-spec handle_preferred_address(nquic_transport:preferred_address(), nquic_protocol:state()) ->
    {ok, nquic_protocol:state(), [nquic_protocol:timeout_action()]}.
handle_preferred_address(PA, #conn_state{path = Path0, peer = Peer} = State) ->
    PS = Path0#conn_path_mgmt.path_state,
    PeerCids = Path0#conn_path_mgmt.peer_cids,
    #{cid := NewCID, stateless_reset_token := Token} = PA,
    TargetPeer = select_preferred_peer(PA, Peer),
    PeerCids1 = PeerCids#{1 => #{cid => NewCID, token => Token}},
    {NewPS, ChallengeFrame} = nquic_path:initiate_validation(PS, TargetPeer),
    NewPath = Path0#conn_path_mgmt{peer_cids = PeerCids1, path_state = NewPS},
    State2 = State#conn_state{peer = TargetPeer, path = NewPath},
    {ok, State3} = nquic_protocol_send_queues:queue_app_frame(ChallengeFrame, State2),
    PVTimeout = nquic_protocol_timer:compute_path_validation_timeout(State3),
    {ok, State3, [{set_timer, path_validation, PVTimeout}]}.

%%%-----------------------------------------------------------------------------
%% PMTUD (RFC 8899)
%%%-----------------------------------------------------------------------------
-doc "Enable PMTUD on the connection. Starts probe search.".
-spec enable_pmtud(nquic_protocol:state()) -> nquic_protocol:state().
enable_pmtud(#conn_state{pmtud = undefined} = State) ->
    PS = nquic_pmtud:enable(nquic_pmtud:new()),
    State#conn_state{pmtud = PS};
enable_pmtud(#conn_state{pmtud = PS} = State) ->
    State#conn_state{pmtud = nquic_pmtud:enable(PS)}.

-spec maybe_detect_black_hole(nquic_protocol:state()) -> nquic_protocol:state().
maybe_detect_black_hole(#conn_state{pmtud = undefined} = State) ->
    State;
maybe_detect_black_hole(#conn_state{pmtud = PS, loss_state = LS} = State) ->
    PtoCount = nquic_loss:get_pto_count(LS),
    CurrentMTU = nquic_pmtud:get_current_mtu(PS),
    case PtoCount >= ?BLACK_HOLE_PTO_THRESHOLD andalso CurrentMTU > 1200 of
        true -> pmtud_on_black_hole(State);
        false -> State
    end.

-doc "Called on black hole detection. Reverts to BASE_PLPMTU.".
-spec pmtud_on_black_hole(nquic_protocol:state()) -> nquic_protocol:state().
pmtud_on_black_hole(#conn_state{pmtud = undefined} = State) ->
    State;
pmtud_on_black_hole(#conn_state{pmtud = PS, loss_state = LS} = State) ->
    PS1 = nquic_pmtud:on_black_hole(PS),
    LS1 = nquic_loss:set_max_datagram_size(LS, 1200),
    State#conn_state{pmtud = PS1, loss_state = LS1}.

-doc "Called when a PMTUD probe was acknowledged. Updates MTU if search progresses.".
-spec pmtud_on_probe_acked(nquic_protocol:state()) -> nquic_protocol:state().
pmtud_on_probe_acked(#conn_state{pmtud = undefined} = State) ->
    State;
pmtud_on_probe_acked(#conn_state{pmtud = PS, loss_state = LS} = State) ->
    PS1 = nquic_pmtud:on_probe_acked(PS),
    NewMTU = nquic_pmtud:get_current_mtu(PS1),
    OldMTU = nquic_pmtud:get_current_mtu(PS),
    LS1 =
        case NewMTU > OldMTU of
            true -> nquic_loss:set_max_datagram_size(LS, NewMTU);
            false -> LS
        end,
    State#conn_state{pmtud = PS1, loss_state = LS1}.

-doc "Called when a PMTUD probe was lost.".
-spec pmtud_on_probe_lost(nquic_protocol:state()) -> nquic_protocol:state().
pmtud_on_probe_lost(#conn_state{pmtud = undefined} = State) ->
    State;
pmtud_on_probe_lost(#conn_state{pmtud = PS} = State) ->
    State#conn_state{pmtud = nquic_pmtud:on_probe_lost(PS)}.

-spec revert_migration(nquic_protocol:state()) -> nquic_protocol:state().
revert_migration(#conn_state{path = #conn_path_mgmt{path_state = PS}} = State) ->
    State#conn_state{peer = PS#path_state.peer}.

-spec select_preferred_peer(nquic_transport:preferred_address(), nquic_socket:sockaddr()) ->
    nquic_socket:sockaddr().
select_preferred_peer(#{ipv4 := {IPv4, V4Port}, ipv6 := {IPv6, V6Port}}, CurrentPeer) ->
    case maps:get(family, CurrentPeer, inet) of
        inet -> nquic_socket:make_sockaddr(IPv4, V4Port);
        inet6 -> nquic_socket:make_sockaddr(IPv6, V6Port)
    end.