lib/voxcpmex/server.ex

defmodule VoxCPMEx.Server do
  @moduledoc """
  GenServer that manages a Python VoxCPM2 bridge via Erlang Port.

  ## Protocol (v2.1 — MessagePack binary framing, single-ref streaming)

  Frame format: `[4-byte BE total_length][msgpack-encoded payload]`

  Audio is raw WAV bytes inside msgpack — no base64.

  Streaming: unlike v2.0 which used a dual stream_id/ref mapping, v2.1
  eliminates stream IDs entirely. The Elixir ref (sent in the request) is
  the sole identifier — Python just echoes stream_start/chunk/end in order.
  """

  use GenServer

  require Logger

  @type model_option ::
          {:model, String.t()}
          | {:device, String.t()}
          | {:load_denoiser, boolean()}
          | {:optimize, boolean()}
          | {:name, atom()}
  @type start_opts :: [model_option()]

  @type generate_opt ::
          {:audio_prompt, String.t()}
          | {:prompt_wav_path, String.t()}
          | {:prompt_text, String.t()}
          | {:cfg_value, float()}
          | {:inference_timesteps, pos_integer()}
          | {:min_len, pos_integer()}
          | {:max_len, pos_integer()}
          | {:normalize, boolean()}
          | {:denoise, boolean()}

  @default_model "openbmb/VoxCPM2"
  @default_device "cuda"
  @frame_header_bytes 4
  @stream_ttl_ms 60_000

  # =========================================================================
  # Client API
  # =========================================================================

  @spec start_link(start_opts()) :: GenServer.on_start()
  def start_link(opts \\ []) do
    name = Keyword.get(opts, :name)
    GenServer.start_link(__MODULE__, opts, name: name)
  end

  @doc "Returns runtime model info: device, sample_rate, status."
  @spec info(GenServer.server()) :: map()
  def info(server) do
    GenServer.call(server, :info)
  end

  @doc """
  Waits for the model to finish loading.

  Returns `:ok` when ready, `{:error, :loading}` if still initializing,
  or `{:error, reason}` if initialization failed.
  """
  @spec await_ready(GenServer.server(), timeout()) :: :ok | {:error, term()}
  def await_ready(server, timeout \\ 120_000) do
    GenServer.call(server, :await_ready, timeout)
  end

  @doc "Gracefully stops the GenServer and the Python bridge process."
  @spec stop(GenServer.server()) :: :ok
  def stop(server) do
    GenServer.stop(server, :normal)
  end

  @spec generate(GenServer.server(), String.t(), [generate_opt()]) :: {:ok, binary()} | {:error, term()}
  def generate(server, text, opts \\ []) do
    generate(server, text, opts, 120_000)
  end

  @spec generate(GenServer.server(), String.t(), [generate_opt()], timeout()) :: {:ok, binary()} | {:error, term()}
  def generate(server, text, opts, timeout) do
    GenServer.call(server, {:generate, text, opts}, timeout)
  end

  @doc """
  Starts async streaming. Returns `{:ok, ref}` immediately.

  Poll with `next_chunk/2`: `{:ok, chunk}` (raw float32 PCM) | `:eos` | `{:error, reason}`.
  Collect everything with `collect_stream/2`: `{:ok, wav_binary}`.
  """
  @spec generate_streaming_async(GenServer.server(), String.t(), [generate_opt()]) :: {:ok, reference()} | {:error, term()}
  def generate_streaming_async(server, text, opts \\ []) do
    GenServer.call(server, {:generate_streaming_async, text, opts}, 10_000)
  end

  @spec next_chunk(GenServer.server(), reference()) :: {:ok, binary()} | :eos | {:error, term()}
  def next_chunk(server, ref) do
    GenServer.call(server, {:next_chunk, ref}, 60_000)
  end

  @spec collect_stream(GenServer.server(), reference()) :: {:ok, binary()} | {:error, term()}
  def collect_stream(server, ref) do
    GenServer.call(server, {:collect_stream, ref}, 120_000)
  end

  @spec save(binary(), Path.t()) :: :ok | {:error, term()}
  def save(audio, path) when is_binary(audio) do
    File.write(path, audio)
  end

  @spec load_lora(GenServer.server(), String.t()) :: {:ok, non_neg_integer(), non_neg_integer()} | {:error, term()}
  def load_lora(server, lora_path) do
    GenServer.call(server, {:load_lora, lora_path}, 30_000)
  end

  @spec unload_lora(GenServer.server()) :: :ok | {:error, term()}
  def unload_lora(server) do
    GenServer.call(server, :unload_lora, 15_000)
  end

  # =========================================================================
  # GenServer Callbacks
  # =========================================================================

  @impl true
  def init(opts) do
    model = Keyword.get(opts, :model, @default_model)
    device = Keyword.get(opts, :device, @default_device)
    load_denoiser = Keyword.get(opts, :load_denoiser, false)
    optimize = Keyword.get(opts, :optimize, true)

    bridge_path = Path.join(:code.priv_dir(:voxcpmex), "python/voxcpmex_bridge.py")

    unless File.exists?(bridge_path) do
      {:stop, "Python bridge not found: #{bridge_path}"}
    else
      python_cmd = System.find_executable("python3") || System.find_executable("python") || "python3"

      port =
        Port.open({:spawn_executable, python_cmd}, [
          :binary,
          :use_stdio,
          :exit_status,
          :stderr_to_stdout,
          args: ["-u", bridge_path]
        ])

      send_frame(port, %{
        "type" => "init",
        "model" => model,
        "device" => device,
        "load_denoiser" => load_denoiser,
        "optimize" => optimize
      })

      state = %{
        port: port,
        model: model,
        status: :loading,
        device: nil,
        sample_rate: nil,
        buffer: <<>>,
        pending: :queue.new(),
        streams: %{},
        stream_timers: %{}
      }

      {:ok, state}
    end
  end

  @impl true
  def terminate(_reason, state) do
    if state.port do
      Port.close(state.port)
    end

    # Cancel all stream timers
    for {_ref, tref} <- state.stream_timers do
      Process.cancel_timer(tref)
    end

    :ok
  end

  # -- call handlers ----------------------------------------------------------

  @impl true
  def handle_call(:info, _from, state) do
    {:reply, %{
      status: state.status,
      device: state.device,
      sample_rate: state.sample_rate,
      model: state.model,
      stream_count: map_size(state.streams)
    }, state}
  end

  def handle_call(:await_ready, _from, state) do
    case state.status do
      :ready -> {:reply, :ok, state}
      :loading -> {:reply, {:error, :loading}, state}
      {:error, reason} -> {:reply, {:error, reason}, state}
    end
  end

  def handle_call({:generate, text, opts}, from, state) do
    if state.status != :ready do
      {:reply, {:error, :not_ready}, state}
    else
      t0 = System.monotonic_time()
      msg = Map.merge(%{"type" => "generate", "text" => text}, stringify_keys(opts))
      send_frame(state.port, msg)
      {:noreply, %{state | pending: :queue.in({from, :generate, t0}, state.pending)}}
    end
  end

  def handle_call({:generate_streaming_async, text, opts}, from, state) do
    if state.status != :ready do
      {:reply, {:error, :not_ready}, state}
    else
      ref = make_ref()
      msg = Map.merge(%{"type" => "generate_streaming", "text" => text}, stringify_keys(opts))
      send_frame(state.port, msg)

      # Set cleanup timer
      tref = Process.send_after(self(), {:stream_timeout, ref}, @stream_ttl_ms)

      streams = Map.put(state.streams, ref, %{
        chunks: [],
        sample_rate: nil,
        done: false,
        waiter: nil,
        collector: nil,
        t0: System.monotonic_time()
      })

      timers = Map.put(state.stream_timers, ref, tref)

      {:noreply,
       %{state | pending: :queue.in({from, :stream_ref, ref}, state.pending),
                 streams: streams, stream_timers: timers}}
    end
  end

  def handle_call({:next_chunk, ref}, from, state) do
    case Map.get(state.streams, ref) do
      nil ->
        {:reply, {:error, :unknown_stream}, state}

      %{chunks: [chunk | rest], done: _done} = stream ->
        # Reset TTL
        timers = reset_stream_timer(state.stream_timers, ref)

        new_streams = Map.put(state.streams, ref, %{stream | chunks: rest})
        {:reply, {:ok, chunk}, %{state | streams: new_streams, stream_timers: timers}}

      %{chunks: [], done: true} ->
        # Clean up
        timers = cancel_stream_timer(state.stream_timers, ref)
        {:reply, :eos, %{state | streams: Map.delete(state.streams, ref), stream_timers: timers}}

      %{chunks: [], done: false} ->
        # Reset TTL and block
        timers = reset_stream_timer(state.stream_timers, ref)
        {:noreply,
         %{state | streams: put_in(state.streams[ref].waiter, from), stream_timers: timers}}
    end
  end

  def handle_call({:collect_stream, ref}, from, state) do
    case Map.get(state.streams, ref) do
      nil ->
        {:reply, {:error, :unknown_stream}, state}

      %{done: false} ->
        {:noreply,
         %{state | streams: put_in(state.streams[ref].collector, from)}}

      %{chunks: chunks, sample_rate: sr, done: true} ->
        timers = cancel_stream_timer(state.stream_timers, ref)
        audio = build_wav(chunks, sr)
        {:reply, {:ok, audio}, %{state | streams: Map.delete(state.streams, ref), stream_timers: timers}}
    end
  end

  def handle_call({:load_lora, lora_path}, from, state) do
    send_frame(state.port, %{"type" => "load_lora", "lora_path" => lora_path})
    {:noreply, %{state | pending: :queue.in({from, :lora_load, nil}, state.pending)}}
  end

  def handle_call(:unload_lora, from, state) do
    send_frame(state.port, %{"type" => "unload_lora"})
    {:noreply, %{state | pending: :queue.in({from, :lora_unload, nil}, state.pending)}}
  end

  # -- port messages ----------------------------------------------------------

  @impl true
  def handle_info({_port, {:data, data}}, state) do
    {msgs, new_buffer} = parse_frames(state.buffer <> data)

    state = Enum.reduce(msgs, %{state | buffer: new_buffer}, fn msg, acc ->
      handle_message(msg, acc)
    end)

    {:noreply, state}
  end

  @impl true
  def handle_info({_port, {:exit_status, status}}, state) do
    Logger.error("VoxCPM bridge exited with status #{status}")
    {:stop, {:bridge_exit, status}, state}
  end

  # Stream timeout — cleanup stale streams
  def handle_info({:stream_timeout, ref}, state) do
    case Map.get(state.streams, ref) do
      nil ->
        {:noreply, state}
      stream ->
        Logger.warning("Stream #{inspect(ref)} timed out after #{@stream_ttl_ms}ms — cleaning up")
        if stream.waiter, do: GenServer.reply(stream.waiter, {:error, :timeout})
        if stream.collector, do: GenServer.reply(stream.collector, {:error, :timeout})
        timers = Map.delete(state.stream_timers, ref)
        {:noreply, %{state | streams: Map.delete(state.streams, ref), stream_timers: timers}}
    end
  end

  # =========================================================================
  # Message dispatch
  # =========================================================================

  defp handle_message(%{"status" => "ok"} = msg, state) do
    cond do
      Map.has_key?(msg, "device") ->
        Logger.info("VoxCPM loaded on #{msg["device"]}, sr=#{msg["sample_rate"]}")
        %{state | status: :ready, device: msg["device"], sample_rate: msg["sample_rate"]}

      Map.has_key?(msg, "audio") ->
        {{:value, {from, _type, t0}}, new_pending} = :queue.out(state.pending)
        emit_telemetry(:generate, t0)
        GenServer.reply(from, {:ok, msg["audio"]})
        %{state | pending: new_pending}

      Map.has_key?(msg, "loaded") ->
        {{:value, {from, _type, _}}, new_pending} = :queue.out(state.pending)
        GenServer.reply(from, {:ok, msg["loaded"], msg["skipped"]})
        %{state | pending: new_pending}

      true ->
        case :queue.out(state.pending) do
          {{:value, {from, _type, _}}, new_pending} ->
            GenServer.reply(from, :ok)
            %{state | pending: new_pending}
          {:empty, _} -> state
        end
    end
  end

  defp handle_message(%{"status" => "error", "error" => error}, state) do
    Logger.error("Bridge error: #{error}")

    {new_pending, updated_state} =
      case :queue.out(state.pending) do
        {{:value, {from, _type, _t0}}, new_q} ->
          GenServer.reply(from, {:error, error})
          {new_q, state}
        {:empty, _} ->
          {state.pending, state}
      end

    # If this error arrived during init, update status
    updated_state =
      if updated_state.status == :loading do
        %{updated_state | status: {:error, error}}
      else
        updated_state
      end

    %{updated_state | pending: new_pending}
  end

  # -- streaming messages (v2.1: no stream_id, ref is implicit from pending) ---

  defp handle_message(%{"type" => "stream_start"} = msg, state) do
    sample_rate = msg["sample_rate"]

    {{:value, {from, :stream_ref, ref}}, new_pending} = :queue.out(state.pending)
    GenServer.reply(from, {:ok, ref})

    streams = put_in(state.streams[ref].sample_rate, sample_rate)
    %{state | pending: new_pending, streams: streams}
  end

  defp handle_message(%{"type" => "stream_chunk"} = msg, state) do
    chunk = msg["chunk"]

    # The stream is the only one whose ref is not in streams... 
    # We need to find which stream is active. Since only one streaming
    # session can be active, just find the one that isn't done.
    case Enum.find(state.streams, fn {_ref, s} -> not s.done end) do
      {ref, stream} ->
        new_chunks = stream.chunks ++ [chunk]
        streams = put_in(state.streams[ref].chunks, new_chunks)
        state = %{state | streams: streams}

        case Map.get(state.streams, ref) do
          %{waiter: waiter} when not is_nil(waiter) ->
            %{chunks: [next | rest]} = Map.get(state.streams, ref)
            GenServer.reply(waiter, {:ok, next})
            new_s = put_in(state.streams[ref].chunks, rest)
            new_s = put_in(new_s[ref].waiter, nil)
            %{state | streams: new_s}
          _ ->
            state
        end

      nil ->
        Logger.warning("Stream chunk with no active stream")
        state
    end
  end

  defp handle_message(%{"type" => "stream_end"} = msg, state) do
    total_chunks = msg["total_chunks"]

    case Enum.find(state.streams, fn {_ref, s} -> not s.done end) do
      {ref, stream} ->
        t0 = stream.t0
        streams = put_in(state.streams[ref].done, true)
        state = %{state | streams: streams}
        timers = cancel_stream_timer(state.stream_timers, ref)

        emit_telemetry(:stream, t0, %{chunks: total_chunks})

        cond do
          not is_nil(stream.collector) ->
            audio = build_wav(stream.chunks, stream.sample_rate)
            GenServer.reply(stream.collector, {:ok, audio})
            %{state | streams: Map.delete(state.streams, ref), stream_timers: timers}

          not is_nil(stream.waiter) ->
            case stream.chunks do
              [last] ->
                GenServer.reply(stream.waiter, {:ok, last})
                %{state | streams: Map.delete(state.streams, ref), stream_timers: timers}
              chunks when chunks != [] ->
                [next | rest] = chunks
                GenServer.reply(stream.waiter, {:ok, next})
                new_s = put_in(state.streams[ref].chunks, rest)
                new_s = put_in(new_s[ref].waiter, nil)
                %{state | streams: new_s, stream_timers: timers}
              [] ->
                GenServer.reply(stream.waiter, :eos)
                %{state | streams: Map.delete(state.streams, ref), stream_timers: timers}
            end

          true ->
            %{state | stream_timers: timers}
        end

      nil ->
        Logger.warning("Stream end with no active stream")
        state
    end
  end

  defp handle_message(%{"type" => "stream_error"} = msg, state) do
    error = msg["error"]
    Logger.error("Stream error: #{error}")

    case Enum.find(state.streams, fn {_ref, s} -> not s.done end) do
      {ref, stream} ->
        if stream.waiter, do: GenServer.reply(stream.waiter, {:error, error})
        if stream.collector, do: GenServer.reply(stream.collector, {:error, error})
        timers = cancel_stream_timer(state.stream_timers, ref)
        %{state | streams: Map.delete(state.streams, ref), stream_timers: timers}
      nil ->
        state
    end
  end

  # =========================================================================
  # Helpers
  # =========================================================================

  defp send_frame(port, msg) do
    data = Msgpax.pack!(msg)
    len = byte_size(data) + @frame_header_bytes
    frame = <<len::32-unsigned-big-integer>> <> data
    send(port, {self(), {:command, frame}})
  end

  defp parse_frames(binary) do
    parse_frames_loop(binary, [])
  end

  defp parse_frames_loop(<<len::32-unsigned-big-integer, rest::binary>>, acc)
       when byte_size(rest) >= len - @frame_header_bytes do
    payload_size = len - @frame_header_bytes
    <<payload::binary-size(payload_size), remaining::binary>> = rest

    msg = Msgpax.unpack!(payload)
    parse_frames_loop(remaining, [msg | acc])
  end

  defp parse_frames_loop(partial, acc) do
    {Enum.reverse(acc), partial}
  end

  defp stringify_keys(kwlist) do
    Map.new(kwlist, fn {k, v} -> {to_string(k), v} end)
  end

  # -- WAV builder ------------------------------------------------------------

  # Build a valid WAV file from raw float32 PCM chunks.
  # Chunks are raw float32 interleaved samples.
  defp build_wav(chunks, sample_rate) when is_list(chunks) and is_integer(sample_rate) do
    pcm = IO.iodata_to_binary(chunks)
    data_size = byte_size(pcm)
    # WAV header: 44 bytes for PCM
    riff_size = data_size + 36
    header = <<
      "RIFF", riff_size::32-little-unsigned,
      "WAVE",
      "fmt ", 16::32-little-unsigned,
      3::16-little-unsigned,        # IEEE float
      1::16-little-unsigned,        # mono
      sample_rate::32-little-unsigned,
      sample_rate * 4::32-little-unsigned,  # byte rate (float32 mono)
      4::16-little-unsigned,        # block align
      32::16-little-unsigned,       # bits per sample
      "data", data_size::32-little-unsigned
    >>
    header <> pcm
  end

  defp build_wav(_, _), do: <<>>

  # -- Stream timer helpers ---------------------------------------------------

  defp reset_stream_timer(timers, ref) do
    if Map.has_key?(timers, ref) do
      _ = Process.cancel_timer(timers[ref])
      Map.put(timers, ref, Process.send_after(self(), {:stream_timeout, ref}, @stream_ttl_ms))
    else
      timers
    end
  end

  defp cancel_stream_timer(timers, ref) do
    if Map.has_key?(timers, ref) do
      Process.cancel_timer(timers[ref])
      Map.delete(timers, ref)
    else
      timers
    end
  end

  # -- Telemetry --------------------------------------------------------------

  defp emit_telemetry(event, t0, extra \\ %{}) do
    t1 = System.monotonic_time()
    duration_us = System.convert_time_unit(t1 - t0, :native, :microsecond)

    :telemetry.execute([:voxcpmex, event], %{duration_us: duration_us}, extra)
  end
end