lib/nx/lazy_container.ex

defprotocol Nx.LazyContainer do
  @moduledoc """
  Converts a data structure to a lazy container.

  Sometimes building tensors for a container is an expensive
  operation, so we want to allow that to happen lazily.

  This module provides a single traverse implementation
  that emits the tensor template and a function that computes
  the tensor as two distinct values. Then a tensor is only
  allocated if necessary.

  This protocol is used throughout `Nx.Defn` API. This means
  compilation, jitting, and streaming will only realize lazy
  tensors when necessary.

  If a data structure does not implement this protocol,
  a default implementation converts eager to lazy using
  `Nx.Container`.
  """

  @fallback_to_any true

  @doc """
  Traverses recursively tensors in a data structure with `acc` and `fun`.

  For each tensor in the container, `fun` receives a tensor
  template, an anonymous function to build the actual tensor,
  and the accumulator . It returns a two element tuple with
  the updated container and the accumulator.

  This function returns the updated container and the accumulator.

  Note this function is recursive by default. Therefore if you
  are implementing this function and one of your arguments may
  be containers, you must call `Nx.LazyContainer.traverse/3`
  on said arguments so they are recursively traversed.
  """
  @spec traverse(t(), acc, (Nx.template(), (() -> Nx.Tensor.t()), acc -> {term(), acc})) :: acc
        when acc: term()
  def traverse(data, acc, fun)
end

defimpl Nx.LazyContainer, for: Nx.Tensor do
  def traverse(tensor, acc, fun) do
    fun.(%{tensor | data: %Nx.TemplateBackend{}}, fn -> tensor end, acc)
  end
end

defimpl Nx.LazyContainer, for: [Integer, Float, Complex] do
  def traverse(number, acc, fun) do
    tensor = Nx.to_tensor(number)
    fun.(%{tensor | data: %Nx.TemplateBackend{}}, fn -> tensor end, acc)
  end
end

# Implement to speed up fallback to container.
defimpl Nx.LazyContainer, for: Tuple do
  def traverse(tuple, acc, fun) do
    tuple
    |> Tuple.to_list()
    |> Enum.map_reduce(acc, &Nx.LazyContainer.traverse(&1, &2, fun))
    |> then(fn {list, acc} -> {List.to_tuple(list), acc} end)
  end
end

# Implement to speed up fallback to container.
defimpl Nx.LazyContainer, for: Map do
  def traverse(map, acc, fun) do
    map
    |> Map.to_list()
    |> Enum.sort()
    |> Enum.map_reduce(acc, fn {k, v}, acc ->
      {v, acc} = Nx.LazyContainer.traverse(v, acc, fun)
      {{k, v}, acc}
    end)
    |> then(fn {list, acc} -> {Map.new(list), acc} end)
  end
end

defimpl Nx.LazyContainer, for: Atom do
  def traverse(bool, _acc, _fun) when is_boolean(bool) do
    raise Protocol.UndefinedError,
      protocol: @protocol,
      value: bool,
      description:
        "booleans are not valid tensors (and therefore not supported as defn inputs). " <>
          "However, you can convert them to tensors using Nx.tensor/1"
  end

  def traverse(atom, _acc, _fun) do
    raise Protocol.UndefinedError,
      protocol: @protocol,
      value: atom
  end
end

defimpl Nx.LazyContainer, for: List do
  def traverse(list, _acc, _fun) do
    raise Protocol.UndefinedError,
      protocol: @protocol,
      value: list,
      description:
        "lists are not valid tensors (and therefore not supported as defn inputs). " <>
          "However, you can convert them to tensors using Nx.tensor/1"
  end
end

defimpl Nx.LazyContainer, for: Any do
  def traverse(data, acc, fun) do
    Nx.Container.traverse(data, acc, &Nx.LazyContainer.traverse(&1, &2, fun))
  end
end