//! Rustler NIF for Nx.Vulkan.
//!
//! Three layers down from Elixir:
//!
//! Elixir → Rust NIF → extern "C" shim → C++ spirit::vulkan
//!
//! v0.0.1 wires bootstrap (init, device_name, has_f64).
//! v0.0.2 adds tensor lifetime + upload/download via ResourceArc.
use rustler::{Binary, Encoder, Env, Error, NifResult, OwnedBinary, ResourceArc, Term};
use std::ffi::CStr;
use std::os::raw::{c_char, c_void};
use std::sync::Mutex;
mod atoms {
rustler::atoms! {
ok,
error,
no_device,
alloc_failed,
upload_failed,
download_failed,
size_mismatch,
dispatch_failed,
bad_op,
load_failed,
persist_failed,
not_initialized,
}
}
// extern "C" declarations matching c_src/nx_vulkan_shim.h.
unsafe extern "C" {
fn nxv_init() -> i32;
fn nxv_device_name() -> *const c_char;
fn nxv_has_f64() -> i32;
fn nxv_timing_reset();
fn nxv_timing_get(
count: *mut u64,
dispatch_ns: *mut u64,
submit_ns: *mut u64,
wait_ns: *mut u64,
record_ns: *mut u64,
);
fn nxv_buf_download_batch(
srcs: *const *mut c_void,
out_data: *const *mut c_void,
sizes: *const u64,
n_buffers: u32,
) -> i32;
fn nxv_buf_upload_batch(
dsts: *const *mut c_void,
data: *const *const c_void,
sizes: *const u64,
n_buffers: u32,
) -> i32;
fn nxv_leapfrog_chain_synth(
q_chain: *mut c_void,
p_chain: *mut c_void,
grad_chain: *mut c_void,
logp_chain: *mut c_void,
q_init: *mut c_void,
p_init: *mut c_void,
inv_mass: *mut c_void,
push_data: *const c_void,
push_size: u32,
spv_path: *const c_char,
) -> i32;
fn nxv_pipeline_cache_load(path: *const c_char) -> i32;
fn nxv_pipeline_cache_persist(path: *const c_char) -> i32;
fn nxv_device_uuid(out: *mut u8) -> i32;
fn nxv_buf_alloc(n_bytes: u64) -> *mut c_void;
fn nxv_buf_free(handle: *mut c_void);
fn nxv_buf_upload(handle: *mut c_void, data: *const c_void, n_bytes: u64) -> i32;
fn nxv_buf_download(handle: *mut c_void, data: *mut c_void, n_bytes: u64) -> i32;
fn nxv_apply_binary(
out: *mut c_void,
a: *mut c_void,
b: *mut c_void,
n: u32,
op: u32,
spv_path: *const c_char,
) -> i32;
fn nxv_apply_unary(
out: *mut c_void,
a: *mut c_void,
n: u32,
op: u32,
spv_path: *const c_char,
) -> i32;
fn nxv_reduce(
out_scalar: *mut f32,
input: *mut c_void,
n: u32,
op: u32,
spv_path: *const c_char,
) -> i32;
fn nxv_matmul(
out: *mut c_void,
a: *mut c_void,
b: *mut c_void,
m: u32,
n: u32,
k: u32,
spv_path: *const c_char,
) -> i32;
fn nxv_random(
out: *mut c_void,
n: u32,
seed: u32,
dist: u32,
spv_path: *const c_char,
) -> i32;
fn nxv_transpose(
out: *mut c_void,
a: *mut c_void,
m: u32,
n: u32,
spv_path: *const c_char,
) -> i32;
fn nxv_cast(
out: *mut c_void,
a: *mut c_void,
n: u32,
spv_path: *const c_char,
) -> i32;
fn nxv_reduce_axis(
out: *mut c_void,
a: *mut c_void,
outer: u32,
reduce_size: u32,
inner: u32,
op: u32,
spv_path: *const c_char,
) -> i32;
fn nxv_pool_clear();
fn nxv_pool_stats(
hits: *mut u64,
misses: *mut u64,
freed: *mut u64,
size_classes: *mut u64,
total_pooled: *mut u64,
);
// f64 elementwise — same C shim, different .spv path. Caller computes
// n_elems and out_bytes per element width.
// f64 reduce_axis and broadcast use the existing C shims unchanged
// (the C side is type-opaque). Out-buffer sizes scale with element width.
// logsumexp uses the same shim as reduce_axis (same push layout) but
// is f32-only — output 4 bytes/element.
fn nxv_matmul_v(
out: *mut c_void,
a: *mut c_void,
b: *mut c_void,
m: u32,
n: u32,
k: u32,
tile_m: u32,
tile_n: u32,
spv_path: *const c_char,
) -> i32;
fn nxv_apply_binary_broadcast(
out: *mut c_void,
a: *mut c_void,
b: *mut c_void,
op: u32,
ndim: u32,
out_shape: *const u32,
a_strides: *const u32,
b_strides: *const u32,
spv_path: *const c_char,
) -> i32;
fn nxv_fused_chain(
out: *mut c_void,
a: *mut c_void,
b: *mut c_void,
n: u32,
n_ops: u32,
ops: *const u32,
spv_path: *const c_char,
) -> i32;
fn nxv_fused_chain_4(
out: *mut c_void,
a: *mut c_void,
b: *mut c_void,
c: *mut c_void,
d: *mut c_void,
n: u32,
n_ops: u32,
ops: *const u32,
buf_idx: *const u32,
spv_path: *const c_char,
) -> i32;
fn nxv_kinetic_energy(
out: *mut c_void,
p: *mut c_void,
inv_mass: *mut c_void,
n: u32,
spv_path: *const c_char,
) -> i32;
fn nxv_normal_logpdf(
out: *mut c_void,
x: *mut c_void,
mu: *mut c_void,
sigma: *mut c_void,
n: u32,
spv_path: *const c_char,
) -> i32;
fn nxv_leapfrog_normal(
q_new: *mut c_void,
p_new: *mut c_void,
q: *mut c_void,
p: *mut c_void,
inv_mass: *mut c_void,
n: u32,
eps: f32,
mu: f32,
sigma: f32,
spv_path: *const c_char,
) -> i32;
fn nxv_leapfrog_chain_normal(
q_chain: *mut c_void,
p_chain: *mut c_void,
grad_chain: *mut c_void,
logp_chain: *mut c_void,
q_init: *mut c_void,
p_init: *mut c_void,
inv_mass: *mut c_void,
n: u32,
K: u32,
eps: f32,
mu: f32,
sigma: f32,
spv_path: *const c_char,
) -> i32;
fn nxv_leapfrog_chain_normal_lg(
q_chain: *mut c_void,
p_chain: *mut c_void,
grad_chain: *mut c_void,
partial_logp: *mut c_void,
q_init: *mut c_void,
p_init: *mut c_void,
inv_mass: *mut c_void,
n: u32,
K: u32,
num_workgroups: u32,
eps: f32,
mu: f32,
sigma: f32,
spv_path: *const c_char,
) -> i32;
fn nxv_leapfrog_chain_exponential(
q_chain: *mut c_void,
p_chain: *mut c_void,
grad_chain: *mut c_void,
logp_chain: *mut c_void,
q_init: *mut c_void,
p_init: *mut c_void,
inv_mass: *mut c_void,
n: u32,
K: u32,
eps: f32,
lambda: f32,
spv_path: *const c_char,
) -> i32;
fn nxv_leapfrog_chain_studentt(
q_chain: *mut c_void, p_chain: *mut c_void,
grad_chain: *mut c_void, logp_chain: *mut c_void,
q_init: *mut c_void, p_init: *mut c_void, inv_mass: *mut c_void,
n: u32, K: u32,
eps: f32, mu: f32, sigma: f32, nu: f32, logp_const: f32,
spv_path: *const c_char,
) -> i32;
fn nxv_leapfrog_chain_cauchy(
q_chain: *mut c_void, p_chain: *mut c_void,
grad_chain: *mut c_void, logp_chain: *mut c_void,
q_init: *mut c_void, p_init: *mut c_void, inv_mass: *mut c_void,
n: u32, K: u32,
eps: f32, loc: f32, scale: f32, log_pi_scale: f32,
spv_path: *const c_char,
) -> i32;
fn nxv_leapfrog_chain_halfnormal(
q_chain: *mut c_void, p_chain: *mut c_void,
grad_chain: *mut c_void, logp_chain: *mut c_void,
q_init: *mut c_void, p_init: *mut c_void, inv_mass: *mut c_void,
n: u32, K: u32,
eps: f32, sigma: f32, log_const: f32,
spv_path: *const c_char,
) -> i32;
fn nxv_leapfrog_chain_weibull(
q_chain: *mut c_void, p_chain: *mut c_void,
grad_chain: *mut c_void, logp_chain: *mut c_void,
q_init: *mut c_void, p_init: *mut c_void, inv_mass: *mut c_void,
n: u32, K: u32,
eps: f32, k: f32, lambda: f32, logp_const: f32,
spv_path: *const c_char,
) -> i32;
fn nxv_leapfrog_chain_normal_f64(
q_chain: *mut c_void, p_chain: *mut c_void,
grad_chain: *mut c_void, logp_chain: *mut c_void,
q_init: *mut c_void, p_init: *mut c_void, inv_mass: *mut c_void,
n: u32, K: u32,
eps: f64, mu: f64, sigma: f64,
spv_path: *const c_char,
) -> i32;
}
// One-shot guard so Elixir can call init/0 idempotently. Vulkan's
// vk_init is itself idempotent at the spirit level (returns 0 if
// already inited) but tracking the state in Rust gives us cleaner
// error semantics on the Elixir side.
static INIT_STATE: Mutex<bool> = Mutex::new(false);
// Global submit serializer.
//
// Vulkan's VkQueue is "externally synchronized" — the spec says concurrent
// vkQueueSubmit calls from multiple host threads to the same VkQueue is
// undefined behaviour. Spirit's compute backend uses a single global
// queue, so any pair of NIF calls that submit (dispatch, upload, download,
// reduce, matmul, random) must NOT run on different threads at the same
// time.
//
// Without this lock, a stress test with 100 concurrent processes
// reproducibly triggers VK_ERROR_DEVICE_LOST within seconds.
//
// This lock serializes the entire submit-and-wait, which costs us
// concurrency on the GPU — but Spirit's submit_and_wait is itself
// blocking (no async dispatch), so we lose nothing real. Async dispatch
// + multiple queues is a v0.2 optimization.
static SUBMIT_LOCK: Mutex<()> = Mutex::new(());
// VulkanTensor owns a heap-allocated VkBuf via the C++ shim. When the
// Elixir reference is GC'd, ResourceArc drops this struct, which
// frees the GPU buffer through nxv_buf_free.
//
// SAFETY: the underlying handle is a void* pointer to a heap C++
// object. Send/Sync are unsafe-impl'd because BEAM may move the
// resource between schedulers; the C++ side serializes Vulkan calls
// through the global compute queue, so concurrent access from
// multiple NIF threads is bounded by Vulkan's own synchronization.
pub struct VulkanTensor {
handle: *mut c_void,
n_bytes: u64,
}
unsafe impl Send for VulkanTensor {}
unsafe impl Sync for VulkanTensor {}
impl Drop for VulkanTensor {
fn drop(&mut self) {
unsafe { nxv_buf_free(self.handle) };
}
}
#[rustler::nif]
fn init<'a>(env: Env<'a>) -> NifResult<Term<'a>> {
let mut state = INIT_STATE.lock().map_err(|_| Error::BadArg)?;
if *state {
return Ok((atoms::ok()).encode(env));
}
let rc = unsafe { nxv_init() };
if rc == 0 {
*state = true;
Ok((atoms::ok()).encode(env))
} else {
Ok((atoms::error(), atoms::no_device()).encode(env))
}
}
#[rustler::nif]
fn device_name<'a>(env: Env<'a>) -> NifResult<Term<'a>> {
let state = INIT_STATE.lock().map_err(|_| Error::BadArg)?;
if !*state {
return Ok((rustler::types::atom::nil()).encode(env));
}
let ptr = unsafe { nxv_device_name() };
if ptr.is_null() {
return Ok((rustler::types::atom::nil()).encode(env));
}
let cstr = unsafe { CStr::from_ptr(ptr) };
let s = cstr.to_string_lossy().to_string();
Ok(s.encode(env))
}
#[rustler::nif]
fn has_f64<'a>(env: Env<'a>) -> NifResult<Term<'a>> {
let rc = unsafe { nxv_has_f64() };
Ok((rc != 0).encode(env))
}
/// H3 dispatch timing — reset all accumulators to zero.
#[rustler::nif]
fn timing_reset<'a>(env: Env<'a>) -> NifResult<Term<'a>> {
unsafe { nxv_timing_reset() };
Ok(rustler::types::atom::ok().encode(env))
}
/// W5 — load the on-disk pipeline cache blob at `path`. Missing file is
/// OK (creates a fresh empty cache). Header mismatch is OK (silently
/// discarded). Returns `:ok` or `{:error, :load_failed}` on read error.
#[rustler::nif]
fn pipeline_cache_load<'a>(env: Env<'a>, path: String) -> NifResult<Term<'a>> {
let cstr = std::ffi::CString::new(path).map_err(|_| Error::BadArg)?;
let rc = unsafe { nxv_pipeline_cache_load(cstr.as_ptr()) };
if rc == 0 {
Ok(rustler::types::atom::ok().encode(env))
} else {
Ok((atoms::error(), atoms::load_failed()).encode(env))
}
}
/// W5 — persist the current pipeline cache to `path` (atomic
/// write-temp-rename). Returns `:ok`, `:empty`, or
/// `{:error, :persist_failed}`.
#[rustler::nif]
fn pipeline_cache_persist<'a>(env: Env<'a>, path: String) -> NifResult<Term<'a>> {
let cstr = std::ffi::CString::new(path).map_err(|_| Error::BadArg)?;
let rc = unsafe { nxv_pipeline_cache_persist(cstr.as_ptr()) };
if rc == 0 {
Ok(rustler::types::atom::ok().encode(env))
} else {
Ok((atoms::error(), atoms::persist_failed()).encode(env))
}
}
/// W5 — read the current device's pipelineCacheUUID as a 16-byte binary.
#[rustler::nif]
fn device_uuid<'a>(env: Env<'a>) -> NifResult<Term<'a>> {
let mut buf = [0u8; 16];
let rc = unsafe { nxv_device_uuid(buf.as_mut_ptr()) };
if rc != 0 {
return Ok((atoms::error(), atoms::not_initialized()).encode(env));
}
let mut out = OwnedBinary::new(16).ok_or_else(|| Error::Term(Box::new("alloc")))?;
out.as_mut_slice().copy_from_slice(&buf);
Ok((atoms::ok(), out.release(env).encode(env)).encode(env))
}
/// H3 dispatch timing — read current accumulator values.
/// Returns {count, dispatch_ns, submit_ns, wait_ns, record_ns}.
#[rustler::nif]
fn timing_get<'a>(env: Env<'a>) -> NifResult<Term<'a>> {
let mut count: u64 = 0;
let mut dispatch_ns: u64 = 0;
let mut submit_ns: u64 = 0;
let mut wait_ns: u64 = 0;
let mut record_ns: u64 = 0;
unsafe {
nxv_timing_get(
&mut count,
&mut dispatch_ns,
&mut submit_ns,
&mut wait_ns,
&mut record_ns,
)
};
Ok((count, dispatch_ns, submit_ns, wait_ns, record_ns).encode(env))
}
/// Upload an Elixir binary (raw bytes — typically packed f32) to a
/// freshly-allocated GPU buffer. Returns a ResourceArc wrapping the
/// VulkanTensor; when the Elixir reference is GC'd, the buffer is
/// freed automatically.
#[rustler::nif]
fn upload_binary<'a>(env: Env<'a>, data: Binary<'a>) -> NifResult<Term<'a>> {
let n_bytes = data.len() as u64;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let handle = unsafe { nxv_buf_alloc(n_bytes) };
if handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let rc = unsafe {
nxv_buf_upload(handle, data.as_slice().as_ptr() as *const c_void, n_bytes)
};
if rc != 0 {
unsafe { nxv_buf_free(handle) };
return Ok((atoms::error(), atoms::upload_failed()).encode(env));
}
let tensor = VulkanTensor { handle, n_bytes };
let resource = ResourceArc::new(tensor);
Ok((atoms::ok(), resource).encode(env))
}
/// Batched in-place upload of 2 binaries into 2 existing GPU buffers in a
/// single submit_and_wait round-trip. Saves 1 fence wait versus two
/// `upload_binary_into/2` calls.
#[rustler::nif]
fn upload_binary_into_batch2<'a>(
env: Env<'a>,
t1: ResourceArc<VulkanTensor>,
d1: Binary<'a>,
t2: ResourceArc<VulkanTensor>,
d2: Binary<'a>,
) -> NifResult<Term<'a>> {
let n1 = d1.len() as u64;
let n2 = d2.len() as u64;
if n1 != t1.n_bytes || n2 != t2.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let dsts: [*mut c_void; 2] = [t1.handle, t2.handle];
let datas: [*const c_void; 2] = [
d1.as_slice().as_ptr() as *const c_void,
d2.as_slice().as_ptr() as *const c_void,
];
let sizes: [u64; 2] = [n1, n2];
let rc = unsafe {
nxv_buf_upload_batch(dsts.as_ptr(), datas.as_ptr(), sizes.as_ptr(), 2)
};
if rc != 0 {
return Ok((atoms::error(), atoms::upload_failed()).encode(env));
}
Ok(rustler::types::atom::ok().encode(env))
}
/// Upload an Elixir binary into an existing GPU buffer. Skips allocation —
/// reuses the buffer in `tensor`. `data.len()` must match `tensor.n_bytes`.
/// Returns `:ok` on success.
#[rustler::nif]
fn upload_binary_into<'a>(
env: Env<'a>,
tensor: ResourceArc<VulkanTensor>,
data: Binary<'a>,
) -> NifResult<Term<'a>> {
let n_bytes = data.len() as u64;
if n_bytes != tensor.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_buf_upload(tensor.handle, data.as_slice().as_ptr() as *const c_void, n_bytes)
};
if rc != 0 {
return Ok((atoms::error(), atoms::upload_failed()).encode(env));
}
Ok(rustler::types::atom::ok().encode(env))
}
/// Download `n_bytes` from a GPU tensor back into an Elixir binary.
/// `n_bytes` must match the buffer's size.
#[rustler::nif]
fn download_binary<'a>(
env: Env<'a>,
tensor: ResourceArc<VulkanTensor>,
n_bytes: u64,
) -> NifResult<Term<'a>> {
if n_bytes != tensor.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
let mut bin = OwnedBinary::new(n_bytes as usize)
.ok_or_else(|| Error::Term(Box::new("could not allocate Elixir binary")))?;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_buf_download(
tensor.handle,
bin.as_mut_slice().as_mut_ptr() as *mut c_void,
n_bytes,
)
};
if rc != 0 {
return Ok((atoms::error(), atoms::download_failed()).encode(env));
}
let term = bin.release(env).encode(env);
Ok((atoms::ok(), term).encode(env))
}
/// Batched download of 4 GPU tensors in a single submit_and_wait round-trip.
/// Returns `{:ok, {b1, b2, b3, b4}}` where each bi is the corresponding
/// tensor's contents as an Elixir binary. Saves 3 fence waits versus 4
/// individual `download_binary/2` calls.
#[rustler::nif]
fn download_binary_batch4<'a>(
env: Env<'a>,
t1: ResourceArc<VulkanTensor>,
t2: ResourceArc<VulkanTensor>,
t3: ResourceArc<VulkanTensor>,
t4: ResourceArc<VulkanTensor>,
) -> NifResult<Term<'a>> {
let n1 = t1.n_bytes;
let n2 = t2.n_bytes;
let n3 = t3.n_bytes;
let n4 = t4.n_bytes;
let mut b1 = OwnedBinary::new(n1 as usize)
.ok_or_else(|| Error::Term(Box::new("alloc b1")))?;
let mut b2 = OwnedBinary::new(n2 as usize)
.ok_or_else(|| Error::Term(Box::new("alloc b2")))?;
let mut b3 = OwnedBinary::new(n3 as usize)
.ok_or_else(|| Error::Term(Box::new("alloc b3")))?;
let mut b4 = OwnedBinary::new(n4 as usize)
.ok_or_else(|| Error::Term(Box::new("alloc b4")))?;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let srcs: [*mut c_void; 4] = [t1.handle, t2.handle, t3.handle, t4.handle];
let outs: [*mut c_void; 4] = [
b1.as_mut_slice().as_mut_ptr() as *mut c_void,
b2.as_mut_slice().as_mut_ptr() as *mut c_void,
b3.as_mut_slice().as_mut_ptr() as *mut c_void,
b4.as_mut_slice().as_mut_ptr() as *mut c_void,
];
let sizes: [u64; 4] = [n1, n2, n3, n4];
let rc = unsafe {
nxv_buf_download_batch(
srcs.as_ptr(),
outs.as_ptr(),
sizes.as_ptr(),
4,
)
};
if rc != 0 {
return Ok((atoms::error(), atoms::download_failed()).encode(env));
}
let bins = (
b1.release(env).encode(env),
b2.release(env).encode(env),
b3.release(env).encode(env),
b4.release(env).encode(env),
);
Ok((atoms::ok(), bins).encode(env))
}
/// Returns the byte size of the tensor.
#[rustler::nif]
fn byte_size<'a>(env: Env<'a>, tensor: ResourceArc<VulkanTensor>) -> NifResult<Term<'a>> {
Ok((tensor.n_bytes).encode(env))
}
/// Apply an elementwise binary op to two GPU tensors. Allocates the
/// output buffer (same byte_size as inputs) and dispatches the
/// elementwise_binary shader. Returns a new ResourceArc.
///
/// op spec constant: 0=add, 1=mul, 2=sub, 3=div, 4=pow, 5=max, 6=min.
#[rustler::nif]
fn apply_binary<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
b: ResourceArc<VulkanTensor>,
op: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
if a.n_bytes != b.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
// Op range bumped 6→9 in v0.1 phase 1.1 — equal/less/greater
// added to elementwise_binary.spv. Update spirit's .comp + .spv
// when adding more.
if op > 9 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let n_bytes = a.n_bytes;
let n_elems = (n_bytes / 4) as u32; // f32 elements
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(n_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_apply_binary(out_handle, a.handle, b.handle, n_elems, op, cstr.as_ptr())
};
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// Apply an elementwise unary op. Allocates a fresh output buffer.
#[rustler::nif]
fn apply_unary<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
op: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
if op > 14 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let n_bytes = a.n_bytes;
let n_elems = (n_bytes / 4) as u32;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(n_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe { nxv_apply_unary(out_handle, a.handle, n_elems, op, cstr.as_ptr()) };
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// Reduction (sum/min/max). Returns a host-side f32 scalar.
#[rustler::nif]
fn reduce_scalar<'a>(
env: Env<'a>,
input: ResourceArc<VulkanTensor>,
op: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
if op > 2 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let n_elems = (input.n_bytes / 4) as u32;
let mut out_scalar: f32 = 0.0;
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_reduce(&mut out_scalar as *mut f32, input.handle, n_elems, op, cstr.as_ptr())
};
if rc != 0 {
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
Ok((atoms::ok(), out_scalar).encode(env))
}
/// Matmul C[M*N] = A[M*K] · B[K*N]. Allocates output.
#[rustler::nif]
fn matmul<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
b: ResourceArc<VulkanTensor>,
m: u32,
n: u32,
k: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
let expected_a = (m * k * 4) as u64;
let expected_b = (k * n * 4) as u64;
if a.n_bytes != expected_a || b.n_bytes != expected_b {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
let out_bytes = (m * n * 4) as u64;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(out_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe { nxv_matmul(out_handle, a.handle, b.handle, m, n, k, cstr.as_ptr()) };
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: out_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// Random fill. Allocates an output buffer of `n` f32 elements.
/// dist: 0=uniform [0,1), 1=normal N(0,1).
#[rustler::nif]
fn random<'a>(
env: Env<'a>,
n: u32,
seed: u32,
dist: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
if dist > 1 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let n_bytes = (n * 4) as u64;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(n_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe { nxv_random(out_handle, n, seed, dist, cstr.as_ptr()) };
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// 2D transpose. Input M×N row-major; output N×M row-major.
/// Allocates the output buffer (same byte_size as input).
#[rustler::nif]
fn transpose<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
m: u32,
n: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
let expected = (m * n * 4) as u64;
if a.n_bytes != expected {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(a.n_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe { nxv_transpose(out_handle, a.handle, m, n, cstr.as_ptr()) };
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: a.n_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// Cast f32↔f64. The .spv file determines direction; output buffer
/// width derives from the destination type. n is element count.
#[rustler::nif]
fn cast<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
n: u32,
out_elem_bytes: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
let out_bytes = (n as u64) * (out_elem_bytes as u64);
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(out_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe { nxv_cast(out_handle, a.handle, n, cstr.as_ptr()) };
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: out_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// Per-axis reduction over a virtual 3-D layout (outer, reduce, inner).
/// Output is (outer, inner) row-major, n_out = outer * inner * 4 bytes (f32).
#[rustler::nif]
fn reduce_axis<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
outer: u32,
reduce_size: u32,
inner: u32,
op: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
if op > 2 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let n_out = (outer as u64) * (inner as u64);
let out_bytes = n_out * 4;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(out_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe { nxv_reduce_axis(out_handle, a.handle, outer, reduce_size, inner, op, cstr.as_ptr()) };
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: out_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// Broadcast elementwise binary op. `out_shape`, `a_strides`,
/// `b_strides` are length-4 vectors padded with 0. A stride of 0
/// on an axis means broadcast (any coord on that axis maps to index 0).
#[rustler::nif]
fn apply_binary_broadcast<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
b: ResourceArc<VulkanTensor>,
op: u32,
ndim: u32,
out_shape: Vec<u32>,
a_strides: Vec<u32>,
b_strides: Vec<u32>,
spv_path: String,
) -> NifResult<Term<'a>> {
if op > 9 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
if ndim == 0 || ndim > 4 || out_shape.len() != 4
|| a_strides.len() != 4 || b_strides.len() != 4 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let n: u64 = (0..ndim as usize)
.map(|d| out_shape[d] as u64)
.product();
let out_bytes = n * 4;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(out_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_apply_binary_broadcast(
out_handle, a.handle, b.handle,
op, ndim,
out_shape.as_ptr(), a_strides.as_ptr(), b_strides.as_ptr(),
cstr.as_ptr(),
)
};
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: out_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// Path A — fused elementwise chain. `ops` is a Vec<u32> of length ≤8;
/// shorter chains are padded with 255 (nop). Op codes:
/// 0..6 binary (add/mul/sub/div/pow/max/min)
/// 100..114 unary (exp..expm1)
#[rustler::nif]
fn fused_chain<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
b: ResourceArc<VulkanTensor>,
ops: Vec<u32>,
spv_path: String,
) -> NifResult<Term<'a>> {
if ops.is_empty() || ops.len() > 8 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
if a.n_bytes != b.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
let n = (a.n_bytes / 4) as u32;
let n_ops = ops.len() as u32;
let mut padded: [u32; 8] = [255; 8];
for (i, &c) in ops.iter().enumerate() { padded[i] = c; }
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(a.n_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_fused_chain(out_handle, a.handle, b.handle, n, n_ops, padded.as_ptr(), cstr.as_ptr())
};
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: a.n_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// f64 reduce_axis. Output is (outer*inner) f64 (8 bytes/element).
#[rustler::nif]
fn reduce_axis_f64<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
outer: u32,
reduce_size: u32,
inner: u32,
op: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
if op > 2 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let n_out = (outer as u64) * (inner as u64);
let out_bytes = n_out * 8;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(out_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_reduce_axis(out_handle, a.handle, outer, reduce_size, inner, op, cstr.as_ptr())
};
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: out_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// f64 broadcast elementwise binary. Same shim as f32 broadcast.
#[rustler::nif]
fn apply_binary_broadcast_f64<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
b: ResourceArc<VulkanTensor>,
op: u32,
ndim: u32,
out_shape: Vec<u32>,
a_strides: Vec<u32>,
b_strides: Vec<u32>,
spv_path: String,
) -> NifResult<Term<'a>> {
if op > 9 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
if ndim == 0 || ndim > 4 || out_shape.len() != 4
|| a_strides.len() != 4 || b_strides.len() != 4 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let n: u64 = (0..ndim as usize)
.map(|d| out_shape[d] as u64)
.product();
let out_bytes = n * 8;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(out_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_apply_binary_broadcast(
out_handle, a.handle, b.handle,
op, ndim,
out_shape.as_ptr(), a_strides.as_ptr(), b_strides.as_ptr(),
cstr.as_ptr(),
)
};
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: out_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// logsumexp: numerically-stable two-pass on a single reduced axis.
/// Reuses nxv_reduce_axis's shim (same push layout); op is unused but
/// passed as 0 for parity. f32 only.
#[rustler::nif]
fn logsumexp<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
outer: u32,
reduce_size: u32,
inner: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
let n_out = (outer as u64) * (inner as u64);
let out_bytes = n_out * 4;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(out_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_reduce_axis(out_handle, a.handle, outer, reduce_size, inner, 0, cstr.as_ptr())
};
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: out_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// f64 elementwise binary. Same op codes 0..6 as the f32 path; the
/// shader's binding type makes the precision choice.
#[rustler::nif]
fn apply_binary_f64<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
b: ResourceArc<VulkanTensor>,
op: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
if op > 6 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
if a.n_bytes != b.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
let n_elems = (a.n_bytes / 8) as u32;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(a.n_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_apply_binary(out_handle, a.handle, b.handle, n_elems, op, cstr.as_ptr())
};
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: a.n_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// f64 elementwise unary.
#[rustler::nif]
fn apply_unary_f64<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
op: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
if op > 14 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let n_elems = (a.n_bytes / 8) as u32;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(a.n_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe { nxv_apply_unary(out_handle, a.handle, n_elems, op, cstr.as_ptr()) };
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: a.n_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// Matmul variant — caller picks the .spv path and the workgroup
/// output tile size. Used by Nx.Vulkan auto-select to dispatch the
/// right shader for a given (M, N, K) shape.
#[rustler::nif]
fn matmul_v<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
b: ResourceArc<VulkanTensor>,
m: u32,
n: u32,
k: u32,
tile_m: u32,
tile_n: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
let expected_a = (m * k * 4) as u64;
let expected_b = (k * n * 4) as u64;
if a.n_bytes != expected_a || b.n_bytes != expected_b {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
let out_bytes = (m * n * 4) as u64;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(out_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_matmul_v(out_handle, a.handle, b.handle, m, n, k, tile_m, tile_n, cstr.as_ptr())
};
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: out_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// kinetic_energy: 0.5 * sum(p² * inv_mass) per workgroup.
/// Output is `ceil(n / 256)` partial sums (4 bytes each) — caller
/// reduces them on host or via a follow-up reduce_axis dispatch.
#[rustler::nif]
fn kinetic_energy<'a>(
env: Env<'a>,
p: ResourceArc<VulkanTensor>,
inv_mass: ResourceArc<VulkanTensor>,
spv_path: String,
) -> NifResult<Term<'a>> {
if p.n_bytes != inv_mass.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
let n = (p.n_bytes / 4) as u32;
let n_groups: u64 = ((n + 255) / 256) as u64;
let out_bytes = n_groups * 4;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(out_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_kinetic_energy(out_handle, p.handle, inv_mass.handle, n, cstr.as_ptr())
};
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: out_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// normal_logpdf: -0.5*((x-mu)/sigma)² - log(sigma) - 0.5*log(2π).
/// Output shape matches x.
#[rustler::nif]
fn normal_logpdf<'a>(
env: Env<'a>,
x: ResourceArc<VulkanTensor>,
mu: ResourceArc<VulkanTensor>,
sigma: ResourceArc<VulkanTensor>,
spv_path: String,
) -> NifResult<Term<'a>> {
if x.n_bytes != mu.n_bytes || x.n_bytes != sigma.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
let n = (x.n_bytes / 4) as u32;
let out_bytes = x.n_bytes;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(out_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_normal_logpdf(out_handle, x.handle, mu.handle, sigma.handle, n, cstr.as_ptr())
};
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: out_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// leapfrog_normal: fused NUTS leapfrog step for univariate Normal.
/// Returns {q_new, p_new}. mu, sigma, eps come in as f32 push constants.
/// q, p, inv_mass must all share byte size (n elements × 4 bytes).
#[rustler::nif]
fn leapfrog_normal<'a>(
env: Env<'a>,
q: ResourceArc<VulkanTensor>,
p: ResourceArc<VulkanTensor>,
inv_mass: ResourceArc<VulkanTensor>,
eps: f64,
mu: f64,
sigma: f64,
spv_path: String,
) -> NifResult<Term<'a>> {
if q.n_bytes != p.n_bytes || q.n_bytes != inv_mass.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
let n = (q.n_bytes / 4) as u32;
let out_bytes = q.n_bytes;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let q_new_handle = unsafe { nxv_buf_alloc(out_bytes) };
if q_new_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let p_new_handle = unsafe { nxv_buf_alloc(out_bytes) };
if p_new_handle.is_null() {
unsafe { nxv_buf_free(q_new_handle) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_leapfrog_normal(
q_new_handle,
p_new_handle,
q.handle,
p.handle,
inv_mass.handle,
n,
eps as f32,
mu as f32,
sigma as f32,
cstr.as_ptr(),
)
};
if rc != 0 {
unsafe { nxv_buf_free(q_new_handle) };
unsafe { nxv_buf_free(p_new_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let q_new = VulkanTensor { handle: q_new_handle, n_bytes: out_bytes };
let p_new = VulkanTensor { handle: p_new_handle, n_bytes: out_bytes };
Ok((atoms::ok(), (ResourceArc::new(q_new), ResourceArc::new(p_new))).encode(env))
}
/// leapfrog_chain_synth: dispatch a synthesized chain shader.
///
/// Generic NIF for shaders generated by `Nx.Vulkan.ShaderTemplate`.
/// The push-constants block is opaque to this NIF — `push` is a raw
/// binary assembled by the Elixir-side codegen (it knows the per-shader
/// layout). Maximum 128 bytes.
///
/// Allocates the four chain output buffers internally (q_chain, p_chain,
/// grad_chain are k*n*4 bytes each; logp_chain is k*4 bytes). `n` is
/// derived from `q_init.n_bytes`.
///
/// Returns `{:ok, {q_chain, p_chain, grad_chain, logp_chain}}`.
#[rustler::nif]
fn leapfrog_chain_synth<'a>(
env: Env<'a>,
q_init: ResourceArc<VulkanTensor>,
p_init: ResourceArc<VulkanTensor>,
inv_mass: ResourceArc<VulkanTensor>,
push: Binary<'a>,
k: u32,
spv_path: String,
) -> NifResult<Term<'a>> {
// R2.2.3 repack: the inv_mass slot at binding 2 may carry
// obs[0..n_obs-1] followed by inv_mass[0..d-1] (total
// (n_obs+d)*4 bytes), so it need not match q_init/p_init's
// d*4 byte size. Only q_init and p_init must agree — chain_bytes
// is computed from q_init.n_bytes regardless.
if q_init.n_bytes != p_init.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
if k == 0 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
if push.len() == 0 || push.len() > 128 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let chain_bytes = q_init.n_bytes * (k as u64);
let logp_bytes = (k as u64) * 4;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let q_chain_handle = unsafe { nxv_buf_alloc(chain_bytes) };
if q_chain_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let p_chain_handle = unsafe { nxv_buf_alloc(chain_bytes) };
if p_chain_handle.is_null() {
unsafe { nxv_buf_free(q_chain_handle) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let grad_chain_handle = unsafe { nxv_buf_alloc(chain_bytes) };
if grad_chain_handle.is_null() {
unsafe { nxv_buf_free(q_chain_handle) };
unsafe { nxv_buf_free(p_chain_handle) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let logp_chain_handle = unsafe { nxv_buf_alloc(logp_bytes) };
if logp_chain_handle.is_null() {
unsafe { nxv_buf_free(q_chain_handle) };
unsafe { nxv_buf_free(p_chain_handle) };
unsafe { nxv_buf_free(grad_chain_handle) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_leapfrog_chain_synth(
q_chain_handle,
p_chain_handle,
grad_chain_handle,
logp_chain_handle,
q_init.handle,
p_init.handle,
inv_mass.handle,
push.as_slice().as_ptr() as *const c_void,
push.len() as u32,
cstr.as_ptr(),
)
};
if rc != 0 {
unsafe { nxv_buf_free(q_chain_handle) };
unsafe { nxv_buf_free(p_chain_handle) };
unsafe { nxv_buf_free(grad_chain_handle) };
unsafe { nxv_buf_free(logp_chain_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let q_chain = VulkanTensor { handle: q_chain_handle, n_bytes: chain_bytes };
let p_chain = VulkanTensor { handle: p_chain_handle, n_bytes: chain_bytes };
let grad_chain = VulkanTensor { handle: grad_chain_handle, n_bytes: chain_bytes };
let logp_chain = VulkanTensor { handle: logp_chain_handle, n_bytes: logp_bytes };
Ok((
atoms::ok(),
(
ResourceArc::new(q_chain),
ResourceArc::new(p_chain),
ResourceArc::new(grad_chain),
ResourceArc::new(logp_chain),
),
)
.encode(env))
}
/// leapfrog_chain_normal: K-step fused leapfrog chain for univariate Normal.
/// Returns {q_chain, p_chain, grad_chain, logp_chain}. All four are
/// allocated by this NIF; q/p/grad chains are K*n*4 bytes each, logp_chain
/// is K*4 bytes. K must be a positive u32; n is derived from input byte size.
#[rustler::nif]
fn leapfrog_chain_normal<'a>(
env: Env<'a>,
q_init: ResourceArc<VulkanTensor>,
p_init: ResourceArc<VulkanTensor>,
inv_mass: ResourceArc<VulkanTensor>,
k: u32,
eps: f64,
mu: f64,
sigma: f64,
spv_path: String,
) -> NifResult<Term<'a>> {
if q_init.n_bytes != p_init.n_bytes || q_init.n_bytes != inv_mass.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
if k == 0 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let n = (q_init.n_bytes / 4) as u32;
let chain_bytes = q_init.n_bytes * (k as u64); // K * n * 4
let logp_bytes = (k as u64) * 4;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let q_chain_handle = unsafe { nxv_buf_alloc(chain_bytes) };
if q_chain_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let p_chain_handle = unsafe { nxv_buf_alloc(chain_bytes) };
if p_chain_handle.is_null() {
unsafe { nxv_buf_free(q_chain_handle) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let grad_chain_handle = unsafe { nxv_buf_alloc(chain_bytes) };
if grad_chain_handle.is_null() {
unsafe { nxv_buf_free(q_chain_handle) };
unsafe { nxv_buf_free(p_chain_handle) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let logp_chain_handle = unsafe { nxv_buf_alloc(logp_bytes) };
if logp_chain_handle.is_null() {
unsafe { nxv_buf_free(q_chain_handle) };
unsafe { nxv_buf_free(p_chain_handle) };
unsafe { nxv_buf_free(grad_chain_handle) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_leapfrog_chain_normal(
q_chain_handle,
p_chain_handle,
grad_chain_handle,
logp_chain_handle,
q_init.handle,
p_init.handle,
inv_mass.handle,
n,
k,
eps as f32,
mu as f32,
sigma as f32,
cstr.as_ptr(),
)
};
if rc != 0 {
unsafe { nxv_buf_free(q_chain_handle) };
unsafe { nxv_buf_free(p_chain_handle) };
unsafe { nxv_buf_free(grad_chain_handle) };
unsafe { nxv_buf_free(logp_chain_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let q_chain = VulkanTensor { handle: q_chain_handle, n_bytes: chain_bytes };
let p_chain = VulkanTensor { handle: p_chain_handle, n_bytes: chain_bytes };
let grad_chain = VulkanTensor { handle: grad_chain_handle, n_bytes: chain_bytes };
let logp_chain = VulkanTensor { handle: logp_chain_handle, n_bytes: logp_bytes };
Ok((
atoms::ok(),
(
ResourceArc::new(q_chain),
ResourceArc::new(p_chain),
ResourceArc::new(grad_chain),
ResourceArc::new(logp_chain),
),
)
.encode(env))
}
/// leapfrog_chain_normal_lg: multi-workgroup K-step chain.
/// Returns {q_chain, p_chain, grad_chain, partial_logp}. partial_logp is
/// K * num_workgroups f32 (host sums num_workgroups partials per step).
#[rustler::nif]
fn leapfrog_chain_normal_lg<'a>(
env: Env<'a>,
q_init: ResourceArc<VulkanTensor>,
p_init: ResourceArc<VulkanTensor>,
inv_mass: ResourceArc<VulkanTensor>,
k: u32,
eps: f64,
mu: f64,
sigma: f64,
spv_path: String,
) -> NifResult<Term<'a>> {
if q_init.n_bytes != p_init.n_bytes || q_init.n_bytes != inv_mass.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
if k == 0 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let n = (q_init.n_bytes / 4) as u32;
let num_workgroups = (n + 255) / 256;
let chain_bytes = q_init.n_bytes * (k as u64);
let partial_bytes = (k as u64) * (num_workgroups as u64) * 4;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let q_h = unsafe { nxv_buf_alloc(chain_bytes) };
if q_h.is_null() { return Ok((atoms::error(), atoms::alloc_failed()).encode(env)); }
let p_h = unsafe { nxv_buf_alloc(chain_bytes) };
if p_h.is_null() {
unsafe { nxv_buf_free(q_h) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let g_h = unsafe { nxv_buf_alloc(chain_bytes) };
if g_h.is_null() {
unsafe { nxv_buf_free(q_h); nxv_buf_free(p_h) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let pl_h = unsafe { nxv_buf_alloc(partial_bytes) };
if pl_h.is_null() {
unsafe { nxv_buf_free(q_h); nxv_buf_free(p_h); nxv_buf_free(g_h) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_leapfrog_chain_normal_lg(
q_h, p_h, g_h, pl_h,
q_init.handle, p_init.handle, inv_mass.handle,
n, k, num_workgroups,
eps as f32, mu as f32, sigma as f32,
cstr.as_ptr(),
)
};
if rc != 0 {
unsafe { nxv_buf_free(q_h); nxv_buf_free(p_h);
nxv_buf_free(g_h); nxv_buf_free(pl_h) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let q_c = VulkanTensor { handle: q_h, n_bytes: chain_bytes };
let p_c = VulkanTensor { handle: p_h, n_bytes: chain_bytes };
let g_c = VulkanTensor { handle: g_h, n_bytes: chain_bytes };
let pl_c = VulkanTensor { handle: pl_h, n_bytes: partial_bytes };
Ok((
atoms::ok(),
(
ResourceArc::new(q_c),
ResourceArc::new(p_c),
ResourceArc::new(g_c),
ResourceArc::new(pl_c),
),
)
.encode(env))
}
/// leapfrog_chain_exponential: K-step chain for Exp(lambda) on the
/// unconstrained line (log-transform). Same I/O shape as the Normal
/// chain (returns {q, p, grad, logp} 4-tuple).
#[rustler::nif]
fn leapfrog_chain_exponential<'a>(
env: Env<'a>,
q_init: ResourceArc<VulkanTensor>,
p_init: ResourceArc<VulkanTensor>,
inv_mass: ResourceArc<VulkanTensor>,
k: u32,
eps: f64,
lambda: f64,
spv_path: String,
) -> NifResult<Term<'a>> {
if q_init.n_bytes != p_init.n_bytes || q_init.n_bytes != inv_mass.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
if k == 0 {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
let n = (q_init.n_bytes / 4) as u32;
let chain_bytes = q_init.n_bytes * (k as u64);
let logp_bytes = (k as u64) * 4;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let q_h = unsafe { nxv_buf_alloc(chain_bytes) };
if q_h.is_null() { return Ok((atoms::error(), atoms::alloc_failed()).encode(env)); }
let p_h = unsafe { nxv_buf_alloc(chain_bytes) };
if p_h.is_null() {
unsafe { nxv_buf_free(q_h) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let g_h = unsafe { nxv_buf_alloc(chain_bytes) };
if g_h.is_null() {
unsafe { nxv_buf_free(q_h); nxv_buf_free(p_h) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let lc_h = unsafe { nxv_buf_alloc(logp_bytes) };
if lc_h.is_null() {
unsafe { nxv_buf_free(q_h); nxv_buf_free(p_h); nxv_buf_free(g_h) };
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_leapfrog_chain_exponential(
q_h, p_h, g_h, lc_h,
q_init.handle, p_init.handle, inv_mass.handle,
n, k,
eps as f32, lambda as f32,
cstr.as_ptr(),
)
};
if rc != 0 {
unsafe { nxv_buf_free(q_h); nxv_buf_free(p_h);
nxv_buf_free(g_h); nxv_buf_free(lc_h) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let q_c = VulkanTensor { handle: q_h, n_bytes: chain_bytes };
let p_c = VulkanTensor { handle: p_h, n_bytes: chain_bytes };
let g_c = VulkanTensor { handle: g_h, n_bytes: chain_bytes };
let lc_c = VulkanTensor { handle: lc_h, n_bytes: logp_bytes };
Ok((
atoms::ok(),
(
ResourceArc::new(q_c),
ResourceArc::new(p_c),
ResourceArc::new(g_c),
ResourceArc::new(lc_c),
),
)
.encode(env))
}
// --- Phase 2 chain NIFs (Student-t, Cauchy, HalfNormal) + f64 chain ---
//
// All four follow the same allocate-4-output-buffers / dispatch /
// return-4-tuple pattern as leapfrog_chain_exponential. The differences
// are the push-constant scalars and the underlying nxv_* dispatch.
// f64 chain uses 8 bytes per element instead of 4.
#[rustler::nif]
fn leapfrog_chain_studentt<'a>(
env: Env<'a>,
q_init: ResourceArc<VulkanTensor>,
p_init: ResourceArc<VulkanTensor>,
inv_mass: ResourceArc<VulkanTensor>,
k: u32,
eps: f64, mu: f64, sigma: f64, nu: f64, logp_const: f64,
spv_path: String,
) -> NifResult<Term<'a>> {
if q_init.n_bytes != p_init.n_bytes || q_init.n_bytes != inv_mass.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
if k == 0 { return Ok((atoms::error(), atoms::bad_op()).encode(env)); }
let n = (q_init.n_bytes / 4) as u32;
let chain_bytes = q_init.n_bytes * (k as u64);
let logp_bytes = (k as u64) * 4;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let (qh, ph, gh, lh) = match alloc_4(chain_bytes, chain_bytes, chain_bytes, logp_bytes) {
Ok(t) => t,
Err(_) => return Ok((atoms::error(), atoms::alloc_failed()).encode(env)),
};
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_leapfrog_chain_studentt(
qh, ph, gh, lh,
q_init.handle, p_init.handle, inv_mass.handle,
n, k,
eps as f32, mu as f32, sigma as f32, nu as f32, logp_const as f32,
cstr.as_ptr(),
)
};
encode_chain_result(env, rc, qh, ph, gh, lh, chain_bytes, logp_bytes)
}
#[rustler::nif]
fn leapfrog_chain_cauchy<'a>(
env: Env<'a>,
q_init: ResourceArc<VulkanTensor>,
p_init: ResourceArc<VulkanTensor>,
inv_mass: ResourceArc<VulkanTensor>,
k: u32,
eps: f64, loc: f64, scale: f64, log_pi_scale: f64,
spv_path: String,
) -> NifResult<Term<'a>> {
if q_init.n_bytes != p_init.n_bytes || q_init.n_bytes != inv_mass.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
if k == 0 { return Ok((atoms::error(), atoms::bad_op()).encode(env)); }
let n = (q_init.n_bytes / 4) as u32;
let chain_bytes = q_init.n_bytes * (k as u64);
let logp_bytes = (k as u64) * 4;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let (qh, ph, gh, lh) = match alloc_4(chain_bytes, chain_bytes, chain_bytes, logp_bytes) {
Ok(t) => t,
Err(_) => return Ok((atoms::error(), atoms::alloc_failed()).encode(env)),
};
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_leapfrog_chain_cauchy(
qh, ph, gh, lh,
q_init.handle, p_init.handle, inv_mass.handle,
n, k,
eps as f32, loc as f32, scale as f32, log_pi_scale as f32,
cstr.as_ptr(),
)
};
encode_chain_result(env, rc, qh, ph, gh, lh, chain_bytes, logp_bytes)
}
#[rustler::nif]
fn leapfrog_chain_halfnormal<'a>(
env: Env<'a>,
q_init: ResourceArc<VulkanTensor>,
p_init: ResourceArc<VulkanTensor>,
inv_mass: ResourceArc<VulkanTensor>,
k: u32,
eps: f64, sigma: f64, log_const: f64,
spv_path: String,
) -> NifResult<Term<'a>> {
if q_init.n_bytes != p_init.n_bytes || q_init.n_bytes != inv_mass.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
if k == 0 { return Ok((atoms::error(), atoms::bad_op()).encode(env)); }
let n = (q_init.n_bytes / 4) as u32;
let chain_bytes = q_init.n_bytes * (k as u64);
let logp_bytes = (k as u64) * 4;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let (qh, ph, gh, lh) = match alloc_4(chain_bytes, chain_bytes, chain_bytes, logp_bytes) {
Ok(t) => t,
Err(_) => return Ok((atoms::error(), atoms::alloc_failed()).encode(env)),
};
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_leapfrog_chain_halfnormal(
qh, ph, gh, lh,
q_init.handle, p_init.handle, inv_mass.handle,
n, k,
eps as f32, sigma as f32, log_const as f32,
cstr.as_ptr(),
)
};
encode_chain_result(env, rc, qh, ph, gh, lh, chain_bytes, logp_bytes)
}
#[rustler::nif]
fn leapfrog_chain_weibull<'a>(
env: Env<'a>,
q_init: ResourceArc<VulkanTensor>,
p_init: ResourceArc<VulkanTensor>,
inv_mass: ResourceArc<VulkanTensor>,
k: u32,
eps: f64, weibull_k: f64, lambda: f64, logp_const: f64,
spv_path: String,
) -> NifResult<Term<'a>> {
if q_init.n_bytes != p_init.n_bytes || q_init.n_bytes != inv_mass.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
if k == 0 { return Ok((atoms::error(), atoms::bad_op()).encode(env)); }
let n = (q_init.n_bytes / 4) as u32;
let chain_bytes = q_init.n_bytes * (k as u64);
let logp_bytes = (k as u64) * 4;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let (qh, ph, gh, lh) = match alloc_4(chain_bytes, chain_bytes, chain_bytes, logp_bytes) {
Ok(t) => t,
Err(_) => return Ok((atoms::error(), atoms::alloc_failed()).encode(env)),
};
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_leapfrog_chain_weibull(
qh, ph, gh, lh,
q_init.handle, p_init.handle, inv_mass.handle,
n, k,
eps as f32, weibull_k as f32, lambda as f32, logp_const as f32,
cstr.as_ptr(),
)
};
encode_chain_result(env, rc, qh, ph, gh, lh, chain_bytes, logp_bytes)
}
#[rustler::nif]
fn leapfrog_chain_normal_f64<'a>(
env: Env<'a>,
q_init: ResourceArc<VulkanTensor>,
p_init: ResourceArc<VulkanTensor>,
inv_mass: ResourceArc<VulkanTensor>,
k: u32,
eps: f64, mu: f64, sigma: f64,
spv_path: String,
) -> NifResult<Term<'a>> {
// f64: 8 bytes per element. n derived from input byte size.
if q_init.n_bytes != p_init.n_bytes || q_init.n_bytes != inv_mass.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
if k == 0 { return Ok((atoms::error(), atoms::bad_op()).encode(env)); }
let n = (q_init.n_bytes / 8) as u32;
let chain_bytes = q_init.n_bytes * (k as u64);
let logp_bytes = (k as u64) * 8;
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let (qh, ph, gh, lh) = match alloc_4(chain_bytes, chain_bytes, chain_bytes, logp_bytes) {
Ok(t) => t,
Err(_) => return Ok((atoms::error(), atoms::alloc_failed()).encode(env)),
};
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_leapfrog_chain_normal_f64(
qh, ph, gh, lh,
q_init.handle, p_init.handle, inv_mass.handle,
n, k,
eps, mu, sigma,
cstr.as_ptr(),
)
};
encode_chain_result(env, rc, qh, ph, gh, lh, chain_bytes, logp_bytes)
}
// Helper: allocate four output buffers, free everything if any alloc fails.
fn alloc_4(b1: u64, b2: u64, b3: u64, b4: u64)
-> Result<(*mut c_void, *mut c_void, *mut c_void, *mut c_void), ()>
{
let h1 = unsafe { nxv_buf_alloc(b1) };
if h1.is_null() { return Err(()); }
let h2 = unsafe { nxv_buf_alloc(b2) };
if h2.is_null() { unsafe { nxv_buf_free(h1) }; return Err(()); }
let h3 = unsafe { nxv_buf_alloc(b3) };
if h3.is_null() { unsafe { nxv_buf_free(h1); nxv_buf_free(h2) }; return Err(()); }
let h4 = unsafe { nxv_buf_alloc(b4) };
if h4.is_null() { unsafe { nxv_buf_free(h1); nxv_buf_free(h2); nxv_buf_free(h3) }; return Err(()); }
Ok((h1, h2, h3, h4))
}
// Helper: dispatch result → tuple-encode or free-and-error.
fn encode_chain_result<'a>(
env: Env<'a>,
rc: i32,
qh: *mut c_void, ph: *mut c_void, gh: *mut c_void, lh: *mut c_void,
chain_bytes: u64, logp_bytes: u64,
) -> NifResult<Term<'a>> {
if rc != 0 {
unsafe { nxv_buf_free(qh); nxv_buf_free(ph);
nxv_buf_free(gh); nxv_buf_free(lh); };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let q_c = VulkanTensor { handle: qh, n_bytes: chain_bytes };
let p_c = VulkanTensor { handle: ph, n_bytes: chain_bytes };
let g_c = VulkanTensor { handle: gh, n_bytes: chain_bytes };
let l_c = VulkanTensor { handle: lh, n_bytes: logp_bytes };
Ok((
atoms::ok(),
(ResourceArc::new(q_c), ResourceArc::new(p_c),
ResourceArc::new(g_c), ResourceArc::new(l_c)),
).encode(env))
}
/// 4-input fused chain. ops + buf_idx are length-≤8 vecs;
/// padded to 8 with [255, 1] respectively. All 4 input buffers must
/// be the same byte size (single output of that size).
#[rustler::nif]
fn fused_chain_4<'a>(
env: Env<'a>,
a: ResourceArc<VulkanTensor>,
b: ResourceArc<VulkanTensor>,
c: ResourceArc<VulkanTensor>,
d: ResourceArc<VulkanTensor>,
ops: Vec<u32>,
buf_idx: Vec<u32>,
spv_path: String,
) -> NifResult<Term<'a>> {
if ops.is_empty() || ops.len() > 8 || ops.len() != buf_idx.len() {
return Ok((atoms::error(), atoms::bad_op()).encode(env));
}
if a.n_bytes != b.n_bytes || a.n_bytes != c.n_bytes || a.n_bytes != d.n_bytes {
return Ok((atoms::error(), atoms::size_mismatch()).encode(env));
}
let n = (a.n_bytes / 4) as u32;
let n_ops = ops.len() as u32;
let mut padded_ops: [u32; 8] = [255; 8];
let mut padded_buf: [u32; 8] = [1; 8];
for (i, (&op, &bi)) in ops.iter().zip(buf_idx.iter()).enumerate() {
padded_ops[i] = op;
padded_buf[i] = bi;
}
let _g = SUBMIT_LOCK.lock().map_err(|_| Error::BadArg)?;
let out_handle = unsafe { nxv_buf_alloc(a.n_bytes) };
if out_handle.is_null() {
return Ok((atoms::error(), atoms::alloc_failed()).encode(env));
}
let cstr = std::ffi::CString::new(spv_path).map_err(|_| Error::BadArg)?;
let rc = unsafe {
nxv_fused_chain_4(
out_handle, a.handle, b.handle, c.handle, d.handle,
n, n_ops, padded_ops.as_ptr(), padded_buf.as_ptr(),
cstr.as_ptr(),
)
};
if rc != 0 {
unsafe { nxv_buf_free(out_handle) };
return Ok((atoms::error(), atoms::dispatch_failed()).encode(env));
}
let out = VulkanTensor { handle: out_handle, n_bytes: a.n_bytes };
Ok((atoms::ok(), ResourceArc::new(out)).encode(env))
}
/// Release every pooled VkBuf back to the device. Call at idle time
/// to reclaim memory.
#[rustler::nif]
fn pool_clear<'a>(env: Env<'a>) -> NifResult<Term<'a>> {
unsafe { nxv_pool_clear() };
Ok(atoms::ok().encode(env))
}
/// Pool stats: returns {:ok, %{hits, misses, freed, size_classes, total_pooled}}.
#[rustler::nif]
fn pool_stats<'a>(env: Env<'a>) -> NifResult<Term<'a>> {
let mut hits: u64 = 0;
let mut misses: u64 = 0;
let mut freed: u64 = 0;
let mut size_classes: u64 = 0;
let mut total_pooled: u64 = 0;
unsafe {
nxv_pool_stats(&mut hits, &mut misses, &mut freed,
&mut size_classes, &mut total_pooled);
}
let map = rustler::Term::map_from_pairs(
env,
&[
(rustler::types::atom::Atom::from_str(env, "hits").unwrap().encode(env), hits.encode(env)),
(rustler::types::atom::Atom::from_str(env, "misses").unwrap().encode(env), misses.encode(env)),
(rustler::types::atom::Atom::from_str(env, "freed").unwrap().encode(env), freed.encode(env)),
(rustler::types::atom::Atom::from_str(env, "size_classes").unwrap().encode(env), size_classes.encode(env)),
(rustler::types::atom::Atom::from_str(env, "total_pooled").unwrap().encode(env), total_pooled.encode(env)),
],
).map_err(|_| Error::BadArg)?;
Ok((atoms::ok(), map).encode(env))
}
fn on_load(env: Env, _info: Term) -> bool {
rustler::resource!(VulkanTensor, env);
true
}
// rustler 0.36 deprecated the second arg (functions are auto-discovered
// via the #[rustler::nif] attribute). One-arg form is the new shape.
rustler::init!("Elixir.Nx.Vulkan.Native", load = on_load);