# Model Management with ExBurn
# Run with: mix run examples/model_management.exs
#
# Demonstrates ExBurn.Model utilities for model inspection,
# serialization, and manipulation:
# 1. Model compilation with different optimizers
# 2. Model.summary() — Keras-style architecture table
# 3. Model.info() — parameter counts, memory estimates
# 4. Model.benchmark() — forward pass timing
# 5. Model.save/2 and Model.load/2 — serialization
# 6. Model.export/3 and Model.import_params/3 — JSON format
# 7. Model.quantize/2 — reduce precision for deployment
# 8. Model.freeze/2 and Model.unfreeze/2 — transfer learning
# 9. Model.clone/2 — deep copy
# 10. Model.update_params/2 — manual parameter updates
Mix.install([
{:nx, "~> 0.7"},
{:axon, "~> 0.7"},
{:ex_burn, path: Path.expand("..", __DIR__)}
])
defmodule ModelManagement do
@moduledoc """
Showcases ExBurn.Model utilities for model management.
"""
def run do
IO.puts("=== Model Management with ExBurn ===\n")
# ── 1. Create a model ──────────────────────────────────────
model =
Axon.input("input", shape: {nil, 10})
|> Axon.dense(64, activation: :relu, name: "encoder")
|> Axon.dropout(rate: 0.3)
|> Axon.dense(32, activation: :relu, name: "bottleneck")
|> Axon.dense(5, name: "classifier")
IO.puts("Model architecture:")
IO.puts(Axon.Display.display(model, []))
# ── 2. Compile with Adam ────────────────────────────────────
compiled =
ExBurn.Model.compile(model,
loss: :cross_entropy,
optimizer: :adam,
learning_rate: 0.001,
weight_decay: 1.0e-4
)
# ── 3. Model summary ────────────────────────────────────────
IO.puts(ExBurn.Model.summary(compiled))
# ── 4. Model info ───────────────────────────────────────────
info = ExBurn.Model.info(compiled)
IO.puts("Model info:")
IO.puts(" Total params: #{info.total_params}")
IO.puts(" Layer count: #{info.layer_count}")
IO.puts(" Loss function: #{info.loss_function}")
IO.puts(" Optimizer: #{info.optimizer}")
IO.puts(" Learning rate: #{info.learning_rate}")
IO.puts(" Device: #{info.device}")
IO.puts(" Weight decay: #{info.weight_decay}")
IO.puts(" Est. memory: #{info.estimated_memory_mb} MB")
IO.puts(" Compiled: #{info.compiled}\n")
# ── 5. Benchmark forward pass ──────────────────────────────
input = Nx.Random.uniform(Nx.Random.key(1), 0.0, 1.0, shape: {1, 10})
bench = ExBurn.Model.benchmark(compiled, input, warmup: 3, runs: 10)
IO.puts("Benchmark (10 runs, 3 warmup):")
IO.puts(" Avg: #{bench.avg_ms} ms")
IO.puts(" Min: #{bench.min_ms} ms")
IO.puts(" Max: #{bench.max_ms} ms")
IO.puts(" Median: #{bench.median_ms} ms")
IO.puts(" Std: #{bench.std_ms} ms\n")
# ── 6. Save and load ────────────────────────────────────────
model_path = "/tmp/ex_burn_model.model"
IO.puts("Saving model to #{model_path}...")
:ok = ExBurn.Model.save(compiled, model_path)
{:ok, loaded} = ExBurn.Model.load(compiled, model_path)
IO.puts(" Loaded successfully: #{loaded.compiled}")
IO.puts(" Params match: #{map_size(loaded.params) == map_size(compiled.params)}\n")
# ── 7. Export and import (JSON) ─────────────────────────────
json_path = "/tmp/ex_burn_model.json"
IO.puts("Exporting to JSON: #{json_path}...")
:ok = ExBurn.Model.export(compiled, json_path, format: :json)
{:ok, imported} = ExBurn.Model.import_params(compiled, json_path, format: :json)
IO.puts(" Imported successfully: #{imported.compiled}\n")
# ── 8. Serialize / deserialize params ──────────────────────
binary = ExBurn.Model.serialize_params(compiled)
{:ok, deserialized} = ExBurn.Model.deserialize_params(binary)
IO.puts("Serialize/deserialize params:")
IO.puts(" Binary size: #{byte_size(binary)} bytes")
IO.puts(" Keys match: #{map_size(deserialized) == map_size(compiled.params)}\n")
# ── 9. Quantize ─────────────────────────────────────────────
quantized = ExBurn.Model.quantize(compiled, :f16)
q_info = ExBurn.Model.info(quantized)
IO.puts("Quantized model (f16):")
IO.puts(" Total params: #{q_info.total_params}")
IO.puts(" Est. memory: #{q_info.estimated_memory_mb} MB\n")
# ── 10. Layer freezing (transfer learning) ──────────────────
frozen = ExBurn.Model.freeze(compiled, ["encoder", "bottleneck"])
IO.puts("Frozen layers: encoder, bottleneck")
IO.puts(" encoder frozen: #{ExBurn.Model.frozen?(frozen, "encoder")}")
IO.puts(" bottleneck frozen: #{ExBurn.Model.frozen?(frozen, "bottleneck")}")
IO.puts(" classifier frozen: #{ExBurn.Model.frozen?(frozen, "classifier")}")
IO.puts(
" All frozen: #{ExBurn.Model.frozen_layers(frozen) |> MapSet.to_list() |> Enum.join(", ")}"
)
unfrozen = ExBurn.Model.unfreeze(frozen, ["encoder"])
IO.puts("\nUnfrozen: encoder")
IO.puts(" encoder frozen: #{ExBurn.Model.frozen?(unfrozen, "encoder")}")
IO.puts(" bottleneck frozen: #{ExBurn.Model.frozen?(unfrozen, "bottleneck")}\n")
# ── 11. Clone ───────────────────────────────────────────────
snapshot = ExBurn.Model.clone(compiled)
IO.puts("Cloned model:")
IO.puts(" Params count: #{map_size(snapshot.params)}")
IO.puts(
" Same config: #{snapshot.optimizer == compiled.optimizer && snapshot.loss_fn == compiled.loss_fn}\n"
)
# ── 12. Update params manually ─────────────────────────────
# Zero out all parameters
zero_params =
Enum.map(compiled.params, fn {key, tensor} ->
{key, Nx.broadcast(Nx.tensor(0.0, type: Nx.type(tensor)), Nx.shape(tensor))}
end)
|> Map.new()
zeroed_model = ExBurn.Model.update_params(compiled, zero_params)
zeroed_info = ExBurn.Model.info(zeroed_model)
IO.puts("Model with zeroed params:")
IO.puts(" Total params: #{zeroed_info.total_params}")
IO.puts(" (params are zeroed but count is preserved)\n")
# ── 13. Compare optimizers ─────────────────────────────────
IO.puts("Optimizer comparison:")
for opt <- [:adam, :sgd, :rmsprop] do
m = ExBurn.Model.compile(model, loss: :cross_entropy, optimizer: opt, learning_rate: 0.001)
info = ExBurn.Model.info(m)
IO.puts(
" #{String.pad_trailing("#{opt}", 8)} lr=#{info.learning_rate} params=#{info.total_params}"
)
end
IO.puts("\n=== Done ===")
end
end
ModelManagement.run()