defmodule Bumblebee.Audio.WhisperFeaturizer do
  alias Bumblebee.Shared

  import Nx.Defn

  options = [
    feature_size: [
      default: 80,
      doc: "the dimension of the extracted features. This corresponds to the number of Mel bins"
    sampling_rate: [
      default: 16_000,
      doc: "the sampling rate at which the audio files should be digitally expressed in Hertz"
    num_seconds: [
      default: 30,
      doc: """
      the maximum duration of the audio sequence. This implies that the the maximum length of the
      input sequence is `:num_seconds` * `:sampling_rate`
    hop_length: [
      default: 160,
        "the hop between consecutive overlapping windows for the STFT used to obtain Mel Frequency coefficients"
    fft_length: [
      default: 400,
      doc: "the size of the fourier transform"
    padding_value: [
      default: 0.0,
      doc: "the value used to pad the audio. Should correspond to silence"

  @moduledoc """
  Whisper featurizer for audio data.

  ## Configuration


  defstruct Shared.option_defaults(options)

  @behaviour Bumblebee.Featurizer
  @behaviour Bumblebee.Configurable

  @impl true
  def config(featurizer, opts \\ []) do
    Shared.put_config_attrs(featurizer, opts)

  @impl true
  def apply(featurizer, raw_samples, defn_options) do
    max_length = featurizer.num_seconds * featurizer.sampling_rate

    transformed_samples =
      for sample <- List.wrap(raw_samples) do
        unless Nx.rank(sample) == 1 do
          raise ArgumentError,
                "expected sample to be a 1-rank tensor, got: #{Nx.rank(sample)}-rank"

        pad_size = max_length - Nx.axis_size(sample, 0)
        sample = Nx.pad(sample, featurizer.padding_value, [{0, pad_size, 0}])

        Nx.Defn.jit(&extract_fbank_features/2, defn_options).(sample,
          fft_length: featurizer.fft_length,
          sampling_rate: featurizer.sampling_rate,
          mel_bins: featurizer.feature_size,
          hop_length: featurizer.hop_length

    samples = Nx.stack(transformed_samples)

    %{"input_features" => samples}

  defnp extract_fbank_features(waveform, opts \\ []) do
    opts = keyword!(opts, [:fft_length, :sampling_rate, :mel_bins, :hop_length])

    window = NxSignal.Windows.hann(n: opts[:fft_length], is_periodic: true)

    {stft, _, _} =
      NxSignal.stft(waveform, window,
        sampling_rate: opts[:sampling_rate],
        fft_length: opts[:fft_length],
        overlap_length: opts[:fft_length] - opts[:hop_length],
        window_padding: :reflect

    # Magic numbers taken from the reference implementation. This yields
    # max_mel ~ 3016
    frequency_spacing = 200.0 / 3
    max_mel = frequency_spacing * 45.245640471924965

    NxSignal.stft_to_mel(stft, opts[:sampling_rate],
      fft_length: opts[:fft_length],
      mel_bins: opts[:mel_bins],
      max_mel: max_mel,
      mel_frequency_spacing: frequency_spacing

  defimpl Bumblebee.HuggingFace.Transformers.Config do
    def load(featurizer, data) do
      import Shared.Converters

      opts =
          feature_size: {"feature_size", number()},
          sampling_rate: {"sampling_rate", number()},
          hop_length: {"hop_length", number()},
          num_seconds: {"chunk_length", number()},
          fft_length: {"n_fft", number()},
          padding_value: {"padding_value", number()}

      @for.config(featurizer, opts)