lib/slipstream/socket.ex

defmodule Slipstream.Socket do
  @moduledoc """
  A data structure representing a potential websocket client connection

  This structure closely resembles `t:Phoenix.Socket.t/0`, but is not
  compatible with its functions. All documented functions from this module
  are imported by `use Slipstream`.
  """

  import Kernel, except: [send: 2, pid: 1]

  alias Slipstream.{TelemetryHelper, Socket.Join}
  alias Slipstream.Events

  if Version.match?(System.version(), ">= 1.8.0") do
    @derive {Inspect, only: [:assigns]}
  end

  defstruct [
    :channel_pid,
    :socket_pid,
    :channel_config,
    :response_headers,
    metadata: %{},
    reconnect_counter: 0,
    joins: %{},
    assigns: %{}
  ]

  @typedoc """
  A socket data structure representing a potential websocket client connection
  """
  @typedoc since: "0.1.0"
  @type t :: %__MODULE__{
          channel_pid: nil | pid(),
          socket_pid: pid(),
          channel_config: Slipstream.Configuration.t() | nil,
          metadata: %{atom() => String.t() | %{String.t() => String.t()}},
          reconnect_counter: non_neg_integer(),
          assigns: map(),
          joins: %{String.t() => %Join{}}
        }

  @doc false
  @spec new() :: t()
  def new do
    %__MODULE__{
      socket_pid: self(),
      metadata: %{
        socket_id: TelemetryHelper.trace_id(),
        joins: %{}
      }
    }
  end

  @doc """
  Adds key-value pairs to socket assigns

  Behaves the same as `Phoenix.Socket.assign/3`

  ## Examples

      iex> assign(socket, :key, :value)
      iex> assign(socket, key: :value)
  """
  # and indeed the implementation is just about the same as well.
  # we can't defdelegate/2 though because the socket module is different
  # (hence the struct doesn't match)
  @doc since: "0.1.0"
  @spec assign(t(), Keyword.t()) :: t()
  @spec assign(t(), key :: atom(), value :: any()) :: t()
  def assign(%__MODULE__{} = socket, key, value) when is_atom(key) do
    assign(socket, [{key, value}])
  end

  def assign(%__MODULE__{} = socket, attrs)
      when is_list(attrs) or is_map(attrs) do
    %__MODULE__{socket | assigns: Map.merge(socket.assigns, Map.new(attrs))}
  end

  @doc """
  Updates an existing key in the socket assigns

  Raises a `KeyError` if the key is not present in `socket.assigns`.

  `func` should be an 1-arity function which takes the existing value at assign
  `key` and updates it to a new value. The new value will take the old value's
  place in `socket.assigns[key]`.

  This function is a useful alternative to `assign/3` when the key is already
  present in assigns and is a list, map, or similarly malleable data structure.

  ## Examples

      @impl Slipstream
      def handle_cast({:join, topic}, socket) do
        socket =
          socket
          |> update(:topics, &[topic | &1])
          |> join(topic)

        {:noreply, socket}
      end

      @impl Slipstream
      def handle_call({:join, topic}, from, socket) do
        socket =
          socket
          |> update(:join_requests, &Map.put(&1, topic, from))
          |> join(topic)

        # note: not replying here so we can provide a synchronous call to a
        # topic being joined
        {:noreply, socket}
      end

      @impl Slipstream
      def handle_join(topic, response, socket) do
        case Map.fetch(socket.assigns.join_requests, topic) do
          {:ok, from} -> GenServer.reply(from, {:ok, response})
          :error -> :ok
        end

        {:ok, socket}
      end
  """
  # again, can't defdelegate/2 because of the socket module being different
  # but see the `Phoenix.LiveView.update/3` implementation for the original
  # source
  @doc since: "0.5.0"
  @spec update(t(), key :: atom(), func :: (value :: any() -> value :: any())) ::
          t()
  def update(%__MODULE__{assigns: assigns} = socket, key, func)
      when is_atom(key) and is_function(func, 1) do
    case Map.fetch(assigns, key) do
      {:ok, value} -> assign(socket, [{key, func.(value)}])
      :error -> raise KeyError, key: key, term: assigns
    end
  end

  @doc """
  Checks if a channel is currently joined

  ## Examples

      iex> joined?(socket, "room:lobby")
      true
  """
  @doc since: "0.1.0"
  @spec joined?(t(), topic :: String.t()) :: boolean()
  def joined?(%__MODULE__{} = socket, topic) when is_binary(topic) do
    join_status(socket, topic) == :joined
  end

  @doc """
  Checks the status of a join request

  When a join is requested with `Slipstream.join/3`, the join request is
  considered to be in the `:requested` state. Once the topic is successfully
  joined, it is considered `:joined` until closed. If there is a failure to
  join the topic, if the topic crashes, or if the topic is left after being
  joined, the status of the join is considered `:closed`. Finally, if a topic
  has not been requested in a join so far for a socket, the status is `nil`.

  Notably, the status of a join will not automatically change to `:joined` once
  the remote server replies with successful join. Either the join must be
  awaited with `Slipstream.await_join/2` or the status may be checked later
  in the `c:Slipstream.handle_join/3` callback.

  ## Examples

      iex> socket = join(socket, "room:lobby")
      iex> join_status(socket, "room:lobby")
      :requested
      iex> {:ok, socket, _join_response} = await_join(socket, "room:lobby")
      iex> join_status(socket, "room:lobby")
      :joined
  """
  @doc since: "0.1.0"
  @spec join_status(t(), topic :: String.t()) ::
          :requested | :joined | :closed | nil
  def join_status(%__MODULE__{} = socket, topic) when is_binary(topic) do
    case Map.fetch(socket.joins, topic) do
      {:ok, %Join{status: status}} -> status
      :error -> nil
    end
  end

  @doc """
  Checks if a socket is connected to a remote websocket host

  ## Examples

      iex> socket = connect(socket, uri: "ws://example.org")
      iex> socket = await_connect!(socket)
      iex> connected?(socket)
      true
  """
  @doc since: "0.1.0"
  @spec connected?(t()) :: boolean()
  def connected?(%__MODULE__{} = socket),
    do: socket |> channel_pid() |> is_pid()

  @doc """
  Gets the process ID of the connection

  The slipstream implementor module is not the same process as the GenServer
  which interfaces with the remote server for websocket communication. This
  other process, the Slipstream.Connection process, interfaces with the
  low-level WebSocket connection and communicates with the implementor module
  by puassing messages (mostly with `Kernel.send/2`).

  It can be useful to have access to this pid for testing or debugging
  purposes, such as sending a fake disconnect message or for getting state
  with `:sys.get_state/1`.

  ## Examples

      iex> Slipstream.Socket.channel_pid(socket)
      #PID<0.1.2>
  """
  @doc since: "0.1.0"
  @spec channel_pid(t()) :: pid() | nil
  def channel_pid(%__MODULE__{channel_pid: pid}) do
    if is_pid(pid) and Process.alive?(pid), do: pid, else: nil
  end

  ## helper functions for implementing Slipstream

  @doc false
  def send(%__MODULE__{} = socket, message) do
    if pid = channel_pid(socket), do: Kernel.send(pid, message)

    socket
  end

  @doc false
  def call(%__MODULE__{} = socket, message, timeout) do
    if pid = channel_pid(socket) do
      {:ok, GenServer.call(pid, message, timeout)}
    else
      {:error, :not_connected}
    end
  end

  @doc false
  def put_join_config(%__MODULE__{} = socket, topic, params) do
    join = Join.new(topic, params)

    %__MODULE__{socket | joins: Map.put_new(socket.joins, topic, join)}
  end

  # potentially changes a socket by applying an event to it
  @doc false
  @spec apply_event(t(), struct()) :: t()
  def apply_event(socket, event)

  def apply_event(socket, %Events.ChannelConnected{} = event) do
    socket = TelemetryHelper.conclude_connect(socket, event)

    %__MODULE__{
      socket
      | channel_pid: event.pid,
        channel_config: event.config || socket.channel_config,
        reconnect_counter: 0
    }
  end

  def apply_event(socket, %Events.TopicJoinSucceeded{topic: topic} = event) do
    socket
    |> TelemetryHelper.conclude_join(event)
    |> put_in([Access.key(:joins), topic, Access.key(:status)], :joined)
    |> put_in([Access.key(:joins), topic, Access.key(:rejoin_counter)], 0)
  end

  def apply_event(socket, %event{topic: topic})
      when event in [
             Events.TopicLeft,
             Events.TopicJoinFailed,
             Events.TopicJoinClosed
           ] do
    put_in(socket, [Access.key(:joins), topic, Access.key(:status)], :closed)
  end

  def apply_event(socket, %Events.ChannelClosed{}) do
    %__MODULE__{
      socket
      | channel_pid: nil,
        joins:
          Enum.into(socket.joins, %{}, fn {topic, join} ->
            {topic, %Join{join | status: :closed}}
          end)
    }
  end

  def apply_event(socket, _event), do: socket

  @doc false
  @spec next_reconnect_time(t()) :: {non_neg_integer(), t()}
  def next_reconnect_time(%__MODULE__{} = socket) do
    socket = update_in(socket, [Access.key(:reconnect_counter)], &(&1 + 1))

    time =
      retry_time(
        socket.channel_config.reconnect_after_msec,
        socket.reconnect_counter - 1
      )

    {time, socket}
  end

  @doc false
  @spec next_rejoin_time(t(), String.t()) :: {non_neg_integer(), t()}
  def next_rejoin_time(socket, topic) do
    socket =
      update_in(
        socket,
        [Access.key(:joins), topic, Access.key(:rejoin_counter)],
        &(&1 + 1)
      )

    time =
      retry_time(
        socket.channel_config.rejoin_after_msec,
        socket.joins[topic].rejoin_counter - 1
      )

    {time, socket}
  end

  defp retry_time(backoff_times, try_number) do
    # when we hit the end of the list, we repeat the last value in the list
    default = Enum.at(backoff_times, -1)

    Enum.at(backoff_times, try_number, default)
  end
end