lib/integrator/utils.ex

defmodule Integrator.Utils do
  @moduledoc """
  Various utility functions used in `Integrator`
  """
  import Nx.Defn

  @doc """
  Performs a 3rd order Hermite interpolation. Adapted from function `hermite_cubic_interpolation` in
  [runge_kutta_interpolate.m](https://github.com/gnu-octave/octave/blob/default/scripts/ode/private/runge_kutta_interpolate.m)


  See [Wikipedia](https://en.wikipedia.org/wiki/Cubic_Hermite_spline)
  """
  @spec hermite_cubic_interpolation(Nx.t(), Nx.t(), Nx.t(), Nx.t()) :: Nx.t()
  defn hermite_cubic_interpolation(t, x, der, t_out) do
    # Octave:
    #   dt = (t(2) - t(1));
    #   s = (t_out - t(1)) / dt;
    #   x_out = ((1 + 2*s) .* (1-s).^2) .* x(:,1) + ...
    #           (s .* (1-s).^2 * dt   ) .* der(:,1) + ...
    #           ((3-2*s) .* s.^2      ) .* x(:,end) + ...
    #           ((s-1) .* s.^2   * dt ) .* der(:,end);

    dt = t[1] - t[0]
    s = (t_out - t[0]) / dt

    x_col1 = Nx.slice_along_axis(x, 0, 1, axis: 1)
    der_col_1 = Nx.slice_along_axis(der, 0, 1, axis: 1)
    x_col2 = Nx.slice_along_axis(x, 1, 1, axis: 1)
    # Note that we are assuming "der" has 4 columns:
    der_last_col = Nx.slice_along_axis(der, 3, 1, axis: 1)

    s_minus_1 = 1 - s
    s_minus_1_sq = s_minus_1 * s_minus_1

    x1 = (1 + 2 * s) * s_minus_1_sq * x_col1
    x2 = s * s_minus_1_sq * dt * der_col_1
    x3 = (3 - 2 * s) * s * s * x_col2
    x4 = (s - 1) * s * s * dt * der_last_col

    x1 + x2 + x3 + x4
  end

  @coefs_u_half [
    6_025_192_743 / 30_085_553_152,
    0.0,
    51_252_292_925 / 65_400_821_598,
    -2_691_868_925 / 45_128_329_728,
    187_940_372_067 / 1_594_534_317_056,
    -1_776_094_331 / 19_743_644_256,
    11_237_099 / 235_043_384
  ]

  @doc """
  Performs a 4th order Hermite interpolation. Used by an ODE solver to interpolate the
  solution at the time `t_out`. As proposed by Shampine in Lawrence, Shampine,
  "Some Practical Runge-Kutta Formulas", 1986.

  See [hermite_quartic_interpolation function in Octave](https://github.com/gnu-octave/octave/blob/default/scripts/ode/private/runge_kutta_interpolate.m#L91).
  """
  @spec hermite_quartic_interpolation(Nx.t(), Nx.t(), Nx.t(), Nx.t()) :: Nx.t()
  defn hermite_quartic_interpolation(t, x, der, t_out) do
    dt = t[1] - t[0]
    x_col1 = Nx.slice_along_axis(x, 0, 1, axis: 1)

    # 4th order approximation of x in t+dt/2 as proposed by Shampine in
    # Lawrence, Shampine, "Some Practical Runge-Kutta Formulas", 1986.
    u_half = x_col1 + 0.5 * dt * Nx.new_axis(Nx.dot(der, Nx.tensor(@coefs_u_half, type: Nx.type(x))), 1)

    # Rescale time on [0,1]
    s = (t_out - t[0]) / dt

    s2 = s * s
    s3 = s2 * s
    s4 = s3 * s

    # Hermite basis functions

    # H0 = x1 = 1   - 11*s^2 + 18*s^3 -  8*s^4
    # H1 = x2 =   s -  4*s^2 +  5*s^3 -  2*s^4
    # H2 = x3 =       16*s^2 - 32*s^3 + 16*s^4
    # H3 = x4 =     -  5*s^2 + 14*s^3 -  8*s^4
    # H4 = x5 =          s^2 -  3*s^3 +  2*s^4

    x1 = (1.0 - 11.0 * s2 + 18.0 * s3 - 8.0 * s4) * x_col1

    der_col_1 = Nx.slice_along_axis(der, 0, 1, axis: 1)
    x2 = (s - 4.0 * s2 + 5.0 * s3 - 2.0 * s4) * (dt * der_col_1)

    x3 = (16.0 * s2 - 32.0 * s3 + 16.0 * s4) * u_half

    x_col2 = Nx.slice_along_axis(x, 1, 1, axis: 1)
    x4 = (-5.0 * s2 + 14.0 * s3 - 8.0 * s4) * x_col2

    # Note that we are assuming that "der" has 7 columns here:
    der_last_col = Nx.slice_along_axis(der, 6, 1, axis: 1)
    x5 = (s2 - 3.0 * s3 + 2.0 * s4) * (dt * der_last_col)

    x1 + x2 + x3 + x4 + x5
  end

  @doc """
  Implements the Kahan summation algorithm, also known as compensated summation.
  Based on this [code in Octave](https://github.com/gnu-octave/octave/blob/default/scripts/ode/private/kahan.m).
  This is really a private function, but is made public so the docs are visible.

  The algorithm significantly reduces the numerical error in the total
  obtained by adding a sequence of finite precision floating point numbers
  compared to the straightforward approach.  For more details
  see [this Wikipedia entry](http://en.wikipedia.org/wiki/Kahan_summation_algorithm).
  This function is called by AdaptiveStepsize.integrate to better catch
  equality comparisons.

  The first input argument is the variable that will contain the summation.
  This variable is also returned as the first output argument in order to
  reuse it in subsequent calls to `kahan_sum/3` function.

  The second input argument contains the compensation term and is returned
  as the second output argument so that it can be reused in future calls of
  the same summation.

  The third input argument `term` is the variable to be added to `sum`.
  """
  @spec kahan_sum(Nx.t(), Nx.t(), Nx.t()) :: {Nx.t(), Nx.t()}
  defn kahan_sum(sum, comp, term) do
    # Octave code:
    #   x = term - comp;
    #   t = sum + x;
    #   comp = (t - sum) - x;
    #   sum = t;

    x = term - comp
    t = sum + x

    {t, t - sum - x}
  end

  @doc """
  Returns the sign of the tensor as -1 or 1 (or 0 for zero tensors)
  """
  @spec sign(float()) :: float()
  def sign(x) when x < 0.0, do: -1.0
  def sign(x) when x > 0.0, do: 1.0
  def sign(_x), do: 0.0

  @doc """
  Returns the columns of a tensor as a list of vector tensors

  E.g., a tensor of the form:

      ~M[
        1  2  3  4
        5  6  7  8
        9 10 11 12
      ]s8

  will return

    [
      ~V[ 1  5   9 ]s8,
      ~V[ 2  6  10 ]s8,
      ~V[ 3  7  11 ]s8,
      ~V[ 4  8  12 ]s8
    ]

  """
  @spec columns_as_list(Nx.t(), integer(), integer() | nil) :: [Nx.t()]
  def columns_as_list(matrix, start_index, end_index \\ nil) do
    matrix_t = Nx.transpose(matrix)

    end_index =
      if end_index do
        end_index
      else
        {_n_rows, n_cols} = Nx.shape(matrix)
        n_cols - 1
      end

    start_index..end_index
    |> Enum.reduce([], fn i, acc ->
      col = Nx.slice_along_axis(matrix_t, i, 1, axis: 0) |> Nx.flatten()
      [col | acc]
    end)
    |> Enum.reverse()
  end

  @doc """
  Converts a Nx vector into a list of 1-D tensors

  Is there an existing Nx way to do this?  If so, swap the usage of this function
  and then delete this

  Note that

      vector |> Nx.as_list() |> Enum.map(& Nx.tensor(&1, type: Nx.type(vector)))

  seems to introduce potential precision issues
  """
  @spec vector_as_list(Nx.t()) :: [Nx.t()]
  def vector_as_list(vector) do
    {length} = Nx.shape(vector)

    (length - 1)..0
    |> Enum.reduce([], fn i, acc ->
      [vector[i] | acc]
    end)
  end

  @doc """
  Given a list of elements, create a new list with only the unique elements from the list
  """
  @spec unique(list()) :: list()
  def unique(values) do
    MapSet.new(values) |> MapSet.to_list() |> Enum.sort()
  end

  # Delete these and use Nx.Constants.episilon/1 instead once Nx 0.5.4 is published
  # (it is currently in main)
  @epislon_f32 1.1920929e-07
  @epislon_f64 2.220446049250313e-16

  @doc """
  Gets the epsilon for a particular Nx type
  """
  @spec epsilon(Nx.Type.t()) :: float()
  def epsilon(nx_type) do
    case nx_type do
      :f32 -> @epislon_f32
      {:f, 32} -> @epislon_f32
      :f64 -> @epislon_f64
      {:f, 64} -> @epislon_f64
    end
  end

  @spec epsilon(Nx.Type.t()) :: float()
  defn epsilon_nx(nx_type) do
    case nx_type do
      :f32 -> @epislon_f32
      {:f, 32} -> @epislon_f32
      :f64 -> @epislon_f64
      {:f, 64} -> @epislon_f64
    end
  end
end