lib/integrator/adaptive_stepsize.ex

defmodule Integrator.AdaptiveStepsize do
  @moduledoc """
  Integrates a set of ODEs with an adaptive timestep.
  """
  import Nx.Defn

  alias Integrator.{MaxErrorsExceededError, NonLinearEqnRoot, RungeKutta, Utils}
  alias Integrator.AdaptiveStepsize.{ArgPrecisionError, MaxErrorsExceededError}

  defmodule ComputedStep do
    @moduledoc """
    The results of the computation of an individual Runge-Kutta step
    """

    @type t :: %__MODULE__{
            t_new: Nx.t(),
            x_new: Nx.t(),
            k_vals: Nx.t(),
            options_comp: Nx.t()
          }

    defstruct [
      :t_new,
      :x_new,
      :k_vals,
      :options_comp
    ]
  end

  @type t :: %__MODULE__{
          t_old: Nx.t() | nil,
          x_old: Nx.t() | nil,
          #
          t_new: Nx.t() | nil,
          x_new: Nx.t() | nil,
          #
          # t & x used for Runge-Kutta interpolations:
          t_new_rk_interpolate: Nx.t() | nil,
          x_new_rk_interpolate: Nx.t() | nil,
          #
          dt: Nx.t() | nil,
          k_vals: Nx.t() | nil,
          nx_type: Nx.Type.t(),
          #
          options_comp: Nx.t() | nil,
          #
          # Fixed output times; e.g., integration output computed at times at [0.1, 0.2, 0.3, ...] via interpolation:
          fixed_times: [Nx.t()] | nil,
          #
          count_loop__increment_step: integer(),
          count_cycles__compute_step: integer(),
          #
          # ireject in Octave:
          error_count: integer(),
          i_step: integer(),
          #
          terminal_event: integration_status(),
          terminal_output: integration_status(),
          #
          # The output of the Runge-Kutta integration:
          ode_t: [Nx.t()],
          ode_x: [Nx.t()],
          #
          # The output of the integration, plus the interpolated points:
          output_t: [Nx.t()],
          output_x: [Nx.t()],
          #
          # The last chunk of points for this computed step; will include the computed point plus the
          # interpolated points (if # interpolation is enabled) or just the computed point (if interpolation is disabled):
          t_new_chunk: [Nx.t()],
          x_new_chunk: [Nx.t()],
          #
          timestamp_ms: integer() | nil,
          timestamp_start_ms: integer() | nil
        }
  defstruct [
    :t_old,
    :x_old,
    #
    :t_new,
    :x_new,
    #
    # t & x used for Runge-Kutta interpolations:
    :t_new_rk_interpolate,
    :x_new_rk_interpolate,
    #
    :dt,
    :k_vals,
    nx_type: :f32,
    #
    options_comp: 0.0,
    #
    # Fixed output times; e.g., integration output computed at times at [0.1, 0.2, 0.3, ...] via interpolation:
    fixed_times: nil,
    #
    count_loop__increment_step: 0,
    count_cycles__compute_step: 0,
    #
    # ireject in Octave:
    error_count: 0,
    i_step: 0,
    #
    terminal_event: :continue,
    terminal_output: :continue,
    #
    # The output of the Runge-Kutta integration:
    ode_t: [],
    ode_x: [],
    #
    # The output of the integration, plus the interpolated points:
    output_t: [],
    output_x: [],
    #
    # The last chunk of points for this computed step; will include the computed point plus the
    # interpolated points (if # interpolation is enabled) or just the computed point (if interpolation is disabled):
    t_new_chunk: [],
    x_new_chunk: [],
    #
    timestamp_ms: nil,
    timestamp_start_ms: nil
  ]

  @type integration_status :: :halt | :continue
  @type refine_strategy :: integer() | :fixed_times

  @type event_fn_t :: (Nx.t(), Nx.t() -> {integration_status(), Nx.t()})
  @type output_fn_t :: ([Nx.t()], [Nx.t()] -> any())

  # Base zero_tolerance on precision?
  @zero_tolerance 1.0e-07

  options = [
    abs_tol: [
      type: :any,
      doc: """
      The absolute tolerance used when computing the absolute relative norm. Defaults to 1.0e-06 in the Nx type that's been specified.
      """
    ],
    event_fn: [
      type: {:or, [{:fun, 2}, nil]},
      doc: "A 2 arity function which determines whether an event has occured.  If so, the integration is halted.",
      default: nil
    ],
    max_number_of_errors: [
      type: :integer,
      doc: "The maximum number of permissible errors before the integration is halted.",
      default: 5_000
    ],
    max_step: [
      type: :any,
      doc: """
      The default max time step.  The default value is determined by the start and end times.
      """
    ],
    norm_control: [
      type: :boolean,
      doc: "Indicates whether norm control is to be used when computing the absolute relative norm.",
      default: true
    ],
    output_fn: [
      type: {:or, [{:fun, 2}, nil]},
      doc: "A 2 arity function which is called at each output point.",
      default: nil
    ],
    refine: [
      type: {:or, [:atom, :pos_integer]},
      doc: """
      Indicates the number of additional interpolated points. `1` means no interpolation; `2` means one
      additional interpolated point; etc. `:fixed_times` means that the output times are fixed.
      """,
      default: 4
    ],
    rel_tol: [
      type: :any,
      doc: """
       The relative tolerance used when computing the absolute relative norm. Defaults to 1.0e-03 in the Nx type that's been specified.
      """
    ],
    speed: [
      type: {:or, [:atom, :float]},
      doc: """
      `:no_delay` means to simulate as fast as possible. `1.0` means real time, `2.0` means twice as fast as real time,
      `0.5` means half as fast as real time, etc.
      """,
      default: :no_delay
    ],
    store_results?: [
      type: :boolean,
      doc: "Indicates whether or not to store the results of the integration.",
      default: true
    ]
  ]

  @options_schema_adaptive_stepsize_only NimbleOptions.new!(options)
  def options_schema_adaptive_stepsize_only, do: @options_schema_adaptive_stepsize_only

  @options_schema NimbleOptions.new!(NonLinearEqnRoot.options_schema().schema |> Keyword.merge(options))
  def options_schema, do: @options_schema

  @options_currently_without_nimble_defaults [abs_tol: nil, rel_tol: nil, max_step: nil]
  def option_keys, do: NimbleOptions.validate!(@options_currently_without_nimble_defaults, @options_schema) |> Keyword.keys()

  @type options_t() :: unquote(NimbleOptions.option_typespec(@options_schema))

  # :no_delay means to perform the integration as fast as possible
  # For float values, 1.0 means to integrate in real-time, 0.5 means half of real-time, 2.0 means twice as fast as real time, etc.
  @type speed :: :no_delay | float()

  @doc """
  Integrates a set of ODEs.

  ## Options

  #{NimbleOptions.docs(@options_schema_adaptive_stepsize_only)}

  ### Additional Options

  Also see the options for the `Integrator.NonLinearEqnRoot.find_zero/4` which are passed
  into `integrate/10`.

  Originally adapted from the Octave
  [integrate_adaptive.m](https://github.com/gnu-octave/octave/blob/default/scripts/ode/private/integrate_adaptive.m)

  See [Wikipedia](https://en.wikipedia.org/wiki/Adaptive_stepsize)
  """
  @spec integrate(
          stepper_fn :: RungeKutta.stepper_fn_t(),
          interpolate_fn :: RungeKutta.interpolate_fn_t(),
          ode_fn :: RungeKutta.ode_fn_t(),
          t_start :: Nx.t(),
          t_end :: Nx.t(),
          fixed_times :: [Nx.t()] | nil,
          initial_tstep :: Nx.t(),
          x0 :: Nx.t(),
          order :: integer(),
          opts :: Keyword.t()
        ) :: t()
  def integrate(stepper_fn, interpolate_fn, ode_fn, t_start, t_end, fixed_times, initial_tstep, x0, order, opts \\ []) do
    opts =
      opts
      |> NimbleOptions.validate!(@options_schema)
      |> abs_rel_norm_opts()
      |> Keyword.put_new_lazy(:max_step, fn -> default_max_step(t_start, t_end) end)

    fixed_times = fixed_times |> drop_first_point()

    # The Nx types of :initial_tstep and opts[:max_step] need to be checked PRIOR to the call to Nx.min()
    # as Nx.min() will convert :f32's to :f64's:
    nx_type = opts[:type]
    check_nx_type([initial_tstep: initial_tstep, max_step: opts[:max_step]], nx_type)
    initial_tstep = Nx.min(Nx.abs(initial_tstep), opts[:max_step])

    # Broadcast the starting conditions (t_start & x0) as the first output point (if there is an output function):
    if fun = opts[:output_fn], do: fun.([t_start], [x0])

    opts =
      if fixed_times do
        # Spot-check the Nx type of the first time value in the list of fixed times:
        check_nx_type([fixed_times: hd(fixed_times)], nx_type)

        # If there are fixed output times, then refine can no longer be an integer value (such as 1 or 4):
        Keyword.merge(opts, refine: :fixed_times)
      else
        opts
      end

    check_nx_type(
      [
        t_start: t_start,
        t_end: t_end,
        x0: x0,
        abs_tol: opts[:abs_tol],
        rel_tol: opts[:rel_tol],
        ode_fn: ode_fn.(t_start, x0)
      ],
      nx_type
    )

    timestamp_now = timestamp_ms()

    %__MODULE__{
      t_new: t_start,
      x_new: x0,
      # t_old must be set on the initial struct in case there's an error when computing the first step (used in t_next/2)
      t_old: t_start,
      dt: initial_tstep,
      k_vals: initial_empty_k_vals(order, x0),
      fixed_times: fixed_times,
      nx_type: nx_type,
      options_comp: Nx.tensor(0.0, type: nx_type),
      timestamp_ms: timestamp_now,
      timestamp_start_ms: timestamp_now
    }
    |> store_first_point(t_start, x0, opts[:store_results?])
    |> step_forward(Nx.to_number(t_start), Nx.to_number(t_end), :continue, stepper_fn, interpolate_fn, ode_fn, order, opts)
    |> reverse_results()
    # Capture end timestamp:
    |> Map.put(:timestamp_ms, timestamp_ms())
  end

  @doc """
  Computes a good initial timestep for an ODE solver of order `order`
  using the algorithm described in the reference below.

  The input argument `ode_fn`, is the function describing the differential
  equations, `t0` is the initial time, and `x0` is the initial
  condition.  `abs_tol` and `rel_tol` are the absolute and relative
  tolerance on the ODE integration.

  Originally based on the Octave
  [`starting_stepsize.m`](https://github.com/gnu-octave/octave/blob/default/scripts/ode/private/starting_stepsize.m).

  Reference:

  E. Hairer, S.P. Norsett and G. Wanner,
  "Solving Ordinary Differential Equations I: Nonstiff Problems",
  Springer.
  """
  @spec starting_stepsize(
          order :: integer(),
          ode_fn :: RungeKutta.ode_fn_t(),
          t0 :: Nx.t(),
          x0 :: Nx.t(),
          abs_tol :: Nx.t(),
          rel_tol :: Nx.t(),
          opts :: Keyword.t()
        ) :: Nx.t()
  defn starting_stepsize(order, ode_fn, t0, x0, abs_tol, rel_tol, opts \\ []) do
    nx_type = Nx.type(x0)
    # Compute norm of initial conditions
    x_zeros = zero_vector(x0)
    d0 = abs_rel_norm(x0, x0, x_zeros, abs_tol, rel_tol, opts)

    x = ode_fn.(t0, x0)

    d1 = abs_rel_norm(x, x, x_zeros, abs_tol, rel_tol, opts)

    h0 =
      if d0 < 1.0e-5 or d1 < 1.0e-5 do
        Nx.tensor(1.0e-6, type: nx_type)
      else
        Nx.tensor(0.01, type: nx_type) * (d0 / d1)
      end

    # Compute one step of Explicit-Euler
    x1 = x0 + h0 * x

    # Approximate the derivative norm
    xh = ode_fn.(t0 + h0, x1)

    xh_minus_x = xh - x
    d2 = Nx.tensor(1.0, type: nx_type) / h0 * abs_rel_norm(xh_minus_x, xh_minus_x, x_zeros, abs_tol, rel_tol, opts)

    one = Nx.tensor(1, type: nx_type)

    h1 =
      if max(d1, d2) <= 1.0e-15 do
        max(Nx.tensor(1.0e-06, type: nx_type), h0 * Nx.tensor(1.0e-03, type: nx_type))
      else
        Nx.pow(Nx.tensor(1.0e-02, type: nx_type) / max(d1, d2), one / (order + one))
      end

    min(Nx.tensor(100.0, type: nx_type) * h0, h1)
  end

  @doc """
  Gets the default values used by the absolute-relative norm; e.g., `abs_tol`, `rel_tol`, and
  `norm_control`
  """
  @spec abs_rel_norm_opts(Keyword.t()) :: Keyword.t()
  def abs_rel_norm_opts(opts) do
    nx_type = Keyword.get(opts, :type, :f64)

    opts
    |> Keyword.put_new_lazy(:abs_tol, fn -> Nx.tensor(1.0e-06, type: nx_type) end)
    |> Keyword.put_new_lazy(:rel_tol, fn -> Nx.tensor(1.0e-03, type: nx_type) end)
  end

  @doc """
  Returns the total elapsed time for the integration (in milleseconds)
  """
  @spec elapsed_time_ms(t()) :: pos_integer()
  def elapsed_time_ms(step), do: step.timestamp_ms - step.timestamp_start_ms

  # ===========================================================================
  # Private functions below here:

  @spec drop_first_point([Nx.t()] | nil) :: [Nx.t()] | nil
  defp drop_first_point(nil), do: nil

  defp drop_first_point(fixed_times) do
    [_drop_first_point | rest_of_fixed_times] = fixed_times
    rest_of_fixed_times
  end

  @point_one Nx.tensor(0.1, type: :f64)

  @spec default_max_step(Nx.t(), Nx.t()) :: Nx.t()
  defnp default_max_step(t_start, t_end) do
    # See Octave: integrate_adaptive.m:89
    @point_one * Nx.abs(t_start - t_end)
  end

  @spec store_first_point(t(), Nx.t(), Nx.t(), boolean()) :: t()
  defp store_first_point(step, t_start, x0, true = _store_results?) do
    %{step | output_t: [t_start], output_x: [x0], ode_t: [t_start], ode_x: [x0]}
  end

  defp store_first_point(step, _t_start, _x0, _store_results?), do: step

  @spec initial_empty_k_vals(integer(), Nx.t()) :: Nx.t()
  defp initial_empty_k_vals(order, x) do
    # Figure out the correct way to do this!  Does k_length depend on the order of the Runge Kutta method?
    k_length = order + 2

    {length_x} = Nx.shape(x)
    zero = Nx.tensor(0.0, type: Nx.type(x))
    Nx.broadcast(zero, {length_x, k_length})
  end

  #  The main integration loop
  @spec step_forward(
          step :: t(),
          t_old :: float(),
          t_end :: float(),
          status :: integration_status(),
          stepper_fn :: RungeKutta.stepper_fn_t(),
          interpolate_fn :: RungeKutta.interpolate_fn_t(),
          ode_fn :: RungeKutta.ode_fn_t(),
          order :: integer(),
          opts :: Keyword.t()
        ) :: t()
  defp step_forward(step, t_old, t_end, _status, _stepper_fn, _interpolate_fn, _ode_fn, _order, _opts)
       when abs(t_old - t_end) < @zero_tolerance or t_old > t_end do
    step
  end

  defp step_forward(step, _t_old, _t_end, status, _stepper_fn, _interpolate_fn, _ode_fn, _order, _opts)
       when status == :halt do
    step
  end

  defp step_forward(step, _t_old, t_end, _status, stepper_fn, interpolate_fn, ode_fn, order, opts) do
    {new_step, error_est} = compute_step(step, stepper_fn, ode_fn, opts)
    step = step |> increment_compute_counter()

    step =
      if less_than_one?(error_est) do
        step
        |> increment_and_reset_counters()
        |> merge_new_step(new_step)
        |> call_event_fn(opts[:event_fn], interpolate_fn, opts)
        |> interpolate(interpolate_fn, opts[:refine])
        |> store_resuts(opts[:store_results?])
        |> call_output_fn(opts[:output_fn])
      else
        bump_error_count(step, opts)
      end

    dt = compute_next_timestep(step.dt, error_est, order, step.t_new, t_end, opts)
    step = %{step | dt: dt} |> delay_simulation(opts[:speed])

    step
    |> step_forward(t_next(step, dt), t_end, halt?(step), stepper_fn, interpolate_fn, ode_fn, order, opts)
  end

  @spec less_than_one?(Nx.t()) :: boolean()
  defp less_than_one?(error_est), do: Nx.less(error_est, 1.0) == Nx.tensor(1, type: :u8)

  @spec halt?(t()) :: integration_status()
  defp halt?(%{terminal_event: :halt} = _step), do: :halt
  defp halt?(%{terminal_output: :halt} = _step), do: :halt
  defp halt?(_step), do: :continue

  @spec bump_error_count(t(), Keyword.t()) :: t()
  defp bump_error_count(step, opts) do
    step = %{step | error_count: step.error_count + 1}

    if step.error_count > opts[:max_number_of_errors] do
      raise MaxErrorsExceededError,
        message: "Too many errors",
        error_count: step.error_count,
        max_number_of_errors: opts[:max_number_of_errors],
        step: step
    end

    step
  end

  @spec t_next(t(), Nx.t()) :: float()
  defp t_next(%{error_count: error_count} = step, dt) when error_count > 0 do
    # Update this into step somehow???
    Nx.add(step.t_old, dt) |> Nx.to_number()
  end

  defp t_next(%{error_count: error_count} = step, _dt) when error_count == 0 do
    Nx.to_number(step.t_new)
  end

  # Results are stored on lists in reverse order for speed; reverse them before returning them to the user
  @spec reverse_results(t()) :: t()
  defp reverse_results(step) do
    %{
      step
      | output_x: step.output_x |> Enum.reverse(),
        output_t: step.output_t |> Enum.reverse(),
        #
        ode_x: step.ode_x |> Enum.reverse(),
        ode_t: step.ode_t |> Enum.reverse()
    }
  end

  @stepsize_factor_min 0.8
  @stepsize_factor_max 1.5

  @spec exponent(Nx.t(), Nx.Type.t()) :: Nx.t()
  defnp exponent(order, nx_type), do: Nx.tensor(1, type: nx_type) / (Nx.tensor(1, type: nx_type) + order)

  @spec factor(Nx.t(), Nx.Type.t()) :: Nx.t()
  # defnp factor(order, nx_type), do: if(order == 3, do: @factor_order3, else: @factor_order5)
  defnp factor(order, nx_type) do
    Nx.tensor(0.38, type: nx_type) ** exponent(order, nx_type)
  end

  # Formula taken from Hairer
  @spec compute_next_timestep(Nx.t(), Nx.t(), integer(), Nx.t(), Nx.t(), Keyword.t()) :: Nx.t()
  defnp compute_next_timestep(dt, error, order, t_old, t_end, opts) do
    nx_type = opts[:type]

    # Avoid divisions by zero:
    # error = error + Nx.Constants.epsilon(nx_type)
    error = error + Utils.epsilon_nx(nx_type)

    # Octave:
    #   dt *= min (facmax, max (facmin, fac * (1 / err)^(1 / (order + 1))));

    one = Nx.tensor(1.0, type: nx_type)
    foo = factor(order, nx_type) * (one / error) ** exponent(order, nx_type)
    dt = dt * min(Nx.tensor(@stepsize_factor_max, type: nx_type), max(Nx.tensor(@stepsize_factor_min, type: nx_type), foo))
    dt = min(Nx.abs(dt), opts[:max_step])

    # Make sure we don't go past t_end:
    min(Nx.abs(dt), Nx.abs(t_end - t_old))
  end

  @spec increment_and_reset_counters(t()) :: t()
  defp increment_and_reset_counters(step) do
    %{
      step
      | count_loop__increment_step: step.count_loop__increment_step + 1,
        i_step: step.i_step + 1,
        error_count: 0
    }
  end

  # Merges a newly-computed Runge-Kutta step into the AdaptiveStepsize struct
  @spec merge_new_step(t(), ComputedStep.t()) :: t()
  defp merge_new_step(step, computed_step) do
    %{
      step
      | x_old: step.x_new,
        t_old: step.t_new,
        #
        x_new: computed_step.x_new,
        t_new: computed_step.t_new,
        #
        x_new_rk_interpolate: computed_step.x_new,
        t_new_rk_interpolate: computed_step.t_new,
        #
        k_vals: computed_step.k_vals,
        options_comp: computed_step.options_comp
    }
  end

  @spec store_resuts(t(), boolean()) :: t()
  defp store_resuts(step, false = _store_results?) do
    step
  end

  defp store_resuts(step, true = _store_results?) do
    %{
      step
      | ode_t: [step.t_new | step.ode_t],
        ode_x: [step.x_new | step.ode_x],
        output_t: step.t_new_chunk ++ step.output_t,
        output_x: step.x_new_chunk ++ step.output_x
    }
  end

  @spec increment_compute_counter(t()) :: t()
  defp increment_compute_counter(step) do
    %{step | count_cycles__compute_step: step.count_cycles__compute_step + 1}
  end

  # Inserts a delay for performing real-time simulations
  @spec delay_simulation(t(), speed()) :: t()
  defp delay_simulation(step, :no_delay), do: step

  defp delay_simulation(%{error_count: error_count} = step, _speed) when error_count > 0, do: step

  defp delay_simulation(step, speed) do
    desired_time_interval = Nx.to_number(Nx.subtract(step.t_new, step.t_old)) * 1000 / speed
    elapsed_time = timestamp_ms() - step.timestamp_ms
    sleep_time = trunc(desired_time_interval - elapsed_time)
    if sleep_time > 0, do: Process.sleep(sleep_time)
    %{step | timestamp_ms: timestamp_ms()}
  end

  # Computes the next Runge-Kutta step. Note that this function "wraps" the Nx functions which
  # perform the actual numerical computations
  @spec compute_step(t(), RungeKutta.stepper_fn_t(), RungeKutta.ode_fn_t(), Keyword.t()) :: {ComputedStep.t(), Nx.t()}
  defp compute_step(step, stepper_fn, ode_fn, opts) do
    x_old = step.x_new
    t_old = step.t_new
    options_comp_old = step.options_comp
    k_vals_old = step.k_vals
    dt = step.dt

    {t_next, x_next, k_vals, options_comp, error} =
      compute_step_nx(stepper_fn, ode_fn, t_old, x_old, k_vals_old, options_comp_old, dt, opts)

    {%ComputedStep{
       t_new: t_next,
       x_new: x_next,
       k_vals: k_vals,
       options_comp: options_comp
     }, error}
  end

  # Computes the next Runge-Kutta step and the associated error
  @spec compute_step_nx(
          stepper_fn :: RungeKutta.stepper_fn_t(),
          ode_fn :: RungeKutta.ode_fn_t(),
          t_old :: Nx.t(),
          x_old :: Nx.t(),
          k_vals_old :: Nx.t(),
          options_comp_old :: Nx.t(),
          dt :: Nx.t(),
          opts :: Keyword.t()
        ) :: {
          t_next :: Nx.t(),
          x_next :: Nx.t(),
          k_vals :: Nx.t(),
          options_comp :: Nx.t(),
          error :: Nx.t()
        }
  defnp compute_step_nx(stepper_fn, ode_fn, t_old, x_old, k_vals_old, options_comp_old, dt, opts) do
    {t_next, options_comp} = Utils.kahan_sum(t_old, options_comp_old, dt)
    {x_next, x_est, k_vals} = stepper_fn.(ode_fn, t_old, x_old, dt, k_vals_old, t_next)
    error = abs_rel_norm(x_next, x_old, x_est, opts[:abs_tol], opts[:rel_tol], norm_control: opts[:norm_control])
    {t_next, x_next, k_vals, options_comp, error}
  end

  @spec add_fixed_point(t(), RungeKutta.interpolate_fn_t()) :: t()
  defp add_fixed_point(%{fixed_times: []} = step, _interpolate_fn) do
    step
  end

  defp add_fixed_point(step, interpolate_fn) do
    [fixed_time | remaining_fixed_times] = step.fixed_times

    if add_fixed_point?(fixed_time, step.t_new) == Nx.tensor(1, type: :u8) do
      x_at_fixed_time = interpolate_one_point(fixed_time, step, interpolate_fn)

      step = %{
        step
        | t_new_chunk: [fixed_time | step.t_new_chunk],
          x_new_chunk: [x_at_fixed_time | step.x_new_chunk],
          fixed_times: remaining_fixed_times
      }

      add_fixed_point(step, interpolate_fn)
    else
      step
    end
  end

  # @spec add_fixed_point?(Nx.t(), Nx.t()) :: Nx.t()
  defnp add_fixed_point?(fixed_time, t_new) do
    fixed_time < t_new or Nx.abs(fixed_time - t_new) < @zero_tolerance
  end

  @spec interpolate(t(), RungeKutta.interpolate_fn_t(), refine_strategy()) :: t()
  defp interpolate(step, interpolate_fn, refine) when refine == :fixed_times do
    add_fixed_point(%{step | t_new_chunk: [], x_new_chunk: []}, interpolate_fn)
  end

  defp interpolate(step, _interpolate_fn, refine) when refine == 1 do
    %{step | t_new_chunk: [step.t_new], x_new_chunk: [step.x_new]}
  end

  defp interpolate(step, interpolate_fn, refine) when refine > 1 do
    tadd = Nx.linspace(step.t_old, step.t_new, n: refine + 1, type: Nx.type(step.x_old))
    # Get rid of the first element (tadd[0]) via this slice:
    tadd = Nx.slice_along_axis(tadd, 1, refine, axis: 0)

    x_out_as_cols = do_interpolation(step, interpolate_fn, tadd) |> Enum.reverse()
    t_new_chunk = tadd |> Utils.vector_as_list() |> Enum.reverse()
    %{step | x_new_chunk: x_out_as_cols, t_new_chunk: t_new_chunk}
  end

  @spec interpolate_one_point(Nx.t(), t(), RungeKutta.interpolate_fn_t()) :: Nx.t()
  defp interpolate_one_point(t_new, step, interpolate_fn) do
    do_interpolation(step, interpolate_fn, Nx.tensor(t_new, type: step.nx_type)) |> List.first()
  end

  @spec do_interpolation(t(), RungeKutta.interpolate_fn_t(), Nx.t()) :: [Nx.t()]
  defp do_interpolation(step, interpolate_fn, tadd) do
    tadd_length =
      case Nx.shape(tadd) do
        {} -> 1
        {length} -> length
      end

    t = Nx.stack([step.t_old, step.t_new_rk_interpolate])
    x = Nx.stack([step.x_old, step.x_new_rk_interpolate]) |> Nx.transpose()

    x_out = interpolate_fn.(t, x, step.k_vals, tadd)
    x_out |> Utils.columns_as_list(0, tadd_length - 1)
  end

  # Calls an output function (such as for plotting while the simulation is in progress)
  @spec call_output_fn(t(), output_fn_t()) :: t()
  defp call_output_fn(step, output_fn) when is_nil(output_fn) do
    step
  end

  defp call_output_fn(step, output_fn) do
    result = output_fn.(Enum.reverse(step.t_new_chunk), Enum.reverse(step.x_new_chunk))
    %{step | terminal_output: result}
  end

  # Calls an event function (e.g., checking to see if a bouncing ball has collided with a surface)
  @spec call_event_fn(t(), event_fn_t(), RungeKutta.interpolate_fn_t(), Keyword.t()) :: t()
  defp call_event_fn(step, event_fn, _interpolate_fn, _opts) when is_nil(event_fn) do
    step
  end

  defp call_event_fn(step, event_fn, interpolate_fn, opts) do
    # Pass opts to event_fn?
    case event_fn.(step.t_new, step.x_new) do
      {:continue, _value} ->
        step

      {:halt, _value} ->
        new_step = step |> compute_new_event_fn_step(event_fn, interpolate_fn, opts)

        %{
          step
          | terminal_event: :halt,
            x_new: new_step.x_new,
            t_new: new_step.t_new
        }
    end
  end

  # Hones in (via interpolation) on the exact point that the event function goes to zero
  @spec compute_new_event_fn_step(t(), event_fn_t(), RungeKutta.interpolate_fn_t(), Keyword.t()) :: ComputedStep.t()
  defp compute_new_event_fn_step(step, event_fn, interpolate_fn, opts) do
    zero_fn = fn t ->
      x = interpolate_one_point(t, step, interpolate_fn)
      {_status, value} = event_fn.(t, x)
      value |> Nx.to_number()
    end

    root =
      NonLinearEqnRoot.find_zero(
        zero_fn,
        [Nx.to_number(step.t_old), Nx.to_number(step.t_new)],
        only_non_linear_eqn_root_opts(opts)
      )

    x_new = interpolate_one_point(root.x, step, interpolate_fn)
    %ComputedStep{t_new: Nx.tensor(root.x, type: opts[:type]), x_new: x_new, k_vals: step.k_vals, options_comp: step.options_comp}
  end

  @spec only_non_linear_eqn_root_opts(Keyword.t()) :: Keyword.t()
  defp only_non_linear_eqn_root_opts(opts) do
    non_linear_eqn_root_opt_keys = NonLinearEqnRoot.option_keys()
    opts |> Keyword.filter(fn {key, _value} -> key in non_linear_eqn_root_opt_keys end)
  end

  # Originally based on
  # [Octave function AbsRelNorm](https://github.com/gnu-octave/octave/blob/default/scripts/ode/private/AbsRel_norm.m)

  # Options
  # * `:norm_control` - Control error relative to norm; i.e., control the error `e` at each step using the norm of the
  #   solution rather than its absolute value.  Defaults to true.

  # See [Matlab documentation](https://www.mathworks.com/help/matlab/ref/odeset.html#bu2m9z6-NormControl)
  # for a description of norm control.
  @spec abs_rel_norm(Nx.t(), Nx.t(), Nx.t(), float(), float(), Keyword.t()) :: Nx.t()
  defnp abs_rel_norm(t, t_old, x, abs_tolerance, rel_tolerance, opts \\ []) do
    if opts[:norm_control] do
      # Octave code
      # sc = max (AbsTol(:), RelTol * max (sqrt (sumsq (t)), sqrt (sumsq (t_old))));
      # retval = sqrt (sumsq ((t - x))) / sc;

      max_sq_t = Nx.max(sum_sq(t), sum_sq(t_old))
      sc = Nx.max(abs_tolerance, rel_tolerance * max_sq_t)
      sum_sq(t - x) / sc
    else
      # Octave code:
      # sc = max (AbsTol(:), RelTol .* max (abs (t), abs (t_old)));
      # retval = max (abs (t - x) ./ sc);

      sc = Nx.max(abs_tolerance, rel_tolerance * Nx.max(Nx.abs(t), Nx.abs(t_old)))
      (Nx.abs(t - x) / sc) |> Nx.reduce_max()
    end
  end

  # Sums the squares of a vector and then takes the square root (e.g., computes the norm of a vector)
  @spec sum_sq(Nx.t()) :: Nx.t()
  defnp sum_sq(x) do
    Nx.dot(x, x) |> Nx.sqrt()
  end

  # Creates a zero vector that has the length of `x`
  # Is there a better built-in Nx way of doing this?
  @spec zero_vector(Nx.t()) :: Nx.t()
  defnp zero_vector(x) do
    {length_of_x} = Nx.shape(x)
    zero = Nx.tensor(0.0, type: Nx.type(x))
    Nx.broadcast(zero, {length_of_x})
  end

  # 4Checks that the Nx types are in line with what is expected. This avoids args with mismatched types.
  @spec check_nx_type(Keyword.t(), Nx.Type.t()) :: atom()
  defp check_nx_type(args, expected_nx_type) do
    args
    |> Enum.each(fn {arg_name, arg_value} ->
      nx_type = Nx.type(arg_value) |> Nx.Type.to_string()

      if nx_type != Atom.to_string(expected_nx_type) do
        raise ArgPrecisionError,
          invalid_argument: arg_value,
          expected_precision: expected_nx_type,
          actual_precision: nx_type,
          argument_name: arg_name
      end
    end)

    :ok
  end

  # Returns a timestamp in milliseconds
  @spec timestamp_ms() :: pos_integer()
  defp timestamp_ms(), do: :os.system_time(:millisecond)
end