defmodule LangChain.Chains.LLMChain do
@doc """
Define an LLMChain. This is the heart of the LangChain library.
The chain deals with functions, a function map, delta tracking, last_message
tracking, conversation messages, and verbose logging. This helps by separating
these responsibilities from the LLM making it easier to support additional
LLMs because the focus is on communication and formats instead of all the
extra logic.
"""
use Ecto.Schema
import Ecto.Changeset
require Logger
alias LangChain.PromptTemplate
alias __MODULE__
alias LangChain.Message
alias LangChain.MessageDelta
alias LangChain.Function
alias LangChain.LangChainError
@primary_key false
embedded_schema do
field :llm, :any, virtual: true
field :verbose, :boolean, default: false
field :functions, {:array, :any}, default: [], virtual: true
# set and managed privately through functions
field :function_map, :map, default: %{}, virtual: true
# List of `Message` structs for creating the conversation with the LLM.
field :messages, {:array, :any}, default: [], virtual: true
# Custom context data made available to functions when executed.
# Could include information like account ID, user data, etc.
field :custom_context, :any, virtual: true
# Track the current merged `%MessageDelta{}` struct received when streamed.
# Set to `nil` when there is no current delta being tracked. This happens
# when the final delta is received that completes the message. At that point,
# the delta is converted to a message and the delta is set to nil.
field :delta, :any, virtual: true
# Track the last `%Message{}` received in the chain.
field :last_message, :any, virtual: true
# Track if the state of the chain expects a response from the LLM. This
# happens after sending a user message or when a function_call is received,
# we've provided a function response and the LLM needs to respond.
field :needs_response, :boolean, default: false
# A callback function to execute when messages are added. Don't allow caller
# to setup in `.new` function. Want to set it from the `.run` function to
# avoid multiple chain instances (across processes) from both firing
# callbacks.
field :callback_fn, :any, virtual: true
end
@type t :: %LLMChain{}
@create_fields [:llm, :functions, :custom_context, :verbose]
@required_fields [:llm]
@doc """
Start a new LLMChain configuration.
"""
@spec new(attrs :: map()) :: {:ok, t} | {:error, Ecto.Changeset.t()}
def new(attrs \\ %{}) do
%LLMChain{}
|> cast(attrs, @create_fields)
|> common_validation()
|> apply_action(:insert)
end
@doc """
Start a new LLMChain configuration and return it or raise an error if invalid.
"""
@spec new!(attrs :: map()) :: t() | no_return()
def new!(attrs \\ %{}) do
case new(attrs) do
{:ok, chain} ->
chain
{:error, changeset} ->
raise LangChainError, changeset
end
end
def common_validation(changeset) do
changeset
|> validate_required(@required_fields)
|> validate_llm_is_struct()
|> build_functions_map_from_functions()
end
defp validate_llm_is_struct(changeset) do
case get_change(changeset, :llm) do
nil -> changeset
llm when is_struct(llm) -> changeset
_other -> add_error(changeset, :llm, "LLM must be a struct")
end
end
@doc false
def build_functions_map_from_functions(changeset) do
functions = get_field(changeset, :functions, [])
# get a list of all the functions indexed into a map by name
fun_map =
Enum.reduce(functions, %{}, fn f, acc ->
Map.put(acc, f.name, f)
end)
put_change(changeset, :function_map, fun_map)
end
@doc """
Add more functions to an LLMChain.
"""
@spec add_functions(t(), Function.t() | [Function.t()]) :: t() | no_return()
def add_functions(%LLMChain{} = chain, %Function{} = function) do
add_functions(chain, [function])
end
def add_functions(%LLMChain{functions: existing} = chain, functions) when is_list(functions) do
updated = existing ++ functions
chain
|> change()
|> cast(%{functions: updated}, [:functions])
|> build_functions_map_from_functions()
|> apply_action!(:update)
end
@doc """
Run the chain on the LLM using messages and any registered functions. This
formats the request for a ChatLLMChain where messages are passed to the API.
When successful, it returns `{:ok, updated_chain, message_or_messages}`
## Options
- `:while_needs_response` - repeatedly evaluates functions and submits to the
LLM so long as we still expect to get a response.
- `:callback_fn` - the callback function to execute as messages are received.
"""
@spec run(t(), Keyword.t()) :: {:ok, t(), Message.t() | [Message.t()]} | {:error, String.t()}
def run(chain, opts \\ [])
def run(%LLMChain{} = chain, opts) do
# set the callback function on the chain
chain = %LLMChain{chain | callback_fn: Keyword.get(opts, :callback_fn)}
if chain.verbose, do: IO.inspect(chain.llm, label: "LLM")
if chain.verbose, do: IO.inspect(chain.messages, label: "MESSAGES")
functions = chain.functions
if chain.verbose, do: IO.inspect(functions, label: "FUNCTIONS")
if Keyword.get(opts, :while_needs_response, false) do
run_while_needs_response(chain)
else
# run the chain and format the return
case do_run(chain) do
{:ok, chain} ->
{:ok, chain, chain.last_message}
{:error, _reason} = error ->
error
end
end
end
# Repeatedly run the chain while `needs_response` is true. This will execute
# functions and re-submit the function result to the LLM giving the LLM an
# opportunity to execute more functions or return a response.
@spec run_while_needs_response(t()) :: {:ok, t(), Message.t()} | {:error, String.t()}
defp run_while_needs_response(%LLMChain{needs_response: false} = chain) do
{:ok, chain, chain.last_message}
end
defp run_while_needs_response(%LLMChain{needs_response: true} = chain) do
chain
|> execute_function()
|> do_run()
|> case do
{:ok, updated_chain} ->
run_while_needs_response(updated_chain)
{:error, reason} ->
{:error, reason}
end
end
# internal reusable function for running the chain
@spec do_run(t()) :: {:ok, t()} | {:error, String.t()}
defp do_run(%LLMChain{} = chain) do
# submit to LLM. The "llm" is a struct. Match to get the name of the module
# then execute the `.call` function on that module.
%module{} = chain.llm
# handle and output response
case module.call(chain.llm, chain.messages, chain.functions, chain.callback_fn) do
{:ok, [%Message{} = message]} ->
if chain.verbose, do: IO.inspect(message, label: "SINGLE MESSAGE RESPONSE")
{:ok, add_message(chain, message)}
{:ok, [%Message{} = message, _others] = messages} ->
if chain.verbose, do: IO.inspect(messages, label: "MULTIPLE MESSAGE RESPONSE")
# return the list of message responses. Happens when multiple
# "choices" are returned from LLM by request.
{:ok, add_message(chain, message)}
{:ok, [[%MessageDelta{} | _] | _] = deltas} ->
if chain.verbose, do: IO.inspect(deltas, label: "DELTA MESSAGE LIST RESPONSE")
{:ok, apply_deltas(chain, deltas)}
{:error, reason} ->
if chain.verbose, do: IO.inspect(reason, label: "ERROR")
Logger.error("Error during chat call. Reason: #{inspect(reason)}")
{:error, reason}
end
end
@doc """
Apply a received MessageDelta struct to the chain. The LLMChain tracks the
current merged MessageDelta state. When the final delta is received that
completes the message, the LLMChain is updated to clear the `delta` and the
`last_message` and list of messages are updated.
"""
@spec apply_delta(t(), MessageDelta.t()) :: t()
def apply_delta(%LLMChain{delta: nil} = chain, %MessageDelta{} = new_delta) do
%LLMChain{chain | delta: new_delta}
end
def apply_delta(%LLMChain{delta: %MessageDelta{} = delta} = chain, %MessageDelta{} = new_delta) do
merged = MessageDelta.merge_delta(delta, new_delta)
# if the merged delta is now complete, updates as a message.
if merged.status in [:complete, :length] do
case MessageDelta.to_message(merged) do
{:ok, %Message{} = message} ->
fire_callback(chain, message)
add_message(%LLMChain{chain | delta: nil}, message)
{:error, reason} ->
# should not have failed, but it did. Log the error and return
# the chain unmodified.
Logger.warning("Error applying delta message. Reason: #{inspect(reason)}")
chain
end
else
# the delta message is not yet complete. Update the delta with the merged
# result.
%LLMChain{chain | delta: merged}
end
end
@doc """
Apply a list of deltas to the chain.
"""
@spec apply_deltas(t(), list()) :: t()
def apply_deltas(%LLMChain{} = chain, deltas) when is_list(deltas) do
deltas
|> List.flatten()
|> Enum.reduce(chain, fn d, acc -> apply_delta(acc, d) end)
end
@doc """
Add a received Message struct to the chain. The LLMChain tracks the
`last_message` received and the complete list of messages exchanged. Depending
on the message role, the chain may be in a pending or incomplete state where
a response from the LLM is anticipated.
"""
@spec add_message(t(), Message.t()) :: t()
def add_message(%LLMChain{} = chain, %Message{} = new_message) do
needs_response =
cond do
new_message.role in [:user, :function_call, :function] -> true
Message.is_function_call?(new_message) -> true
new_message.role in [:system, :assistant] -> false
end
%LLMChain{
chain
| messages: chain.messages ++ [new_message],
last_message: new_message,
needs_response: needs_response
}
end
@doc """
Add a set of Message structs to the chain. This enables quickly building a chain
for submitting to an LLM.
"""
@spec add_messages(t(), [Message.t()]) :: t()
def add_messages(%LLMChain{} = chain, messages) do
Enum.reduce(messages, chain, fn msg, acc ->
add_message(acc, msg)
end)
end
@doc """
Apply a set of PromptTemplates to the chain. The list of templates can also
include Messages with no templates. Provide the inputs to apply to the
templates for rendering as a message. The prepared messages are applied to the
chain.
"""
@spec apply_prompt_templates(t(), [Message.t() | PromptTemplate.t()], %{atom() => any()}) ::
t() | no_return()
def apply_prompt_templates(%LLMChain{} = chain, templates, %{} = inputs) do
messages = PromptTemplate.to_messages!(templates, inputs)
add_messages(chain, messages)
end
@doc """
Convenience function for setting the prompt text for the LLMChain using
prepared text.
"""
@spec quick_prompt(t(), String.t()) :: t()
def quick_prompt(%LLMChain{} = chain, text) do
messages = [
Message.new_system!(),
Message.new_user!(text)
]
add_messages(chain, messages)
end
@doc """
If the `last_message` is a `%Message{role: :function_call}`, then the linked
function is executed. If there is no `last_message` or the `last_message` is
not a `:function_call`, the LLMChain is returned with no action performed.
This makes it safe to call any time.
The `context` is additional data that will be passed to the executed function.
The value given here will override any `custom_context` set on the LLMChain.
If not set, the global `custom_context` is used.
https://platform.openai.com/docs/guides/gpt/function-calling
"""
@spec execute_function(t(), context :: any()) :: t()
def execute_function(chain, context \\ nil)
def execute_function(%LLMChain{last_message: nil} = chain, _context), do: chain
def execute_function(
%LLMChain{last_message: %Message{} = message} = chain,
context
) do
if Message.is_function_call?(message) do
# context to use
use_context = context || chain.custom_context
# find and execute the linked function
case chain.function_map[message.function_name] do
%Function{} = function ->
if chain.verbose, do: IO.inspect(function.name, label: "EXECUTING FUNCTION")
# execute the function
result = Function.execute(function, message.arguments, use_context)
if chain.verbose, do: IO.inspect(result, label: "FUNCTION RESULT")
# add the :function response to the chain
function_result = Message.new_function!(function.name, result)
# fire the callback as this is newly generated message
fire_callback(chain, function_result)
LLMChain.add_message(chain, function_result)
nil ->
Logger.warning(
"Received function_call for missing function #{inspect(message.function_name)}"
)
chain
end
else
# Either not a function_call or an incomplete function_call, do nothing.
chain
end
end
# Fire the callback if set.
defp fire_callback(%LLMChain{callback_fn: nil}, _data), do: :ok
# OPTIONAL: Execute callback function
defp fire_callback(%LLMChain{callback_fn: callback_fn}, data) when is_function(callback_fn) do
case data do
value when is_list(value) ->
value
|> List.flatten()
|> Enum.each(fn item -> callback_fn.(item) end)
:ok
# not a list, pass the item as-is
item ->
callback_fn.(item)
:ok
end
end
@doc """
Remove an incomplete MessageDelta from `delta` and add a Message with the
desired status to the chain.
"""
def cancel_delta(%LLMChain{delta: nil} = chain, _message_status), do: chain
def cancel_delta(%LLMChain{delta: delta} = chain, message_status) do
# remove the in-progress delta
updated_chain = %LLMChain{chain | delta: nil}
case MessageDelta.to_message(%MessageDelta{delta | status: :complete}) do
{:ok, message} ->
message = %Message{message | status: message_status}
add_message(updated_chain, message)
{:error, reason} ->
Logger.error("Error attempting to cancel_delta. Reason: #{inspect(reason)}")
chain
end
end
end