Skip to main content

native/nx_vulkan_vulkano/src/lib.rs

//! nx_vulkan_vulkano — Pure-Rust Rustler NIF for synthesised chain
//! shader dispatch via vulkano.
//!
//! Sibling of `nx_vulkan_native` (the C++ shim + spirit Vulkan backend).
//! Resource lifetimes flow through `Arc<Buffer>` so stale `VkBuf*`
//! handles cannot escape — the bug class that surfaced in Mission II
//! R4 step 4 (Nx.Vulkan.Backend.to_binary ArgumentError on a freed
//! tensor reference) is structurally eliminated.
//!
//! Exposes one NIF for now:
//!
//!     leapfrog_chain_synth(q_bin, p_bin, extras_bin, push, k, spv_path)
//!         -> {:ok, {q_chain_bin, p_chain_bin, grad_chain_bin,
//!                   logp_chain_bin}}
//!         |  {:error, atom_or_string}
//!
//! All inputs and outputs are binaries; the NIF allocates fresh
//! Vulkan buffers per call (no persistent pool — that comes later
//! once the calling pattern is established).

use std::fs;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::OnceLock;

use rustler::{Binary, Encoder, Env, NewBinary, NifResult, ResourceArc, Term};
use vulkano::{
    buffer::{Buffer, BufferContents, BufferCreateInfo, BufferUsage, Subbuffer},
    command_buffer::{
        allocator::{StandardCommandBufferAllocator, StandardCommandBufferAllocatorCreateInfo},
        AutoCommandBufferBuilder, CommandBufferUsage,
    },
    descriptor_set::{
        allocator::StandardDescriptorSetAllocator, PersistentDescriptorSet, WriteDescriptorSet,
    },
    device::{
        physical::PhysicalDeviceType, Device, DeviceCreateInfo, Queue, QueueCreateInfo, QueueFlags,
    },
    instance::{Instance, InstanceCreateInfo},
    memory::allocator::{AllocationCreateInfo, MemoryTypeFilter, StandardMemoryAllocator},
    pipeline::{
        compute::ComputePipelineCreateInfo, layout::PipelineDescriptorSetLayoutCreateInfo,
        ComputePipeline, PipelineBindPoint, PipelineLayout, PipelineShaderStageCreateInfo,
    },
    shader::SpecializationConstant,
    shader::{ShaderModule, ShaderModuleCreateInfo},
    sync::{self, GpuFuture},
    VulkanLibrary,
};

mod atoms {
    rustler::atoms! {
        ok,
        error,
        size_mismatch,
        bad_input,
        spv_read_failed,
        vulkan_init_failed,
        dispatch_failed,
        upload_failed,
        download_failed,
    }
}

// -- Pipeline cache --------------------------------------------------------
//
// vulkano's StandardDescriptorSetAllocator (allocator.rs:448) creates a fresh
// DescriptorPool per unique layout identity. Per-call pipeline + layout
// creation produces a fresh layout every dispatch, so the allocator never
// recycles its 32-slot pool — it just keeps creating new pools, eventually
// exhausting driver-side limits (observed: ~5000 iterations on FreeBSD
// NVIDIA before `descriptor set: a non-validation error occurred`).
//
// Caching by (spv_path, op_code) means the same layout identity is used
// across calls; vulkano's allocator recycles slots within a single pool.
//
// op_code = -1 sentinel means "shader has no spec constant" (reduce_axis,
// transpose_2d, matmul, leapfrog_chain_synth).

#[derive(Clone)]
struct CachedPipeline {
    layout: Arc<PipelineLayout>,
    pipeline: Arc<ComputePipeline>,
}

static PIPELINE_CACHE: OnceLock<Mutex<std::collections::HashMap<(String, i32), CachedPipeline>>> =
    OnceLock::new();

fn pipeline_cache() -> &'static Mutex<std::collections::HashMap<(String, i32), CachedPipeline>> {
    PIPELINE_CACHE.get_or_init(|| Mutex::new(std::collections::HashMap::new()))
}

fn get_or_create_pipeline(
    spv_path: &str,
    op_code: Option<i32>,
) -> Result<CachedPipeline, String> {
    let key = (spv_path.to_string(), op_code.unwrap_or(-1));

    {
        let guard = pipeline_cache().lock().map_err(|_| "cache poisoned".to_string())?;
        if let Some(cached) = guard.get(&key) {
            return Ok(cached.clone());
        }
    }

    let context = ctx()?;
    let spv_bytes = fs::read(spv_path).map_err(|e| format!("read spv: {e}"))?;
    let spv_words = bytes_to_u32_words(&spv_bytes)?;

    let shader = unsafe {
        ShaderModule::new(context.device.clone(), ShaderModuleCreateInfo::new(&spv_words))
            .map_err(|e| format!("ShaderModule: {e}"))?
    };

    let entry = match op_code {
        Some(op) => {
            let mut spec: ahash::HashMap<u32, SpecializationConstant> =
                ahash::HashMap::default();
            spec.insert(0, SpecializationConstant::I32(op));
            let specialized = shader
                .specialize(spec)
                .map_err(|e| format!("specialize: {e}"))?;
            specialized
                .entry_point("main")
                .ok_or_else(|| "no main entry point".to_string())?
        }
        None => shader
            .entry_point("main")
            .ok_or_else(|| "no main entry point".to_string())?,
    };

    let stage = PipelineShaderStageCreateInfo::new(entry);
    let layout_info = PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage])
        .into_pipeline_layout_create_info(context.device.clone())
        .map_err(|e| format!("layout info: {e}"))?;
    let layout = PipelineLayout::new(context.device.clone(), layout_info)
        .map_err(|e| format!("PipelineLayout: {e}"))?;

    let pipeline = ComputePipeline::new(
        context.device.clone(),
        None,
        ComputePipelineCreateInfo::stage_layout(stage, layout.clone()),
    )
    .map_err(|e| format!("ComputePipeline: {e}"))?;

    let cached = CachedPipeline { layout, pipeline };

    pipeline_cache()
        .lock()
        .map_err(|_| "cache poisoned".to_string())?
        .insert(key, cached.clone());

    Ok(cached)
}

/// NIF resource: a Vulkan-backed buffer whose lifetime is owned by Rust.
/// When the Elixir-side reference is GC'd, Rustler runs the Drop, which
/// in turn drops the inner Subbuffer. The Subbuffer holds an Arc to the
/// underlying allocation; once all references go, vkDestroyBuffer +
/// vkFreeMemory run via vulkano's Drop chain. No raw VkBuf* escapes.
pub struct VulkanoTensor {
    buf: Subbuffer<[u8]>,
    n_bytes: u64,
}

/// Lazy-init Vulkan context: instance, device, queue, allocators.
/// Held across NIF calls to avoid per-dispatch instance creation.
struct VkContext {
    device: Arc<Device>,
    queue: Arc<Queue>,
    mem_allocator: Arc<StandardMemoryAllocator>,
    cmd_allocator: Arc<StandardCommandBufferAllocator>,
    set_allocator: Arc<StandardDescriptorSetAllocator>,
}

static CTX: OnceLock<VkContext> = OnceLock::new();

fn ctx() -> Result<&'static VkContext, String> {
    if let Some(c) = CTX.get() {
        return Ok(c);
    }

    let library = VulkanLibrary::new().map_err(|e| format!("VulkanLibrary::new: {e}"))?;
    let instance = Instance::new(library, InstanceCreateInfo::default())
        .map_err(|e| format!("Instance::new: {e}"))?;

    let (physical, queue_family_index) = instance
        .enumerate_physical_devices()
        .map_err(|e| format!("enumerate_physical_devices: {e}"))?
        .filter_map(|p| {
            p.queue_family_properties()
                .iter()
                .enumerate()
                .position(|(_, q)| q.queue_flags.intersects(QueueFlags::COMPUTE))
                .map(|i| (p, i as u32))
        })
        .min_by_key(|(p, _)| match p.properties().device_type {
            PhysicalDeviceType::DiscreteGpu => 0,
            PhysicalDeviceType::IntegratedGpu => 1,
            PhysicalDeviceType::VirtualGpu => 2,
            PhysicalDeviceType::Cpu => 3,
            _ => 4,
        })
        .ok_or_else(|| "no compute-capable Vulkan device".to_string())?;

    eprintln!(
        "[nx_vulkan_vulkano] device: {} ({:?})",
        physical.properties().device_name,
        physical.properties().device_type
    );

    // Enable shaderFloat64 if the device supports it; required by the
    // _f64.spv shaders. Falls back gracefully on devices without it
    // (those will keep using the f32 paths + host fallback for f64).
    let supports_f64 = physical.supported_features().shader_float64;

    let enabled_features = vulkano::device::Features {
        shader_float64: supports_f64,
        ..Default::default()
    };

    let (device, mut queues) = Device::new(
        physical,
        DeviceCreateInfo {
            queue_create_infos: vec![QueueCreateInfo {
                queue_family_index,
                ..Default::default()
            }],
            enabled_features,
            ..Default::default()
        },
    )
    .map_err(|e| format!("Device::new: {e}"))?;

    let queue = queues.next().ok_or_else(|| "no queue".to_string())?;

    let mem_allocator = Arc::new(StandardMemoryAllocator::new_default(device.clone()));
    let cmd_allocator = Arc::new(StandardCommandBufferAllocator::new(
        device.clone(),
        StandardCommandBufferAllocatorCreateInfo::default(),
    ));
    // Default 32-slot pool is fine *if* layouts are stable (which the
    // pipeline cache ensures). Bumping set_count regressed RTX 3060 Ti
    // perf 6× on small-matmul without helping FreeBSD's failure mode —
    // see r1 of the race bench (May 20 2026).
    let set_allocator = Arc::new(StandardDescriptorSetAllocator::new(
        device.clone(),
        Default::default(),
    ));

    let ctx = VkContext {
        device,
        queue,
        mem_allocator,
        cmd_allocator,
        set_allocator,
    };

    let _ = CTX.set(ctx);
    Ok(CTX.get().unwrap())
}

#[derive(Clone, Copy, BufferContents)]
#[repr(C)]
struct PushBlock {
    k_steps: u32,
    n_obs: u32,
    d: u32,
    _pad: u32,
    eps: f32,
}

fn parse_push_block(bytes: &[u8]) -> Result<PushBlock, &'static str> {
    if bytes.len() < 20 {
        return Err("push block must be >= 20 bytes");
    }
    let u32_at = |off: usize| {
        u32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]])
    };
    let f32_at = |off: usize| {
        f32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]])
    };
    Ok(PushBlock {
        k_steps: u32_at(0),
        n_obs: u32_at(4),
        d: u32_at(8),
        _pad: u32_at(12),
        eps: f32_at(16),
    })
}

fn bytes_to_u32_words(bytes: &[u8]) -> Result<Vec<u32>, &'static str> {
    if bytes.len() % 4 != 0 {
        return Err("SPV bytes must be u32-aligned");
    }
    Ok(bytes
        .chunks_exact(4)
        .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
        .collect())
}

fn upload_buffer(
    alloc: Arc<StandardMemoryAllocator>,
    bytes: &[u8],
    usage: BufferUsage,
) -> Result<Subbuffer<[u8]>, String> {
    Buffer::from_iter(
        alloc,
        BufferCreateInfo {
            usage,
            ..Default::default()
        },
        AllocationCreateInfo {
            memory_type_filter: MemoryTypeFilter::PREFER_DEVICE
                | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE,
            ..Default::default()
        },
        bytes.iter().copied(),
    )
    .map_err(|e| format!("upload buffer: {e}"))
}

fn alloc_buffer(
    alloc: Arc<StandardMemoryAllocator>,
    n_bytes: usize,
    usage: BufferUsage,
) -> Result<Subbuffer<[u8]>, String> {
    Buffer::from_iter(
        alloc,
        BufferCreateInfo {
            usage,
            ..Default::default()
        },
        AllocationCreateInfo {
            memory_type_filter: MemoryTypeFilter::PREFER_DEVICE
                | MemoryTypeFilter::HOST_RANDOM_ACCESS,
            ..Default::default()
        },
        std::iter::repeat(0u8).take(n_bytes),
    )
    .map_err(|e| format!("alloc buffer: {e}"))
}

fn download_buffer(buf: Subbuffer<[u8]>) -> Result<Vec<u8>, String> {
    let guard = buf.read().map_err(|e| format!("read buffer: {e}"))?;
    Ok(guard.to_vec())
}

/// Run a K-step leapfrog dispatch against the synthesised SPV.
///
/// Returns {q_chain_bin, p_chain_bin, grad_chain_bin, logp_chain_bin}
/// as little-endian f32 binaries:
///   q/p/grad: K * d * 4 bytes
///   logp:    K * 4 bytes
#[rustler::nif(schedule = "DirtyIo")]
fn leapfrog_chain_synth<'a>(
    env: Env<'a>,
    q_init: Binary<'a>,
    p_init: Binary<'a>,
    extras: Binary<'a>,
    push: Binary<'a>,
    k: u32,
    spv_path: String,
) -> NifResult<Term<'a>> {
    if q_init.len() != p_init.len() {
        return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
    }
    if k == 0 {
        return Ok((atoms::error(), atoms::bad_input()).encode(env));
    }
    if push.len() == 0 || push.len() > 128 {
        return Ok((atoms::error(), atoms::bad_input()).encode(env));
    }

    let push_block = match parse_push_block(push.as_slice()) {
        Ok(p) => p,
        Err(_) => return Ok((atoms::error(), atoms::bad_input()).encode(env)),
    };

    let d = push_block.d as usize;
    let chain_bytes = (k as usize) * d * 4;
    let logp_bytes = (k as usize) * 4;

    let context = match ctx() {
        Ok(c) => c,
        Err(e) => {
            return Ok((atoms::error(), atoms::vulkan_init_failed(), e).encode(env));
        }
    };

    let result = (|| -> Result<(Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>), String> {
        let cached = get_or_create_pipeline(&spv_path, None)?;
        let layout = cached.layout.clone();
        let pipeline = cached.pipeline.clone();

        let q_buf = upload_buffer(
            context.mem_allocator.clone(),
            q_init.as_slice(),
            BufferUsage::STORAGE_BUFFER,
        )?;
        let p_buf = upload_buffer(
            context.mem_allocator.clone(),
            p_init.as_slice(),
            BufferUsage::STORAGE_BUFFER,
        )?;
        let extras_buf = upload_buffer(
            context.mem_allocator.clone(),
            extras.as_slice(),
            BufferUsage::STORAGE_BUFFER,
        )?;

        let q_chain_buf = alloc_buffer(
            context.mem_allocator.clone(),
            chain_bytes,
            BufferUsage::STORAGE_BUFFER | BufferUsage::TRANSFER_SRC,
        )?;
        let p_chain_buf = alloc_buffer(
            context.mem_allocator.clone(),
            chain_bytes,
            BufferUsage::STORAGE_BUFFER | BufferUsage::TRANSFER_SRC,
        )?;
        let grad_chain_buf = alloc_buffer(
            context.mem_allocator.clone(),
            chain_bytes,
            BufferUsage::STORAGE_BUFFER | BufferUsage::TRANSFER_SRC,
        )?;
        let logp_chain_buf = alloc_buffer(
            context.mem_allocator.clone(),
            logp_bytes,
            BufferUsage::STORAGE_BUFFER | BufferUsage::TRANSFER_SRC,
        )?;

        let set = PersistentDescriptorSet::new(
            &context.set_allocator,
            layout.set_layouts()[0].clone(),
            [
                WriteDescriptorSet::buffer(0, q_buf.clone()),
                WriteDescriptorSet::buffer(1, p_buf.clone()),
                WriteDescriptorSet::buffer(2, extras_buf.clone()),
                WriteDescriptorSet::buffer(3, q_chain_buf.clone()),
                WriteDescriptorSet::buffer(4, p_chain_buf.clone()),
                WriteDescriptorSet::buffer(5, grad_chain_buf.clone()),
                WriteDescriptorSet::buffer(6, logp_chain_buf.clone()),
            ],
            [],
        )
        .map_err(|e| format!("descriptor set: {e}"))?;

        let mut cmd = AutoCommandBufferBuilder::primary(
            &context.cmd_allocator,
            context.queue.queue_family_index(),
            CommandBufferUsage::OneTimeSubmit,
        )
        .map_err(|e| format!("cmd builder: {e}"))?;

        cmd.bind_pipeline_compute(pipeline.clone())
            .map_err(|e| format!("bind pipeline: {e}"))?
            .bind_descriptor_sets(PipelineBindPoint::Compute, layout.clone(), 0, set.clone())
            .map_err(|e| format!("bind descriptor: {e}"))?
            .push_constants(layout.clone(), 0, push_block)
            .map_err(|e| format!("push_constants: {e}"))?
            .dispatch([1, 1, 1])
            .map_err(|e| format!("dispatch: {e}"))?;

        let cmd_buf = cmd.build().map_err(|e| format!("build cmd: {e}"))?;

        let future = sync::now(context.device.clone())
            .then_execute(context.queue.clone(), cmd_buf)
            .map_err(|e| format!("then_execute: {e}"))?
            .then_signal_fence_and_flush()
            .map_err(|e| format!("then_signal_fence_and_flush: {e}"))?;

        future.wait(None).map_err(|e| format!("wait: {e}"))?;

        Ok((
            download_buffer(q_chain_buf)?,
            download_buffer(p_chain_buf)?,
            download_buffer(grad_chain_buf)?,
            download_buffer(logp_chain_buf)?,
        ))
    })();

    match result {
        Ok((q, p, g, l)) => {
            let q_bin = bytes_to_nif_binary(env, &q);
            let p_bin = bytes_to_nif_binary(env, &p);
            let g_bin = bytes_to_nif_binary(env, &g);
            let l_bin = bytes_to_nif_binary(env, &l);
            Ok((atoms::ok(), (q_bin, p_bin, g_bin, l_bin)).encode(env))
        }
        Err(msg) => Ok((atoms::error(), atoms::dispatch_failed(), msg).encode(env)),
    }
}

fn bytes_to_nif_binary<'a>(env: Env<'a>, bytes: &[u8]) -> Binary<'a> {
    let mut bin = NewBinary::new(env, bytes.len());
    bin.as_mut_slice().copy_from_slice(bytes);
    bin.into()
}

// -- Buffer lifecycle NIFs ------------------------------------------------
//
// Sibling of the C++ shim's nxv_buf_* family, but every buffer is held
// behind a Rust Arc<Buffer> wrapped in a Subbuffer<[u8]> + ResourceArc.
// The stale-handle bug class is structurally absent: a Subbuffer cannot
// outlive its underlying Buffer because vulkano enforces it at the type
// level, and Rustler's ResourceArc Drop runs vulkano's Drop before any
// Elixir reference becomes dangling.

/// Allocate + upload a host binary into a fresh device buffer.
/// Returns `{:ok, resource}`.
#[rustler::nif(schedule = "DirtyIo")]
fn buf_upload<'a>(env: Env<'a>, data: Binary<'a>) -> NifResult<Term<'a>> {
    let context = match ctx() {
        Ok(c) => c,
        Err(e) => return Ok((atoms::error(), atoms::vulkan_init_failed(), e).encode(env)),
    };

    let buf = match upload_buffer(
        context.mem_allocator.clone(),
        data.as_slice(),
        BufferUsage::STORAGE_BUFFER | BufferUsage::TRANSFER_SRC | BufferUsage::TRANSFER_DST,
    ) {
        Ok(b) => b,
        Err(e) => return Ok((atoms::error(), atoms::upload_failed(), e).encode(env)),
    };

    let tensor = VulkanoTensor {
        buf,
        n_bytes: data.len() as u64,
    };
    Ok((atoms::ok(), ResourceArc::new(tensor)).encode(env))
}

/// Allocate a zero-initialised device buffer of `n_bytes`.
/// Returns `{:ok, resource}`.
#[rustler::nif(schedule = "DirtyIo")]
fn buf_alloc<'a>(env: Env<'a>, n_bytes: u64) -> NifResult<Term<'a>> {
    let context = match ctx() {
        Ok(c) => c,
        Err(e) => return Ok((atoms::error(), atoms::vulkan_init_failed(), e).encode(env)),
    };

    let buf = match alloc_buffer(
        context.mem_allocator.clone(),
        n_bytes as usize,
        BufferUsage::STORAGE_BUFFER | BufferUsage::TRANSFER_SRC | BufferUsage::TRANSFER_DST,
    ) {
        Ok(b) => b,
        Err(e) => return Ok((atoms::error(), atoms::upload_failed(), e).encode(env)),
    };

    let tensor = VulkanoTensor { buf, n_bytes };
    Ok((atoms::ok(), ResourceArc::new(tensor)).encode(env))
}

/// Download `tensor.n_bytes` bytes from a device buffer to the BEAM.
/// Returns `{:ok, binary}`.
#[rustler::nif(schedule = "DirtyIo")]
fn buf_download<'a>(env: Env<'a>, tensor: ResourceArc<VulkanoTensor>) -> NifResult<Term<'a>> {
    let bytes = match tensor.buf.read() {
        Ok(guard) => guard.to_vec(),
        Err(_) => return Ok((atoms::error(), atoms::download_failed()).encode(env)),
    };
    let bin = bytes_to_nif_binary(env, &bytes);
    Ok((atoms::ok(), bin).encode(env))
}

/// Tensor's buffer size in bytes.
#[rustler::nif]
fn buf_byte_size(tensor: ResourceArc<VulkanoTensor>) -> u64 {
    tensor.n_bytes
}

/// Overwrite an existing device buffer with new host data.
/// Returns `:ok` or `{:error, :size_mismatch}` when `data.len() != tensor.n_bytes`.
#[rustler::nif(schedule = "DirtyIo")]
fn buf_upload_into<'a>(
    env: Env<'a>,
    tensor: ResourceArc<VulkanoTensor>,
    data: Binary<'a>,
) -> NifResult<Term<'a>> {
    if data.len() as u64 != tensor.n_bytes {
        return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
    }
    let mut guard = match tensor.buf.write() {
        Ok(g) => g,
        Err(_) => return Ok((atoms::error(), atoms::upload_failed()).encode(env)),
    };
    guard.copy_from_slice(data.as_slice());
    Ok(rustler::types::atom::ok().encode(env))
}

// -- Compute NIFs ---------------------------------------------------------

#[derive(Clone, Copy, BufferContents)]
#[repr(C)]
struct PushN {
    n: u32,
}

/// Elementwise binary op. `op_code` selects:
///   0=add, 1=mul, 2=sub, 3=div, 4=pow, 5=max, 6=min
/// Bindings: a, b, out at 0, 1, 2. Push: uint n.
/// Workgroup: 256 threads, ceil(n/256) groups.
#[rustler::nif(schedule = "DirtyIo")]
fn apply_binary<'a>(
    env: Env<'a>,
    out_ref: ResourceArc<VulkanoTensor>,
    a_ref: ResourceArc<VulkanoTensor>,
    b_ref: ResourceArc<VulkanoTensor>,
    n: u32,
    op_code: u32,
    spv_path: String,
) -> NifResult<Term<'a>> {
    if a_ref.n_bytes != b_ref.n_bytes || a_ref.n_bytes != out_ref.n_bytes {
        return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
    }

    let context = match ctx() {
        Ok(c) => c,
        Err(e) => return Ok((atoms::error(), atoms::vulkan_init_failed(), e).encode(env)),
    };

    let result = (|| -> Result<(), String> {
        let cached = get_or_create_pipeline(&spv_path, Some(op_code as i32))?;
        let layout = cached.layout.clone();
        let pipeline = cached.pipeline.clone();

        let set = PersistentDescriptorSet::new(
            &context.set_allocator,
            layout.set_layouts()[0].clone(),
            [
                WriteDescriptorSet::buffer(0, a_ref.buf.clone()),
                WriteDescriptorSet::buffer(1, b_ref.buf.clone()),
                WriteDescriptorSet::buffer(2, out_ref.buf.clone()),
            ],
            [],
        )
        .map_err(|e| format!("descriptor set: {e}"))?;

        let groups = (n + 255) / 256;

        let mut cmd = AutoCommandBufferBuilder::primary(
            &context.cmd_allocator,
            context.queue.queue_family_index(),
            CommandBufferUsage::OneTimeSubmit,
        )
        .map_err(|e| format!("cmd builder: {e}"))?;

        cmd.bind_pipeline_compute(pipeline.clone())
            .map_err(|e| format!("bind pipeline: {e}"))?
            .bind_descriptor_sets(PipelineBindPoint::Compute, layout.clone(), 0, set.clone())
            .map_err(|e| format!("bind descriptor: {e}"))?
            .push_constants(layout.clone(), 0, PushN { n })
            .map_err(|e| format!("push_constants: {e}"))?
            .dispatch([groups, 1, 1])
            .map_err(|e| format!("dispatch: {e}"))?;

        let cmd_buf = cmd.build().map_err(|e| format!("build cmd: {e}"))?;

        let future = sync::now(context.device.clone())
            .then_execute(context.queue.clone(), cmd_buf)
            .map_err(|e| format!("then_execute: {e}"))?
            .then_signal_fence_and_flush()
            .map_err(|e| format!("fence: {e}"))?;
        future.wait(None).map_err(|e| format!("wait: {e}"))?;

        Ok(())
    })();

    match result {
        Ok(()) => Ok(rustler::types::atom::ok().encode(env)),
        Err(msg) => Ok((atoms::error(), atoms::dispatch_failed(), msg).encode(env)),
    }
}

/// Elementwise unary op. `op_code` selects:
///   0=exp 1=log 2=sqrt 3=abs 4=neg 5=sigmoid 6=tanh 7=relu
///   8=ceil 9=floor 10=sign 11=reciprocal 12=square
/// Bindings: a, out at 0, 1. Push: uint n. Workgroup: 256 threads.
#[rustler::nif(schedule = "DirtyIo")]
fn apply_unary<'a>(
    env: Env<'a>,
    out_ref: ResourceArc<VulkanoTensor>,
    a_ref: ResourceArc<VulkanoTensor>,
    n: u32,
    op_code: u32,
    spv_path: String,
) -> NifResult<Term<'a>> {
    if a_ref.n_bytes != out_ref.n_bytes {
        return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
    }

    let context = match ctx() {
        Ok(c) => c,
        Err(e) => return Ok((atoms::error(), atoms::vulkan_init_failed(), e).encode(env)),
    };

    let result = (|| -> Result<(), String> {
        let cached = get_or_create_pipeline(&spv_path, Some(op_code as i32))?;
        let layout = cached.layout.clone();
        let pipeline = cached.pipeline.clone();

        let set = PersistentDescriptorSet::new(
            &context.set_allocator,
            layout.set_layouts()[0].clone(),
            [
                WriteDescriptorSet::buffer(0, a_ref.buf.clone()),
                WriteDescriptorSet::buffer(1, out_ref.buf.clone()),
            ],
            [],
        )
        .map_err(|e| format!("descriptor set: {e}"))?;

        let groups = (n + 255) / 256;

        let mut cmd = AutoCommandBufferBuilder::primary(
            &context.cmd_allocator,
            context.queue.queue_family_index(),
            CommandBufferUsage::OneTimeSubmit,
        )
        .map_err(|e| format!("cmd builder: {e}"))?;

        cmd.bind_pipeline_compute(pipeline.clone())
            .map_err(|e| format!("bind pipeline: {e}"))?
            .bind_descriptor_sets(PipelineBindPoint::Compute, layout.clone(), 0, set.clone())
            .map_err(|e| format!("bind descriptor: {e}"))?
            .push_constants(layout.clone(), 0, PushN { n })
            .map_err(|e| format!("push_constants: {e}"))?
            .dispatch([groups, 1, 1])
            .map_err(|e| format!("dispatch: {e}"))?;

        let cmd_buf = cmd.build().map_err(|e| format!("build cmd: {e}"))?;
        let future = sync::now(context.device.clone())
            .then_execute(context.queue.clone(), cmd_buf)
            .map_err(|e| format!("then_execute: {e}"))?
            .then_signal_fence_and_flush()
            .map_err(|e| format!("fence: {e}"))?;
        future.wait(None).map_err(|e| format!("wait: {e}"))?;

        Ok(())
    })();

    match result {
        Ok(()) => Ok(rustler::types::atom::ok().encode(env)),
        Err(msg) => Ok((atoms::error(), atoms::dispatch_failed(), msg).encode(env)),
    }
}

#[derive(Clone, Copy, BufferContents)]
#[repr(C)]
struct PushReduceAxis {
    outer: u32,
    reduce_size: u32,
    inner: u32,
    op: u32,
}

/// Per-axis reduction. `op`: 0=sum, 1=max, 2=min.
/// Bindings: a, out. Push: {outer, reduce_size, inner, op}.
/// dispatch ceil(outer*inner/256) workgroups.
#[rustler::nif(schedule = "DirtyIo")]
fn reduce_axis<'a>(
    env: Env<'a>,
    out_ref: ResourceArc<VulkanoTensor>,
    a_ref: ResourceArc<VulkanoTensor>,
    outer: u32,
    reduce_size: u32,
    inner: u32,
    op_code: u32,
    spv_path: String,
) -> NifResult<Term<'a>> {
    let context = match ctx() {
        Ok(c) => c,
        Err(e) => return Ok((atoms::error(), atoms::vulkan_init_failed(), e).encode(env)),
    };

    let result = (|| -> Result<(), String> {
        let cached = get_or_create_pipeline(&spv_path, None)?;
        let layout = cached.layout.clone();
        let pipeline = cached.pipeline.clone();

        let set = PersistentDescriptorSet::new(
            &context.set_allocator,
            layout.set_layouts()[0].clone(),
            [
                WriteDescriptorSet::buffer(0, a_ref.buf.clone()),
                WriteDescriptorSet::buffer(1, out_ref.buf.clone()),
            ],
            [],
        )
        .map_err(|e| format!("descriptor set: {e}"))?;

        let n_slots = outer * inner;
        let groups = (n_slots + 255) / 256;

        let mut cmd = AutoCommandBufferBuilder::primary(
            &context.cmd_allocator,
            context.queue.queue_family_index(),
            CommandBufferUsage::OneTimeSubmit,
        )
        .map_err(|e| format!("cmd builder: {e}"))?;

        cmd.bind_pipeline_compute(pipeline.clone())
            .map_err(|e| format!("bind pipeline: {e}"))?
            .bind_descriptor_sets(PipelineBindPoint::Compute, layout.clone(), 0, set.clone())
            .map_err(|e| format!("bind descriptor: {e}"))?
            .push_constants(
                layout.clone(),
                0,
                PushReduceAxis {
                    outer,
                    reduce_size,
                    inner,
                    op: op_code,
                },
            )
            .map_err(|e| format!("push_constants: {e}"))?
            .dispatch([groups, 1, 1])
            .map_err(|e| format!("dispatch: {e}"))?;

        let cmd_buf = cmd.build().map_err(|e| format!("build cmd: {e}"))?;
        let future = sync::now(context.device.clone())
            .then_execute(context.queue.clone(), cmd_buf)
            .map_err(|e| format!("then_execute: {e}"))?
            .then_signal_fence_and_flush()
            .map_err(|e| format!("fence: {e}"))?;
        future.wait(None).map_err(|e| format!("wait: {e}"))?;

        Ok(())
    })();

    match result {
        Ok(()) => Ok(rustler::types::atom::ok().encode(env)),
        Err(msg) => Ok((atoms::error(), atoms::dispatch_failed(), msg).encode(env)),
    }
}

#[derive(Clone, Copy, BufferContents)]
#[repr(C)]
struct PushTranspose {
    m: u32,
    n: u32,
}

/// 2D transpose. Input A is M×N row-major; output is N×M row-major.
/// Bindings: a, out at 0, 1. Push: {m, n}. Workgroup 16×16.
#[rustler::nif(schedule = "DirtyIo")]
fn transpose_2d<'a>(
    env: Env<'a>,
    out_ref: ResourceArc<VulkanoTensor>,
    a_ref: ResourceArc<VulkanoTensor>,
    m: u32,
    n: u32,
    spv_path: String,
) -> NifResult<Term<'a>> {
    let context = match ctx() {
        Ok(c) => c,
        Err(e) => return Ok((atoms::error(), atoms::vulkan_init_failed(), e).encode(env)),
    };

    let result = (|| -> Result<(), String> {
        let spv_bytes = fs::read(&spv_path).map_err(|e| format!("read spv: {e}"))?;
        let spv_words = bytes_to_u32_words(&spv_bytes)?;

        let shader = unsafe {
            ShaderModule::new(context.device.clone(), ShaderModuleCreateInfo::new(&spv_words))
                .map_err(|e| format!("ShaderModule: {e}"))?
        };

        let entry = shader
            .entry_point("main")
            .ok_or_else(|| "no main entry point".to_string())?;
        let stage = PipelineShaderStageCreateInfo::new(entry);

        let layout_info = PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage])
            .into_pipeline_layout_create_info(context.device.clone())
            .map_err(|e| format!("layout info: {e}"))?;
        let layout = PipelineLayout::new(context.device.clone(), layout_info)
            .map_err(|e| format!("PipelineLayout: {e}"))?;

        let pipeline = ComputePipeline::new(
            context.device.clone(),
            None,
            ComputePipelineCreateInfo::stage_layout(stage, layout.clone()),
        )
        .map_err(|e| format!("ComputePipeline: {e}"))?;

        let set = PersistentDescriptorSet::new(
            &context.set_allocator,
            layout.set_layouts()[0].clone(),
            [
                WriteDescriptorSet::buffer(0, a_ref.buf.clone()),
                WriteDescriptorSet::buffer(1, out_ref.buf.clone()),
            ],
            [],
        )
        .map_err(|e| format!("descriptor set: {e}"))?;

        let gx = (n + 15) / 16;
        let gy = (m + 15) / 16;

        let mut cmd = AutoCommandBufferBuilder::primary(
            &context.cmd_allocator,
            context.queue.queue_family_index(),
            CommandBufferUsage::OneTimeSubmit,
        )
        .map_err(|e| format!("cmd builder: {e}"))?;

        cmd.bind_pipeline_compute(pipeline.clone())
            .map_err(|e| format!("bind pipeline: {e}"))?
            .bind_descriptor_sets(PipelineBindPoint::Compute, layout.clone(), 0, set.clone())
            .map_err(|e| format!("bind descriptor: {e}"))?
            .push_constants(layout.clone(), 0, PushTranspose { m, n })
            .map_err(|e| format!("push_constants: {e}"))?
            .dispatch([gx, gy, 1])
            .map_err(|e| format!("dispatch: {e}"))?;

        let cmd_buf = cmd.build().map_err(|e| format!("build cmd: {e}"))?;
        let future = sync::now(context.device.clone())
            .then_execute(context.queue.clone(), cmd_buf)
            .map_err(|e| format!("then_execute: {e}"))?
            .then_signal_fence_and_flush()
            .map_err(|e| format!("fence: {e}"))?;
        future.wait(None).map_err(|e| format!("wait: {e}"))?;

        Ok(())
    })();

    match result {
        Ok(()) => Ok(rustler::types::atom::ok().encode(env)),
        Err(msg) => Ok((atoms::error(), atoms::dispatch_failed(), msg).encode(env)),
    }
}

#[derive(Clone, Copy, BufferContents)]
#[repr(C)]
struct PushMatmul {
    m: u32,
    n: u32,
    k: u32,
}

/// 2D matmul. C = A · B where A is M×K, B is K×N, C is M×N.
/// All row-major f32. Bindings: a, b, out at 0, 1, 2. Push {m, n, k}.
/// Workgroup 16×16, dispatch ceil(N/16)×ceil(M/16).
#[rustler::nif(schedule = "DirtyIo")]
fn matmul<'a>(
    env: Env<'a>,
    out_ref: ResourceArc<VulkanoTensor>,
    a_ref: ResourceArc<VulkanoTensor>,
    b_ref: ResourceArc<VulkanoTensor>,
    m: u32,
    n: u32,
    k: u32,
    spv_path: String,
) -> NifResult<Term<'a>> {
    let context = match ctx() {
        Ok(c) => c,
        Err(e) => return Ok((atoms::error(), atoms::vulkan_init_failed(), e).encode(env)),
    };

    let result = (|| -> Result<(), String> {
        let spv_bytes = fs::read(&spv_path).map_err(|e| format!("read spv: {e}"))?;
        let spv_words = bytes_to_u32_words(&spv_bytes)?;

        let shader = unsafe {
            ShaderModule::new(context.device.clone(), ShaderModuleCreateInfo::new(&spv_words))
                .map_err(|e| format!("ShaderModule: {e}"))?
        };

        let entry = shader
            .entry_point("main")
            .ok_or_else(|| "no main entry point".to_string())?;
        let stage = PipelineShaderStageCreateInfo::new(entry);

        let layout_info = PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage])
            .into_pipeline_layout_create_info(context.device.clone())
            .map_err(|e| format!("layout info: {e}"))?;
        let layout = PipelineLayout::new(context.device.clone(), layout_info)
            .map_err(|e| format!("PipelineLayout: {e}"))?;

        let pipeline = ComputePipeline::new(
            context.device.clone(),
            None,
            ComputePipelineCreateInfo::stage_layout(stage, layout.clone()),
        )
        .map_err(|e| format!("ComputePipeline: {e}"))?;

        let set = PersistentDescriptorSet::new(
            &context.set_allocator,
            layout.set_layouts()[0].clone(),
            [
                WriteDescriptorSet::buffer(0, a_ref.buf.clone()),
                WriteDescriptorSet::buffer(1, b_ref.buf.clone()),
                WriteDescriptorSet::buffer(2, out_ref.buf.clone()),
            ],
            [],
        )
        .map_err(|e| format!("descriptor set: {e}"))?;

        let gx = (n + 15) / 16;
        let gy = (m + 15) / 16;

        let mut cmd = AutoCommandBufferBuilder::primary(
            &context.cmd_allocator,
            context.queue.queue_family_index(),
            CommandBufferUsage::OneTimeSubmit,
        )
        .map_err(|e| format!("cmd builder: {e}"))?;

        cmd.bind_pipeline_compute(pipeline.clone())
            .map_err(|e| format!("bind pipeline: {e}"))?
            .bind_descriptor_sets(PipelineBindPoint::Compute, layout.clone(), 0, set.clone())
            .map_err(|e| format!("bind descriptor: {e}"))?
            .push_constants(layout.clone(), 0, PushMatmul { m, n, k })
            .map_err(|e| format!("push_constants: {e}"))?
            .dispatch([gx, gy, 1])
            .map_err(|e| format!("dispatch: {e}"))?;

        let cmd_buf = cmd.build().map_err(|e| format!("build cmd: {e}"))?;
        let future = sync::now(context.device.clone())
            .then_execute(context.queue.clone(), cmd_buf)
            .map_err(|e| format!("then_execute: {e}"))?
            .then_signal_fence_and_flush()
            .map_err(|e| format!("fence: {e}"))?;
        future.wait(None).map_err(|e| format!("wait: {e}"))?;

        Ok(())
    })();

    match result {
        Ok(()) => Ok(rustler::types::atom::ok().encode(env)),
        Err(msg) => Ok((atoms::error(), atoms::dispatch_failed(), msg).encode(env)),
    }
}

fn load(env: rustler::Env, _info: rustler::Term) -> bool {
    rustler::resource!(VulkanoTensor, env);
    true
}

rustler::init!(
    "Elixir.Nx.Vulkan.NativeV",
    [
        leapfrog_chain_synth,
        buf_upload,
        buf_alloc,
        buf_download,
        buf_byte_size,
        buf_upload_into,
        apply_binary,
        apply_unary,
        reduce_axis,
        transpose_2d,
        matmul,
    ],
    load = load
);