Skip to main content

examples/training_callbacks.exs

# Training Callbacks with ExBurn
# Run with: mix run examples/training_callbacks.exs
#
# Demonstrates the ExBurn.Training callback system:
#   1. LoggingCallback — log metrics after each epoch
#   2. EarlyStoppingCallback — stop when validation loss plateaus
#   3. CheckpointCallback — save model checkpoints at intervals
#   4. WarmupCallback — gradually increase learning rate
#   5. ReduceLROnPlateauCallback — reduce LR when loss stops improving
#   6. HistoryCallback — record all metrics for later analysis
#   7. Custom callbacks — writing your own

Mix.install([
  {:nx, "~> 0.7"},
  {:axon, "~> 0.7"},
  {:ex_burn, path: Path.expand("..", __DIR__)}
])

defmodule TrainingCallbacks do
  @moduledoc """
  Showcases ExBurn.Training callbacks for customizing training.
  """

  def run do
    IO.puts("=== Training Callbacks with ExBurn ===\n")

    # ── 1. Create synthetic classification data ────────────────
    {train_x, train_y, val_x, val_y} = generate_data(500, 50)

    IO.puts("Dataset:")

    IO.puts(
      "  Train: #{Nx.shape(train_x) |> elem(0)} samples, #{Nx.shape(train_x) |> elem(1)} features"
    )

    IO.puts("  Val:   #{Nx.shape(val_x) |> elem(0)} samples")
    IO.puts("  Classes: 3\n")

    # ── 2. Define model ────────────────────────────────────────
    model =
      Axon.input("input", shape: {nil, 10})
      |> Axon.dense(32, activation: :relu, name: "hidden1")
      |> Axon.dropout(rate: 0.2)
      |> Axon.dense(16, activation: :relu, name: "hidden2")
      |> Axon.dense(3, name: "output")

    compiled =
      ExBurn.Model.compile(model,
        loss: :cross_entropy,
        optimizer: :adam,
        learning_rate: 0.001
      )

    # ── 3. Train with all callbacks ────────────────────────────
    IO.puts("Training with callbacks...\n")

    history_pid = setup_history()

    trained =
      ExBurn.Training.fit(compiled, {train_x, train_y},
        epochs: 50,
        batch_size: 32,
        validation_data: {val_x, val_y},
        verbose: false,
        lr_schedule: {:cosine, 0.001, 1.0e-5},
        clip_norm: 1.0,
        weight_decay: 1.0e-4,
        accuracy: true,
        callbacks: [
          # Log every epoch
          &ExBurn.Training.LoggingCallback.log/1,

          # Stop early if no improvement for 10 epochs
          ExBurn.Training.EarlyStoppingCallback.wait(10, 1.0e-4),

          # Save checkpoint every 10 epochs
          ExBurn.Training.CheckpointCallback.every(10, "/tmp/ex_burn_checkpoints"),

          # Warmup: lr from 1e-5 to 0.001 over 5 epochs
          ExBurn.Training.WarmupCallback.linear(5, 1.0e-5, 0.001),

          # Reduce LR on plateau
          ExBurn.Training.ReduceLROnPlateauCallback.new(patience: 5, factor: 0.5, min_lr: 1.0e-6),

          # Record history
          ExBurn.Training.HistoryCallback.new(),

          # Custom callback: print a message at epoch 25
          fn
            %{epoch: 25} = metrics ->
              IO.puts("  [Custom] Reached epoch 25! loss=#{Float.round(metrics.loss, 4)}")
              metrics

            metrics ->
              metrics
          end
        ]
      )

    # ── 4. Review history ──────────────────────────────────────
    IO.puts("\nTraining history (last 5 epochs):")
    history = ExBurn.Training.HistoryCallback.get_history(history_pid)

    history
    |> Enum.take(5)
    |> Enum.each(fn m ->
      epoch = m.epoch
      loss = Float.round(m.loss, 4)
      val_loss = if m.val_loss, do: Float.round(m.val_loss, 4), else: "N/A"
      acc = if m.accuracy, do: "#{Float.round(m.accuracy * 100, 1)}%", else: "N/A"

      IO.puts(
        "  Epoch #{String.pad_leading("#{epoch}", 2)}: loss=#{loss} val_loss=#{val_loss} acc=#{acc}"
      )
    end)

    # ── 5. Final evaluation ────────────────────────────────────
    IO.puts("\nFinal evaluation:")
    {train_loss, train_acc} = ExBurn.Training.evaluate(trained, {train_x, train_y}, true)
    {val_loss, val_acc} = ExBurn.Training.evaluate(trained, {val_x, val_y}, true)
    IO.puts("  Train: loss=#{Float.round(train_loss, 4)} acc=#{Float.round(train_acc * 100, 1)}%")
    IO.puts("  Val:   loss=#{Float.round(val_loss, 4)} acc=#{Float.round(val_acc * 100, 1)}%")

    # ── 6. Profile a training step ─────────────────────────────
    IO.puts("\nProfile single training step:")
    batch_x = Nx.slice(train_x, [0, 0], [32, 10])
    batch_y = Nx.slice(train_y, [0, 0], [32, 3])
    profile = ExBurn.Training.profile_step(trained, {batch_x, batch_y})
    IO.puts("  Forward:   #{profile.forward_ms} ms")
    IO.puts("  Backward:  #{profile.backward_ms} ms")
    IO.puts("  Optimizer: #{profile.optimizer_ms} ms")
    IO.puts("  Total:     #{profile.total_ms} ms")

    IO.puts("\n=== Done ===")
  end

  # ── Data Generation ──────────────────────────────────────────

  defp generate_data(n_train, n_val) do
    key = Nx.Random.key(42)
    input_dim = 10
    num_classes = 3

    # Training data
    {train_x, key} = Nx.Random.normal(key, 0.0, 1.0, shape: {n_train, input_dim})
    {train_labels, key} = Nx.Random.uniform(key, 0, num_classes - 0.001, shape: {n_train})
    train_labels = Nx.as_type(train_labels, {:s, 64})
    train_y = ExBurn.Dataset.one_hot(train_labels, num_classes: num_classes)

    # Add class-specific signal
    signal = Nx.multiply(train_y, 0.5)
    signal = Nx.slice(signal, [0, 0], [n_train, input_dim])
    train_x = Nx.add(train_x, signal)

    # Validation data
    {val_x, _key} = Nx.Random.normal(key, 0.0, 1.0, shape: {n_val, input_dim})
    {val_labels, _key} = Nx.Random.uniform(key, 0, num_classes - 0.001, shape: {n_val})
    val_labels = Nx.as_type(val_labels, {:s, 64})
    val_y = ExBurn.Dataset.one_hot(val_labels, num_classes: num_classes)
    val_signal = Nx.multiply(val_y, 0.5) |> Nx.slice([0, 0], [n_val, input_dim])
    val_x = Nx.add(val_x, val_signal)

    {train_x, train_y, val_x, val_y}
  end

  defp setup_history do
    {:ok, pid} = Agent.start_link(fn -> [] end)
    pid
  end
end

TrainingCallbacks.run()