defmodule EXLA.Backend do
@moduledoc """
A Nx tensor backend for the data kept on the device.
You can directly transfer to this backend by calling
`Nx.backend_transfer/2` or `Nx.backend_copy/2`. It
allows the following options:
* `:client` - the client to store the data on.
Defaults to EXLA's default client.
* `:device_id` - which device to store it on.
To get the data out of the device backend into a regular
tensor, call `Nx.backend_transfer/1` (with the device
tensor as the single argument).
Note that the `EXLA.Backend` is asynchronous: operations
on its tensors *may* return immediately, before the tensor
data is available. The backend will then block only when
trying to read the data or when passing it to another operation.
"""
@behaviour Nx.Backend
@enforce_keys [:buffer]
defstruct [:buffer]
alias Nx.Tensor, as: T
alias EXLA.Backend, as: B
@impl true
def init(opts) do
Keyword.validate!(opts, [:client, :device_id])
end
@impl true
def constant(out, constant, backend_options) do
binary_tensor = Nx.BinaryBackend.constant(out, constant, [])
Nx.BinaryBackend.backend_transfer(binary_tensor, __MODULE__, backend_options)
end
@impl true
def from_binary(%T{shape: shape, type: type} = tensor, binary, backend_options) do
{client, device_id} = client_and_device_id(backend_options)
shape = EXLA.Shape.make_shape(type, shape)
buffer = EXLA.DeviceBuffer.place_on_device(binary, shape, client, device_id)
put_in(tensor.data, %B{buffer: buffer})
end
@impl true
def backend_copy(%T{data: %B{buffer: buffer}} = tensor, EXLA.Backend, backend_options) do
{client, device_id} = client_and_device_id(backend_options)
if same_client_device?(buffer, client, device_id) do
# We cannot copy to the same client/device using copy_to_device
EXLA.Backend.from_binary(tensor, EXLA.DeviceBuffer.read(buffer), backend_options)
else
buffer = EXLA.DeviceBuffer.copy_to_device(buffer, client, device_id)
put_in(tensor.data, %B{buffer: buffer})
end
end
def backend_copy(%T{data: %B{buffer: buffer}} = tensor, backend, backend_options) do
backend.from_binary(tensor, EXLA.DeviceBuffer.read(buffer), backend_options)
end
@impl true
def backend_transfer(%T{data: %B{buffer: buffer}} = tensor, backend, backend_options) do
if backend == __MODULE__ and same_client_device?(buffer, backend_options) do
tensor
else
try do
backend_copy(tensor, backend, backend_options)
after
EXLA.DeviceBuffer.deallocate(buffer)
end
end
end
@impl true
def backend_deallocate(%T{data: %B{buffer: buffer}}) do
EXLA.DeviceBuffer.deallocate(buffer)
end
@impl true
def to_batched(out, tensor, opts) do
leftover = opts[:leftover]
batch_size = elem(out.shape, 0)
axis_size = elem(tensor.shape, 0)
remainder = rem(axis_size, batch_size)
num_full_batches = div(axis_size, batch_size)
range =
if remainder != 0 and leftover == :repeat do
0..num_full_batches
else
0..(num_full_batches - 1)
end
Stream.map(range, fn
^num_full_batches ->
expr_fun = fn tensor ->
Nx.concatenate([
Nx.slice_along_axis(tensor, num_full_batches * batch_size, remainder),
Nx.slice_along_axis(tensor, 0, batch_size - remainder)
])
end
jit([], expr_fun, [tensor])
i ->
expr_fun = fn tensor, start_idx ->
Nx.slice_along_axis(tensor, start_idx, batch_size)
end
start_idx = i * batch_size
jit([], expr_fun, [tensor, start_idx])
end)
end
@impl true
def to_binary(%T{data: %B{buffer: buffer}, type: {_, size}}, limit) do
EXLA.DeviceBuffer.read(buffer, limit * div(size, 8))
end
@impl true
def inspect(%T{} = tensor, inspect_opts) do
limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1
tensor
|> to_binary(min(limit, Nx.size(tensor)))
|> then(&Nx.Backend.inspect(tensor, &1, inspect_opts))
|> maybe_add_signature(tensor)
end
if Application.compile_env(:exla, :add_backend_on_inspect, true) do
defp maybe_add_signature(result, %T{data: %B{buffer: buffer}}) do
%EXLA.DeviceBuffer{client_name: client_name, device_id: device_id, ref: ref} = buffer
~c"#Ref<" ++ rest = :erlang.ref_to_list(ref)
info = "EXLA.Backend<#{client_name}:#{device_id}, " <> List.to_string(rest)
Inspect.Algebra.concat([info, Inspect.Algebra.line(), result])
end
else
defp maybe_add_signature(result, _tensor) do
result
end
end
## Helpers
defp client_and_device_id(opts) do
client = EXLA.Client.fetch!(opts[:client] || EXLA.Client.default_name())
device_id = opts[:device_id] || client.default_device_id
{client, device_id}
end
defp same_client_device?(buffer, opts) do
{client, device_id} = client_and_device_id(opts)
same_client_device?(buffer, client, device_id)
end
defp same_client_device?(buffer, client, device_id) do
buffer.client_name == client.name and buffer.device_id == device_id
end
## JIT callbacks
@impl true
def concatenate(out, tensors, axis) do
out = Nx.to_template(out)
expr_fun = fn tensors ->
Nx.Defn.Expr.concatenate(out, Tuple.to_list(tensors), axis)
end
jit([], expr_fun, tensors, [List.to_tuple(tensors)])
end
@impl true
def slice(out, tensor, start_indices, lengths, strides) do
out = Nx.to_template(out)
if Enum.all?(start_indices, &is_integer/1) do
expr_fun = fn tensor ->
Nx.Defn.Expr.slice(out, tensor, start_indices, lengths, strides)
end
jit([], expr_fun, [tensor])
else
expr_fun = fn tensor, start_indices ->
Nx.Defn.Expr.slice(out, tensor, Tuple.to_list(start_indices), lengths, strides)
end
jit([], expr_fun, [tensor | start_indices], [tensor, List.to_tuple(start_indices)])
end
end
@impl true
def put_slice(out, tensor, start_indices, slice) do
out = Nx.to_template(out)
if Enum.all?(start_indices, &is_integer/1) do
expr_fun = fn tensor, slice ->
Nx.Defn.Expr.put_slice(out, tensor, start_indices, slice)
end
jit([], expr_fun, [tensor, slice])
else
expr_fun = fn tensor, start_indices, slice ->
Nx.Defn.Expr.put_slice(out, tensor, Tuple.to_list(start_indices), slice)
end
jit(
[],
expr_fun,
[tensor, slice | start_indices],
[tensor, List.to_tuple(start_indices), slice]
)
end
end
@impl true
def optional(name, args, fun) do
# Here we take the leading tensor arguments and pass them as JIT arguments
{tensors, rest} = Enum.split_while(args, &is_struct(&1, Nx.Tensor))
wrapper_fun = fn tensors ->
Nx.Defn.Expr.optional(name, Tuple.to_list(tensors) ++ rest, fun)
end
jit([], wrapper_fun, tensors, [List.to_tuple(tensors)])
end
binary_ops =
[:add, :subtract, :multiply, :pow, :remainder, :divide, :atan2, :min, :max, :quotient] ++
[:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] ++
[:equal, :not_equal, :greater, :less, :greater_equal, :less_equal] ++
[:logical_and, :logical_or, :logical_xor]
unary_ops =
[:exp, :expm1, :log, :log1p, :sigmoid, :cos, :sin, :tan] ++
[:cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh] ++
[:sqrt, :rsqrt, :cbrt, :is_nan, :is_infinity, :erf, :erfc, :erf_inv] ++
[:abs, :bitwise_not, :ceil, :conjugate, :floor, :negate, :round, :sign] ++
[:count_leading_zeros, :population_count, :real, :imag]
callbacks =
[
{:eye, [:backend_options], []},
{:iota, [:axis, :backend_options], []},
{:random_uniform, [:min, :max, :backend_options], [:min, :max]},
{:random_normal, [:mu, :sigma, :backend_options], [:mu, :sigma]},
{:as_type, [:tensor], [:tensor]},
{:bitcast, [:tensor], [:tensor]},
{:reshape, [:tensor], [:tensor]},
{:squeeze, [:tensor, :axes], [:tensor]},
{:broadcast, [:tensor, :shape, :axes], [:tensor]},
{:transpose, [:tensor, :axes], [:tensor]},
{:pad, [:tensor, :pad_value, :padding_config], [:tensor, :pad_value]},
{:reverse, [:tensor, :axes], [:tensor]},
{:dot, [:left, :c1, :b1, :right, :c2, :b2], [:left, :right]},
{:clip, [:tensor, :min, :max], [:tensor, :min, :max]},
{:take, [:tensor, :indices, :axis], [:tensor, :indices]},
{:take_along_axis, [:tensor, :indices, :axis], [:tensor, :indices]},
{:gather, [:input, :indices], [:input, :indices]},
{:select, [:pred, :on_true, :on_false], [:pred, :on_true, :on_false]},
{:conv, [:tensor, :kernel, :opts], [:tensor, :kernel]},
{:all, [:tensor, :opts], [:tensor]},
{:any, [:tensor, :opts], [:tensor]},
{:sum, [:tensor, :opts], [:tensor]},
{:product, [:tensor, :opts], [:tensor]},
{:reduce_max, [:tensor, :opts], [:tensor]},
{:reduce_min, [:tensor, :opts], [:tensor]},
{:argmax, [:tensor, :opts], [:tensor]},
{:argmin, [:tensor, :opts], [:tensor]},
{:reduce, [:tensor, :acc, :opts, :fun], [:tensor, :acc]},
{:window_reduce, [:tensor, :acc, :shape, :opts, :fun], [:tensor, :acc]},
{:window_sum, [:tensor, :shape, :opts], [:tensor]},
{:window_product, [:tensor, :shape, :opts], [:tensor]},
{:window_max, [:tensor, :shape, :opts], [:tensor]},
{:window_min, [:tensor, :shape, :opts], [:tensor]},
{:map, [:tensor, :opts, :fun], [:tensor]},
{:sort, [:tensor, :opts], [:tensor]},
{:argsort, [:tensor, :opts], [:tensor]},
{:window_scatter_max, [:tensor, :source, :init_value, :window_dims, :opts],
[:tensor, :source, :init_value]},
{:window_scatter_min, [:tensor, :source, :init_value, :window_dims, :opts],
[:tensor, :source, :init_value]},
{:indexed_add, [:tensor, :indices, :updates], [:tensor, :indices, :updates]},
{:indexed_put, [:tensor, :indices, :updates], [:tensor, :indices, :updates]},
{:cholesky, [:tensor], [:tensor]},
{:lu, [:tensor, :opts], [:tensor]},
{:qr, [:tensor, :opts], [:tensor]},
{:triangular_solve, [:a, :b, :opts], [:a, :b]},
{:eigh, [:tensor, :opts], [:tensor]},
{:fft, [:tensor, :opts], [:tensor]},
{:ifft, [:tensor, :opts], [:tensor]}
] ++
for(op <- binary_ops, do: {op, [:left, :right], [:left, :right]}) ++
for(op <- unary_ops, do: {op, [:tensor], [:tensor]})
for {name, args, tensor_args} <- callbacks do
args = Enum.map(args, &Macro.var(&1, __MODULE__))
tensor_args = Enum.map(tensor_args, &Macro.var(&1, __MODULE__))
backend_options = Enum.find(args, [], &match?({:backend_options, _, _}, &1))
@impl true
def unquote(name)(out, unquote_splicing(args)) do
out = Nx.to_template(out)
expr_fun = fn unquote_splicing(tensor_args) ->
Nx.Defn.Expr.unquote(name)(out, unquote_splicing(args))
end
jit(unquote(backend_options), expr_fun, [unquote_splicing(tensor_args)])
end
end
defp jit(backend_options, fun, args), do: jit(backend_options, fun, args, args)
defp jit(backend_options, fun, tensors, args) do
client =
for %T{data: %B{buffer: %EXLA.DeviceBuffer{client_name: client_name}}} <- tensors,
reduce: nil do
acc when acc != nil and acc != client_name ->
if EXLA.Client.fetch!(client_name).platform == :host do
acc
else
client_name
end
_ ->
client_name
end
client = backend_options[:client] || client
EXLA.jit_apply(fun, args, on_conflict: :force, client: client || EXLA.Client.default_name())
end
end