Skip to main content

lib/nx_vulkan/compiler.ex

defmodule Nx.Vulkan.Compiler do
  @moduledoc """
  Path A.2 v2 (partial) — `Nx.Defn.Compiler` that auto-detects fusable
  elementwise chains and dispatches `Nx.Vulkan.fused_chain/3` instead
  of N separate shader calls.

  ## What it does today

  At `__jit__` time, calls `fun.(vars)` once to materialize the IR.
  Walks the result looking for a chain of supported elementwise ops
  whose only inputs are the two function arguments. If matched: skip
  Evaluator entirely and dispatch a single `fused_chain` call.

  ## What it doesn't do yet

    * **Multi-output**: only single-output chains. A defn that returns
      a tuple falls through to `Nx.Defn.Evaluator`.
    * **Branched chains**: only linear chains. A node that's used
      twice falls through.
    * **More than 2 vars**: 2-arg functions only (matches the shader's
      two-input layout). Wider arities fall through.
    * **Chains > 8 ops**: shader limit; longer chains fall through.
    * **Non-elementwise ops**: any reduce/reshape/dot in the chain
      falls through.

  All fall-through cases delegate to `Nx.Defn.Evaluator` so behavior
  stays correct — the worst case is "no fusion, same speed as before."

  ## Configuration

      config :exmc, :compiler, :vulkan

  `Exmc.JIT` then routes to `Nx.Vulkan.jit/2`, which uses this compiler
  when available (defaults to `Nx.Defn.Evaluator` if unsupported).
  """

  @behaviour Nx.Defn.Compiler

  alias Nx.Defn.Expr
  alias Nx.Tensor, as: T

  @binary_ops %{
    add: 0,
    multiply: 1,
    subtract: 2,
    divide: 3,
    pow: 4,
    max: 5,
    min: 6
  }

  # Commutative binary ops — if first arg is the b var and second is
  # not, we can swap and still produce the correct result. Subtract,
  # divide, pow are NOT commutative and so the chain bails on the
  # reversed pattern.
  @commutative_ops [:add, :multiply, :max, :min]

  @unary_ops [
    :exp,
    :log,
    :sqrt,
    :abs,
    :negate,
    :sigmoid,
    :tanh,
    :relu,
    :ceil,
    :floor,
    :sign,
    :reciprocal,
    :square,
    :erf,
    :expm1
  ]

  @impl true
  def __partitions_options__(opts) do
    List.duplicate(opts, Keyword.get(opts, :max_concurrency, 1))
  end

  @impl true
  def __to_backend__(_opts) do
    Nx.default_backend()
  end

  @impl true
  def __jit__(key, vars, fun, args_list, opts) do
    __compile__(key, vars, fun, opts).(args_list)
  end

  @impl true
  def __compile__(key, vars, fun, opts) do
    expr = trace(fun, vars)
    var_ids = collect_var_ids(vars)

    case detect_chain(expr, var_ids) do
      {:ok, {:fused_4in, ops_with_buf, a_id, leaf_to_buf, _vars_list}, [], _, _} ->
        if System.get_env("NXV_FUSE_DEBUG") == "1" do
          IO.puts("[Nx.Vulkan.Compiler] FUSED_4IN: ops=#{inspect(ops_with_buf)}")
        end

        compile_fused_4in(ops_with_buf, a_id, leaf_to_buf, var_ids, expr)

      {:ok, outer_ops, b_pre_ops, a_var_id, b_var_id} ->
        if System.get_env("NXV_FUSE_DEBUG") == "1" do
          tag = if b_pre_ops == [], do: "FUSED", else: "FUSED+pre"
          IO.puts("[Nx.Vulkan.Compiler] #{tag}: pre=#{inspect(b_pre_ops)} outer=#{inspect(outer_ops)}")
        end

        compile_fused(outer_ops, b_pre_ops, a_var_id, b_var_id, vars, expr)

      :no_match ->
        if System.get_env("NXV_FUSE_DEBUG") == "1" do
          IO.puts("[Nx.Vulkan.Compiler] no_match — vars=#{length(var_ids)} root_op=#{inspect_root(expr)}")
        end

        # Fall through: evaluator handles the rest.
        Nx.Defn.Evaluator.__compile__(key, vars, fun, opts)
    end
  end

  defp inspect_root(%T{data: %Expr{op: op, args: args}}) do
    arg_ops =
      Enum.map(args, fn
        %T{data: %Expr{op: o}} -> o
        other -> "#{inspect(other)}"
      end)

    "#{op} args=#{inspect(arg_ops)}"
  end

  defp inspect_root(other), do: inspect(other)

  @impl true
  def __shard_jit__(_key, _mesh, _vars, _fun, _args_list, _opts) do
    raise "sharding is not supported by Nx.Vulkan.Compiler"
  end

  # --- Tracing ---------------------------------------------------------

  defp trace(fun, vars) do
    # Trace under the Nx.Defn.Expr backend so calls produce IR nodes
    # rather than executing on whichever backend is current.
    previous = Process.put(Nx.Shared.backend_pdict_key(), {Nx.Defn.Expr, []})

    try do
      vars
      |> fun.()
    after
      if previous,
        do: Process.put(Nx.Shared.backend_pdict_key(), previous),
        else: Process.delete(Nx.Shared.backend_pdict_key())
    end
  end

  defp collect_var_ids(vars) do
    vars
    |> Enum.map(fn
      %T{data: %Expr{op: :parameter, args: [i]}} = t ->
        # Capture {position, id, shape, type} for matching during walk.
        {i, t.data.id, t.shape, t.type}

      _ ->
        nil
    end)
    |> Enum.reject(&is_nil/1)
  end

  # --- Chain detection -------------------------------------------------

  # A chain is recognizable iff the result IR is a single-output tensor
  # whose op is elementwise, recursing into a chain that ultimately
  # bottoms out at one of the input vars (the "a" var). Every binary
  # step's second operand must be the "b" var.
  defp detect_chain(%T{data: %Expr{}} = root, var_ids) when length(var_ids) == 2 do
    [{0, a_id, _shape_a, _type_a}, {1, b_id, _shape_b, _type_b}] = var_ids

    case walk_chain(root, a_id, b_id, []) do
      {:ok, ops, b_pre_ops}
      when length(ops) >= 1 and length(ops) <= 8 and length(b_pre_ops) <= 8 ->
        {:ok, ops, b_pre_ops, a_id, b_id}

      _ ->
        :no_match
    end
  end

  # 1-arg defn — pass `a` as both buffers. Only unary chains can fuse
  # this way (binary ops would degenerate to op(a, a) which is
  # rarely what the user wrote). The shader still reads b unconditionally
  # but ignores its value when no binary op fires.
  defp detect_chain(%T{data: %Expr{}} = root, var_ids) when length(var_ids) == 1 do
    [{0, a_id, _shape, _type}] = var_ids

    case walk_unary_only(root, a_id, []) do
      {:ok, ops, []} when length(ops) >= 1 and length(ops) <= 8 ->
        {:ok, ops, [], a_id, a_id}

      _ ->
        :no_match
    end
  end

  # 3-arg or 4-arg defn — try the 4-input shader path. The chain
  # register starts at one of the parameters; the others appear as
  # second-operand of binary ops with assigned buf_idx (1, 2, 3 → b, c, d).
  defp detect_chain(%T{data: %Expr{}} = root, var_ids)
       when length(var_ids) == 3 or length(var_ids) == 4 do
    detect_chain_n(root, var_ids)
  end

  defp detect_chain(_, _), do: :no_match

  # Try each parameter as the chain start (`a`); first successful walk wins.
  defp detect_chain_n(root, var_ids) do
    Enum.find_value(var_ids, :no_match, fn {_pos, a_id, _shape, _type} ->
      case find_chain_to(a_id, root, %{}) do
        {:ok, ops_with_buf, leaf_to_buf}
        when length(ops_with_buf) >= 1 and length(ops_with_buf) <= 8 ->
          format_4in_match(ops_with_buf, leaf_to_buf, a_id, var_ids)

        _ ->
          nil
      end
    end)
  end

  # Pack the result into the 5-tuple shape detect_chain returns.
  # ops_with_buf: list of `op_atom` (unary) or `{op_atom, buf_idx}` (binary).
  # leaf_to_buf: %{leaf_id → buf_idx} mapping for non-`a` parameters.
  defp format_4in_match(ops_with_buf, leaf_to_buf, a_id, var_ids) do
    # Sentinel signaling 4-input fused dispatch; compile path branches on it.
    {:ok, {:fused_4in, ops_with_buf, a_id, leaf_to_buf, var_ids}, [], a_id, a_id}
  end

  # find_chain_to(target, expr, leaf_to_buf) walks the IR from root toward
  # the target leaf. Returns:
  #   {:ok, ops_with_buf_in_exec_order, updated_leaf_to_buf} | :no_match
  # Sibling subtrees of binary ops along the path must be parameter leaves.
  defp find_chain_to(target, %T{data: %Expr{op: :parameter, id: id}}, map)
       when id == target do
    {:ok, [], map}
  end

  # Unary op — recurse, append op AFTER the inner chain ops.
  defp find_chain_to(target, %T{data: %Expr{op: op, args: [arg]}}, map)
       when op in @unary_ops do
    case find_chain_to(target, arg, map) do
      {:ok, ops, m} -> {:ok, ops ++ [op], m}
      :no_match -> :no_match
    end
  end

  # Binary op — descend into first; if not on the path, try second
  # (only if commutative).
  defp find_chain_to(target, %T{data: %Expr{op: op, args: [first, second]}}, map) do
    cond do
      not Map.has_key?(@binary_ops, op) ->
        :no_match

      true ->
        case find_chain_to(target, first, map) do
          {:ok, ops, m} ->
            case classify_b_leaf(second, m) do
              {:ok, idx, m2} -> {:ok, ops ++ [{op, idx}], m2}
              :no_match -> :no_match
            end

          :no_match when op in @commutative_ops ->
            case find_chain_to(target, second, map) do
              {:ok, ops, m} ->
                case classify_b_leaf(first, m) do
                  {:ok, idx, m2} -> {:ok, ops ++ [{op, idx}], m2}
                  :no_match -> :no_match
                end

              :no_match ->
                :no_match
            end

          :no_match ->
            :no_match
        end
    end
  end

  defp find_chain_to(_, _, _), do: :no_match

  # The "non-chain side" of a binary op must be a parameter leaf in v1.
  # Assigns or reuses a buf_idx slot (1/2/3 = b/c/d).
  defp classify_b_leaf(%T{data: %Expr{op: :parameter, id: id}}, map) do
    case Map.get(map, id) do
      nil ->
        next_idx = map_size(map) + 1

        if next_idx > 3 do
          :no_match
        else
          {:ok, next_idx, Map.put(map, id, next_idx)}
        end

      existing ->
        {:ok, existing, map}
    end
  end

  defp classify_b_leaf(_, _), do: :no_match

  # 1-arg variant: only unary ops; bottom out at the single var.
  defp walk_unary_only(%T{data: %Expr{id: id, op: :parameter}}, a_id, acc)
       when id == a_id do
    {:ok, acc, []}
  end

  defp walk_unary_only(%T{data: %Expr{op: op, args: [arg]}}, a_id, acc)
       when op in @unary_ops do
    walk_unary_only(arg, a_id, [op | acc])
  end

  defp walk_unary_only(_, _, _), do: :no_match

  # Walks a sub-expression that bottoms out at b. Recognized shapes:
  #   - parameter b → empty pre-chain (just use b directly)
  #   - unary(sub) → recurse, prepend the unary
  #   - multiply(b, b) → :square peephole
  defp walk_b_subchain(%T{data: %Expr{op: :parameter, id: id}}, b_id, acc)
       when id == b_id do
    {:ok, acc}
  end

  defp walk_b_subchain(%T{data: %Expr{op: op, args: [arg]}}, b_id, acc)
       when op in @unary_ops do
    walk_b_subchain(arg, b_id, [op | acc])
  end

  # mult(b, b) ⇒ square(b) peephole
  defp walk_b_subchain(
         %T{
           data: %Expr{
             op: :multiply,
             args: [%T{data: %Expr{op: :parameter, id: id1}}, %T{data: %Expr{op: :parameter, id: id2}}]
           }
         },
         b_id,
         acc
       )
       when id1 == b_id and id2 == b_id do
    {:ok, [:square | acc]}
  end

  defp walk_b_subchain(_, _, _), do: :no_match

  # walk_chain returns {:ok, ops, b_pre_ops} or :no_match.
  # b_pre_ops is the unary chain to apply to b BEFORE the outer chain.
  # Most paths return [] (b used directly); right-folded patterns return
  # the pre-eval ops for b's sub-expression.

  # Reached `a` — bottom of chain. No b pre-eval.
  defp walk_chain(%T{data: %Expr{id: id, op: :parameter}}, a_id, _b_id, acc)
       when id == a_id do
    {:ok, acc, []}
  end

  # Unary fusable op — record and recurse. Propagates b_pre_ops from below.
  defp walk_chain(%T{data: %Expr{op: op, args: [arg]}}, a_id, b_id, acc)
       when op in @unary_ops do
    walk_chain(arg, a_id, b_id, [op | acc])
  end

  # Binary fusable op — second arg must be `b`. If first is `b` and the
  # op is commutative, swap and continue. If neither but first is the
  # `a` parameter and second is a chain on b, switch to right-folded
  # mode (pre-eval the second-arg sub-chain on b once, then dispatch
  # the outer chain with that temp as b).
  defp walk_chain(%T{data: %Expr{op: op, args: [first, second]}}, a_id, b_id, acc) do
    cond do
      not Map.has_key?(@binary_ops, op) ->
        :no_match

      var_id(second) == b_id ->
        walk_chain(first, a_id, b_id, [op | acc])

      var_id(first) == b_id and op in @commutative_ops ->
        walk_chain(second, a_id, b_id, [op | acc])

      var_id(first) == a_id ->
        # Right-folded: chain ends here on the a side; pre-eval the
        # sub-chain on b and combine with op.
        case walk_b_subchain(second, b_id, []) do
          {:ok, b_pre_ops} -> {:ok, [op | acc], b_pre_ops}
          :no_match -> :no_match
        end

      var_id(second) == a_id and op in @commutative_ops ->
        case walk_b_subchain(first, b_id, []) do
          {:ok, b_pre_ops} -> {:ok, [op | acc], b_pre_ops}
          :no_match -> :no_match
        end

      true ->
        :no_match
    end
  end

  defp walk_chain(_, _, _, _), do: :no_match

  defp var_id(%T{data: %Expr{op: :parameter, id: id}}), do: id
  defp var_id(_), do: nil

  # --- Compilation -----------------------------------------------------

  # 4-input variant. Maps each parameter to a thunk index (positional)
  # and a buffer slot. Builds a closure that grabs the right tensors
  # and dispatches Nx.Vulkan.fused_chain_4 with one shader invocation.
  defp compile_fused_4in(ops_with_buf, a_id, leaf_to_buf, var_ids, expr) do
    out_shape = expr.shape
    out_type = expr.type

    # Build a position-sorted list of {pos, id} so we can index thunks.
    pos_to_id =
      var_ids
      |> Enum.map(fn {pos, id, _shape, _type} -> {pos, id} end)
      |> Enum.into(%{})

    # Reverse map: id → position in thunks list.
    id_to_pos = Map.new(pos_to_id, fn {pos, id} -> {id, pos} end)

    a_pos = Map.fetch!(id_to_pos, a_id)

    # buf_pos: idx → param_position (idx ∈ {1, 2, 3}).
    # leaf_to_buf maps id → idx; we want idx → pos.
    buf_pos =
      Enum.into(leaf_to_buf, %{}, fn {id, idx} ->
        {idx, Map.fetch!(id_to_pos, id)}
      end)

    fn [params] ->
      thunks = params

      a_tensor = Enum.fetch!(thunks, a_pos).()

      b_tensor = lookup_or(buf_pos, 1, thunks, a_tensor)
      c_tensor = lookup_or(buf_pos, 2, thunks, a_tensor)
      d_tensor = lookup_or(buf_pos, 3, thunks, a_tensor)

      [run_fused_4in(a_tensor, b_tensor, c_tensor, d_tensor, ops_with_buf, out_shape, out_type)]
    end
  end

  # Returns the thunk's tensor at position buf_pos[idx], or fallback when
  # that buf_idx isn't in use (the shader won't read it; just satisfy
  # Vulkan's bind requirement).
  defp lookup_or(buf_pos, idx, thunks, fallback) do
    case Map.get(buf_pos, idx) do
      nil -> fallback
      pos -> Enum.fetch!(thunks, pos).()
    end
  end

  defp run_fused_4in(a, b, c, d, ops_with_buf, out_shape, out_type) do
    case {a.data, b.data, c.data, d.data} do
      {%Nx.Vulkan.Backend{ref: ar},
       %Nx.Vulkan.Backend{ref: br},
       %Nx.Vulkan.Backend{ref: cr},
       %Nx.Vulkan.Backend{ref: dr}} ->
        {:ok, ref} = Nx.Vulkan.fused_chain_4(ar, br, cr, dr, ops_with_buf)

        %T{
          data: %Nx.Vulkan.Backend{ref: ref, shape: out_shape, type: out_type},
          shape: out_shape,
          type: out_type,
          names: List.duplicate(nil, tuple_size(out_shape)),
          vectorized_axes: []
        }

      _ ->
        # Operands aren't all on Vulkan — fall through to per-op execution.
        run_4in_fallback(a, b, c, d, ops_with_buf)
    end
  end

  defp run_4in_fallback(a, b, c, d, ops_with_buf) do
    Enum.reduce(ops_with_buf, a, fn
      op, acc when op in @unary_ops ->
        apply(Nx, op, [acc])

      {op, idx}, acc ->
        other =
          case idx do
            1 -> b
            2 -> c
            3 -> d
          end

        apply(Nx, op, [acc, other])
    end)
  end

  defp compile_fused(outer_ops, b_pre_ops, _a_var_id, _b_var_id, _vars, expr) do
    out_shape = expr.shape
    out_type = expr.type

    fn [params] ->
      case params do
        [a_thunk] ->
          # 1-arg path — no b_pre_ops possible.
          a_tensor = a_thunk.()
          [run_fused(a_tensor, a_tensor, outer_ops, out_shape, out_type)]

        [a_thunk, b_thunk | _] ->
          a_tensor = a_thunk.()
          b_tensor = b_thunk.()

          # If b_pre_ops is non-empty, evaluate that unary chain on b
          # first to produce a temp buffer used as b in the outer
          # fused chain. One pre-dispatch + one fused dispatch.
          b_eff =
            if b_pre_ops == [] do
              b_tensor
            else
              run_fused(b_tensor, b_tensor, b_pre_ops, b_tensor.shape, b_tensor.type)
            end

          [run_fused(a_tensor, b_eff, outer_ops, out_shape, out_type)]
      end
    end
  end

  defp run_fused(a_tensor, b_tensor, ops, out_shape, out_type) do
    case {a_tensor.data, b_tensor.data} do
      {%Nx.Vulkan.Backend{ref: a_ref}, %Nx.Vulkan.Backend{ref: b_ref}} ->
        {:ok, ref} = Nx.Vulkan.fused_chain(a_ref, b_ref, ops)

        # Build the result tensor matching the Nx.Tensor struct shape.
        # vectorized_axes must be present (defaults to []) so downstream
        # ops that pattern-match it (to_binary, etc) don't crash.
        %T{
          data: %Nx.Vulkan.Backend{ref: ref, shape: out_shape, type: out_type},
          shape: out_shape,
          type: out_type,
          names: List.duplicate(nil, tuple_size(out_shape)),
          vectorized_axes: []
        }

      _ ->
        # Operands aren't both on the Vulkan backend — fall through to
        # ordinary execution. The simplest "fall through" here is to
        # just rebuild the chain on whatever backend the inputs are on.
        Enum.reduce(ops, a_tensor, &apply_chain_op(&1, &2, b_tensor))
    end
  end

  # Helper for the fall-through path: apply each op in order.
  defp apply_chain_op(op, acc, _b_tensor) when op in @unary_ops do
    apply(Nx, op, [acc])
  end

  defp apply_chain_op(op, acc, b_tensor) do
    apply(Nx, op, [acc, b_tensor])
  end
end