Skip to main content

native/whisper_ct2_native/src/align.rs

//! Word-level alignment.
//!
//! After `sys::Whisper::generate` we have per-chunk token sequences. To get
//! word boundaries we feed the encoder output and the **text** tokens of
//! every chunk into `sys::Whisper::align` in one batched call, then group
//! each chunk's per-text-token frame indices into Whisper words by walking
//! BPE leading-space markers and isolated-punctuation boundaries.
//!
//! The raw DTW boundaries get the same post-processing faster-whisper
//! applies before returning per-word timings:
//!
//! 1. **Median-duration + sentence-end clamp.** Words that absorbed
//!    silence end up multi-second long; we cap them at `2 * median(word
//!    duration)` when they border a sentence-end mark (`.。!!??`).
//! 2. **`merge_punctuations`.** Standalone punctuation tokens
//!    (`,`, `.`, `"`, `(`, ...) get folded back into the adjacent real
//!    word, mirroring faster-whisper's `prepend_punctuations` /
//!    `append_punctuations` defaults. The merged-into word keeps its
//!    original (already-clamped) end, which is the punctuation token's
//!    frame — so the silence after `"so,"` doesn't get attributed to
//!    `"so,"`.
//! 3. **Subsegment snap.** Each generated `<|t_start|> text <|t_end|>`
//!    block carries its own timestamps; when the DTW first/last word
//!    drifts past those, we snap it back. Same rule faster-whisper uses
//!    to avoid first-word overshoot at chunk boundaries.

// Frame indices and token ids stay within u32; the pedantic cast lints
// don't catch any real bug here.
#![allow(
    clippy::cast_precision_loss,
    clippy::cast_possible_truncation,
    clippy::cast_sign_loss,
    clippy::cast_possible_wrap
)]

use anyhow::{Result, anyhow};
use ct2rs::sys::{StorageView, Whisper};
use ct2rs::tokenizers::hf;
use tokenizers::Tokenizer as InnerTokenizer;

use crate::tokens::SubSegment;

/// One word with absolute time span.
#[derive(Debug)]
pub(crate) struct Word {
    pub(crate) text: String,
    pub(crate) start: f32,
    pub(crate) end: f32,
    pub(crate) probability: f32,
}

pub(crate) const DEFAULT_MEDIAN_FILTER_WIDTH: i64 = 7;

/// Punctuation lists mirror faster-whisper's defaults
/// (`transcribe.py:286-287`). Single-char tokens whose stripped form is in
/// one of these strings are folded into the adjacent real word by
/// `merge_punctuations`.
const PREPEND_PUNCT: &str = "\"'\u{201C}\u{00BF}([{-";
const APPEND_PUNCT: &str = "\"'.\u{3002},\u{FF0C}!\u{FF01}?\u{FF1F}:\u{FF1A}\u{201D})]}\u{3001}";
/// Sentence-final marks used to decide which side of an over-long word
/// gets clamped.
const SENTENCE_END_MARKS: &[&str] = &[".", "\u{3002}", "!", "\u{FF01}", "?", "\u{FF1F}"];
/// Upper bound on the per-chunk median word duration (seconds).
/// faster-whisper caps the median at 0.7 s so a chunk full of long words
/// can't lift the clamp threshold high enough to be useless.
const MEDIAN_DURATION_CAP_S: f32 = 0.7;

/// Per-chunk input to `align_batch`.
pub(crate) struct ChunkAlignInput<'a> {
    pub(crate) sub_segments: &'a [SubSegment],
    pub(crate) chunk_offset_s: f32,
    pub(crate) num_frames: usize,
}

/// Runs one batched `align` call across every chunk in `encoder_output` and
/// returns post-processed per-sub-segment word lists.
///
/// `start_sequence` must mirror the SOT block used at `generate` time:
/// `[sot, no_timestamps]` for `*.en` checkpoints, `[sot, lang, transcribe,
/// no_timestamps]` for multilingual. `sys::Whisper::align` takes a single
/// `start_sequence` for the whole batch, so all chunks in `encoder_output`
/// must have been generated against the same one.
pub(crate) fn align_batch(
    whisper: &Whisper,
    tokenizer: &hf::Tokenizer,
    encoder_output: &StorageView<'_>,
    chunks: &[ChunkAlignInput<'_>],
    start_sequence: &[usize],
    seconds_per_frame: f32,
    median_filter_width: i64,
) -> Result<Vec<Vec<Vec<Word>>>> {
    if chunks.is_empty() {
        return Ok(Vec::new());
    }

    let mut text_tokens_per_chunk: Vec<Vec<usize>> = Vec::with_capacity(chunks.len());
    let mut spans_per_chunk: Vec<Vec<std::ops::Range<usize>>> = Vec::with_capacity(chunks.len());
    let mut num_frames: Vec<usize> = Vec::with_capacity(chunks.len());

    for chunk in chunks {
        let mut flat: Vec<usize> = Vec::new();
        let mut spans: Vec<std::ops::Range<usize>> = Vec::new();
        for seg in chunk.sub_segments {
            let start = flat.len();
            flat.extend(seg.text_token_ids.iter().map(|id| *id as usize));
            spans.push(start..flat.len());
        }
        text_tokens_per_chunk.push(flat);
        spans_per_chunk.push(spans);
        num_frames.push(chunk.num_frames);
    }

    if text_tokens_per_chunk.iter().all(Vec::is_empty) {
        return Ok(chunks
            .iter()
            .map(|c| (0..c.sub_segments.len()).map(|_| Vec::new()).collect())
            .collect());
    }

    let results = whisper
        .align(
            encoder_output,
            start_sequence,
            &text_tokens_per_chunk,
            &num_frames,
            median_filter_width,
        )
        .map_err(|e| anyhow!("Whisper::align failed: {e}"))?;
    if results.len() != chunks.len() {
        return Err(anyhow!(
            "expected {} alignment results, got {}",
            chunks.len(),
            results.len()
        ));
    }

    let inner: &InnerTokenizer = tokenizer;
    let mut out: Vec<Vec<Vec<Word>>> = Vec::with_capacity(chunks.len());

    for (chunk_idx, alignment) in results.into_iter().enumerate() {
        let flat = &text_tokens_per_chunk[chunk_idx];
        let spans = &spans_per_chunk[chunk_idx];
        let chunk = &chunks[chunk_idx];

        if flat.is_empty() {
            out.push((0..chunk.sub_segments.len()).map(|_| Vec::new()).collect());
            continue;
        }

        let token_frames = token_first_frames(&alignment.alignments, flat.len());
        let token_probs = &alignment.text_token_probs;

        // 1. Pre-words: split on leading-space OR isolated punctuation,
        //    matching faster-whisper's `split_tokens_on_spaces`. Each
        //    pre-word carries the subsegment it came from.
        let mut pre_words = group_pre_words(inner, flat, token_probs, &token_frames, spans)?;

        // 2. Median-driven sentence-end clamp on the still-pre-merge list.
        let max_duration_frames = compute_max_duration_frames(&pre_words, seconds_per_frame);
        clamp_sentence_boundaries(&mut pre_words, max_duration_frames);

        // 3. Fold prepend/append punctuation back into the adjacent word.
        merge_punctuations(&mut pre_words);

        // 4. Partition surviving words back to sub-segments, then snap each
        //    sub-segment's first/last word to the `<|t_..|>` timestamps.
        let mut per_sub: Vec<Vec<Word>> = (0..spans.len()).map(|_| Vec::new()).collect();
        for pw in pre_words.iter().filter(|w| !w.text.is_empty()) {
            let words_of_sub = &mut per_sub[pw.subsegment_idx];
            words_of_sub.push(materialise_word(
                pw,
                chunk.chunk_offset_s,
                seconds_per_frame,
            ));
        }
        for (sub_idx, sub) in chunk.sub_segments.iter().enumerate() {
            snap_subsegment_boundaries(
                &mut per_sub[sub_idx],
                chunk.chunk_offset_s + sub.start_in_chunk,
                chunk.chunk_offset_s + sub.end_in_chunk,
                MEDIAN_DURATION_CAP_S.min(max_duration_frames as f32 * seconds_per_frame / 2.0),
            );
        }

        out.push(per_sub);
    }
    Ok(out)
}

/// One word in its pre-merge form. Frames are encoder frames within the
/// chunk; conversion to absolute seconds happens in `materialise_word`.
struct PreWord {
    text: String,
    tokens: Vec<u32>,
    probs: Vec<f32>,
    start_frame: usize,
    /// First frame of the next pre-word (or `last_token_frame + 1` for the
    /// chunk's last pre-word). Becomes the word's `end` after conversion.
    end_frame: usize,
    subsegment_idx: usize,
}

impl PreWord {
    fn duration_frames(&self) -> usize {
        self.end_frame.saturating_sub(self.start_frame)
    }
}

/// For each text-token index in `[0, n_tokens)`, returns the first encoder
/// frame the DTW path assigns to it. Tokens missing from the path inherit
/// the previous token's frame.
fn token_first_frames(
    alignments: &[ct2rs::sys::WhisperTokenAlignment],
    n_tokens: usize,
) -> Vec<usize> {
    let mut out = vec![usize::MAX; n_tokens];
    for align in alignments {
        let Ok(token_x) = usize::try_from(align.token_x) else {
            continue;
        };
        let frame_x = align.frame_x.max(0) as usize;
        if token_x < n_tokens && out[token_x] == usize::MAX {
            out[token_x] = frame_x;
        }
    }
    let mut last = 0usize;
    for slot in &mut out {
        if *slot == usize::MAX {
            *slot = last;
        } else {
            last = *slot;
        }
    }
    out
}

/// Groups subword tokens into pre-words. A new word starts on:
///
/// - a token whose surface form has a leading space (BPE `Ġ`/literal `' '`);
/// - an isolated punctuation token (single non-space character).
///
/// The second rule mirrors faster-whisper's `split_tokens_on_spaces`: it
/// lets the punctuation token surface as its own (very short) word that
/// then gets folded into the adjacent real word by `merge_punctuations`,
/// keeping the real word's end time at the punctuation token's frame
/// instead of the next word's.
fn group_pre_words(
    inner: &InnerTokenizer,
    token_ids: &[usize],
    token_probs: &[f32],
    token_frames: &[usize],
    spans: &[std::ops::Range<usize>],
) -> Result<Vec<PreWord>> {
    let mut words: Vec<PreWord> = Vec::new();

    for (sub_idx, span) in spans.iter().enumerate() {
        let mut current: Option<PreWord> = None;
        for idx in span.clone() {
            let id = token_ids[idx];
            let prob = token_probs[idx];
            let frame = token_frames[idx];
            let id_u32 = u32::try_from(id).map_err(|_| anyhow!("token id {id} exceeds u32"))?;
            let surface = inner
                .id_to_token(id_u32)
                .ok_or_else(|| anyhow!("token id {id} not in vocab"))?;
            let display = surface.replace('\u{0120}', " ");

            let starts_word = display.starts_with(' ') || is_isolated_punct(&display);

            match current.as_mut() {
                Some(w) if !starts_word => {
                    w.text.push_str(&display);
                    w.tokens.push(id_u32);
                    w.probs.push(prob);
                }
                _ => {
                    if let Some(prev) = current.take() {
                        words.push(prev);
                    }
                    current = Some(PreWord {
                        text: display,
                        tokens: vec![id_u32],
                        probs: vec![prob],
                        start_frame: frame,
                        end_frame: frame, // filled in below
                        subsegment_idx: sub_idx,
                    });
                }
            }
        }
        if let Some(prev) = current {
            words.push(prev);
        }
    }

    // Set each pre-word's end_frame to the next pre-word's start. The last
    // pre-word in the chunk has no next, so we extend it by one frame to
    // give the materialised word a non-zero duration.
    for i in 0..words.len() {
        let end = if i + 1 < words.len() {
            words[i + 1].start_frame
        } else {
            words[i].start_frame + 1
        };
        words[i].end_frame = end.max(words[i].start_frame);
    }

    Ok(words)
}

fn is_isolated_punct(s: &str) -> bool {
    let trimmed = s.trim();
    let mut chars = trimmed.chars();
    let Some(c) = chars.next() else {
        return false;
    };
    chars.next().is_none() && c.is_ascii_punctuation()
}

/// Caps any pre-word whose `[start, end]` exceeds `max_duration_frames`
/// and that borders (`words[i-1]` or `words[i]`) a sentence-end mark.
/// Mirrors `faster_whisper.transcribe._WhisperModel._add_word_timestamps`
/// lines 1607-1616.
fn clamp_sentence_boundaries(words: &mut [PreWord], max_duration_frames: usize) {
    if max_duration_frames == 0 {
        return;
    }
    let is_sentence_end = |s: &str| SENTENCE_END_MARKS.iter().any(|m| s.trim() == *m);
    for i in 1..words.len() {
        if words[i].duration_frames() > max_duration_frames {
            if is_sentence_end(&words[i].text) {
                words[i].end_frame = words[i].start_frame + max_duration_frames;
            } else if is_sentence_end(&words[i - 1].text) {
                words[i].start_frame = words[i].end_frame.saturating_sub(max_duration_frames);
            }
        }
    }
}

/// `2 * median(nonzero word durations)` in frames, capped at 0.7 s worth
/// of frames. Returns 0 when no word has a non-zero duration, which
/// disables the downstream clamp.
fn compute_max_duration_frames(words: &[PreWord], seconds_per_frame: f32) -> usize {
    let mut durations: Vec<usize> = words
        .iter()
        .map(PreWord::duration_frames)
        .filter(|d| *d > 0)
        .collect();
    if durations.is_empty() {
        return 0;
    }
    durations.sort_unstable();
    let median = durations[durations.len() / 2];
    let cap_frames = (MEDIAN_DURATION_CAP_S / seconds_per_frame).round() as usize;
    let capped = median.min(cap_frames);
    capped.saturating_mul(2)
}

/// Folds standalone prepend/append punctuation tokens back into the
/// adjacent real word. Adapted from faster-whisper's
/// `merge_punctuations` (`transcribe.py:1910-1941`). The merged-into
/// word's `start`/`end` are intentionally left alone — the original
/// frame boundaries become the timing of the combined word, which is
/// what gives `"so,"` an end at the comma's frame.
fn merge_punctuations(words: &mut Vec<PreWord>) {
    if words.len() < 2 {
        return;
    }

    // Prepend pass (right to left).
    let mut i = words.len() as isize - 2;
    let mut j: usize = words.len() - 1;
    while i >= 0 {
        let iu = i as usize;
        let candidate = is_single_char_member(&words[iu].text, PREPEND_PUNCT)
            && words[iu].text.starts_with(' ');
        if candidate {
            // Move prev's tokens/probs into following (j), inherit start.
            let prev_text = std::mem::take(&mut words[iu].text);
            let prev_tokens = std::mem::take(&mut words[iu].tokens);
            let prev_probs = std::mem::take(&mut words[iu].probs);
            let prev_start = words[iu].start_frame;

            let following = &mut words[j];
            following.text = prev_text + &following.text;
            following.tokens = [prev_tokens, std::mem::take(&mut following.tokens)].concat();
            following.probs = [prev_probs, std::mem::take(&mut following.probs)].concat();
            following.start_frame = prev_start;
        } else {
            j = iu;
        }
        i -= 1;
    }

    // Append pass (left to right).
    let mut i = 0usize;
    let mut j = 1usize;
    while j < words.len() {
        let mergeable = !words[i].text.is_empty()
            && !words[i].text.ends_with(' ')
            && is_single_char_member(&words[j].text, APPEND_PUNCT)
            && !words[j].text.starts_with(' ');
        if mergeable {
            let foll_text = std::mem::take(&mut words[j].text);
            let foll_tokens = std::mem::take(&mut words[j].tokens);
            let foll_probs = std::mem::take(&mut words[j].probs);

            let previous = &mut words[i];
            previous.text.push_str(&foll_text);
            previous.tokens.extend(foll_tokens);
            previous.probs.extend(foll_probs);
            // previous.end_frame is left alone: it already equals the
            // punctuation word's start_frame (== the punct token's frame).
        } else {
            i = j;
        }
        j += 1;
    }

    words.retain(|w| !w.text.is_empty());
}

fn is_single_char_member(s: &str, set: &str) -> bool {
    let trimmed = s.trim();
    let mut chars = trimmed.chars();
    let Some(c) = chars.next() else {
        return false;
    };
    if chars.next().is_some() {
        return false;
    }
    set.contains(c)
}

/// Aligns a sub-segment's first/last word with the `<|t_..|>` timestamps
/// the model emitted, when the raw DTW boundaries clearly overshoot.
/// Mirrors `_WhisperModel._add_word_timestamps` lines 1670-1692.
fn snap_subsegment_boundaries(
    words: &mut [Word],
    subseg_start_s: f32,
    subseg_end_s: f32,
    median_duration_s: f32,
) {
    if words.is_empty() {
        return;
    }
    let last_idx = words.len() - 1;

    // Snap word[0].start back when DTW placed it more than 0.5 s before
    // the segment's own timestamp.
    let first = &mut words[0];
    if subseg_start_s < first.end && subseg_start_s - 0.5 > first.start {
        first.start = 0.0_f32.max((first.end - median_duration_s).min(subseg_start_s));
    }
    // Snap word[-1].end forward when DTW placed it more than 0.5 s after
    // the segment's own timestamp.
    let last = &mut words[last_idx];
    if subseg_end_s > last.start && subseg_end_s + 0.5 < last.end {
        last.end = (last.start + median_duration_s).max(subseg_end_s);
    }
}

fn materialise_word(pre: &PreWord, chunk_offset_s: f32, seconds_per_frame: f32) -> Word {
    let start = chunk_offset_s + pre.start_frame as f32 * seconds_per_frame;
    let end = chunk_offset_s + pre.end_frame as f32 * seconds_per_frame;
    let probability = if pre.probs.is_empty() {
        0.0
    } else {
        pre.probs.iter().sum::<f32>() / pre.probs.len() as f32
    };
    Word {
        text: pre.text.trim_start().to_owned(),
        start,
        end: end.max(start),
        probability,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make_pre_word(start: usize, end: usize, text: &str) -> PreWord {
        PreWord {
            text: text.to_owned(),
            tokens: vec![1],
            probs: vec![1.0],
            start_frame: start,
            end_frame: end,
            subsegment_idx: 0,
        }
    }

    #[test]
    fn token_first_frames_fills_missing_with_previous() {
        // Two alignments touching tokens 0 and 2; token 1 has no DTW path
        // assignment and must inherit token 0's frame, not stay at usize::MAX
        // (which would translate to a giant `end` time downstream).
        let alignments = vec![
            ct2rs::sys::WhisperTokenAlignment {
                token_x: 0,
                frame_x: 5,
            },
            ct2rs::sys::WhisperTokenAlignment {
                token_x: 2,
                frame_x: 30,
            },
        ];
        assert_eq!(token_first_frames(&alignments, 4), vec![5, 5, 30, 30]);
    }

    #[test]
    fn token_first_frames_ignores_out_of_range_token_x() {
        // A token_x past `n_tokens` is library noise and must not panic
        // (the DTW backend can emit padding indices).
        let alignments = vec![ct2rs::sys::WhisperTokenAlignment {
            token_x: 100,
            frame_x: 7,
        }];
        assert_eq!(token_first_frames(&alignments, 3), vec![0, 0, 0]);
    }

    #[test]
    fn compute_max_duration_frames_returns_2x_median() {
        // Durations 10, 20, 30 frames; median = 20; 2x median = 40, well
        // under the 0.7 s cap at seconds_per_frame=0.02 (35 frames).
        // So the cap dominates: min(20, 35) * 2 = 40 .. wait — cap is
        // `min(median, cap_frames)`, then `* 2`. cap_frames at 0.02 s/frame
        // = 35. min(20, 35) = 20 → 40.
        let words = vec![
            make_pre_word(0, 10, "a"),
            make_pre_word(0, 20, "b"),
            make_pre_word(0, 30, "c"),
        ];
        assert_eq!(compute_max_duration_frames(&words, 0.02), 40);
    }

    #[test]
    fn compute_max_duration_frames_caps_at_median_duration_cap() {
        // Median 100 frames > 35-frame cap → uses cap, doubled to 70.
        let words = vec![
            make_pre_word(0, 80, "a"),
            make_pre_word(0, 100, "b"),
            make_pre_word(0, 120, "c"),
        ];
        assert_eq!(compute_max_duration_frames(&words, 0.02), 70);
    }

    #[test]
    fn compute_max_duration_frames_returns_zero_with_no_nonzero_durations() {
        // All-zero-duration input must yield 0 — that is the sentinel that
        // disables the sentence-end clamp downstream.
        let words = vec![make_pre_word(5, 5, "a"), make_pre_word(7, 7, "b")];
        assert_eq!(compute_max_duration_frames(&words, 0.02), 0);
    }

    #[test]
    fn clamp_sentence_boundaries_caps_word_after_sentence_end() {
        // `group_pre_words` splits standalone punctuation into its own
        // pre-word, so a sentence end here is literally `"."`. `is_sentence_end`
        // checks `text.trim() == mark` (not "ends with"). When the previous
        // word is the mark and the next word is too long, trim the next
        // word's start so silence after the period doesn't get absorbed.
        let mut words = vec![make_pre_word(0, 10, "."), make_pre_word(10, 200, "bar")];
        clamp_sentence_boundaries(&mut words, 40);
        assert_eq!(words[1].start_frame, 160);
        assert_eq!(words[1].end_frame, 200);
    }

    #[test]
    fn clamp_sentence_boundaries_caps_sentence_end_word_itself() {
        // `"foo" | "."` where `"."` is too long — `"."` is a sentence end,
        // so trim its end, not its start (its start is the previous word's
        // boundary).
        let mut words = vec![make_pre_word(0, 50, "foo"), make_pre_word(50, 250, ".")];
        clamp_sentence_boundaries(&mut words, 40);
        assert_eq!(words[1].start_frame, 50);
        assert_eq!(words[1].end_frame, 90);
    }

    #[test]
    fn clamp_sentence_boundaries_is_a_noop_with_zero_threshold() {
        // Threshold 0 means "no median was computable" → disable clamping
        // entirely (otherwise every word would be flagged as too long).
        let mut words = vec![make_pre_word(0, 10, "foo."), make_pre_word(10, 200, "bar")];
        clamp_sentence_boundaries(&mut words, 0);
        assert_eq!(words[1].start_frame, 10);
        assert_eq!(words[1].end_frame, 200);
    }

    #[test]
    fn is_isolated_punct_matches_single_ascii_punct() {
        assert!(is_isolated_punct(","));
        assert!(is_isolated_punct(" ."));
        assert!(!is_isolated_punct("hello"));
        assert!(!is_isolated_punct(",,"));
        assert!(!is_isolated_punct(""));
    }
}