defmodule Polaris.Optimizers do
@moduledoc """
Implementations of common gradient-based optimization algorithms.
All of the methods in this module are written in terms of
the update methods defined in `Polaris.Updates`. Polaris treats
optimizers as the tuple:
{init_fn, update_fn}
where `init_fn` returns an initial optimizer state and `update_fn`
scales input gradients. `init_fn` accepts a model's parameters
and attaches state to each parameter. `update_fn` accepts
gradients, optimizer state, and current model parameters and
returns updated optimizer state and gradients.
Custom optimizers are often created via the `Polaris.Updates` API.
## Example
Consider the following usage of the Adam optimizer in a basic
update function (assuming `objective` and the `dataset` are
defined elsewhere):
defmodule Learning do
import Nx.Defn
defn init(params, init_fn) do
init_fn.(params)
end
defn update(params, optimizer_state, inputs, targets, update_fn) do
{loss, gradient} = value_and_grad(params, &objective(&1, inputs, targets))
{scaled_updates, new_optimizer_state} = update_fn.(gradient, optimizer_state, params)
{Polaris.Updates.apply_updates(params, scaled_updates), new_optimizer_state, loss}
end
end
{model_params, _key} = Nx.Random.uniform(key, shape: {784, 10})
{init_fn, update_fn} = Polaris.Optimizers.adam(0.005)
optimizer_state =
Learning.init(params, init_fn)
{new_params, new_optimizer_state, loss} =
Learning.update(params, optimizer_state, inputs, targets, update_fn)
"""
alias Polaris.Updates
@doc """
Adabelief optimizer.
## Options
* `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-3`
* `:b1` - first moment decay. Defaults to `0.9`
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `0.0`
* `:eps_root` - numerical stability term. Defaults to `1.0e-16`
## References
* [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468)
"""
def adabelief(opts \\ []) do
{learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-3)
Updates.scale_by_belief(opts)
|> scale_by_learning_rate(learning_rate)
end
@doc """
Adagrad optimizer.
## Options
* `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-3`
* `:eps` - numerical stability term. Defaults to `1.0e-7`
## References
* [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
"""
def adagrad(opts \\ []) do
{learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-3)
Updates.scale_by_rss(opts)
|> scale_by_learning_rate(learning_rate)
end
@doc """
Adam optimizer.
## Options
* `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-3`
* `:b1` - first moment decay. Defaults to `0.9`
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `1.0e-15`
## References
* [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
"""
def adam(opts \\ []) do
{learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-3)
Updates.scale_by_adam(opts)
|> scale_by_learning_rate(learning_rate)
end
@doc """
Adam with weight decay optimizer.
## Options
* `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-3`
* `:b1` - first moment decay. Defaults to `0.9`
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `0.0`
* `:decay` - weight decay. Defaults to `0.0`
"""
def adamw(opts \\ []) do
{learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-3)
{decay, opts} = Keyword.pop(opts, :decay, 0.0)
Updates.scale_by_adam(opts)
|> Updates.add_decayed_weights(decay: decay)
|> scale_by_learning_rate(learning_rate)
end
@doc """
Lamb optimizer.
## Options
* `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-2`
* `:b1` - first moment decay. Defaults to `0.9`
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `0.0`
* `:decay` - weight decay. Defaults to `0.0`
* `:min_norm` - minimum norm value. Defaults to `0.0`
## References
* [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962)
"""
def lamb(opts \\ []) do
{learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-2)
{decay, opts} = Keyword.pop(opts, :decay, 0.0)
{min_norm, opts} = Keyword.pop(opts, :min_norm, 0.0)
Updates.scale_by_adam(opts)
|> Updates.add_decayed_weights(decay: decay)
|> Updates.scale_by_trust_ratio(min_norm: min_norm)
|> scale_by_learning_rate(learning_rate)
end
@doc """
Noisy SGD optimizer.
## Options
* `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-2`
* `:eta` - used to compute variance of noise distribution. Defaults to `0.1`
* `:gamma` - used to compute variance of noise distribution. Defaults to `0.55`
"""
def noisy_sgd(opts \\ []) do
{learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-2)
scale_by_learning_rate(learning_rate)
|> Updates.add_noise(opts)
end
@doc """
Rectified Adam optimizer.
## Options
* `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-3`
* `:b1` - first moment decay. Defaults to `0.9`
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `0.0`
* `:threshold` - threshold term. Defaults to `5.0`
## References
* [On the Variance of Adaptive Learning Rate and Beyond](https://arxiv.org/pdf/1908.03265.pdf)
"""
def radam(opts \\ []) do
{learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-3)
Updates.scale_by_radam(opts)
|> scale_by_learning_rate(learning_rate)
end
@doc """
RMSProp optimizer.
## Options
* `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-2`
* `:centered` - whether to scale by centered root of EMA of squares. Defaults to `false`
* `:momentum` - momentum term. If set, uses SGD with momentum and decay set
to value of this term.
* `:nesterov` - whether or not to use nesterov momentum. Defaults to `false`
* `:initial_scale` - initial value of EMA. Defaults to `0.0`
* `:decay` - EMA decay rate. Defaults to `0.9`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
"""
def rmsprop(opts \\ []) do
{learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-2)
{centered, opts} = Keyword.pop(opts, :centered, false)
{nesterov?, opts} = Keyword.pop(opts, :nesterov, false)
{momentum, opts} = Keyword.pop(opts, :momentum, nil)
combinator =
if centered do
Updates.scale_by_stddev(opts)
else
Updates.scale_by_rms(opts)
end
|> scale_by_learning_rate(learning_rate)
if momentum,
do: Updates.trace(combinator, decay: momentum, nesterov: nesterov?),
else: combinator
end
@doc """
SGD optimizer.
## Options
* `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-2`
* `:momentum` - momentum term. If set, uses SGD with momentum and decay set
to value of this term.
* `:nesterov` - whether or not to use nesterov momentum. Defaults to `false`
"""
def sgd(opts \\ []) do
{learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-2)
momentum = opts[:momentum]
nesterov? = opts[:nesterov] || false
if momentum do
Updates.trace(decay: momentum, nesterov: nesterov?)
|> scale_by_learning_rate(learning_rate)
else
scale_by_learning_rate(learning_rate)
end
end
@doc """
Yogi optimizer.
## Options
* `:learning_rate` - the learning rate for the optimizer. Defaults to `1.0e-2`
* `:initial_accumulator_value` - initial value for first and second moment. Defaults to `0.0`
* `:b1` - first moment decay. Defaults to `0.9`
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `0.0`
## References
* [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf)
"""
def yogi(opts \\ []) do
{learning_rate, opts} = Keyword.pop(opts, :learning_rate, 1.0e-2)
Updates.scale_by_yogi(opts)
|> scale_by_learning_rate(learning_rate)
end
## Helpers
defp scale_by_learning_rate(combinator \\ Updates.identity(), lr)
defp scale_by_learning_rate(combinator, schedule) when is_function(schedule, 1) do
Updates.scale_by_schedule(combinator, fn count -> Nx.negate(schedule.(count)) end)
end
defp scale_by_learning_rate(combinator, lr) do
Updates.scale_by_state(combinator, -lr)
end
end