src/locus_http_download.erl

%% Copyright (c) 2017-2022 Guilherme Andrade
%%
%% Permission is hereby granted, free of charge, to any person obtaining a
%% copy  of this software and associated documentation files (the "Software"),
%% to deal in the Software without restriction, including without limitation
%% the rights to use, copy, modify, merge, publish, distribute, sublicense,
%% and/or sell copies of the Software, and to permit persons to whom the
%% Software is furnished to do so, subject to the following conditions:
%%
%% The above copyright notice and this permission notice shall be included in
%% all copies or substantial portions of the Software.
%%
%% THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
%% IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
%% FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
%% AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
%% LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
%% FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
%% DEALINGS IN THE SOFTWARE.
%%
%% locus is an independent project and has not been authorized, sponsored,
%% or otherwise approved by MaxMind.

%% @doc Downloads a file using HTTP(S) without blocking the caller
-module(locus_http_download).
-behaviour(gen_server).

%% ------------------------------------------------------------------
%% API Function Exports
%% ------------------------------------------------------------------

-export(
   [validate_opts/1,
    start_link/3
   ]).

-ignore_xref(
   [start_link/3
   ]).

%% ------------------------------------------------------------------
%% gen_server Function Exports
%% ------------------------------------------------------------------

-export(
   [init/1,
    handle_call/3,
    handle_cast/2,
    handle_info/2,
    terminate/2,
    code_change/3
   ]).

%% ------------------------------------------------------------------
%% Macro Definitions
%% ------------------------------------------------------------------

-define(DEFAULT_CONNECT_TIMEOUT, (timer:seconds(8))).
-define(DEFAULT_DOWNLOAD_START_TIMEOUT, (timer:seconds(5))).
-define(DEFAULT_IDLE_DOWNLOAD_TIMEOUT, (timer:seconds(5))).
-define(MAX_REDIRECTIONS, 5).

-define(is_timeout(V), ((is_integer((V)) andalso ((V) >= 0)) orelse ((V) =:= infinity))).
-define(is_list_of_censored_query_keys(V), (length((V)) >= 0
                                            andalso lists:all(fun is_atom/1, (V)))).

%% ------------------------------------------------------------------
%% Record and Type Definitions
%% ------------------------------------------------------------------

-type opt() ::
    {connect_timeout, timeout()} |
    {download_start_timeout, timeout()} |
    {idle_download_timeout, timeout()} |
    insecure |
    {insecure, boolean()} |
    {censor_query, CensoredKeys :: [atom()]}.
-export_type([opt/0]).

-type msg() ::
    {event, event()} |
    {finished, {success, success()}} |
    {finished, dismissed} |
    {finished, {error, term()}}.
-export_type([msg/0]).

-type event() ::
    event_request_sent() |
    event_download_dismissed() |
    event_download_redirected() |
    event_download_failed_to_start() |
    event_download_started() |
    event_download_finished().
-export_type([event/0]).

-type event_request_sent() ::
    {request_sent, url(), headers()}.
-export_type([event_request_sent/0]).

-type event_download_dismissed() ::
    {download_dismissed, full_http_response()}.
-export_type([event_download_dismissed/0]).

-type event_download_redirected() ::
    {download_redirected, redirection()}.
-export_type([event_download_redirected/0]).

-type event_download_failed_to_start() ::
    {download_failed_to_start, reason_for_download_failing_to_start()}.
-export_type([event_download_failed_to_start/0]).

-type event_download_started() ::
    {download_started, headers()}.
-export_type([event_download_started/0]).

-type event_download_finished() ::
    {download_finished, BodySize :: non_neg_integer(), {ok, TrailingHeaders :: headers()}} |
    {download_finished, BodySize :: non_neg_integer(), {error, term()}} |
    {download_finished, BodySize :: non_neg_integer(), {error, timeout}}.
-export_type([event_download_finished/0]).

-type reason_for_download_failing_to_start() ::
    full_http_response() |
    too_many_redirections |
    {invalid_redirection, term()} |
    {error, term()} |
    timeout.
-export_type([reason_for_download_failing_to_start/0]).

-type full_http_response() ::
    {http, response_status(), headers(), body()}.
-export_type([full_http_response/0]).

-type success() ::
    #{ headers := headers(),
       body := binary()
     }.
-export_type([success/0]).

-type url() :: string().
-export_type([url/0]).

-type response_status() :: {100..999, binary()}.
-export_type([response_status/0]).

% case insensitive
-type headers() :: [{string(), string()}].
-export_type([headers/0]).

-type body() :: binary().
-export_type([body/0]).

-type redirection() ::
    #{ url := url(),
       permanence := permanent | temporary
     }.
-export_type([redirection/0]).

-record(state, {
          owner_pid :: pid(),
          url :: url(),
          censored_url :: url(),
          headers :: headers(),
          opts :: [opt()],
          timeouts :: #{ term() => infinity | reference() },
          redirections :: non_neg_integer(),
          request_id :: reference() | undefined,
          response_headers :: headers() | undefined,
          response_body :: iodata() | undefined
         }).
-type state() :: #state{}.

%% ------------------------------------------------------------------
%% API Function Definitions
%% ------------------------------------------------------------------

-spec validate_opts(proplists:proplist())
        -> {ok, {[opt()], proplists:proplist()}} |
           {error, BadOpt :: term()}.
%% @private
validate_opts(MixedOpts) ->
    try
        lists:partition(
          fun ({connect_timeout, Value} = Opt) ->
                  ?is_timeout(Value)
                  orelse error({badopt, Opt});
              ({download_start_timeout, Value} = Opt) ->
                  ?is_timeout(Value)
                  orelse error({badopt, Opt});
              ({idle_download_timeout, Value} = Opt) ->
                  ?is_timeout(Value)
                  orelse error({badopt, Opt});
              (insecure) ->
                  true;
              ({insecure, Insecure} = Opt) ->
                  is_boolean(Insecure)
                  orelse error({badopt, Opt});
              ({censor_query, CensoredKeys} = Opt) ->
                  ?is_list_of_censored_query_keys(CensoredKeys)
                  orelse error({badopt, Opt});
              (_) ->
                  false
          end,
          MixedOpts)
    of
        {MyOpts, OtherOpts} ->
            {ok, {MyOpts, OtherOpts}}
    catch
        error:{badopt, BadOpt} ->
            {error, BadOpt}
    end.

-spec start_link(url(), headers(), [opt()]) -> {ok, pid()}.
%% @private
start_link(URL, Headers, Opts) ->
    gen_server:start_link(?MODULE, [self(), URL, Headers, Opts], []).

%% ------------------------------------------------------------------
%% gen_server Function Definitions
%% ------------------------------------------------------------------

-spec init([InitArg, ...]) -> {ok, state()}
        when InitArg :: OwnerPid | URL | Headers | Opts,
             OwnerPid :: pid(),
             URL :: url(),
             Headers :: headers(),
             Opts :: [opt()].
%% @private
init([OwnerPid, URL, Headers, Opts]) ->
    _ = process_flag(trap_exit, true),
    self() ! send_request,
    CiHeaders = lists:keymap(fun string:to_lower/1, 1, Headers),
    {ok, #state{
            owner_pid = OwnerPid,
            url = URL,
            censored_url = maybe_censor_url(URL, Opts),
            headers = CiHeaders,
            opts = Opts,
            timeouts = #{},
            redirections = 0
           }}.

-spec handle_call(term(), {pid(), reference()}, state())
        -> {stop, unexpected_call, state()}.
%% @private
handle_call(_Call, _From, State) ->
    {stop, unexpected_call, State}.

-spec handle_cast(term(), state())
        -> {stop, unexpected_cast, state()}.
%% @private
handle_cast(_Cast, State) ->
    {stop, unexpected_cast, State}.

-spec handle_info(term(), state())
        -> {noreply, state()} |
           {stop, normal, state()} |
           {stop, unexpected_info, state()}.
%% @private
handle_info(send_request, State) ->
    UpdatedState = send_request(State),
    {noreply, UpdatedState};
handle_info({http, Msg}, State)
  when element(1, Msg) =:= State#state.request_id ->
    handle_httpc_message(Msg, State);
handle_info({timeout, OptName}, State) ->
    #state{timeouts = Timeouts} = State,
    #{OptName := _} = Timeouts,
    UpdatedTimeouts = maps:remove(OptName, Timeouts),
    UpdatedState = State#state{ timeouts = UpdatedTimeouts },
    handle_timeout(OptName, UpdatedState);
handle_info({'EXIT', Pid, _}, State) ->
    handle_linked_process_death(Pid, State);
handle_info(_Info, State) ->
    {stop, unexpected_info, State}.

-spec terminate(term(), state()) -> ok.
%% @private
terminate(_Reason, _State) ->
    ok.

-spec code_change(term(), state(), term()) -> {ok, state()}.
%% @private
code_change(_OldVsn, #state{} = State, _Extra) ->
    {ok, State}.

%% ------------------------------------------------------------------
%% Internal Function Definitions
%% ------------------------------------------------------------------

-spec maybe_censor_url(url(), [opt()]) -> url().
maybe_censor_url(URL, Opts) ->
    case proplists:get_value(censor_query, Opts, []) of
        [_|_] = CensoredKeys ->
            CensoredStringKeys = [atom_to_list(Key) || Key <- CensoredKeys],
            locus_util:censor_url_query(URL, CensoredStringKeys);
        [] ->
            URL
    end.

maybe_censor_redirection(#{url := URL} = Redirection, Opts) ->
    CensoredURL = maybe_censor_url(URL, Opts),
    Redirection#{url := CensoredURL}.

-spec send_request(state()) -> state().
send_request(State)
  when State#state.request_id =:= undefined ->
    #state{url = URL, censored_url = CensoredURL, headers = Headers, opts = Opts} = State,
    ConnectTimeout = proplists:get_value(connect_timeout, Opts, ?DEFAULT_CONNECT_TIMEOUT),
    Insecure = proplists:get_value(insecure, Opts, false),

    Request = {URL, Headers},
    BaseHTTPOpts = [{connect_timeout, ConnectTimeout}],
    ExtraHTTPOpts =
        case Insecure of
            true ->
                [];
            false ->
                [{ssl, tls_certificate_check:options(URL)}]
        end,
    % Autoredirect causes issues for HTTPS downloads,
    % since the TLS validation set up in `ExtraHTTPOpts'
    % can only account for the current URL's hostname.
    NoRedirectHTTPOpts = [{autoredirect, false}],
    HTTPOpts = BaseHTTPOpts  ++ ExtraHTTPOpts ++ NoRedirectHTTPOpts,

    RequestOpts = [{sync, false}, {stream, self}],
    {ok, RequestId} = httpc:request(get, Request, HTTPOpts, RequestOpts),
    true = is_reference(RequestId),

    report_event({request_sent, CensoredURL, Headers}, State),

    State2 = State#state{ request_id = RequestId },
    _State3 = schedule_download_start_timeout(State2).

-spec handle_httpc_message(tuple(), state()) -> {noreply, state()} | {stop, normal, state()}.
handle_httpc_message(Msg, State)
  when State#state.response_headers =:= undefined ->
    case Msg of
        {_, stream_start, Headers} ->
            CiHeaders = lists:keymap(fun string:to_lower/1, 1, Headers),
            State2 = cancel_download_start_timeout(State),
            State3 = schedule_idle_download_timeout(State2),
            State4 = State3#state{ response_headers = CiHeaders, response_body = <<>> },
            report_event({download_started, CiHeaders}, State4),
            {noreply, State4};
        {_, {{_, StatusCode, StatusDesc}, Headers, Body}} ->
            CiHeaders = lists:keymap(fun string:to_lower/1, 1, Headers),
            handle_download_start_http_failure(StatusCode, StatusDesc, CiHeaders, Body, State);
        {_, {error, Reason}} ->
            report_event({download_failed_to_start, {error, Reason}}, State),
            notify_owner({finished, {error, {http, Reason}}}, State),
            {stop, normal, State}
    end;
handle_httpc_message(Msg, State)
  when State#state.response_body =/= undefined ->
    case Msg of
        {_, stream, BodyPart} ->
            #state{response_body = BodyAcc} = State,
            UpdatedBodyAcc = [BodyAcc, BodyPart],
            State2 = State#state{ response_body = UpdatedBodyAcc },
            State3 = reschedule_idle_download_timeout(State2),
            {noreply, State3};
        {_, stream_end, TrailingHeaders} -> % no chunked encoding
            #state{response_headers = HeadersAcc, response_body = BodyAcc} = State,
            CiTrailingHeaders = lists:keymap(fun string:to_lower/1, 1, TrailingHeaders),
            Headers = lists:usort(HeadersAcc ++ CiTrailingHeaders),
            Body = iolist_to_binary(BodyAcc),
            BodySize = byte_size(Body),
            report_event({download_finished, BodySize, {ok, CiTrailingHeaders}}, State),
            handle_successful_download_conclusion(Headers, Body, State);
        {_, {error, Reason}} ->
            #state{response_body = BodyAcc} = State,
            BodySizeSoFar = iolist_size(BodyAcc),
            report_event({download_finished, BodySizeSoFar, {error, Reason}}, State),
            notify_owner({finished, {error, {http, Reason}}}, State),
            {stop, normal, State}
    end.

handle_download_start_http_failure(StatusCode, StatusDesc, CiHeaders, Body, State) ->
    case stream_start_failure_type(StatusCode, CiHeaders, State) of
        not_modified ->
            report_event({download_dismissed, {http, {StatusCode, StatusDesc},
                                               CiHeaders, Body}}, State),
            notify_owner({finished, dismissed}, State),
            {stop, normal, State};
        {redirection, Redirection} when State#state.redirections < ?MAX_REDIRECTIONS ->
            %% TODO test coverage of redirections
            CensoredRedirection = maybe_censor_redirection(Redirection, State#state.opts),
            report_event({download_redirected, CensoredRedirection}, State),
            #{url := NewURL} = Redirection,
            #{url := CensoredNewURL} = CensoredRedirection,
            State2 = cancel_download_start_timeout(State),
            State3 = State2#state{ request_id = undefined,
                                   url = NewURL,
                                   censored_url = CensoredNewURL,
                                   redirections = State2#state.redirections + 1 },
            State4 = send_request(State3),
            {noreply, State4};
        {redirection, _} ->
            %% TODO test coverage of redirections
            report_event({download_failed_to_start, too_many_redirections}, State),
            notify_owner({finished, {error, too_many_redirections}}, State),
            {stop, normal, State};
        {invalid_redirection, Reason} ->
            %% TODO test coverage of redirections
            report_event({download_failed_to_start, {invalid_redirection, Reason}}, State),
            notify_owner({finished, {error, {invalid_redirection, Reason}}}, State),
            {stop, normal, State};
        error ->
            report_event({download_failed_to_start, {http, {StatusCode, StatusDesc},
                                                     CiHeaders, Body}}, State),
            notify_owner({finished, {error, {http, StatusCode, StatusDesc}}}, State),
            {stop, normal, State}
    end.

stream_start_failure_type(StatusCode, CiHeaders, State) ->
    %% https://developer.mozilla.org/en-US/docs/Web/HTTP/Redirections
    case StatusCode of
        301 -> stream_start_redirect(permanent, CiHeaders, State);
        302 -> stream_start_redirect(temporary, CiHeaders, State);
        303 -> stream_start_redirect(temporary, CiHeaders, State);
        304 -> not_modified;
        307 -> stream_start_redirect(temporary, CiHeaders, State);
        308 -> stream_start_redirect(permanent, CiHeaders, State);
        _   -> error
    end.

stream_start_redirect(Permanence, CiHeaders, State) ->
    case lists:keyfind("location", 1, CiHeaders) of
        {_, NewLocation} ->
            stream_start_redirect_for_location(Permanence, NewLocation, State);
        _ ->
            {invalid_redirection, missing_location_header}
    end.

stream_start_redirect_for_location(Permanence, NewLocation, State) ->
    case locus_util:resolve_http_location(State#state.url, NewLocation) of
        {ok, NewURL} ->
            {redirection, #{permanence => Permanence, url => NewURL}};
        {error, Reason} ->
            {invalid_redirection, {bad_location, Reason}}
    end.

handle_successful_download_conclusion(Headers, Body, State) ->
    ActualContentLength = integer_to_list( byte_size(Body) ),

    case lists:keyfind("content-length", 1, Headers) of
        {_, DeclaredContentLength} when DeclaredContentLength =/= ActualContentLength ->
            ErrorReason = {body_size_mismatch, #{declared_content_length => DeclaredContentLength,
                                                 actual_content_length => ActualContentLength}},
            notify_owner({finished, {error, ErrorReason}}, State),
            {stop, normal, State};
        _ ->
            Success = #{headers => Headers, body => Body},
            notify_owner({finished, {success, Success}}, State),
            {stop, normal, State}
    end.

%% ------------------------------------------------------------------
%% Internal Function Definitions - Timeouts
%% ------------------------------------------------------------------

-spec schedule_download_start_timeout(state()) -> state().
schedule_download_start_timeout(State) ->
    schedule_timeout(download_start_timeout, ?DEFAULT_DOWNLOAD_START_TIMEOUT, State).

-spec cancel_download_start_timeout(state()) -> state().
cancel_download_start_timeout(State) ->
    cancel_timeout(download_start_timeout, State).

-spec schedule_idle_download_timeout(state()) -> state().
schedule_idle_download_timeout(State) ->
    schedule_timeout(idle_download_timeout, ?DEFAULT_IDLE_DOWNLOAD_TIMEOUT, State).

-spec reschedule_idle_download_timeout(state()) -> state().
reschedule_idle_download_timeout(State) ->
    reschedule_timeout(idle_download_timeout, ?DEFAULT_IDLE_DOWNLOAD_TIMEOUT, State).

schedule_timeout(OptName, DefaultValue, State) ->
    #state{opts = Opts, timeouts = Timeouts} = State,
    false = maps:is_key(OptName, Timeouts),
    case proplists:get_value(OptName, Opts, DefaultValue) of
        infinity ->
            UpdatedTimeouts = Timeouts#{ OptName => infinity },
            State#state{ timeouts = UpdatedTimeouts };
        Interval ->
            TimeoutMsg = {timeout, OptName},
            Timer = erlang:send_after(Interval, self(), TimeoutMsg),
            UpdatedTimeouts = Timeouts#{ OptName => Timer },
            State#state{ timeouts = UpdatedTimeouts }
    end.

cancel_timeout(OptName, State) ->
    #state{timeouts = Timeouts} = State,
    #{OptName := MaybeTimer} = Timeouts,
    UpdatedTimeouts = maps:remove(OptName, Timeouts),
    case MaybeTimer of
        infinity ->
            State#state{ timeouts = UpdatedTimeouts };
        Timer ->
            TimeoutMsg = {timeout, OptName},
            true = cancel_or_flush_timer(Timer, TimeoutMsg),
            State#state{ timeouts = UpdatedTimeouts }
    end.

cancel_or_flush_timer(Timer, TimeoutMsg) ->
    is_integer( erlang:cancel_timer(Timer) )
    orelse receive
               TimeoutMsg -> true
           after
               0 -> false
           end.

reschedule_timeout(OptName, DefaultValue, State) ->
    State2 = cancel_timeout(OptName, State),
    schedule_timeout(OptName, DefaultValue, State2).

handle_timeout(download_start_timeout, State) ->
    ok = httpc:cancel_request(State#state.request_id),
    report_event({download_failed_to_start, timeout}, State),
    notify_owner({finished, {error, {timeout, waiting_stream_start}}}, State),
    {stop, normal, State};
handle_timeout(idle_download_timeout, State) ->
    ok = httpc:cancel_request(State#state.request_id),
    #state{response_body = BodyAcc} = State,
    BodySizeSoFar = iolist_size(BodyAcc),
    report_event({download_finished, BodySizeSoFar, {error, timeout}}, State),
    notify_owner({finished, {error, {timeout, waiting_stream_end}}}, State),
    {stop, normal, State}.

%% ------------------------------------------------------------------
%% Internal Function Definitions - Events
%% ------------------------------------------------------------------

-spec report_event(event(), state()) -> ok.
report_event(Event, State) ->
    notify_owner({event, Event}, State).

-spec notify_owner(msg(), state()) -> ok.
notify_owner(Msg, State) ->
    #state{owner_pid = OwnerPid} = State,
    _ = erlang:send(OwnerPid, {self(), Msg}, [noconnect]),
    ok.

%% ------------------------------------------------------------------
%% Internal Function Definitions - Death
%% ------------------------------------------------------------------

-spec handle_linked_process_death(pid(), state()) -> {stop, normal, state()}.
handle_linked_process_death(Pid, State)
  when Pid =:= State#state.owner_pid ->
    {stop, normal, State}.