defmodule Axon.Recurrent do
@moduledoc """
Functional implementations of common recurrent neural network
routines.
Recurrent Neural Networks are commonly used for working with
sequences of data where there is some level of dependence between
outputs at different timesteps.
This module contains 3 RNN Cell functions and methods to "unroll"
cells over an entire sequence. Each cell function returns a tuple:
{new_carry, output}
Where `new_carry` is an updated carry state and `output` is the output
for a singular timestep. In order to apply an RNN across multiple timesteps,
you need to use either `static_unroll` or `dynamic_unroll` (coming soon).
Unrolling an RNN is equivalent to a `map_reduce` or `scan` starting
from an initial carry state and ending with a final carry state and
an output sequence.
All of the functions in this module are implemented as
numerical functions and can be JIT or AOT compiled with
any supported `Nx` compiler.
"""
import Nx.Defn
import Axon.Layers
@doc """
GRU Cell.
When combined with `Axon.Recurrent.*_unroll`, implements a
GRU-based RNN. More memory efficient than traditional LSTM.
## References
* [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](https://arxiv.org/pdf/1412.3555v1.pdf)
"""
defn gru_cell(
input,
carry,
input_kernel,
hidden_kernel,
bias,
gate_fn \\ &sigmoid/1,
activation_fn \\ &tanh/1
) do
{hidden} = carry
{wir, wiz, win} = input_kernel
{whr, whz, whn} = hidden_kernel
{br, bz, bin, bhn} = bias
r = gate_fn.(dense(input, wir, br) + dense(hidden, whr, 0))
z = gate_fn.(dense(input, wiz, bz) + dense(hidden, whz, 0))
n = activation_fn.(dense(input, win, bin) + r * dense(hidden, whn, bhn))
new_h = (1.0 - z) * n + z * hidden
{{new_h}, new_h}
end
@doc """
LSTM Cell.
When combined with `Axon.Recurrent.*_unroll`, implements a
LSTM-based RNN. More memory efficient than traditional LSTM.
## References
* [Long Short-Term Memory](http://www.bioinf.jku.at/publications/older/2604.pdf)
"""
defn lstm_cell(
input,
carry,
input_kernel,
hidden_kernel,
bias,
gate_fn \\ &sigmoid/1,
activation_fn \\ &tanh/1
) do
{cell, hidden} = carry
{wii, wif, wig, wio} = input_kernel
{whi, whf, whg, who} = hidden_kernel
{bi, bf, bg, bo} = bias
i = gate_fn.(dense(input, wii, bi) + dense(hidden, whi, 0))
f = gate_fn.(dense(input, wif, bf) + dense(hidden, whf, 0))
g = activation_fn.(dense(input, wig, bg) + dense(hidden, whg, 0))
o = gate_fn.(dense(input, wio, bo) + dense(hidden, who, 0))
new_c = f * cell + i * g
new_h = o * activation_fn.(new_c)
{{new_c, new_h}, new_h}
end
defnp rank_down(rnn_data) do
transform(rnn_data, fn {{cell, hidden}, input} ->
[cell, hidden, input] =
for tensor <- [cell, hidden, input] do
Nx.squeeze(tensor, axes: [1])
end
{{cell, hidden}, input}
end)
end
defnp rank_up(rnn_data) do
transform(rnn_data, fn {{cell, hidden}, input} ->
[cell, hidden, input] =
for tensor <- [cell, hidden, input] do
new_shape =
Nx.shape(tensor)
|> Tuple.insert_at(1, 1)
Nx.reshape(tensor, new_shape)
end
{{cell, hidden}, input}
end)
end
@doc """
ConvLSTM Cell.
When combined with `Axon.Recurrent.*_unroll`, implements a
ConvLSTM-based RNN. More memory efficient than traditional LSTM.
## Options
* `:strides` - convolution strides. Defaults to `1`.
* `:padding` - convolution padding. Defaults to `:same`.
## References
* [Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting](https://arxiv.org/abs/1506.04214)
"""
defn conv_lstm_cell(input, carry, input_kernel, hidden_kernel, bias, opts \\ []) do
opts = keyword!(opts, strides: 1, padding: :same)
{ih} = input_kernel
{hh} = hidden_kernel
{bi} = bias
{{cell, hidden}, input} = rank_down({carry, input})
gates =
Nx.add(
conv(input, ih, bi, strides: opts[:strides], padding: opts[:padding]),
conv(hidden, hh, 0, strides: opts[:strides], padding: opts[:padding])
)
{i, g, f, o} = split_gates(gates)
f = sigmoid(f + 1)
new_c = f * cell + sigmoid(i) * tanh(g)
new_h = sigmoid(o) * tanh(new_c)
rank_up({{new_c, new_h}, new_h})
end
defnp split_gates(gates) do
transform(gates, fn gates ->
channels = elem(Nx.shape(gates), 1)
split_every = div(channels, 4)
split_dims =
for i <- 0..3 do
{i * split_every, split_every}
end
split_dims
|> Enum.map(fn {start, len} -> Nx.slice_along_axis(gates, start, len, axis: 1) end)
|> List.to_tuple()
end)
end
@doc """
Dynamically unrolls an RNN.
Unrolls implement a `scan` operation which applies a
transformation on the leading axis of `input_sequence` carrying
some state. In this instance `cell_fn` is an RNN cell function
such as `lstm_cell` or `gru_cell`.
This function will make use of an `defn` while-loop such and thus
may be more efficient for long sequences.
"""
defn dynamic_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) do
time_steps = transform(Nx.shape(input_sequence), &elem(&1, 1))
feature_dims = transform(Nx.rank(input_sequence), &List.duplicate(0, &1 - 2))
initial_shape =
transform({Nx.shape(input_sequence), Nx.shape(elem(input_kernel, 0))}, fn {shape, kernel} ->
put_elem(shape, 2, elem(kernel, 1))
end)
init_sequence = Nx.broadcast(0.0, initial_shape)
i = Nx.tensor(0)
{_, carry, output, _, _, _, _} =
while {i, carry, init_sequence, input_sequence, input_kernel, recurrent_kernel, bias},
Nx.less(i, time_steps) do
sequence = Nx.slice_along_axis(input_sequence, i, 1, axis: 1)
indices = transform({feature_dims, i}, fn {feature_dims, i} -> [0, i] ++ feature_dims end)
{carry, output} = cell_fn.(sequence, carry, input_kernel, recurrent_kernel, bias)
update_sequence = Nx.put_slice(init_sequence, indices, output)
{i + 1, carry, update_sequence, input_sequence, input_kernel, recurrent_kernel, bias}
end
{carry, output}
end
@doc """
Statically unrolls an RNN.
Unrolls implement a `scan` operation which applies a
transformation on the leading axis of `input_sequence` carrying
some state. In this instance `cell_fn` is an RNN cell function
such as `lstm_cell` or `gru_cell`.
This function inlines the unrolling of the sequence such that
the entire operation appears as a part of the compilation graph.
This makes it suitable for shorter sequences.
"""
defn static_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) do
transform(
{cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias},
fn {cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias} ->
time_steps = elem(Nx.shape(input_sequence), 1)
{carry, outputs} =
for t <- 0..(time_steps - 1), reduce: {carry, []} do
{carry, outputs} ->
input = Nx.slice_along_axis(input_sequence, t, 1, axis: 1)
{carry, output} = cell_fn.(input, carry, input_kernel, recurrent_kernel, bias)
{carry, [output | outputs]}
end
{carry, Nx.concatenate(Enum.reverse(outputs), axis: 1)}
end
)
end
end