lib/nx/defn/composite.ex

defmodule Nx.Defn.Composite do
  @moduledoc """
  Functions to deal with composite data types according to `Nx.Container`.

  The functions in this module can be used both inside and outside `defn`.
  """

  alias Nx.Defn.Expr
  alias Nx.Tensor, as: T

  import Nx, only: [is_tensor: 1]

  @doc """
  Traverses two composite types to see if they are compatible.

  For non-composite types, the given `fun` will be called to
  compare numbers/tensors pairwise.
  """
  def compatible?(left, right, fun)
      when is_tensor(left) and is_tensor(right),
      do: fun.(left, right)

  def compatible?(left, right, fun) when tuple_size(left) == tuple_size(right) do
    Tuple.to_list(left)
    |> Enum.zip(Tuple.to_list(right))
    |> Enum.all?(fn {l, r} -> compatible?(l, r, fun) end)
  end

  def compatible?(%mod{} = left, %mod{} = right, fun) do
    left = Nx.Container.reduce(left, [], &[&1 | &2])
    right = Nx.Container.reduce(right, [], &[&1 | &2])
    Enum.zip(left, right) |> Enum.all?(fn {l, r} -> compatible?(l, r, fun) end)
  end

  def compatible?(%_{}, %_{}, _fun),
    do: false

  def compatible?(left, right, fun) when map_size(left) == map_size(right) do
    Enum.all?(left, fn {k, v1} ->
      case right do
        %{^k => v2} -> compatible?(v1, v2, fun)
        %{} -> false
      end
    end)
  end

  def compatible?(_, _, _),
    do: false

  @doc """
  Counts the number of non-composite types in the composite type.

  ## Examples

      iex> Nx.Defn.Composite.count(123)
      1
      iex> Nx.Defn.Composite.count({1, {2, 3}})
      3
      iex> Nx.Defn.Composite.count({Complex.new(1), {Nx.tensor(2), 3}})
      3

  """
  def count(tree), do: count(tree, 0)
  defp count(tensor, acc) when is_tensor(tensor), do: acc + 1
  defp count(container, acc), do: Nx.Container.reduce(container, acc, &count/2)

  @doc """
  Traverses the given composite types with `fun`.

  If a composite tensor is given, such as a tuple, the composite
  type is recursively traversed and returned.

  Otherwise the function is invoked with the tensor (be it a
  number, complex, or actual tensor).
  """
  def traverse(expr, fun) when is_function(fun, 1) do
    {result, []} = traverse(expr, [], fn expr, [] -> {fun.(expr), []} end)
    result
  end

  @doc """
  Traverses the given composite types with `acc` and `fun`.

  If a composite tensor is given, such as a tuple, the composite
  type is recursively traversed and returned.

  Otherwise the function is invoked with the tensor (be it a
  number, complex, or actual tensor).
  """
  def traverse(expr, acc, fun) when (is_tensor(expr) or is_boolean(expr)) and is_function(fun, 2),
    do: fun.(expr, acc)

  def traverse(container, acc, fun),
    do: Nx.Container.traverse(container, acc, &traverse(&1, &2, fun))

  @doc """
  Reduces the given composite types with `acc` and `fun`.

  If composite tensor expressions are given, such as a tuple,
  the composite type is recursively traversed and returned.

  If a non-composite tensor expression is given, the function
  is invoked for it but not for its arguments.
  """
  def reduce(expr, acc, fun) when is_tensor(expr) and is_function(fun, 2),
    do: fun.(expr, acc)

  def reduce(container, acc, fun),
    do: Nx.Container.reduce(container, acc, &reduce(&1, &2, fun))

  @doc """
  Flattens the given list of composite types.

  Elements that are not tensors (i.e. numbers and `Complex` numbers) are kept as is
  unless a custom function is given.

  ## Examples

      iex> Nx.Defn.Composite.flatten_list([1, {2, 3}])
      [1, 2, 3]

      iex> Nx.Defn.Composite.flatten_list([1, {2, 3}], [Nx.tensor(4)])
      [1, 2, 3, Nx.tensor(4)]

      iex> Nx.Defn.Composite.flatten_list([1, {2, 3}], [Nx.tensor(4)], &Nx.tensor/1)
      [Nx.tensor(1), Nx.tensor(2), Nx.tensor(3), Nx.tensor(4)]

  """
  def flatten_list(args, tail \\ [], fun \\ & &1) when is_list(args) do
    args
    |> Enum.reduce([], &flatten_each(&1, &2, fun))
    |> Enum.reverse(tail)
  end

  defp flatten_each(%T{} = tensor, acc, _fun),
    do: [tensor | acc]

  defp flatten_each(number, acc, fun)
       when is_number(number) or is_struct(number, Complex),
       do: [fun.(number) | acc]

  defp flatten_each(container, acc, fun),
    do: Nx.Container.reduce(container, acc, &flatten_each(&1, &2, fun))

  ## Nx.Defn callbacks

  @doc false
  def split_compile_args(args, cache) do
    split_compile_args(args, cache, [])
  end

  defp split_compile_args([arg | args], cache, vars)
       when is_function(arg)
       when is_tuple(arg) and is_function(elem(arg, 0)) do
    split_compile_args(args, [arg | cache], vars)
  end

  defp split_compile_args([arg | args], cache, vars) do
    split_compile_args(args, cache, [arg | vars])
  end

  defp split_compile_args([], cache, vars), do: {cache, Enum.reverse(vars)}

  @doc false
  def join_compile_args(partial_args, full_args) do
    {args, []} =
      Enum.map_reduce(full_args, partial_args, fn
        arg, acc
        when is_function(arg)
        when is_tuple(arg) and is_function(elem(arg, 0)) ->
          {arg, acc}

        _arg, [arg | acc] ->
          {arg, acc}
      end)

    args
  end

  @doc false
  def flatten_runtime_args(args, tail) do
    list = flatten_list(args, tail, &Nx.to_tensor/1)

    for %Nx.Tensor{data: %Nx.Defn.Expr{}} = tensor <- list do
      raise ArgumentError,
            "cannot pass a tensor expression as argument to defn, got: #{inspect(tensor)}"
    end

    list
  end

  @doc false
  def to_inputs(args) do
    {args, _} =
      Enum.map_reduce(args, 0, fn arg, i ->
        traverse(arg, i, fn
          arg, _i when is_boolean(arg) ->
            raise ArgumentError, "booleans are not supported as defn inputs"

          arg, i ->
            {Expr.parameter(Nx.to_tensor(arg), :root, i), i + 1}
        end)
      end)

    args
  end

  @doc false
  def to_output(container) do
    traverse(container, &Expr.tensor/1)
  end
end