defmodule Nx.Backend do
@moduledoc """
The behaviour for tensor backends.
Each backend is module that defines a struct and implements the callbacks
defined in this module. The callbacks are mostly implementations of the
functions in the `Nx` module with the tensor output shape given as first
argument.
`Nx` backends come in two flavors: opaque backends, of which you should
not access its data directly except through the functions in the `Nx`
module, and public ones, of which its data can be directly accessed and
traversed. The former typically have the `Backend` suffix.
`Nx` ships with the following backends:
* `Nx.BinaryBackend` - an opaque backend written in pure Elixir
that stores the data in Elixir's binaries. This is the default
backend used by the `Nx` module. The backend itself (and its
data) is private and must not be accessed directly.
* `Nx.TemplateBackend` - an opaque backend written that works as
a template in APIs to declare the type, shape, and names of
tensors to be expected in the future.
* `Nx.Defn.Expr` - a public backend used by `defn` to build
expression trees that are traversed by custom compilers.
This module also includes functions that are meant to be shared
across backends.
"""
@type t :: %{__struct__: atom()}
@type tensor :: Nx.Tensor.t()
@type shape :: Nx.Tensor.shape()
@type axis :: Nx.Tensor.axis()
@type axes :: Nx.Tensor.axes()
@type backend_options :: term()
@callback init(keyword()) :: backend_options
@callback constant(out :: tensor, number | Complex.t(), backend_options) :: tensor
@callback from_binary(out :: tensor, binary, backend_options) :: tensor
@callback eye(tensor, backend_options) :: tensor
@callback iota(tensor, axis | nil, backend_options) :: tensor
@callback random_uniform(tensor, tensor, tensor, backend_options) :: tensor
@callback random_normal(tensor, mu :: tensor, sigma :: tensor, backend_options) :: tensor
@callback backend_deallocate(tensor) :: :ok | :already_deallocated
@callback backend_copy(tensor, module, backend_options) :: tensor
@callback backend_transfer(tensor, module, backend_options) :: tensor
@callback to_batched(out :: tensor, tensor, keyword) :: [tensor]
@callback to_binary(tensor, limit :: non_neg_integer) :: binary
@callback inspect(tensor, Inspect.Opts.t()) :: tensor
@callback as_type(out :: tensor, tensor) :: tensor
@callback bitcast(out :: tensor, tensor) :: tensor
@callback reshape(out :: tensor, tensor) :: tensor
@callback squeeze(out :: tensor, tensor, axes) :: tensor
@callback broadcast(out :: tensor, tensor, shape, axes) :: tensor
@callback transpose(out :: tensor, tensor, axes) :: tensor
@callback pad(out :: tensor, tensor, pad_value :: tensor, padding_config :: list()) :: tensor
@callback reverse(out :: tensor, tensor, axes) :: tensor
@callback dot(out :: tensor, tensor, axes, axes, tensor, axes, axes) :: tensor
@callback clip(out :: tensor, tensor, min :: tensor, max :: tensor) :: tensor
@callback slice(out :: tensor, tensor, list, list, list) :: tensor
@callback put_slice(out :: tensor, tensor, tensor, list) :: tensor
@callback take(out :: tensor, input :: tensor, indices :: tensor, axis) :: tensor
@callback take_along_axis(out :: tensor, input :: tensor, indices :: tensor, axis) :: tensor
@callback gather(out :: tensor, input :: tensor, indices :: tensor) :: tensor
@callback concatenate(out :: tensor, tensor, axis) :: tensor
@callback select(out :: tensor, tensor, tensor, tensor) :: tensor
@callback conv(out :: tensor, tensor, kernel :: tensor, keyword) :: tensor
@callback all(out :: tensor, tensor, keyword) :: tensor
@callback any(out :: tensor, tensor, keyword) :: tensor
@callback sum(out :: tensor, tensor, keyword) :: tensor
@callback product(out :: tensor, tensor, keyword) :: tensor
@callback reduce_max(out :: tensor, tensor, keyword) :: tensor
@callback reduce_min(out :: tensor, tensor, keyword) :: tensor
@callback argmax(out :: tensor, tensor, keyword) :: tensor
@callback argmin(out :: tensor, tensor, keyword) :: tensor
@callback reduce(out :: tensor, tensor, acc :: tensor, keyword, fun) :: tensor
@callback window_reduce(out :: tensor, tensor, acc :: tensor, shape, keyword, fun) :: tensor
@callback window_sum(out :: tensor, tensor, shape, keyword) :: tensor
@callback window_product(out :: tensor, tensor, shape, keyword) :: tensor
@callback window_max(out :: tensor, tensor, shape, keyword) :: tensor
@callback window_min(out :: tensor, tensor, shape, keyword) :: tensor
@callback map(out :: tensor, tensor, keyword, fun) :: tensor
@callback sort(out :: tensor, tensor, keyword) :: tensor
@callback argsort(out :: tensor, tensor, keyword) :: tensor
@callback window_scatter_max(out :: tensor, tensor, tensor, tensor, shape, keyword) :: tensor
@callback window_scatter_min(out :: tensor, tensor, tensor, tensor, shape, keyword) :: tensor
@callback indexed_add(out :: tensor, target :: tensor, indices :: tensor, updates :: tensor) ::
tensor
@callback indexed_put(out :: tensor, target :: tensor, indices :: tensor, updates :: tensor) ::
tensor
@callback cholesky(out :: tensor, tensor) :: tensor
@callback lu({p :: tensor, l :: tensor, u :: tensor}, tensor, keyword) :: tensor
@callback qr({q :: tensor, r :: tensor}, tensor, keyword) :: tensor
@callback triangular_solve(out :: tensor, a :: tensor, b :: tensor, keyword) :: tensor
@callback eigh({eigenvals :: tensor, eigenvecs :: tensor}, tensor, keyword) :: tensor
@callback svd({u :: tensor, s :: tensor, v :: tensor}, tensor, keyword) :: tensor
@callback fft(out :: tensor, tensor, keyword) :: tensor
@callback ifft(out :: tensor, tensor, keyword) :: tensor
binary_ops =
[:add, :subtract, :multiply, :pow, :remainder, :divide, :atan2, :min, :max, :quotient] ++
[:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] ++
[:equal, :not_equal, :greater, :less, :greater_equal, :less_equal] ++
[:logical_and, :logical_or, :logical_xor]
for binary_op <- binary_ops do
@callback unquote(binary_op)(out :: tensor, tensor, tensor) :: tensor
end
unary_ops =
Enum.map(Nx.Shared.unary_math_funs(), &elem(&1, 0)) ++
[:abs, :bitwise_not, :ceil, :conjugate, :floor, :negate, :round, :sign] ++
[:count_leading_zeros, :population_count, :real, :imag, :is_nan, :is_infinity]
for unary_op <- unary_ops do
@callback unquote(unary_op)(out :: tensor, tensor) :: tensor
end
## Optional Callbacks
@doc """
Invoked for execution of optional callbacks with a default implementation.
First we will attempt to call the optional callback itself
(one of the many callbacks defined below), then we attempt
to call this callback (which is also optional), then we
fallback to the default iomplementation.
"""
@callback optional(atom, [term], fun) :: tensor
@callback solve(out :: tensor, a :: tensor, b :: tensor) :: tensor
@callback determinant(out :: tensor, t :: tensor) :: tensor
@callback logical_not(out :: tensor, t :: tensor) :: tensor
@callback phase(out :: tensor, t :: tensor) :: tensor
@callback cumulative_sum(out :: tensor, t :: tensor, keyword) :: tensor
@callback cumulative_product(out :: tensor, t :: tensor, keyword) :: tensor
@callback cumulative_min(out :: tensor, t :: tensor, keyword) :: tensor
@callback cumulative_max(out :: tensor, t :: tensor, keyword) :: tensor
@callback all_close(out :: tensor, tensor, tensor, keyword) :: tensor
@callback top_k(out :: tensor, tensor, keyword) :: tensor
@optional_callbacks [
optional: 3,
solve: 3,
determinant: 2,
logical_not: 2,
phase: 2,
cumulative_sum: 3,
cumulative_product: 3,
cumulative_min: 3,
cumulative_max: 3,
all_close: 4,
svd: 3,
top_k: 3
]
## Inspect implementation
require Nx.Shared
alias Inspect.Algebra, as: IA
@doc """
Inspects the given tensor given by `binary`.
Note the `binary` may have fewer elements than the
tensor size but, in such cases, it must strictly have
more elements than `inspect_opts.limit`
## Options
The following must be passed through `Inspect` `:custom_options`
* `:nx_precision` - Configures the floating-point number printing precision.
If set, will print floating-point numbers in scientific notation using the
specified number of significant digits. Otherwise, default Elixir printing
rules are applied.
"""
def inspect(%{shape: shape, type: type}, binary, inspect_opts) do
open = IA.color("[", :list, inspect_opts)
sep = IA.color(",", :list, inspect_opts)
close = IA.color("]", :list, inspect_opts)
# TO-DO: This is a paliative accessibility-related solution
precision = inspect_opts.custom_options[:nx_precision]
dims = Tuple.to_list(shape)
{data, _rest, _limit} =
chunk(dims, binary, type, inspect_opts.limit, precision, {open, sep, close})
data
end
defp chunk([], data, type, limit, precision, _docs) do
{doc, tail} =
Nx.Shared.match_types [type] do
<<match!(head, 0), tail::binary>> = data
{inspect_value(read!(head, 0), precision), tail}
end
if limit == :infinity, do: {doc, tail, limit}, else: {doc, tail, limit - 1}
end
defp chunk([dim | dims], data, type, limit, precision, {open, sep, close} = docs) do
{acc, rest, limit} =
chunk_each(dim, data, [], limit, fn chunk, limit ->
chunk(dims, chunk, type, limit, precision, docs)
end)
{open, sep, close, nest} =
if dims == [] do
{open, IA.concat(sep, " "), close, 0}
else
{IA.concat(open, IA.line()), IA.concat(sep, IA.line()), IA.concat(IA.line(), close), 2}
end
doc =
open
|> IA.concat(IA.concat(Enum.intersperse(acc, sep)))
|> IA.nest(nest)
|> IA.concat(close)
{doc, rest, limit}
end
defp chunk_each(0, data, acc, limit, _fun) do
{Enum.reverse(acc), data, limit}
end
defp chunk_each(_dim, data, acc, 0, _fun) do
{Enum.reverse(["..." | acc]), data, 0}
end
defp chunk_each(dim, data, acc, limit, fun) do
{doc, rest, limit} = fun.(data, limit)
chunk_each(dim - 1, rest, [doc | acc], limit, fun)
end
defp inspect_value(integer, _) when is_integer(integer), do: Integer.to_string(integer)
defp inspect_value(:neg_infinity, _), do: "-Inf"
defp inspect_value(:infinity, _), do: "Inf"
defp inspect_value(:nan, _), do: "NaN"
defp inspect_value(%Complex{} = val, precision), do: complex_to_string(val, precision)
defp inspect_value(float, precision), do: float_to_string(float, precision)
defp float_to_string(float, precision) do
[integer_part, decimal_part, exponent_part] =
case String.split(Float.to_string(float), [".", "e"], parts: 3) do
[i, d] -> [i, d, ""]
[i, d, e] -> [i, d, "e" <> e]
end
# We'll now prune decimal_part to ensure we have at most `precision`
# digits there.
decimal_part =
decimal_part
|> binary_part(0, min(byte_size(decimal_part), precision))
# We also prune trailing zeros. Only for more than 1 digit because that single
# digit always needs to stay put.
decimal_part =
if byte_size(decimal_part) > 1 do
String.trim_trailing(decimal_part, "0")
else
decimal_part
end
integer_part <> "." <> decimal_part <> exponent_part
end
def complex_to_string(%Complex{re: re, im: im}, precision) do
re_str = inspect_value(re, precision)
im_str = inspect_value(im, precision)
im_str =
case im_str do
"-" <> _ -> im_str
s -> "+" <> s
end
re_str <> im_str <> "i"
end
end