lib/nx/defn/tree.ex

defmodule Nx.Defn.Tree do
  @moduledoc """
  Helper functions to traverse defn expressions,
  either as single nodes or recursively.
  """

  alias Nx.Defn.{Composite, Expr}
  alias Nx.Tensor, as: T

  @doc """
  Check if the given tree has any of the given hooks in it.
  """
  def has_hooks?(tree, hooks) do
    Composite.reduce(tree, %{}, &detect_hook(&1, &2, hooks))
    false
  catch
    :side_effect -> true
  end

  defp detect_hook(%T{data: %Expr{op: :token, args: [token]}} = t, cache, hooks) do
    if Enum.any?(token.hooks, &(hooks[&1.name] || &1.callback)) do
      throw(:side_effect)
    else
      fallback_detect_hook(t, cache, hooks)
    end
  end

  defp detect_hook(t, cache, hooks), do: fallback_detect_hook(t, cache, hooks)

  defp fallback_detect_hook(%T{data: %Expr{id: id}} = t, cache, hooks) do
    case cache do
      %{^id => _} ->
        cache

      %{} ->
        {_, cache} = apply_args(t, cache, &{&1, detect_hook(&1, &2, hooks)})
        Map.put(cache, id, true)
    end
  end

  @doc """
  Gets all IDs of all elements in the same scope.

  `while`'s condition and body, `fun`'s body and similar are
  considered different scopes. When it comes to `cond`, an ID will
  only be considered if it is used outside of the `cond` or used
  in several distinct conds. Constants are also ignored, as they
  have global IDs based on the constants themselves.

  An existing maps of `ids` can be given to accumulate on top of it.
  """
  def scope_ids(expr, ids \\ %{}) do
    Composite.reduce(expr, {ids, %{}}, &scope_ids_each(&1, &2, nil)) |> elem(0)
  end

  # Ignore constants
  defp scope_ids_each(%Nx.Tensor{data: %Expr{op: :constant}}, {ids, cond_ids}, _scope) do
    {ids, cond_ids}
  end

  # We are at the root.
  defp scope_ids_each(%Nx.Tensor{data: %Expr{id: id, op: op} = expr} = t, {ids, cond_ids}, nil) do
    case ids do
      %{^id => _} ->
        {ids, cond_ids}

      %{} when op == :cond ->
        ids = Map.put(ids, id, op)
        acc = {ids, cond_ids}

        # We will treat the predicate as part of the scope to avoid executing it more than once
        [[{pred, body} | clauses], last] = expr.args
        acc = scope_ids_each(pred, acc, nil)
        acc = Composite.reduce(body, acc, &scope_ids_each(&1, &2, id))

        # Now we traverse as in apply_args
        acc =
          Enum.reduce(clauses, acc, fn {pred, body}, acc ->
            acc = scope_ids_each(pred, acc, id)
            Composite.reduce(body, acc, &scope_ids_each(&1, &2, id))
          end)

        Composite.reduce(last, acc, &scope_ids_each(&1, &2, id))

      %{} ->
        ids = Map.put(ids, id, op)
        scope_ids_args(t, {ids, cond_ids}, nil)
    end
  end

  # If we are inside a cond, we want to collect all of the IDs inside that
  # cond separately and, in case it is present in more than one direct cond,
  # move it to the parent scope.
  defp scope_ids_each(%Nx.Tensor{data: %Expr{id: id}} = t, {ids, cond_ids}, scope) do
    case cond_ids do
      %{^id => ^scope} ->
        {ids, cond_ids}

      %{^id => _} ->
        scope_ids_each(t, {ids, cond_ids}, nil)

      %{} ->
        cond_ids = Map.put(cond_ids, id, scope)
        scope_ids_args(t, {ids, cond_ids}, scope)
    end
  end

  defp scope_ids_args(t, acc, scope) do
    t
    |> apply_args(:scope, acc, &{&1, scope_ids_each(&1, &2, scope)})
    |> elem(1)
  end

  @doc """
  Puts new args in the given tensor expression and gives it a new id.
  """
  def put_args(%T{data: %Expr{} = expr} = t, args) do
    %{t | data: %{expr | id: Expr.id(), args: args}}
  end

  @doc """
  Applies the given function to the arguments of the node,
  with the given accumulator as a starting value.

  By default, `type` is `:all`, which means all arguments
  are traversed. If `type` is `:scope`, only expressions
  that are in the same scope are traversed. Therefore,
  expressions such as `while`'s condition and body, 
  `optional`'s default implementation, functions, and so forth
  are not traversed. Note `cond`s are always traversed because,
  while they introduce a new scope, they can also access its
  parents directly, so you must take `cond`s into account
  accordingly.

  Warning: be very careful when using this function to traverse
  the expression recursively. If you plan to do so, you should
  consider also storing the visited nodes to avoid multiple
  traversals by using `tensor.data.expr.id` as cache key.
  """
  def apply_args(expr, type \\ :all, acc, fun)

  def apply_args(%T{data: %Expr{op: :fun, args: [args, expr, mfa]}}, type, acc, fun) do
    {args, acc} = Enum.map_reduce(args, acc, &Composite.traverse(&1, &2, fun))

    {expr, acc} =
      case type do
        :all -> Composite.traverse(expr, acc, fun)
        :scope -> {expr, acc}
      end

    {[args, expr, mfa], acc}
  end

  def apply_args(%T{data: %Expr{op: :cond, args: [clauses, last]}}, _type, acc, fun) do
    {clauses, acc} =
      Enum.map_reduce(clauses, acc, fn {pred, expr}, acc ->
        {pred, acc} = fun.(pred, acc)
        {expr, acc} = Composite.traverse(expr, acc, fun)
        {{pred, expr}, acc}
      end)

    {last, acc} = Composite.traverse(last, acc, fun)
    {[clauses, last], acc}
  end

  def apply_args(%T{data: %Expr{op: :while, args: args}}, type, acc, fun) do
    [initial, arg, pred, block] = args
    {initial, acc} = Composite.traverse(initial, acc, fun)

    case type do
      :all ->
        {arg, acc} = Composite.traverse(arg, acc, fun)
        {pred, acc} = fun.(pred, acc)
        {block, acc} = Composite.traverse(block, acc, fun)
        {[initial, arg, pred, block], acc}

      :scope ->
        {[initial, arg, pred, block], acc}
    end
  end

  def apply_args(%T{data: %Expr{op: :optional, args: args}}, type, acc, fun) do
    [call, expr] = args
    {call, acc} = fun.(call, acc)

    {expr, acc} =
      case type do
        :all -> Composite.traverse(expr, acc, fun)
        :scope -> {expr, acc}
      end

    {[call, expr], acc}
  end

  def apply_args(%T{data: %Expr{op: :token, args: [token]}}, _type, acc, fun) do
    {hooks, acc} =
      Enum.map_reduce(token.hooks, acc, fn %{expr: expr} = token, acc ->
        {expr, acc} = Composite.traverse(expr, acc, fun)
        {%{token | expr: expr}, acc}
      end)

    {[%{token | hooks: hooks}], acc}
  end

  def apply_args(%T{data: %Expr{op: :concatenate, args: [list | args]}}, _type, acc, fun) do
    {list, acc} = Enum.map_reduce(list, acc, fun)
    {[list | args], acc}
  end

  def apply_args(%T{data: %Expr{op: :slice, args: args}}, _type, acc, fun) do
    [tensor, start_indices | args] = args
    {tensor, acc} = fun.(tensor, acc)

    {start_indices, acc} =
      Enum.map_reduce(start_indices, acc, fn
        x, acc when is_integer(x) -> {x, acc}
        x, acc -> fun.(x, acc)
      end)

    {[tensor, start_indices | args], acc}
  end

  def apply_args(%T{data: %Expr{op: :put_slice, args: args}}, _type, acc, fun) do
    [tensor, start_indices, slice] = args
    {tensor, acc} = fun.(tensor, acc)
    {slice, acc} = fun.(slice, acc)

    {start_indices, acc} =
      Enum.map_reduce(start_indices, acc, fn
        x, acc when is_integer(x) -> {x, acc}
        x, acc -> fun.(x, acc)
      end)

    {[tensor, start_indices, slice], acc}
  end

  def apply_args(%T{data: %Expr{op: :metadata, args: [expr, metadata]}}, _type, acc, fun) do
    {expr, acc} = Composite.traverse(expr, acc, fun)
    {[expr, metadata], acc}
  end

  def apply_args(%T{data: %Expr{args: args}}, _type, acc, fun) do
    Enum.map_reduce(args, acc, fn
      %T{data: %Expr{}} = arg, acc -> fun.(arg, acc)
      arg, acc -> {arg, acc}
    end)
  end

  ## Type helpers

  @doc """
  Rewrites the types of the given tensor expressions according to
  the given options.

  ## Options

    * `:max_float_type` - set the max float type
    * `:max_signed_type` - set the max signed integer type
    * `:max_unsigned_type` - set the max unsigned integer type

  """
  def rewrite_types(tensor_expr, opts \\ []) when is_list(opts) do
    {_, max_float_size} = max_float_type = opts[:max_float_type] || {:f, 64}
    {_, max_signed_size} = max_signed_type = opts[:max_signed_type] || {:s, 64}
    {_, max_unsigned_size} = max_unsigned_type = opts[:max_unsigned_type] || {:u, 64}

    if not Nx.Type.float?(max_float_type) do
      raise ArgumentError, ":max_float_type must be float type, got: #{inspect(max_float_type)}"
    end

    if max_float_type != {:f, 64} or max_signed_type != {:s, 64} or max_unsigned_type != {:u, 64} do
      rewrite_type(tensor_expr, fn
        {:u, size} when size >= max_unsigned_size -> max_unsigned_type
        {:s, size} when size >= max_signed_size -> max_signed_type
        {:f, size} when size >= max_float_size -> max_float_type
        {:bf, size} when size >= max_float_size -> max_float_type
        type -> type
      end)
    else
      tensor_expr
    end
  end

  defp rewrite_type(expr, fun) do
    {res, _} = Composite.traverse(expr, %{}, &rewrite_type(&1, &2, fun))
    res
  end

  defp rewrite_type(%T{data: %Expr{id: id, op: op}} = t, cache, fun) do
    case cache do
      %{^id => res} ->
        {res, cache}

      %{} ->
        {args, cache} = apply_args(t, cache, &rewrite_type(&1, &2, fun))
        res = rewrite_type(op, args, t, fun)
        {res, Map.put(cache, id, res)}
    end
  end

  defp rewrite_type(:parameter, _args, %{data: %{context: :root}} = t, type_fun) do
    Nx.as_type(t, type_fun.(t.type))
  end

  defp rewrite_type(:tensor, [arg], t, type_fun) do
    type = type_fun.(t.type)
    rewrite_type_args(:tensor, t, type, [Nx.as_type(arg, type)])
  end

  defp rewrite_type(op, args, t, type_fun) do
    rewrite_type_args(op, t, type_fun.(t.type), args)
  end

  defp rewrite_type_args(:constant, t, type, [arg]) do
    Expr.constant(%{t | type: type}, arg, [])
  end

  defp rewrite_type_args(_op, %{data: data} = t, type, args) do
    %{t | data: %{data | id: Expr.id(), args: args}, type: type}
  end
end