priv/python/voxcpmex_bridge.py

#!/usr/bin/env python3
"""
VoxCPMEx Bridge v2.1 — Python bridge for Elixir VoxCPMEx library.

Protocol: binary frames over stdin/stdout
  Frame: [4-byte big-endian total_length][msgpack-encoded message]

Audio is raw WAV bytes inside msgpack — no base64 encoding.

Streaming (v2.1): simplified — no stream IDs. The bridge processes one
streaming request at a time, emitting stream_start → N×stream_chunk →
stream_end in strict sequence.
"""

import sys
import io
import os
import struct
import signal

import msgpack
import numpy as np
import soundfile as sf

# ---------------------------------------------------------------------------
# Graceful shutdown on SIGPIPE / SIGTERM
# ---------------------------------------------------------------------------
_shutting_down = False

def _handle_signal(signum, frame):
    global _shutting_down
    _shutting_down = True
    sys.stderr.write(f"Bridge received signal {signum}, shutting down\n")
    sys.stderr.flush()

signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # Let OS handle broken pipe
signal.signal(signal.SIGTERM, _handle_signal)

# ---------------------------------------------------------------------------
# PyTorch early import check
# ---------------------------------------------------------------------------
try:
    import torch
except ImportError as e:
    err = msgpack.dumps({"status": "error", "error": f"Missing dependency: {e}"})
    _write_frame(err)
    sys.exit(1)

_original_torch_load = torch.load

def _patched_torch_load(f, map_location=None, **kwargs):
    if map_location is None:
        map_location = "cpu"
    return _original_torch_load(f, map_location=map_location, **kwargs)

torch.load = _patched_torch_load


# ---------------------------------------------------------------------------
# I/O helpers
# ---------------------------------------------------------------------------
def _read_exact(n: int) -> bytes:
    """Read exactly n bytes from stdin. Raises EOFError on EOF."""
    data = b""
    while len(data) < n:
        chunk = sys.stdin.buffer.read(n - len(data))
        if not chunk:
            raise EOFError("stdin closed")
        data += chunk
    return data


def _read_frame() -> dict:
    """Read one msgpack frame from stdin."""
    header = _read_exact(4)
    total_len = struct.unpack(">I", header)[0]
    payload = _read_exact(total_len - 4)
    return msgpack.loads(payload, raw=False)


def _write_frame(data: bytes) -> None:
    """Write one msgpack frame to stdout. Returns False on broken pipe."""
    try:
        frame_len = struct.pack(">I", len(data) + 4)
        sys.stdout.buffer.write(frame_len + data)
        sys.stdout.buffer.flush()
        return True
    except (BrokenPipeError, OSError):
        return False


def _send(msg: dict) -> bool:
    """Encode and send. Returns False on failure."""
    return _write_frame(msgpack.dumps(msg))


def _send_error(error: str) -> bool:
    return _send({"status": "error", "error": error})


# ---------------------------------------------------------------------------
# Device detection
# ---------------------------------------------------------------------------
def _detect_device(requested: str) -> str:
    req = (requested or "cuda").strip().lower()
    if req.startswith("cuda"):
        return req if torch.cuda.is_available() else "cpu"
    if req == "mps":
        has_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
        return "mps" if has_mps else "cpu"
    return "cpu"


# ---------------------------------------------------------------------------
# VoxCPM Bridge
# ---------------------------------------------------------------------------
class VoxCPMBridge:
    def __init__(self):
        self.model = None
        self.device = None
        self.sample_rate = None

    def init_model(self, msg: dict) -> dict:
        try:
            from voxcpm import VoxCPM

            hf_model_id = msg.get("model", "openbmb/VoxCPM2")
            load_denoiser = msg.get("load_denoiser", False)
            optimize = msg.get("optimize", True)
            requested_device = msg.get("device", "cuda")

            self.device = _detect_device(requested_device)

            sys.stderr.write(f"Loading {hf_model_id} on {self.device}...\n")
            sys.stderr.flush()

            self.model = VoxCPM.from_pretrained(
                hf_model_id,
                load_denoiser=load_denoiser,
                device=self.device,
                optimize=optimize,
            )

            self.sample_rate = self.model.tts_model.sample_rate
            sys.stderr.write(f"Loaded. device={self.device} sr={self.sample_rate}\n")
            sys.stderr.flush()

            return {"status": "ok", "device": self.device, "sample_rate": self.sample_rate}

        except Exception as e:
            return {"status": "error", "error": str(e)}

    def generate(self, msg: dict) -> dict:
        if self.model is None:
            return {"status": "error", "error": "Model not initialized"}

        try:
            text = msg["text"]
            wav = self.model.generate(
                text,
                cfg_value=msg.get("cfg_value", 2.0),
                inference_timesteps=msg.get("inference_timesteps", 10),
                min_len=msg.get("min_len", 2),
                max_len=msg.get("max_len", 4096),
                normalize=msg.get("normalize", False),
                denoise=msg.get("denoise", False),
                reference_wav_path=msg.get("audio_prompt"),
                prompt_wav_path=msg.get("prompt_wav_path"),
                prompt_text=msg.get("prompt_text"),
                retry_badcase=msg.get("retry_badcase", True),
                retry_badcase_max_times=msg.get("retry_badcase_max_times", 3),
                retry_badcase_ratio_threshold=msg.get("retry_badcase_ratio_threshold", 6.0),
            )

            audio_bytes = self._wav_to_bytes(wav, self.sample_rate)
            return {
                "status": "ok",
                "audio": audio_bytes,
                "sample_rate": self.sample_rate,
                "duration": len(wav) / self.sample_rate,
            }

        except Exception as e:
            return {"status": "error", "error": str(e)}

    def generate_streaming(self, msg: dict) -> None:
        """Generate speech with streaming — emits frames in strict sequence."""
        if self.model is None:
            _send_error("Model not initialized")
            return

        text = msg["text"]

        try:
            # Announce stream start
            if not _send({"type": "stream_start", "sample_rate": self.sample_rate}):
                return  # Elixir side disconnected

            idx = 0
            for chunk in self.model.generate_streaming(
                text,
                cfg_value=msg.get("cfg_value", 2.0),
                inference_timesteps=msg.get("inference_timesteps", 10),
                min_len=msg.get("min_len", 2),
                max_len=msg.get("max_len", 4096),
                normalize=msg.get("normalize", False),
                denoise=msg.get("denoise", False),
                reference_wav_path=msg.get("audio_prompt"),
                prompt_wav_path=msg.get("prompt_wav_path"),
                prompt_text=msg.get("prompt_text"),
                retry_badcase=msg.get("retry_badcase", False),
                retry_badcase_max_times=msg.get("retry_badcase_max_times", 3),
                retry_badcase_ratio_threshold=msg.get("retry_badcase_ratio_threshold", 6.0),
            ):
                if _shutting_down:
                    return

                chunk_bytes = chunk.astype(np.float32).tobytes()
                if not _send({
                    "type": "stream_chunk",
                    "chunk": chunk_bytes,
                    "index": idx,
                    "length": len(chunk),
                }):
                    return  # Elixir side disconnected
                idx += 1

            _send({
                "type": "stream_end",
                "total_chunks": idx,
            })

        except Exception as e:
            _send({"type": "stream_error", "error": str(e)})

    def load_lora(self, msg: dict) -> dict:
        if self.model is None:
            return {"status": "error", "error": "Model not initialized"}
        try:
            loaded, skipped = self.model.load_lora(msg["lora_path"])
            return {"status": "ok", "loaded": len(loaded), "skipped": len(skipped)}
        except Exception as e:
            return {"status": "error", "error": str(e)}

    def unload_lora(self) -> dict:
        if self.model is None:
            return {"status": "error", "error": "Model not initialized"}
        try:
            self.model.unload_lora()
            return {"status": "ok"}
        except Exception as e:
            return {"status": "error", "error": str(e)}

    def _wav_to_bytes(self, wav: np.ndarray, sr: int) -> bytes:
        buf = io.BytesIO()
        if wav.dtype != np.float32:
            wav = wav.astype(np.float32)
        sf.write(buf, wav, sr, format="WAV")
        buf.seek(0)
        return buf.read()


# ---------------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------------
def main():
    bridge = VoxCPMBridge()

    while not _shutting_down:
        try:
            msg = _read_frame()
        except EOFError:
            sys.stderr.write("stdin closed, exiting\n")
            sys.stderr.flush()
            break
        except Exception as e:
            if not _send_error(f"Frame read error: {e}"):
                break
            continue

        msg_type = msg.get("type")

        try:
            if msg_type == "init":
                _send(bridge.init_model(msg))

            elif msg_type == "generate":
                _send(bridge.generate(msg))

            elif msg_type == "generate_streaming":
                bridge.generate_streaming(msg)

            elif msg_type == "load_lora":
                _send(bridge.load_lora(msg))

            elif msg_type == "unload_lora":
                _send(bridge.unload_lora())

            elif msg_type == "ping":
                _send({"status": "ok", "message": "pong"})

            else:
                _send_error(f"Unknown request type: {msg_type}")

        except Exception as e:
            _send_error(f"Unhandled error: {e}")


if __name__ == "__main__":
    main()