lib/utils/chain_result.ex

defmodule LangChain.Utils.ChainResult do
  @moduledoc """
  Module to help when working with the results of a chain.
  """
  alias LangChain.LangChainError
  alias LangChain.Chains.LLMChain
  alias LangChain.Message
  alias LangChain.Message.ContentPart
  alias __MODULE__

  @doc """
  Return the result of the chain as a string. Returned in an `:ok` tuple format.
  An `{:error, reason}` is returned for various reasons. These include:
  - The last message of the chain is not an `:assistant` message.
  - The last message of the chain is incomplete.
  - There is no last message.

  This supports passing in the final, updated LLMChain, or the result of the
  `LLMChain.run/2` function.
  """
  @spec to_string(
          LLMChain.t()
          | {:ok, LLMChain.t()}
          | {:error, LLMChain.t(), LangChainError.t()}
        ) ::
          {:ok, String.t()} | {:error, LLMChain.t(), LangChainError.t()}
  def to_string({:error, chain, %LangChainError{} = reason}) do
    # if an error was passed in, forward it through.
    {:error, chain, reason}
  end

  def to_string({:ok, %LLMChain{} = chain}) do
    ChainResult.to_string(chain)
  end

  # when received a single ContentPart
  def to_string(
        %LLMChain{
          last_message: %Message{
            role: :assistant,
            status: :complete,
            content: [%ContentPart{type: :text} = part]
          }
        } = _chain
      ) do
    {:ok, part.content}
  end

  def to_string(%LLMChain{last_message: %Message{role: :assistant, status: :complete}} = chain) do
    {:ok, chain.last_message.content}
  end

  def to_string(%LLMChain{last_message: %Message{role: :assistant, status: _incomplete}} = chain) do
    {:error, chain, LangChainError.exception(type: "to_string", message: "Message is incomplete")}
  end

  def to_string(%LLMChain{last_message: %Message{}} = chain) do
    {:error, chain,
     LangChainError.exception(type: "to_string", message: "Message is not from assistant")}
  end

  def to_string(%LLMChain{last_message: nil} = chain) do
    {:error, chain, LangChainError.exception(type: "to_string", message: "No last message")}
  end

  @doc """
  Return the last message's content when it is valid to use it. Otherwise it
  raises and exception with the reason why it cannot be used. See the docs for
  `to_string/2` for details.
  """
  @spec to_string!(LLMChain.t() | {:ok, LLMChain.t(), Message.t()} | {:error, String.t()}) ::
          String.t() | no_return()
  def to_string!(%LLMChain{} = chain) do
    case ChainResult.to_string(chain) do
      {:ok, result} -> result
      {:error, _chain, %LangChainError{} = exception} -> raise exception
    end
  end

  @doc """
  Write the result to the given map as the value of the given key.
  """
  @spec to_map(LLMChain.t(), map(), any()) ::
          {:ok, map()} | {:error, LLMChain.t(), LangChainError.t()}
  def to_map(%LLMChain{} = chain, map, key) do
    case ChainResult.to_string(chain) do
      {:ok, value} ->
        {:ok, Map.put(map, key, value)}

      {:error, _chain, _reason} = error ->
        error
    end
  end

  @doc """
  Write the result to the given map as the value of the given key. If invalid,
  an exception is raised.
  """
  @spec to_map!(LLMChain.t(), map(), any()) :: map() | no_return()
  def to_map!(%LLMChain{} = chain, map, key) do
    case ChainResult.to_map(chain, map, key) do
      {:ok, updated} ->
        updated

      {:error, _chain, %LangChainError{} = exception} ->
        raise exception
    end
  end
end