README.md

<p align="center">
  <img src="assets/nx_penalties.svg" alt="NxPenalties" width="400">
</p>

<p align="center">
  <strong>Composable Regularization Penalties for Elixir ML</strong>
</p>

<p align="center">
  <a href="https://hex.pm/packages/nx_penalties">
    <img src="https://img.shields.io/hexpm/v/nx_penalties.svg" alt="Hex.pm Version">
  </a>
  <a href="https://hexdocs.pm/nx_penalties">
    <img src="https://img.shields.io/badge/docs-hexdocs-blue.svg" alt="Documentation">
  </a>
  <a href="https://opensource.org/licenses/MIT">
    <img src="https://img.shields.io/badge/License-MIT-yellow.svg" alt="License: MIT">
  </a>
</p>

---

## Overview

NxPenalties is a tensor-only library of regularization primitives for [Nx](https://github.com/elixir-nx/nx). It is designed to be composable inside `defn` code and training loops, leaving any data-aware adaptation (e.g., resolving references from data structures) to downstream libraries such as Tinkex.

### Features (v0.1.0)

- **Penalties**: L1, L2 (with centering/clipping), Elastic Net
- **Divergences**: KL, JS, Entropy (bonus/penalty, normalization)
- **Constraints**: Orthogonality (soft/hard/spectral, axis options), Consistency (MSE/L1/Cosine/KL)
- **Gradient Penalties**: Gradient norm, interpolated (WGAN-GP), output magnitude proxy
- **Pipeline**: Compose penalties with weights, enable/disable, gradient-compatible computation
- **Integrations**: Axon loss wrapping, Polaris gradient transforms (L1/L2/Elastic Net decay)
- **Debugging**: Gradient norm tracking, NaN/Inf validation
- **Telemetry**: Pipeline compute events

## Installation

Add `nx_penalties` to your list of dependencies in `mix.exs`:

```elixir
def deps do
  [
    {:nx_penalties, "~> 0.1.0"}
  ]
end
```

## Quick Start

### Simple Penalties

```elixir
# L1 penalty (promotes sparsity)
l1_loss = NxPenalties.l1(weights)
# => Nx.tensor(6.5)

# L2 penalty (weight decay)
l2_loss = NxPenalties.l2(weights, lambda: 0.01)
# => Nx.tensor(0.1425)

# Elastic Net (combined L1 + L2)
elastic_loss = NxPenalties.elastic_net(weights, l1_ratio: 0.5)
# => Nx.tensor(10.375)

# Add to your training loss
total_loss = Nx.add(base_loss, l1_loss)
```

### Pipeline Composition

Compose multiple penalties with individual weights:

```elixir
pipeline =
  NxPenalties.pipeline([
    {:l1, weight: 0.001},
    {:l2, weight: 0.01},
    {:entropy, weight: 0.1, opts: [mode: :bonus]}
  ])

{total_penalty, metrics} = NxPenalties.compute(pipeline, model_outputs)
total_loss = Nx.add(base_loss, total_penalty)
```

### Dynamic Weight Adjustment

Useful for curriculum learning or adaptive regularization:

```elixir
# Update weights during training
pipeline =
  pipeline
  |> NxPenalties.Pipeline.update_weight(:l1, 0.01)  # Increase L1
  |> NxPenalties.Pipeline.update_weight(:l2, 0.001) # Decrease L2

# Enable/disable penalties
pipeline = NxPenalties.Pipeline.set_enabled(pipeline, :entropy, false)
```

### Gradient-Compatible Computation

Use `compute_total/3` inside `defn`:

```elixir
total = NxPenalties.compute_total(pipeline, tensor)

grad_fn = Nx.Defn.grad(fn t -> NxPenalties.compute_total(pipeline, t) end)
gradients = grad_fn.(tensor)
```

## Divergences

For probability distributions (log-space inputs):

```elixir
# KL Divergence - knowledge distillation
kl_loss = NxPenalties.kl_divergence(student_logprobs, teacher_logprobs)

# JS Divergence - symmetric comparison
js_loss = NxPenalties.js_divergence(p_logprobs, q_logprobs)

# Entropy - encourage/discourage confidence
entropy_penalty = NxPenalties.entropy(logprobs, mode: :penalty)  # Minimize entropy
entropy_bonus = NxPenalties.entropy(logprobs, mode: :bonus)      # Maximize entropy
```

## Gradient Penalties

For Lipschitz smoothness (WGAN-GP style):

```elixir
# Full gradient penalty (expensive - use sparingly)
loss_fn = fn x -> Nx.sum(Nx.pow(x, 2)) end
gp = NxPenalties.gradient_penalty(loss_fn, tensor, target_norm: 1.0)

# Cheaper proxy - output magnitude penalty
mag_penalty = NxPenalties.output_magnitude_penalty(model_output, target: 1.0)

# Interpolated gradient penalty (WGAN-GP)
interp_gp = NxPenalties.interpolated_gradient_penalty(loss_fn, fake, real, target_norm: 1.0)
```

**Performance Warning**: Gradient penalties compute second-order derivatives and are computationally expensive. Best practices:
- Apply every N training steps instead of every step
- Use `output_magnitude_penalty/2` as a cheaper alternative
- See `examples/gradient_penalty.exs` for usage patterns

## Constraints

Structural penalties for representations:

```elixir
alias NxPenalties.Constraints

# Orthogonality - encourage uncorrelated representations
hidden_states = Nx.tensor([[1.0, 0.0], [0.0, 1.0], [0.5, 0.5]])

# Soft mode: penalize off-diagonal correlations only
penalty = Constraints.orthogonality(hidden_states, mode: :soft)

# Hard mode: penalize deviation from identity matrix
penalty = Constraints.orthogonality(hidden_states, mode: :hard)

# Spectral mode: encourage uniform singular values
penalty = Constraints.orthogonality(hidden_states, mode: :spectral)

# Axis options for 3D tensors [batch, seq, vocab]
penalty = Constraints.orthogonality(logits, axis: :sequence)   # Decorrelate positions
penalty = Constraints.orthogonality(embeddings, axis: :vocabulary)  # Decorrelate dimensions
```

```elixir
# Consistency - penalize divergence between paired outputs
clean_output = model.(clean_input)
noisy_output = model.(add_noise.(clean_input))

# MSE (default)
penalty = Constraints.consistency(clean_output, noisy_output)

# L1 distance
penalty = Constraints.consistency(clean_output, noisy_output, metric: :l1)

# Cosine distance
penalty = Constraints.consistency(clean_output, noisy_output, metric: :cosine)

# Symmetric KL for log-probabilities
penalty = Constraints.consistency(logprobs1, logprobs2, metric: :kl)
```

## Polaris Integration

Gradient-level weight decay transforms (AdamW-style):

```elixir
alias NxPenalties.Integration.Polaris, as: PolarisIntegration

# Add L2 weight decay to any optimizer
optimizer =
  Polaris.Optimizers.adam(learning_rate: 0.001)
  |> PolarisIntegration.add_l2_decay(0.01)

# Add L1 decay for sparsity
optimizer =
  Polaris.Optimizers.sgd(learning_rate: 0.01)
  |> PolarisIntegration.add_l1_decay(0.001)

# Elastic Net decay (combined L1 + L2)
optimizer =
  Polaris.Optimizers.adam(learning_rate: 0.001)
  |> PolarisIntegration.add_elastic_net_decay(0.01, 0.3)  # 30% L1, 70% L2

# Compose multiple transforms
optimizer =
  Polaris.Optimizers.adam(learning_rate: 0.001)
  |> PolarisIntegration.add_l2_decay(0.01)
  |> PolarisIntegration.add_l1_decay(0.001)
```

**Loss-Based vs Gradient-Based**: Loss-based regularization (pipeline) adds penalty to loss before backprop. Gradient-based (Polaris transforms) modifies gradients directly. They're equivalent for SGD but differ for adaptive optimizers like Adam—gradient-based is generally preferred for modern training.

## Axon Integration

Wrap your loss function with regularization:

```elixir
alias NxPenalties.Integration.Axon, as: AxonIntegration

# Create penalty pipeline
pipeline = NxPenalties.pipeline([
  {:l2, weight: 0.01}
])

# Wrap loss function
regularized_loss = AxonIntegration.wrap_loss_with_pipeline(
  &Axon.Losses.mean_squared_error/2,
  pipeline
)

# Use in training
model
|> Axon.Loop.trainer(regularized_loss, optimizer)
|> Axon.Loop.run(data, epochs: 10)
```

## API Reference

### Penalty Functions

| Function | Description | Options |
|----------|-------------|---------|
| `l1/2` | L1 norm (Lasso) | `lambda`, `reduction` |
| `l2/2` | L2 norm squared (Ridge) | `lambda`, `reduction`, `center`, `clip` |
| `elastic_net/2` | Combined L1+L2 | `lambda`, `l1_ratio`, `reduction` |

### Divergence Functions

| Function | Description | Options |
|----------|-------------|---------|
| `kl_divergence/3` | KL(P \|\| Q) | `reduction` |
| `js_divergence/3` | Jensen-Shannon | `reduction` |
| `entropy/2` | Shannon entropy | `mode`, `reduction`, `normalize` |

### Gradient Penalty Functions

| Function | Description | Options |
|----------|-------------|---------|
| `gradient_penalty/3` | Gradient norm penalty (expensive) | `target_norm` |
| `output_magnitude_penalty/2` | Cheaper proxy for gradient penalty | `target`, `reduction` |
| `interpolated_gradient_penalty/4` | WGAN-GP style interpolated penalty | `target_norm` |

### Pipeline Functions

| Function | Description |
|----------|-------------|
| `pipeline/1` | Create pipeline from keyword list |
| `compute/3` | Execute pipeline, return `{total, metrics}` |
| `compute_total/3` | Execute pipeline, return tensor only (gradient-safe) |
| `Pipeline.add/4` | Add penalty to pipeline |
| `Pipeline.update_weight/3` | Change penalty weight |
| `Pipeline.set_enabled/3` | Enable/disable penalty |

### Constraint Functions

| Function | Description | Options |
|----------|-------------|---------|
| `Constraints.orthogonality/2` | Decorrelation penalty | `mode` (`:soft`/`:hard`/`:spectral`), `normalize`, `axis` (`:rows`/`:sequence`/`:vocabulary`) |
| `Constraints.consistency/3` | Paired output consistency | `metric` (`:mse`/`:l1`/`:cosine`/`:kl`), `reduction` |

### Polaris Transforms

| Function | Description | Parameters |
|----------|-------------|------------|
| `Integration.Polaris.add_l2_decay/2` | AdamW-style weight decay | `decay` (default: `0.01`) |
| `Integration.Polaris.add_l1_decay/2` | Sparsity-inducing decay | `decay` (default: `0.001`) |
| `Integration.Polaris.add_elastic_net_decay/3` | Combined L1+L2 decay | `decay`, `l1_ratio` |

### Utility Functions

| Function | Description | Returns |
|----------|-------------|---------|
| `NxPenalties.validate/1` | Check for NaN/Inf | `{:ok, tensor}` or `{:error, :nan\|:inf}` |
| `GradientTracker.compute_grad_norm/2` | Gradient L2 norm | `float() \| nil` |
| `GradientTracker.pipeline_grad_norms/2` | Per-penalty grad norms | `map()` |
| `GradientTracker.total_grad_norm/2` | Total pipeline grad norm | `float() \| nil` |

## Telemetry Events

NxPenalties emits telemetry events for monitoring:

```elixir
# Attach handler
:telemetry.attach(
  "nx-penalties-logger",
  [:nx_penalties, :pipeline, :compute, :stop],
  fn _event, measurements, metadata, _config ->
    Logger.info("Pipeline computed in #{measurements.duration}ns")
    Logger.info("Metrics: #{inspect(metadata.metrics)}")
  end,
  nil
)
```

| Event | Measurements | Metadata |
|-------|-------------|----------|
| `[:nx_penalties, :pipeline, :compute, :start]` | `system_time` | `size` |
| `[:nx_penalties, :pipeline, :compute, :stop]` | `duration` | `metrics`, `total` |

## Debugging & Monitoring

### Gradient Tracking

Monitor which penalties contribute most to the gradient signal:

```elixir
pipeline = NxPenalties.pipeline([
  {:l1, weight: 0.001},
  {:l2, weight: 0.01},
  {:entropy, weight: 0.1, opts: [mode: :penalty]}
])

# Enable gradient norm tracking
{total, metrics} = NxPenalties.compute(pipeline, tensor, track_grad_norms: true)

metrics["l1_grad_norm"]       # L2 norm of L1 penalty's gradient
metrics["l2_grad_norm"]       # L2 norm of L2 penalty's gradient
metrics["entropy_grad_norm"]  # L2 norm of entropy penalty's gradient
metrics["total_grad_norm"]    # Combined gradient norm
```

**What it measures**: These are `∂penalty/∂(pipeline_input)`, not `∂penalty/∂params`. The "pipeline input" is whatever tensor you pass to `compute/3`—typically model outputs, activations, or logprobs.

**Performance note**: Gradient tracking requires additional backward passes. Only enable when debugging or for periodic monitoring (e.g., every 100 steps).

### Validation

Check tensors for numerical issues:

```elixir
case NxPenalties.validate(tensor) do
  {:ok, tensor} -> # Tensor is finite, proceed
  {:error, :nan} -> Logger.warning("NaN detected in tensor")
  {:error, :inf} -> Logger.warning("Inf detected in tensor")
end
```

## Performance

All penalty functions are implemented using `Nx.Defn` for JIT compilation:

- **GPU acceleration** - Automatically uses EXLA/CUDA when available
- **Fused operations** - Penalties compose efficiently in the computation graph
- **Minimal overhead** - No runtime option parsing in hot path

## Testing

```bash
# Run tests
mix test

# Run with coverage
mix coveralls.html

# Run quality checks
mix quality  # format + credo + dialyzer
```

## Examples

See the `examples/` directory for complete usage examples:

- `basic_usage.exs` - L1, L2, Elastic Net penalty functions
- `pipeline_composition.exs` - Pipeline creation and manipulation
- `curriculum_learning.exs` - Dynamic weight adjustment over epochs
- `axon_training.exs` - Axon neural network integration
- `polaris_integration.exs` - Gradient-level weight decay transforms
- `constraints.exs` - Orthogonality and consistency penalties
- `entropy_normalization.exs` - Entropy bonus/penalty with normalization
- `gradient_penalty.exs` - Gradient penalties and proxies
- `gradient_tracking.exs` - Monitoring gradient norms

Run examples with:

```bash
mix run examples/basic_usage.exs
./examples/run_all.sh  # Run all examples
```

## Notes

- NxPenalties is tensor-only. Data-aware adapters (e.g., selecting targets from `loss_fn_inputs`) live in downstream libraries such as Tinkex.
- Gradient penalties are computationally heavy; use sparingly and consider proxies.

## Contributing

Contributions are welcome! Please read our contributing guidelines and submit PRs to the `main` branch.

1. Fork the repository
2. Create your feature branch (`git checkout -b feature/amazing-feature`)
3. Write tests first (TDD)
4. Ensure all checks pass (`mix quality && mix test`)
5. Submit a pull request

## License

MIT License - Copyright (c) 2025 North-Shore-AI

See [LICENSE](LICENSE) for details.

## Acknowledgments

- [Nx](https://github.com/elixir-nx/nx) - Numerical computing for Elixir
- [Axon](https://github.com/elixir-nx/axon) - Neural network library
- [Polaris](https://github.com/elixir-nx/polaris) - Gradient optimization