lib/exgboost/training/callback.ex

defmodule EXGBoost.Training.Callback do
  @moduledoc """
  Callbacks are a mechanism to hook into the training process and perform custom actions.

  Callbacks are structs with the following fields:
  * `event` - the event that triggers the callback
  * `fun` - the function to call when the callback is triggered
  * `name` - the name of the callback
  * `init_state` - the initial state of the callback

  The following events are supported:
  * `:before_training` - called before the training starts
  * `:after_training` - called after the training ends
  * `:before_iteration` - called before each iteration
  * `:after_iteration` - called after each iteration

  The callback function is called with the following arguments:
  * `state` - the current training state

  The callback function should return one of the following:
  * `{:cont, state}` - continue training with the given state
  * `{:halt, state}` - stop training with the given state

  The following callbacks are provided in the `EXGBoost.Training.Callback` module:
  * `lr_scheduler` - sets the learning rate for each iteration
  * `early_stop` - performs early stopping
  * `eval_metrics` - evaluates metrics on the training and evaluation sets
  * `eval_monitor` - prints evaluation metrics

  Callbacks can be added to the training process by passing them to `EXGBoost.Training.train/2`.

  ## Example

  ```elixir
  # Callback to perform setup before training
  setup_fn = fn state ->
    updated_state = put_in(state, [:meta_vars,:early_stop], %{best: 1, since_last_improvement: 0, mode: :max, patience: 5})
    {:cont, updated_state}
  end

  setup_callback = Callback.new(:before_training, setup_fn)
  ```

  """
  alias EXGBoost.Training.State
  @enforce_keys [:event, :fun]
  defstruct [:event, :fun, :name, :init_state]

  @doc """
  Factory for a new callback without an initial state. See `EXGBoost.Callback.new/4` for more details.
  """
  @spec new(
          event :: :before_training | :after_training | :before_iteration | :after_iteration,
          fun :: (State.t() -> {:cont, State.t()} | {:halt, State.t()})
        ) :: Callback.t()
  def new(event, fun) do
    new(event, fun, nil, %{})
  end

  @doc """
  Factory for a new callback with an initial state.
  """
  @spec new(
          event :: :before_training | :after_training | :before_iteration | :after_iteration,
          fun :: (State.t() -> {:cont, State.t()} | {:halt, State.t()}),
          name :: atom(),
          init_state :: map()
        ) :: Callback.t()
  def new(event, fun, name, %{} = init_state)
      when event in [:before_training, :after_training, :before_iteration, :after_iteration] and
             is_atom(name) do
    %__MODULE__{event: event, fun: fun, name: name, init_state: init_state}
  end

  @doc """
  A callback that sets the learning rate for each iteration.

  Requires that `learning_rates` either be a list of learning rates or a function that takes the
  iteration number and returns a learning rate.  `learning_rates` must exist in the `state` that
  is passed to the callback.
  """
  def lr_scheduler(
        %State{
          booster: bst,
          meta_vars: %{lr_scheduler: %{learning_rates: learning_rates}},
          iteration: i
        } = state
      ) do
    lr = if is_list(learning_rates), do: Enum.at(learning_rates, i), else: learning_rates.(i)
    boostr = EXGBoost.Booster.set_params(bst, learning_rate: lr)
    {:cont, %{state | booster: boostr}}
  end

  # TODO: Ideally this would be generalized like it is in Axon to allow generic monitoring of metrics,
  # but for now we'll just do early stopping

  @doc """
  A callback function that performs early stopping.

  Requires that the following exist in the `state` that is passed to the callback:

  * `target` is the metric to monitor for early stopping.  It must exist in the `metrics` that the
  state contains.
  * `mode` is either `:min` or `:max` and indicates whether the metric should be
     minimized or maximized.
  * `patience` is the number of iterations to wait for the metric to improve before stopping.
  * `since_last_improvement` is the number of iterations since the metric last improved.
  * `best` is the best value of the metric seen so far.
  """
  def early_stop(
        %State{
          booster: bst,
          meta_vars:
            %{
              early_stop: %{
                best: best,
                patience: patience,
                target_metric: target_metric,
                target_eval: target_eval,
                mode: mode,
                since_last_improvement: since_last_improvement
              }
            } = meta_vars,
          metrics: metrics
        } = state
      ) do
    unless Map.has_key?(metrics, target_eval) do
      raise ArgumentError,
            "target eval_set #{inspect(target_eval)} not found in metrics #{inspect(metrics)}"
    end

    unless Map.has_key?(metrics[target_eval], target_metric) do
      raise ArgumentError,
            "target metric #{inspect(target_metric)} not found in metrics #{inspect(metrics)}"
    end

    prev_criteria_value = best

    cur_criteria_value = metrics[target_eval][target_metric]

    improved? =
      case mode do
        :min ->
          prev_criteria_value == nil or
            cur_criteria_value < prev_criteria_value

        :max ->
          prev_criteria_value == nil or
            cur_criteria_value > prev_criteria_value
      end

    over_patience? = since_last_improvement >= patience

    cond do
      improved? ->
        updated_meta_vars =
          meta_vars
          |> put_in([:early_stop, :best], cur_criteria_value)
          |> put_in([:early_stop, :since_last_improvement], 0)

        bst =
          bst
          |> struct(best_iteration: state.iteration, best_score: cur_criteria_value)
          |> EXGBoost.Booster.set_attr(
            best_iteration: state.iteration,
            best_score: cur_criteria_value
          )

        {:cont, %{state | meta_vars: updated_meta_vars, booster: bst}}

      not improved? and not over_patience? ->
        updated_meta_vars =
          meta_vars
          |> put_in([:early_stop, :since_last_improvement], since_last_improvement + 1)

        {:cont, %{state | meta_vars: updated_meta_vars}}

      true ->
        updated_meta_vars =
          meta_vars
          |> put_in([:early_stop, :since_last_improvement], since_last_improvement + 1)

        bst = struct(bst, best_iteration: state.iteration, best_score: cur_criteria_value)
        {:halt, %{state | meta_vars: updated_meta_vars, booster: bst}}
    end
  end

  @doc """
  A callback function that evaluates metrics on the training and evaluation sets.

  Requires that the following exist in the `state.meta_vars` that is passed to the callback:
   * eval_metrics:
      * evals: a list of evaluation sets to evaluate metrics on
      * filter: a function that takes a metric name and value and returns
      true if the metric should be included in the results
  """
  def eval_metrics(
        %State{
          booster: bst,
          iteration: iter,
          meta_vars: %{eval_metrics: %{evals: evals, filter: filter}}
        } = state
      ) do
    metrics =
      EXGBoost.Booster.eval_set(bst, evals, iter)
      |> Enum.reduce(%{}, fn {evname, mname, value}, acc ->
        Map.update(acc, evname, %{mname => value}, fn existing ->
          Map.put(existing, mname, value)
        end)
      end)
      |> Map.filter(filter)

    {:cont, %{state | metrics: metrics}}
    {:cont, %{state | metrics: metrics}}
  end

  @doc """
  A callback function that prints evaluation metrics according to a period.

  Requires that the following exist in the `state.meta_vars` that is passed to the callback:
   * monitor_metrics:
      * period: print metrics every `period` iterations
      * filter: a function that takes a metric name and value and returns
      true if the metric should be included in the results
  """
  def monitor_metrics(
        %State{
          iteration: iteration,
          metrics: metrics,
          meta_vars: %{
            monitor_metrics: %{period: period, filter: filter}
          }
        } = state
      ) do
    if period != 0 and rem(iteration, period) == 0 do
      metrics = Map.filter(metrics, filter)
      IO.puts("Iteration #{iteration}: #{inspect(metrics)}")
    end

    {:cont, state}
  end
end