docs/guides/regularizers.md

# Regularizers

This guide covers the structured regularizer composition system in Tinkex, which enables modular loss engineering for training LLMs. You'll learn how to implement custom regularizers, compose multiple regularization strategies, track gradient norms, and integrate with the Tinker API.

> Looking for the live training flow that sends gradients to the backend? See `docs/guides/custom_loss_training.md` for the `forward_backward_custom/4` pipeline that operates on per-datum logprobs and returns `ForwardBackwardOutput`.

## Overview

Regularizers add penalty terms to the base loss function during training to encourage desired model behaviors such as:

- **Sparsity** (L1): Encourage sparse activations or weight distributions
- **Weight decay** (L2): Prevent large weights and overfitting
- **Entropy**: Promote diversity in predictions
- **Custom constraints**: Domain-specific penalties (KL divergence, feature correlation, etc.)

The regularizer system in Tinkex composes multiple weighted regularizers into a total loss:

```
loss_total = base_loss + Σ(weight_i × regularizer_i)
```

Each regularizer is executed independently (optionally in parallel), with full telemetry and optional gradient norm tracking for monitoring training dynamics.

## Core Concepts

### The Regularizer Behaviour

Regularizers implement the `Tinkex.Regularizer` behaviour, which defines two callbacks:

```elixir
@callback compute(
  data :: list(Datum.t()),
  logprobs :: Nx.Tensor.t(),
  opts :: keyword()
) :: {Nx.Tensor.t(), %{String.t() => number()}}

@callback name() :: String.t()
```

The `compute/3` callback:
- Takes training data and log probabilities from the forward pass
- Returns a tuple of `{loss_tensor, metrics_map}`
- The loss tensor should be a scalar (or will be summed automatically)
- Metrics are custom measurements for telemetry (e.g., `%{"l1_value" => 0.042}`)

The optional `name/0` callback provides a unique identifier for telemetry and logging. If not implemented, the name must be provided via `RegularizerSpec`.

### RegularizerSpec

The `RegularizerSpec` struct configures how a regularizer is executed:

```elixir
%RegularizerSpec{
  fn: function() | module(),      # Regularizer function or module
  weight: float(),                # Non-negative multiplier
  name: String.t(),              # Unique identifier
  async: boolean()               # Whether fn returns a Task (default: false)
}
```

**Fields:**
- `fn`: Either an anonymous function (arity 2 or 3) or a module implementing the `Regularizer` behaviour
- `weight`: Multiplier applied to the regularizer loss (must be >= 0)
- `name`: Unique name for telemetry events and output indexing
- `async`: If `true`, the function should return a `Task.t()` for async execution

Create a spec using `RegularizerSpec.new/1`:

```elixir
spec = RegularizerSpec.new(%{
  fn: &my_regularizer/2,
  weight: 0.01,
  name: "l1_sparsity"
})
```

## Implementing Regularizers

### As Anonymous Functions

The simplest approach is to use anonymous functions:

```elixir
# Arity 2: (data, logprobs) -> {loss, metrics}
l1_regularizer = fn _data, logprobs ->
  l1_loss = Nx.sum(Nx.abs(logprobs))
  {l1_loss, %{}}
end

spec = RegularizerSpec.new(%{
  fn: l1_regularizer,
  weight: 0.01,
  name: "l1_sparsity"
})
```

You can also use arity 3 to receive options:

```elixir
# Arity 3: (data, logprobs, opts) -> {loss, metrics}
configurable_l1 = fn _data, logprobs, opts ->
  threshold = Keyword.get(opts, :threshold, 0.0)

  # Only penalize values above threshold
  masked = Nx.select(Nx.greater(Nx.abs(logprobs), threshold), logprobs, 0)
  l1_loss = Nx.sum(Nx.abs(masked))

  {l1_loss, %{"threshold" => threshold}}
end

spec = RegularizerSpec.new(%{
  fn: configurable_l1,
  weight: 0.01,
  name: "l1_sparsity",
})
```

### As Behaviour-Implementing Modules

For reusable regularizers, implement the behaviour in a module:

```elixir
defmodule MyRegularizers.L1Sparsity do
  @behaviour Tinkex.Regularizer

  @impl true
  def compute(_data, logprobs, _opts) do
    l1_loss = Nx.sum(Nx.abs(logprobs))
    l1_value = Nx.to_number(l1_loss)

    {l1_loss, %{"l1_value" => l1_value}}
  end

  @impl true
  def name, do: "l1_sparsity"
end

# Use in a spec
spec = RegularizerSpec.new(%{
  fn: MyRegularizers.L1Sparsity,
  weight: 0.01,
  name: MyRegularizers.L1Sparsity.name()
})
```

### Gradient Tracking Compatibility

**Important**: When using gradient norm tracking (`:track_grad_norms => true`), avoid calling `Nx.to_number/1` inside the regularizer function. Nx's automatic differentiation requires operations to remain as tensors during tracing.

```elixir
# BAD: Calls Nx.to_number inside the function
bad_regularizer = fn _data, logprobs ->
  l1 = Nx.sum(Nx.abs(logprobs))
  # This breaks gradient computation!
  {l1, %{"l1_value" => Nx.to_number(l1)}}
end

# GOOD: Returns empty metrics or computes them from the tensor later
good_regularizer = fn _data, logprobs ->
  l1 = Nx.sum(Nx.abs(logprobs))
  # Metrics will be computed from the loss value by the pipeline
  {l1, %{}}
end
```

## Common Regularizer Examples

### L1 Sparsity

Encourages sparse activations by penalizing the L1 norm:

```elixir
l1_spec = RegularizerSpec.new(%{
  fn: fn _data, logprobs ->
    {Nx.sum(Nx.abs(logprobs)), %{}}
  end,
  weight: 0.01,
  name: "l1_sparsity"
})
```

### L2 Weight Decay

Penalizes large weights (L2 norm):

```elixir
l2_spec = RegularizerSpec.new(%{
  fn: fn _data, logprobs ->
    {Nx.sum(Nx.pow(logprobs, 2)), %{}}
  end,
  weight: 0.005,
  name: "l2_weight_decay"
})
```

### Entropy Regularization

Encourages diversity in predictions by maximizing entropy:

```elixir
entropy_spec = RegularizerSpec.new(%{
  fn: fn _data, logprobs ->
    # Convert log probs to probs
    probs = Nx.exp(logprobs)
    # Negative entropy (we minimize, so negate to maximize entropy)
    neg_entropy = Nx.sum(Nx.multiply(probs, logprobs))
    {neg_entropy, %{}}
  end,
  weight: 0.001,
  name: "entropy"
})
```

### KL Divergence from Target Distribution

Encourage the model to match a target distribution:

```elixir
kl_spec = RegularizerSpec.new(%{
  fn: fn _data, logprobs ->
    # Assume uniform target distribution
    target_logprobs = Nx.broadcast(
      Nx.log(1.0 / Nx.size(logprobs)),
      Nx.shape(logprobs)
    )

    # KL(target || model) = sum(target * (log(target) - log(model)))
    probs = Nx.exp(logprobs)
    target_probs = Nx.exp(target_logprobs)
    kl = Nx.sum(
      Nx.multiply(
        target_probs,
        Nx.subtract(target_logprobs, logprobs)
      )
    )

    {kl, %{}}
  end,
  weight: 0.002,
  name: "kl_uniform"
})
```

### NxPenalties adapter options: KL direction/symmetric and entropy temperature

The built-in NxPenalties-backed adapters expose additional controls:

```elixir
# Forward (default) vs reverse KL, plus symmetric averaging
kl_forward = RegularizerSpec.new(%{
  fn: fn data, logprobs ->
    Regularizers.KLDivergence.compute(data, logprobs,
      reference_field: :reference_logprobs,
      direction: :forward
    )
  end,
  weight: 0.01,
  name: "kl_forward"
})

kl_reverse = RegularizerSpec.new(%{
  fn: fn data, logprobs ->
    Regularizers.KLDivergence.compute(data, logprobs,
      reference_field: :reference_logprobs,
      direction: :reverse # mode-seeking; penalizes mass outside sharp targets
    )
  end,
  weight: 0.01,
  name: "kl_reverse"
})

kl_symmetric = RegularizerSpec.new(%{
  fn: fn data, logprobs ->
    Regularizers.KLDivergence.compute(data, logprobs,
      reference_field: :reference_logprobs,
      symmetric: true # (KL(P||Q) + KL(Q||P)) / 2
    )
  end,
  weight: 0.01,
  name: "kl_symmetric"
})

# Entropy temperature scaling (sharper < 1.0, flatter > 1.0)
entropy_cool = RegularizerSpec.new(%{
  fn: fn data, logprobs ->
    Regularizers.Entropy.compute(data, logprobs,
      mode: :maximize,
      temperature: 0.5
    )
  end,
  weight: 0.001,
  name: "entropy_cool"
})
```

## Composing Regularizer Pipelines

### Basic Pipeline Execution

Use `Regularizer.Pipeline.compute/4` to compose base loss with regularizers:

```elixir
alias Tinkex.Regularizer.Pipeline
alias Tinkex.Types.RegularizerSpec

# Define base loss function
base_loss_fn = fn _data, logprobs ->
  # Negative log-likelihood
  nll = Nx.negate(Nx.mean(logprobs))
  {nll, %{}}
end

# Define regularizers
regularizers = [
  RegularizerSpec.new(%{fn: &l1/2, weight: 0.01, name: "l1"}),
  RegularizerSpec.new(%{fn: &l2/2, weight: 0.005, name: "l2"}),
  RegularizerSpec.new(%{fn: &entropy/2, weight: 0.001, name: "entropy"})
]

# Execute pipeline
{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers
)

# Access results
IO.puts("Total loss: #{output.loss_total}")
IO.puts("Base loss: #{output.base_loss.value}")
IO.puts("Regularizer total: #{output.regularizer_total}")

# Per-regularizer breakdown
for {name, reg} <- output.regularizers do
  IO.puts("#{name}: value=#{reg.value}, contribution=#{reg.contribution}")
end
```

### Pipeline Options

`Pipeline.compute/4` accepts the following options:

- `:regularizers` - List of `RegularizerSpec` structs (default: `[]`)
- `:track_grad_norms` - Compute gradient norms for monitoring (default: `false`)
- `:parallel` - Execute regularizers in parallel (default: `true`)
- `:timeout` - Timeout for async operations in milliseconds (default: `30_000`)
- `:max_concurrency` - Max parallel tasks (default: `System.schedulers_online()`)

Example with options:

```elixir
{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers,
  track_grad_norms: true,
  parallel: true,
  timeout: 60_000,
  max_concurrency: 4
)
```

### Sequential vs Parallel Execution

By default, regularizers execute in parallel for better throughput:

```elixir
# Parallel execution (default)
{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers,
  parallel: true
)
```

For deterministic execution order or debugging, use sequential mode:

```elixir
# Sequential execution
{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers,
  parallel: false
)
```

## Gradient Norm Tracking

Gradient norms help you monitor which components dominate the training signal. Enable tracking with `:track_grad_norms => true`:

```elixir
{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers,
  track_grad_norms: true
)

# Gradient norms are L2 norms: sqrt(sum(grad^2))
IO.puts("Base loss grad norm: #{output.base_loss.grad_norm}")
IO.puts("Total grad norm: #{output.total_grad_norm}")

for {name, reg} <- output.regularizers do
  IO.puts("#{name} grad norm: #{reg.grad_norm}")
  IO.puts("#{name} weighted grad norm: #{reg.grad_norm_weighted}")
end
```

### Understanding Gradient Norms

- **Base loss grad norm**: Gradient contribution from the base loss alone
- **Per-regularizer grad norm**: Gradient contribution from each regularizer (unweighted)
- **Weighted grad norm**: `weight × grad_norm` (actual contribution to total gradient)
- **Total grad norm**: L2 norm of the complete composed gradient

These metrics help identify:
- Which regularizers dominate training
- Whether regularizers are too strong/weak
- Training instability (exploding/vanishing gradients)

### Direct Gradient Computation

For custom gradient analysis, use `GradientTracker` directly:

```elixir
alias Tinkex.Regularizer.GradientTracker

# Compute gradient norm for a loss function
loss_fn = fn logprobs -> Nx.sum(Nx.abs(logprobs)) end
grad_norm = GradientTracker.compute_grad_norm(loss_fn, logprobs)

# Compute gradient norm for a regularizer spec
grad_norm = GradientTracker.grad_norm_for_regularizer(spec, data, logprobs)

# Compute total composed gradient norm
total_norm = GradientTracker.total_grad_norm(base_loss_fn, regularizers, data, logprobs)
```

## Executing Regularizers

### Via Pipeline (Recommended)

The pipeline is the high-level API that handles everything:

```elixir
{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers
)
```

### Via Executor (Low-Level)

For fine-grained control, use `Executor` directly:

```elixir
alias Tinkex.Regularizer.Executor

# Execute a single regularizer
{:ok, output} = Executor.execute_one(spec, data, logprobs,
  timeout: 5000,
  track_grad_norms: true
)

# Execute all regularizers
{:ok, outputs} = Executor.execute_all(regularizers, data, logprobs,
  parallel: true,
  timeout: 30_000,
  track_grad_norms: true
)
```

### Via Regularizer Module (Direct)

Execute regularizers directly without specs:

```elixir
alias Tinkex.Regularizer

# With anonymous function (arity 2)
{loss, metrics} = Regularizer.execute(
  fn _data, logprobs -> {Nx.sum(logprobs), %{}} end,
  data,
  logprobs
)

# With anonymous function (arity 3)
{loss, metrics} = Regularizer.execute(
  fn _data, logprobs, opts -> {Nx.sum(logprobs), opts} end,
  data,
  logprobs,
  custom_option: "value"
)

# With module
{loss, metrics} = Regularizer.execute(MyRegularizer, data, logprobs)
```

## Async Regularizers

For I/O-bound operations (e.g., calling external APIs, querying databases), use async regularizers:

```elixir
async_spec = RegularizerSpec.new(%{
  fn: fn _data, logprobs ->
    Task.async(fn ->
      # Simulate external API call
      :timer.sleep(100)

      # Compute penalty based on external validation
      penalty = Nx.mean(Nx.abs(logprobs))
      {penalty, %{"external_validated" => true}}
    end)
  end,
  weight: 0.02,
  name: "async_validator",
  async: true  # Mark as async
})

{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: [async_spec],
  timeout: 5000  # Wait up to 5s for async tasks
)
```

The executor will automatically `Task.await/2` the result with the specified timeout.

## Telemetry Integration

The regularizer system emits comprehensive telemetry events for observability.

### Custom Loss Pipeline Events

**`[:tinkex, :custom_loss, :start]`**
- Measurements: `%{system_time: integer()}`
- Metadata: `%{regularizer_count: integer(), track_grad_norms: boolean()}`

**`[:tinkex, :custom_loss, :stop]`**
- Measurements: `%{duration: integer(), loss_total: float(), regularizer_total: float()}`
- Metadata: `%{regularizer_count: integer()}`

**`[:tinkex, :custom_loss, :exception]`**
- Measurements: `%{duration: integer()}`
- Metadata: `%{reason: term()}`

### Per-Regularizer Events

**`[:tinkex, :regularizer, :compute, :start]`**
- Measurements: `%{system_time: integer()}`
- Metadata: `%{regularizer_name: String.t(), weight: float(), async: boolean()}`

**`[:tinkex, :regularizer, :compute, :stop]`**
- Measurements: `%{duration: integer(), value: float(), contribution: float(), grad_norm: float() | nil}`
- Metadata: `%{regularizer_name: String.t(), weight: float(), async: boolean()}`

**`[:tinkex, :regularizer, :compute, :exception]`**
- Measurements: `%{duration: integer()}`
- Metadata: `%{regularizer_name: String.t(), weight: float(), reason: term()}`

### Attaching Telemetry Handlers

Use the built-in telemetry helper:

```elixir
alias Tinkex.Regularizer.Telemetry

# Attach logger (logs all events)
handler_id = Telemetry.attach_logger(level: :info)

# Run pipeline (emits telemetry)
{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers,
  track_grad_norms: true
)

# Detach when done
Telemetry.detach(handler_id)
```

Or attach custom handlers:

```elixir
:telemetry.attach(
  "my-regularizer-handler",
  [:tinkex, :regularizer, :compute, :stop],
  fn event, measurements, metadata, _config ->
    IO.puts("Regularizer #{metadata.regularizer_name} completed in #{measurements.duration}μs")
    IO.puts("  Value: #{measurements.value}")
    IO.puts("  Contribution: #{measurements.contribution}")
  end,
  nil
)
```

## Output Structure

### CustomLossOutput

The pipeline returns a `CustomLossOutput` struct:

```elixir
%CustomLossOutput{
  loss_total: float(),              # Total composed loss
  base_loss: %{                     # Base loss component
    value: float(),
    metrics: map(),
    grad_norm: float() | nil
  },
  regularizers: %{                  # Per-regularizer outputs
    String.t() => RegularizerOutput.t()
  },
  regularizer_total: float(),       # Sum of all regularizer contributions
  total_grad_norm: float() | nil    # Total gradient L2 norm
}
```

### RegularizerOutput

Each regularizer produces a `RegularizerOutput`:

```elixir
%RegularizerOutput{
  name: String.t(),                 # Regularizer name
  value: float(),                   # Raw loss value
  weight: float(),                  # Weight multiplier
  contribution: float(),            # weight × value (added to total)
  custom_metrics: map(),            # Custom metrics from compute/3
  grad_norm: float() | nil,         # Gradient L2 norm
  grad_norm_weighted: float() | nil # weight × grad_norm
}
```

### JSON Serialization

Both output types implement `Jason.Encoder` for easy serialization:

```elixir
{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers,
  track_grad_norms: true
)

# Serialize to JSON
json = Jason.encode!(output, pretty: true)
File.write!("training_metrics.json", json)

# Deserialize (manual reconstruction)
data = Jason.decode!(json)
```

## Error Handling

The pipeline and executor provide comprehensive error handling.

### Common Error Patterns

**Duplicate regularizer names:**

```elixir
regularizers = [
  RegularizerSpec.new(%{fn: &l1/2, weight: 0.01, name: "dup"}),
  RegularizerSpec.new(%{fn: &l2/2, weight: 0.02, name: "dup"})
]

{:error, {:pipeline_failed, %ArgumentError{message: msg}}} =
  Pipeline.compute(data, logprobs, base_loss_fn, regularizers: regularizers)

# msg: "Duplicate regularizer names: [\"dup\"]"
```

**Invalid base loss function:**

```elixir
{:error, {:pipeline_failed, %ArgumentError{}}} =
  Pipeline.compute(data, logprobs, "not a function")
```

**Regularizer execution failure:**

```elixir
failing_spec = RegularizerSpec.new(%{
  fn: fn _data, _logprobs -> raise "oops" end,
  weight: 0.01,
  name: "failing"
})

{:error, {:regularizer_failed, "failing", exception}} =
  Executor.execute_one(failing_spec, data, logprobs)
```

**Timeout:**

```elixir
slow_spec = RegularizerSpec.new(%{
  fn: fn _data, logprobs ->
    Task.async(fn ->
      :timer.sleep(10_000)
      {Nx.sum(logprobs), %{}}
    end)
  end,
  weight: 0.01,
  name: "slow",
  async: true
})

{:error, :timeout} =
  Executor.execute_one(slow_spec, data, logprobs, timeout: 100)
```

### Handling Errors

Always pattern match on error tuples:

```elixir
case Pipeline.compute(data, logprobs, base_loss_fn, regularizers: regularizers) do
  {:ok, output} ->
    # Success - use output
    process_training_step(output)

  {:error, {:pipeline_failed, exception}} ->
    # Pipeline-level error
    Logger.error("Pipeline failed: #{Exception.message(exception)}")
    reraise exception, __STACKTRACE__

  {:error, {:regularizer_failed, name, exception}} ->
    # Specific regularizer failed
    Logger.error("Regularizer #{name} failed: #{inspect(exception)}")
    :retry

  {:error, {:regularizer_exit, name, reason}} ->
    # Regularizer process exited
    Logger.error("Regularizer #{name} exited: #{inspect(reason)}")
    :halt

  {:error, other} ->
    # Other errors
    Logger.error("Unknown error: #{inspect(other)}")
    :halt
end
```

## Integration with Training API

When using Tinkex with a live Tinker backend, wrap regularizers in `TrainingClient.forward_backward_custom/4`:

```elixir
alias Tinkex.Types.{Datum, ModelInput, RegularizerSpec}

# 1. Create training client
config = Tinkex.Config.new(api_key: System.fetch_env!("TINKER_API_KEY"))
{:ok, service} = Tinkex.ServiceClient.start_link(config: config)
{:ok, training} = Tinkex.ServiceClient.create_lora_training_client(service, "meta-llama/Llama-3.1-8B",
  lora_config: %Tinkex.Types.LoraConfig{rank: 16}
)

# 2. Prepare training data
{:ok, model_input} = ModelInput.from_text("The quick brown fox",
  model_name: "meta-llama/Llama-3.1-8B",
  training_client: training
)

datum = Datum.new(%{
  model_input: model_input,
  loss_fn_inputs: %{
    target_tokens: Nx.tensor([1, 2, 3, 4, 5]),
    weights: Nx.tensor([1.0, 1.0, 1.0, 1.0, 1.0])
  }
})

# 3. Define base loss and regularizers
base_loss_fn = fn _data, logprobs ->
  nll = Nx.negate(Nx.mean(logprobs))
  {nll, %{}}
end

regularizers = [
  RegularizerSpec.new(%{fn: &l1/2, weight: 0.01, name: "l1"}),
  RegularizerSpec.new(%{fn: &entropy/2, weight: 0.001, name: "entropy"})
]

# 4. Execute forward-backward pass with custom loss
{:ok, task} = Tinkex.TrainingClient.forward_backward_custom(
  training,
  [datum],
  base_loss_fn,
  regularizers: regularizers,
  track_grad_norms: true
)

# 5. Await results
{:ok, output} = Task.await(task, :infinity)

# output contains real logprobs from the server!
IO.puts("Total loss: #{output.loss_total}")
IO.puts("Base loss: #{output.base_loss.value}")
IO.puts("Regularizer total: #{output.regularizer_total}")
```

The `TrainingClient.forward_backward_custom/4` function:
1. Sends the training data to the Tinker server
2. Performs a forward pass to get log probabilities
3. Executes `Pipeline.compute/4` locally with the returned logprobs
4. Returns the composed `CustomLossOutput`

## Complete Example

Here's a complete example demonstrating all features:

```elixir
alias Tinkex.Regularizer.Pipeline
alias Tinkex.Types.RegularizerSpec

# Define base loss
base_loss_fn = fn _data, logprobs ->
  nll = Nx.negate(Nx.mean(logprobs))
  {nll, %{}}
end

# Define regularizers
l1_spec = RegularizerSpec.new(%{
  fn: fn _data, logprobs ->
    {Nx.sum(Nx.abs(logprobs)), %{}}
  end,
  weight: 0.01,
  name: "l1_sparsity"
})

l2_spec = RegularizerSpec.new(%{
  fn: fn _data, logprobs ->
    {Nx.sum(Nx.pow(logprobs, 2)), %{}}
  end,
  weight: 0.005,
  name: "l2_weight_decay"
})

entropy_spec = RegularizerSpec.new(%{
  fn: fn _data, logprobs ->
    probs = Nx.exp(logprobs)
    neg_entropy = Nx.sum(Nx.multiply(probs, logprobs))
    {neg_entropy, %{}}
  end,
  weight: 0.001,
  name: "entropy"
})

regularizers = [l1_spec, l2_spec, entropy_spec]

# Mock data
logprobs = Nx.tensor([-0.5, -1.2, -0.8, -2.1, -0.3])
data = []

# Execute pipeline with all features
{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers,
  track_grad_norms: true,
  parallel: true,
  timeout: 30_000
)

# Display results
IO.puts("=== Training Step Results ===")
IO.puts("Total Loss: #{Float.round(output.loss_total, 6)}")
IO.puts("Base Loss: #{Float.round(output.base_loss.value, 6)}")
IO.puts("Regularizer Total: #{Float.round(output.regularizer_total, 6)}")

if output.total_grad_norm do
  IO.puts("Total Grad Norm: #{Float.round(output.total_grad_norm, 6)}")
end

IO.puts("\n=== Per-Regularizer Breakdown ===")
for {name, reg} <- output.regularizers do
  IO.puts("\n#{name}:")
  IO.puts("  value: #{Float.round(reg.value, 6)}")
  IO.puts("  weight: #{reg.weight}")
  IO.puts("  contribution: #{Float.round(reg.contribution, 6)}")

  if reg.grad_norm do
    IO.puts("  grad_norm: #{Float.round(reg.grad_norm, 6)}")
    IO.puts("  grad_norm_weighted: #{Float.round(reg.grad_norm_weighted, 6)}")
  end
end

# Serialize to JSON
json = Jason.encode!(output, pretty: true)
File.write!("training_step.json", json)
IO.puts("\n✓ Saved to training_step.json")
```

## Best Practices

### 1. Start with Small Weights

Begin with small regularizer weights and increase gradually:

```elixir
# Start small
regularizers = [
  RegularizerSpec.new(%{fn: &l1/2, weight: 0.001, name: "l1"}),
  RegularizerSpec.new(%{fn: &l2/2, weight: 0.0005, name: "l2"})
]

# Monitor gradient norms to tune weights
{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers,
  track_grad_norms: true
)

# Adjust if regularizers dominate base loss
```

### 2. Use Gradient Norms for Tuning

Track gradient norms to ensure balanced contributions:

```elixir
# Check if regularizers are dominating
base_norm = output.base_loss.grad_norm
reg_norms = Enum.map(output.regularizers, fn {_name, reg} ->
  reg.grad_norm_weighted
end)
total_reg_norm = Enum.sum(reg_norms)

ratio = total_reg_norm / base_norm
IO.puts("Regularizer/Base gradient ratio: #{ratio}")

# Aim for ratio ~0.1 to 0.5 (regularizers shouldn't dominate)
```

### 3. Avoid Nx.to_number in Regularizers

Keep operations as tensors for gradient compatibility:

```elixir
# BAD
bad = fn _data, logprobs ->
  loss = Nx.sum(logprobs)
  {loss, %{"value" => Nx.to_number(loss)}}  # Breaks gradients!
end

# GOOD
good = fn _data, logprobs ->
  loss = Nx.sum(logprobs)
  {loss, %{}}  # Pipeline will compute metrics
end
```

### 4. Use Unique Names

Ensure each regularizer has a unique name for telemetry:

```elixir
# BAD - duplicate names
regularizers = [
  RegularizerSpec.new(%{fn: &l1/2, weight: 0.01, name: "reg"}),
  RegularizerSpec.new(%{fn: &l2/2, weight: 0.01, name: "reg"})  # Error!
]

# GOOD - unique names
regularizers = [
  RegularizerSpec.new(%{fn: &l1/2, weight: 0.01, name: "l1"}),
  RegularizerSpec.new(%{fn: &l2/2, weight: 0.01, name: "l2"})
]
```

### 5. Handle Errors Gracefully

Always pattern match on error results:

```elixir
case Pipeline.compute(data, logprobs, base_loss_fn, regularizers: regularizers) do
  {:ok, output} ->
    process_output(output)

  {:error, reason} ->
    Logger.error("Training step failed: #{inspect(reason)}")
    :retry
end
```

### 6. Use Parallel Execution

Enable parallel execution for multiple regularizers:

```elixir
# Parallel (default) - better throughput
{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers,
  parallel: true
)

# Sequential - only for debugging
{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers,
  parallel: false
)
```

### 7. Monitor with Telemetry

Attach telemetry handlers for production monitoring:

```elixir
:telemetry.attach(
  "my-training-monitor",
  [:tinkex, :custom_loss, :stop],
  fn _event, measurements, metadata, _config ->
    # Log to monitoring system
    MyMonitoring.record_metric("training.loss", measurements.loss_total)
    MyMonitoring.record_metric("training.regularizers", metadata.regularizer_count)
  end,
  nil
)
```

## See Also

- **API Reference**: `docs/guides/api_reference.md`
- **Training Loop**: `docs/guides/training_loop.md`
- **Examples**: `examples/structured_regularizers.exs`, `examples/structured_regularizers_live.exs`
- **Source Code**: `lib/tinkex/regularizer/`