%% @doc Client to interact with grisp.io
%%
%% This module contains a state machine to ensure connectivity with grisp.io.
%% JsonRPC traffic is managed here.
%% @end
-module(grisp_connect_client).
-behaviour(gen_statem).
-include_lib("kernel/include/logger.hrl").
% External API
-export([start_link/0]).
-export([connect/0]).
-export([is_connected/0]).
-export([wait_connected/1]).
-export([request/3]).
-export([notify/3]).
% Internal API
-export([reboot/0]).
% Behaviour gen_statem callback functions
-export([init/1, terminate/3, code_change/4, callback_mode/0]).
% State Functions
-export([idle/3]).
-export([waiting_ip/3]).
-export([connecting/3]).
-export([connected/3]).
%--- Types ---------------------------------------------------------------------
-record(data, {
domain :: binary(),
port :: inet:port_number(),
ws_path :: binary(),
ws_transport :: tcp | tls,
conn :: undefined | pid(),
retry_count = 0 :: non_neg_integer(),
last_error :: term(),
max_retries = infinity :: non_neg_integer() | infinity,
wait_calls = [] :: [gen_statem:from()]
}).
-type data() :: #data{}.
-type on_result_fun() :: fun((data(), Result :: term()) -> data()).
-type on_error_fun() :: fun((data(), local | remote,
Code :: atom() | integer(),
Message :: undefined | binary(),
Data :: term()) -> data()).
%--- Macros --------------------------------------------------------------------
-define(GRISP_IO_PROTOCOL, <<"grisp-io-v1">>).
-define(FORMAT(FMT, ARGS), iolist_to_binary(io_lib:format(FMT, ARGS))).
-define(CONNECT_TIMEOUT, 5000).
-define(ENV(KEY, GUARDS), fun() ->
case application:get_env(grisp_connect, KEY) of
{ok, V} when GUARDS -> V;
{ok, V} -> erlang:exit({invalid_env, KEY, V});
undefined -> erlang:exit({missing_env, KEY})
end
end()).
-define(ENV(KEY, GUARDS, CONV), fun() ->
case application:get_env(grisp_connect, KEY) of
{ok, V} when GUARDS -> CONV;
{ok, V} -> erlang:exit({invalid_env, KEY, V});
undefined -> erlang:exit({missing_env, KEY})
end
end()).
-define(HANDLE_COMMON,
?FUNCTION_NAME(EventType, EventContent, Data) ->
handle_common(EventType, EventContent, ?FUNCTION_NAME, Data)).
%--- External API Functions ----------------------------------------------------
start_link() ->
gen_statem:start_link({local, ?MODULE}, ?MODULE, [], []).
connect() ->
gen_statem:cast(?MODULE, ?FUNCTION_NAME).
is_connected() ->
try gen_statem:call(?MODULE, ?FUNCTION_NAME)
catch exit:noproc -> false
end.
wait_connected(Timeout) ->
try gen_statem:call(?MODULE, ?FUNCTION_NAME, Timeout)
catch exit:noproc -> {error, noproc}
end.
request(Method, Type, Params) ->
gen_statem:call(?MODULE, {?FUNCTION_NAME, Method, Type, Params}).
notify(Method, Type, Params) ->
gen_statem:cast(?MODULE, {?FUNCTION_NAME, Method, Type, Params}).
%--- Internal API Functions ----------------------------------------------------
reboot() ->
erlang:send_after(1000, ?MODULE, reboot).
%--- Behaviour gen_statem Callback Functions -----------------------------------
init([]) ->
process_flag(trap_exit, true),
AutoConnect = ?ENV(connect, is_boolean(V)),
Domain = ?ENV(domain, is_binary(V) orelse is_list(V) orelse is_atom(V), as_bin(V)),
Port = ?ENV(port, is_integer(V) andalso V >= 0 andalso V < 65536),
WsTransport = ?ENV(ws_transport, V =:= tls orelse V =:= tcp),
WsPath = ?ENV(ws_path, is_binary(V) orelse is_list(V), as_bin(V)),
MaxRetries = ?ENV(ws_max_retries, is_integer(V) orelse V =:= infinity),
Data = #data{
domain = Domain,
port = Port,
ws_transport = WsTransport,
ws_path = WsPath,
max_retries = MaxRetries
},
% The error list is put in a persistent term to not add noise to the state.
persistent_term:put({?MODULE, self()}, generic_errors()),
NextState = case AutoConnect of
true -> waiting_ip;
false -> idle
end,
{ok, NextState, Data}.
terminate(Reason, _State, Data) ->
conn_close(Data, Reason),
persistent_term:erase({?MODULE, self()}),
ok.
code_change(_Vsn, State, Data, _Extra) -> {ok, State, Data}.
callback_mode() -> [state_functions, state_enter].
%--- Behaviour gen_statem State Callback Functions -----------------------------
idle(enter, _OldState,
Data = #data{wait_calls = WaitCalls, last_error = LastError}) ->
% When entering idle, we reply to all wait_connected calls with the last error
gen_statem:reply([{reply, F, {error, LastError}} || F <- WaitCalls]),
{keep_state, Data#data{wait_calls = [], last_error = undefined}};
idle({call, From}, wait_connected, _) ->
{keep_state_and_data, [{reply, From, {error, not_connecting}}]};
idle(cast, connect, Data) ->
{next_state, waiting_ip, Data};
?HANDLE_COMMON.
% @doc State waiting_ip is used to check the device has an IP address.
% The first time entering this state, the check will be performed right away.
% If the device do not have an IP address, it will wait a fixed amount of time
% and check again, without incrementing the retry counter.
waiting_ip(enter, _OldState, _Data) ->
% First IP check do not have any delay
{keep_state_and_data, [{state_timeout, 0, check_ip}]};
waiting_ip(state_timeout, check_ip, Data) ->
case grisp_connect_utils:check_inet_ipv4() of
{ok, IP} ->
?LOG_DEBUG(#{description => ?FORMAT("IP Address available: ~s",
[format_ipv4(IP)]),
event => checked_ip, ip => format_ipv4(IP)}),
{next_state, connecting, Data};
invalid ->
?LOG_DEBUG(#{description => <<"Waiting for an IP address do connect to grisp.io">>,
event => waiting_ip}),
{keep_state_and_data, [{state_timeout, 1000, check_ip}]}
end;
?HANDLE_COMMON.
% @doc State connecting is used to establish a connection to grisp.io.
connecting(enter, _OldState, #data{retry_count = RetryCount}) ->
Delay = grisp_connect_utils:retry_delay(RetryCount),
?LOG_DEBUG("Scheduling connection attempt in ~w ms", [Delay]),
{keep_state_and_data, [{state_timeout, Delay, connect}]};
connecting(state_timeout, connect, Data = #data{conn = undefined}) ->
?LOG_INFO(#{description => <<"Connecting to grisp.io ...">>,
event => connecting}),
case conn_start(Data) of
{ok, Data2} ->
{keep_state, Data2, [{state_timeout, ?CONNECT_TIMEOUT, timeout}]};
{error, Reason} ->
?LOG_WARNING(#{description => ?FORMAT("Failed to connect to grisp.io: ~p", [Reason]),
event => connection_failed, reason => Reason}),
reconnect(Data, Reason)
end;
connecting(state_timeout, timeout, Data) ->
Reason = connect_timeout,
?LOG_WARNING(#{description => <<"Timeout while connecting to grisp.io">>,
event => connection_failed, reason => Reason}),
reconnect(conn_close(Data, Reason), Reason);
connecting(info, {jarl, Conn, {connected, _}}, Data = #data{conn = Conn}) ->
% Received from the connection process
?LOG_NOTICE(#{description => <<"Connected to grisp.io">>,
event => connected}),
{next_state, connected, Data#data{retry_count = 0}};
?HANDLE_COMMON.
connected(enter, _OldState, Data = #data{wait_calls = WaitCalls}) ->
% When entering connected, we reply to all wait_connected calls with ok
gen_statem:reply([{reply, F, ok} || F <- WaitCalls]),
{keep_state, Data#data{wait_calls = [], last_error = undefined}};
connected({call, From}, is_connected, _) ->
{keep_state_and_data, [{reply, From, true}]};
connected(info, {jarl, Conn, Msg}, Data = #data{conn = Conn}) ->
handle_connection_message(Data, Msg);
connected({call, From}, {request, Method, Type, Params}, Data) ->
Data2 = conn_request(Data, Method, Type, Params,
fun(D, R) -> gen_statem:reply(From, {ok, R}), D end,
fun(D, _, C, _, _) -> gen_statem:reply(From, {error, C}), D end),
{keep_state, Data2};
connected(cast, {notify, Method, Type, Params}, Data) ->
conn_notify(Data, Method, Type, Params),
keep_state_and_data;
?HANDLE_COMMON.
% Common event handling appended as last match case to each state_function
handle_common(cast, connect, State, _Data) when State =/= idle ->
keep_state_and_data;
handle_common({call, From}, is_connected, State, _) when State =/= connected ->
{keep_state_and_data, [{reply, From, false}]};
handle_common({call, From}, wait_connected, _State,
Data = #data{wait_calls = WaitCalls}) ->
{keep_state, Data#data{wait_calls = [From | WaitCalls]}};
handle_common({call, From}, {request, _, _, _}, State, _Data)
when State =/= connected ->
{keep_state_and_data, [{reply, From, {error, disconnected}}]};
handle_common(cast, {notify, _Method, _Type, _Params}, _State, _Data) ->
% We ignore notifications sent while disconnected
keep_state_and_data;
handle_common(info, reboot, _, _) ->
init:stop(),
keep_state_and_data;
handle_common(info, {'EXIT', Conn, Reason}, _State, Data = #data{conn = Conn}) ->
RealReason = case Reason of
{shutdown, R} -> R;
R -> R
end,
?LOG_WARNING(#{description => ?FORMAT("Connection to grisp.io terminated: ~p", [RealReason]),
event => connection_failed, reason => RealReason}),
reconnect(conn_died(Data), RealReason);
handle_common(info, {'EXIT', _Conn, _Reason}, _State, _Data) ->
% Ignore any EXIT from past jarl connections
keep_state_and_data;
handle_common(info, {jarl, Conn, Msg}, State, _Data) ->
?LOG_DEBUG("Received message from unknown connection ~p in state ~w: ~p",
[Conn, State, Msg]),
keep_state_and_data;
handle_common(cast, Cast, _, _) ->
error({unexpected_cast, Cast});
handle_common({call, _}, Call, _, _) ->
error({unexpected_call, Call});
handle_common(info, Info, State, _Data) ->
?LOG_WARNING(#{description => <<"Unexpected message">>,
event => unexpected_info, info => Info, state => State}),
keep_state_and_data.
%--- Internal Functions --------------------------------------------------------
generic_errors() -> [
{device_not_linked, -1, <<"Device not linked">>},
{token_expired, -2, <<"Token expired">>},
{device_already_linked, -3, <<"Device already linked">>},
{invalid_token, -4, <<"Invalid token">>},
{grisp_updater_unavailable, -10, <<"Software update unavailable">>},
{already_updating, -11, <<"Already updating">>},
{boot_system_not_validated, -12, <<"Boot system not validated">>},
{validate_from_unbooted, -13, <<"Validate from unbooted">>}
].
as_bin(Binary) when is_binary(Binary) -> Binary;
as_bin(List) when is_list(List) -> list_to_binary(List);
as_bin(Atom) when is_atom(Atom) -> atom_to_binary(Atom).
format_ipv4({A, B, C, D}) ->
?FORMAT("~w.~w.~w.~w", [A, B, C, D]).
handle_connection_message(_Data, {response, _Result, #{on_result := undefined}}) ->
keep_state_and_data;
handle_connection_message(Data, {response, Result, #{on_result := OnResult}}) ->
{keep_state, OnResult(Data, Result)};
handle_connection_message(_Data, {error, Code, Msg, _ErrorData,
#{on_error := undefined}}) ->
?LOG_WARNING("Unhandled remote request error ~w: ~s", [Code, Msg]),
keep_state_and_data;
handle_connection_message(Data, {error, Code, Msg, ErrorData,
#{on_error := OnError}}) ->
{keep_state, OnError(Data, remote, Code, Msg, ErrorData)};
handle_connection_message(_Data, {jarl_error, Reason,
#{on_error := undefined}}) ->
?LOG_WARNING("Unhandled local request error ~w", [Reason]),
keep_state_and_data;
handle_connection_message(Data, {jarl_error, Reason,
#{on_error := OnError}}) ->
{keep_state, OnError(Data, local, Reason, undefined, undefined)};
handle_connection_message(Data, Msg) ->
case grisp_connect_api:handle_msg(Msg) of
ok -> keep_state_and_data;
{error, Code, Message, ErData, ReqRef} ->
conn_error(Data, Code, Message, ErData, ReqRef),
keep_state_and_data;
{reply, Result, ReqRef} ->
conn_result(Data, Result, ReqRef),
keep_state_and_data
end.
% @doc Setup the state machine to rety connecting to grisp.io if the maximum
% number of allowed atempts has not been reached.
% Otherwise, the state machine will give up and go back to idle.
reconnect(Data = #data{retry_count = RetryCount,
max_retries = MaxRetries,
last_error = LastError}, Reason)
when MaxRetries =/= infinity, RetryCount >= MaxRetries ->
Error = case Reason of undefined -> LastError; E -> E end,
?LOG_ERROR(#{description => <<"Max retries reached, giving up connecting to grisp.io">>,
event => max_retries_reached, last_error => LastError}),
{next_state, idle, Data#data{retry_count = 0, last_error = Error}};
reconnect(Data = #data{retry_count = RetryCount, last_error = LastError},
Reason) ->
Error = case Reason of undefined -> LastError; E -> E end,
% When reconnecting we always increment the retry counter, even if we
% where connected and it was reset to 0, the next step will always be
% retry number 1. It should never reconnect right away.
{next_state, waiting_ip,
Data#data{retry_count = RetryCount + 1, last_error = Error}}.
% Connection Functions
conn_start(Data = #data{conn = undefined,
domain = Domain,
port = Port,
ws_path = WsPath,
ws_transport = WsTransport}) ->
WsPingTimeout = ?ENV(ws_ping_timeout, V =:= infinity orelse is_integer(V)),
WsReqTimeout = ?ENV(ws_request_timeout, V =:= infinity orelse is_integer(V)),
ConnTransport = case WsTransport of
tcp -> tcp;
tls -> {tls, grisp_cryptoauth_tls:options(Domain)}
end,
ErrorList = persistent_term:get({?MODULE, self()}),
ConnOpts = #{
domain => Domain,
port => Port,
transport => ConnTransport,
path => WsPath,
errors => ErrorList,
ping_timeout => WsPingTimeout,
request_timeout => WsReqTimeout,
protocols => [?GRISP_IO_PROTOCOL]
},
case jarl:start_link(self(), ConnOpts) of
{error, _Reason} = Error -> Error;
{ok, Conn} -> {ok, Data#data{conn = Conn}}
end.
% Safe to call in any state
conn_close(Data = #data{conn = undefined}, _Reason) ->
Data;
conn_close(Data = #data{conn = Conn}, _Reason) ->
jarl:disconnect(Conn),
Data#data{conn = undefined}.
% Safe to call in any state
conn_died(Data) ->
Data#data{conn = undefined}.
-spec conn_request(data(), jarl:method(), atom(), map(),
undefined | on_result_fun(), undefined | on_error_fun())
-> data().
conn_request(Data = #data{conn = Conn}, Method, Type, Params, OnResult, OnError)
when Conn =/= undefined ->
ReqCtx = #{on_result => OnResult, on_error => OnError},
Params2 = maps:put(type, Type, Params),
case jarl:request(Conn, Method, Params2, ReqCtx) of
ok -> Data;
{jarl_error, Reason} ->
OnError(Data, local, Reason, undefined, undefined)
end.
conn_notify(#data{conn = Conn}, Method, Type, Params)
when Conn =/= undefined ->
Params2 = maps:put(type, Type, Params),
jarl:notify(Conn, Method, Params2).
conn_result(#data{conn = Conn}, Result, ReqRef)
when Conn =/= undefined ->
jarl:reply(Conn, Result, ReqRef).
conn_error(#data{conn = Conn}, Code, Message, ErData, ReqRef)
when Conn =/= undefined, is_binary(ErData) orelse ErData =:= undefined ->
jarl:reply(Conn, Code, Message, ErData, ReqRef);
conn_error(Data, Code, Message, ErData, ReqRef) ->
BinErData = iolist_to_binary(io_lib:format("~p", [ErData])),
conn_error(Data, Code, Message, BinErData, ReqRef).