lib/db_connection/connection.ex

defmodule DBConnection.ConnectionError do
  defexception [:message, severity: :error, reason: :error]

  @moduledoc """
  A generic connection error exception.

  The raised exception might include the reason which would be useful
  to programmatically determine what was causing the error.
  """

  @doc false
  def exception(message, reason) do
    message
    |> exception()
    |> Map.replace!(:reason, reason)
  end
end

defmodule DBConnection.Connection do
  @moduledoc false

  @behaviour :gen_statem

  require Logger
  alias DBConnection.Backoff
  alias DBConnection.Holder

  @timeout 15_000

  @doc false
  def start_link(mod, opts, pool, tag) do
    start_opts = Keyword.take(opts, [:debug, :spawn_opt])
    :gen_statem.start_link(__MODULE__, {mod, opts, pool, tag}, start_opts)
  end

  @doc false
  def child_spec(mod, opts, pool, tag, child_opts) do
    Supervisor.child_spec(
      %{id: __MODULE__, start: {__MODULE__, :start_link, [mod, opts, pool, tag]}},
      child_opts
    )
  end

  @doc false
  def disconnect({pid, ref}, err, state) do
    :gen_statem.cast(pid, {:disconnect, ref, err, state})
  end

  @doc false
  def stop({pid, ref}, err, state) do
    :gen_statem.cast(pid, {:stop, ref, err, state})
  end

  @doc false
  def ping({pid, ref}, state) do
    :gen_statem.cast(pid, {:ping, ref, state})
  end

  ## gen_statem API

  @doc false
  @impl :gen_statem
  def callback_mode, do: :handle_event_function

  @doc false
  @impl :gen_statem
  def init({mod, opts, pool, tag}) do
    s = %{
      mod: mod,
      opts: opts,
      state: nil,
      client: :closed,
      pool: pool,
      tag: tag,
      timer: nil,
      backoff: Backoff.new(opts),
      connection_listeners: Keyword.get(opts, :connection_listeners, []),
      after_connect: Keyword.get(opts, :after_connect),
      after_connect_timeout: Keyword.get(opts, :after_connect_timeout, @timeout)
    }

    {:ok, :no_state, s, {:next_event, :internal, {:connect, :init}}}
  end

  @impl :gen_statem
  def handle_event(type, info, state, s)

  def handle_event(:internal, {:connect, _info}, :no_state, s) do
    %{mod: mod, opts: opts, backoff: backoff, after_connect: after_connect} = s

    try do
      apply(mod, :connect, [connect_opts(opts)])
    rescue
      e ->
        {e, stack} = maybe_sanitize_exception(e, __STACKTRACE__, opts)
        reraise e, stack
    else
      {:ok, state} when after_connect != nil ->
        ref = make_ref()
        :gen_statem.cast(self(), {:after_connect, ref})
        {:keep_state, %{s | state: state, client: {ref, :connect}}}

      {:ok, state} ->
        backoff = backoff && Backoff.reset(backoff)
        ref = make_ref()
        :gen_statem.cast(self(), {:connected, ref})
        {:keep_state, %{s | state: state, client: {ref, :connect}, backoff: backoff}}

      {:error, err} when is_nil(backoff) ->
        Logger.error(
          fn ->
            [
              inspect(mod),
              " (",
              inspect(self()),
              ") failed to connect: " | Exception.format_banner(:error, err, [])
            ]
          end,
          crash_reason: {err, []}
        )

        raise err

      {:error, err} ->
        Logger.error(
          fn ->
            [
              inspect(mod),
              ?\s,
              ?(,
              inspect(self()),
              ") failed to connect: "
              | Exception.format_banner(:error, err, [])
            ]
          end,
          crash_reason: {err, []}
        )

        {timeout, backoff} = Backoff.backoff(backoff)
        {:keep_state, %{s | backoff: backoff}, {{:timeout, :backoff}, timeout, nil}}
    end
  end

  def handle_event(:internal, {:disconnect, {log, err}}, :no_state, %{mod: mod} = s) do
    if log == :log do
      severity =
        case err do
          %DBConnection.ConnectionError{severity: severity} -> severity
          _ -> :error
        end

      Logger.log(severity, fn ->
        [
          inspect(mod),
          ?\s,
          ?(,
          inspect(self()),
          ") disconnected: " | Exception.format_banner(:error, err, [])
        ]
      end)

      :ok
    end

    %{state: state, client: client, timer: timer, backoff: backoff} = s
    demonitor(client)
    cancel_timer(timer)
    :ok = apply(mod, :disconnect, [err, state])
    s = %{s | state: nil, client: :closed, timer: nil}

    notify_connection_listeners(:disconnected, s)

    case client do
      _ when backoff == nil ->
        {:stop, {:shutdown, err}, s}

      {_, :after_connect} ->
        {timeout, backoff} = Backoff.backoff(backoff)
        {:keep_state, %{s | backoff: backoff}, {{:timeout, :backoff}, timeout, nil}}

      _ ->
        {:keep_state, s, {:next_event, :internal, {:connect, :disconnect}}}
    end
  end

  def handle_event({:timeout, :backoff}, _content, :no_state, s) do
    {:keep_state, s, {:next_event, :internal, {:connect, :backoff}}}
  end

  def handle_event(:cast, {:ping, ref, state}, :no_state, %{client: {ref, :pool}, mod: mod} = s) do
    case apply(mod, :ping, [state]) do
      {:ok, state} ->
        pool_update(state, s)

      {:disconnect, err, state} ->
        {:keep_state, %{s | state: state}, {:next_event, :internal, {:disconnect, {:log, err}}}}
    end
  end

  def handle_event(:cast, {:disconnect, ref, err, state}, :no_state, %{client: {ref, _}} = s) do
    {:keep_state, %{s | state: state}, {:next_event, :internal, {:disconnect, {:log, err}}}}
  end

  def handle_event(:cast, {:stop, ref, err, state}, :no_state, %{client: {ref, _}} = s) do
    {_, stack} = :erlang.process_info(self(), :current_stacktrace)

    case err do
      ok when ok in [:normal, :shutdown] ->
        :ok

      {:shutdown, _term} ->
        :ok

      _ ->
        reason =
          case err do
            %{__exception__: true} -> Exception.format_banner(:error, err, stack)
            _other -> "** #{inspect(err)}"
          end

        format =
          ~c"** State machine ~p terminating~n" ++
            ~c"** Reason for termination ==~n" ++
            ~c"~s~n"

        :error_logger.format(format, [self(), reason])
    end

    {:stop, {err, stack}, %{s | state: state}}
  end

  def handle_event(:cast, {tag, _, _, _}, :no_state, s) when tag in [:disconnect, :stop] do
    handle_timeout(s)
  end

  def handle_event(:cast, {:after_connect, ref}, :no_state, %{client: {ref, :connect}} = s) do
    %{
      mod: mod,
      state: state,
      after_connect: after_connect,
      after_connect_timeout: timeout,
      opts: opts
    } = s

    notify_connection_listeners(:connected, s)

    case apply(mod, :checkout, [state]) do
      {:ok, state} ->
        opts = [timeout: timeout] ++ opts
        {pid, ref} = DBConnection.Task.run_child(mod, state, after_connect, opts)
        timer = start_timer(pid, timeout)
        s = %{s | client: {ref, :after_connect}, timer: timer, state: state}
        {:keep_state, s}

      {:disconnect, err, state} ->
        {:keep_state, %{s | state: state}, {:next_event, :internal, {:disconnect, {:log, err}}}}
    end
  end

  def handle_event(:cast, {:after_connect, _}, :no_state, _s) do
    :keep_state_and_data
  end

  def handle_event(:cast, {:connected, ref}, :no_state, %{client: {ref, :connect}} = s) do
    %{mod: mod, state: state} = s

    notify_connection_listeners(:connected, s)

    case apply(mod, :checkout, [state]) do
      {:ok, state} ->
        pool_update(state, s)

      {:disconnect, err, state} ->
        {:keep_state, %{s | state: state}, {:next_event, :internal, {:disconnect, {:log, err}}}}
    end
  end

  def handle_event(:cast, {:connected, _}, :no_state, _s) do
    :keep_state_and_data
  end

  def handle_event(
        :info,
        {:DOWN, ref, _, pid, reason},
        :no_state,
        %{client: {ref, :after_connect}} = s
      ) do
    message = "client #{inspect(pid)} exited: " <> Exception.format_exit(reason)
    err = DBConnection.ConnectionError.exception(message)

    {:keep_state, %{s | client: {nil, :after_connect}},
     {:next_event, :internal, {:disconnect, {down_log(reason), err}}}}
  end

  def handle_event(:info, {:DOWN, mon, _, pid, reason}, :no_state, %{client: {ref, mon}} = s) do
    message = "client #{inspect(pid)} exited: " <> Exception.format_exit(reason)
    err = DBConnection.ConnectionError.exception(message)

    {:keep_state, %{s | client: {ref, nil}},
     {:next_event, :internal, {:disconnect, {down_log(reason), err}}}}
  end

  def handle_event(
        :info,
        {:timeout, timer, {__MODULE__, pid, timeout}},
        :no_state,
        %{timer: timer} = s
      )
      when is_reference(timer) do
    message =
      "client #{inspect(pid)} timed out because it checked out " <>
        "the connection for longer than #{timeout}ms"

    exc =
      case Process.info(pid, :current_stacktrace) do
        {:current_stacktrace, stacktrace} ->
          message <>
            "\n\n#{inspect(pid)} was at location:\n\n" <>
            Exception.format_stacktrace(stacktrace)

        _ ->
          message
      end
      |> DBConnection.ConnectionError.exception()

    {:keep_state, %{s | timer: nil}, {:next_event, :internal, {:disconnect, {:log, exc}}}}
  end

  def handle_event(
        :info,
        {:"ETS-TRANSFER", holder, _pid, {msg, ref, extra}},
        :no_state,
        %{client: {ref, :after_connect}, timer: timer} = s
      ) do
    {_, state} = Holder.delete(holder)
    cancel_timer(timer)
    s = %{s | timer: nil}

    case msg do
      :checkin -> handle_checkin(state, s)
      :disconnect -> handle_event(:cast, {:disconnect, ref, extra, state}, :no_state, s)
      :stop -> handle_event(:cast, {:stop, ref, extra, state}, :no_state, s)
    end
  end

  def handle_event(:info, msg, :no_state, %{mod: mod} = s) do
    Logger.info(fn ->
      [inspect(mod), ?\s, ?(, inspect(self()), ") missed message: " | inspect(msg)]
    end)

    handle_timeout(s)
  end

  @doc false
  @impl :gen_statem
  # If client is :closed then the connection was previouly disconnected
  # and cleanup is not required.
  def terminate(_, _, %{client: :closed}), do: :ok

  def terminate(reason, _, s) do
    %{mod: mod, state: state} = s
    msg = "connection exited: " <> Exception.format_exit(reason)

    msg
    |> DBConnection.ConnectionError.exception()
    |> mod.disconnect(state)
  end

  @doc false
  @impl :gen_statem
  def format_status(info, [_, :no_state, %{client: :closed, mod: mod}]) do
    case info do
      :normal -> [{:data, [{~c"Module", mod}]}]
      :terminate -> mod
    end
  end

  def format_status(info, [pdict, :no_state, %{mod: mod, state: state}]) do
    case function_exported?(mod, :format_status, 2) do
      true when info == :normal ->
        normal_status(mod, pdict, state)

      false when info == :normal ->
        normal_status_default(mod, state)

      true when info == :terminate ->
        {mod, terminate_status(mod, pdict, state)}

      false when info == :terminate ->
        {mod, state}
    end
  end

  ## Helpers

  defp maybe_sanitize_exception(e, stack, opts) do
    if Keyword.get(opts, :show_sensitive_data_on_connection_error, false) do
      {e, stack}
    else
      message =
        "connect raised #{inspect(e.__struct__)} exception#{sanitized_message(e)}. " <>
          "The exception details are hidden, as they may contain sensitive data such as " <>
          "database credentials. You may set :show_sensitive_data_on_connection_error " <>
          "to true when starting your connection if you wish to see all of the details"

      {RuntimeError.exception(message), cleanup_stacktrace(stack)}
    end
  end

  defp sanitized_message(%KeyError{} = e), do: ": #{Exception.message(%{e | term: nil})}"
  defp sanitized_message(_), do: ""

  defp connect_opts(opts) do
    case Keyword.get(opts, :configure) do
      {mod, fun, args} ->
        apply(mod, fun, [opts | args])

      fun when is_function(fun, 1) ->
        fun.(opts)

      nil ->
        opts
    end
  end

  defp down_log(:normal), do: :nolog
  defp down_log(:shutdown), do: :nolog
  defp down_log({:shutdown, _}), do: :nolog
  defp down_log(_), do: :log

  defp handle_timeout(s), do: {:keep_state, s}

  defp demonitor({_, mon}) when is_reference(mon) do
    Process.demonitor(mon, [:flush])
  end

  defp demonitor({mon, :after_connect}) when is_reference(mon) do
    Process.demonitor(mon, [:flush])
  end

  defp demonitor({_, _}), do: true
  defp demonitor(nil), do: true

  defp start_timer(_, :infinity), do: nil

  defp start_timer(pid, timeout) do
    :erlang.start_timer(timeout, self(), {__MODULE__, pid, timeout})
  end

  defp cancel_timer(nil), do: :ok

  defp cancel_timer(timer) do
    case :erlang.cancel_timer(timer) do
      false -> flush_timer(timer)
      _ -> :ok
    end
  end

  defp flush_timer(timer) do
    receive do
      {:timeout, ^timer, {__MODULE__, _, _}} ->
        :ok
    after
      0 ->
        raise ArgumentError, "timer #{inspect(timer)} does not exist"
    end
  end

  defp handle_checkin(state, s) do
    %{backoff: backoff, client: client} = s
    backoff = backoff && Backoff.reset(backoff)
    demonitor(client)
    pool_update(state, %{s | client: nil, backoff: backoff})
  end

  defp pool_update(state, %{pool: pool, tag: tag, mod: mod} = s) do
    case Holder.update(pool, tag, mod, state) do
      {:ok, ref} ->
        {:keep_state, %{s | client: {ref, :pool}, state: state}, :hibernate}

      :error ->
        {:stop, {:shutdown, :no_more_pool}, s}
    end
  end

  defp normal_status(mod, pdict, state) do
    try do
      mod.format_status(:normal, [pdict, state])
    catch
      _, _ ->
        normal_status_default(mod, state)
    else
      status ->
        status
    end
  end

  defp normal_status_default(mod, state) do
    [{:data, [{~c"Module", mod}, {~c"State", state}]}]
  end

  defp terminate_status(mod, pdict, state) do
    try do
      mod.format_status(:terminate, [pdict, state])
    catch
      _, _ ->
        state
    else
      status ->
        status
    end
  end

  defp cleanup_stacktrace(stack) do
    case stack do
      [{_, _, arity, _} | _rest] = stacktrace when is_integer(arity) ->
        stacktrace

      [{mod, fun, args, info} | rest] when is_list(args) ->
        [{mod, fun, length(args), info} | rest]
    end
  end

  defp notify_connection_listeners(action, %{} = state) do
    %{connection_listeners: connection_listeners} = state

    {listeners, message} =
      case connection_listeners do
        listeners when is_list(listeners) ->
          {listeners, {action, self()}}

        {listeners, tag} when is_list(listeners) ->
          {listeners, {action, self(), tag}}
      end

    Enum.each(listeners, &send(&1, message))
  end
end