defmodule Axon.LossScale do
@moduledoc """
Implementations of loss-scalers for use in mixed precision
training.
Loss scaling is used to prevent underflow when using mixed
precision during the model training process. Each loss-scale
implementation here returns a 3-tuple of the functions:
{init_fn, scale_fn, unscale_fn, adjust_fn} = Axon.LossScale.static(Nx.pow(2, 15))
You can use these to scale/unscale loss and gradients as well
as adjust the loss scale state.
`Axon.Loop.trainer/3` builds loss-scaling in by default. You
can reference the `Axon.Loop.train_step/3` implementation to
see how loss-scaling is applied in practice.
"""
@default_loss_scale 2 ** 15
import Nx.Defn
import Axon.Shared
@doc """
Implements identity loss-scale.
"""
def identity(_opts \\ []) do
scale_unscale_fun = fn x, _state -> x end
adjust_fun = fn x, state -> {x, state} end
{fn -> %{} end, scale_unscale_fun, adjust_fun}
end
@doc """
Implements static loss-scale.
"""
def static(opts \\ []) do
opts = Keyword.validate!(opts, init_scale: @default_loss_scale)
loss_scale = Nx.backend_copy(opts[:init_scale], Nx.BinaryBackend)
{fn -> init_static(loss_scale) end, &scale_static/2, &unscale_static/2}
end
defnp init_static(loss_scale) do
%{loss_scale: loss_scale}
end
defnp scale_static(value, %{loss_scale: loss_scale}) do
deep_new(value, fn x -> x * loss_scale end)
end
defnp unscale_static(value, %{loss_scale: loss_scale} = state) do
inv_loss_scale = 1 / loss_scale
unscaled = deep_new(value, fn x -> x * inv_loss_scale end)
{unscaled, state}
end
@doc """
Implements dynamic loss-scale.
"""
def dynamic(opts \\ []) do
opts =
Keyword.validate!(opts,
init_scale: @default_loss_scale,
period: 2_000,
factor: 2,
min_loss_scale: 1
)
{loss_scale, opts} = Keyword.pop(opts, :init_scale, @default_loss_scale)
loss_scale = Nx.backend_copy(loss_scale, Nx.BinaryBackend)
{
fn -> init_dynamic(loss_scale) end,
&scale_dynamic/2,
&unscale_dynamic(&1, &2, opts)
}
end
defnp init_dynamic(loss_scale) do
%{
loss_scale: loss_scale,
counter: 0
}
end
defnp scale_dynamic(value, %{loss_scale: loss_scale}) do
deep_new(value, fn x -> x * loss_scale end)
end
defnp unscale_dynamic(value, %{loss_scale: loss_scale} = state, opts \\ []) do
inv_loss_scale = 1 / loss_scale
unscaled = deep_new(value, fn x -> x * inv_loss_scale end)
{unscaled, adjust_dynamic(value, state, opts)}
end
defnp adjust_dynamic(grads, %{loss_scale: loss_scale, counter: counter}, opts \\ []) do
opts = keyword!(opts, period: 2_000, factor: 2, min_loss_scale: 1)
grads_are_finite =
deep_reduce(grads, Nx.tensor(1), fn x, acc ->
x
|> is_finite()
|> Nx.logical_and(acc)
end)
new_loss_scale =
Nx.select(
grads_are_finite,
Nx.select(
Nx.equal(counter, opts[:period] - 1),
first_finite(loss_scale * opts[:factor], loss_scale),
loss_scale
),
Nx.max(opts[:min_loss_scale], loss_scale / opts[:factor])
)
new_counter = Nx.remainder(counter + 1, opts[:period]) * grads_are_finite
%{loss_scale: new_loss_scale, counter: new_counter}
end
defnp is_finite(x), do: Nx.all(Nx.logical_not(Nx.is_infinity(x)))
defnp first_finite(a, b), do: Nx.select(is_finite(a), a, b)
end