Skip to main content

examples/dataset_utils.exs

# Dataset Utilities with ExBurn
# Run with: mix run examples/dataset_utils.exs
#
# Demonstrates the ExBurn.Dataset module for common data preprocessing:
#   1. Splitting data into train/validation sets
#   2. Creating batched data loaders with shuffling
#   3. Normalization (standard, minmax, l2)
#   4. One-hot encoding
#   5. Dataset statistics

Mix.install([
  {:nx, "~> 0.7"},
  {:axon, "~> 0.7"},
  {:ex_burn, path: Path.expand("..", __DIR__)}
])

defmodule DatasetUtils do
  @moduledoc """
  Showcases ExBurn.Dataset utilities for data preprocessing.
  """

  def run do
    IO.puts("=== Dataset Utilities with ExBurn ===\n")

    # ── 1. Create sample dataset ───────────────────────────────
    key = Nx.Random.key(42)

    # 100 samples, 5 features
    {x, key} = Nx.Random.normal(key, 0.0, 1.0, shape: {100, 5})

    # Integer class labels (3 classes)
    {labels, _key} = Nx.Random.uniform(key, 0, 2.999, shape: {100})
    labels = Nx.as_type(labels, {:s, 64})

    IO.puts("Original dataset:")
    stats = ExBurn.Dataset.stats({x, labels})
    IO.puts("  Samples: #{stats.num_samples}")
    IO.puts("  Input shape: #{inspect(stats.input_shape)}")
    IO.puts("  Target shape: #{inspect(stats.target_shape)}")
    IO.puts("  Input type: #{inspect(stats.input_type)}")
    IO.puts("  Target type: #{inspect(stats.target_type)}\n")

    # ── 2. Split into train / validation ───────────────────────
    {train, val} = ExBurn.Dataset.split({x, labels}, val_split: 0.25, seed: 42)
    {train_x, train_y} = train
    {val_x, val_y} = val

    IO.puts("After split (val_split=0.25, seed=42):")
    IO.puts("  Train: #{Nx.shape(train_x) |> elem(0)} samples")
    IO.puts("  Val:   #{Nx.shape(val_x) |> elem(0)} samples\n")

    # ── 3. Normalization ───────────────────────────────────────
    IO.puts("Normalization:")

    # Standard (z-score) normalization
    {train_std, std_stats} = ExBurn.Dataset.normalize(train_x, method: :standard)
    _val_std = ExBurn.Dataset.normalize_with_stats(val_x, std_stats)

    IO.puts("  Standard (z-score):")

    IO.puts(
      "    Train mean (first feature): #{train_std |> Nx.slice([0, 0], [1, 1]) |> Nx.to_number() |> Float.round(4)}"
    )

    # Min-max normalization
    {train_minmax, minmax_stats} = ExBurn.Dataset.normalize(train_x, method: :minmax)
    _val_minmax = ExBurn.Dataset.normalize_with_stats(val_x, minmax_stats)

    IO.puts("  Min-Max [0, 1]:")

    IO.puts(
      "    Train min: #{train_minmax |> Nx.reduce_min() |> Nx.to_number() |> Float.round(4)}"
    )

    IO.puts(
      "    Train max: #{train_minmax |> Nx.reduce_max() |> Nx.to_number() |> Float.round(4)}"
    )

    # L2 normalization
    {train_l2, l2_stats} = ExBurn.Dataset.normalize(train_x, method: :l2)
    _val_l2 = ExBurn.Dataset.normalize_with_stats(val_x, l2_stats)

    IO.puts("  L2 (row-wise unit norm):")
    # Check that rows have unit norm
    row_norms = Nx.sqrt(Nx.sum(Nx.multiply(train_l2, train_l2), axes: [-1]))
    IO.puts("    Mean row norm: #{Nx.mean(row_norms) |> Nx.to_number() |> Float.round(4)}")

    # ── 4. One-hot encoding ────────────────────────────────────
    IO.puts("\nOne-hot encoding (3 classes):")
    one_hot = ExBurn.Dataset.one_hot(train_y, num_classes: 3)
    IO.puts("  Shape: #{inspect(Nx.shape(one_hot))}")
    IO.puts("  First 3 rows:")
    one_hot_slice = Nx.slice(one_hot, [0, 0], [3, 3])

    Enum.each(0..2, fn i ->
      row = Nx.slice(one_hot_slice, [i, 0], [1, 3]) |> Nx.to_flat_list()

      IO.puts(
        "    Class #{Enum.at(row, 0) |> round()}#{Enum.map_join(row, ", ", &Float.round(&1, 1))}"
      )
    end)

    # ── 5. Batched data loader ─────────────────────────────────
    IO.puts("\nBatched data loader (batch_size=16, shuffle=true):")
    loader = ExBurn.Dataset.loader({train_x, train_y}, batch_size: 16, shuffle: true)

    loader
    |> Stream.take(3)
    |> Enum.with_index(1)
    |> Enum.each(fn {{batch_x, batch_y}, idx} ->
      IO.puts("  Batch #{idx}: x=#{inspect(Nx.shape(batch_x))}, y=#{inspect(Nx.shape(batch_y))}")
    end)

    # ── 6. Data loader with drop_last ─────────────────────────
    IO.puts("\nData loader with drop_last=true:")
    loader_drop = ExBurn.Dataset.loader({train_x, train_y}, batch_size: 16, drop_last: true)
    batches = Enum.to_list(loader_drop)
    IO.puts("  Total batches: #{length(batches)}")

    # Without drop_last
    loader_keep = ExBurn.Dataset.loader({train_x, train_y}, batch_size: 16, drop_last: false)
    batches_keep = Enum.to_list(loader_keep)
    IO.puts("  Total batches (keep last): #{length(batches_keep)}")

    IO.puts("\n=== Done ===")
  end
end

DatasetUtils.run()