Skip to main content

src/wa_raft_transport_worker.erl

%%% Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
%%%
%%% This source code is licensed under the Apache 2.0 license found in
%%% the LICENSE file in the root directory of this source tree.

-module(wa_raft_transport_worker).
-compile(warn_missing_spec_all).
-behaviour(gen_server).

-include_lib("wa_raft/include/wa_raft.hrl").
-include_lib("wa_raft/include/wa_raft_logger.hrl").

%% OTP supervision
-export([
    child_spec/2,
    start_link/2
]).

%% gen_server callbacks
-export([
    init/1,
    handle_call/3,
    handle_cast/2,
    handle_info/2,
    terminate/2
]).

-define(CONTINUE_TIMEOUT, 0).

-record(state, {
    node :: node(),
    number :: non_neg_integer(),
    jobs = queue:new() :: queue:queue(job()),
    states = #{} :: #{module() => state()}
}).
-type state() :: #state{}.

-record(transport, {
    id :: wa_raft_transport:transport_id(),
    table :: wa_raft:table() | undefined
}).
-record(file, {
    id :: wa_raft_transport:transport_id(),
    table :: wa_raft:table() | undefined,
    file :: wa_raft_transport:file_id()
}).
-type job() :: #transport{} | #file{}.

%%% ------------------------------------------------------------------------
%%%  OTP supervision callbacks
%%%

-spec child_spec(Node :: node(), Number :: non_neg_integer()) -> supervisor:child_spec().
child_spec(Node, Number) ->
    #{
        id => {?MODULE, Node, Number},
        start => {?MODULE, start_link, [Node, Number]},
        restart => permanent,
        shutdown => 5000,
        modules => [?MODULE]
    }.

-spec start_link(Node :: node(), Number :: non_neg_integer()) -> gen_server:start_ret().
start_link(Node, Number) ->
    gen_server:start_link(?MODULE, {Node, Number}, []).

%%% ------------------------------------------------------------------------
%%%  gen_server callbacks
%%%

-spec init(Args :: {node(), non_neg_integer()}) -> {ok, State :: state(), Timeout :: timeout()}.
init({Node, Number}) ->
    {ok, #state{node = Node, number = Number}, ?CONTINUE_TIMEOUT}.

-spec handle_call(Request :: term(), From :: {Pid :: pid(), Tag :: term()}, State :: state()) ->
    {noreply, NewState :: state(), Timeout :: timeout()}.
handle_call(Request, From, #state{number = Number} = State) ->
    ?RAFT_LOG_WARNING("[~p] received unrecognized call ~p from ~p", [Number, Request, From]),
    {noreply, State, ?CONTINUE_TIMEOUT}.

-spec handle_cast(Request, State :: state()) -> {noreply, NewState :: state(), Timeout :: timeout()}
    when Request :: {notify, wa_raft_transport:transport_id(), wa_raft:table() | undefined}.
handle_cast({notify, ID, Table}, #state{jobs = Jobs} = State) ->
    {noreply, State#state{jobs = queue:in(#transport{id = ID, table = Table}, Jobs)}, ?CONTINUE_TIMEOUT};
handle_cast(Request, #state{number = Number} = State) ->
    ?RAFT_LOG_WARNING("[~p] received unrecognized cast ~p", [Number, Request]),
    {noreply, State, ?CONTINUE_TIMEOUT}.

-spec handle_info(Info :: term(), State :: state()) ->
      {noreply, NewState :: state()}
    | {noreply, NewState :: state(), Timeout :: timeout() | hibernate}.
handle_info(timeout, #state{number = Number, jobs = Jobs, states = States} = State) ->
    case queue:out(Jobs) of
        {empty, NewJobs} ->
            {noreply, State#state{jobs = NewJobs}, hibernate};
        {{value, #transport{id = ID, table = Table}}, NewJobs} ->
            case wa_raft_transport:pop_file(ID) of
                {ok, FileID} ->
                    ?RAFT_COUNT(Table, 'transport.file.send'),
                    wa_raft_transport:update_file_info(ID, FileID,
                        fun (Info) -> Info#{status => sending, start_ts => erlang:system_time(millisecond)} end),
                    NewJob = #file{id = ID, table = Table, file = FileID},
                    {noreply, State#state{jobs = queue:in(NewJob, NewJobs)}, ?CONTINUE_TIMEOUT};
                _Other ->
                    {noreply, State#state{jobs = NewJobs}, ?CONTINUE_TIMEOUT}
            end;
        {{value, #file{id = ID, file = FileID} = Job}, NewJobs} ->
            {Result, NewState} = case wa_raft_transport:transport_info(ID) of
                {ok, #{module := Module}} ->
                    try get_module_state(Module, State) of
                        {ok, ModuleState0} ->
                            try Module:transport_send(ID, FileID, ModuleState0) of
                                {ok, ModuleState1} ->
                                    {ok, State#state{states = States#{Module => ModuleState1}}};
                                {continue, ModuleState1} ->
                                    {continue, State#state{states = States#{Module => ModuleState1}}};
                                {stop, Reason, ModuleState1} ->
                                    {{stop, Reason}, State#state{states = States#{Module => ModuleState1}}}
                            catch
                                T:E:S ->
                                    ?RAFT_LOG_WARNING(
                                        "[~p] module ~p failed to send file ~p:~p due to ~p ~p: ~p",
                                        [Number, Module, ID, FileID, T, E, S]
                                    ),
                                    {{T, E}, State}
                            end;
                        Other ->
                            {Other, State}
                    catch
                        T:E:S ->
                            ?RAFT_LOG_WARNING(
                                "[~p] module ~p failed to get/init module state due to ~p ~p: ~p",
                                [Number, Module, T, E, S]
                            ),
                            {{T, E}, State}
                    end;
                _ ->
                    ?RAFT_LOG_WARNING("[~p] trying to send for unknown transfer ~p", [Number, ID]),
                    {{stop, invalid_transport}, State}
            end,
            case Result =:= continue of
                true ->
                    {noreply, NewState#state{jobs = queue:in(Job, NewJobs)}, ?CONTINUE_TIMEOUT};
                false ->
                    wa_raft_transport:complete(ID, FileID, Result),
                    {noreply, NewState#state{jobs = queue:in(#transport{id = ID}, NewJobs)}, ?CONTINUE_TIMEOUT}
            end
    end;
handle_info(Info, #state{number = Number} = State) ->
    ?RAFT_LOG_WARNING("[~p] received unrecognized info ~p", [Number, Info]),
    {noreply, State, ?CONTINUE_TIMEOUT}.

-spec terminate(term(), state()) -> ok.
terminate(Reason, #state{states = States}) ->
    [
        case erlang:function_exported(Module, transport_terminate, 2) of
            true  -> Module:transport_terminate(Reason, State);
            false -> ok
        end
     || Module := State <- States
    ],
    ok.

-spec get_module_state(module(), state()) -> {ok, state()} | {stop, term()}.
get_module_state(Module, #state{node = Node, states = States}) ->
    case States of
        #{Module := ModuleState} -> {ok, ModuleState};
        _                        -> Module:transport_init(Node)
    end.