Skip to main content

examples/xor_classifier.exs

# XOR Classifier with ExBurn
# Run with: mix run examples/xor_classifier.exs
#
# Demonstrates training a neural network on the classic XOR problem.
# XOR is not linearly separable, so a hidden layer is required.
# This example shows:
#   1. A minimal non-linear classification problem
#   2. Using Model.summary() to inspect architecture
#   3. Training with validation data and early stopping
#   4. Using the Dataset module for splitting

Mix.install([
  {:nx, "~> 0.7"},
  {:axon, "~> 0.7"},
  {:ex_burn, path: Path.expand("..", __DIR__)}
])

defmodule XORClassifier do
  @moduledoc """
  XOR classification: learns the non-linear XOR function.
  Input: 2D binary vectors -> Output: XOR of the two bits.
  """

  def run do
    IO.puts("=== XOR Classifier with ExBurn ===\n")

    # 1. Create XOR dataset
    x = Nx.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], type: {:f, 32})
    y = Nx.tensor([[0], [1], [1], [0]], type: {:f, 32})

    # Augment with noisy copies so the model has enough data to learn
    {x_aug, y_aug} = augment_data(x, y, 50)

    IO.puts("Dataset: #{Nx.shape(x_aug) |> elem(0)} samples (4 unique + augmented)")
    IO.puts("  Input shape:  #{inspect(Nx.shape(x_aug))}")
    IO.puts("  Target shape: #{inspect(Nx.shape(y_aug))}")

    # 2. Split into train / validation
    {train, val} = ExBurn.Dataset.split({x_aug, y_aug}, val_split: 0.2, seed: 42)
    {train_x, train_y} = train
    {val_x, val_y} = val

    IO.puts("  Train: #{Nx.shape(train_x) |> elem(0)} samples")
    IO.puts("  Val:   #{Nx.shape(val_x) |> elem(0)} samples\n")

    # 3. Define model: 2 -> 8 (relu) -> 1 (sigmoid)
    model =
      Axon.input("input", shape: {nil, 2})
      |> Axon.dense(8, activation: :relu, name: "hidden")
      |> Axon.dense(1, activation: :sigmoid, name: "output")

    IO.puts("Model architecture: 2 -> 8 (relu) -> 1 (sigmoid)")

    # 4. Compile
    compiled =
      ExBurn.Model.compile(model,
        loss: :binary_cross_entropy,
        optimizer: :adam,
        learning_rate: 0.1
      )

    IO.puts(ExBurn.Model.summary(compiled))

    # 5. Train with early stopping
    IO.puts("Training...\n")

    trained =
      ExBurn.Training.fit(compiled, {train_x, train_y},
        epochs: 200,
        batch_size: 32,
        validation_data: {val_x, val_y},
        verbose: false,
        callbacks: [
          ExBurn.Training.EarlyStoppingCallback.wait(20, 1.0e-5),
          fn
            %{epoch: epoch, loss: loss, val_loss: val_loss} when rem(epoch, 50) == 0 ->
              IO.puts(
                "  Epoch #{String.pad_leading("#{epoch}", 3)}: loss=#{Float.round(loss, 5)} val_loss=#{Float.round(val_loss, 5)}"
              )

              %{epoch: epoch, loss: loss, val_loss: val_loss}

            metrics ->
              metrics
          end
        ]
      )

    # 6. Evaluate on all 4 XOR inputs
    IO.puts("\nPredictions on XOR truth table:")

    IO.puts(
      "  (0,0) -> #{predict(trained, Nx.tensor([[0, 0]], type: {:f, 32})) |> Nx.to_number() |> round_pred()}"
    )

    IO.puts(
      "  (0,1) -> #{predict(trained, Nx.tensor([[0, 1]], type: {:f, 32})) |> Nx.to_number() |> round_pred()}"
    )

    IO.puts(
      "  (1,0) -> #{predict(trained, Nx.tensor([[1, 0]], type: {:f, 32})) |> Nx.to_number() |> round_pred()}"
    )

    IO.puts(
      "  (1,1) -> #{predict(trained, Nx.tensor([[1, 1]], type: {:f, 32})) |> Nx.to_number() |> round_pred()}"
    )

    # Compute accuracy on validation set
    val_preds = predict_batch(trained, val_x)
    val_acc = compute_accuracy(val_preds, val_y)
    IO.puts("\nValidation accuracy: #{Float.round(val_acc * 100, 1)}%")

    IO.puts("\n=== Done ===")
  end

  defp predict(model, input) do
    {:ok, output} = ExBurn.Model.predict(model, input)
    output
  end

  defp predict_batch(model, input) do
    {:ok, output} = ExBurn.Model.predict(model, input)
    output
  end

  defp round_pred(val) when val >= 0.5, do: 1
  defp round_pred(_val), do: 0

  defp compute_accuracy(preds, targets) do
    pred_classes = Nx.greater_equal(preds, 0.5) |> Nx.as_type(:f32)
    correct = Nx.equal(pred_classes, targets) |> Nx.as_type(:f32)
    Nx.mean(correct) |> Nx.to_number()
  end

  defp augment_data(x, y, copies) do
    key = Nx.Random.key(123)

    num_rows = Nx.shape(x) |> elem(0)

    {noise, _key} =
      Nx.Random.normal(key, 0.0, 0.05, shape: {num_rows * copies, 2})

    x_aug = Nx.tile(x, [copies, 1]) |> Nx.add(noise)
    x_aug = Nx.clip(x_aug, 0.0, 1.0)
    y_aug = Nx.tile(y, [copies, 1])
    {x_aug, y_aug}
  end
end

XORClassifier.run()