Skip to main content

lib/llm/stream.ex

defmodule LLM.Stream do
  @moduledoc """
  Low-level stream handling for LLM API responses.

  For most use cases prefer `LLM.stream/3` which opens the stream and
  calls `collect/2` for you, returning a `LLM.Response` directly.

  ## Basic usage

      {:ok, stream} = LLM.Stream.start(context, opts)
      {:ok, response} = LLM.Stream.collect(stream, on_chunk: fn c -> IO.write(c.text) end)

  For most use cases `LLM.stream/3` is simpler — it handles start and collect
  internally and returns the final `LLM.Response`.

  ## Chunk types

  Each call to `next/1` returns a list of typed chunk structs:

  | Struct | Meaning |
  |--------|---------|
  | `LLM.Stream.Chunk` | Text delta (`:text` field) |
  | `LLM.Stream.Thinking` | Reasoning delta (`:text` and optional `:signature`) |
  | `LLM.Stream.ToolCall` | Tool invocation (`:id`, `:name`, `:arguments`, `:complete`) |
  | `LLM.Stream.Stop` | End of generation (`:reason`, `:usage`) |
  | `LLM.Stream.Error` | Provider-level error (`:message`) |

  ## Manual iteration

      {:ok, stream} = LLM.stream("Hello", provider: :anthropic, model: "claude-sonnet-4-6")

      Enum.reduce_while(Stream.iterate(stream, &LLM.Stream.next/1), [], fn
        {:ok, chunks, _s}, acc -> {:cont, acc ++ chunks}
        {:halt, _}, acc -> {:halt, acc}
      end)

  ## Tool call loop

  When `auto_tools: true` (default), `collect/2` automatically:

  1. Detects complete `ToolCall` chunks at end of stream
  2. Looks up and executes each tool in `context.tools`
  3. Appends the assistant turn + tool results to the conversation
  4. Opens a new stream and repeats (up to `max_rounds` times)

  Use `:on_message` to observe each completed turn. It fires once per finished
  `LLM.Message` in conversation order — every assistant turn, and every tool
  result (as a `role: :tool` message) — including the final assistant turn:

      {:ok, response} = LLM.Stream.collect(stream,
        on_message: fn
          %LLM.Message{role: :assistant} = msg -> IO.puts("assistant: \#{msg.content}")
          %LLM.Message{role: :tool} = msg -> IO.puts("tool \#{msg.name}: \#{msg.content}")
        end
      )

  Pair it with `:on_chunk` to observe both token-level deltas and turn-level
  messages in a single pass. The full reconstructed conversation is also
  available afterwards as `response.messages`.

  Avoid sending messages to the calling process's mailbox from inside the
  callback, as `next/1`'s `receive` loop will consume and discard unknown
  messages — use an `Agent` or ETS table to accumulate instead.
  """

  defstruct [
    :ref,
    :adapter,
    :provider,
    :context,
    :opts,
    :url,
    :headers,
    adapter_state: %{},
    rounds: 0,
    base_message_count: 0,
    buffer: "",
    done: false,
    timeout: 30_000,
    deadline: nil,
    accumulated_usage: %LLM.Usage{}
  ]

  @type t :: %__MODULE__{
          ref: Req.Response.t(),
          adapter: module(),
          provider: map(),
          context: LLM.Context.t(),
          opts: keyword(),
          url: String.t(),
          headers: [{String.t(), String.t()}],
          adapter_state: map(),
          rounds: non_neg_integer(),
          base_message_count: non_neg_integer(),
          buffer: String.t(),
          done: boolean(),
          timeout: non_neg_integer(),
          accumulated_usage: LLM.Usage.t()
        }

  @type chunk_type ::
          LLM.Stream.Chunk.t()
          | LLM.Stream.ToolCall.t()
          | LLM.Stream.Thinking.t()
          | LLM.Stream.Stop.t()
          | LLM.Stream.Error.t()

  @doc """
  Start a streaming request. Returns `{:ok, stream}` or `{:error, reason}`.
  """
  @spec start(LLM.Context.t(), keyword()) :: {:ok, t()} | {:error, term()}
  def start(context, opts) do
    provider = LLM.Provider.Resolver.resolve(opts[:provider] || :openai)
    adapter = provider.adapter
    model = opts[:model] || raise ArgumentError, "model is required"

    request_body = adapter.build_request(context, [model: model] ++ opts)
    path = build_path(adapter.stream_path(), model)

    url = build_url(provider.base_url, path)
    headers = build_headers(provider, opts)

    req =
      Req.new(
        url: url,
        method: :post,
        headers: headers,
        json: request_body,
        into: :self
      )

    case LLM.HTTPClient.request(req) do
      {:ok, %Req.Response{status: 200} = resp} ->
        stream = %__MODULE__{
          ref: resp,
          adapter: adapter,
          provider: provider,
          context: context,
          opts: opts,
          url: url,
          headers: headers,
          adapter_state: init_adapter_state(adapter),
          base_message_count: length(opts[:messages] || []),
          deadline: System.monotonic_time(:millisecond) + 30_000
        }

        {:ok, stream}

      {:ok, %Req.Response{status: status, body: body} = response} ->
        {:error, %{status: status, body: normalize_error_body(body, response)}}

      {:error, _} = err ->
        err
    end
  end

  @doc false
  def normalize_error_body(%Req.Response.Async{} = body, response) do
    body
    |> Enum.to_list()
    |> IO.iodata_to_binary()
    |> maybe_decode_json_body(response)
  end

  @doc false
  def normalize_error_body(body, response) when is_binary(body),
    do: maybe_decode_json_body(body, response)

  @doc false
  def normalize_error_body(body, _response), do: body

  @doc false
  def maybe_decode_json_body(body, response) do
    if should_decode_json_body?(body, response) do
      case Jason.decode(body) do
        {:ok, decoded} -> decoded
        {:error, _} -> body
      end
    else
      body
    end
  end

  defp should_decode_json_body?(body, response) do
    json_content_type?(response) or json_body?(body)
  end

  @doc false
  def json_content_type?(%Req.Response{} = response) do
    case Req.Response.get_header(response, "content-type") do
      [content_type | _] when is_binary(content_type) ->
        json_media_type?(content_type)

      _ ->
        false
    end
  end

  defp json_media_type?(content_type) do
    normalized_content_type =
      content_type
      |> String.downcase()
      |> String.split(";", parts: 2)
      |> hd()
      |> String.trim()

    normalized_content_type == "application/json" or
      String.ends_with?(normalized_content_type, "+json")
  end

  @doc false
  def json_body?(body) do
    case String.trim_leading(body) do
      <<"{" <> _rest>> -> true
      <<"[" <> _rest>> -> true
      _ -> false
    end
  end

  @doc false
  def build_path(path, model) do
    if path =~ "{model}" do
      String.replace(path, "{model}", model)
    else
      path
    end
  end

  @doc false
  def build_url(base_url, path) do
    base_uri = URI.parse(base_url)
    path_uri = URI.parse(path)

    base_segments = split_path_segments(base_uri.path)
    path_segments = split_path_segments(path_uri.path)
    overlap = overlapping_segments(base_segments, path_segments)

    merged_path =
      base_segments
      |> Kernel.++(Enum.drop(path_segments, overlap))
      |> join_path_segments()

    base_uri
    |> Map.put(:path, merged_path)
    |> Map.put(:query, path_uri.query || base_uri.query)
    |> Map.put(:fragment, path_uri.fragment || base_uri.fragment)
    |> URI.to_string()
  end

  defp split_path_segments(nil), do: []

  defp split_path_segments(path) do
    path
    |> String.trim("/")
    |> case do
      "" -> []
      trimmed -> String.split(trimmed, "/", trim: true)
    end
  end

  defp join_path_segments([]), do: "/"
  defp join_path_segments(segments), do: "/" <> Enum.join(segments, "/")

  defp overlapping_segments(base_segments, path_segments) do
    max_overlap = min(length(base_segments), length(path_segments))

    Enum.find(max_overlap..0//-1, 0, fn overlap ->
      Enum.take(base_segments, -overlap) == Enum.take(path_segments, overlap)
    end)
  end

  @doc """
  Receive the next chunk from the stream.

  Returns:
  - `{:ok, [chunk], stream}` — one or more chunks received
  - `{:halt, stream}` — stream ended
  - `{:error, reason}` — error occurred
  """
  @spec next(t()) :: {:ok, [chunk_type()], t()} | {:halt, t()} | {:error, term()}
  def next(%__MODULE__{ref: ref, deadline: deadline, timeout: timeout} = stream) do
    remaining =
      if deadline do
        max(deadline - System.monotonic_time(:millisecond), 0)
      else
        timeout
      end

    if remaining <= 0 do
      {:error, :timeout}
    else
      receive do
        msg ->
          case Req.parse_message(ref, msg) do
            {:ok, chunks} ->
              parse_chunks(chunks, stream)

            {:error, reason} ->
              {:error, reason}

            :unknown ->
              next(stream)
          end
      after
        remaining -> {:error, :timeout}
      end
    end
  end

  @doc """
  Collect all chunks into a final response, executing tool calls automatically.

  ## Options

  - `:auto_tools` — auto-execute tool calls (default: `true`)
  - `:max_rounds` — max tool call rounds (default: `10`)
  - `:on_chunk` — callback for each chunk, `fn chunk -> ... end`
  - `:on_message` — callback fired once per completed `LLM.Message` (each assistant
    turn and each tool result, in order), `fn message -> ... end`
  """
  @spec collect(t(), keyword()) :: {:ok, LLM.Response.t()} | {:error, term()}
  def collect(stream, opts \\ []) do
    auto_tools = Keyword.get(opts, :auto_tools, true)
    max_rounds = Keyword.get(opts, :max_rounds, 10)
    on_chunk = Keyword.get(opts, :on_chunk)
    on_message = Keyword.get(opts, :on_message)
    do_collect(stream, [], auto_tools, max_rounds, on_chunk, on_message)
  end

  defp do_collect(stream, acc, _auto, max_rounds, _on_chunk, on_message)
       when stream.rounds >= max_rounds do
    cancel_stream(stream)

    finalize_chunks(
      acc,
      :max_rounds,
      stream.context.provider_state,
      session_messages(stream),
      on_message,
      stream.accumulated_usage
    )
  end

  defp do_collect(stream, acc, auto_tools, max_rounds, on_chunk, on_message) do
    case next(stream) do
      {:halt, _stream} ->
        maybe_continue_with_tools(stream, acc, auto_tools, max_rounds, on_chunk, on_message)

      {:ok, chunks, stream} ->
        Enum.each(chunks, fn chunk -> on_chunk && on_chunk.(chunk) end)
        new_acc = acc ++ chunks

        if stream.done do
          maybe_continue_with_tools(stream, new_acc, auto_tools, max_rounds, on_chunk, on_message)
        else
          do_collect(stream, new_acc, auto_tools, max_rounds, on_chunk, on_message)
        end

      {:error, _} = err ->
        err
    end
  end

  @doc false
  def has_tool_calls?(chunks) do
    Enum.any?(chunks, &match?(%LLM.Stream.ToolCall{}, &1))
  end

  defp maybe_continue_with_tools(stream, chunks, auto_tools, max_rounds, on_chunk, on_message) do
    if auto_tools and has_complete_tool_calls?(chunks) do
      case execute_tool_calls(chunks, stream, on_message) do
        {:ok, next_stream} ->
          do_collect(next_stream, [], auto_tools, max_rounds, on_chunk, on_message)

        {:error, _} = err ->
          err
      end
    else
      finalize_chunks(
        chunks,
        nil,
        stream.context.provider_state,
        session_messages(stream),
        on_message,
        stream.accumulated_usage
      )
    end
  end

  # Messages produced in this collect session: the user prompt that started the
  # loop plus every assistant/tool turn since — excluding any prior conversation
  # passed in via the `:messages` option.
  defp session_messages(%__MODULE__{context: context, base_message_count: base}) do
    Enum.drop(context.messages, base)
  end

  defp has_complete_tool_calls?(chunks) do
    Enum.any?(chunks, &match?(%LLM.Stream.ToolCall{complete: true}, &1))
  end

  defp execute_tool_calls(chunks, stream, on_message) do
    tool_calls =
      chunks
      |> Enum.filter(&match?(%LLM.Stream.ToolCall{complete: true}, &1))
      |> Enum.uniq_by(fn tc -> {tc.index, tc.id, tc.name, tc.arguments} end)

    tools = Enum.map(stream.context.tools, &LLM.Tool.normalize/1)

    results =
      tool_calls
      |> Task.async_stream(
        fn tc ->
          tool = Enum.find(tools, fn t -> t.name == tc.name end)

          result =
            if tool do
              try do
                case tool.execute.(tc.arguments, %{messages: stream.context.messages}) do
                  {:ok, result} -> to_string(result)
                  {:error, err} -> "Error: #{err}"
                end
              rescue
                e ->
                  "Error: #{Exception.message(e)}"
              end
            else
              "Error: Unknown tool #{inspect(tc.name)}"
            end

          {:ok, %{id: tc.id, name: tc.name, content: result}}
        end,
        ordered: true,
        timeout: Keyword.get(stream.opts, :tool_timeout, 30_000)
      )
      |> Enum.zip_with(tool_calls, fn
        {:ok, {:ok, result}}, _tc ->
          result

        {:ok, {:exit, reason}}, tc ->
          %{id: tc.id, name: tc.name, content: "Error: #{inspect(reason)}"}

        {:exit, reason}, tc ->
          %{id: tc.id, name: tc.name, content: "Error: #{inspect(reason)}"}
      end)

    text_content =
      chunks
      |> Enum.filter(&match?(%LLM.Stream.Chunk{}, &1))
      |> Enum.map(& &1.text)
      |> IO.iodata_to_binary()

    tool_calls_for_history =
      chunks
      |> Enum.filter(&match?(%LLM.Stream.ToolCall{complete: true}, &1))

    thinking_for_history =
      chunks
      |> Enum.filter(&match?(%LLM.Stream.Thinking{}, &1))
      |> merge_thinking_parts()
      |> extract_thinking()

    assistant_message =
      build_message(
        if(thinking_for_history, do: [{:thinking, thinking_for_history}], else: []),
        text_content,
        tool_calls_for_history
      )

    # Extract usage from this round's Stop chunks and attach to the assistant message.
    # Usage may arrive in a trailing Stop separate from the `finish_reason` Stop, so
    # merge across all Stop chunks rather than taking the first.
    round_usage = round_usage_from_chunks(chunks)
    assistant_message = %{assistant_message | usage: round_usage}

    # Accumulate usage across rounds
    new_accumulated_usage = LLM.Usage.add(stream.accumulated_usage, round_usage)

    tool_result_messages =
      Enum.map(results, fn r ->
        %LLM.Message{role: :tool, tool_call_id: r.id, name: r.name, content: r.content}
      end)

    # Emit each completed message in conversation order: the assistant turn first,
    # then each tool result. `on_tool` is folded into this — tool results arrive as
    # `LLM.Message` structs with `role: :tool`.
    if on_message do
      on_message.(assistant_message)
      Enum.each(tool_result_messages, on_message)
    end

    new_messages =
      stream.context.messages ++ [assistant_message] ++ tool_result_messages

    new_provider_state =
      case Map.get(stream.adapter_state, :response_id) do
        id when is_binary(id) ->
          Map.put(stream.context.provider_state, "previous_response_id", id)

        _ ->
          stream.context.provider_state
      end

    new_context = %{stream.context | messages: new_messages, provider_state: new_provider_state}

    case start(new_context, stream.opts) do
      {:ok, new_stream} ->
        {:ok, %{new_stream | rounds: stream.rounds + 1, accumulated_usage: new_accumulated_usage}}

      {:error, _} = err ->
        err
    end
  end

  @doc false
  def finalize_chunks(
        chunks,
        forced_stop_reason \\ nil,
        provider_state \\ %{},
        context_messages \\ [],
        on_message \\ nil,
        accumulated_usage \\ %LLM.Usage{}
      ) do
    text =
      chunks
      |> Enum.filter(&match?(%LLM.Stream.Chunk{}, &1))
      |> Enum.map(& &1.text)
      |> IO.iodata_to_binary()

    stop = Enum.find(chunks, &match?(%LLM.Stream.Stop{}, &1))

    tool_calls =
      chunks
      |> Enum.filter(&match?(%LLM.Stream.ToolCall{complete: true}, &1))

    thinking_parts =
      chunks
      |> Enum.filter(&match?(%LLM.Stream.Thinking{}, &1))
      |> merge_thinking_parts()

    round_usage = round_usage_from_chunks(chunks)

    message =
      thinking_parts
      |> build_message(text, tool_calls)
      |> Map.put(:usage, round_usage)

    on_message && on_message.(message)

    total_usage = LLM.Usage.add(accumulated_usage, round_usage)

    {:ok,
     %LLM.Response{
       message: message,
       messages: context_messages ++ [message],
       usage: total_usage,
       stop_reason: forced_stop_reason || (stop && stop.reason),
       provider_state: provider_state
     }}
  end

  @doc false
  def parse_chunks(raw_chunks, stream) do
    {buffer, decoded, has_done, adapter_state} =
      Enum.reduce(raw_chunks, {stream.buffer, [], false, stream.adapter_state}, fn
        {:data, data}, {buf, acc, done, state} ->
          merged = buf <> data
          {remaining, events} = split_events(merged)

          Enum.reduce(events, {remaining, acc, done, state}, fn event,
                                                                {_r, a, d, current_state} ->
            {chunks, new_state, event_done} =
              decode_sse_event(event, stream.adapter, current_state)

            {remaining, a ++ chunks, d || event_done, new_state}
          end)

        :done, {buf, acc, _done, state} ->
          {buf, acc, true, state}

        _, acc ->
          acc
      end)

    {buffer, decoded, adapter_state} =
      if has_done and buffer != "" do
        {chunks, state, _} = decode_sse_event(buffer, stream.adapter, adapter_state)
        {"", decoded ++ chunks, state}
      else
        {buffer, decoded, adapter_state}
      end

    stream = %{stream | buffer: buffer, done: has_done, adapter_state: adapter_state}

    if has_done and decoded == [] do
      {:halt, stream}
    else
      case decoded do
        [] -> next(stream)
        _ -> {:ok, decoded, stream}
      end
    end
  end

  defp split_events(data) do
    parts = String.split(data, "\n\n")

    case parts do
      [_] ->
        {data, []}

      _ ->
        {last, complete} = List.pop_at(parts, -1)

        events =
          complete
          |> Enum.reject(&(&1 == ""))

        {last, events}
    end
  end

  defp decode_sse_event(event, adapter, adapter_state) do
    trimmed_event = String.trim_trailing(event)

    case decode_chunk_result(decode_chunk(adapter, trimmed_event, adapter_state)) do
      {[], state, false} ->
        # Only fall back to per-line processing for multi-line SSE events (e.g. events
        # with id:/event: fields alongside data:). For single-line data: events, the
        # whole-event result is already correct — re-processing would double-apply any
        # state update (e.g. streamed tool-call argument fragments), corrupting the JSON.
        if String.contains?(trimmed_event, "\n") do
          trimmed_event
          |> String.split("\n")
          |> Enum.filter(&String.starts_with?(&1, "data: "))
          |> Enum.reduce({[], state, false}, fn line, {acc, current_state, done} ->
            {chunks, new_state, event_done} =
              decode_chunk_result(decode_chunk(adapter, line, current_state))

            {acc ++ chunks, new_state, done || event_done}
          end)
        else
          {[], state, false}
        end

      {chunks, new_state, done} ->
        {chunks, new_state, done}
    end
  end

  defp decode_chunk(adapter, event, adapter_state) do
    Code.ensure_loaded(adapter)

    if function_exported?(adapter, :decode_chunk, 2) do
      adapter.decode_chunk(event, adapter_state)
    else
      {adapter.decode_chunk(event), adapter_state}
    end
  end

  defp decode_chunk_result({:done, state}), do: {[], state, true}

  defp decode_chunk_result({chunks, state}) when is_list(chunks) do
    Enum.reduce(chunks, {[], state, false}, fn
      :done, {acc, current_state, _done} ->
        {acc, current_state, true}

      chunk, {acc, current_state, done} ->
        {acc ++ [chunk], current_state, done || terminal_chunk?(chunk)}
    end)
  end

  defp decode_chunk_result(:done), do: {[], %{}, true}

  defp decode_chunk_result(chunks) when is_list(chunks) do
    Enum.reduce(chunks, {[], %{}, false}, fn
      :done, {acc, state, _done} ->
        {acc, state, true}

      chunk, {acc, state, done} ->
        {acc ++ [chunk], state, done || terminal_chunk?(chunk)}
    end)
  end

  defp decode_chunk_result(_), do: {[], %{}, false}

  # `:done` always ends the stream. A `Stop` chunk ends it only when it already
  # carries usage — some OpenAI-compatible providers (e.g. OpenRouter) emit the
  # `finish_reason` Stop first and the usage in a separate trailing chunk, so a
  # usage-less Stop must not terminate collection before that chunk arrives. If
  # no usage ever follows, the provider's `[DONE]`/connection-close still ends it.
  defp terminal_chunk?(%LLM.Stream.Stop{usage: usage}), do: not is_nil(usage)
  defp terminal_chunk?(_), do: false

  # Merge usage across every Stop chunk in a round. Providers may split the
  # `finish_reason` Stop (no usage) from a trailing Stop carrying token counts.
  defp round_usage_from_chunks(chunks) do
    chunks
    |> Enum.filter(&match?(%LLM.Stream.Stop{}, &1))
    |> Enum.map(& &1.usage)
    |> Enum.reject(&is_nil/1)
    |> LLM.Usage.sum()
  end

  @doc false
  def build_headers(provider, opts) do
    headers = provider.adapter.stream_headers(opts)

    cond do
      provider.api_key == nil ->
        headers

      true ->
        Code.ensure_loaded(provider.adapter)

        if function_exported?(provider.adapter, :auth_headers, 2) do
          provider.adapter.auth_headers(provider, opts) ++ headers
        else
          [{"authorization", "Bearer #{provider.api_key}"} | headers]
        end
    end
  end

  defp init_adapter_state(adapter) do
    Code.ensure_loaded(adapter)

    if function_exported?(adapter, :init_stream_state, 0),
      do: adapter.init_stream_state(),
      else: %{}
  end

  defp cancel_stream(%__MODULE__{ref: %Req.Response{body: %Req.Response.Async{} = async}}) do
    async.cancel_fun.(async.ref)
    :ok
  end

  defp cancel_stream(_), do: :ok

  defp build_message(thinking_parts, text, tool_calls) do
    thinking = extract_thinking(thinking_parts)

    tools =
      case tool_calls do
        [] ->
          nil

        calls ->
          Enum.map(calls, fn %LLM.Stream.ToolCall{id: id, name: name, arguments: args} ->
            %{id: id, name: name, args: args}
          end)
      end

    %LLM.Message{
      role: :assistant,
      content: if(text == "", do: nil, else: text),
      thinking: thinking,
      tools: tools
    }
  end

  defp extract_thinking([{:thinking, value} | _]), do: value
  defp extract_thinking([_ | rest]), do: extract_thinking(rest)
  defp extract_thinking([]), do: nil

  defp merge_thinking_parts(thinking_chunks) do
    thinking_chunks
    |> Enum.reduce([], fn
      %LLM.Stream.Thinking{text: text, signature: nil}, [{:thinking, existing} | rest]
      when is_binary(existing) ->
        [{:thinking, existing <> text} | rest]

      %LLM.Stream.Thinking{text: text, signature: signature},
      [{:thinking, %{text: existing, signature: same_signature}} | rest]
      when signature == same_signature ->
        [{:thinking, %{text: existing <> text, signature: signature}} | rest]

      %LLM.Stream.Thinking{text: text, signature: nil}, acc ->
        [{:thinking, text} | acc]

      %LLM.Stream.Thinking{text: text, signature: signature}, acc ->
        [{:thinking, %{text: text, signature: signature}} | acc]
    end)
    |> Enum.reverse()
  end
end