Skip to main content

native/whisper_ct2_native/src/lib.rs

//! Rustler NIF wrapping `ct2rs::sys::Whisper`. Every entry point returns
//! `{:ok, value}` or `{:error, %{type, message, details}}`; PCM input is
//! little-endian IEEE-754 `f32` mono at 16 kHz, chunked into Whisper
//! 30 s windows internally.

// `deny` rather than `forbid`: ndarray's `s!` macro expands to code that
// uses `#[allow(unsafe_code)]` internally. Our own code stays unsafe-free
// (cargo will still error on any unsafe block we write).
#![deny(unsafe_code)]

use std::collections::HashMap;
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::path::PathBuf;

use ct2rs::sys::{ComputeType, Config, Device, Whisper, WhisperOptions, get_device_count};
use ct2rs::tokenizers::hf;
use parking_lot::Mutex;
use rustler::types::binary::Binary;
use rustler::{Encoder, Env, NifMap, ResourceArc, Term};

mod align;
mod errors;
mod preprocessor;
mod tokens;
mod transcribe;

use errors::kind_from_chain;
use preprocessor::Preprocessor;
use tokens::SpecialTokens;
use transcribe::{SegmentResult, TranscribeRequest, TranscriptionResult, WordResult};

#[allow(missing_docs)]
mod atoms {
    rustler::atoms! {
        ok,
        error,
    }
}
use atoms::{error, ok};

/// `true` when this build was compiled with any CUDA cargo feature.
const CUDA_SUPPORTED: bool = cfg!(any(feature = "cuda", feature = "cuda-dynamic"));

/// Structured error returned to Elixir as a `NifMap`. The Elixir side maps
/// `type` to a `WhisperCt2.Error.reason` atom.
#[derive(Debug, NifMap)]
struct NativeError {
    r#type: String,
    message: String,
    details: HashMap<String, String>,
}

impl NativeError {
    fn new(type_name: &str, message: impl Into<String>) -> Self {
        Self {
            r#type: type_name.to_owned(),
            message: message.into(),
            details: HashMap::new(),
        }
    }

    fn with_detail(mut self, key: &str, value: impl Into<String>) -> Self {
        self.details.insert(key.to_owned(), value.into());
        self
    }
}

impl From<anyhow::Error> for NativeError {
    fn from(err: anyhow::Error) -> Self {
        // Recover the category attached at the originating call-site via
        // `errors::invalid_request` / `errors::runtime_error`. Uncategorised
        // errors fall back to `inference_error` — that path is for real
        // CTranslate2 failures bubbled up through `anyhow!`.
        let kind = kind_from_chain(&err).unwrap_or("inference_error");
        NativeError::new(kind, format!("{err:#}"))
    }
}

/// Opaque BEAM resource holding a loaded Whisper model plus everything we
/// need to feed `ct2rs::sys::Whisper` (mel filterbank, tokenizer, special
/// token IDs).
///
/// `sys::Whisper` is serialised through a [`Mutex`] for inference. The
/// CTranslate2 engine itself is thread-safe; load multiple models if you
/// need parallel inference.
struct WhisperResource {
    whisper: Mutex<Whisper>,
    tokenizer: hf::Tokenizer,
    preprocessor: Preprocessor,
    specials: SpecialTokens,
    sampling_rate: usize,
    n_samples: usize,
    multilingual: bool,
    device: &'static str,
    compute_type: &'static str,
}

impl rustler::Resource for WhisperResource {}

#[derive(NifMap)]
struct LoadOpts {
    device: Option<String>,
    compute_type: Option<String>,
    device_indices: Option<Vec<i32>>,
    num_threads_per_replica: Option<u32>,
    max_queued_batches: Option<i32>,
    cpu_core_offset: Option<i32>,
}

#[derive(NifMap)]
struct TranscribeOpts {
    language: Option<String>,
    initial_prompt: Option<String>,
    prefix: Option<String>,
    word_timestamps: Option<bool>,
    with_timestamps: Option<bool>,
    beam_size: Option<u32>,
    patience: Option<f32>,
    length_penalty: Option<f32>,
    repetition_penalty: Option<f32>,
    no_repeat_ngram_size: Option<u32>,
    sampling_temperature: Option<f32>,
    sampling_topk: Option<u32>,
    suppress_blank: Option<bool>,
    max_length: Option<u32>,
    num_hypotheses: Option<u32>,
    max_initial_timestamp_index: Option<u32>,
    suppress_tokens: Option<Vec<i32>>,
}

#[derive(NifMap)]
struct ModelInfo {
    sampling_rate: usize,
    n_samples: usize,
    multilingual: bool,
    device: String,
    compute_type: String,
}

#[derive(NifMap)]
struct AvailableDevices {
    cpu: i32,
    cuda: i32,
    cuda_supported: bool,
}

#[derive(NifMap)]
struct NifWord {
    text: String,
    start: f32,
    end: f32,
    probability: f32,
}

#[derive(NifMap)]
struct NifSegment {
    text: String,
    start: f32,
    end: f32,
    no_speech_prob: f32,
    avg_logprob: f32,
    tokens: Vec<u32>,
    words: Option<Vec<NifWord>>,
}

#[derive(NifMap)]
struct NifTranscription {
    language: String,
    duration_s: f32,
    segments: Vec<NifSegment>,
}

impl From<WordResult> for NifWord {
    fn from(w: WordResult) -> Self {
        Self {
            text: w.text,
            start: w.start,
            end: w.end,
            probability: w.probability,
        }
    }
}

impl From<SegmentResult> for NifSegment {
    fn from(s: SegmentResult) -> Self {
        Self {
            text: s.text,
            start: s.start,
            end: s.end,
            no_speech_prob: s.no_speech_prob,
            avg_logprob: s.avg_logprob,
            tokens: s.tokens,
            words: s
                .words
                .map(|ws| ws.into_iter().map(NifWord::from).collect()),
        }
    }
}

impl From<TranscriptionResult> for NifTranscription {
    fn from(t: TranscriptionResult) -> Self {
        Self {
            language: t.language,
            duration_s: t.duration_s,
            segments: t.segments.into_iter().map(NifSegment::from).collect(),
        }
    }
}

fn run_with_panic_protection<T, F>(f: F) -> Result<T, NativeError>
where
    F: FnOnce() -> Result<T, NativeError>,
{
    catch_unwind(AssertUnwindSafe(f)).unwrap_or_else(|panic_info| {
        let message = panic_info
            .downcast_ref::<String>()
            .map(String::as_str)
            .or_else(|| panic_info.downcast_ref::<&str>().copied())
            .unwrap_or("unknown panic");
        Err(NativeError::new("nif_panic", message))
    })
}

fn encode_result<T: Encoder>(env: Env<'_>, result: Result<T, NativeError>) -> Term<'_> {
    match result {
        Ok(value) => (ok(), value).encode(env),
        Err(err) => (error(), err).encode(env),
    }
}

fn parse_device(s: &str) -> Result<Device, NativeError> {
    match s.to_ascii_lowercase().as_str() {
        "cpu" => Ok(Device::CPU),
        "cuda" | "gpu" => Ok(Device::CUDA),
        other => Err(NativeError::new("invalid_request", "unknown device")
            .with_detail("device", other)
            .with_detail("supported", "cpu,cuda")),
    }
}

#[inline]
fn device_label(d: Device) -> &'static str {
    match d {
        Device::CPU => "cpu",
        Device::CUDA => "cuda",
        _ => "unknown",
    }
}

fn parse_compute_type(s: &str) -> Result<ComputeType, NativeError> {
    match s.to_ascii_lowercase().as_str() {
        "default" => Ok(ComputeType::DEFAULT),
        "auto" => Ok(ComputeType::AUTO),
        "float32" | "fp32" => Ok(ComputeType::FLOAT32),
        "float16" | "fp16" => Ok(ComputeType::FLOAT16),
        "bfloat16" | "bf16" => Ok(ComputeType::BFLOAT16),
        "int8" => Ok(ComputeType::INT8),
        "int8_float32" => Ok(ComputeType::INT8_FLOAT32),
        "int8_float16" => Ok(ComputeType::INT8_FLOAT16),
        "int8_bfloat16" => Ok(ComputeType::INT8_BFLOAT16),
        "int16" => Ok(ComputeType::INT16),
        other => Err(NativeError::new("invalid_request", "unknown compute_type")
            .with_detail("compute_type", other)),
    }
}

#[inline]
fn compute_type_label(c: ComputeType) -> &'static str {
    match c {
        ComputeType::DEFAULT => "default",
        ComputeType::AUTO => "auto",
        ComputeType::FLOAT32 => "float32",
        ComputeType::FLOAT16 => "float16",
        ComputeType::BFLOAT16 => "bfloat16",
        ComputeType::INT8 => "int8",
        ComputeType::INT8_FLOAT32 => "int8_float32",
        ComputeType::INT8_FLOAT16 => "int8_float16",
        ComputeType::INT8_BFLOAT16 => "int8_bfloat16",
        ComputeType::INT16 => "int16",
        _ => "unknown",
    }
}

/// Resolves `:auto` to CUDA when the build has CUDA support and a device is
/// visible. Explicit `:cuda` returns an error if either condition fails.
fn resolve_device(requested: Option<&str>) -> Result<Device, NativeError> {
    let lowered = requested.map(str::to_ascii_lowercase);
    match lowered.as_deref() {
        None | Some("auto") => {
            if CUDA_SUPPORTED && get_device_count(Device::CUDA) > 0 {
                Ok(Device::CUDA)
            } else {
                Ok(Device::CPU)
            }
        }
        Some(other) => {
            let device = parse_device(other)?;
            if matches!(device, Device::CUDA) {
                if !CUDA_SUPPORTED {
                    return Err(NativeError::new(
                        "invalid_request",
                        "this build of whisper_ct2 was not compiled with GPU support",
                    )
                    .with_detail(
                        "rebuild_with",
                        "WHISPER_CT2_FEATURES=cuda-dynamic mix compile  # NVIDIA",
                    ));
                }
                if get_device_count(Device::CUDA) == 0 {
                    return Err(NativeError::new(
                        "invalid_request",
                        "no CUDA devices visible to CTranslate2",
                    ));
                }
            }
            Ok(device)
        }
    }
}

/// Reports `CTranslate2` device support for this build.
#[rustler::nif]
fn nif_available_devices(env: Env<'_>) -> Term<'_> {
    let result = run_with_panic_protection(|| {
        let cuda = if CUDA_SUPPORTED {
            get_device_count(Device::CUDA)
        } else {
            0
        };
        Ok(AvailableDevices {
            cpu: get_device_count(Device::CPU),
            cuda,
            cuda_supported: CUDA_SUPPORTED,
        })
    });
    encode_result(env, result)
}

/// Loads a CT2-converted Whisper directory.
#[rustler::nif(schedule = "DirtyCpu")]
#[allow(clippy::needless_pass_by_value)] // Rustler decodes nif args by value.
fn nif_load_model(env: Env<'_>, path: String, opts: LoadOpts) -> Term<'_> {
    let result = run_with_panic_protection(|| {
        let path_buf = PathBuf::from(&path);
        if !path_buf.is_dir() {
            return Err(
                NativeError::new("invalid_request", "model path is not a directory")
                    .with_detail("path", path.clone()),
            );
        }

        let device = resolve_device(opts.device.as_deref())?;
        let compute_type = opts
            .compute_type
            .as_deref()
            .map_or_else(|| Ok(ComputeType::default()), parse_compute_type)?;

        let device_indices = match opts.device_indices {
            Some(v) if v.is_empty() => {
                return Err(NativeError::new(
                    "invalid_request",
                    "device_indices must be non-empty",
                ));
            }
            Some(v) => v,
            None => vec![0],
        };

        let mut config = Config {
            device,
            compute_type,
            device_indices,
            ..Config::default()
        };

        if let Some(v) = opts.num_threads_per_replica {
            config.num_threads_per_replica = v as usize;
        }
        if let Some(v) = opts.max_queued_batches {
            config.max_queued_batches = v;
        }
        if let Some(v) = opts.cpu_core_offset {
            config.cpu_core_offset = v;
        }

        let whisper = Whisper::new(&path_buf, config).map_err(|reason| {
            NativeError::new("load_error", "failed to load Whisper model")
                .with_detail("reason", reason.to_string())
                .with_detail("path", path.clone())
                .with_detail("device", device_label(device))
                .with_detail("compute_type", compute_type_label(compute_type))
        })?;

        let tokenizer = hf::Tokenizer::new(&path_buf).map_err(|reason| {
            NativeError::new("load_error", "failed to load tokenizer.json")
                .with_detail("reason", reason.to_string())
                .with_detail("path", path.clone())
        })?;
        let specials = SpecialTokens::resolve(&tokenizer).map_err(|reason| {
            NativeError::new(
                "load_error",
                "tokenizer is missing required Whisper special tokens",
            )
            .with_detail("reason", reason.to_string())
        })?;
        let preprocessor = Preprocessor::load(&path_buf).map_err(|reason| {
            NativeError::new("load_error", "failed to load preprocessor_config.json")
                .with_detail("reason", reason.to_string())
                .with_detail("path", path.clone())
        })?;

        let sampling_rate = preprocessor.sampling_rate;
        let n_samples = preprocessor.n_samples;
        let multilingual = whisper.is_multilingual();

        Ok(ResourceArc::new(WhisperResource {
            whisper: Mutex::new(whisper),
            tokenizer,
            preprocessor,
            specials,
            sampling_rate,
            n_samples,
            multilingual,
            device: device_label(device),
            compute_type: compute_type_label(compute_type),
        }))
    });

    encode_result(env, result)
}

/// Returns metadata cached at load time.
#[rustler::nif]
#[allow(clippy::needless_pass_by_value)] // Rustler decodes nif args by value.
fn nif_model_info(env: Env<'_>, model: ResourceArc<WhisperResource>) -> Term<'_> {
    let result = run_with_panic_protection(|| {
        Ok(ModelInfo {
            sampling_rate: model.sampling_rate,
            n_samples: model.n_samples,
            multilingual: model.multilingual,
            device: model.device.to_owned(),
            compute_type: model.compute_type.to_owned(),
        })
    });
    encode_result(env, result)
}

fn decode_pcm_f32(bytes: &[u8]) -> Result<Vec<f32>, NativeError> {
    if bytes.len() % 4 != 0 {
        return Err(NativeError::new(
            "invalid_request",
            "samples binary length must be a multiple of 4 (f32)",
        )
        .with_detail("byte_length", bytes.len().to_string()));
    }

    let samples = bytes
        .chunks_exact(4)
        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
        .collect();

    Ok(samples)
}

// Verified 2026-05: every field exposed in `TranscribeOpts` whose ct2rs
// default we inherit (`beam_size=5`, `patience=1.0`, `length_penalty=1.0`,
// `repetition_penalty=1.0`, `no_repeat_ngram_size=0`, `max_length=448`,
// `sampling_topk=1` (greedy), `sampling_temperature=1.0`,
// `num_hypotheses=1`, `max_initial_timestamp_index=50` (= 1.0 s),
// `suppress_blank=true`, `suppress_tokens=[-1]`) matches the
// `faster_whisper.TranscriptionOptions` default for the same field. The
// one apparent mismatch — faster-whisper's `temperature=[0.0, 0.2, ...]`
// fallback list vs ct2rs's `1.0` — is a no-op at `sampling_topk=1`
// because beam/greedy decoding ignores temperature.
//
// If you bump the ct2rs version, re-run the comparison before assuming
// the inheritance still holds.
fn build_request(opts: &TranscribeOpts) -> TranscribeRequest {
    let mut whisper_opts = WhisperOptions::default();
    if let Some(v) = opts.beam_size {
        whisper_opts.beam_size = v as usize;
    }
    if let Some(v) = opts.patience {
        whisper_opts.patience = v;
    }
    if let Some(v) = opts.length_penalty {
        whisper_opts.length_penalty = v;
    }
    if let Some(v) = opts.repetition_penalty {
        whisper_opts.repetition_penalty = v;
    }
    if let Some(v) = opts.no_repeat_ngram_size {
        whisper_opts.no_repeat_ngram_size = v as usize;
    }
    if let Some(v) = opts.sampling_temperature {
        whisper_opts.sampling_temperature = v;
    }
    if let Some(v) = opts.sampling_topk {
        whisper_opts.sampling_topk = v as usize;
    }
    if let Some(v) = opts.suppress_blank {
        whisper_opts.suppress_blank = v;
    }
    if let Some(v) = opts.max_length {
        whisper_opts.max_length = v as usize;
    }
    if let Some(v) = opts.num_hypotheses {
        whisper_opts.num_hypotheses = v as usize;
    }
    if let Some(v) = opts.max_initial_timestamp_index {
        whisper_opts.max_initial_timestamp_index = v as usize;
    }
    if let Some(ref tokens) = opts.suppress_tokens {
        whisper_opts.suppress_tokens.clone_from(tokens);
    }

    TranscribeRequest {
        language: opts.language.clone(),
        with_timestamps: opts.with_timestamps.unwrap_or(true),
        initial_prompt: opts.initial_prompt.clone(),
        prefix: opts.prefix.clone(),
        word_timestamps: opts.word_timestamps.unwrap_or(false),
        options: whisper_opts,
    }
}

/// Transcribes a single PCM buffer. The buffer may be longer than the 30 s
/// Whisper window; chunks are batched internally.
#[rustler::nif(schedule = "DirtyCpu")]
#[allow(clippy::needless_pass_by_value)] // Rustler decodes nif args by value.
fn nif_transcribe<'a>(
    env: Env<'a>,
    model: ResourceArc<WhisperResource>,
    samples_bin: Binary,
    opts: TranscribeOpts,
) -> Term<'a> {
    let bytes = samples_bin.as_slice();
    let result = run_with_panic_protection(|| {
        let samples = decode_pcm_f32(bytes)?;
        let request = build_request(&opts);

        // `parking_lot::Mutex` cannot be poisoned, so a panic under the
        // lock does not brick the loaded model for the rest of its life.
        let whisper = model.whisper.lock();

        let transcription = transcribe::transcribe_one(
            &whisper,
            &model.tokenizer,
            &model.preprocessor,
            &model.specials,
            &samples,
            &request,
        )?;
        Ok(NifTranscription::from(transcription))
    });

    encode_result(env, result)
}

/// Transcribes a list of PCM buffers in one batched `generate` call. The
/// caller passes a `Vec<Binary>`; each buffer is decoded, chunked, and
/// stacked so the encoder runs once across every chunk in the batch.
#[rustler::nif(schedule = "DirtyCpu")]
#[allow(clippy::needless_pass_by_value)] // Rustler decodes nif args by value.
fn nif_transcribe_batch<'a>(
    env: Env<'a>,
    model: ResourceArc<WhisperResource>,
    samples_bins: Vec<Binary>,
    opts: TranscribeOpts,
) -> Term<'a> {
    let bytes_per_audio: Vec<&[u8]> = samples_bins.iter().map(Binary::as_slice).collect();
    let result = run_with_panic_protection(|| {
        let decoded: Vec<Vec<f32>> = bytes_per_audio
            .iter()
            .map(|b| decode_pcm_f32(b))
            .collect::<Result<_, _>>()?;
        let request = build_request(&opts);

        let audios: Vec<&[f32]> = decoded.iter().map(Vec::as_slice).collect();

        let whisper = model.whisper.lock();

        let transcriptions = transcribe::transcribe_many(
            &whisper,
            &model.tokenizer,
            &model.preprocessor,
            &model.specials,
            &audios,
            &request,
        )?;
        Ok(transcriptions
            .into_iter()
            .map(NifTranscription::from)
            .collect::<Vec<_>>())
    });

    encode_result(env, result)
}

fn on_load(env: Env<'_>, _info: Term<'_>) -> bool {
    env.register::<WhisperResource>().is_ok()
}

rustler::init!("Elixir.WhisperCt2.Native", load = on_load);

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn decode_pcm_f32_round_trips_samples() {
        let mut bytes = Vec::new();
        for v in [0.0_f32, 1.0, -1.0, 0.5, -0.25] {
            bytes.extend_from_slice(&v.to_le_bytes());
        }
        let decoded = decode_pcm_f32(&bytes).unwrap();
        assert_eq!(decoded, vec![0.0, 1.0, -1.0, 0.5, -0.25]);
    }

    #[test]
    fn decode_pcm_f32_rejects_misaligned_length() {
        let err = decode_pcm_f32(&[1, 2, 3]).unwrap_err();
        assert_eq!(err.r#type, "invalid_request");
        assert_eq!(
            err.details.get("byte_length").map(String::as_str),
            Some("3")
        );
    }

    #[test]
    fn parse_device_accepts_canonical_names() {
        assert!(matches!(parse_device("cpu").unwrap(), Device::CPU));
        assert!(matches!(parse_device("CUDA").unwrap(), Device::CUDA));
        assert!(matches!(parse_device("gpu").unwrap(), Device::CUDA));
        assert!(parse_device("tpu").is_err());
    }

    #[test]
    fn parse_compute_type_accepts_aliases() {
        assert!(matches!(
            parse_compute_type("fp16").unwrap(),
            ComputeType::FLOAT16
        ));
        assert!(matches!(
            parse_compute_type("int8_float16").unwrap(),
            ComputeType::INT8_FLOAT16
        ));
        assert!(parse_compute_type("nibble").is_err());
    }

    #[test]
    fn resolve_device_auto_falls_back_to_cpu_without_gpu() {
        if !CUDA_SUPPORTED || get_device_count(Device::CUDA) == 0 {
            assert!(matches!(resolve_device(None).unwrap(), Device::CPU));
            assert!(matches!(resolve_device(Some("auto")).unwrap(), Device::CPU));
        }
    }

    #[test]
    fn resolve_device_rejects_cuda_when_unavailable() {
        if !CUDA_SUPPORTED {
            let err = resolve_device(Some("cuda")).unwrap_err();
            assert_eq!(err.r#type, "invalid_request");
            assert!(err.message.contains("GPU"));
        }
    }

    #[test]
    fn run_with_panic_protection_catches_string_panic() {
        let result: Result<(), _> = run_with_panic_protection(|| panic!("boom"));
        let err = result.unwrap_err();
        assert_eq!(err.r#type, "nif_panic");
        assert_eq!(err.message, "boom");
    }

    #[test]
    fn run_with_panic_protection_passes_through_ok_and_err() {
        let ok_result = run_with_panic_protection(|| Ok::<_, NativeError>(42));
        assert_eq!(ok_result.unwrap(), 42);

        let err_result: Result<(), _> =
            run_with_panic_protection(|| Err(NativeError::new("load_error", "x")));
        assert_eq!(err_result.unwrap_err().r#type, "load_error");
    }
}