lib/nx/binary_backend.ex

defmodule Nx.BinaryBackend do
  @moduledoc """
  An opaque backend written in pure Elixir that stores
  the data in Elixir's binaries.

  This is the default backend used by the `Nx` module.
  The backend itself (and its data) is private and must
  not be accessed directly.
  """
  use Complex.Kernel

  @behaviour Nx.Backend

  @doc false
  defstruct [:state]

  alias Nx.Tensor, as: T
  alias Nx.BinaryBackend, as: B

  import Nx.Shared
  import Bitwise, only: [>>>: 2, &&&: 2]

  ## Creation

  @impl true
  def constant(%{type: type, shape: shape} = out, constant, _backend_options) do
    data = :binary.copy(number_to_binary(constant, type), Nx.size(shape))
    from_binary(out, data)
  end

  @impl true
  def random_uniform(%{type: type, shape: shape} = out, min, max, _backend_options) do
    min = scalar_to_number(min)
    max = scalar_to_number(max)

    gen_float = fn -> (max - min) * :rand.uniform() + min end

    gen =
      case type do
        {:s, _} -> fn -> min + :rand.uniform(max - min) - 1 end
        {:u, _} -> fn -> min + :rand.uniform(max - min) - 1 end
        {:c, _} -> fn -> Complex.new(gen_float.(), gen_float.()) end
        {_, _} -> gen_float
      end

    data = for _ <- 1..Nx.size(shape), into: "", do: number_to_binary(gen.(), type)
    from_binary(out, data)
  end

  @impl true
  def random_normal(%{type: type, shape: shape} = out, mu, sigma, _backend_options) do
    mu = scalar_to_number(mu)
    sigma = scalar_to_number(sigma)
    variance = sigma * sigma

    gen =
      case type do
        {:c, _} ->
          fn ->
            Complex.new(:rand.normal(mu, variance), :rand.normal(mu, variance))
          end

        _ ->
          fn -> :rand.normal(mu, variance) end
      end

    data =
      for _ <- 1..Nx.size(shape),
          into: "",
          do: number_to_binary(gen.(), type)

    from_binary(out, data)
  end

  @impl true
  def iota(%{shape: {}, type: type} = out, nil, _backend_options) do
    from_binary(out, number_to_binary(0, type))
  end

  def iota(%{shape: shape, type: type} = out, nil, backend_options) do
    t = iota(%T{type: type, shape: {Nx.size(shape)}, names: [nil]}, 0, backend_options)
    %{out | data: t.data}
  end

  def iota(%{shape: {n}, type: type} = out, 0, _backend_options) do
    data = for i <- 0..(n - 1), do: number_to_binary(i, type)
    from_binary(out, data)
  end

  def iota(%{shape: shape, type: type} = out, axis, _backend_options) do
    {dims_before, [dim | dims_after]} =
      shape
      |> Tuple.to_list()
      |> Enum.split(axis)

    # Number of repetitions of an index in memory
    repeat_blocks =
      dims_after
      |> Enum.reduce(1, &*/2)

    # Number of cycles of the counting pattern
    cycles =
      dims_before
      |> Enum.reduce(1, &*/2)

    data =
      for _ <- 1..cycles,
          i <- 0..(dim - 1),
          _ <- 1..repeat_blocks,
          into: "",
          do: number_to_binary(i, type)

    from_binary(out, data)
  end

  @impl true
  def eye(%{shape: {n, n}, type: type} = out, _backend_options) do
    one = number_to_binary(1, type)
    zero = number_to_binary(0, type)

    data =
      for i <- 1..n, j <- 1..n, into: <<>> do
        if i == j, do: one, else: zero
      end

    from_binary(out, data)
  end

  ## Conversions

  @impl true
  def from_binary(t, binary, _backend_options), do: from_binary(t, binary)

  defp from_binary(t, binary) when is_binary(binary), do: %{t | data: %B{state: binary}}
  defp from_binary(t, other), do: %{t | data: %B{state: IO.iodata_to_binary(other)}}

  @impl true
  def to_binary(%{type: {_backend_options, size}} = t, limit) do
    limit = limit * div(size, 8)
    binary = to_binary(t)

    if byte_size(binary) == limit do
      binary
    else
      binary_part(binary, 0, limit)
    end
  end

  defp to_binary(%T{data: %{state: data}}), do: data

  @impl true
  def backend_copy(tensor, backend, opts) do
    backend_transfer(tensor, backend, opts)
  end

  @impl true
  def backend_transfer(tensor, Nx.Tensor, _opts) do
    tensor
  end

  def backend_transfer(tensor, Nx.BinaryBackend, _opts) do
    tensor
  end

  def backend_transfer(tensor, backend, opts) do
    backend.from_binary(tensor, to_binary(tensor), opts)
  end

  @impl true
  def backend_deallocate(_tensor) do
    :ok
  end

  @impl true
  def to_batched(out, %{type: {_, size}} = tensor, opts) do
    leftover = opts[:leftover]

    batch_size = elem(out.shape, 0)
    axis_size = elem(tensor.shape, 0)

    remainder = rem(axis_size, batch_size)
    num_full_batches = div(axis_size, batch_size)

    range =
      if remainder != 0 and leftover == :repeat do
        0..num_full_batches
      else
        0..(num_full_batches - 1)
      end

    binary = to_binary(tensor)
    batch_bytes = Nx.size(out) * div(size, 8)

    Stream.map(range, fn
      ^num_full_batches ->
        before = num_full_batches * batch_bytes
        available = byte_size(binary) - before
        missing = batch_bytes - available

        from_binary(out, [binary_part(binary, before, available), binary_part(binary, 0, missing)])

      i ->
        from_binary(out, binary_part(binary, i * batch_bytes, batch_bytes))
    end)
  end

  ## Shape

  @impl true
  def reshape(out, tensor), do: from_binary(out, to_binary(tensor))

  @impl true
  def squeeze(out, tensor, _axes), do: from_binary(out, to_binary(tensor))

  ## Broadcast

  @impl true
  def broadcast(out, t, shape, axes) do
    from_binary(out, broadcast_data(t, shape, axes))
  end

  defp broadcast_data(%{shape: shape} = t, shape),
    do: to_binary(t)

  defp broadcast_data(t, shape),
    do: broadcast_data(t, shape, Nx.Shape.broadcast_axes(t.shape, shape))

  defp broadcast_data(%T{shape: {}} = t, shape, []) do
    t
    |> to_binary()
    |> :binary.copy(Nx.size(shape))
  end

  defp broadcast_data(%T{shape: old_shape, type: {_, size}} = t, new_shape, axes) do
    chunk_size = size * Nx.size(old_shape)

    new_shape
    |> Tuple.to_list()
    |> unary_broadcast(0, old_shape, 0, axes, to_binary(t), chunk_size)
    |> IO.iodata_to_binary()
  end

  # Old and new match
  defp unary_broadcast([dim | dims], axis, old_shape, old_pos, [axis | axes], data, chunk_size)
       when elem(old_shape, old_pos) == dim do
    chunk_size = div(chunk_size, dim)

    for <<chunk::size(chunk_size)-bitstring <- data>> do
      unary_broadcast(dims, axis + 1, old_shape, old_pos + 1, axes, chunk, chunk_size)
    end
  end

  # Implicit broadcasting
  defp unary_broadcast([dim | dims], axis, old_shape, old_pos, [axis | axes], data, chunk_size)
       when elem(old_shape, old_pos) == 1 do
    for _ <- 1..dim do
      unary_broadcast(dims, axis + 1, old_shape, old_pos + 1, axes, data, chunk_size)
    end
  end

  # Explicit broadcasting (unmapped axes)
  defp unary_broadcast([dim | dims], axis, old_shape, old_pos, axes, data, chunk_size) do
    for _ <- 1..dim do
      unary_broadcast(dims, axis + 1, old_shape, old_pos, axes, data, chunk_size)
    end
  end

  defp unary_broadcast([], _axis, _old_shape, _old_pos, [], data, _chunk_size) do
    data
  end

  ## Shape

  @impl true
  def transpose(out, %T{shape: shape, type: {_, size}} = t, axes) do
    data = to_binary(t)
    {list, min, max} = transpose_axes(shape, axes)
    weighted_shape = weighted_shape(shape, size)

    # The chunk size is computed based on all dimensions
    # before the minimum one being changed. For example,
    # for {0, 1, 2, 3} and the swap is between 1 and 2,
    # the chunk_size will be d1 * d2 * d3 * size.
    chunk_size = weighted_chunk(weighted_shape, min, size)

    # All of the major dimensions not being transposed can be
    # read at once. For example, for {0, 1, 2, 3} and the swap
    # is between 1 and 2, the read_size will be d3 * size.
    read_size = weighted_chunk(weighted_shape, max + 1, size)

    # And now how we will traverse
    traverse_list = Enum.map(list, &Enum.fetch!(weighted_shape, &1))

    data =
      for <<chunk::size(chunk_size)-bitstring <- data>> do
        weighted_traverse(traverse_list, chunk, read_size)
      end

    from_binary(out, data)
  end

  defp transpose_axes(shape, axes) do
    size = tuple_size(shape)
    {axes, min} = transpose_min(axes, 0)
    {axes, max} = transpose_max(Enum.reverse(axes), size - 1)
    {axes, min, max}
  end

  defp transpose_min([head | tail], head), do: transpose_min(tail, head + 1)
  defp transpose_min(tail, head), do: {tail, head}

  defp transpose_max([head | tail], head), do: transpose_max(tail, head - 1)
  defp transpose_max(tail, head), do: {Enum.reverse(tail), head}

  ## Pad

  # We ignore the out because we need to recur over the shape
  # as we transpose and build the rest.
  @impl true
  def pad(out, t, pad_value, padding_config) do
    pad_value = to_binary(as_type(out, pad_value))

    case t.shape do
      {} ->
        t

      {_} ->
        [{edge_low, edge_high, interior}] = padding_config
        pad_last_dim(t, pad_value, edge_low, edge_high, interior)

      _ ->
        permutation = for i <- 0..(Nx.rank(t) - 2), do: i
        permutation = [Nx.rank(t) - 1 | permutation]

        for {edge_low, edge_high, interior} <- Enum.reverse(padding_config), reduce: t do
          acc ->
            Nx.transpose(pad_last_dim(acc, pad_value, edge_low, edge_high, interior),
              axes: permutation
            )
        end
    end
  end

  # Add padding to the high and low ends of the last dimension of a tensor
  defp pad_last_dim(
         %T{shape: shape, type: {_, size} = type} = t,
         value,
         edge_low,
         edge_high,
         interior
       ) do
    view = aggregate_axes(to_binary(t), [tuple_size(shape) - 1], shape, size)
    new_shape = pad_in_dim(shape, tuple_size(shape) - 1, edge_low, edge_high, interior)

    edge_high_padding =
      if edge_high <= 0,
        do: <<>>,
        else: for(_ <- 1..edge_high, into: <<>>, do: value)

    edge_low_padding =
      if edge_low <= 0,
        do: <<>>,
        else: for(_ <- 1..edge_low, into: <<>>, do: value)

    interior_padding =
      if interior == 0,
        do: <<>>,
        else: for(_ <- 1..interior, into: <<>>, do: value)

    interior_padding_size = interior * size

    interior_padded =
      for bin <- view do
        padded =
          for <<dim::size(size)-bitstring <- bin>>, into: <<>> do
            <<dim::size(size)-bitstring, interior_padding::bitstring>>
          end

        new_bytes = byte_size(padded) * 8 - interior_padding_size
        <<new_bin::size(new_bytes)-bitstring, _::bitstring>> = padded
        new_bin
      end

    data =
      for bin <- interior_padded, into: <<>> do
        cond do
          edge_low < 0 and edge_high < 0 ->
            low_byte = abs(edge_low) * size
            high_byte = abs(edge_high) * size
            new_bytes = byte_size(bin) * 8 - high_byte - low_byte

            <<_::size(low_byte)-bitstring, new_bin::size(new_bytes)-bitstring, _::bitstring>> =
              bin

            new_bin

          edge_low < 0 and edge_high >= 0 ->
            low_byte = abs(edge_low) * size
            <<_::size(low_byte)-bitstring, new_bin::bitstring>> = bin
            <<new_bin::bitstring, edge_high_padding::bitstring>>

          edge_low >= 0 and edge_high < 0 ->
            high_byte = abs(edge_high) * size
            new_bytes = byte_size(bin) * 8 - high_byte
            <<new_bin::size(new_bytes)-bitstring, _::bitstring>> = bin
            <<edge_low_padding::bitstring, new_bin::bitstring>>

          true ->
            <<edge_low_padding::bitstring, bin::bitstring, edge_high_padding::bitstring>>
        end
      end

    from_binary(%{t | type: type, shape: new_shape}, data)
  end

  defp pad_in_dim(shape, dim, edge_low, edge_high, interior) do
    dim_size = elem(shape, dim)
    interior_padding_factor = (dim_size - 1) * interior
    new_dim = dim_size + interior_padding_factor + edge_high + edge_low
    put_elem(shape, dim, new_dim)
  end

  @impl true
  def reverse(out, %{type: {_, size}, shape: shape} = t, axes) do
    data = to_binary(t)
    weighted_shape = weighted_shape(shape, size)

    # Nx guaranteex axes is sorted and non-empty.
    min = List.first(axes)
    max = List.last(axes) + 1

    # The chunk size is computed based on all dimensions
    # before the minimum one being changed. For example,
    # for {0, 1, 2, 3} and the reverse is between 1 and 2,
    # the chunk_size will be d1 * d2 * d3 * size.
    chunk_size = weighted_chunk(weighted_shape, min, size)

    # All of the major dimensions not being reverse can be
    # read at once. For example, for {0, 1, 2, 3} and the reverse
    # is between 1 and 2, the read_size will be d3 * size.
    read_size = weighted_chunk(weighted_shape, max, size)

    # And now how we will traverse
    traverse =
      weighted_shape
      |> Enum.take(max)
      |> Enum.drop(min)
      |> reverse_traverse(min, axes)

    data =
      for <<chunk::size(chunk_size)-bitstring <- data>> do
        weighted_traverse(traverse, chunk, read_size)
      end

    from_binary(out, data)
  end

  defp reverse_traverse([head | tail], axis, axes) do
    if axis in axes do
      [&Enum.reverse/1, head | reverse_traverse(tail, axis + 1, axes)]
    else
      [head | reverse_traverse(tail, axis + 1, axes)]
    end
  end

  defp reverse_traverse([], _axis, _axes), do: []

  ## Two-element

  @impl true
  def dot(out, left, contract_axes1, [], right, contract_axes2, []) do
    # dot/4 is directed to this specific clause so we can keep a more efficient implementation
    # for non-batched dot products. See the clause below for batched dot products
    data = bin_dot(left, contract_axes1, right, contract_axes2, out.type)
    from_binary(out, data)
  end

  def dot(
        out,
        %{shape: left_shape, type: {_, left_size}, names: left_names} = left,
        left_contract_axes,
        left_batch_axes,
        %{shape: right_shape, type: {_, right_size}, names: right_names} = right,
        right_contract_axes,
        right_batch_axes
      ) do
    left_binary = to_binary(left)
    right_binary = to_binary(right)

    left_batch_contract_axes =
      Enum.map(left_contract_axes, fn axis -> axis - length(left_batch_axes) end)

    right_batch_contract_axes =
      Enum.map(right_contract_axes, fn axis -> axis - length(right_batch_axes) end)

    {left_batch_shape, _left_batch_names} =
      Nx.Shape.contract(left_shape, left_batch_axes, left_names, false)

    {right_batch_shape, _right_batch_names} =
      Nx.Shape.contract(right_shape, right_batch_axes, right_names, false)

    left_batch_item_length = Nx.size(left_batch_shape)
    right_batch_item_length = Nx.size(right_batch_shape)

    batch_count = Enum.reduce(left_batch_axes, 1, fn x, acc -> elem(left_shape, x) * acc end)

    range = if batch_count == 0, do: [], else: 0..(batch_count - 1)

    left_batch_item_template = %{left | shape: left_batch_shape}
    right_batch_item_template = %{right | shape: right_batch_shape}

    bin_result =
      for index <- range do
        left_offset = index * left_batch_item_length
        right_offset = index * right_batch_item_length

        left_offset_bits = left_offset * left_size
        right_offset_bits = right_offset * right_size

        left_batch_item_bits = left_batch_item_length * left_size
        right_batch_item_bits = right_batch_item_length * right_size

        <<_::bitstring-size(left_offset_bits),
          left_batch_item_binary::bitstring-size(left_batch_item_bits),
          _::bitstring>> = left_binary

        <<_::bitstring-size(right_offset_bits),
          right_batch_item_binary::bitstring-size(right_batch_item_bits),
          _::bitstring>> = right_binary

        bin_dot(
          from_binary(left_batch_item_template, left_batch_item_binary),
          left_batch_contract_axes,
          from_binary(right_batch_item_template, right_batch_item_binary),
          right_batch_contract_axes,
          out.type
        )
      end

    from_binary(out, bin_result)
  end

  defp bin_dot(%{type: t1} = left, contract_axes1, %{type: t2} = right, contract_axes2, type) do
    {left, left_contract_axes} = bin_dot_transpose_contract_axes(left, contract_axes1)

    {right, right_contract_axes} = bin_dot_transpose_contract_axes(right, contract_axes2)

    bin_zip_reduce(left, left_contract_axes, right, right_contract_axes, type, 0, fn
      lhs, rhs, acc ->
        res = binary_to_number(lhs, t1) * binary_to_number(rhs, t2) + acc
        {res, res}
    end)
  end

  defp bin_dot_transpose_contract_axes(tensor, contract_axes) do
    # The intution here is that we can pre-condense the contracting axes into a
    # single dimension, which will then be contracted through bin_zip_reduce below.
    # This takes a shape {a, m, n, b} which contracts on m, n and turns it into
    # {a, b, m * n}, contracting on the last dimension. This is necessary because
    # bin_zip_reduce and aggregate_axes are order independent but dot depends
    # on the axes order.

    axes = Nx.axes(tensor)

    remaining_axes =
      contract_axes
      |> Enum.sort(:desc)
      |> Enum.reduce(axes, &List.delete_at(&2, &1))

    transpose_axes = remaining_axes ++ contract_axes

    transposed =
      if transpose_axes == axes do
        tensor
      else
        {shape, names} = Nx.Shape.transpose(tensor.shape, transpose_axes, tensor.names)
        transpose(%{tensor | shape: shape, names: names}, tensor, transpose_axes)
      end

    {kept, contracted} =
      transposed.shape
      |> Tuple.to_list()
      |> Enum.split(length(remaining_axes))

    kept_shape = List.to_tuple(kept)

    kept_size = tuple_size(kept_shape)

    reduced_shape = Tuple.insert_at(kept_shape, kept_size, Enum.product(contracted))

    {%{transposed | shape: reduced_shape, names: List.duplicate(nil, tuple_size(reduced_shape))},
     [kept_size]}
  end

  ## Element wise ternary ops

  @impl true
  def select(out, %{shape: {}} = pred, on_true, on_false) do
    if scalar_to_number(pred) == 0,
      do: from_binary(out, broadcast_data(on_false, out.shape)),
      else: from_binary(out, broadcast_data(on_true, out.shape))
  end

  def select(%{shape: shape, type: type} = out, pred, on_true, on_false) do
    %T{type: {_, pred_size} = pred_type} = pred
    %T{type: {_, left_size} = left_type} = on_true
    %T{type: {_, right_size} = right_type} = on_false

    pred_data = to_binary(pred)
    on_true_data = broadcast_data(on_true, shape)
    on_false_data = broadcast_data(on_false, shape)

    data =
      for i <- 0..(Nx.size(shape) - 1), into: <<>> do
        pred =
          match_types [pred_type] do
            consumed = i * pred_size
            <<_::size(consumed)-bitstring, match!(pred, 0), _::bitstring>> = pred_data
            read!(pred, 0)
          end

        result =
          if pred == 0 do
            match_types [right_type] do
              consumed = i * right_size
              <<_::size(consumed)-bitstring, match!(x, 0), _::bitstring>> = on_false_data
              read!(x, 0)
            end
          else
            match_types [left_type] do
              consumed = i * left_size
              <<_::size(consumed)-bitstring, match!(x, 0), _::bitstring>> = on_true_data
              read!(x, 0)
            end
          end

        number_to_binary(result, type)
      end

    from_binary(out, data)
  end

  ## Element wise bin ops

  for fun <-
        [:add, :subtract, :multiply, :power, :remainder, :divide, :atan2, :min, :max, :quotient] ++
          [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] ++
          [:equal, :not_equal, :greater, :less, :greater_equal, :less_equal] ++
          [:logical_and, :logical_or, :logical_xor] do
    capture = Macro.var(:"element_#{fun}", __MODULE__)

    @impl true
    def unquote(fun)(out, left, right) do
      element_wise_bin_op(out, left, right, &(unquote(capture) / 3))
    end
  end

  defp element_wise_bin_op(%{type: type} = out, %{shape: {}} = left, right, fun) do
    number = scalar_to_number(left)

    data =
      binary_to_binary(to_binary(right), right.type, type, fn x ->
        fun.(type, number, x)
      end)

    from_binary(out, data)
  end

  defp element_wise_bin_op(%{type: type} = out, left, %{shape: {}} = right, fun) do
    number = scalar_to_number(right)

    data =
      binary_to_binary(to_binary(left), left.type, type, fn x ->
        fun.(type, x, number)
      end)

    from_binary(out, data)
  end

  defp element_wise_bin_op(%{shape: shape, type: type} = out, left, right, fun) do
    %T{type: {_, left_size} = left_type} = left
    %T{type: {_, right_size} = right_type} = right

    count = Nx.size(shape)
    left_data = broadcast_data(left, shape)
    right_data = broadcast_data(right, shape)

    data =
      match_types [left_type, right_type, type] do
        for i <- 0..(count - 1), into: <<>> do
          left_consumed = i * left_size
          <<_::size(left_consumed)-bitstring, match!(x, 0), _::bitstring>> = left_data
          x = read!(x, 0)

          right_consumed = i * right_size
          <<_::size(right_consumed)-bitstring, match!(y, 1), _::bitstring>> = right_data
          y = read!(y, 1)

          <<write!(fun.(type, x, y), 2)>>
        end
      end

    from_binary(out, data)
  end

  defp element_add(_, a, b), do: Complex.add(a, b)
  defp element_subtract(_, a, b), do: Complex.subtract(a, b)
  defp element_multiply(_, a, b), do: Complex.multiply(a, b)
  defp element_divide(_, a, b), do: Complex.divide(a, b)
  defp element_quotient(_, a, b), do: div(a, b)

  defp element_remainder(_, a, b) when is_integer(a) and is_integer(b), do: rem(a, b)
  defp element_remainder(_, a, b), do: :math.fmod(a, b)

  defp element_atan2(_, a, b), do: Complex.atan2(a, b)

  defp element_max(_, :nan, _), do: :nan
  defp element_max(_, _, :nan), do: :nan
  defp element_max(_, :infinity, _), do: :infinity
  defp element_max(_, _, :infinity), do: :infinity
  defp element_max(_, :neg_infinity, x), do: x
  defp element_max(_, x, :neg_infinity), do: x
  defp element_max(_, a, b) when is_number(a) and is_number(b), do: max(a, b)

  defp element_min(_, :nan, _), do: :nan
  defp element_min(_, _, :nan), do: :nan
  defp element_min(_, :infinity, x), do: x
  defp element_min(_, x, :infinity), do: x
  defp element_min(_, :neg_infinity, _), do: :neg_infinity
  defp element_min(_, _, :neg_infinity), do: :neg_infinity
  defp element_min(_, a, b) when is_number(a) and is_number(b), do: min(a, b)

  defp element_power({type, _}, a, b) when type in [:s, :u], do: Integer.pow(a, b)
  defp element_power(_, a, b), do: Complex.power(a, b)

  defp element_bitwise_and(_, a, b), do: :erlang.band(a, b)
  defp element_bitwise_or(_, a, b), do: :erlang.bor(a, b)
  defp element_bitwise_xor(_, a, b), do: :erlang.bxor(a, b)

  defp element_left_shift(_, a, b) when is_number(b) and b >= 0,
    do: :erlang.bsl(a, b)

  defp element_left_shift(_, _, b), do: raise(ArgumentError, "cannot left shift by #{b}")

  defp element_right_shift(_, a, b) when is_number(b) and b >= 0,
    do: :erlang.bsr(a, b)

  defp element_right_shift(_, _, b), do: raise(ArgumentError, "cannot right shift by #{b}")

  defp element_equal(_, :nan, _), do: 0
  defp element_equal(_, _, :nan), do: 0
  defp element_equal(_, a, b), do: boolean_as_number(a == b)

  defp element_not_equal(_, :nan, _), do: 1
  defp element_not_equal(_, _, :nan), do: 1
  defp element_not_equal(_, a, b), do: boolean_as_number(a != b)

  defp element_logical_and(_, a, b), do: boolean_as_number(as_boolean(a) and as_boolean(b))
  defp element_logical_or(_, a, b), do: boolean_as_number(as_boolean(a) or as_boolean(b))
  defp element_logical_xor(_, a, b), do: boolean_as_number(as_boolean(a) != as_boolean(b))

  defp element_greater(_, :nan, _), do: 0
  defp element_greater(_, _, :nan), do: 0
  defp element_greater(_, x, x), do: 0
  defp element_greater(_, :infinity, _), do: 1
  defp element_greater(_, _, :neg_infinity), do: 1
  defp element_greater(_, :neg_infinity, _), do: 0
  defp element_greater(_, _, :infinity), do: 0
  defp element_greater(_, a, b), do: boolean_as_number(a > b)

  defp element_less(_, :nan, _), do: 0
  defp element_less(_, _, :nan), do: 0
  defp element_less(_, :infinity, _), do: 0
  defp element_less(_, _, :neg_infinity), do: 0
  defp element_less(_, x, x), do: 0
  defp element_less(_, _, :infinity), do: 1
  defp element_less(_, :neg_infinity, _), do: 1
  defp element_less(_, a, b), do: boolean_as_number(a < b)

  defp element_greater_equal(_, :nan, _), do: 0
  defp element_greater_equal(_, _, :nan), do: 0
  defp element_greater_equal(_, x, x), do: 1
  defp element_greater_equal(_, :neg_infinity, _), do: 0
  defp element_greater_equal(_, _, :infinity), do: 0
  defp element_greater_equal(_, :infinity, _), do: 1
  defp element_greater_equal(_, _, :neg_infinity), do: 1
  defp element_greater_equal(_, a, b), do: boolean_as_number(a >= b)

  defp element_less_equal(_, :nan, _), do: 0
  defp element_less_equal(_, _, :nan), do: 0
  defp element_less_equal(_, _, :infinity), do: 1
  defp element_less_equal(_, :neg_infinity, _), do: 1
  defp element_less_equal(_, x, x), do: 1
  defp element_less_equal(_, :infinity, _), do: 0
  defp element_less_equal(_, _, :neg_infinity), do: 0
  defp element_less_equal(_, a, b), do: boolean_as_number(a <= b)

  defp as_boolean(n) when n == 0, do: false
  defp as_boolean(%Complex{re: re, im: im}) when re == 0 and im == 0, do: false
  defp as_boolean(_), do: true

  defp boolean_as_number(true), do: 1
  defp boolean_as_number(false), do: 0

  ## Element wise unary ops

  for {name, {_desc, code, _formula}} <- Nx.Shared.unary_math_funs() do
    @impl true
    def unquote(name)(out, tensor) do
      element_wise_unary_op(out, tensor, fn x -> unquote(code) end)
    end
  end

  @impl true
  def count_leading_zeros(out, %{type: {_, size}} = tensor) do
    element_wise_bit_op(out, tensor, &element_clz(&1, size))
  end

  @impl true
  def population_count(out, tensor) do
    element_wise_bit_op(out, tensor, &element_popcount(&1, 0))
  end

  defp element_wise_bit_op(out, %{type: {_, size}} = tensor, fun) do
    data =
      match_types [out.type] do
        for <<seg::unsigned-size(size)-native <- to_binary(tensor)>>, into: <<>> do
          <<write!(fun.(seg), 0)>>
        end
      end

    from_binary(out, data)
  end

  @impl true
  def abs(out, tensor), do: element_wise_unary_op(out, tensor, &Complex.abs/1)

  @impl true
  def conjugate(out, tensor), do: element_wise_unary_op(out, tensor, &Complex.conjugate/1)

  @impl true
  def real(%{type: {_, component_size}} = out, %{type: {:c, _}} = tensor) do
    data = to_binary(tensor)

    result =
      for <<real::bitstring-size(component_size), _::bitstring-size(component_size) <- data>>,
        into: <<>>,
        do: real

    from_binary(out, result)
  end

  @impl true
  def imag(%{type: {_, component_size}} = out, %{type: {:c, _}} = tensor) do
    data = to_binary(tensor)

    result =
      for <<_::bitstring-size(component_size), imag::bitstring-size(component_size) <- data>>,
        into: <<>>,
        do: imag

    from_binary(out, result)
  end

  @impl true
  def bitwise_not(out, tensor), do: element_wise_unary_op(out, tensor, &:erlang.bnot/1)

  @impl true
  def is_nan(out, %{type: {t, _}}) when t in [:u, :s] do
    # integers cannot represent nans, so we can just create
    # a zero boolean tensor

    # 8 bits per entry because we return u8
    size = Nx.size(out.shape) * 8
    from_binary(out, <<0::size(size)>>)
  end

  def is_nan(out, tensor) do
    element_wise_unary_op(out, tensor, fn
      %Complex{re: :nan} -> 1
      %Complex{im: :nan} -> 1
      :nan -> 1
      _ -> 0
    end)
  end

  @impl true
  def is_infinity(out, %{type: {t, _}}) when t in [:u, :s] do
    # integers cannot represent nans, so we can just create
    # a zero boolean tensor

    # 8 bits per entry because we return u8
    size = Nx.size(out.shape) * 8
    from_binary(out, <<0::size(size)>>)
  end

  def is_infinity(out, tensor) do
    element_wise_unary_op(out, tensor, fn
      %Complex{re: re} when re in [:infinity, :neg_infinity] -> 1
      %Complex{im: im} when im in [:infinity, :neg_infinity] -> 1
      :infinity -> 1
      :neg_infinity -> 1
      _ -> 0
    end)
  end

  @impl true
  def ceil(out, tensor), do: element_wise_unary_op(out, tensor, &:erlang.ceil/1)

  @impl true
  def floor(out, tensor), do: element_wise_unary_op(out, tensor, &:erlang.floor/1)

  @impl true
  def negate(out, tensor), do: element_wise_unary_op(out, tensor, &Complex.negate/1)

  @impl true
  def round(out, tensor), do: element_wise_unary_op(out, tensor, &:erlang.round/1)

  @impl true
  def sign(out, tensor), do: element_wise_unary_op(out, tensor, &element_sign/1)

  defp element_sign(n) when n < 0, do: -1
  defp element_sign(n) when n > 0, do: 1
  defp element_sign(n), do: n

  # https://en.wikipedia.org/wiki/Hamming_weight
  # There are algorithms with faster worst case but they are size specific.
  # The implementation below is also the most efficient for low counts. Given
  # our integers are always 64 bits internally, we will have a lot of zeros
  # internally, so this should be the fastest.
  defp element_popcount(0, count), do: count
  defp element_popcount(n, count), do: element_popcount(n &&& n - 1, count + 1)

  defp element_wise_unary_op(out, tensor, fun) do
    data = binary_to_binary(to_binary(tensor), tensor.type, out.type, fun)
    from_binary(out, data)
  end

  defp element_clz(0, size), do: size
  defp element_clz(n, 64), do: element_clz64(n)
  defp element_clz(n, 32), do: element_clz32(n)
  defp element_clz(n, 16), do: element_clz16(n)
  defp element_clz(n, 8), do: element_clz8(n)

  defp element_clz64(num) do
    case num &&& 0xFFFFFFFF00000000 do
      0 -> 32 + element_clz32(num)
      _ -> element_clz32(num >>> 32)
    end
  end

  defp element_clz32(num) do
    case num &&& 0xFFFF0000 do
      0 -> 16 + element_clz16(num)
      _ -> element_clz16(num >>> 16)
    end
  end

  defp element_clz16(num) do
    case num &&& 0xFF00 do
      0 -> 8 + element_clz8(num)
      _ -> element_clz8(num >>> 8)
    end
  end

  defp element_clz8(num) do
    case num &&& 0xF0 do
      0 -> 4 + element_clz4(num)
      _ -> element_clz4(num >>> 4)
    end
  end

  defp element_clz4(num) do
    case num &&& 0xC do
      0 -> 2 + element_clz2(num)
      _ -> element_clz2(num >>> 2)
    end
  end

  defp element_clz2(0), do: 2
  defp element_clz2(1), do: 1
  defp element_clz2(_), do: 0

  ## Inspect

  @impl true
  def inspect(tensor, inspect_opts) do
    limit = inspect_opts.limit
    binary = Nx.to_binary(tensor, if(limit == :infinity, do: [], else: [limit: limit + 1]))
    Nx.Backend.inspect(tensor, binary, inspect_opts)
  end

  ## Conv

  @impl true
  def conv(out, t, k, opts) do
    padding = opts[:padding]
    strides = opts[:strides]
    input_dilation = opts[:input_dilation]
    kernel_dilation = opts[:kernel_dilation]
    feature_groups = opts[:feature_group_size]
    batch_groups = opts[:batch_group_size]
    input_permutation = opts[:input_permutation]
    kernel_permutation = opts[:kernel_permutation]
    output_permutation = opts[:output_permutation]

    output_permutation =
      output_permutation
      |> Enum.with_index()
      |> Enum.sort()
      |> Enum.map(&elem(&1, 1))

    # Consider an image representation, the input shape should be:
    # {batch, channels, height, width}
    #
    # The kernel then has the following shape:
    # {num_filters, channels, filter_height, filter_width}
    #
    # The output type is merged between the input tensor
    # and the input kernel, both inputs must be floating types
    #
    # The output shape is a product of the batch size,
    # number of filters in the input kernel, and the transformation
    # on the spatial dimensions.
    #
    # The shape of each spatial dimension is calculated as:
    #
    # (spatial_dim - filter_size + 2 * padding_size) / stride + 1
    #
    # where spatial dim is the current spatial dimension size, filter size is
    # the size of the corresponding spatial dimension in the kernel,
    # padding size is the amount of padding applied to either side of the
    # spatial dimension, and stride is the input stride
    #
    # The final shape is then given as:
    # {batch, num_filters, spatial_dim0, spatial_dim1, ...}
    %T{type: {_, input_size} = input_type} = t = Nx.transpose(t, axes: input_permutation)
    %T{type: {_, kernel_size} = kernel_type} = k = Nx.transpose(k, axes: kernel_permutation)

    %{type: output_type, shape: output_shape, names: output_names} = out

    inverse_output_permutation =
      output_permutation |> Enum.with_index() |> Enum.sort() |> Enum.map(&elem(&1, 1))

    {untransposed_shape, untransposed_names} =
      Nx.Shape.transpose(output_shape, inverse_output_permutation, output_names)

    untransposed_out = %{out | names: untransposed_names, shape: untransposed_shape}

    # We need to dilate the spatial dimensions of the kernel first...
    dilation_padding = [
      {0, 0, 0},
      {0, 0, 0} | Enum.map(kernel_dilation, &{0, 0, &1 - 1})
    ]

    %T{shape: kernel_shape} = k = Nx.pad(k, 0, dilation_padding)

    # The size of the "window" or the size of each filter
    # removes the first two dimensions of the kernel
    #
    # This "window" is applied `num_filters` times across
    # the input tensors spatial dimensions
    filter_shape =
      kernel_shape
      |> Tuple.delete_at(0)
      |> Tuple.delete_at(0)

    num_filters = elem(kernel_shape, 0)
    num_input_channels = elem(kernel_shape, 1)

    # We first pad the input tensor with the following semantics
    #   :valid - no padding
    #   :same - pad with 0 such that the input's spatial dims
    #           remain unchanged WITHOUT considering strides
    #   general - pad with the given padding configuration
    #
    # Padding configuration is guaranteed to always be provided
    # as a list because it is handled in the Nx module before
    # lowering to the implementation; however, the padding
    # configuration only accounts for spatial dims
    spatial_padding_config =
      Enum.zip_with(padding, input_dilation, fn {lo, hi}, dilation ->
        {lo, hi, dilation - 1}
      end)

    padding_config = [
      {0, 0, 0},
      {0, 0, 0} | spatial_padding_config
    ]

    %T{shape: padded_shape} = padded_t = Nx.pad(t, 0, padding_config)

    single_data_dims = Tuple.delete_at(padded_shape, 0)
    batch_size = Nx.size(single_data_dims) * input_size

    # We will traverse the input tensor exactly the same as we traversed
    # the binary in window_reduce, but the window is equal to the filter
    # size of the kernel plus the channel size of the input tensor
    window_shape = Tuple.insert_at(filter_shape, 0, num_input_channels)

    input_data = to_binary(padded_t)
    kernel_data = to_binary(k)

    filter_groups_with_index =
      split_into_feature_groups(kernel_data, kernel_shape, kernel_type, feature_groups)

    batch_groups_with_index =
      split_into_batch_groups(input_data, padded_shape, input_type, batch_groups)

    # If we're not splitting across input channels, we're splitting across input batches
    # So we split the filters into groups to apply to the corresponding batch
    filter_groups_with_index =
      if batch_groups > 1,
        do: Enum.chunk_every(filter_groups_with_index, div(num_filters, batch_groups)),
        else: filter_groups_with_index

    batch_weighted_shape =
      weighted_shape(Tuple.delete_at(padded_shape, 0), input_size, window_shape)

    # We calculate our "anchors" using just the spatial dimensions
    # but they also need to consider the depth or channels of the input
    # tensor, so we always anchor on the `0th` channel
    padded_spatial_dims =
      padded_shape
      |> Tuple.delete_at(0)
      |> Tuple.delete_at(0)

    anchors = Enum.sort(make_anchors(padded_spatial_dims, strides, filter_shape))

    output_data =
      for {batch_group, batch_group_index} <- batch_groups_with_index, into: <<>> do
        current_filter_group_with_index =
          if batch_groups > 1,
            do: Enum.at(filter_groups_with_index, batch_group_index),
            else: filter_groups_with_index

        # Traverse the batch dim first
        for <<batch::size(batch_size)-bitstring <- batch_group>>,
            # Traverse the filters next, this allows us to rebuild
            # the resulting binary correctly
            {filter, feature_group} <- current_filter_group_with_index,
            # Then we traverse the spatial dimension, applying
            # the filter at each step
            anchor <- anchors,
            into: <<>> do
          offset =
            weighted_offset(batch_weighted_shape, [feature_group * num_input_channels | anchor])

          # The shape of the window is {channels} + filter_shape
          # The shape of the kernel is {num_filters, channels} + filter_shape
          window =
            batch_weighted_shape
            |> weighted_traverse(batch, input_size, offset)
            |> IO.iodata_to_binary()

          # The receptive field size of each binary in bytes
          input_field_size = Nx.size(filter_shape) * input_size
          filter_field_size = Nx.size(filter_shape) * kernel_size

          # For each channel in both filter and input...
          # The output from a single filter being applied over a window
          # of the input tensor is the sum of the element-wise products
          values =
            for i <- 0..(num_input_channels - 1) do
              current_input_pos = i * input_field_size
              current_filter_pos = i * filter_field_size
              <<_::size(current_input_pos)-bitstring, input_receptive_field::bitstring>> = window

              <<_::size(current_filter_pos)-bitstring, filter_receptive_field::bitstring>> =
                filter

              for j <- 0..(Nx.size(filter_shape) - 1) do
                x =
                  match_types [input_type] do
                    left_consumed = j * input_size

                    <<_::size(left_consumed)-bitstring, match!(x, 0), _::bitstring>> =
                      input_receptive_field

                    read!(x, 0)
                  end

                y =
                  match_types [kernel_type] do
                    right_consumed = j * kernel_size

                    <<_::size(right_consumed)-bitstring, match!(y, 0), _::bitstring>> =
                      filter_receptive_field

                    read!(y, 0)
                  end

                x * y
              end
            end

          values
          |> List.flatten()
          |> Enum.reduce(&+/2)
          |> number_to_binary(output_type)
        end
      end

    Nx.transpose(from_binary(untransposed_out, output_data), axes: output_permutation)
  end

  defp split_into_feature_groups(data, shape, {_, size}, groups) do
    group_size = div(Nx.size(shape) * size, groups)
    data_size = div(Nx.size(shape) * size, elem(shape, 0))

    for i <- 0..(groups - 1),
        offset = group_size * i,
        <<_::size(offset)-bitstring, data_group::size(group_size)-bitstring, _::bitstring>> =
          data,
        <<out_data::size(data_size)-bitstring <- data_group>>,
        do: {out_data, i}
  end

  defp split_into_batch_groups(data, shape, {_, size}, groups) do
    item_size = div(Nx.size(shape) * size, elem(shape, 0))
    num_groups = div(elem(shape, 0), groups)

    output_batch_groups =
      for i <- 0..(num_groups - 1),
          j <- 0..(groups - 1) do
        offset = (i + j * num_groups) * item_size
        <<_::size(offset)-bitstring, item::size(item_size)-bitstring, _::bitstring>> = data
        item
      end

    output_batch_groups |> Enum.with_index() |> Enum.map(fn {x, i} -> {x, rem(i, groups)} end)
  end

  @impl true
  def cholesky(
        %T{type: output_type, shape: output_shape} = out,
        %T{type: input_type} = tensor
      ) do
    data = to_binary(tensor)
    rank = tuple_size(output_shape)
    n = elem(output_shape, rank - 1)

    l =
      bin_batch_reduce(data, n * n, input_type, <<>>, fn matrix, acc ->
        l = B.Matrix.cholesky(matrix, input_type, {n, n}, output_type)
        acc <> l
      end)

    from_binary(out, l)
  end

  @impl true
  def qr(
        {%{shape: q_holder_shape, type: output_type} = q_holder,
         %{shape: r_holder_shape, type: output_type} = r_holder},
        %{type: input_type, shape: input_shape} = tensor,
        opts
      ) do
    bin = to_binary(tensor)
    rank = tuple_size(input_shape)
    m = elem(q_holder_shape, rank - 2)
    k = elem(q_holder_shape, rank - 1)
    n = elem(r_holder_shape, rank - 1)

    {q, r} =
      bin_batch_reduce(bin, m * n, input_type, {<<>>, <<>>}, fn matrix, {q_acc, r_acc} ->
        {q, r} = B.Matrix.qr(matrix, input_type, {m, n}, output_type, m, k, n, opts)
        {q_acc <> q, r_acc <> r}
      end)

    {from_binary(q_holder, q), from_binary(r_holder, r)}
  end

  @impl true
  def eigh(
        {%{type: output_type} = eigenvals_holder, eigenvecs_holder},
        %{type: input_type, shape: input_shape} = tensor,
        opts
      ) do
    bin = to_binary(tensor)
    rank = tuple_size(input_shape)
    n = elem(input_shape, rank - 1)

    {eigenvals, eigenvecs} =
      bin_batch_reduce(bin, n * n, input_type, {<<>>, <<>>}, fn matrix, {vals_acc, vecs_acc} ->
        {vals, vecs} = B.Matrix.eigh(matrix, input_type, {n, n}, output_type, opts)
        {vals_acc <> vals, vecs_acc <> vecs}
      end)

    {from_binary(eigenvals_holder, eigenvals), from_binary(eigenvecs_holder, eigenvecs)}
  end

  @impl true
  def svd(
        {u_holder, %{type: output_type} = s_holder, v_holder},
        %{type: input_type, shape: input_shape} = tensor,
        opts
      ) do
    bin = to_binary(tensor)
    rank = tuple_size(input_shape)
    m = elem(input_shape, rank - 2)
    n = elem(input_shape, rank - 1)

    vt_rows = elem(v_holder.shape, rank - 2)
    vt_cols = elem(v_holder.shape, rank - 1)

    if m < n do
      error_msg =
        if rank == 2 do
          "SVD not implemented for wide matrices (tensors with shape {m, n} where m < n)"
        else
          "SVD not implemented for batches of wide matrices (tensors with shape {..., m, n} where m < n)"
        end

      raise ArgumentError, error_msg
    end

    {u, s, v} =
      bin_batch_reduce(bin, m * n, input_type, {<<>>, <<>>, <<>>}, fn matrix,
                                                                      {u_acc, s_acc, v_acc} ->
        {u, s, v} =
          B.Matrix.svd(matrix, input_type, {m, n}, output_type, {vt_rows, vt_cols}, opts)

        {u_acc <> u, s_acc <> s, v_acc <> v}
      end)

    {from_binary(u_holder, u), from_binary(s_holder, s), from_binary(v_holder, v)}
  end

  @impl true
  def lu(
        {%{type: p_type} = p_holder, %{type: l_type} = l_holder, %{type: u_type} = u_holder},
        %{type: input_type, shape: input_shape} = tensor,
        opts
      ) do
    bin = to_binary(tensor)
    rank = tuple_size(input_shape)
    n = elem(input_shape, rank - 1)

    {p, l, u} =
      bin_batch_reduce(bin, n * n, input_type, {<<>>, <<>>, <<>>}, fn matrix,
                                                                      {p_acc, l_acc, u_acc} ->
        {p, l, u} = B.Matrix.lu(matrix, input_type, {n, n}, p_type, l_type, u_type, opts)
        {p_acc <> p, l_acc <> l, u_acc <> u}
      end)

    {from_binary(p_holder, p), from_binary(l_holder, l), from_binary(u_holder, u)}
  end

  @impl true
  def triangular_solve(
        %{type: output_type} = out,
        %{type: {_, a_size} = a_type, shape: a_shape} = a,
        %{type: b_type, shape: b_shape} = b,
        opts
      ) do
    a_data = to_binary(a)
    b_data = to_binary(b)

    m = elem(a_shape, tuple_size(a_shape) - 1)

    a_batch_byte_size = (m * m * a_size) |> div(8)
    batches_num = byte_size(a_data) |> div(a_batch_byte_size)

    a_batches =
      Enum.map(
        0..(batches_num - 1),
        &binary_part(a_data, &1 * a_batch_byte_size, a_batch_byte_size)
      )

    b_batch_byte_size = byte_size(b_data) |> div(batches_num)

    b_batches =
      Enum.map(
        0..(batches_num - 1),
        &binary_part(b_data, &1 * b_batch_byte_size, b_batch_byte_size)
      )

    b_batch_shape =
      if tuple_size(b_shape) == 1 do
        b_shape
      else
        {a_batch_shape, _} = a_shape |> Tuple.to_list() |> Enum.split(-2)
        {b_1d_batch_shape, [b_n]} = b_shape |> Tuple.to_list() |> Enum.split(-1)
        {b_2d_batch_shape, [b_m, ^b_n]} = b_shape |> Tuple.to_list() |> Enum.split(-2)

        case a_batch_shape do
          ^b_1d_batch_shape -> {b_n}
          ^b_2d_batch_shape -> {b_m, b_n}
        end
      end

    [a_batches, b_batches]
    |> Enum.zip_reduce(<<>>, fn [a, b], acc ->
      acc <> B.Matrix.ts(a, a_type, {m, m}, b, b_type, b_batch_shape, output_type, opts)
    end)
    |> then(&from_binary(out, &1))
  end

  ## Aggregation

  @impl true
  def all(out, %{type: type} = tensor, opts) do
    data =
      bin_reduce(tensor, out.type, 1, opts, fn bin, acc ->
        res = if binary_to_number(bin, type) != 0, do: acc, else: 0

        {res, res}
      end)

    from_binary(out, data)
  end

  @impl true
  def any(out, %{type: type} = tensor, opts) do
    data =
      bin_reduce(tensor, out.type, 0, opts, fn bin, acc ->
        res = if binary_to_number(bin, type) != 0, do: 1, else: acc
        {res, res}
      end)

    from_binary(out, data)
  end

  @impl true
  def sum(out, %{type: type} = tensor, opts) do
    data =
      bin_reduce(tensor, out.type, 0, opts, fn bin, acc ->
        res = binary_to_number(bin, type) + acc
        {res, res}
      end)

    from_binary(out, data)
  end

  @impl true
  def product(out, %{type: type} = tensor, opts) do
    data =
      bin_reduce(tensor, out.type, 1, opts, fn bin, acc ->
        res = binary_to_number(bin, type) * acc
        {res, res}
      end)

    from_binary(out, data)
  end

  @impl true
  def reduce_max(out, %{type: type} = tensor, opts) do
    data =
      bin_reduce(tensor, out.type, :first, opts, fn bin, acc ->
        val = binary_to_number(bin, type)
        res = if acc == :first, do: val, else: Kernel.max(acc, val)
        {res, res}
      end)

    from_binary(out, data)
  end

  @impl true
  def reduce_min(out, %{type: type} = tensor, opts) do
    data =
      bin_reduce(tensor, out.type, :first, opts, fn bin, acc ->
        val = binary_to_number(bin, type)
        res = if acc == :first, do: val, else: Kernel.min(acc, val)
        {res, res}
      end)

    from_binary(out, data)
  end

  @impl true
  def argmin(out, tensor, opts) do
    comparator =
      case opts[:tie_break] do
        :high -> &<=/2
        :low -> &</2
      end

    argmin_or_max(out, tensor, comparator, opts[:axis])
  end

  @impl true
  def argmax(out, tensor, opts) do
    comparator =
      case opts[:tie_break] do
        :high -> &>=/2
        :low -> &>/2
      end

    argmin_or_max(out, tensor, comparator, opts[:axis])
  end

  defp argmin_or_max(out, %{type: type} = tensor, comparator, axis) do
    opts = if axis, do: [axes: [axis]], else: []

    data =
      bin_reduce(tensor, out.type, {0, :first, -1}, opts, fn
        bin, {i, cur_extreme_x, cur_extreme_i} ->
          x = binary_to_number(bin, type)

          if comparator.(x, cur_extreme_x) or cur_extreme_x == :first do
            {i, {i + 1, x, i}}
          else
            {cur_extreme_i, {i + 1, cur_extreme_x, cur_extreme_i}}
          end
      end)

    from_binary(out, data)
  end

  @impl true
  def reduce(out, tensor, acc, opts, fun) do
    each = %{tensor | shape: {}}

    data =
      bin_reduce(tensor, out.type, acc, opts, fn bin, acc ->
        res = fun.(from_binary(each, bin), acc)
        {res, res}
      end)

    from_binary(out, data)
  end

  @impl true
  def window_reduce(out, tensor, acc, window_dimensions, opts, fun) do
    padding_config = opts[:padding]
    strides = opts[:strides]
    dilations = opts[:window_dilations]

    %T{shape: padded_shape, type: {_, size} = type} =
      tensor = Nx.pad(tensor, acc, Enum.map(padding_config, &Tuple.append(&1, 0)))

    acc = scalar_to_number(acc)

    data = to_binary(tensor)
    weighted_shape = weighted_shape(padded_shape, size, window_dimensions, dilations)
    anchors = Enum.sort(make_anchors(padded_shape, strides, window_dimensions, dilations))

    data =
      match_types [type] do
        for anchor <- anchors, into: <<>> do
          offset = weighted_offset(weighted_shape, anchor, dilations)
          window = IO.iodata_to_binary(weighted_traverse(weighted_shape, data, size, offset))

          window_val =
            for <<match!(x, 0) <- window>>,
              reduce: acc,
              do: (acc -> fun.(read!(x, 0), acc))

          scalar_to_binary!(window_val, type)
        end
      end

    from_binary(out, data)
  end

  @impl true
  def window_sum(out, tensor, window_dimensions, opts) do
    %{type: type} = out

    init_value = number_to_binary(0, type)
    init_value = from_binary(%{out | shape: {}, names: []}, init_value)

    fun = fn a, b -> element_add(type, a, b) end
    window_reduce(out, tensor, init_value, window_dimensions, opts, fun)
  end

  @impl true
  def window_max(out, tensor, window_dimensions, opts) do
    %{type: type} = out

    init_value = Nx.Type.min_binary(type)
    init_value = from_binary(%{out | shape: {}, names: []}, init_value)

    fun = fn a, b -> element_max(type, a, b) end
    window_reduce(out, tensor, init_value, window_dimensions, opts, fun)
  end

  @impl true
  def window_min(out, tensor, window_dimensions, opts) do
    %{type: type} = out

    init_value = Nx.Type.max_binary(type)
    init_value = from_binary(%{out | shape: {}, names: []}, init_value)

    fun = fn a, b -> element_min(type, a, b) end
    window_reduce(out, tensor, init_value, window_dimensions, opts, fun)
  end

  @impl true
  def window_product(out, tensor, window_dimensions, opts) do
    %{type: type} = out

    init_value = number_to_binary(1, type)
    init_value = from_binary(%{out | shape: {}, names: []}, init_value)

    fun = fn a, b -> element_multiply(type, a, b) end
    window_reduce(out, tensor, init_value, window_dimensions, opts, fun)
  end

  @impl true
  def map(%{type: output_type} = out, %{type: {_, size}} = tensor, _opts, fun) do
    data = to_binary(tensor)
    template = %{tensor | shape: {}}

    output_data =
      for <<bin::size(size)-bitstring <- data>>, into: <<>> do
        tensor = put_in(template.data.state, bin)
        number_to_binary(scalar_to_number(fun.(tensor)), output_type)
      end

    from_binary(out, output_data)
  end

  @impl true
  def window_scatter_max(out, tensor, source, init_value, window_dimensions, opts) do
    select_and_scatter(
      out,
      tensor,
      source,
      &Nx.greater_equal/2,
      window_dimensions,
      opts,
      init_value,
      &Nx.add/2
    )
  end

  @impl true
  def window_scatter_min(out, tensor, source, init_value, window_dimensions, opts) do
    select_and_scatter(
      out,
      tensor,
      source,
      &Nx.less_equal/2,
      window_dimensions,
      opts,
      init_value,
      &Nx.add/2
    )
  end

  defp select_and_scatter(
         %{type: {_, output_size} = output_type, shape: output_shape} = out,
         t,
         source,
         select_fn,
         window_dimensions,
         opts,
         init_value,
         scatter_fn
       ) do
    padding = opts[:padding]
    strides = opts[:strides]

    init_value = scalar_to_number(init_value)

    %T{shape: padded_shape, type: {_, size} = type} =
      tensor = Nx.pad(t, init_value, Enum.map(padding, &Tuple.append(&1, 0)))

    input_data = to_binary(tensor)
    input_weighted_shape = weighted_shape(padded_shape, size, window_dimensions)
    input_anchors = Enum.sort(make_anchors(padded_shape, strides, window_dimensions))

    %T{type: {_, source_size} = source_type} = source
    source_data = to_binary(source)

    output_windows =
      for {anchor, i} <- Enum.with_index(input_anchors) do
        offset = weighted_offset(input_weighted_shape, anchor)

        window =
          IO.iodata_to_binary(weighted_traverse(input_weighted_shape, input_data, size, offset))

        # Get the index where `select_fn` is true
        {_, index, _} =
          match_types [type] do
            <<match!(first_elem, 0), _::bitstring>> = window

            for <<match!(x, 0) <- window>>, reduce: {0, 0, read!(first_elem, 0)} do
              {cur_index, selected_index, acc} ->
                if select_fn.(read!(x, 0), acc) == Nx.tensor(1, type: {:u, 8}) do
                  {cur_index + 1, cur_index, read!(x, 0)}
                else
                  {cur_index + 1, selected_index, acc}
                end
            end
          end

        offset_from_anchor =
          flattened_index_to_offset(index, Tuple.to_list(window_dimensions), 0, [])

        absolute_index =
          anchor
          |> Enum.zip(offset_from_anchor)
          |> Enum.map(fn {x, y} -> x + y end)

        source_consumed = i * source_size

        <<_::size(source_consumed)-bitstring, from_source::size(source_size)-bitstring,
          _::bitstring>> = source_data

        source_value = binary_to_number(from_source, source_type)
        {source_value, absolute_index}
      end

    output_weighted_shape = weighted_shape(output_shape, output_size)

    # Fold duplicate indices into one another using `scatter_fn` and sort
    # by absolute offset
    values_with_indices =
      output_windows
      |> Enum.group_by(&elem(&1, 1), &elem(&1, 0))
      |> Enum.map(fn {index, value} ->
        offset = weighted_offset(output_weighted_shape, index)
        {offset, Enum.reduce(value, init_value, scatter_fn)}
      end)
      |> Enum.sort_by(&elem(&1, 0))

    init_binary = number_to_binary(init_value, output_type)

    {final_offset, unpadded_output_data} =
      for {offset, value} <- values_with_indices, reduce: {0, <<>>} do
        {acc_offset, acc_binary} ->
          num_vals_before = div(offset - acc_offset, output_size)
          vals_before = List.duplicate(init_binary, num_vals_before)
          source_val = to_binary(value)
          new_binary = IO.iodata_to_binary([vals_before, source_val])

          {offset + output_size,
           <<acc_binary::size(acc_offset)-bitstring, new_binary::bitstring>>}
      end

    num_vals_left = div(output_size * Nx.size(output_shape) - final_offset, output_size)
    vals_left = IO.iodata_to_binary(List.duplicate(init_binary, num_vals_left))
    output_data = <<unpadded_output_data::size(final_offset)-bitstring, vals_left::bitstring>>

    from_binary(out, output_data)
  end

  defp flattened_index_to_offset(leftover, [dim | []], _, acc),
    do: Enum.reverse([rem(leftover, dim) | acc])

  defp flattened_index_to_offset(leftover, [dim | rest], n, acc) do
    if leftover < Nx.size(List.to_tuple(rest)) do
      flattened_index_to_offset(leftover, rest, 0, [n | acc])
    else
      flattened_index_to_offset(
        leftover - Nx.size(List.to_tuple(rest)),
        [dim | rest],
        n + 1,
        acc
      )
    end
  end

  @impl true
  def indexed_add(out, target, indices, updates) do
    resolve_updates = fn upds -> Enum.reduce(upds, 0, &+/2) end
    update_element = &+/2

    indexed_op(out, target, indices, updates, resolve_updates, update_element)
  end

  @impl true
  def indexed_put(out, target, indices, updates) do
    resolve_updates = fn upds -> Enum.at(upds, -1) end
    update_element = fn _cur, upd -> upd end

    indexed_op(out, target, indices, updates, resolve_updates, update_element)
  end

  defp indexed_op(
         %T{} = out,
         %T{shape: shape, type: {_, target_size}} = target,
         %T{shape: {indices_rows, _indices_cols} = indices_shape} = indices,
         %T{shape: {indices_rows}} = updates,
         resolve_updates,
         update_element
       ) do
    indices_bin_list =
      indices |> to_binary() |> aggregate_axes([1], indices_shape, elem(indices.type, 1))

    offsets_list =
      for idx_bin <- indices_bin_list do
        idx = binary_to_list(idx_bin, indices.type)
        offset = index_to_binary_offset(idx, shape)
        offset * target_size
      end

    updates_list = binary_to_list(to_binary(updates), updates.type)

    {offsets_with_updates, _last_offset} =
      offsets_list
      |> Enum.zip(updates_list)
      |> Enum.group_by(fn {off, _} -> off end, fn {_, upd} -> upd end)
      |> Enum.sort_by(fn {off, _} -> off end)
      |> Enum.map_reduce(0, fn {next_offset, upds}, previous_offset ->
        {{
           previous_offset + target_size,
           next_offset,
           resolve_updates.(upds)
         }, next_offset}
      end)

    target_binary = to_binary(target)

    offsets_with_updates =
      List.update_at(offsets_with_updates, 0, fn {_prev, current, update} ->
        {0, current, update}
      end)

    {result, tail} =
      for {previous, current, update} <- offsets_with_updates, reduce: {<<>>, target_binary} do
        {traversed, to_traverse} ->
          before_slice_size = current - previous

          {before_offset, updated_element, to_traverse} =
            match_types [target.type, out.type] do
              <<before_offset::bitstring-size(before_slice_size), match!(element, 0),
                to_traverse::bitstring>> = to_traverse

              updated_element = <<write!(update_element.(read!(element, 0), update), 1)>>
              {before_offset, updated_element, to_traverse}
            end

          # this can be a list of binaries because we are accumulation an iodata list
          before_offset =
            if target.type == out.type do
              before_offset
            else
              binary_to_binary(before_offset, target.type, out.type, & &1)
            end

          {[traversed, before_offset | updated_element], to_traverse}
      end

    # this can be a list of binaries because we are accumulation an iodata list
    tail =
      if target.type == out.type do
        tail
      else
        binary_to_binary(tail, target.type, out.type, & &1)
      end

    from_binary(out, IO.iodata_to_binary([result, tail]))
  end

  @impl true
  def clip(out, tensor, min, max) do
    in_data = to_binary(tensor)
    min = binary_to_number(to_binary(min), min.type)
    max = binary_to_number(to_binary(max), max.type)
    out_data = binary_to_binary(in_data, tensor.type, out.type, &min(max(&1, min), max))
    from_binary(out, out_data)
  end

  @impl true
  def slice(out, tensor, start_indices, lengths, strides) do
    # If you think of a slice as drawing a bounding box in the dimensions
    # of the tensor, then it's clear we can simply use a weighted
    # traversal to construct the new tensor
    %T{type: {_, size}, shape: shape} = tensor
    %{shape: output_shape} = out

    tensor
    |> to_binary()
    |> bin_slice(shape, size, start_indices, lengths, strides, output_shape)
    |> then(&from_binary(out, &1))
  end

  defp bin_slice(data, shape, size, start_indices, lengths, strides, output_shape) do
    start_indices = clamp_indices(start_indices, shape, lengths)

    if hd(strides) == 1 and top_dimension_slice?(tuple_size(shape), shape, output_shape) do
      length = Nx.size(output_shape) * div(size, 8)
      offset = div(length, elem(output_shape, 0)) * hd(start_indices)
      binary_part(data, offset, length)
    else
      # Anchored around the start indices
      weighted_shape = weighted_shape(shape, size, output_shape)
      offset = weighted_offset(weighted_shape, start_indices)

      # The weighted shape is altered such that we traverse
      # with respect to the stride for each dimension
      weighted_shape =
        Enum.zip_with(weighted_shape, strides, fn {d, dim_size}, s ->
          {d, dim_size + (s - 1) * dim_size}
        end)

      IO.iodata_to_binary(weighted_traverse(weighted_shape, data, size, offset))
    end
  end

  defp clamp_indices(start_indices, shape, lengths) do
    Enum.zip_with([Tuple.to_list(shape), start_indices, lengths], fn [dim_size, idx, len] ->
      idx = scalar_to_number(idx)
      min(max(idx, 0), dim_size - len)
    end)
  end

  defp top_dimension_slice?(1, _, _), do: true

  defp top_dimension_slice?(i, is, os)
       when :erlang.element(i, is) == :erlang.element(i, os),
       do: top_dimension_slice?(i - 1, is, os)

  defp top_dimension_slice?(_, _, _), do: false

  @impl true
  def put_slice(out, tensor, start_indices, slice, combine_fn \\ fn _prev, new -> new end) do
    %T{type: {_, size} = type, shape: shape} = tensor = as_type(out, tensor)
    %T{shape: slice_shape} = slice = as_type(%{slice | type: type}, slice)

    start_indices = clamp_indices(start_indices, shape, Tuple.to_list(slice_shape))
    weighted_shape = weighted_shape(shape, size)

    rank = Nx.rank(shape)
    ones = List.duplicate(1, rank)

    offsets =
      slice_shape
      |> make_anchors(ones, List.to_tuple(ones))
      |> Enum.map(fn index ->
        pos =
          index
          |> Enum.zip(start_indices)
          |> Enum.map(fn {x, s} -> x + s end)

        weighted_offset(weighted_shape, pos)
      end)
      |> Enum.sort()

    {_, data} =
      for offset <- offsets, reduce: {to_binary(slice), to_binary(tensor)} do
        {<<cur_elem::size(size)-bitstring, rest_of_slice::bitstring>>, binary} ->
          <<before::size(offset)-bitstring, prev_elem::size(size)-bitstring,
            rest_of_tensor::bitstring>> = binary

          new_elem = combine_fn.(prev_elem, cur_elem)

          {rest_of_slice,
           <<before::size(offset)-bitstring, new_elem::size(size)-bitstring,
             rest_of_tensor::bitstring>>}
      end

    from_binary(out, data)
  end

  @impl true
  def take(out, tensor, indices, axis) do
    # We iterate over the indices in a flat manner,
    # and take a unit tensor slice along axis given
    # by each index. Then we concatenate the tensors
    # along the axis, which gives us the result with
    # index dimensions flattened and we just reshape.

    %T{type: {_, size}, shape: shape} = tensor
    %T{type: {_, idx_size}} = indices

    data = to_binary(tensor)
    tensor_rank = tuple_size(shape)
    slice_start = List.duplicate(0, tensor_rank)
    slice_lengths = shape |> Tuple.to_list() |> List.replace_at(axis, 1)
    slice_shape = List.to_tuple(slice_lengths)
    strides = List.duplicate(1, tensor_rank)

    slices =
      for <<bin::size(idx_size)-bitstring <- to_binary(indices)>> do
        idx = binary_to_number(bin, indices.type)

        if idx < 0 or idx >= elem(shape, axis) do
          raise ArgumentError,
                "index #{idx} is out of bounds for axis #{axis} in shape #{inspect(shape)}"
        end

        slice_start = List.replace_at(slice_start, axis, idx)

        slice_data =
          bin_slice(data, shape, size, slice_start, slice_lengths, strides, slice_shape)

        {slice_data, slice_shape}
      end

    concat_shape = put_elem(tensor.shape, axis, length(slices))
    result_data = bin_concatenate(slices, size, axis, concat_shape)

    from_binary(out, result_data)
  end

  @impl true
  def take_along_axis(
        %T{type: output_type} = output,
        %T{shape: t_shape, type: {_, t_size} = t_type} = tensor,
        %T{shape: idx_shape, type: {_, idx_size} = idx_type} = indices,
        axis
      ) do
    permutation =
      tensor
      |> Nx.axes()
      |> List.delete(axis)
      |> List.insert_at(Nx.rank(tensor) - 1, axis)

    inverse_permutation =
      permutation
      |> Enum.with_index()
      |> Enum.sort_by(fn {x, _} -> x end)
      |> Enum.map(fn {_, i} -> i end)

    shape_list = Tuple.to_list(output.shape)
    permuted_shape = permutation |> Enum.map(&Enum.at(shape_list, &1)) |> List.to_tuple()

    t_view = tensor |> to_binary() |> aggregate_axes([axis], t_shape, t_size)
    idx_view = indices |> to_binary() |> aggregate_axes([axis], idx_shape, idx_size)

    [t_view, idx_view]
    |> Enum.zip_with(fn [data_bin, idx_bin] ->
      data = binary_to_list(data_bin, t_type)

      binary_to_binary(idx_bin, idx_type, output_type, fn idx ->
        if idx < 0 or idx >= elem(tensor.shape, axis) do
          raise ArgumentError,
                "index #{idx} is out of bounds for axis #{axis} in shape #{inspect(tensor.shape)}"
        end

        Enum.at(data, idx)
      end)
    end)
    |> then(&from_binary(%{output | shape: permuted_shape}, &1))
    |> then(&transpose(output, &1, inverse_permutation))
  end

  @impl true
  def gather(out, tensor, indices) do
    %T{type: {_, size}, shape: shape} = tensor
    %T{type: {_, idx_size}} = indices

    data = to_binary(tensor)
    rank = tuple_size(shape)
    byte_size = div(size, 8)

    idx_last_dim_bin_size = rank * idx_size

    new_data =
      for <<bin::size(idx_last_dim_bin_size)-bitstring <- to_binary(indices)>>, into: <<>> do
        slice_start =
          for <<bin::size(idx_size)-bitstring <- bin>>, do: binary_to_number(bin, indices.type)

        offset = index_to_binary_offset(slice_start, shape)
        binary_part(data, offset * byte_size, byte_size)
      end

    from_binary(out, new_data)
  end

  defp index_to_binary_offset(index, input_shape) when is_list(index) and is_tuple(input_shape) do
    {offset, []} =
      index
      |> Enum.with_index()
      |> Enum.reduce(
        {0, Tuple.to_list(input_shape)},
        fn {idx, axis}, {offset, [dim_size | shape]} ->
          if idx < 0 or idx >= dim_size do
            raise ArgumentError,
                  "index #{idx} is out of bounds for axis #{axis} in shape #{inspect(input_shape)}"
          end

          {offset + idx * Enum.product(shape), shape}
        end
      )

    offset
  end

  @impl true
  def concatenate(out, tensors, axis) do
    %{shape: output_shape, type: {_, size} = output_type} = out

    tensors
    |> Enum.map(fn %{shape: shape} = t ->
      t = as_type(%{t | type: output_type}, t)
      {to_binary(t), shape}
    end)
    |> bin_concatenate(size, axis, output_shape)
    |> then(&from_binary(out, &1))
  end

  defp bin_concatenate(binaries_with_shape, _size, 0, _output_shape) do
    binaries_with_shape |> Enum.map(&elem(&1, 0)) |> IO.iodata_to_binary()
  end

  defp bin_concatenate(binaries_with_shape, size, axis, output_shape) do
    rank = tuple_size(output_shape)
    steps = product_part(output_shape, 0, axis)

    data =
      for step <- 1..steps,
          {binary, shape} <- binaries_with_shape do
        product = product_part(shape, axis, rank) * size
        before = (step - 1) * product
        <<_::bitstring-size(before), part::bitstring-size(product), _::bitstring>> = binary
        part
      end

    IO.iodata_to_binary(data)
  end

  defp product_part(_tuple, n, n), do: 1
  defp product_part(tuple, n, limit), do: elem(tuple, n) * product_part(tuple, n + 1, limit)

  @impl true
  def as_type(out, tensor) do
    %{type: output_type} = out

    case tensor do
      %T{type: ^output_type} ->
        tensor

      %T{type: input_type} ->
        float_output? = Nx.Type.float?(output_type)
        data = to_binary(tensor)

        output_data =
          match_types [input_type] do
            for <<match!(x, 0) <- data>>, into: <<>> do
              x = read!(x, 0)

              case x do
                %Complex{re: re} when float_output? ->
                  number_to_binary(re, output_type)

                _ when float_output? ->
                  number_to_binary(x, output_type)

                %Complex{re: re} ->
                  number_to_binary(trunc(re), output_type)

                _ when is_number(x) ->
                  number_to_binary(trunc(x), output_type)

                :nan ->
                  number_to_binary(0, output_type)

                :infinity ->
                  Nx.Type.max_finite_binary(output_type)

                :neg_infinity ->
                  Nx.Type.min_finite_binary(output_type)
              end
            end
          end

        from_binary(out, output_data)
    end
  end

  @impl true
  def bitcast(out, tensor), do: from_binary(out, to_binary(tensor))

  @impl true
  def sort(output, t, opts), do: do_sort(output, t, opts, false)

  @impl true
  def argsort(output, t, opts), do: do_sort(output, t, opts, true)

  defp do_sort(output, t, opts, return_indices) do
    %T{shape: shape, type: type} = t
    last_axis = Nx.rank(t) - 1

    axis = opts[:axis]

    comparator =
      case opts[:direction] do
        :desc ->
          fn a, b ->
            a = binary_to_number(a, type)
            b = binary_to_number(b, type)
            a >= b
          end

        :asc ->
          fn a, b ->
            a = binary_to_number(a, type)
            b = binary_to_number(b, type)
            a <= b
          end
      end

    case shape do
      {} when return_indices ->
        sort_last_dim(t, comparator, output, return_indices)

      {} ->
        t

      _ when axis == last_axis ->
        sort_last_dim(t, comparator, output, return_indices)

      _ ->
        permutation = Nx.axes(t)

        permutation =
          permutation
          |> List.delete(axis)
          |> List.insert_at(Nx.rank(t) - 1, axis)

        inverse_permutation =
          permutation
          |> Enum.with_index()
          |> Enum.sort_by(fn {x, _} -> x end)
          |> Enum.map(fn {_, i} -> i end)

        permuted_t = Nx.transpose(t, axes: permutation)

        permuted_t
        |> sort_last_dim(comparator, %{permuted_t | type: output.type}, return_indices)
        |> Nx.transpose(axes: inverse_permutation)
    end
  end

  defp sort_last_dim(
         %T{shape: shape, type: {_, size}} = t,
         comparator,
         output,
         return_indices
       ) do
    view = aggregate_axes(to_binary(t), [tuple_size(shape) - 1], shape, size)

    new_data =
      for bin <- view, into: <<>> do
        data = for <<x::size(size)-bitstring <- bin>>, do: x

        sorted =
          if return_indices do
            data
            |> Enum.with_index()
            |> Enum.sort_by(&elem(&1, 0), comparator)
            |> Enum.map(fn {_, index} -> number_to_binary(index, output.type) end)
          else
            Enum.sort(data, comparator)
          end

        IO.iodata_to_binary(sorted)
      end

    from_binary(output, new_data)
  end

  @impl true
  def fft(out, tensor, opts), do: calculate_fft(out, tensor, opts, :fft)

  @impl true
  def ifft(out, tensor, opts), do: calculate_fft(out, tensor, opts, :ifft)

  defp calculate_fft(out, %{shape: shape, type: {_, size}} = tensor, opts, kind) do
    eps = opts[:eps]
    n = opts[:length]

    input_view = aggregate_axes(to_binary(tensor), [tuple_size(shape) - 1], shape, size)

    [row | _] =
      input_data =
      for row <- input_view do
        binary_to_list(row, tensor.type)
      end

    len_data = length(row)

    data =
      for row <- input_data do
        cond do
          len_data < n -> row ++ List.duplicate(0, n - len_data)
          len_data > n -> Enum.take(row, n)
          true -> row
        end
      end

    result =
      for row <- data do
        case kind do
          :fft ->
            fft_list(row, n)

          :ifft ->
            row
            |> Enum.map(&Complex.conjugate/1)
            |> fft_list(n)
            |> Enum.map(&(Complex.conjugate(&1) / n))
        end
      end

    output_data =
      match_types [out.type] do
        for row <- result, %Complex{re: re, im: im} <- row, into: <<>> do
          re = if abs(re) <= eps, do: 0, else: re
          im = if abs(im) <= eps, do: 0, else: im

          <<write!(Complex.new(re, im), 0)>>
        end
      end

    from_binary(out, output_data)
  end

  defp fft_list(data, bigN) when bigN < 2, do: data

  defp fft_list(data, bigN) when rem(bigN, 2) != 0 do
    # Naive DFT case for when bigN is odd > 2
    for k <- 0..(bigN - 1) do
      minus_two_pi_k_over_bigN = -2 * :math.pi() * k / bigN

      data
      |> Enum.with_index(fn x, n ->
        x * Complex.exp(Complex.new(0, minus_two_pi_k_over_bigN * n))
      end)
      |> Enum.reduce(&+/2)
    end
  end

  defp fft_list([_ | tail] = data, bigN) do
    # Here bigN is guaranteed to be even, so the recursion will
    # work on at least one level
    bigN_over_2 = div(bigN, 2)

    even = fft_list(Enum.take_every(data, 2), bigN_over_2)
    odd = fft_list(Enum.take_every(tail, 2), bigN_over_2)

    t =
      Enum.with_index(odd, fn item, k ->
        Complex.exp(Complex.new(0, -2 * :math.pi() * k / bigN)) * item
      end)

    left = Enum.zip_with([even, t], fn [x, y] -> x + y end)
    right = Enum.zip_with([even, t], fn [x, y] -> x - y end)

    left ++ right
  end

  ## Binary reducers

  defp bin_reduce(tensor, type, acc, opts, fun) do
    %T{type: {_, size}, shape: shape} = tensor

    view =
      if axes = opts[:axes] do
        aggregate_axes(to_binary(tensor), axes, shape, size)
      else
        [to_binary(tensor)]
      end

    for axis <- view, into: <<>> do
      {result, _} =
        for <<bin::size(size)-bitstring <- axis>>, reduce: {<<>>, acc} do
          {_, acc} -> fun.(bin, acc)
        end

      scalar_to_binary!(result, type)
    end
  end

  defp bin_zip_reduce(t1, [], t2, [], type, acc, fun) do
    %{type: {_, s1}} = t1
    %{type: {_, s2}} = t2
    b1 = to_binary(t1)
    b2 = to_binary(t2)

    match_types [t1.type, t2.type] do
      for <<d1::size(s1)-bitstring <- b1>>, <<d2::size(s2)-bitstring <- b2>>, into: <<>> do
        {result, _} = fun.(d1, d2, acc)
        scalar_to_binary!(result, type)
      end
    end
  end

  defp bin_zip_reduce(t1, [_ | _] = axes1, t2, [_ | _] = axes2, type, acc, fun) do
    {_, s1} = t1.type
    {_, s2} = t2.type

    v1 = aggregate_axes(to_binary(t1), axes1, t1.shape, s1)
    v2 = aggregate_axes(to_binary(t2), axes2, t2.shape, s2)

    for b1 <- v1, b2 <- v2 do
      {bin, _acc} = bin_zip_reduce_axis(b1, b2, s1, s2, <<>>, acc, fun)
      scalar_to_binary!(bin, type)
    end
  end

  # Helper for reducing down a single axis over two tensors,
  # returning tensor data and a final accumulator.
  defp bin_zip_reduce_axis(<<>>, <<>>, _s1, _s2, bin, acc, _fun),
    do: {bin, acc}

  defp bin_zip_reduce_axis(b1, b2, s1, s2, _bin, acc, fun) do
    <<x::size(s1)-bitstring, rest1::bitstring>> = b1
    <<y::size(s2)-bitstring, rest2::bitstring>> = b2
    {bin, acc} = fun.(x, y, acc)
    bin_zip_reduce_axis(rest1, rest2, s1, s2, bin, acc, fun)
  end

  defp bin_batch_reduce(bin, batch_size, {_, size}, acc, fun) do
    batch_byte_size = (batch_size * size) |> div(8)
    batches = byte_size(bin) |> div(batch_byte_size)

    for i <- 0..(batches - 1), reduce: acc do
      acc ->
        batch = binary_part(bin, i * batch_byte_size, batch_byte_size)
        fun.(batch, acc)
    end
  end

  ## Conversion helpers

  defp scalar_to_number(n) when is_number(n), do: n
  defp scalar_to_number(%Complex{} = n), do: n
  defp scalar_to_number(t), do: binary_to_number(to_binary(t), t.type)

  defp scalar_to_binary!(%Complex{re: re, im: im}, type) do
    real_type = Nx.Type.to_real(type)
    number_to_binary(re, real_type) <> number_to_binary(im, real_type)
  end

  defp scalar_to_binary!(value, type)
       when is_number(value) or value in [:nan, :neg_infinity, :infinity],
       do: number_to_binary(value, type)

  defp scalar_to_binary!(%T{shape: {}, type: type} = t, type),
    do: to_binary(t)

  defp scalar_to_binary!(t, type) do
    raise ArgumentError,
          "expected a number or a scalar tensor of type #{inspect(type)}, got: #{inspect(t)}"
  end

  defp number_to_binary(number, type),
    do: match_types([type], do: <<write!(number, 0)>>)

  defp binary_to_number(bin, type) do
    match_types [type] do
      <<match!(value, 0)>> = bin
      read!(value, 0)
    end
  end

  defp binary_to_list(binary, type) do
    match_types [type] do
      for <<match!(x, 0) <- binary>>, do: read!(x, 0)
    end
  end

  defp binary_to_binary(binary, in_type, out_type, fun) do
    match_types [in_type, out_type] do
      for <<match!(seg, 0) <- binary>>, into: <<>> do
        <<write!(fun.(read!(seg, 0)), 1)>>
      end
    end
  end

  ## Aggregation helpers

  defp aggregate_axes(binary, axes, shape, size) do
    {chunk_size, read_size, path} = aggregate_axes(axes, shape, size)

    view =
      for <<chunk::size(chunk_size)-bitstring <- binary>> do
        weighted_traverse(path, chunk, read_size)
      end

    List.flatten(view)
  end

  defp aggregate_axes([_ | _] = axes, shape, size) do
    axes = Enum.sort(axes)
    min = hd(axes)
    weighted_shape = weighted_shape(shape, size)
    [{axis_count, axis_weight} | _] = weighted_shape = Enum.drop(weighted_shape, min)
    chunk_size = axis_count * axis_weight

    # The goal of aggregate path is to split the paths
    # we are reducing from the ones we are keeping as is.
    {reverse_pre, reverse_pos} = aggregate_path(weighted_shape, axes, min, [], [])

    # Now if we are reducing on the last dimensions, we
    # can increase the read size.
    {reverse_pos, read_size} =
      aggregate_read(reverse_pos, tuple_size(shape) - 1, Enum.reverse(axes), size)

    path = Enum.reverse(reverse_pre, [(&IO.iodata_to_binary/1) | Enum.reverse(reverse_pos)])
    {chunk_size, read_size, path}
  end

  defp aggregate_path([pair | shape], [i | axes], i, pre, pos),
    do: aggregate_path(shape, axes, i + 1, pre, [pair | pos])

  defp aggregate_path([pair | shape], axes, i, pre, pos),
    do: aggregate_path(shape, axes, i + 1, [pair | pre], pos)

  defp aggregate_path([], [], _i, pre, pos), do: {pre, pos}

  defp aggregate_read([{axis, weight} | shape], i, [i | axis], _size),
    do: aggregate_read(shape, i - 1, axis, axis * weight)

  defp aggregate_read(shape, _i, _axis, size),
    do: {shape, size}

  ## Weighted shapes

  # Converts the shape to a weight shape list.
  #
  # A weighted shape is a list of tuples where the first
  # element is the number of elements in the dimension
  # and the second element is the size to be traversed in
  # the binary to fetch the next element.
  #
  # This is often given to `weighted_traverse/4` as a general
  # mechanism to traverse binaries.
  defp weighted_shape(shape, size, limits \\ :none, dilations \\ 1) do
    rank = tuple_size(shape)

    dilations =
      if is_list(dilations),
        do: Enum.reverse(dilations),
        else: List.duplicate(dilations, rank)

    weighted_shape(shape, rank, size, limits, dilations, [])
  end

  defp weighted_shape(_shape, 0, _weight, _limits, [], acc), do: acc

  defp weighted_shape(shape, pos, weight, limits, [dilation | dilations], acc) do
    shape_elem = :erlang.element(pos, shape)

    element =
      if limits == :none, do: shape_elem, else: min(:erlang.element(pos, limits), shape_elem)

    dilation_factor =
      if element == 1,
        do: 1,
        else: dilation

    acc = [{element, dilation_factor * weight} | acc]
    weighted_shape(shape, pos - 1, weight * shape_elem, limits, dilations, acc)
  end

  # Reads the chunk size from a weighted list at the given position.
  defp weighted_chunk(list, at, size) do
    {element, size} = Enum.at(list, at, {1, size})
    element * size
  end

  # Traverses a binary using the elements and shape given by `weighted_shape`.
  #
  # When all dimensions are traversed, we read `read_size`.
  #
  # The `weighted_shape` can also contain functions, which are applied to the
  # result of the remaining of the weighted shape.
  defp weighted_traverse(weighted_shape, binary, read_size, offset \\ 0)

  defp weighted_traverse([], data, read_size, offset) do
    <<_::size(offset)-bitstring, chunk::size(read_size)-bitstring, _::bitstring>> = data
    chunk
  end

  defp weighted_traverse([{dim, size} | dims], data, read_size, offset) do
    weighted_traverse(dim, size, dims, data, read_size, offset)
  end

  defp weighted_traverse([fun | dims], data, read_size, offset) do
    fun.(weighted_traverse(dims, data, read_size, offset))
  end

  defp weighted_traverse(dim, dim_size, dims, data, read_size, offset) do
    head = weighted_traverse(dims, data, read_size, offset)

    case dim do
      1 ->
        [head]

      _ ->
        <<_::size(dim_size)-bitstring, data::bitstring>> = data
        [head | weighted_traverse(dim - 1, dim_size, dims, data, read_size, offset)]
    end
  end

  # Makes anchors for traversing a binary with a window.
  defp make_anchors(shape, strides, window, dilations \\ 1)

  defp make_anchors(shape, strides, window, dilations)
       when is_tuple(shape) and is_tuple(window) and is_list(strides) do
    dilations =
      if is_integer(dilations),
        do: List.duplicate(dilations, tuple_size(shape)),
        else: dilations

    make_anchors(Tuple.to_list(shape), strides, Tuple.to_list(window), dilations, :init)
  end

  defp make_anchors([], [], _window, _dilations, anchors), do: anchors

  defp make_anchors([dim | shape], [s | strides], [w | window], [dil | dilation], :init) do
    dims = for i <- 0..(dim - 1), rem(i, s) == 0 and i + (w - 1) * dil < dim, do: [i]
    make_anchors(shape, strides, window, dilation, dims)
  end

  defp make_anchors([dim | shape], [s | strides], [w | window], [dil | dilation], anchors) do
    anchors =
      for i <- 0..(dim - 1),
          rem(i, s) == 0 and i + (w - 1) * dil < dim,
          anchor <- anchors,
          do: anchor ++ [i]

    make_anchors(shape, strides, window, dilation, anchors)
  end

  # Calculates the offset needed to reach a specified position
  # in the binary from a weighted shape list.
  defp weighted_offset(weighted_shape, pos, dilations \\ 1)

  defp weighted_offset(weighted_shape, pos, dilations) when is_list(pos) do
    dilations =
      if is_list(dilations),
        do: dilations,
        else: List.duplicate(dilations, length(weighted_shape))

    sum_weighted_offset(weighted_shape, pos, dilations)
  end

  defp sum_weighted_offset([], [], []), do: 0

  defp sum_weighted_offset([{s, size} | dims], [x | pos], [d | dilation]) do
    # Account for the dilation
    dilation_factor =
      if s == 1,
        do: 1,
        else: d

    div(size, dilation_factor) * x + weighted_offset(dims, pos, dilation)
  end
end