lib/nostrum/api/ratelimiter.ex

defmodule Nostrum.Api.Ratelimiter do
  @moduledoc """
  Ratelimit implementation specific to Discord's API.
  Only to be used when starting in a rest-only manner.
  """

  use GenServer

  alias Nostrum.Api.{Base, Bucket}
  alias Nostrum.Constants
  alias Nostrum.Error.ApiError
  alias Nostrum.Util

  require Logger

  @typedoc """
  Return values of start functions.
  """
  @type on_start ::
          {:ok, pid}
          | :ignore
          | {:error, {:already_started, pid} | term}

  @major_parameters ["channels", "guilds", "webhooks"]
  @gregorian_epoch 62_167_219_200

  @doc """
  Starts the ratelimiter.
  """
  @spec start_link([]) :: on_start
  def start_link([]) do
    GenServer.start_link(__MODULE__, [], name: Ratelimiter)
  end

  def init([]) do
    :ets.new(:ratelimit_buckets, [:set, :public, :named_table])
    domain = to_charlist(Constants.domain())

    open_opts = %{retry: 1_000_000_000, tls_opts: Constants.gun_tls_opts()}
    {:ok, conn_pid} = :gun.open(domain, 443, open_opts)

    {:ok, :http2} = :gun.await_up(conn_pid)

    # Start the old route cleanup loop
    Process.send_after(self(), :remove_old_buckets, :timer.hours(1))

    {:ok, conn_pid}
  end

  @doc """
  Empties all buckets, voiding any saved ratelimit values.
  """
  @spec empty_buckets() :: true
  def empty_buckets do
    :ets.delete_all_objects(:ratelimit_buckets)
  end

  def handle_call({:queue, request, original_from}, from, conn) do
    retry_time =
      request.route
      |> get_endpoint(request.method)
      |> Bucket.get_ratelimit_timeout()

    case retry_time do
      :now ->
        GenServer.reply(original_from || from, do_request(request, conn))

      time when time < 0 ->
        GenServer.reply(original_from || from, do_request(request, conn))

      time ->
        Task.start(fn ->
          wait_for_timeout(request, time, original_from || from)
        end)
    end

    {:noreply, conn}
  end

  def handle_info({:gun_down, _conn, _proto, _reason, _killed_streams}, state) do
    {:noreply, state}
  end

  def handle_info({:gun_up, _conn, _proto}, state) do
    {:noreply, state}
  end

  def handle_info({:gun_response, _conn, _ref, :nofin, status, headers}, state) do
    Logger.debug("Got unexpected response with status #{status}, headers #{inspect(headers)}")
    {:noreply, state}
  end

  def handle_info(:remove_old_buckets, state) do
    Bucket.remove_old_buckets()
    Process.send_after(self(), :remove_old_buckets, :timer.hours(1))
    {:noreply, state}
  end

  defp do_request(request, conn) do
    conn
    |> Base.request(request.method, request.route, request.body, request.headers, request.params)
    |> handle_headers(get_endpoint(request.route, request.method))
    |> format_response
  end

  @spec value_from_rltuple({String.t(), String.t()}) :: String.t() | nil
  defp value_from_rltuple({_k, v}), do: v

  @spec header_value([{String.t(), String.t()}], String.t(), String.t() | nil) :: String.t() | nil
  defp header_value(headers, key, default \\ nil) do
    headers
    |> List.keyfind(key, 0, {key, default})
    |> value_from_rltuple()
  end

  defp handle_headers({:error, reason}, _route), do: {:error, reason}

  defp handle_headers({:ok, {_status, headers, _body}} = response, route) do
    # Per https://discord.com/developers/docs/topics/rate-limits, all of these
    # headers are optional, which is why we supply a default of 0.

    global_limit = header_value(headers, "x-ratelimit-global")

    remaining = header_value(headers, "x-ratelimit-remaining")
    remaining = unless is_nil(remaining), do: String.to_integer(remaining)

    reset = header_value(headers, "x-ratelimit-reset")
    reset = unless is_nil(reset), do: String.to_float(reset)
    retry_after = header_value(headers, "retry-after")

    retry_after =
      unless is_nil(retry_after) do
        # Since for some reason this might not contain a "."
        # and String.to_float raises if it doesn't
        {retry_after, ""} = Float.parse(retry_after)
        retry_after
      end

    origin_timestamp =
      headers
      |> header_value("date", "0")
      |> date_string_to_unix

    latency = abs(origin_timestamp - Util.now())

    # If we have hit a global limit, Discord responds with a 429 and informs
    # us when we can retry. Our global bucket keeps track of this ratelimit.
    unless is_nil(global_limit), do: update_global_bucket(route, 0, retry_after, latency)

    # If Discord did send us other ratelimit information, we can also update
    # the ratelimiter bucket for this route. For some endpoints, such as
    # when creating a DM with a user, we may not retrieve ratelimit headers.
    unless is_nil(reset) or is_nil(remaining), do: update_bucket(route, remaining, reset, latency)

    response
  end

  defp update_bucket(route, remaining, reset_time, latency) do
    Bucket.update_bucket(route, remaining, reset_time * 1000, latency)
  end

  defp update_global_bucket(_route, _remaining, retry_after, latency) do
    Bucket.update_bucket("GLOBAL", 0, retry_after + Util.now(), latency)
  end

  defp wait_for_timeout(request, timeout, from) do
    truncated = :erlang.ceil(timeout)

    Logger.info(
      "RATELIMITER: Waiting #{truncated}ms to process request with route #{request.route}"
    )

    Process.sleep(truncated)
    GenServer.call(Ratelimiter, {:queue, request, from}, :infinity)
  end

  defp date_string_to_unix(header) do
    header
    |> String.to_charlist()
    |> :httpd_util.convert_request_date()
    |> erl_datetime_to_timestamp
  end

  defp erl_datetime_to_timestamp(datetime) do
    (:calendar.datetime_to_gregorian_seconds(datetime) - @gregorian_epoch) * 1000
  end

  @doc """
  Retrieves a proper ratelimit endpoint from a given route and url.
  """
  @spec get_endpoint(String.t(), String.t()) :: String.t()
  def get_endpoint(route, method) do
    endpoint =
      Regex.replace(~r/\/([a-z-]+)\/(?:[0-9]{17,19})/i, route, fn capture, param ->
        case param do
          param when param in @major_parameters ->
            capture

          param ->
            "/#{param}/_id"
        end
      end)
      |> replace_webhook_token()
      |> replace_emojis()

    if String.ends_with?(endpoint, "/messages/_id") and method == :delete do
      "delete:" <> endpoint
    else
      endpoint
    end
  end

  defp format_response(response) do
    case response do
      {:error, error} ->
        {:error, error}

      {:ok, {status, _, body}} when status in [200, 201] ->
        {:ok, body}

      {:ok, {204, _, _}} ->
        {:ok}

      {:ok, {status, _, body}} ->
        {:error, %ApiError{status_code: status, response: Jason.decode!(body, keys: :atoms)}}
    end
  end

  defp replace_emojis(endpoint) do
    Regex.replace(
      ~r/\/reactions\/[^\/]+\/?(@me|_id)?/i,
      endpoint,
      "/reactions/_emoji/\\g{1}/"
    )
  end

  defp replace_webhook_token(endpoint) do
    Regex.replace(
      ~r/\/webhooks\/([0-9]{17,19})\/[^\/]+\/?/i,
      endpoint,
      "/webhooks/\\g{1}/_token/"
    )
  end
end