Skip to main content

native/whisper_ct2_native/src/transcribe.rs

//! Core transcription flow driving `ct2rs::sys::Whisper`. `transcribe_many`
//! stacks every chunk of every audio into one batched `encode`+`generate`
//! call so multi-chunk and multi-audio inputs share a single encoder pass;
//! `word_timestamps` adds one batched `align` on the same encoder output.

// `transcribe_many` is intentionally one long function — every step shares
// per-audio bookkeeping and splitting it costs clarity more than it saves.
#![allow(
    clippy::cast_precision_loss,
    clippy::cast_possible_truncation,
    clippy::cast_sign_loss,
    clippy::similar_names,
    clippy::too_many_lines
)]

use anyhow::{Context, Result, anyhow};
use ct2rs::sys::{Device, StorageView, Whisper, WhisperOptions};
use ct2rs::tokenizers::hf;

use crate::align::{ChunkAlignInput, DEFAULT_MEDIAN_FILTER_WIDTH, align_batch};
use crate::errors::{invalid_request, runtime_error};
use crate::preprocessor::Preprocessor;
use crate::tokens::{
    NO_TIMESTAMPS, PromptParts, SOT, STARTOFPREV, SpecialTokens, SubSegment, TRANSCRIBE,
    decode_ids, encode_plain, language_token, split_sub_segments, token_id,
};

/// Soft cap on the flat mel buffer (`total_chunks * n_mels * nb_max_frames`
/// f32 elements) to keep `transcribe_batch` from OOM-killing the BEAM on a
/// pathological caller. 2 GiB ≈ 537 M f32, well past any realistic batch:
/// at the standard tiny config (80 mel × 3000 frames × 4 B) one chunk is
/// 960 kB, so 2 GiB tolerates ~2200 chunks ≈ 18 h of audio in one call.
const MAX_FEATURE_BUFFER_BYTES: usize = 2 * 1024 * 1024 * 1024;

/// Request knobs for one transcribe call (applies to every audio in a batch).
pub(crate) struct TranscribeRequest {
    pub(crate) language: Option<String>,
    pub(crate) with_timestamps: bool,
    pub(crate) initial_prompt: Option<String>,
    pub(crate) prefix: Option<String>,
    pub(crate) word_timestamps: bool,
    pub(crate) options: WhisperOptions,
}

/// One word with absolute time span.
pub(crate) struct WordResult {
    pub(crate) text: String,
    pub(crate) start: f32,
    pub(crate) end: f32,
    pub(crate) probability: f32,
}

/// One `<|t_start|> text <|t_end|>` segment.
pub(crate) struct SegmentResult {
    pub(crate) text: String,
    pub(crate) start: f32,
    pub(crate) end: f32,
    pub(crate) no_speech_prob: f32,
    pub(crate) avg_logprob: f32,
    pub(crate) tokens: Vec<u32>,
    pub(crate) words: Option<Vec<WordResult>>,
}

/// Full transcription of one audio.
pub(crate) struct TranscriptionResult {
    pub(crate) language: String,
    pub(crate) duration_s: f32,
    pub(crate) segments: Vec<SegmentResult>,
}

/// Per-chunk state shared between the parse, align, and materialise
/// passes. Keeping all three fields together prevents the parallel-vec
/// indexing the original implementation had to do by hand.
#[derive(Default)]
struct ChunkState {
    sub_segments: Vec<SubSegment>,
    offset_s: f32,
    num_frames: usize,
}

/// Transcribes a single audio.
pub(crate) fn transcribe_one(
    whisper: &Whisper,
    tokenizer: &hf::Tokenizer,
    preprocessor: &Preprocessor,
    specials: &SpecialTokens,
    samples: &[f32],
    request: &TranscribeRequest,
) -> Result<TranscriptionResult> {
    let results = transcribe_many(
        whisper,
        tokenizer,
        preprocessor,
        specials,
        &[samples],
        request,
    )?;
    results
        .into_iter()
        .next()
        .ok_or_else(|| runtime_error("transcribe_many returned no results"))
}

/// Transcribes multiple audios in a single batched `generate` call.
/// Language detection (when not pinned) runs per-audio on the first chunk.
pub(crate) fn transcribe_many(
    whisper: &Whisper,
    tokenizer: &hf::Tokenizer,
    preprocessor: &Preprocessor,
    specials: &SpecialTokens,
    audios: &[&[f32]],
    request: &TranscribeRequest,
) -> Result<Vec<TranscriptionResult>> {
    if audios.is_empty() {
        return Ok(Vec::new());
    }

    let mut per_audio_chunks: Vec<Vec<ndarray::Array2<f32>>> = Vec::with_capacity(audios.len());
    let mut per_audio_languages: Vec<String> = Vec::with_capacity(audios.len());
    let mut per_audio_duration: Vec<f32> = Vec::with_capacity(audios.len());

    for samples in audios {
        let chunks = preprocessor
            .build_chunks(samples)
            .context("building mel chunks")?;
        if chunks.is_empty() {
            return Err(invalid_request("audio produced no mel chunks"));
        }
        let duration_s = samples.len() as f32 / preprocessor.sampling_rate as f32;
        per_audio_duration.push(duration_s);
        per_audio_chunks.push(chunks);
    }

    let chunk_counts: Vec<usize> = per_audio_chunks.iter().map(Vec::len).collect();
    let chunk_offsets = compute_chunk_offsets(&chunk_counts);

    let multilingual = whisper.is_multilingual();
    if let Some(lang) = &request.language {
        reject_non_english_on_en_checkpoint(multilingual, lang)?;
        let token = language_token(tokenizer, lang)
            .map_err(|_| invalid_request(format!("invalid language code: {lang}")))?;
        per_audio_languages.extend(std::iter::repeat_n(token, audios.len()));
    } else if !multilingual {
        // English-only checkpoints (`*.en`) always transcribe English.
        // Skip `detect_language` and pin the result.
        per_audio_languages.extend(std::iter::repeat_n("<|en|>".to_owned(), audios.len()));
    } else {
        for chunks in &per_audio_chunks {
            let detected = detect_language(whisper, &chunks[0], preprocessor)?;
            per_audio_languages.push(detected);
        }
    }

    let total_chunks: usize = per_audio_chunks.iter().map(Vec::len).sum();
    let n_mels = preprocessor.feature_size;
    let chunk_length = preprocessor.nb_max_frames;

    // Guard against a pathological caller stacking enough audio that the
    // flat mel buffer would dwarf the BEAM's address space. The check
    // also catches usize overflow on the multiplication: if any of the
    // intermediate `checked_mul`s returns `None`, we treat the request
    // as oversized.
    let elements = total_chunks
        .checked_mul(n_mels)
        .and_then(|n| n.checked_mul(chunk_length));
    let bytes = elements.and_then(|n| n.checked_mul(std::mem::size_of::<f32>()));
    match bytes {
        Some(b) if b > MAX_FEATURE_BUFFER_BYTES => {
            return Err(invalid_request(format!(
                "batch feature buffer {b} bytes exceeds {MAX_FEATURE_BUFFER_BYTES} byte cap; \
                 split the input into smaller transcribe_batch calls"
            )));
        }
        None => {
            return Err(invalid_request(format!(
                "batch feature buffer size overflows usize \
                 (total_chunks={total_chunks}, n_mels={n_mels}, chunk_length={chunk_length})"
            )));
        }
        _ => {}
    }

    let mut flat: Vec<f32> = Vec::with_capacity(elements.expect("checked above"));
    for chunks in &per_audio_chunks {
        for chunk in chunks {
            let slice = chunk
                .as_slice()
                .ok_or_else(|| runtime_error("mel chunk not contiguous"))?;
            if slice.iter().any(|v| !v.is_finite()) {
                return Err(runtime_error(
                    "mel features contain NaN or infinity; check for corrupted PCM input",
                ));
            }
            flat.extend_from_slice(slice);
        }
    }

    let features = StorageView::new(
        &[total_chunks, n_mels, chunk_length],
        &mut flat,
        Device::CPU,
    )
    .map_err(|e| anyhow!("StorageView::new failed: {e}"))?;

    // faster-whisper prepends a space before tokenising both initial_prompt
    // and prefix so the first BPE token carries the leading-space marker.
    // Skipping it leaves the model wedged between a "continuation" subword
    // and the SOT block, which empties the output on certain prompts.
    let initial_prompt_tokens = match &request.initial_prompt {
        Some(text) if !text.trim().is_empty() => {
            encode_plain(tokenizer, &format!(" {}", text.trim()))?
        }
        _ => Vec::new(),
    };
    let prefix_tokens = match &request.prefix {
        Some(text) if !text.trim().is_empty() => {
            encode_plain(tokenizer, &format!(" {}", text.trim()))?
        }
        _ => Vec::new(),
    };

    let emit_timestamps = request.with_timestamps || request.word_timestamps;

    let mut prompts: Vec<Vec<String>> = Vec::with_capacity(total_chunks);
    for (audio_idx, chunks) in per_audio_chunks.iter().enumerate() {
        let lang_token = &per_audio_languages[audio_idx];
        for _ in 0..chunks.len() {
            let parts = PromptParts {
                sot: SOT,
                startofprev: STARTOFPREV,
                language_token: lang_token,
                transcribe: TRANSCRIBE,
                no_timestamps: NO_TIMESTAMPS,
                initial_prompt: &initial_prompt_tokens,
                prefix: &prefix_tokens,
                with_timestamps: emit_timestamps,
                multilingual,
            };
            prompts.push(parts.build());
        }
    }
    let prompt_refs: Vec<Vec<&str>> = prompts
        .iter()
        .map(|p| p.iter().map(String::as_str).collect())
        .collect();

    // Encode once; reuse for generate and (optionally) align.
    let encoder_output = whisper
        .encode(&features, false)
        .map_err(|e| anyhow!("Whisper::encode failed: {e}"))?;

    let mut opts = request.options.clone();
    opts.return_no_speech_prob = true;
    opts.return_scores = true;

    let generated = whisper
        .generate(&encoder_output, &prompt_refs, &opts)
        .map_err(|e| anyhow!("Whisper::generate failed: {e}"))?;
    if generated.len() != total_chunks {
        return Err(anyhow!(
            "expected {} generation results, got {}",
            total_chunks,
            generated.len()
        ));
    }

    let chunk_duration_s = preprocessor.n_samples as f32 / preprocessor.sampling_rate as f32;

    let mut chunk_state: Vec<ChunkState> = Vec::with_capacity(total_chunks);
    for (audio_idx, chunks) in per_audio_chunks.iter().enumerate() {
        for within_audio_idx in 0..chunks.len() {
            let chunk_offset_s = within_audio_idx as f32 * chunk_duration_s;
            let global_idx = chunk_offsets[audio_idx] + within_audio_idx;

            let token_ids = generated[global_idx]
                .sequences_ids
                .first()
                .ok_or_else(|| anyhow!("generation result missing first hypothesis"))?;
            let token_ids_u32: Vec<u32> = token_ids
                .iter()
                .map(|id| u32::try_from(*id).map_err(|_| anyhow!("token id {id} exceeds u32")))
                .collect::<Result<_>>()?;

            chunk_state.push(ChunkState {
                sub_segments: split_sub_segments(
                    &token_ids_u32,
                    specials.timestamp_begin,
                    chunk_duration_s,
                ),
                offset_s: chunk_offset_s,
                num_frames: encoder_frames_for_chunk(
                    audios[audio_idx].len(),
                    within_audio_idx,
                    preprocessor,
                ),
            });
        }
    }

    let words_per_chunk: Vec<Vec<Vec<crate::align::Word>>> = if request.word_timestamps {
        // sys::Whisper::align takes one start_sequence for the whole batch,
        // so every audio in the batch must share the same SOT block we
        // used at generate time. For multilingual auto-detect this means
        // every audio's detected language must match.
        let align_start_sequence =
            build_align_start_sequence(tokenizer, specials, multilingual, &per_audio_languages)?;

        let align_inputs: Vec<ChunkAlignInput<'_>> = chunk_state
            .iter()
            .map(|c| ChunkAlignInput {
                sub_segments: &c.sub_segments,
                chunk_offset_s: c.offset_s,
                num_frames: c.num_frames,
            })
            .collect();
        align_batch(
            whisper,
            tokenizer,
            &encoder_output,
            &align_inputs,
            &align_start_sequence,
            preprocessor.seconds_per_encoder_frame(),
            DEFAULT_MEDIAN_FILTER_WIDTH,
        )?
    } else {
        Vec::new()
    };

    // Materialise per-audio results.
    let mut output: Vec<TranscriptionResult> = Vec::with_capacity(audios.len());
    for audio_idx in 0..audios.len() {
        let mut segments: Vec<SegmentResult> = Vec::new();
        let global_range = chunk_offsets[audio_idx]..chunk_offsets[audio_idx + 1];

        for global_idx in global_range {
            let chunk = std::mem::take(&mut chunk_state[global_idx]);
            let chunk_offset_s = chunk.offset_s;
            let subs = chunk.sub_segments;
            let result = &generated[global_idx];
            // return_scores is forced on in the request, so a missing score
            // is a real bug — not something to paper over with 0.0.
            let avg_logprob = *result.scores.first().ok_or_else(|| {
                anyhow!("ct2 generation result is missing scores despite return_scores=true")
            })?;

            for (sub_idx, sub) in subs.into_iter().enumerate() {
                let text = decode_ids(tokenizer, &sub.text_token_ids)?
                    .trim()
                    .to_owned();
                if text.is_empty() {
                    continue;
                }
                let words = if request.word_timestamps {
                    let aligned_chunk = words_per_chunk.get(global_idx).ok_or_else(|| {
                        anyhow!(
                            "word_timestamps: alignment result missing for chunk {global_idx} \
                             (expected {total_chunks} chunks, got {})",
                            words_per_chunk.len()
                        )
                    })?;
                    let ws = aligned_chunk.get(sub_idx).ok_or_else(|| {
                        anyhow!(
                            "word_timestamps: alignment result missing for sub-segment \
                             {sub_idx} of chunk {global_idx} ({} sub-segments aligned)",
                            aligned_chunk.len()
                        )
                    })?;
                    Some(
                        ws.iter()
                            .map(|w| WordResult {
                                text: w.text.clone(),
                                start: w.start,
                                end: w.end,
                                probability: w.probability,
                            })
                            .collect::<Vec<_>>(),
                    )
                } else {
                    None
                };

                segments.push(SegmentResult {
                    text,
                    start: chunk_offset_s + sub.start_in_chunk,
                    end: chunk_offset_s + sub.end_in_chunk,
                    no_speech_prob: result.no_speech_prob,
                    avg_logprob,
                    tokens: sub.text_token_ids,
                    words,
                });
            }
        }

        let lang_token = &per_audio_languages[audio_idx];
        let language = lang_token
            .trim_start_matches("<|")
            .trim_end_matches("|>")
            .to_owned();

        output.push(TranscriptionResult {
            language,
            duration_s: per_audio_duration[audio_idx],
            segments,
        });
    }

    Ok(output)
}

fn detect_language(
    whisper: &Whisper,
    chunk: &ndarray::Array2<f32>,
    preprocessor: &Preprocessor,
) -> Result<String> {
    let mut buf = chunk
        .as_slice()
        .ok_or_else(|| anyhow!("mel chunk not contiguous"))?
        .to_vec();
    let features = StorageView::new(
        &[1, preprocessor.feature_size, preprocessor.nb_max_frames],
        &mut buf,
        Device::CPU,
    )
    .map_err(|e| anyhow!("StorageView::new for detect_language: {e}"))?;
    let result = whisper
        .detect_language(&features)
        .map_err(|e| anyhow!("Whisper::detect_language failed: {e}"))?;
    let detected = result
        .into_iter()
        .next()
        .and_then(|v| v.into_iter().next())
        .ok_or_else(|| anyhow!("detect_language returned no candidates"))?;
    Ok(detected.language)
}

/// Builds the SOT block used by `Whisper::align`, mirroring the prompt
/// shape `generate` was given so word boundaries land where they were
/// scored. `*.en` checkpoints get `[sot, no_timestamps]`; multilingual
/// gets `[sot, lang, transcribe, no_timestamps]`.
///
/// Errors out when the batch mixes detected languages: `sys::Whisper::align`
/// only accepts one start_sequence for the whole batch, so the caller has
/// to split the work or pin `:language`.
fn build_align_start_sequence(
    tokenizer: &hf::Tokenizer,
    specials: &SpecialTokens,
    multilingual: bool,
    per_audio_languages: &[String],
) -> Result<Vec<usize>> {
    if !multilingual {
        return Ok(vec![specials.sot as usize, specials.no_timestamps as usize]);
    }
    let first = uniform_align_language(per_audio_languages)?;
    let lang_id = token_id(tokenizer, first)?;
    Ok(vec![
        specials.sot as usize,
        lang_id as usize,
        specials.transcribe as usize,
        specials.no_timestamps as usize,
    ])
}

/// Guards against the silent-mismatch case where a caller pins
/// `:language` to a non-English code on an English-only checkpoint
/// (`*.en`). Decoding ignores the pinned token (the SOT block on `*.en`
/// is just `[<|startoftranscript|>]`) and runs English, but the
/// returned `language` would still echo the pinned code — misrouting
/// any downstream language-based logic.
fn reject_non_english_on_en_checkpoint(multilingual: bool, lang: &str) -> Result<()> {
    if !multilingual && lang != "en" {
        return Err(invalid_request(format!(
            "language {lang:?} requested on an English-only checkpoint; \
             only \"en\" is valid (or omit :language). Use a multilingual \
             checkpoint to transcribe other languages."
        )));
    }
    Ok(())
}

/// Pure check used by `build_align_start_sequence`: returns the common
/// language token of the batch, or an `invalid_request` error if the
/// languages disagree. Extracted so the mixed-language guard can be
/// unit-tested without a loaded tokenizer.
fn uniform_align_language(per_audio_languages: &[String]) -> Result<&str> {
    let first = per_audio_languages
        .first()
        .ok_or_else(|| runtime_error("align: no audios in batch"))?;
    if per_audio_languages.iter().any(|l| l != first) {
        return Err(invalid_request(format!(
            "word_timestamps requires every audio in a batch to share the same \
             resolved language; got {per_audio_languages:?}. Pin :language or \
             split the batch."
        )));
    }
    Ok(first.as_str())
}

/// Builds the chunk-offset prefix-sum used to translate
/// `(audio_idx, within_audio_idx)` into a flat batch index. Result has
/// length `chunk_counts.len() + 1`, with `offsets[i]..offsets[i+1]` covering
/// audio `i`. Pulled out of `transcribe_many` so the index arithmetic can
/// be exercised in isolation.
fn compute_chunk_offsets(chunk_counts: &[usize]) -> Vec<usize> {
    let mut offsets = Vec::with_capacity(chunk_counts.len() + 1);
    offsets.push(0);
    for n in chunk_counts {
        let prev = *offsets.last().expect("seeded above");
        offsets.push(prev + n);
    }
    offsets
}

/// Number of valid frames in chunk `chunk_idx` of a `samples_len`-sample
/// audio, in the units `sys::Whisper::align` expects.
///
/// Despite the API doc saying "encoder frames", faster-whisper passes the
/// mel-frame count (`samples / hop_length`, ~100 Hz) and that is what
/// produces correct DTW output — passing the encoder-frame count
/// (`mel / 2`, ~50 Hz) compresses every word into the first half of the
/// clip. We follow faster-whisper.
fn encoder_frames_for_chunk(
    samples_len: usize,
    chunk_idx: usize,
    preprocessor: &Preprocessor,
) -> usize {
    let start = chunk_idx * preprocessor.n_samples;
    let remaining = samples_len.saturating_sub(start);
    let chunk_samples = remaining.min(preprocessor.n_samples);
    (chunk_samples / preprocessor.hop_length)
        .min(preprocessor.nb_max_frames)
        .max(1)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::errors::kind_from_chain;
    use ndarray::Array2;

    fn tiny_preprocessor() -> Preprocessor {
        Preprocessor {
            feature_size: 80,
            hop_length: 160,
            n_fft: 400,
            n_samples: 480_000,
            nb_max_frames: 3_000,
            sampling_rate: 16_000,
            mel_filters: Array2::<f64>::zeros((80, 201)),
        }
    }

    #[test]
    fn compute_chunk_offsets_is_a_prefix_sum() {
        assert_eq!(compute_chunk_offsets(&[]), vec![0]);
        assert_eq!(compute_chunk_offsets(&[3]), vec![0, 3]);
        // Three audios, each 2 / 5 / 1 chunks — offsets must let us
        // recover any (audio_idx, within_audio_idx) -> global_idx with
        // `offsets[audio_idx] + within_audio_idx`.
        let offsets = compute_chunk_offsets(&[2, 5, 1]);
        assert_eq!(offsets, vec![0, 2, 7, 8]);

        let ranges: Vec<_> = (0..3).map(|i| offsets[i]..offsets[i + 1]).collect();
        assert_eq!(ranges[0].clone().collect::<Vec<_>>(), vec![0, 1]);
        assert_eq!(ranges[1].clone().collect::<Vec<_>>(), vec![2, 3, 4, 5, 6]);
        assert_eq!(ranges[2].clone().collect::<Vec<_>>(), vec![7]);
    }

    #[test]
    fn encoder_frames_for_chunk_clamps_to_nb_max_frames() {
        let preprocessor = tiny_preprocessor();

        // Full 30 s chunk: samples / hop = 480_000 / 160 = 3000 frames,
        // matches `nb_max_frames` exactly. Clamp is a no-op.
        assert_eq!(encoder_frames_for_chunk(480_000, 0, &preprocessor), 3000);

        // Partial first chunk: half-second of audio = 8000 samples →
        // 8000 / 160 = 50 frames.
        assert_eq!(encoder_frames_for_chunk(8_000, 0, &preprocessor), 50);

        // Second chunk of 35 s audio: only 5 s remain → 80_000 / 160 = 500.
        assert_eq!(encoder_frames_for_chunk(560_000, 1, &preprocessor), 500);

        // Tail past the end must clamp to a non-zero minimum so `align`
        // never sees `num_frames = 0` (which the DTW path would divide by).
        assert_eq!(encoder_frames_for_chunk(480_000, 5, &preprocessor), 1);
    }

    #[test]
    fn uniform_align_language_passes_through_matching_batch() {
        let langs = vec!["<|en|>".to_owned(), "<|en|>".to_owned()];
        assert_eq!(uniform_align_language(&langs).unwrap(), "<|en|>");
    }

    #[test]
    fn uniform_align_language_rejects_mixed_languages_as_invalid_request() {
        // Mixed-language batch + word_timestamps is a caller bug, not an
        // inference failure. The error category must reflect that so
        // Elixir surfaces `:invalid_request`.
        let langs = vec!["<|en|>".to_owned(), "<|de|>".to_owned()];
        let err = uniform_align_language(&langs).unwrap_err();
        assert_eq!(kind_from_chain(&err), Some("invalid_request"));
        let msg = format!("{err:#}");
        assert!(msg.contains("word_timestamps"), "got: {msg}");
        assert!(
            msg.contains("<|en|>") && msg.contains("<|de|>"),
            "got: {msg}"
        );
    }

    #[test]
    fn reject_non_english_on_en_checkpoint_allows_en() {
        assert!(reject_non_english_on_en_checkpoint(false, "en").is_ok());
    }

    #[test]
    fn reject_non_english_on_en_checkpoint_allows_anything_on_multilingual() {
        // Multilingual checkpoints decode the pinned language for real, so
        // the guard must not fire there.
        assert!(reject_non_english_on_en_checkpoint(true, "de").is_ok());
        assert!(reject_non_english_on_en_checkpoint(true, "fr").is_ok());
    }

    #[test]
    fn reject_non_english_on_en_checkpoint_rejects_non_en_as_invalid_request() {
        // Pinning :language to anything but "en" on an `.en` checkpoint
        // would otherwise be silently ignored by the prompt and then echoed
        // back in TranscriptionResult.language, misrouting downstream
        // language-based logic.
        let err = reject_non_english_on_en_checkpoint(false, "de").unwrap_err();
        assert_eq!(kind_from_chain(&err), Some("invalid_request"));
        let msg = format!("{err:#}");
        assert!(msg.contains("English-only"), "got: {msg}");
        assert!(msg.contains("\"de\""), "got: {msg}");
    }

    #[test]
    fn uniform_align_language_empty_batch_is_runtime_error() {
        // Reaching this guard with an empty per_audio_languages indicates
        // a NIF-internal bug (transcribe_many should already have early-
        // returned), so the category is runtime rather than invalid.
        let langs: Vec<String> = Vec::new();
        let err = uniform_align_language(&langs).unwrap_err();
        assert_eq!(kind_from_chain(&err), Some("runtime_error"));
    }
}