defmodule Axon do
@moduledoc """
A high-level interface for creating neural network models.
Axon is built entirely on top of Nx numerical definitions,
so every neural network can be JIT or AOT compiled using
any Nx compiler, or even transformed into high-level neural
network formats like TensorFlow Lite and
[ONNX](https://github.com/elixir-nx/axon_onnx).
## Model Creation
All Axon models start with an input layer, specifying the
expected input shape of the training data:
input = Axon.input({nil, 784}, "input")
Notice you can specify some dimensions as `nil`, indicating
that the dimension size will be filled in at model runtime.
You can then compose inputs with other layers:
model =
input
|> Axon.dense(128, activation: :relu)
|> Axon.batch_norm()
|> Axon.dropout(rate: 0.8)
|> Axon.dense(64)
|> Axon.tanh()
|> Axon.dense(10)
|> Axon.activation(:softmax)
You can inspect the model for a nice summary:
IO.inspect(model)
---------------------------------------------------------------------------------------------------------
Model
=========================================================================================================
Layer Shape Policy Parameters Parameters Memory
=========================================================================================================
input ( input ) {nil, 784} p=f32 c=f32 o=f32 0 0 bytes
dense_0 ( dense["input"] ) {nil, 128} p=f32 c=f32 o=f32 100480 401920 bytes
relu_0 ( relu["dense_0"] ) {nil, 128} p=f32 c=f32 o=f32 0 0 bytes
batch_norm_0 ( batch_norm["relu_0"] ) {nil, 128} p=f32 c=f32 o=f32 512 2048 bytes
dropout_0 ( dropout["batch_norm_0"] ) {nil, 128} p=f32 c=f32 o=f32 0 0 bytes
dense_1 ( dense["dropout_0"] ) {nil, 64} p=f32 c=f32 o=f32 8256 33024 bytes
tanh_0 ( tanh["dense_1"] ) {nil, 64} p=f32 c=f32 o=f32 0 0 bytes
dense_2 ( dense["tanh_0"] ) {nil, 10} p=f32 c=f32 o=f32 650 2600 bytes
softmax_0 ( softmax["dense_2"] ) {nil, 10} p=f32 c=f32 o=f32 0 0 bytes
---------------------------------------------------------------------------------------------------------
Total Parameters: 109898
Total Parameters Memory: 439592 bytes
Inputs: %{"input" => {nil, 784}}
### Multiple Inputs
Creating a model with multiple inputs is as easy as declaring an
additional input in your Axon graph. Every input layer present in
the final Axon graph will be required to be passed as input at the
time of model execution.
inp1 = Axon.input({nil, 1}, "input_0")
inp2 = Axon.input({nil, 1}, "input_1")
# Both inputs will be used
model1 = Axon.add(inp1, inp2)
# Only inp2 will be used
model2 = Axon.add(inp2, inp2)
Axon graphs are immutable, which means composing and manipulating
an Axon graph creates an entirely new graph. Additionally, layer
names are lazily generated at model execution time. To avoid
non-deterministic input orderings and names, Axon requires each
input to have a unique binary identifier. You can then reference
inputs by name when passing to models at execution time:
inp1 = Axon.input({nil, 1}, "input_0")
inp2 = Axon.input({nil, 1}, "input_1")
model1 = Axon.add(inp1, inp2)
params1 = Axon.init(model1)
# Inputs are referenced by name
Axon.predict(model1, params1, %{"input_0" => x, "input_1" => y})
### Multiple Outputs
Nx offers robust [container](https://hexdocs.pm/nx/Nx.Container.html) support
which is extended to Axon. Axon allows you to wrap any valid Nx container
in a layer. Containers are most commonly used to structure outputs:
inp1 = Axon.input({nil, 1}, "input_0")
inp2 = Axon.input({nil, 1}, "input_1")
model = Axon.container(%{foo: inp1, bar: inp2})
Containers can be arbitrarily nested:
inp1 = Axon.input({nil, 1}, "input_0")
inp2 = Axon.input({nil, 1}, "input_1")
model = Axon.container({%{foo: {inp1, %{bar: inp2}}}})
You can even use custom structs which implement the container protocol:
inp1 = Axon.input({nil, 1}, "input_0")
inp2 = Axon.input({nil, 1}, "input_1")
model = Axon.container(%MyStruct{foo: inp1, bar: inp2})
### Custom Layers
If you find that Axon's built-in layers are insufficient for your needs,
you can create your own using the custom layer API. All of Axon's built-in
layers (aside from special ones such as `input`, `constant`, and `container`)
make use of this same API.
Axon layers are really just placeholders for Nx computations with trainable
parameters and possibly state. To define a custom layer, you just need to
define a `defn` implementation:
defn my_layer(x, weight, _opts \\ []) do
Nx.atan2(x, weight)
end
Notice the only stipulation is that your custom layer implementation must
accept at least 1 input and a list of options. At execution time, every
layer will be passed a `:mode` option which can be used to control behavior
at training and inference time.
Inputs to your custom layer can be either Axon graph inputs or trainable
parameters. You can pass Axon graph inputs as-is to a custom layer. To
declare trainable parameters, use `Axon.param/3`:
weight = Axon.param(input_shape, "weight")
To create a custom layer, you "wrap" your implementation and inputs into
a layer using `Axon.layer`. You'll notice the API mirrors Elixir's `apply`:
def atan2_layer(%Axon{output_shape: shape} = input) do
weight = Axon.param(input_shape, "weight")
Axon.layer(&my_layer/3, [input, weight])
end
## Model Execution
Under the hood, Axon models are represented as Elixir structs. You
can initialize and apply models using the macros `Axon.init/3` and
`Axon.predict/4`:
params = Axon.init(model, compiler: EXLA)
Axon.predict(model, params, inputs, compiler: EXLA, mode: :train)
It is suggested that you set compiler options globally rather than pass
them as options to execution macros:
EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])
params = Axon.init(model)
Axon.predict(model, params, inputs, mode: :train)
`Axon.predict/4` by default runs in inference mode, which performs certain
optimizations and removes layers such as dropout layers. If constructing
a training step using `Axon.predict/4`, be sure to specify `mode: :train`.
## Model Training
Combining the Axon model creation API with the optimization and training
APIs, you can create and train neural networks with ease:
model =
Axon.input({nil, 784}, "input_0")
|> Axon.dense(128, activation: :relu)
|> Axon.layer_norm()
|> Axon.dropout()
|> Axon.dense(10, activation: :softmax)
IO.inspect model
model_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005))
|> Axon.Loop.run(train_data, epochs: 10, compiler: EXLA)
See `Axon.Updates` and `Axon.Loop` for a more in-depth treatment of
model optimization and model training.
"""
alias __MODULE__, as: Axon
alias Axon.Parameter
# Axon serialization version
@file_version 1
@type t :: %__MODULE__{}
defstruct [
:id,
:name,
:output_shape,
:parent,
:parameters,
:args,
:op,
:policy,
:hooks,
:opts,
:op_name
]
@doc """
Custom Axon layer with given inputs.
Inputs may be other Axon layers or trainable parameters created
with `Axon.param`. At inference time, `op` will be applied with
inputs in specified order and an additional `opts` parameter which
specifies inference options. All options passed to layer are forwarded
to inference function except:
* `:shape` - specify layer output shape to bypass shape inference.
* `:name` - layer name.
* `:op_name` - layer operation for inspection and building parameter
map.
Note this means your layer should not use these as input options,
as they will always be dropped during inference compilation.
Axon's compiler will additionally forward the following options to
every layer at inference time:
* `:mode` - `:inference` or `:train`. To control layer behavior
based on inference or train time.
`op` is a function of the form:
fun = fn input, weight, bias, _opts ->
input * weight + bias
end
"""
@doc type: :special
def layer(op, inputs, opts \\ []) when (is_atom(op) or is_function(op)) and is_list(inputs) do
{inputs, params, args, input_shapes} = split_inputs(op, inputs)
inputs = Enum.reverse(inputs)
params = Enum.reverse(params)
args = Enum.reverse(args)
input_shapes = Enum.reverse(input_shapes)
{name, opts} = Keyword.pop(opts, :name)
{shape, opts} = Keyword.pop(opts, :shape)
{op_name, opts} = Keyword.pop(opts, :op_name, :custom)
{id, name} = unique_identifiers(op_name, name)
output_shape =
if shape do
shape
else
infer_shape(input_shapes, op, opts)
end
%Axon{
id: id,
name: name,
output_shape: output_shape,
parent: inputs,
parameters: params,
args: args,
op: op,
policy: Axon.MixedPrecision.create_policy(),
hooks: [],
opts: opts,
op_name: op_name
}
end
defp split_inputs(:container, [container] = inputs) do
input_shapes = deep_new(container, fn %Axon{output_shape: shape} -> shape end)
args = [:layer]
params = []
{inputs, params, args, [input_shapes]}
end
defp split_inputs(_op, inputs) do
Enum.reduce(inputs, {[], [], [], []}, fn
%Axon{output_shape: shape} = layer, {layers, params, args, shapes} ->
{[layer | layers], params, [:layer | args], [shape | shapes]}
%Parameter{shape: shape} = param, {layers, params, args, shapes} ->
{layers, [param | params], [:parameter | args], [shape | shapes]}
invalid, _ ->
raise ArgumentError, "invalid input given to layer: #{inspect(invalid)}"
end)
end
defp infer_shape(input_shapes, fun, opts) do
{inputs, indices} =
Enum.reduce(input_shapes, {[], []}, fn shape, {input_shapes, indices} ->
{template, template_indices} = template_shape(shape)
{[template | input_shapes], [template_indices | indices]}
end)
inputs = Enum.reverse(inputs)
opts = Keyword.put(opts, :mode, :inference)
wrapper_fun = fn tensors ->
tensors = Tuple.to_list(tensors)
apply(fun, tensors ++ [opts])
end
expr = Nx.Defn.jit(wrapper_fun, [List.to_tuple(inputs)], compiler: Axon.Defn)
indices = Enum.map(indices, &MapSet.new/1)
indices_that_are_1 =
expr.shape
|> Tuple.to_list()
|> Enum.with_index()
|> Enum.filter(fn {x, _} -> x == 1 end)
|> Enum.map(fn {_, i} -> i end)
|> MapSet.new()
indices_to_make_nil =
case indices do
[] ->
[]
indices ->
indices
|> Enum.reduce(MapSet.new(), &MapSet.union/2)
|> MapSet.intersection(indices_that_are_1)
|> Enum.to_list()
end
Enum.reduce(indices_to_make_nil, expr.shape, fn i, shape ->
put_elem(shape, i, nil)
end)
end
defp template_shape(shape) when is_map(shape) do
Nx.Container.traverse(shape, [], &recur_template_shape/2)
end
defp template_shape(shape) do
if tuple_size(shape) == 0 do
{Nx.template({}, {:f, 32}), []}
else
first_elem = elem(shape, 0)
if is_integer(first_elem) or is_nil(first_elem) do
{shape, template_indices} = Axon.Shape.replace_nil(shape)
template = Nx.template(shape, {:f, 32})
{template, List.wrap(template_indices)}
else
Nx.Container.traverse(shape, [], &recur_template_shape/2)
end
end
end
defp recur_template_shape(shape, indices) do
case shape do
shape when is_map(shape) ->
{template, template_indices} = template_shape(shape)
{template, indices ++ template_indices}
shape when is_tuple(shape) ->
{template, template_indices} = template_shape(shape)
{template, indices ++ template_indices}
end
end
@doc """
Trainable Axon parameter used to create custom layers.
Parameters are specified in usages of `Axon.layer` and will
be automatically initialized and used in subsequent applications
of Axon models.
Parameters *must* be specified in order of their usage.
## Options
* `:initializer` - parameter initializer. Defaults to `:glorot_uniform`.
"""
def param(name, shape, opts \\ []) when is_binary(name) and is_tuple(shape) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform)
initializer = opts[:initializer]
validate_initializer!(initializer)
id = System.unique_integer([:positive, :monotonic])
%Axon.Parameter{
id: id,
name: name,
shape: shape,
initializer: initializer
}
end
@doc """
Adds an input layer to the network.
Input layers specify a model's inputs. Input layers are
always the root layers of the neural network.
You must specify the input layers name, which will be used
to uniquely identify it in the case of multiple inputs.
"""
@doc type: :special
def input(input_shape, name) when is_tuple(input_shape) and is_binary(name) do
output_shape = Axon.Shape.input(input_shape)
layer(:input, [], name: name, shape: output_shape, op_name: :input)
end
@doc """
Adds a constant layer to the network.
Constant layers encapsulate Nx tensors in an Axon layer for ease
of use with other Axon layers. They can be used interchangeably
with other Axon layers:
inp = Axon.input({nil, 32}, "input")
my_constant = Axon.constant(Nx.iota({1, 32}))
model = Axon.add(inp, my_constant)
Constant layers will be cast according to the mixed precision policy.
If it's important for your constant to retain it's type during
the computation, you will need to set the mixed precision policy to
ignore constant layers.
## Options
* `:name` - layer name.
"""
def constant(tensor, opts \\ [])
@doc type: :special
def constant(%Nx.Tensor{shape: output_shape} = tensor, opts) do
opts = Keyword.validate!(opts, [:name])
layer(:constant, [], name: opts[:name], value: tensor, shape: output_shape, op_name: :constant)
end
def constant(value, _) do
raise ArgumentError,
"value passed to constant must be an Nx tensor" <>
" but got #{inspect(value)}, if you are passing" <>
" a number, wrap it with a call to Nx.tensor/2"
end
@doc """
Adds a container layer to the network.
In certain cases you may want your model to have multiple
outputs. In order to make this work, you must "join" the
outputs into an Axon layer using this function for use in
initialization and inference later on.
The given container can be any valid Axon Nx container.
## Options
* `:name` - layer name.
## Examples
iex> inp1 = Axon.input({nil, 1}, "input_0")
iex> inp2 = Axon.input({nil, 2}, "input_1")
iex> model = Axon.container(%{a: inp1, b: inp2})
iex> %{a: a, b: b} = Axon.predict(model, %{}, %{
...> "input_0" => Nx.tensor([[1.0]]),
...> "input_1" => Nx.tensor([[1.0, 2.0]])
...> })
iex> a
#Nx.Tensor<
f32[1][1]
[
[1.0]
]
>
iex> b
#Nx.Tensor<
f32[1][2]
[
[1.0, 2.0]
]
>
"""
@doc type: :special
def container(container, opts \\ []) do
opts = Keyword.validate!(opts, [:name])
output_shape =
deep_new(container, fn %Axon{output_shape: shape} ->
shape
end)
layer(:container, [container], name: opts[:name], shape: output_shape, op_name: :container)
end
# TODO: This should not be duplicated
defp deep_new(map, fun) do
{cont, :ok} = Nx.Container.traverse(map, :ok, &recur_traverse(&1, &2, fun))
cont
end
defp recur_traverse(item, :ok, fun) do
case item do
%Axon{} = t ->
{fun.(t), :ok}
%{axon: :axon} = t ->
{fun.(t), :ok}
container ->
{deep_new(container, fun), :ok}
end
end
@doc """
Wraps an Axon model into a namespace.
A namespace is a part of an Axon model which is meant to
be a self-contained collection of Axon layers. Namespaces
are guaranteed to always generate with the same internal
layer names and can be re-used universally across models.
Namespaces are most useful for containing large collections
of layers and offering a straightforward means for accessing
the parameters of individual model components. A common application
of namespaces is to use them in with a pre-trained model for
fine-tuning:
{base, resnet_params} = resnet()
base = base |> Axon.namespace("resnet")
model = base |> Axon.dense(1)
Axon.init(model, %{"resnset" => resnet_params})
Notice you can use `Axon.init` in conjunction with namespaces
to specify which portion of a model you'd like to initialize
from a fixed starting point.
Namespaces have fixed names, which means it's easy to run into namespace
collisions. Re-using namespaces, re-using inner parts of a namespace,
and attempting to share layers between namespaces are still sharp
edges in namespace usage.
"""
def namespace(%Axon{output_shape: shape} = axon, name) when is_binary(name) do
layer(:namespace, [axon], name: name, shape: shape)
end
@doc """
Adds a dense layer to the network.
The dense layer implements:
output = activation(dot(input, kernel) + bias)
where `activation` is given by the `:activation` option and both
`kernel` and `bias` are layer parameters. `units` specifies the
number of output units.
Compiles to `Axon.Layers.dense/4`.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`.
* `:activation` - element-wise activation function.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`.
"""
@doc type: :linear
def dense(%Axon{output_shape: parent_shape} = x, units, opts \\ [])
when is_integer(units) and units > 0 do
opts =
Keyword.validate!(opts, [
:name,
:activation,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: true
])
kernel_shape = Axon.Shape.dense_kernel(parent_shape, units)
bias_shape = Axon.Shape.dense_bias(parent_shape, units)
output_shape = Axon.Shape.dense(parent_shape, units)
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
{inputs, op} =
if opts[:use_bias] do
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
{[x, kernel, bias], :dense}
else
{[x, kernel], &Axon.Layers.dense(&1, &2, 0, &3)}
end
node = layer(op, inputs, name: opts[:name], shape: output_shape, op_name: :dense)
if activation = opts[:activation] do
activation(node, activation)
else
node
end
end
@doc """
Adds a bilinear layer to the network.
The bilinear layer implements:
output = activation(dot(dot(input1, kernel), input2) + bias)
where `activation` is given by the `:activation` option and both
`kernel` and `bias` are layer parameters. `units` specifies the
number of output units.
All dimensions but the last of `input1` and `input2` must match. The
batch sizes of both inputs must also match or at least one must be `nil`.
Inferred output batch size coerces to the strictest input batch size.
Compiles to `Axon.Layers.bilinear/5`.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`.
* `:activation` - element-wise activation function.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`.
"""
@doc type: :linear
def bilinear(
%Axon{output_shape: parent1_shape} = input1,
%Axon{output_shape: parent2_shape} = input2,
units,
opts \\ []
)
when is_integer(units) and units > 0 do
opts =
Keyword.validate!(opts, [
:name,
:activation,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: true
])
kernel_shape = Axon.Shape.bilinear_kernel(parent1_shape, parent2_shape, units)
bias_shape = Axon.Shape.bilinear_bias(parent1_shape, parent2_shape, units)
output_shape = Axon.Shape.bilinear(parent1_shape, parent2_shape, units)
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
{inputs, op} =
if opts[:use_bias] do
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
{[input1, input2, kernel, bias], :bilinear}
else
{[input1, input2, kernel], &Axon.Layers.bilinear(&1, &2, &3, 0, &4)}
end
node = layer(op, inputs, name: opts[:name], shape: output_shape, op_name: :bilinear)
if activation = opts[:activation] do
activation(node, activation)
else
node
end
end
@doc """
Adds a convolution layer to the network.
The convolution layer implements a general dimensional
convolutional layer - which convolves a kernel over the input
to produce an output.
Compiles to `Axon.Layers.conv/4`.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`
* `:activation` - element-wise activation function.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
to `1`.
* `:strides` - stride during convolution. Defaults to `1`.
* `:padding` - padding to the spatial dimensions of the input.
Defaults to `:valid`.
* `:input_dilation` - dilation to apply to input. Defaults to `1`.
* `:kernel_dilation` - dilation to apply to kernel. Defaults to `1`.
* `:feature_group_size` - feature group size for convolution. Defaults
to `1`.
* `:channels` - channels location. One of `:first` or `:last`.
Defaults to `:first`.
"""
@doc type: :convolution
def conv(%Axon{output_shape: parent_shape} = x, units, opts \\ [])
when is_integer(units) and units > 0 do
opts =
Keyword.validate!(opts, [
:name,
:activation,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: true,
kernel_size: 1,
strides: 1,
padding: :valid,
input_dilation: 1,
kernel_dilation: 1,
channels: :first,
feature_group_size: 1
])
kernel_size = opts[:kernel_size]
strides = opts[:strides]
padding = opts[:padding]
input_dilation = opts[:input_dilation]
kernel_dilation = opts[:kernel_dilation]
channels = opts[:channels]
feature_group_size = opts[:feature_group_size]
inner_rank = Nx.rank(parent_shape) - 2
kernel_size = tuple_or_duplicate(:kernel_size, kernel_size, inner_rank)
strides = list_or_duplicate(:strides, strides, inner_rank)
input_dilation = list_or_duplicate(:input_dilation, input_dilation, inner_rank)
kernel_dilation = list_or_duplicate(:kernel_dilation, kernel_dilation, inner_rank)
kernel_shape = Axon.Shape.conv_kernel(parent_shape, units, kernel_size, channels)
bias_shape = Axon.Shape.conv_bias(parent_shape, units, kernel_size, channels)
output_shape =
Axon.Shape.conv(
parent_shape,
kernel_shape,
strides,
padding,
input_dilation,
kernel_dilation,
channels,
feature_group_size
)
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
{inputs, op} =
if opts[:use_bias] do
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
{[x, kernel, bias], :conv}
else
{[x, kernel], &Axon.Layers.conv(&1, &2, 0, &3)}
end
node =
layer(op, inputs,
name: opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
kernel_dilation: kernel_dilation,
feature_group_size: feature_group_size,
channels: channels,
shape: output_shape,
op_name: :conv
)
if activation = opts[:activation] do
activation(node, activation)
else
node
end
end
@doc """
Adds a transposed convolution layer to the network.
The transposed convolution layer is sometimes referred to as a
fractionally strided convolution or (incorrectly) as a deconvolution.
Compiles to `Axon.Layers.conv_transpose/4`.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`
* `:activation` - element-wise activation function.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
to `1`.
* `:strides` - stride during convolution. Defaults to `1`.
* `:padding` - padding to the spatial dimensions of the input.
Defaults to `:valid`.
* `:kernel_dilation` - dilation to apply to kernel. Defaults to `1`.
* `:channels` - channels location. One of `:first` or `:last`.
Defaults to `:first`.
"""
@doc type: :convolution
def conv_transpose(%Axon{output_shape: parent_shape} = x, units, opts \\ []) do
opts =
Keyword.validate!(opts, [
:name,
:activation,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: true,
kernel_size: 1,
strides: 1,
padding: :valid,
kernel_dilation: 1,
channels: :first
])
kernel_size = opts[:kernel_size]
strides = opts[:strides]
padding = opts[:padding]
kernel_dilation = opts[:kernel_dilation]
channels = opts[:channels]
inner_rank = Nx.rank(parent_shape) - 2
kernel_size = tuple_or_duplicate(:kernel_size, kernel_size, inner_rank)
strides = list_or_duplicate(:strides, strides, inner_rank)
kernel_dilation = list_or_duplicate(:kernel_dilation, kernel_dilation, inner_rank)
kernel_shape = Axon.Shape.conv_kernel(parent_shape, units, kernel_size, channels)
bias_shape = Axon.Shape.conv_bias(parent_shape, units, kernel_size, channels)
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
{inputs, op} =
if opts[:use_bias] do
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
{[x, kernel, bias], :conv_transpose}
else
{[x, kernel], &Axon.Layers.conv_transpose(&1, &2, 0, &3)}
end
output_shape =
Axon.Shape.conv_transpose(
parent_shape,
kernel_shape,
strides,
padding,
kernel_dilation,
channels
)
node =
layer(op, inputs,
name: opts[:name],
strides: strides,
padding: padding,
kernel_dilation: kernel_dilation,
channels: channels,
shape: output_shape,
op_name: :conv_transpose
)
if activation = opts[:activation] do
activation(node, activation)
else
node
end
end
@doc """
Adds a depthwise convolution layer to the network.
The depthwise convolution layer implements a general
dimensional depthwise convolution - which is a convolution
where the feature group size is equal to the number of
input channels.
Channel multiplier grows the input channels by the given
factor. An input factor of 1 means the output channels
are the same as the input channels.
Compiles to `Axon.Layers.depthwise_conv/4`.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`
* `:activation` - element-wise activation function.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
to `1`.
* `:strides` - stride during convolution. Defaults to `1`.
* `:padding` - padding to the spatial dimensions of the input.
Defaults to `:valid`.
* `:input_dilation` - dilation to apply to input. Defaults to `1`.
* `:kernel_dilation` - dilation to apply to kernel. Defaults to `1`.
* `:channels` - channels location. One of `:first` or `:last`.
Defaults to `:first`.
"""
@doc type: :convolution
def depthwise_conv(%Axon{output_shape: parent_shape} = x, channel_multiplier, opts \\ [])
when is_integer(channel_multiplier) and channel_multiplier >= 1 do
opts =
Keyword.validate!(opts, [
:name,
:activation,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: true,
kernel_size: 1,
strides: 1,
padding: :valid,
input_dilation: 1,
kernel_dilation: 1,
channels: :first
])
kernel_size = opts[:kernel_size]
strides = opts[:strides]
padding = opts[:padding]
input_dilation = opts[:input_dilation]
kernel_dilation = opts[:kernel_dilation]
channels = opts[:channels]
inner_rank = Nx.rank(parent_shape) - 2
kernel_size = tuple_or_duplicate(:kernel_size, kernel_size, inner_rank)
strides = list_or_duplicate(:strides, strides, inner_rank)
input_dilation = list_or_duplicate(:input_dilation, input_dilation, inner_rank)
kernel_dilation = list_or_duplicate(:kernel_dilation, kernel_dilation, inner_rank)
kernel_shape =
Axon.Shape.depthwise_conv_kernel(parent_shape, channel_multiplier, kernel_size, channels)
bias_shape =
Axon.Shape.depthwise_conv_bias(parent_shape, channel_multiplier, kernel_size, channels)
output_shape =
Axon.Shape.depthwise_conv(
parent_shape,
kernel_shape,
strides,
padding,
input_dilation,
kernel_dilation,
channels
)
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
{inputs, op} =
if opts[:use_bias] do
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
{[x, kernel, bias], :depthwise_conv}
else
{[x, kernel], &Axon.Layers.depthwise_conv(&1, &2, 0, &3)}
end
node =
layer(op, inputs,
name: opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
kernel_dilation: kernel_dilation,
channels: channels,
shape: output_shape,
op_name: :depthwise_conv
)
if activation = opts[:activation] do
activation(node, activation)
else
node
end
end
@doc """
Adds a depthwise separable 2-dimensional convolution to the
network.
Depthwise separable convolutions break the kernel into kernels
for each dimension of the input and perform a depthwise conv
over the input with each kernel.
Compiles to `Axon.Layers.separable_conv2d/6`.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`
* `:activation` - element-wise activation function.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
to `1`.
* `:strides` - stride during convolution. Defaults to `1`.
* `:padding` - padding to the spatial dimensions of the input.
Defaults to `:valid`.
* `:input_dilation` - dilation to apply to input. Defaults to `1`.
* `:kernel_dilation` - dilation to apply to kernel. Defaults to `1`.
* `:channels` - channels location. One of `:first` or `:last`.
Defaults to `:first`.
"""
@doc type: :convolution
def separable_conv2d(%Axon{output_shape: parent_shape} = x, channel_multiplier, opts \\ [])
when is_integer(channel_multiplier) and channel_multiplier >= 1 do
opts =
Keyword.validate!(opts, [
:name,
:activation,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: true,
kernel_size: 1,
strides: 1,
padding: :valid,
input_dilation: 1,
kernel_dilation: 1,
channels: :first
])
kernel_size = opts[:kernel_size]
strides = opts[:strides]
padding = opts[:padding]
input_dilation = opts[:input_dilation]
kernel_dilation = opts[:kernel_dilation]
channels = opts[:channels]
inner_rank = Nx.rank(parent_shape) - 2
kernel_size = tuple_or_duplicate(:kernel_size, kernel_size, inner_rank)
strides = list_or_duplicate(:strides, strides, inner_rank)
input_dilation = list_or_duplicate(:input_dilation, input_dilation, inner_rank)
kernel_dilation = list_or_duplicate(:kernel_dilation, kernel_dilation, inner_rank)
k1_shape =
Axon.Shape.separable_conv2d_kernel(
parent_shape,
channel_multiplier,
kernel_size,
1,
channels
)
k2_shape =
Axon.Shape.separable_conv2d_kernel(
parent_shape,
channel_multiplier,
kernel_size,
2,
channels
)
b1_shape =
Axon.Shape.separable_conv2d_bias(parent_shape, channel_multiplier, kernel_size, channels)
b2_shape =
Axon.Shape.separable_conv2d_bias(parent_shape, channel_multiplier, kernel_size, channels)
output_shape =
Axon.Shape.depthwise_conv(
parent_shape,
Axon.Shape.depthwise_conv_kernel(parent_shape, channel_multiplier, kernel_size, channels),
strides,
padding,
input_dilation,
kernel_dilation,
channels
)
kernel_initializer = opts[:kernel_initializer]
k1 = param("kernel_1", k1_shape, initializer: kernel_initializer)
k2 = param("kernel_2", k2_shape, initializer: kernel_initializer)
{inputs, op} =
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]
b1 = param("bias_1", b1_shape, initializer: bias_initializer)
b2 = param("bias_2", b2_shape, initializer: bias_initializer)
{[x, k1, b1, k2, b2], :separable_conv2d}
else
{[x, k1, k2], &Axon.Layers.separable_conv2d(&1, &2, 0, &3, 0, &4)}
end
node =
layer(
op,
inputs,
name: opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
kernel_dilation: kernel_dilation,
channels: channels,
shape: output_shape,
op_name: :separable_conv2d
)
if activation = opts[:activation] do
activation(node, activation)
else
node
end
end
@doc """
Adds a depthwise separable 3-dimensional convolution to the
network.
Depthwise separable convolutions break the kernel into kernels
for each dimension of the input and perform a depthwise conv
over the input with each kernel.
Compiles to `Axon.Layers.separable_conv3d/8`.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`
* `:activation` - element-wise activation function.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
to `1`.
* `:strides` - stride during convolution. Defaults to `1`.
* `:padding` - padding to the spatial dimensions of the input.
Defaults to `:valid`.
* `:input_dilation` - dilation to apply to input. Defaults to `1`.
* `:kernel_dilation` - dilation to apply to kernel. Defaults to `1`.
* `:channels` - channels location. One of `:first` or `:last`.
Defaults to `:first`.
"""
@doc type: :convolution
def separable_conv3d(%Axon{output_shape: parent_shape} = x, channel_multiplier, opts \\ [])
when is_integer(channel_multiplier) and channel_multiplier >= 1 do
opts =
Keyword.validate!(opts, [
:name,
:activation,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: true,
kernel_size: 1,
strides: 1,
padding: :valid,
input_dilation: 1,
kernel_dilation: 1,
channels: :first
])
kernel_size = opts[:kernel_size]
strides = opts[:strides]
padding = opts[:padding]
input_dilation = opts[:input_dilation]
kernel_dilation = opts[:kernel_dilation]
channels = opts[:channels]
inner_rank = Nx.rank(parent_shape) - 2
kernel_size = tuple_or_duplicate(:kernel_size, kernel_size, inner_rank)
strides = list_or_duplicate(:strides, strides, inner_rank)
input_dilation = list_or_duplicate(:input_dilation, input_dilation, inner_rank)
kernel_dilation = list_or_duplicate(:kernel_dilation, kernel_dilation, inner_rank)
k1_shape =
Axon.Shape.separable_conv3d_kernel(
parent_shape,
channel_multiplier,
kernel_size,
1,
channels
)
k2_shape =
Axon.Shape.separable_conv3d_kernel(
parent_shape,
channel_multiplier,
kernel_size,
2,
channels
)
k3_shape =
Axon.Shape.separable_conv3d_kernel(
parent_shape,
channel_multiplier,
kernel_size,
3,
channels
)
b1_shape =
Axon.Shape.separable_conv3d_bias(parent_shape, channel_multiplier, kernel_size, channels)
b2_shape =
Axon.Shape.separable_conv3d_bias(parent_shape, channel_multiplier, kernel_size, channels)
b3_shape =
Axon.Shape.separable_conv3d_bias(parent_shape, channel_multiplier, kernel_size, channels)
output_shape =
Axon.Shape.depthwise_conv(
parent_shape,
Axon.Shape.depthwise_conv_kernel(parent_shape, channel_multiplier, kernel_size, channels),
strides,
padding,
input_dilation,
kernel_dilation,
channels
)
kernel_initializer = opts[:kernel_initializer]
k1 = param("kernel_1", k1_shape, initializer: kernel_initializer)
k2 = param("kernel_2", k2_shape, initializer: kernel_initializer)
k3 = param("kernel_3", k3_shape, initializer: kernel_initializer)
{inputs, op} =
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]
b1 = param("bias_1", b1_shape, initializer: bias_initializer)
b2 = param("bias_2", b2_shape, initializer: bias_initializer)
b3 = param("bias_3", b3_shape, initializer: bias_initializer)
{[x, k1, b1, k2, b2, k3, b3], :separable_conv3d}
else
{[x, k1, k2, k3], &Axon.Layers.separable_conv3d(&1, &2, 0, &3, 0, &4, 0, &5)}
end
node =
layer(
op,
inputs,
name: opts[:name],
strides: strides,
padding: padding,
input_dilation: input_dilation,
kernel_dilation: kernel_dilation,
channels: channels,
shape: output_shape,
op_name: :separable_conv3d
)
if activation = opts[:activation] do
activation(node, activation)
else
node
end
end
@activation_layers [
{:celu, "Continuously-differentiable exponential linear unit", "a"},
{:elu, "Exponential linear unit", "an"},
{:exp, "Exponential", "an"},
{:gelu, "Gaussian error linear unit", "a"},
{:hard_sigmoid, "Hard sigmoid", "a"},
{:hard_silu, "Hard sigmoid weighted linear unit", "a"},
{:hard_tanh, "Hard hyperbolic tangent", "a"},
{:leaky_relu, "Leaky rectified linear unit", "a"},
{:linear, "Linear", "a"},
{:log_sigmoid, "Log-sigmoid", "a"},
{:log_softmax, "Log-softmax", "a"},
{:mish, "Mish", "a"},
{:relu, "Rectified linear unit", "a"},
{:relu6, "Rectified linear unit 6", "a"},
{:sigmoid, "Sigmoid", "a"},
{:silu, "Sigmoid weighted linear unit", "a"},
{:selu, "Scaled exponential linear unit", "a"},
{:softmax, "Softmax", "a"},
{:softplus, "Softplus", "a"},
{:softsign, "Softsign", "a"},
{:tanh, "Hyperbolic tangent", "a"}
]
@doc """
Adds an activation layer to the network.
Activation layers are element-wise functions typically called
after the output of another layer.
## Options
* `:name` - layer name.
"""
@doc type: :activation
def activation(x, activation, opts \\ [])
def activation(%Axon{output_shape: shape} = x, activation, opts) when is_atom(activation) do
opts = [shape: shape, op_name: activation] ++ opts
layer(activation, [x], opts)
end
def activation(%Axon{output_shape: shape} = x, activation, opts)
when is_function(activation, 1) do
layer(activation, [x], [shape: shape] ++ opts)
end
## Activation
for {activation, name, a_or_an} <- @activation_layers do
@doc """
Adds #{a_or_an} #{name} activation layer to the network.
See `Axon.Activations.#{Atom.to_string(activation)}/1` for more details.
## Options
* `:name` - layer name.
"""
@doc type: :activation
def unquote(activation)(%Axon{} = x, opts \\ []) do
activation(x, unquote(activation), opts)
end
end
## Dropout
@dropout_layers [
{:dropout, "Dropout", "a"},
{:feature_alpha_dropout, "Feature alpha dropout", "a"},
{:spatial_dropout, "Spatial dropout", "a"},
{:alpha_dropout, "Alpha dropout", "an"}
]
for {dropout, name, a_or_an} <- @dropout_layers do
@doc """
Adds #{a_or_an} #{name} layer to the network.
See `Axon.Layers.#{Atom.to_string(dropout)}/2` for more details.
## Options
* `:name` - layer name.
* `:rate` - dropout rate. Defaults to `0.5`.
"""
@doc type: :dropout
def unquote(dropout)(%Axon{} = x, opts \\ []) do
dropout(x, unquote(dropout), opts)
end
end
defp dropout(%Axon{output_shape: parent_shape} = x, dropout, opts) do
opts = Keyword.validate!(opts, [:name, rate: 0.5])
layer(dropout, [x],
name: opts[:name],
rate: opts[:rate],
shape: parent_shape,
op_name: dropout
)
end
## Pooling
@pooling_layers [
{:max_pool, "Max pool", "a"},
{:avg_pool, "Average pool", "an"},
{:lp_pool, "Power average pool", "a"}
]
for {pool, name, a_or_an} <- @pooling_layers do
@doc """
Adds #{a_or_an} #{name} layer to the network.
See `Axon.Layers.#{Atom.to_string(pool)}/2` for more details.
## Options
* `:name` - layer name.
* `:kernel_size` - size of the kernel spatial dimensions. Defaults
to `1`.
* `:strides` - stride during convolution. Defaults to size of kernel.
* `:padding` - padding to the spatial dimensions of the input.
Defaults to `:valid`.
* `:dilations` - window dilations. Defaults to `1`.
* `:channels` - channels location. One of `:first` or `:last`.
Defaults to `:first`.
"""
@doc type: :pooling
def unquote(pool)(%Axon{} = x, opts \\ []) do
pool(x, unquote(pool), opts)
end
end
defp pool(%Axon{output_shape: parent_shape} = x, pool, opts) do
opts =
Keyword.validate!(opts, [
:name,
:strides,
kernel_size: 1,
padding: :valid,
channels: :first,
dilations: 1,
norm: 2
])
kernel_size = opts[:kernel_size]
strides = opts[:strides]
padding = opts[:padding]
channels = opts[:channels]
dilations = opts[:dilations]
inner_rank = Nx.rank(parent_shape) - 2
kernel_size = tuple_or_duplicate(:kernel_size, kernel_size, inner_rank)
strides = if strides, do: strides, else: Tuple.to_list(kernel_size)
strides = list_or_duplicate(:strides, strides, inner_rank)
dilations = list_or_duplicate(:dilations, dilations, inner_rank)
output_shape =
Axon.Shape.pool(parent_shape, kernel_size, strides, padding, dilations, channels)
name = opts[:name]
opts =
if pool == :lp_pool do
norm = opts[:norm]
[
name: name,
kernel_size: kernel_size,
strides: strides,
padding: padding,
channels: channels,
window_dilations: dilations,
norm: norm,
shape: output_shape,
op_name: pool
]
else
[
name: name,
kernel_size: kernel_size,
strides: strides,
padding: padding,
channels: channels,
window_dilations: dilations,
shape: output_shape,
op_name: pool
]
end
layer(pool, [x], opts)
end
## Adaptive Pooling
@adaptive_pooling_layers [
{:adaptive_avg_pool, "Adaptive average pool", "an"},
{:adaptive_max_pool, "Adaptive max pool", "an"},
{:adaptive_lp_pool, "Adaptive power average pool", "an"}
]
for {pool, name, a_or_an} <- @adaptive_pooling_layers do
@doc """
Adds #{a_or_an} #{name} layer to the network.
See `Axon.Layers.#{Atom.to_string(pool)}/2` for more details.
## Options
* `:name` - layer name.
* `:output_size` - layer output size.
* `:channels` - channel configuration. One of `:first` or `:last`.
Defaults to `:first`.
"""
@doc type: :pooling
def unquote(pool)(%Axon{} = x, opts \\ []) do
adaptative_pool(x, unquote(pool), opts)
end
end
defp adaptative_pool(%Axon{output_shape: parent_shape} = x, pool, opts) do
opts = Keyword.validate!(opts, [:name, :output_size, channels: :first, norm: 2])
channels = opts[:channels]
idx =
if channels == :first do
1
else
Nx.rank(parent_shape) - 1
end
output_size =
if size = opts[:output_size] do
size
else
parent_shape
|> Tuple.delete_at(0)
|> Tuple.delete_at(idx - 1)
end
inner_rank = Nx.rank(parent_shape) - 2
output_size = tuple_or_duplicate(:output_size, output_size, inner_rank)
output_shape = Axon.Shape.adaptive_pool(parent_shape, output_size, channels)
name = opts[:name]
opts =
if pool == :adaptive_lp_pool do
norm = opts[:norm]
[
name: name,
output_size: output_size,
norm: norm,
channels: channels,
shape: output_shape,
op_name: pool
]
else
[
name: name,
output_size: output_size,
channels: channels,
shape: output_shape,
op_name: pool
]
end
layer(pool, [x], opts)
end
## Global Pooling
@global_pooling_layers [
{:global_avg_pool, "Global average pool"},
{:global_max_pool, "Global max pool"},
{:global_lp_pool, "Global LP pool"}
]
for {pool, name} <- @global_pooling_layers do
@doc """
Adds a #{name} layer to the network.
See `Axon.Layers.#{Atom.to_string(pool)}/2` for more details.
Typically used to connect feature extractors such as those in convolutional
neural networks to fully-connected models by reducing inputs along spatial
dimensions to only feature and batch dimensions.
## Options
* `:name` - layer name.
* `:keep_axes` - option to keep reduced axes. If `true`, keeps reduced axes
with a dimension size of 1.
* `:channels` - channel configuration. One of `:first` or `:last`.
Defaults to `:first`.
"""
@doc type: :pooling
def unquote(pool)(%Axon{} = x, opts \\ []) do
global_pool(x, unquote(pool), opts)
end
end
defp global_pool(%Axon{output_shape: parent_shape} = x, pool, opts) do
opts = Keyword.validate!(opts, [:name, keep_axes: false, channels: :first, norm: 2])
keep_axes = opts[:keep_axes]
name = opts[:name]
channels = opts[:channels]
output_shape = Axon.Shape.global_pool(parent_shape, keep_axes, channels)
opts =
if pool == :global_lp_pool do
norm = opts[:norm]
[
name: name,
channels: channels,
keep_axes: keep_axes,
norm: norm,
shape: output_shape,
op_name: pool
]
else
[name: name, channels: channels, keep_axes: keep_axes, shape: output_shape, op_name: pool]
end
layer(pool, [x], opts)
end
## Normalization
@normalization_with_stats_layers [
{:batch_norm, "Batch normalization", "a"},
{:instance_norm, "Instance normalization", "an"}
]
for {norm, name, a_or_an} <- @normalization_with_stats_layers do
@doc """
Adds #{a_or_an} #{name} layer to the network.
See `Axon.Layers.#{Atom.to_string(norm)}/6` for more details.
## Options
* `:name` - layer name.
* `:gamma_initializer` - gamma parameter initializer. Defaults
to `:glorot_uniform`.
* `:beta_initializer` - beta parameter initializer. Defaults to
`:zeros`.
* `:channel_index` - input feature index used for calculating
mean and variance. Defaults to `1`.
* `:epsilon` - numerical stability term.
"""
@doc type: :normalization
def unquote(norm)(%Axon{} = x, opts \\ []) do
norm_with_stats(x, unquote(norm), opts)
end
end
defp norm_with_stats(%Axon{output_shape: shape} = x, norm, opts) do
opts =
Keyword.validate!(opts, [
:name,
gamma_initializer: :glorot_uniform,
beta_initializer: :zeros,
channel_index: 1,
epsilon: 1.0e-5,
momentum: 0.1
])
channel_index = opts[:channel_index]
gamma_shape = Axon.Shape.norm_param(shape, channel_index)
beta_shape = Axon.Shape.norm_param(shape, channel_index)
mean_shape = Axon.Shape.norm_param(shape, channel_index)
var_shape = Axon.Shape.norm_param(shape, channel_index)
gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer])
beta = param("beta", beta_shape, initializer: opts[:beta_initializer])
mean = param("mean", mean_shape, initializer: :zeros)
var = param("var", var_shape, initializer: :ones)
layer(
norm,
[x, gamma, beta, mean, var],
name: opts[:name],
epsilon: opts[:epsilon],
channel_index: channel_index,
momentum: opts[:momentum],
shape: shape,
op_name: norm
)
end
@normalization_layers [
{:layer_norm, "Layer normalization", "a"}
]
for {norm, name, a_or_an} <- @normalization_layers do
@doc """
Adds #{a_or_an} #{name} layer to the network.
See `Axon.Layers.#{Atom.to_string(norm)}/4` for more details.
## Options
* `:name` - layer name.
* `:gamma_initializer` - gamma parameter initializer. Defaults
to `:glorot_uniform`.
* `:beta_initializer` - beta parameter initializer. Defaults to
`:zeros`.
* `:channel_index` - input feature index used for calculating
mean and variance. Defaults to `1`.
* `:epsilon` - numerical stability term.
"""
@doc type: :normalization
def unquote(norm)(%Axon{} = x, opts \\ []) do
norm(x, unquote(norm), opts)
end
end
defp norm(%Axon{output_shape: shape} = x, norm, opts) do
opts =
Keyword.validate!(opts, [
:name,
gamma_initializer: :glorot_uniform,
beta_initializer: :zeros,
channel_index: 1,
epsilon: 1.0e-5
])
channel_index = opts[:channel_index]
gamma_shape = Axon.Shape.norm_param(shape, channel_index)
beta_shape = Axon.Shape.norm_param(shape, channel_index)
gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer])
beta = param("beta", beta_shape, initializer: opts[:beta_initializer])
layer(norm, [x, gamma, beta],
name: opts[:name],
epsilon: opts[:epsilon],
channel_index: channel_index,
shape: shape,
op_name: norm
)
end
@doc """
Adds a group normalization layer to the network.
See `Axon.Layers.group_norm/4` for more details.
## Options
* `:name` - layer name.
* `:gamma_initializer` - gamma parameter initializer. Defaults
to `:glorot_uniform`.
* `:beta_initializer` - beta parameter initializer. Defaults to
`:zeros`.
* `:channel_index` - input feature index used for calculating
mean and variance. Defaults to `1`.
* `:epsilon` - numerical stability term.
"""
@doc type: :normalization
def group_norm(%Axon{output_shape: shape} = x, group_size, opts \\ [])
when is_integer(group_size) and group_size >= 1 do
opts =
Keyword.validate!(opts, [
:name,
gamma_initializer: :glorot_uniform,
beta_initializer: :zeros,
channel_index: 1,
epsilon: 1.0e-5
])
channel_index = opts[:channel_index]
gamma_shape = Axon.Shape.norm_param(shape, channel_index)
beta_shape = Axon.Shape.norm_param(shape, channel_index)
gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer])
beta = param("beta", beta_shape, initializer: opts[:beta_initializer])
layer(:group_norm, [x, gamma, beta],
name: opts[:name],
epsilon: opts[:epsilon],
channel_index: channel_index,
group_size: group_size,
shape: shape,
op_name: :group_norm
)
end
@doc """
Applies the given `Nx` expression to the input.
Nx layers are meant for quick applications of functions without
trainable parameters. For example, they are useful for applying
functions which apply accessors to containers:
model = Axon.container({foo, bar})
Axon.nx(model, &elem(&1, 0))
## Options
* `:name` - layer name.
"""
def nx(input, fun, opts \\ [])
@doc type: :special
def nx(%Axon{output_shape: input_shape} = x, fun, opts) when is_function(fun, 1) do
opts = Keyword.validate!(opts, [:name])
{name, opts} = Keyword.pop(opts, :name)
fun_with_params = fn x, _opts -> fun.(x) end
output_shape = infer_shape([input_shape], fun_with_params, opts)
layer(fun_with_params, [x], name: name, shape: output_shape, op_name: :nx)
end
@doc """
Adds a flatten layer to the network.
This layer will flatten all but the batch dimensions
of the input into a single layer. Typically called to flatten
the output of a convolution for use with a dense layer.
## Options
* `:name` - layer name.
* `:ignore_batch?` - whether to ignore batch dimension in
transpose operation. Defaults to `true`.
"""
@doc type: :shape
def flatten(%Axon{op: op, output_shape: shape} = x, opts \\ []) do
opts = Keyword.validate!(opts, [:name, ignore_batch?: op != :constant])
ignore_batch? = opts[:ignore_batch?]
output_shape = Axon.Shape.flatten(shape, ignore_batch?)
layer(:flatten, [x],
name: opts[:name],
ignore_batch?: ignore_batch?,
shape: output_shape,
op_name: :flatten
)
end
@doc """
Adds a reshape layer to the network.
This layer implements a special case of `Nx.reshape` which accounts
for possible batch dimensions in the input tensor. If the input contains
batch dimensions, the reshape operation is performed on all non-batch
dimensions of the input - preserving the original batch size.
If the input is an Axon constant, the reshape behavior matches that of
`Nx.reshape`.
## Options
* `:name` - layer name.
* `:ignore_batch?` - whether to ignore batch dimension in transpose
operation. Defaults to `true`.
"""
@doc type: :shape
def reshape(%Axon{op: op, output_shape: shape} = x, new_shape, opts \\ []) do
opts = Keyword.validate!(opts, [:name, ignore_batch?: op != :constant])
ignore_batch? = opts[:ignore_batch?]
output_shape = Axon.Shape.reshape(shape, new_shape, ignore_batch?)
layer(:reshape, [x],
name: opts[:name],
ignore_batch?: ignore_batch?,
shape: output_shape,
to: output_shape,
op_name: :reshape
)
end
@doc """
Adds a transpose layer to the network.
## Options
* `:name` - layer name.
* `:ignore_batch?` - whether to ignore batch dimension in transpose
operation. Defaults to true.
"""
@doc type: :shape
def transpose(%Axon{op: op, output_shape: shape} = x, permutation, opts \\ []) do
opts = Keyword.validate!(opts, [:name, ignore_batch?: op != :constant])
ignore_batch? = opts[:ignore_batch?]
output_shape = Axon.Shape.transpose(shape, permutation, ignore_batch?)
layer(:transpose, [x],
name: opts[:name],
axes: permutation,
ignore_batch?: ignore_batch?,
shape: output_shape,
op_name: :transpose
)
end
@doc """
Adds a pad layer to the network.
This layer will pad the spatial dimensions of the input.
Padding configuration is a list of tuples for each spatial
dimension.
## Options
* `:name` - layer name.
* `:channels` - channel configuration. One of `:first` or
`:last`. Defaults to `:first`.
"""
@doc type: :shape
def pad(%Axon{output_shape: shape} = x, config, value \\ 0.0, opts \\ [])
when is_list(config) and is_number(value) do
opts = Keyword.validate!(opts, [:name, channels: :first])
channels = opts[:channels]
output_shape = Axon.Shape.pad(shape, config)
layer(:pad, [x],
name: opts[:name],
padding_config: config,
value: value,
channels: channels,
shape: output_shape,
op_name: :pad
)
end
@doc """
Adds a resize layer to the network.
Resizing can be used for interpolation or upsampling input
values in a neural network. For example, you can use this
layer as an upsampling layer within a GAN.
Resize shape must be a tuple representing the resized spatial
dimensions of the input tensor.
Compiles to `Axon.Layers.resize/2`.
## Options
* `:name` - layer name.
* `:method` - resize method. Defaults to `:nearest`.
* `:channels` - channel configuration. One of `:first` or
`:last`. Defaults to `:first`.
"""
@doc type: :shape
def resize(%Axon{output_shape: shape} = x, resize_shape, opts \\ []) do
opts = Keyword.validate!(opts, [:name, method: :nearest, channels: :first])
channels = opts[:channels]
output_shape = Axon.Shape.resize(shape, resize_shape, channels)
layer(:resize, [x],
name: opts[:name],
method: opts[:method],
channels: channels,
shape: output_shape,
to: resize_shape,
op_name: :resize
)
end
@doc """
Adds a concatenate layer to the network.
This layer will concatenate inputs along the last
dimension unless specified otherwise.
## Options
* `:name` - layer name.
* `:axis` - concatenate axis. Defaults to `-1`.
"""
@doc type: :combinator
def concatenate(%Axon{output_shape: x_shape} = x, %Axon{output_shape: y_shape} = y, opts)
when is_list(opts) do
opts = Keyword.validate!(opts, [:name, axis: -1])
axis = opts[:axis]
output_shape = Axon.Shape.concatenate([x_shape, y_shape], axis)
layer(:concatenate, [container({x, y})],
name: opts[:name],
axis: axis,
shape: output_shape,
op_name: :concatenate
)
end
@doc type: :combinator
def concatenate([%Axon{} | _] = inputs, opts)
when is_list(inputs) and is_list(opts) do
opts = Keyword.validate!(opts, [:name, axis: -1])
axis = opts[:axis]
input_shapes = inputs |> Enum.map(fn %Axon{output_shape: shape} -> shape end)
output_shape = Axon.Shape.concatenate(input_shapes, axis)
layer(:concatenate, [container(List.to_tuple(inputs))],
name: opts[:name],
axis: axis,
shape: output_shape,
op_name: :concatenate
)
end
@doc false
def concatenate(%Axon{} = x, %Axon{} = y), do: concatenate(x, y, [])
@doc false
def concatenate(inputs) when is_list(inputs), do: concatenate(inputs, [])
@element_wise_layers [:add, :subtract, :multiply]
for op <- @element_wise_layers do
@doc """
Adds a #{op} layer to the network.
This layer performs an element-wise #{Atom.to_string(op)} operation
on input layers. All input layers must be capable of being
broadcast together.
If one shape has a static batch size, all other shapes must have a
static batch size as well.
## Options
* `:name` - layer name.
"""
@doc type: :combinator
def unquote(op)(%Axon{output_shape: lhs_shape} = x, %Axon{output_shape: rhs_shape} = y, opts) do
opts = Keyword.validate!(opts, [:name])
output_shape = Axon.Shape.element_wise([lhs_shape, rhs_shape])
layer(unquote(op), [container({x, y})],
name: opts[:name],
shape: output_shape,
op_name: unquote(op)
)
end
@doc """
Adds a #{op} layer to the network.
This layer performs an element-wise #{Atom.to_string(op)} operation
on all input layers. All input layers must be capable of being
broadcast together.
## Options
* `:name` - layer name.
"""
@doc type: :combinator
def unquote(op)(inputs, opts) when is_list(inputs) and is_list(opts) do
opts = Keyword.validate!(opts, [:name])
shapes =
Enum.map(inputs, fn
%Axon{output_shape: shape} -> shape
invalid -> raise ArgumentError, "invalid input #{inspect(invalid)}"
end)
output_shape = Axon.Shape.element_wise(shapes)
layer(unquote(op), [container(List.to_tuple(inputs))],
name: opts[:name],
shape: output_shape,
op_name: unquote(op)
)
end
@doc false
def unquote(op)(%Axon{} = x, %Axon{} = y) do
unquote(op)(x, y, [])
end
@doc false
def unquote(op)([%Axon{} | _] = inputs), do: unquote(op)(inputs, [])
end
@doc """
Adds a conditional layer which conditionally executes
`true_graph` or `false_graph` based on the condition `cond_fn`
at runtime.
`cond_fn` is an arity-1 function executed on the output of the
parent graph. It must return a boolean scalar tensor (e.g. 1 or 0).
The shapes of `true_graph` and `false_graph` must be equal.
"""
@doc type: :combinator
def cond(
%Axon{} = parent,
cond_fn,
%Axon{output_shape: out_shape} = true_graph,
%Axon{output_shape: out_shape} = false_graph,
opts \\ []
)
when is_function(cond_fn, 1) do
opts = Keyword.validate!(opts, [:name])
layer(:cond, [parent, true_graph, false_graph],
name: opts[:name],
cond: cond_fn,
shape: out_shape,
op_name: :cond
)
end
@doc """
Splits input graph into a container of `n` input graphs
along the given axis.
## Options
* `:name` - layer name.
* `:axis` - concatenate axis. Defaults to `-1`.
"""
@doc type: :combinator
def split(parent, splits, opts \\ [])
def split(%Axon{} = parent, splits, opts) when is_list(splits) do
opts = Keyword.validate!(opts, [:name, axis: -1])
axis = opts[:axis]
{_, split_layers} =
for {split, i} <- Enum.with_index(splits), reduce: {0, []} do
{num_split, split_layers} ->
name =
case opts[:name] do
names when is_list(names) ->
Enum.at(names, i)
name ->
name
end
layer =
layer(
fn x, _ -> Nx.slice_along_axis(x, num_split, split, axis: axis) end,
[parent],
name: name,
op_name: :split
)
{num_split + split, [layer | split_layers]}
end
split_layers |> Enum.reverse() |> List.to_tuple()
end
def split(%Axon{output_shape: shape} = parent, n, opts) when is_integer(n) do
opts = Keyword.validate!(opts, [:name, axis: -1])
axis = opts[:axis]
{slice_size, split_shape} = Axon.Shape.split(shape, n, axis)
splits =
for i <- 0..(n - 1) do
name =
case opts[:name] do
names when is_list(names) ->
Enum.at(names, i)
name ->
name
end
layer(
fn x, _ -> Nx.slice_along_axis(x, i * slice_size, slice_size, axis: axis) end,
[parent],
name: name,
shape: split_shape,
op_name: :split
)
end
List.to_tuple(splits)
end
@doc """
See `lstm/3`.
"""
@doc type: :recurrent
def lstm(%Axon{} = x, units) when is_integer(units) and units > 0 do
lstm(x, units, [])
end
@doc """
Adds a long short-term memory (LSTM) layer to the network
with a random initial hidden state.
See `lstm/4` for more details.
## Additional options
* `:recurrent_initializer` - initializer for hidden state.
Defaults to `:glorot_uniform`.
"""
@doc type: :recurrent
def lstm(%Axon{output_shape: shape} = x, units, opts)
when is_integer(units) and units > 0 and is_list(opts) do
{recurrent_initializer, opts} = Keyword.pop(opts, :recurrent_initializer, :glorot_uniform)
c = rnn_state(x, shape, units, :lstm, opts[:name], "c", recurrent_initializer)
h = rnn_state(x, shape, units, :lstm, opts[:name], "h", recurrent_initializer)
lstm(x, {c, h}, units, opts)
end
def lstm(%Axon{} = x, {%Axon{}, %Axon{}} = hidden_state, units)
when is_integer(units) and units > 0 do
lstm(x, hidden_state, units, [])
end
@doc """
Adds a long short-term memory (LSTM) layer to the network
with the given initial hidden state.
LSTMs apply `Axon.Recurrent.lstm_cell/7` over an entire input
sequence and return:
{{new_cell, new_hidden}, output_sequence}
You can use the output state as the hidden state of another
LSTM layer.
## Options
* `:name` - layer name.
* `:activation` - recurrent activation. Defaults to `:tanh`.
* `:gate` - recurrent gate function. Defaults to `:sigmoid`.
* `:unroll` - `:dynamic` (loop preserving) or `:static` (compiled)
unrolling of RNN.
* `:kernel_initializer` - initializer for kernel weights. Defaults
to `:glorot_uniform`.
* `:bias_initializer` - initializer for bias weights. Defaults to
`:zeros`.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`.
"""
@doc type: :recurrent
def lstm(
%Axon{output_shape: shape} = x,
{%Axon{output_shape: h_shape}, %Axon{output_shape: h_shape}} = hidden_state,
units,
opts \\ []
)
when is_integer(units) and units > 0 and is_list(opts) do
opts =
Keyword.validate!(opts, [
:name,
activation: :tanh,
gate: :sigmoid,
unroll: :dynamic,
use_bias: true,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros
])
activation = opts[:activation]
gate = opts[:gate]
unroll = opts[:unroll]
output_shape = Axon.Shape.rnn(shape, units, :lstm)
input_kernel_shape = Axon.Shape.rnn_input_kernel(shape, units, :lstm)
hidden_kernel_shape = Axon.Shape.rnn_hidden_kernel(shape, units, :lstm)
bias_shape = Axon.Shape.rnn_bias(shape, units, :lstm)
kernel_initializer = opts[:kernel_initializer]
# Parameters
input_kernel =
param("input_kernel", {:tuple, List.duplicate(input_kernel_shape, 4)},
initializer: kernel_initializer
)
hidden_kernel =
param("hidden_kernel", {:tuple, List.duplicate(hidden_kernel_shape, 4)},
initializer: kernel_initializer
)
hidden_state_name =
case opts[:name] do
nil ->
fn _, op_counts ->
"lstm_#{op_counts[:lstm]}_hidden_state"
end
name when is_binary(name) ->
"#{name}_hidden_state"
end
hidden_state = Axon.container(hidden_state, name: hidden_state_name)
{inputs, op} =
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]
bias =
param("bias", {:tuple, List.duplicate(bias_shape, 4)}, initializer: bias_initializer)
{[x, hidden_state, input_kernel, hidden_kernel, bias], :lstm}
else
{[x, hidden_state, input_kernel, hidden_kernel], &Axon.Layers.lstm(&1, &2, &3, &4, 0, &5)}
end
output =
layer(
op,
inputs,
name: opts[:name],
activation: activation,
gate: gate,
unroll: unroll,
shape: {{h_shape, h_shape}, output_shape},
op_name: :lstm
)
new_c_name =
case opts[:name] do
nil ->
fn _, op_counts ->
"lstm_#{op_counts[:lstm]}_c_hidden_state"
end
name when is_binary(name) ->
"#{name}_c_hidden_state"
end
new_h_name =
case opts[:name] do
nil ->
fn _, op_counts ->
"lstm_#{op_counts[:lstm]}_h_hidden_state"
end
name when is_binary(name) ->
"#{name}_h_hidden_state"
end
output_sequence_name =
case opts[:name] do
nil ->
fn _, op_counts ->
"lstm_#{op_counts[:lstm]}_output_sequence"
end
name when is_binary(name) ->
"#{name}_output_sequence"
end
new_c =
layer(fn x, _ -> elem(elem(x, 0), 0) end, [output],
name: new_c_name,
shape: h_shape,
op_name: :elem
)
new_h =
layer(fn x, _ -> elem(elem(x, 0), 1) end, [output],
name: new_h_name,
shape: h_shape,
op_name: :elem
)
output_sequence =
layer(fn x, _ -> elem(x, 1) end, [output],
name: output_sequence_name,
shape: output_shape,
op_name: :elem
)
{{new_c, new_h}, output_sequence}
end
@doc """
See `gru/3`.
"""
@doc type: :recurrent
def gru(%Axon{} = x, units) do
gru(x, units, [])
end
@doc """
Adds a gated recurrent unit (GRU) layer to the network with
a random initial hidden state.
See `gru/4` for more details.
## Additional options
* `:recurrent_initializer` - initializer for hidden state.
Defaults to `:glorot_uniform`.
"""
@doc type: :recurrent
def gru(%Axon{output_shape: shape} = x, units, opts)
when is_integer(units) and units > 0
when is_list(opts) do
{recurrent_initializer, opts} = Keyword.pop(opts, :recurrent_initializer, :glorot_uniform)
h = rnn_state(x, shape, units, :gru, opts[:name], "h", recurrent_initializer)
gru(x, {h}, units, opts)
end
def gru(%Axon{} = x, {%Axon{}} = hidden_state, units) when is_integer(units) and units > 0 do
gru(x, hidden_state, units, [])
end
@doc """
Adds a gated recurrent unit (GRU) layer to the network with
the given initial hidden state.
GRUs apply `Axon.Recurrent.gru_cell/7` over an entire input
sequence and return:
{{new_hidden}, output_sequence}
You can use the output state as the hidden state of another
GRU layer.
## Options
* `:name` - layer name.
* `:activation` - recurrent activation. Defaults to `:tanh`.
* `:gate` - recurrent gate function. Defaults to `:sigmoid`.
* `:unroll` - `:dynamic` (loop preserving) or `:static` (compiled)
unrolling of RNN.
* `:kernel_initializer` - initializer for kernel weights. Defaults
to `:glorot_uniform`.
* `:bias_initializer` - initializer for bias weights. Defaults to
`:zeros`.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`.
"""
@doc type: :recurrent
def gru(
%Axon{output_shape: shape} = x,
{%Axon{output_shape: h_shape}} = hidden_state,
units,
opts
)
when is_integer(units) and units > 0 and is_list(opts) do
opts =
Keyword.validate!(opts, [
:name,
activation: :tanh,
gate: :sigmoid,
unroll: :dynamic,
use_bias: true,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros
])
activation = opts[:activation]
gate = opts[:gate]
unroll = opts[:unroll]
output_shape = Axon.Shape.rnn(shape, units, :gru)
input_kernel_shape = Axon.Shape.rnn_input_kernel(shape, units, :gru)
hidden_kernel_shape = Axon.Shape.rnn_hidden_kernel(shape, units, :gru)
bias_shape = Axon.Shape.rnn_bias(shape, units, :gru)
kernel_initializer = opts[:kernel_initializer]
input_kernel =
param("input_kernel", {:tuple, List.duplicate(input_kernel_shape, 3)},
initializer: kernel_initializer
)
hidden_kernel =
param("hidden_kernel", {:tuple, List.duplicate(hidden_kernel_shape, 3)},
initializer: kernel_initializer
)
hidden_state_name =
case opts[:name] do
nil ->
fn _, op_counts ->
"gru_#{op_counts[:gru]}_hidden_state"
end
name when is_binary(name) ->
"#{name}_hidden_state"
end
hidden_state = Axon.container(hidden_state, name: hidden_state_name)
inputs =
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]
bias =
param("bias", {:tuple, List.duplicate(bias_shape, 4)}, initializer: bias_initializer)
[x, hidden_state, input_kernel, hidden_kernel, bias]
else
[x, hidden_state, input_kernel, hidden_kernel]
end
output =
layer(
:gru,
inputs,
name: opts[:name],
activation: activation,
gate: gate,
unroll: unroll,
shape: {{h_shape}, output_shape},
op_name: :gru
)
new_h_name =
case opts[:name] do
nil ->
fn _, op_counts ->
"gru_#{op_counts[:gru]}_hidden_state"
end
name when is_binary(name) ->
"#{name}_hidden_state"
end
output_sequence_name =
case opts[:name] do
nil ->
fn _, op_counts ->
"gru_#{op_counts[:gru]}_output_sequence"
end
name when is_binary(name) ->
"#{name}_output_sequence"
end
new_h =
layer(fn x, _ -> elem(elem(x, 0), 0) end, [output],
name: new_h_name,
shape: h_shape,
op_name: :elem
)
output_sequence =
layer(fn x, _ -> elem(x, 1) end, [output],
name: output_sequence_name,
shape: output_shape,
op_name: :elem
)
{{new_h}, output_sequence}
end
@doc """
See `conv_lstm/3`.
"""
@doc type: :recurrent
def conv_lstm(%Axon{} = x, units) when is_integer(units) and units > 0 do
conv_lstm(x, units, [])
end
@doc """
Adds a convolutional long short-term memory (LSTM) layer to the network
with a random initial hidden state.
See `conv_lstm/4` for more details.
## Additional options
* `:recurrent_initializer` - initializer for hidden state. Defaults
to `:glorot_uniform`.
"""
@doc type: :recurrent
def conv_lstm(%Axon{output_shape: shape} = x, units, opts)
when is_integer(units) and units > 0 and is_list(opts) do
{recurrent_initializer, opts} = Keyword.pop(opts, :recurrent_initializer, :glorot_uniform)
c = rnn_state(x, shape, units, :conv_lstm, opts[:name], "c", recurrent_initializer)
h = rnn_state(x, shape, units, :conv_lstm, opts[:name], "h", recurrent_initializer)
conv_lstm(x, {c, h}, units, opts)
end
def conv_lstm(%Axon{} = x, {%Axon{}, %Axon{}} = hidden_state, units)
when is_integer(units) and units > 0 do
conv_lstm(x, hidden_state, units, [])
end
@doc """
Adds a convolutional long short-term memory (LSTM) layer to the network
with the given initial hidden state..
ConvLSTMs apply `Axon.Recurrent.conv_lstm_cell/5` over an entire input
sequence and return:
{{new_cell, new_hidden}, output_sequence}
You can use the output state as the hidden state of another
ConvLSTM layer.
## Options
* `:name` - layer name.
* `:padding` - convolutional padding. Defaults to `:same`.
* `:kernel_size` - convolutional kernel size. Defaults to `1`.
* `:strides` - convolutional strides. Defaults to `1`.
* `:unroll` - `:dynamic` (loop preserving) or `:static` (compiled)
unrolling of RNN.
* `:kernel_initializer` - initializer for kernel weights. Defaults
to `:glorot_uniform`.
* `:bias_initializer` - initializer for bias weights. Defaults to
`:zeros`.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`.
"""
@doc type: :recurrent
def conv_lstm(
%Axon{output_shape: shape} = x,
{%Axon{output_shape: h_shape}, %Axon{output_shape: h_shape}} = hidden_state,
units,
opts
)
when is_integer(units) and units > 0 and is_list(opts) do
opts =
Keyword.validate!(opts, [
:name,
padding: :same,
kernel_size: 1,
strides: 1,
unroll: :dynamic,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: true
])
padding = opts[:padding]
kernel_size = opts[:kernel_size]
strides = opts[:strides]
unroll = opts[:unroll]
inner_rank = Nx.rank(shape) - 3
sequence_length = elem(shape, 1)
kernel_size = tuple_or_duplicate(:kernel_size, kernel_size, inner_rank)
strides = list_or_duplicate(:strides, strides, inner_rank)
input_dilation = List.duplicate(1, inner_rank)
kernel_dilation = List.duplicate(1, inner_rank)
conv_shape = Tuple.delete_at(shape, 1)
conv_hidden_state_shape = Tuple.delete_at(h_shape, 1)
hidden_kernel_shape =
Axon.Shape.conv_kernel(conv_hidden_state_shape, 4 * units, kernel_size, :first)
input_kernel_shape = Axon.Shape.conv_kernel(conv_shape, 4 * units, kernel_size, :first)
bias_shape = Axon.Shape.conv_bias(conv_shape, 4 * units, kernel_size, :first)
output_kernel_shape =
Axon.Shape.conv_kernel(conv_hidden_state_shape, units, kernel_size, :first)
output_shape =
conv_hidden_state_shape
|> Axon.Shape.conv(
output_kernel_shape,
strides,
padding,
input_dilation,
kernel_dilation,
:first,
1
)
|> Tuple.insert_at(1, sequence_length)
kernel_initializer = opts[:kernel_initializer]
wi = param("input_kernel", {:tuple, [input_kernel_shape]}, initializer: kernel_initializer)
wh = param("hidden_kernel", {:tuple, [hidden_kernel_shape]}, initializer: kernel_initializer)
hidden_state_name =
case opts[:name] do
nil ->
fn _, op_counts ->
"conv_lstm_#{op_counts[:conv_lstm]}_hidden_state"
end
name when is_binary(name) ->
"#{name}_hidden_state"
end
hidden_state = Axon.container(hidden_state, name: hidden_state_name)
{inputs, op} =
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]
b = param("bias", {:tuple, [bias_shape]}, initializer: bias_initializer)
{[x, hidden_state, wi, wh, b], :conv_lstm}
else
{[x, hidden_state, wi, wh], &Axon.Layers.conv_lstm(&1, &2, &3, &4, {0}, &5)}
end
output =
layer(
op,
inputs,
name: opts[:name],
conv_opts: [
strides: strides,
padding: padding
],
unroll: unroll,
shape: output_shape,
op_name: :conv_lstm
)
new_c_name =
case opts[:name] do
nil ->
fn _, op_counts ->
"conv_lstm_#{op_counts[:lstm]}_c_hidden_state"
end
name when is_binary(name) ->
"#{name}_c_hidden_state"
end
new_h_name =
case opts[:name] do
nil ->
fn _, op_counts ->
"conv_lstm_#{op_counts[:lstm]}_h_hidden_state"
end
name when is_binary(name) ->
"#{name}_h_hidden_state"
end
output_sequence_name =
case opts[:name] do
nil ->
fn _, op_counts ->
"conv_lstm_#{op_counts[:lstm]}_output_sequence"
end
name when is_binary(name) ->
"#{name}_output_sequence"
end
new_c =
layer(fn x, _ -> elem(elem(x, 0), 0) end, [output],
name: new_c_name,
shape: h_shape,
op_name: :elem
)
new_h =
layer(fn x, _ -> elem(elem(x, 0), 1) end, [output],
name: new_h_name,
shape: h_shape,
op_name: :elem
)
output_sequence =
layer(fn x, _ -> elem(x, 1) end, [output],
name: output_sequence_name,
shape: output_shape,
op_name: :elem
)
{{new_c, new_h}, output_sequence}
end
defp rnn_state(x, shape, units, rnn_type, parent_name, state_name, initializer) do
initializer = initializer || :glorot_uniform
name =
case parent_name do
nil ->
fn _, op_counts ->
"lstm_#{op_counts[rnn_type]}_#{state_name}_hidden_state"
end
parent_name when is_binary(parent_name) ->
"#{parent_name}_#{state_name}_hidden_state"
end
shape = Axon.Shape.rnn_hidden_state(shape, units, rnn_type)
fun = fn inputs, _opts ->
shape = put_elem(shape, 0, elem(Nx.shape(inputs), 0))
case initializer do
fun when is_function(fun) ->
fun.(shape)
fun when is_atom(fun) ->
fun = apply(Axon.Initializers, fun, [])
fun.(shape, {:f, 32})
end
end
layer(fun, [x], name: name, op_name: :recurrent_state)
end
@doc """
Adds an embedding layer to the network.
An embedding layer initializes a kernel of shape `{vocab_size, embedding_size}`
which acts as a lookup table for sequences of discrete tokens (e.g. sentences).
Embeddings are typically used to obtain a dense representation of a sparse input
space.
## Options
* `:name` - layer name.
* `:kernel_initializer` - initializer for `kernel` weights. Defaults
to `:uniform`.
"""
@doc type: :linear
def embedding(%Axon{output_shape: shape} = x, vocab_size, embedding_size, opts \\ []) do
opts = Keyword.validate!(opts, [:name, kernel_initializer: :uniform])
kernel_shape = Axon.Shape.embedding_kernel(shape, vocab_size, embedding_size)
output_shape = Axon.Shape.embedding(shape, vocab_size, embedding_size)
kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer])
layer(:embedding, [x, kernel], name: opts[:name], shape: output_shape, op_name: :embedding)
end
@doc """
Adds a bias layer to the network.
A bias layer simply adds a trainable bias to an input.
## Options
* `:name` - layer name.
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`.
"""
@doc type: :linear
def bias(%Axon{output_shape: shape} = x, opts \\ []) do
opts = Keyword.validate!(opts, [:name, bias_initializer: :zeros])
units = elem(shape, tuple_size(shape) - 1)
bias_shape = Axon.Shape.dense_bias(shape, units)
bias = param("bias", bias_shape, initializer: opts[:bias_initializer])
layer(:bias, [x, bias], name: opts[:name], shape: shape, op_name: :bias)
end
@doc """
Freezes parameters returned from `fun` in the given
model. `fun` takes the model's parameter list and returns
the list of parameters it wishes to freeze. `fun` defaults
to the identity function, freezing all of the parameters in
`model`.
Freezing parameters is useful when performing transfer learning
to leverage features learned from another problem in a new problem.
For example, it's common to combine the convolutional base from
larger models trained on ImageNet with fresh fully-connected classifiers.
The combined model is then trained on fresh data, with the convolutional
base frozen so as not to lose information. You can see this example in code
here:
cnn_base = get_pretrained_cnn_base()
model =
cnn_base
|> Axon.freeze()
|> Axon.flatten()
|> Axon.dense(1024, activation: :relu)
|> Axon.dropout()
|> Axon.dense(1000, activation: :softmax)
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.005))
|> Axon.Loop.run(data, epochs: 10)
When compiled, frozen parameters are wrapped in `Nx.Defn.Kernel.stop_grad/1`,
which zeros out the gradient with respect to the frozen parameter. Gradients
of frozen parameters will return `0.0`, meaning they won't be changed during
the update process.
"""
def freeze(%Axon{} = model, fun \\ & &1) when is_function(fun, 1) do
parameters =
tree_reduce(model, MapSet.new(), fn %Axon{parameters: params}, acc ->
Enum.reduce(params, acc, fn param, acc ->
MapSet.put(acc, param)
end)
end)
parameters_to_freeze = fun.(Enum.to_list(parameters))
tree_map(model, fn %Axon{parameters: params} = axon ->
frozen_params =
Enum.map(params, fn %{name: param_name} = v ->
if Enum.any?(parameters_to_freeze, fn %{name: name} -> name == param_name end) do
%{v | frozen: true}
else
v
end
end)
%{axon | parameters: frozen_params}
end)
end
@doc """
Attaches a hook to the given Axon model.
Hooks compile down to `Nx.Defn.Kernel.hook/3` and provide the same
functionality for adding side-effecting operations to a compiled
model. For example, you can use hooks to inspect intermediate activations,
send data to an external service, and more.
Hooks can be configured to be invoked on the following events:
* `:initialize` - on model initialization.
* `:pre_forward` - before layer forward pass is invoked.
* `:forward` - after layer forward pass is invoked.
* `:backward` - after layer backward pass is invoked.
To invoke a hook on every single event, you may pass `:all` to `on:`.
Axon.input({nil, 1}, "input") |> Axon.attach_hook(&IO.inspect/1, on: :all)
The default event is `:forward`, assuming you want a hook invoked
on the layers forward pass.
You may configure hooks to run in one of only training or inference
mode using the `:mode` option. The default mode is `:both` to be invoked
during both train and inference mode.
Axon.input({nil, 1}, "input") |> Axon.attach_hook(&IO.inspect/1, on: :forward, mode: :train)
You can also attach multiple hooks to a single layer. Hooks are invoked in
the order in which they are declared. If order is important, you should attach
hooks in the order you want them to be executed:
Axon.input({nil, 1}, "input")
# I will be executed first
|> Axon.attach_hook(&IO.inspect/1)
# I will be executed second
|> Axon.attach_hook(fn _ -> IO.write("HERE") end)
Hooks are executed at their point of attachment. You must insert hooks at each point
you want a hook to execute during model execution.
Axon.input({nil, 1}, "input")
|> Axon.attach_hook(&IO.inspect/1)
|> Axon.relu()
|> Axon.attach_hook(&IO.inspect/1)
"""
def attach_hook(%Axon{hooks: hooks} = axon, fun, opts \\ []) do
opts = Keyword.validate!(opts, on: :forward, mode: :both)
on_event = opts[:on]
mode = opts[:mode]
%{axon | hooks: [{on_event, mode, fun} | hooks]}
end
## Traversal
@doc """
Traverses a model tree applying `fun` to each layer.
"""
def tree_map(%Axon{op: :container, parent: [container]} = axon, fun) do
x = deep_new(container, fun)
%{fun.(axon) | parent: [x]}
end
def tree_map(%Axon{parent: x} = axon, fun) when is_list(x) do
x = Enum.map(x, &tree_map(&1, fun))
%{fun.(axon) | parent: x}
end
@doc """
Traverses a model applying `fun` with an accumulator.
"""
def tree_reduce(%Axon{op: :container, parent: [container]} = axon, acc, fun) do
deep_reduce(container, fun.(axon, acc), fun)
end
def tree_reduce(%Axon{parent: x} = axon, acc, fun) when is_list(x) do
Enum.reduce(x, fun.(axon, acc), &tree_reduce(&1, &2, fun))
end
# TODO: Should not be duplicated
def deep_reduce(map, acc, fun) do
Nx.Container.reduce(map, acc, &recur_deep_reduce(&1, &2, fun))
end
defp recur_deep_reduce(value, acc, fun) do
case value do
%Axon{} = val ->
fun.(val, acc)
%Nx.Tensor{} = val ->
fun.(val, acc)
{:leaf, val} ->
fun.(val, acc)
val ->
deep_reduce(val, acc, fun)
end
end
## Utilities
@doc """
Returns the model's signature as a tuple of `{input_shape, output_shape}`.
## Examples
iex> model = Axon.input({nil, 32}, "input") |> Axon.dense(10)
iex> {inp, out} = Axon.get_model_signature(model)
iex> inp
{nil, 32}
iex> out
{nil, 10}
iex> inp1 = Axon.input({nil, 32}, "input_0")
iex> inp2 = Axon.input({nil, 32}, "input_1")
iex> model = Axon.concatenate(inp1, inp2)
iex> {{inp1_shape, inp2_shape}, out} = Axon.get_model_signature(model)
iex> inp1_shape
{nil, 32}
iex> inp2_shape
{nil, 32}
iex> out
{nil, 64}
"""
def get_model_signature(%Axon{output_shape: output_shape} = axon) do
# TODO: Refactor for tuples and use `tree_*` when they support
# tuple inputs
input_shapes =
tree_reduce(axon, [], fn
%Axon{op: :input, output_shape: shape}, acc -> [shape | acc]
_, acc -> acc
end)
case input_shapes do
[input_shape] ->
{input_shape, output_shape}
shapes ->
{List.to_tuple(Enum.reverse(shapes)), output_shape}
end
end
@doc """
Compiles the given model to `{init_fn, predict_fn}`.
Once compiled, a model can be passed as argument to `Nx.Defn`.
"""
@doc type: :compilation
def compile(model, opts \\ []) when is_list(opts) do
{Axon.Compiler.compile_init(model, opts), Axon.Compiler.compile_predict(model, opts)}
end
@doc """
Compiles and runs the given models initialization function
with the given compiler options.
You may optionally specify initial parameters for some layers or
namespaces by passing a partial parameter map:
Axon.init(model, %{"dense_0" => dense_params})
The parameter map will be merged with the initialized model
parameters.
"""
@doc type: :execution
def init(model, params \\ %{}, opts \\ []) when is_list(opts) do
Axon.Compiler.compile_init(model, opts).(params)
end
@doc """
Compiles and runs the given Axon model with `params` on
`input` with the given compiler options.
"""
@doc type: :execution
def predict(%Axon{} = model, params, input, opts \\ []) when is_list(opts) do
Axon.Compiler.compile_predict(model, opts).(params, input)
end
## Inspection
defimpl Inspect do
import Inspect.Algebra
import Axon.Shared
alias Axon.Parameter
def inspect(axon, _opts) do
title = "Model"
header = ["Layer", "Shape", "Policy", "Parameters", "Parameters Memory"]
model_info = %{num_params: 0, total_param_byte_size: 0, inputs: []}
{_, _, cache, _, model_info} = axon_to_rows(axon, %{}, %{}, model_info)
rows =
cache
|> Enum.sort()
|> Enum.unzip()
|> elem(1)
|> Enum.map(&elem(&1, 0))
rows
|> TableRex.Table.new(header, title)
|> TableRex.Table.render!(
header_separator_symbol: "=",
title_separator_symbol: "=",
vertical_style: :off
)
|> then(&(&1 <> "Total Parameters: #{model_info.num_params}\n"))
|> then(&(&1 <> "Total Parameters Memory: #{model_info.total_param_byte_size} bytes\n"))
|> then(&(&1 <> "Inputs: #{inspect(Map.new(model_info.inputs))}\n"))
|> string()
end
defp axon_to_rows(%{id: id, op_name: op_name} = graph, cache, op_counts, model_info) do
case cache do
%{^id => {row, name}} ->
{row, name, cache, op_counts, model_info}
%{} ->
{row, name, cache, op_counts, model_info} =
do_axon_to_rows(graph, cache, op_counts, model_info)
cache = Map.put(cache, id, {row, name})
op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end)
{row, name, cache, op_counts, model_info}
end
end
defp do_axon_to_rows(
%Axon{
op: :container,
parent: [parents],
name: name_fn,
output_shape: shape,
policy: policy
},
cache,
op_counts,
model_info
) do
{input_names, {cache, op_counts, model_info}} =
deep_map_reduce(parents, {cache, op_counts, model_info}, fn
graph, {cache, op_counts, model_info} ->
{_, name, cache, op_counts, model_info} =
axon_to_rows(graph, cache, op_counts, model_info)
{name, {cache, op_counts, model_info}}
end)
op_string = "container"
name = name_fn.(:container, op_counts)
row = [
"#{name} ( #{op_string} #{inspect(input_names)} )",
"#{inspect(shape)}",
"#{inspect(policy)}",
0,
"0 bytes"
]
{row, name, cache, op_counts, model_info}
end
defp do_axon_to_rows(
%Axon{
op: :namespace,
parent: parents,
name: name_fn,
output_shape: shape,
policy: policy
},
cache,
op_counts,
model_info
) do
init_model_info = %{num_params: 0, total_param_byte_size: 0, inputs: []}
{_input_names, {_cache, op_counts, namespace_model_info}} =
Enum.map_reduce(parents, {%{}, op_counts, init_model_info}, fn
graph, {cache, op_counts, model_info} ->
{_, name, cache, op_counts, model_info} =
axon_to_rows(graph, cache, op_counts, model_info)
{name, {cache, op_counts, model_info}}
end)
name = name_fn.(:namespace, op_counts)
num_params = namespace_model_info.num_params
param_byte_size = namespace_model_info.total_param_byte_size
inputs = namespace_model_info.inputs
model_info =
model_info
|> Map.update(:num_params, 0, fn x -> x + num_params end)
|> Map.update(:total_param_byte_size, 0, fn x -> x + param_byte_size end)
|> Map.update(:inputs, [], fn x -> x ++ inputs end)
row = [
"#{name} ( #{inputs |> Map.new() |> Map.keys()} )",
"#{inspect(shape)}",
"#{inspect(policy)}",
"#{num_params}",
"#{param_byte_size} bytes"
]
{row, name, cache, op_counts, model_info}
end
defp do_axon_to_rows(
%Axon{
parent: parents,
parameters: params,
name: name_fn,
output_shape: shape,
policy: %{params: {_, bitsize}} = policy,
op_name: op_name
},
cache,
op_counts,
model_info
) do
{input_names, {cache, op_counts, model_info}} =
Enum.map_reduce(parents, {cache, op_counts, model_info}, fn
graph, {cache, op_counts, model_info} ->
{_, name, cache, op_counts, model_info} =
axon_to_rows(graph, cache, op_counts, model_info)
{name, {cache, op_counts, model_info}}
end)
num_params =
Enum.reduce(params, 0, fn
%Parameter{shape: {:tuple, shapes}}, acc ->
Enum.reduce(shapes, acc, &(Nx.size(&1) + &2))
%Parameter{shape: shape}, acc ->
acc + Nx.size(shape)
end)
param_byte_size = num_params * div(bitsize, 8)
op_inspect = Atom.to_string(op_name)
inputs =
case input_names do
[] ->
""
[_ | _] = input_names ->
"#{inspect(input_names)}"
end
name = name_fn.(op_name, op_counts)
row = [
"#{name} ( #{op_inspect}#{inputs} )",
"#{inspect(shape)}",
"#{inspect(policy)}",
"#{num_params}",
"#{param_byte_size} bytes"
]
model_info =
model_info
|> Map.update(:num_params, 0, &(&1 + num_params))
|> Map.update(:total_param_byte_size, 0, &(&1 + param_byte_size))
|> Map.update(:inputs, [], fn inputs ->
if op_name == :input, do: [{name, shape} | inputs], else: inputs
end)
{row, name, cache, op_counts, model_info}
end
end
## Serialization
@doc """
Serializes a model and its parameters for persisting
models to disk or elsewhere.
Model and parameters are serialized as a tuple, where the
model is converted to a recursive map to ensure compatibility
with future Axon versions and the parameters are serialized
using `Nx.serialize/2`. There is some additional metadata included
such as current serialization version for compatibility.
Serialization `opts` are forwarded to `Nx.serialize/2` and
`:erlang.term_to_binary/2` for controlling compression options.
## Examples
iex> model = Axon.input({nil, 2}, "input") |> Axon.dense(1, kernel_initializer: :zeros, activation: :relu)
iex> params = Axon.init(model)
iex> serialized = Axon.serialize(model, params)
iex> {saved_model, saved_params} = Axon.deserialize(serialized)
iex> Axon.predict(saved_model, saved_params, Nx.tensor([[1.0, 1.0]]))
#Nx.Tensor<
f32[1][1]
[
[0.0]
]
>
"""
def serialize(%Axon{} = model, params, opts \\ []) do
model_meta = axon_to_map(model)
params = Nx.serialize(params, opts)
:erlang.term_to_binary({@file_version, model_meta, params}, opts)
end
defp axon_to_map(%Axon{op: :container, parent: [parents]} = model) do
parents = deep_new(parents, &axon_to_map/1)
axon_map = Map.from_struct(model) |> Map.put(:axon, :axon)
%{axon_map | parent: List.wrap(parents)}
end
defp axon_to_map(%Axon{parent: parents} = model) do
parents = Enum.map(parents, &axon_to_map/1)
axon_map = Map.from_struct(model) |> Map.put(:axon, :axon)
%{axon_map | parent: parents}
end
@doc """
Deserializes serialized model and parameters into a `{model, params}`
tuple.
It is the opposite of `Axon.serialize/3`.
## Examples
iex> model = Axon.input({nil, 2}, "input") |> Axon.dense(1, kernel_initializer: :zeros, activation: :relu)
iex> params = Axon.init(model)
iex> serialized = Axon.serialize(model, params)
iex> {saved_model, saved_params} = Axon.deserialize(serialized)
iex> Axon.predict(saved_model, saved_params, Nx.tensor([[1.0, 1.0]]))
#Nx.Tensor<
f32[1][1]
[
[0.0]
]
>
"""
def deserialize(serialized, opts \\ []) do
{1, model_meta, serialized_params} = :erlang.binary_to_term(serialized, [:safe | opts])
model = map_to_axon(model_meta)
params = Nx.deserialize(serialized_params, opts)
{model, params}
end
defp map_to_axon(%{op: :container, parent: [parents]} = model) do
parents = deep_new(parents, &map_to_axon/1)
model = Map.drop(model, [:axon])
model = %{model | parent: List.wrap(parents)}
struct(__MODULE__, model)
end
defp map_to_axon(%{axon: :axon, parent: parents} = model) do
parents = Enum.map(parents, &map_to_axon/1)
model = Map.drop(model, [:axon])
model = %{model | parent: parents}
struct(__MODULE__, model)
end
## Helpers
@valid_initializers [:zeros, :ones, :uniform, :normal, :identity] ++
[:lecun_uniform, :lecun_normal, :he_uniform, :he_normal] ++
[:glorot_uniform, :glorot_normal, :variance_scaling]
defp validate_initializer!(initializer)
when is_atom(initializer) and initializer in @valid_initializers do
:ok
end
defp validate_initializer!(initializer) when is_function(initializer, 2) do
:ok
end
defp validate_initializer!(initializer) do
raise ArgumentError,
"initializer must be one of #{inspect(@valid_initializers)}," <>
" or an arity-2 function accepting initializer shape and type" <>
" got #{inspect(initializer)}"
end
defp tuple_or_duplicate(key, tuple_or_integer, rank) do
cond do
is_tuple(tuple_or_integer) ->
if tuple_size(tuple_or_integer) != rank do
raise ArgumentError,
"expected #{inspect(key)} to be a #{rank}-element tuple, " <>
"got: #{inspect(tuple_or_integer)}"
end
tuple_or_integer
is_integer(tuple_or_integer) ->
Tuple.duplicate(tuple_or_integer, rank)
true ->
raise ArgumentError,
"expected #{inspect(key)} to be an integer or a tuple, " <>
"got: #{inspect(tuple_or_integer)}"
end
end
defp list_or_duplicate(key, list_or_integer, rank) do
cond do
is_list(list_or_integer) ->
if length(list_or_integer) != rank do
raise ArgumentError,
"expected #{inspect(key)} to be a #{rank}-element list, " <>
"got: #{inspect(list_or_integer)}"
end
list_or_integer
is_integer(list_or_integer) ->
List.duplicate(list_or_integer, rank)
true ->
raise ArgumentError,
"expected #{inspect(key)} to be an integer or a list, " <>
"got: #{inspect(list_or_integer)}"
end
end
# Names are generated lazily at inspect, initialization, and compile
# time, so for name we return a function which takes `op` and `op_count`
# and returns a unique name for the given model.
defp unique_identifiers(type, nil) do
id = System.unique_integer([:positive, :monotonic])
name = fn op, op_counts ->
count = op_counts[op] || 0
Atom.to_string(type) <> "_#{count}"
end
{id, name}
end
defp unique_identifiers(_type, name_fn) when is_function(name_fn, 2) do
id = System.unique_integer([:positive, :monotonic])
{id, name_fn}
end
defp unique_identifiers(_type, name) when is_binary(name) do
{System.unique_integer([:positive, :monotonic]), fn _, _ -> name end}
end
defp unique_identifiers(_, name) do
raise ArgumentError,
"expected layer name to be a binary, a function or nil, " <>
"got: #{inspect(name)}"
end
end