Skip to main content

lib/skill_kit/test.ex

if Mix.env() == :test do
  defmodule SkillKit.Test do
    @moduledoc """
    Test helpers for SkillKit.

    Provides Mox convenience helpers for testing agents and LLM interactions.
    Provider-specific event builders live in their own modules
    (e.g., `Anthropic.Test`).

    ## Setup

        use SkillKit.Test

    This imports `SkillKit.Test` and sets up `Mox.verify_on_exit!/1`.
    """

    alias SkillKit.Agent
    alias SkillKit.Agent.Server
    alias SkillKit.Agent.ToolRunner
    alias SkillKit.Response.Error

    defmacro __using__(_opts) do
      quote do
        import SkillKit.Test
        setup :verify_on_exit!
      end
    end

    @doc """
    Starts a bare Server process for unit testing.

    Returns `{:ok, server_pid, context}` where context contains `:registry`,
    `:agent_name`, and `:definition`.
    """
    @spec start_server(keyword()) :: {:ok, pid(), map()}
    def start_server(opts \\ []) do
      agent_name =
        Keyword.get(opts, :agent_name, "test-agent-#{:erlang.unique_integer([:positive])}")

      caller = Keyword.get(opts, :caller, self())
      scope = Keyword.get(opts, :scope)
      tools = Keyword.get(opts, :tools, [])
      skills = Keyword.get(opts, :skills, [])
      registry_name = :"test_registry_#{:erlang.unique_integer([:positive])}"

      ExUnit.Callbacks.start_supervised!({Registry, keys: :unique, name: registry_name})

      agent =
        Keyword.get_lazy(opts, :agent, fn ->
          %Agent{
            name: agent_name,
            description: "Test agent",
            system_prompt: "You are a test agent.",
            caller: caller,
            scope: scope,
            tools: tools,
            skills: skills,
            registry: registry_name
          }
        end)

      ExUnit.Callbacks.start_supervised!(
        {SkillKit.Catalog,
         name: {:via, Registry, {registry_name, {agent.name, :catalog}}},
         tools: tools,
         skills: skills,
         scope: scope}
      )

      ToolRunner.start_link(agent)

      {:ok, pid} = Server.start_link(agent)

      Mox.allow(SkillKit.LLM.Mock, self(), pid)

      context = %{registry: registry_name, agent_name: agent.name, agent: agent}
      {:ok, pid, context}
    end

    @doc """
    Sets up a single Mox expectation that returns the given response.
    """
    @spec expect_response(struct()) :: :ok
    def expect_response(response) do
      Mox.expect(SkillKit.LLM.Mock, :stream, 1, fn _messages, _opts ->
        build_event_stream(response)
      end)

      :ok
    end

    @doc """
    Sets up a Mox expectation that runs the assertion callback, then returns the response.

    The callback receives `(messages, opts)` — the arguments the LLM mock was called with.
    """
    @spec assert_response(struct(), (list(), keyword() -> any())) :: :ok
    def assert_response(response, assertion_fn) do
      Mox.expect(SkillKit.LLM.Mock, :stream, 1, fn messages, opts ->
        assertion_fn.(messages, opts)
        build_event_stream(response)
      end)

      :ok
    end

    @doc """
    Sets up multi-call Mox expectations from a list of response types.

    Each element corresponds to one `LLM.stream` call in order.
    """
    @spec expect_responses([struct()]) :: :ok
    def expect_responses(responses) do
      count = length(responses)
      counter = :counters.new(1, [:atomics])
      responses_list = :lists.zip(:lists.seq(1, count), responses)
      responses_map = Map.new(responses_list)

      Mox.expect(SkillKit.LLM.Mock, :stream, count, fn _messages, _opts ->
        index = :counters.get(counter, 1) + 1
        :counters.put(counter, 1, index)
        build_event_stream(Map.fetch!(responses_map, index))
      end)

      :ok
    end

    @doc """
    Sets up a Mox expectation returning an LLM error.
    """
    @spec expect_error(integer(), String.t()) :: :ok
    def expect_error(status, message) do
      expect_response(%Error{status: status, message: message})
    end

    defp build_event_stream(%SkillKit.Response.Text{content: text}) do
      events = [
        %SkillKit.Event.Delta{text: text},
        %SkillKit.Event.Done{stop_reason: :end_turn}
      ]

      {:ok, Stream.map(events, & &1)}
    end

    defp build_event_stream(%SkillKit.Response.ToolCall{name: name, input: input}) do
      id = "tc_test_#{:erlang.unique_integer([:positive])}"

      events = [
        %SkillKit.Event.ToolCallStart{id: id, name: name},
        %SkillKit.Event.ToolCallComplete{id: id, name: name, input: input},
        %SkillKit.Event.Done{stop_reason: :tool_use}
      ]

      {:ok, Stream.map(events, & &1)}
    end

    defp build_event_stream(%SkillKit.Response.Empty{}) do
      events = [%SkillKit.Event.Done{stop_reason: :end_turn}]
      {:ok, Stream.map(events, & &1)}
    end

    defp build_event_stream(%SkillKit.Response.Error{status: status, message: message}) do
      {:error, {status, message}}
    end
  end
end