defmodule Axon.Loop do
@moduledoc """
Abstraction for modeling a reduction of a dataset with an accumulated
state for a number of epochs.
Inspired heavily by [PyTorch Ignite](https://pytorch.org/ignite/index.html).
The main abstraction is the `%Axon.Loop{}` struct, which controls a nested
reduction of the form:
Enum.reduce(1..max_epochs, state, fn epoch, state ->
Enum.reduce(data, state, &batch_step/2)
end)
`data` is assumed to be an `Enumerable` or `Stream` of input data which is
handled by a processing function, `batch_step`. The purpose of the loop
abstraction is to take away much of the boilerplate code used in solving machine
learning tasks. Tasks such as normalizing a dataset, hyperparameter optimization,
or training machine learning models boil down to writing one function:
defn batch_step(batch, state) do
# ...do something with batch...
updated_state
end
For tasks such as training a neural network, `state` will encapsulate things
such as model and optimizer state. For supervised learning tasks, `batch_step`
might look something like:
defn batch_step({inputs, targets}, state) do
%{parameters: params, optimizer_state: optim_state} = state
gradients = grad(params, objective_fn.(&1, inputs, targets))
{updates, new_optim_state} = optimizer.(optim_state, params, gradients)
new_params = apply_updates(params, updates)
%{parameters: new_params, optimizer_state: optim_state}
end
`batch_step` takes a batch of `{input, target}` pairs and the current state,
and updates the model parameters based on the gradients received from some arbitrary
objective function. This function will run in a nested loop, iterating over the entire
dataset for `N` epochs before finally returning the trained model state. By defining
1 function, we've created a training loop that works for most machine learning models.
In actuality, the loop abstraction accumulates a struct, `%Axon.Loop.State{}`, which looks
like (assuming `container` is a generic Elixir container of tensors, e.g. map, tuple, etc.):
%Axon.Loop.State{
epoch: integer(),
max_epoch: integer(),
iteration: integer(),
max_iteration: integer(),
metrics: map(string(), container()),
times: map(integer(), integer()),
step_state: container()
}
`batch_step` takes in the batch and the step state field and returns a `step_state`,
which is a generic container of state accumulated at each iteration. The rest of the fields
in the state struct are updated automatically behind the scenes.
The loop must start from some initial step state, thus most tasks must also provide
an additional initialization function to provide some starting point for the step
state. For machine learning tasks, the initialization function will return things like
initial model parameters and optimizer state.
Typically, the final output of the loop is the accumulated final state; however, you
may optionally apply an output transform to extract specific values at the end of the
loop. For example, `Axon.Loop.trainer/4` by default extracts trained model state:
output_transform = fn state ->
state.step_state[:model_state]
end
## Initialize and Step
The core of the Axon loop are the init and step functions. The initialization is an
arity-0 function which provides an initial step state:
init = fn ->
%{params: Axon.init(model)}
end
While the step function is the `batch_step` function mentioned earlier:
step = fn data, state ->
new_state = # ...do something...
new_state
end
Note that any optimization and training anonymous functions that need to be used in the
`batch_step` function can be passed as extra arguments. For example:
step_with_training_arguments = fn data, state, optimizer_update_fn, state_update_fn ->
# ...do something...
end
step = &(step_with_training_arguments.(&1, &2, actual_optimizer_update_fn, actual_state_update_fn))
## Metrics
Often times you want to compute metrics associated with your training iterations.
To accomplish this, you can attach metrics to each `Axon.Loop`. Assuming a `batch_step`
function which looks like:
defn batch_step({inputs, targets}, state) do
%{parameters: params, optimizer_state: optim_state} = state
gradients = grad(params, objective_fn.(&1, inputs, targets))
{updates, new_optim_state} = optimizer.(optim_state, params, gradients)
new_params = apply_updates(params, updates)
# Shown for simplicity, you can optimize this by calculating preds
# along with the gradient calculation
preds = model_fn.(params, inputs)
%{
y_true: targets,
y_pred: preds,
parameters: new_params,
optimizer_state: optim_state
}
end
You can attach metrics to this by using `Axon.Loop.metric/4`:
Axon.Loop.loop(&batch_step/2)
|> Axon.Loop.metric("Accuracy", :accuracy, fn %{y_true: y_, y_pred: y} -> [y_, y] end)
|> Axon.Loop.run(data)
Because metrics work directly on `step_state`, you typically need to provide an output
transform to indicate which values should be passed to your metric function. By default,
Axon assumes a supervised training task with the fields `:y_true` and `:y_pred` present
in the step state. See `Axon.Loop.metric/4` for more information.
Metrics will be tracked in the loop state using the user-provided key. Metrics integrate
seamlessly with the supervised metrics defined in `Axon.Metrics`. You can also use metrics
to keep running averages of some values in the original dataset.
## Events and Handlers
You can instrument several points in the loop using event handlers. By default, several events
are fired when running a loop:
events = [
:started, # After loop state initialization
:epoch_started, # On epoch start
:iteration_started, # On iteration start
:iteration_completed, # On iteration complete
:epoch_completed, # On epoch complete
:epoch_halted, # On epoch halt, if early halted
]
You can attach event handlers to events using `Axon.Loop.handle_event/4`:
loop
|> Axon.Loop.handle_event(:iteration_completed, &log_metrics/1, every: 100)
|> Axon.Loop.run(data)
The above will trigger `log_metrics/1` every 100 times the `:iteration_completed` event
is fired. Event handlers must return a tuple `{status, state}`, where `status` is an
atom with one of the following values:
:continue # Continue epoch, continue looping
:halt_epoch # Halt the epoch, continue looping
:halt_loop # Halt looping
And `state` is an updated `Axon.Loop.State` struct. Handler functions take as input
the current loop state.
It's important to note that event handlers are triggered in the order they are attached
to the loop. If you have two handlers on the same event, they will trigger in order:
loop
|> Axon.Loop.handle_event(:epoch_completed, &normalize_state/1) # Runs first
|> Axon.Loop.handle_event(:epoch_completed, &log_state/1) # Runs second
You may provide filters to filter when event handlers trigger. See `Axon.Loop.handle_event/4`
for more details on valid filters.
## Factories
Axon loops are typically created from one of the factory functions provided in this
module:
* `Axon.Loop.loop/3` - Creates a loop from step function and optional initialization
functions and output transform functions.
* `Axon.Loop.trainer/3` - Creates a supervised training loop from model, loss, and
optimizer.
* `Axon.Loop.evaluator/1` - Creates a supervised evaluator loop from model.
## Running loops
In order to execute a loop, you should use `Axon.Loop.run/3`:
Axon.Loop.run(loop, data, epochs: 10)
## Resuming loops
At times you may want to resume a loop from some previous state. You can accomplish this
with `Axon.Loop.from_state/2`:
loop
|> Axon.Loop.from_state(state)
|> Axon.Loop.run(data)
"""
require Axon.Updates
require Logger
alias __MODULE__, as: Loop
alias Axon.Loop.State
import Axon.Shared
import Nx.Defn
@file_version 1
@default_events [
:started,
:epoch_started,
:iteration_started,
:iteration_completed,
:epoch_completed,
:epoch_halted
]
@default_handlers %{
started: [],
epoch_started: [],
iteration_started: [],
iteration_completed: [],
epoch_completed: [],
epoch_halted: [],
halted: [],
completed: []
}
@valid_axon_losses [
:binary_cross_entropy,
:categorical_cross_entropy,
:categorical_hinge,
:hinge,
:kl_divergence,
:log_cosh,
:mean_absolute_error,
:mean_squared_error,
:poisson,
:soft_margin
]
@valid_axon_optimizers [
:adabelief,
:adagrad,
:adam,
:adamw,
:lamb,
:noisy_sgd,
:radam,
:rmsprop,
:sgd,
:yogi
]
@valid_axon_loss_scale [:identity, :dynamic, :static]
@doc false
@derive {Inspect, only: [:metrics, :handlers]}
@enforce_keys [:init, :step]
defstruct [
:init,
:step,
:attached_state,
:output_transform,
metrics: %{},
handlers: @default_handlers
]
## Step Factories
@doc """
Creates a supervised train step from a model, loss function, and
optimizer.
This function is intended for more fine-grained control over the loop
creation process. It returns a tuple of `{init_fn, step_fn}` where `init_fn`
is an initialization function which returns an initial step state and
`step_fn` is a supervised train step constructed from `model`, `loss`,
and `optimizer`.
`model` must be an Axon struct, a valid defn container
of Axon structs, or a `{init_fn, apply_fn}`-tuple where `init_fn` is
an arity-2 function which initializes the model state and `apply_fn` is
an arity-2 function which applies the forward pass of the model. The forward
pass of the model must return a map with keys `:prediction` and `:state`
representing the model's prediction and updated state for layers which
aggregate state during training.
`loss` must be an atom which matches a function in `Axon.Losses`, a list
of `{loss, weight}` tuples representing a basic weighted loss function
for multi-output models, or an arity-2 function representing a custom loss
function.
`optimizer` must be an atom matching the name of a valid optimizer in `Axon.Optimizers`,
or a `{init_fn, update_fn}` tuple where `init_fn` is an arity-1 function which
initializes the optimizer state from the model parameters and `update_fn` is an
arity-3 function that receives `(gradient, optimizer_state, model_parameters)` and
scales gradient updates with respect to input parameters, optimizer state, and gradients.
The `update_fn` returns `{scaled_updates, optimizer_state}`, which can then be applied to
the model through `model_parameters = Axon.Update.apply_updates(model_parameters, scaled_updates)`.
See `Axon.Updates` for more information on building optimizers.
## Options
* `:seed` - seed to use when constructing models. Seed controls random initialization
of model parameters. Defaults to no seed which constructs a random seed for you at
model build time.
* `:loss_scale` - type of loss-scaling to use, if any. Loss-scaling is necessary when
doing mixed precision training for numerical stability. Defaults to `:identity` or
no loss-scaling.
* `:gradient_accumulation_steps` - number of gradient accumulation steps to take during
training. Gradient accumulation decreases the number of updates by accumulating gradients
between steps, increasing the effective batch size on smaller devices. Defaults to 1.
"""
def train_step(model, loss, optimizer, opts \\ []) do
opts = Keyword.validate!(opts, [:seed, loss_scale: :identity, gradient_accumulation_steps: 1])
loss_scale = opts[:loss_scale] || :identity
gradient_accumulation_steps = opts[:gradient_accumulation_steps] || 1
{init_model_fn, forward_model_fn} = build_model_fns(model, :train, opts)
loss_fn = build_loss_fn(loss)
{init_optimizer_fn, update_optimizer_fn} = build_optimizer_fns(optimizer)
{init_loss_scale, scale_loss, unscale_grads} = build_loss_scale_fns(loss_scale)
init_fn = fn
{inp, tar}, %{} = init_model_state ->
model_state = init_model_fn.(inp, init_model_state)
optimizer_state = init_optimizer_fn.(model_state)
loss_scale_state = init_loss_scale.()
# TODO: is this expensive? Will it compute the entire
# forward?
%{prediction: output} = forward_model_fn.(model_state, inp)
%{
i: Nx.tensor(0),
y_true: zeros_like(tar),
y_pred: zeros_like(output),
loss: Nx.tensor(0.0),
gradient_step: Nx.tensor(0),
model_state: model_state,
gradient_state: zeros_like(model_state, type: :f32),
optimizer_state: optimizer_state,
loss_scale_state: loss_scale_state
}
data, state ->
raise_bad_training_inputs!(data, state)
end
# TODO: We should probably compute in same compute policy as MP
# here
objective_fn = fn model_state, loss_scale_state, inp, tar ->
model_out = forward_model_fn.(model_state, inp)
{scaled_loss, unscaled_loss} =
tar
|> loss_fn.(model_out.prediction)
|> then(fn loss ->
scaled =
loss
|> scale_loss.(loss_scale_state)
|> Nx.divide(gradient_accumulation_steps)
{scaled, Nx.divide(loss, gradient_accumulation_steps)}
end)
{model_out, scaled_loss, unscaled_loss}
end
step_fn = fn
{inp, tar}, %{} = state ->
%{
i: i,
gradient_step: gradient_step,
loss_scale_state: loss_scale_state,
gradient_state: gradient_state,
model_state: model_state,
optimizer_state: optimizer_state,
loss: loss
} = state
{{model_out, _batch_scaled_loss, batch_unscaled_loss}, gradients} =
Nx.Defn.value_and_grad(
model_state,
&objective_fn.(&1, loss_scale_state, inp, tar),
fn x -> elem(x, 1) end
)
{gradients, new_loss_scale_state} = unscale_grads.(gradients, loss_scale_state)
preds = model_out.prediction
new_state = model_out.state
new_loss =
loss
|> Nx.multiply(i)
|> Nx.add(Nx.multiply(batch_unscaled_loss, gradient_accumulation_steps))
|> Nx.divide(Nx.add(i, 1))
{new_model_state, new_optimizer_state, new_gradient_state, new_gradient_step} =
accumulate_gradients(
gradients,
model_state,
new_state,
optimizer_state,
gradient_state,
gradient_step,
update_optimizer_fn,
steps: gradient_accumulation_steps
)
%{
state
| i: Nx.add(i, 1),
gradient_step: new_gradient_step,
y_true: tar,
y_pred: preds,
loss: new_loss,
model_state: new_model_state,
gradient_state: new_gradient_state,
optimizer_state: new_optimizer_state,
loss_scale_state: new_loss_scale_state
}
data, state ->
raise_bad_training_inputs!(data, state)
end
{
Nx.Defn.jit(init_fn, on_conflict: :reuse),
Nx.Defn.jit(step_fn, on_conflict: :reuse)
}
end
defnp accumulate_gradients(
gradients,
model_state,
new_state,
optimizer_state,
gradient_state,
gradient_step,
update_optimizer_fn,
opts \\ []
) do
opts = keyword!(opts, [:steps])
steps = opts[:steps]
{_, new_model_state, _, new_optimizer_state, new_gradient_state, new_gradient_step, _} =
while {gradients, model_state, new_state, optimizer_state, gradient_state, gradient_step,
flag = Nx.tensor(1)},
flag do
if Nx.greater_equal(gradient_step, steps - 1) do
{updates, new_optimizer_state} =
update_optimizer_fn.(gradients, optimizer_state, model_state)
new_gradient_state = zeros_like(model_state)
new_model_state = Axon.Updates.apply_updates(model_state, updates, new_state)
{gradients, new_model_state, new_state, new_optimizer_state, new_gradient_state, 0,
Nx.tensor(0)}
else
acc_gradients = deep_merge(gradient_state, gradients, fn x, y -> x + y end)
{gradients, model_state, new_state, optimizer_state, acc_gradients, gradient_step + 1,
Nx.tensor(0)}
end
end
{new_model_state, new_optimizer_state, new_gradient_state, new_gradient_step}
end
defp raise_bad_training_inputs!(data, state) do
raise ArgumentError,
"invalid arguments given to train-step initialization," <>
" this usually happens when you pass a invalid parameters" <>
" to Axon.Loop.run with a loop constructed using Axon.Loop.trainer" <>
" or Axon.Loop.evaluator, supervised training and evaluation loops" <>
" expect a stream or enumerable of inputs" <>
" of the form {x_train, y_train} where x_train and y_train" <>
" are batches of tensors, you must also provide an initial model" <>
" state such as an empty map: Axon.Loop.run(loop, data, %{}), got" <>
" input data: #{inspect(data)} and initial model state: " <>
" #{inspect(state)}"
end
@doc """
Creates a supervised evaluation step from a model and model state.
This function is intended for more fine-grained control over the loop
creation process. It returns a tuple of `{init_fn, step_fn}` where
`init_fn` returns an initial step state and `step_fn` performs a
single evaluation step.
"""
def eval_step(model) do
{_, forward_model_fn} = build_model_fns(model, :inference, [])
init_fn = fn
{inp, tar}, state ->
# TODO: Is this expensive
output = forward_model_fn.(state, inp)
output_type = Nx.type(output)
output_shape = Nx.shape(output)
y_pred = Nx.broadcast(Nx.tensor(0, type: output_type), output_shape)
%{
model_state: state,
y_true: zeros_like(tar),
y_pred: y_pred
}
data, state ->
raise_bad_training_inputs!(data, state)
end
step_fn = fn
{inp, tar}, %{model_state: model_state} ->
%{
model_state: model_state,
y_true: tar,
y_pred: forward_model_fn.(model_state, inp)
}
data, state ->
raise_bad_training_inputs!(data, state)
end
{
Nx.Defn.jit(init_fn, on_conflict: :reuse),
Nx.Defn.jit(step_fn, on_conflict: :reuse)
}
end
## Loop Factories
@doc """
Creates a loop from `step_fn`, an optional `init_fn`, and an
optional `output_transform`.
`step_fn` is an arity-2 function which takes a batch and state
and returns an updated step state:
defn batch_step(batch, step_state) do
step_state + 1
end
`init_fn` by default is an identity function which forwards its
initial arguments as the model state. You should define a custom
initialization function if you require a different behavior:
defn init_step_state(state) do
Map.merge(%{foo: 1}, state)
end
You may use `state` in conjunction with initialization functions in
`init_fn`. For example, `train_step/3` uses initial state as initial
model parameters to allow initializing models from partial parameterizations.
`step_batch/2` and `init_step_state/1` are typically called from
within `Nx.Defn.jit/3`. While JIT-compilation will work with anonymous functions,
`def`, and `defn`, it is recommended that you use the stricter `defn` to define
both functions in order to avoid bugs or cryptic errors.
`output_transform/1` applies a transformation on the final accumulated loop state.
This is useful for extracting specific fields from a loop and piping them into
additional functions.
"""
def loop(step_fn, init_fn \\ &default_init/2, output_transform \\ & &1)
when is_function(step_fn, 2) and is_function(init_fn, 2) and
is_function(output_transform, 1) do
%Loop{
init: init_fn,
step: step_fn,
output_transform: output_transform
}
end
defp default_init(_data, state), do: state
@doc """
Creates a supervised training loop from a model, loss function,
and optimizer.
This function is useful for training models on most standard supervised
learning tasks. It assumes data consists of tuples of input-target pairs,
e.g. `[{x0, y0}, {x1, y1}, ..., {xN, yN}]` where `x0` and `y0` are batched
tensors or containers of batched tensors.
It defines an initialization function which first initializes model state
using the given model and then initializes optimizer state using the initial
model state. The step function uses a differentiable objective function
defined with respect to the model parameters, input data, and target data
using the given loss function. It then updates model parameters using the
given optimizer in order to minimize loss with respect to the model parameters.
`model` must be an Axon struct, a valid defn container
of Axon structs, or a `{init_fn, apply_fn}`-tuple where `init_fn` is
an arity-2 function which initializes the model state and `apply_fn` is
an arity-2 function which applies the forward pass of the model.
`loss` must be an atom which matches a function in `Axon.Losses`, a list
of `{loss, weight}` tuples representing a basic weighted loss function
for multi-output models, or an arity-2 function representing a custom loss
function.
`optimizer` must be an atom matching the name of a valid optimizer in `Axon.Optimizers`,
or a `{init_fn, update_fn}` tuple where `init_fn` is an arity-1 function which
initializes the optimizer state from attached parameters and `update_fn` is an
arity-3 function which scales gradient updates with respect to input parameters,
optimizer state, and gradients. See `Axon.Updates` for more information on building
optimizers.
This function creates a step function which outputs a map consisting of the following
fields for `step_state`:
%{
y_pred: tensor() | container(tensor()), # Model predictions for use in metrics
y_true: tensor() | container(tensor()), # True labels for use in metrics
loss: tensor(), # Running average of loss over epoch
model_state: container(tensor()), # Model parameters and state
optimizer_state: container(tensor()) # Optimizer state associated with each parameter
}
## Examples
### Basic usage
data = Stream.zip(input, target)
model = Axon.input("input", shape: {nil, 32}) |> Axon.dense(1, activation: :sigmoid)
model
|> Axon.Loop.trainer(:binary_cross_entropy, :adam)
|> Axon.Loop.run(data)
### Customizing Optimizer
model
|> Axon.Loop.trainer(:binary_cross_entropy, Axon.Optimizers.adam(0.05))
|> Axon.Loop.run(data)
### Custom loss
loss_fn = fn y_true, y_pred -> Nx.cos(y_true, y_pred) end
model
|> Axon.Loop.trainer(loss_fn, Axon.Optimizers.rmsprop(0.01))
|> Axon.Loop.run(data)
### Multiple objectives with multi-output model
model = {Axon.input("input_0", shape: {nil, 1}), Axon.input("input_1", shape: {nil, 2})}
loss_weights = [mean_squared_error: 0.5, mean_absolute_error: 0.5]
model
|> Axon.Loop.trainer(loss_weights, :sgd)
|> Axon.Loop.run(data)
## Options
* `:log` - training loss and metric log interval. Set to 0 to silence
training logs. Defaults to 50
* `:seed` - seed to use when constructing models. Seed controls random initialization
of model parameters. Defaults to no seed which constructs a random seed for you at
model build time.
* `:loss_scale` - type of loss-scaling to use, if any. Loss-scaling is necessary when
doing mixed precision training for numerical stability. Defaults to `:identity` or
no loss-scaling.
* `:gradient_accumulation_steps` - number of gradient accumulation steps to take during
training. Gradient accumulation decreases the number of updates by accumulating gradients
between steps, increasing the effective batch size on smaller devices. Defaults to 1.
"""
def trainer(model, loss, optimizer, opts \\ []) do
opts = Keyword.validate!(opts, [:seed, :loss_scale, :gradient_accumulation_steps, log: 50])
# Build loss now so we can use it as a metric
loss_fn = build_loss_fn(loss)
step_opts = Keyword.take(opts, [:gradient_accumulation_steps, :loss_cale, :seed])
{init_fn, step_fn} = train_step(model, loss_fn, optimizer, step_opts)
log_interval = opts[:log] || 50
output_transform = fn state -> state.step_state[:model_state] end
loop =
step_fn
|> loop(init_fn, output_transform)
|> metric(loss_fn, "loss")
if log_interval > 0 do
loop
|> log(&supervised_log_message_fn/1,
event: :iteration_completed,
filter: [every: log_interval]
)
|> log(fn _ -> "\n" end, event: :epoch_completed)
else
loop
end
end
defp format_metric({name, val}) do
{type, _} = val.type
unless Nx.size(val) == 1 do
raise ArgumentError,
"metric value is not a scalar, this may happen if you forget" <>
" to specify a reduction such as mean or sum in a metric or" <>
" loss function, if this is a loss function, try adding" <>
" `reduction: :mean` as an option"
end
case type do
t when t in [:s, :u] -> "#{name}: #{Nx.to_number(val)}"
:f -> "#{name}: #{float_format(~c"~.7f", Nx.to_number(val))}"
:bf -> "#{name}: #{float_format(~c"~.3f", Nx.to_number(val))}"
_ -> "#{name}: unsupported type of metric #{inspect(type)}"
end
end
defp float_format(_format, :nan), do: "NaN"
defp float_format(_format, :infinity), do: "Inf"
defp float_format(_format, :neg_infinity), do: "-Inf"
defp float_format(format, val) when is_float(val), do: :io_lib.format(format, [val])
defp supervised_log_message_fn(state, log_epochs \\ true) do
%State{metrics: metrics, epoch: epoch, iteration: iter} = state
metrics =
metrics
|> Enum.map(&format_metric/1)
|> Enum.join(" ")
if log_epochs do
"\rEpoch: #{Nx.to_number(epoch)}, Batch: #{Nx.to_number(iter)}, #{metrics}"
else
"\rBatch: #{Nx.to_number(iter)}, #{metrics}"
end
end
@doc """
Creates a supervised evaluator from a model.
An evaluator can be used for things such as testing and validation of models
after or during training. It assumes `model` is an Axon struct, container of
structs, or a tuple of `init` / `apply` functions. `model_state` must be a
container usable from within `model`.
The evaluator returns a step state of the form:
%{
y_true: labels,
y_pred: predictions
}
Such that you can attach any number of supervised metrics to the evaluation
loop:
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric("Accuracy", :accuracy)
You must pass a compatible trained model state to `Axon.Loop.run/4` when using
supervised evaluation loops. For example, if you've binded the result of a training
run to `trained_model_state`, you can run the trained model through an evaluation
run like this:
model
|> Axon.Loop.evaluator()
|> Axon.Loop.run(data, trained_model_state, compiler: EXLA)
This function applies an output transform which returns the map of metrics accumulated
over the given loop.
"""
def evaluator(model) do
{init_fn, step_fn} = eval_step(model)
output_transform = fn state -> state.metrics end
loop(step_fn, init_fn, output_transform)
|> log(&supervised_log_message_fn(&1, false), event: :iteration_completed)
end
@doc """
Adds a metric of the given name to the loop.
A metric is a function which tracks or measures some value with respect
to values in the step state. For example, when training classification
models, it's common to track the model's accuracy during training:
loop
|> Axon.Loop.metric(:accuracy, "Accuracy")
By default, metrics assume a supervised learning task and extract the fields
`[:y_true, :y_pred]` from the step state. If you wish to work on a different
value, you can use an output transform. An output transform is a list of keys
to extract from the output state, or a function which returns a flattened list
of values to pass to the given metric function. Values received from output
transforms are passed to the given metric using:
value = output_transform.(step_state)
apply(metric, value)
Thus, even if you want your metric to work on a container, your output transform
must return a list.
`metric` must be an atom which matches the name of a metric in `Axon.Metrics`, or
an arbitrary function which returns a tensor or container.
`name` must be a string or atom used to store the computed metric in the loop
state. If names conflict, the last attached metric will take precedence:
loop
|> Axon.Loop.metric(:mean_squared_error, "Error") # Will be overwritten
|> Axon.Loop.metric(:mean_absolute_error, "Error") # Will be used
By default, metrics keep a running average of the metric calculation. You can
override this behavior by changing `accumulate`:
loop
|> Axon.Loop.metric(:true_negatives, "tn", :running_sum)
Accumulation function can be one of the accumulation combinators in Axon.Metrics
or an arity-3 function of the form: `accumulate(acc, obs, i) :: new_acc`.
"""
def metric(
%Loop{metrics: metric_fns} = loop,
metric,
name \\ nil,
accumulate \\ :running_average,
transform_or_fields \\ [:y_true, :y_pred]
) do
name =
case name do
nil ->
if is_atom(metric) do
Atom.to_string(metric)
else
raise ArgumentError, "must provide name if using a custom metric"
end
name ->
name
end
case metric_fns do
%{^name => _} ->
Logger.warning(
"Metric #{name} declared twice in loop. Original metric will be overridden."
)
_ ->
:ok
end
metric_fn = build_metric_fn(metric, accumulate, transform_or_fields)
# For internal use we keep the raw metric as well as the compiled metric
# function
%Loop{loop | metrics: Map.put(metric_fns, name, {metric_fn, metric})}
end
@doc """
Adds a handler function to the loop which will be triggered on `event`
with an optional filter.
Events take place at different points during loop execution. The default
events are:
events = [
:started, # After loop state initialization
:epoch_started, # On epoch start
:iteration_started, # On iteration start
:iteration_completed, # On iteration complete
:epoch_completed, # On epoch complete
:epoch_halted, # On epoch halt, if early halted
]
Generally, event handlers are side-effecting operations which provide some
sort of inspection into the loop's progress. It's important to note that
if you define multiple handlers to be triggered on the same event, they
will execute in order from when they were attached to the training
loop:
loop
|> Axon.Loop.handle_event(:epoch_started, &normalize_step_state/1) # executes first
|> Axon.Loop.handle_event(:epoch_started, &log_step_state/1) # executes second
Thus, if you have separate handlers which alter or depend on loop state,
you need to ensure they are ordered correctly, or combined into a single
event handler for maximum control over execution.
`event` must be an atom representing the event to trigger `handler` or a
list of atoms indicating `handler` should be triggered on multiple events.
`event` may be `:all` which indicates the handler should be triggered on
every event during loop processing.
`handler` must be an arity-1 function which takes as input loop state and
returns `{status, state}`, where `status` is an atom with one of the following
values:
:continue # Continue epoch, continue looping
:halt_epoch # Halt the epoch, continue looping
:halt_loop # Halt looping
`filter` is an atom representing a valid filter predicate, a keyword of
predicate-value pairs, or a function which takes loop state and returns
a `true`, indicating the handler should run, or `false`, indicating the
handler should not run. Valid predicates are:
:always # Always trigger event
:once # Trigger on first event firing
Valid predicate-value pairs are:
every: N # Trigger every `N` event
only: N # Trigger on `N` event
**Warning: If you modify the step state in an event handler, it will trigger
potentially excessive recompilation and result in significant additinal overhead
during loop execution.**
"""
def handle_event(%Loop{handlers: handle_fns} = loop, event, handler, filter \\ :always) do
filter = build_filter_fn(filter)
handle_fns =
case event do
[_ | _] = events ->
Enum.reduce(events, handle_fns, &add_event_handler(&1, &2, {handler, filter}))
:all ->
Enum.reduce(@default_events, handle_fns, &add_event_handler(&1, &2, {handler, filter}))
event when is_atom(event) ->
add_event_handler(event, handle_fns, {handler, filter})
end
%Loop{loop | handlers: handle_fns}
end
@doc false
@deprecated "handle/4 is deprecated, use handle_event/4 instead"
def handle(%Loop{} = loop, event, handler, filter \\ :always) do
handle_event(loop, event, handler, filter)
end
@doc """
Adds a handler function which logs the given message produced
by `message_fn` to the given IO device every `event` satisfying
`filter`.
In most cases, this is useful for inspecting the contents of
the loop state at intermediate stages. For example, the default
`trainer` loop factory attaches IO logging of epoch, batch, loss
and metrics.
It's also possible to log loop state to files by changing the
given IO device. By default, the IO device is `:stdio`.
`message_fn` should take the loop state and return a binary
representing the message to be written to the IO device.
"""
def log(%Loop{} = loop, message_fn, opts \\ []) when is_function(message_fn, 1) do
opts = Keyword.validate!(opts, event: :iteration_completed, filter: :always, device: :stdio)
event = opts[:event] || :iteration_completed
filter = opts[:filter] || :always
device = opts[:device] || :stdio
log_fn = fn %State{} = state ->
try do
msg = message_fn.(state)
IO.write(device, msg)
{:continue, state}
rescue
error ->
Logger.error(
"Error on Axon.Loop.log/5 callback: " <>
Exception.format(:error, error, __STACKTRACE__)
)
{:halt_loop, state}
end
end
handle_event(loop, event, log_fn, filter)
end
@doc """
Adds a handler function which tests the performance of `model`
against the given validation set.
This handler assumes the loop state matches the state initialized
in a supervised training loop. Typically, you'd call this immediately
after creating a supervised training loop:
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.validate(model, validation_data)
Please note that you must pass the same (or an equivalent) model
into this method so it can be used during the validation loop. The
metrics which are computed are those which are present BEFORE the
validation handler was added to the loop. For the following loop:
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.metric(:binary_cross_entropy)
only `:mean_absolute_error` will be computed at validation time.
The returned loop state is altered to contain validation
metrics for use in later handlers such as early stopping and model
checkpoints. Since the order of execution of event handlers is in
the same order they are declared in the training loop, you MUST call
this method before any other handler which expects or may use
validation metrics.
By default the validation loop runs after every epoch; however, you
can customize it by overriding the default event and event filters:
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.validate(model, validation_data, event: :iteration_completed, filter: [every: 10_000])
|> Axon.Loop.metric(:binary_cross_entropy)
"""
def validate(
%Loop{metrics: metric_fns} = loop,
model,
validation_data,
opts \\ []
) do
opts = Keyword.validate!(opts, event: :epoch_completed, filter: :always)
event = opts[:event] || :epoch_completed
filter = opts[:filter] || :always
evaluator = evaluator(model)
validation_loop = fn %State{metrics: metrics, step_state: step_state} = state ->
%{model_state: model_state} = step_state
metrics =
Enum.reduce(metric_fns, evaluator, fn {k, {_, v}}, loop -> metric(loop, v, k) end)
|> run(validation_data, model_state)
|> Access.get(0)
|> Map.new(fn {k, v} ->
{"validation_#{k}", v}
end)
|> Map.merge(metrics, fn _, _, v -> v end)
{:continue, %{state | metrics: metrics}}
end
handle_event(loop, event, validation_loop, filter)
end
@doc """
Adds a handler function which monitors the given metric
and fires some action when the given metric meets some
criteria.
This function is a generalization of handlers such as
`Axon.Loop.reduce_lr_on_plateau/3` and `Axon.Loop.early_stop/3`.
You must specify a metric to monitor that is present in
the state metrics. This handler will then monitor the value
of the metric at the specified intervals and fire the specified
function if the criteria is met.
You must also specify a name for the monitor attached to the
given metric. This will be used to store metadata associated
with the monitor.
The common case of monitor is to track improvement of metrics
and take action if metrics haven't improved after a certain number
of events. However, you can also set a monitor up to trigger if
a metric hits some criteria (such as a threshold) by passing a
custom monitoring mode.
## Options
* `:event` - event to fire handler on. Defaults to `:epoch_completed`.
* `:filter` - event filter to attach to handler. Defaults to `:always`.
* `:patience` - number of given events to wait for improvement. Defaults
to `3`.
* `:mode` - whether given metric is being minimized or maximized. One of
`:min`, `:max` or an arity-1 function which returns `true` or `false`.
Defaults to `:min`.
"""
def monitor(%Loop{} = loop, metric, fun, name, opts \\ []) do
opts =
Keyword.validate!(opts, event: :epoch_completed, filter: :always, mode: :max, patience: 3)
event = opts[:event] || :epoch_completed
filter = opts[:filter] || :always
mode = opts[:mode] || :min
patience = opts[:patience] || 3
handle_event(loop, event, &monitor_impl(&1, metric, fun, name, mode, patience), filter)
end
defp monitor_impl(
%State{metrics: metrics, handler_metadata: handler_meta} = state,
monitor,
fun,
name,
mode,
patience
) do
unless Map.has_key?(metrics, monitor) do
raise ArgumentError,
"invalid metric to monitor, key #{inspect(monitor)} not present in metrics"
end
cur_criteria_value = metrics[monitor]
{prev_criteria_value, since_last_improvement} =
case handler_meta[name] do
nil ->
{nil, 0}
meta ->
{meta[monitor], meta[:since_last_improvement]}
end
improved? =
case mode do
:min ->
prev_criteria_value == nil or
Nx.to_number(Nx.less(cur_criteria_value, prev_criteria_value)) == 1
:max ->
prev_criteria_value == nil or
Nx.to_number(Nx.greater(cur_criteria_value, prev_criteria_value)) == 1
fun when is_function(fun, 1) ->
fun.(cur_criteria_value)
end
over_patience? = since_last_improvement >= patience
cond do
improved? ->
default = %{monitor => cur_criteria_value, :since_last_improvement => 0}
updated_handler_meta =
Map.update(handler_meta, name, default, fn meta ->
meta
|> Map.update(monitor, cur_criteria_value, fn _ -> cur_criteria_value end)
|> Map.update(:since_last_improvement, 0, fn _ -> 0 end)
end)
{:continue, %{state | handler_metadata: updated_handler_meta}}
not improved? and not over_patience? ->
default = %{monitor => prev_criteria_value, :since_last_improvement => 0}
updated_handler_meta =
Map.update(handler_meta, name, default, fn meta ->
Map.update(meta, :since_last_improvement, 0, fn x -> x + 1 end)
end)
{:continue, %{state | handler_metadata: updated_handler_meta}}
true ->
{status, state} = fun.(state)
default = %{monitor => cur_criteria_value, :since_last_improvement => 0}
updated_handler_meta = Map.put(handler_meta, name, default)
{status, %{state | handler_metadata: updated_handler_meta}}
end
end
@doc """
Adds a handler function which saves loop checkpoints on a given
event, optionally with metric-based criteria.
By default, loop checkpoints will be saved at the end of every
epoch in the current working directory under the `checkpoint/`
path. Checkpoints are serialized representations of loop state
obtained from `Axon.Loop.serialize_state/2`. Serialization
options will be forwarded to `Axon.Loop.serialize_state/2`.
You can customize checkpoint events by passing `:event` and `:filter`
options:
loop
|> Axon.Loop.checkpoint(event: :iteration_completed, filter: [every: 50])
Checkpoints are saved under the `checkpoint/` directory with a pattern
of `checkpoint_{epoch}.ckpt`. You can customize the path and pattern
with the `:path` and `:file_pattern` options:
my_file_pattern =
fn %Axon.Loop.State{epoch: epoch, iteration: iter} ->
"checkpoint_\#{epoch}_\#{iter}"
end
loop
|> Axon.Loop.checkpoint(path: "my_checkpoints", file_pattern: my_file_pattern)
If you'd like to only save checkpoints based on some metric criteria,
you can specify the `:criteria` option. `:criteria` must be a valid key
in metrics:
loop
|> Axon.Loop.checkpoint(criteria: "validation_loss")
The default criteria mode is `:min`, meaning the min score metric will
be considered "best" when deciding to save on a given event. Valid modes
are `:min` and `:max`:
loop
|> Axon.Loop.checkpoint(criteria: "validation_accuracy", mode: :max)
## Options
* `:event` - event to fire handler on. Defaults to `:epoch_completed`.
* `:filter` - event filter to attach to handler. Defaults to `:always`.
* `:patience` - number of given events to wait for improvement. Defaults
to `3`.
* `:mode` - whether given metric is being minimized or maximized. One of
`:min`, `:max` or an arity-1 function which returns `true` or `false`.
Defaults to `:min`.
* `:path` - path to directory to save checkpoints. Defaults to `checkpoint`
* `:file_pattern` - arity-1 function which returns a string file pattern
based on the current loop state. Defaults to saving checkpoints to files
`checkpoint_\#{epoch}_\#{iteration}.ckpt`.
"""
def checkpoint(%Loop{} = loop, opts \\ []) do
{event, opts} = Keyword.pop(opts, :event, :epoch_completed)
{filter, opts} = Keyword.pop(opts, :filter, :always)
{path, opts} = Keyword.pop(opts, :path, "checkpoint")
{file_pattern, opts} = Keyword.pop(opts, :file_pattern, &default_checkpoint_file/1)
{criteria, opts} = Keyword.pop(opts, :criteria)
{mode, serialize_opts} = Keyword.pop(opts, :mode, :min)
checkpoint_fun = &checkpoint_impl(&1, path, file_pattern, serialize_opts)
if criteria do
monitor(loop, criteria, checkpoint_fun, :checkpoint,
mode: mode,
event: event,
filter: filter
)
else
handle_event(loop, event, checkpoint_fun, filter)
end
end
defp default_checkpoint_file(%State{epoch: epoch, iteration: step}),
do: "checkpoint_#{epoch}_#{step}.ckpt"
defp checkpoint_impl(%State{} = state, path, file_pattern, serialize_opts) do
serialized_state = serialize_state(state, serialize_opts)
filename = Path.join([path, file_pattern.(state)])
dirname = Path.dirname(filename)
File.mkdir_p!(dirname)
File.write!(filename, serialized_state)
{:continue, state}
end
@doc """
Adds a handler function which halts a loop if the given
metric does not improve between events.
By default, this will run after each epoch and track the
improvement of a given metric.
You must specify a metric to monitor and the metric must
be present in the loop state. Typically, this will be
a validation metric:
model
|> Axon.Loop.trainer(loss, optim)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(val_data)
|> Axon.Loop.early_stop("validation_accuracy")
It's important to remember that handlers are executed in the
order they are added to the loop. For example, if you'd like
to checkpoint a loop after every epoch and use early stopping,
most likely you want to add the checkpoint handler before
the early stopping handler:
model
|> Axon.Loop.trainer(loss, optim)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.checkpoint()
|> Axon.Loop.early_stop("accuracy")
That will ensure checkpoint is always fired, even if the loop
exited early.
"""
def early_stop(%Loop{} = loop, monitor, opts \\ []) do
event = opts[:event] || :epoch_completed
filter = opts[:filter] || :always
patience = opts[:patience] || 3
mode = opts[:mode] || :min
early_stop_fn = fn state -> {:halt_loop, state} end
monitor(loop, monitor, early_stop_fn, :early_stop,
event: event,
filter: filter,
patience: patience,
mode: mode
)
end
@doc """
Adds a handler function which reduces the learning rate by
the given factor if the given metric does not improve between
events.
By default, this will run after each epoch and track the
improvement of a given metric.
You must specify a metric to monitor and the metric must
be present in the loop state. Typically, this will be
a validation metric:
model
|> Axon.Loop.trainer(loss, optim)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, val_data)
|> Axon.Loop.reduce_lr_on_plateau("accuracy", mode: :max)
## Options
* `:event` - event to fire handler on. Defaults to `:epoch_completed`.
* `:filter` - event filter to attach to handler. Defaults to `:always`.
* `:patience` - number of given events to wait for improvement. Defaults
to `3`.
* `:mode` - whether given metric is being minimized or maximized. Defaults
to `:min`.
* `:factor` - factor to decrease learning rate by. Defaults to `0.1`.
"""
def reduce_lr_on_plateau(%Loop{} = loop, monitor, opts \\ []) do
event = opts[:event] || :epoch_completed
filter = opts[:filter] || :always
patience = opts[:patience] || 3
mode = opts[:mode] || :min
factor = opts[:factor] || 0.1
reduce_lr_fn = fn %State{step_state: step_state} = state ->
unless Map.has_key?(step_state, :optimizer_state) do
raise ArgumentError,
"given loop state is not a supervised training loop, key `:optimizer_state`" <>
" was not present in the given step state"
end
# TODO: This is a strong assumption
%{scale: current_lr} = elem(step_state[:optimizer_state], 0)
updated_lr = Nx.multiply(current_lr, factor)
updated_optimizer_state = put_elem(step_state[:optimizer_state], 0, %{scale: updated_lr})
updated_step_state = %{step_state | optimizer_state: updated_optimizer_state}
{:continue, %{state | step_state: updated_step_state}}
end
monitor(loop, monitor, reduce_lr_fn, :reduce_lr,
event: event,
filter: filter,
mode: mode,
patience: patience
)
end
@compile {:no_warn_undefined, Kino.VegaLite}
@doc """
Adds a handler function which updates a `Kino.VegaLite` plot.
By default, this will run after every iteration.
You must specify a plot to push to and a metric to track. The `:x` axis will be the iteration count, labeled `"step"`. The metric must match the name given to the `:y` axis in your `VegaLite` plot:
plot =
Vl.new()
|> Vl.mark(:line)
|> Vl.encode_field(:x, "step", type: :quantitative)
|> Vl.encode_field(:y, "loss", type: :quantitative)
|> Kino.VegaLite.new()
|> Kino.render()
model
|> Axon.Loop.trainer(loss, optim)
|> Axon.Loop.kino_vega_lite_plot(plot, "loss")
## Options
* `:event` - event to fire handler on. Defaults to `:iteration_completed`.
* `:filter` - event filter to attach to handler. Defaults to `:always`.
"""
def kino_vega_lite_plot(loop, plot, metric, opts \\ []) do
assert_kino_vega_lite!("plot/5")
opts = Keyword.validate!(opts, event: :iteration_completed, filter: :always)
handle_event(
loop,
opts[:event],
fn %{
metrics: metrics,
handler_metadata: handler_meta
} = state ->
unless Map.has_key?(metrics, metric) do
raise ArgumentError,
"invalid metric to plot, key #{inspect(metric)} not present in metrics"
end
{iteration, handler_meta} = absolute_iteration(handler_meta)
Kino.VegaLite.push(plot, %{
"step" => iteration,
metric => Nx.to_number(metrics[metric])
})
{:continue, %{state | handler_metadata: handler_meta}}
end,
opts[:filter]
)
end
defp absolute_iteration(
%{"plot" => %{"absolute_iteration" => absolute_iteration}} = handler_meta
),
do:
{absolute_iteration,
put_in(handler_meta, ["plot", "absolute_iteration"], absolute_iteration + 1)}
defp absolute_iteration(handler_meta),
do: {0, Map.put(handler_meta, "plot", %{"absolute_iteration" => 1})}
defp assert_kino_vega_lite!(fn_name) do
unless Code.ensure_loaded?(Kino.VegaLite) do
raise RuntimeError, """
#{fn_name} depends on the :kino_vega_lite package.
You can install it by adding
{:kino_vega_lite, "~> 0.1.7"}
to your dependency list.
"""
end
end
@doc """
Attaches `state` to the given loop in order to resume looping
from a previous state.
It's important to note that a loop's attached state takes precedence
over defined initialization functions. Given initialization function:
defn init_state(), do: %{foo: 1, bar: 2}
And an attached state:
state = %State{step_state: %{foo: 2, bar: 3}}
`init_state/0` will never execute, and instead the initial step state
of `%{foo: 2, bar: 3}` will be used.
"""
def from_state(%Loop{} = loop, %State{} = state) do
%{loop | attached_state: state}
end
@doc """
Serializes loop state to a binary for saving and loading
loop from previous states.
You can consider the serialized state to be a checkpoint of
all state at a given iteration and epoch.
By default, the step state is serialized using `Nx.serialize/2`;
however, this behavior can be changed if step state is an application
specific container. For example, if you introduce your own data
structure into step_state, `Nx.serialize/2` will not be sufficient
for serialization - you must pass custom serialization as an option
with `:serialize_step_state`.
Additional `opts` controls serialization options such as compression.
It is forwarded to `:erlang.term_to_binary/2`.
"""
def serialize_state(%State{} = state, opts \\ []) do
{serialize_step_state_fn, opts} = Keyword.pop(opts, :serialize_step_state, &Nx.serialize/2)
serialized_step_state = serialize_step_state_fn.(state.step_state, opts)
serialized_metrics = Nx.serialize(state.metrics, opts)
state_map = Map.from_struct(state)
state_map = %{state_map | step_state: serialized_step_state, metrics: serialized_metrics}
:erlang.term_to_binary({@file_version, state_map}, opts)
end
@doc """
Deserializes loop state from a binary.
It is the opposite of `Axon.Loop.serialize_state/2`.
By default, the step state is deserialized using `Nx.deserialize.2`;
however, this behavior can be changed if step state is an application
specific container. For example, if you introduce your own data
structure into step_state and you customized the serialization logic,
`Nx.deserialize/2` will not be sufficient for deserialization. - you
must pass custom logic with `:deserialize_step_state`.
"""
def deserialize_state(serialized, opts \\ []) do
{deserialize_step_state_fn, opts} =
Keyword.pop(opts, :deserialize_step_state, &Nx.deserialize/2)
{1, state_map} = :erlang.binary_to_term(serialized, [:safe | opts])
step_state = deserialize_step_state_fn.(state_map.step_state, opts)
metrics = Nx.deserialize(state_map.metrics, opts)
state_map = %{state_map | step_state: step_state, metrics: metrics}
struct!(Axon.Loop.State, state_map)
end
@doc """
Runs the given loop on data with the given options.
`loop` must be a valid Axon.Loop struct built from one of the
loop factories provided in this module.
`data` must be an Enumerable or Stream which yields batches of
data on each iteration.
## Options
* `:epochs` - max epochs to run loop for. Must be non-negative integer.
Defaults to `1`.
* `:iterations` - max iterations to run each epoch. Must be non-negative
integer. Defaults to `-1` or no max iterations.
* `:jit_compile?` - whether or not to JIT compile initialization and step
functions. JIT compilation must be used for gradient computations. Defaults
to true.
* `:strict?` - whether or not to compile step functions strictly. If this flag
is set, the loop will raise on any cache miss during the training loop. Defaults
to true.
* `:debug` - run loop in debug mode to trace loop progress. Defaults to
false.
Additional options are forwarded to `Nx.Defn.jit` as JIT-options. If no JIT
options are set, the default options set with `Nx.Defn.default_options` are
used.
"""
def run(loop, data, init_state \\ %{}, opts \\ []) do
{max_epochs, opts} = Keyword.pop(opts, :epochs, 1)
{max_iterations, opts} = Keyword.pop(opts, :iterations, -1)
{jit_compile?, opts} = Keyword.pop(opts, :jit_compile?, true)
{strict?, jit_opts} = Keyword.pop(opts, :strict?, true)
debug? = Keyword.get(jit_opts, :debug, false)
if jit_opts != [] do
Logger.debug("Forwarding options: #{inspect(jit_opts)} to JIT compiler")
end
%Loop{
init: init_fn,
step: step_fn,
handlers: handler_fns,
metrics: metric_fns,
attached_state: attached_state,
output_transform: output_transform
} = loop
sample_data =
case Enum.take(data, 1) do
[sample_data | _] ->
sample_data
[] ->
raise ArgumentError,
"Axon.Loop.run received empty dataset, this can happen" <>
" if you've built a stream and accidentally filtered" <>
" out every value, your dataset must have at least one" <>
" entry"
end
if debug? do
Logger.debug("Axon.Loop started initializing loop state")
end
{time, loop_state} =
:timer.tc(fn ->
init_loop_state(
init_fn,
sample_data,
init_state,
attached_state,
max_epochs,
max_iterations,
jit_compile?,
jit_opts
)
end)
epoch_start = loop_state.epoch
epoch_end = max_epochs + epoch_start - 1
if debug? do
Logger.debug("Axon.Loop finished initializing loop state in #{us_to_ms(time)}ms")
end
# TODO: Can we infer here?
zero_metrics = Map.new(metric_fns, fn {k, _} -> {k, Nx.tensor(0, type: :f32)} end)
final_metrics_map =
epoch_start..epoch_end
|> Map.new(&{&1, zero_metrics})
|> Map.merge(loop_state.metrics)
loop_state = %{loop_state | metrics: zero_metrics}
{status, final_metrics, state} =
case fire_event(:started, handler_fns, loop_state, debug?) do
{:halt_epoch, state} ->
{:halted, final_metrics_map, state}
{:halt_loop, state} ->
{:halted, final_metrics_map, state}
{:continue, state} ->
batch_fn =
{:non_compiled, build_batch_fn(step_fn, metric_fns), jit_compile?, strict?, jit_opts}
epoch_start..epoch_end//1
|> Enum.reduce_while(
{batch_fn, final_metrics_map, state},
fn epoch, {batch_fn, final_metrics_map, loop_state} ->
case fire_event(:epoch_started, handler_fns, loop_state, debug?) do
{:halt_epoch, state} ->
halt_epoch(handler_fns, batch_fn, final_metrics_map, state, debug?)
{:halt_loop, state} ->
{:halt, {final_metrics_map, state}}
{:continue, state} ->
if debug? do
Logger.debug("Axon.Loop started running epoch #{epoch}")
end
{time, status_batch_fn_and_state} =
:timer.tc(&run_epoch/5, [batch_fn, handler_fns, state, data, debug?])
if debug? do
Logger.debug("Axon.Loop finished running epoch in #{us_to_ms(time)} ms")
end
case status_batch_fn_and_state do
{:halt_epoch, batch_fn, state} ->
halt_epoch(handler_fns, batch_fn, final_metrics_map, state, debug?)
{:halt_loop, _, state} ->
{:halt, {final_metrics_map, state}}
{:continue, batch_fn, state} ->
new_loop_state = put_in(state.times[epoch], time)
case fire_event(:epoch_completed, handler_fns, new_loop_state, debug?) do
{:halt_epoch, state} ->
halt_epoch(handler_fns, batch_fn, final_metrics_map, state, debug?)
{:halt_loop, state} ->
{:halt, {final_metrics_map, state}}
{:continue, state} ->
{:cont,
{batch_fn, %{final_metrics_map | epoch => state.metrics},
%State{
state
| epoch: epoch + 1,
metrics: zero_metrics,
iteration: 0,
max_iteration: state.max_iteration
}}}
end
end
end
end
)
|> case do
{final_metrics_map, state} -> {:halted, final_metrics_map, state}
{_batch_fn, final_metrics_map, state} -> {:completed, final_metrics_map, state}
end
end
state = %State{state | metrics: final_metrics, status: status}
output_transform.(state)
end
## Helpers
defp init_loop_state(
init_fn,
sample_data,
init_state,
attached_state,
max_epochs,
max_iterations,
jit_compile?,
jit_opts
) do
case attached_state do
%State{} = state ->
%{state | max_epoch: max_epochs + state.epoch}
nil ->
step_state = maybe_jit(init_fn, [sample_data, init_state], jit_compile?, jit_opts)
%State{
epoch: 0,
max_epoch: max_epochs,
iteration: 0,
max_iteration: max_iterations,
step_state: step_state,
metrics: %{},
times: %{}
}
end
end
defp run_epoch(batch_fn, handler_fns, loop_state, data, debug?) do
Enum.reduce_while(data, {:continue, batch_fn, loop_state}, fn data, {_, batch_fn, state} ->
case fire_event(:iteration_started, handler_fns, state, debug?) do
{:halt_epoch, state} ->
{:halt, {:halt_epoch, batch_fn, state}}
{:halt_loop, state} ->
{:halt, {:halt_loop, batch_fn, state}}
{:continue, state} ->
%State{
iteration: iters,
max_iteration: max_iters,
step_state: step_state,
metrics: metrics
} = state
batch_fn =
case batch_fn do
{:non_compiled, batch_fn, jit_compile?, strict?, jit_opts} ->
cond do
jit_compile? and strict? ->
Nx.Defn.compile(batch_fn, [data, iters, step_state, metrics], jit_opts)
jit_compile? ->
Nx.Defn.jit(batch_fn, jit_opts)
true ->
batch_fn
end
{:compiled, batch_fn} ->
batch_fn
end
if debug? do
Logger.debug("Axon.Loop started batch step execution")
end
{time, {new_step_state, new_metrics}} =
:timer.tc(fn -> batch_fn.(data, iters, step_state, metrics) end)
if debug? do
Logger.debug("Axon.Loop finished batch step execution in #{us_to_ms(time)}ms")
end
batch_fn = {:compiled, batch_fn}
state = %{state | step_state: new_step_state, metrics: new_metrics}
case fire_event(:iteration_completed, handler_fns, state, debug?) do
{:halt_epoch, state} ->
{:halt, {:halt_epoch, batch_fn, state}}
{:halt_loop, state} ->
{:halt, {:halt_loop, batch_fn, state}}
{:continue, state} ->
state = %{state | iteration: iters + 1}
if max_iterations_reached?(max_iters, iters) do
{:halt, {:continue, batch_fn, state}}
else
{:cont, {:continue, batch_fn, state}}
end
end
end
end)
end
defp max_iterations_reached?(max_iters, iters) do
iters >= max_iters - 1 and max_iters > 0
end
# Adds an event handler to the map of handler funs by prepending handler
# to the existing handler funs. Because we prepend here, we must reverse
# handler funs in fire_event.
# TODO(seanmor5): Custom events
defp add_event_handler(event, handle_fns, handler) do
Map.update!(handle_fns, event, fn event_funs -> [handler | event_funs] end)
end
# Fires event `event` using handler_fns associated with the event. We
# must reverse handler funs in order to enforce order that handlers are
# attached to the loop.
# TODO(seanmor5): Custom events
defp fire_event(event, handler_fns, state, debug?) do
handler_fns[event]
|> Enum.reverse()
|> Enum.reduce_while({:continue, state}, fn {handler, filter}, {_, state} ->
if debug? do
Logger.debug("Axon.Loop fired event #{inspect(event)}")
end
state = update_counts(state, event)
if filter.(state, event) do
case handler.(state) do
{:continue, %State{} = state} ->
if debug? do
Logger.debug("Axon.Loop handled event #{inspect(event)} with status :continue")
end
{:cont, {:continue, state}}
{:halt_epoch, %State{} = state} ->
if debug? do
Logger.debug("Axon.Loop handled event #{inspect(event)} with status :halt_epoch")
end
{:halt, {:halt_epoch, state}}
{:halt_loop, %State{} = state} ->
if debug? do
Logger.debug("Axon.Loop handled event #{inspect(event)} with status :halt_loop")
end
{:halt, {:halt_loop, state}}
invalid ->
raise ArgumentError,
"invalid value #{inspect(invalid)} returned from event handler" <>
" triggered on #{inspect(event)}, event handler must return" <>
" a tuple of {status, state} where status is one of :halt_epoch," <>
" :halt_loop, or :continue and state is an updated State struct"
end
else
if debug? do
Logger.debug("Axon.Loop no handlers fired for event #{inspect(event)}")
end
{:cont, {:continue, state}}
end
end)
end
defp update_counts(%State{event_counts: event_counts} = state, event) do
%{state | event_counts: Map.update(event_counts, event, 1, fn x -> x + 1 end)}
end
# Halts an epoch during looping
defp halt_epoch(handler_fns, batch_fn, final_metrics_map, loop_state, debug?) do
case fire_event(:epoch_halted, handler_fns, loop_state, debug?) do
{:halt_epoch, state} ->
{:cont,
{batch_fn, final_metrics_map, %State{state | epoch: state.epoch + 1, iteration: 0}}}
{:halt_loop, state} ->
{:halt, {final_metrics_map, state}}
{:continue, state} ->
{:cont, {batch_fn, final_metrics_map, state}}
end
end
# Builds the overall batch step function from the given
# step function and metrics. We need to run both step and metric
# functions from within here to ensure they can be JIT compiled
# if that's desired
defp build_batch_fn(step_fn, metric_fns) do
fn data, iter, pstate, metrics ->
new_step_state = step_fn.(data, pstate)
new_metrics =
metrics
|> Enum.zip_with(metric_fns, fn {k, avg}, {k, {v, _}} ->
# In some instances the metric is actually present in the
# step state e.g. in a supervised training loop when we
# are computing loss but it's already computed as a part
# of the step state, so we need to check here
metric = String.to_atom(k)
case pstate do
%{^metric => value} ->
{k, value}
%{} ->
{k, v.(avg, List.wrap(new_step_state), iter)}
end
end)
|> Map.new()
{new_step_state, new_metrics}
end
end
# Builds a loss function from an atom, function, or list of. Valid loss
# functions must be one of an atom matching the name of a function in
# Axon.Losses, an arity-2 function of the form loss(y_true, y_pred),
# or a list of 2-tuples of {loss, weight} for constructing a simple
# joint, multi-objective loss function.
# TODO(seanmor5): Configurable per-batch reductions
# TODO(seanmor5): Configurable multi-objective reductions
# TODO(seanmor5): Should we trace custom loss functions and provide a
# more clear error if the output shape is wrong?
defp build_loss_fn(loss) do
case loss do
loss_name when is_atom(loss_name) and loss_name in @valid_axon_losses ->
&apply(Axon.Losses, loss_name, [&1, &2, [reduction: :mean]])
loss_fn when is_function(loss, 2) ->
loss_fn
[{_, _} | _] = losses ->
fn y_true, y_pred ->
{_, loss} =
Enum.reduce(losses, {0, Nx.tensor(0)}, fn {loss, weight}, {i, acc_loss} ->
loss_fn = build_loss_fn(loss)
y_true_i = elem(y_true, i)
y_pred_i = elem(y_pred, i)
new_acc_loss =
y_true_i
|> loss_fn.(y_pred_i)
|> Nx.multiply(weight)
|> Nx.add(acc_loss)
{i + 1, new_acc_loss}
end)
loss
end
invalid ->
raise ArgumentError,
"Invalid loss function #{inspect(invalid)}, a valid loss" <>
" function is an atom which matches a function in Axon.Losses," <>
" an arity-2 function of the form loss(y_true, y_pred), or a list" <>
" of 2-tuples of {loss, weight} for multi-objective models"
end
end
# Builds model init and forward functions from an Axon struct,
# a tuple of Axon structs, or a tuple of init / forward
# functions. Model functions are essentially just model
# init / apply functions.
defp build_model_fns(%Axon{} = model, mode, opts) do
Axon.build(model, [mode: mode] ++ opts)
end
defp build_model_fns({init_fn, forward_fn}, _, _opts)
when is_function(init_fn, 2) and is_function(forward_fn, 2) do
{init_fn, forward_fn}
end
defp build_model_fns(invalid, _, _) do
raise ArgumentError,
"Invalid model #{inspect(invalid)}, a valid model" <>
" is an Axon struct or a tuple of {init_fn, forward_fn} with signatures" <>
" init_fn() :: model_state, forward_fn(model_state, inp) :: prediction"
end
# Builds optimizer init and update functions either from an atom
# or a tuple of init / update functions. The init and update functions
# match the signatures of those defined in Axon.Updates. If the
# optimizer is an atom, it must match the name of a function in
# Axon.Optimizers.
defp build_optimizer_fns(optimizer)
when is_atom(optimizer) and optimizer in @valid_axon_optimizers do
apply(Axon.Optimizers, optimizer, [])
end
defp build_optimizer_fns({init_optimizer_fn, update_optimizer_fn})
when is_function(init_optimizer_fn, 1) and is_function(update_optimizer_fn, 3) do
{init_optimizer_fn, update_optimizer_fn}
end
defp build_optimizer_fns(invalid) do
raise ArgumentError,
"Invalid optimizer #{inspect(invalid)}, a valid optimizer" <>
" is an atom matching the name of an optimizer in Axon.Optimizers" <>
" or a tuple of {init_fn, update_fn}. See Axon.Updates for more" <>
" information on building optimizers using the low-level API"
end
# Builds loss scale init, scale, and unscale functions either from an
# atom or a tuple of init, scale, unscale functions. The init, scale, and
# unscale functions match the signatures of those defined in Axon.LossScale.
# If the loss scale is an atom, it must match the name of a function in
# Axon.LossScale
defp build_loss_scale_fns(loss_scale)
when is_atom(loss_scale) and loss_scale in @valid_axon_loss_scale do
apply(Axon.LossScale, loss_scale, [])
end
defp build_loss_scale_fns({init_scale_fn, scale_fn, unscale_fn})
when is_function(init_scale_fn, 0) and is_function(scale_fn, 2) and
is_function(unscale_fn, 2) do
{init_scale_fn, scale_fn, unscale_fn}
end
defp build_loss_scale_fns(invalid) do
raise ArgumentError,
"Invalid loss scale #{inspect(invalid)}, a valid" <>
" loss scale is an atom matching the name of a loss" <>
" scale implementation in Axon.LossScale or a 3-tuple" <>
" of {init_scale, scale_fn, unscale_fn}. See Axon.LossScale" <>
" for more information"
end
# Builds a metric function from an atom or function and an output transform.
# A valid metric is an atom which matches the name of a function in
# Axon.Metrics or a function which takes an arbitrary number of parameters
# and returns an output of arbitrary shape/type. Output transforms are field(s)
# to extract from the step state, or a function which transforms the step
# state before it is passed to the metric function.
# TODO(seanmor5): Reconsider the form of output transform
defp build_metric_fn(metric, accumulator, transform_or_fields) do
transform_fn =
case transform_or_fields do
[_ | _] = fields ->
fn output ->
fields
|> Enum.reduce([], fn field, acc -> [output[field] | acc] end)
|> Enum.reverse()
end
field when is_atom(field) ->
fn output ->
output[field]
end
transform when is_function(transform, 1) ->
transform
invalid ->
raise ArgumentError,
"Invalid output transform #{inspect(invalid)}, a valid output" <>
" transform is an atom or list of atoms specifying field(s)" <>
" to extract from the step state, or an arity-1 function" <>
" applied to the step state"
end
metric_fn =
case metric do
metric when is_atom(metric) ->
fn output ->
output
|> transform_fn.()
|> then(&apply(Axon.Metrics, metric, &1))
end
metric_fn when is_function(metric) ->
fn output ->
output
|> transform_fn.()
|> then(&apply(metric_fn, &1))
# |> List.wrap()
end
invalid ->
raise ArgumentError,
"Invalid metric #{inspect(invalid)}, a valid metric" <>
" is an atom which matches the name of a function in" <>
" Axon.Metrics or a function which takes a transformed" <>
" step state and returns a value"
end
case accumulator do
acc_fun when acc_fun in [:running_average, :running_sum] ->
apply(Axon.Metrics, acc_fun, [metric_fn])
acc_fun when is_function(acc_fun, 3) ->
&acc_fun.(&1, apply(metric_fn, &2), &3)
invalid ->
raise ArgumentError,
"Invalid accumulation function #{inspect(invalid)}, a valid" <>
" accumulation function is an atom which matches the name" <>
" of an accumulation function in Axon.Metrics, or an arity-3" <>
" function which takes current accumulator, observation, and" <>
" iteration and returns an updated accumulator"
end
end
# Builds a filter function from an atom, keyword list, or function. A
# valid filter is an atom which matches on of the valid predicates `:always`
# or `:once`, a keyword which matches one of the valid predicate-value pairs
# such as `every: N`, or a function which takes loop state and the current event
# and returns `true` to run the handler of `false` to avoid it.
defp build_filter_fn(filter) do
case filter do
:always ->
fn _, _ -> true end
:first ->
fn %State{event_counts: counts}, event ->
counts[event] == 1
end
filters when is_list(filters) ->
Enum.reduce(filters, fn _, _ -> true end, fn
{:every, n}, acc ->
fn state, event ->
acc.(state, event) and filter_every_n(state, event, n)
end
{:before, n}, acc ->
fn state, event ->
acc.(state, event) and filter_before_n(state, event, n)
end
{:after, n}, acc ->
fn state, event ->
acc.(state, event) and filter_after_n(state, event, n)
end
{:once, n}, acc ->
fn state, event ->
acc.(state, event) and filter_once_n(state, event, n)
end
end)
fun when is_function(fun, 2) ->
fun
invalid ->
raise ArgumentError,
"Invalid filter #{inspect(invalid)}, a valid filter" <>
" is an atom which matches a valid filter predicate" <>
" such as :always or :once, a keyword of predicate-value" <>
" pairs such as every: N, or an arity-2 function which takes" <>
" loop state and current event and returns true or false"
end
end
defp filter_every_n(%State{event_counts: counts}, event, n) do
rem(counts[event] - 1, n) == 0
end
defp filter_after_n(%State{event_counts: counts}, event, n) do
counts[event] > n
end
defp filter_before_n(%State{event_counts: counts}, event, n) do
counts[event] < n
end
defp filter_once_n(%State{event_counts: counts}, event, n) do
counts[event] == n
end
# JIT-compiles the given function if jit_compile? is true
# otherwise just applies the function with the given arguments
defp maybe_jit(fun, args, jit_compile?, jit_opts) do
if jit_compile? do
apply(Nx.Defn.jit(fun, jit_opts), args)
else
apply(fun, args)
end
end
defp us_to_ms(time), do: Float.round(time / 1000, 1)
end