defmodule Nx.Defn.Expr do
@doc """
The expression used by `Nx.Defn.Compiler`.
`Nx.Defn.Compiler` changes `Nx` default backend from `Nx.BinaryBackend`
to `Nx.Defn.Expr`. It is a struct with the following fields:
* `:id` - a unique identifier
* `:op` - the operation name
* `:args` - the operation arguments
* `:context` - the context of the expression.
The default context is `:root`.
Convenience functions for traversing expressions and composite types
can be found in `Nx.Defn.Composite` and `Nx.Defn.Tree`.
## Syntax nodes
Most nodes are created directly via the `Nx` module and
therefore map directly to `Nx.Tensor` callbacks. However
the following syntax nodes exist:
* `parameter(integer)`
* `constant(number)`
* `tensor(tensor)`
* `metadata(expr, metadata)`
* `elem(tuple, pos)` - created automatically from
expression that return tuples. Note it may return
tuples too, which means we have nested tuples
* `fun(parameters, t, mfa)` - the `mfa` is used only for
introspection purposes
* `cond(clauses, otherwise)`
* `while(initial, condition, body)`
* `attach_token(token(%Nx.Defn.Token{}), expr)`
`defn` compilers must handle said nodes accordingly.
"""
alias Nx.Defn.{Composite, Expr, Tree}
alias Nx.Tensor, as: T
import Nx.Shared
@enforce_keys [:id, :op, :args, :context]
defstruct [:id, :op, :args, :context]
## Public API
@doc """
Builds an tensor expression from the given tensor.
"""
def tensor(tensor), do: to_expr(tensor)
@doc """
Creates a tensor expression parameter at `pos` based on the given tensor expression.
"""
def parameter(%T{data: %Expr{context: context}} = tensor, pos) do
parameter(tensor, context, pos)
end
@doc """
Creates a tensor expression parameter at `pos` based on the given `tensor` and `context`.
"""
def parameter(tensor, context, pos) when is_integer(pos) and pos >= 0 do
expr(tensor, context, :parameter, [pos])
end
@doc """
Creates a tensor expression parameter at `pos` with the given `context`, `type`,
`shape`, and `pos`.
"""
def parameter(context, type, shape, pos) do
names = List.duplicate(nil, tuple_size(shape))
expr(%T{type: type, shape: shape, names: names}, context, :parameter, [pos])
end
@doc """
Creates a tensor expression metadata node wrapping the
given tensor expression.
The metadata is map. If the `inspect` key is present,
it will be used to annotate the metadata when inspected.
Otherwise the metadata node does not appear during
inspection.
"""
def metadata(expr, metadata) when is_map(metadata) do
case to_container_expr(expr) do
%{data: %{context: context}} = res ->
expr(res, context, :metadata, [expr, metadata])
t when is_tuple(t) ->
context = elem(t, 0).data.context
tuple(
expr(tuple_out(tuple_size(t)), context, :metadata, [expr, metadata]),
Tuple.to_list(t)
)
end
end
@doc """
Creates a tuple with elements in `list` that points to tuple
expression `expr`.
`list` must be a list of tensor expressions of the same size
as the tuple expression.
"""
def tuple(%T{type: {:tuple, size}, data: %{context: context}} = expr, list)
when is_list(list) do
tuple =
list
|> Enum.with_index(fn %T{} = tensor, i ->
expr(tensor, context, :elem, [expr, i])
end)
|> List.to_tuple()
^size = tuple_size(tuple)
tuple
end
@doc """
Creates a `cond` tensor expression.
"""
def cond([], last) do
last
end
def cond(clauses, last = out) do
{preds, exprs} = Enum.unzip(clauses)
{preds, context} = to_exprs(preds)
[last | exprs] =
[last | exprs]
|> Enum.map(&Composite.flatten_list([&1]))
|> Enum.zip_with(&broadcast_clause/1)
|> case do
# Handle the case where branches don't return anything
[] -> Enum.map([last | exprs], fn _ -> {} end)
clauses -> unzip_clauses(clauses)
end
clauses = Enum.zip(preds, exprs)
flatten_to_composite(out, context, exprs, &expr(&1, context, :cond, [clauses, last]))
end
defp broadcast_clause([type = last | exprs]) do
%{shape: shape, names: names} = last = to_expr(last)
{exprs, {type, shape, names}} =
Enum.map_reduce(exprs, {type, shape, names}, fn expr, {type, shape, names} ->
type = binary_type(type, expr)
expr = to_expr(expr)
{shape, names} = Nx.Shape.binary_broadcast(shape, names, expr.shape, expr.names)
{expr, {type, shape, names}}
end)
for expr <- [last | exprs] do
expr
|> Nx.as_type(type)
|> Nx.broadcast(shape, names: names)
end
end
defp unzip_clauses([exprs | _] = clauses),
do: unzip_clauses(clauses, List.duplicate([], length(exprs)))
defp unzip_clauses([exprs | tail], acc),
do: unzip_clauses(tail, unzip_each(exprs, acc))
defp unzip_clauses([], acc) do
Enum.map(acc, fn
[entry] -> entry
list -> List.to_tuple(Enum.reverse(list))
end)
end
defp unzip_each([head | tail], [acc_head | acc_tail]),
do: [[head | acc_head] | unzip_each(tail, acc_tail)]
defp unzip_each([], []),
do: []
@doc """
Creates a `while` tensor expression.
"""
def while(initial, context, arg, condition, body) do
[flatten_initial, flatten_arg, flatten_body] = clauses = flatten_clauses([initial, arg, body])
args = [flatten_initial, flatten_arg, condition, flatten_body]
flatten_to_composite(initial, context, clauses, &expr(&1, context, :while, args))
end
defp flatten_clauses(clauses) do
Enum.map(clauses, fn expr ->
case Composite.flatten_list([expr]) do
[single] -> single
list -> List.to_tuple(list)
end
end)
end
defp flatten_to_composite(out, context, [head | _], fun) when is_tuple(head) do
size = tuple_size(head)
expr = fun.(tuple_out(size))
{out, {[], ^size}} =
Composite.traverse(out, {Tuple.to_list(head), 0}, fn _, {[head | tail], i} ->
{expr(head, context, :elem, [expr, i]), {tail, i + 1}}
end)
out
end
defp flatten_to_composite(out, _context, [head | _], fun) do
{out, []} = Composite.traverse(out, [fun.(head)], fn _, [head | tail] -> {head, tail} end)
out
end
@impl true
def optional(name, args, fun) do
{args, opts} = Enum.split_while(args, &(not is_list(&1)))
params = Enum.with_index(args, ¶meter/2)
case apply(fun, params ++ opts) do
%{data: %{context: context}} = res ->
expr(res, context, :optional, [expr(res, context, name, args), res])
t when is_tuple(t) ->
context = elem(t, 0).data.context
out = expr(tuple_out(tuple_size(t)), context, name, args)
tuple(expr(out, context, :optional, [out, t]), Tuple.to_list(t))
end
end
## Nx.Defn AST callbacks
@doc false
def id(), do: make_ref()
@doc false
def add_hook(token, expr, name, function) do
expr = to_container_expr(expr)
token = Nx.Defn.Token.add_hook(token, expr, name, function)
{token, expr}
end
@doc false
def attach_token(%T{data: %{op: :token}} = token, expr) do
Composite.traverse(expr, fn tensor ->
expr = to_expr(tensor)
expr(expr, expr.data.context, :attach_token, [token, expr])
end)
end
def attach_token(%Nx.Defn.Token{} = token, expr) do
# We first create an expression to store the token
# so we have a shared ID to avoid multiple traversals.
# The size of the tuple is not used, but the amount of
# hooks is a good indicator.
size = length(token.hooks)
token = expr(%T{shape: {}, type: {:tuple, size}, names: []}, nil, :token, [token])
attach_token(token, expr)
end
@doc false
def defn_cond(file, [{meta, _} | _] = clauses) do
clauses =
for {meta, {pred, expr}} <- clauses,
pred = to_pred(pred, meta[:line], file, :cond),
# Eliminate all clauses that will never match
not match?(%T{data: %Expr{op: :constant, args: [number]}} when number == 0, pred) do
{meta, pred, expr}
end
case clauses do
# At least one clause is expected
[] ->
raise CompileError,
line: meta[:line],
file: file,
description: "cond/if expects at least one branch to always evaluate to true"
# We found a clause that always matches, return it always
[{_meta, %T{data: %Expr{op: :constant, args: [number]}}, expr} | _] when number != 0 ->
expr.()
# Otherwise, keep it as a cond and validate the last clause always returns true
[{_, first_pred, first} | rest] ->
first = first.()
[{last_pred, last} | reverse] =
Enum.reduce(rest, [{first_pred, first}], fn {meta, pred, expr}, acc ->
expr = expr.()
if not Nx.Defn.Composite.compatible?(first, expr, fn _, _ -> true end) do
raise CompileError,
line: meta[:line],
file: file,
description: """
cond/if expects all branches to return compatible tensor types.
Got mismatching templates:
#{inspect_as_template(first)}
and
#{inspect_as_template(expr)}
"""
end
[{pred, expr} | acc]
end)
case last_pred do
%T{data: %Expr{op: :constant, args: [number]}} when number != 0 ->
cond(Enum.reverse(reverse), last)
_ ->
raise CompileError,
line: meta[:line],
file: file,
description: "cond/if expects at least one branch to always evaluate to true"
end
end
end
@doc false
def defn_while(file, line, initial, generator, condition_body, opts) do
initial = to_container_expr(initial)
case generator do
:none ->
if opts != [] do
raise CompileError,
line: line,
file: file,
description:
"options are not supported on while with conditions, only with generators, got: #{inspect(opts)}"
end
{arg, context} = to_param_expr(initial, :while)
{condition, body} = condition_body.(arg)
condition = to_pred(condition, line, file, :while)
body = to_container_expr(body)
compatible_while!(file, line, initial, body)
while(initial, context, arg, condition, body)
{:while, %T{} = generator} ->
range =
case Nx.shape(generator) do
{} ->
message = "cannot have a scalar tensor as generator"
raise CompileError, line: line, file: file, description: message
_ ->
0..(Nx.axis_size(generator, 0) - 1)//1
end
condition_body = fn {{index, generator_expr}, acc} ->
condition_body.({generator_expr[index], acc})
end
while_range(range, file, line, initial, generator, condition_body, opts)
{:while, _.._//_ = range} ->
condition_body = fn {{index, {}}, acc} ->
condition_body.({index, acc})
end
while_range(range, file, line, initial, {}, condition_body, opts)
{:while, other} ->
raise CompileError,
line: line,
file: file,
description: "generators in while expect a range or a tensor, got: #{inspect(other)}"
end
end
defp while_range(range, file, line, initial, generator, condition_body, opts) do
opts = Keyword.validate!(opts, unroll: false)
size = Range.size(range)
{internal, external} =
case opts[:unroll] do
true ->
{nil, range}
false ->
{{range, 0..0//1}, 0..-1//1}
unroll when is_integer(unroll) and unroll >= size ->
{nil, range}
unroll when is_integer(unroll) and unroll > 0 ->
{internal, external} = split_range(range, size - rem(size, unroll))
{{internal, 0..(unroll - 1)//1}, external}
unroll ->
message = ":unroll must be a boolean, got: #{inspect(unroll)}"
raise CompileError, line: line, file: file, description: message
end
result =
case internal do
{first..last//step, internal_unroll} ->
gen_initial = {{tensor(first), generator}, initial}
{{{index_param, generator_param}, arg}, context} = to_param_expr(gen_initial, :while)
condition =
if step > 0 do
Nx.less_equal(index_param, tensor(last))
else
Nx.greater_equal(index_param, tensor(last))
end
body =
Enum.reduce(internal_unroll, arg, fn index, acc ->
next = Nx.add(index_param, step * index)
{true, body} = condition_body.({{next, generator_param}, acc})
body = to_container_expr(body)
index == 0 and compatible_while!(file, line, initial, body)
body
end)
next = Range.size(internal_unroll) * step
gen_arg = {{index_param, generator_param}, arg}
gen_body = {{Nx.add(index_param, next), generator_param}, body}
{_, result} = while(gen_initial, context, gen_arg, condition, gen_body)
result
nil ->
initial
end
Enum.reduce(external, result, fn index, acc ->
{true, body} = condition_body.({{index, generator}, acc})
body = to_container_expr(body)
index == external.first and compatible_while!(file, line, initial, body)
body
end)
end
# TODO: Use Range.split/2 when we require Elixir v1.15+
defp split_range(first..last//step = range, split) when is_integer(split) do
if split >= 0 do
split_range(first, last, step, split)
else
split_range(first, last, step, Range.size(range) + split)
end
end
defp split_range(first, last, step, split) when first < last or (first == last and step > 0) do
if step > 0 do
mid = max(min(first + step * (split - 1), last), first - step)
{first..mid//step, (mid + step)..last//step}
else
{first..(first - step)//step, (last + step)..last//step}
end
end
defp split_range(last, first, step, split) do
if step < 0 do
mid = min(max(last + step * (split - 1), first), last - step)
{last..mid//step, (mid + step)..first//step}
else
{last..(last - step)//step, (first + step)..first//step}
end
end
defp compatible_while!(file, line, initial, body) do
if not Nx.compatible?(initial, body) do
raise CompileError,
line: line,
file: file,
description: """
the do-block in while must return tensors with the same shape, type, and names as the initial arguments.
Body matches template:
#{inspect_as_template(body)}
and initial argument has template:
#{inspect_as_template(initial)}
"""
end
end
## Nx.Backend Callbacks
@behaviour Nx.Backend
@impl true
def init(opts) do
if opts != [] do
raise ArgumentError, "Nx.BinaryBackend accepts no options"
end
opts
end
@impl true
def from_binary(binary, type, _options) do
to_expr(Nx.BinaryBackend.from_binary(binary, type, []))
end
@impl true
def constant(out, number, _options) do
constant(out, number)
end
@impl true
def eye(out, _backend_options) do
expr(out, nil, :eye, [])
end
@impl true
def iota(out, axis, _backend_options) do
expr(out, nil, :iota, [axis])
end
@impl true
def random_uniform(out, min, max, _backend_options) do
{[min, max], context} = to_exprs([min, max])
expr(out, context, :random_uniform, [min, max])
end
@impl true
def random_normal(out, mu, sigma, _backend_options) do
{[mu, sigma], context} = to_exprs([mu, sigma])
expr(out, context, :random_normal, [mu, sigma])
end
unary_ops =
[:exp, :expm1, :log, :log1p, :sigmoid, :cos, :sin, :tan, :cosh, :sinh, :tanh] ++
[:acosh, :asinh, :atanh, :sqrt, :rsqrt, :cbrt, :negate, :sign, :abs, :bitwise_not] ++
[:is_nan, :is_infinity] ++
[:conjugate, :population_count, :count_leading_zeros, :floor, :ceil, :round] ++
[:erf, :erfc, :erf_inv, :acos, :asin, :atan, :bitcast, :real, :imag]
for op <- unary_ops do
@impl true
def unquote(op)(out, tensor) do
tensor = to_expr(tensor)
unary_expr(out, tensor.data.context, unquote(op), tensor)
end
end
@impl true
def add(out, t1, t2) do
{[t1, t2], context} = to_exprs([t1, t2])
c1 = maybe_constant(t1)
c2 = maybe_constant(t2)
cond do
c1 == 0 ->
ensure_compatible(t2, out)
c2 == 0 ->
ensure_compatible(t1, out)
c2 ->
commute(out, context, :add, &Complex.add/2, c2, t2, t1)
true ->
case t2 do
%T{
data: %Expr{
op: :subtract,
args: [%T{data: %Expr{op: :constant, args: [constant]}}, t2]
}
}
when constant == 0 ->
binary_expr(out, context, :subtract, t1, t2)
%T{} ->
commute(out, context, :add, &Complex.add/2, c1, t1, t2)
end
end
end
@impl true
def subtract(out, t1, t2) do
{[t1, t2], context} = to_exprs([t1, t2])
c1 = maybe_constant(t1)
c2 = maybe_constant(t2)
cond do
c2 == 0 -> ensure_compatible(t1, out)
c1 && c2 -> constant(out, Complex.subtract(c1, c2))
true -> binary_expr(out, context, :subtract, t1, t2)
end
end
@impl true
def multiply(out, t1, t2) do
{[t1, t2], context} = to_exprs([t1, t2])
c1 = maybe_constant(t1)
c2 = maybe_constant(t2)
cond do
c1 == 1 ->
ensure_compatible(t2, out)
c2 == 1 ->
ensure_compatible(t1, out)
c2 ->
commute(out, context, :multiply, &Complex.multiply/2, c2, t2, t1)
true ->
case t2 do
%T{
data: %Expr{op: :divide, args: [%T{data: %Expr{op: :constant, args: [constant]}}, t2]}
}
when constant == 1 ->
binary_expr(out, context, :divide, t1, t2)
%T{} ->
commute(out, context, :multiply, &Complex.multiply/2, c1, t1, t2)
end
end
end
@impl true
def divide(out, t1, t2) do
{[t1, t2], context} = to_exprs([t1, t2])
c2 = maybe_constant(t2)
cond do
c2 == 1 -> ensure_compatible(t1, out)
true -> binary_expr(out, context, :divide, t1, t2)
end
end
@impl true
def pow(out, t1, t2) do
{[t1, t2], context} = to_exprs([t1, t2])
c2 = maybe_constant(t2)
cond do
c2 == 1 -> ensure_compatible(t1, out)
true -> binary_expr(out, context, :pow, t1, t2)
end
end
binary_ops =
[:remainder, :atan2, :max, :min, :quotient] ++
[:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] ++
[:equal, :not_equal, :greater, :less, :less_equal, :greater_equal] ++
[:logical_and, :logical_or, :logical_xor]
for op <- binary_ops do
@impl true
def unquote(op)(out, t1, t2) do
{[t1, t2], context} = to_exprs([t1, t2])
binary_expr(out, context, unquote(op), t1, t2)
end
end
aggregate_ops = [:all, :any, :argmax, :argmin, :sum, :product, :reduce_min, :reduce_max]
for op <- aggregate_ops do
@impl true
def unquote(op)(out, tensor, opts) do
tensor = to_expr(tensor)
expr(out, tensor.data.context, unquote(op), [tensor, opts])
end
end
window_aggregate_ops = [:window_sum, :window_product, :window_max, :window_min]
for op <- window_aggregate_ops do
@impl true
def unquote(op)(out, tensor, window_dimensions, opts) do
tensor = to_expr(tensor)
expr(out, tensor.data.context, unquote(op), [tensor, window_dimensions, opts])
end
end
@impl true
def reduce(%{type: type} = out, tensor, acc, opts, fun) do
context = new_context(:reduce)
args = [parameter(context, type, {}, 0), parameter(context, type, {}, 1)]
{[tensor, acc], context} = to_exprs([tensor, acc])
fun = apply_fun(context, fun, args, type)
if fun.shape != {} do
raise "reduce function must return a scalar tensor, got: #{inspect(fun.shape)}"
end
expr(out, context, :reduce, [tensor, acc, opts, fun])
end
@impl true
def window_reduce(
%{type: type} = out,
tensor,
acc,
window_dims,
opts,
fun
) do
context = new_context(:window_reduce)
args = [parameter(context, type, {}, 0), parameter(context, type, {}, 1)]
{[tensor, acc], context} = to_exprs([tensor, acc])
fun = apply_fun(context, fun, args, type)
if fun.shape != {} do
raise "window_reduce function must return a scalar tensor, got: #{inspect(fun.shape)}"
end
expr(out, context, :window_reduce, [tensor, acc, window_dims, opts, fun])
end
@impl true
def map(%{type: type} = out, tensor, opts, fun) do
args = [parameter(new_context(:map), type, {}, 0)]
%{data: %{context: context}} = tensor = to_expr(tensor)
expr(out, context, :map, [tensor, opts, apply_fun(context, fun, args, type)])
end
@impl true
def window_scatter_max(out, tensor, source, init_value, window_dims, opts) do
{[tensor, source, init_value], context} = to_exprs([tensor, source, init_value])
args = [tensor, source, init_value, window_dims, opts]
expr(out, context, :window_scatter_max, args)
end
@impl true
def window_scatter_min(out, tensor, source, init_value, window_dims, opts) do
{[tensor, source, init_value], context} = to_exprs([tensor, source, init_value])
args = [tensor, source, init_value, window_dims, opts]
expr(out, context, :window_scatter_min, args)
end
@impl true
def indexed_add(out, target, indices, updates) do
{[target, indices, updates], context} = to_exprs([target, indices, updates])
expr(out, context, :indexed_add, [target, indices, updates])
end
@impl true
def indexed_put(out, target, indices, updates) do
{[target, indices, updates], context} = to_exprs([target, indices, updates])
expr(out, context, :indexed_put, [target, indices, updates])
end
@impl true
def reshape(out, tensor) do
tensor = to_expr(tensor)
expr(out, tensor.data.context, :reshape, [tensor])
end
@impl true
def squeeze(out, tensor, axes) do
tensor = to_expr(tensor)
# If we are in a sequence of squeezes, we collapse them.
# This helps us fuse the access syntax.
with %T{data: %Expr{op: :squeeze, args: [tensor, inner_axes]}} <- tensor do
axes = merge_squeeze(Enum.sort(inner_axes), Enum.sort(axes), 0)
expr(out, tensor.data.context, :squeeze, [tensor, axes])
else
_ -> expr(out, tensor.data.context, :squeeze, [tensor, axes])
end
end
defp merge_squeeze([inner_axis | inner_axes], [axis | axes], extra)
when inner_axis <= axis + extra,
do: [inner_axis | merge_squeeze(inner_axes, [axis | axes], extra + 1)]
defp merge_squeeze(inner_axes, [axis | axes], extra),
do: [axis + extra | merge_squeeze(inner_axes, axes, extra)]
defp merge_squeeze([], [], _extra),
do: []
@impl true
def transpose(out, tensor, axes) do
tensor = to_expr(tensor)
expr(out, tensor.data.context, :transpose, [tensor, axes])
end
@impl true
def as_type(out, tensor) do
tensor = to_expr(tensor)
unary_expr(out, tensor.data.context, :as_type, tensor)
end
@impl true
def broadcast(out, tensor, shape, axes) do
tensor = to_expr(tensor)
with %T{data: %Expr{op: :broadcast, args: [inner_tensor, inner_shape, inner_axes]}} <- tensor,
true <-
(contiguous?(inner_axes, 0) and contiguous?(axes, 0)) or
(contiguous_last?(inner_axes, inner_shape, inner_tensor) and
contiguous_last?(axes, shape, tensor)) do
expr(out, tensor.data.context, :broadcast, [inner_tensor, shape, inner_axes])
else
_ ->
if c = maybe_constant(tensor) do
constant(out, c)
else
expr(out, tensor.data.context, :broadcast, [tensor, shape, axes])
end
end
end
defp contiguous_last?(axes, out_shape, in_shape),
do: contiguous?(axes, Nx.rank(out_shape) - Nx.rank(in_shape))
defp contiguous?([], _), do: true
defp contiguous?([i | rest], i), do: contiguous?(rest, i + 1)
defp contiguous?(_, _), do: false
@impl true
def dot(out, t1, c1, b1, t2, c2, b2) do
{[t1, t2], context} = to_exprs([t1, t2])
expr(out, context, :dot, [t1, c1, b1, t2, c2, b2])
end
@impl true
def conv(out, inp, kernel, opts) do
{[inp, kernel], context} = to_exprs([inp, kernel])
expr(out, context, :conv, [inp, kernel, opts])
end
@impl true
def pad(out, expr, value, config) do
{[expr, value], context} = to_exprs([expr, value])
expr(out, context, :pad, [expr, value, config])
end
@impl true
def select(out, pred, on_true, on_false) do
{[pred, on_true, on_false], context} = to_exprs([pred, on_true, on_false])
expr(out, context, :select, [pred, on_true, on_false])
end
@impl true
def clip(out, operand, min, max) do
{[operand, min, max], context} = to_exprs([operand, min, max])
expr(out, context, :clip, [operand, min, max])
end
@impl true
def slice(out, tensor, start, lengths, strides) do
all_static? = Enum.all?(start, &is_integer/1)
{[tensor | start], context} =
if all_static? do
tensor = to_expr(tensor)
{[tensor | start], tensor.data.context}
else
to_exprs([tensor | start])
end
# If we are in a sequence of slices, it is the access syntax,
# so we compact them into a single slice.
with true <- ones_stride?(strides),
{slice, axes} <- maybe_squeeze(tensor),
%T{data: %Expr{op: :slice, args: [tensor, inner_start, inner_lengths, strides]}} <-
slice,
true <- ones_stride?(strides) do
{start, lengths} =
0
|> merge_slice(axes, inner_start, start, inner_lengths, lengths)
|> Enum.unzip()
tensor
|> Nx.slice(start, lengths)
|> Nx.squeeze(axes: axes)
else
_ ->
expr(out, context, :slice, [tensor, start, lengths, strides])
end
end
defp ones_stride?(strides), do: Enum.all?(strides, &(&1 == 1))
defp maybe_squeeze(%T{data: %Expr{op: :squeeze, args: [slice, axes]}}), do: {slice, axes}
defp maybe_squeeze(slice), do: {slice, []}
defp merge_slice(_axis, _axes, [], [], [], []), do: []
defp merge_slice(axis, axes, [is | inner_start], start, [il | inner_lengths], lengths) do
# This is one of the erased axes, so we need to get coordinates from inner
if axis in axes do
[{is, il} | merge_slice(axis + 1, axes, inner_start, start, inner_lengths, lengths)]
else
[s | start] = start
[l | lengths] = lengths
[
{Nx.Defn.Kernel.+(is, s), l}
| merge_slice(axis + 1, axes, inner_start, start, inner_lengths, lengths)
]
end
end
@impl true
def put_slice(out, tensor, start, slice) do
{[tensor, slice | start], context} = to_exprs([tensor, slice | start])
expr(out, context, :put_slice, [tensor, start, slice])
end
@impl true
def take(out, tensor, indices, axis) do
{[tensor, indices], context} = to_exprs([tensor, indices])
expr(out, context, :take, [tensor, indices, axis])
end
@impl true
def take_along_axis(out, tensor, indices, axis) do
{[tensor, indices], context} = to_exprs([tensor, indices])
expr(out, context, :take_along_axis, [tensor, indices, axis])
end
@impl true
def gather(out, tensor, indices) do
{[tensor, indices], context} = to_exprs([tensor, indices])
expr(out, context, :gather, [tensor, indices])
end
@impl true
def reverse(out, tensor, axes) do
tensor = to_expr(tensor)
expr(out, tensor.data.context, :reverse, [tensor, axes])
end
@impl true
def concatenate(out, tensors, axis) do
{tensors, context} = to_exprs(tensors)
expr(out, context, :concatenate, [tensors, axis])
end
@impl true
def cholesky(out, tensor) do
tensor = to_expr(tensor)
expr(out, tensor.data.context, :cholesky, [tensor])
end
@impl true
def triangular_solve(out, a, b, opts) do
{[a, b], context} = to_exprs([a, b])
expr(out, context, :triangular_solve, [a, b, opts])
end
@impl true
def lu({p, l, u}, tensor, opts) do
tensor = to_expr(tensor)
context = tensor.data.context
out = %T{names: [], shape: {}, type: {:tuple, 3}}
tuple(expr(out, context, :lu, [{p, l, u}, tensor, opts]), [p, l, u])
end
@impl true
def qr({q, r}, t, opts) do
tensor = to_expr(t)
context = tensor.data.context
out = %T{names: [], shape: {}, type: {:tuple, 2}}
tuple(expr(out, context, :qr, [{q, r}, tensor, opts]), [q, r])
end
@impl true
def eigh({evals, evecs}, tensor, opts) do
tensor = to_expr(tensor)
context = tensor.data.context
out = %T{names: [], shape: {}, type: {:tuple, 2}}
tuple(expr(out, context, :eigh, [{evals, evecs}, tensor, opts]), [evals, evecs])
end
@impl true
def sort(out, tensor, opts) do
%{data: %{context: context}} = tensor = to_expr(tensor)
expr(out, context, :sort, [tensor, opts])
end
@impl true
def argsort(out, tensor, opts) do
%{data: %{context: context}} = tensor = to_expr(tensor)
expr(out, context, :argsort, [tensor, opts])
end
@impl true
def fft(out, tensor, opts) do
%{data: %{context: context}} = tensor = to_expr(tensor)
expr(out, context, :fft, [tensor, opts])
end
@impl true
def ifft(out, tensor, opts) do
%{data: %{context: context}} = tensor = to_expr(tensor)
expr(out, context, :ifft, [tensor, opts])
end
## Undefined
@impl true
def backend_transfer(out, __MODULE__, _), do: out
ops = [
backend_copy: 3,
backend_deallocate: 1,
backend_transfer: 3,
to_binary: 2,
to_batched: 3
]
for {op, arity} <- ops do
args = Macro.generate_arguments(arity, __MODULE__)
@impl true
def unquote(op)(unquote_splicing(args)) do
raise ArgumentError, """
cannot invoke #{unquote(op)}/#{unquote(arity)} on Nx.Defn.Expr.
This typically means you are invoking an unsupported Nx function
inside `defn` or inside JIT compiled code
"""
end
end
## Helpers
@compile {:inline, new_context: 1}
defp new_context(atom) when is_atom(atom) do
{atom, make_ref()}
end
defp expr(tensor, context, op, args) do
%{tensor | data: %Expr{id: id(), op: op, args: args, context: context}}
end
defp to_expr(%T{data: %Expr{}} = t),
do: t
defp to_expr(%T{data: %Nx.BinaryBackend{}, shape: shape} = t) do
case shape do
{} -> constant(t, Nx.to_number(t))
_ -> expr(t, nil, :tensor, [t])
end
end
defp to_expr(%T{data: %backend{}} = t) do
raise ArgumentError,
"cannot convert tensor allocated on #{inspect(backend)} to an expression. " <>
"This may mean you are passing a tensor to defn/jit as an optional argument " <>
"or as closure in an anonymous function. For efficiency, it is preferred " <>
"to always pass tensors as required arguments instead. Alternatively, you " <>
"could call Nx.backend_copy/1 on the tensor, however this will copy its " <>
"value and inline it inside the defn expression. Got: #{inspect(t)}"
end
defp to_expr(number) when is_number(number),
do: constant(%T{shape: {}, names: [], type: Nx.Type.infer(number)}, number)
defp to_expr(other) do
raise ArgumentError,
"unable to build tensor expression, expected a tensor or a number, " <>
"got: #{inspect(other)}"
end
defp to_exprs(list) do
Enum.map_reduce(list, nil, fn tensor, acc ->
expr = to_expr(tensor)
{expr, merge_context!(expr, acc)}
end)
end
defp to_container_expr(container_or_tensor) do
Composite.traverse(container_or_tensor, &to_expr/1)
end
defp to_param_expr(container_or_tensor, context_label) do
context = new_context(context_label)
{arg, {_, context}} =
Composite.traverse(container_or_tensor, {0, nil}, fn expr, {counter, acc} ->
{parameter(expr, context, counter), {counter + 1, merge_context!(expr, acc)}}
end)
{arg, context}
end
defp tuple_out(size) do
%T{shape: {}, names: [], type: {:tuple, size}}
end
defp fun(context, args, body, {_, _, _} = mfa) do
case to_container_expr(body) do
%T{} = tensor ->
expr(tensor, context, :fun, [args, tensor, mfa])
tuple when is_tuple(tuple) ->
expr(tuple_out(tuple_size(tuple)), context, :fun, [args, tuple, mfa])
end
end
defp apply_fun(context, fun, args, type) when is_function(fun, length(args)) do
{:module, mod} = Function.info(fun, :module)
{:name, name} = Function.info(fun, :name)
{:arity, arity} = Function.info(fun, :arity)
# We modify the type after applying because the best form
# to perform type conversions is always left to the compiler.
%{fun(context, args, apply(fun, args), {mod, name, arity}) | type: type}
end
defp to_pred(pred, line, file, op) do
pred =
cond do
is_boolean(pred) ->
number = if pred == false, do: 0, else: 1
%T{data: constant_expr({}, {:u, 8}, number), shape: {}, type: {:u, 8}, names: []}
is_atom(pred) or is_binary(pred) or is_list(pred) ->
raise CompileError,
line: line,
file: file,
description:
"#{Atom.to_string(op)} in defn expects the predicate to be true, false," <>
" or a scalar tensor where 0 is false and everything else is true." <>
" Unsupported value: #{inspect(pred)}"
true ->
to_expr(pred)
end
if not match?(%T{shape: {}}, pred) do
raise CompileError,
line: line,
file: file,
description:
"condition must be a scalar tensor, got: #{inspect(pred)}," <>
" consider using Nx.all/1 or Nx.any/1 to obtain a scalar" <>
" predicate from tensor"
end
pred
end
defp merge_context!(%{data: %{context: context}}, acc) do
if context != acc and context != nil and acc != nil do
raise """
cannot build defn because expressions come from different contexts: \
#{inspect(context)} and #{inspect(acc)}.
This typically happens on "while" and inside anonymous functions when you \
try to access an external variable. All variables you intend to use inside \
"while" or anonymous functions in defn must be explicitly given as arguments.
For example, this is not valid:
defn increment_by_y_while_less_than_10(y) do
while x = 0, Nx.less(x, 10) do
x + y
end
end
In the example above, we want to increment "x" by "y" while it is less than 10. \
However, the code won't compile because "y" is used inside "while" but not \
explicitly defined as part of "while". You must fix it like so:
defn increment_by_y_while_less_than_10(y) do
while {x = 0, y}, Nx.less(x, 10) do
{x + y, y}
end
end
"""
end
context || acc
end
defp inspect_as_template(data) do
if is_number(data) or is_tuple(data) or
(is_map(data) and Nx.Container.impl_for(data) != Nx.Container.Any) do
data
|> Nx.to_template()
|> Kernel.inspect(custom_options: [skip_template_backend_header: true])
else
inspect(data)
end
end
## Constant helpers and related optimizations
defp constant(%{shape: shape, type: type} = out, number) do
number =
cond do
is_integer(number) and Nx.Type.float?(type) ->
Complex.multiply(1.0, number)
not is_integer(number) and Nx.Type.integer?(type) ->
raise ArgumentError,
"value #{inspect(number)} is not valid for constant of type #{inspect(type)}"
number ->
number
end
%{out | data: constant_expr(shape, type, number)}
end
defp constant_expr(shape, type, number) do
%Expr{id: {number, type, shape}, op: :constant, args: [number], context: nil}
end
defp constant_binary(tensor, c) do
Nx.BinaryBackend.constant(%T{type: tensor.type, names: [], shape: {}}, c, [])
end
defp maybe_constant(expr) do
case expr do
%T{data: %Expr{op: :constant, args: [number]}} -> number
_ -> nil
end
end
defp ensure_compatible(t, out) do
t
|> Nx.as_type(out.type)
|> Nx.broadcast(out.shape)
|> Map.replace!(:names, out.names)
end
# Rewrite commutative operations so the constant always come on the left
defp commute(out, context, op, fun, s1, t1, t2) do
{a1, a2} =
case t2 do
%T{data: %Expr{op: ^op, args: [%T{data: %Expr{op: :constant, args: [s2]}}, t3]}} ->
nullary_out = %{out | shape: {}, names: []}
if s1 do
{constant(nullary_out, fun.(s1, s2)), t3 |> Nx.broadcast(out.shape)}
else
{constant(nullary_out, s2), apply(Nx, op, [t1, t3]) |> Nx.broadcast(out.shape)}
end
%T{} ->
case t1 do
%T{data: %Expr{op: ^op, args: [%T{data: %Expr{op: :constant, args: [s1]}}, t3]}} ->
nullary_out = %{out | shape: {}, names: []}
{constant(nullary_out, s1), apply(Nx, op, [t2, t3]) |> Nx.broadcast(out.shape)}
%T{} ->
{t1, t2}
end
end
binary_expr(out, context, op, a1, a2)
end
defp binary_expr(out, context, op, arg1, arg2) do
c1 = maybe_constant(arg1)
c2 = maybe_constant(arg2)
if c1 && c2 do
apply(Nx.BinaryBackend, op, [
%{out | shape: {}, names: []},
constant_binary(arg1, c1),
constant_binary(arg2, c2)
])
|> Nx.to_number()
|> then(&constant(out, &1))
else
expr(out, context, op, [arg1, arg2])
end
end
defp unary_expr(out, context, op, arg) do
if c = maybe_constant(arg) do
apply(Nx.BinaryBackend, op, [%{out | shape: {}, names: []}, constant_binary(arg, c)])
|> Nx.to_number()
|> then(&constant(out, &1))
else
expr(out, context, op, [arg])
end
end
## Inspect
import Inspect.Algebra
@impl true
# Special case for constants since we show them inline in regular printing
def inspect(%T{data: %Expr{op: :constant, args: [constant]}}, opts) do
concat([line(), color("Nx.Defn.Expr", :map, opts), line(), to_string(constant)])
end
def inspect(tensor, opts) do
{t, acc} = inspect_expr(tensor, {[], [], %{}, %{}})
{_, {exprs, params, _var_map, _cache}} = Tree.apply_args(t, acc, &inspect_expr/2)
all = Enum.reverse(params, Enum.reverse(exprs))
header = concat(line(), color("Nx.Defn.Expr", :map, opts))
length = Enum.reduce(all, 0, fn {str, _tensor}, acc -> max(byte_size(str), acc) end)
all
|> Enum.map(fn {str, tensor} ->
String.pad_trailing(str, length, " ") <> " " <> to_type_shape(tensor)
end)
|> Enum.uniq()
|> Enum.reduce(header, &concat(&2, concat(line(), &1)))
end
# Constants and funs are shown as is
defp inspect_expr(%T{data: %Expr{op: :constant}} = t, acc), do: {t, acc}
defp inspect_expr(%T{data: %Expr{op: :fun}} = t, acc), do: {t, acc}
defp inspect_expr(%T{data: %Expr{op: :metadata, args: [expr, metadata]}}, acc)
when not is_map_key(metadata, :inspect),
do: inspect_expr(expr, acc)
defp inspect_expr(%T{data: %Expr{op: :optional, args: [expr, _default]}}, acc) do
inspect_expr(expr, acc)
end
defp inspect_expr(%T{data: %Expr{id: id}} = t, {exprs, params, var_map, cache} = acc) do
case cache do
%{^id => _} -> {t, acc}
%{} -> cached_inspect_expr(t, {exprs, params, var_map, Map.put(cache, id, true)})
end
end
defp cached_inspect_expr(%T{data: %Expr{op: :parameter, id: id, args: [i]}} = t, acc) do
{exprs, params, var_map, cache} = acc
{var, var_map} = var_for_id(var_map, id)
param = "parameter " <> var <> ":" <> Integer.to_string(i)
{t, {exprs, [{param, t} | params], var_map, cache}}
end
defp cached_inspect_expr(%T{data: %Expr{op: :tensor, id: id}} = t, acc) do
{exprs, params, var_map, cache} = acc
{var, var_map} = var_for_id(var_map, id)
param = "tensor " <> var
{t, {exprs, [{param, t} | params], var_map, cache}}
end
defp cached_inspect_expr(%T{} = t, acc) do
%{data: %Expr{id: id, op: op}} = t
{args, {exprs, params, var_map, cache}} = traverse_args(op, t, acc)
{var, var_map} = var_for_id(var_map, id)
args_str = inspect_args(op, args, var_map)
expr_str = var <> " = " <> Atom.to_string(op) <> " " <> args_str
{t, {[{expr_str, t} | exprs], params, var_map, cache}}
end
defp traverse_args(:while, %T{data: %{args: [initial, _, _, _]}}, acc) do
{initial, acc} = Composite.traverse(initial, acc, &inspect_expr/2)
{[initial], acc}
end
defp traverse_args(:token, %T{data: %{args: [token]}}, acc) do
{hooks, acc} =
Enum.map_reduce(token.hooks, acc, fn %{name: name, expr: expr}, acc ->
{expr, acc} = Composite.traverse(expr, acc, &inspect_expr/2)
{{name, expr}, acc}
end)
{hooks, acc}
end
defp traverse_args(_op, t, acc) do
Tree.apply_args(t, acc, &inspect_expr/2)
end
defp inspect_args(:token, hooks, var_map) do
IO.iodata_to_binary(
Enum.map_intersperse(hooks, ", ", fn {key, val} ->
"#{key}: " <> inspect_arg(val, var_map)
end)
)
end
defp inspect_args(:while, [initial], var_map) do
IO.iodata_to_binary(inspect_arg(initial, var_map))
end
defp inspect_args(:cond, [clauses, last], var_map) do
clauses =
Enum.map(clauses, fn {pred, expr} ->
[inspect_arg(pred, var_map), " -> ", inspect_arg(expr, var_map), ", "]
end)
IO.iodata_to_binary([clauses, "true -> ", inspect_arg(last, var_map)])
end
defp inspect_args(:metadata, [expr, %{inspect: inspect}], var_map) do
IO.iodata_to_binary([inspect_arg(expr, var_map), ", ", inspect(inspect)])
end
defp inspect_args(_op, [tuple | args], var_map) when is_tuple(tuple),
do: inspect_args(args, var_map)
defp inspect_args(_op, args, var_map),
do: inspect_args(args, var_map)
defp inspect_args(args, var_map),
do: Enum.map_join(args, ", ", &inspect_arg(&1, var_map))
defp inspect_arg(arg, var_map) do
case arg do
%T{data: %Expr{op: :fun, args: [_, _, {m, f, a}]}} ->
[?&, Exception.format_mfa(m, f, a)]
%T{data: %Expr{op: :constant, args: [number]}} ->
to_string(number)
%T{data: %Expr{id: id}} ->
Map.fetch!(var_map, id)
_ ->
cond do
Keyword.keyword?(arg) and arg != [] ->
Enum.map_join(arg, ", ", fn {k, v} -> "#{k}: #{inspect(v)}" end)
is_list(arg) ->
[?[, inspect_args(arg, var_map), ?]]
is_tuple(arg) ->
[?{, inspect_args(Tuple.to_list(arg), var_map), ?}]
true ->
inspect(arg)
end
end
end
defp var_for_id(var_map, id) do
case var_map do
%{^id => var} ->
{var, var_map}
%{} ->
var = IO.iodata_to_binary(counter_to_name(map_size(var_map)))
{var, Map.put(var_map, id, var)}
end
end
defp counter_to_name(counter) when counter >= 26 do
[counter_to_name(div(counter, 26)) | counter_to_name(rem(counter, 26))]
end
defp counter_to_name(counter), do: [Enum.at(?a..?z, counter)]
defp to_type_shape(%{type: type, shape: shape}) do
brackets =
shape
|> Tuple.to_list()
|> Enum.map(&[?[, Integer.to_string(&1), ?]])
IO.iodata_to_binary([Nx.Type.to_string(type) | brackets])
end
end