%% Copyright (c) 2017-2024 Guilherme Andrade
%%
%% Permission is hereby granted, free of charge, to any person obtaining a
%% copy of this software and associated documentation files (the "Software"),
%% to deal in the Software without restriction, including without limitation
%% the rights to use, copy, modify, merge, publish, distribute, sublicense,
%% and/or sell copies of the Software, and to permit persons to whom the
%% Software is furnished to do so, subject to the following conditions:
%%
%% The above copyright notice and this permission notice shall be included in
%% all copies or substantial portions of the Software.
%%
%% THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
%% IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
%% FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
%% AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
%% LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
%% FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
%% DEALINGS IN THE SOFTWARE.
%%
%% locus is an independent project and has not been authorized, sponsored,
%% or otherwise approved by MaxMind.
%% @doc Launches `locus_database_loader' and later shares the unpacked database,
%% as well as any updates to it, through `persistent_term'.
-module(locus_database).
-behaviour(gen_server).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
%% ------------------------------------------------------------------
%% API Function Exports
%% ------------------------------------------------------------------
-export(
[version/1
]).
%% ------------------------------------------------------------------
%% "Private" API Function Exports
%% ------------------------------------------------------------------
-export(
[start/3,
stop/2,
start_link/3,
dynamic_child_spec/1,
static_child_spec/4,
async_get_version_or_subscribe/1,
find/1
]).
-ignore_xref(
[start_link/3
]).
-ifdef(TEST).
-export([whereis/1]).
-export([list_subscribers/1]).
-endif.
%% ------------------------------------------------------------------
%% gen_server Function Exports
%% ------------------------------------------------------------------
-export(
[init/1,
handle_call/3,
handle_cast/2,
handle_info/2,
terminate/2,
code_change/3
]).
%% ------------------------------------------------------------------
%% Macro Definitions
%% ------------------------------------------------------------------
-define(HIBERNATE_AFTER, (timer:seconds(5))).
-define(SHARED_STATE_KEY(DatabaseId), {'locus_database.shared_state', DatabaseId}).
%% ------------------------------------------------------------------
%% API Record and Type Definitions
%% ------------------------------------------------------------------
-type opt() ::
database_opt() |
locus_loader:opt().
-export_type([opt/0]).
-type database_opt() :: {event_subscriber, atom() | pid()}.
-export_type([database_opt/0]).
-type static_child_spec() ::
#{ id := term(),
start := {?MODULE, start_link, [atom() | origin() | [opt()], ...]},
restart := permanent,
shutdown := non_neg_integer(),
type := worker,
modules := [?MODULE, ...]
}.
-export_type([static_child_spec/0]).
-type origin() :: locus_loader:origin().
-export_type([origin/0]).
-type event() ::
locus_loader:event() |
event_load_attempt_finished().
-export_type([event/0]).
-type event_load_attempt_finished() ::
{load_attempt_finished, locus_loader:source(), {ok, Version :: calendar:datetime()}} |
{load_attempt_finished, locus_loader:source(), {error, term()}}.
-export_type([event_load_attempt_finished/0]).
%% ------------------------------------------------------------------
%% Internal Record and Type Definitions
%% ------------------------------------------------------------------
-record(state, {
id :: atom(),
loader_pid :: pid(),
subscribers :: [atom() | pid()],
subscriber_mons :: #{monitor() => pid()}
}).
-type state() :: #state{}.
-type monitor() :: reference().
-record(shared_state, {
database :: locus_mmdb:database(),
source :: locus_loader:source(),
version :: calendar:datetime()
}).
%% ------------------------------------------------------------------
%% API Function Definitions
%% ------------------------------------------------------------------
%% @doc Returns the database version based on its build epoch (UNIX timestamp)
-spec version(BuildEpoch) -> Version
when BuildEpoch :: non_neg_integer(),
Version :: calendar:datetime().
version(BuildEpoch) ->
GregorianEpoch = calendar:datetime_to_gregorian_seconds({{1970, 1, 1}, {0, 0, 0}}),
calendar:gregorian_seconds_to_datetime(GregorianEpoch + BuildEpoch).
%% ------------------------------------------------------------------
%% "Private" API Function Definitions
%% ------------------------------------------------------------------
-spec start(atom(), origin(), [opt()])
-> ok |
{error, already_started} |
{error, {invalid_opt, term()}} |
{error, application_not_running}.
%% @private
start(Id, Origin, Opts) ->
case locus_database_sup:start_child([Id, Origin, Opts]) of
{ok, _Pid} ->
ok;
{error, {already_started, _Pid}} ->
{error, already_started};
{error, Reason} ->
{error, Reason}
end.
-spec stop(atom(), term()) -> ok | {error, not_found}.
%% @private
stop(Id, Reason) ->
ServerName = server_name(Id),
try gen:stop(ServerName, Reason, 5000) of
ok -> ok
catch
exit:noproc -> {error, not_found};
exit:normal -> ok;
exit:shutdown -> ok;
exit:{shutdown, _} -> ok
end.
-spec start_link(atom(), origin(), [opt()]) -> {ok, pid()}.
%% @private
start_link(Id, Origin, Opts) ->
ServerName = server_name(Id),
ServerOpts = [{hibernate_after, ?HIBERNATE_AFTER}],
gen_server:start_link({local, ServerName}, ?MODULE, [Id, Origin, Opts], ServerOpts).
-spec dynamic_child_spec(term()) -> supervisor:child_spec().
%% @private
dynamic_child_spec(ChildId) ->
#{ id => ChildId,
start => {?MODULE, start_link, []},
restart => transient,
shutdown => timer:seconds(5),
type => worker,
modules => [?MODULE]
}.
-spec static_child_spec(term(), atom(), origin(), [opt()]) -> static_child_spec().
%% @private
static_child_spec(ChildId, DatabaseId, Origin, Opts) ->
#{ id => ChildId,
start => {?MODULE, start_link, [DatabaseId, Origin, Opts]},
restart => permanent,
shutdown => timer:seconds(5),
type => worker,
modules => [?MODULE]
}.
-spec async_get_version_or_subscribe(atom()) -> {await, reference()}.
%% @private
async_get_version_or_subscribe(Id) ->
ServerName = server_name(Id),
ServerPid = erlang:whereis(ServerName),
Ref = monitor(process, ServerPid),
gen_server:cast(ServerPid, {get_version_or_subscribe, {self(), Ref}}),
{await, Ref}.
-spec find(atom()) -> {ok, locus_mmdb:database(), locus_loader:source(), calendar:datetime()}
| {error, database_not_loaded}
| {error, database_unknown}.
%% @private
find(Id) ->
case find_shared_state(Id) of
{ok, #shared_state{database = Database, source = Source, version = Version}} ->
{ok, Database, Source, Version};
error ->
ServerName = server_name(Id),
DoesLoaderExist = (erlang:whereis(ServerName) =/= undefined),
ErrorReason = maps:get(DoesLoaderExist, #{true => database_not_loaded,
false => database_unknown}),
{error, ErrorReason}
end.
-ifdef(TEST).
whereis(Id) ->
ServerName = server_name(Id),
erlang:whereis(ServerName).
list_subscribers(Id) ->
ServerName = server_name(Id),
State = sys:get_state(ServerName),
State#state.subscribers.
-endif.
%% ------------------------------------------------------------------
%% gen_server Function Definitions
%% ------------------------------------------------------------------
-spec init([atom() | origin() | [opt()], ...])
-> {ok, state()} |
{stop, {invalid_opt, term()}}.
%% @private
init([Id, Origin, Opts]) ->
_ = process_flag(trap_exit, true),
case validate_opts(Origin, Opts) of
{ok, {DatabaseOpts, LoaderOpts, FetcherOpts}} ->
init(Id, Origin, DatabaseOpts, LoaderOpts, FetcherOpts);
{error, BadOpt} ->
{stop, {invalid_opt, BadOpt}}
end.
-spec handle_call(term(), {pid(), reference()}, state())
-> {stop, unexpected_call, state()}.
%% @private
handle_call(_Call, _From, State) ->
{stop, unexpected_call, State}.
-spec handle_cast(term(), state())
-> {noreply, state()} |
{stop, unexpected_cast, state()}.
%% @private
handle_cast({get_version_or_subscribe, {Pid, Ref}}, State) ->
case find_shared_state(State#state.id) of
error ->
Mon = monitor(process, Pid),
UpdatedSubscribers = [Pid | State#state.subscribers],
UpdatedSubscriberMons = maps:put(Mon, Pid, State#state.subscriber_mons),
UpdatedState = State#state{ subscribers = UpdatedSubscribers,
subscriber_mons = UpdatedSubscriberMons },
{noreply, UpdatedState};
{ok, #shared_state{version = LastVersion}} ->
_ = Pid ! {Ref, {version, LastVersion}},
{noreply, State}
end;
handle_cast(_Cast, State) ->
{stop, unexpected_cast, State}.
-spec handle_info(term(), state())
-> {noreply, state()} |
{stop, term(), state()}.
%% @private
handle_info({LoaderPid, Msg}, State)
when LoaderPid =:= State#state.loader_pid ->
handle_loader_msg(Msg, State);
handle_info({'DOWN', Ref, process, _, _}, State) ->
handle_monitored_process_death(Ref, State);
handle_info({'EXIT', Pid, Reason}, State) ->
handle_linked_process_death(Pid, Reason, State);
handle_info(_Info, State) ->
{stop, unexpected_info, State}.
-spec terminate(term(), state()) -> ok.
%% @private
terminate(Reason, State) ->
_ = (locus_util:is_termination_reason_harmless(Reason)
% Avoid erasing shared state if we're in a crash cycle
% (as frequent restarts could put quite a strain on
% the GC and/or memory consumption.)
andalso erase_shared_state(State#state.id)),
ok.
-spec code_change(term(), state(), term()) -> {ok, state()}.
%% @private
code_change(_OldVsn, #state{} = State, _Extra) ->
{ok, State}.
%% ------------------------------------------------------------------
%% Internal Function Definitions - Initialization and Event Handling
%% ------------------------------------------------------------------
-spec server_name(atom()) -> atom().
server_name(Id) ->
list_to_atom(
atom_to_list(?MODULE)
++ "."
++ atom_to_list(Id)
).
-spec validate_opts(origin(), list())
-> {ok, {[database_opt()], [locus_loader:loader_opt()], [locus_loader:fetcher_opt()]}} |
{error, term()}.
validate_opts(Origin, Opts) ->
case locus_loader:validate_opts(Origin, Opts) of
{ok, {LoaderOpts, FetcherOpts, DatabaseOpts}} ->
validate_database_opts(DatabaseOpts, LoaderOpts, FetcherOpts);
{error, BadOpt} ->
{error, BadOpt}
end.
-spec validate_database_opts(list(), [locus_loader:loader_opt()], [locus_loader:fetcher_opt()])
-> {ok, {[database_opt()], [locus_loader:loader_opt()], [locus_loader:fetcher_opt()]}} |
{error, term()}.
validate_database_opts(DatabaseOpts, LoaderOpts, FetcherOpts) ->
case
locus_util:lists_anymap(
fun ({event_subscriber, Module}) when is_atom(Module) ->
Module =:= undefined;
({event_subscriber, Pid}) ->
not is_pid(Pid);
(_) ->
true
end,
DatabaseOpts)
of
{true, BadOpt} ->
{error, BadOpt};
false ->
{ok, {DatabaseOpts, LoaderOpts, FetcherOpts}}
end.
-spec init(atom(), origin(), [database_opt()], [locus_loader:loader_opt()],
[locus_loader:fetcher_opt()]) -> {ok, state()}.
init(Id, Origin, DatabaseOpts, LoaderOpts, FetcherOpts) ->
{ok, LoaderPid} = locus_loader:start_link(Id, Origin, LoaderOpts, FetcherOpts),
BaseState =
#state{
id = Id,
loader_pid = LoaderPid,
subscribers = [],
subscriber_mons = #{}
},
init_opts(DatabaseOpts, BaseState).
-spec init_opts([database_opt()], state()) -> {ok, state()}.
init_opts([{event_subscriber, Module} | Opts], State)
when is_atom(Module) ->
#state{subscribers = Subscribers} = State,
UpdatedSubscribers = [Module | Subscribers],
UpdatedState = State#state{ subscribers = UpdatedSubscribers },
init_opts(Opts, UpdatedState);
init_opts([{event_subscriber, Pid} | Opts], State) ->
#state{subscribers = Subscribers, subscriber_mons = SubscriberMons} = State,
Mon = monitor(process, Pid),
UpdatedSubscribers = [Pid | Subscribers],
UpdatedSubscriberMons = SubscriberMons#{Mon => Pid},
UpdatedState = State#state{ subscribers = UpdatedSubscribers,
subscriber_mons = UpdatedSubscriberMons },
init_opts(Opts, UpdatedState);
init_opts([], State) ->
_ = process_flag(trap_exit, true), % ensure `:terminate/2' is called (unless killed)
{ok, State}.
-spec handle_loader_msg(locus_loader:msg(), state()) -> {noreply, state()}.
handle_loader_msg({event, Event}, State) ->
report_event(Event, State),
{noreply, State};
handle_loader_msg({load_success, Source, Version, Database}, State) ->
update_shared_state(State#state.id, Database, Source, Version),
report_event({load_attempt_finished, Source, {ok, Version}}, State),
{noreply, State};
handle_loader_msg({load_failure, Source, Reason}, State) ->
report_event({load_attempt_finished, Source, {error, Reason}}, State),
{noreply, State}.
-spec report_event(event(), state()) -> ok.
report_event(Event, #state{id = Id, subscribers = Subscribers}) ->
lists:foreach(
fun (Module) when is_atom(Module) ->
locus_event_subscriber:report(Module, Id, Event);
(Pid) ->
erlang:send(Pid, {locus, Id, Event}, [noconnect])
end,
Subscribers).
-spec handle_monitored_process_death(monitor(), state()) -> {noreply, state()}.
handle_monitored_process_death(Ref, State) ->
#state{subscribers = Subscribers, subscriber_mons = SubscriberMons} = State,
{Pid, UpdatedSubscriberMons} = maps:take(Ref, SubscriberMons),
{ok, UpdatedSubscribers} = locus_util:lists_take(Pid, Subscribers),
UpdatedState = State#state{ subscribers = UpdatedSubscribers,
subscriber_mons = UpdatedSubscriberMons },
{noreply, UpdatedState}.
-spec handle_linked_process_death(pid(), term(), state())
-> {stop, {loader_stopped, pid(), term()}, state()}.
handle_linked_process_death(Pid, Reason, State)
when Pid =:= State#state.loader_pid ->
{stop, {loader_stopped, Pid, Reason}, State}.
%% ------------------------------------------------------------------
%% Internal Function Definitions - Shared State Management
%% ------------------------------------------------------------------
erase_shared_state(DatabaseId) ->
Key = ?SHARED_STATE_KEY(DatabaseId),
persistent_term:erase(Key).
find_shared_state(DatabaseId) ->
Key = ?SHARED_STATE_KEY(DatabaseId),
try persistent_term:get(Key) of
SharedState ->
{ok, SharedState}
catch
error:badarg ->
error
end.
update_shared_state(DatabaseId, Database, Source, Version) ->
Key = ?SHARED_STATE_KEY(DatabaseId),
SharedState = #shared_state{database = Database, source = Source, version = Version},
persistent_term:put(Key, SharedState).
%% ------------------------------------------------------------------
%% Unit Tests
%% ------------------------------------------------------------------
-ifdef(TEST).
large_binaries_in_gcollected_persistent_term_are_copied_by_ref_test() ->
%
% Confirm that the VM behaves as we expect it to - otherwise
% using `persistent_term' to publish `#shared_state{}' would be
% a giant foot gun.
%
Key = {?MODULE, ?FUNCTION_NAME, make_ref()},
Randomizer = rand:uniform(1024),
BinarySize = 1024 * 1024, % large enough to be ref counted
Binary = <<0:BinarySize/integer-unit:8>>,
SubBinarySize = 1024 + Randomizer,
<<SubBinary:SubBinarySize/bytes, _/bytes>> = Binary,
persistent_term:put(Key, SubBinary),
TestPid = self(),
HelperPid
= spawn_link(
fun () ->
RetrievedSubBinary = persistent_term:get(Key),
TestPid ! ready_to_check,
receive
check_it ->
erlang:garbage_collect(), % just to be sure
?assertEqual(BinarySize,
binary:referenced_byte_size(RetrievedSubBinary))
end
end),
receive ready_to_check -> ok end,
RetrievedSubBinary = persistent_term:get(Key),
persistent_term:put(Key, updated_value),
HelperPid ! check_it,
erlang:garbage_collect(), % just to be sure
?assertEqual(BinarySize, binary:referenced_byte_size(RetrievedSubBinary)).
shared_state_is_cleared_when_gracefully_stopped_test() ->
{ok, _} = application:ensure_all_started(locus),
PrevTrapExit = process_flag(trap_exit, true),
try
Id = ?FUNCTION_NAME,
launch_database_loader_and_await_it(Id),
?assertMatch({ok, #shared_state{}}, find_shared_state(Id)),
ok = stop(Id, _Reason = normal),
?assertMatch(error, find_shared_state(Id))
after
true = process_flag(trap_exit, PrevTrapExit)
end.
shared_state_is_kept_when_crashed_test() ->
{ok, _} = application:ensure_all_started(locus),
PrevTrapExit = process_flag(trap_exit, true),
try
Id = ?FUNCTION_NAME,
launch_database_loader_and_await_it(Id),
?assertMatch({ok, #shared_state{}}, find_shared_state(Id)),
ok = stop(Id, _Reason = {shutdown, simulated_crash}),
?assertMatch({ok, #shared_state{}}, find_shared_state(Id))
after
true = process_flag(trap_exit, PrevTrapExit)
end.
launch_database_loader_and_await_it(Id) ->
Origin = {filesystem, "test/priv/GeoLite2-Country.mmdb"},
{ok, _Pid} = start_link(Id, Origin, _Opts = []),
{await, Ref} = async_get_version_or_subscribe(Id),
{ok, _} = wait_for_database_to_load(Ref, Id),
ok.
wait_for_database_to_load(Ref, Id) ->
receive
{Ref, {version, _Version}} ->
ok;
{locus, Id, {load_attempt_finished, _, Result}} ->
{ok, Result};
{locus, Id, _Event} ->
% ct:pal("~p", [Event]),
wait_for_database_to_load(Ref, Id);
{'DOWN', Ref, process, _, Reason} ->
exit(Reason)
end.
-endif. % -ifdef(TEST).