defmodule Nx.Batch do
@moduledoc """
Creates a batch of tensors (and containers).
A batch is lazily traversed, concatenated, and padded upon `defn` invocation.
"""
@axis 0
@doc """
A Nx.Batch struct.
The `:size` field is public.
"""
@derive {Inspect, only: [:size, :pad]}
defstruct stack: [], size: 0, template: nil, pad: 0
@type t :: %Nx.Batch{
stack: list(),
size: non_neg_integer(),
template: Nx.Container.t() | Nx.Tensor.t() | nil,
pad: non_neg_integer()
}
@doc """
Returns a new empty batch.
"""
def new, do: %Nx.Batch{}
@doc """
Merges two batches.
The tensors on the left will appear before the tensors on the right.
The size and padding of both batches are summed. The padding still
applies only at the end of batch.
It will raise if the batch templates are incompatible.
## Examples
iex> batch1 = Nx.Batch.stack([Nx.tensor(1), Nx.tensor(2), Nx.tensor(3)])
iex> batch2 = Nx.Batch.concatenate([Nx.tensor([4, 5]), Nx.tensor([6, 7, 8])])
iex> batch = Nx.Batch.merge(batch1, batch2)
iex> batch.size
8
iex> Nx.Defn.jit_apply(&Function.identity/1, [batch])
#Nx.Tensor<
s64[8]
[1, 2, 3, 4, 5, 6, 7, 8]
>
"""
def merge(left, right), do: merge([left, right])
@doc """
Merges a list of batches.
See `merge/2`.
"""
def merge([]), do: new()
def merge([%Nx.Batch{} = head | tail]) do
%{template: template, stack: stack, pad: pad, size: size} = head
{template, stack, pad, size} =
Enum.reduce(tail, {template, stack, pad, size}, fn batch, acc ->
%Nx.Batch{template: template, stack: stack, pad: pad, size: size} = batch
{acc_template, acc_stack, acc_pad, acc_size} = acc
if template != nil and acc_template != nil and not Nx.compatible?(template, acc_template) do
raise ArgumentError, """
cannot merge batches due to incompatible templates:
#{inspect(template)}
and:
#{inspect(acc_template)}
"""
end
{acc_template || template, stack ++ acc_stack, pad + acc_pad, size + acc_size}
end)
%Nx.Batch{template: template, stack: stack, pad: pad, size: size}
end
@doc """
Splits a batch in two, where the first one has at most `n` elements.
If there is any padding and the batch is not full, the amount of padding
necessary will be moved to the first batch and the remaining stays in the
second batch.
## Examples
iex> batch = Nx.Batch.concatenate([Nx.tensor([1, 2]), Nx.tensor([3, 4, 5])])
iex> {left, right} = Nx.Defn.jit_apply(&Function.identity/1, [Nx.Batch.split(batch, 3)])
iex> left
#Nx.Tensor<
s64[3]
[1, 2, 3]
>
iex> right
#Nx.Tensor<
s64[2]
[4, 5]
>
"""
def split(%Nx.Batch{} = batch, n) when is_integer(n) and n > 0 do
%{template: template, stack: stack, pad: pad, size: size} = batch
if n < size do
{left, right} = drop_split(stack, size - n, [])
{%{batch | stack: left, size: n, pad: 0},
%Nx.Batch{template: template, pad: pad, size: size - n, stack: right}}
else
right_pad = max(size + pad - n, 0)
left_pad = pad - right_pad
{%{batch | pad: left_pad}, %Nx.Batch{template: template, pad: right_pad}}
end
end
defp drop_split([{funs, size} | stack], n, acc) when size < n do
drop_split(stack, n - size, [{funs, size} | acc])
end
defp drop_split([{funs, size} | stack], n, acc) when size == n do
{stack, Enum.reverse([{funs, size} | acc])}
end
defp drop_split([{funs, size} | stack], n, acc) when size > n do
left_start = 0
left_size = size - n
left_funs =
Enum.map(funs, fn fun ->
fn -> Nx.slice_along_axis(fun.(), left_start, left_size, axis: @axis) end
end)
right_start = size - n
right_size = n
right_funs =
Enum.map(funs, fn fun ->
fn -> Nx.slice_along_axis(fun.(), right_start, right_size, axis: @axis) end
end)
{[{left_funs, left_size} | stack], Enum.reverse([{right_funs, right_size} | acc])}
end
@doc """
Configures the batch with the given padding.
The batch will be padded when consumed:
iex> batch = Nx.Batch.stack([Nx.tensor(1), Nx.tensor(2), Nx.tensor(3)])
iex> Nx.Defn.jit_apply(&Function.identity/1, [Nx.Batch.pad(batch, 2)])
#Nx.Tensor<
s64[5]
[1, 2, 3, 0, 0]
>
"""
def pad(%Nx.Batch{} = batch, pad) when is_integer(pad) and pad >= 0 do
%{batch | pad: pad}
end
@doc """
Concatenates the given entries to the batch.
Entries are concatenated based on their first axis.
If the first axis has multiple entries, each entry
is added to the size of the batch.
You can either concatenate to an existing batch
or skip the batch argument to create a new batch.
See `stack/2` if you want to stack entries instead
of concatenating them.
## Examples
If no batch is given, one is automatically created:
iex> batch = Nx.Batch.concatenate([Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])])
iex> Nx.Defn.jit_apply(&Function.identity/1, [batch])
#Nx.Tensor<
s64[3]
[1, 2, 3]
>
But you can also concatenate to existing batches:
iex> batch = Nx.Batch.concatenate([Nx.tensor([1]), Nx.tensor([2])])
iex> batch = Nx.Batch.concatenate(batch, [Nx.tensor([3]), Nx.tensor([4])])
iex> Nx.Defn.jit_apply(&Function.identity/1, [batch])
#Nx.Tensor<
s64[4]
[1, 2, 3, 4]
>
If the first axis has multiple entries, each entry counts
towards the size of the batch:
iex> batch = Nx.Batch.concatenate([Nx.tensor([1, 2]), Nx.tensor([3, 4, 5])])
iex> batch.size
5
iex> Nx.Defn.jit_apply(&Function.identity/1, [batch])
#Nx.Tensor<
s64[5]
[1, 2, 3, 4, 5]
>
What makes batches powerful is that they can concatenate
across containers:
iex> container1 = {Nx.tensor([11]), Nx.tensor([21])}
iex> container2 = {Nx.tensor([12]), Nx.tensor([22])}
iex> batch = Nx.Batch.concatenate([container1, container2])
iex> {batched1, batched2} = Nx.Defn.jit_apply(&Function.identity/1, [batch])
iex> batched1
#Nx.Tensor<
s64[2]
[11, 12]
>
iex> batched2
#Nx.Tensor<
s64[2]
[21, 22]
>
"""
def concatenate(%Nx.Batch{} = batch \\ new(), entries) when is_list(entries),
do: add(batch, entries, false)
@doc """
Stacks the given entries to the batch.
Each entry counts exactly as a single entry.
You can either stack to an existing batch
or skip the batch argument to create a new batch.
See `concatenate/2` if you want to concatenate entries
instead of stacking them.
## Examples
If no batch is given, one is automatically created:
iex> batch = Nx.Batch.stack([Nx.tensor(1), Nx.tensor(2), Nx.tensor(3)])
iex> batch.size
3
iex> Nx.Defn.jit_apply(&Function.identity/1, [batch])
#Nx.Tensor<
s64[3]
[1, 2, 3]
>
But you can also stack an existing batch:
iex> batch = Nx.Batch.stack([Nx.tensor(1), Nx.tensor(2)])
iex> batch = Nx.Batch.stack(batch, [Nx.tensor(3), Nx.tensor(4)])
iex> batch.size
4
iex> Nx.Defn.jit_apply(&Function.identity/1, [batch])
#Nx.Tensor<
s64[4]
[1, 2, 3, 4]
>
What makes batches powerful is that they can concatenate
across containers:
iex> container1 = {Nx.tensor(11), Nx.tensor(21)}
iex> container2 = {Nx.tensor(12), Nx.tensor(22)}
iex> batch = Nx.Batch.stack([container1, container2])
iex> {batched1, batched2} = Nx.Defn.jit_apply(&Function.identity/1, [batch])
iex> batched1
#Nx.Tensor<
s64[2]
[11, 12]
>
iex> batched2
#Nx.Tensor<
s64[2]
[21, 22]
>
"""
def stack(%Nx.Batch{} = batch \\ new(), entries) when is_list(entries),
do: add(batch, entries, true)
defp add(batch, [], _new_axis?), do: batch
defp add(batch, [head | tail], new_axis?) do
%{template: template, stack: stack, size: size} = batch
{head_template, head_size, head_funs} = traverse(head, new_axis?)
acc = {head_size + size, [{head_funs, head_size} | stack]}
{size, stack} =
Enum.reduce(tail, acc, fn arg, {acc_size, acc_stack} ->
{arg_template, size, arg_funs} = traverse(arg, new_axis?)
unless Nx.compatible?(arg_template, head_template) do
raise ArgumentError, """
cannot add to batch due to incompatible tensors/containers.
The head of the list has shape:
#{inspect(head_template)}
But another list element has template:
#{inspect(arg_template)}
From entry:
#{inspect(arg)}
"""
end
{size + acc_size, [{arg_funs, size} | acc_stack]}
end)
if template == nil or Nx.compatible?(template, head_template) do
%{batch | template: head_template, stack: stack, size: size}
else
raise ArgumentError, """
cannot add to batch due to incompatible tensors/containers.
The batch has shape:
#{inspect(template)}
But then the head of the list has template:
#{inspect(head_template)}
From entry:
#{inspect(head)}
"""
end
end
defp traverse(container, true) do
{template, funs} =
Nx.LazyContainer.traverse(container, [], fn template, fun, acc ->
{template, [fn -> Nx.new_axis(fun.(), @axis) end | acc]}
end)
{template, 1, funs}
end
defp traverse(container, false) do
{template, {size, funs}} =
Nx.LazyContainer.traverse(container, {nil, []}, fn template, fun, {acc_size, acc_funs} ->
%Nx.Tensor{shape: shape, names: names} = template
if shape == {} do
raise ArgumentError, "cannot concatenate scalar tensor in #{inspect(container)}"
end
size = elem(shape, @axis)
if acc_size != nil and size != acc_size do
raise ArgumentError,
"concatenate expects all tensors in the same container to have the same value " <>
"for first axis, got #{size} and #{acc_size} in #{inspect(container)}"
end
template = %{template | shape: Tuple.delete_at(shape, @axis), names: tl(names)}
{template, {size, [fun | acc_funs]}}
end)
if size == nil do
raise ArgumentError, "cannot have an empty container in concatenate: #{inspect(container)}"
end
{template, size, funs}
end
end
defimpl Nx.LazyContainer, for: Nx.Batch do
@axis 0
def traverse(%{stack: []}, _acc, _acc_fun) do
raise ArgumentError, "cannot traverse/jit/compile Nx.Batch without entries"
end
def traverse(%{stack: funs_size, pad: pad, template: template, size: size}, acc, acc_fun) do
total = size + pad
funs =
funs_size
|> first_reverse([])
|> Enum.zip_with(fn funs ->
fn ->
funs
|> apply_each()
|> Nx.concatenate(axis: @axis)
|> maybe_pad(pad)
end
end)
|> Enum.reverse()
{template, {acc, []}} =
Nx.Defn.Composite.traverse(template, {acc, funs}, fn template, {acc, [fun | funs]} ->
%{shape: shape, names: names} = template
template = %{template | shape: Tuple.insert_at(shape, 0, total), names: [nil | names]}
{template, acc} = acc_fun.(template, fun, acc)
{template, {acc, funs}}
end)
{template, acc}
end
defp first_reverse([{fun, _} | funs], acc), do: first_reverse(funs, [fun | acc])
defp first_reverse([], acc), do: acc
defp apply_each([fun | funs]), do: [fun.() | apply_each(funs)]
defp apply_each([]), do: []
defp maybe_pad(tensor, 0), do: tensor
defp maybe_pad(tensor, pad_size) do
padding =
{0, 0, 0}
|> List.duplicate(Nx.rank(tensor))
|> List.replace_at(@axis, {0, pad_size, 0})
Nx.pad(tensor, Nx.tensor(0, type: Nx.type(tensor)), padding)
end
end