lib/nx_signal.ex

defmodule NxSignal do
  @moduledoc """
  Nx library extension for digital signal processing.
  """

  import Nx.Defn

  @doc ~S"""
  Computes the Short-Time Fourier Transform of a tensor.

  Returns the complex spectrum Z, the time in seconds for
  each frame and the frequency bins in Hz.

  The STFT is parameterized through:

    * $k$: length of the Discrete Fourier Transform (DFT)
    * $N$: length of each frame
    * $H$: hop (in samples) between frames (calculated as $H = N - \text{overlap\\_length}$)
    * $M$: number of frames
    * $x[n]$: the input time-domain signal
    * $w[n]$: the window function to be applied to each frame

  $$
  DFT(x, w) := \sum_{n=0}^{N - 1} x[n]w[n]e^\frac{-2 \pi i k n}{N} \\\\
  X[m, k] = DFT(x[mH..(mH + N - 1)], w)
  $$

  where $m$ assumes all values in the interval $[0, M - 1]$

  See also: `NxSignal.Windows`, `istft/3`, `stft_to_mel/3`

  ## Options

    * `:sampling_rate` - the sampling frequency $F_s$ for the input in Hz. Defaults to `1000`.
    * `:fft_length` - the DFT length that will be passed to `Nx.fft/2`. Defaults to `:power_of_two`.
    * `:overlap_length` - the number of samples for the overlap between frames.
      Defaults to half the window size.
    * `:window_padding` - `:reflect`, `:zeros` or `nil`. See `as_windowed/3` for more details.
    * `:scaling` - `nil`, `:spectrum` or `:psd`.
      * `:spectrum` - each frame is divided by $\sum_{i} window[i]$.
      * `nil` - No scaling is applied.
      * `:psd` - each frame is divided by $\sqrt{F\_s\sum_{i} window[i]^2}$.

  ## Examples

      iex> {z, t, f} = NxSignal.stft(Nx.iota({4}), NxSignal.Windows.rectangular(n: 2), overlap_length: 1, fft_length: 2, sampling_rate: 400)
      iex> z
      #Nx.Tensor<
        c64[frames: 3][frequencies: 2]
        [
          [1.0+0.0i, -1.0+0.0i],
          [3.0+0.0i, -1.0+0.0i],
          [5.0+0.0i, -1.0+0.0i]
        ]
      >
      iex> t
      #Nx.Tensor<
        f32[frames: 3]
        [0.0024999999441206455, 0.004999999888241291, 0.007499999832361937]
      >
      iex> f
      #Nx.Tensor<
        f32[frequencies: 2]
        [0.0, 200.0]
      >
  """
  @doc type: :time_frequency
  deftransform stft(data, window, opts \\ []) do
    {frame_length} = Nx.shape(window)

    opts =
      Keyword.validate!(opts, [
        :overlap_length,
        :window,
        :scaling,
        window_padding: :valid,
        sampling_rate: 100,
        fft_length: :power_of_two
      ])

    sampling_rate = opts[:sampling_rate] || raise ArgumentError, "missing sampling_rate option"

    overlap_length = opts[:overlap_length] || div(frame_length, 2)

    stft_n(data, window, sampling_rate, Keyword.put(opts, :overlap_length, overlap_length))
  end

  defnp stft_n(data, window, sampling_rate, opts) do
    {frame_length} = Nx.shape(window)
    padding = opts[:window_padding]
    fft_length = opts[:fft_length]
    overlap_length = opts[:overlap_length]

    spectrum =
      data
      |> as_windowed(
        padding: padding,
        window_length: frame_length,
        stride: frame_length - overlap_length
      )
      |> Nx.multiply(window)
      |> Nx.fft(length: fft_length)

    {num_frames, fft_length} = Nx.shape(spectrum)

    frequencies = fft_frequencies(sampling_rate, fft_length: fft_length)

    # assign the middle of the equivalent time window as the time for the given frame
    time_step = frame_length / (2 * sampling_rate)
    last_frame = time_step * num_frames
    times = Nx.linspace(time_step, last_frame, n: num_frames, name: :frames)

    output =
      case opts[:scaling] do
        :spectrum ->
          spectrum / Nx.sum(window)

        :psd ->
          spectrum / Nx.sqrt(sampling_rate * Nx.sum(window ** 2))

        nil ->
          spectrum

        scaling ->
          raise ArgumentError,
                "invalid :scaling, expected one of :spectrum, :psd or nil, got: #{inspect(scaling)}"
      end

    {Nx.reshape(output, spectrum.shape, names: [:frames, :frequencies]), times, frequencies}
  end

  @doc """
  Computes the frequency bins for a FFT with given options.

  ## Arguments

    * `sampling_rate` - Sampling frequency in Hz.

  ## Options

    * `:fft_length` - Number of FFT frequency bins.
    * `:type` - Optional output type. Defaults to `{:f, 32}`
    * `:name` - Optional axis name for the tensor. Defaults to `:frequencies`

  ## Examples

      iex> NxSignal.fft_frequencies(1.6e4, fft_length: 10)
      #Nx.Tensor<
        f32[frequencies: 10]
        [0.0, 1.6e3, 3.2e3, 4.8e3, 6.4e3, 8.0e3, 9.6e3, 1.12e4, 1.28e4, 1.44e4]
      >
  """
  @doc type: :time_frequency
  defn fft_frequencies(sampling_rate, opts \\ []) do
    opts = keyword!(opts, [:fft_length, type: {:f, 32}, name: :frequencies, endpoint: false])
    fft_length = opts[:fft_length]

    step = sampling_rate / fft_length

    Nx.linspace(0, step * fft_length,
      n: fft_length,
      type: opts[:type],
      name: opts[:name],
      endpoint: opts[:endpoint]
    )
  end

  @doc """
  Returns a tensor of K windows of length N

  ## Options

    * `:window_length` - the number of samples in a window
    * `:stride` - The number of samples to skip between windows. Defaults to `1`.
    * `:padding` - Padding mode, can be `:reflect` or a valid padding as per `Nx.pad/3` over the
      input tensor's shape. Defaults to `:valid`. If `:reflect` or `:same`, the first window will be centered
      at the start of the signal. The padding is applied for the whole input, rather than individual
      windows. For `:zeros`, effectively each incomplete window will be zero-padded.

  ## Examples

      iex> NxSignal.as_windowed(Nx.tensor([0, 1, 2, 3, 4, 10, 11, 12]), window_length: 4)
      #Nx.Tensor<
        s64[5][4]
        [
          [0, 1, 2, 3],
          [1, 2, 3, 4],
          [2, 3, 4, 10],
          [3, 4, 10, 11],
          [4, 10, 11, 12]
        ]
      >

      iex> NxSignal.as_windowed(Nx.tensor([0, 1, 2, 3, 4, 10, 11, 12]), window_length: 3)
      #Nx.Tensor<
        s64[6][3]
        [
          [0, 1, 2],
          [1, 2, 3],
          [2, 3, 4],
          [3, 4, 10],
          [4, 10, 11],
          [10, 11, 12]
        ]
      >

      iex> NxSignal.as_windowed(Nx.tensor([0, 1, 2, 3, 4, 10, 11]), window_length: 2, stride: 2, padding: [{0, 3}])
      #Nx.Tensor<
        s64[5][2]
        [
          [0, 1],
          [2, 3],
          [4, 10],
          [11, 0],
          [0, 0]
        ]
      >

      iex> t = Nx.iota({7});
      iex> NxSignal.as_windowed(t, window_length: 6, padding: :reflect, stride: 1)
      #Nx.Tensor<
        s64[8][6]
        [
          [3, 2, 1, 0, 1, 2],
          [2, 1, 0, 1, 2, 3],
          [1, 0, 1, 2, 3, 4],
          [0, 1, 2, 3, 4, 5],
          [1, 2, 3, 4, 5, 6],
          [2, 3, 4, 5, 6, 5],
          [3, 4, 5, 6, 5, 4],
          [4, 5, 6, 5, 4, 3]
        ]
      >

      iex> NxSignal.as_windowed(Nx.iota({10}), window_length: 6, padding: :reflect, stride: 2)
      #Nx.Tensor<
        s64[6][6]
        [
          [3, 2, 1, 0, 1, 2],
          [1, 0, 1, 2, 3, 4],
          [1, 2, 3, 4, 5, 6],
          [3, 4, 5, 6, 7, 8],
          [5, 6, 7, 8, 9, 8],
          [7, 8, 9, 8, 7, 6]
        ]
      >
  """
  @doc type: :windowing
  deftransform as_windowed(tensor, opts \\ []) do
    if opts[:padding] == :reflect do
      as_windowed_reflect_padding(tensor, opts)
    else
      as_windowed_non_reflect_padding(tensor, opts)
    end
  end

  deftransformp as_windowed_parse_reflect_opts(shape, opts) do
    window_length = opts[:window_length]

    as_windowed_parse_non_reflect_opts(
      shape,
      Keyword.put(opts, :padding, [{div(window_length, 2), div(window_length, 2)}])
    )
  end

  deftransformp as_windowed_parse_non_reflect_opts(shape, opts) do
    opts = Keyword.validate!(opts, [:window_length, padding: :valid, stride: 1])
    window_length = opts[:window_length]
    window_dimensions = {window_length}

    padding = opts[:padding]

    [stride] =
      strides =
      case opts[:stride] do
        stride when is_list(stride) ->
          stride

        stride when is_integer(stride) and stride >= 1 ->
          [stride]

        stride ->
          raise ArgumentError,
                "expected an integer >= 1 or a list of integers, got: #{inspect(stride)}"
      end

    padding_config = as_windowed_to_padding_config(shape, window_dimensions, padding)

    # trick so that we can get Nx to calculate the pooled shape for us
    %{shape: pooled_shape} =
      Nx.window_max(
        Nx.iota(shape, backend: Nx.Defn.Expr),
        window_dimensions,
        padding: padding,
        strides: strides
      )

    output_shape = {Tuple.product(pooled_shape), window_length}

    {window_length, stride, padding_config, output_shape}
  end

  defp as_windowed_to_padding_config(shape, kernel_size, mode) do
    case mode do
      :valid ->
        List.duplicate({0, 0, 0}, tuple_size(shape))

      :same ->
        Enum.zip_with(Tuple.to_list(shape), Tuple.to_list(kernel_size), fn dim, k ->
          padding_size = max(dim - 1 + k - dim, 0)
          {floor(padding_size / 2), ceil(padding_size / 2), 0}
        end)

      config when is_list(config) ->
        Enum.map(config, fn
          {x, y} when is_integer(x) and is_integer(y) ->
            {x, y, 0}

          _other ->
            raise ArgumentError,
                  "padding must be a list of {high, low} tuples, where each element is an integer. " <>
                    "Got: #{inspect(config)}"
        end)

      mode ->
        raise ArgumentError,
              "invalid padding mode specified, padding must be one" <>
                " of :valid, :same, or a padding configuration, got:" <>
                " #{inspect(mode)}"
    end
  end

  defnp as_windowed_non_reflect_padding(tensor, opts \\ []) do
    # current implementation only supports windowing 1D tensors
    {window_length, stride, padding, output_shape} =
      as_windowed_parse_non_reflect_opts(Nx.shape(tensor), opts)

    tensor = Nx.pad(tensor, 0, padding)

    as_windowed_apply(tensor, stride, output_shape, window_length)
  end

  defnp as_windowed_reflect_padding(tensor, opts \\ []) do
    # current implementation only supports windowing 1D tensors
    {window_length, stride, _padding, output_shape} =
      as_windowed_parse_reflect_opts(Nx.shape(tensor), opts)

    half_window = div(window_length, 2)
    tensor = Nx.reflect(tensor, padding_config: [{half_window, half_window}])

    as_windowed_apply(tensor, stride, output_shape, window_length)
  end

  defnp as_windowed_apply(tensor, stride, output_shape, window_length) do
    output = Nx.broadcast(Nx.tensor(0, type: tensor.type), output_shape)
    {num_windows, _} = Nx.shape(output)

    [output, tensor] = Nx.broadcast_vectors([output, tensor])

    {output, _, _, _} =
      while {output, i = 0, current_window = 0, t = tensor}, current_window < num_windows do
        window = t |> Nx.slice([i], [window_length])
        updated = Nx.put_slice(output, [current_window, 0], Nx.new_axis(window, 0))
        {updated, i + stride, current_window + 1, t}
      end

    output
  end

  @doc """
  Generates weights for converting an STFT representation into MEL-scale.

  See also: `stft/3`, `istft/3`, `stft_to_mel/3`

  ## Arguments

    * `fft_length` - Number of FFT bins
    * `mel_bins` - Number of target MEL bins
    * `sampling_rate` - Sampling frequency in Hz

  ## Options
    * `:max_mel` - the pitch for the last MEL bin before log scaling. Defaults to 3016
    * `:mel_frequency_spacing` - the distance in Hz between two MEL bins before log scaling. Defaults to 66.6
    * `:type` - Target output type. Defaults to `{:f, 32}`

  ## Examples

      iex> NxSignal.mel_filters(10, 5, 8.0e3)
      #Nx.Tensor<
        f32[mels: 5][frequencies: 10]
        [
          [0.0, 8.129207999445498e-4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
          [0.0, 9.972016559913754e-4, 2.1870288765057921e-4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
          [0.0, 0.0, 9.510891977697611e-4, 4.150509194005281e-4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
          [0.0, 0.0, 0.0, 4.035891906823963e-4, 5.276656011119485e-4, 2.574124082457274e-4, 0.0, 0.0, 0.0, 0.0],
          [0.0, 0.0, 0.0, 0.0, 7.329034269787371e-5, 2.342205698369071e-4, 3.8295105332508683e-4, 2.8712040511891246e-4, 1.9128978601656854e-4, 9.545915963826701e-5]
        ]
      >
  """
  @doc type: :time_frequency
  deftransform mel_filters(fft_length, mel_bins, sampling_rate, opts \\ []) do
    opts =
      Keyword.validate!(opts,
        max_mel: 3016,
        mel_frequency_spacing: 200 / 3,
        type: {:f, 32}
      )

    mel_filters_n(sampling_rate, opts[:max_mel], opts[:mel_frequency_spacing],
      type: opts[:type],
      fft_length: fft_length,
      mel_bins: mel_bins
    )
  end

  defnp mel_filters_n(sampling_rate, max_mel, f_sp, opts) do
    fft_length = opts[:fft_length]
    mel_bins = opts[:mel_bins]
    type = opts[:type]

    fftfreqs = fft_frequencies(sampling_rate, type: type, fft_length: fft_length)

    mels = Nx.linspace(0, max_mel / f_sp, type: type, n: mel_bins + 2, name: :mels)
    freqs = f_sp * mels

    min_log_hz = 1_000
    min_log_mel = min_log_hz / f_sp

    # numpy uses the f64 value by default
    logstep = Nx.log(6.4) / 27

    log_t = mels >= min_log_mel

    # This is the same as freqs[log_t] = min_log_hz * Nx.exp(logstep * (mels[log_t] - min_log_mel))
    # notice that since freqs and mels are indexed by the same conditional tensor, we don't
    # need to slice either of them
    mel_f = Nx.select(log_t, min_log_hz * Nx.exp(logstep * (mels - min_log_mel)), freqs)

    fdiff = Nx.new_axis(mel_f[1..-1//1] - mel_f[0..-2//1], 1)
    ramps = Nx.new_axis(mel_f, 1) - fftfreqs

    lower = -ramps[0..(mel_bins - 1)] / fdiff[0..(mel_bins - 1)]
    upper = ramps[2..(mel_bins + 1)//1] / fdiff[1..mel_bins]
    weights = Nx.max(0, Nx.min(lower, upper))

    enorm = 2.0 / (mel_f[2..(mel_bins + 1)] - mel_f[0..(mel_bins - 1)])

    weights * Nx.new_axis(enorm, 1)
  end

  @doc """
  Converts a given STFT time-frequency spectrum into a MEL-scale time-frequency spectrum.

  See also: `stft/3`, `istft/3`, `mel_filters/4`

  ## Arguments

    * `z` - STFT spectrum
    * `sampling_rate` - Sampling frequency in Hz

  ## Options

    * `:fft_length` - Number of FFT bins
    * `:mel_bins` - Number of target MEL bins. Defaults to 128
    * `:type` - Target output type. Defaults to `{:f, 32}`

  ## Examples

      iex> fft_length = 16
      iex> sampling_rate = 8.0e3
      iex> {z, _, _} = NxSignal.stft(Nx.iota({10}), NxSignal.Windows.hann(n: 4), overlap_length: 2, fft_length: fft_length, sampling_rate: sampling_rate, window_padding: :reflect)
      iex> Nx.axis_size(z, :frequencies)
      16
      iex> Nx.axis_size(z, :frames)
      6
      iex> NxSignal.stft_to_mel(z, sampling_rate, fft_length: fft_length, mel_bins: 4)
      #Nx.Tensor<
        f32[frames: 6][mel: 4]
        [
          [0.2900530695915222, 0.17422175407409668, 0.18422472476959229, 0.09807997941970825],
          [0.6093881130218506, 0.5647397041320801, 0.4353824257850647, 0.08635270595550537],
          [0.7584103345870972, 0.7085014581680298, 0.5636920928955078, 0.179118812084198],
          [0.8461772203445435, 0.7952491044998169, 0.6470762491226196, 0.2520409822463989],
          [0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721],
          [0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721]
        ]
      >
  """
  @doc type: :time_frequency
  defn stft_to_mel(z, sampling_rate, opts \\ []) do
    opts =
      keyword!(opts, [:fft_length, :mel_bins, :max_mel, :mel_frequency_spacing, type: {:f, 32}])

    magnitudes = Nx.abs(z) ** 2

    filters =
      mel_filters(opts[:fft_length], opts[:mel_bins], sampling_rate, mel_filters_opts(opts))

    freq_size = div(opts[:fft_length], 2)

    real_freqs_mag = Nx.slice_along_axis(magnitudes, 0, freq_size, axis: :frequencies)
    real_freqs_filters = Nx.slice_along_axis(filters, 0, freq_size, axis: :frequencies)

    mel_spec =
      Nx.dot(
        real_freqs_mag,
        [:frequencies],
        real_freqs_filters,
        [:frequencies]
      )

    mel_spec = Nx.reshape(mel_spec, Nx.shape(mel_spec), names: [:frames, :mel])

    log_spec = Nx.log(Nx.clip(mel_spec, 1.0e-10, :infinity)) / Nx.log(10)
    log_spec = Nx.max(log_spec, Nx.reduce_max(log_spec) - 8)
    (log_spec + 4) / 4
  end

  deftransformp mel_filters_opts(opts) do
    Keyword.take(opts, [:max_mel, :mel_frequency_spacing, :type])
  end

  @doc ~S"""
  Computes the Inverse Short-Time Fourier Transform of a tensor.

  Returns a tensor of M time-domain frames of length `fft_length`.

  See also: `NxSignal.Windows`, `stft/3`

  ## Options

    * `:fft_length` - the DFT length that will be passed to `Nx.fft/2`. Defaults to `:power_of_two`.
    * `:overlap_length` - the number of samples for the overlap between frames.
      Defaults to half the window size.
    * `:sampling_rate` - the sampling rate $F_s$ in Hz. Defaults to `1000`.
    * `:scaling` - `nil`, `:spectrum` or `:psd`.
      * `:spectrum` - each frame is multiplied by $\sum_{i} window[i]$.
      * `nil` - No scaling is applied.
      * `:psd` - each frame is multiplied by $\sqrt{F\_s\sum_{i} window[i]^2}$.

  ## Examples

  In general, `istft/3` takes in the same parameters and window as the `stft/3` that generated the spectrum.
  In the first example, we can notice that the reconstruction is mostly perfect, aside from the first sample.

  This is because the Hann window only ensures perfect reconstruction in overlapping regions, so the edges
  of the signal end up being distorted.

      iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20])
      iex> w = NxSignal.Windows.hann(n: 4)
      iex> opts = [sampling_rate: 1, fft_length: 4]
      iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts)
      iex> result = NxSignal.istft(z, w, opts)
      iex> Nx.as_type(result, Nx.type(t))
      #Nx.Tensor<
        s64[8]
        [0, 10, 1, 0, 10, 10, 2, 20]
      >

  Different scaling options are available (see `stft/3` for a more detailed explanation).
  For perfect reconstruction, you want to use the same scaling as the STFT:

      iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20])
      iex> w = NxSignal.Windows.hann(n: 4)
      iex> opts = [scaling: :spectrum, sampling_rate: 1, fft_length: 4]
      iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts)
      iex> result = NxSignal.istft(z, w, opts)
      iex> Nx.as_type(result, Nx.type(t))
      #Nx.Tensor<
        s64[8]
        [0, 10, 1, 0, 10, 10, 2, 20]
      >

      iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20], type: :f32)
      iex> w = NxSignal.Windows.hann(n: 4)
      iex> opts = [scaling: :psd, sampling_rate: 1, fft_length: 4]
      iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts)
      iex> result = NxSignal.istft(z, w, opts)
      iex> Nx.as_type(result, Nx.type(t))
      #Nx.Tensor<
        f32[8]
        [0.0, 10.0, 0.9999999403953552, -2.1900146407460852e-7, 10.0, 10.0, 2.000000238418579, 20.0]
      >
  """
  @doc type: :time_frequency
  defn istft(data, window, opts) do
    opts = keyword!(opts, [:fft_length, :overlap_length, :scaling, sampling_rate: 1000])

    fft_length =
      case opts[:fft_length] do
        nil ->
          :power_of_two

        fft_length ->
          fft_length
      end

    overlap_length =
      case opts[:overlap_length] do
        nil ->
          div(Nx.size(window), 2)

        overlap_length ->
          overlap_length
      end

    sampling_rate =
      case {opts[:scaling], opts[:sampling_rate]} do
        {:psd, nil} -> raise ArgumentError, ":sampling_rate is mandatory if scaling is :psd"
        {_, sampling_rate} -> sampling_rate
      end

    frames = Nx.ifft(data, length: fft_length)

    frames_rescaled =
      case opts[:scaling] do
        :spectrum ->
          frames * Nx.sum(window)

        :psd ->
          frames * Nx.sqrt(sampling_rate * Nx.sum(window ** 2))

        nil ->
          frames

        scaling ->
          raise ArgumentError,
                "invalid :scaling, expected one of :spectrum, :psd or nil, got: #{inspect(scaling)}"
      end

    result_non_normalized =
      overlap_and_add(frames_rescaled * window, overlap_length: overlap_length)

    normalization_factor =
      overlap_and_add(Nx.broadcast(Nx.abs(window) ** 2, data.shape),
        overlap_length: overlap_length
      )

    normalization_factor = Nx.select(normalization_factor > 1.0e-10, normalization_factor, 1.0)

    result_non_normalized / normalization_factor
  end

  @doc """
  Performs the overlap-and-add algorithm over
  an {..., M, N}-shaped tensor, where M is the number of
  windows and N is the window size.

  The tensor is zero-padded on the right so
  the last window fully appears in the result.

  ## Options

    * `:overlap_length` - The number of overlapping samples between windows
    * `:type` - output type for casting the accumulated result.
      If not given, defaults to `Nx.Type.to_complex/1` called on the input type.

  ## Examples

      iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_length: 0)
      #Nx.Tensor<
        s64[12]
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
      >

      iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_length: 3)
      #Nx.Tensor<
        s64[6]
        [0, 5, 15, 18, 17, 11]
      >

      iex> t = Nx.tensor([[[[0, 1, 2, 3], [4, 5, 6, 7]]], [[[10, 11, 12, 13], [14, 15, 16, 17]]]]) |> Nx.vectorize(x: 2, y: 1)
      iex> NxSignal.overlap_and_add(t, overlap_length: 3)
      #Nx.Tensor<
        vectorized[x: 2][y: 1]
        s64[5]
        [
          [
            [0, 5, 7, 9, 7]
          ],
          [
            [10, 25, 27, 29, 17]
          ]
        ]
      >
  """
  @doc type: :windowing
  defn overlap_and_add(tensor, opts \\ []) do
    opts = keyword!(opts, [:overlap_length, type: Nx.type(tensor)])
    overlap_length = opts[:overlap_length]

    %{vectorized_axes: vectorized_axes, shape: input_shape} = tensor
    num_windows = Nx.axis_size(tensor, -2)
    window_length = Nx.axis_size(tensor, -1)

    if overlap_length >= window_length do
      raise ArgumentError,
            "overlap_length must be a number less than the window size #{window_length}, got: #{inspect(window_length)}"
    end

    tensor =
      Nx.revectorize(tensor, [condensed_vectors: :auto, windows: num_windows],
        target_shape: {window_length}
      )

    stride = window_length - overlap_length
    output_holder_shape = {num_windows * stride + overlap_length}

    out =
      Nx.broadcast(
        Nx.tensor(0, type: tensor.type),
        output_holder_shape
      )

    idx_template = Nx.iota({window_length, 1}, vectorized_axes: [windows: 1])
    i = Nx.iota({num_windows}) |> Nx.vectorize(:windows)
    idx = idx_template + i * stride

    [%{vectorized_axes: [condensed_vectors: n, windows: _]} = tensor, idx] =
      Nx.broadcast_vectors([tensor, idx])

    tensor = Nx.revectorize(tensor, [condensed_vectors: n], target_shape: {:auto})
    idx = Nx.revectorize(idx, [condensed_vectors: n], target_shape: {:auto, 1})

    out_shape = overlap_and_add_output_shape(out.shape, input_shape)

    out
    |> Nx.indexed_add(idx, tensor)
    |> Nx.as_type(opts[:type])
    |> Nx.revectorize(vectorized_axes, target_shape: out_shape)
  end

  deftransformp overlap_and_add_output_shape({out_len}, in_shape) do
    idx = tuple_size(in_shape) - 2

    in_shape
    |> Tuple.delete_at(idx)
    |> Tuple.delete_at(idx)
    |> Tuple.append(out_len)
  end
end