Skip to main content

native/whisper_cpp_native/src/lib.rs

//! Rustler NIF wrapping whisper.cpp via the `whisper-rs` crate.
//!
//! Every entry point returns `{:ok, value}` or
//! `{:error, %{type, message, details}}`; PCM input is little-endian
//! IEEE-754 `f32` mono at 16 kHz.

#![deny(unsafe_code)]

use std::collections::HashMap;
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};

use parking_lot::Mutex;
use rustler::types::binary::Binary;
use rustler::{Encoder, Env, LocalPid, NifMap, ResourceArc, Term};
use whisper_rs::{WhisperContext, WhisperContextParameters};

mod errors;
mod transcribe;

use errors::kind_from_chain;
use transcribe::{SegmentResult, TranscribeRequest, TranscriptionResult, WordResult};

#[allow(missing_docs)]
mod atoms {
    rustler::atoms! { ok, error }
}
use atoms::{error, ok};

/// `Some(label)` when this build was compiled with a GPU cargo feature.
/// At most one GPU backend is active per artefact.
const GPU_BACKEND: Option<&str> = if cfg!(feature = "cuda") {
    Some("cuda")
} else if cfg!(feature = "hipblas") {
    Some("hipblas")
} else if cfg!(feature = "vulkan") {
    Some("vulkan")
} else if cfg!(feature = "metal") {
    Some("metal")
} else if cfg!(feature = "coreml") {
    Some("coreml")
} else if cfg!(feature = "intel-sycl") {
    Some("intel_sycl")
} else {
    None
};

#[derive(Debug, NifMap)]
struct NativeError {
    r#type: String,
    message: String,
    details: HashMap<String, String>,
}

impl NativeError {
    fn new(type_name: &str, message: impl Into<String>) -> Self {
        Self {
            r#type: type_name.to_owned(),
            message: message.into(),
            details: HashMap::new(),
        }
    }

    fn with_detail(mut self, key: &str, value: impl Into<String>) -> Self {
        self.details.insert(key.to_owned(), value.into());
        self
    }
}

impl From<anyhow::Error> for NativeError {
    fn from(err: anyhow::Error) -> Self {
        let kind = kind_from_chain(&err).unwrap_or("inference_error");
        NativeError::new(kind, format!("{err:#}"))
    }
}

/// Opaque BEAM resource holding a loaded whisper.cpp context. The mutex
/// only wraps the brief `create_state()` step; inference itself runs
/// without it. See `transcribe::transcribe_one`.
struct WhisperResource {
    ctx: Mutex<WhisperContext>,
    sampling_rate: usize,
    multilingual: bool,
    n_vocab: usize,
    /// `<|endoftext|>` token id - text tokens occupy `[0, token_eot)`,
    /// everything above is timestamp / language / control. Read from the
    /// loaded model at load time so the boundary stays correct across
    /// checkpoint variants (en-only vs multilingual vs future vocabs).
    token_eot: u32,
    device: &'static str,
}

impl rustler::Resource for WhisperResource {}

/// Cooperative cancellation flag. whisper.cpp polls the abort callback
/// between encoder/decoder steps and returns early when set.
pub(crate) struct AbortHandle {
    pub(crate) flag: Arc<AtomicBool>,
}

impl rustler::Resource for AbortHandle {}

#[derive(NifMap)]
struct LoadOpts {
    device: Option<String>,
}

#[derive(NifMap)]
struct TranscribeOpts {
    language: Option<String>,
    translate: Option<bool>,
    initial_prompt: Option<String>,
    word_timestamps: Option<bool>,
    beam_size: Option<u32>,
    best_of: Option<u32>,
    temperature: Option<f32>,
    n_threads: Option<u32>,
    n_max_text_ctx: Option<u32>,
    offset_ms: Option<u32>,
    duration_ms: Option<u32>,
    no_speech_thold: Option<f32>,
    logprob_thold: Option<f32>,
    suppress_blank: Option<bool>,
    suppress_non_speech_tokens: Option<bool>,
    single_segment: Option<bool>,
    print_progress: Option<bool>,
}

#[derive(NifMap)]
struct ModelInfo {
    sampling_rate: usize,
    multilingual: bool,
    n_vocab: usize,
    device: String,
}

#[derive(NifMap)]
struct AvailableDevices {
    backends: Vec<String>,
    gpu_supported: bool,
}

#[derive(NifMap)]
struct NifWord {
    text: String,
    start: f32,
    end: f32,
    probability: f32,
}

#[derive(NifMap)]
struct NifSegment {
    text: String,
    start: f32,
    end: f32,
    no_speech_prob: f32,
    avg_logprob: f32,
    tokens: Vec<u32>,
    words: Option<Vec<NifWord>>,
}

#[derive(NifMap)]
struct NifTranscription {
    language: String,
    duration_s: f32,
    segments: Vec<NifSegment>,
}

impl From<WordResult> for NifWord {
    fn from(w: WordResult) -> Self {
        Self {
            text: w.text,
            start: w.start,
            end: w.end,
            probability: w.probability,
        }
    }
}

impl From<SegmentResult> for NifSegment {
    fn from(s: SegmentResult) -> Self {
        Self {
            text: s.text,
            start: s.start,
            end: s.end,
            no_speech_prob: s.no_speech_prob,
            avg_logprob: s.avg_logprob,
            tokens: s.tokens,
            words: s
                .words
                .map(|ws| ws.into_iter().map(NifWord::from).collect()),
        }
    }
}

impl From<TranscriptionResult> for NifTranscription {
    fn from(t: TranscriptionResult) -> Self {
        Self {
            language: t.language,
            duration_s: t.duration_s,
            segments: t.segments.into_iter().map(NifSegment::from).collect(),
        }
    }
}

fn run_with_panic_protection<T, F>(f: F) -> Result<T, NativeError>
where
    F: FnOnce() -> Result<T, NativeError>,
{
    catch_unwind(AssertUnwindSafe(f)).unwrap_or_else(|panic_info| {
        let message = panic_info
            .downcast_ref::<String>()
            .map(String::as_str)
            .or_else(|| panic_info.downcast_ref::<&str>().copied())
            .unwrap_or("unknown panic");
        Err(NativeError::new("nif_panic", message))
    })
}

fn encode_result<T: Encoder>(env: Env<'_>, result: Result<T, NativeError>) -> Term<'_> {
    match result {
        Ok(value) => (ok(), value).encode(env),
        Err(err) => (error(), err).encode(env),
    }
}

fn resolve_device(requested: Option<&str>) -> Result<(bool, &'static str), NativeError> {
    let lowered = requested.map(str::to_ascii_lowercase);
    match lowered.as_deref() {
        None | Some("auto") => match GPU_BACKEND {
            Some(label) => Ok((true, label)),
            None => Ok((false, "cpu")),
        },
        Some("cpu") => Ok((false, "cpu")),
        Some(other) if Some(other) == GPU_BACKEND => Ok((true, GPU_BACKEND.unwrap())),
        Some(other) => Err(NativeError::new(
            "invalid_request",
            "requested device backend is not enabled in this NIF artefact",
        )
        .with_detail("requested", other)
        .with_detail("enabled", GPU_BACKEND.map_or("cpu", |b| b).to_owned())),
    }
}

/// Reports the active backends compiled into this NIF artefact.
#[rustler::nif]
fn nif_available_devices(env: Env<'_>) -> Term<'_> {
    let result = run_with_panic_protection(|| {
        let mut backends = vec!["cpu".to_owned()];
        if let Some(b) = GPU_BACKEND {
            backends.push(b.to_owned());
        }
        Ok(AvailableDevices {
            backends,
            gpu_supported: GPU_BACKEND.is_some(),
        })
    });
    encode_result(env, result)
}

/// Loads a GGML / GGUF whisper.cpp model file.
#[rustler::nif(schedule = "DirtyCpu")]
#[allow(clippy::needless_pass_by_value)]
fn nif_load_model(env: Env<'_>, path: String, opts: LoadOpts) -> Term<'_> {
    let result = run_with_panic_protection(|| {
        let path_buf = PathBuf::from(&path);
        if !path_buf.is_file() {
            return Err(
                NativeError::new("invalid_request", "model path is not a regular file")
                    .with_detail("path", path.clone()),
            );
        }

        let (use_gpu, device_label) = resolve_device(opts.device.as_deref())?;

        let mut ctx_params = WhisperContextParameters::default();
        ctx_params.use_gpu(use_gpu);

        let ctx = WhisperContext::new_with_params(&path_buf, ctx_params).map_err(|reason| {
            NativeError::new("load_error", "failed to load whisper.cpp model")
                .with_detail("reason", reason.to_string())
                .with_detail("path", path.clone())
                .with_detail("device", device_label.to_owned())
        })?;

        // whisper.cpp's published checkpoints all run at 16 kHz; the C
        // API does not expose the rate.
        let sampling_rate = 16_000_usize;
        let multilingual = ctx.is_multilingual();
        let n_vocab = usize::try_from(ctx.n_vocab()).unwrap_or(0);
        let token_eot = u32::try_from(ctx.token_eot()).unwrap_or(u32::MAX);

        Ok(ResourceArc::new(WhisperResource {
            ctx: Mutex::new(ctx),
            sampling_rate,
            multilingual,
            n_vocab,
            token_eot,
            device: device_label,
        }))
    });

    encode_result(env, result)
}

/// Returns metadata cached at load time.
#[rustler::nif]
#[allow(clippy::needless_pass_by_value)]
fn nif_model_info(env: Env<'_>, model: ResourceArc<WhisperResource>) -> Term<'_> {
    let result = run_with_panic_protection(|| {
        Ok(ModelInfo {
            sampling_rate: model.sampling_rate,
            multilingual: model.multilingual,
            n_vocab: model.n_vocab,
            device: model.device.to_owned(),
        })
    });
    encode_result(env, result)
}

fn decode_pcm_f32(bytes: &[u8]) -> Result<Vec<f32>, NativeError> {
    if bytes.is_empty() {
        return Err(NativeError::new(
            "invalid_request",
            "samples binary is empty",
        ));
    }

    if !bytes.len().is_multiple_of(4) {
        return Err(NativeError::new(
            "invalid_request",
            "samples binary length must be a multiple of 4 (f32)",
        )
        .with_detail("byte_length", bytes.len().to_string()));
    }

    Ok(bytes
        .chunks_exact(4)
        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
        .collect())
}

fn build_request(opts: TranscribeOpts) -> TranscribeRequest {
    TranscribeRequest {
        language: opts.language,
        translate: opts.translate.unwrap_or(false),
        initial_prompt: opts.initial_prompt,
        word_timestamps: opts.word_timestamps.unwrap_or(false),
        beam_size: opts.beam_size,
        best_of: opts.best_of,
        temperature: opts.temperature,
        n_threads: opts.n_threads,
        n_max_text_ctx: opts.n_max_text_ctx,
        offset_ms: opts.offset_ms,
        duration_ms: opts.duration_ms,
        no_speech_thold: opts.no_speech_thold,
        logprob_thold: opts.logprob_thold,
        suppress_blank: opts.suppress_blank,
        suppress_non_speech_tokens: opts.suppress_non_speech_tokens,
        single_segment: opts.single_segment,
        print_progress: opts.print_progress.unwrap_or(false),
    }
}

/// Transcribes a single PCM buffer. The buffer may be longer than the
/// 30 s Whisper window; whisper.cpp chunks internally. `abort` and
/// `progress_pid` are optional cooperative hooks.
#[rustler::nif(schedule = "DirtyCpu")]
#[allow(clippy::needless_pass_by_value)]
fn nif_transcribe<'a>(
    env: Env<'a>,
    model: ResourceArc<WhisperResource>,
    samples_bin: Binary,
    opts: TranscribeOpts,
    abort: Option<ResourceArc<AbortHandle>>,
    progress_pid: Option<LocalPid>,
) -> Term<'a> {
    let bytes = samples_bin.as_slice();
    let abort_flag = abort.map(|h| Arc::clone(&h.flag));
    let result = run_with_panic_protection(|| {
        let samples = decode_pcm_f32(bytes)?;
        let request = build_request(opts);
        let transcription = transcribe::transcribe_one(
            &model.ctx,
            &samples,
            &request,
            model.token_eot,
            abort_flag,
            progress_pid,
        )?;
        Ok(NifTranscription::from(transcription))
    });

    encode_result(env, result)
}

/// Allocates a fresh cooperative-cancellation flag.
#[rustler::nif]
fn nif_new_abort_handle() -> ResourceArc<AbortHandle> {
    ResourceArc::new(AbortHandle {
        flag: Arc::new(AtomicBool::new(false)),
    })
}

/// Signals an `AbortHandle`; in-flight transcribe calls observing this
/// flag will return early.
#[rustler::nif]
#[allow(clippy::needless_pass_by_value)]
fn nif_abort_handle_signal(handle: ResourceArc<AbortHandle>) -> rustler::Atom {
    handle.flag.store(true, Ordering::SeqCst);
    ok()
}

/// Reads the current state of an `AbortHandle` (`true` once signalled).
#[rustler::nif]
#[allow(clippy::needless_pass_by_value)]
fn nif_abort_handle_aborted(handle: ResourceArc<AbortHandle>) -> bool {
    handle.flag.load(Ordering::SeqCst)
}

fn on_load(env: Env<'_>, _info: Term<'_>) -> bool {
    env.register::<WhisperResource>().is_ok() && env.register::<AbortHandle>().is_ok()
}

rustler::init!("Elixir.WhisperCpp.Native", load = on_load);

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn decode_pcm_f32_round_trips_samples() {
        let mut bytes = Vec::new();
        for v in [0.0_f32, 1.0, -1.0, 0.5, -0.25] {
            bytes.extend_from_slice(&v.to_le_bytes());
        }
        let decoded = decode_pcm_f32(&bytes).unwrap();
        assert_eq!(decoded, vec![0.0, 1.0, -1.0, 0.5, -0.25]);
    }

    #[test]
    fn decode_pcm_f32_rejects_empty_input() {
        let err = decode_pcm_f32(&[]).unwrap_err();
        assert_eq!(err.r#type, "invalid_request");
    }

    #[test]
    fn decode_pcm_f32_rejects_misaligned_length() {
        let err = decode_pcm_f32(&[1, 2, 3]).unwrap_err();
        assert_eq!(err.r#type, "invalid_request");
        assert_eq!(
            err.details.get("byte_length").map(String::as_str),
            Some("3")
        );
    }

    #[test]
    fn resolve_device_auto_falls_back_to_cpu_without_gpu() {
        if GPU_BACKEND.is_none() {
            let (use_gpu, label) = resolve_device(None).unwrap();
            assert!(!use_gpu);
            assert_eq!(label, "cpu");

            let (use_gpu, label) = resolve_device(Some("auto")).unwrap();
            assert!(!use_gpu);
            assert_eq!(label, "cpu");
        }
    }

    #[test]
    fn resolve_device_cpu_works_in_any_build() {
        let (use_gpu, label) = resolve_device(Some("cpu")).unwrap();
        assert!(!use_gpu);
        assert_eq!(label, "cpu");
    }

    #[test]
    fn resolve_device_rejects_gpu_when_not_built_in() {
        if GPU_BACKEND.is_none() {
            assert!(resolve_device(Some("cuda")).is_err());
            assert!(resolve_device(Some("hipblas")).is_err());
        }
    }

    #[test]
    fn run_with_panic_protection_catches_string_panic() {
        let result: Result<(), _> = run_with_panic_protection(|| panic!("boom"));
        let err = result.unwrap_err();
        assert_eq!(err.r#type, "nif_panic");
        assert_eq!(err.message, "boom");
    }
}