lib/emlx_axon.ex

defmodule EMLXAxon do
  @moduledoc """
  Axon model rewrites that swap supported nodes to `EMLX.Fast` Metal shaders.

  Pass an `%Axon{}` model through `rewrite/1` before compiling it with
  `Axon.build/2` or `Bumblebee.Text.generation/4` to replace supported
  normalization and attention nodes with single-kernel MLX equivalents.

  ## Supported rewrites

  | Key                  | Matched node                           | Replaced with                                  |
  |----------------------|----------------------------------------|------------------------------------------------|
  | `:rms_norm`          | `op_name: :rms_norm`, `shift: 0.0`    | `EMLX.Fast.rms_norm/3`                        |
  | `:layer_norm`        | `op_name: :layer_norm`                 | `EMLX.Fast.layer_norm/3,4`                    |
  | `:rotary_embedding`  | Bumblebee `apply_rotary_embedding/5`   | `EMLX.Fast.rope_with_positions/6`             |
  | `:sdpa`              | Bumblebee `attention_output_impl/3`    | `EMLX.Fast.scaled_dot_product_attention_causal/4` or unmasked |
  | `:dropout`           | `op_name: :dropout` (inference)        | identity pass-through                          |
  | `:swiglu`            | `:multiply(container(up, silu(gate)))` | `EMLX.Fast.swiglu/2`                          |
  | `:native_attention`  | Bumblebee causal self-attention        | `EMLX.kv_cache_attention_masked/8`            |

  ## Usage

      {:ok, %{model: model, params: params}} = Bumblebee.load_model({:hf, "Qwen/Qwen3-0.6B"})
      model = EMLXAxon.rewrite(model)
      serving = Bumblebee.Text.generation(
        %{model: model, params: params, spec: spec},
        tokenizer, generation_config,
        compile: [batch_size: 1, sequence_length: 256]
      )

  ## Limitations

  - **`:rms_norm`** rewrite requires `shift: 0.0`. Nodes with a non-zero shift
    are skipped because `EMLX.Fast.rms_norm(x, w, eps)` computes `x/rms(x)*w`,
    not `x/rms(x)*(shift+w)`.

  - **`:rotary_embedding`** rewrite assumes sequential position IDs within each
    batch example (standard causal LM). Non-sequential schemes (packed sequences,
    custom position offsets) will produce incorrect results. Bumblebee's
    `apply_rotary_embedding/5` is matched by function identity via `function_info/1` —
    this is tied to Bumblebee's internal implementation and may break across major
    Bumblebee version changes.

    RoPE scaling strategies: `:llama3` precomputes the inv-frequency tensor at
    rewrite time and dispatches to `EMLX.Fast.rope_with_freqs/6`. Other strategies
    (`:linear`, `:dynamic`, `:longrope`) fall back to `rope_with_positions` with the
    standard base frequency — they are not frequency-precomputable because they are
    seq-len conditional or require post-multiply of cos/sin that `mlx::fast::rope`
    cannot absorb.

  - **`:sdpa`** rewrite threads `key_mask` (from input padding) through to the
    C++ NIF, which checks at runtime whether the mask is all-ones and fast-paths
    to the pure causal Metal kernel when it is. Padded batches get a combined
    causal + key_mask additive mask. Sliding-window attention falls back to the
    original `attention_output_impl`. Inference-only: dropout is elided.

  - **`:dropout`** rewrite replaces `op_name: :dropout` nodes with an identity
    pass-through. Dropout at inference time is always a no-op regardless of rate;
    eliminating the NIF-boundary crossings per decode step without any functional
    change. Not appropriate for training graphs.

  - **`:swiglu`** rewrite matches `:multiply` nodes backed by a `:container` node whose
    two parents include one `:silu` node (the Bumblebee SwiGLU pattern:
    `multiply(container(up_proj, silu(gate_proj)))`). Replaces the multiply + container +
    silu triple with a single `EMLX.Fast.swiglu/2` call. Does not match generic
    multiplications or containers without a silu child.

  - **`:attn_weights`** rewrite replaces `:bb_attn_weights` passthrough nodes (added by
    the local Bumblebee patch) with a no-arg constant-zero layer, cutting the
    `attention_weights_impl` sub-graph and its K-side `repeat_interleave` nodes out of the
    reachable graph entirely. Inference-only: the attention weights tensor is never used
    for token generation.

  - **`:if_present`** rewrite replaces Bumblebee's KV-cache conditional nodes with their
    "cache present" branch. In compiled serving the KV cache is always initialized (never
    `%Axon.None{}`), so the else branch is dead code. Removing the `:if_present` nodes
    and their `:optional` wrappers eliminates per-step Axon dispatch overhead without any
    functional change.

  - **`:gqa_cache_fix`** rewrite fixes a shape mismatch that arises when GQA head expansion
    (`repeat_interleave`) runs *before* `update_attention_cache`. The standard Bumblebee
    transformer block expands keys/values from `num_key_value_heads` to `num_attention_heads`
    before the cache update, but the cache is allocated with `num_key_value_heads`. This
    rewrite strips the `repeat_interleave` from the key and value inputs to every
    `update_attention_cache` Axon layer, so the cache receives the compact GQA tensors.
    The SDPA rewriter (`maybe_strip_repeat_interleave`) already handles the expanded-head
    removal on the SDPA side, and MLX fast SDPA handles GQA natively.
  """

  # matches on Bumblebee 0.7+
  @bumblebee_rope_mfa {Bumblebee.Layers, :apply_rotary_embedding, 5}
  @bumblebee_attn_mfa {Bumblebee.Layers, :attention_output_impl, 3}
  @bumblebee_repeat_interleave_mfa {Bumblebee.Layers, :"-repeat_interleave/3-fun-0-", 2}
  @bumblebee_update_attn_cache_mfa {Bumblebee.Layers.Decoder, :update_attention_cache, 5}
  @bumblebee_put_block_cache_mfa {Bumblebee.Layers.Decoder, :"-put_block_cache/3-fun-0-", 3}
  @kv_cache_proc_key :"$emlx_axon_native_attention_kv_cache"
  # Memoizes Nx.to_number(offset_tensor) across all layers of a single forward pass.
  # Stores {MapSet.t(layer_key), cached_integer_offset}.
  @step_offset_proc_key :"$emlx_axon_step_offset_cache"

  @default_rewrites [
    :rms_norm,
    :layer_norm,
    :rotary_embedding,
    :sdpa,
    :dropout,
    :swiglu,
    :attn_weights,
    :if_present,
    :native_attention,
    :nullify_block_cache
  ]

  @doc """
  Rewrites all supported nodes in `model` to their `EMLX.Fast` equivalents.

  ## Options

    * `:only` — list of atoms selecting which rewrites to apply. Defaults to
      `[:rms_norm, :layer_norm, :rotary_embedding, :sdpa, :dropout, :swiglu,
      :attn_weights, :if_present, :native_attention, :nullify_block_cache]`.
      Pass `:gqa_cache_fix` explicitly when targeting a Bumblebee build whose
      `init_cache` allocates the KV cache with `num_key_value_heads` rather than
      `num_attention_heads` (i.e. the upstream PR branch patch).

  ## Example

      model = EMLXAxon.rewrite(model)
      model = EMLXAxon.rewrite(model, only: [:rms_norm, :layer_norm])

  """
  @spec rewrite(Axon.t(), keyword()) :: Axon.t()
  def rewrite(%Axon{} = model, opts \\ []) do
    {:ok, opts} = Keyword.validate(opts, only: @default_rewrites)
    cache = :ets.new(:emlx_axon_rewrite_cache, [:set, :public])

    try do
      enabled = expand_enabled(opts[:only])

      rewriters =
        []
        |> maybe_add(:rms_norm, rms_norm_rewriter(), enabled)
        |> maybe_add(:layer_norm, layer_norm_rewriter(), enabled)
        |> maybe_add(:rotary_embedding, rotary_embedding_rewriter(cache), enabled)
        |> maybe_add(:sdpa, sdpa_rewriter(), enabled)
        |> maybe_add(:dropout, dropout_rewriter(), enabled)
        |> maybe_add(:swiglu, swiglu_rewriter(), enabled)
        |> maybe_add(:attn_weights, attn_weights_rewriter(), enabled)
        |> maybe_add(:if_present, if_present_rewriter(), enabled)
        |> maybe_add(:gqa_cache_fix, gqa_cache_fix_rewriter(), enabled)
        |> maybe_add(:native_attention, native_attention_rewriter(), enabled)
        |> maybe_add(:nullify_block_cache, nullify_block_cache_rewriter(), enabled)

      Axon.rewrite_nodes(model, fn node ->
        Enum.find_value(rewriters, :skip, fn {_key, fun} ->
          case fun.(node) do
            :skip -> nil
            rewriter -> rewriter
          end
        end)
      end)
    after
      :ets.delete(cache)
    end
  end

  # ── load_quantized ───────────────────────────────────────────────────────────

  @doc """
  Loads a quantized Bumblebee model from an MLX-4bit checkpoint directory.

  Combines three steps into one call:

  1. Loads the Axon model structure from `config.json` via `Bumblebee.load_model/2`.
  2. Loads the MLX-4bit safetensors weights via `EMLXAxon.MLX4BitParams.load/1`,
     dequantizing and transposing to Bumblebee `{in, out}` layout (BF16).
  3. Re-quantizes all eligible weight matrices via `EMLXAxon.QuantizeParams.quantize/1`
     so that `Nx.dot` dispatch routes to `EMLX.quantized_matmul` at serving time.

  Returns `{:ok, model_info}` compatible with `Bumblebee.Text.generation/4`.

  ## Usage

      {:ok, model_info} = EMLXAxon.load_quantized({:local, "~/models/Qwen3-0.6B-MLX-4bit"})
      {:ok, tokenizer} = Bumblebee.load_tokenizer({:local, path})
      {:ok, gen_cfg}   = Bumblebee.load_generation_config({:local, path})
      gen_cfg = Bumblebee.configure(gen_cfg, max_new_tokens: 100)

      serving = Bumblebee.Text.generation(model_info, tokenizer, gen_cfg,
        compile: [batch_size: 1, sequence_length: 256],
        defn_options: [compiler: EMLX]
      )

      result = Nx.Serving.run(serving, "The capital of France is")

  ## Notes

  - **Do not apply `EMLXAxon.rewrite/2` after `load_quantized`** — the rotary
    embedding rewrite is incompatible with the standard Bumblebee `native_kv_cache: false`
    path and produces incorrect outputs. BF16 fast ops (rms_norm, swiglu, dropout, sdpa)
    may be added once the rotary embedding rewrite is fixed.

  - Model architecture is inferred from `config.json` in the checkpoint directory.
    Validated with Bumblebee and Qwen3-0.6B.

  - Quantization metadata: QuantizeParams logs shape-mismatch warnings for tensors whose
    physical packed dimensions differ from the Bumblebee model's expected shapes. These
    warnings are benign — the quantized tensors are still used correctly via the EMLX
    backend's quantized_matmul dispatch.
  """
  @spec load_quantized({:local, Path.t()}, keyword()) :: {:ok, map()} | {:error, term()}
  def load_quantized(source, opts \\ [])

  def load_quantized({:local, path}, opts) do
    path = Path.expand(path)
    load_model_opts = Keyword.merge([backend: {EMLX.Backend, device: :gpu}, type: :bf16], opts)

    with {:ok, model_info} <- Bumblebee.load_model({:local, path}, load_model_opts) do
      params = EMLXAxon.MLX4BitParams.load(path)

      bb_keys = model_info.params.data |> Map.keys() |> MapSet.new()
      mlx_keys = params.data |> Map.keys() |> MapSet.new()
      missing = MapSet.difference(bb_keys, mlx_keys)
      extra = MapSet.difference(mlx_keys, bb_keys)

      if MapSet.size(missing) > 0 or MapSet.size(extra) > 0 do
        require Logger

        Logger.warning(
          "EMLXAxon.load_quantized: param key mismatch — " <>
            "#{MapSet.size(missing)} missing from checkpoint, " <>
            "#{MapSet.size(extra)} extra in checkpoint. " <>
            "Missing: #{inspect(Enum.take(Enum.sort(missing), 5))}. " <>
            "This may indicate an unsupported model or mismatched checkpoint."
        )
      end

      quant_params = EMLXAxon.QuantizeParams.quantize(params)
      patched = %{model_info.params | data: quant_params.data}
      {:ok, %{model_info | params: patched}}
    end
  end

  # ── rms_norm ────────────────────────────────────────────────────────────────

  @doc """
  Returns the rewriter function for `rms_norm` nodes.

  Replaces `op_name: :rms_norm` nodes with `shift: 0.0` with an Axon layer
  that calls `EMLX.Fast.rms_norm/3` — a single fused Metal shader.
  """
  @spec rms_norm_rewriter() ::
          (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)
  def rms_norm_rewriter do
    fn
      %Axon.Node{op_name: :rms_norm, opts: node_opts, name: name_fn} ->
        eps = Keyword.get(node_opts, :epsilon, 1.0e-6)
        shift = Keyword.get(node_opts, :shift, 0.0)

        if shift == 0.0 do
          fn [x], _placeholder ->
            # Recreate the weight parameter with the same name and shape as the
            # original rms_norm weight so model_state keys match after loading.
            # Bumblebee always uses channel_index: -1 (last axis) for rms_norm.
            weight =
              Axon.param(
                "weight",
                fn input_shape ->
                  {elem(input_shape, Nx.rank(input_shape) - 1)}
                end,
                initializer: :ones
              )

            Axon.layer(
              fn x, w, op_opts ->
                EMLX.Fast.rms_norm(x, w, op_opts[:epsilon])
              end,
              [x, weight],
              name: name_fn,
              op_name: :fast_rms_norm,
              epsilon: eps
            )
          end
        else
          :skip
        end

      _ ->
        :skip
    end
  end

  # ── layer_norm ───────────────────────────────────────────────────────────────

  @doc """
  Returns the rewriter function for `layer_norm` nodes.

  Replaces `op_name: :layer_norm` nodes (Axon's built-in layer normalisation)
  with an Axon layer that calls `EMLX.Fast.layer_norm/3,4` — a single fused
  Metal shader. Skips nodes where `channel_index` is not `-1` (last axis),
  as the kernel only normalises over the last axis.
  """
  @spec layer_norm_rewriter() ::
          (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)
  def layer_norm_rewriter do
    fn
      %Axon.Node{op_name: :layer_norm, opts: node_opts, name: name_fn} ->
        channel_index = Keyword.get(node_opts, :channel_index, -1)
        eps = Keyword.get(node_opts, :epsilon, 1.0e-5)

        if channel_index == -1 do
          fn inputs, _placeholder ->
            case inputs do
              [x, gamma, beta] ->
                # With bias (gamma + beta as parents)
                Axon.layer(
                  fn x, gamma, beta, op_opts ->
                    EMLX.Fast.layer_norm(x, gamma, beta, op_opts[:epsilon])
                  end,
                  [x, gamma, beta],
                  name: name_fn,
                  op_name: :fast_layer_norm,
                  epsilon: eps
                )

              [x, gamma] ->
                # No bias (weight only)
                Axon.layer(
                  fn x, gamma, op_opts ->
                    EMLX.Fast.layer_norm(x, gamma, op_opts[:epsilon])
                  end,
                  [x, gamma],
                  name: name_fn,
                  op_name: :fast_layer_norm,
                  epsilon: eps
                )
            end
          end
        else
          :skip
        end

      _ ->
        :skip
    end
  end

  # ── rotary_embedding ─────────────────────────────────────────────────────────

  @doc """
  Returns the rewriter function for Bumblebee's `rotary_embedding` nodes.

  Matches `%Axon.Node{op: &Bumblebee.Layers.apply_rotary_embedding/5}` by MFA
  identity via `function_info/1`, then replaces it with an `EMLX.Fast.rope_with_positions/6`
  call on both Q and K. The replacement node returns `{q_rotated, k_rotated}` —
  downstream `Axon.nx(_, &elem(&1, i))` unwrap nodes continue to work unchanged.

  **Assumes sequential positions** — see `EMLXAxon` moduledoc for the limitation.
  """
  @spec rotary_embedding_rewriter(reference() | nil) ::
          (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)
  def rotary_embedding_rewriter(cache \\ nil) do
    fn %Axon.Node{op: op, opts: node_opts, name: name_fn} ->
      if function_info(op) == @bumblebee_rope_mfa do
        size = Keyword.get(node_opts, :size, nil)
        base = Keyword.get(node_opts, :base, 10_000)
        strategy = Keyword.get(node_opts, :scaling_strategy)
        max_pos = Keyword.get(node_opts, :max_positions, 2048)

        # Precompute inv-frequency tensor once per unique hyperparameter combo.
        # Returns nil for unsupported strategies → falls back to rope_with_positions.
        freqs =
          if strategy do
            cached(cache, {:rope_freqs, strategy, size, base, max_pos}, fn ->
              precompute_rope_freqs(strategy, size, base, max_pos)
            end)
          end

        fn [q_axon, k_axon, pos_axon, _mask_axon], _placeholder ->
          Axon.layer(
            fn q, k, pos, _op_opts ->
              # Bumblebee uses rotate_half → traditional: false.
              # freqs and base/size are captured from rewrite-time scope.
              if freqs do
                q_rot = EMLX.Fast.rope_with_freqs(q, pos, size, false, 1.0, freqs)
                k_rot = EMLX.Fast.rope_with_freqs(k, pos, size, false, 1.0, freqs)
                {q_rot, k_rot}
              else
                q_rot = EMLX.Fast.rope_with_positions(q, pos, size, false, base, 1.0)
                k_rot = EMLX.Fast.rope_with_positions(k, pos, size, false, base, 1.0)
                {q_rot, k_rot}
              end
            end,
            [q_axon, k_axon, pos_axon],
            name: name_fn,
            op_name: :fast_rotary_embedding,
            dims: size,
            base: base
          )
        end
      else
        :skip
      end
    end
  end

  # ── dropout ──────────────────────────────────────────────────────────────────

  @doc """
  Returns the rewriter function for `dropout` nodes.

  At inference time, dropout is always a pass-through regardless of rate. This
  rewriter replaces every `:dropout` node with an identity layer, eliminating the
  NIF-boundary crossing without any functional change.

  **Not appropriate for training graphs** — only enable this rewriter when the
  model will be used for inference only.
  """
  @spec dropout_rewriter() ::
          (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)
  def dropout_rewriter do
    fn
      %Axon.Node{op_name: :dropout, name: name_fn} ->
        fn [x], _placeholder ->
          Axon.layer(fn x, _opts -> x end, [x], name: name_fn, op_name: :dropout_identity)
        end

      _ ->
        :skip
    end
  end

  # ── Attention weights elision ─────────────────────────────────────────────

  @doc """
  Returns the rewriter function for Bumblebee aux attention-weights nodes.

  Matches `:bb_attn_weights` nodes — a passthrough layer inserted by the local
  Bumblebee patch around the `{output, weights}` return of `Layers.attention/8`.
  Replaces the node with a no-arg constant-zero layer so the entire
  `attention_weights_impl` sub-graph (and the K-side `repeat_interleave` it
  consumes) becomes unreachable. Inference-only: attention weight tensors are
  never used for token generation.
  """
  @spec attn_weights_rewriter() ::
          (Axon.Node.t() -> :skip | ([Axon.t(), ...], Axon.t() -> Axon.t()))
  def attn_weights_rewriter do
    fn
      %Axon.Node{op_name: :bb_attn_weights} ->
        fn [weights_axon], _placeholder ->
          # Build a constant-zero node and merge its nodes map with the existing
          # graph so rewrite_nodes can locate original_id during the ID swap.
          # The constant has no parents, so attention_weights_impl and K's
          # repeat_interleave become orphaned and are pruned by Axon automatically.
          %Axon{output: const_id, nodes: const_nodes} = Axon.constant(Nx.tensor(0.0))
          merged = Map.merge(weights_axon.nodes, const_nodes)
          %Axon{output: const_id, nodes: merged}
        end

      _ ->
        :skip
    end
  end

  # ── if_present elision ───────────────────────────────────────────────────────

  @doc """
  Returns the rewriter function for `:if_present` nodes.

  Bumblebee wraps every KV-cache operation in `Layers.if_present(cache, ...)` to
  handle the case where no cache is provided. In compiled serving the cache is
  always initialized (never `%Axon.None{}`), so the conditional is dead code.

  The rewriter unconditionally selects the "cache present" branch (`on_true`) and
  lets the "no cache" branch (`on_false`) and all its `:optional` wrappers become
  unreachable, pruning them from the compiled graph.

  **Do not enable for training graphs** — training models typically run without a
  KV cache and rely on the `else` branch.
  """
  @spec if_present_rewriter() ::
          (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)
  def if_present_rewriter do
    fn
      # 3-parent :if_present: [optional(condition), optional(on_true), optional(on_false)]
      %Axon.Node{op_name: :if_present, parent: [_, _, _]} ->
        fn [_cond_opt, on_true_opt, _on_false_opt], _placeholder ->
          # Skip the :optional wrapper to return the underlying on_true node.
          # The cache is always non-None in compiled serving.
          optional_node = on_true_opt.nodes[on_true_opt.output]

          on_true_id =
            case {optional_node.op_name, optional_node.parent} do
              {:optional, [id]} -> id
              # If the node isn't an :optional for some reason, use it as-is.
              _ -> on_true_opt.output
            end

          %Axon{output: on_true_id, nodes: on_true_opt.nodes}
        end

      _ ->
        :skip
    end
  end

  # ── GQA cache fix ────────────────────────────────────────────────────────────

  @doc """
  Returns the rewriter function for GQA key/value cache shape fix.

  In the standard Bumblebee transformer block, GQA head expansion via
  `repeat_interleave` (expanding from `num_key_value_heads` to `num_attention_heads`)
  is applied to the key and value tensors *before* `update_attention_cache`. However,
  `init_cache` allocates the preallocated buffer with `num_key_value_heads`, causing
  a shape mismatch at Axon compile time when `num_key_value_heads < num_attention_heads`.

  This rewriter fixes the graph by stripping the `repeat_interleave` node from the
  key and value inputs of every `update_attention_cache` layer. The cache update then
  operates on the compact GQA tensors. The SDPA rewriter (`maybe_strip_repeat_interleave`)
  separately handles the expanded-head removal on the SDPA path, and MLX fast SDPA
  handles GQA natively without explicit head repetition.

  Only applies when the key or value parent is a Bumblebee `repeat_interleave` node.
  Models without GQA (or where repeat_interleave is already absent) are unaffected.
  """
  @spec gqa_cache_fix_rewriter() ::
          (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)
  def gqa_cache_fix_rewriter do
    fn
      %Axon.Node{op: op, parent: [_key_id, _value_id | _]} ->
        if function_info(op) == @bumblebee_update_attn_cache_mfa do
          fn [key_axon, value_axon, cache_axon, offset_axon], _placeholder ->
            # Strip GQA head expansion from K and V so the cache update receives
            # {B, 1, Hkv, D} tensors matching the preallocated cache shape.
            key_raw = maybe_strip_repeat_interleave(key_axon.nodes, key_axon.output)
            val_raw = maybe_strip_repeat_interleave(value_axon.nodes, value_axon.output)
            Axon.layer(op, [key_raw, val_raw, cache_axon, offset_axon])
          end
        else
          :skip
        end

      _ ->
        :skip
    end
  end

  # ── SwiGLU ──────────────────────────────────────────────────────────────────

  @doc """
  Returns the rewriter function for SwiGLU nodes.

  Matches `:multiply` nodes backed by a single `:container` parent whose two
  children include one `:silu` node (the Bumblebee SwiGLU pattern:
  `multiply(container(up_proj, silu(gate_proj)))`). Replaces the
  multiply + container + silu triple with a single `EMLX.Fast.swiglu/2` call,
  passing the gate's raw input (pre-silu) and the up-projection directly to the
  fused NIF.

  Generic `:multiply` nodes (no `:container` parent, or container without a
  `:silu` child) are reconstructed identically.
  """
  @spec swiglu_rewriter() ::
          (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)
  def swiglu_rewriter do
    fn
      # Bumblebee's SwiGLU: multiply(container(up_proj, silu(gate))).
      # The :multiply node has ONE parent (the container), and the container
      # has TWO parents: one :silu (the gate activation) and one other node
      # (the up-projection).
      %Axon.Node{op_name: :multiply, parent: [container_id], name: name_fn} ->
        fn [container_axon], _placeholder ->
          node_map = container_axon.nodes

          case detect_swiglu_pattern(node_map, container_id) do
            {gate_id, up_id} ->
              gate_axon = %Axon{output: gate_id, nodes: node_map}
              up_axon = %Axon{output: up_id, nodes: node_map}

              Axon.layer(
                fn gate, up, _opts -> EMLX.Fast.swiglu(gate, up) end,
                [gate_axon, up_axon],
                name: name_fn,
                op_name: :fast_swiglu
              )

            :skip ->
              # Not a SwiGLU container — reconstruct the original multiply.
              Axon.layer(
                fn container, _opts -> Nx.multiply(elem(container, 0), elem(container, 1)) end,
                [container_axon],
                name: name_fn,
                op_name: :multiply
              )
          end
        end

      _ ->
        :skip
    end
  end

  # Checks if `container_id` is a :container node with exactly two parents, one
  # of which is a :silu node with a single parent (the gate input).
  # Returns {gate_id, up_id} on match, or :skip otherwise.
  defp detect_swiglu_pattern(nodes, container_id) do
    container = nodes[container_id]

    with %Axon.Node{op_name: :container, parent: [a_id, b_id]} <- container do
      cond do
        nodes[a_id].op_name == :silu and length(nodes[a_id].parent) == 1 ->
          # a is silu(gate), b is up_proj
          {hd(nodes[a_id].parent), b_id}

        nodes[b_id].op_name == :silu and length(nodes[b_id].parent) == 1 ->
          # b is silu(gate), a is up_proj
          {hd(nodes[b_id].parent), a_id}

        true ->
          :skip
      end
    else
      _ -> :skip
    end
  end

  # ── SDPA ─────────────────────────────────────────────────────────────────────

  @doc """
  Returns the rewriter function for Bumblebee's attention output nodes.

  Matches `%Axon.Node{op: &Bumblebee.Layers.attention_output_impl/3}` by MFA
  identity via `function_info/1`, then navigates up through the dropout and
  `attention_weights_impl` nodes to recover Q and K, and replaces the whole
  attention chain with a single `EMLX.Fast` SDPA call.

  - **`causal: true`, no window_size** — uses `scaled_dot_product_attention_causal_key_masked/5`.
    The `key_mask` is threaded through; the C++ NIF checks if it is all-ones at
    runtime and dispatches to the pure causal Metal kernel (no mask allocation) or
    builds a combined additive mask for padded batches.
  - **`causal: false`, no window_size** — uses `scaled_dot_product_attention/4`
    (unmasked, for cross-attention or prefix LM heads).
  - **`window_size` set** — re-applies the original `attention_output_impl` unchanged.

  **Inference-only**: attention dropout is elided (a no-op at inference time). Nodes
  with `dropout_rate > 0` are skipped to preserve training-time stochastic behaviour.
  """
  @spec sdpa_rewriter() ::
          (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)
  def sdpa_rewriter do
    fn %Axon.Node{op: op, opts: node_opts} ->
      dropout_rate = Keyword.get(node_opts, :dropout_rate, 0.0)

      if function_info(op) == @bumblebee_attn_mfa and dropout_rate == 0.0 do
        original_op = op

        fn [weights_dropped_axon, v_axon], _placeholder ->
          nodes = weights_dropped_axon.nodes
          weights_dropped_id = weights_dropped_axon.output

          # Navigate: weights_dropped (dropout) → attention_weights_impl
          dropout_node = nodes[weights_dropped_id]
          [attn_weights_id] = dropout_node.parent
          attn_weights_node = nodes[attn_weights_id]

          causal = Keyword.get(attn_weights_node.opts, :causal, false)
          window_size = Keyword.get(attn_weights_node.opts, :window_size)
          scale_opt = Keyword.get(attn_weights_node.opts, :scale)

          # parents: [q_id, k_id, key_mask_id, head_mask_id, bias_id, offset_id]
          [q_id, k_id, key_mask_id | _] = attn_weights_node.parent

          cond do
            not is_nil(window_size) ->
              # Sliding-window attention: fall back to original.
              Axon.layer(original_op, [weights_dropped_axon, v_axon])

            causal ->
              # Causal SDPA. The key_mask check runs at C++ level: if all-ones
              # (no padding — the common single-sequence case), dispatches to the
              # pure causal Metal kernel. Otherwise builds a combined additive
              # mask. No Nx.cond double-evaluation required.
              q_axon = %Axon{output: q_id, nodes: nodes}
              # Strip repeat_interleave from K and V if present — after the local
              # Bumblebee GQA cache patch, expand runs after cache write, so
              # mlx::fast::sdpa receives raw 8-head K/V and handles GQA natively.
              k_axon = maybe_strip_repeat_interleave(nodes, k_id)
              v_axon_8 = maybe_strip_repeat_interleave(nodes, v_axon.output)

              if System.get_env("NATIVE_ATTN_DEBUG") in ["1", "2"] do
                km_node = nodes[key_mask_id]
                km_inner = maybe_unwrap_optional(nodes, key_mask_id)

                if km_node do
                  IO.puts(
                    "[sdpa_build] key_mask_id=#{inspect(key_mask_id)} op_name=#{km_node.op_name} inner_id=#{inspect(km_inner)}"
                  )
                end
              end

              key_mask_axon = %Axon{output: key_mask_id, nodes: nodes}
              build_sdpa_layer(q_axon, k_axon, v_axon_8, key_mask_axon, scale_opt)

            true ->
              # Non-causal, no mask (e.g. cross-attention in encoder-decoder models).
              q_axon = %Axon{output: q_id, nodes: nodes}
              k_axon = %Axon{output: k_id, nodes: nodes}
              build_sdpa_layer(q_axon, k_axon, v_axon, scale_opt, :none)
          end
        end
      else
        :skip
      end
    end
  end

  # If `node_id` is a `repeat_interleave` custom node (GQA head expand), return
  # an Axon wrapping its single parent (the 8-head tensor). Otherwise return the
  # node unchanged. Safe to call when the patch is absent — the else branch is a no-op.
  defp maybe_strip_repeat_interleave(nodes, node_id) do
    node = nodes[node_id]

    if node.op_name == :custom and
         function_info(node.op) == @bumblebee_repeat_interleave_mfa and
         length(node.parent) == 1 do
      %Axon{output: hd(node.parent), nodes: nodes}
    else
      %Axon{output: node_id, nodes: nodes}
    end
  end

  # Causal SDPA with key_mask: delegates all-ones check to the C++ NIF.
  defp build_sdpa_layer(q_axon, k_axon, v_axon, key_mask_axon, scale_opt)
       when is_struct(key_mask_axon, Axon) do
    Axon.layer(
      fn q, k, v, key_mask, op_opts ->
        # Q, K, V arrive in {B, T, N, D}. SDPA expects {B, N, T, D}.
        q_t = Nx.transpose(q, axes: [0, 2, 1, 3])
        k_t = Nx.transpose(k, axes: [0, 2, 1, 3])
        v_t = Nx.transpose(v, axes: [0, 2, 1, 3])

        head_dim = elem(Nx.shape(q_t), 3)
        scale = op_opts[:scale] || 1.0 / :math.sqrt(head_dim)

        out =
          EMLX.Fast.scaled_dot_product_attention_causal_key_masked(q_t, k_t, v_t, scale, key_mask)

        Nx.transpose(out, axes: [0, 2, 1, 3])
      end,
      [q_axon, k_axon, v_axon, key_mask_axon],
      op_name: :fast_sdpa,
      scale: if(is_number(scale_opt), do: scale_opt, else: nil)
    )
  end

  # Non-causal SDPA (no mask).
  defp build_sdpa_layer(q_axon, k_axon, v_axon, scale_opt, :none) do
    Axon.layer(
      fn q, k, v, op_opts ->
        q = Nx.transpose(q, axes: [0, 2, 1, 3])
        k = Nx.transpose(k, axes: [0, 2, 1, 3])
        v = Nx.transpose(v, axes: [0, 2, 1, 3])

        head_dim = elem(Nx.shape(q), 3)
        scale = op_opts[:scale] || 1.0 / :math.sqrt(head_dim)

        out = EMLX.Fast.scaled_dot_product_attention(q, k, v, scale)
        Nx.transpose(out, axes: [0, 2, 1, 3])
      end,
      [q_axon, k_axon, v_axon],
      op_name: :fast_sdpa,
      scale: if(is_number(scale_opt), do: scale_opt, else: nil)
    )
  end

  # ── Native KV attention ─────────────────────────────────────────────────────

  @doc """
  Returns the rewriter function for Bumblebee causal self-attention nodes.

  This rewrite replaces the attention output with a single `Nx.runtime_call`
  callback that updates a process-local ETS K/V cache and calls
  `EMLX.kv_cache_attention_masked/8`. It intentionally only matches causal
  attention without sliding-window masking; cross-attention and local attention
  fall back to the original graph.
  """
  @spec native_attention_rewriter() ::
          (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)
  def native_attention_rewriter do
    fn %Axon.Node{op: op, opts: node_opts} ->
      dropout_rate = Keyword.get(node_opts, :dropout_rate, 0.0)

      if function_info(op) == @bumblebee_attn_mfa and dropout_rate == 0.0 do
        original_op = op

        fn [weights_dropped_axon, v_axon], _placeholder ->
          nodes = weights_dropped_axon.nodes
          weights_dropped_id = weights_dropped_axon.output

          dropout_node = nodes[weights_dropped_id]
          [attn_weights_id] = dropout_node.parent
          attn_weights_node = nodes[attn_weights_id]

          causal = Keyword.get(attn_weights_node.opts, :causal, false)
          window_size = Keyword.get(attn_weights_node.opts, :window_size)
          scale_opt = Keyword.get(attn_weights_node.opts, :scale)

          # parents: [q_id, k_id, key_mask_id, head_mask_id, bias_id, offset_id]
          [q_id, k_id, key_mask_id, _head_mask_id, _bias_id, _offset_id] =
            attn_weights_node.parent

          with true <- causal,
               true <- is_nil(window_size),
               {:ok, k_key_id, _k_value_id, _k_cache_id, k_offset_id} <-
                 find_update_attention_cache(nodes, k_id),
               {:ok, _v_key_id, v_value_id, _v_cache_id, v_offset_id} <-
                 find_update_attention_cache(nodes, v_axon.output),
               true <- k_offset_id == v_offset_id do
            q_axon = %Axon{output: q_id, nodes: nodes}
            new_k_axon = maybe_strip_repeat_interleave(nodes, k_key_id)
            new_v_axon = maybe_strip_repeat_interleave(nodes, v_value_id)
            offset_axon = %Axon{output: k_offset_id, nodes: nodes}

            key_mask_axon = %Axon{output: key_mask_id, nodes: nodes}

            build_native_attention_layer(
              q_axon,
              new_k_axon,
              new_v_axon,
              offset_axon,
              key_mask_axon,
              scale_opt
            )
          else
            reason ->
              IO.puts(
                "[native_attention_rewriter] FALLTHROUGH causal=#{causal} window=#{inspect(window_size)} reason=#{inspect(reason)}"
              )

              Axon.layer(original_op, [weights_dropped_axon, v_axon])
          end
        end
      else
        :skip
      end
    end
  end

  defp build_native_attention_layer(
         q_axon,
         new_k_axon,
         new_v_axon,
         offset_axon,
         key_mask_axon,
         scale_opt
       ) do
    layer_key = make_ref()

    Axon.layer(
      fn q, new_k, new_v, offset, key_mask, op_opts ->
        out = Nx.template(Nx.shape(q), Nx.type(q))
        head_dim = elem(Nx.shape(q), 3)
        scale = op_opts[:scale] || 1.0 / :math.sqrt(head_dim)

        # Both prefill and decode are handled entirely inside the callback:
        # - Prefill (t_new > 1): callback computes SDPA eagerly, stores K/V in ETS.
        # - Decode  (t_new == 1): callback reads ETS cache, calls kv_cache_attention_masked.
        Nx.runtime_call(
          out,
          {q, new_k, new_v, offset, key_mask},
          [layer_key: op_opts[:layer_key], scale: scale],
          &__MODULE__.native_kv_attn_callback/2
        )
      end,
      [q_axon, new_k_axon, new_v_axon, offset_axon, key_mask_axon],
      op_name: :native_kv_attention,
      layer_key: layer_key,
      scale: if(is_number(scale_opt), do: scale_opt, else: nil)
    )
  end

  @doc false
  def native_kv_attn_callback(
        {query, new_k, new_v, offset_tensor, key_mask},
        opts
      ) do
    layer_key = Keyword.fetch!(opts, :layer_key)
    t_new = elem(Nx.shape(new_k), 1)

    if t_new > 1 do
      # Prefill path: always read the actual offset tensor.
      #
      # The step-offset cache may hold a stale value from a previous serving's
      # last decode step when this layer_key is brand new (never seen before).
      # For prefill the optimization is irrelevant (only 1 GPU→CPU sync per
      # layer anyway since it's a single step), so bypass get_step_offset.
      #
      # Also accumulate this layer_key into the seen set so decode step 1 finds
      # it already "seen" and issues a fresh read instead of using the (now
      # correct) prefill cached_offset at the wrong decode position.
      offset = Nx.to_number(offset_tensor)
      register_prefill_layer(layer_key, offset)

      if offset == 0 do
        native_kv_prefill(query, new_k, new_v, key_mask, layer_key, opts)
      else
        native_kv_decode(query, new_k, new_v, offset, key_mask, layer_key, opts)
      end
    else
      offset = get_step_offset(offset_tensor, layer_key)
      native_kv_decode(query, new_k, new_v, offset, key_mask, layer_key, opts)
    end
  end

  # Registers a layer_key seen during the prefill step.
  # Accumulates all prefill keys into the seen set so that decode step 1
  # will find each layer_key already "seen" and issue a fresh step-boundary read
  # at the correct decode offset rather than reusing the prefill offset (0).
  defp register_prefill_layer(layer_key, offset) do
    new_state =
      case Process.get(@step_offset_proc_key) do
        nil -> {MapSet.new([layer_key]), offset}
        {seen, _prev} -> {MapSet.put(seen, layer_key), offset}
      end

    Process.put(@step_offset_proc_key, new_state)
  end

  defp native_kv_prefill(query, new_k, new_v, key_mask, layer_key, opts) do
    t_new = elem(Nx.shape(new_k), 1)
    scale = Keyword.fetch!(opts, :scale)

    # Bumblebee compiles with max_length = seq_length + max_new_tokens, so
    # key_mask is {B, max_length} while new_k is {B, seq_length, N_kv, D}.
    # Pad new_k/new_v with zeros to max_length before storing so the
    # decode path can retrieve a buffer of the expected size.
    max_len =
      case Nx.shape(key_mask) do
        {_, max} -> max
        {_, _, _, max} -> max
      end

    pad_len = max_len - t_new
    {b, _, nkv, d} = Nx.shape(new_k)
    type = Nx.type(new_k)

    {k_full, v_full} =
      if pad_len > 0 do
        zeros_k = Nx.broadcast(Nx.tensor(0, type: type), {b, pad_len, nkv, d})
        zeros_v = Nx.broadcast(Nx.tensor(0, type: type), {b, pad_len, nkv, d})
        {Nx.concatenate([new_k, zeros_k], axis: 1), Nx.concatenate([new_v, zeros_v], axis: 1)}
      else
        {new_k, new_v}
      end

    # Store raw {dev, ref} tuples — no ETS, no to_nx overhead for k/v.
    cache_map = Process.get(@kv_cache_proc_key, %{})

    Process.put(
      @kv_cache_proc_key,
      Map.put(cache_map, layer_key, {EMLX.Backend.from_nx(k_full), EMLX.Backend.from_nx(v_full)})
    )

    # Compute prefill SDPA eagerly here in the callback (avoids splitting the
    # computation across the Axon layer boundary and the Nx.add(sdpa, zeros) trick).
    # Use k_full/v_full (padded to max_length) and the FULL key_mask rather than
    # slicing both to t_new. This exactly matches the default sdpa rewriter path
    # (T_kv = max_length), ensuring identical NaN-propagation behavior on Metal
    # for left-padded input sequences.

    # Q/K/V: {B, T, N, D} → transpose to {B, N, T, D} for the NIF.
    q_t = Nx.transpose(query, axes: [0, 2, 1, 3])
    k_t = Nx.transpose(k_full, axes: [0, 2, 1, 3])
    v_t = Nx.transpose(v_full, axes: [0, 2, 1, 3])

    # kv_offset = 0 for prefill (lower-triangular causal mask from position 0).
    sdpa_t =
      EMLX.fast_sdpa_causal_key_masked(
        EMLX.Backend.from_nx(q_t),
        EMLX.Backend.from_nx(k_t),
        EMLX.Backend.from_nx(v_t),
        scale,
        EMLX.Backend.from_nx(key_mask),
        0
      )
      |> EMLX.Backend.to_nx()

    # Transpose back to {B, T, N, D} and match query dtype.
    out = Nx.transpose(sdpa_t, axes: [0, 2, 1, 3])
    if Nx.type(out) == Nx.type(query), do: out, else: Nx.as_type(out, Nx.type(query))
  end

  defp native_kv_decode(query, new_k, new_v, offset, key_mask, layer_key, opts) do
    t_new = elem(Nx.shape(new_k), 1)
    valid_len = offset + t_new

    {batch_size, max_length, mask_axis} =
      case Nx.shape(key_mask) do
        {batch_size, max_length} -> {batch_size, max_length, 1}
        {batch_size, _heads, _query_len, max_length} -> {batch_size, max_length, 3}
      end

    {_, _, kv_heads, head_dim} = Nx.shape(new_k)
    full_shape = {batch_size, max_length, kv_heads, head_dim}
    type = Nx.type(new_k)
    scale = Keyword.fetch!(opts, :scale)

    # Get cached refs directly from process dict — no ETS lookup, no term copying.
    cache_map = Process.get(@kv_cache_proc_key, %{})

    {k_cache_ref, v_cache_ref} =
      case Map.get(cache_map, layer_key) do
        nil ->
          # No prefill ran — initialize zero-filled cache buffer.
          zeros = Nx.broadcast(Nx.tensor(0, type: type), full_shape)
          {EMLX.Backend.from_nx(zeros), EMLX.Backend.from_nx(zeros)}

        refs ->
          refs
      end

    key_mask_sliced = Nx.slice_along_axis(key_mask, 0, valid_len, axis: mask_axis)

    {attn_ref, k_upd_ref, v_upd_ref} =
      EMLX.kv_cache_attention_masked(
        EMLX.Backend.from_nx(query),
        EMLX.Backend.from_nx(new_k),
        EMLX.Backend.from_nx(new_v),
        k_cache_ref,
        v_cache_ref,
        offset,
        scale,
        EMLX.Backend.from_nx(key_mask_sliced)
      )

    # Store raw refs directly — no to_nx for k/v, no ETS insert.
    Process.put(@kv_cache_proc_key, Map.put(cache_map, layer_key, {k_upd_ref, v_upd_ref}))

    attn_out = EMLX.Backend.to_nx(attn_ref)

    if Nx.type(attn_out) == Nx.type(query) do
      attn_out
    else
      Nx.as_type(attn_out, Nx.type(query))
    end
  end

  defp find_update_attention_cache(nodes, node_id, seen \\ MapSet.new()) do
    cond do
      MapSet.member?(seen, node_id) ->
        :error

      node = nodes[node_id] ->
        if function_info(node.op) == @bumblebee_update_attn_cache_mfa do
          [key_id, value_id, cache_id, offset_id] = node.parent
          {:ok, key_id, value_id, cache_id, offset_id}
        else
          seen = MapSet.put(seen, node_id)

          # For :if_present nodes, only follow parent[1] (the on_true / cache-present branch).
          # parent[0] is optional(condition) which chains through put_block_cache to OTHER
          # layers' UAC nodes. parent[2] is the fallback (no cache). Only parent[1] leads
          # to the current layer's own UAC.
          parents_to_search =
            if node.op_name == :if_present do
              case node.parent do
                [_cond, on_true, _on_false] -> [on_true]
                _ -> node.parent
              end
            else
              node.parent
            end

          Enum.find_value(parents_to_search, :error, fn parent_id ->
            case find_update_attention_cache(nodes, parent_id, seen) do
              {:ok, _key_id, _value_id, _cache_id, _offset_id} = found -> found
              :error -> nil
            end
          end)
        end

      true ->
        :error
    end
  end

  defp maybe_unwrap_optional(nodes, node_id) do
    case nodes[node_id] do
      %Axon.Node{op_name: :optional, parent: [inner_id]} -> inner_id
      _ -> node_id
    end
  end

  @doc """
  Returns the rewriter function for Bumblebee block-cache update nodes.

  When native attention owns K/V state in an ETS table, the Axon block-cache
  update chain is dead. Replacing `put_block_cache` with an identity lets DCE
  prune `get_block_cache`, `update_attention_cache`, and container plumbing.
  """
  @spec nullify_block_cache_rewriter() ::
          (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)
  def nullify_block_cache_rewriter do
    fn %Axon.Node{op: op} ->
      if function_info(op) == @bumblebee_put_block_cache_mfa do
        fn [cache_axon, _block_cache_axon], _placeholder -> cache_axon end
      else
        :skip
      end
    end
  end

  # Memoizes Nx.to_number(offset_tensor) within a single forward pass (decode step).
  #
  # All layer callbacks within one forward pass share the same offset value. Calling
  # Nx.to_number once per layer forces one GPU→CPU sync per layer. Instead, call it
  # once per step: detect the step boundary by tracking which layer_keys have been
  # called in the current step. When a layer_key appears AGAIN (it was already seen in
  # the previous step), we know a new step has begun and issue one fresh Nx.to_number;
  # all other layers reuse the cache.
  defp get_step_offset(offset_tensor, layer_key) do
    case Process.get(@step_offset_proc_key) do
      nil ->
        offset = Nx.to_number(offset_tensor)
        Process.put(@step_offset_proc_key, {MapSet.new([layer_key]), offset})
        offset

      {seen, cached_offset} ->
        if MapSet.member?(seen, layer_key) do
          # layer_key already seen in the prior step cycle → this is the start of a new step.
          offset = Nx.to_number(offset_tensor)
          Process.put(@step_offset_proc_key, {MapSet.new([layer_key]), offset})
          offset
        else
          Process.put(@step_offset_proc_key, {MapSet.put(seen, layer_key), cached_offset})
          cached_offset
        end
    end
  end

  # ── RoPE frequency precomputation ────────────────────────────────────────────

  # Returns a precomputed {dims/2} inv-frequency tensor for strategies where
  # the frequency vector can be baked at graph-rewrite time and passed to the
  # fast::rope freqs overload.  Returns nil for strategies that are seq-len
  # conditional or require post-multiply of cos/sin (:longrope, :linear,
  # :dynamic), falling back to rope_with_positions with the standard base freq.

  # :llama3 — smooth ramp interpolation between low- and high-freq components.
  # strategy_opts must provide :factor, :low_freq_factor, :high_freq_factor,
  # and :original_max_positions (or we use Meta's published defaults).
  defp precompute_rope_freqs({:llama3, strategy_opts}, size, base, _max_positions) do
    factor = Map.get(strategy_opts, :factor, 8.0)
    low_freq_factor = Map.get(strategy_opts, :low_freq_factor, 1.0)
    high_freq_factor = Map.get(strategy_opts, :high_freq_factor, 4.0)
    original_max_pos = Map.get(strategy_opts, :original_max_positions, 8_192)

    dims = div(size, 2)
    range = Nx.iota({dims}) |> Nx.multiply(2) |> Nx.divide(size)
    inv_freq = Nx.divide(1.0, Nx.pow(base, range))

    wavelen = Nx.multiply(2.0 * :math.pi(), Nx.divide(1.0, inv_freq))
    low_wavelen = original_max_pos / low_freq_factor
    high_wavelen = original_max_pos / high_freq_factor

    ramp =
      Nx.clip(
        Nx.divide(Nx.subtract(wavelen, high_wavelen), low_wavelen - high_wavelen),
        0.0,
        1.0
      )

    Nx.add(
      Nx.multiply(Nx.subtract(1.0, ramp), Nx.divide(inv_freq, factor)),
      Nx.multiply(ramp, inv_freq)
    )
  end

  # All other strategies fall back to rope_with_positions (standard base freq).
  defp precompute_rope_freqs(_strategy, _size, _base, _max_positions), do: nil

  # ── Helpers ──────────────────────────────────────────────────────────────────

  @doc """
  Extracts `{module, name, arity}` from a function reference, or returns `nil`
  for non-function values.

  Works for both named functions (`def`/`defp`/`defn`/`defnp`) and closures.
  Closures report the module where they were defined and a generated name like
  `"-foo/2-fun-0-"`, which is distinct from any hand-written function name and
  therefore safe to use in MFA comparisons.

  Note: Nx's `defnp` may compile to a closure rather than a named function,
  so this helper intentionally does not filter by `:erlang.fun_info(:type)`.
  """
  @spec function_info(term()) :: {module(), atom(), non_neg_integer()} | nil
  def function_info(fun) when is_function(fun) do
    {:module, m} = Function.info(fun, :module)
    {:name, n} = Function.info(fun, :name)
    {:arity, a} = Function.info(fun, :arity)
    {m, n, a}
  end

  def function_info(_), do: nil

  defp cached(nil, _key, compute_fn), do: compute_fn.()

  defp cached(cache, key, compute_fn) do
    case :ets.lookup(cache, key) do
      [{^key, value}] ->
        value

      [] ->
        value = compute_fn.()
        :ets.insert(cache, {key, value})
        value
    end
  end

  defp expand_enabled(enabled) do
    if :native_attention in enabled do
      Enum.uniq([:if_present | enabled])
    else
      enabled
    end
  end

  defp maybe_add(acc, key, fun, enabled) do
    if key in enabled, do: [{key, fun} | acc], else: acc
  end
end

defmodule EMLXAxon.QuantizeParams do
  @moduledoc """
  Post-load param quantization for Bumblebee models.

  Traverses a Bumblebee params map and quantizes eligible 2-D weight tensors to
  4-bit so that `Nx.dot` dispatches to `EMLX.quantized_matmul` via the backend's
  transparent dispatch (A6-1 of the emlx#108 investigation).

  ## Usage

      {:ok, model_info} = Bumblebee.load_model(source, backend: {EMLX.Backend, device: :gpu})
      model_info = %{model_info | params: EMLXAxon.QuantizeParams.quantize(model_info.params)}
      model_info = %{model_info | model: EMLXAxon.rewrite(model_info.model)}

  ## Eligibility

  A tensor is quantized if ALL of the following hold:
  - rank is 2
  - first dimension (in_features) is divisible by `group_size` (default 64)
  - first dimension < `skip_vocab_threshold` (default 100_000) — skips embed_tokens / lm_head
  - both dimensions ≥ `2 * group_size`
  """

  @doc """
  Traverse `params` and quantize all eligible weight tensors.

  ## Options

    * `:bits` — quantization bit-width, 4 (default) or 8.
    * `:group_size` — quantization group size, must evenly divide in_features (default 64).
    * `:skip_vocab_threshold` — skip tensors whose first dim exceeds this (default 100_000).
  """
  @spec quantize(map(), keyword()) :: map()
  def quantize(params, opts \\ []) do
    bits = Keyword.get(opts, :bits, 4)
    group_size = Keyword.get(opts, :group_size, 64)
    skip_vocab = Keyword.get(opts, :skip_vocab_threshold, 100_000)

    deep_map(params, fn tensor ->
      if eligible?(tensor, group_size, skip_vocab) do
        original_type = Nx.type(tensor)
        original_shape = Nx.shape(tensor)

        # Bumblebee uses {in_features, out_features} but MLX quantize expects
        # {out_features, in_features} and packs along the last (in_features) dim.
        # Transpose first, then quantize, so the physical storage is {out, in/8}.
        # Set cfg.transpose=true so quantized_dot calls mlx::quantized_matmul with
        # transpose=true (act @ dequant(w).T = act_{...,in} @ {in,out} = {..out}) ✓
        qw = EMLX.quantize(Nx.transpose(tensor), type: {:s, bits}, group_size: group_size)

        # Patch the config and restore the Bumblebee {in, out} logical shape + type
        # so Axon's shape-checking and Nx.dot's right_axes=[0] remain unaware of
        # the internal {out, in/8} physical layout.
        new_cfg = %{qw.data.quantization_config | transpose: true}
        new_data = %{qw.data | quantization_config: new_cfg}
        %Nx.Tensor{} = qw
        %{qw | data: new_data, shape: original_shape, type: original_type}
      else
        tensor
      end
    end)
  end

  # in_features is the first dim (rows) in Bumblebee {in, out} convention.
  # After transposing to {out, in}, quantize packs along in_features (last dim),
  # so we check rem(rows, group_size) == 0.
  defp eligible?(%Nx.Tensor{} = tensor, group_size, skip_vocab) do
    Nx.rank(tensor) == 2 and
      not EMLX.Quantization.quantized?(tensor) and
      (fn {rows, cols} ->
         rem(rows, group_size) == 0 and
           rows >= 2 * group_size and
           cols >= 2 * group_size and
           rows < skip_vocab
       end).(Nx.shape(tensor))
  end

  # Recursively traverse nested maps/lists, applying fun to Nx.Tensor leaves.
  defp deep_map(%Nx.Tensor{} = tensor, fun), do: fun.(tensor)

  # Axon.ModelState: only traverse `data` (contains params); leave metadata fields alone.
  defp deep_map(%Axon.ModelState{data: data} = model_state, fun) do
    %{model_state | data: deep_map(data, fun)}
  end

  # Plain maps (not structs): recurse into values.
  defp deep_map(map, fun) when is_map(map) and not is_struct(map) do
    Map.new(map, fn {k, v} -> {k, deep_map(v, fun)} end)
  end

  defp deep_map(list, fun) when is_list(list), do: Enum.map(list, &deep_map(&1, fun))
  defp deep_map(other, _fun), do: other
end