Skip to main content

lib/llm/adapter.ex

defmodule LLM.Adapter do
  @moduledoc """
  Behaviour for wire format translation between normalized messages and provider APIs.

  Each adapter handles:
  - Encoding `LLM.Context` into provider-specific JSON request bodies
  - Providing the streaming endpoint path and extra headers
  - Decoding full responses back into `LLM.Response.t()`
  - Decoding SSE chunks back into normalized stream structs

  Built-in adapters:

  | Module | Provider API |
  |---|---|
  | `LLM.Adapter.OpenAI` | OpenAI Chat Completions |
  | `LLM.Adapter.OpenAIResponse` | OpenAI Responses API |
  | `LLM.Adapter.Anthropic` | Anthropic Messages API |
  | `LLM.Adapter.Gemini` | Google Gemini API |

  ## Implementing a custom adapter

  Implement all required callbacks plus any optional ones your provider needs:

      defmodule MyApp.Adapter do
        @behaviour LLM.Adapter

        @impl true
        def build_request(context, opts) do
          %{
            "model" => Keyword.fetch!(opts, :model),
            "messages" => encode_messages(context)
          }
        end

        @impl true
        def decode_response(%{"choices" => [choice | _]} = raw) do
          {:ok,
           %LLM.Response{
             message: %LLM.Message{role: :assistant, content: choice["text"]},
             raw: raw
           }}
        end

        @impl true
        def stream_path, do: "/v1/chat/completions"

        @impl true
        def stream_headers(_opts), do: [{"content-type", "application/json"}]

        @impl true
        def decode_chunk("data: [DONE]", _state), do: {:done, %{}}
        def decode_chunk("data: " <> json, state) do
          # Parse SSE data and return normalized chunks
          {[%LLM.Stream.Chunk{text: text}], state}
        end
      end
  """

  @type chunk_type ::
          LLM.Stream.Chunk.t()
          | LLM.Stream.ToolCall.t()
          | LLM.Stream.Thinking.t()
          | LLM.Stream.Stop.t()
          | LLM.Stream.Error.t()

  @type model_info :: %{
          id: String.t(),
          name: String.t() | nil,
          description: String.t() | nil,
          context_window: non_neg_integer() | nil,
          max_output: non_neg_integer() | nil,
          capabilities: [atom()] | nil
        }

  @doc """
  Build the request body from a context and options.

  Receives an `LLM.Context.t()` and a keyword list of options (including `:model`).
  Returns a map suitable for `Jason.encode!/1`.
  """
  @callback build_request(LLM.Context.t(), keyword()) :: map()

  @doc """
  Decode a full (non-streaming) response from the provider.

  Receives the decoded JSON response body. Returns `{:ok, LLM.Response.t()}` or
  `{:error, reason}`.
  """
  @callback decode_response(map()) :: {:ok, LLM.Response.t()} | {:error, term()}

  @doc """
  Return the streaming endpoint path (e.g., `"/v1/chat/completions"`).

  May contain `{model}` which is replaced with the actual model name.
  """
  @callback stream_path() :: String.t()

  @doc """
  Return the non-streaming endpoint path.

  If unimplemented, defaults to the streaming path.
  May contain `{model}` which is replaced with the actual model name.
  """
  @callback non_stream_path() :: String.t()

  @doc """
  Return extra HTTP headers needed for the streaming request.

  Content-Type and auth headers are added separately by `LLM.Stream`.
  """
  @callback stream_headers(keyword()) :: [{String.t(), String.t()}]

  @doc """
  Decode a single SSE event into stream chunks.

  Receives the raw SSE data string and adapter state. Returns a tuple of
  `{chunks, new_state}` where chunks is a list of stream structs or `:done`,
  or `{:done, new_state}` to signal end of stream.
  """
  @callback decode_chunk(String.t(), map()) :: {[chunk_type()] | :done, map()}

  @doc """
  Decode a single SSE event (stateless variant).

  Used by adapters that don't need to track state across chunks.
  """
  @callback decode_chunk(String.t()) :: [chunk_type()] | :done

  @doc """
  Normalize a thinking/reasoning option to the provider's expected format.

  Maps atoms like `:low`, `:medium`, `:high` to provider-specific values.
  """
  @callback normalize_thinking(term()) :: term()

  @doc """
  Return authentication headers for the provider.

  Receives the provider config map and options. Returns a list of header tuples.
  """
  @callback auth_headers(map(), keyword()) :: [{String.t(), String.t()}]

  @doc """
  Initialize the adapter state for a new streaming request.

  Called once when a stream starts. The state is passed to `decode_chunk/2`.
  """
  @callback init_stream_state() :: map()

  @doc """
  List available models from the provider.

  Returns `{:ok, [model_info()]}` or `{:error, reason}`.
  Not all providers support this — implement as a no-op if unsupported.
  """
  @callback list_models(map()) :: {:ok, [model_info()]} | {:error, term()}

  @optional_callbacks [
    decode_chunk: 1,
    normalize_thinking: 1,
    auth_headers: 2,
    init_stream_state: 0,
    list_models: 1,
    non_stream_path: 0
  ]

  @doc """
  Extract the name and schema from a structured output spec.

  Handles three shapes:
  - `%{name: name, schema: schema}` (atom keys)
  - `%{"name" => name, "schema" => schema}` (string keys)
  - A bare schema map (defaults to name `"output"`)
  """
  @spec extract_schema(map()) :: {String.t(), map()}
  def extract_schema(%{name: name, schema: schema}), do: {to_string(name), schema}
  def extract_schema(%{"name" => name, "schema" => schema}), do: {name, schema}
  def extract_schema(schema), do: {"output", schema}
end