lib/nx/defn/tree.ex

defmodule Nx.Defn.Tree do
  @moduledoc """
  Helper functions to traverse expressions,
  either as single nodes or recursively.
  """

  alias Nx.Defn.{Composite, Expr}
  alias Nx.Tensor, as: T

  @doc """
  Check if the given tree has any of the given hooks in it.
  """
  def has_hooks?(tree, hooks) do
    Composite.reduce(tree, %{}, &detect_hook(&1, &2, hooks))
    false
  catch
    :side_effect -> true
  end

  defp detect_hook(%T{data: %Expr{op: :token, args: [token]}} = t, cache, hooks) do
    if Enum.any?(token.hooks, &(hooks[&1.name] || &1.callback)) do
      throw(:side_effect)
    else
      fallback_detect_hook(t, cache, hooks)
    end
  end

  defp detect_hook(t, cache, hooks), do: fallback_detect_hook(t, cache, hooks)

  defp fallback_detect_hook(%T{data: %Expr{id: id}} = t, cache, hooks) do
    case cache do
      %{^id => _} ->
        cache

      %{} ->
        {_, cache} = apply_args(t, cache, &{&1, detect_hook(&1, &2, hooks)})
        Map.put(cache, id, true)
    end
  end

  @doc """
  Puts new args in the given tensor expression and gives it a new id.
  """
  def put_args(%T{data: %Expr{} = expr} = t, args) do
    %{t | data: %{expr | id: Expr.id(), args: args}}
  end

  @doc """
  Replaces args in the given tensor expression.

  Use this function with extreme care. Changing the args but keeping
  the same id may mean you have different versions of the same node.
  Do this change only if you guarante all nodes in the tree have been
  replaced equally.
  """
  def replace_args(%T{data: %Expr{} = expr} = t, args) do
    %{t | data: %{expr | args: args}}
  end

  @doc """
  Applies the given function to the arguments of the node,
  with the given accumulator as a starting value.

  Warning: be very careful when using this function to traverse the expression
  recursively. If you plan to do so, you should consider also storing the visited
  nodes to avoid multiple traversals.
  """
  def apply_args(expr, acc, fun)

  def apply_args(%T{data: %Expr{op: :token, args: [token]}}, acc, fun) do
    {hooks, acc} =
      Enum.map_reduce(token.hooks, acc, fn %{expr: expr} = token, acc ->
        {expr, acc} = Composite.traverse(expr, acc, fun)
        {%{token | expr: expr}, acc}
      end)

    {[%{token | hooks: hooks}], acc}
  end

  def apply_args(%T{data: %Expr{op: :fun, args: [args, expr, mfa]}}, acc, fun) do
    {args, acc} = Enum.map_reduce(args, acc, &Composite.traverse(&1, &2, fun))
    {expr, acc} = Composite.traverse(expr, acc, fun)
    {[args, expr, mfa], acc}
  end

  def apply_args(%T{data: %Expr{op: :cond, args: [clauses, last]}}, acc, fun) do
    {clauses, acc} =
      Enum.map_reduce(clauses, acc, fn {pred, expr}, acc ->
        {pred, acc} = fun.(pred, acc)
        {expr, acc} = Composite.traverse(expr, acc, fun)
        {{pred, expr}, acc}
      end)

    {last, acc} = Composite.traverse(last, acc, fun)
    {[clauses, last], acc}
  end

  def apply_args(%T{data: %Expr{op: :while, args: args}}, acc, fun) do
    [initial, arg, pred, block] = args
    {initial, acc} = Composite.traverse(initial, acc, fun)
    {arg, acc} = Composite.traverse(arg, acc, fun)
    {pred, acc} = fun.(pred, acc)
    {block, acc} = Composite.traverse(block, acc, fun)
    {[initial, arg, pred, block], acc}
  end

  def apply_args(%T{data: %Expr{op: :concatenate, args: [list | args]}}, acc, fun) do
    {list, acc} = Enum.map_reduce(list, acc, fun)
    {[list | args], acc}
  end

  def apply_args(%T{data: %Expr{op: :slice, args: [tensor, start_indices | args]}}, acc, fun) do
    {tensor, acc} = fun.(tensor, acc)

    {start_indices, acc} =
      Enum.map_reduce(start_indices, acc, fn
        x, acc when is_integer(x) -> {x, acc}
        x, acc -> fun.(x, acc)
      end)

    {[tensor, start_indices | args], acc}
  end

  def apply_args(
        %T{data: %Expr{op: :put_slice, args: [tensor, start_indices, slice]}},
        acc,
        fun
      ) do
    {tensor, acc} = fun.(tensor, acc)
    {slice, acc} = fun.(slice, acc)

    {start_indices, acc} =
      Enum.map_reduce(start_indices, acc, fn
        x, acc when is_integer(x) -> {x, acc}
        x, acc -> fun.(x, acc)
      end)

    {[tensor, start_indices, slice], acc}
  end

  def apply_args(%T{data: %Expr{op: :optional, args: [expr, default_impl_expr]}}, acc, fun) do
    {expr, acc} = fun.(expr, acc)
    {default_impl_expr, acc} = fun.(default_impl_expr, acc)
    {[expr, default_impl_expr], acc}
  end

  def apply_args(%T{data: %Expr{args: args}}, acc, fun) do
    Enum.map_reduce(args, acc, fn
      %T{data: %Expr{}} = arg, acc -> fun.(arg, acc)
      arg, acc -> {arg, acc}
    end)
  end

  ## Type helpers

  @doc """
  Rewrites the types of the given tensor expressions according to
  the given options.

  ## Options

    * `:max_float_type` - set the max float type
    * `:max_signed_type` - set the max signed integer type
    * `:max_unsigned_type` - set the max unsigned integer type

  """
  def rewrite_types(tensor_expr, opts \\ []) when is_list(opts) do
    {_, max_float_size} = max_float_type = opts[:max_float_type] || {:f, 64}
    {_, max_signed_size} = max_signed_type = opts[:max_signed_type] || {:s, 64}
    {_, max_unsigned_size} = max_unsigned_type = opts[:max_unsigned_type] || {:u, 64}

    if not Nx.Type.float?(max_float_type) do
      raise ArgumentError, ":max_float_type must be float type, got: #{inspect(max_float_type)}"
    end

    if max_float_type != {:f, 64} or max_signed_type != {:s, 64} or max_unsigned_type != {:u, 64} do
      rewrite_type(tensor_expr, fn
        {:u, size} when size >= max_unsigned_size -> max_unsigned_type
        {:s, size} when size >= max_signed_size -> max_signed_type
        {:f, size} when size >= max_float_size -> max_float_type
        {:bf, size} when size >= max_float_size -> max_float_type
        type -> type
      end)
    else
      tensor_expr
    end
  end

  defp rewrite_type(expr, fun) do
    {res, _} = Composite.traverse(expr, %{}, &rewrite_type(&1, &2, fun))
    res
  end

  defp rewrite_type(%T{data: %Expr{id: id, op: op}} = t, cache, fun) do
    case cache do
      %{^id => res} ->
        {res, cache}

      %{} ->
        {args, cache} = apply_args(t, cache, &rewrite_type(&1, &2, fun))
        res = rewrite_type(op, args, t, fun)
        {res, Map.put(cache, id, res)}
    end
  end

  defp rewrite_type(:parameter, _args, %{data: %{context: :root}} = t, type_fun) do
    Nx.as_type(t, type_fun.(t.type))
  end

  defp rewrite_type(:tensor, [arg], t, type_fun) do
    type = type_fun.(t.type)
    rewrite_type_args(t, type, [Nx.as_type(arg, type)])
  end

  defp rewrite_type(:constant, [arg], t, type_fun) do
    type = type_fun.(t.type)
    rewrite_type_args(t, type, [arg])
  end

  defp rewrite_type(_op, args, t, type_fun) do
    rewrite_type_args(t, type_fun.(t.type), args)
  end

  defp rewrite_type_args(%{data: data} = t, type, args) do
    %{t | data: %{data | id: Expr.id(), args: args}, type: type}
  end
end