%% @doc Process handling one single MQTT session.
%% Transports attaches and detaches from this session.
%% @author Marc Worrell <marc@worrell.nl>
%% @copyright 2018-2022 Marc Worrell
%% Copyright 2018-2022 Marc Worrell
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%% TODO: Limit in-flight acks (both ways)
%% TODO: Refuse incoming publish messages if too many publish_jobs
%% TODO: Limit incoming_data buffer size
% MQTTv5 spec http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html
% MQTTv3.1.1 spec http://docs.oasis-open.org/mqtt/mqtt/v5.0/cos01/mqtt-v5.0-cos01.html
-module(mqtt_sessions_process).
-behaviour(gen_server).
-export([
get_user_context/1,
set_user_context/2,
update_user_context/2,
get_transport/1,
kill/1,
incoming_connect/3,
incoming_data/2,
fetch_queue/1,
start_link/3
]).
-export([
init/1,
handle_call/3,
handle_cast/2,
handle_info/2,
code_change/3,
terminate/2
]).
-define(MAX_PACKET_ID, 65535).
-define(QUEUE_PURGE_LEN, 500). % Allow 500 packages in the queue before triggering a purge.
-define(RECEIVE_MAXIMUM, 65535).
-define(KEEP_ALIVE_DEFAULT, 30). % Default keep alive in seconds
-define(SESSION_EXPIRY, 600). % Default session expiration
-define(SESSION_EXPIRY_DEFAULT, 3600). % Maximum allowed session expiration
-define(MESSAGE_EXPIRY_DEFAULT, 3600).
-define(KILL_TIMEOUT, 5000).
-type packet_id() :: 0..65535. % ?MAX_PACKET_ID
-record(state, {
protocol_version :: mqtt_packet_map:mqtt_version(),
pool :: atom(),
runtime :: atom(),
client_id :: binary(),
routing_id :: binary(),
user_context :: term(),
transport = undefined :: pid() | function() | undefined,
connection_pid = undefined :: pid() | undefined,
is_session_present = false :: boolean(),
pending_connack = undefined :: term(),
pending :: queue:queue(),
packet_id = 1 :: packet_id(),
send_quota = ?RECEIVE_MAXIMUM :: non_neg_integer(),
awaiting_ack = #{} :: map(), % Initiated by server
awaiting_rel = #{} :: map(), % Initiated by client
will = undefined :: undefined | map(),
will_pid = undefined :: undefined | pid(),
msg_nr = 0 :: non_neg_integer(),
keep_alive = ?KEEP_ALIVE_DEFAULT :: non_neg_integer(),
keep_alive_counter = 3 :: integer(),
keep_alive_ref :: undefined | reference(),
session_expiry_interval = ?SESSION_EXPIRY :: non_neg_integer(),
% Number of times we had a succesful connect to this session
connect_count = 0 :: non_neg_integer(),
% Buffering incoming data for a complete packet
incoming_data = <<>> :: binary(),
% Tracking publish jobs
publish_jobs = #{} :: map()
}).
-record(queued, {
type :: atom(),
msg_nr :: pos_integer(),
packet_id = undefined :: undefined | non_neg_integer(),
queued :: non_neg_integer(),
expiry :: non_neg_integer(),
qos :: 0..2,
message :: mqtt_packet_map:mqtt_packet()
}).
-include_lib("kernel/include/logger.hrl").
-include_lib("mqtt_packet_map/include/mqtt_packet_map.hrl").
-include_lib("../include/mqtt_sessions.hrl").
-spec get_user_context( pid() ) -> {ok, term()} | {error, noproc}.
get_user_context(Pid) ->
try
gen_server:call(Pid, get_user_context, infinity)
catch
exit:{noproc, _} ->
{error, noproc}
end.
-spec set_user_context( pid(), term() ) -> ok | {error, noproc}.
set_user_context(Pid, UserContext) ->
try
gen_server:call(Pid, {set_user_context, UserContext}, infinity)
catch
exit:{noproc, _} ->
{error, noproc}
end.
-spec update_user_context( pid(), fun( (term()) -> term() ) ) -> ok | {error, noproc}.
update_user_context(Pid, Fun) ->
try
gen_server:call(Pid, {update_user_context, Fun}, infinity)
catch
exit:{noproc, _} ->
{error, noproc}
end.
-spec get_transport( pid() ) -> {ok, pid()} | {error, notransport | noproc}.
get_transport(Pid) ->
try
gen_server:call(Pid, get_transport, infinity)
catch
exit:{noproc, _} ->
{error, noproc}
end.
-spec kill( pid() ) -> ok.
kill(Pid) when is_pid(Pid) ->
MRef = monitor(process, Pid),
gen_server:cast(Pid, kill),
receive
{'DOWN', MRef, process, Pid, _Reason} ->
ok
after ?KILL_TIMEOUT ->
erlang:exit(Pid, kill),
receive
{'DOWN', MRef, process, Pid, _Reason} ->
ok
end
end.
-spec incoming_connect(pid(), mqtt_packet_map:mqtt_packet(), mqtt_sessions:msg_options()) -> ok.
incoming_connect(Pid, Msg, Options) when is_map(Options) ->
gen_server:cast(Pid, {incoming_connect, Msg, Options}).
-spec incoming_data(pid(), binary()) -> ok | {error, wrong_connection | mqtt_packet_map:decode_error()}.
incoming_data(Pid, Data) ->
gen_server:call(Pid, {incoming_data, Data, self()}).
-spec fetch_queue(pid()) -> {ok, list( map() | binary() )}.
fetch_queue( Pid ) ->
gen_server:call(Pid, fetch_queue, infinity).
-spec start_link( Pool::atom(), ClientId::binary(), mqtt_sessions:session_options() ) -> {ok, pid()}.
start_link( Pool, ClientId, SessionOptions ) ->
gen_server:start_link(?MODULE, [ Pool, ClientId, SessionOptions ], []).
% ---------------------------------------------------------------------------------------
% --------------------------- gen_server functions --------------------------------------
% ---------------------------------------------------------------------------------------
init([ Pool, ClientId, SessionOptions ]) ->
RoutingId = mqtt_sessions_registry:routing_id(Pool),
mqtt_sessions_registry:register(Pool, ClientId, self()),
{ok, WillPid} = mqtt_sessions_will_sup:start(Pool, self()),
{ok, Runtime} = application:get_env(mqtt_sessions, runtime),
erlang:monitor(process, WillPid),
SessionOptions1 = SessionOptions#{
routing_id => RoutingId
},
KeepAliveRef = erlang:make_ref(),
erlang:send_after(?KEEP_ALIVE_DEFAULT * 500, self(), {keep_alive, KeepAliveRef}),
{ok, #state{
pool = Pool,
runtime = Runtime,
user_context = Runtime:new_user_context(Pool, ClientId, SessionOptions1),
client_id = ClientId,
routing_id = RoutingId,
pending = queue:new(),
will_pid = WillPid,
keep_alive = ?KEEP_ALIVE_DEFAULT,
keep_alive_counter = 3,
keep_alive_ref = KeepAliveRef
}}.
handle_call(fetch_queue, _From, #state{ pending_connack = undefined } = State) ->
Qs = [ Msg || #queued{ message = Msg } <- queue:to_list(State#state.pending) ],
{reply, {ok, encode(State#state.protocol_version, Qs)}, State#state{ pending = queue:new() }};
handle_call(fetch_queue, _From, #state{ pending_connack = ConnAck } = State) ->
{reply, {ok, encode(State#state.protocol_version, ConnAck)}, State#state{ pending_connack = undefined }};
handle_call(get_user_context, _From, #state{ user_context = UserContext } = State) ->
{reply, {ok, UserContext}, State};
handle_call({set_user_context, UserContext}, _From, State) ->
{reply, ok, State#state{ user_context = UserContext }};
handle_call({update_user_context, Fun}, _From, #state{ user_context = UserContext} = State) ->
{reply, ok, State#state{ user_context = Fun(UserContext) }};
handle_call(get_transport, _From, #state{ transport = undefined } = State) ->
{reply, {error, notransport}, State};
handle_call(get_transport, _From, #state{ transport = TransportPid } = State) ->
{reply, {ok, TransportPid}, State};
handle_call({incoming_data, NewData, ConnectionPid}, _From, #state{ incoming_data = Data, connection_pid = ConnectionPid } = State) ->
Data1 = << Data/binary, NewData/binary >>,
case handle_incoming_data(Data1, State) of
{ok, {Rest, StateRest}} ->
{reply, ok, StateRest#state{ keep_alive_counter = 3, incoming_data = Rest }};
{error, Reason} when is_atom(Reason) ->
% illegal packet, disconnect and wait for new connection
?LOG_INFO("Error decoding incoming data: ~p", [ Reason ]),
{reply, {error, Reason}, force_disconnect(State)}
end;
handle_call({incoming_data, _NewData, ConnectionPid}, _From, State) ->
?LOG_DEBUG("MQTT session incoming data from ~p, expected from ~p", [ConnectionPid, State#state.connection_pid]),
{reply, {error, wrong_connection}, State};
handle_call(Cmd, _From, State) ->
{stop, {unknown_cmd, Cmd}, State}.
handle_cast({incoming_connect, Msg, Options}, State) ->
case handle_incoming_with_context(Msg, Options, State) of
{ok, State1} ->
{noreply, State1#state{ keep_alive_counter = 3, incoming_data = <<>> }};
{error, _} ->
{noreply, force_disconnect(State)}
end;
handle_cast(kill, State) ->
{stop, shutdown, State}.
handle_info({mqtt_msg, #{ type := publish } = MqttMsg}, State) ->
% io:fwrite(standard_error, "publish: ~p~n", [MqttMsg]),
State1 = relay_publish(MqttMsg, State),
{noreply, State1};
handle_info({keep_alive, Ref}, #state{ keep_alive_counter = 0, keep_alive_ref = Ref } = State) ->
?LOG_DEBUG("MQTT past keep_alive, disconnecting transport"),
{noreply, force_disconnect(State)};
handle_info({keep_alive, Ref}, #state{ keep_alive_counter = N, keep_alive_ref = Ref } = State) ->
erlang:send_after(State#state.keep_alive * 500, self(), {keep_alive, Ref}),
{noreply, State#state{ keep_alive_counter = erlang:max(N-1, 0) }};
handle_info({keep_alive, _Ref}, State) ->
{noreply, State};
handle_info({publish_job, undefined}, State) ->
{noreply, State};
handle_info({publish_job, JobPid}, #state{ publish_jobs = Jobs } = State) when is_pid(JobPid) ->
State1 = case erlang:is_process_alive(JobPid) of
true ->
State#state{ publish_jobs = Jobs#{ JobPid => erlang:monitor(process, JobPid) } };
false ->
State
end,
{noreply, State1};
handle_info({'DOWN', _Mref, process, Pid, _Reason}, #state{ connection_pid = Pid } = State) ->
State1 = do_disconnected(State),
{noreply, State1};
handle_info({'DOWN', _Mref, process, Pid, _Reason}, #state{ will_pid = Pid } = State) ->
send_transport(#{
type => disconnect,
reason_code => ?MQTT_RC_ERROR
}, State),
{stop, shutdown, State};
handle_info({'DOWN', _Mref, process, Pid, _Reason}, State) ->
State1 = case maps:is_key(Pid, State#state.publish_jobs) of
true ->
State#state{ publish_jobs = maps:remove(Pid, State#state.publish_jobs) };
false ->
State
end,
{noreply, State1};
handle_info(Info, State) ->
?LOG_INFO("Unknown info message ~p", [Info]),
{noreply, State}.
code_change(_Vsn, State, _Extra) ->
{ok, State}.
terminate(_Reason, _State) ->
ok.
% ---------------------------------------------------------------------------------------
% ----------------------------- support functions ---------------------------------------
% ---------------------------------------------------------------------------------------
handle_incoming_data(<<>>, State) ->
{ok, {<<>>, State}};
handle_incoming_data(Data, State) ->
case mqtt_packet_map:decode(State#state.protocol_version, Data) of
{ok, {Msg, Rest}} ->
case handle_incoming_with_context(Msg, #{}, State) of
{ok, State1} ->
handle_incoming_data(Rest, State1);
{error, _Reason} = Error ->
Error
end;
{error, incomplete_packet} ->
% @todo Limit buffer size, disconnect if over max size
{ok, {Data, State}};
{error, _Reason} = Error ->
Error
end.
handle_incoming_with_context(Msg, Options, #state{ runtime = Runtime, user_context = UserContext } = State) ->
case Runtime:is_valid_message(Msg, Options, UserContext) of
true ->
handle_incoming(Msg, Options, State);
false ->
% We don't want this here, drop connection
{error, invalid_message}
end.
handle_incoming(#{ type := connect } = Msg, Options, #state{ is_session_present = false } = State) ->
% First time connect, accept.
packet_connect(Msg, Options, State);
handle_incoming(#{ type := connect } = Msg, Options, State) ->
% A client reopens a connection. Check if the credentials match with the current
% session credentials (otherwise someone else might steal this session).
packet_connect(Msg, Options, State);
handle_incoming(#{ type := auth } = Msg, _Options, State) ->
packet_connect_auth(Msg, State);
handle_incoming(#{ type := Type }, _Options, #state{ connection_pid = undefined } = State) ->
?LOG_INFO("Dropping packet for MQTT session ~p ~s (~p) for receiving ~p when not connected.",
[State#state.pool, State#state.client_id, self(), Type]),
{error, not_connected};
handle_incoming(#{ type := Type }, _Options, #state{ is_session_present = false } = State) ->
% Only AUTH and CONNECT before the CONNACK
?LOG_INFO("Killing MQTT session ~p ~s (~p) for receiving ~p when no session started.",
[State#state.pool, State#state.client_id, self(), Type]),
{stop, State};
handle_incoming(#{ type := publish } = Msg, _Options, State) ->
packet_publish(Msg, State);
% PUBREL is for publish messages sent by the client
handle_incoming(#{ type := pubrel } = Msg, _Options, State) ->
packet_pubrel(Msg, State);
% PUBREC, PUBACK, PUBCOMP is for publish messages sent by us to the client
handle_incoming(#{ type := pubrec } = Msg, _Options, State) ->
packet_pubrec(Msg, State);
handle_incoming(#{ type := pubcomp } = Msg, _Options, State) ->
packet_pubcomp(Msg, State);
handle_incoming(#{ type := puback } = Msg, _Options, State) ->
packet_puback(Msg, State);
handle_incoming(#{ type := subscribe } = Msg, _Options, State) ->
packet_subscribe(Msg, State);
handle_incoming(#{ type := unsubscribe } = Msg, _Options, State) ->
packet_unsubscribe(Msg, State);
handle_incoming(#{ type := pingreq }, _Options, State) ->
State1 = reply(#{ type => pingresp }, State),
{ok, State1};
handle_incoming(#{ type := pingresp }, _Options, State) ->
{ok, State};
handle_incoming(#{ type := disconnect } = Msg, _Options, State) ->
packet_disconnect(Msg, State);
handle_incoming(#{ type := Type }, _Options, State) ->
?LOG_INFO("MQTT dropping unhandled packet with type ~p", [Type]),
{ok, State}.
% ---------------------------------------------------------------------------------------
% --------------------------- message type functions ------------------------------------
% ---------------------------------------------------------------------------------------
%% @doc Handle the connect message. Either this is a re-connect or the first connect.
packet_connect(#{ protocol_version := V, protocol_name := <<"MQTT">> }, Options, #state{ protocol_version = PV } = State)
when is_integer(PV), V =/= PV ->
% Do not change protocol versions for an existing session
ConnAck = #{
type => connack,
reason_code => ?MQTT_RC_NOT_AUTHORIZED
},
_ = reply(ConnAck, set_connection(Options, State)),
{error, protocol_version_changed};
packet_connect(#{ protocol_version := 5, protocol_name := <<"MQTT">>, properties := Props } = Msg, Options, State) ->
% MQTT v5
ExpiryInterval = case maps:get(session_expiry_interval, Props, none) of
none -> ?SESSION_EXPIRY_DEFAULT;
EI -> EI
end,
KeepAlive = maps:get(keep_alive, Msg, ?KEEP_ALIVE_DEFAULT),
StateIfAccept = State#state{
protocol_version = 5,
will = extract_will(Msg),
session_expiry_interval = ExpiryInterval,
keep_alive = KeepAlive,
incoming_data = <<>>
},
StateIfAccept1 = set_connection(Options, StateIfAccept),
handle_connect_auth(Msg, Options, StateIfAccept1, State);
packet_connect(#{ protocol_version := 4, protocol_name := <<"MQTT">> } = Msg, Options, State) ->
% MQTT v3.1.1
KeepAlive = maps:get(keep_alive, Msg, ?KEEP_ALIVE_DEFAULT),
StateIfAccept = State#state{
protocol_version = 4,
will = extract_will(Msg),
session_expiry_interval = KeepAlive * 3,
keep_alive = KeepAlive,
incoming_data = <<>>
},
StateIfAccept1 = set_connection(Options, StateIfAccept),
handle_connect_auth(Msg, Options, StateIfAccept1, State);
packet_connect(_ConnectMsg, Options, State) ->
ConnAck = #{
type => connack,
reason_code => ?MQTT_RC_PROTOCOL_VERSION
},
_ = reply(ConnAck, set_connection(Options, State)),
{error, protocol_version}.
packet_connect_auth(Msg, #state{ runtime = Runtime, user_context = UserContext } = State) ->
handle_connect_auth_1(Runtime:reauth(Msg, UserContext), Msg, State, State).
handle_connect_auth(Msg, Options, StateIfAccept, #state{ runtime = Runtime, is_session_present = IsSessionPresent, user_context = UserContext } = State) ->
handle_connect_auth_1(Runtime:connect(Msg, IsSessionPresent, Options, UserContext), Msg, StateIfAccept, State).
%% @doc Accept the new connection with the given ConnAck or Auth message.
%% If an Auth message is sent then we need further authenticaion handshakes.
%% Only after a succesful connack we will set the is_session_present flag.
handle_connect_auth_1({ok, #{ type := connack, reason_code := ?MQTT_RC_SUCCESS } = ConnAck, UserContext1},
#{ clean_start := CleanStart }, StateIfAccept, #state{ is_session_present = IsSessionPresent }) ->
StateCleaned = maybe_clean_start(CleanStart, StateIfAccept),
%% Set the session_present flag to true, when the runtime omitted it, and when there is a
%% session present.
ConnAck1 = case maps:find(session_present, ConnAck) of
{ok, _} -> ConnAck;
error ->
ConnAck#{ session_present => IsSessionPresent andalso not (CleanStart =:= true) }
end,
State1 = StateCleaned#state{
user_context = UserContext1,
is_session_present = true,
will = undefined
},
State2 = reply_connack(ConnAck1, State1),
mqtt_sessions_will:connected(State2#state.will_pid, StateIfAccept#state.will,
State2#state.session_expiry_interval, State2#state.user_context),
State3 = resend_unacknowledged( cleanup_pending(State2) ),
{ok, State3};
handle_connect_auth_1({ok, #{ type := connack, reason_code := ReasonCode } = ConnAck, _UserContext1}, _Msg, StateIfAccept, _State) ->
_ = reply_connack(ConnAck, StateIfAccept),
?LOG_DEBUG("MQTT connect/auth refused (~p): ~p", [ReasonCode, ConnAck]),
{error, connection_refused};
handle_connect_auth_1({ok, #{ type := auth } = Auth, UserContext1}, _Msg, StateIfAccept, _State) ->
State1 = StateIfAccept#state{
user_context = UserContext1
},
State2 = reply(Auth, State1),
mqtt_sessions_will:connected(State2#state.will_pid, undefined,
State2#state.session_expiry_interval, State2#state.user_context),
{ok, State2};
handle_connect_auth_1({error, Reason}, Msg, _StateIfAccept, _State) ->
?LOG_INFO("MQTT connect/auth refused (~p): ~p", [Reason, Msg]),
{error, connection_refused}.
%% @doc Drop all current subscriptions and pending messages on a clean start
maybe_clean_start(false, State) ->
State;
maybe_clean_start(true, #state{ pool = Pool } = State) ->
mqtt_sessions_router:unsubscribe_pid(Pool, self()),
State#state{ pending = queue:new() }.
%% @doc Handle a publish request
packet_publish(#{ topic := Topic, qos := 0 } = Msg,
#state{ runtime = Runtime, user_context = UCtx, client_id = ClientId } = State) ->
case Topic of
[ <<"$client">>, ClientId | Rest ] ->
MsgPub = mqtt_sessions_payload:decode(Msg#{ dup => false }),
{ok, UCtx1} = Runtime:control_message(Rest, MsgPub, UCtx),
{ok, State#state{ user_context = UCtx1 }};
_ ->
case Runtime:is_allowed(publish, Topic, Msg, UCtx) of
true ->
MsgPub = mqtt_sessions_payload:decode(Msg#{ dup => false }),
{ok, JobPid} = mqtt_sessions_router:publish(State#state.pool, Topic, MsgPub, UCtx),
self() ! {publish_job, JobPid},
{ok, State};
false ->
{ok, State}
end
end;
packet_publish(#{ topic := Topic, qos := 1, dup := Dup, packet_id := PacketId } = Msg,
#state{ runtime = Runtime, user_context = UCtx, awaiting_rel = WaitRel } = State) ->
case maps:find(PacketId, WaitRel) of
{ok, _} when not Dup ->
% There is a qos 2 level message with the same packet id
PubAck = #{
type => puback,
packet_id => PacketId,
reason_code => ?MQTT_RC_PACKET_ID_IN_USE
},
reply(PubAck, State);
{ok, {pubrel, RC, _}} when Dup ->
% There is a qos 2 level message with the same packet id
% But the received mesage is a duplicate, just ack.
PubAck = #{
type => puback,
packet_id => PacketId,
reason_code => RC
},
reply(PubAck, State);
error ->
RC = case Runtime:is_allowed(publish, Topic, Msg, UCtx) of
true ->
MsgPub = mqtt_sessions_payload:decode(Msg#{ dup => false }),
{ok, JobPid} = mqtt_sessions_router:publish(State#state.pool, Topic, MsgPub, UCtx),
self() ! {publish_job, JobPid},
?MQTT_RC_SUCCESS;
false ->
?MQTT_RC_NOT_AUTHORIZED
end,
PubAck = #{
type => puback,
packet_id => PacketId,
reason_code => RC
},
State1 = reply(PubAck, State),
{ok, State1}
end;
packet_publish(#{ topic := Topic, qos := 2, dup := Dup, packet_id := PacketId } = Msg,
#state{ runtime = Runtime, user_context = UCtx, awaiting_rel = WaitRel } = State) ->
case maps:find(PacketId, WaitRel) of
{ok, _} when not Dup ->
PubRec = #{
type => pubrec,
packet_id => PacketId,
reason_code => ?MQTT_RC_PACKET_ID_IN_USE
},
reply(PubRec, State);
{ok, {pubrel, RC, _}} when Dup ->
PubRec = #{
type => pubrec,
packet_id => PacketId,
reason_code => RC
},
State1 = reply(PubRec, State),
{ok, State1};
error ->
RC = case Runtime:is_allowed(publish, Topic, Msg, UCtx) of
true ->
MsgPub = mqtt_sessions_payload:decode(Msg#{ dup => false }),
{ok, JobPid} = mqtt_sessions_router:publish(State#state.pool, Topic, MsgPub, UCtx),
self() ! {publish_job, JobPid},
?MQTT_RC_SUCCESS;
false ->
?MQTT_RC_NOT_AUTHORIZED
end,
State1 = if
RC < 16#80 ->
State#state{ awaiting_rel = WaitRel#{ PacketId => {pubrel, RC, mqtt_sessions_timestamp:timestamp()} } };
true ->
State
end,
PubRec = #{
type => pubrec,
packet_id => PacketId,
reason_code => RC
},
State2 = reply(PubRec, State1),
{ok, State2}
end.
%% @doc Handle the pubrel
packet_pubrel(#{ packet_id := PacketId, reason_code := ?MQTT_RC_SUCCESS }, #state{ awaiting_rel = WaitRel } = State) ->
case maps:find(PacketId, WaitRel) of
{ok, {pubrel, _RC, _Tm}} ->
PubComp = #{
type => pubcomp,
packet_id => PacketId,
reason_code => ?MQTT_RC_SUCCESS
},
WaitRel1 = maps:remove(PacketId, WaitRel),
State1 = reply(PubComp, State),
{ok, State1#state{ awaiting_rel = WaitRel1 }};
error ->
PubComp = #{
type => pubcomp,
packet_id => PacketId,
reason_code => ?MQTT_RC_PACKET_ID_NOT_FOUND
},
State1 = reply(PubComp, State),
{ok, State1}
end;
packet_pubrel(#{ packet_id := PacketId, reason_code := RC }, #state{ awaiting_rel = WaitRel } = State) ->
% Error server/client out of sync - remove the wait-rel for this packet_id
?LOG_INFO("PUBREL with reason ~p for packet ~p",
[ RC, PacketId ]),
WaitRel1 = maps:remove(PacketId, WaitRel),
{ok, State#state{ awaiting_rel = WaitRel1 }}.
%% @doc Handle puback for QoS 1 publish messages sent to the client
packet_puback(#{ packet_id := PacketId }, #state{ awaiting_ack = WaitAck } = State) ->
WaitAck1 = case maps:find(PacketId, WaitAck) of
{ok, {_MsgNr, puback, _Msg}} ->
maps:remove(PacketId, WaitAck);
{ok, {_MsgNr, Wait, Msg}} ->
?LOG_WARNING("PUBACK for message ~p waiting for ~p. Message: ~p",
[ PacketId, Wait, Msg ]),
maps:remove(PacketId, WaitAck);
error ->
WaitAck
end,
{ok, State#state{ awaiting_ack = WaitAck1 }}.
%% @doc Handle pubrec for QoS 2 publish messages sent to the client
packet_pubrec(#{ packet_id := PacketId, reason_code := RC }, #state{ awaiting_ack = WaitAck } = State) when RC >= 16#80 ->
WaitAck1 = case maps:find(PacketId, WaitAck) of
{ok, {_MsgNr, pubrec, _Msg}} ->
maps:remove(PacketId, WaitAck);
{ok, {_MsgNr, pubcomp, _Msg}} ->
maps:remove(PacketId, WaitAck);
{ok, {_MsgNr, Wait, Msg}} ->
?LOG_WARNING("PUBREC for message ~p waiting for ~p. Message: ~p",
[ PacketId, Wait, Msg ]),
maps:remove(PacketId, WaitAck);
error ->
WaitAck
end,
{ok, State#state{ awaiting_ack = WaitAck1 }};
packet_pubrec(#{ packet_id := PacketId }, #state{ awaiting_ack = WaitAck } = State) ->
{WaitAck1, RC} = case maps:find(PacketId, WaitAck) of
{ok, {MsgNr, pubrec, _Msg}} ->
{WaitAck#{ PacketId => {MsgNr, pubcomp, undefined} }, ?MQTT_RC_SUCCESS};
{ok, {_MsgNr, pubcomp, _Msg}} ->
{WaitAck, ?MQTT_RC_SUCCESS};
{ok, {_MsgNr, Wait, Msg}} ->
?LOG_WARNING("PUBREC for message ~p waiting for ~p. Message: ~p",
[ PacketId, Wait, Msg ]),
{maps:remove(PacketId, WaitAck), ?MQTT_RC_PACKET_ID_NOT_FOUND};
error ->
{WaitAck, ?MQTT_RC_PACKET_ID_NOT_FOUND}
end,
State1 = State#state{ awaiting_ack = WaitAck1 },
PubRel = #{
type => pubrel,
packet_id => PacketId,
reason_code => RC
},
{ok, reply(PubRel, State1)}.
%% @doc Handle pubcomp for QoS 2 publish messages sent to the client
packet_pubcomp(#{ packet_id := PacketId }, #state{ awaiting_ack = WaitAck } = State) ->
WaitAck1 = case maps:find(PacketId, WaitAck) of
{ok, {_MsgNr, pubcomp, _Msg}} ->
maps:remove(PacketId, WaitAck);
{ok, {_MsgNr, Wait, Msg}} ->
?LOG_WARNING("PUBREC for message ~p waiting for ~p. Message: ~p",
[ PacketId, Wait, Msg ]),
maps:remove(PacketId, WaitAck);
error ->
WaitAck
end,
{ok, State#state{ awaiting_ack = WaitAck1 }}.
%% @doc Handle a subscribe request
packet_subscribe(#{ topics := Topics } = Msg, #state{ runtime = Runtime, user_context = UCtx } = State) ->
Resp = lists:map(
fun(#{ topic := TopicFilter0 } = Sub) ->
case mqtt_packet_map_topic:validate_topic(TopicFilter0) of
{ok, TopicFilter} ->
case Runtime:is_allowed(subscribe, TopicFilter, Msg, State#state.user_context) of
true ->
QoS = maps:get(qos, Sub, 0),
SubOptions = Sub#{
qos => QoS,
no_local => maps:get(no_local, Sub, false)
},
SubOptions1 = maps:remove(topic, SubOptions),
case mqtt_sessions_router:subscribe(State#state.pool, TopicFilter, self(), self(), SubOptions1, UCtx) of
ok -> {ok, QoS};
{error, _} -> {error, ?MQTT_RC_ERROR}
end;
false ->
{error, ?MQTT_RC_NOT_AUTHORIZED}
end;
{error, _} ->
{error, ?MQTT_RC_TOPIC_FILTER_INVALID}
end
end,
Topics),
SubAck = #{
type => suback,
packet_id => maps:get(packet_id, Msg, 0),
acks => Resp
},
State1 = reply(SubAck, State),
{ok, State1}.
%% @doc Handle the unsubscribe request
packet_unsubscribe(#{ topics := Topics } = Msg, State) ->
Resp = lists:map(
fun(TopicFilter) ->
case mqtt_sessions_router:unsubscribe(State#state.pool, TopicFilter, self()) of
ok -> {ok, found};
{error, notfound} -> {ok, notfound}
end
end,
Topics),
UnsubAck = #{
type => unsuback,
packet_id => maps:get(packet_id, Msg, 0),
acks => Resp
},
State1 = reply(UnsubAck, State),
{ok, State1}.
%% @doc Handle a disconnect from the client.
packet_disconnect(#{ reason_code := RC, properties := Props },
#state{ will_pid = WillPid, session_expiry_interval = SessionExpiryInterval } = State) ->
NewExpiryInterval = case SessionExpiryInterval of
0 ->
% TODO: If the props.session_expiry_interval > 0 then send disconnect with MQTT_RC_PROTOCOL_ERROR
SessionExpiryInterval;
_ ->
maps:get(session_expiry_interval, Props, SessionExpiryInterval)
end,
IsSendWill = (RC =/= ?MQTT_RC_SUCCESS),
mqtt_sessions_will:disconnected(WillPid, IsSendWill, NewExpiryInterval),
State1 = force_disconnect(State),
case NewExpiryInterval of
0 -> gen_server:cast(self(), kill);
_ -> ok
end,
{ok, State1}.
% ---------------------------------------------------------------------------------------
% --------------------------- relay publish to client -----------------------------------
% ---------------------------------------------------------------------------------------
relay_publish(#{ type := publish, message := Msg } = MqttMsg, State) ->
QoS = erlang:min( maps:get(qos, Msg, 0), maps:get(qos, MqttMsg, 0) ),
Msg2 = mqtt_sessions_payload:encode(Msg#{
qos => QoS,
dup => false
}),
{StateN, MsgN} = case QoS of
0 ->
{State, Msg2#{ packet_id => 0 }};
_ ->
State1 = #state{ packet_id = PacketId } = inc_packet_id(State),
State2 = #state{ msg_nr = MsgNr } = inc_msg_nr(State1),
AckRec = case QoS of
1 -> puback;
2 -> pubrec
end,
Msg3 = Msg2#{
packet_id => PacketId
},
State3 = State2#state{
awaiting_ack = (State2#state.awaiting_ack)#{ PacketId => {MsgNr, AckRec, Msg3} }
},
{State3, Msg3}
end,
reply(MsgN, StateN).
% ---------------------------------------------------------------------------------------
% ------------------------------- queue functions ---------------------------------------
% ---------------------------------------------------------------------------------------
cleanup_pending(#state{ pending = Pending } = State) ->
L1 = lists:filter(
fun
(#{ type := publish, qos := 0 }) -> true;
(_) -> false
end,
queue:to_list(Pending)),
State#state{ pending = queue:from_list(L1) }.
resend_unacknowledged(#state{ awaiting_ack = AwaitAck } = State) ->
Msgs = maps:fold(
fun
(_PacketId, {MsgNr, pubrec, Msg}, Acc) ->
[ {MsgNr, Msg#{ dup => true }} | Acc ];
(PacketId, {MsgNr, pubcomp, _Msg}, Acc) ->
PubComp = #{
type => pubrec,
packet_id => PacketId
},
[ {MsgNr, PubComp} | Acc ];
(_PacketId, {MsgNr, suback, Msg}, Acc) ->
[ {MsgNr, Msg} | Acc ];
(_PacketId, {MsgNr, unsuback, Msg}, Acc) ->
[ {MsgNr, Msg} | Acc ];
(_PacketId, _, Acc) ->
Acc
end,
[],
AwaitAck),
lists:foldl(
fun({_Nr, Msg}, StateAcc) ->
reply(Msg, StateAcc)
end,
State,
lists:sort(Msgs)).
% ---------------------------------------------------------------------------------------
% -------------------------------- misc functions ---------------------------------------
% ---------------------------------------------------------------------------------------
%% @doc Called when the connection disconnects or crashes/stops
do_disconnected(#state{ will_pid = WillPid } = State) ->
mqtt_sessions_will:disconnected(WillPid),
cleanup_state_disconnected(State).
%% @todo Cleanup pending messages and awaiting states.
cleanup_state_disconnected(State) ->
cleanup_pending(State#state{
pending_connack = undefined,
connection_pid = undefined,
transport = undefined,
awaiting_rel = #{}
}).
%% @doc Send a connack to the remote, ping the will-watchdog that we connected
reply_connack(#{ type := connack, reason_code := ?MQTT_RC_SUCCESS } = ConnAck, State) ->
AckProps = maps:get(properties, ConnAck, #{}),
ConnAck1 = ConnAck#{
properties => AckProps#{
session_expiry_interval => State#state.session_expiry_interval,
server_keep_alive => State#state.keep_alive,
assigned_client_identifier => State#state.client_id,
subscription_identifier_available => false,
shared_subscription_available => false,
<<"cotonic-routing-id">> => State#state.routing_id
}
},
reply(ConnAck1, State);
reply_connack(#{ type := connack } = ConnAck, State) ->
reply(ConnAck, State).
%% @doc Check the connect packet, extract the will as a map for the will-watchdog.
extract_will(#{ type := connect, will_flag := false }) ->
#{};
extract_will(#{ protocol_version := 4, type := connect, will_flag := true } = Msg) ->
#{
expiry_interval => 0,
topic => maps:get(will_topic, Msg),
payload => maps:get(will_payload, Msg, <<>>),
properties => maps:get(will_properties, Msg, #{}),
qos => maps:get(will_qos, Msg, 0),
retain => maps:get(will_retain, Msg, false)
};
extract_will(#{ type := connect, will_flag := true, properties := Props } = Msg) ->
#{
expiry_interval => maps:get(will_expiry_interval, Props, 0),
topic => maps:get(will_topic, Msg),
payload => maps:get(will_payload, Msg, <<>>),
properties => maps:get(will_properties, Msg, #{}),
qos => maps:get(will_qos, Msg, 0),
retain => maps:get(will_retain, Msg, false)
}.
force_disconnect(State) ->
State1 = disconnect_transport(State),
State2 = cleanup_state_disconnected(State1),
case State2#state.is_session_present of
false ->
gen_server:cast(self(), kill);
true ->
ok
end,
State2.
disconnect_transport(#state{ transport = undefined } = State) ->
State;
disconnect_transport(#state{ transport = Transport } = State) when is_pid(Transport) ->
Transport ! {mqtt_transport, self(), disconnect},
State#state{ transport = undefined };
disconnect_transport(#state{ transport = Transport } = State) when is_function(Transport) ->
Transport(disconnect),
State#state{ transport = undefined }.
reply(undefined, State) ->
State;
reply(Msg, #state{ transport = undefined } = State) ->
queue(Msg, State);
reply(Msg, State) ->
case send_transport(Msg, State) of
ok ->
State;
{error, _} ->
queue(Msg, State#state{ transport = undefined })
end.
send_transport(_Msg, #state{ transport = undefined }) ->
ok;
send_transport(Msg, #state{ protocol_version = PV } = State) when is_map(Msg) ->
send_transport(encode(PV, Msg), State);
send_transport(Msg, #state{ transport = Pid }) when is_pid(Pid) ->
case erlang:is_process_alive(Pid) of
true ->
Pid ! {mqtt_transport, self(), Msg},
ok;
false ->
ok
end;
send_transport(Msg, #state{ transport = Fun }) when is_function(Fun) ->
Fun(Msg).
%% @doc Queue a message, extract, type, message expiry, and QoS
queue(#{ type := connack } = Msg, State) ->
State#state{ pending_connack = Msg };
queue(#{ type := auth } = Msg, State) ->
State#state{ pending_connack = Msg };
queue(Msg, State) ->
queue_1(Msg, inc_msg_nr(State)).
queue_1(#{ type := Type } = Msg, #state{ msg_nr = MsgNr, pending = Pending } = State) ->
Props = maps:get(properties, Msg, #{}),
Now = mqtt_sessions_timestamp:timestamp(),
Item = #queued{
msg_nr = MsgNr,
type = Type,
queued = Now,
expiry = Now + maps:get(message_expiry_interval, Props, ?MESSAGE_EXPIRY_DEFAULT),
qos = maps:get(qos, Msg, 1),
message = Msg
},
State#state{ pending = queue:in(Item, maybe_purge(Pending)) }.
maybe_purge(Queue) ->
case queue:len(Queue) > ?QUEUE_PURGE_LEN of
true -> purge(Queue);
false -> Queue
end.
purge(Queue) ->
{value, #queued{ queued = Oldest }} = queue:peek(Queue),
{value, #queued{ queued = Newest }} = queue:peek_r(Queue),
PurgeTime = mqtt_sessions_timestamp:timestamp(),
QoS0PurgeAge = (Newest - Oldest) / 2,
queue:filter(fun(#queued{ qos = QoS, queued = Queued, expiry = Expiry }) ->
case QoS of
0 ->
PurgeTime < Expiry andalso PurgeTime < (Queued + QoS0PurgeAge);
_ ->
PurgeTime < Expiry
end
end,
Queue).
-spec encode( mqtt_packet_map:mqtt_version(), mqtt_packet_map:mqtt_packet() | list( mqtt_packet_map:mqtt_packet() )) -> binary().
encode(ProtocolVersion, Msg) when is_map(Msg) ->
{ok, Bin} = mqtt_packet_map:encode(ProtocolVersion, Msg),
Bin;
encode(ProtocolVersion, Ms) when is_list(Ms) ->
iolist_to_binary([ encode(ProtocolVersion, M) || M <- Ms ]).
%% @doc Set the new connection, disconnect existing transport.
set_connection(#{ connection_pid := ConnectionPid, transport := Transport }, State) ->
case State#state.connection_pid of
ConnectionPid ->
State;
undefined ->
set_connection_1(ConnectionPid, Transport, State);
OldConnectionPid ->
OldConnectionPid ! {mqtt_transport, self(), disconnect},
set_connection_1(ConnectionPid, Transport, State)
end.
set_connection_1(ConnectionPid, Transport, State) ->
erlang:monitor(process, ConnectionPid),
start_keep_alive(State#state{ connection_pid = ConnectionPid, transport = Transport }).
start_keep_alive(#state{ keep_alive = 0 } = State) ->
State;
start_keep_alive(#state{ keep_alive = N } = State) ->
Ref = erlang:make_ref(),
erlang:send_after(N * 500, self(), {keep_alive, Ref}),
State#state{ keep_alive_counter = 3, keep_alive_ref = Ref }.
%% @doc Increment the message number, this number is used for order of resent messages
inc_msg_nr(#state{ msg_nr = Nr } = State) ->
State#state{ msg_nr = Nr + 1 }.
%% @doc Fetch a packet id that is not yet used.
inc_packet_id(#state{ packet_id = PacketId, awaiting_ack = Acks } = State) ->
PacketId1 = case PacketId >= ?MAX_PACKET_ID of
true -> 1;
false -> PacketId + 1
end,
State1 = State#state{ packet_id = PacketId1 },
case maps:is_key(PacketId1, Acks) of
true -> inc_packet_id(State1);
false -> State1
end.