defmodule Instructor do
require Logger
alias Instructor.JSONSchema
@external_resource "README.md"
[_, readme_docs, _] =
"README.md"
|> File.read!()
|> String.split("<!-- Docs -->")
@moduledoc """
#{readme_docs}
"""
@doc """
Create a new chat completion for the provided messages and parameters.
The parameters are passed directly to the LLM adapter.
By default they shadow the OpenAI API parameters.
For more information on the parameters, see the [OpenAI API docs](https://platform.openai.com/docs/api-reference/chat-completions/create).
Additionally, the following parameters are supported:
* `:response_model` - The Ecto schema to validate the response against.
* `:max_retries` - The maximum number of times to retry the LLM call if it fails, or does not pass validations.
(defaults to `0`)
## Examples
iex> Instructor.chat_completion(%{
...> model: "gpt-3.5-turbo",
...> response_model: Instructor.Demos.SpamPrediction,
...> messages: [
...> %{
...> role: "user",
...> content: "Classify the following text: Hello, I am a Nigerian prince and I would like to give you $1,000,000."
...> }
...> })
{:ok,
%Instructor.Demos.SpamPrediction{
class: :spam
score: 0.999
}}
"""
@spec chat_completion(Keyword.t()) ::
{:ok, Ecto.Schema.t()} | {:error, Ecto.Changeset.t()} | {:error, String.t()}
def chat_completion(params) do
response_model = Keyword.get(params, :response_model)
validate =
if function_exported?(response_model, :validate_changeset, 1) do
&response_model.validate_changeset/1
else
fn x -> x end
end
params =
params
|> Keyword.put(:validate, validate)
|> Keyword.put_new(:max_retries, 0)
|> Keyword.put_new(:mode, :tools)
do_chat_completion(params)
end
defp do_chat_completion(params) do
response_model = params[:response_model]
validate = params[:validate]
max_retries = params[:max_retries]
mode = Keyword.get(params, :mode, :tools)
params = params_for_tool(mode, params)
with {:llm, {:ok, response}} <- {:llm, adapter().chat_completion(params)},
{:valid_json, {:ok, params}} <- {:valid_json, parse_response_for_mode(mode, response)},
changeset <- to_changeset(response_model.__struct__(), params),
{:validation, %Ecto.Changeset{valid?: true} = changeset, _response} <-
{:validation, validate.(changeset), response} do
{:ok, changeset |> Ecto.Changeset.apply_changes()}
else
{:llm, {:error, error}} ->
{:error, "LLM Adapter Error: #{inspect(error)}"}
{:valid_json, {:error, error}} ->
{:error, "Invalid JSON returned from LLM: #{inspect(error)}"}
{:validation, changeset, response} ->
if max_retries > 0 do
errors = format_errors(changeset)
Logger.debug("Retrying LLM call for #{response_model}...", errors: errors)
params =
params
|> Keyword.put(:max_retries, max_retries - 1)
|> Keyword.update(:messages, [], fn messages ->
messages ++
echo_response(response) ++
[
%{
role: "system",
content: """
The response did not pass validation. Please try again and fix the following validation errors:\n
#{errors}
"""
}
]
end)
do_chat_completion(params)
else
{:error, changeset}
end
{:error, reason} ->
{:error, reason}
e ->
{:error, e}
end
end
defp parse_response_for_mode(:tools, %{
choices: [
%{
"message" => %{
"tool_calls" => [%{"function" => %{"arguments" => args}}]
}
}
]
}) do
Jason.decode(args)
end
defp echo_response(%{choices: [%{"message" => %{"tool_calls" => [function]}}]}) do
[
%{
role: "assistant",
content: Jason.encode!(function)
}
]
end
#
# Though technically correct for the tools api, seems to yield worse results.
# Leaving here to investigate further later.
# defp echo_response(%{
# choices: [
# %{
# "message" =>
# %{
# "tool_calls" => [
# %{"id" => tool_call_id, "function" => %{"name" => name, "arguments" => args}} =
# function
# ]
# } = message
# }
# ]
# }) do
# [
# Map.put(message, "content", function |> Jason.encode!()),
# %{
# role: "tool",
# tool_call_id: tool_call_id,
# name: name,
# content: args
# }
# ]
# end
defp params_for_tool(:tools, params) do
response_model = Keyword.fetch!(params, :response_model)
json_schema = JSONSchema.from_ecto_schema(response_model)
title = JSONSchema.title_for(response_model)
params =
params
|> Keyword.update(:messages, [], fn messages ->
sys_message = %{
role: "system",
content: """
As a genius expert, your task is to understand the content and provide the parsed objects in json that match the following json_schema:\n
#{json_schema}
"""
}
[sys_message | messages]
end)
|> Keyword.put(:tools, [
%{
type: "function",
function: %{
"description" =>
"Correctly extracted `#{title}` with all the required parameters with correct types",
"name" => title,
"parameters" => json_schema |> Jason.decode!()
}
}
])
|> Keyword.put(:tool_choice, %{
type: "function",
function: %{name: title}
})
params
end
defp to_changeset(schema, params) do
response_model = schema.__struct__
fields = response_model.__schema__(:fields) |> MapSet.new()
embedded_fields = response_model.__schema__(:embeds) |> MapSet.new()
associated_fields = response_model.__schema__(:associations) |> MapSet.new()
fields =
fields
|> MapSet.difference(embedded_fields)
|> MapSet.difference(associated_fields)
changeset =
schema
|> Ecto.Changeset.cast(params, fields |> MapSet.to_list())
changeset =
for field <- embedded_fields, reduce: changeset do
changeset ->
changeset
|> Ecto.Changeset.cast_embed(field, with: &to_changeset/2)
end
changeset =
for field <- associated_fields, reduce: changeset do
changeset ->
changeset
|> Ecto.Changeset.cast_assoc(field, with: &to_changeset/2)
end
changeset
end
defp format_errors(changeset) do
errors =
Ecto.Changeset.traverse_errors(changeset, fn _changeset, _field, {msg, opts} ->
msg =
Regex.replace(~r"%{(\w+)}", msg, fn _, key ->
opts |> Keyword.get(String.to_existing_atom(key), key) |> to_string()
end)
"#{msg}"
end)
|> Map.values()
|> List.flatten()
Enum.join(errors, ", and ")
end
defp adapter() do
Application.get_env(:instructor, :adapter, Instructor.Adapters.OpenAI)
end
end