Skip to main content

lib/whisper_ct2.ex

defmodule WhisperCt2 do
  @moduledoc """
  Native Elixir bindings for Whisper speech-to-text via CTranslate2.

  Calls `ct2rs::sys::Whisper` directly through a Rustler NIF: no Python, no
  HTTP gateway. The NIF owns the mel spectrogram, tokenizer, and prompt
  construction, so structured per-segment results, `:initial_prompt` /
  `:prefix` biasing, word-level timestamps, and batched transcribe across
  multiple audios are all first-class.

  ## Quickstart

      {:ok, model} = WhisperCt2.load_model("/path/to/faster-whisper-tiny")
      pcm = File.read!("audio.pcm")

      {:ok, %WhisperCt2.Transcription{text: text, segments: segs}} =
        WhisperCt2.transcribe(model, {:pcm_f32, pcm}, language: "en")

      IO.puts(text)
      for s <- segs, do: IO.puts("[\#{s.start}-\#{s.end}] \#{s.text}")

  ## Audio contract

  CTranslate2 expects **mono `f32` PCM samples** at the model's
  `:sampling_rate` (16 kHz for every published Whisper checkpoint),
  normalised to the `-1.0..1.0` range. `transcribe/3` and
  `transcribe_batch/3` accept exactly one shape:

  - `{:pcm_f32, binary}` - little-endian f32 samples at the model's
    sample rate.

  Decoding `.wav`, resampling, downmixing, and any other format
  conversion is the caller's job. There is no bundled audio decoder;
  use `ffmpeg`, a dedicated library, or your platform's audio stack
  before calling in.

  Audio longer than the Whisper 30 s window is chunked internally; the
  encoder runs once across every chunk in the batch. Diarization-driven
  workflows that need many short splices should use
  `transcribe_batch/3`.
  """

  alias WhisperCt2.{Error, Model, Native, Segment, Transcription, Word}

  @typedoc "Audio sources accepted by `transcribe/3` and `transcribe_batch/3`."
  @type audio :: {:pcm_f32, binary()}

  @typedoc "Options accepted by `transcribe/3` / `transcribe_batch/3`."
  @type transcribe_opt ::
          {:language, String.t() | nil}
          | {:initial_prompt, String.t() | nil}
          | {:prefix, String.t() | nil}
          | {:word_timestamps, boolean()}
          | {:with_timestamps, boolean()}
          | {:beam_size, pos_integer()}
          | {:patience, float()}
          | {:length_penalty, float()}
          | {:repetition_penalty, float()}
          | {:no_repeat_ngram_size, non_neg_integer()}
          | {:sampling_temperature, float()}
          | {:sampling_topk, pos_integer()}
          | {:suppress_blank, boolean()}
          | {:max_length, pos_integer()}
          | {:num_hypotheses, pos_integer()}
          | {:max_initial_timestamp_index, non_neg_integer()}
          | {:suppress_tokens, [integer()]}

  @typedoc "Options accepted by `load_model/2`."
  @type load_opt ::
          {:device, :cpu | :cuda | :auto}
          | {:compute_type, Model.compute_type()}
          | {:device_indices, [non_neg_integer(), ...]}
          | {:num_threads_per_replica, non_neg_integer()}
          | {:max_queued_batches, integer()}
          | {:cpu_core_offset, integer()}

  @devices [:cpu, :cuda, :auto]
  @compute_types [
    :default,
    :auto,
    :float32,
    :float16,
    :bfloat16,
    :int8,
    :int8_float32,
    :int8_float16,
    :int8_bfloat16,
    :int16
  ]

  @doc """
  Reports CTranslate2 device support for this build.

  Returns `{:ok, %{cpu: n, cuda: n, cuda_supported: bool}}` on success.
  `cuda_supported` reflects compile-time CUDA features (build with
  `WHISPER_CT2_FEATURES=cuda-dynamic mix compile` to enable). `cuda` is the
  count of NVIDIA GPU devices visible at runtime, or `0` when CUDA is not
  built in.
  """
  @spec available_devices() ::
          {:ok,
           %{
             cpu: non_neg_integer(),
             cuda: non_neg_integer(),
             cuda_supported: boolean()
           }}
          | {:error, Error.t()}
  def available_devices do
    case Native.available_devices() do
      {:ok, info} -> {:ok, info}
      {:error, payload} -> {:error, Error.from_native(payload)}
    end
  end

  @doc """
  Loads a CTranslate2 Whisper model from a directory.

  See the `WhisperCt2` module doc for required model files.

  ## Options

  - `:device` - `:cpu`, `:cuda`, or `:auto` (default). `:auto` picks CUDA
    when the binary was built with CUDA support and a device is visible;
    otherwise CPU.
  - `:compute_type` - precision used at inference. `:default` keeps the
    model's stored quantisation; `:auto` picks the fastest supported on
    this device.
  - `:device_indices` - non-empty list of GPU indices (default `[0]`).
  - `:num_threads_per_replica` - intra-op threads. `0` lets CTranslate2 pick.
  - `:max_queued_batches`, `:cpu_core_offset` - passed through to
    CTranslate2.
  """
  @spec load_model(Path.t(), [load_opt()]) :: {:ok, Model.t()} | {:error, Error.t()}
  def load_model(path, opts \\ [])

  def load_model(path, opts) when is_binary(path) and is_list(opts) do
    with :ok <- validate_non_empty_string(path, :path),
         :ok <- validate_options(opts, load_validators()) do
      do_load_model(path, opts)
    end
  end

  def load_model(_path, _opts) do
    {:error, Error.new(:invalid_request, "path must be a string and opts a keyword list")}
  end

  defp do_load_model(path, opts) do
    with {:ok, ref} <- native_call(Native.load_model(path, build_load_opts(opts))),
         {:ok, info} <- native_call(Native.model_info(ref)),
         {:ok, device} <- decode_device(info.device),
         {:ok, compute_type} <- decode_compute_type(info.compute_type) do
      {:ok,
       %Model{
         ref: ref,
         path: path,
         sampling_rate: info.sampling_rate,
         n_samples: info.n_samples,
         multilingual: info.multilingual,
         device: device,
         compute_type: compute_type
       }}
    end
  end

  @device_atoms Map.new(@devices, fn a -> {Atom.to_string(a), a} end)
  @compute_type_atoms Map.new(@compute_types, fn a -> {Atom.to_string(a), a} end)

  defp decode_device(label) when is_binary(label) do
    case Map.fetch(@device_atoms, label) do
      {:ok, atom} ->
        {:ok, atom}

      :error ->
        {:error, Error.new(:runtime_error, "NIF reported unknown device", %{device: label})}
    end
  end

  defp decode_compute_type(label) when is_binary(label) do
    case Map.fetch(@compute_type_atoms, label) do
      {:ok, atom} ->
        {:ok, atom}

      :error ->
        {:error, Error.new(:runtime_error, "NIF reported unknown compute_type", %{compute_type: label})}
    end
  end

  defp native_call({:ok, _} = ok), do: ok
  defp native_call({:error, payload}), do: {:error, Error.from_native(payload)}

  defp build_load_opts(opts) do
    %{
      device: opts |> Keyword.get(:device) |> atom_to_string(),
      compute_type: opts |> Keyword.get(:compute_type) |> atom_to_string(),
      device_indices: Keyword.get(opts, :device_indices),
      num_threads_per_replica: Keyword.get(opts, :num_threads_per_replica),
      max_queued_batches: Keyword.get(opts, :max_queued_batches),
      cpu_core_offset: Keyword.get(opts, :cpu_core_offset)
    }
  end

  defp atom_to_string(nil), do: nil
  defp atom_to_string(value) when is_atom(value), do: Atom.to_string(value)

  @doc """
  Transcribes `audio` using `model`.

  Returns `{:ok, %WhisperCt2.Transcription{}}` whose `:segments` carry
  absolute start/end times, `no_speech_prob`, `avg_logprob`, the
  underlying text tokens, and (when `:word_timestamps` is set) per-word
  timing. `no_speech_prob` and `avg_logprob` are always populated.

  ## Options

  - `:language` - ISO code (`"en"`). `nil` (default) auto-detects.
  - `:initial_prompt` - free-text context prepended via `<|startofprev|>`
    to bias decoding.
  - `:prefix` - forced text the generation must start with.
  - `:word_timestamps` - when `true`, attaches `:words` to each segment
    via one extra batched DTW alignment pass. Default `false`.
  - `:with_timestamps` - when `true` (default) the prompt asks the model
    to emit `<|t_..|>` timestamp tokens that split the output into
    sub-segments. Set to `false` for fine-tunes that emit text without
    timestamps; the chunk's full text becomes one segment spanning
    `[0, chunk_duration_s)`. Implicitly forced to `true` whenever
    `:word_timestamps` is enabled because alignment needs the timestamp
    scaffolding.
  - Decoding knobs forwarded to CTranslate2: `:beam_size`, `:patience`,
    `:length_penalty`, `:repetition_penalty`, `:no_repeat_ngram_size`,
    `:sampling_temperature`, `:sampling_topk`, `:suppress_blank`,
    `:max_length`, `:num_hypotheses`, `:max_initial_timestamp_index`,
    `:suppress_tokens`.
  """
  @spec transcribe(Model.t(), audio(), [transcribe_opt()]) ::
          {:ok, Transcription.t()} | {:error, Error.t()}
  def transcribe(model, audio, opts \\ [])

  def transcribe(%Model{} = model, audio, opts) when is_list(opts) do
    with :ok <- validate_options(opts, transcribe_validators()),
         {:ok, samples} <- resolve_audio(audio) do
      do_transcribe(model, samples, opts)
    end
  end

  def transcribe(_model, _audio, _opts) do
    {:error, Error.new(:invalid_request, "expected a %WhisperCt2.Model{} and a keyword list")}
  end

  defp do_transcribe(%Model{ref: ref}, samples, opts) do
    case Native.transcribe(ref, samples, build_transcribe_opts(opts)) do
      {:ok, payload} -> {:ok, build_transcription(payload)}
      {:error, payload} -> {:error, Error.from_native(payload)}
    end
  end

  @doc """
  Transcribes a list of audios in one batched `generate` call. Every
  chunk of every input shares a single encoder forward pass; output
  preserves input order.

  Options are the same as `transcribe/3`. `:language` applies to every
  audio in the batch; pass `nil` to auto-detect per-audio.
  """
  @spec transcribe_batch(Model.t(), [audio()], [transcribe_opt()]) ::
          {:ok, [Transcription.t()]} | {:error, Error.t()}
  def transcribe_batch(model, audios, opts \\ [])

  def transcribe_batch(%Model{} = _model, [], _opts), do: {:ok, []}

  def transcribe_batch(%Model{} = model, audios, opts)
      when is_list(audios) and is_list(opts) do
    with :ok <- validate_options(opts, transcribe_validators()),
         {:ok, samples_list} <- resolve_audios(audios) do
      do_transcribe_batch(model, samples_list, opts)
    end
  end

  def transcribe_batch(_model, _audios, _opts) do
    {:error,
     Error.new(
       :invalid_request,
       "expected a %WhisperCt2.Model{}, a list of audios, and a keyword list"
     )}
  end

  defp do_transcribe_batch(%Model{ref: ref}, samples_list, opts) do
    case Native.transcribe_batch(ref, samples_list, build_transcribe_opts(opts)) do
      {:ok, payloads} ->
        {:ok, Enum.map(payloads, &build_transcription/1)}

      {:error, payload} ->
        {:error, Error.from_native(payload)}
    end
  end

  defp resolve_audios(audios) do
    result =
      Enum.reduce_while(audios, {:ok, []}, fn audio, {:ok, acc} ->
        case resolve_audio(audio) do
          {:ok, samples} -> {:cont, {:ok, [samples | acc]}}
          {:error, _} = err -> {:halt, err}
        end
      end)

    case result do
      {:ok, reversed} -> {:ok, Enum.reverse(reversed)}
      err -> err
    end
  end

  defp resolve_audio({:pcm_f32, samples}) when is_binary(samples) do
    cond do
      byte_size(samples) == 0 ->
        {:error, Error.new(:invalid_request, "PCM binary is empty")}

      rem(byte_size(samples), 4) != 0 ->
        {:error,
         Error.new(:invalid_request, "PCM binary length must be a multiple of 4 (f32)", %{
           byte_size: byte_size(samples)
         })}

      true ->
        {:ok, samples}
    end
  end

  defp resolve_audio(_) do
    {:error,
     Error.new(
       :invalid_request,
       "audio must be {:pcm_f32, binary}; decode/resample upstream"
     )}
  end

  # The three `build_*` functions pattern-match the NIF map shape
  # strictly in the function head. That makes the shape a static
  # contract: the Elixir 1.18 typechecker proves any caller that passes
  # a map it can't show fits the head will not match, and surfaces it as
  # a type warning at compile time. Exposed as `@doc false def` so the
  # contract tests in `nif_contract_test.exs` can pin which atom keys
  # map to which struct fields without a loaded model; not part of the
  # public API.

  @doc false
  @spec build_transcription(map()) :: Transcription.t()
  def build_transcription(%{
        language: language,
        duration_s: duration_s,
        segments: raw_segments
      }) do
    segments = Enum.map(raw_segments, &build_segment/1)

    text =
      segments
      |> Enum.map_join(" ", & &1.text)
      |> String.trim()

    %Transcription{
      text: text,
      segments: segments,
      language: language,
      duration_s: duration_s
    }
  end

  @doc false
  @spec build_segment(map()) :: Segment.t()
  def build_segment(%{
        text: text,
        start: start,
        end: end_s,
        no_speech_prob: no_speech_prob,
        avg_logprob: avg_logprob,
        tokens: tokens,
        words: words
      }) do
    %Segment{
      text: text,
      start: start,
      end: end_s,
      no_speech_prob: no_speech_prob,
      avg_logprob: avg_logprob,
      tokens: tokens,
      words: words && Enum.map(words, &build_word/1)
    }
  end

  @doc false
  @spec build_word(map()) :: Word.t()
  def build_word(%{text: text, start: start, end: end_s, probability: probability}) do
    %Word{text: text, start: start, end: end_s, probability: probability}
  end

  defp build_transcribe_opts(opts) do
    %{
      language: Keyword.get(opts, :language),
      initial_prompt: Keyword.get(opts, :initial_prompt),
      prefix: Keyword.get(opts, :prefix),
      word_timestamps: Keyword.get(opts, :word_timestamps),
      with_timestamps: Keyword.get(opts, :with_timestamps),
      beam_size: Keyword.get(opts, :beam_size),
      patience: Keyword.get(opts, :patience),
      length_penalty: Keyword.get(opts, :length_penalty),
      repetition_penalty: Keyword.get(opts, :repetition_penalty),
      no_repeat_ngram_size: Keyword.get(opts, :no_repeat_ngram_size),
      sampling_temperature: Keyword.get(opts, :sampling_temperature),
      sampling_topk: Keyword.get(opts, :sampling_topk),
      suppress_blank: Keyword.get(opts, :suppress_blank),
      max_length: Keyword.get(opts, :max_length),
      num_hypotheses: Keyword.get(opts, :num_hypotheses),
      max_initial_timestamp_index: Keyword.get(opts, :max_initial_timestamp_index),
      suppress_tokens: Keyword.get(opts, :suppress_tokens)
    }
  end

  @spec validate_non_empty_string(String.t(), atom()) :: :ok | {:error, Error.t()}
  defp validate_non_empty_string(value, name) do
    if String.trim(value) == "" do
      {:error, Error.new(:invalid_request, "#{name} must be a non-empty string")}
    else
      :ok
    end
  end

  defp load_validators do
    %{
      device: &(&1 in @devices),
      compute_type: &(&1 in @compute_types),
      device_indices: &non_empty_list_of_non_neg_integers?/1,
      num_threads_per_replica: &non_neg_integer?/1,
      max_queued_batches: &is_integer/1,
      cpu_core_offset: &is_integer/1
    }
  end

  defp transcribe_validators do
    %{
      language: &valid_optional_string?/1,
      initial_prompt: &valid_optional_string?/1,
      prefix: &valid_optional_string?/1,
      word_timestamps: &is_boolean/1,
      with_timestamps: &is_boolean/1,
      beam_size: &positive_integer?/1,
      # `patience` is faster-whisper's beam-search patience; values < 1.0
      # are documented as effectively disabling it.
      patience: &positive_number?/1,
      # CTranslate2 accepts any sign for `length_penalty`, including
      # negative values that bias toward shorter generations.
      length_penalty: &number?/1,
      # `repetition_penalty` < 1.0 amplifies repetition; documented values
      # are >= 1.0 (1.0 = neutral). Reject < 1.0 at the boundary.
      repetition_penalty: &repetition_penalty?/1,
      no_repeat_ngram_size: &non_neg_integer?/1,
      # Negative temperatures are nonsensical; 0.0 = greedy.
      sampling_temperature: &non_neg_number?/1,
      sampling_topk: &positive_integer?/1,
      suppress_blank: &is_boolean/1,
      max_length: &positive_integer?/1,
      num_hypotheses: &positive_integer?/1,
      max_initial_timestamp_index: &non_neg_integer?/1,
      suppress_tokens: &list_of_integers?/1
    }
  end

  @spec validate_options(keyword(), map()) :: :ok | {:error, Error.t()}
  defp validate_options(opts, validators) do
    Enum.reduce_while(opts, :ok, fn pair, :ok -> check_option(pair, validators) end)
  end

  defp check_option({key, value}, validators) do
    case Map.fetch(validators, key) do
      :error ->
        {:halt, {:error, Error.new(:invalid_request, "unknown option #{inspect(key)}")}}

      {:ok, validator} ->
        if validator.(value) do
          {:cont, :ok}
        else
          {:halt,
           {:error,
            Error.new(
              :invalid_request,
              "invalid value for option #{inspect(key)}: #{inspect(value)}"
            )}}
        end
    end
  end

  defp valid_optional_string?(nil), do: true
  defp valid_optional_string?(value) when is_binary(value), do: String.trim(value) != ""
  defp valid_optional_string?(_), do: false

  defp positive_integer?(v), do: is_integer(v) and v > 0
  defp non_neg_integer?(v), do: is_integer(v) and v >= 0
  defp number?(v), do: is_integer(v) or is_float(v)
  defp positive_number?(v), do: number?(v) and v > 0
  defp non_neg_number?(v), do: number?(v) and v >= 0
  defp repetition_penalty?(v), do: number?(v) and v >= 1

  defp list_of_integers?(v) when is_list(v), do: Enum.all?(v, &is_integer/1)
  defp list_of_integers?(_), do: false

  defp non_empty_list_of_non_neg_integers?([_ | _] = v),
    do: Enum.all?(v, &non_neg_integer?/1)

  defp non_empty_list_of_non_neg_integers?(_), do: false
end