Skip to main content

c_src/nx_vulkan_shim.h

/* nx_vulkan_shim.h — extern "C" interface bridging Rust to spirit's
 * C++ Vulkan backend.
 *
 * Spirit's Backend_par_vulkan.{hpp,cpp} use C++ namespaces, classes,
 * and STL types — none of which Rust's bindgen handles cleanly. This
 * header declares a flat C ABI that the Rust NIF binds against; the
 * implementation in nx_vulkan_shim.cpp is a thin C++ file that
 * delegates to Spirit's helpers.
 *
 * Naming convention: nxv_* (Nx.Vulkan).
 */

#ifndef NX_VULKAN_SHIM_H
#define NX_VULKAN_SHIM_H

#ifdef __cplusplus
extern "C" {
#endif

/* Lifecycle ---------------------------------------------------------- */

/* Initialize the global Vulkan context. Idempotent. Returns 0 on
 * success, non-zero if no Vulkan-capable device is found. */
int nxv_init(void);

/* Tear down the global Vulkan context. Idempotent. */
void nxv_destroy(void);

/* Introspection ------------------------------------------------------ */

/* Returns a pointer to the device name string, valid until the next
 * nxv_destroy. NULL if init hasn't run. */
const char* nxv_device_name(void);

/* Returns 1 if the selected device supports f64, 0 otherwise. */
int nxv_has_f64(void);

/* H3 dispatch timing accumulators. Times are nanoseconds; count is the
 * number of dispatch() calls since the last reset. Caller may pass NULL
 * for any out parameter to skip it. */
void nxv_timing_reset(void);
void nxv_timing_get(unsigned long long* count,
                    unsigned long long* dispatch_ns,
                    unsigned long long* submit_ns,
                    unsigned long long* wait_ns,
                    unsigned long long* record_ns);

/* Batched upload: copies N host buffers into N device buffers in a single
 * submit_and_wait round-trip. dsts holds VkBuf* handles. Returns 0 on
 * success. */
int nxv_buf_upload_batch(void** dsts, const void** data,
                         const unsigned long* sizes, unsigned int n_buffers);

/* Batched download: copies N device buffers into N out_data pointers in a
 * single submit_and_wait round-trip (instead of N round-trips). srcs is an
 * array of VkBuf* handles (the same opaque pointers returned by
 * nxv_buf_alloc). Returns 0 on success. */
int nxv_buf_download_batch(void** srcs, void** out_data,
                           const unsigned long* sizes, unsigned int n_buffers);

/* Phase 2 W5 — pipeline cache disk persistence.
 *
 * Loads the on-disk blob at `path` (if it exists) and rebuilds the
 * spirit pipeline cache from it. Header sniff happens inside spirit;
 * mismatched UUID is silently discarded with a stderr warning.
 *
 * Returns 0 on success (including the "no file" + "header mismatch"
 * cases — both produce a fresh empty cache). Returns negative if the
 * file exists but cannot be read. */
int nxv_pipeline_cache_load(const char* path);

/* Persists the current pipeline cache to `path` via write-temp-then-
 * rename (atomic on same FS). Caller is responsible for ensuring the
 * parent directory exists.
 *
 * Returns 0 on success, negative on I/O error. */
int nxv_pipeline_cache_persist(const char* path);

/* Reads the current device's pipelineCacheUUID into out[16].
 * Returns 0 on success, negative if the device isn't initialized. */
int nxv_device_uuid(unsigned char out[16]);

/* Generic K-step leapfrog chain dispatch for synthesized shaders.
 *
 * Identical buffer binding order to all the family-specific
 * nxv_leapfrog_chain_* entries (q_init, p_init, inv_mass, q_chain,
 * p_chain, grad_chain, logp_chain at bindings 0-6).
 *
 * The push-constants block layout is OPAQUE to this shim — `push_data`
 * is a raw `push_size`-byte blob assembled by the caller (Elixir-side
 * codegen knows the per-shader layout). Maximum 128 bytes (Vulkan
 * minimum guaranteed push-constants size).
 *
 * Returns 0 on success, negative on error. */
int nxv_leapfrog_chain_synth(void* q_chain, void* p_chain,
                              void* grad_chain, void* logp_chain,
                              void* q_init, void* p_init, void* inv_mass,
                              const void* push_data, unsigned int push_size,
                              const char* spv_path);

/* Tensor primitives (v0.0.2) ------------------------------------------ */
/* Stubs placed here so Rust can declare them; implementations land in
 * the next iteration once the resource type lifetime is in place. */

/* Allocate a device-local buffer of `n_bytes`. Returns an opaque
 * handle (cast from VkBuf*) or NULL on failure. */
void* nxv_buf_alloc(unsigned long n_bytes);

/* Free a buffer handle. (Returns to pool when capacity allows; only
 * actually calls vkFreeMemory when the per-size-class cap is exceeded
 * or when nxv_pool_clear/nxv_destroy runs.) */
void nxv_buf_free(void* handle);

/* Release every pooled buffer back to the device. Call at idle time
 * to reclaim memory; otherwise pool grows to the working set size and
 * stays there. Idempotent. */
void nxv_pool_clear(void);

/* Pool stats. Any out-pointer may be NULL. */
void nxv_pool_stats(unsigned long* hits, unsigned long* misses,
                    unsigned long* freed, unsigned long* size_classes,
                    unsigned long* total_pooled);

/* Upload `n_bytes` of host data to the buffer. Returns 0 on success. */
int nxv_buf_upload(void* handle, const void* data, unsigned long n_bytes);

/* Download `n_bytes` from the buffer to host memory. Returns 0 on success. */
int nxv_buf_download(void* handle, void* data, unsigned long n_bytes);

/* Compute primitives (v0.0.3) ----------------------------------------- */

/* Elementwise binary op. `out`, `a`, `b` are buffers of `n` f32 elements.
 * `op` is the elementwise_binary.spv spec constant:
 *   0=add, 1=mul, 2=sub, 3=div, 4=pow, 5=max, 6=min.
 * Returns 0 on success.
 *
 * Pipeline is created on first use per (shader_path, op) and cached in
 * the shim — avoids the ~22 ms pipeline-create overhead documented in
 * spirit's RESULTS_RTX_3060_TI.md reductions section. */
int nxv_apply_binary(void* out, void* a, void* b,
                     unsigned int n, unsigned int op,
                     const char* spv_path);

/* Elementwise unary op. `out`, `a` are buffers of `n` f32 elements.
 * `op` spec constant: 0=exp, 1=log, 2=sqrt, 3=abs, 4=neg, 5=sigmoid,
 * 6=tanh, 7=relu, 8=ceil, 9=floor, 10=sign, 11=reciprocal, 12=square. */
int nxv_apply_unary(void* out, void* a,
                    unsigned int n, unsigned int op,
                    const char* spv_path);

/* Reduction. `n` f32 elements reduced to one f32 written to `out_scalar`.
 * `op`: 0=sum, 1=min, 2=max. */
int nxv_reduce(float* out_scalar, void* in,
               unsigned int n, unsigned int op,
               const char* spv_path);

/* Matmul (naive). C[M*N] = A[M*K] · B[K*N], all row-major f32. */
int nxv_matmul(void* out, void* a, void* b,
               unsigned int m, unsigned int n, unsigned int k,
               const char* spv_path);

/* Matmul variant — caller specifies the workgroup output tile size.
 * Dispatches gy=ceil(M/tile_m), gx=ceil(N/tile_n).
 * Tile sizes per shader:
 *   matmul.spv          : tile_m=16, tile_n=16 (compatible with nxv_matmul)
 *   matmul_tiled.spv    : tile_m=16, tile_n=16
 *   matmul_tiled32.spv  : tile_m=32, tile_n=32
 *   matmul_tiled16x2.spv: tile_m=32, tile_n=16 (each thread does 2 rows) */
int nxv_matmul_v(void* out, void* a, void* b,
                 unsigned int m, unsigned int n, unsigned int k,
                 unsigned int tile_m, unsigned int tile_n,
                 const char* spv_path);

/* Random. Fill `out` with `n` f32 values. dist=0 uniform [0,1),
 * dist=1 normal N(0,1) via Box-Muller. */
int nxv_random(void* out, unsigned int n, unsigned int seed, unsigned int dist,
               const char* spv_path);

/* 2D transpose. Input A is M×N row-major; output C is N×M row-major.
 * C[j, i] = A[i, j] for i in 0..M, j in 0..N. */
int nxv_transpose(void* out, void* a,
                  unsigned int m, unsigned int n,
                  const char* spv_path);

/* Cast f32↔f64. Two-file split: spv_path picks the direction.
 * `n` is the element count. In/out element widths differ; caller is
 * responsible for sizing buffers correctly (n*4 vs n*8). */
int nxv_cast(void* out, void* a,
             unsigned int n,
             const char* spv_path);

/* Per-axis reduction. Input is a virtual 3-D tensor (outer, reduce, inner)
 * row-major; output is (outer, inner) row-major. `op`: 0=sum, 1=max, 2=min. */
int nxv_reduce_axis(void* out, void* a,
                    unsigned int outer, unsigned int reduce_size, unsigned int inner,
                    unsigned int op,
                    const char* spv_path);

/* Broadcast elementwise binary op. Op spec constant 0..9
 * (add/mul/sub/div/pow/max/min/equal/less/greater). `ndim` is 1..4.
 * `out_shape`, `a_strides`, `b_strides` are 4-element arrays;
 * unused trailing entries should be 0. A stride of 0 broadcasts on
 * that axis. */
int nxv_apply_binary_broadcast(void* out, void* a, void* b,
                                unsigned int op, unsigned int ndim,
                                const unsigned int* out_shape,
                                const unsigned int* a_strides,
                                const unsigned int* b_strides,
                                const char* spv_path);

/* Fused n-way elementwise chain — up to 8 ops in one dispatch.
 * `ops` is an array of length 8 (pad with 255 = nop). Op codes:
 *   binary 0..6 (add/mul/sub/div/pow/max/min) — second operand is buf B
 *   unary  100..114 (exp/log/sqrt/abs/neg/sigmoid/tanh/relu/ceil/floor/
 *                    sign/reciprocal/square/erf/expm1)
 * The chain applies left-to-right starting from a[i], using b[i] for
 * binary steps. Output is c[i] of length n. */
int nxv_fused_chain(void* out, void* a, void* b,
                    unsigned int n, unsigned int n_ops,
                    const unsigned int* ops,
                    const char* spv_path);

/* 4-input fused chain. `ops` and `buf_idx` are length-8 arrays
 * (pad with 255 for ops, 1 for buf_idx). buf_idx values: 1=b, 2=c,
 * 3=d. Ignored for unary ops. */
int nxv_fused_chain_4(void* out, void* a, void* b, void* c, void* d,
                      unsigned int n, unsigned int n_ops,
                      const unsigned int* ops,
                      const unsigned int* buf_idx,
                      const char* spv_path);

/* kinetic_energy.spv — fused 0.5 * sum(p² * inv_mass). Output is
 * partial sums, one f32 per workgroup. Caller does final reduction. */
int nxv_kinetic_energy(void* out, void* p, void* inv_mass,
                        unsigned int n,
                        const char* spv_path);

/* normal_logpdf.spv — fused -0.5*((x-mu)/sigma)² - log(sigma) - 0.5*log(2π).
 * Output shape matches x, mu, sigma. */
int nxv_normal_logpdf(void* out, void* x, void* mu, void* sigma,
                       unsigned int n,
                       const char* spv_path);

/* leapfrog_normal.spv — fused NUTS leapfrog step for a univariate Normal
 * log-density model. Replaces ~12 elementwise dispatches via the IR walker
 * with one dispatch. 5 buffers: q, p, inv_mass (read), q_new, p_new (write).
 * Push constants carry {uint n; float eps; float mu; float sigma} = 16 bytes. */
int nxv_leapfrog_normal(void* q_new, void* p_new,
                         void* q, void* p, void* inv_mass,
                         unsigned int n,
                         float eps, float mu, float sigma,
                         const char* spv_path);

/* leapfrog_chain_normal.spv — fused chain of K NUTS leapfrog steps for a
 * univariate Normal log-density model. Output buffers q_chain, p_chain,
 * grad_chain are each K*n floats; logp_chain is K floats (per-step
 * reduction). Assumes n <= 256 (single workgroup). 7 buffers total
 * (3 read + 4 write). Push constants {n, K, eps, mu, sigma} = 20 bytes.
 * Caller pre-allocates the four output buffers. */
int nxv_leapfrog_chain_normal(void* q_chain, void* p_chain,
                               void* grad_chain, void* logp_chain,
                               void* q_init, void* p_init, void* inv_mass,
                               unsigned int n, unsigned int K,
                               float eps, float mu, float sigma,
                               const char* spv_path);

/* leapfrog_chain_normal_lg.spv — multi-workgroup variant. Lifts the
 * n <= 256 constraint. partial_logp output is K * num_workgroups floats
 * (per-workgroup partials per step); caller sums num_workgroups partials
 * to get the per-step logp. num_workgroups = ceil(n / 256). Workgroup 0
 * includes the constant -n*(log(sigma) + 0.5*log(2pi)) so the host sum
 * gives final logp directly. Push constants 24 bytes. */
int nxv_leapfrog_chain_normal_lg(void* q_chain, void* p_chain,
                                  void* grad_chain, void* partial_logp,
                                  void* q_init, void* p_init, void* inv_mass,
                                  unsigned int n, unsigned int K,
                                  unsigned int num_workgroups,
                                  float eps, float mu, float sigma,
                                  const char* spv_path);

/* leapfrog_chain_exponential.spv — Phase 2 sibling of leapfrog_chain_normal.
 * Same I/O shape (4 output buffers, K * n + K). Single-workgroup (n<=256).
 * Closed-form unconstrained gradient: grad_q_uc = 1 - lambda * exp(q_uc).
 * Push constants {n, K, eps, lambda} = 16 bytes. */
int nxv_leapfrog_chain_exponential(void* q_chain, void* p_chain,
                                    void* grad_chain, void* logp_chain,
                                    void* q_init, void* p_init, void* inv_mass,
                                    unsigned int n, unsigned int K,
                                    float eps, float lambda,
                                    const char* spv_path);

/* leapfrog_chain_studentt.spv — Phase 2: real-valued Student-t with df nu.
 * Push constants {n, K, eps, mu, sigma, nu, logp_const} = 28 bytes. */
int nxv_leapfrog_chain_studentt(void* q_chain, void* p_chain,
                                 void* grad_chain, void* logp_chain,
                                 void* q_init, void* p_init, void* inv_mass,
                                 unsigned int n, unsigned int K,
                                 float eps, float mu, float sigma,
                                 float nu, float logp_const,
                                 const char* spv_path);

/* leapfrog_chain_cauchy.spv — Phase 2: real-valued Cauchy(loc, scale).
 * Push constants {n, K, eps, loc, scale, log_pi_scale} = 24 bytes. */
int nxv_leapfrog_chain_cauchy(void* q_chain, void* p_chain,
                               void* grad_chain, void* logp_chain,
                               void* q_init, void* p_init, void* inv_mass,
                               unsigned int n, unsigned int K,
                               float eps, float loc, float scale,
                               float log_pi_scale,
                               const char* spv_path);

/* leapfrog_chain_halfnormal.spv — Phase 2: positive HalfNormal(sigma)
 * on the unconstrained line via log-transform.
 * Push constants {n, K, eps, sigma, log_const} = 20 bytes. */
int nxv_leapfrog_chain_halfnormal(void* q_chain, void* p_chain,
                                   void* grad_chain, void* logp_chain,
                                   void* q_init, void* p_init, void* inv_mass,
                                   unsigned int n, unsigned int K,
                                   float eps, float sigma, float log_const,
                                   const char* spv_path);

/* leapfrog_chain_weibull.spv — Weibull(k, lambda) on the
 * unconstrained line via log-transform. Closed-form gradient
 * `k * (1 - (exp(q_uc)/lambda)^k)`; logp_const carries
 * `n * (log(k) - k * log(lambda))` precomputed by the host.
 * Push constants {n, K, eps, k, lambda, logp_const} = 24 bytes. */
int nxv_leapfrog_chain_weibull(void* q_chain, void* p_chain,
                                void* grad_chain, void* logp_chain,
                                void* q_init, void* p_init, void* inv_mass,
                                unsigned int n, unsigned int K,
                                float eps, float k, float lambda,
                                float logp_const,
                                const char* spv_path);

/* leapfrog_chain_normal_f64.spv — f64 sibling of leapfrog_chain_normal.
 * All buffers double; output sizes are K * n * 8 (q, p, grad) and K * 8 (logp).
 * Push constants {n, K, eps, mu, sigma} = 32 bytes (eps/mu/sigma as double). */
int nxv_leapfrog_chain_normal_f64(void* q_chain, void* p_chain,
                                   void* grad_chain, void* logp_chain,
                                   void* q_init, void* p_init, void* inv_mass,
                                   unsigned int n, unsigned int K,
                                   double eps, double mu, double sigma,
                                   const char* spv_path);

#ifdef __cplusplus
}
#endif

#endif