Skip to main content

lib/council_ex/providers/mock.ex

defmodule CouncilEx.Providers.Mock do
  @moduledoc """
  In-process mock provider for tests and demos.

  Scripts are stored in a shared ETS table keyed by the owner process. When
  `complete/2` runs in a spawned Task (e.g. inside `Task.async_stream`), it
  walks the `$callers` ancestor chain to find the originating test process
  and reads scripts registered by that process.

  This makes `Mock` safe to use under `async: true` across the Runner's
  parallel member execution: each test owns its own script set, and child
  tasks discover the right set via the caller chain.

  ## Usage

      Mock.script(:my_member, content: "fake answer")
      Mock.script(:other_member, fn req -> {:ok, Response.new(content: req.model, model: "mock")} end)
      Mock.script(:flaky_member, error: :timeout)
      Mock.script(:flaky_member, error: [kind: :transient, reason: :rate_limit])

  ## Notes

  - If both `:error` and `:content` are passed in the keyword form, `:error`
    takes precedence.
  - `Request.member_id` must not be `nil`. Mock returns
    `{:error, %CouncilEx.Error{kind: :permanent, reason: :missing_member_id}}` in that case.
  - All errors are returned as `%CouncilEx.Error{}` structs. Legacy atom errors
    (e.g. `error: :boom`) are auto-classified as `:permanent`.
  """

  @behaviour CouncilEx.Provider

  require Logger

  alias CouncilEx.{Request, Response}

  @table __MODULE__

  @impl CouncilEx.Provider
  def complete(%Request{member_id: nil}, _opts) do
    {:error, %CouncilEx.Error{kind: :permanent, reason: :missing_member_id}}
  end

  def complete(%Request{member_id: member_id} = req, _opts) do
    case fetch_script(member_id) do
      {:exhausted, ^member_id} ->
        {:error, %CouncilEx.Error{kind: :permanent, reason: {:script_exhausted, member_id}}}

      {:sequence, entries} ->
        handle_sequence(member_id, entries, req)

      {:fun, fun} ->
        normalize_result(fun.(req))

      {:response_attrs, attrs} ->
        {:ok, build_response(attrs, req)}

      {:content, content} ->
        {:ok, Response.new(content: content, model: req.model)}

      {:respond_input, _input} = entry ->
        run_entry(entry, req)

      {:error, %CouncilEx.Error{} = err} ->
        {:error, err}

      {:error, reason} ->
        {:error, %CouncilEx.Error{kind: :permanent, reason: reason}}

      :missing ->
        warn_missing_script(member_id)
        {:error, %CouncilEx.Error{kind: :permanent, reason: {:no_script, member_id}}}
    end
  end

  defp warn_missing_script(member_id) do
    Logger.warning(fn ->
      registered =
        @table
        |> :ets.match({{:"$1", :"$2"}, :_})
        |> Enum.map(fn [_pid, id] -> id end)
        |> Enum.uniq()

      "CouncilEx.Providers.Mock: no script for #{inspect(member_id)} " <>
        "(type=#{type_name(member_id)}). Registered keys: #{inspect(registered)}. " <>
        "Mock matches member_id by exact term equality — atom vs string mismatches " <>
        "are a common cause. Ensure Mock.script/2 uses the same term type as " <>
        "the council's member id."
    end)
  end

  defp type_name(v) when is_atom(v), do: :atom
  defp type_name(v) when is_binary(v), do: :string
  defp type_name(_), do: :other

  @impl CouncilEx.Provider
  def stream(%Request{member_id: nil}, _opts, _sink) do
    {:error, %CouncilEx.Error{kind: :permanent, reason: :missing_member_id}}
  end

  def stream(%Request{member_id: member_id} = req, _opts, sink) do
    case fetch_script(member_id) do
      {:stream_attrs, attrs} ->
        chunks = Keyword.fetch!(attrs, :stream)
        delay = Keyword.get(attrs, :chunk_delay_ms, 0)
        do_stream(member_id, chunks, delay, sink, req)

      {:fun, fun} ->
        # Functional scripts run sync; stream not supported.
        normalize_result(fun.(req))

      {:content, content} ->
        sink.(%CouncilEx.StreamChunk{content: content, index: 0, finish_reason: :stop})
        {:ok, Response.new(content: content, model: req.model)}

      _ ->
        complete(req, [])
    end
  end

  defp do_stream(_member_id, chunks, delay, sink, req) do
    final_text =
      chunks
      |> Enum.with_index()
      |> Enum.map(fn {ch, idx} ->
        sink.(%CouncilEx.StreamChunk{content: ch, index: idx, finish_reason: nil})
        if delay > 0, do: Process.sleep(delay)
        ch
      end)
      |> Enum.join()

    sink.(%CouncilEx.StreamChunk{content: "", index: length(chunks), finish_reason: :stop})

    {:ok, Response.new(content: final_text, model: req.model)}
  end

  defp handle_sequence(member_id, [entry | rest], req) do
    result = run_entry(entry, req)

    if rest == [] do
      put({:exhausted, member_id}, member_id)
    else
      put({:sequence, rest}, member_id)
    end

    result
  end

  defp handle_sequence(_member_id, [], _req) do
    {:error, %CouncilEx.Error{kind: :permanent, reason: :empty_sequence}}
  end

  defp run_entry({:response_attrs, attrs}, req),
    do: {:ok, build_response(attrs, req)}

  defp run_entry({:respond_input, input}, %CouncilEx.Request{response_schema: nil}) do
    {:ok,
     %CouncilEx.Response{
       content: Jason.encode!(input),
       parsed: input,
       model: "mock"
     }}
  end

  defp run_entry({:respond_input, input}, %CouncilEx.Request{response_schema: schema}) do
    fields = schema.__schema__(:fields)
    types = Map.new(fields, fn f -> {f, schema.__schema__(:type, f)} end)

    cs =
      Ecto.Changeset.cast(
        {struct(schema), types},
        Map.new(input, fn {k, v} -> {to_string(k), v} end),
        fields
      )

    if cs.valid? do
      {:ok,
       %CouncilEx.Response{
         content: Jason.encode!(input),
         parsed: Ecto.Changeset.apply_changes(cs),
         model: "mock"
       }}
    else
      {:error, %CouncilEx.Error{kind: :validation, reason: cs}}
    end
  end

  defp run_entry({:tool_calls_attrs, tool_calls}, req) do
    {:ok,
     CouncilEx.Response.new(
       content: "",
       parsed: :tool_calls_pending,
       model: req.model,
       raw: %{tool_calls: tool_calls}
     )}
  end

  defp run_entry({:content, content}, req),
    do: {:ok, CouncilEx.Response.new(content: content, model: req.model)}

  defp run_entry({:error, %CouncilEx.Error{} = err}, _req), do: {:error, err}

  defp run_entry({:error, term}, _req),
    do: {:error, %CouncilEx.Error{kind: :permanent, reason: term}}

  defp normalize_result({:ok, %CouncilEx.Response{}} = ok), do: ok
  defp normalize_result({:error, %CouncilEx.Error{}} = err), do: err

  defp normalize_result({:error, reason}),
    do: {:error, %CouncilEx.Error{kind: :permanent, reason: reason}}

  defp build_response(attrs, req) do
    CouncilEx.Response.new(
      content: Keyword.fetch!(attrs, :content),
      model: Keyword.get(attrs, :model, req.model),
      parsed: Keyword.get(attrs, :parsed),
      usage: Keyword.get(attrs, :usage, %{input_tokens: 0, output_tokens: 0}),
      raw: Keyword.get(attrs, :raw),
      latency_ms: Keyword.get(attrs, :latency_ms)
    )
  end

  @doc "Register a scripted response for a member id."
  @spec script(atom(), (Request.t() -> {:ok, Response.t()} | {:error, term()}) | keyword()) :: :ok
  def script(member_id, fun) when is_function(fun, 1) do
    put({:fun, fun}, member_id)
  end

  def script(member_id, scripts) when is_list(scripts) do
    if all_lists?(scripts) and scripts != [] do
      entries = Enum.map(scripts, &compile_entry/1)
      put({:sequence, entries}, member_id)
      :ok
    else
      script_kw(member_id, scripts)
    end
  end

  defp all_lists?(scripts), do: Enum.all?(scripts, &is_list/1)

  defp compile_entry(opts) when is_list(opts) do
    cond do
      Keyword.has_key?(opts, :tool_calls) ->
        {:tool_calls_attrs, Keyword.fetch!(opts, :tool_calls)}

      Keyword.has_key?(opts, :respond_input) ->
        {:respond_input, Keyword.fetch!(opts, :respond_input)}

      Keyword.has_key?(opts, :error) ->
        case Keyword.fetch!(opts, :error) do
          err_opts when is_list(err_opts) ->
            {:error,
             %CouncilEx.Error{
               kind: Keyword.fetch!(err_opts, :kind),
               reason: Keyword.get(err_opts, :reason),
               context: Keyword.get(err_opts, :context, %{})
             }}

          term ->
            {:error, term}
        end

      Keyword.has_key?(opts, :stream) ->
        {:stream_attrs, opts}

      Keyword.has_key?(opts, :content) ->
        {:response_attrs, Keyword.delete(opts, :error)}

      true ->
        raise ArgumentError, "Mock script entry needs :error or :content"
    end
  end

  defp script_kw(member_id, opts) do
    cond do
      Keyword.has_key?(opts, :tool_calls) ->
        tool_calls = Keyword.fetch!(opts, :tool_calls)
        then_content = Keyword.get(opts, :then, "")

        entries = [
          {:tool_calls_attrs, tool_calls},
          {:content, then_content}
        ]

        put({:sequence, entries}, member_id)

      Keyword.has_key?(opts, :respond_input) ->
        put({:respond_input, Keyword.fetch!(opts, :respond_input)}, member_id)

      Keyword.has_key?(opts, :error) ->
        case Keyword.fetch!(opts, :error) do
          err_opts when is_list(err_opts) ->
            put(
              {:error,
               %CouncilEx.Error{
                 kind: Keyword.fetch!(err_opts, :kind),
                 reason: Keyword.get(err_opts, :reason),
                 context: Keyword.get(err_opts, :context, %{})
               }},
              member_id
            )

          atom_or_term ->
            put({:error, atom_or_term}, member_id)
        end

      Keyword.has_key?(opts, :stream) ->
        put({:stream_attrs, opts}, member_id)

      Keyword.has_key?(opts, :content) ->
        put({:response_attrs, Keyword.delete(opts, :error)}, member_id)

      true ->
        raise ArgumentError, "Mock.script/2 needs :content or :error or a 1-arity fun"
    end
  end

  @doc "Clear all scripts for the current owner process."
  @spec reset() :: :ok
  def reset do
    ensure_table()
    :ets.match_delete(@table, {{self(), :_}, :_})
    :ok
  end

  defp put(entry, member_id) do
    ensure_table()
    :ets.insert(@table, {{self(), member_id}, entry})
    :ok
  end

  defp fetch_script(member_id) do
    ensure_table()

    self()
    |> owner_chain()
    |> Enum.find_value(:missing, fn pid ->
      case :ets.lookup(@table, {pid, member_id}) do
        [{_, entry}] -> entry
        [] -> false
      end
    end)
  end

  defp owner_chain(pid) do
    [pid | Process.get(:"$callers", [])]
  end

  defp ensure_table do
    case :ets.whereis(@table) do
      :undefined ->
        :ets.new(@table, [:set, :public, :named_table, read_concurrency: true])

      _ ->
        :ok
    end
  rescue
    ArgumentError -> :ok
  end
end