Skip to main content

native/whisper_ct2_native/src/preprocessor.rs

//! Whisper feature extractor: reads `preprocessor_config.json` and turns
//! raw `f32` PCM into the `[feature_size, nb_max_frames]` log-mel chunks
//! `CTranslate2` consumes. Numerically matches
//! `faster_whisper.FeatureExtractor` (see `build_chunks`).

#![allow(
    clippy::cast_precision_loss,
    clippy::cast_possible_truncation,
    clippy::cast_sign_loss,
    // The STFT and mel-projection loops index multiple arrays in lockstep;
    // an `enumerate` rewrite is uglier than the explicit index loop.
    clippy::needless_range_loop
)]

use std::fs::File;
use std::io::BufReader;
use std::path::Path;

use std::f32::consts::PI;
use std::sync::Arc;

use anyhow::{Context, Result, anyhow};
use mel_spec::mel::mel;
use ndarray::Array2;
use rustfft::num_complex::Complex32;
use rustfft::{Fft, FftPlanner};
use serde::Deserialize;

const PREPROCESSOR_CONFIG_FILE: &str = "preprocessor_config.json";

/// Parsed `preprocessor_config.json` plus the (possibly synthesised) mel
/// filterbank used to produce log-mel features.
pub(crate) struct Preprocessor {
    pub(crate) feature_size: usize,
    pub(crate) hop_length: usize,
    pub(crate) n_fft: usize,
    pub(crate) n_samples: usize,
    pub(crate) nb_max_frames: usize,
    pub(crate) sampling_rate: usize,
    pub(crate) mel_filters: Array2<f64>,
}

#[derive(Deserialize)]
struct PreprocessorJson {
    feature_size: usize,
    hop_length: usize,
    n_fft: usize,
    n_samples: usize,
    nb_max_frames: usize,
    sampling_rate: usize,
    mel_filters: Option<Vec<Vec<f64>>>,
}

impl Preprocessor {
    pub(crate) fn load<P: AsRef<Path>>(model_dir: P) -> Result<Self> {
        let path = model_dir.as_ref().join(PREPROCESSOR_CONFIG_FILE);
        let file = File::open(&path).with_context(|| format!("opening {}", path.display()))?;
        let aux: PreprocessorJson = serde_json::from_reader(BufReader::new(file))
            .with_context(|| format!("parsing {}", path.display()))?;

        let mel_filters = if let Some(rows) = aux.mel_filters {
            let n_rows = rows.len();
            let n_cols = rows.first().map_or(0, Vec::len);
            Array2::from_shape_vec((n_rows, n_cols), rows.into_iter().flatten().collect())?
        } else {
            mel(
                aux.sampling_rate as f64,
                aux.n_fft,
                aux.feature_size,
                None,
                None,
                false,
                true,
            )
        };

        Ok(Self {
            feature_size: aux.feature_size,
            hop_length: aux.hop_length,
            n_fft: aux.n_fft,
            n_samples: aux.n_samples,
            nb_max_frames: aux.nb_max_frames,
            sampling_rate: aux.sampling_rate,
            mel_filters,
        })
    }

    /// Splits `samples` into 30 s windows and produces one
    /// `[feature_size, nb_max_frames]` log-mel array per window.
    ///
    /// Numerically matches `faster_whisper.FeatureExtractor` (librosa STFT
    /// with `center=True` and reflect padding) so the golden fixture in
    /// `test/fixtures/mel_golden/` passes element-wise. Each 30 s chunk
    /// is padded with `n_fft/2` reflected samples on each side, framed at
    /// `hop_length` with a Hann window, FFT'd, squared, mel-projected, and
    /// normalised with Whisper's
    /// `(max(log10(max(x,1e-10)), max_log-8) + 4)/4` curve.
    ///
    /// Chunks are processed independently — no STFT state leaks between
    /// the 30 s windows.
    pub(crate) fn build_chunks(&self, samples: &[f32]) -> Result<Vec<Array2<f32>>> {
        if samples.is_empty() {
            return Err(anyhow!("samples buffer is empty"));
        }

        let window = hann_window(self.n_fft);
        let fft = self.fft();
        let mut scratch = vec![Complex32::default(); fft.get_inplace_scratch_len()];
        let mut frame_buf = vec![Complex32::default(); self.n_fft];
        let pad = self.n_fft / 2;
        let n_freq = self.n_fft / 2 + 1;

        let mut out = Vec::new();
        for chunk in samples.chunks(self.n_samples) {
            // Pad chunk to n_samples so every output gets the full
            // nb_max_frames; faster-whisper zero-pads the tail the same way.
            let mut padded = vec![0.0_f32; self.n_samples + 2 * pad];
            for i in 0..pad {
                let src = (pad - i).min(chunk.len().saturating_sub(1));
                padded[i] = chunk[src];
            }
            padded[pad..pad + chunk.len()].copy_from_slice(chunk);
            // Tail past `chunk.len()` is already zero; reflect the last
            // samples into the trailing pad area for parity with
            // `np.pad(..., mode='reflect')` on a zero-padded array.
            let body_end = pad + self.n_samples;
            for i in 0..pad {
                let src = body_end.saturating_sub(2 + i);
                padded[body_end + i] = padded[src];
            }

            let mut mel_chunk = Array2::<f32>::zeros((self.feature_size, self.nb_max_frames));
            for f in 0..self.nb_max_frames {
                let start = f * self.hop_length;
                let end = start + self.n_fft;
                if end > padded.len() {
                    break;
                }
                for i in 0..self.n_fft {
                    frame_buf[i] = Complex32::new(padded[start + i] * window[i], 0.0);
                }
                fft.process_with_scratch(&mut frame_buf, &mut scratch);

                for m in 0..self.feature_size {
                    let mut sum = 0.0_f64;
                    for (k, bin) in frame_buf.iter().take(n_freq).enumerate() {
                        let power = f64::from(bin.re) * f64::from(bin.re)
                            + f64::from(bin.im) * f64::from(bin.im);
                        sum += self.mel_filters[(m, k)] * power;
                    }
                    mel_chunk[(m, f)] = sum as f32;
                }
            }

            normalise_log_mel(&mut mel_chunk);
            out.push(mel_chunk);
        }

        Ok(out)
    }

    fn fft(&self) -> Arc<dyn Fft<f32>> {
        FftPlanner::new().plan_fft_forward(self.n_fft)
    }

    /// Seconds per encoder output frame. Whisper's encoder downsamples mel
    /// frames by 2 (stride-2 conv), so each encoder frame covers
    /// `2 * hop_length / sampling_rate` seconds (0.02 s at 16 kHz / 160 hop).
    pub(crate) fn seconds_per_encoder_frame(&self) -> f32 {
        (2.0 * self.hop_length as f32) / self.sampling_rate as f32
    }
}

/// Periodic Hann window of length `n` (numpy's `np.hanning`-style endpoints).
fn hann_window(n: usize) -> Vec<f32> {
    (0..n)
        .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n as f32).cos()))
        .collect()
}

/// Faster-whisper / OpenAI Whisper log-mel normalisation:
/// `clip(log10(max(x, 1e-10)), max - 8, ...)` then `(x + 4) / 4`.
fn normalise_log_mel(mel: &mut Array2<f32>) {
    let mut max_log = f32::NEG_INFINITY;
    for v in mel.iter_mut() {
        let log = v.max(1e-10).log10();
        *v = log;
        if log > max_log {
            max_log = log;
        }
    }
    let floor = max_log - 8.0;
    for v in mel.iter_mut() {
        *v = (v.max(floor) + 4.0) / 4.0;
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::path::PathBuf;

    /// Workspace-relative path to the checked-in golden mel fixture.
    /// Generated by `tools/mel-reference/generate.py`. Pinned to the
    /// faster-whisper FeatureExtractor at fixture-generation time.
    fn golden_dir() -> PathBuf {
        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
            .join("..")
            .join("..")
            .join("test")
            .join("fixtures")
            .join("mel_golden")
    }

    fn read_f32_le(path: &std::path::Path) -> Vec<f32> {
        let bytes = std::fs::read(path).expect("fixture file");
        assert!(bytes.len().is_multiple_of(4), "fixture not 4-byte aligned");
        bytes
            .chunks_exact(4)
            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
            .collect()
    }

    #[test]
    fn seconds_per_encoder_frame_matches_whisper_default() {
        let preprocessor = Preprocessor {
            feature_size: 80,
            hop_length: 160,
            n_fft: 400,
            n_samples: 480_000,
            nb_max_frames: 3_000,
            sampling_rate: 16_000,
            mel_filters: Array2::<f64>::zeros((80, 201)),
        };
        // 2 * 160 / 16_000 = 0.02 s. Hard-coded here to catch a
        // regression in either the formula or the assumed defaults.
        let dt = preprocessor.seconds_per_encoder_frame();
        assert!((dt - 0.02).abs() < 1e-9, "got {dt}");
    }

    #[test]
    fn build_chunks_matches_faster_whisper_golden_mel() {
        let dir = golden_dir();
        let preprocessor = Preprocessor::load(&dir).expect("load preprocessor config");
        let samples = read_f32_le(&dir.join("input.pcm_f32"));
        let reference = read_f32_le(&dir.join("mel.f32"));

        let chunks = preprocessor
            .build_chunks(&samples)
            .expect("build_chunks succeeds");
        assert_eq!(chunks.len(), 1, "fixture is a single 30 s window");
        let mel = chunks.into_iter().next().unwrap();
        assert_eq!(
            mel.shape(),
            &[preprocessor.feature_size, preprocessor.nb_max_frames]
        );
        assert_eq!(reference.len(), mel.len());

        let mel_slice = mel.as_slice().expect("contiguous mel chunk");
        let feature_size = preprocessor.feature_size;
        let n_frames = preprocessor.nb_max_frames;

        // Mel layout is `[feature_size, n_frames]` row-major, so the
        // stride along the frame axis is 1. `mel[bin * n_frames + f]`
        // gives bin `bin` at frame `f`.
        //
        // Per-frame RMS captures average bin error within one time slice.
        // A wrong filterbank shows up as systemic non-zero RMS across
        // every frame; an STFT framing difference shows up as a tight
        // band of bad frames at chunk boundaries with a clean middle.
        let mut per_frame_rms = vec![0.0_f64; n_frames];
        for bin in 0..feature_size {
            for f in 0..n_frames {
                let idx = bin * n_frames + f;
                let d = f64::from(mel_slice[idx] - reference[idx]);
                per_frame_rms[f] += d * d;
            }
        }
        for v in &mut per_frame_rms {
            *v = (*v / feature_size as f64).sqrt();
        }

        // faster-whisper's librosa STFT uses `center=True` with reflect
        // padding, so the first and last ~`n_fft / (2 * hop_length)`
        // frames have no streaming-STFT equivalent and must be
        // excluded. At n_fft=400, hop=160 that's ~1 frame either side;
        // we trim 2 to be safe.
        let trim = 2_usize;
        let mut overall_max = 0.0_f64;
        let mut overall_sumsq = 0.0_f64;
        let mut counted = 0_usize;
        for f in trim..(n_frames - trim) {
            for bin in 0..feature_size {
                let idx = bin * n_frames + f;
                let d = f64::from(mel_slice[idx] - reference[idx]).abs();
                if d > overall_max {
                    overall_max = d;
                }
                overall_sumsq += d * d;
                counted += 1;
            }
        }
        let overall_rms = (overall_sumsq / counted as f64).sqrt();

        // Tolerances: per-element max < 0.05 (mel values normalise to
        // ~[-1, 1] after `norm_mel`, so 5 % is a hard fail), and RMS
        // across all interior frames < 0.005. A wrong filterbank or
        // power-vs-magnitude bug pushes both into the 0.1+ range.
        assert!(
            overall_max < 0.05,
            "mel max-abs delta {overall_max:.4} exceeds 0.05 — \
             likely filterbank or STFT scaling drift"
        );
        assert!(
            overall_rms < 0.005,
            "mel RMS delta {overall_rms:.6} exceeds 0.005 — \
             likely a systemic preprocessor parity issue"
        );
    }
}