Skip to main content

lib/omni/ui/turn.ex

defmodule Omni.UI.Turn do
  @moduledoc """
  A UI-oriented view of a conversation exchange.

  Each turn collapses a sequence of tree nodes — a user prompt, any intermediate
  tool-use rounds, and the final assistant response — into a single renderable
  struct. `all/1` walks the active path and chunks it into turns; `get/2`
  returns a single turn by its starting node ID; `new/3` builds a turn from
  raw messages (used during streaming).

  ## Fields

    * `id` — node ID of the user message that starts this turn.
    * `res_id` — node ID of the first assistant message in this turn, or `nil`
      if no assistant response exists yet (e.g. turn just submitted).
    * `status` — `:complete`, `:streaming`, or `:error`.
    * `user_text` — list of `Omni.Content.Text` blocks from the user message.
    * `user_attachments` — list of `Omni.Content.Attachment` blocks from the
      user message.
    * `user_timestamp` — `DateTime` from the user's message.
    * `content` — all assistant content blocks (`Text`, `Thinking`, `ToolUse`)
      accumulated across all assistant messages in the turn.
    * `timestamp` — `DateTime` from the last assistant message, `nil` while
      streaming.
    * `tool_results` — `%{tool_use_id => ToolResult}` extracted from
      intermediate user messages.
    * `error` — error reason string when `status == :error`.
    * `usage` — cumulative `Omni.Usage` for this turn.

  ## Branching metadata

    * `edits` — sorted node IDs of **all** user messages that share the same
      parent as this turn's user message, **including the active node** (`id`).
      Length > 1 means the user edited their prompt. The active node is included
      so that UI components can compute position (e.g. "2/3") and prev/next
      navigation without needing to re-insert it.
    * `regens` — sorted node IDs of **all** assistant messages that are children
      of this turn's user message, **including the active node** (`res_id`).
      Length > 1 means the user regenerated the response.
  """

  alias Omni.Session.Tree

  defstruct [
    :id,
    :res_id,
    status: :complete,
    edits: [],
    regens: [],
    user_text: [],
    user_attachments: [],
    user_timestamp: nil,
    content: [],
    timestamp: nil,
    tool_results: %{},
    error: nil,
    usage: %Omni.Usage{}
  ]

  @typedoc """
  A UI-oriented view of a single conversation exchange. See module docs for fields.
  """
  @type t :: %__MODULE__{
          id: Tree.node_id() | nil,
          res_id: Tree.node_id() | nil,
          status: :complete | :streaming | :error,
          edits: [Tree.node_id()],
          regens: [Tree.node_id()],
          user_text: [Omni.Content.Text.t()],
          user_attachments: [Omni.Content.Attachment.t()],
          user_timestamp: DateTime.t() | nil,
          content: [Omni.Message.content()],
          timestamp: DateTime.t() | nil,
          tool_results: %{String.t() => Omni.Content.ToolResult.t()},
          error: String.t() | nil,
          usage: Omni.Usage.t()
        }

  @doc """
  Builds a turn from a node ID, a list of messages, and cumulative usage.

  The first message must be the user prompt. Subsequent messages are reduced
  into the turn: assistant messages append to `content`, and user messages
  containing tool results are collected into `tool_results`.
  """
  @spec new(Tree.node_id() | nil, [Omni.Message.t()], Omni.Usage.t()) :: t()
  def new(node_id, [user | rest], %Omni.Usage{} = usage) do
    turn = %__MODULE__{
      id: node_id,
      user_text: Enum.filter(user.content, &match?(%Omni.Content.Text{}, &1)),
      user_attachments: Enum.filter(user.content, &match?(%Omni.Content.Attachment{}, &1)),
      user_timestamp: user.timestamp,
      usage: usage
    }

    Enum.reduce(rest, turn, &reduce_content/2)
  end

  @doc """
  Converts a tree's active path into a list of turns.

  Chunks the path at user-message boundaries (skipping tool-result user messages),
  then builds a turn from each chunk with edits and regens populated from the
  full tree structure.
  """
  @spec all(Tree.t()) :: [t()]
  def all(%Tree{} = tree) do
    children_map = children_map(tree)

    tree
    |> Enum.chunk_while([], &tree_chunk/2, &after_tree_chunk/1)
    |> Enum.map(&from_tree_nodes(&1, children_map))
  end

  @doc """
  Returns a single turn from the tree starting at the given node ID.

  Walks the active path forward from `node_id`, collecting nodes until the
  next turn boundary (a non-tool-result user message), then builds a turn
  with branching metadata from the full tree structure.

  Returns `nil` if `node_id` is not on the active path.
  """
  @spec get(Tree.t(), Tree.node_id()) :: t() | nil
  def get(%Tree{} = tree, node_id) do
    case tree.path |> Enum.drop_while(&(&1 != node_id)) do
      [] ->
        nil

      ids ->
        [first | rest] = Enum.map(ids, &tree.nodes[&1])

        turn_nodes =
          Enum.take_while(rest, fn node ->
            not turn_boundary?(node.message)
          end)

        from_tree_nodes([first | turn_nodes], children_map(tree))
    end
  end

  @doc """
  Appends a content block to the turn's assistant content.

  Called during streaming when the agent starts a new content block
  (e.g. a new `Text` or `Thinking` block).
  """
  @spec push_content(t(), Omni.Message.content()) :: t()
  def push_content(%__MODULE__{} = turn, content_block) do
    %{turn | content: turn.content ++ [content_block]}
  end

  @doc """
  Appends a text delta to the last content block.

  Called during streaming as text chunks arrive from the agent. Assumes the
  last content block has a `:text` field (i.e. `push_content/2` was called
  first to start the block).
  """
  @spec push_delta(t(), String.t()) :: t()
  def push_delta(%__MODULE__{} = turn, delta) do
    content = List.update_at(turn.content, -1, &%{&1 | text: &1.text <> delta})
    %{turn | content: content}
  end

  @doc """
  Replaces a content block in the turn by matching its `id`.

  Called during streaming when a content block is finalised — e.g. a tool-use
  block that started as a stub on `:tool_use_start` and is replaced with the
  fully-formed struct on `:tool_use_end`. For blocks with an `id` field (like
  `ToolUse`), the match is by id so parallel blocks don't clobber each other.
  Falls back to replacing the last block for id-less content types.
  """
  @spec replace_content(t(), Omni.Message.content()) :: t()
  def replace_content(%__MODULE__{} = turn, %{id: id} = content_block) when id != nil do
    case Enum.find_index(turn.content, &(Map.get(&1, :id) == id)) do
      nil ->
        raise ArgumentError, "no content block with id #{inspect(id)} in turn #{inspect(turn.id)}"

      idx ->
        content = List.replace_at(turn.content, idx, content_block)
        %{turn | content: content}
    end
  end

  def replace_content(%__MODULE__{} = turn, content_block) do
    content = List.replace_at(turn.content, -1, content_block)
    %{turn | content: content}
  end

  @doc """
  Stores a tool result, keyed by its `tool_use_id`.

  Called during streaming when a tool execution completes. The result is
  stored so the corresponding `ToolUse` content block can render it.
  """
  @spec put_tool_result(t(), Omni.Content.ToolResult.t()) :: t()
  def put_tool_result(%__MODULE__{} = turn, tool_result) do
    tool_results = Map.put(turn.tool_results, tool_result.tool_use_id, tool_result)
    %{turn | tool_results: tool_results}
  end

  @doc """
  Returns the concatenated text content for the given role in a turn.

  Multiple text blocks are joined with double newlines. For `:assistant`,
  non-text content blocks (e.g. `Thinking`, `ToolUse`) are filtered out.
  """
  @spec get_text(t(), :user | :assistant) :: String.t()
  def get_text(%__MODULE__{user_text: texts}, :user) do
    texts |> Enum.map(& &1.text) |> Enum.join("\n\n")
  end

  def get_text(%__MODULE__{content: content}, :assistant) do
    content
    |> Enum.filter(&match?(%Omni.Content.Text{}, &1))
    |> Enum.map(& &1.text)
    |> Enum.join("\n\n")
  end

  # Private

  defp from_tree_nodes([%{id: node_id, parent_id: parent_id} | rest] = nodes, children_map) do
    usage =
      rest
      |> Enum.filter(&match?(%Omni.Usage{}, &1.usage))
      |> Enum.reduce(%Omni.Usage{}, &Omni.Usage.add(&2, &1.usage))

    res_id =
      case rest do
        [%{id: id} | _] -> id
        [] -> nil
      end

    edits = Map.get(children_map, parent_id, [])
    regens = Map.get(children_map, node_id, [])

    turn = new(node_id, Enum.map(nodes, & &1.message), usage)
    %{turn | res_id: res_id, edits: edits, regens: regens}
  end

  defp children_map(%Tree{nodes: nodes}) do
    nodes
    |> Enum.reduce(%{}, fn {id, node}, acc ->
      Map.update(acc, node.parent_id, [id], &[id | &1])
    end)
    |> Map.new(fn {k, v} -> {k, Enum.sort(v)} end)
  end

  defp reduce_content(%Omni.Message{role: :user} = msg, turn) do
    tool_results =
      msg.content
      |> Enum.filter(&match?(%Omni.Content.ToolResult{}, &1))
      |> Enum.reduce(turn.tool_results, &Map.put(&2, &1.tool_use_id, &1))

    %{turn | tool_results: tool_results}
  end

  defp reduce_content(%Omni.Message{role: :assistant} = msg, turn) do
    %{turn | content: turn.content ++ msg.content, timestamp: msg.timestamp}
  end

  defp turn_boundary?(%Omni.Message{role: :user, content: content}) do
    not Enum.any?(content, &match?(%Omni.Content.ToolResult{}, &1))
  end

  defp turn_boundary?(_message), do: false

  # No guard for a leading assistant node with an empty accumulator — if the
  # active path starts with an assistant message, something upstream is wrong
  # and we want a crash rather than silently dropping nodes.
  defp tree_chunk(node, acc) do
    if turn_boundary?(node.message) and acc != [] do
      {:cont, Enum.reverse(acc), [node]}
    else
      {:cont, [node | acc]}
    end
  end

  defp after_tree_chunk([]), do: {:cont, []}
  defp after_tree_chunk(acc), do: {:cont, Enum.reverse(acc), []}
end