%% @author Marc Worrell <marc@worrell.nl>
%% @copyright 2018-2024 Marc Worrell
%% @doc Process handling one single MQTT session.
%% MQTT connections attach and detach from this session. Buffers outgoing
%% messages if there is not connection attached.
%% @end
%% Copyright 2018-2024 Marc Worrell
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%% TODO: Limit in-flight acks (both ways)
%% TODO: Refuse incoming publish messages if too many publish_jobs
%% TODO: Limit incoming_data buffer size
%% Cleanup awaiting_rel for too old waiting
%% Refactor use of awaiting_ack to not use buffer
% MQTTv5 spec http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html
% MQTTv3.1.1 spec http://docs.oasis-open.org/mqtt/mqtt/v5.0/cos01/mqtt-v5.0-cos01.html
-module(mqtt_sessions_process).
-behaviour(gen_server).
-export([
get_user_context/1,
set_user_context/2,
update_user_context/2,
get_transport/1,
kill/1,
incoming_connect/3,
incoming_data/2,
start_link/3
]).
-export([
init/1,
handle_call/3,
handle_cast/2,
handle_info/2,
code_change/3,
terminate/2
]).
-define(MAX_PACKET_ID, 65535).
-define(RECEIVE_MAXIMUM, 65535).
-define(KEEP_ALIVE_DEFAULT, 30). % Default keep alive in seconds
-define(SESSION_EXPIRY, 900). % Default session expiration (15 minutes)
-define(SESSION_EXPIRY_MAX, 3600). % Maximum allowed session expiration (1 hour)
-define(MESSAGE_EXPIRY_DEFAULT, 3600).
-define(ACK_EXPIRY, 600).
-define(MAX_BUFFERED, 500). % Max buffered QoS 0 messages
-define(MAX_INFLIGHT_ACK, 500). % Max in-flight QoS 1/2 messages
-define(KILL_TIMEOUT, 5000).
-type packet_id() :: 0..65535. % ?MAX_PACKET_ID
-record(queued, {
msg_nr :: pos_integer(),
type :: atom(),
packet_id = undefined :: undefined | packet_id(),
queued :: non_neg_integer(),
expiry :: non_neg_integer(),
qos = 0 :: 0..2,
message :: mqtt_packet_map:mqtt_packet()
}).
-record(wait_for, {
msg_nr :: pos_integer(),
type :: atom(),
message = undefined :: undefined | mqtt_packet_map:mqtt_packet(),
is_sent = true :: boolean(),
queued :: non_neg_integer()
}).
-record(state, {
protocol_version :: mqtt_packet_map:mqtt_version(),
pool :: atom(),
runtime :: atom(),
client_id :: binary(),
routing_id :: binary(),
user_context :: term(),
transport = undefined :: mqtt_sessions:transport() | undefined,
connection_pid = undefined :: pid() | undefined,
is_session_present = false :: boolean(),
is_connected = false :: boolean(),
buffer = #{} :: #{ non_neg_integer() => #queued{} },
packet_id = 1 :: packet_id(),
send_quota = ?RECEIVE_MAXIMUM :: non_neg_integer(),
awaiting_ack = #{} :: #{ non_neg_integer() => #wait_for{} }, % Initiated by server
awaiting_rel = #{} :: map(), % Initiated by client
will = undefined :: undefined | map(),
will_pid = undefined :: undefined | pid(),
msg_nr = 0 :: non_neg_integer(), % Incremental counter to keep the buffer in sequence
keep_alive = ?KEEP_ALIVE_DEFAULT :: non_neg_integer(),
keep_alive_counter = 3 :: integer(),
keep_alive_ref :: undefined | reference(),
session_expiry_interval = ?SESSION_EXPIRY :: non_neg_integer(),
% Number of times we had a succesful connect to this session
connect_count = 0 :: non_neg_integer(),
% Buffering incoming data for a complete packet
incoming_data = <<>> :: binary(),
% Tracking publish jobs
publish_jobs = #{} :: map()
}).
-include_lib("kernel/include/logger.hrl").
-include_lib("mqtt_packet_map/include/mqtt_packet_map.hrl").
-include_lib("../include/mqtt_sessions.hrl").
-spec get_user_context( pid() ) -> {ok, term()} | {error, noproc}.
get_user_context(Pid) ->
try
gen_server:call(Pid, get_user_context, infinity)
catch
exit:{noproc, _} ->
{error, noproc}
end.
-spec set_user_context( pid(), term() ) -> ok | {error, noproc}.
set_user_context(Pid, UserContext) ->
try
gen_server:call(Pid, {set_user_context, UserContext}, infinity)
catch
exit:{noproc, _} ->
{error, noproc}
end.
-spec update_user_context( pid(), fun( (term()) -> term() ) ) -> ok | {error, noproc}.
update_user_context(Pid, Fun) ->
try
gen_server:call(Pid, {update_user_context, Fun}, infinity)
catch
exit:{noproc, _} ->
{error, noproc}
end.
-spec get_transport( pid() ) -> {ok, mqtt_sessions:transport()} | {error, notransport | noproc}.
get_transport(Pid) ->
try
gen_server:call(Pid, get_transport, infinity)
catch
exit:{noproc, _} ->
{error, noproc}
end.
-spec kill( pid() ) -> ok.
kill(Pid) when is_pid(Pid) ->
MRef = monitor(process, Pid),
gen_server:cast(Pid, kill),
receive
{'DOWN', MRef, process, Pid, _Reason} ->
ok
after ?KILL_TIMEOUT ->
erlang:exit(Pid, kill),
receive
{'DOWN', MRef, process, Pid, _Reason} ->
ok
end
end.
-spec incoming_connect(pid(), mqtt_packet_map:mqtt_packet(), mqtt_sessions:msg_options()) -> ok.
incoming_connect(Pid, Msg, Options) when is_map(Options) ->
gen_server:cast(Pid, {incoming_connect, Msg, Options}).
-spec incoming_data(pid(), binary()) -> ok | {error, wrong_connection | mqtt_packet_map:decode_error()}.
incoming_data(Pid, Data) ->
gen_server:call(Pid, {incoming_data, Data, self()}).
-spec start_link( Pool::atom(), ClientId::binary(), mqtt_sessions:session_options() ) -> {ok, pid()}.
start_link( Pool, ClientId, SessionOptions ) ->
gen_server:start_link(?MODULE, [ Pool, ClientId, SessionOptions ], []).
% ---------------------------------------------------------------------------------------
% --------------------------- gen_server functions --------------------------------------
% ---------------------------------------------------------------------------------------
init([ Pool, ClientId, SessionOptions ]) ->
RoutingId = mqtt_sessions_registry:routing_id(Pool),
mqtt_sessions_registry:register(Pool, ClientId, self()),
{ok, WillPid} = mqtt_sessions_will_sup:start(Pool, self()),
{ok, Runtime} = application:get_env(mqtt_sessions, runtime),
erlang:monitor(process, WillPid),
SessionOptions1 = SessionOptions#{
routing_id => RoutingId
},
KeepAliveRef = erlang:make_ref(),
erlang:send_after(?KEEP_ALIVE_DEFAULT * 500, self(), {keep_alive, KeepAliveRef}),
{ok, #state{
pool = Pool,
runtime = Runtime,
user_context = Runtime:new_user_context(Pool, ClientId, SessionOptions1),
client_id = ClientId,
routing_id = RoutingId,
buffer = #{},
will_pid = WillPid,
keep_alive = ?KEEP_ALIVE_DEFAULT,
keep_alive_counter = 3,
keep_alive_ref = KeepAliveRef
}}.
handle_call(get_user_context, _From, #state{ user_context = UserContext } = State) ->
{reply, {ok, UserContext}, State};
handle_call({set_user_context, UserContext}, _From, State) ->
{reply, ok, State#state{ user_context = UserContext }};
handle_call({update_user_context, Fun}, _From, #state{ user_context = UserContext} = State) ->
{reply, ok, State#state{ user_context = Fun(UserContext) }};
handle_call(get_transport, _From, #state{ transport = undefined } = State) ->
{reply, {error, notransport}, State};
handle_call(get_transport, _From, #state{ transport = Transport } = State) ->
{reply, {ok, Transport}, State};
handle_call({incoming_data, NewData, ConnectionPid}, _From, #state{ incoming_data = Data, connection_pid = ConnectionPid } = State) ->
Data1 = << Data/binary, NewData/binary >>,
case handle_incoming_data(Data1, State) of
{ok, {Rest, StateRest}} ->
{reply, ok, StateRest#state{ keep_alive_counter = 3, incoming_data = Rest }};
{error, Reason} when is_atom(Reason) ->
% illegal packet, disconnect and wait for new connection
?LOG_WARNING(#{
in => mqtt_sessions,
text => <<"Error decoding incoming data - disconnecting">>,
result => error,
reason => Reason
}),
{reply, {error, Reason}, force_disconnect(State)}
end;
handle_call({incoming_data, _NewData, ConnectionPid}, _From, State) ->
?LOG_DEBUG(#{
in => mqtt_sessions,
text => <<"MQTT session incoming data from unexpected Pid">>,
from_pid => ConnectionPid,
expected_pid => State#state.connection_pid
}),
{reply, {error, wrong_connection}, State};
handle_call(Cmd, _From, State) ->
{stop, {unknown_cmd, Cmd}, State}.
handle_cast({incoming_connect, Msg, Options}, State) ->
case handle_incoming_with_context(Msg, Options, State) of
{ok, State1} ->
{noreply, State1#state{ keep_alive_counter = 3, incoming_data = <<>> }};
{error, _} ->
{noreply, force_disconnect(State)}
end;
handle_cast(kill, State) ->
{stop, shutdown, State}.
handle_info({mqtt_msg, #{ type := publish } = MqttMsg}, State) ->
% io:fwrite(standard_error, "publish: ~p~n", [MqttMsg]),
State1 = relay_publish(MqttMsg, State),
{noreply, State1};
handle_info({keep_alive, Ref}, #state{ keep_alive_counter = 0, keep_alive_ref = Ref } = State) ->
?LOG_DEBUG("MQTT past keep_alive, disconnecting transport"),
{noreply, force_disconnect(State)};
handle_info({keep_alive, Ref}, #state{ keep_alive_counter = N, keep_alive_ref = Ref } = State) ->
erlang:send_after(State#state.keep_alive * 500, self(), {keep_alive, Ref}),
{noreply, State#state{ keep_alive_counter = erlang:max(N-1, 0) }};
handle_info({keep_alive, _Ref}, State) ->
{noreply, State};
handle_info({publish_job, undefined}, State) ->
{noreply, State};
handle_info({publish_job, JobPid}, #state{ publish_jobs = Jobs } = State) when is_pid(JobPid) ->
State1 = case erlang:is_process_alive(JobPid) of
true ->
State#state{ publish_jobs = Jobs#{ JobPid => erlang:monitor(process, JobPid) } };
false ->
State
end,
{noreply, State1};
handle_info({'DOWN', _Mref, process, Pid, _Reason}, #state{ connection_pid = Pid } = State) ->
State1 = do_disconnected(State),
{noreply, State1};
handle_info({'DOWN', _Mref, process, Pid, _Reason}, #state{ will_pid = Pid } = State) ->
send_transport(#{
type => disconnect,
reason_code => ?MQTT_RC_ERROR
}, State),
{stop, shutdown, State};
handle_info({'DOWN', _Mref, process, Pid, _Reason}, State) ->
State1 = case maps:is_key(Pid, State#state.publish_jobs) of
true ->
State#state{ publish_jobs = maps:remove(Pid, State#state.publish_jobs) };
false ->
State
end,
{noreply, State1};
handle_info(Info, State) ->
?LOG_INFO(#{
in => mqtt_sessions,
text => <<"Ignored unknown info message">>,
info_msg => Info
}),
{noreply, State}.
code_change(_Vsn, State, _Extra) ->
{ok, State}.
terminate(_Reason, _State) ->
ok.
% ---------------------------------------------------------------------------------------
% ----------------------------- support functions ---------------------------------------
% ---------------------------------------------------------------------------------------
handle_incoming_data(<<>>, State) ->
{ok, {<<>>, State}};
handle_incoming_data(Data, State) ->
case mqtt_packet_map:decode(State#state.protocol_version, Data) of
{ok, {Msg, Rest}} ->
case handle_incoming_with_context(Msg, #{}, State) of
{ok, State1} ->
handle_incoming_data(Rest, State1);
{error, _Reason} = Error ->
Error
end;
{error, incomplete_packet} ->
% @todo Limit buffer size, disconnect if over max size
{ok, {Data, State}};
{error, _Reason} = Error ->
Error
end.
handle_incoming_with_context(Msg, Options, #state{ runtime = Runtime, user_context = UserContext } = State) ->
case Runtime:is_valid_message(Msg, Options, UserContext) of
true ->
handle_incoming(Msg, Options, State);
false ->
% We don't want this here, drop connection
{error, invalid_message}
end.
handle_incoming(#{ type := connect } = Msg, Options, #state{ is_session_present = false } = State) ->
% First time connect, accept.
packet_connect(Msg, Options, State);
handle_incoming(#{ type := connect } = Msg, Options, State) ->
% A client reopens a connection. Check if the credentials match with the current
% session credentials (otherwise someone else might steal this session).
packet_connect(Msg, Options, State);
handle_incoming(#{ type := auth } = Msg, _Options, State) ->
packet_connect_auth(Msg, State);
handle_incoming(#{ type := Type }, _Options, #state{ connection_pid = undefined } = State) ->
?LOG_INFO(#{
in => mqtt_sessions,
text => <<"Dropping packet for MQTT session when not connected.">>,
result => error,
reason => not_connected,
pool => State#state.pool,
client_id => State#state.client_id,
session_pid => self(),
message_type => Type
}),
{error, not_connected};
handle_incoming(#{ type := Type }, _Options, #state{ is_session_present = false } = State) ->
% Only AUTH and CONNECT before the CONNACK
?LOG_INFO(#{
in => mqtt_sessions,
text => <<"MQTT received non AUTH or CONNECT before CONNACK - killed session">>,
result => error,
reason => no_connack,
pool => State#state.pool,
client_id => State#state.client_id,
session_pid => self(),
message_type => Type
}),
{stop, State};
handle_incoming(#{ type := publish } = Msg, _Options, State) ->
packet_publish(Msg, State);
% PUBREL is for publish messages sent by the client
handle_incoming(#{ type := pubrel } = Msg, _Options, State) ->
packet_pubrel(Msg, State);
% PUBREC, PUBACK, PUBCOMP is for publish messages sent by us to the client
handle_incoming(#{ type := pubrec } = Msg, _Options, State) ->
packet_pubrec(Msg, State);
handle_incoming(#{ type := pubcomp } = Msg, _Options, State) ->
packet_pubcomp(Msg, State);
handle_incoming(#{ type := puback } = Msg, _Options, State) ->
packet_puback(Msg, State);
handle_incoming(#{ type := subscribe } = Msg, _Options, State) ->
packet_subscribe(Msg, State);
handle_incoming(#{ type := unsubscribe } = Msg, _Options, State) ->
packet_unsubscribe(Msg, State);
handle_incoming(#{ type := pingreq }, _Options, State) ->
State1 = reply_or_drop(#{ type => pingresp }, State),
{ok, State1};
handle_incoming(#{ type := pingresp }, _Options, State) ->
{ok, State};
handle_incoming(#{ type := disconnect } = Msg, _Options, State) ->
packet_disconnect(Msg, State);
handle_incoming(#{ type := Type }, _Options, State) ->
?LOG_INFO(#{
in => mqtt_sessions,
text => <<"MQTT dropping unhandled packet with type">>,
message_type => Type
}),
{ok, State}.
% ---------------------------------------------------------------------------------------
% --------------------------- message type functions ------------------------------------
% ---------------------------------------------------------------------------------------
%% @doc Handle the connect message. Either this is a re-connect or the first connect.
packet_connect(#{ protocol_version := V, protocol_name := <<"MQTT">> }, Options, #state{ protocol_version = PV } = State)
when is_integer(PV), V =/= PV ->
% Do not change protocol versions for an existing session
ConnAck = #{
type => connack,
reason_code => ?MQTT_RC_NOT_AUTHORIZED
},
_ = reply_to_transport(ConnAck, set_connection(Options, State)),
{error, protocol_version_changed};
packet_connect(#{ protocol_version := 5, protocol_name := <<"MQTT">>, properties := Props } = Msg, Options, State) ->
% MQTT v5
ExpiryInterval = case maps:get(session_expiry_interval, Props, none) of
none -> ?SESSION_EXPIRY;
EI -> max(EI, ?SESSION_EXPIRY_MAX)
end,
KeepAlive = maps:get(keep_alive, Msg, ?KEEP_ALIVE_DEFAULT),
StateIfAccept = State#state{
protocol_version = 5,
will = extract_will(Msg),
session_expiry_interval = ExpiryInterval,
keep_alive = KeepAlive,
incoming_data = <<>>
},
StateIfAccept1 = set_connection(Options, StateIfAccept),
handle_connect_auth(Msg, Options, StateIfAccept1, State);
packet_connect(#{ protocol_version := 4, protocol_name := <<"MQTT">> } = Msg, Options, State) ->
% MQTT v3.1.1
KeepAlive = maps:get(keep_alive, Msg, ?KEEP_ALIVE_DEFAULT),
StateIfAccept = State#state{
protocol_version = 4,
will = extract_will(Msg),
session_expiry_interval = KeepAlive * 3,
keep_alive = KeepAlive,
incoming_data = <<>>
},
StateIfAccept1 = set_connection(Options, StateIfAccept),
handle_connect_auth(Msg, Options, StateIfAccept1, State);
packet_connect(_ConnectMsg, Options, State) ->
ConnAck = #{
type => connack,
reason_code => ?MQTT_RC_PROTOCOL_VERSION
},
_ = reply_to_transport(ConnAck, set_connection(Options, State)),
{error, protocol_version}.
packet_connect_auth(Msg, #state{ runtime = Runtime, user_context = UserContext } = State) ->
handle_connect_auth_1(Runtime:reauth(Msg, UserContext), Msg, State, State).
handle_connect_auth(Msg, Options, StateIfAccept, #state{ runtime = Runtime, is_session_present = IsSessionPresent, user_context = UserContext } = State) ->
handle_connect_auth_1(Runtime:connect(Msg, IsSessionPresent, Options, UserContext), Msg, StateIfAccept, State).
%% @doc Accept the new connection with the given ConnAck or Auth message.
%% If an Auth message is sent then we need further authenticaion handshakes.
%% Only after a succesful connack we will set the is_session_present flag.
handle_connect_auth_1({ok, #{ type := connack, reason_code := ?MQTT_RC_SUCCESS } = ConnAck, UserContext1},
#{ clean_start := CleanStart }, StateIfAccept, #state{ is_session_present = IsSessionPresent }) ->
StateCleaned = maybe_clean_start(CleanStart, StateIfAccept),
%% Set the session_present flag to true, if the runtime omitted it, and if there is a
%% session present.
ConnAck1 = case maps:find(session_present, ConnAck) of
{ok, _} -> ConnAck;
error ->
ConnAck#{ session_present => IsSessionPresent andalso not (CleanStart =:= true) }
end,
State1 = StateCleaned#state{
user_context = UserContext1,
is_session_present = true,
is_connected = true,
will = undefined,
connect_count = StateCleaned#state.connect_count + 1
},
State2 = reply_connack(ConnAck1, State1),
mqtt_sessions_will:connected(State2#state.will_pid, StateIfAccept#state.will,
State2#state.session_expiry_interval, State2#state.user_context),
State3 = resend_buffered_and_unacknowledged(State2),
{ok, State3};
handle_connect_auth_1({ok, #{ type := connack, reason_code := ReasonCode } = ConnAck, _UserContext1}, _Msg, StateIfAccept, _State) ->
_ = reply_connack(ConnAck, StateIfAccept),
?LOG_INFO(#{
in => mqtt_sessions,
text => <<"MQTT connect/auth refused">>,
result => error,
reason => connection_refused,
reason_code => ReasonCode,
connack => ConnAck
}),
{error, connection_refused};
handle_connect_auth_1({ok, #{ type := auth } = Auth, UserContext1}, _Msg, StateIfAccept, _State) ->
State1 = StateIfAccept#state{
user_context = UserContext1,
is_connected = true
},
State2 = reply_or_drop(Auth, State1),
mqtt_sessions_will:connected(State2#state.will_pid, undefined,
State2#state.session_expiry_interval, State2#state.user_context),
{ok, State2};
handle_connect_auth_1({error, Reason}, Msg, _StateIfAccept, _State) ->
?LOG_INFO(#{
in => mqtt_sessions,
text => <<"MQTT connect/auth refused">>,
result => error,
reason => connection_refused,
msg_reason => Reason,
message => Msg
}),
{error, connection_refused}.
%% @doc Drop all current subscriptions and buffered messages on a clean start
maybe_clean_start(false, State) ->
State;
maybe_clean_start(true, #state{ pool = Pool } = State) ->
mqtt_sessions_router:unsubscribe_pid(Pool, self()),
State#state{
buffer = #{},
awaiting_ack = #{},
awaiting_rel = #{}
}.
%% @doc Handle a publish request from remote to here
packet_publish(#{ topic := Topic, qos := 0 } = Msg,
#state{ runtime = Runtime, user_context = UCtx, client_id = ClientId } = State) ->
case Topic of
[ <<"$client">>, ClientId | Rest ] ->
MsgPub = mqtt_sessions_payload:decode(Msg#{ dup => false }),
{ok, UCtx1} = Runtime:control_message(Rest, MsgPub, UCtx),
{ok, State#state{ user_context = UCtx1 }};
_ ->
case Runtime:is_allowed(publish, Topic, Msg, UCtx) of
true ->
MsgPub = mqtt_sessions_payload:decode(Msg#{ dup => false }),
{ok, JobPid} = mqtt_sessions_router:publish(State#state.pool, Topic, MsgPub, UCtx),
self() ! {publish_job, JobPid},
{ok, State};
false ->
{ok, State}
end
end;
packet_publish(#{ topic := Topic, qos := 1, dup := Dup, packet_id := PacketId } = Msg,
#state{ runtime = Runtime, user_context = UCtx, awaiting_rel = WaitRel } = State) ->
case maps:find(PacketId, WaitRel) of
{ok, _} when not Dup ->
% There is a qos 2 level message with the same packet id
PubAck = #{
type => puback,
packet_id => PacketId,
reason_code => ?MQTT_RC_PACKET_ID_IN_USE
},
reply_or_drop(PubAck, State);
{ok, {pubrel, RC, _}} when Dup ->
% There is a qos 2 level message with the same packet id
% But the received mesage is a duplicate, just ack.
PubAck = #{
type => puback,
packet_id => PacketId,
reason_code => RC
},
reply_or_drop(PubAck, State);
error ->
RC = case Runtime:is_allowed(publish, Topic, Msg, UCtx) of
true ->
MsgPub = mqtt_sessions_payload:decode(Msg#{ dup => false }),
{ok, JobPid} = mqtt_sessions_router:publish(State#state.pool, Topic, MsgPub, UCtx),
self() ! {publish_job, JobPid},
?MQTT_RC_SUCCESS;
false ->
?MQTT_RC_NOT_AUTHORIZED
end,
PubAck = #{
type => puback,
packet_id => PacketId,
reason_code => RC
},
State1 = reply_or_drop(PubAck, State),
{ok, State1}
end;
packet_publish(#{ topic := Topic, qos := 2, dup := Dup, packet_id := PacketId } = Msg,
#state{ runtime = Runtime, user_context = UCtx, awaiting_rel = WaitRel } = State) ->
case maps:find(PacketId, WaitRel) of
{ok, _} when not Dup ->
PubRec = #{
type => pubrec,
packet_id => PacketId,
reason_code => ?MQTT_RC_PACKET_ID_IN_USE
},
reply_or_drop(PubRec, State);
{ok, {pubrel, RC, _}} when Dup ->
PubRec = #{
type => pubrec,
packet_id => PacketId,
reason_code => RC
},
State1 = reply_or_drop(PubRec, State),
{ok, State1};
error ->
RC = case Runtime:is_allowed(publish, Topic, Msg, UCtx) of
true ->
MsgPub = mqtt_sessions_payload:decode(Msg#{ dup => false }),
{ok, JobPid} = mqtt_sessions_router:publish(State#state.pool, Topic, MsgPub, UCtx),
self() ! {publish_job, JobPid},
?MQTT_RC_SUCCESS;
false ->
?MQTT_RC_NOT_AUTHORIZED
end,
State1 = if
RC < 16#80 ->
State#state{
awaiting_rel = WaitRel#{ PacketId => {pubrel, RC, mqtt_sessions_timestamp:timestamp()} }
};
true ->
State
end,
PubRec = #{
type => pubrec,
packet_id => PacketId,
reason_code => RC
},
State2 = reply_or_drop(PubRec, State1),
{ok, State2}
end.
%% @doc Handle the pubrel
packet_pubrel(#{ packet_id := PacketId, reason_code := ?MQTT_RC_SUCCESS }, #state{ awaiting_rel = WaitRel } = State) ->
case maps:find(PacketId, WaitRel) of
{ok, {pubrel, _RC, _Tm}} ->
PubComp = #{
type => pubcomp,
packet_id => PacketId,
reason_code => ?MQTT_RC_SUCCESS
},
WaitRel1 = maps:remove(PacketId, WaitRel),
State1 = reply_or_drop(PubComp, State),
{ok, State1#state{ awaiting_rel = WaitRel1 }};
error ->
PubComp = #{
type => pubcomp,
packet_id => PacketId,
reason_code => ?MQTT_RC_PACKET_ID_NOT_FOUND
},
State1 = reply_or_drop(PubComp, State),
{ok, State1}
end;
packet_pubrel(#{ packet_id := PacketId, reason_code := RC }, #state{ awaiting_rel = WaitRel } = State) ->
% Error server/client out of sync - remove the wait-rel for this packet_id
?LOG_INFO(#{
in => mqtt_sessions,
text => <<"PUBREL with non success reason for packet">>,
reason_code => RC,
packet_id => PacketId
}),
WaitRel1 = maps:remove(PacketId, WaitRel),
{ok, State#state{ awaiting_rel = WaitRel1 }}.
%% @doc Handle puback for QoS 1 publish messages sent to the client
packet_puback(#{ packet_id := PacketId }, #state{ awaiting_ack = WaitAck } = State) ->
WaitAck1 = case maps:find(PacketId, WaitAck) of
{ok, #wait_for{ is_sent = false }} ->
WaitAck;
{ok, #wait_for{ type = puback }} ->
maps:remove(PacketId, WaitAck);
{ok, #wait_for{ type = Wait, message = Msg }} ->
?LOG_WARNING(#{
in => mqtt_sessions,
text => <<"PUBACK for message wating for something else - dropping pending ack">>,
result => error,
packet_id => PacketId,
wait => Wait,
message => Msg
}),
maps:remove(PacketId, WaitAck);
error ->
WaitAck
end,
{ok, State#state{ awaiting_ack = WaitAck1 }}.
%% @doc Handle pubrec for QoS 2 publish messages sent to the client
packet_pubrec(#{ packet_id := PacketId, reason_code := RC }, #state{ awaiting_ack = WaitAck } = State) when RC >= 16#80 ->
WaitAck1 = case maps:find(PacketId, WaitAck) of
{ok, #wait_for{ is_sent = false }} ->
WaitAck;
{ok, #wait_for{ type = pubrec }} ->
maps:remove(PacketId, WaitAck);
{ok, #wait_for{ type = pubcomp }} ->
maps:remove(PacketId, WaitAck);
{ok, #wait_for{ type = Wait, message = Msg }} ->
?LOG_WARNING(#{
in => mqtt_sessions,
text => <<"PUBREC for message wating for something else - dropping pending ack">>,
result => error,
packet_id => PacketId,
wait => Wait,
message => Msg
}),
maps:remove(PacketId, WaitAck);
error ->
WaitAck
end,
{ok, State#state{ awaiting_ack = WaitAck1 }};
packet_pubrec(#{ packet_id := PacketId }, #state{ awaiting_ack = WaitAck } = State) ->
{WaitAck1, RC} = case maps:find(PacketId, WaitAck) of
{ok, #wait_for{ msg_nr = MsgNr, type = pubrec }} ->
WaitFor = #wait_for{
msg_nr = MsgNr,
type = pubcomp,
queued = mqtt_sessions_timestamp:timestamp()
},
{WaitAck#{ PacketId => WaitFor }, ?MQTT_RC_SUCCESS};
{ok, #wait_for{ type = pubcomp }} ->
{WaitAck, ?MQTT_RC_SUCCESS};
{ok, #wait_for{ is_sent = false }} ->
{WaitAck, ?MQTT_RC_SUCCESS};
{ok, #wait_for{ type = Wait, message = Msg }} ->
?LOG_WARNING(#{
in => mqtt_sessions,
text => <<"PUBREC for message wating for something else - dropping pending ack">>,
result => error,
packet_id => PacketId,
wait => Wait,
message => Msg
}),
{maps:remove(PacketId, WaitAck), ?MQTT_RC_PACKET_ID_NOT_FOUND};
error ->
{WaitAck, ?MQTT_RC_PACKET_ID_NOT_FOUND}
end,
State1 = State#state{ awaiting_ack = WaitAck1 },
PubRel = #{
type => pubrel,
packet_id => PacketId,
reason_code => RC
},
{ok, reply_or_drop(PubRel, State1)}.
%% @doc Handle pubcomp for QoS 2 publish messages sent to the client
packet_pubcomp(#{ packet_id := PacketId }, #state{ awaiting_ack = WaitAck } = State) ->
WaitAck1 = case maps:find(PacketId, WaitAck) of
{ok, #wait_for{ type = pubcomp }} ->
maps:remove(PacketId, WaitAck);
{ok, #wait_for{ is_sent = false }} ->
WaitAck;
{ok, #wait_for{ type = Wait, message = Msg }} ->
?LOG_WARNING(#{
in => mqtt_sessions,
text => <<"PUBCOMP for message wating for something else - dropping pending ack">>,
result => error,
packet_id => PacketId,
wait => Wait,
message => Msg
}),
maps:remove(PacketId, WaitAck);
error ->
WaitAck
end,
{ok, State#state{ awaiting_ack = WaitAck1 }}.
%% @doc Handle a subscribe request
packet_subscribe(#{ topics := Topics } = Msg, #state{ runtime = Runtime, user_context = UCtx } = State) ->
Resp = lists:map(
fun(#{ topic := TopicFilter0 } = Sub) ->
case mqtt_packet_map_topic:validate_topic(TopicFilter0) of
{ok, TopicFilter} ->
case Runtime:is_allowed(subscribe, TopicFilter, Msg, State#state.user_context) of
true ->
QoS = maps:get(qos, Sub, 0),
SubOptions = Sub#{
qos => QoS,
no_local => maps:get(no_local, Sub, false)
},
SubOptions1 = maps:remove(topic, SubOptions),
case mqtt_sessions_router:subscribe(State#state.pool, TopicFilter, self(), self(), SubOptions1, UCtx) of
ok -> {ok, QoS};
{error, _} -> {error, ?MQTT_RC_ERROR}
end;
false ->
{error, ?MQTT_RC_NOT_AUTHORIZED}
end;
{error, _} ->
{error, ?MQTT_RC_TOPIC_FILTER_INVALID}
end
end,
Topics),
SubAck = #{
type => suback,
packet_id => maps:get(packet_id, Msg, 0),
acks => Resp
},
State1 = reply_or_drop(SubAck, State),
{ok, State1}.
%% @doc Handle the unsubscribe request
packet_unsubscribe(#{ topics := Topics } = Msg, State) ->
Resp = lists:map(
fun(TopicFilter) ->
case mqtt_sessions_router:unsubscribe(State#state.pool, TopicFilter, self()) of
ok -> {ok, found};
{error, notfound} -> {ok, notfound}
end
end,
Topics),
UnsubAck = #{
type => unsuback,
packet_id => maps:get(packet_id, Msg, 0),
acks => Resp
},
State1 = reply_or_drop(UnsubAck, State),
{ok, State1}.
%% @doc Handle a disconnect from the client.
packet_disconnect(#{ reason_code := RC, properties := Props },
#state{ will_pid = WillPid, session_expiry_interval = SessionExpiryInterval } = State) ->
NewExpiryInterval = case SessionExpiryInterval of
0 ->
% TODO: If the props.session_expiry_interval > 0 then send disconnect with MQTT_RC_PROTOCOL_ERROR
SessionExpiryInterval;
_ ->
maps:get(session_expiry_interval, Props, SessionExpiryInterval)
end,
IsSendWill = (RC =/= ?MQTT_RC_SUCCESS),
mqtt_sessions_will:disconnected(WillPid, IsSendWill, NewExpiryInterval),
State1 = force_disconnect(State),
case NewExpiryInterval of
0 -> gen_server:cast(self(), kill);
_ -> ok
end,
{ok, State1}.
% ---------------------------------------------------------------------------------------
% --------------------------- relay publish to client -----------------------------------
% ---------------------------------------------------------------------------------------
relay_publish(#{ type := publish, message := Msg } = MqttMsg, State) ->
QoS = erlang:min( maps:get(qos, Msg, 0), maps:get(qos, MqttMsg, 0) ),
Msg2 = mqtt_sessions_payload:encode(Msg#{
qos => QoS,
dup => false
}),
StatePurged = maybe_purge(State),
case QoS of
0 ->
State2 = #state{ msg_nr = MsgNr } = inc_msg_nr(StatePurged),
reply_or_queue(Msg2#{ packet_id => 0 }, MsgNr, State2);
_ ->
case maps:size(StatePurged#state.awaiting_ack) >= ?MAX_INFLIGHT_ACK of
true when State#state.transport =/= undefined ->
?LOG_INFO(#{
in => mqtt_session,
text => <<"Not accepting QoS 1/2 message, too many inflight or queued acks">>,
result => error,
reason => buffer_full
}),
StatePurged;
true ->
% Dormant session, just drop excess messages.
StatePurged;
false ->
State1 = #state{ packet_id = PacketId } = inc_packet_id(StatePurged),
State2 = #state{ msg_nr = MsgNr } = inc_msg_nr(State1),
AckRec = case QoS of
1 -> puback;
2 -> pubrec
end,
Msg3 = Msg2#{
packet_id => PacketId
},
{IsSent, State3} = if
State2#state.transport =:= undefined ->
{false, State2};
true ->
{true, reply_or_drop(Msg3, State2)}
end,
WaitFor = #wait_for{
msg_nr = MsgNr,
message = Msg3,
type = AckRec,
is_sent = IsSent,
queued = mqtt_sessions_timestamp:timestamp()
},
State3#state{
awaiting_ack = (State3#state.awaiting_ack)#{
PacketId => WaitFor
}
}
end
end.
% ---------------------------------------------------------------------------------------
% ------------------------------- queue functions ---------------------------------------
% ---------------------------------------------------------------------------------------
delete_buffered_qos0(#state{ buffer = Buffer } = State) ->
Buffer1 = maps:filter(fun(_MsgNr, #queued{ qos = QoS }) -> QoS > 0 end, Buffer),
State#state{ buffer = Buffer1 }.
resend_buffered_and_unacknowledged(#state{ awaiting_ack = AwaitAck, buffer = Buffer } = State) ->
ResendMap = maps:fold(
fun
(_PacketId, #wait_for{ is_sent = false, msg_nr = MsgNr, message = Msg }, Acc) ->
% Unsent QoS 1 or 2 message
Acc#{ MsgNr => Msg };
(_PacketId, #wait_for{ msg_nr = MsgNr, type = puback, message = Msg }, Acc) ->
Acc#{ MsgNr => Msg#{ dup => true } };
(_PacketId, #wait_for{ msg_nr = MsgNr, type = pubrec, message = Msg }, Acc) ->
Acc#{ MsgNr => Msg#{ dup => true } };
(PacketId, #wait_for{ msg_nr = MsgNr, type = pubcomp }, Acc) ->
PubComp = #{
type => pubrec,
packet_id => PacketId
},
Acc#{ MsgNr => PubComp };
(_PacketId, #wait_for{ msg_nr = MsgNr, type = suback, message = Msg }, Acc) ->
Acc#{ MsgNr => Msg };
(_PacketId, #wait_for{ msg_nr = MsgNr, type = unsuback, message = Msg }, Acc) ->
Acc#{ MsgNr => Msg };
(_PacketId, _, Acc) ->
Acc
end,
Buffer,
AwaitAck),
ResendList = lists:sort(maps:to_list(ResendMap)),
lists:foldl(
fun
(_Msg, #state{ transport = undefined } = AccState) ->
AccState;
({_MsgNr, #{ type := publish, packet_id := PacketId } = Msg}, AccState) ->
AccState1 = mark_packet_sent(PacketId, AccState),
reply_or_drop(Msg, AccState1);
({_MsgNr, #{ type := _} = Msg}, AccState) ->
reply_or_drop(Msg, AccState);
({MsgNr, #queued{ message = Msg } = Q}, AccState) ->
case reply_or_drop(Msg, AccState) of
#state{ transport = undefined, buffer = AccBuffer } = AccState1 ->
AccBuffer1 = AccBuffer#{ MsgNr => Q },
AccState1#state{ buffer = AccBuffer1 };
#state{} = AccState1 ->
AccState1
end
end,
State#state{ buffer = #{} },
ResendList).
mark_packet_sent(PacketId, #state{ awaiting_ack = AwaitAck } = State) ->
WaitFor = maps:get(PacketId, AwaitAck),
State#state{
awaiting_ack = AwaitAck#{ PacketId => WaitFor#wait_for{ is_sent = true } }
}.
% ---------------------------------------------------------------------------------------
% -------------------------------- misc functions ---------------------------------------
% ---------------------------------------------------------------------------------------
%% @doc Called when the connection disconnects or crashes/stops
do_disconnected(#state{ will_pid = WillPid } = State) ->
mqtt_sessions_will:disconnected(WillPid),
cleanup_state_disconnected(State).
%% @todo Cleanup pending messages and awaiting states.
cleanup_state_disconnected(State) ->
delete_buffered_qos0(State#state{
connection_pid = undefined,
transport = undefined,
is_connected = false
}).
%% @doc Send a connack to the remote, ping the will-watchdog that we connected
reply_connack(#{ type := connack, reason_code := ?MQTT_RC_SUCCESS } = ConnAck, State) ->
AckProps = maps:get(properties, ConnAck, #{}),
ConnAck1 = ConnAck#{
properties => AckProps#{
session_expiry_interval => State#state.session_expiry_interval,
server_keep_alive => State#state.keep_alive,
assigned_client_identifier => State#state.client_id,
subscription_identifier_available => false,
shared_subscription_available => false,
<<"cotonic-routing-id">> => State#state.routing_id
}
},
reply_to_transport(ConnAck1, State);
reply_connack(#{ type := connack } = ConnAck, State) ->
reply_to_transport(ConnAck, State).
%% @doc Check the connect packet, extract the will as a map for the will-watchdog.
extract_will(#{ type := connect, will_flag := false }) ->
#{};
extract_will(#{ protocol_version := 4, type := connect, will_flag := true } = Msg) ->
#{
expiry_interval => 0,
topic => maps:get(will_topic, Msg),
payload => maps:get(will_payload, Msg, <<>>),
properties => maps:get(will_properties, Msg, #{}),
qos => maps:get(will_qos, Msg, 0),
retain => maps:get(will_retain, Msg, false)
};
extract_will(#{ type := connect, will_flag := true, properties := Props } = Msg) ->
#{
expiry_interval => maps:get(will_expiry_interval, Props, 0),
topic => maps:get(will_topic, Msg),
payload => maps:get(will_payload, Msg, <<>>),
properties => maps:get(will_properties, Msg, #{}),
qos => maps:get(will_qos, Msg, 0),
retain => maps:get(will_retain, Msg, false)
}.
force_disconnect(#state{ connection_pid = undefined, transport = undefined } = State) ->
State;
force_disconnect(State) ->
State1 = disconnect_transport(State),
if
is_pid(State#state.connection_pid) ->
State#state.connection_pid ! {mqtt_transport, self(), disconnect};
true ->
ok
end,
State2 = cleanup_state_disconnected(State1),
case State2#state.is_session_present of
false ->
gen_server:cast(self(), kill);
true ->
ok
end,
State2.
disconnect_transport(#state{ transport = undefined } = State) ->
State;
disconnect_transport(#state{ transport = Transport } = State) when is_pid(Transport) ->
Transport ! {mqtt_transport, self(), disconnect},
State#state{ transport = undefined, is_connected = false };
disconnect_transport(#state{ transport = Transport } = State) when is_function(Transport) ->
Transport(disconnect),
State#state{ transport = undefined, is_connected = false };
disconnect_transport(#state{ transport = {M, F, A} } = State) ->
erlang:apply(M, F, [disconnect | A]),
State#state{ transport = undefined, is_connected = false }.
reply_to_transport(_Msg, #state{ transport = undefined } = State) ->
State;
reply_to_transport(Msg, State) ->
case send_transport(Msg, State) of
ok ->
State;
{error, _} ->
force_disconnect(State)
end.
reply_or_drop(_Msg, #state{ is_connected = false } = State) ->
State;
reply_or_drop(Msg, State) ->
case send_transport(Msg, State) of
ok ->
State;
{error, _} ->
force_disconnect(State)
end.
reply_or_queue(Msg, MsgNr, #state{ is_connected = false } = State) ->
maybe_purge( queue(Msg, MsgNr, State) );
reply_or_queue(Msg, MsgNr, State) ->
case send_transport(Msg, State) of
ok ->
State;
{error, _} ->
State1 = force_disconnect(State),
maybe_purge( queue(Msg, MsgNr, State1) )
end.
send_transport(_Msg, #state{ transport = undefined }) ->
ok;
send_transport(Msg, #state{ protocol_version = PV } = State) when is_map(Msg) ->
send_transport(encode(PV, Msg), State);
send_transport(Msg, #state{ transport = Pid }) when is_pid(Pid) ->
case erlang:is_process_alive(Pid) of
true ->
Pid ! {mqtt_transport, self(), Msg},
ok;
false ->
{error, transport_down}
end;
send_transport(Msg, #state{ transport = Fun }) when is_function(Fun) ->
Fun(Msg);
send_transport(Msg, #state{ transport = {M, F, A} }) ->
erlang:apply(M, F, [Msg | A]).
%% @doc Queue a message, extract, type, message expiry, and QoS
queue(#{ type := Type } = Msg, MsgNr, #state{ buffer = Buffer } = State) ->
Props = maps:get(properties, Msg, #{}),
Now = mqtt_sessions_timestamp:timestamp(),
Item = #queued{
msg_nr = MsgNr,
type = Type,
queued = Now,
packet_id = maps:get(packet_id, Msg, 0),
expiry = Now + maps:get(message_expiry_interval, Props, ?MESSAGE_EXPIRY_DEFAULT),
qos = maps:get(qos, Msg, 1),
message = Msg
},
State#state{ buffer = Buffer#{ MsgNr => Item } }.
maybe_purge(#state{ buffer = Buffer, awaiting_ack = WaitAcks } = State) ->
State#state{
buffer = maybe_purge_buffer(Buffer),
awaiting_ack = maybe_purge_ack(WaitAcks)
}.
%% @doc Drop expired acks and expired waiting QoS 1/2 messages.
maybe_purge_ack(WaitAcks) ->
Now = mqtt_sessions_timestamp:timestamp(),
maps:filter(
fun
(_, #wait_for{ is_sent = true, message = Msg, queued = Queued }) ->
Props = maps:get(properties, Msg, #{}),
Expiry = Queued + maps:get(message_expiry_interval, Props, ?MESSAGE_EXPIRY_DEFAULT),
Expiry > Now;
(_, #wait_for{ queued = Queued }) ->
Queued + ?ACK_EXPIRY > Now
end,
WaitAcks).
%% @doc Purge expired and oldest message from the QoS 0 buffer.
%% In case of overflow, drop the 20% oldest messages
maybe_purge_buffer(Buffer) ->
Now = mqtt_sessions_timestamp:timestamp(),
Buffer1 = maps:filter(
fun(_MsgNr, #queued{ expiry = Expiry }) -> Now < Expiry end,
Buffer),
case maps:size(Buffer1) > ?MAX_BUFFERED of
true ->
Qs = lists:sort(maps:to_list(Buffer)),
{_, Qs1} = lists:split(maps:size(Buffer) div 5, Qs),
maps:from_list(Qs1);
false ->
Buffer1
end.
-spec encode( mqtt_packet_map:mqtt_version(), mqtt_packet_map:mqtt_packet() | list( mqtt_packet_map:mqtt_packet() )) -> binary().
encode(ProtocolVersion, Msg) when is_map(Msg) ->
{ok, Bin} = mqtt_packet_map:encode(ProtocolVersion, Msg),
Bin.
%% @doc Set the new connection, disconnect existing transport.
set_connection(#{ connection_pid := ConnectionPid, transport := Transport }, State) ->
case State#state.connection_pid of
ConnectionPid ->
State#state{ transport = Transport };
undefined ->
set_connection_1(ConnectionPid, Transport, State);
OldConnectionPid ->
OldConnectionPid ! {mqtt_transport, self(), disconnect},
set_connection_1(ConnectionPid, Transport, State)
end.
set_connection_1(ConnectionPid, Transport, State) ->
erlang:monitor(process, ConnectionPid),
start_keep_alive(State#state{ connection_pid = ConnectionPid, transport = Transport }).
start_keep_alive(#state{ keep_alive = 0 } = State) ->
State;
start_keep_alive(#state{ keep_alive = N } = State) ->
Ref = erlang:make_ref(),
erlang:send_after(N * 500, self(), {keep_alive, Ref}),
State#state{ keep_alive_counter = 3, keep_alive_ref = Ref }.
%% @doc Increment the message number, this number is used for order of resending buffered messages
inc_msg_nr(#state{ msg_nr = Nr } = State) ->
State#state{ msg_nr = Nr + 1 }.
%% @doc Fetch a packet id that is not yet used.
inc_packet_id(#state{ packet_id = PacketId, awaiting_ack = Acks } = State) ->
PacketId1 = case PacketId >= ?MAX_PACKET_ID of
true -> 1;
false -> PacketId + 1
end,
State1 = State#state{ packet_id = PacketId1 },
case maps:is_key(PacketId1, Acks) of
true -> inc_packet_id(State1);
false -> State1
end.