Skip to main content

lib/jido/chat/state_adapter.ex

defmodule Jido.Chat.StateAdapter do
  @moduledoc """
  Behavior and helpers for pluggable chat state storage.

  State adapters own subscriptions, dedupe windows, and per-thread / per-channel
  state maps. The default adapter keeps everything in memory, but adapters may
  persist state elsewhere as long as they can expose a normalized snapshot.
  """

  @type dedupe_key :: {atom(), String.t()}

  @type snapshot :: %{
          subscriptions: MapSet.t(String.t()),
          dedupe: MapSet.t(dedupe_key()),
          dedupe_order: [dedupe_key()],
          thread_state: %{optional(String.t()) => map()},
          channel_state: %{optional(String.t()) => map()},
          locks: %{optional(String.t()) => map()},
          pending_locks: %{optional(String.t()) => [map()]}
        }

  @type state :: term()
  @type lock_result :: :acquired | :queued | :debounced | :busy
  @type release_result :: {:released, [map()]} | {:error, :not_owner}

  @callback init(snapshot(), keyword()) :: state()
  @callback snapshot(state()) :: snapshot() | map()
  @callback subscribed?(state(), String.t()) :: boolean()
  @callback subscribe(state(), String.t()) :: state()
  @callback unsubscribe(state(), String.t()) :: state()
  @callback thread_state(state(), String.t()) :: map()
  @callback put_thread_state(state(), String.t(), map()) :: state()
  @callback channel_state(state(), String.t()) :: map()
  @callback put_channel_state(state(), String.t(), map()) :: state()
  @callback duplicate?(state(), dedupe_key()) :: boolean()
  @callback mark_dedupe(state(), dedupe_key(), pos_integer()) :: state()
  @callback lock(state(), String.t(), String.t(), atom(), map()) :: {lock_result(), state()}
  @callback release_lock(state(), String.t(), String.t()) :: {release_result(), state()}
  @callback force_release_lock(state(), String.t()) :: {{:released, [map()]}, state()}

  @dialyzer {:nowarn_function, default_snapshot: 0}

  @doc "Initializes adapter state from a normalized snapshot."
  @spec init(module(), map(), keyword()) :: state()
  def init(adapter_module, snapshot, opts \\ []) when is_atom(adapter_module) do
    adapter_module.init(normalize_snapshot(snapshot), opts)
  end

  @doc "Returns a normalized snapshot for adapter-managed state."
  @spec snapshot(module(), state()) :: snapshot()
  def snapshot(adapter_module, state) when is_atom(adapter_module) do
    adapter_module.snapshot(state)
    |> normalize_snapshot()
  end

  @doc "Returns true when the thread is subscribed in adapter-managed state."
  @spec subscribed?(module(), state(), String.t()) :: boolean()
  def subscribed?(adapter_module, state, thread_id)
      when is_atom(adapter_module) and is_binary(thread_id) do
    adapter_module.subscribed?(state, thread_id)
  end

  @doc "Adds a subscribed thread id to adapter-managed state."
  @spec subscribe(module(), state(), String.t()) :: state()
  def subscribe(adapter_module, state, thread_id)
      when is_atom(adapter_module) and is_binary(thread_id) do
    adapter_module.subscribe(state, thread_id)
  end

  @doc "Removes a subscribed thread id from adapter-managed state."
  @spec unsubscribe(module(), state(), String.t()) :: state()
  def unsubscribe(adapter_module, state, thread_id)
      when is_atom(adapter_module) and is_binary(thread_id) do
    adapter_module.unsubscribe(state, thread_id)
  end

  @doc "Returns thread state map from adapter-managed state."
  @spec thread_state(module(), state(), String.t()) :: map()
  def thread_state(adapter_module, state, thread_id)
      when is_atom(adapter_module) and is_binary(thread_id) do
    adapter_module.thread_state(state, thread_id)
  end

  @doc "Writes thread state map into adapter-managed state."
  @spec put_thread_state(module(), state(), String.t(), map()) :: state()
  def put_thread_state(adapter_module, state, thread_id, value)
      when is_atom(adapter_module) and is_binary(thread_id) and is_map(value) do
    adapter_module.put_thread_state(state, thread_id, value)
  end

  @doc "Returns channel state map from adapter-managed state."
  @spec channel_state(module(), state(), String.t()) :: map()
  def channel_state(adapter_module, state, channel_id)
      when is_atom(adapter_module) and is_binary(channel_id) do
    adapter_module.channel_state(state, channel_id)
  end

  @doc "Writes channel state map into adapter-managed state."
  @spec put_channel_state(module(), state(), String.t(), map()) :: state()
  def put_channel_state(adapter_module, state, channel_id, value)
      when is_atom(adapter_module) and is_binary(channel_id) and is_map(value) do
    adapter_module.put_channel_state(state, channel_id, value)
  end

  @doc "Returns true when a message dedupe key has already been seen."
  @spec duplicate?(module(), state(), dedupe_key()) :: boolean()
  def duplicate?(adapter_module, state, key)
      when is_atom(adapter_module) and is_tuple(key) do
    adapter_module.duplicate?(state, key)
  end

  @doc "Records a new dedupe key and trims state to the requested limit."
  @spec mark_dedupe(module(), state(), dedupe_key(), pos_integer()) :: state()
  def mark_dedupe(adapter_module, state, key, limit)
      when is_atom(adapter_module) and is_tuple(key) and is_integer(limit) and limit > 0 do
    adapter_module.mark_dedupe(state, key, limit)
  end

  @doc "Attempts to acquire a concurrency lock for the given key and owner."
  @spec lock(module(), state(), String.t(), String.t(), atom(), map()) :: {lock_result(), state()}
  def lock(adapter_module, state, key, owner, strategy, metadata \\ %{})
      when is_atom(adapter_module) and is_binary(key) and is_binary(owner) and is_atom(strategy) and
             is_map(metadata) do
    adapter_module.lock(state, key, owner, strategy, metadata)
  end

  @doc "Releases a held lock and returns any queued/debounced pending entries."
  @spec release_lock(module(), state(), String.t(), String.t()) :: {release_result(), state()}
  def release_lock(adapter_module, state, key, owner)
      when is_atom(adapter_module) and is_binary(key) and is_binary(owner) do
    adapter_module.release_lock(state, key, owner)
  end

  @doc "Force-releases a lock regardless of owner and returns pending entries."
  @spec force_release_lock(module(), state(), String.t()) :: {{:released, [map()]}, state()}
  def force_release_lock(adapter_module, state, key)
      when is_atom(adapter_module) and is_binary(key) do
    adapter_module.force_release_lock(state, key)
  end

  @doc "Returns the canonical empty snapshot."
  @spec default_snapshot() :: snapshot()
  def default_snapshot do
    %{
      subscriptions: MapSet.new(),
      dedupe: MapSet.new(),
      dedupe_order: [],
      thread_state: %{},
      channel_state: %{},
      locks: %{},
      pending_locks: %{}
    }
  end

  @doc "Normalizes maps, lists, and map-sets into the canonical state snapshot shape."
  @spec normalize_snapshot(map()) :: snapshot()
  def normalize_snapshot(snapshot) when is_map(snapshot) do
    defaults = default_snapshot()

    %{
      subscriptions: snapshot[:subscriptions] || snapshot["subscriptions"] || defaults.subscriptions,
      dedupe: snapshot[:dedupe] || snapshot["dedupe"] || defaults.dedupe,
      dedupe_order: snapshot[:dedupe_order] || snapshot["dedupe_order"] || defaults.dedupe_order,
      thread_state: snapshot[:thread_state] || snapshot["thread_state"] || defaults.thread_state,
      channel_state: snapshot[:channel_state] || snapshot["channel_state"] || defaults.channel_state,
      locks: snapshot[:locks] || snapshot["locks"] || defaults.locks,
      pending_locks: snapshot[:pending_locks] || snapshot["pending_locks"] || defaults.pending_locks
    }
    |> normalize_subscriptions()
    |> normalize_dedupe()
    |> normalize_dedupe_order()
    |> normalize_thread_state()
    |> normalize_channel_state()
    |> normalize_locks()
    |> normalize_pending_locks()
  end

  def normalize_snapshot(_snapshot), do: default_snapshot()

  defp normalize_subscriptions(snapshot) do
    subscriptions =
      case snapshot.subscriptions do
        %MapSet{} = subscriptions ->
          subscriptions

        subscriptions when is_list(subscriptions) ->
          subscriptions
          |> Enum.map(&to_string/1)
          |> MapSet.new()

        _ ->
          MapSet.new()
      end

    %{snapshot | subscriptions: subscriptions}
  end

  defp normalize_dedupe(snapshot) do
    dedupe =
      case snapshot.dedupe do
        %MapSet{} = dedupe ->
          dedupe

        dedupe when is_list(dedupe) ->
          Enum.reduce(dedupe, MapSet.new(), fn
            [adapter_name, message_id], acc ->
              MapSet.put(acc, {normalize_key_atom(adapter_name), to_string(message_id)})

            {adapter_name, message_id}, acc ->
              MapSet.put(acc, {normalize_key_atom(adapter_name), to_string(message_id)})

            _other, acc ->
              acc
          end)

        _ ->
          MapSet.new()
      end

    %{snapshot | dedupe: dedupe}
  end

  defp normalize_dedupe_order(snapshot) do
    dedupe_order =
      case snapshot.dedupe_order do
        dedupe_order when is_list(dedupe_order) ->
          Enum.reduce(dedupe_order, [], fn
            [adapter_name, message_id], acc ->
              [{normalize_key_atom(adapter_name), to_string(message_id)} | acc]

            {adapter_name, message_id}, acc ->
              [{normalize_key_atom(adapter_name), to_string(message_id)} | acc]

            _other, acc ->
              acc
          end)
          |> Enum.reverse()

        _ ->
          []
      end

    %{snapshot | dedupe_order: dedupe_order}
  end

  defp normalize_thread_state(snapshot) do
    thread_state =
      case snapshot.thread_state do
        thread_state when is_map(thread_state) -> thread_state
        _ -> %{}
      end

    %{snapshot | thread_state: thread_state}
  end

  defp normalize_channel_state(snapshot) do
    channel_state =
      case snapshot.channel_state do
        channel_state when is_map(channel_state) -> channel_state
        _ -> %{}
      end

    %{snapshot | channel_state: channel_state}
  end

  defp normalize_locks(snapshot) do
    locks =
      case snapshot.locks do
        locks when is_map(locks) -> locks
        _ -> %{}
      end

    %{snapshot | locks: locks}
  end

  defp normalize_pending_locks(snapshot) do
    pending_locks =
      case snapshot.pending_locks do
        pending when is_map(pending) ->
          pending
          |> Enum.map(fn {key, entries} ->
            normalized_entries =
              if is_list(entries) do
                Enum.filter(entries, &is_map/1)
              else
                []
              end

            {to_string(key), normalized_entries}
          end)
          |> Map.new()

        _ ->
          %{}
      end

    %{snapshot | pending_locks: pending_locks}
  end

  defp normalize_key_atom(key) when is_atom(key), do: key
  defp normalize_key_atom(key) when is_binary(key), do: String.to_atom(key)
  defp normalize_key_atom(key), do: key
end