lib/integrator/non_linear_eqn_root.ex

defmodule Integrator.NonLinearEqnRoot do
  @moduledoc """
  Finds the roots (i.e., zeros) of a non-linear equation.
  Based on [`fzero.m`](https://github.com/gnu-octave/octave/blob/default/scripts/optimization/fzero.m)
  from Octave.

  This is essentially the ACM algorithm 748: Enclosing Zeros of Continuous Functions
  due to Alefeld, Potra and Shi, ACM Transactions on Mathematical Software, Vol. 21,
  No. 3, September 1995. Although the workflow is the same, the structure of
  the algorithm has been transformed non-trivially; also, the algorithm has also been
  slightly modified.
  """

  import Integrator.Utils, only: [sign: 1]
  alias Integrator.Utils

  alias Integrator.NonLinearEqnRoot.{
    BracketingFailureError,
    InvalidInitialBracketError,
    MaxFnEvalsExceededError,
    MaxIterationsExceededError
  }

  @type zero_fn_t :: (float() -> float())
  @type output_fn_t :: (float(), float() -> any())

  @type interpolation_type ::
          :bisect
          | :double_secant
          | :inverse_cubic_interpolation
          | :quadratic_interpolation_plus_newton
          | :secant

  @type convergence_status :: :halt | :continue

  @type iter_type :: 1 | 2 | 3 | 4 | 5

  @type t :: %__MODULE__{
          a: float() | nil,
          b: float() | nil,
          c: float() | nil,
          d: float() | nil,
          e: float() | nil,
          u: float() | nil,
          #
          # Function evaluations; e.g., fb is fn(b):
          fa: float() | nil,
          fb: float() | nil,
          fc: float() | nil,
          fd: float() | nil,
          fe: float() | nil,
          fu: float() | nil,
          #
          # x (and fx) are the actual found values (i.e., fx should be very close to zero):
          x: float() | nil,
          fx: float() | nil,
          #
          mu_ba: float() | nil,
          #
          fn_eval_count: integer(),
          iteration_count: integer(),
          # Change iter_type to a more descriptive atom later (possibly?):
          iter_type: iter_type()
        }

  defstruct [
    :a,
    :b,
    :c,
    :d,
    :e,
    :u,
    #
    # Function evaluations; e.g., fb is fn(b):
    :fa,
    :fb,
    :fc,
    :fd,
    :fe,
    :fu,
    #
    # x (and fx) are the actual found values (i.e., fx should be very close to zero):
    :x,
    :fx,
    #
    :mu_ba,
    #
    fn_eval_count: 0,
    iteration_count: 0,
    # Change iter_type to a more descriptive atom later (possibly?):
    iter_type: 1
  ]

  options = [
    max_iterations: [
      type: :integer,
      doc: "The maximum allowed number of iterations when finding a root.",
      default: 1000
    ],
    max_fn_eval_count: [
      type: :integer,
      doc: "The maximum allowed number of function evaluations when finding a root.",
      default: 1000
    ],
    type: [
      type: {:in, [:f32, :f64]},
      doc: "The Nx type.",
      default: :f64
    ],
    machine_eps: [
      type: :float,
      doc: "The machine epsilon. Defaults to Nx.constants.epsilon/1 for this Nx type."
    ],
    tolerance: [
      type: :float,
      doc: "The tolerance for the convergence when finding a root. Defaults to Nx.Constants.epsilon/1 for this Nx type."
    ],
    nonlinear_eqn_root_output_fn: [
      # Ideally the type for this should be set to a function with arity 2, but I could not get that to work:
      type: :any,
      doc: "An output function to call so intermediate results can be retrieved when finding a root.",
      default: nil
    ]
  ]

  @options_schema NimbleOptions.new!(options)
  @type options_t() :: unquote(NimbleOptions.option_typespec(@options_schema))

  def options_schema, do: @options_schema

  @initial_mu 0.5

  @doc """
  Finds a zero for a function in an interval `[a, b]` (if the 2nd argument is a list) or
  in the vicinity of `a` (if the 2nd argument is a float).

  ## Options

  #{NimbleOptions.docs(@options_schema)}

  """

  # I can't get this to work for some unknown reason:
  # @spec find_zero(zero_fn_t(), [float()] | float(), options_t()) :: t()

  @spec find_zero(zero_fn_t(), [float()] | float(), Keyword.t(), integer()) :: t()
  def find_zero(zero_fn, initial_values, opts \\ [], fn_evals \\ 0)

  def find_zero(zero_fn, [a, b], opts, fn_evals) do
    opts = opts |> NimbleOptions.validate!(@options_schema) |> merge_default_opts()

    fa = zero_fn.(a)
    fb = zero_fn.(b)
    fn_eval_count = 2 + fn_evals
    {u, fu} = if abs(fa) < abs(fb), do: {a, fa}, else: {b, fb}
    {a, b, fa, fb} = if b < a, do: {b, a, fb, fa}, else: {a, b, fa, fb}

    z = %__MODULE__{
      a: a,
      b: b,
      d: u,
      e: u,
      u: u,
      #
      fa: fa,
      fb: fb,
      fd: fu,
      fe: fu,
      fu: fu,
      #
      fn_eval_count: fn_eval_count,
      iter_type: 1,
      mu_ba: @initial_mu * (b - a)
    }

    if sign(z.fa) * sign(z.fb) > 0.0, do: raise(InvalidInitialBracketError, step: z)

    case converged?(z, opts[:machine_eps], opts[:tolerance]) do
      :continue -> iterate(z, :continue, zero_fn, opts)
      :halt -> %{z | x: u, fx: fu}
    end
  end

  def find_zero(zero_fn, solo_point, opts, _fn_evals) do
    second_point = find_2nd_starting_point(zero_fn, solo_point)

    find_zero(zero_fn, [solo_point, second_point.b], opts, second_point.fn_eval_count)
  end

  @spec bracket_x(t()) :: [float()]
  def bracket_x(z) do
    [z.a, z.b]
  end

  @spec bracket_fx(t()) :: [float()]
  def bracket_fx(z) do
    [z.fa, z.fb]
  end

  def option_keys, do: NimbleOptions.validate!([], @options_schema) |> Keyword.keys()

  # ===========================================================================
  # Private functions below here:

  @spec iterate(t(), atom(), zero_fn_t(), Keyword.t()) :: t()
  defp iterate(z, :halt, _zero_fn, _opts), do: z

  defp iterate(z, _status, zero_fn, opts) do
    machine_eps = opts[:machine_eps]
    tolerance = opts[:tolerance]

    {status_1, z} =
      z
      |> compute_iteration()
      |> adjust_if_too_close_to_a_or_b(machine_eps, tolerance)
      |> fn_eval_new_point(zero_fn, opts)
      |> check_for_non_monotonicity()
      |> bracket()

    z =
      z
      |> skip_bisection_if_successful_reduction()
      |> update_u()
      |> call_output_fn(opts[:nonlinear_eqn_root_output_fn])

    status_2 = converged?(z, machine_eps, tolerance)

    iterate(z, halt?(status_1, status_2), zero_fn, opts)
  end

  @spec halt?(convergence_status(), convergence_status()) :: convergence_status()
  defp halt?(:halt, _), do: :halt
  defp halt?(_, :halt), do: :halt
  defp halt?(_, _), do: :continue

  @spec compute_iteration(t()) :: t()
  defp compute_iteration(%{iter_type: 1} = z) do
    # Octave:
    #   if (abs (fa) <= 1e3*abs (fb) && abs (fb) <= 1e3*abs (fa))
    #     # Secant step.
    #     c = u - (a - b) / (fa - fb) * fu;
    #   else
    #     # Bisection step.
    #     c = 0.5*(a + b);
    #   endif
    #   d = u; fd = fu;
    #   iter_type = 5;

    # What is the significance or meaning of the 1000 here? Replace with a more descriptive module variable
    c =
      if abs(z.fa) <= 1000 * abs(z.fb) && abs(z.fb) <= 1000 * abs(z.fa) do
        interpolate(z, :secant)
      else
        interpolate(z, :bisect)
      end

    %{z | c: c, d: z.u, fd: z.fu, iter_type: 5}
  end

  defp compute_iteration(%{iter_type: 2} = z) do
    compute_iteration_two_or_three(z)
  end

  defp compute_iteration(%{iter_type: 3} = z) do
    compute_iteration_two_or_three(z)
  end

  defp compute_iteration(%{iter_type: 4} = z) do
    # Octave:
    #   # Double secant step.
    #   c = u - 2*(b - a)/(fb - fa)*fu;
    #   # Bisect if too far.
    #   if (abs (c - u) > 0.5*(b - a))
    #     c = 0.5 * (b + a);
    #   endif
    #   iter_type = 5;

    c = interpolate(z, :double_secant)

    c =
      if too_far?(c, z) do
        # Bisect if too far:
        interpolate(z, :bisect)
      else
        c
      end

    %{z | iter_type: 5, c: c}
  end

  defp compute_iteration(%{iter_type: 5} = z) do
    # Octave:
    #   # Bisection step.
    #   c = 0.5 * (b + a);
    #   iter_type = 2;
    c = interpolate(z, :bisect)
    %{z | iter_type: 2, c: c}
  end

  @spec compute_iteration_two_or_three(t()) :: t()
  defp compute_iteration_two_or_three(z) do
    c =
      case length(Utils.unique([z.fa, z.fb, z.fd, z.fe])) do
        4 ->
          interpolate(z, :inverse_cubic_interpolation)

        length ->
          if length < 4 || sign(z.c - z.a) * sign(z.c - z.b) > 0 do
            interpolate(z, :quadratic_interpolation_plus_newton)
          else
            # what do we do here?  it's not handled in fzero.m...
            z.c
          end
      end

    %{z | iter_type: z.iter_type + 1, c: c}
  end

  @search_values [-0.01, 0.025, -0.05, 0.10, -0.25, 0.50, -1.0, 2.5, -5.0, 10.0, -50.0, 100.0, 500.0, 1000.0]

  defmodule SearchFor2ndPoint do
    @moduledoc false
    defstruct [:a, :fa, :b, :fb, :fn_eval_count]
  end

  @type search_for_2nd_point_t :: %SearchFor2ndPoint{
          a: float() | nil,
          b: float() | nil,
          #
          # Function evaluations; e.g., fb is fn(b):
          fa: float() | nil,
          fb: float() | nil,
          #
          fn_eval_count: integer()
        }

  @spec find_2nd_starting_point(zero_fn_t(), float()) :: map()
  defp find_2nd_starting_point(zero_fn, a) do
    # For very small values, switch to absolute rather than relative search:
    a =
      if abs(a) < 0.001 do
        if a == 0, do: 0.1, else: sign(a) * 0.1
      else
        a
      end

    fa = zero_fn.(a)
    x = %SearchFor2ndPoint{a: a, fa: fa, b: nil, fb: nil, fn_eval_count: 1}

    # Search in an ever-widening range around the initial point:
    searching_for_2nd_point(:continue, zero_fn, x, @search_values)
  end

  @spec searching_for_2nd_point(atom(), zero_fn_t(), search_for_2nd_point_t(), [float()]) :: map()
  defp searching_for_2nd_point(:found, _zero_fn, x, _search_values), do: x

  defp searching_for_2nd_point(:continue, zero_fn, x, search_values) do
    [search | rest_of_search_values] = search_values
    b = x.a + x.a * search
    fb = zero_fn.(b)
    x = %{x | b: b, fb: fb, fn_eval_count: x.fn_eval_count + 1}
    status = if sign(x.fa) * sign(fb) <= 0, do: :found, else: :continue
    searching_for_2nd_point(status, zero_fn, x, rest_of_search_values)
  end

  @spec interpolate(t(), interpolation_type()) :: float()
  defp interpolate(z, :quadratic_interpolation_plus_newton) do
    a0 = z.fa
    a1 = (z.fb - z.fa) / (z.b - z.a)
    a2 = ((z.fd - z.fb) / (z.d - z.b) - a1) / (z.d - z.a)

    ## Modification 1: this is simpler and does not seem to be worse.
    c = z.a - a0 / a1

    if a2 != 0 do
      1..z.iter_type
      |> Enum.reduce(c, fn _i, c ->
        pc = a0 + (a1 + a2 * (c - z.b)) * (c - z.a)
        pdc = a1 + a2 * (2 * c - z.a - z.b)

        if pdc == 0 do
          # Octave does a break here - is the c = 0 caught downstream? Need to handle this case somehow"
          # Note that there is NO test case for this case, as I couldn't figure out how to set up
          # the initial conditions to reach here
          z.a - a0 / a1
        else
          c - pc / pdc
        end
      end)
    else
      c
    end
  end

  defp interpolate(z, :inverse_cubic_interpolation) do
    q11 = (z.d - z.e) * z.fd / (z.fe - z.fd)
    q21 = (z.b - z.d) * z.fb / (z.fd - z.fb)
    q31 = (z.a - z.b) * z.fa / (z.fb - z.fa)
    d21 = (z.b - z.d) * z.fd / (z.fd - z.fb)
    d31 = (z.a - z.b) * z.fb / (z.fb - z.fa)

    q22 = (d21 - q11) * z.fb / (z.fe - z.fb)
    q32 = (d31 - q21) * z.fa / (z.fd - z.fa)
    d32 = (d31 - q21) * z.fd / (z.fd - z.fa)
    q33 = (d32 - q22) * z.fa / (z.fe - z.fa)

    z.a + q31 + q32 + q33
  end

  defp interpolate(z, :double_secant) do
    z.u - 2.0 * (z.b - z.a) / (z.fb - z.fa) * z.fu
  end

  defp interpolate(z, :bisect) do
    0.5 * (z.b + z.a)
  end

  defp interpolate(z, :secant) do
    z.u - (z.a - z.b) / (z.fa - z.fb) * z.fu
  end

  @spec too_far?(float(), t()) :: boolean()
  defp too_far?(c, z) do
    abs(c - z.u) > 0.5 * (z.b - z.a)
  end

  @spec fn_eval_new_point(t(), zero_fn_t(), Keyword.t()) :: t()
  defp fn_eval_new_point(z, zero_fn, opts) do
    fc = zero_fn.(z.c)
    # Perhaps move the incrementing of the iteration count elsewhere?
    iteration_count = z.iteration_count + 1
    fn_eval_count = z.fn_eval_count + 1

    if iteration_count > opts[:max_iterations] do
      raise MaxIterationsExceededError, step: z, iteration_count: iteration_count
    end

    if fn_eval_count > opts[:max_fn_eval_count] do
      raise MaxFnEvalsExceededError, step: z, fn_eval_count: fn_eval_count
    end

    %{
      z
      | fc: fc,
        x: z.c,
        fx: fc,
        fn_eval_count: fn_eval_count,
        iteration_count: iteration_count
    }
  end

  # Modification 2: skip inverse cubic interpolation if nonmonotonicity is detected
  @spec check_for_non_monotonicity(t()) :: t()
  defp check_for_non_monotonicity(z) do
    if sign(z.fc - z.fa) * sign(z.fc - z.fb) >= 0 do
      # The new point broke monotonicity.
      # Disable inverse cubic:
      %{z | fe: z.fc}
    else
      %{z | e: z.d, fe: z.fd}
    end
  end

  @spec adjust_if_too_close_to_a_or_b(t(), float(), float()) :: t()
  defp adjust_if_too_close_to_a_or_b(z, machine_eps, tolerance) do
    delta = 2 * 0.7 * (2 * abs(z.u) * machine_eps + tolerance)

    c =
      if z.b - z.a <= 2 * delta do
        (z.a + z.b) / 2
      else
        max(z.a + delta, min(z.b - delta, z.c))
      end

    %{z | c: c}
  end

  @spec bracket(t()) :: {convergence_status(), t()}
  defp bracket(z) do
    {status, z} =
      if sign(z.fa) * sign(z.fc) < 0 do
        {:continue, %{z | d: z.b, fd: z.fb, b: z.c, fb: z.fc}}
      else
        if sign(z.fb) * sign(z.fc) < 0 do
          {:continue, %{z | d: z.a, fd: z.fa, a: z.c, fa: z.fc}}
        else
          if z.fc == 0.0 do
            {:halt, %{z | a: z.c, b: z.c, fa: z.fc, fb: z.fc}}
          else
            # Should never reach here
            raise BracketingFailureError, step: z
          end
        end
      end

    {status, z}
  end

  @spec call_output_fn(t(), output_fn_t()) :: t()
  defp call_output_fn(z, nil = _output_fn), do: z

  defp call_output_fn(z, output_fn) do
    output_fn.(z.x, z)
    z
  end

  @spec update_u(t()) :: t()
  defp update_u(z) do
    # Octave:
    #   if (abs (fa) < abs (fb))
    #     u = a; fu = fa;
    #   else
    #     u = b; fu = fb;
    #   endif

    if abs(z.fa) < abs(z.fb) do
      %{z | u: z.a, fu: z.fa}
    else
      %{z | u: z.b, fu: z.fb}
    end
  end

  @spec converged?(t(), float(), float()) :: convergence_status()
  defp converged?(z, machine_eps, tolerance) do
    if z.b - z.a <= 2 * (2 * abs(z.u) * machine_eps + tolerance) do
      :halt
    else
      :continue
    end
  end

  @spec skip_bisection_if_successful_reduction(t()) :: t()
  defp skip_bisection_if_successful_reduction(z) do
    # Octave:
    #   if (iter_type == 5 && (b - a) <= mba)
    #     iter_type = 2;
    #   endif
    #   if (iter_type == 2)
    #     mba = mu * (b - a);
    #   endif

    z =
      if z.iter_type == 5 && z.b - z.a <= z.mu_ba do
        %{z | iter_type: 2}
      else
        z
      end

    if z.iter_type == 2 do
      # Should this really be @initial_mu here?  or should it be mu_ba?  Seems a bit odd...
      %{z | mu_ba: @initial_mu * (z.b - z.a)}
    else
      z
    end
  end

  # ---------------------------------------
  # Option handling

  @spec set_tolerance(Keyword.t()) :: Keyword.t()
  defp set_tolerance(opts) do
    # Keyword.put_new_lazy(opts, :tolerance, fn -> Nx.to_number(Nx.Constants.epsilon(opts[:type])) end)
    Keyword.put_new_lazy(opts, :tolerance, fn -> Utils.epsilon(opts[:type]) end)
  end

  @spec set_machine_eps(Keyword.t()) :: Keyword.t()
  defp set_machine_eps(opts) do
    # Keyword.put_new_lazy(opts, :machine_eps, fn -> Nx.to_number(Nx.Constants.epsilon(opts[:type])) end)
    Keyword.put_new_lazy(opts, :machine_eps, fn -> Utils.epsilon(opts[:type]) end)
  end

  @spec merge_default_opts(Keyword.t()) :: Keyword.t()
  defp merge_default_opts(opts) do
    opts |> set_tolerance() |> set_machine_eps()
  end
end