priv/python/omnivoice_ex_bridge.py

#!/usr/bin/env python3
"""
OmniVoiceEx Bridge — Python bridge for Elixir OmniVoiceEx library.

Protocol: binary frames over stdin/stdout
  Frame: [4-byte big-endian total_length][msgpack-encoded message]

OmniVoice generates 24kHz audio. Audio is raw WAV bytes in msgpack.
"""

import sys
import io
import os
import struct
import signal

import msgpack
import numpy as np
import soundfile as sf

# ---------------------------------------------------------------------------
# Graceful shutdown
# ---------------------------------------------------------------------------
_shutting_down = False

def _handle_signal(signum, frame):
    global _shutting_down
    _shutting_down = True
    sys.stderr.write(f"Bridge received signal {signum}, shutting down\n")
    sys.stderr.flush()

signal.signal(signal.SIGPIPE, signal.SIG_DFL)
signal.signal(signal.SIGTERM, _handle_signal)

# ---------------------------------------------------------------------------
# PyTorch check
# ---------------------------------------------------------------------------
try:
    import torch
except ImportError as e:
    _write_frame(msgpack.dumps({"status": "error", "error": f"Missing dependency: {e}"}))
    sys.exit(1)

_original_torch_load = torch.load

def _patched_torch_load(f, map_location=None, **kwargs):
    if map_location is None:
        map_location = "cpu"
    return _original_torch_load(f, map_location=map_location, **kwargs)

torch.load = _patched_torch_load

# ---------------------------------------------------------------------------
# I/O helpers
# ---------------------------------------------------------------------------
def _read_exact(n: int) -> bytes:
    data = b""
    while len(data) < n:
        chunk = sys.stdin.buffer.read(n - len(data))
        if not chunk:
            raise EOFError("stdin closed")
        data += chunk
    return data

def _read_frame() -> dict:
    header = _read_exact(4)
    total_len = struct.unpack(">I", header)[0]
    payload = _read_exact(total_len - 4)
    return msgpack.loads(payload, raw=False)

def _write_frame(data: bytes) -> bool:
    try:
        frame_len = struct.pack(">I", len(data) + 4)
        sys.stdout.buffer.write(frame_len + data)
        sys.stdout.buffer.flush()
        return True
    except (BrokenPipeError, OSError):
        return False

def _send(msg: dict) -> bool:
    return _write_frame(msgpack.dumps(msg))

def _send_error(error: str) -> bool:
    return _send({"status": "error", "error": error})

# ---------------------------------------------------------------------------
# Resolve device
# ---------------------------------------------------------------------------
def _resolve_device(requested: str) -> str:
    req = (requested or "cuda").strip().lower()
    if req.startswith("cuda"):
        return req if torch.cuda.is_available() else "cpu"
    if req == "mps":
        has_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
        return "mps" if has_mps else "cpu"
    return "cpu"

# ---------------------------------------------------------------------------
# OmniVoice Bridge
# ---------------------------------------------------------------------------
class OmniVoiceBridge:
    def __init__(self):
        self.model = None
        self.device = None
        self.sample_rate = None

    def init_model(self, msg: dict) -> dict:
        try:
            from omnivoice import OmniVoice

            hf_model_id = msg.get("model", "k2-fsa/OmniVoice")
            dtype_str = msg.get("dtype", "float16")
            requested_device = msg.get("device", "cuda")

            self.device = _resolve_device(requested_device)
            device_map = self.device

            dtype = getattr(torch, dtype_str, torch.float16)

            sys.stderr.write(f"Loading {hf_model_id} on {device_map} ({dtype_str})...\n")
            sys.stderr.flush()

            self.model = OmniVoice.from_pretrained(
                hf_model_id,
                device_map=device_map,
                dtype=dtype,
            )

            self.sample_rate = self.model.sampling_rate or 24000
            sys.stderr.write(f"Loaded. device={self.device} sr={self.sample_rate}\n")
            sys.stderr.flush()

            return {"status": "ok", "device": self.device, "sample_rate": self.sample_rate}

        except Exception as e:
            return {"status": "error", "error": str(e)}

    def generate(self, msg: dict) -> dict:
        if self.model is None:
            return {"status": "error", "error": "Model not initialized"}

        try:
            text = msg["text"]
            gen_kwargs = {}

            # Voice cloning
            ref_audio = msg.get("ref_audio")
            ref_text = msg.get("ref_text")
            if ref_audio:
                gen_kwargs["ref_audio"] = ref_audio
                if ref_text:
                    gen_kwargs["ref_text"] = ref_text

            # Voice design
            instruct = msg.get("instruct")
            if instruct:
                gen_kwargs["instruct"] = instruct

            # Language
            language = msg.get("language")
            if language:
                gen_kwargs["language"] = language

            # Duration / speed
            duration = msg.get("duration")
            if duration:
                gen_kwargs["duration"] = duration
            speed = msg.get("speed")
            if speed:
                gen_kwargs["speed"] = speed

            # Generation config
            num_step = msg.get("num_step", 32)
            guidance_scale = msg.get("guidance_scale", 2.0)
            gen_kwargs["num_step"] = num_step
            gen_kwargs["guidance_scale"] = guidance_scale

            audio_list = self.model.generate(text, **gen_kwargs)
            wav = audio_list[0]  # First (and usually only) output

            audio_bytes = self._wav_to_bytes(wav, self.sample_rate)

            return {
                "status": "ok",
                "audio": audio_bytes,
                "sample_rate": self.sample_rate,
                "duration": len(wav) / self.sample_rate,
            }

        except Exception as e:
            return {"status": "error", "error": str(e)}

    def _wav_to_bytes(self, wav: np.ndarray, sr: int) -> bytes:
        buf = io.BytesIO()
        if wav.dtype != np.float32:
            wav = wav.astype(np.float32)
        sf.write(buf, wav, sr, format="WAV")
        buf.seek(0)
        return buf.read()


# ---------------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------------
def main():
    bridge = OmniVoiceBridge()

    while not _shutting_down:
        try:
            msg = _read_frame()
        except EOFError:
            sys.stderr.write("stdin closed, exiting\n")
            sys.stderr.flush()
            break
        except Exception as e:
            if not _send_error(f"Frame read error: {e}"):
                break
            continue

        msg_type = msg.get("type")

        try:
            if msg_type == "init":
                _send(bridge.init_model(msg))
            elif msg_type == "generate":
                _send(bridge.generate(msg))
            elif msg_type == "ping":
                _send({"status": "ok", "message": "pong"})
            else:
                _send_error(f"Unknown request type: {msg_type}")
        except Exception as e:
            _send_error(f"Unhandled error: {e}")


if __name__ == "__main__":
    main()