Skip to main content

lib/pi/llm/broker.ex

defmodule Pi.LLM.Broker do
  @moduledoc "Multiplexes BEAM-initiated LLM requests over the active pi transport."

  use GenServer

  alias Pi.LLM.Stream, as: LLMStream
  alias Pi.Protocol.LLM.Cancel
  alias Pi.Protocol.Response
  alias Pi.Transport.Stdio

  @timeout 60_000

  def start_link(opts \\ []), do: GenServer.start_link(__MODULE__, opts, name: __MODULE__)

  def install do
    case Process.whereis(__MODULE__) do
      nil -> start_link([])
      _pid -> :ok
    end
  end

  def complete(messages, opts \\ []) do
    request(
      :llm_complete,
      %{messages: messages, opts: Map.new(opts)},
      Keyword.get(opts, :timeout, @timeout)
    )
  end

  def request(op, payload, timeout \\ @timeout) when is_atom(op) and is_map(payload) do
    install()
    GenServer.call(__MODULE__, {:request, op, payload, timeout}, timeout + 1_000)
  end

  def stream(messages, opts \\ []) do
    request_stream(:llm_stream, %{messages: messages, opts: Map.new(opts)}, opts)
  end

  def deliver(id, result) when is_binary(id) do
    install()
    GenServer.cast(__MODULE__, {:deliver, id, result})
  end

  def deliver_stream(id, event, payload)
      when is_binary(id) and event in [:chunk, :done, :error] do
    install()
    GenServer.cast(__MODULE__, {:deliver_stream, id, event, payload})
  end

  @impl true
  def init(_opts), do: {:ok, %{next_id: 0, pending: %{}, streams: %{}}}

  @impl true
  def handle_call({:request, op, payload, timeout}, from, state) do
    id = request_id(state.next_id + 1)
    timer = Process.send_after(self(), {:timeout, id}, timeout)
    Stdio.emit_request(id, op, payload)

    pending = Map.put(state.pending, id, %{from: from, timer: timer})
    {:noreply, %{state | next_id: state.next_id + 1, pending: pending}}
  end

  def handle_call({:register_stream, id, owner}, _from, state) do
    {:reply, :ok, %{state | streams: Map.put(state.streams, id, owner)}}
  end

  def handle_call({:unregister_stream, id}, _from, state) do
    {:reply, :ok, %{state | streams: Map.delete(state.streams, id)}}
  end

  @impl true
  def handle_cast({:deliver, id, result}, state) do
    {:noreply, reply(state, id, Response.to_result(result))}
  end

  def handle_cast({:deliver_stream, id, event, payload}, state) do
    case Map.get(state.streams, id) do
      nil ->
        {:noreply, state}

      owner ->
        send_stream(owner, id, event, payload)

        streams =
          if event in [:done, :error], do: Map.delete(state.streams, id), else: state.streams

        {:noreply, %{state | streams: streams}}
    end
  end

  @impl true
  def handle_info({:timeout, id}, state) do
    {:noreply, reply(state, id, {:error, "Pi LLM request timed out"})}
  end

  defp reply(state, id, result) do
    case Map.pop(state.pending, id) do
      {nil, _pending} ->
        state

      {%{from: from, timer: timer}, pending} ->
        Process.cancel_timer(timer)
        GenServer.reply(from, result)
        %{state | pending: pending}
    end
  end

  defp request_stream(op, payload, opts) do
    install()
    id = request_id(System.unique_integer([:positive]))
    GenServer.call(__MODULE__, {:register_stream, id, self()})
    Stdio.emit_request(id, op, payload)

    stream =
      Elixir.Stream.resource(
        fn -> id end,
        fn
          :done ->
            {:halt, :done}

          stream_id ->
            receive do
              {:pi_llm_chunk, ^stream_id, delta} -> {[delta], stream_id}
              {:pi_llm_done, ^stream_id, result} -> {[result], :done}
              {:pi_llm_error, ^stream_id, error} -> raise RuntimeError, message: inspect(error)
            after
              Keyword.get(opts, :timeout, @timeout) ->
                Stdio.emit(%Cancel{type: :llm_cancel, id: stream_id, reason: "timeout"})
                {:halt, stream_id}
            end
        end,
        fn
          :done ->
            :ok

          stream_id ->
            GenServer.call(__MODULE__, {:unregister_stream, stream_id})
            Stdio.emit(%Cancel{type: :llm_cancel, id: stream_id, reason: "closed"})
        end
      )

    %LLMStream{id: id, stream: stream}
  end

  defp send_stream(owner, id, :chunk, delta), do: send(owner, {:pi_llm_chunk, id, delta})
  defp send_stream(owner, id, :done, result), do: send(owner, {:pi_llm_done, id, result})
  defp send_stream(owner, id, :error, error), do: send(owner, {:pi_llm_error, id, error})

  defp request_id(next_id), do: "llm_#{System.unique_integer([:positive])}_#{next_id}"
end