Skip to main content

lib/nx_vulkan.ex

defmodule Nx.Vulkan do
  @moduledoc """
  Nx tensor backend on Vulkan compute.

  Wraps Spirit's Vulkan compute kernels (elementwise, reductions,
  matmul, random) as an `Nx.Backend`. Works on FreeBSD with NVIDIA
  hardware where CUDA does not. Same backend code runs on Linux,
  macOS (via MoltenVK), and any Vulkan-capable GPU.

  ## Status

  v0.0.1 — bootstrap. The plan in `PLAN.md` lays out the 10-milestone
  path to v0.1. This release just initializes Vulkan and reports
  which physical device was selected; tensor types and operators
  land in subsequent commits.

  ## Usage (target)

      iex> Nx.Vulkan.init()
      :ok

      iex> Nx.tensor([1.0, 2.0, 3.0], backend: Nx.Vulkan.Backend)
      ...

      iex> Nx.Defn.default_options(default_backend: Nx.Vulkan.Backend)

  ## Files

    * `lib/nx_vulkan.ex`           - this module (top-level API)
    * `lib/nx_vulkan/native.ex`    - Rustler NIF binding (skeleton)
    * `lib/nx_vulkan/backend.ex`   - `Nx.Backend` impl (TBD)
    * `native/nx_vulkan_native/`   - Rust NIF crate
    * `c_src/`                     - extern "C" shim into spirit's
                                     C++ Vulkan backend
  """

  @doc """
  Initialize the Vulkan compute context. Call once at startup.
  Returns `:ok` on success, `{:error, reason}` if no Vulkan-capable
  device is found.
  """
  defdelegate init(), to: Nx.Vulkan.Native

  @doc """
  Returns the name of the selected physical device, or `nil` if
  `init/0` has not been called.
  """
  defdelegate device_name(), to: Nx.Vulkan.Native

  @doc "Returns true if the selected device supports f64 (double precision)."
  defdelegate has_f64?(), to: Nx.Vulkan.Native, as: :has_f64

  # ------------------------------------------------------------------
  # v0.0.2 — tensor lifetime + round-trip
  # ------------------------------------------------------------------

  @doc """
  Upload a list of f32 values to a freshly-allocated GPU buffer.
  Returns `{:ok, tensor_ref}` where `tensor_ref` is an opaque
  `ResourceArc` whose underlying VkBuf is freed when GC'd.

      iex> Nx.Vulkan.init()
      :ok
      iex> {:ok, t} = Nx.Vulkan.upload_f32([1.0, 2.0, 3.0, 4.0])
      iex> {:ok, [1.0, 2.0, 3.0, 4.0]} = Nx.Vulkan.download_f32(t, 4)
  """
  def upload_f32(list) when is_list(list) do
    bin =
      list
      |> Enum.flat_map(fn x -> [<<x::float-32-native>>] end)
      |> IO.iodata_to_binary()

    Nx.Vulkan.Native.upload_binary(bin)
  end

  @doc "Upload a raw binary (already packed f32 little-endian) to GPU memory."
  def upload_binary(bin) when is_binary(bin) do
    Nx.Vulkan.Native.upload_binary(bin)
  end

  @doc """
  Download a GPU buffer back into a list of f32 values. `n_elements`
  must match what was uploaded.

  Non-finite values (NaN, +Inf, -Inf) are returned as the atoms
  `:nan`, `:infinity`, `:neg_infinity`. Erlang's float pattern
  `<<x::float-32-native>>` rejects these bit patterns; we decode
  the raw 32-bit pattern and check the IEEE 754 exponent/mantissa
  to recover them.
  """
  def download_f32(tensor, n_elements) when is_integer(n_elements) and n_elements >= 0 do
    case Nx.Vulkan.Native.download_binary(tensor, n_elements * 4) do
      {:ok, bin} -> {:ok, decode_f32_list(bin)}
      err -> err
    end
  end

  defp decode_f32_list(<<>>), do: []

  defp decode_f32_list(<<bits::32-native, rest::binary>>) do
    [decode_f32(bits) | decode_f32_list(rest)]
  end

  # IEEE 754 binary32: 1 sign bit, 8 exponent, 23 mantissa.
  # We only need the exponent==255 branch for special values; finite
  # values are pulled directly via the float pattern.
  defp decode_f32(bits) do
    <<f::float-32-native>> = <<bits::32-native>>
    f
  rescue
    MatchError -> decode_f32_special(bits)
  end

  # Big-endian view for the spec-defined breakdown — the field
  # extraction is byte-order agnostic when we operate on the integer.
  defp decode_f32_special(bits) do
    sign = Bitwise.band(Bitwise.bsr(bits, 31), 1)
    exponent = Bitwise.band(Bitwise.bsr(bits, 23), 0xFF)
    mantissa = Bitwise.band(bits, 0x7FFFFF)

    cond do
      exponent == 255 and mantissa == 0 and sign == 0 -> :infinity
      exponent == 255 and mantissa == 0 and sign == 1 -> :neg_infinity
      exponent == 255 -> :nan
      true -> raise MatchError, term: bits
    end
  end

  @doc "Download as a raw binary (caller does the unpack)."
  def download_binary(tensor, n_bytes) do
    Nx.Vulkan.Native.download_binary(tensor, n_bytes)
  end

  import Kernel, except: [byte_size: 1, max: 2, min: 2]

  @doc "Byte size of an uploaded tensor (in bytes)."
  defdelegate byte_size(tensor), to: Nx.Vulkan.Native

  # ------------------------------------------------------------------
  # v0.0.3 — elementwise binary ops
  # ------------------------------------------------------------------

  @ops_binary %{
    add: 0,
    multiply: 1,
    subtract: 2,
    divide: 3,
    pow: 4,
    max: 5,
    min: 6,
    equal: 7,
    less: 8,
    greater: 9
  }

  for {name, op_const} <- @ops_binary do
    @doc """
    Elementwise `#{name}` of two GPU tensors of equal length.
    Returns `{:ok, tensor}` or `{:error, reason}`.
    """
    def unquote(name)(a, b) do
      Nx.Vulkan.Native.apply_binary(
        a,
        b,
        unquote(op_const),
        shader_path("elementwise_binary.spv")
      )
    end
  end

  @doc false
  def shader_path(name) do
    :nx_vulkan
    |> :code.priv_dir()
    |> Path.join("shaders")
    |> Path.join(name)
  end

  # ------------------------------------------------------------------
  # v0.0.4 — elementwise unary ops
  # ------------------------------------------------------------------

  @ops_unary %{
    exp: 0,
    log: 1,
    sqrt: 2,
    abs: 3,
    negate: 4,
    sigmoid: 5,
    tanh: 6,
    relu: 7,
    ceil: 8,
    floor: 9,
    sign: 10,
    reciprocal: 11,
    square: 12,
    erf: 13,
    expm1: 14
  }

  for {name, op_const} <- @ops_unary do
    @doc "Elementwise `#{name}` of a GPU tensor."
    def unquote(name)(a) do
      Nx.Vulkan.Native.apply_unary(a, unquote(op_const), shader_path("elementwise_unary.spv"))
    end
  end

  # ------------------------------------------------------------------
  # f64 elementwise — Day 6 / step 2c
  # ------------------------------------------------------------------

  @doc "f64 elementwise binary; dispatches elementwise_binary_f64.spv."
  def apply_binary_f64(a, b, op) do
    code = Map.fetch!(@ops_binary, op)
    Nx.Vulkan.Native.apply_binary_f64(a, b, code,
                                       shader_path("elementwise_binary_f64.spv"))
  end

  @doc "f64 elementwise unary; dispatches elementwise_unary_f64.spv."
  def apply_unary_f64(a, op) do
    code = Map.fetch!(@ops_unary, op)
    Nx.Vulkan.Native.apply_unary_f64(a, code,
                                      shader_path("elementwise_unary_f64.spv"))
  end

  @doc "f64 per-axis reduction; dispatches reduce_axis_f64.spv."
  def reduce_axis_f64(a, outer, reduce_size, inner, op) do
    Nx.Vulkan.Native.reduce_axis_f64(a, outer, reduce_size, inner, op,
                                      shader_path("reduce_axis_f64.spv"))
  end

  @doc "f64 broadcast elementwise binary; dispatches elementwise_binary_broadcast_f64.spv."
  def apply_binary_broadcast_f64(a, b, op, ndim, out_shape, a_strides, b_strides) do
    op_const = Map.fetch!(@ops_binary, op)
    Nx.Vulkan.Native.apply_binary_broadcast_f64(
      a, b, op_const, ndim,
      pad4(out_shape), pad4(a_strides), pad4(b_strides),
      shader_path("elementwise_binary_broadcast_f64.spv")
    )
  end

  @doc """
  Numerically-stable logsumexp over a single virtual reduce axis.
  `log(sum(exp(x - max(x))))`-shape inside one shader dispatch via the
  two-pass shader. f32 only.
  """
  def logsumexp(a, outer, reduce_size, inner) do
    Nx.Vulkan.Native.logsumexp(a, outer, reduce_size, inner,
                                shader_path("logsumexp.spv"))
  end

  # ------------------------------------------------------------------
  # v0.0.5 — reductions (return host-side scalar)
  # ------------------------------------------------------------------

  @doc "Sum of all elements (returns a host-side f32)."
  def sum(t), do: Nx.Vulkan.Native.reduce_scalar(t, 0, shader_path("reduce.spv"))

  @doc "Min of all elements."
  def reduce_min(t), do: Nx.Vulkan.Native.reduce_scalar(t, 1, shader_path("reduce.spv"))

  @doc "Max of all elements."
  def reduce_max(t), do: Nx.Vulkan.Native.reduce_scalar(t, 2, shader_path("reduce.spv"))

  @doc "Mean of all elements (sum + host-side divide)."
  def mean(t) do
    case sum(t) do
      {:ok, s} ->
        n = div(Nx.Vulkan.Native.byte_size(t), 4)
        {:ok, s / n}

      err ->
        err
    end
  end

  # ------------------------------------------------------------------
  # v0.0.6 — matmul (naive)
  # ------------------------------------------------------------------

  @doc """
  Matrix multiply: `C[M*N] = A[M*K] · B[K*N]`. All row-major f32.
  Returns `{:ok, c_tensor}`.

  Auto-selects the best shader variant based on `(M, N, K)`:

    * Tiny (M*N*K < 4096): naive `matmul.spv` — dispatch overhead
      dominates; tiling adds no win.
    * Medium (4096 ≤ M*N*K < 256³): `matmul_tiled.spv` (16×16 shared-
      memory tiles) — good cache behavior, modest GPU occupancy.
    * Large (M*N*K ≥ 256³ ≈ 16M): `matmul_tiled16x2.spv` — each thread
      computes 2 output rows; mac-248 measured **4.2× win at 1024×1024**
      vs the naive variant.

  `matmul_tiled32.spv` exists in spirit too but only wins on Ampere+
  (1024 threads/SM); on Kepler/Maxwell it loses to 16x2 due to shared
  memory pressure (8 KB tile vs 3 KB). Not auto-selected; reachable
  via `matmul_variant/6`.
  """
  def matmul(a, b, m, n, k) do
    {shader, tile_m, tile_n} = pick_matmul(m, n, k)
    matmul_variant(a, b, m, n, k, shader, tile_m, tile_n)
  end

  @doc """
  Matrix multiply with explicit shader variant. Use when you know
  better than the heuristic, or to benchmark.

      :matmul                 # naive, gx=ceil(N/16), gy=ceil(M/16)
      :matmul_tiled           # 16×16 shared-mem tiles
      :matmul_tiled32         # 32×32 tiles (Ampere wins)
      :matmul_tiled16x2       # 32×16 output (2 rows per thread)
  """
  def matmul_variant(a, b, m, n, k, variant)
      when variant in [:matmul, :matmul_tiled, :matmul_tiled32, :matmul_tiled16x2] do
    {tile_m, tile_n} = variant_tiles(variant)
    matmul_variant(a, b, m, n, k, "#{variant}.spv", tile_m, tile_n)
  end

  @doc false
  def matmul_variant(a, b, m, n, k, shader_name, tile_m, tile_n) do
    Nx.Vulkan.Native.matmul_v(a, b, m, n, k, tile_m, tile_n, shader_path(shader_name))
  end

  @doc """
  Picks the best matmul shader for a given `(M, N, K)` shape. Returns
  `{shader_name, tile_m, tile_n}`. Public so benchmarks can introspect
  the heuristic.
  """
  def pick_matmul(m, n, k) do
    flops = m * n * k

    cond do
      # Below 4K total ops: dispatch + descriptor write costs dominate.
      flops < 4_096 -> {"matmul.spv", 16, 16}
      # 256³ = 16 777 216. mac-248's bench shows the 16x2 variant
      # taking the lead from this size onward.
      flops >= 16_777_216 -> {"matmul_tiled16x2.spv", 32, 16}
      # Middle ground: classic 16×16 tile.
      true -> {"matmul_tiled.spv", 16, 16}
    end
  end

  defp variant_tiles(:matmul), do: {16, 16}
  defp variant_tiles(:matmul_tiled), do: {16, 16}
  defp variant_tiles(:matmul_tiled32), do: {32, 32}
  defp variant_tiles(:matmul_tiled16x2), do: {32, 16}

  # ------------------------------------------------------------------
  # v0.0.7 — random
  # ------------------------------------------------------------------

  @doc "Generate `n` uniform [0,1) f32 values, deterministic via `seed`."
  def uniform(n, seed \\ 42) when is_integer(n) and is_integer(seed) do
    Nx.Vulkan.Native.random(n, seed, 0, shader_path("random_philox.spv"))
  end

  @doc "Generate `n` standard-normal N(0,1) f32 values via Box-Muller."
  def normal(n, seed \\ 42) when is_integer(n) and is_integer(seed) do
    Nx.Vulkan.Native.random(n, seed, 1, shader_path("random_philox.spv"))
  end

  # ------------------------------------------------------------------
  # v0.1 phase 1.1 — comparisons via composition
  # ------------------------------------------------------------------

  @doc """
  Branchless select: `cond_true_or_false ? t : f`. `cond` is a 0/1
  tensor (typically the output of `equal/2`, `less/2`, `greater/2`).

  Implemented compositionally as `cond * t + (1 - cond) * f`. Once
  the v0.1 broadcast shader supports scalar broadcast we'll switch
  to a 3-input shader; today this composition is the right shape
  and adds two dispatches' worth of overhead per select.
  """
  def select(cond, t, f) do
    n_bytes = Nx.Vulkan.Native.byte_size(cond)
    n = div(n_bytes, 4)

    with {:ok, ones} <- upload_constant(1.0, n),
         {:ok, inv} <- subtract(ones, cond),
         {:ok, on_t} <- multiply(cond, t),
         {:ok, on_f} <- multiply(inv, f),
         {:ok, out} <- add(on_t, on_f) do
      {:ok, out}
    end
  end

  @doc """
  Clip every element to `[low, high]`. Implemented as
  `max(low, min(high, a))` with broadcasted scalar tensors.
  Replaceable with a single-shader `clip.comp` once the broadcast
  story matures (currently materializes scalars to N-element
  buffers).
  """
  def clip(a, low, high) when is_number(low) and is_number(high) do
    n_bytes = Nx.Vulkan.Native.byte_size(a)
    n = div(n_bytes, 4)

    with {:ok, low_t} <- upload_constant(low, n),
         {:ok, high_t} <- upload_constant(high, n),
         {:ok, capped} <- min(a, high_t),
         {:ok, floored} <- max(low_t, capped) do
      {:ok, floored}
    end
  end

  defp upload_constant(value, n) when is_number(value) do
    f = value / 1.0
    bin = :binary.copy(<<f::float-32-native>>, n)
    Nx.Vulkan.Native.upload_binary(bin)
  end

  # ------------------------------------------------------------------
  # v0.1 phase 1.2 — reshape (metadata-only) + broadcast (Backend) +
  # transpose (new shader)
  # ------------------------------------------------------------------

  @doc """
  2D transpose: `c = a^T` where `a` is M×N and `c` is N×M, both
  row-major f32. Returns `{:ok, c_tensor}`.
  """
  def transpose_2d(a, m, n) do
    Nx.Vulkan.Native.transpose(a, m, n, shader_path("transpose.spv"))
  end

  # ------------------------------------------------------------------
  # v0.1 phase 1.8 GPU path — f32↔f64 cast
  # ------------------------------------------------------------------

  @doc "Cast f32 tensor → f64 (allocates 8-byte output)."
  def cast_f32_to_f64(a, n) do
    Nx.Vulkan.Native.cast(a, n, 8, shader_path("cast_f32_to_f64.spv"))
  end

  @doc "Cast f64 tensor → f32 (allocates 4-byte output)."
  def cast_f64_to_f32(a, n) do
    Nx.Vulkan.Native.cast(a, n, 4, shader_path("cast_f64_to_f32.spv"))
  end

  # ------------------------------------------------------------------
  # v0.1 phase 1.4 GPU path — per-axis reduce
  # ------------------------------------------------------------------

  @doc """
  Per-axis reduction over a virtual 3-D layout (outer, reduce, inner).
  `op`: 0=sum, 1=max, 2=min. Output is (outer * inner) f32.
  """
  def reduce_axis(a, outer, reduce_size, inner, op) do
    Nx.Vulkan.Native.reduce_axis(a, outer, reduce_size, inner, op, shader_path("reduce_axis.spv"))
  end

  # ------------------------------------------------------------------
  # Path A — fused elementwise chain (FUSION_RESEARCH.md)
  # ------------------------------------------------------------------

  @op_codes %{
    # Binary ops — second operand is always buffer `b`.
    add: 0,
    multiply: 1,
    subtract: 2,
    divide: 3,
    pow: 4,
    max: 5,
    min: 6,
    # Unary ops — operate on the running register only.
    exp: 100,
    log: 101,
    sqrt: 102,
    abs: 103,
    negate: 104,
    sigmoid: 105,
    tanh: 106,
    relu: 107,
    ceil: 108,
    floor: 109,
    sign: 110,
    reciprocal: 111,
    square: 112,
    erf: 113,
    expm1: 114
  }

  @doc """
  Run a chain of up to 8 elementwise ops in a single shader dispatch.

  Replaces N separate dispatches with one. Each binary step combines the
  running register with `b`; each unary step transforms the register only.

      iex> {:ok, a} = Nx.Vulkan.upload_f32([1.0, 2.0, 3.0])
      iex> {:ok, b} = Nx.Vulkan.upload_f32([0.5, 0.5, 0.5])
      iex> # (a * b) + b → exp
      iex> {:ok, c} = Nx.Vulkan.fused_chain(a, b, [:multiply, :add, :exp])
      iex> {:ok, vals} = Nx.Vulkan.download_f32(c, 3)
      iex> vals  # exp((a*b)+b) = exp(1.0), exp(1.5), exp(2.0)
      [2.71828..., 4.48168..., 7.38905...]

  Op atoms supported:

    * Binary (combine register with `b`): `:add`, `:multiply`, `:subtract`,
      `:divide`, `:pow`, `:max`, `:min`
    * Unary (transform register): `:exp`, `:log`, `:sqrt`, `:abs`,
      `:negate`, `:sigmoid`, `:tanh`, `:relu`, `:ceil`, `:floor`,
      `:sign`, `:reciprocal`, `:square`

  Note: `:erf` (113) and `:expm1` (114) became fully functional in
  the chain after spirit `161296d1` — `apply_unary` switched in cases
  13 and 14. Earlier versions of the fused shader passed them through
  unchanged.

  Chains longer than 8 ops should be split: dispatch fused_chain twice
  with the running tensor used as `a` for the second call.
  """
  def fused_chain(a_ref, b_ref, ops) when is_list(ops) do
    codes = Enum.map(ops, &Map.fetch!(@op_codes, &1))
    Nx.Vulkan.Native.fused_chain(a_ref, b_ref, codes, shader_path("fused_elementwise.spv"))
  end

  @doc """
  4-input fused chain. `ops_with_buf` items are either `{op_atom, idx}`
  for binary (idx ∈ {1, 2, 3} for b/c/d) or plain `op_atom` for unary.
  All 4 buffers must be the same byte size; up to 8 ops.
  """
  def fused_chain_4(a_ref, b_ref, c_ref, d_ref, ops_with_buf)
      when is_list(ops_with_buf) do
    {codes, buf_idx} =
      ops_with_buf
      |> Enum.map(fn
        {op, idx} -> {Map.fetch!(@op_codes, op), idx}
        op when is_atom(op) -> {Map.fetch!(@op_codes, op), 1}
      end)
      |> Enum.unzip()

    Nx.Vulkan.Native.fused_chain_4(
      a_ref, b_ref, c_ref, d_ref, codes, buf_idx,
      shader_path("fused_elementwise_4in.spv")
    )
  end

  @doc """
  Fused kinetic-energy primitive: `0.5 * sum(p² * inv_mass)` reduced
  per workgroup. Returns a buffer of `ceil(n/256)` partial f32 sums;
  caller does the final reduction (typically via `Nx.Vulkan.sum/1` or
  on the host).
  """
  def kinetic_energy(p_ref, inv_mass_ref) do
    Nx.Vulkan.Native.kinetic_energy(p_ref, inv_mass_ref,
                                     shader_path("kinetic_energy.spv"))
  end

  @doc """
  Fused Normal log-density primitive:
  `-0.5*((x-mu)/sigma)² - log(sigma) - 0.5*log(2π)`.
  Output shape matches `x`. f32 only.
  """
  def normal_logpdf(x_ref, mu_ref, sigma_ref) do
    Nx.Vulkan.Native.normal_logpdf(x_ref, mu_ref, sigma_ref,
                                    shader_path("normal_logpdf.spv"))
  end

  @doc """
  Fused NUTS leapfrog step for a univariate Normal log-density model.
  One Vulkan dispatch per leapfrog step instead of ~12 elementwise
  dispatches via the IR walker. Returns `{q_new_ref, p_new_ref}`.

  `q_ref`, `p_ref`, `inv_mass_ref` are f32 buffers of identical size.
  `eps`, `mu`, `sigma` are scalars (f32 in the shader push constants;
  f64 here for caller convenience). f32 only.

  Closed-form gradient:
  `grad_q log N(q | mu, sigma) = -(q - mu) / sigma²` — no autodiff
  machinery in the shader.
  """
  def leapfrog_normal(q_ref, p_ref, inv_mass_ref, eps, mu, sigma) do
    Nx.Vulkan.Native.leapfrog_normal(
      q_ref, p_ref, inv_mass_ref,
      eps, mu, sigma,
      shader_path("leapfrog_normal.spv")
    )
  end

  @doc """
  Fused **K-step chain** of NUTS leapfrog steps for a univariate Normal
  log-density model. Performs `k` consecutive leapfrog steps in one Vulkan
  dispatch and returns all `k` intermediate states:
  `{q_chain_ref, p_chain_ref, grad_chain_ref, logp_chain_ref}`.

  - `q_chain`, `p_chain`, `grad_chain` — each `k * n` f32 elements,
    laid out row-major (`step k, dimension i` at offset `k*n + i`).
  - `logp_chain` — `k` f32 elements; per-step log-density reduced
    across the `n` dimensions.

  Per-step amortized cost is `(per_dispatch_baseline + k * compute) / k`.
  At `k=32` on the dev box this is ~16 µs per leapfrog step vs ~537 µs
  for the single-step `leapfrog_normal` and ~6000 µs for the unfused
  IR-walker path.

  Constraints (Phase 1.5):
  - `n ≤ 256` (single workgroup; multi-workgroup version is future work).
  - f32 only; long chains (`k ≥ 64`) may accumulate measurable drift
    relative to a f64 reference.
  - Univariate Normal log-density only — closed-form gradient
    `−(q − mu) / sigma²` baked into the shader.
  """
  def leapfrog_chain_normal(q_ref, p_ref, inv_mass_ref, k, eps, mu, sigma)
      when is_integer(k) and k > 0 do
    Nx.Vulkan.Native.leapfrog_chain_normal(
      q_ref, p_ref, inv_mass_ref,
      k, eps, mu, sigma,
      shader_path("leapfrog_chain_normal.spv")
    )
  end

  @doc """
  Multi-workgroup variant of `leapfrog_chain_normal/7` for `n > 256`.
  Returns `{q_chain_ref, p_chain_ref, grad_chain_ref, partial_logp_ref}`
  where `partial_logp_ref` is a buffer of `K * num_workgroups` f32 floats
  (per-workgroup partial sums per step). The caller does the per-step sum
  across the `num_workgroups` axis to recover the final per-step logp.
  Workgroup 0 includes the constant term so the host sum gives final logp
  directly.
  """
  def leapfrog_chain_normal_lg(q_ref, p_ref, inv_mass_ref, k, eps, mu, sigma)
      when is_integer(k) and k > 0 do
    Nx.Vulkan.Native.leapfrog_chain_normal_lg(
      q_ref, p_ref, inv_mass_ref,
      k, eps, mu, sigma,
      shader_path("leapfrog_chain_normal_lg.spv")
    )
  end

  @doc """
  Phase 2 sibling of `leapfrog_chain_normal/7` for the Exponential(lambda)
  family on the unconstrained line (log-transform). Same I/O shape: returns
  `{q_chain_ref, p_chain_ref, grad_chain_ref, logp_chain_ref}`.

  Closed-form unconstrained gradient: `grad_q_uc = 1 - lambda * exp(q_uc)`.
  `n ≤ 256` (single workgroup); see `leapfrog_chain_normal_lg/7` for the
  multi-workgroup pattern when an `_lg` exponential variant is needed.
  """
  def leapfrog_chain_exponential(q_ref, p_ref, inv_mass_ref, k, eps, lambda)
      when is_integer(k) and k > 0 do
    Nx.Vulkan.Native.leapfrog_chain_exponential(
      q_ref, p_ref, inv_mass_ref,
      k, eps, lambda,
      shader_path("leapfrog_chain_exponential.spv")
    )
  end

  @doc """
  Phase 2 chain shader for Student-t(ν, μ, σ). Returns 4-tuple
  `{q_chain_ref, p_chain_ref, grad_chain_ref, logp_chain_ref}`.
  `logp_const` should be precomputed by the caller as
  `log Γ((ν+1)/2) − log Γ(ν/2) − ½ log(πν) − log σ`.
  """
  def leapfrog_chain_studentt(q_ref, p_ref, inv_mass_ref, k, eps, mu, sigma, nu, logp_const)
      when is_integer(k) and k > 0 do
    Nx.Vulkan.Native.leapfrog_chain_studentt(
      q_ref, p_ref, inv_mass_ref,
      k, eps, mu, sigma, nu, logp_const,
      shader_path("leapfrog_chain_studentt.spv")
    )
  end

  @doc """
  Phase 2 chain shader for Cauchy(loc, scale). Returns 4-tuple of refs.
  `log_pi_scale` is precomputed as `−log(π · scale)`.
  """
  def leapfrog_chain_cauchy(q_ref, p_ref, inv_mass_ref, k, eps, loc, scale, log_pi_scale)
      when is_integer(k) and k > 0 do
    Nx.Vulkan.Native.leapfrog_chain_cauchy(
      q_ref, p_ref, inv_mass_ref,
      k, eps, loc, scale, log_pi_scale,
      shader_path("leapfrog_chain_cauchy.spv")
    )
  end

  @doc """
  Phase 2 chain shader for HalfNormal(σ) on the unconstrained line
  via log-transform `q_uc = log(q)`. Returns 4-tuple of refs.
  `log_const` is precomputed as `−log(σ) − ½ log(π)`.

  **Numerical caveat**: the gradient `1 − exp(2·q_uc)/σ²` overflows
  in f32 when `q_uc > ~44`; for σ ≈ 1 the unconstrained range is
  comfortably small.
  """
  def leapfrog_chain_halfnormal(q_ref, p_ref, inv_mass_ref, k, eps, sigma, log_const)
      when is_integer(k) and k > 0 do
    Nx.Vulkan.Native.leapfrog_chain_halfnormal(
      q_ref, p_ref, inv_mass_ref,
      k, eps, sigma, log_const,
      shader_path("leapfrog_chain_halfnormal.spv")
    )
  end

  @doc """
  f64 sibling of `leapfrog_chain_normal/7`. Same I/O contract but all
  buffers use 8 bytes per element (input refs must be f64-typed Vulkan
  tensors). Useful when chain integration needs higher precision than
  f32 (e.g., long chains, sensitive log-densities).
  """
  def leapfrog_chain_normal_f64(q_ref, p_ref, inv_mass_ref, k, eps, mu, sigma)
      when is_integer(k) and k > 0 do
    Nx.Vulkan.Native.leapfrog_chain_normal_f64(
      q_ref, p_ref, inv_mass_ref,
      k, eps, mu, sigma,
      shader_path("leapfrog_chain_normal_f64.spv")
    )
  end

  @doc """
  Phase 2 chain shader for Weibull(k, lambda) on the unconstrained
  line via log-transform `q_uc = log(q)`. Returns 4-tuple of refs.

  Closed-form gradient: `∇logp(q_uc) = k · (1 − (exp(q_uc)/lambda)^k)`.
  `logp_const` is precomputed as `n · (log(k) − k · log(lambda))` —
  no `lgamma` in the shader.
  """
  def leapfrog_chain_weibull(q_ref, p_ref, inv_mass_ref, k_steps, eps, weibull_k, lambda, logp_const)
      when is_integer(k_steps) and k_steps > 0 do
    Nx.Vulkan.Native.leapfrog_chain_weibull(
      q_ref, p_ref, inv_mass_ref,
      k_steps, eps, weibull_k, lambda, logp_const,
      shader_path("leapfrog_chain_weibull.spv")
    )
  end

  # ------------------------------------------------------------------
  # Phase 2 — Nx.Defn JIT integration
  # ------------------------------------------------------------------

  @doc """
  JIT-compile a function so each op dispatches through the Vulkan backend.

  Symmetric counterpart of `EXLA.jit/2` and `EMLX.jit/2`. There's no
  kernel fusion in v0.1 — each `Nx.*` call inside the defn becomes one
  shader dispatch via `Nx.Defn.Evaluator`. Combined-shader fusion is the
  v0.2 work (see FUSION_RESEARCH.md).

  Sets `Nx.Vulkan.Backend` as the global default if it isn't already, so
  scalars and tensors created inside the defn land on the GPU. Calls
  `Nx.Vulkan.init/0` (idempotent).

      iex> Nx.Vulkan.init()
      :ok
      iex> f = fn x -> Nx.add(x, x) end
      iex> Nx.Vulkan.jit(f).(Nx.tensor([1.0, 2.0]))
      #Nx.Tensor<f32[2] [2.0, 4.0]>
  """
  def jit(fun, opts \\ []) do
    ensure_default_backend!()
    compiler = Keyword.get(opts, :compiler, Nx.Vulkan.Compiler)
    Nx.Defn.jit(fun, [{:compiler, compiler} | Keyword.delete(opts, :compiler)])
  end

  defp ensure_default_backend! do
    case Nx.default_backend() do
      {Nx.Vulkan.Backend, _} ->
        :ok

      _ ->
        :ok = init()
        Nx.global_default_backend(Nx.Vulkan.Backend)
        :ok
    end
  end

  # ------------------------------------------------------------------
  # Broadcast elementwise binary
  # ------------------------------------------------------------------

  @doc """
  Dispatch the broadcast variant of an elementwise binary op. `op` is
  one of the binary atom keys in `@ops_binary`, ndim ≤ 4. Stride of 0
  on an axis means broadcast on that axis. `out_shape`, `a_strides`,
  `b_strides` are lists; the helper pads to length 4.

  Use `Nx.Vulkan.broadcast_strides/2` to compute strides from a source
  shape against the output shape.
  """
  def apply_binary_broadcast(a, b, op, ndim, out_shape, a_strides, b_strides) do
    op_const = Map.fetch!(@ops_binary, op)

    Nx.Vulkan.Native.apply_binary_broadcast(
      a,
      b,
      op_const,
      ndim,
      pad4(out_shape),
      pad4(a_strides),
      pad4(b_strides),
      shader_path("elementwise_binary_broadcast.spv")
    )
  end

  @doc """
  Per-axis strides for broadcasting `src_shape` to `out_shape`.

  Returns a length-4 list (zero-padded). Stride is 0 on a broadcast axis
  (size 1 in `src` but >1 in `out`); otherwise it's the row-major
  product of trailing source dims.

      iex> Nx.Vulkan.broadcast_strides({1, 4}, {3, 4})
      [0, 1, 0, 0]
      iex> Nx.Vulkan.broadcast_strides({2, 1}, {2, 4})
      [1, 0, 0, 0]
  """
  def broadcast_strides(src_shape, out_shape) do
    src = Tuple.to_list(src_shape)
    out = Tuple.to_list(out_shape)
    rank = length(out)

    # Right-align: pad src with 1s on the left so the trailing dims align.
    pad = rank - length(src)
    src_aligned = List.duplicate(1, pad) ++ src

    {strides, _} =
      Enum.zip(src_aligned, out)
      |> Enum.reverse()
      |> Enum.reduce({[], 1}, fn {sd, od}, {acc, running} ->
        cond do
          sd == od ->
            {[running | acc], running * sd}

          sd == 1 ->
            {[0 | acc], running}

          true ->
            raise ArgumentError,
                  "shapes don't broadcast: #{inspect(src_shape)}#{inspect(out_shape)}"
        end
      end)

    pad4(strides)
  end

  defp pad4(list) do
    # Kernel.max because Kernel.max/2 is excluded at the top of this
    # module (Nx.Vulkan.max/2 is the GPU op, not the integer max).
    n_pad = Kernel.max(0, 4 - length(list))
    Enum.take(list, 4) ++ List.duplicate(0, n_pad)
  end

  # ------------------------------------------------------------------
  # Buffer pool — Week 1 step 1a (PATH_TO_FULL_PASS.md)
  # ------------------------------------------------------------------

  @doc """
  Release every pooled VkBuf back to the device. Call at idle time to
  reclaim memory; otherwise the pool grows to working-set size and stays
  there. Idempotent.
  """
  defdelegate pool_clear(), to: Nx.Vulkan.Native

  @doc """
  Buffer pool stats. Returns `{:ok, %{hits, misses, freed,
  size_classes, total_pooled}}`. `hits/misses` count alloc requests
  served from / missed by the pool; `freed` counts buffers actually
  vkFreeMemory'd (pool-overflow or explicit clear); `size_classes` is
  the number of distinct sizes currently held; `total_pooled` is the
  total VkBuf count waiting for reuse.

      iex> Nx.Vulkan.init()
      iex> Nx.Vulkan.pool_stats()
      {:ok, %{hits: _, misses: _, freed: _, size_classes: _, total_pooled: _}}
  """
  defdelegate pool_stats(), to: Nx.Vulkan.Native
end