Skip to main content

lib/pi/agent.ex

defmodule Pi.Agent do
  @moduledoc "Unified BEAM abstraction for top-level agents and child agents."

  alias Pi.Agent.Job
  alias Pi.Agent.Manager
  alias Pi.Agent.Result
  alias Pi.Agent.Run
  alias Pi.Agent.Step
  alias Pi.Protocol.LLM.Message
  alias Pi.Session, as: RuntimeSession
  alias Pi.Session.State

  def run(prompt_or_opts, opts \\ []) do
    session = session(prompt_or_opts, opts)

    with {:ok, runtime} <- start_runtime_session(session, opts) do
      complete_runtime(runtime, session, opts)
    end
  end

  def run!(prompt_or_opts, opts \\ []) do
    case run(prompt_or_opts, opts) do
      {:ok, result} -> result
      {:error, %Result{error: reason}} -> raise RuntimeError, message: inspect(reason)
    end
  end

  def start(task, opts \\ []) when is_binary(task), do: Manager.start_job(task, opts)

  def jobs, do: Manager.jobs()

  def status(%Job{id: id}), do: status(id)
  def status(id) when is_binary(id), do: Manager.status(id)

  def result(%Job{id: id}), do: result(id)
  def result(id) when is_binary(id), do: Manager.result(id)

  def cancel(%Job{id: id}), do: cancel(id)
  def cancel(id) when is_binary(id), do: Manager.cancel(id)

  def run_many(specs, opts \\ []) when is_list(specs) do
    jobs = Enum.map(specs, &start_job_spec(&1, opts))

    case Enum.find(jobs, &match?({:error, _reason}, &1)) do
      nil -> {:ok, Enum.map(jobs, fn {:ok, job} -> job end)}
      error -> error
    end
  end

  def async(prompt_or_opts, opts \\ []) do
    Task.async(fn -> run(prompt_or_opts, opts) end)
  end

  def await(task_or_job_or_id, timeout \\ 60_000)
  def await(%Task{} = task, timeout), do: Task.await(task, timeout)
  def await(%Job{id: id}, timeout), do: await(id, timeout)
  def await(id, timeout) when is_binary(id), do: await_job(id, deadline(timeout))

  def await_many(tasks, timeout \\ 60_000) do
    Enum.map(tasks, &await(&1, timeout))
  end

  def parallel(runs, opts \\ []) when is_list(runs) do
    {:ok, parent} = RuntimeSession.start(name: Keyword.get(opts, :name, :parallel))
    timeout = Keyword.get(opts, :timeout, 60_000)

    results =
      runs
      |> Enum.with_index(1)
      |> Enum.map(fn {run, index} ->
        Task.async(fn -> run_child(parent, run, child_opts(run, opts, index)) end)
      end)
      |> await_many(timeout + 1_000)

    if Enum.all?(results, &match?({:ok, %Result{}}, &1)) do
      {:ok, Run.ok(:parallel, Enum.map(results, fn {:ok, result} -> result end))}
    else
      {:error, Run.error(:parallel, results, :one_or_more_failed)}
    end
  end

  def fanout(inputs, opts \\ []) when is_list(inputs), do: parallel(inputs, opts)

  def chain(steps, opts \\ []) when is_list(steps) do
    case reduce_chain(steps, opts, nil, []) do
      {:ok, results} -> {:ok, Run.ok(:chain, Enum.reverse(results))}
      {:error, results, reason} -> {:error, Run.error(:chain, Enum.reverse(results), reason)}
    end
  end

  def child(%State{} = parent, opts \\ []), do: State.child(parent, opts)

  def sessions, do: RuntimeSession.list()

  def children(%State{id: id}), do: children(id)

  def children(parent_id) when is_binary(parent_id) do
    RuntimeSession.list()
    |> Enum.filter(&(&1.parent_id == parent_id))
  end

  def history(%State{id: id, messages: fallback}), do: history(id, fallback)
  def history(session_id) when is_binary(session_id), do: history(session_id, [])

  def session(prompt_or_opts, opts \\ [])

  def session(%Step{} = step, opts), do: Step.to_session(step, opts)

  def session(prompt, opts) when is_binary(prompt) do
    opts
    |> Keyword.put(:messages, [%Message{role: :user, content: prompt}])
    |> State.new()
  end

  def session(opts, extra_opts) when is_list(opts) do
    opts
    |> Keyword.merge(extra_opts)
    |> State.new()
  end

  def session(%State{} = session, _opts), do: session

  defp start_job_spec(task, opts) when is_binary(task), do: start(task, opts)

  defp start_job_spec(%{task: task} = spec, opts) when is_binary(task) do
    start(task, Keyword.merge(opts, Map.to_list(Map.delete(spec, :task))))
  end

  defp start_job_spec(spec, _opts), do: {:error, {:invalid_job_spec, spec}}

  defp await_job(id, deadline) do
    case status(id) do
      {:ok, %Job{status: :running}} ->
        if System.monotonic_time(:millisecond) >= deadline do
          {:error, :timeout}
        else
          Process.sleep(25)
          await_job(id, deadline)
        end

      {:ok, %Job{status: :done} = job} ->
        {:ok, job}

      {:ok, %Job{status: status} = job} when status in [:failed, :cancelled] ->
        {:error, job}

      {:error, reason} ->
        {:error, reason}
    end
  end

  defp deadline(timeout), do: System.monotonic_time(:millisecond) + timeout

  defp child_opts(run, opts, index) do
    opts = Keyword.delete(opts, :name)

    if child_has_name?(run) do
      opts
    else
      Keyword.put(opts, :name, child_name(run, index))
    end
  end

  defp child_has_name?(run) when is_list(run), do: Keyword.has_key?(run, :name)
  defp child_has_name?(_run), do: false

  defp child_name(prompt, _index) when is_binary(prompt) do
    prompt
    |> String.replace(~r/\s+/u, " ")
    |> String.trim()
    |> String.slice(0, 40)
  end

  defp child_name(_run, index), do: "child #{index}"

  defp run_child(parent, prompt_or_opts, opts) do
    session = session(prompt_or_opts, opts)

    with {:ok, runtime} <- start_runtime_child(parent, session, opts) do
      complete_runtime(runtime, session, opts)
    end
  end

  defp complete_runtime(runtime, %State{} = session, opts) do
    case RuntimeSession.complete(runtime, Keyword.put(opts, :agent, session.id)) do
      {:ok, result} -> {:ok, Result.ok(session, result)}
      {:error, reason} -> {:error, Result.error(session, reason)}
    end
  end

  defp start_runtime_session(%State{} = session, opts) do
    RuntimeSession.start(
      id: session.id,
      parent_id: session.parent_id,
      name: session.name,
      system: session.system,
      messages: session.messages,
      metadata: Map.merge(session.metadata, Map.new(Keyword.get(opts, :metadata, %{})))
    )
  end

  defp start_runtime_child(parent, %State{} = session, opts) do
    RuntimeSession.child(parent,
      id: session.id,
      name: session.name,
      system: session.system,
      messages: session.messages,
      metadata: Map.merge(session.metadata, Map.new(Keyword.get(opts, :metadata, %{})))
    )
  end

  defp history(session_id, fallback) do
    case RuntimeSession.lookup(session_id) do
      {:ok, pid} -> RuntimeSession.state(pid).messages
      {:error, :not_found} -> fallback
    end
  end

  defp reduce_chain([], _opts, _previous, results), do: {:ok, results}

  defp reduce_chain([step | steps], opts, previous, results) do
    input = chain_input(step, previous)

    case run(input, opts) do
      {:ok, %Result{result: result} = agent_result} ->
        reduce_chain(steps, opts, result, [agent_result | results])

      {:error, reason} ->
        {:error, results, reason}
    end
  end

  defp chain_input(step, nil), do: step

  defp chain_input(step, previous) when is_binary(step),
    do: step <> "\n\nPrevious result:\n" <> inspect(previous)

  defp chain_input(step, previous) when is_list(step) do
    Keyword.update(
      step,
      :messages,
      [%Message{role: :user, content: inspect(previous)}],
      fn messages ->
        messages ++ [%Message{role: :user, content: inspect(previous)}]
      end
    )
  end
end