Skip to main content

lib/nx_vulkan/vulkano_backend.ex

defmodule Nx.Vulkan.VulkanoBackend do
  @moduledoc """
  Pure-Rust (vulkano) `Nx.Backend` implementation. Sibling of
  `Nx.Vulkan.Backend` (C++ spirit-backed); same compute fabric,
  different memory-management story.

  Tensors are represented by:

      %Nx.Vulkan.VulkanoBackend{ref: ResourceArc<VulkanoTensor>,
                                shape: tuple, type: {kind, bits}}

  The `ref` is a Rustler resource owning an `Arc<Subbuffer<u8>>` in
  vulkano. When the BEAM GCs the Elixir reference, vulkano's `Drop`
  runs `vkDestroyBuffer` + `vkFreeMemory`. Stale-handle bugs (where
  a freed `VkBuf*` is read back at the C++ layer) are structurally
  impossible: the `Subbuffer` cannot outlive its `Buffer`.

  ## Status — storage-only baseline

  This module implements **just the storage callbacks** required for
  tensors to round-trip host↔GPU without crashing:

    - `init/1`, `from_binary/3`, `to_binary/2`
    - `backend_copy/3`, `backend_transfer/3`, `backend_deallocate/1`
    - `inspect/2`, `constant/3`, `iota/3`, `eye/2`

  Compute ops (add / multiply / sum / matmul / …) are not yet
  implemented. To use this backend for actual computation,
  configure Nx to fall back via `Nx.BinaryBackend` for ops, or
  call `Nx.backend_transfer(t, Nx.BinaryBackend)` before computing.

  The next port chunk will add per-op compute NIFs to
  `Nx.Vulkan.NativeV` and wire them here.
  """

  @behaviour Nx.Backend

  @enforce_keys [:ref, :shape, :type]
  defstruct [:ref, :shape, :type]

  alias Nx.Tensor, as: T

  # ---------------------------------------------------------------- init

  @impl true
  def init(opts), do: opts

  # ---------------------------------------------------------------- storage

  @impl true
  def from_binary(%T{shape: shape, type: type} = tensor, binary, _opts) do
    {:ok, ref} = Nx.Vulkan.NativeV.buf_upload(binary)
    put_in(tensor.data, %__MODULE__{ref: ref, shape: shape, type: type})
  end

  @impl true
  def to_binary(%T{data: %__MODULE__{ref: ref}, shape: shape, type: type}, _limit) do
    {:ok, bin} = Nx.Vulkan.NativeV.buf_download(ref)
    expected = byte_size_of(shape) * element_bytes(type)

    cond do
      byte_size(bin) == expected -> bin
      byte_size(bin) > expected -> binary_part(bin, 0, expected)
      true -> bin
    end
  end

  @impl true
  def backend_copy(%T{} = tensor, target_backend, opts) do
    expected = byte_size_of(tensor.shape) * element_bytes(tensor.type)
    bin = to_binary(tensor, expected)
    target_backend.from_binary(tensor, bin, opts)
  end

  @impl true
  def backend_transfer(%T{} = tensor, backend, opts) do
    backend_copy(tensor, backend, opts)
  end

  @impl true
  def backend_deallocate(%T{}), do: :ok

  # ---------------------------------------------------------------- inspect

  @impl true
  def inspect(%T{} = tensor, opts) do
    try do
      tensor
      |> backend_copy(Nx.BinaryBackend, [])
      |> Nx.BinaryBackend.inspect(opts)
    catch
      :exit, _ -> Inspect.Algebra.string("#Nx.Vulkan.VulkanoBackend<unreadable>")
      _, _ -> Inspect.Algebra.string("#Nx.Vulkan.VulkanoBackend<unreadable>")
    end
  end

  # ---------------------------------------------------------------- creation

  @impl true
  def constant(%T{shape: shape, type: type} = tensor, scalar, _opts) do
    n = byte_size_of(shape)
    bin = :binary.copy(encode_scalar(scalar, type), n)
    {:ok, ref} = Nx.Vulkan.NativeV.buf_upload(bin)
    put_in(tensor.data, %__MODULE__{ref: ref, shape: shape, type: type})
  end

  @impl true
  def iota(%T{shape: shape, type: type} = out, axis, _opts) do
    # Materialise on the host via BinaryBackend, then upload.
    iota_t = Nx.iota(shape, type: type, axis: axis, backend: Nx.BinaryBackend)
    from_binary(out, Nx.to_binary(iota_t), [])
  end

  @impl true
  def eye(%T{shape: shape, type: type} = out, _opts) do
    eye_t = Nx.eye(shape, type: type, backend: Nx.BinaryBackend)
    from_binary(out, Nx.to_binary(eye_t), [])
  end

  # ---------------------------------------------------------------- elementwise binary

  # Op codes match `priv/shaders/elementwise_binary.spv` spec constant ID 0:
  #   0=add  1=mul  2=sub  3=div  4=pow  5=max  6=min
  @binary_ops [
    add: 0,
    multiply: 1,
    subtract: 2,
    divide: 3,
    pow: 4,
    max: 5,
    min: 6
  ]

  @elementwise_binary_spv Path.expand(
                            "../../priv/shaders/elementwise_binary.spv",
                            __DIR__
                          )
  @elementwise_binary_f64_spv Path.expand(
                                "../../priv/shaders/elementwise_binary_f64.spv",
                                __DIR__
                              )

  defp binary_spv({:f, 32}), do: @elementwise_binary_spv
  defp binary_spv({:f, 64}), do: @elementwise_binary_f64_spv
  defp binary_spv(_), do: nil

  for {op, code} <- @binary_ops do
    @impl true
    def unquote(op)(%T{shape: shape, type: type} = out, a, b) do
      a_v = ensure_on_backend(a)
      b_v = ensure_on_backend(b)
      spv = binary_spv(type)

      shape_match =
        a_v.shape == b_v.shape and a_v.shape == shape and a_v.type == b_v.type and
          a_v.type == type

      if spv != nil and shape_match do
        %T{data: %__MODULE__{ref: a_ref}} = a_v
        %T{data: %__MODULE__{ref: b_ref}} = b_v
        n = byte_size_of(shape)
        n_bytes = n * element_bytes(type)
        {:ok, out_ref} = Nx.Vulkan.NativeV.buf_alloc(n_bytes)

        :ok =
          Nx.Vulkan.NativeV.apply_binary(out_ref, a_ref, b_ref, n, unquote(code), spv)

        put_in(out.data, %__MODULE__{ref: out_ref, shape: shape, type: type})
      else
        binary_op_host_fallback(unquote(op), out, a_v, b_v)
      end
    end
  end

  defp binary_op_host_fallback(op, out, a, b) do
    a_bin = Nx.backend_transfer(a, Nx.BinaryBackend)
    b_bin = Nx.backend_transfer(b, Nx.BinaryBackend)
    result = apply(Nx, op, [a_bin, b_bin])
    from_binary(out, Nx.to_binary(result), [])
  end

  # ---------------------------------------------------------------- elementwise unary

  # Op codes match `priv/shaders/elementwise_unary.spv` spec constant ID 0:
  #   0=exp  1=log  2=sqrt  3=abs  4=neg  5=sigmoid  6=tanh  7=relu
  #   8=ceil  9=floor  10=sign  11=reciprocal  12=square
  @unary_ops [
    exp: 0,
    log: 1,
    sqrt: 2,
    abs: 3,
    negate: 4,
    sigmoid: 5,
    tanh: 6,
    floor: 9,
    ceil: 8,
    sign: 10
  ]

  @elementwise_unary_spv Path.expand(
                           "../../priv/shaders/elementwise_unary.spv",
                           __DIR__
                         )
  @elementwise_unary_f64_spv Path.expand(
                               "../../priv/shaders/elementwise_unary_f64.spv",
                               __DIR__
                             )

  defp unary_spv({:f, 32}), do: @elementwise_unary_spv
  defp unary_spv({:f, 64}), do: @elementwise_unary_f64_spv
  defp unary_spv(_), do: nil

  for {op, code} <- @unary_ops do
    @impl true
    def unquote(op)(%T{shape: shape, type: type} = out, a) do
      a_v = ensure_on_backend(a)
      spv = unary_spv(type)

      if spv != nil and a_v.type == type do
        %T{data: %__MODULE__{ref: a_ref}} = a_v
        n = byte_size_of(shape)
        n_bytes = n * element_bytes(type)
        {:ok, out_ref} = Nx.Vulkan.NativeV.buf_alloc(n_bytes)

        :ok = Nx.Vulkan.NativeV.apply_unary(out_ref, a_ref, n, unquote(code), spv)

        put_in(out.data, %__MODULE__{ref: out_ref, shape: shape, type: type})
      else
        unary_op_host_fallback(unquote(op), out, a_v)
      end
    end
  end

  defp unary_op_host_fallback(op, out, a) do
    a_bin = Nx.backend_transfer(a, Nx.BinaryBackend)
    result = apply(Nx, op, [a_bin])
    from_binary(out, Nx.to_binary(result), [])
  end

  # ---------------------------------------------------------------- reductions

  @reduce_axis_spv Path.expand("../../priv/shaders/reduce_axis.spv", __DIR__)
  @reduce_axis_f64_spv Path.expand("../../priv/shaders/reduce_axis_f64.spv", __DIR__)

  defp reduce_spv({:f, 32}), do: @reduce_axis_spv
  defp reduce_spv({:f, 64}), do: @reduce_axis_f64_spv
  defp reduce_spv(_), do: nil

  @impl true
  def sum(out, t, opts), do: do_reduce(out, t, opts, 0)

  @impl true
  def reduce_max(out, t, opts), do: do_reduce(out, t, opts, 1)

  @impl true
  def reduce_min(out, t, opts), do: do_reduce(out, t, opts, 2)

  # Resolves the (outer, reduce_size, inner) virtual shape from
  # `opts[:axes]`. Supports all-axes (collapse to scalar) and
  # single-axis cases that map cleanly to contiguous slabs. More
  # exotic patterns fall back to BinaryBackend transfer.
  defp do_reduce(
         %T{shape: out_shape, type: type} = out,
         %T{shape: in_shape} = tensor,
         opts,
         op_code
       ) do
    axes = Keyword.get(opts, :axes) || all_axes(in_shape)

    spv = reduce_spv(type)

    fast_path =
      spv != nil and tensor.type == type and
        match?(%__MODULE__{}, tensor.data) and
        match?({:ok, _}, classify_reduce_axes(in_shape, axes))

    if fast_path do
      %T{data: %__MODULE__{ref: a_ref}} = tensor
      {:ok, {outer, reduce_size, inner}} = classify_reduce_axes(in_shape, axes)
      n_out = max(byte_size_of(out_shape), 1)
      out_bytes = n_out * element_bytes(type)
      {:ok, out_ref} = Nx.Vulkan.NativeV.buf_alloc(out_bytes)

      :ok =
        Nx.Vulkan.NativeV.reduce_axis(out_ref, a_ref, outer, reduce_size, inner, op_code, spv)

      put_in(out.data, %__MODULE__{ref: out_ref, shape: out_shape, type: type})
    else
      reduce_op_host_fallback(op_code, out, tensor, opts)
    end
  end

  defp reduce_op_host_fallback(op_code, out, tensor, opts) do
    bin_in = Nx.backend_transfer(tensor, Nx.BinaryBackend)

    op =
      case op_code do
        0 -> :sum
        1 -> :reduce_max
        2 -> :reduce_min
      end

    result = apply(Nx, op, [bin_in, opts])
    from_binary(out, Nx.to_binary(result), [])
  end

  defp all_axes(shape), do: Enum.to_list(0..(tuple_size(shape) - 1))

  # Classify the reduction shape:
  #   - All axes      → outer=1, reduce=product(shape), inner=1
  #   - Leading axes  → outer=1, reduce=product(reduced), inner=product(remaining)
  #   - Trailing axes → outer=product(remaining), reduce=product(reduced), inner=1
  defp classify_reduce_axes(in_shape, axes) do
    rank = tuple_size(in_shape)
    sorted = Enum.sort(axes)
    dims = Tuple.to_list(in_shape)

    cond do
      sorted == Enum.to_list(0..(rank - 1)) ->
        {:ok, {1, Enum.reduce(dims, 1, &Kernel.*/2), 1}}

      sorted == Enum.to_list(0..(length(sorted) - 1)) ->
        reduced = Enum.take(dims, length(sorted))
        remaining = Enum.drop(dims, length(sorted))
        outer = 1
        reduce_size = Enum.reduce(reduced, 1, &Kernel.*/2)
        inner = Enum.reduce(remaining, 1, &Kernel.*/2)
        {:ok, {outer, reduce_size, inner}}

      sorted == Enum.to_list((rank - length(sorted))..(rank - 1)) ->
        kept = Enum.take(dims, rank - length(sorted))
        reduced = Enum.drop(dims, rank - length(sorted))
        outer = Enum.reduce(kept, 1, &Kernel.*/2)
        reduce_size = Enum.reduce(reduced, 1, &Kernel.*/2)
        inner = 1
        {:ok, {outer, reduce_size, inner}}

      true ->
        :fallback
    end
  end

  # ---------------------------------------------------------------- shape / movement

  @transpose_spv Path.expand("../../priv/shaders/transpose.spv", __DIR__)

  # Reshape + squeeze are zero-copy: same buffer, new shape metadata.
  # The buffer might be physically larger than the new shape implies
  # if it came from an operation that allocated extra slack — that's
  # fine, the metadata determines what bytes get read out.

  @impl true
  def reshape(%T{shape: new_shape, type: type} = out, %T{data: %__MODULE__{ref: ref}}) do
    put_in(out.data, %__MODULE__{ref: ref, shape: new_shape, type: type})
  end

  @impl true
  def squeeze(%T{shape: new_shape, type: type} = out, %T{data: %__MODULE__{ref: ref}}, _axes) do
    put_in(out.data, %__MODULE__{ref: ref, shape: new_shape, type: type})
  end

  # 2D transpose. Higher-rank transposes (axis permutations) fall back
  # to BinaryBackend until we wire a general-rank shader.
  @impl true
  def transpose(
        %T{shape: out_shape, type: type} = out,
        %T{shape: in_shape, data: %__MODULE__{ref: a_ref}},
        axes
      ) do
    rank = tuple_size(in_shape)

    case {rank, axes} do
      {2, [1, 0]} ->
        m = elem(in_shape, 0)
        n = elem(in_shape, 1)
        n_bytes = m * n * element_bytes(type)
        {:ok, out_ref} = Nx.Vulkan.NativeV.buf_alloc(n_bytes)

        :ok =
          Nx.Vulkan.NativeV.transpose_2d(
            out_ref,
            a_ref,
            m,
            n,
            @transpose_spv
          )

        put_in(out.data, %__MODULE__{ref: out_ref, shape: out_shape, type: type})

      _ ->
        raise "transpose rank=#{rank} axes=#{inspect(axes)}: only 2D [1,0] supported in stage 4"
    end
  end

  # ---------------------------------------------------------------- host-fallback ops

  # as_type — Nx-level cast via BinaryBackend. For f32↔f32 (no-op) we
  # just rewrap. For real casts we round-trip through host.
  @impl true
  def as_type(%T{type: type} = out, %T{type: source_type, data: %__MODULE__{ref: ref}} = tensor) do
    if type == source_type do
      put_in(out.data, %__MODULE__{ref: ref, shape: out.shape, type: type})
    else
      bin_in = Nx.backend_transfer(tensor, Nx.BinaryBackend)
      bin_cast = Nx.as_type(bin_in, type)
      bin = Nx.to_binary(bin_cast)
      from_binary(out, bin, [])
    end
  end

  # Comparison ops — host-fallback. The elementwise_binary.spv catalog
  # has op codes 7/8/9 (equal/less/greater) but its output is f32, not
  # u8 (the type Nx expects from a comparison). Routing through
  # BinaryBackend keeps the Nx type contract correct. Scholar uses
  # comparison + select heavily; this unblocks the classical-ML target.
  for op <- [:equal, :not_equal, :less, :less_equal, :greater, :greater_equal] do
    @impl true
    def unquote(op)(out, a, b) do
      a_v = ensure_on_backend(a)
      b_v = ensure_on_backend(b)
      a_bin = Nx.backend_transfer(a_v, Nx.BinaryBackend)
      b_bin = Nx.backend_transfer(b_v, Nx.BinaryBackend)
      result = apply(Nx, unquote(op), [a_bin, b_bin])
      from_binary(out, Nx.to_binary(result), [])
    end
  end

  # select(cond, on_true, on_false) — host-fallback.
  @impl true
  def select(out, pred, on_true, on_false) do
    pred_bin = Nx.backend_transfer(ensure_on_backend(pred), Nx.BinaryBackend)
    t_bin = Nx.backend_transfer(ensure_on_backend(on_true), Nx.BinaryBackend)
    f_bin = Nx.backend_transfer(ensure_on_backend(on_false), Nx.BinaryBackend)
    result = Nx.select(pred_bin, t_bin, f_bin)
    from_binary(out, Nx.to_binary(result), [])
  end

  # all/3, any/3 — boolean reductions, host-fallback.
  for op <- [:all, :any] do
    @impl true
    def unquote(op)(out, tensor, opts) do
      bin = Nx.backend_transfer(ensure_on_backend(tensor), Nx.BinaryBackend)
      result = apply(Nx, unquote(op), [bin, opts])
      from_binary(out, Nx.to_binary(result), [])
    end
  end

  # block/4 — Nx's "block" callback dispatches Nx.Block-derived structs
  # (SVD, QR, LU, etc.). Host-fallback: transfer every tensor in the
  # input tuple to BinaryBackend, evaluate via its block impl, then
  # transfer outputs back. Scholar's linear regression uses
  # Nx.Block.LinAlg.SVD internally; this unblocks it.
  @impl true
  def block(out, block_def, inputs, opts) do
    transfer_to_bin = fn t ->
      if is_struct(t, Nx.Tensor) and match?(%__MODULE__{}, t.data) do
        Nx.backend_transfer(t, Nx.BinaryBackend)
      else
        t
      end
    end

    inputs_bin =
      cond do
        is_list(inputs) -> Enum.map(inputs, transfer_to_bin)
        is_tuple(inputs) -> inputs |> Tuple.to_list() |> Enum.map(transfer_to_bin) |> List.to_tuple()
        true -> transfer_to_bin.(inputs)
      end

    result = Nx.BinaryBackend.block(out, block_def, inputs_bin, opts)

    # `result` may be a single tensor or a tuple of tensors. Walk and
    # transfer each tensor leaf back to VulkanoBackend.
    transfer_back = fn t ->
      if is_struct(t, Nx.Tensor) do
        Nx.backend_transfer(t, __MODULE__)
      else
        t
      end
    end

    case result do
      %Nx.Tensor{} = t -> transfer_back.(t)
      tuple when is_tuple(tuple) ->
        tuple
        |> Tuple.to_list()
        |> Enum.map(transfer_back)
        |> List.to_tuple()
    end
  end

  # ---------------------------------------------------------------- slice (host fallback)

  # Slice is host-routed: download the source tensor to BinaryBackend,
  # do the slice there, upload the slab back. A future stage adds a
  # GPU-side slice shader for contiguous prefixes; until then this is
  # correct but copies through host memory.
  @impl true
  def slice(out, tensor, start_indices, lengths, strides) do
    # Delegate to Nx-level slice on BinaryBackend, then upload result.
    bin_in = Nx.backend_transfer(tensor, Nx.BinaryBackend)
    bin_result = Nx.slice(bin_in, start_indices, lengths, strides: strides)
    bin = Nx.to_binary(bin_result)
    from_binary(out, bin, [])
  end

  # ---------------------------------------------------------------- linalg

  @matmul_spv Path.expand("../../priv/shaders/matmul.spv", __DIR__)

  # Dot product (matmul) — Nx callback signature:
  #   dot(out, a, contracting_axes_a, batched_axes_a,
  #            b, contracting_axes_b, batched_axes_b)
  #
  # Fast path: rank-2 × rank-2, contracting [1] of a vs [0] of b
  # (standard matmul A·B). f32 only — matmul.spv has no f64 variant
  # in priv/shaders/ yet. Everything else routes through BinaryBackend.
  @impl true
  def dot(%T{shape: out_shape, type: type} = out, a, axes_a, batched_a, b, axes_b, batched_b) do
    a_v = ensure_on_backend(a)
    b_v = ensure_on_backend(b)

    fast_path =
      type == {:f, 32} and a_v.type == {:f, 32} and b_v.type == {:f, 32} and
        tuple_size(a_v.shape) == 2 and tuple_size(b_v.shape) == 2 and
        axes_a == [1] and axes_b == [0] and
        batched_a == [] and batched_b == []

    if fast_path do
      %T{data: %__MODULE__{ref: a_ref}, shape: a_shape} = a_v
      %T{data: %__MODULE__{ref: b_ref}, shape: b_shape} = b_v
      m = elem(a_shape, 0)
      k_a = elem(a_shape, 1)
      n = elem(b_shape, 1)

      out_bytes = m * n * element_bytes(type)
      {:ok, out_ref} = Nx.Vulkan.NativeV.buf_alloc(out_bytes)

      :ok = Nx.Vulkan.NativeV.matmul(out_ref, a_ref, b_ref, m, n, k_a, @matmul_spv)

      put_in(out.data, %__MODULE__{ref: out_ref, shape: out_shape, type: type})
    else
      a_bin = Nx.backend_transfer(a_v, Nx.BinaryBackend)
      b_bin = Nx.backend_transfer(b_v, Nx.BinaryBackend)
      result = Nx.dot(a_bin, axes_a, batched_a, b_bin, axes_b, batched_b)
      from_binary(out, Nx.to_binary(result), [])
    end
  end

  # ---------------------------------------------------------------- helpers

  # Tolerate inputs from other backends — Nx.Defn.Evaluator may hand us
  # tensors that haven't been transferred yet (e.g. an Nx.constant
  # produced on BinaryBackend before the op dispatches here).
  defp ensure_on_backend(%T{data: %__MODULE__{}} = t), do: t

  defp ensure_on_backend(%T{} = t) do
    Nx.backend_transfer(t, __MODULE__)
  end

  defp byte_size_of(shape) when is_tuple(shape) do
    shape |> Tuple.to_list() |> Enum.reduce(1, &(&1 * &2))
  end

  defp element_bytes({:f, 32}), do: 4
  defp element_bytes({:f, 64}), do: 8
  defp element_bytes({:s, 8}), do: 1
  defp element_bytes({:s, 16}), do: 2
  defp element_bytes({:s, 32}), do: 4
  defp element_bytes({:s, 64}), do: 8
  defp element_bytes({:u, 8}), do: 1
  defp element_bytes({:u, 16}), do: 2
  defp element_bytes({:u, 32}), do: 4
  defp element_bytes({:u, 64}), do: 8
  defp element_bytes({:bf, 16}), do: 2

  defp encode_scalar(s, {:f, 32}), do: <<s / 1.0::float-32-native>>
  defp encode_scalar(s, {:f, 64}), do: <<s / 1.0::float-64-native>>
  defp encode_scalar(s, {:s, 8}), do: <<trunc(s)::signed-8>>
  defp encode_scalar(s, {:s, 16}), do: <<trunc(s)::signed-16-native>>
  defp encode_scalar(s, {:s, 32}), do: <<trunc(s)::signed-32-native>>
  defp encode_scalar(s, {:s, 64}), do: <<trunc(s)::signed-64-native>>
  defp encode_scalar(s, {:u, 8}), do: <<trunc(s)::unsigned-8>>
  defp encode_scalar(s, {:u, 16}), do: <<trunc(s)::unsigned-16-native>>
  defp encode_scalar(s, {:u, 32}), do: <<trunc(s)::unsigned-32-native>>
  defp encode_scalar(s, {:u, 64}), do: <<trunc(s)::unsigned-64-native>>
  defp encode_scalar(s, {:bf, 16}), do: <<s / 1.0::float-16-native>>
end