src/mqtt_sessions_router.erl

%% @doc Process owning the MQTT topic router.
%% @author Marc Worrell <marc@worrell.nl>
%% @copyright 2018-2020 Marc Worrell

%% Copyright 2018-2020 Marc Worrell
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%%     http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.

-module(mqtt_sessions_router).

-behaviour(gen_server).

-export([
    publish/3,
    publish/4,
    subscribe/4,
    subscribe/6,
    unsubscribe/3,
    unsubscribe_pid/2,
    start_link/1,
    name/1,
    info/1
    ]).

-export([
    init/1,
    handle_call/3,
    handle_cast/2,
    handle_info/2,
    code_change/3,
    terminate/2
    ]).

-type mqtt_msg() :: #{
        pool => atom(),
        topic => list( binary() ),
        topic_bindings => list( proplists:property() ),
        message => mqtt_packet_map:mqtt_packet(),
        publisher_context => term(),
        subscriber_context => term(),
        no_local => boolean(),
        qos => 0 | 1 | 2,
        retain_as_published => boolean(),
        retain_handling => integer()
    }.

-type subscriber_options() :: #{
        subscriber_context => term(),
        no_local => boolean(),
        qos => 0 | 1 | 2,
        retain_as_published => boolean(),
        retain_handling => integer()
    }.

-type subscriber() :: {pid() | mfa(), OwnerPid::pid(), subscriber_options()}.


-export_type([
    subscriber/0,
    mqtt_msg/0,
    subscriber_options/0
]).

-record(state, {
        pool :: atom(),
        router :: router:router(),
        monitors :: map()
    }).

-include_lib("router/include/router.hrl").
-include_lib("../include/mqtt_sessions.hrl").


-spec publish( atom(), list(), mqtt_packet_map:mqtt_packet() ) -> {ok, pid() | undefined} | {error, overload}.
publish( Pool, Topic, Msg ) ->
    publish(Pool, Topic, Msg, undefined).

-spec publish( atom(), list(), mqtt_packet_map:mqtt_packet(), term() ) -> {ok, pid() | undefined} | {error, overload}.
publish( Pool, Topic0, Msg, PublisherContext ) ->
    Topic = publish_topic(Topic0),
    Pool1 = maybe_map_to_default_pool(Pool, Topic),
    Routes = router:route(Pool1, Topic),
    case mqtt_sessions_job:publish(Pool1, Topic, Routes, Msg, PublisherContext) of
        {ok, JobPid} ->
            case maps:get(retain, Msg, false) of
                true -> mqtt_sessions_retain:retain(Pool1, Msg, PublisherContext);
                false -> ok
            end,
            {ok, JobPid};
        {error, _} = Error ->
            Error
    end.


-spec subscribe( atom(), list(), mqtt_sessions:callback(), term() ) -> ok | {error, invalid_subscriber}.
subscribe( Pool, Topic, {_, _, _} = MFA, SubscriberContext) ->
    subscribe( Pool, Topic, MFA, self(), #{}, SubscriberContext);
subscribe( Pool, Topic, Pid, SubscriberContext) when is_pid(Pid) ->
    subscribe( Pool, Topic, Pid, Pid, #{}, SubscriberContext).

subscribe( Pool, TopicFilter, Subscriber, OwnerPid, Options, SubscriberContext ) when is_pid(OwnerPid), is_map(Options) ->
    case is_valid_subscriber(Subscriber) of
        true ->
            Pool1 = maybe_map_to_default_pool(Pool, TopicFilter),
            case gen_server:call(name(Pool1), {subscribe, TopicFilter, Subscriber, OwnerPid, Options}, infinity) of
                {ok, IsNew} ->
                    % Check retained messages, publish to the Subscriber
                    maybe_publish_retained(Pool1, IsNew, TopicFilter, Subscriber, Options, SubscriberContext),
                    ok;
                {error, _} = Error ->
                    Error
            end;
        false ->
            {error, invalid_subscriber}
    end.

maybe_publish_retained(Pool, IsNew, TopicFilter, Subscriber, Options, SubscriberContext) ->
    case maps:get(retain_handling, Options, 0) of
        0 ->
            % All retained messages
            publish_retained(Pool, TopicFilter, Subscriber, Options, SubscriberContext);
        1 when IsNew ->
            % Only if new subscription
            publish_retained(Pool, TopicFilter, Subscriber, Options, SubscriberContext);
        _ ->
            ok
    end.

publish_retained(Pool, TopicFilter, Subscriber, Options, SubscriberContext) ->
    {ok, Ms} = mqtt_sessions_retain:lookup(Pool, TopicFilter),
    mqtt_sessions_job:publish_retained(Pool, TopicFilter, Ms, Subscriber, Options, SubscriberContext).

-spec unsubscribe( atom(), list(), pid() ) -> ok | {error, notfound}.
unsubscribe( Pool, TopicFilter, Pid ) ->
    Pool1 = maybe_map_to_default_pool(Pool, TopicFilter),
    gen_server:call(name(Pool1), {unsubscribe, TopicFilter, Pid}, infinity).

-spec unsubscribe_pid( atom(), pid() ) -> ok.
unsubscribe_pid( Pool, Pid ) ->
    gen_server:cast(name(Pool), {unsubscribe_pid, Pid}).


-spec start_link( atom() ) -> {ok, pid()} | {error, term()}.
start_link( Pool ) ->
    gen_server:start_link({local, name(Pool)}, ?MODULE, [Pool], []).


is_valid_subscriber({M, F, A}) when is_atom(M), is_atom(F), is_list(A) -> true;
is_valid_subscriber(Pid) when is_pid(Pid) -> true;
is_valid_subscriber(_) -> false.


% ---------------------------------------------------------------------------------------
% --------------------------- gen_server functions --------------------------------------
% ---------------------------------------------------------------------------------------

-spec init( [ atom() ]) -> {ok, #state{}}.
init([ Pool ]) ->
    {ok, #state{
        pool = Pool,
        router = router:new(Pool),
        monitors = #{}
    }}.

handle_call({subscribe, TopicFilter0, Subscriber, OwnerPid, Options}, _From,
            #state{ router = Router, monitors = Monitors } = State) ->
    TopicFilter = subscribe_topic(TopicFilter0),
    Current = maps:get(OwnerPid, Monitors, []),
    {Current1, IsNew} = case lists:keysearch(TopicFilter, 1, Current) of
        {value, {_Filter, PrevSubscriber}} ->
            router:remove_path(Router, TopicFilter, PrevSubscriber),
            {lists:keydelete(TopicFilter, 1, Current), false};
        false ->
            {Current, true}
    end,
    Destination = {Subscriber, OwnerPid, Options},
    ok = router:add(Router, TopicFilter, Destination),
    case maps:is_key(OwnerPid, Monitors) of
        false -> erlang:monitor(process, OwnerPid);
        true -> ok
    end,
    Monitors1 = Monitors#{
        OwnerPid => [ {TopicFilter, Destination} | Current1
    ]},
    {reply, {ok, IsNew}, State#state{ monitors = Monitors1 }};
handle_call({unsubscribe, TopicFilter0, Pid}, _From,
            #state{ router = Router, monitors = Monitors } = State) ->
    TopicFilter = subscribe_topic(TopicFilter0),
    Subs = maps:get(Pid, Monitors, []),
    case lists:keysearch(TopicFilter, 1, Subs) of
        {value, {_Filter, Destination}} ->
            router:remove_path(Router, TopicFilter, Destination),
            Subs1 = lists:keydelete(TopicFilter, 1, Subs),
            Monitors1 = Monitors#{ Pid => Subs1 },
            {reply, ok, State#state{ monitors = Monitors1 }};
        false ->
            {reply, {error, notfound}, State}
    end;
handle_call(Cmd, _From, State) ->
    {stop, {unknown_cmd, Cmd}, State}.

handle_cast({unsubscribe_pid, Pid}, State) ->
    {noreply, remove_subscriber(Pid, State)};
handle_cast(Cmd, State) ->
    {stop, {unknown_cmd, Cmd}, State}.

handle_info({'DOWN', _Mref, process, Pid, _Reason}, State) ->
    {noreply, remove_subscriber(Pid, State)};
handle_info(_Info, State) ->
    {noreply, State}.

code_change(_Vsn, State, _Extra) ->
    {ok, State}.

terminate(_Reason, _State) ->
    ok.

% ---------------------------------------------------------------------------------------
% ----------------------------- support functions ---------------------------------------
% ---------------------------------------------------------------------------------------


subscribe_topic(TopicFilter) ->
    lists:map(
        fun
            ('#') -> '#';
            ('+') -> '+';
            (<<"#">>) -> '#';
            (<<"+">>) -> '+';
            (T) when is_integer(T) -> integer_to_binary(T);
            (T) when is_binary(T) -> T;
            (T) when is_atom(T) -> atom_to_binary(T, utf8)
        end,
        TopicFilter).

publish_topic(Topic) ->
    lists:map(
        fun
            (T) when is_integer(T) -> integer_to_binary(T);
            (T) when is_binary(T) -> T;
            (T) when is_atom(T) -> atom_to_binary(T, utf8)
        end,
        Topic).


%% @doc Remove all subscriptions belonging to a certain process
remove_subscriber(Pid, #state{ router = Router, monitors = Monitors } = State) ->
    lists:foreach(
        fun({TopicFilter, Subscriber}) ->
            router:remove_path(Router, TopicFilter, Subscriber)
        end,
        maps:get(Pid, Monitors, [])),
    State#state{ monitors = maps:remove(Pid, Monitors) }.


-spec name( atom() ) -> atom().
name( Pool ) ->
    list_to_atom(atom_to_list(Pool) ++ "$router").

-spec info( atom() ) -> list().
info( Pool ) ->
    router:info(Pool).

%% @doc For subscribtions to a $SYS topic, we map to the default pool.
-spec maybe_map_to_default_pool( atom(), mqtt_sessions:topic()) -> atom().
maybe_map_to_default_pool(_Pool, [<<"$SYS">> | _]) -> ?MQTT_SESSIONS_DEFAULT;
maybe_map_to_default_pool(Pool, _Topic) -> Pool.