src/cets_dist_blocker.erl

%% @doc Disallow distributed erlang connections until cleaning is done.
%%
%% This module prevents a node from reconnecting, until cleaning activity is
%% finished. It prevents race conditions.
%%
%% This module assumes all nodes share the same cookie.
-module(cets_dist_blocker).
-behaviour(gen_server).
-include_lib("kernel/include/logger.hrl").

%% API
-export([
    start_link/0,
    add_cleaner/1,
    cleaning_done/2
]).

%% gen_server callbacks
-export([
    init/1,
    handle_call/3,
    handle_cast/2,
    handle_info/2,
    terminate/2,
    code_change/3
]).

-ignore_xref([
    start_link/0,
    add_cleaner/1,
    cleaning_done/2
]).

-type cleaner_pid() :: pid().
-type waiting() :: [{node(), cleaner_pid()}].

-type state() :: #{
    cleaners := [cleaner_pid()],
    waiting := waiting()
}.

%% @doc Spawn `dist_blocker'
start_link() ->
    gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).

%% @doc Register CleanerPid as a cleaner.
-spec add_cleaner(pid()) -> ok.
add_cleaner(CleanerPid) ->
    gen_server:call(?MODULE, {add_cleaner, CleanerPid}).

%% @doc Confirm that cleaning is done.
%%
%% This function should be called by a cleaner when it receives
%% nodedown and finishes cleaning.
-spec cleaning_done(pid(), node()) -> ok.
cleaning_done(CleanerPid, Node) ->
    gen_server:call(?MODULE, {cleaning_done, CleanerPid, Node}).

%%--------------------------------------------------------------------
%% gen_server callbacks
%%--------------------------------------------------------------------
init([]) ->
    ok = net_kernel:monitor_nodes(true),
    State = #{cleaners => [], waiting => []},
    State2 = lists:foldl(fun handle_nodeup/2, State, nodes()),
    {ok, State2}.

handle_call({add_cleaner, CleanerPid}, _From, State) ->
    {reply, ok, handle_add_cleaner(CleanerPid, State)};
handle_call({cleaning_done, CleanerPid, Node}, _From, State) ->
    {reply, ok, maybe_unblock(State, handle_cleaning_done(CleanerPid, Node, State))};
handle_call(Request, _From, State) ->
    ?LOG_ERROR(#{what => unexpected_call, msg => Request}),
    {reply, {error, unexpected_call}, State}.

handle_cast(Msg, State) ->
    ?LOG_ERROR(#{what => unexpected_cast, msg => Msg}),
    {noreply, State}.

handle_info({nodeup, Node}, State) ->
    {noreply, handle_nodeup(Node, State)};
handle_info({nodedown, Node}, State) ->
    {noreply, handle_nodedown(Node, State)};
handle_info({'DOWN', _Ref, process, Pid, _Info}, State) ->
    {noreply, maybe_unblock(State, handle_cleaner_down(Pid, State))};
handle_info(Info, State) ->
    ?LOG_ERROR(#{what => unexpected_info, msg => Info}),
    {noreply, State}.

terminate(_Reason, State) ->
    %% Restore cookies
    _ = maybe_unblock(State, State#{waiting := []}),
    ok.

code_change(_OldVsn, State, _Extra) ->
    {ok, State}.

%%--------------------------------------------------------------------
%% internal functions
%%--------------------------------------------------------------------

-spec handle_nodeup(node(), state()) -> state().
handle_nodeup(Node, State) ->
    %% We change the cookie as soon as the node is connected.
    %% Alternative is to do it on nodedown, but because nodedown-s are async,
    %% we would have a high chance of race conditions (so, node could reconnect
    %% before we set cookie).
    erlang:set_cookie(Node, blocking_cookie()),
    State.

%% Make cookie, that would prevent node from connecting
-spec blocking_cookie() -> atom().
blocking_cookie() ->
    list_to_atom(atom_to_list(erlang:get_cookie()) ++ "_blocked_by_" ++ atom_to_list(node())).

%% Allow the node to connect to us again
-spec unblock_node(node(), state()) -> state().
unblock_node(Node, State) ->
    erlang:set_cookie(Node, erlang:get_cookie()),
    State.

-spec handle_nodedown(node(), state()) -> state().
handle_nodedown(Node, State = #{cleaners := []}) ->
    %% Skip waiting when no cleaners
    unblock_node(Node, State);
handle_nodedown(Node, State = #{cleaners := Cleaners, waiting := Waiting}) ->
    New = [{Node, CleanerPid} || CleanerPid <- Cleaners],
    State#{waiting := lists:usort(New ++ Waiting)}.

-spec handle_add_cleaner(cleaner_pid(), state()) -> state().
handle_add_cleaner(CleanerPid, State = #{cleaners := Cleaners}) ->
    erlang:monitor(process, CleanerPid),
    State#{cleaners := lists:usort([CleanerPid | Cleaners])}.

-spec handle_cleaning_done(cleaner_pid(), node(), state()) -> state().
handle_cleaning_done(CleanerPid, Node, State = #{waiting := Waiting}) ->
    State#{waiting := lists:delete({Node, CleanerPid}, Waiting)}.

-spec handle_cleaner_down(cleaner_pid(), state()) -> state().
handle_cleaner_down(CleanerPid, State = #{cleaners := Cleaners, waiting := Waiting}) ->
    State#{
        cleaners := lists:delete(CleanerPid, Cleaners),
        waiting := [X || {_Node, CleanerPid2} = X <- Waiting, CleanerPid =/= CleanerPid2]
    }.

%% Unblock nodes when the last cleaner confirms the cleaning is done.
%% Call this function each time you remove entries from the waiting list.
-spec maybe_unblock(state(), state()) -> state().
maybe_unblock(_OldState = #{waiting := OldWaiting}, NewState = #{waiting := NewWaiting}) ->
    OldNodes = cast_waiting_to_nodes(OldWaiting),
    NewNodes = cast_waiting_to_nodes(NewWaiting),
    CleanedNodes = OldNodes -- NewNodes,
    lists:foldl(fun unblock_node/2, NewState, CleanedNodes).

-spec cast_waiting_to_nodes(waiting()) -> [node()].
cast_waiting_to_nodes(Waiting) ->
    lists:usort([Node || {Node, _CleanerPid} <- Waiting]).