Skip to main content

c_src/nx_vulkan_shim.cpp

/* nx_vulkan_shim.cpp — flat C ABI on top of spirit::Engine::Backend::vulkan.
 *
 * Includes spirit's Vulkan backend header at compile time. The header
 * lives at ../../../spirit/core/include/engine/Backend_par_vulkan.hpp;
 * the source at ../../../spirit/core/src/engine/Backend_par_vulkan.cpp.
 * build.rs adds those paths to the include + source list.
 */

#include "nx_vulkan_shim.h"
#include <engine/Backend_par_vulkan.hpp>
#include <cstdio>
#include <cstring>
#include <map>
#include <mutex>
#include <string>
#include <unistd.h>
#include <utility>
#include <vector>

using namespace Engine::Backend::vulkan;

/* ---------------------------------------------------------------- *
 * Buffer pool — persistent device-resident allocations
 *
 * Per the EXMC port plan step 1a (PATH_TO_FULL_PASS.md): the largest
 * per-op overhead in the NIF path is vkAllocateMemory + vkFreeMemory
 * on every call. A NUTS leapfrog allocates ~30 × N steps × 1000 of
 * fresh output buffers; pool them by byte-size and cycle counts drop
 * 100-700×.
 *
 * Lifetime model (matches PERSISTENT_BUFFERS_PLAN.md decision #1):
 * caller-owned. nxv_buf_alloc returns a VkBuf (from pool when match
 * exists, fresh otherwise); nxv_buf_free returns to the pool instead
 * of releasing. The Rust ResourceArc Drop hits nxv_buf_free, so
 * tensor GC silently feeds the pool. nxv_pool_clear actually frees
 * everything back to the device — call at idle time. nxv_destroy
 * also flushes the pool.
 *
 * Concurrency: pool has its own mutex. Independent of SUBMIT_LOCK
 * which lives in the Rust NIF and serialises queue submits. Any NIF
 * call path that reaches the pool must be safe regardless of which
 * lock the caller holds.
 *
 * Eviction: soft cap per size class of POOL_CAP_PER_SIZE entries.
 * Beyond that, actually vkFreeMemory the buffer instead of pooling.
 * Prevents runaway in adversarial allocation patterns.
 * ---------------------------------------------------------------- */

static const size_t POOL_CAP_PER_SIZE = 64;
static std::mutex g_pool_mutex;
static std::map<unsigned long, std::vector<VkBuf*>> g_buf_pool;
static unsigned long g_pool_hits = 0;
static unsigned long g_pool_misses = 0;
static unsigned long g_pool_freed = 0;

/* Internal: allocate a fresh VkBuf bypassing the pool. Used by
 * nxv_buf_alloc when the pool is empty for the requested size. */
static VkBuf* alloc_fresh(unsigned long n_bytes) {
    VkBuf* buf = new VkBuf();
    VkBufferUsageFlags usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
                               VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
                               VK_BUFFER_USAGE_TRANSFER_DST_BIT;
    VkMemoryPropertyFlags mem = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT;

    int rc = buf_alloc(buf, (VkDeviceSize) n_bytes, usage, mem);
    if (rc != 0) {
        delete buf;
        return nullptr;
    }
    return buf;
}

/* Internal: actually release a VkBuf to the device. Used when the
 * pool overflows or is being cleared. */
static void release_to_device(VkBuf* buf) {
    if (!buf) return;
    buf_free(buf);
    delete buf;
}

/* The selected device name, cached after nxv_init so we can hand out
 * a stable pointer to Rust. */
static std::string g_device_name;

/* Pipeline cache keyed on (spv_path, op_spec_constant). Persistent
 * across calls — first dispatch pays the create cost, subsequent ones
 * reuse the pipeline. Cleared in nxv_destroy.
 *
 * push_size for each shader family:
 *   binary    : sizeof(uint)        (n)
 *   unary     : sizeof(uint)        (n)
 *   reduce    : sizeof(uint)        (n)        — handled by spirit's reduce()
 *   matmul    : 3 * sizeof(uint)    (M, N, K)
 *   random    : 2 * sizeof(uint)    (n, seed)
 *
 * For the cache key we add n_buffers because matmul/binary share a
 * spec-const-0 path with different binding counts. */
struct PipeKey {
    std::string path;
    unsigned int op;
    unsigned int n_buffers;
    bool operator<(const PipeKey& o) const {
        if (path != o.path) return path < o.path;
        if (op != o.op) return op < o.op;
        return n_buffers < o.n_buffers;
    }
};
static std::map<PipeKey, VkPipe*> g_pipe_cache;

static VkPipe* get_or_create_pipe(const std::string& spv_path, unsigned int op,
                                  unsigned int n_buffers) {
    PipeKey key{spv_path, op, n_buffers};
    auto it = g_pipe_cache.find(key);
    if (it != g_pipe_cache.end()) return it->second;

    VkShaderModule shader = load_shader(spv_path);
    if (!shader) return nullptr;

    /* push_size: declare 72 (max across all shader families). Vulkan
     * ignores any push range bytes the shader doesn't read. binary/
     * unary/random/transpose all use ≤12; reduce_axis uses 16;
     * fused_elementwise uses 40 (n + n_ops + 8 op slots);
     * elementwise_binary_broadcast uses 56 (n + ndim + out_shape[4] +
     * a_strides[4] + b_strides[4]); fused_elementwise_4in uses 72
     * (n + n_ops + ops[8] + buf_idx[8]). */
    uint32_t push_size = 72;

    VkPipe* pipe = new VkPipe();
    int rc = create_pipeline(pipe, shader, n_buffers, push_size, (int32_t) op);
    if (rc != 0) {
        delete pipe;
        return nullptr;
    }

    g_pipe_cache[key] = pipe;
    return pipe;
}

extern "C" {

int nxv_init(void) {
    int rc = vk_init(0);
    if (rc != 0) return rc;
    g_device_name = g_vk_ctx.device_props.deviceName;
    return 0;
}

void nxv_destroy(void) {
    /* Tear down cached pipelines first (they reference the device). */
    for (auto& kv : g_pipe_cache) {
        destroy_pipeline(kv.second);
        delete kv.second;
    }
    g_pipe_cache.clear();

    /* Flush the buffer pool so vkDestroyDevice doesn't see leaked allocations. */
    nxv_pool_clear();

    vk_destroy();
    g_device_name.clear();
}

const char* nxv_device_name(void) {
    return g_device_name.empty() ? nullptr : g_device_name.c_str();
}

int nxv_has_f64(void) {
    return g_vk_ctx.has_float64 ? 1 : 0;
}

int nxv_pipeline_cache_load(const char* path) {
    if (!path) return -1;
    FILE* f = fopen(path, "rb");
    if (!f) return 0;  // missing file is OK — start with empty cache

    fseek(f, 0, SEEK_END);
    long sz = ftell(f);
    fseek(f, 0, SEEK_SET);
    if (sz < 0) { fclose(f); return -2; }

    std::vector<uint8_t> buf((size_t) sz);
    size_t got = fread(buf.data(), 1, (size_t) sz, f);
    fclose(f);
    if (got != (size_t) sz) return -3;

    return Engine::Backend::vulkan::pipeline_cache_create(buf.data(), buf.size());
}

int nxv_pipeline_cache_persist(const char* path) {
    if (!path) return -1;

    size_t sz = 0;
    int rc = Engine::Backend::vulkan::pipeline_cache_get_data(nullptr, &sz);
    if (rc != 0) return -2;
    if (sz == 0) return 0;  // nothing to persist

    std::vector<uint8_t> buf(sz);
    rc = Engine::Backend::vulkan::pipeline_cache_get_data(buf.data(), &sz);
    if (rc != 0) return -3;

    // Atomic write: temp file in same dir then rename.
    std::string tmp(path);
    tmp += ".tmp.";
    tmp += std::to_string((unsigned long) getpid());

    FILE* f = fopen(tmp.c_str(), "wb");
    if (!f) return -4;
    size_t wrote = fwrite(buf.data(), 1, sz, f);
    if (wrote != sz) { fclose(f); std::remove(tmp.c_str()); return -5; }
    fflush(f);
    fclose(f);

    if (std::rename(tmp.c_str(), path) != 0) {
        std::remove(tmp.c_str());
        return -6;
    }
    return 0;
}

int nxv_device_uuid(unsigned char out[16]) {
    if (!out) return -1;
    auto& props = Engine::Backend::vulkan::g_vk_ctx.device_props;
    if (!Engine::Backend::vulkan::g_vk_ctx.device) return -2;
    memcpy(out, props.pipelineCacheUUID, 16);
    return 0;
}

void nxv_timing_reset(void) {
    Engine::Backend::vulkan::timing_reset();
}

void nxv_timing_get(unsigned long long* count,
                    unsigned long long* dispatch_ns,
                    unsigned long long* submit_ns,
                    unsigned long long* wait_ns,
                    unsigned long long* record_ns) {
    uint64_t c = 0, d = 0, s = 0, w = 0, r = 0;
    Engine::Backend::vulkan::timing_get(&c, &d, &s, &w, &r);
    if (count)       *count = (unsigned long long) c;
    if (dispatch_ns) *dispatch_ns = (unsigned long long) d;
    if (submit_ns)   *submit_ns = (unsigned long long) s;
    if (wait_ns)     *wait_ns = (unsigned long long) w;
    if (record_ns)   *record_ns = (unsigned long long) r;
}

int nxv_buf_download_batch(void** srcs, void** out_data,
                           const unsigned long* sizes, unsigned int n_buffers) {
    if (!srcs || !out_data || !sizes || n_buffers == 0) return -1;
    std::vector<Engine::Backend::vulkan::VkBuf*> bufs(n_buffers);
    std::vector<VkDeviceSize> vsizes(n_buffers);
    for (unsigned int i = 0; i < n_buffers; i++) {
        bufs[i] = (Engine::Backend::vulkan::VkBuf*) srcs[i];
        vsizes[i] = (VkDeviceSize) sizes[i];
    }
    return Engine::Backend::vulkan::download_batch(
        bufs.data(), out_data, vsizes.data(), n_buffers);
}

int nxv_buf_upload_batch(void** dsts, const void** data,
                         const unsigned long* sizes, unsigned int n_buffers) {
    if (!dsts || !data || !sizes || n_buffers == 0) return -1;
    std::vector<Engine::Backend::vulkan::VkBuf*> bufs(n_buffers);
    std::vector<VkDeviceSize> vsizes(n_buffers);
    for (unsigned int i = 0; i < n_buffers; i++) {
        bufs[i] = (Engine::Backend::vulkan::VkBuf*) dsts[i];
        vsizes[i] = (VkDeviceSize) sizes[i];
    }
    return Engine::Backend::vulkan::upload_batch(
        bufs.data(), data, vsizes.data(), n_buffers);
}

/* Tensor primitives — heap-allocate a VkBuf so the handle survives
 * across NIF calls. Lifetime is owned by the Rust ResourceArc; when
 * the Elixir reference is GC'd, Rust calls nxv_buf_free which
 * delegates to spirit's buf_free + delete. */

void* nxv_buf_alloc(unsigned long n_bytes) {
    /* Pool fast path: if a buffer of exactly this size is free, reuse. */
    {
        std::lock_guard<std::mutex> lk(g_pool_mutex);
        auto it = g_buf_pool.find(n_bytes);
        if (it != g_buf_pool.end() && !it->second.empty()) {
            VkBuf* reused = it->second.back();
            it->second.pop_back();
            g_pool_hits++;
            return (void*) reused;
        }
    }

    /* Slow path: actually call vkAllocateMemory. */
    g_pool_misses++;
    VkBuf* buf = alloc_fresh(n_bytes);
    return (void*) buf;
}

void nxv_buf_free(void* handle) {
    if (!handle) return;
    VkBuf* buf = (VkBuf*) handle;

    /* Match the n_bytes back from the buffer's recorded size. spirit's
     * VkBuf carries `VkDeviceSize size` after buf_alloc. */
    unsigned long n_bytes = (unsigned long) buf->size;

    {
        std::lock_guard<std::mutex> lk(g_pool_mutex);
        auto& slot = g_buf_pool[n_bytes];
        if (slot.size() < POOL_CAP_PER_SIZE) {
            slot.push_back(buf);
            return;  /* Pooled — caller no longer owns the handle. */
        }
    }

    /* Pool is full for this size class — actually release. */
    release_to_device(buf);
    g_pool_freed++;
}

void nxv_pool_clear(void) {
    std::lock_guard<std::mutex> lk(g_pool_mutex);
    for (auto& kv : g_buf_pool) {
        for (VkBuf* buf : kv.second) {
            release_to_device(buf);
            g_pool_freed++;
        }
        kv.second.clear();
    }
    g_buf_pool.clear();
}

void nxv_pool_stats(unsigned long* hits, unsigned long* misses,
                    unsigned long* freed, unsigned long* size_classes,
                    unsigned long* total_pooled) {
    std::lock_guard<std::mutex> lk(g_pool_mutex);
    if (hits) *hits = g_pool_hits;
    if (misses) *misses = g_pool_misses;
    if (freed) *freed = g_pool_freed;
    if (size_classes) *size_classes = g_buf_pool.size();
    if (total_pooled) {
        unsigned long total = 0;
        for (auto& kv : g_buf_pool) total += kv.second.size();
        *total_pooled = total;
    }
}

int nxv_buf_upload(void* handle, const void* data, unsigned long n_bytes) {
    if (!handle || !data) return -1;
    VkBuf* buf = (VkBuf*) handle;
    return upload(buf, data, (VkDeviceSize) n_bytes);
}

int nxv_buf_download(void* handle, void* data, unsigned long n_bytes) {
    if (!handle || !data) return -1;
    VkBuf* buf = (VkBuf*) handle;
    return download(buf, data, (VkDeviceSize) n_bytes);
}

int nxv_apply_binary(void* out, void* a, void* b,
                     unsigned int n, unsigned int op,
                     const char* spv_path) {
    if (!out || !a || !b || !spv_path) return -1;

    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), op, 3);
    if (!pipe) return -2;

    VkBuf* buf_a   = (VkBuf*) a;
    VkBuf* buf_b   = (VkBuf*) b;
    VkBuf* buf_out = (VkBuf*) out;

    /* Shader binding order: a, b, out. Push constant: n. */
    VkBuffer bufs[3] = { buf_a->buffer, buf_b->buffer, buf_out->buffer };
    unsigned int push_n = n;
    unsigned int groups = (n + 255) / 256;

    return dispatch(pipe, bufs, 3, groups, sizeof(unsigned int), &push_n);
}

int nxv_apply_unary(void* out, void* a,
                    unsigned int n, unsigned int op,
                    const char* spv_path) {
    if (!out || !a || !spv_path) return -1;
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), op, 2);
    if (!pipe) return -2;

    VkBuf* buf_a   = (VkBuf*) a;
    VkBuf* buf_out = (VkBuf*) out;

    /* Shader binding order: a, out. Push constant: n. */
    VkBuffer bufs[2] = { buf_a->buffer, buf_out->buffer };
    unsigned int push_n = n;
    unsigned int groups = (n + 255) / 256;

    return dispatch(pipe, bufs, 2, groups, sizeof(unsigned int), &push_n);
}

int nxv_reduce(float* out_scalar, void* in, unsigned int n, unsigned int op,
               const char* spv_path) {
    if (!out_scalar || !in || !spv_path) return -1;
    VkBuf* buf_in = (VkBuf*) in;
    *out_scalar = reduce(buf_in, (int) n, (ReduceOp) op, std::string(spv_path));
    return 0;
}

int nxv_matmul(void* out, void* a, void* b,
               unsigned int m, unsigned int n, unsigned int k,
               const char* spv_path) {
    if (!out || !a || !b || !spv_path) return -1;

    /* matmul has no spec constant; cache key is just the spv path. */
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 3);
    if (!pipe) return -2;

    VkBuf* buf_a   = (VkBuf*) a;
    VkBuf* buf_b   = (VkBuf*) b;
    VkBuf* buf_out = (VkBuf*) out;

    /* matmul uses 2D dispatch (gx, gy) but spirit's dispatch helper is
     * 1D-only. Inline the dispatch dance — same pattern as
     * test_matmul.cpp. Push constants: M, N, K (12 bytes). */
    auto& ctx = g_vk_ctx;

    VkBuffer bufs[3] = { buf_a->buffer, buf_b->buffer, buf_out->buffer };

    VkDescriptorBufferInfo bi[3];
    VkWriteDescriptorSet w[3];
    for (int i = 0; i < 3; i++) {
        bi[i] = {bufs[i], 0, VK_WHOLE_SIZE};
        w[i] = {};
        w[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
        w[i].dstSet = pipe->descriptor_set;
        w[i].dstBinding = (uint32_t) i;
        w[i].descriptorCount = 1;
        w[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
        w[i].pBufferInfo = &bi[i];
    }
    vkUpdateDescriptorSets(ctx.device, 3, w, 0, nullptr);

    VkCommandBufferAllocateInfo ai{};
    ai.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
    ai.commandPool = ctx.command_pool;
    ai.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
    ai.commandBufferCount = 1;
    VkCommandBuffer cmd;
    vkAllocateCommandBuffers(ctx.device, &ai, &cmd);

    VkCommandBufferBeginInfo bb{};
    bb.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
    bb.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
    vkBeginCommandBuffer(cmd, &bb);
    vkCmdBindPipeline(cmd, VK_PIPELINE_BIND_POINT_COMPUTE, pipe->pipeline);
    vkCmdBindDescriptorSets(cmd, VK_PIPELINE_BIND_POINT_COMPUTE,
                            pipe->pipeline_layout, 0, 1, &pipe->descriptor_set, 0, nullptr);

    unsigned int push[3] = { m, n, k };
    vkCmdPushConstants(cmd, pipe->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
                       0, sizeof(push), push);

    unsigned int gx = (n + 15) / 16;
    unsigned int gy = (m + 15) / 16;
    vkCmdDispatch(cmd, gx, gy, 1);
    vkEndCommandBuffer(cmd);

    VkSubmitInfo si{};
    si.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
    si.commandBufferCount = 1;
    si.pCommandBuffers = &cmd;
    vkQueueSubmit(ctx.compute_queue, 1, &si, VK_NULL_HANDLE);
    vkQueueWaitIdle(ctx.compute_queue);
    vkFreeCommandBuffers(ctx.device, ctx.command_pool, 1, &cmd);

    return 0;
}

int nxv_matmul_v(void* out, void* a, void* b,
                 unsigned int m, unsigned int n, unsigned int k,
                 unsigned int tile_m, unsigned int tile_n,
                 const char* spv_path) {
    if (!out || !a || !b || !spv_path) return -1;

    /* Cache key on path only — matmul has no spec constant. */
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 3);
    if (!pipe) return -2;

    VkBuf* buf_a   = (VkBuf*) a;
    VkBuf* buf_b   = (VkBuf*) b;
    VkBuf* buf_out = (VkBuf*) out;

    /* Same dispatch dance as nxv_matmul; only the grid changes. */
    auto& ctx = g_vk_ctx;

    VkBuffer bufs[3] = { buf_a->buffer, buf_b->buffer, buf_out->buffer };

    VkDescriptorBufferInfo bi[3];
    VkWriteDescriptorSet w[3];
    for (int i = 0; i < 3; i++) {
        bi[i] = {bufs[i], 0, VK_WHOLE_SIZE};
        w[i] = {};
        w[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
        w[i].dstSet = pipe->descriptor_set;
        w[i].dstBinding = (uint32_t) i;
        w[i].descriptorCount = 1;
        w[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
        w[i].pBufferInfo = &bi[i];
    }
    vkUpdateDescriptorSets(ctx.device, 3, w, 0, nullptr);

    VkCommandBufferAllocateInfo ai{};
    ai.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
    ai.commandPool = ctx.command_pool;
    ai.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
    ai.commandBufferCount = 1;
    VkCommandBuffer cmd;
    vkAllocateCommandBuffers(ctx.device, &ai, &cmd);

    VkCommandBufferBeginInfo bb{};
    bb.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
    bb.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
    vkBeginCommandBuffer(cmd, &bb);
    vkCmdBindPipeline(cmd, VK_PIPELINE_BIND_POINT_COMPUTE, pipe->pipeline);
    vkCmdBindDescriptorSets(cmd, VK_PIPELINE_BIND_POINT_COMPUTE,
                            pipe->pipeline_layout, 0, 1, &pipe->descriptor_set, 0, nullptr);

    unsigned int push[3] = { m, n, k };
    vkCmdPushConstants(cmd, pipe->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
                       0, sizeof(push), push);

    unsigned int gx = (n + tile_n - 1) / tile_n;
    unsigned int gy = (m + tile_m - 1) / tile_m;
    vkCmdDispatch(cmd, gx, gy, 1);
    vkEndCommandBuffer(cmd);

    VkSubmitInfo si{};
    si.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
    si.commandBufferCount = 1;
    si.pCommandBuffers = &cmd;
    vkQueueSubmit(ctx.compute_queue, 1, &si, VK_NULL_HANDLE);
    vkQueueWaitIdle(ctx.compute_queue);
    vkFreeCommandBuffers(ctx.device, ctx.command_pool, 1, &cmd);

    return 0;
}

int nxv_random(void* out, unsigned int n, unsigned int seed, unsigned int dist,
               const char* spv_path) {
    if (!out || !spv_path) return -1;
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), dist, 1);
    if (!pipe) return -2;

    VkBuf* buf_out = (VkBuf*) out;

    /* Shader binding: out only. Push constants: {n, seed} = 8 bytes. */
    VkBuffer bufs[1] = { buf_out->buffer };
    struct { unsigned int n; unsigned int seed; } push = { n, seed };
    unsigned int groups = (n + 255) / 256;

    return dispatch(pipe, bufs, 1, groups, sizeof(push), &push);
}

int nxv_cast(void* out, void* a, unsigned int n, const char* spv_path) {
    if (!out || !a || !spv_path) return -1;
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 2);
    if (!pipe) return -2;

    VkBuf* buf_a   = (VkBuf*) a;
    VkBuf* buf_out = (VkBuf*) out;

    /* Shader binding order: a, out. Push: n. */
    VkBuffer bufs[2] = { buf_a->buffer, buf_out->buffer };
    unsigned int push_n = n;
    unsigned int groups = (n + 255) / 256;

    return dispatch(pipe, bufs, 2, groups, sizeof(unsigned int), &push_n);
}

int nxv_reduce_axis(void* out, void* a,
                    unsigned int outer, unsigned int reduce_size, unsigned int inner,
                    unsigned int op,
                    const char* spv_path) {
    if (!out || !a || !spv_path) return -1;
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 2);
    if (!pipe) return -2;

    VkBuf* buf_a   = (VkBuf*) a;
    VkBuf* buf_out = (VkBuf*) out;

    /* Shader binding order: a, out. Push: {outer, reduce_size, inner, op}. */
    VkBuffer bufs[2] = { buf_a->buffer, buf_out->buffer };
    unsigned int push[4] = { outer, reduce_size, inner, op };
    unsigned int n_slots = outer * inner;
    unsigned int groups = (n_slots + 255) / 256;

    return dispatch(pipe, bufs, 2, groups, sizeof(push), push);
}

int nxv_apply_binary_broadcast(void* out, void* a, void* b,
                                unsigned int op, unsigned int ndim,
                                const unsigned int* out_shape,
                                const unsigned int* a_strides,
                                const unsigned int* b_strides,
                                const char* spv_path) {
    if (!out || !a || !b || !out_shape || !a_strides || !b_strides || !spv_path)
        return -1;

    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), op, 3);
    if (!pipe) return -2;

    VkBuf* buf_a   = (VkBuf*) a;
    VkBuf* buf_b   = (VkBuf*) b;
    VkBuf* buf_out = (VkBuf*) out;

    /* Compute total output count from shape (multiply non-zero dims).
     * Shape entries beyond ndim are zero-padded. */
    unsigned int n = 1;
    for (unsigned int d = 0; d < ndim; d++) n *= out_shape[d];

    /* Push: {n, ndim, out_shape[4], a_strides[4], b_strides[4]} = 56 bytes. */
    struct {
        unsigned int n;
        unsigned int ndim;
        unsigned int out_shape[4];
        unsigned int a_strides[4];
        unsigned int b_strides[4];
    } push;
    push.n = n;
    push.ndim = ndim;
    for (int i = 0; i < 4; i++) {
        push.out_shape[i] = out_shape[i];
        push.a_strides[i] = a_strides[i];
        push.b_strides[i] = b_strides[i];
    }

    VkBuffer bufs[3] = { buf_a->buffer, buf_b->buffer, buf_out->buffer };
    unsigned int groups = (n + 255) / 256;

    return dispatch(pipe, bufs, 3, groups, sizeof(push), &push);
}

int nxv_kinetic_energy(void* out, void* p, void* inv_mass,
                        unsigned int n, const char* spv_path) {
    if (!out || !p || !inv_mass || !spv_path) return -1;
    /* 3 buffers: p, inv_mass, out. */
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 3);
    if (!pipe) return -2;

    VkBuf* buf_p   = (VkBuf*) p;
    VkBuf* buf_m   = (VkBuf*) inv_mass;
    VkBuf* buf_out = (VkBuf*) out;

    VkBuffer bufs[3] = { buf_p->buffer, buf_m->buffer, buf_out->buffer };
    unsigned int push_n = n;
    unsigned int groups = (n + 255) / 256;

    return dispatch(pipe, bufs, 3, groups, sizeof(unsigned int), &push_n);
}

int nxv_normal_logpdf(void* out, void* x, void* mu, void* sigma,
                       unsigned int n, const char* spv_path) {
    if (!out || !x || !mu || !sigma || !spv_path) return -1;
    /* The shader has 4 buffers: x, mu, sigma, out. */
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 4);
    if (!pipe) return -2;

    VkBuf* buf_x   = (VkBuf*) x;
    VkBuf* buf_mu  = (VkBuf*) mu;
    VkBuf* buf_s   = (VkBuf*) sigma;
    VkBuf* buf_out = (VkBuf*) out;

    VkBuffer bufs[4] = { buf_x->buffer, buf_mu->buffer, buf_s->buffer, buf_out->buffer };
    unsigned int push_n = n;
    unsigned int groups = (n + 255) / 256;

    return dispatch(pipe, bufs, 4, groups, sizeof(unsigned int), &push_n);
}

int nxv_leapfrog_chain_normal_lg(void* q_chain, void* p_chain,
                                  void* grad_chain, void* partial_logp,
                                  void* q_init, void* p_init, void* inv_mass,
                                  unsigned int n, unsigned int K,
                                  unsigned int num_workgroups,
                                  float eps, float mu, float sigma,
                                  const char* spv_path) {
    if (!q_chain || !p_chain || !grad_chain || !partial_logp ||
        !q_init || !p_init || !inv_mass || !spv_path) return -1;
    /* Same 7-buffer binding order as the single-workgroup variant. */
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 7);
    if (!pipe) return -2;

    VkBuf* buf_qi = (VkBuf*) q_init;
    VkBuf* buf_pi = (VkBuf*) p_init;
    VkBuf* buf_m  = (VkBuf*) inv_mass;
    VkBuf* buf_qc = (VkBuf*) q_chain;
    VkBuf* buf_pc = (VkBuf*) p_chain;
    VkBuf* buf_gc = (VkBuf*) grad_chain;
    VkBuf* buf_pl = (VkBuf*) partial_logp;

    VkBuffer bufs[7] = {
        buf_qi->buffer, buf_pi->buffer, buf_m->buffer,
        buf_qc->buffer, buf_pc->buffer, buf_gc->buffer, buf_pl->buffer
    };

    /* Push: {n, K, num_workgroups, eps, mu, sigma} = 24 bytes. */
    struct {
        unsigned int n;
        unsigned int K;
        unsigned int num_workgroups;
        float eps;
        float mu;
        float sigma;
    } push;
    push.n = n;
    push.K = K;
    push.num_workgroups = num_workgroups;
    push.eps = eps;
    push.mu = mu;
    push.sigma = sigma;

    /* Multi-workgroup: dispatch ceil(n/256) workgroups. */
    return dispatch(pipe, bufs, 7, num_workgroups, sizeof(push), &push);
}

int nxv_leapfrog_chain_exponential(void* q_chain, void* p_chain,
                                    void* grad_chain, void* logp_chain,
                                    void* q_init, void* p_init, void* inv_mass,
                                    unsigned int n, unsigned int K,
                                    float eps, float lambda,
                                    const char* spv_path) {
    if (!q_chain || !p_chain || !grad_chain || !logp_chain ||
        !q_init || !p_init || !inv_mass || !spv_path) return -1;
    /* Same 7-buffer binding order as Normal chain. */
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 7);
    if (!pipe) return -2;

    VkBuf* buf_qi = (VkBuf*) q_init;
    VkBuf* buf_pi = (VkBuf*) p_init;
    VkBuf* buf_m  = (VkBuf*) inv_mass;
    VkBuf* buf_qc = (VkBuf*) q_chain;
    VkBuf* buf_pc = (VkBuf*) p_chain;
    VkBuf* buf_gc = (VkBuf*) grad_chain;
    VkBuf* buf_lc = (VkBuf*) logp_chain;

    VkBuffer bufs[7] = {
        buf_qi->buffer, buf_pi->buffer, buf_m->buffer,
        buf_qc->buffer, buf_pc->buffer, buf_gc->buffer, buf_lc->buffer
    };

    /* Push: {n, K, eps, lambda} = 16 bytes. */
    struct {
        unsigned int n;
        unsigned int K;
        float eps;
        float lambda;
    } push;
    push.n = n;
    push.K = K;
    push.eps = eps;
    push.lambda = lambda;

    /* Single workgroup (n <= 256). */
    return dispatch(pipe, bufs, 7, 1, sizeof(push), &push);
}

int nxv_leapfrog_chain_studentt(void* q_chain, void* p_chain,
                                 void* grad_chain, void* logp_chain,
                                 void* q_init, void* p_init, void* inv_mass,
                                 unsigned int n, unsigned int K,
                                 float eps, float mu, float sigma,
                                 float nu, float logp_const,
                                 const char* spv_path) {
    if (!q_chain || !p_chain || !grad_chain || !logp_chain ||
        !q_init || !p_init || !inv_mass || !spv_path) return -1;
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 7);
    if (!pipe) return -2;

    VkBuffer bufs[7] = {
        ((VkBuf*) q_init)->buffer, ((VkBuf*) p_init)->buffer, ((VkBuf*) inv_mass)->buffer,
        ((VkBuf*) q_chain)->buffer, ((VkBuf*) p_chain)->buffer,
        ((VkBuf*) grad_chain)->buffer, ((VkBuf*) logp_chain)->buffer
    };

    struct { unsigned int n, K; float eps, mu, sigma, nu, logp_const; } push;
    push.n = n; push.K = K; push.eps = eps; push.mu = mu;
    push.sigma = sigma; push.nu = nu; push.logp_const = logp_const;

    return dispatch(pipe, bufs, 7, 1, sizeof(push), &push);
}

int nxv_leapfrog_chain_cauchy(void* q_chain, void* p_chain,
                               void* grad_chain, void* logp_chain,
                               void* q_init, void* p_init, void* inv_mass,
                               unsigned int n, unsigned int K,
                               float eps, float loc, float scale,
                               float log_pi_scale,
                               const char* spv_path) {
    if (!q_chain || !p_chain || !grad_chain || !logp_chain ||
        !q_init || !p_init || !inv_mass || !spv_path) return -1;
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 7);
    if (!pipe) return -2;

    VkBuffer bufs[7] = {
        ((VkBuf*) q_init)->buffer, ((VkBuf*) p_init)->buffer, ((VkBuf*) inv_mass)->buffer,
        ((VkBuf*) q_chain)->buffer, ((VkBuf*) p_chain)->buffer,
        ((VkBuf*) grad_chain)->buffer, ((VkBuf*) logp_chain)->buffer
    };

    struct { unsigned int n, K; float eps, loc, scale, log_pi_scale; } push;
    push.n = n; push.K = K; push.eps = eps; push.loc = loc;
    push.scale = scale; push.log_pi_scale = log_pi_scale;

    return dispatch(pipe, bufs, 7, 1, sizeof(push), &push);
}

int nxv_leapfrog_chain_weibull(void* q_chain, void* p_chain,
                                void* grad_chain, void* logp_chain,
                                void* q_init, void* p_init, void* inv_mass,
                                unsigned int n, unsigned int K,
                                float eps, float k, float lambda,
                                float logp_const,
                                const char* spv_path) {
    if (!q_chain || !p_chain || !grad_chain || !logp_chain ||
        !q_init || !p_init || !inv_mass || !spv_path) return -1;
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 7);
    if (!pipe) return -2;

    VkBuffer bufs[7] = {
        ((VkBuf*) q_init)->buffer, ((VkBuf*) p_init)->buffer, ((VkBuf*) inv_mass)->buffer,
        ((VkBuf*) q_chain)->buffer, ((VkBuf*) p_chain)->buffer,
        ((VkBuf*) grad_chain)->buffer, ((VkBuf*) logp_chain)->buffer
    };

    /* Push: {n, K, eps, k, lambda, logp_const} = 24 bytes. */
    struct { unsigned int n, K; float eps, k, lambda, logp_const; } push;
    push.n = n; push.K = K; push.eps = eps;
    push.k = k; push.lambda = lambda; push.logp_const = logp_const;

    return dispatch(pipe, bufs, 7, 1, sizeof(push), &push);
}

int nxv_leapfrog_chain_halfnormal(void* q_chain, void* p_chain,
                                   void* grad_chain, void* logp_chain,
                                   void* q_init, void* p_init, void* inv_mass,
                                   unsigned int n, unsigned int K,
                                   float eps, float sigma, float log_const,
                                   const char* spv_path) {
    if (!q_chain || !p_chain || !grad_chain || !logp_chain ||
        !q_init || !p_init || !inv_mass || !spv_path) return -1;
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 7);
    if (!pipe) return -2;

    VkBuffer bufs[7] = {
        ((VkBuf*) q_init)->buffer, ((VkBuf*) p_init)->buffer, ((VkBuf*) inv_mass)->buffer,
        ((VkBuf*) q_chain)->buffer, ((VkBuf*) p_chain)->buffer,
        ((VkBuf*) grad_chain)->buffer, ((VkBuf*) logp_chain)->buffer
    };

    struct { unsigned int n, K; float eps, sigma, log_const; } push;
    push.n = n; push.K = K; push.eps = eps;
    push.sigma = sigma; push.log_const = log_const;

    return dispatch(pipe, bufs, 7, 1, sizeof(push), &push);
}

int nxv_leapfrog_chain_normal_f64(void* q_chain, void* p_chain,
                                   void* grad_chain, void* logp_chain,
                                   void* q_init, void* p_init, void* inv_mass,
                                   unsigned int n, unsigned int K,
                                   double eps, double mu, double sigma,
                                   const char* spv_path) {
    if (!q_chain || !p_chain || !grad_chain || !logp_chain ||
        !q_init || !p_init || !inv_mass || !spv_path) return -1;
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 7);
    if (!pipe) return -2;

    VkBuffer bufs[7] = {
        ((VkBuf*) q_init)->buffer, ((VkBuf*) p_init)->buffer, ((VkBuf*) inv_mass)->buffer,
        ((VkBuf*) q_chain)->buffer, ((VkBuf*) p_chain)->buffer,
        ((VkBuf*) grad_chain)->buffer, ((VkBuf*) logp_chain)->buffer
    };

    /* Push: {uint n, K; double eps, mu, sigma} = 4 + 4 + 8 + 8 + 8 = 32 bytes. */
    struct { unsigned int n, K; double eps, mu, sigma; } push;
    push.n = n; push.K = K; push.eps = eps; push.mu = mu; push.sigma = sigma;

    return dispatch(pipe, bufs, 7, 1, sizeof(push), &push);
}

int nxv_leapfrog_chain_synth(void* q_chain, void* p_chain,
                              void* grad_chain, void* logp_chain,
                              void* q_init, void* p_init, void* inv_mass,
                              const void* push_data, unsigned int push_size,
                              const char* spv_path) {
    if (!q_chain || !p_chain || !grad_chain || !logp_chain ||
        !q_init || !p_init || !inv_mass || !push_data || !spv_path) return -1;
    if (push_size == 0 || push_size > 128) return -3;

    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 7);
    if (!pipe) return -2;

    VkBuffer bufs[7] = {
        ((VkBuf*) q_init)->buffer, ((VkBuf*) p_init)->buffer, ((VkBuf*) inv_mass)->buffer,
        ((VkBuf*) q_chain)->buffer, ((VkBuf*) p_chain)->buffer,
        ((VkBuf*) grad_chain)->buffer, ((VkBuf*) logp_chain)->buffer
    };

    return dispatch(pipe, bufs, 7, 1, push_size, push_data);
}

int nxv_leapfrog_chain_normal(void* q_chain, void* p_chain,
                               void* grad_chain, void* logp_chain,
                               void* q_init, void* p_init, void* inv_mass,
                               unsigned int n, unsigned int K,
                               float eps, float mu, float sigma,
                               const char* spv_path) {
    if (!q_chain || !p_chain || !grad_chain || !logp_chain ||
        !q_init || !p_init || !inv_mass || !spv_path) return -1;
    /* 7 buffers in shader binding order: q_init, p_init, inv_mass,
     * q_chain, p_chain, grad_chain, logp_chain. */
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 7);
    if (!pipe) return -2;

    VkBuf* buf_qi = (VkBuf*) q_init;
    VkBuf* buf_pi = (VkBuf*) p_init;
    VkBuf* buf_m  = (VkBuf*) inv_mass;
    VkBuf* buf_qc = (VkBuf*) q_chain;
    VkBuf* buf_pc = (VkBuf*) p_chain;
    VkBuf* buf_gc = (VkBuf*) grad_chain;
    VkBuf* buf_lc = (VkBuf*) logp_chain;

    VkBuffer bufs[7] = {
        buf_qi->buffer, buf_pi->buffer, buf_m->buffer,
        buf_qc->buffer, buf_pc->buffer, buf_gc->buffer, buf_lc->buffer
    };

    /* Push: {uint n; uint K; float eps; float mu; float sigma} = 20 bytes. */
    struct {
        unsigned int n;
        unsigned int K;
        float eps;
        float mu;
        float sigma;
    } push;
    push.n = n;
    push.K = K;
    push.eps = eps;
    push.mu = mu;
    push.sigma = sigma;

    /* Single workgroup of 256 threads (assumes n <= 256). The shader
     * carries each dimension's chain state through K iterations within
     * one invocation; no need for multi-workgroup dispatch here. */
    return dispatch(pipe, bufs, 7, 1, sizeof(push), &push);
}

int nxv_leapfrog_normal(void* q_new, void* p_new,
                         void* q, void* p, void* inv_mass,
                         unsigned int n,
                         float eps, float mu, float sigma,
                         const char* spv_path) {
    if (!q_new || !p_new || !q || !p || !inv_mass || !spv_path) return -1;
    /* 5 buffers in shader binding order: q, p, inv_mass, q_new, p_new. */
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 5);
    if (!pipe) return -2;

    VkBuf* buf_q  = (VkBuf*) q;
    VkBuf* buf_p  = (VkBuf*) p;
    VkBuf* buf_m  = (VkBuf*) inv_mass;
    VkBuf* buf_qn = (VkBuf*) q_new;
    VkBuf* buf_pn = (VkBuf*) p_new;

    VkBuffer bufs[5] = {
        buf_q->buffer, buf_p->buffer, buf_m->buffer,
        buf_qn->buffer, buf_pn->buffer
    };

    /* Push: {uint n; float eps; float mu; float sigma} = 16 bytes. */
    struct {
        unsigned int n;
        float eps;
        float mu;
        float sigma;
    } push;
    push.n = n;
    push.eps = eps;
    push.mu = mu;
    push.sigma = sigma;

    unsigned int groups = (n + 255) / 256;
    return dispatch(pipe, bufs, 5, groups, sizeof(push), &push);
}

int nxv_fused_chain_4(void* out, void* a, void* b, void* c, void* d,
                      unsigned int n, unsigned int n_ops,
                      const unsigned int* ops,
                      const unsigned int* buf_idx,
                      const char* spv_path) {
    if (!out || !a || !b || !c || !d || !ops || !buf_idx || !spv_path) return -1;
    /* 5 buffers: a, b, c, d, out. */
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 5);
    if (!pipe) return -2;

    VkBuf* buf_a   = (VkBuf*) a;
    VkBuf* buf_b   = (VkBuf*) b;
    VkBuf* buf_c   = (VkBuf*) c;
    VkBuf* buf_d   = (VkBuf*) d;
    VkBuf* buf_out = (VkBuf*) out;

    VkBuffer bufs[5] = {
        buf_a->buffer, buf_b->buffer, buf_c->buffer, buf_d->buffer, buf_out->buffer
    };

    /* Push: {n, n_ops, ops[8], buf_idx[8]} = 72 bytes. */
    struct {
        unsigned int n;
        unsigned int n_ops;
        unsigned int ops[8];
        unsigned int buf_idx[8];
    } push;
    push.n = n;
    push.n_ops = n_ops;
    for (int i = 0; i < 8; i++) {
        push.ops[i] = ops[i];
        push.buf_idx[i] = buf_idx[i];
    }

    unsigned int groups = (n + 255) / 256;
    return dispatch(pipe, bufs, 5, groups, sizeof(push), &push);
}

int nxv_fused_chain(void* out, void* a, void* b,
                    unsigned int n, unsigned int n_ops,
                    const unsigned int* ops,
                    const char* spv_path) {
    if (!out || !a || !b || !ops || !spv_path) return -1;
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 3);
    if (!pipe) return -2;

    VkBuf* buf_a   = (VkBuf*) a;
    VkBuf* buf_b   = (VkBuf*) b;
    VkBuf* buf_out = (VkBuf*) out;

    /* Shader bindings: a, b, out. Push: {n, n_ops, ops[8]} = 40 bytes. */
    VkBuffer bufs[3] = { buf_a->buffer, buf_b->buffer, buf_out->buffer };

    struct {
        unsigned int n;
        unsigned int n_ops;
        unsigned int ops[8];
    } push;
    push.n = n;
    push.n_ops = n_ops;
    for (int i = 0; i < 8; i++) push.ops[i] = ops[i];

    unsigned int groups = (n + 255) / 256;
    return dispatch(pipe, bufs, 3, groups, sizeof(push), &push);
}

int nxv_transpose(void* out, void* a, unsigned int m, unsigned int n,
                  const char* spv_path) {
    if (!out || !a || !spv_path) return -1;
    VkPipe* pipe = get_or_create_pipe(std::string(spv_path), 0, 2);
    if (!pipe) return -2;

    VkBuf* buf_a   = (VkBuf*) a;
    VkBuf* buf_out = (VkBuf*) out;

    /* 2D dispatch — 16×16 tiles; same dance as matmul. The existing
     * spirit dispatch() helper is 1D-only, so inline. */
    auto& ctx = g_vk_ctx;

    VkBuffer bufs[2] = { buf_a->buffer, buf_out->buffer };

    VkDescriptorBufferInfo bi[2];
    VkWriteDescriptorSet w[2];
    for (int i = 0; i < 2; i++) {
        bi[i] = {bufs[i], 0, VK_WHOLE_SIZE};
        w[i] = {};
        w[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
        w[i].dstSet = pipe->descriptor_set;
        w[i].dstBinding = (uint32_t) i;
        w[i].descriptorCount = 1;
        w[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
        w[i].pBufferInfo = &bi[i];
    }
    vkUpdateDescriptorSets(ctx.device, 2, w, 0, nullptr);

    VkCommandBufferAllocateInfo ai{};
    ai.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
    ai.commandPool = ctx.command_pool;
    ai.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
    ai.commandBufferCount = 1;
    VkCommandBuffer cmd;
    vkAllocateCommandBuffers(ctx.device, &ai, &cmd);

    VkCommandBufferBeginInfo bb{};
    bb.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
    bb.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
    vkBeginCommandBuffer(cmd, &bb);
    vkCmdBindPipeline(cmd, VK_PIPELINE_BIND_POINT_COMPUTE, pipe->pipeline);
    vkCmdBindDescriptorSets(cmd, VK_PIPELINE_BIND_POINT_COMPUTE,
                            pipe->pipeline_layout, 0, 1, &pipe->descriptor_set, 0, nullptr);

    unsigned int push[2] = { m, n };
    vkCmdPushConstants(cmd, pipe->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
                       0, sizeof(push), push);

    unsigned int gx = (n + 15) / 16;
    unsigned int gy = (m + 15) / 16;
    vkCmdDispatch(cmd, gx, gy, 1);
    vkEndCommandBuffer(cmd);

    VkSubmitInfo si{};
    si.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
    si.commandBufferCount = 1;
    si.pCommandBuffers = &cmd;
    vkQueueSubmit(ctx.compute_queue, 1, &si, VK_NULL_HANDLE);
    vkQueueWaitIdle(ctx.compute_queue);
    vkFreeCommandBuffers(ctx.device, ctx.command_pool, 1, &cmd);

    return 0;
}

}