defmodule Axon.Schedules do
@moduledoc """
Parameter Schedules.
Parameter schedules are often used to anneal hyperparameters
such as the learning rate during the training process. Schedules
provide a mapping from the current time step to a learning rate
or another hyperparameter.
Choosing a good learning rate and consequently a good learning
rate schedule is typically a process of trial and error. Learning
rates should be relatively small such that the learning curve
does not oscillate violently during the training process, but
not so small that learning proceeds too slowly. Using a
schedule slowly decreases oscillations during the training
process such that, as the model converges, training also
becomes more stable.
All of the functions in this module are implemented as
numerical functions and can be JIT or AOT compiled with
any supported `Nx` compiler.
"""
import Nx.Defn
@doc """
Linear decay schedule.
## Options
* `:warmup` - scheduler warmup steps. Defaults to `0`
* `:steps` - total number of decay steps. Defaults to `1000`
"""
def linear_decay(init_value, opts \\ []) do
&apply_linear_decay(&1, [{:init_value, init_value} | opts])
end
defnp apply_linear_decay(step, opts \\ []) do
opts =
keyword!(opts,
init_value: 1.0e-2,
warmup: 0,
steps: 1000
)
if step < opts[:warmup] do
step / Nx.max(1, opts[:warmup])
else
Nx.max(0.0, (opts[:steps] - step) / Nx.max(1, opts[:steps] - opts[:warmup]))
end
end
@doc ~S"""
Exponential decay schedule.
$$\gamma(t) = \gamma_0 * r^{\frac{t}{k}}$$
## Options
* `:decay_rate` - rate of decay. $r$ in above formulation.
Defaults to `0.95`
* `:transition_steps` - steps per transition. $k$ in above
formulation. Defaults to `10`
* `:transition_begin` - step to begin transition. Defaults to `0`
* `:staircase` - discretize outputs. Defaults to `false`
"""
def exponential_decay(init_value, opts \\ []) do
&apply_exponential_decay(&1, [{:init_value, init_value} | opts])
end
defnp apply_exponential_decay(step, opts \\ []) do
opts =
keyword!(opts,
init_value: 1.0e-2,
decay_rate: 0.95,
transition_steps: 10,
transition_begin: 0,
staircase: false
)
init_value = opts[:init_value]
rate = opts[:decay_rate]
staircase? = opts[:staircase]
k = opts[:transition_steps]
start = opts[:transition_begin]
t = Nx.subtract(step, start)
p =
if staircase? do
t
|> Nx.divide(k)
|> Nx.floor()
else
t
|> Nx.divide(k)
end
decayed_value =
rate
|> Nx.power(p)
|> Nx.multiply(init_value)
Nx.select(
Nx.less_equal(t, 0),
init_value,
decayed_value
)
end
@doc ~S"""
Cosine decay schedule.
$$\gamma(t) = \gamma_0 * \left(\frac{1}{2}(1 - \alpha)(1 + \cos\pi \frac{t}{k}) + \alpha\right)$$
## Options
* `:decay_steps` - number of steps to apply decay for.
$k$ in above formulation. Defaults to `10`
* `:alpha` - minimum value of multiplier adjusting learning rate.
$\alpha$ in above formulation. Defaults to `0.0`
## References
* [SGDR: Stochastic Gradient Descent with Warm Restarts](https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx)
"""
def cosine_decay(init_value, opts \\ []) do
&apply_cosine_decay(&1, [{:init_value, init_value} | opts])
end
defnp apply_cosine_decay(step, opts \\ []) do
opts = keyword!(opts, init_value: 1.0e-2, decay_steps: 10, alpha: 0.0)
init_value = opts[:init_value]
decay_steps = opts[:decay_steps]
alpha = opts[:alpha]
step
|> Nx.min(decay_steps)
|> Nx.divide(decay_steps)
|> Nx.multiply(3.1415926535897932384626433832795028841971)
|> Nx.cos()
|> Nx.add(1)
|> Nx.divide(2)
|> Nx.multiply(1 - alpha)
|> Nx.add(alpha)
|> Nx.multiply(init_value)
end
@doc ~S"""
Constant schedule.
$$\gamma(t) = \gamma_0$$
"""
def constant(init_value, opts \\ []) do
&apply_constant(&1, [{:init_value, init_value} | opts])
end
defnp apply_constant(_step, opts \\ []) do
opts = keyword!(opts, init_value: 0.01)
opts[:init_value]
end
@doc ~S"""
Polynomial schedule.
$$\gamma(t) = (\gamma_0 - \gamma_n) * (1 - \frac{t}{k})^p$$
## Options
* `:end_value` - end value of annealed scalar. $\gamma_n$ in above formulation.
Defaults to `1.0e-3`
* `:power` - power of polynomial. $p$ in above formulation. Defaults to `2`
* `:transition_steps` - number of steps over which annealing takes place.
$k$ in above formulation. Defaults to `10`
"""
def polynomial_decay(init_value, opts \\ []) do
&apply_polynomial_decay(&1, [{:init_value, init_value} | opts])
end
defnp apply_polynomial_decay(step, opts \\ []) do
opts =
keyword!(opts,
init_value: 1.0e-2,
end_value: 1.0e-3,
power: 2,
transition_steps: 10,
transition_begin: 0
)
init_value = opts[:init_value]
end_value = opts[:end_value]
start = opts[:transition_begin]
k = opts[:transition_steps]
p = opts[:power]
step
|> Nx.subtract(start)
|> Nx.clip(0, k)
|> Nx.divide(k)
|> Nx.negate()
|> Nx.add(1)
|> Nx.power(p)
|> Nx.multiply(Nx.subtract(init_value, end_value))
|> Nx.add(end_value)
end
end