Skip to main content

lib/pixir/provider/connection.ex

defmodule Pixir.Provider.Connection do
  @moduledoc """
  Per-key WebSocket connection process for the Responses Provider.

  The process owns connection-local optimization state: socket, latest
  `previous_response_id`, prompt input prefix, keepalive timers, and temporary
  degraded/backoff state. It does not own Session History. If continuation state is
  missing or invalid, the caller can still replay from Pixir's Log over HTTP/SSE or a
  fresh WebSocket. The default idle window is deliberately agent-scale so ordinary
  pauses between Turns do not erase same-socket continuation evidence.
  """

  use GenServer

  alias Pixir.Provider.WebSocketClient
  alias Pixir.Tool

  @registry Pixir.Provider.ConnectionRegistry
  @default_timeout_ms 30_000
  @default_degraded_ms 5_000
  @default_idle_ms 30 * 60 * 1_000
  @default_keepalive_ms 25_000

  def child_spec(opts) do
    key = Keyword.fetch!(opts, :key)

    %{
      id: {__MODULE__, key},
      start: {__MODULE__, :start_link, [opts]},
      restart: :transient,
      shutdown: 5_000,
      type: :worker
    }
  end

  def start_link(opts) do
    key = Keyword.fetch!(opts, :key)
    GenServer.start_link(__MODULE__, opts, name: via(key))
  end

  def via(key), do: {:via, Registry, {@registry, key}}

  @spec stream(term(), map(), acc, (term(), acc -> acc), keyword()) ::
          {:ok, acc} | {:error, map(), acc}
        when acc: term()
  def stream(key, http_request, acc, fun, opts \\ []) do
    case Pixir.Provider.ConnectionSupervisor.ensure_started(key) do
      {:ok, pid} ->
        timeout_ms = Keyword.get(opts, :timeout_ms, @default_timeout_ms)
        GenServer.call(pid, {:stream, http_request, acc, fun, opts}, timeout_ms + 5_000)

      {:error, reason} ->
        error =
          Tool.error(:websocket_start_failed, "WebSocket connection process could not start.", %{
            reason: inspect(reason),
            key: inspect(key)
          })

        {:error, error, acc}
    end
  end

  @impl true
  def init(opts) do
    {:ok,
     %{
       key: Keyword.fetch!(opts, :key),
       socket: nil,
       initial_buffer: "",
       endpoint: nil,
       headers_fingerprint: nil,
       previous_response_id: nil,
       previous_input: nil,
       previous_model: nil,
       last_continuation_reset_reason: nil,
       degraded_until_ms: 0,
       failures: 0,
       idle_timer: nil,
       keepalive_timer: nil,
       websocket_client: WebSocketClient
     }}
  end

  @impl true
  def handle_call({:stream, http_request, acc, fun, opts}, _from, state) do
    now = monotonic_ms()

    cond do
      state.degraded_until_ms != 0 and state.degraded_until_ms > now ->
        error =
          Tool.error(:websocket_degraded, "WebSocket is temporarily degraded.", %{
            retry_after_ms: state.degraded_until_ms - now,
            key: inspect(state.key)
          })

        {:reply, {:error, error, acc}, state}

      true ->
        reply_stream(http_request, acc, fun, opts, state)
    end
  end

  defp reply_stream(http_request, acc, fun, opts, state) do
    client = Keyword.get(opts, :websocket_client, WebSocketClient)
    client_opts = Keyword.get(opts, :websocket_client_opts, [])
    timeout_ms = Keyword.get(opts, :timeout_ms, @default_timeout_ms)
    client_opts = Keyword.put_new(client_opts, :timeout_ms, timeout_ms)

    client_opts =
      case Keyword.get(opts, :stream_activity) do
        fun when is_function(fun, 0) -> Keyword.put(client_opts, :stream_activity, fun)
        _ -> client_opts
      end

    client_opts =
      case Keyword.get(opts, :stream_idle_timeout_ms) do
        nil -> client_opts
        idle_ms -> Keyword.put(client_opts, :stream_idle_timeout_ms, idle_ms)
      end

    stream_started_ms = monotonic_ms()

    with {:ok, state} <- ensure_connected(state, client, http_request, client_opts),
         {:ok, full_payload, wire_payload, continuation} <- build_payload(http_request, state) do
      case stream_wire_payload(
             state,
             client,
             wire_payload,
             full_payload,
             continuation,
             acc,
             fun,
             client_opts,
             opts
           ) do
        {:ok, acc, response, state} ->
          next_state = successful_state(state, full_payload, response, opts)

          acc =
            fun.(
              {:metadata,
               %{
                 "websocket_captured_response_id" => response_id_from_stream(response),
                 "websocket_stored_previous_response_id" => next_state.previous_response_id
               }},
              acc
            )

          {:reply, {:ok, acc}, next_state}

        {:continuation_not_found, state} ->
          state = reset_continuation(state, "previous_response_not_found")
          retry_client_opts = client_opts_with_remaining_timeout(client_opts, stream_started_ms)

          case stream_wire_payload(
                 state,
                 client,
                 full_payload,
                 full_payload,
                 continuation_metadata(false, "previous_response_not_found"),
                 acc,
                 fun,
                 retry_client_opts,
                 opts
               ) do
            {:ok, acc, response, state} ->
              next_state = successful_state(state, full_payload, response, opts)

              acc =
                fun.(
                  {:metadata,
                   %{
                     "websocket_captured_response_id" => response_id_from_stream(response),
                     "websocket_stored_previous_response_id" => next_state.previous_response_id
                   }},
                  acc
                )

              {:reply, {:ok, acc}, next_state}

            {:error, error, acc, state} ->
              next_state = mark_degraded(state, client, error, opts)
              {:reply, {:error, error, acc}, next_state}

            {:continuation_not_found, state} ->
              error =
                Tool.error(
                  :provider_http_error,
                  "Provider rejected previous_response_id even after full replay retry.",
                  %{reason: "previous_response_not_found"}
                )

              next_state = mark_degraded(state, client, error, opts)
              {:reply, {:error, error, acc}, next_state}
          end

        {:error, error, acc, state} ->
          next_state = mark_degraded(state, client, error, opts)
          {:reply, {:error, error, acc}, next_state}
      end
    else
      {:error, error} ->
        next_state = mark_degraded(state, client, error, opts)
        {:reply, {:error, error, acc}, next_state}
    end
  end

  defp stream_wire_payload(
         state,
         client,
         wire_payload,
         _full_payload,
         continuation,
         acc,
         fun,
         client_opts,
         _opts
       ) do
    metadata =
      transport_metadata(
        "websocket",
        Map.merge(continuation, %{
          "used_previous_response_id" => Map.has_key?(wire_payload, "previous_response_id"),
          "websocket_key" => inspect(state.key),
          "websocket_stored_previous_response_id" => state.previous_response_id
        })
      )

    acc =
      acc
      |> then(&fun.({:metadata, metadata}, &1))
      |> then(&fun.({:status, 200}, &1))

    case client.stream(
           state.socket,
           state.initial_buffer,
           wire_payload,
           acc,
           fun,
           client_opts
         ) do
      {:ok, acc, response} ->
        if Map.has_key?(wire_payload, "previous_response_id") and
             previous_response_not_found?(acc) do
          {:continuation_not_found, %{state | initial_buffer: ""}}
        else
          {:ok, acc, response, %{state | initial_buffer: ""}}
        end

      {:error, error, acc} ->
        {:error, error, acc, state}
    end
  end

  defp successful_state(state, full_payload, response, opts) do
    captured_id = response_id_from_stream(response)

    %{
      state
      | initial_buffer: "",
        previous_response_id: captured_id || state.previous_response_id,
        previous_input: full_payload["input"] || [],
        previous_model: full_payload["model"],
        last_continuation_reset_reason: nil,
        degraded_until_ms: 0,
        failures: 0
    }
    |> schedule_keepalive(opts)
    |> schedule_idle_close(opts)
  end

  defp response_id_from_stream(response) when is_map(response) do
    Map.get(response, :response_id) || Map.get(response, "response_id")
  end

  defp response_id_from_stream(_response), do: nil

  @impl true
  def handle_info(:idle_close, state) do
    close_socket(state)
    {:stop, :normal, %{state | socket: nil, idle_timer: nil, keepalive_timer: nil}}
  end

  @impl true
  def handle_info(:keepalive_ping, %{socket: nil} = state) do
    {:noreply, %{state | keepalive_timer: nil}}
  end

  def handle_info(:keepalive_ping, state) do
    case state.websocket_client.ping(state.socket) do
      :ok ->
        {:noreply, schedule_keepalive(%{state | keepalive_timer: nil}, [])}

      {:error, _reason} ->
        close_socket(state)

        {:noreply,
         state
         |> reset_continuation("keepalive_failed")
         |> Map.merge(%{
           socket: nil,
           initial_buffer: "",
           endpoint: nil,
           headers_fingerprint: nil,
           keepalive_timer: nil
         })}
    end
  end

  @impl true
  def terminate(_reason, state) do
    close_socket(state)
    :ok
  end

  defp ensure_connected(state, client, http_request, opts) do
    endpoint = websocket_endpoint(http_request.url)
    headers_fingerprint = stable_fingerprint(http_request.headers)

    if connected_to?(state, endpoint, headers_fingerprint) do
      {:ok, state}
    else
      reconnect(state, client, endpoint, http_request.headers, headers_fingerprint, opts)
    end
  end

  defp connected_to?(state, endpoint, headers_fingerprint) do
    not is_nil(state.socket) and state.endpoint == endpoint and
      state.headers_fingerprint == headers_fingerprint
  end

  defp reconnect(state, client, endpoint, headers, headers_fingerprint, opts) do
    state =
      state
      |> cancel_idle_timer()
      |> cancel_keepalive_timer()

    close_socket(state)

    case client.connect(endpoint, headers, opts) do
      {:ok, socket, initial_buffer, _handshake} ->
        {:ok,
         state
         |> reset_continuation(reconnect_reset_reason(state, endpoint, headers_fingerprint))
         |> Map.merge(%{
           socket: socket,
           initial_buffer: initial_buffer,
           endpoint: endpoint,
           headers_fingerprint: headers_fingerprint,
           idle_timer: nil,
           keepalive_timer: nil,
           websocket_client: client
         })}

      {:error, error} ->
        {:error, error}
    end
  end

  defp build_payload(%{body: body}, state) do
    with {:ok, decoded} when is_map(decoded) <- Jason.decode(IO.iodata_to_binary(body)) do
      full =
        decoded
        |> Map.delete("stream")
        |> Map.put("type", "response.create")
        |> Map.put("store", false)

      {wire, continuation} = maybe_continue(full, state)
      {:ok, full, wire, continuation}
    else
      _ ->
        {:error,
         Tool.error(:invalid_provider_request, "Could not decode Provider request body.", %{})}
    end
  end

  # Returns `{wire_payload, continuation_metadata}`. The metadata is evidence about the
  # ACTUAL wire payload: `"continuation_attempted"` is true only when the wire payload
  # really carries `previous_response_id` with a delta input. Any reset to the full
  # payload records why in `"continuation_reset_reason"` (string-keyed, ADR 0019).
  defp maybe_continue(full, %{
         previous_response_id: id,
         previous_input: previous,
         previous_model: model
       })
       when is_binary(id) and is_list(previous) do
    current = full["input"] || []

    case {full["model"] == model, suffix_after_prefix(current, previous)} do
      {false, _} ->
        {full, continuation_metadata(false, "model_changed")}

      {true, :error} ->
        {full, continuation_metadata(false, "prefix_mismatch")}

      {true, {:ok, suffix}} ->
        delta = continuation_delta(suffix)

        if delta == [] do
          {full, continuation_metadata(false, "empty_delta")}
        else
          wire =
            full
            |> Map.put("input", delta)
            |> Map.put("previous_response_id", id)

          {wire, continuation_metadata(true, nil)}
        end
    end
  end

  defp maybe_continue(full, %{last_continuation_reset_reason: reason}) when is_binary(reason),
    do: {full, continuation_metadata(false, reason)}

  defp maybe_continue(full, _state),
    do: {full, continuation_metadata(false, "no_previous_response")}

  defp continuation_metadata(attempted?, reset_reason) do
    %{
      "continuation_attempted" => attempted?,
      "continuation_reset_reason" => reset_reason
    }
  end

  defp previous_response_not_found?(%{stream_error: %{error: error}}) when is_map(error) do
    details =
      case Map.get(error, :details) || Map.get(error, "details") do
        details when is_map(details) -> details
        _ -> %{}
      end

    [
      error[:code],
      error["code"],
      error[:type],
      error["type"],
      error[:message],
      error["message"],
      details[:code],
      details["code"],
      details[:type],
      details["type"]
    ]
    |> Enum.any?(fn
      value when is_binary(value) ->
        value =~ ~r/previous[ _-]?response.*not[ _-]?found/i

      _ ->
        false
    end)
  end

  defp previous_response_not_found?(_acc), do: false

  defp client_opts_with_remaining_timeout(client_opts, started_ms) do
    timeout_ms = Keyword.get(client_opts, :timeout_ms, @default_timeout_ms)
    elapsed_ms = max(0, monotonic_ms() - started_ms)
    remaining_ms = max(1, timeout_ms - elapsed_ms)
    Keyword.put(client_opts, :timeout_ms, remaining_ms)
  end

  defp suffix_after_prefix(current, previous) when length(current) >= length(previous) do
    {prefix, suffix} = Enum.split(current, length(previous))
    if prefix == previous, do: {:ok, suffix}, else: :error
  end

  defp suffix_after_prefix(_current, _previous), do: :error

  defp continuation_delta(suffix) do
    Enum.reject(suffix, fn
      %{"type" => "function_call"} -> true
      %{"type" => "reasoning"} -> true
      %{"type" => "message", "role" => "assistant"} -> true
      _ -> false
    end)
  end

  defp mark_degraded(state, client, error, opts) do
    close_socket(%{state | websocket_client: client})

    degraded_ms = Keyword.get(opts, :websocket_degraded_ms, @default_degraded_ms)

    %{
      state
      | socket: nil,
        initial_buffer: "",
        endpoint: nil,
        headers_fingerprint: nil,
        previous_model: nil,
        keepalive_timer: nil,
        degraded_until_ms: monotonic_ms() + degraded_ms,
        failures: state.failures + 1
    }
    |> reset_continuation("websocket_failed")
    |> cancel_idle_timer()
    |> cancel_keepalive_timer()
    |> maybe_keep_previous_input(error)
  end

  defp reconnect_reset_reason(%{previous_response_id: nil}, _endpoint, _headers_fingerprint),
    do: nil

  defp reconnect_reset_reason(state, endpoint, headers_fingerprint) do
    cond do
      state.endpoint != endpoint -> "endpoint_changed"
      state.headers_fingerprint != headers_fingerprint -> "headers_changed"
      true -> "websocket_reconnected"
    end
  end

  defp reset_continuation(state, nil), do: state

  defp reset_continuation(state, reason) do
    %{
      state
      | previous_response_id: nil,
        previous_input: nil,
        previous_model: nil,
        last_continuation_reset_reason: reason
    }
  end

  defp schedule_idle_close(state, opts) do
    state = cancel_idle_timer(state)
    idle_ms = Keyword.get(opts, :websocket_idle_ms, @default_idle_ms)

    if is_integer(idle_ms) and idle_ms > 0 do
      %{state | idle_timer: Process.send_after(self(), :idle_close, idle_ms)}
    else
      state
    end
  end

  defp cancel_idle_timer(%{idle_timer: nil} = state), do: state

  defp cancel_idle_timer(%{idle_timer: ref} = state) do
    Process.cancel_timer(ref)
    %{state | idle_timer: nil}
  end

  defp schedule_keepalive(state, opts) do
    state = cancel_keepalive_timer(state)
    keepalive_ms = Keyword.get(opts, :websocket_keepalive_ms, @default_keepalive_ms)

    if is_integer(keepalive_ms) and keepalive_ms > 0 do
      %{state | keepalive_timer: Process.send_after(self(), :keepalive_ping, keepalive_ms)}
    else
      state
    end
  end

  defp cancel_keepalive_timer(%{keepalive_timer: nil} = state), do: state

  defp cancel_keepalive_timer(%{keepalive_timer: ref} = state) do
    Process.cancel_timer(ref)
    %{state | keepalive_timer: nil}
  end

  defp maybe_keep_previous_input(state, _error), do: state

  defp close_socket(%{socket: nil}), do: :ok

  defp close_socket(%{websocket_client: client, socket: socket}) do
    _ = client.close(socket)
    :ok
  end

  defp websocket_endpoint(http_url) do
    uri = URI.parse(http_url)
    scheme = if uri.scheme == "http", do: "ws", else: "wss"
    URI.to_string(%{uri | scheme: scheme})
  end

  defp transport_metadata(active_transport, extra) do
    Map.merge(
      %{
        "transport_preference" => "websocket",
        "active_transport" => active_transport
      },
      extra
    )
  end

  defp monotonic_ms, do: System.monotonic_time(:millisecond)

  defp stable_fingerprint(headers) do
    headers
    |> Enum.map(&normalize_header_fingerprint/1)
    |> Enum.sort()
    |> :erlang.term_to_binary()
    |> then(&:crypto.hash(:sha256, &1))
    |> Base.encode16(case: :lower)
  end

  # OAuth access tokens rotate between Turns; fingerprinting the bearer value would
  # force a reconnect and discard connection-local continuation state.
  defp normalize_header_fingerprint({"authorization", _value}),
    do: {"authorization", "<bearer>"}

  defp normalize_header_fingerprint({"Authorization", _value}),
    do: {"authorization", "<bearer>"}

  defp normalize_header_fingerprint({name, value}),
    do: {String.downcase(to_string(name)), to_string(value)}
end