defmodule EXLA.Op do
@moduledoc """
Wrapper around XLA's ops.
"""
alias __MODULE__
alias EXLA.{Builder, Computation, Shape}
@enforce_keys [:builder, :ref]
defstruct [:builder, :ref]
## Constructors
@doc """
Creates a numeric constant.
"""
def constant_r0(%Builder{} = builder, non_finite, dtype) when is_atom(non_finite) do
binary = apply(Nx.Type, :"#{non_finite}_binary", [dtype])
shape = EXLA.Shape.make_shape(dtype, {})
constant_from_binary(builder, binary, shape)
end
def constant_r0(%Builder{} = builder, %Complex{re: r, im: i}, dtype = {:c, size}) do
data =
case size do
64 -> <<r::32-float-native, i::32-float-native>>
128 -> <<r::64-float-native, i::64-float-native>>
end
constant_from_binary(builder, data, Shape.make_shape(dtype, {}))
end
def constant_r0(%Builder{ref: builder}, value, dtype = {_, _}) when is_number(value) do
value = cast_number!(dtype, value)
ref = EXLA.NIF.constant_r0(builder, value, Shape.dtype_to_charlist(dtype)) |> unwrap!()
%Op{builder: builder, ref: ref}
end
defp cast_number!({:pred, 8}, 0), do: 0
defp cast_number!({:pred, 8}, 1), do: 1
defp cast_number!({:pred, 8}, n), do: raise("cannot cast #{inspect(n)} to {:pred, 8}")
defp cast_number!(type, number), do: Nx.Type.cast_number!(type, number)
@doc """
Creates a n-dimensional constant from binary `data` with `shape`.
"""
def constant_from_binary(%Builder{ref: builder}, data, %Shape{} = shape)
when is_binary(data) do
%{dims: dims, dtype: {_, size}, ref: shape_ref} = shape
if bit_size(data) != size * tuple_product(dims) do
raise ArgumentError, "binary does not match the given type and dimensions"
end
ref = EXLA.NIF.constant_from_binary(builder, data, shape_ref) |> unwrap!()
%Op{builder: builder, ref: ref}
end
@doc """
Specifies a parameter at position `i` with `shape` and `name`.
"""
def parameter(%Builder{ref: builder}, i, %Shape{ref: shape}, name)
when is_integer(i) and i >= 0 and is_binary(name) do
ref = EXLA.NIF.parameter(builder, i, shape, name) |> unwrap!()
%Op{builder: builder, ref: ref}
end
@doc """
Builds a tuple with the given elements.
"""
def tuple(%Builder{ref: builder}, elements) when is_list(elements) do
element_refs = Enum.map(elements, & &1.ref)
ref = EXLA.NIF.tuple(builder, element_refs) |> unwrap!()
%Op{builder: builder, ref: ref}
end
@doc """
Creates tensor with normal distribution.
"""
def rng_normal(%Op{builder: builder, ref: mu}, %Op{builder: builder, ref: sigma}, %Shape{
ref: shape
}) do
ref = EXLA.NIF.rng_normal(mu, sigma, shape) |> unwrap!()
%Op{builder: builder, ref: ref}
end
@doc """
Creates tensor with uniform distribution.
"""
def rng_uniform(%Op{builder: builder, ref: a}, %Op{builder: builder, ref: b}, %Shape{ref: shape}) do
ref = EXLA.NIF.rng_uniform(a, b, shape) |> unwrap!()
%Op{builder: builder, ref: ref}
end
@doc """
Creates iota tensor.
"""
def iota(%Builder{ref: builder}, %Shape{ref: shape}, dim) do
ref = EXLA.NIF.iota(builder, shape, dim) |> unwrap!()
%Op{builder: builder, ref: ref}
end
## Shape
@doc """
Gets the shape of an operator.
"""
def get_shape(%Op{builder: builder, ref: operand}) do
ref = EXLA.NIF.get_shape(builder, operand) |> unwrap!()
Shape.get_shape_info(ref)
end
@doc """
Reshapes the tensor to `shape`.
"""
def reshape(%Op{ref: ref} = op, shape) when is_tuple(shape) do
ref = EXLA.NIF.reshape(ref, shape) |> unwrap!()
%{op | ref: ref}
end
@doc """
Pads the tensor with value and padding config.
"""
def pad(%Op{ref: op, builder: builder}, %Op{ref: value, builder: builder}, padding_config) do
ref = EXLA.NIF.pad(op, value, padding_config) |> unwrap!()
%Op{builder: builder, ref: ref}
end
@doc """
Broadcasts the tensor to `shape`.
"""
def broadcast_in_dim(%Op{ref: ref} = op, shape, broadcast_dims)
when is_tuple(shape) and is_tuple(broadcast_dims) do
ref = EXLA.NIF.broadcast_in_dim(ref, shape, broadcast_dims) |> unwrap!()
%{op | ref: ref}
end
## Element-wise binary ops
arith = [:add, :subtract, :multiply, :divide, :max, :min, :remainder, :atan2, :power]
bitwise = [:bitwise_and, :bitwise_or, :bitwise_xor]
shift = [:left_shift, :right_shift_arithmetic, :right_shift_logical]
comparison = [:equal, :not_equal, :greater, :less, :greater_equal, :less_equal]
for fun <- arith ++ bitwise ++ shift ++ comparison do
@doc """
Element-wise #{fun} with broadcasting.
"""
def unquote(fun)(
%Op{builder: builder, ref: left},
%Op{builder: builder, ref: right},
broadcast_dims \\ {}
)
when is_tuple(broadcast_dims) do
ref = EXLA.NIF.unquote(fun)(left, right, broadcast_dims) |> unwrap!()
%Op{builder: builder, ref: ref}
end
end
## Element-wise unary ops
returns_float =
[:exp, :expm1, :log, :log1p, :sigmoid, :cos, :sin, :tanh, :sqrt, :rsqrt, :cbrt] ++
[:acosh, :asinh, :atanh, :acos, :asin, :atan, :cosh, :sinh] ++
[:erf, :erfc, :erf_inv]
returns_any = [:negate]
requires_int = [:count_leading_zeros, :population_count, :bitwise_not]
requires_signed = [:abs, :sign]
requires_float = [:floor, :ceil, :round]
for fun <- returns_float ++ returns_any ++ requires_int ++ requires_signed ++ requires_float do
@doc """
Unary #{fun}.
"""
def unquote(fun)(%Op{ref: ref} = op) do
ref = EXLA.NIF.unquote(fun)(ref) |> unwrap!()
%{op | ref: ref}
end
end
def fft(%Op{ref: ref} = op, fft_size) do
ref = EXLA.NIF.fft(ref, fft_size) |> unwrap!()
%{op | ref: ref}
end
def ifft(%Op{ref: ref} = op, fft_size) do
ref = EXLA.NIF.ifft(ref, fft_size) |> unwrap!()
%{op | ref: ref}
end
def is_nan(op, type, shape, axes, state),
do: is_non_finite(&EXLA.NIF.is_nan/1, op, type, shape, axes, state)
def is_infinity(op, type, shape, axes, state),
do: is_non_finite(&EXLA.NIF.is_infinity/1, op, type, shape, axes, state)
def is_non_finite(nif_function, %{ref: ref} = op, {:c, _}, _shape, _axes, _state) do
re_part = ref |> EXLA.NIF.real() |> unwrap!() |> nif_function.() |> unwrap!()
im_part = ref |> EXLA.NIF.imag() |> unwrap!() |> nif_function.() |> unwrap!()
result_ref = EXLA.NIF.bitwise_or(re_part, im_part, {}) |> unwrap!()
%{op | ref: result_ref}
end
def is_non_finite(nif_function, op, {t, _}, _shape, _axes, _state) when t in [:f, :bf] do
%{ref: ref} = op
result_ref = nif_function.(ref) |> unwrap!()
%{op | ref: result_ref}
end
def is_non_finite(_nif_function, _op, _type, shape, axes, %{builder: builder}) do
# For non-floating types, we can just return
# a boolean 0 tensor in the output shape
builder
|> constant_r0(0, {:u, 8})
|> reshape(Tuple.duplicate(1, tuple_size(shape)))
|> broadcast_in_dim(shape, List.to_tuple(axes))
end
## Ops
def get_tuple_element(%Op{ref: operand} = op, index) when is_integer(index) do
ref = EXLA.NIF.get_tuple_element(operand, index) |> unwrap!()
%{op | ref: ref}
end
def conditional(
%Op{builder: builder, ref: pred},
%Op{builder: builder, ref: true_op},
%Computation{ref: true_comp},
%Op{builder: builder, ref: false_op},
%Computation{ref: false_comp}
) do
ref = EXLA.NIF.conditional(pred, true_op, true_comp, false_op, false_comp) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def conditional(%Op{builder: builder, ref: index}, branches, operands) do
branches_refs =
branches
|> Enum.map(& &1.ref)
operands_refs =
operands
|> Enum.map(& &1.ref)
ref = EXLA.NIF.conditional(index, branches_refs, operands_refs) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def select(
%Op{builder: builder, ref: pred},
%Op{builder: builder, ref: on_true},
%Op{builder: builder, ref: on_false}
) do
ref = EXLA.NIF.select(pred, on_true, on_false) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def slice(
%Op{builder: builder, ref: op},
start_indices,
limit_indices,
strides
)
when is_list(start_indices) and is_list(limit_indices) and is_list(strides) do
ref = EXLA.NIF.slice(op, start_indices, limit_indices, strides) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def dynamic_slice(
%Op{builder: builder, ref: op},
indices,
slice_sizes
)
when is_list(indices) and is_list(slice_sizes) do
indices_refs = Enum.map(indices, & &1.ref)
ref = EXLA.NIF.dynamic_slice(op, indices_refs, slice_sizes) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def dynamic_update_slice(
%Op{builder: builder, ref: op},
%Op{builder: builder, ref: update},
indices
)
when is_list(indices) do
indices_refs = Enum.map(indices, & &1.ref)
ref = EXLA.NIF.dynamic_update_slice(op, update, indices_refs) |> unwrap!()
%Op{builder: builder, ref: ref}
end
@doc """
The XLA gather operation stitches together several slices
of an input array.
Note that this operation is extremely generic and far from
intuitive for regular usage. However, it can be used to implement
many specific operations that have to do with combining multiple
tensor slices.
## Parameteres
The XLA docs are rather cryptic unless already understood,
so here's an attempt of a more intuitive description.
### `index_vector_dim`
Determines which dimension contains index vectors. In most cases
we want to set this to the last dimension.
given
start_indices = [[0, 1], [1, 1]]
and given
index_vector_dim = 1
then
index vectors are [0, 1] and [1, 1]
Note that we can set this to `last_dimension + 1`, in which case
`start_indices` are implicitly reshaped to have a trailing dimension
of 1.
given
start_indices = [[0, 1], [1, 1]]
and given
index_vector_dim = 2
then
start_indices <- [[[0], [1]], [[1], [1]]]
index vectors are [0], [1], [1], [1]
### `start_index_map`
Note: though given as a list, it can be treated as a map of `list_idx -> value`.
An index vector may have less elements than the operand tensor shape.
For example:
given
operand = [[1, 2], [3, 4]]
start_indices = [[1], [0]]
index_vector_dim = 1
As described above, in this case index vectors are `[1]`, `[0]` and they have
length 1. However, the operand has rank 2, so we need vectors of the form `[_, _]`
to point to a specific element in the operand. The `start_index_map` determines
where indices go into this template:
and given
start_index_map = [0] # effectively %{0 => 0}
then
actual index vectors are [1, _] and [0, _]
and given
start_index_map = [1] # effectively %{0 => 1}
then
actual index vectors are [_, 1] and [_, 0]
Finally, the missing elements (`_`) are assumed to be 0.
Complete examples:
given
operand = [[1, 2], [3, 4]]
start_indices = [[0], [1]]
index_vector_dim = 1
and given
start_index_map = [1] # effectively %{0 => 1}
then
actual index vectors are [0, 0], [0, 1] (leading 0 is inserted)
given
operand = [[1, 2], [3, 4]]
start_indices = [[0, 1], [1, 1]]
index_vector_dim = 1
and given
start_index_map = [0, 1] # effectively %{0 => 0, 1 => 1}
then
actual index vectors are [0, 1], [1, 1] (as expected)
given
operand = [[1, 2], [3, 4]]
start_indices = [[0, 1], [1, 1]]
index_vector_dim = 1
and given
start_index_map = [1, 0] # effectively %{0 => 1, 1 => 0}
then
actual index vectors are [1, 0], [1, 1] (see how the first vector is reversed)
### `slice_sizes`
For every starting point (as described above) we take a slice given
by `slice_sizes`. Naturally, `slice_sizes` must have the same length
as operand rank, so that we have one size per dimension.
given
operand = [[1, 2], [3, 4]]
actual index vector [1, 0]
and given
slice_sizes = [1, 2]
then
slice for actual index vector is [[3, 4]]
### `collapsed_slice_dims`
A list of dimensions that are collapsed (effectively removed) in
the slice shape. Only dimensions of size 1 can be collapsed.
given
slice is [[3, 4]] # shape: [1][2]
and given
collapsed_slice_dims = [0]
then
actual slice is [3, 4] # shape [2]
### `offset_dims`
A list of dimensions in the output tensor corresponding to the
non-collapsed dimensions in slice tensors. In other words, these
dimensions are used for indexing elements of the slice tensors.
given
operand = [[1, 2], [3, 4]]
start_indices = [[1, 0], [0, 0], [1, 0]]
index_vector_dim = 1
start_index_map = [1, 2] # effectively %{0 => 0, 1 => 1}
collapsed_slice_dims = [0]
and given
offset_dims = [1]
then
result is [[3, 4], [1, 2], [3, 4]]
In the above example the collapsed slices are `[3, 4]`, `[1, 2]`, `[3, 4]`
and have rank 1. Using `offset_dims` we specify that the first
dimension in each slice corresponds to the second dimension in
the output tensor.
If we use the first output dimension instead, we get:
and given
offset_dims = [0]
then
result is [[3, 1, 3], [4, 2, 4]]
## Docs
More formal specification can be found in [the XLA Gather docs](https://www.tensorflow.org/xla/operation_semantics#gather).
"""
def gather(
%Op{builder: builder, ref: op},
%Op{builder: builder, ref: start_indices},
index_vector_dim,
slice_sizes,
offset_dims,
collapsed_slice_dims,
start_index_map
)
when is_integer(index_vector_dim) and is_list(slice_sizes) and is_list(offset_dims) and
is_list(collapsed_slice_dims) and is_list(start_index_map) do
ref =
EXLA.NIF.gather(
op,
start_indices,
index_vector_dim,
slice_sizes,
offset_dims,
collapsed_slice_dims,
start_index_map
)
|> unwrap!()
%Op{builder: builder, ref: ref}
end
def dot(
%Op{builder: builder, ref: left},
%Op{builder: builder, ref: right},
precision_config
) do
config = get_precision_config_int(precision_config)
ref = EXLA.NIF.dot(left, right, config) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def dot_general(
%Op{builder: builder, ref: left},
%Op{builder: builder, ref: right},
dimnos,
precision_config
) do
config = get_precision_config_int(precision_config)
ref = EXLA.NIF.dot_general(left, right, dimnos, config) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def conv_general_dilated(
%Op{builder: builder, ref: operand},
%Op{builder: builder, ref: kernel},
strides,
padding,
lhs_dilation,
rhs_dilation,
dim_nums,
feature_group_count,
batch_group_count,
precision_config
)
when is_list(strides) and is_list(lhs_dilation) and is_list(rhs_dilation) do
config = get_precision_config_int(precision_config)
ref =
EXLA.NIF.conv_general_dilated(
operand,
kernel,
strides,
padding,
lhs_dilation,
rhs_dilation,
dim_nums,
feature_group_count,
batch_group_count,
config
)
|> unwrap!()
%Op{builder: builder, ref: ref}
end
def transpose(%Op{builder: builder, ref: operand}, permutation) when is_tuple(permutation) do
ref = EXLA.NIF.transpose(operand, permutation) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def reduce(
%Op{builder: builder, ref: operand},
%Op{builder: builder, ref: init_value},
%Computation{ref: reduction},
reduction_dimensions
) do
ref = EXLA.NIF.reduce(operand, init_value, reduction, reduction_dimensions) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def variadic_reduce(
%Builder{ref: builder},
operands,
init_values,
%Computation{ref: reduction},
reduction_dimensions
) do
operand_refs = Enum.map(operands, & &1.ref)
init_value_refs = Enum.map(init_values, & &1.ref)
ref =
EXLA.NIF.variadic_reduce(
builder,
operand_refs,
init_value_refs,
reduction,
reduction_dimensions
)
|> unwrap!()
%Op{builder: builder, ref: ref}
end
def window_reduce(
%Op{builder: builder, ref: operand},
%Op{builder: builder, ref: init_value},
%Computation{ref: reduction},
window_dimensions,
window_strides,
window_dilations,
padding_config
)
when is_tuple(window_dimensions) and is_list(window_strides) and is_list(window_dilations) do
ref =
EXLA.NIF.window_reduce(
operand,
init_value,
reduction,
window_dimensions,
window_strides,
window_dilations,
padding_config
)
|> unwrap!()
%Op{builder: builder, ref: ref}
end
def select_and_scatter(
%Op{builder: builder, ref: operand},
%Computation{ref: select_fn},
window_dimensions,
window_strides,
padding_config,
%Op{builder: builder, ref: source},
%Op{builder: builder, ref: init_value},
%Computation{ref: scatter_fn}
)
when is_tuple(window_dimensions) and is_list(window_strides) and is_list(padding_config) do
ref =
EXLA.NIF.select_and_scatter(
operand,
select_fn,
window_dimensions,
window_strides,
padding_config,
source,
init_value,
scatter_fn
)
|> unwrap!()
%Op{builder: builder, ref: ref}
end
def scatter(
%Op{builder: builder, ref: target},
%Op{ref: indices},
%Op{ref: updates},
%Computation{ref: scatter_fn},
indices_rank,
update_window_dims,
inserted_window_dims,
index_dims_to_window_dims
)
when is_integer(indices_rank) and is_list(update_window_dims) and
is_list(inserted_window_dims) and is_list(index_dims_to_window_dims) do
ref =
EXLA.NIF.scatter(
target,
indices,
updates,
scatter_fn,
indices_rank,
update_window_dims,
inserted_window_dims,
index_dims_to_window_dims
)
|> unwrap!()
%Op{builder: builder, ref: ref}
end
def map(%Op{builder: builder, ref: operand}, %Computation{ref: function}, dimensions) do
ref = EXLA.NIF.map(builder, operand, function, dimensions) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def while(
%Computation{ref: cond_fn},
%Computation{ref: body_fn},
%Op{builder: builder, ref: init_value}
) do
ref = EXLA.NIF.while(cond_fn, body_fn, init_value) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def call(
%Builder{ref: builder},
args,
%Computation{ref: body_fn}
) do
args_fn = Enum.map(args, & &1.ref)
# wrap args in an n-tuple to avoid nif variadic limitations
ref = EXLA.NIF.call(builder, args_fn, body_fn) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def convert_element_type(%Op{builder: builder, ref: operand}, dtype) do
ref = EXLA.NIF.convert_element_type(operand, Shape.dtype_to_charlist(dtype)) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def bitcast_convert_type(%Op{builder: builder, ref: operand}, dtype) do
ref = EXLA.NIF.bitcast_convert_type(operand, Shape.dtype_to_charlist(dtype)) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def clamp(%Op{builder: builder, ref: operand}, %Op{builder: builder, ref: min}, %Op{
builder: builder,
ref: max
}) do
ref = EXLA.NIF.clamp(operand, min, max) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def reverse(%Op{builder: builder, ref: operand}, dimensions) do
ref = EXLA.NIF.reverse(operand, dimensions) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def concatenate([o1 | _] = operands, dimension) do
%Op{builder: builder} = o1
operand_refs =
operands
|> Enum.map(& &1.ref)
ref = EXLA.NIF.concatenate(builder, operand_refs, dimension) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def conjugate(%Op{builder: builder, ref: operand}) do
ref = EXLA.NIF.conj(operand) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def real(%Op{builder: builder, ref: operand}) do
ref = EXLA.NIF.real(operand) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def imag(%Op{builder: builder, ref: operand}) do
ref = EXLA.NIF.imag(operand) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def cholesky(%Op{builder: builder, ref: operand}) do
ref = EXLA.NIF.cholesky(operand) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def eigh(%Op{builder: builder, ref: operand}, lower) do
{v_ref, w_ref} = EXLA.NIF.eigh(operand, lower) |> unwrap!()
{
%Op{builder: builder, ref: v_ref},
%Op{builder: builder, ref: w_ref}
}
end
def lu(%Op{builder: builder, ref: operand}) do
{lu_ref, pivot_ref, permutation_ref} = EXLA.NIF.lu(operand) |> unwrap!()
{
%Op{builder: builder, ref: lu_ref},
%Op{builder: builder, ref: pivot_ref},
%Op{builder: builder, ref: permutation_ref}
}
end
def qr(%Op{builder: builder, ref: operand}, full_matrices)
when is_boolean(full_matrices) do
full_matrices = boolean_to_int(full_matrices)
{q_ref, r_ref} = EXLA.NIF.qr(operand, full_matrices) |> unwrap!()
{
%Op{builder: builder, ref: q_ref},
%Op{builder: builder, ref: r_ref}
}
end
def svd(%Op{builder: builder, ref: operand}, precision) do
precision_config = get_precision_config_int(precision)
{u_ref, d_ref, v_ref} = EXLA.NIF.svd(operand, precision_config) |> unwrap!()
{
%Op{builder: builder, ref: u_ref},
%Op{builder: builder, ref: d_ref},
%Op{builder: builder, ref: v_ref}
}
end
def triangular_solve(
%Op{builder: builder, ref: a},
%Op{builder: builder, ref: b},
left_side,
lower,
unit_diagonal,
transpose_a
)
when is_boolean(left_side) and is_boolean(lower) and is_boolean(unit_diagonal) do
left_side = boolean_to_int(left_side)
lower = boolean_to_int(lower)
unit_diagonal = boolean_to_int(unit_diagonal)
transpose_a_int =
case transpose_a do
:none -> 0
:transpose -> 1
:conjugate -> 2
end
ref =
EXLA.NIF.triangular_solve(a, b, left_side, lower, unit_diagonal, transpose_a_int)
|> unwrap!()
%Op{builder: builder, ref: ref}
end
def sort(%Op{builder: builder, ref: operand}, %Computation{ref: comparator}, dimension) do
ref = EXLA.NIF.sort(operand, comparator, dimension) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def variadic_sort(
%Builder{ref: builder},
operands,
%Computation{ref: comparator},
dimension
) do
operand_refs = Enum.map(operands, & &1.ref)
ref = EXLA.NIF.variadic_sort(operand_refs, comparator, dimension) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def create_token(%Builder{ref: builder}) do
ref = EXLA.NIF.create_token(builder) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def infeed(%Op{builder: builder, ref: token}, %Shape{ref: shape}) do
ref = EXLA.NIF.infeed(token, shape) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def outfeed(%Op{builder: builder, ref: operand}, %Op{builder: builder, ref: token}) do
shape_ref = EXLA.NIF.get_shape(builder, operand) |> unwrap!()
ref = EXLA.NIF.outfeed(operand, token, shape_ref) |> unwrap!()
%Op{builder: builder, ref: ref}
end
def outfeed(%Op{builder: builder, ref: operand}) do
ref = EXLA.NIF.optimization_barrier(operand) |> unwrap!()
%Op{builder: builder, ref: ref}
end
## Helpers
defp get_precision_config_int(precision_config) do
case precision_config do
:default ->
0
:high ->
1
:highest ->
2
_ ->
raise ArgumentError,
"expected precision configuration to be one of" <>
" :default, :high, or :highest, got: #{inspect(precision_config)}"
end
end
defp boolean_to_int(true), do: 1
defp boolean_to_int(false), do: 0
defp tuple_product(tuple), do: tuple_product(tuple, tuple_size(tuple))
defp tuple_product(_tuple, 0), do: 1
defp tuple_product(tuple, i), do: :erlang.element(i, tuple) * tuple_product(tuple, i - 1)
defp unwrap!({:ok, ref}), do: ref
defp unwrap!({:error, error}), do: raise(List.to_string(error))
end