lib/nx/lazy_container.ex

defprotocol Nx.LazyContainer do
  @moduledoc """
  Converts a data structure to a container lazily.

  When you invoke a `defn`, its arguments must implement
  the `Nx.LazyContainer` protocol and return a data structure
  that implements `Nx.Container`. In other words, this
  protocol teaches Nx to work with additional data types
  beyond numbers and tensors.

  This module provides a single traverse implementation
  that emits the tensor template and a function that computes
  the tensor as two distinct values. Then a tensor is only
  allocated if necessary.

  You can convert any `Nx.LazyContainer` to a tensor by using
  `Nx.stack/2` and `Nx.concatenate/2`.
  """

  @fallback_to_any true

  @doc """
  Traverses recursively tensors in a data structure with `acc` and `fun`.

  For each tensor in the container, `fun` receives a tensor
  template, an anonymous function to build the actual tensor,
  and the accumulator . It returns a two element tuple with
  a non-lazy Nx.Container and the accumulator.

  This function returns the updated container and the accumulator.

  Note this function is recursive by default. Therefore if you
  are implementing this function and one of your arguments may
  be containers, you must call `Nx.LazyContainer.traverse/3`
  on said arguments so they are recursively traversed.
  """
  @spec traverse(t(), acc, (Nx.template(), (() -> Nx.Tensor.t()), acc -> {term(), acc})) ::
          {Nx.Container.t(), acc}
        when acc: term()
  def traverse(data, acc, fun)
end

defimpl Nx.LazyContainer, for: Nx.Tensor do
  def traverse(tensor, acc, fun) do
    fun.(%{tensor | data: %Nx.TemplateBackend{}}, fn -> tensor end, acc)
  end
end

defimpl Nx.LazyContainer, for: [Integer, Float, Complex] do
  def traverse(number, acc, fun) do
    tensor = Nx.to_tensor(number)
    fun.(%{tensor | data: %Nx.TemplateBackend{}}, fn -> tensor end, acc)
  end
end

# Implement to speed up fallback to container.
defimpl Nx.LazyContainer, for: Tuple do
  def traverse(tuple, acc, fun) do
    tuple
    |> Tuple.to_list()
    |> Enum.map_reduce(acc, &Nx.LazyContainer.traverse(&1, &2, fun))
    |> then(fn {list, acc} -> {List.to_tuple(list), acc} end)
  end
end

# Implement to speed up fallback to container.
defimpl Nx.LazyContainer, for: Map do
  def traverse(map, acc, fun) do
    map
    |> Map.to_list()
    |> Enum.sort()
    |> Enum.map_reduce(acc, fn {k, v}, acc ->
      {v, acc} = Nx.LazyContainer.traverse(v, acc, fun)
      {{k, v}, acc}
    end)
    |> then(fn {list, acc} -> {Map.new(list), acc} end)
  end
end

defimpl Nx.LazyContainer, for: Atom do
  def traverse(bool, _acc, _fun) when is_boolean(bool) do
    raise Protocol.UndefinedError,
      protocol: @protocol,
      value: bool,
      description:
        "booleans are not valid tensors (and therefore not supported as defn inputs). " <>
          "However, you can convert them to tensors using Nx.tensor/1"
  end

  def traverse(atom, _acc, _fun) do
    raise Protocol.UndefinedError,
      protocol: @protocol,
      value: atom
  end
end

defimpl Nx.LazyContainer, for: List do
  def traverse(list, _acc, _fun) do
    raise Protocol.UndefinedError,
      protocol: @protocol,
      value: list,
      description:
        "lists are not valid tensors (and therefore not supported as defn inputs). " <>
          "However, you can convert them to tensors using Nx.tensor/1"
  end
end

defimpl Nx.LazyContainer, for: Any do
  def traverse(data, acc, fun) do
    if impl = Nx.Container.impl_for(data) do
      impl.traverse(data, acc, &Nx.LazyContainer.traverse(&1, &2, fun))
    else
      raise Protocol.UndefinedError,
        protocol: @protocol,
        value: data,
        description:
          "data-structures given to defn/Nx must implement either Nx.LazyContainer or Nx.Container"
    end
  end
end