defprotocol Nx.Container do
@moduledoc """
A protocol that teaches Nx how to traverse data structures
non-recursively.
`Nx` and `defn` expect the arguments to be numbers, tensors,
or one of the following composite data types:
1. tuples of numbers/tensors
2. maps of any key with numbers/tensors as values
3. any struct that implements `Nx.Container`
If you need to pass additional values, you can implement
or derive this protocol. For example:
@derive {Nx.Container,
containers: [:field_name, :other_field]}
defstruct [:field_name, :other_fields, ...]
The `:containers` option is required and it must specify a
list of fields that contains tensors. Inside `defn`, the
container fields will be automatically converted to tensor
expressions. All other fields will be reset to their default
value, unless you explicitly declare them to be kept:
@derive {Nx.Container,
containers: [:field_name, :other_field],
keep: [:another_field]}
defstruct [:field_name, :other_fields, ...]
Note the functions in this module are not recursive.
If you want to deeply traverse and reduce containers,
use the functions in `Nx.Defn.Composite` instead.
> **Careful!**: If you keep a field, its value will be part
> of the `Nx.Defn` compiler cache key (i.e. therefore if you
> give a struct with two different values for a kept field,
> `Nx.Defn` will have to compile and cache it twice).
> You must only keep fields that you are certain to be used
> inside `defn` during compilation time.
"""
@doc """
Traverses non-recursively tensors in a data structure with `acc` and `fun`.
`fun` is invoked with each tensor or tensor container in the
data structure plus an accumulator. It must return a two element
tuple with the updated value and accumulator.
This function returns the updated container and the accumulator.
Given `fun` may receive containers, it is not recursive by default.
See `Nx.Defn.Composite.traverse/3` for a recursive variant.
"""
@spec traverse(t(), acc, (t(), acc -> {Nx.t(), acc})) :: {Nx.t(), acc}
when acc: term()
def traverse(data, acc, fun)
@doc """
Reduces non-recursively tensors in a data structure with `acc` and `fun`.
`fun` is invoked with each tensor or tensor container in the
data structure plus an accumulator. It must return the new
accumulator.
This function the final accumulator.
Given `fun` may receive containers, it is not recursive by default.
See `Nx.Defn.Composite.reduce/3` for a recursive variant.
"""
@spec reduce(t(), acc, (t(), acc -> acc)) :: acc when acc: term()
def reduce(data, acc, fun)
end
defimpl Nx.Container, for: Tuple do
def traverse(tuple, acc, fun) do
tuple
|> Tuple.to_list()
|> Enum.map_reduce(acc, fun)
|> then(fn {list, acc} -> {List.to_tuple(list), acc} end)
end
def reduce(tuple, acc, fun) do
tuple
|> Tuple.to_list()
|> Enum.reduce(acc, fun)
end
end
defimpl Nx.Container, for: Map do
def traverse(map, acc, fun) do
map
|> Map.to_list()
|> Enum.sort()
|> Enum.map_reduce(acc, fn {k, v}, acc ->
{v, acc} = fun.(v, acc)
{{k, v}, acc}
end)
|> then(fn {list, acc} -> {Map.new(list), acc} end)
end
def reduce(map, acc, fun) do
map
|> Map.to_list()
|> Enum.sort()
|> Enum.reduce(acc, fn {_, v}, acc -> fun.(v, acc) end)
end
end
defimpl Nx.Container, for: [Integer, Float, Complex, Nx.Tensor] do
def traverse(tensor, acc, fun), do: {tensor, fun.(tensor, acc)}
def reduce(tensor, acc, fun), do: fun.(tensor, acc)
end
defimpl Nx.Container, for: Any do
defmacro __deriving__(module, struct, options) do
containers = Keyword.fetch!(options, :containers)
keep = Keyword.get(options, :keep, [])
container_pattern = Enum.map(containers, &field_var(struct, &1))
keep_pattern = Enum.map(keep, &field_var(struct, &1))
full_pattern = container_pattern ++ keep_pattern
updates =
for field <- containers do
var = Macro.var(field, Nx.Container)
quote do
{unquote(var), acc} = fun.(unquote(var), acc)
end
end
reduces =
for field <- containers do
var = Macro.var(field, Nx.Container)
quote do
acc = fun.(unquote(var), acc)
end
end
return =
struct
|> Map.to_list()
|> Keyword.drop(keep ++ containers)
|> Macro.escape()
|> Keyword.merge(full_pattern)
quote do
defimpl Nx.Container, for: unquote(module) do
def traverse(%{unquote_splicing(full_pattern)} = struct, acc, fun) do
unquote_splicing(updates)
{%{unquote_splicing(return)}, acc}
end
def reduce(%{unquote_splicing(container_pattern)} = struct, acc, fun) do
unquote_splicing(reduces)
acc
end
end
end
end
defp field_var(struct, field) do
unless Map.has_key?(struct, field) do
raise ArgumentError,
"cannot derive Nx.Container for struct #{inspect(struct.__struct__)} " <>
"because it does not have field #{inspect(field)}"
end
{field, Macro.var(field, Nx.Container)}
end
def traverse(data, _acc, _fun) do
raise Protocol.UndefinedError, protocol: @protocol, value: data
end
def reduce(data, _acc, _fun) do
raise Protocol.UndefinedError, protocol: @protocol, value: data
end
end