lib/instructor.ex

defmodule Instructor do
  require Logger

  alias Instructor.JSONSchema

  @external_resource "README.md"

  [_, readme_docs, _] =
    "README.md"
    |> File.read!()
    |> String.split("<!-- Docs -->")

  @moduledoc """
  #{readme_docs}
  """

  defguardp is_ecto_schema(mod) when is_atom(mod)

  @doc """
  Create a new chat completion for the provided messages and parameters.

  The parameters are passed directly to the LLM adapter.
  By default they shadow the OpenAI API parameters.
  For more information on the parameters, see the [OpenAI API docs](https://platform.openai.com/docs/api-reference/chat-completions/create).

  Additionally, the following parameters are supported:

    * `:response_model` - The Ecto schema to validate the response against.
    * `:max_retries` - The maximum number of times to retry the LLM call if it fails, or does not pass validations.
                       (defaults to `0`)

  ## Examples

    iex> Instructor.chat_completion(%{
    ...>   model: "gpt-3.5-turbo",
    ...>   response_model: Instructor.Demos.SpamPrediction,
    ...>   messages: [
    ...>     %{
    ...>       role: "user",
    ...>       content: "Classify the following text: Hello, I am a Nigerian prince and I would like to give you $1,000,000."
    ...>     }
    ...> })
    {:ok,
        %Instructor.Demos.SpamPrediction{
            class: :spam
            score: 0.999
        }}
  """
  @spec chat_completion(Keyword.t()) ::
          {:ok, Ecto.Schema.t()} | {:error, Ecto.Changeset.t()} | {:error, String.t()}
  def chat_completion(params) do
    params =
      params
      |> Keyword.put_new(:max_retries, 0)
      |> Keyword.put_new(:mode, :tools)

    is_stream = Keyword.get(params, :stream, false)
    response_model = Keyword.fetch!(params, :response_model)

    case {response_model, is_stream} do
      {{:partial, {:array, response_model}}, true} ->
        do_streaming_partial_array_chat_completion(response_model, params)

      {{:partial, response_model}, true} ->
        do_streaming_partial_chat_completion(response_model, params)

      {{:array, response_model}, true} ->
        do_streaming_array_chat_completion(response_model, params)

      {response_model, false} ->
        do_chat_completion(response_model, params)
    end
  end

  def cast_all({data, types}, params) do
    fields = Map.keys(types)

    {data, types}
    |> Ecto.Changeset.cast(params, fields)
    |> Ecto.Changeset.validate_required(fields)
  end

  def cast_all(schema, params) do
    response_model = schema.__struct__
    fields = response_model.__schema__(:fields) |> MapSet.new()
    embedded_fields = response_model.__schema__(:embeds) |> MapSet.new()
    associated_fields = response_model.__schema__(:associations) |> MapSet.new()

    fields =
      fields
      |> MapSet.difference(embedded_fields)
      |> MapSet.difference(associated_fields)

    changeset =
      schema
      |> Ecto.Changeset.cast(params, fields |> MapSet.to_list())

    changeset =
      for field <- embedded_fields, reduce: changeset do
        changeset ->
          changeset
          |> Ecto.Changeset.cast_embed(field, with: &cast_all/2)
      end

    changeset =
      for field <- associated_fields, reduce: changeset do
        changeset ->
          changeset
          |> Ecto.Changeset.cast_assoc(field, with: &cast_all/2)
      end

    changeset
  end

  defp do_streaming_partial_array_chat_completion(response_model, params) do
    wrapped_model = %{
      value:
        {:parameterized, Ecto.Embedded,
         %Ecto.Embedded{cardinality: :many, related: response_model}}
    }

    validation_context = Keyword.get(params, :validation_context, %{})
    mode = Keyword.get(params, :mode, :tools)
    params = params_for_tool(mode, wrapped_model, params)

    adapter().chat_completion(params)
    |> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
    |> Instructor.JSONStreamParser.parse()
    |> Stream.map(fn params ->
      params = Map.get(params, "value", [])

      Enum.map(params, fn params ->
        model =
          if is_ecto_schema(response_model) do
            response_model.__struct__()
          else
            {%{}, response_model}
          end

        with changeset <- cast_all(model, params),
             {:validation, %Ecto.Changeset{valid?: true} = changeset} <-
               {:validation, call_validate(response_model, changeset, validation_context)} do
          {:ok, changeset |> Ecto.Changeset.apply_changes()}
        else
          {:validation, changeset} -> {:error, changeset}
          {:error, reason} -> {:error, reason}
          e -> {:error, e}
        end
      end)
    end)
  end

  defp do_streaming_partial_chat_completion(response_model, params) do
    wrapped_model = %{
      value:
        {:parameterized, Ecto.Embedded,
         %Ecto.Embedded{cardinality: :one, related: response_model}}
    }

    validation_context = Keyword.get(params, :validation_context, %{})
    mode = Keyword.get(params, :mode, :tools)
    params = params_for_tool(mode, wrapped_model, params)

    adapter().chat_completion(params)
    |> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
    |> Instructor.JSONStreamParser.parse()
    |> Stream.map(fn params ->
      params = Map.get(params, "value", %{})

      model =
        if is_ecto_schema(response_model) do
          response_model.__struct__()
        else
          {%{}, response_model}
        end

      with changeset <- cast_all(model, params),
           {:validation, %Ecto.Changeset{valid?: true} = changeset} <-
             {:validation, call_validate(response_model, changeset, validation_context)} do
        {:ok, changeset |> Ecto.Changeset.apply_changes()}
      else
        {:validation, changeset} -> {:error, changeset}
        {:error, reason} -> {:error, reason}
        e -> {:error, e}
      end
    end)
  end

  defp do_streaming_array_chat_completion(response_model, params) do
    wrapped_model = %{
      value:
        {:parameterized, Ecto.Embedded,
         %Ecto.Embedded{cardinality: :many, related: response_model}}
    }

    validation_context = Keyword.get(params, :validation_context, %{})
    mode = Keyword.get(params, :mode, :tools)
    params = params_for_tool(mode, wrapped_model, params)

    adapter().chat_completion(params)
    |> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
    |> Jaxon.Stream.from_enumerable()
    |> Jaxon.Stream.query([:root, "value", :all])
    |> Stream.map(fn params ->
      model =
        if is_ecto_schema(response_model) do
          response_model.__struct__()
        else
          {%{}, response_model}
        end

      with changeset <- cast_all(model, params),
           {:validation, %Ecto.Changeset{valid?: true} = changeset} <-
             {:validation, call_validate(response_model, changeset, validation_context)} do
        {:ok, changeset |> Ecto.Changeset.apply_changes()}
      else
        {:validation, changeset} -> {:error, changeset}
        {:error, reason} -> {:error, reason}
        e -> {:error, e}
      end
    end)
  end

  defp do_chat_completion(response_model, params) do
    validation_context = Keyword.get(params, :validation_context, %{})
    max_retries = Keyword.get(params, :max_retries)
    mode = Keyword.get(params, :mode, :tools)
    params = params_for_tool(mode, response_model, params)

    model =
      if is_ecto_schema(response_model) do
        response_model.__struct__()
      else
        {%{}, response_model}
      end

    with {:llm, {:ok, response}} <- {:llm, adapter().chat_completion(params)},
         {:valid_json, {:ok, params}} <- {:valid_json, parse_response_for_mode(mode, response)},
         changeset <- cast_all(model, params),
         {:validation, %Ecto.Changeset{valid?: true} = changeset, _response} <-
           {:validation, call_validate(response_model, changeset, validation_context), response} do
      {:ok, changeset |> Ecto.Changeset.apply_changes()}
    else
      {:llm, {:error, error}} ->
        {:error, "LLM Adapter Error: #{inspect(error)}"}

      {:valid_json, {:error, error}} ->
        {:error, "Invalid JSON returned from LLM: #{inspect(error)}"}

      {:validation, changeset, response} ->
        if max_retries > 0 do
          errors = Instructor.ErrorFormatter.format_errors(changeset)
          Logger.debug("Retrying LLM call for #{inspect(response_model)}...", errors: errors)

          params =
            params
            |> Keyword.put(:max_retries, max_retries - 1)
            |> Keyword.update(:messages, [], fn messages ->
              messages ++
                echo_response(response) ++
                [
                  %{
                    role: "system",
                    content: """
                    The response did not pass validation. Please try again and fix the following validation errors:\n

                    #{errors}
                    """
                  }
                ]
            end)

          do_chat_completion(response_model, params)
        else
          {:error, changeset}
        end

      {:error, reason} ->
        {:error, reason}

      e ->
        {:error, e}
    end
  end

  defp parse_response_for_mode(:tools, %{
         choices: [
           %{
             "message" => %{
               "tool_calls" => [%{"function" => %{"arguments" => args}}]
             }
           }
         ]
       }) do
    Jason.decode(args)
  end

  defp parse_stream_chunk_for_mode(:tools, %{
         "choices" => [
           %{"delta" => %{"tool_calls" => [%{"function" => %{"arguments" => chunk}}]}}
         ]
       }),
       do: chunk

  defp parse_stream_chunk_for_mode(:tools, %{"choices" => [%{"finish_reason" => "stop"}]}), do: ""

  defp echo_response(%{
         choices: [
           %{
             "message" =>
               %{
                 "tool_calls" => [
                   %{"id" => tool_call_id, "function" => %{"name" => name, "arguments" => args}} =
                     function
                 ]
               } = message
           }
         ]
       }) do
    [
      Map.put(message, "content", function |> Jason.encode!())
      |> Map.new(fn {k, v} -> {String.to_atom(k), v} end),
      %{
        role: "tool",
        tool_call_id: tool_call_id,
        name: name,
        content: args
      }
    ]
  end

  defp params_for_tool(:tools, response_model, params) do
    json_schema = JSONSchema.from_ecto_schema(response_model)
    title = JSONSchema.title_for(response_model) |> sanitize()

    params =
      params
      |> Keyword.update(:messages, [], fn messages ->
        sys_message = %{
          role: "system",
          content: """
          As a genius expert, your task is to understand the content and provide the parsed objects in json that match the following json_schema:\n

          #{json_schema}
          """
        }

        [sys_message | messages]
      end)
      |> Keyword.put(:tools, [
        %{
          type: "function",
          function: %{
            "description" =>
              "Correctly extracted `#{title}` with all the required parameters with correct types",
            "name" => title,
            "parameters" => json_schema |> Jason.decode!()
          }
        }
      ])
      |> Keyword.put(:tool_choice, %{
        type: "function",
        function: %{name: title}
      })

    params
  end

  defp sanitize(title),
    do: title |> String.replace("_", "-") |> String.replace("?", "") |> String.replace(".", "-")

  defp call_validate(response_model, changeset, context) do
    cond do
      not is_ecto_schema(response_model) ->
        changeset

      function_exported?(response_model, :validate_changeset, 1) ->
        response_model.validate_changeset(changeset)

      function_exported?(response_model, :validate_changeset, 2) ->
        response_model.validate_changeset(changeset, context)

      true ->
        changeset
    end
  end

  defp adapter() do
    Application.get_env(:instructor, :adapter, Instructor.Adapters.OpenAI)
  end
end