src/grisp_connect_cluster.erl

-module(grisp_connect_cluster).

-behaviour(gen_server).


%--- Includes -------------------------------------------------------------------

-include_lib("kernel/include/logger.hrl").
-include_lib("grisp/include/grisp.hrl").


%--- Exports -------------------------------------------------------------------

% API Functions
-export([start_link/0]).
-export([system_info/0]).
-export([join/2]).
-export([leave/1]).
-export([list/0]).
-export([is_allowed/1]).

% Behaviour gen_server callback functions
-export([init/1]).
-export([handle_call/3]).
-export([handle_cast/2]).
-export([handle_info/2]).

% Disable dialyzer warnings
-dialyzer({nowarn_function, clear_dist_pem_cache/0}).


%--- Types ---------------------------------------------------------------------

-record(peer, {
    node :: atom(),
    hostname :: binary(),
    address :: inet: ip4_address(),
    cookie :: atom(),
    ca :: binary(),
    fingerprint :: binary(),
    monitor :: boolean(),
    timer_ref :: undefined | reference()
}).

-record(state, {
    peers = #{} :: #{atom() => #peer{}}
}).

-type node_options() :: #{
    hostname := binary(),
    address := inet:ip4_address(),
    cookie := atom(),
    ca := binary(),
    fingerprint := binary(),
    monitor => boolean()
}.


%--- Macros --------------------------------------------------------------------

-define(SERVER, ?MODULE).
-define(RETRY_DELAY, 1000). % ms
-define(FORMAT(FMT, ARGS), iolist_to_binary(io_lib:format(FMT, ARGS))).
-define(FINGERPRINT_TABLE, grisp_connect_cluster_fingerprints).


%--- API FUNCTIONS -------------------------------------------------------------

start_link() ->
    gen_server:start_link({local, ?SERVER}, ?MODULE, [], []).

system_info() ->
    {ok, Hostname} = inet:gethostname(),
    #{
        cluster_enabled => true,
        nodename => node(),
        hostname => list_to_binary(Hostname)
    }.

-spec join(Node :: atom(), Opts :: node_options()) -> true | false | error.
join(Node, Opts) when is_atom(Node), is_map(Opts) ->
    Peer = #peer{
        node = Node,
        hostname = required(binary, hostname, Opts),
        address = required(ipv4, address, Opts),
        cookie = required(atom, cookie, Opts),
        ca = required(pem, ca, Opts),
        fingerprint = required(fingerprint, fingerprint, Opts),
        monitor = optional(bool, false, monitor, Opts)
    },
    gen_server:call(?SERVER, {join, Peer}).

-spec leave(Node :: atom()) -> true | false.
leave(Node) when is_atom(Node) ->
    gen_server:call(?SERVER, {leave, Node}).

-spec list() -> [#{nodename := atom(), connected := boolean()}].
list() ->
    gen_server:call(?SERVER, list).

-spec is_allowed(CertFingerprint :: binary()) -> boolean().
is_allowed(CertFingerprint) ->
    case ets:lookup(?FINGERPRINT_TABLE, CertFingerprint) of
        [] -> false;
        [_] -> true
    end.


%--- Behaviour gen_server Callback Functions -----------------------------------

init([]) ->
    net_kernel:monitor_nodes(true, #{node_type => all, nodedown_reason => true}),
    ets:new(?FINGERPRINT_TABLE, [named_table, protected, set, {keypos, 1}]),
    net_kernel:allow([]), % Prevent connecting to any nodes by default
    {ok, store_ca_certs(store_board_certs(#state{}))}.

handle_call({join, Peer}, _From, State) ->
    case join_node(State, Peer) of
        {ok, IsConnected, State2} -> {reply, IsConnected, State2};
        {error, State2} -> {reply, error, State2}
    end;
handle_call({leave, Node}, _From, State) ->
    {WasConnected, State2} = leave_node(State, Node),
    {reply, WasConnected, State2};
handle_call(list, _From, State) ->
    {reply, list_node(State), State};
handle_call(Msg, From, _State) ->
    gen_server:reply(From, {error, unexpected_call}),
    error({unexpected_call, Msg}).

handle_cast(Msg, _State) ->
    error({unexpected_cast, Msg}).

handle_info({retry_connecting, Node}, State) ->
    case retry_node(State, Node) of
        {ok, _IsConnected, State2} -> {noreply, State2};
        {error, State2} -> {noreply, State2}
    end;
handle_info({nodedown, Node, Info}, State) ->
    {noreply, node_down(State, Node, Info)};
handle_info({nodeup, _Node, _Info}, State) ->
    {noreply, State};
handle_info(Info, State) ->
    ?LOG_WARNING(#{
        description => ?FORMAT("Received unexpected message: ~w", [Info]),
        event => unexpected_message, message => Info}),
    {noreply, State}.


%--- Internal Funcitons --------------------------------------------------------

clear_dist_pem_cache() ->
    % ssl:clear_pem_cache/0 doesn't support distribution, hacking around...
    try gen_server:call(ssl_pem_cache:name(dist),
                        {unconditionally_clear_pem_cache, self()}, infinity)
    catch exit:{noproc,_} ->
        % No distribution PEM cache running
        ok
    end.

required(atom, Key, Map) ->
    case maps:find(Key, Map) of
        {ok, V} when is_atom(V) -> V;
        _ -> error(badarg)
    end;
required(binary, Key, Map) ->
    case maps:find(Key, Map) of
        {ok, V} when is_binary(V) -> V;
        _ -> error(badarg)
    end;
required(fingerprint, Key, Map) ->
    V = required(binary, Key, Map),
    case byte_size(V) of
        32 -> V;
        _ -> error(badarg)
    end;
required(pem, Key, Map) ->
    case maps:find(Key, Map) of
        {ok, V} when is_binary(V) ->
            case public_key:pem_decode(V) of
                [] -> error(badarg);
                _ -> V
            end;
        _ -> error(badarg)
    end;
required(ipv4, Key, Map) ->
    case maps:find(Key, Map) of
        {ok, {A, B, C, D} = IPv4}
          when A >= 0, A =< 256, B >= 0, B =< 256,
               C >= 0, C =< 256, D >= 0, D =< 256 ->
            IPv4;
        _ -> error(badarg)
    end.

optional(bool, Default, Key, Map) ->
    case maps:find(Key, Map) of
        {ok, V} when is_boolean(V) -> V;
        error -> Default;
    _ -> error(badarg)
    end.

store_board_certs(State) ->
    case ?IS_EMULATED of
        true -> State;
        false ->
            DerCert = grisp_cryptoauth:read_cert(primary, der),
            PemCert = der_list_to_pem([DerCert]),
            ok = file:write_file("/etc/board.pem", PemCert),
            State
    end.

store_ca_certs(State = #state{peers = Peers}) ->
    {ok, Filename} = application:get_env(grisp_connect, allowed_ca_chain),
    CAPemItems = unique([P#peer.ca || P <- maps:values(Peers)]),
    Data = lists:join("\n", CAPemItems),
    ok = file:write_file(Filename, Data),
    clear_dist_pem_cache(),
    State.

connect_node(State, Peer = #peer{node = Node}) ->
    case connect_peer(Peer) of
        {ok, Peer2} ->
            ?LOG_NOTICE(#{
                description => ?FORMAT("Joined node ~w cluster", [Node]),
                event => cluster_join, node => Node}),
            {ok, true, set_peer(State, Peer2)};
        {error, Peer2 = #peer{monitor = true}} ->
            ?LOG_DEBUG(#{
                description => ?FORMAT("Failed to join node ~w cluster, postpone connection", [Node]),
                event => cluster_join_postpone, node => Node}),
            {ok, false, set_peer(State, schedule_retry(Peer2))};
        {error, Peer2} ->
            ?LOG_DEBUG(#{
                description => ?FORMAT("Failed to join node ~w cluster", [Node]),
                event => cluster_join_failed, node => Node}),
            Peer3 = unregister_peer(State, Peer2),
            {error, store_ca_certs(del_peer(State, Peer3))}
    end.

join_node(State, Peer = #peer{node = Node}) ->
    {State2, Peer2} = case find_peer(State, Node) of
        {ok, Peer} ->
            {State, Peer};
        {ok, OldPeer} ->
            ?LOG_DEBUG(#{
                description => ?FORMAT("Update node ~w cluster configuration", [Node]),
                event => cluster_update, node => Node}),
            unregister_peer(State, disconnect_peer(OldPeer)),
            NewPeer = register_peer(Peer),
            {store_ca_certs(set_peer(State, NewPeer)), NewPeer};
        error ->
            ?LOG_DEBUG(#{
                description => ?FORMAT("Register node ~w cluster configuration", [Node]),
                event => cluster_register, node => Node}),
            NewPeer = register_peer(Peer),
            {store_ca_certs(set_peer(State, NewPeer)), NewPeer}
    end,
    connect_node(State2, Peer2).

leave_node(State, Node) ->
    case find_peer(State, Node) of
        error ->
            {false, State};
        {ok, Peer} ->
            ?LOG_NOTICE(#{
                description => ?FORMAT("Leaved node ~w cluster", [Node]),
                event => cluster_leave, node => Node}),
            WasConnected = is_peer_connected(Peer),
            Peer2 = unregister_peer(State, disconnect_peer(Peer)),
            {WasConnected, store_ca_certs(del_peer(State, Peer2))}
    end.

list_node(#state{peers = Peers}) ->
    [#{nodename => P#peer.node, connected => is_peer_connected(P)}
     || P <- maps:values(Peers)].

retry_node(State, Node) ->
    case find_peer(State, Node) of
        error ->
            {error, State};
        {ok, Peer} ->
            Peer2 = Peer#peer{timer_ref = undefined},
            connect_node(set_peer(State, Peer2), Peer2)
    end.

node_down(State, Node, Info) ->
    #{node_type := NodeType, nodedown_reason := Reason} = Info,
    case find_peer(State, Node) of
        error -> State;
        {ok, Peer} ->
            ?LOG_ERROR(#{
                description => ?FORMAT("Disconnected from node ~w cluster: ~w", [Node, Reason]),
                event => cluster_disconnected, node => Node,
                node_type => NodeType, reason => Reason}),
            case Peer#peer.monitor of
                true ->
                    set_peer(State, schedule_retry(Peer));
                false ->
                    ?LOG_NOTICE(#{
                        description => ?FORMAT("Leaved node ~w cluster", [Node]),
                        event => cluster_leaved, node => Node}),
                    Peer2 = unregister_peer(State, Peer),
                    store_ca_certs(del_peer(State, Peer2))
            end
    end.

set_peer(State = #state{peers = Peers}, Peer = #peer{node = Node}) ->
    State#state{peers = Peers#{Node => Peer}}.

del_peer(State = #state{peers = Peers}, #peer{node = Node}) ->
    State#state{peers = maps:remove(Node, Peers)}.

find_peer(#state{peers = Peers}, Node) ->
    maps:find(Node, Peers).

schedule_retry(Peer = #peer{node = Node, timer_ref = undefined}) ->
    Ref = erlang:send_after(?RETRY_DELAY, self(), {retry_connecting, Node}),
    Peer#peer{timer_ref = Ref};
schedule_retry(Peer = #peer{timer_ref = Ref}) ->
    erlang:cancel_timer(Ref),
    schedule_retry(Peer#peer{timer_ref = undefined}).

register_peer(Peer = #peer{node = Node, cookie = Cookie,
                           fingerprint = Fingerprint,
                           hostname = Hostname, address = Address}) ->
    ets:insert(?FINGERPRINT_TABLE, {Fingerprint, Node}),
    inet_db:add_host(Address, [binary_to_list(Hostname)]),
    erlang:set_cookie(Node, Cookie),
    net_kernel:allow([Node]),
    Peer.

% We need the full state to be sure to not remove an address used by another peer
unregister_peer(#state{peers = Peers},
                Peer = #peer{node = Node, fingerprint = Fingerprint,
                             address = Address}) ->
    ets:delete(?FINGERPRINT_TABLE, Fingerprint),
    SameAddr = [N || #peer{node = N, address = A} <- maps:values(Peers),
                     N =/= Node, A =:= Address],
    case SameAddr of
        [] -> inet_db:del_host(Address);
        _ -> ok
    end,
    Peer.

connect_peer(Peer = #peer{node = Node}) ->
    case net_adm:ping(Node) of
        pong -> {ok, Peer};
        pang -> {error, Peer}
    end.

disconnect_peer(Peer = #peer{node = Node}) ->
    erlang:disconnect_node(Node),
    Peer.

is_peer_connected(#peer{node = Node}) ->
    lists:member(Node, nodes()).

unique(L) ->
    maps:keys(maps:from_list([{K, true} || K <- L])).

der_list_to_pem(DerCerts) when is_list(DerCerts) ->
    lists:map(fun der_to_pem_block/1, DerCerts).

der_to_pem_block(Der) when is_binary(Der) ->
    Enc64 = base64:encode(Der),
    Wrapped = wrap_base64(Enc64, 64),
    [ "-----BEGIN CERTIFICATE-----\n",
      Wrapped,
      "-----END CERTIFICATE-----\n" ].

wrap_base64(Base64, LineLen) ->
    wrap_lines(Base64, LineLen, []).

wrap_lines(<<>>, _LineLen, Acc) ->
    lists:reverse(Acc);
wrap_lines(Data, LineLen, Acc) ->
    case Data of
        <<Line:LineLen/binary, Rest/binary>> ->
            wrap_lines(Rest, LineLen, ["\n", Line | Acc]);
        LastLine ->
            wrap_lines(<<>>, LineLen, ["\n", LastLine | Acc])
    end.