lib/utils/chain_result.ex

defmodule LangChain.Utils.ChainResult do
  @moduledoc """
  Module to help when working with the results of a chain.
  """
  alias LangChain.LangChainError
  alias LangChain.Chains.LLMChain
  alias LangChain.Message
  alias __MODULE__

  @doc """
  Return the result of the chain as a string. Returned in an `:ok` tuple format.
  An `{:error, reason}` is returned for various reasons. These include:
  - The last message of the chain is not an `:assistant` message.
  - The last message of the chain is incomplete.
  - There is no last message.
  """
  @spec to_string(LLMChain.t()) :: {:ok, String.t()} | {:error, String.t()}
  def to_string(%LLMChain{last_message: %Message{role: :assistant, status: :complete}} = chain) do
    {:ok, chain.last_message.content}
  end

  def to_string(%LLMChain{last_message: %Message{role: :assistant, status: _incomplete}} = _chain) do
    {:error, "Message is incomplete"}
  end

  def to_string(%LLMChain{last_message: %Message{}} = _chain) do
    {:error, "Message is not from assistant"}
  end

  def to_string(%LLMChain{last_message: nil} = _chain) do
    {:error, "No last message"}
  end

  @doc """
  Return the last message's content when it is valid to use it. Otherwise it
  raises and exception with the reason why it cannot be used. See the docs for
  `LangChain.Utils.ChainResult.to_string/2` for details.
  """
  @spec to_string!(LLMChain.t()) :: String.t() | no_return()
  def to_string!(%LLMChain{} = chain) do
    case ChainResult.to_string(chain) do
      {:ok, result} -> result
      {:error, reason} -> raise LangChainError, reason
    end
  end

  @doc """
  Write the result to the given map as the value of the given key.
  """
  @spec to_map(LLMChain.t(), map(), any()) :: {:ok, map()} | {:error, String.t()}
  def to_map(%LLMChain{} = chain, map, key) do
    case ChainResult.to_string(chain) do
      {:ok, value} ->
        {:ok, Map.put(map, key, value)}

      {:error, _reason} = error ->
        error
    end
  end

  @doc """
  Write the result to the given map as the value of the given key. If invalid,
  an exception is raised.
  """
  @spec to_map!(LLMChain.t(), map(), any()) :: map() | no_return()
  def to_map!(%LLMChain{} = chain, map, key) do
    case ChainResult.to_map(chain, map, key) do
      {:ok, updated} ->
        updated

      {:error, reason} ->
        raise LangChainError, reason
    end
  end
end