Skip to main content

native/ex_burn_nif/src/lib.rs

//! # ExBurn NIF
//!
//! Rust NIF bridge between Elixir and the Burn deep learning framework.
//! Supports multiple backends:
//!   - **CUDA** (NVIDIA GPUs) — via `burn/cuda` feature
//!   - **Metal** (Apple GPUs) — via `burn/metal` feature
//!   - **Vulkan** (Android/Linux/Windows) — via `burn/vulkan` feature
//!   - **NdArray** (CPU fallback) — always available
//!
//! The backend is selected at compile time via Cargo features on the `burn` crate.
//! The default feature is `cuda`.  Build with `--features metal` or
//! `--features vulkan` to target other GPUs, or `--no-default-features`
//! for CPU-only (NdArray).

use rustler::ResourceArc;
use std::cell::RefCell;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::sync::OnceLock;

use burn::prelude::ElementConversion;
use burn::tensor::backend::AutodiffBackend;
use burn::tensor::Tensor;
use burn_autodiff::Autodiff;
#[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
use burn_ndarray::NdArray;

// ── Backend selection ─────────────────────────────────────────────
//
// The active GPU backend is chosen at compile time via Cargo features.
// Each feature gate brings in the corresponding Burn backend and device
// types.  When no GPU feature is enabled we fall back to NdArray (CPU).

/// The concrete backend type used throughout this NIF.
///
/// With the `cuda` feature this is a CUDA GPU backend.
/// With `metal` it is a Metal GPU backend.
/// With `vulkan` it is a Vulkan GPU backend.
/// With no GPU feature it falls back to NdArray (CPU).
#[cfg(feature = "cuda")]
type GpuBackend = burn::backend::Cuda;

#[cfg(feature = "metal")]
type GpuBackend = burn::backend::Metal;

#[cfg(feature = "vulkan")]
type GpuBackend = burn::backend::Vulkan;

#[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
type GpuBackend = NdArray;

/// Autodiff wrapper around the GPU backend.
#[cfg(any(feature = "cuda", feature = "metal", feature = "vulkan"))]
type B = Autodiff<GpuBackend>;

/// Autodiff wrapper around NdArray (CPU fallback).
#[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
type B = Autodiff<NdArray>;

// ── Device ────────────────────────────────────────────────────────
// The `BackendDevice` trait abstracts over the concrete device type
// for each backend.  Each backend cfg provides its own impl so the
// rest of the code can call `device()` uniformly.

/// Trait to obtain the default device for the active backend.
trait BackendDevice {
    type Device;
    fn device() -> Self::Device;
}

#[cfg(feature = "cuda")]
impl BackendDevice for GpuBackend {
    type Device = burn::backend::cuda::CudaDevice;
    fn device() -> Self::Device {
        burn::backend::cuda::CudaDevice::default()
    }
}

#[cfg(feature = "metal")]
impl BackendDevice for GpuBackend {
    type Device = burn::backend::metal::MetalDevice;
    fn device() -> Self::Device {
        burn::backend::metal::MetalDevice::default()
    }
}

#[cfg(feature = "vulkan")]
impl BackendDevice for GpuBackend {
    type Device = burn::backend::vulkan::VulkanDevice;
    fn device() -> Self::Device {
        burn::backend::vulkan::VulkanDevice::default()
    }
}

#[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
impl BackendDevice for NdArray {
    type Device = burn_ndarray::NdArrayDevice;
    fn device() -> Self::Device {
        burn_ndarray::NdArrayDevice::default()
    }
}

/// Returns the default device for the compiled backend.
fn device() -> <GpuBackend as BackendDevice>::Device {
    <GpuBackend as BackendDevice>::device()
}

// ── GPU availability cache ────────────────────────────────────────

static GPU_AVAILABLE: OnceLock<bool> = OnceLock::new();

fn gpu_available_cached() -> bool {
    *GPU_AVAILABLE.get_or_init(probe_gpu_available)
}

fn probe_gpu_available() -> bool {
    #[cfg(feature = "cuda")]
    {
        use burn::backend::cuda::CudaDevice;
        // Attempt to create a tiny tensor on the CUDA device.
        // If this succeeds, CUDA is available.
        let dev = CudaDevice::default();
        let t: Tensor<Autodiff<burn::backend::Cuda>, 1> = Tensor::from_floats([0.0f32], &dev);
        // Force evaluation by reading the data
        let _ = t.to_data();
        true
    }
    #[cfg(feature = "metal")]
    {
        use burn::backend::metal::MetalDevice;
        let dev = MetalDevice::default();
        let t: Tensor<Autodiff<burn::backend::Metal>, 1> = Tensor::from_floats([0.0f32], &dev);
        let _ = t.to_data();
        true
    }
    #[cfg(feature = "vulkan")]
    {
        use burn::backend::vulkan::VulkanDevice;
        let dev = VulkanDevice::default();
        let t: Tensor<Autodiff<burn::backend::Vulkan>, 1> = Tensor::from_floats([0.0f32], &dev);
        let _ = t.to_data();
        true
    }
    #[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
    {
        false
    }
}

// ── Tensor enum ───────────────────────────────────────────────────

#[derive(Clone)]
pub enum BurnTensor {
    F32x1(Tensor<B, 1>),
    F32x2(Tensor<B, 2>),
    F32x3(Tensor<B, 3>),
    F32x4(Tensor<B, 4>),
}

pub struct TensorResource {
    pub tensor: BurnTensor,
    pub shape: Vec<usize>,
    pub dtype: String,
}

#[rustler::resource_impl]
impl rustler::Resource for TensorResource {}
impl RefUnwindSafe for TensorResource {}
impl UnwindSafe for TensorResource {}

// ── Helpers ───────────────────────────────────────────────────────

fn tensor_to_bytes(t: &BurnTensor) -> (Vec<usize>, String, Vec<u8>) {
    match t {
        BurnTensor::F32x1(t) => {
            let dims: Vec<usize> = t.shape().dims::<1>().to_vec();
            let vals: Vec<f32> = t.to_data().into_vec().unwrap_or_default();
            let bytes: Vec<u8> = vals.iter().flat_map(|v| v.to_le_bytes()).collect();
            (dims, "f32".into(), bytes)
        }
        BurnTensor::F32x2(t) => {
            let dims: Vec<usize> = t.shape().dims::<2>().to_vec();
            let vals: Vec<f32> = t.to_data().into_vec().unwrap_or_default();
            let bytes: Vec<u8> = vals.iter().flat_map(|v| v.to_le_bytes()).collect();
            (dims, "f32".into(), bytes)
        }
        BurnTensor::F32x3(t) => {
            let dims: Vec<usize> = t.shape().dims::<3>().to_vec();
            let vals: Vec<f32> = t.to_data().into_vec().unwrap_or_default();
            let bytes: Vec<u8> = vals.iter().flat_map(|v| v.to_le_bytes()).collect();
            (dims, "f32".into(), bytes)
        }
        BurnTensor::F32x4(t) => {
            let dims: Vec<usize> = t.shape().dims::<4>().to_vec();
            let vals: Vec<f32> = t.to_data().into_vec().unwrap_or_default();
            let bytes: Vec<u8> = vals.iter().flat_map(|v| v.to_le_bytes()).collect();
            (dims, "f32".into(), bytes)
        }
    }
}

fn make_f32_tensor(vals: &[f32], shape: &[usize]) -> BurnTensor {
    let dev = device();
    match shape.len() {
        1 => BurnTensor::F32x1(Tensor::<B, 1>::from_floats(vals, &dev)),
        2 => BurnTensor::F32x2(Tensor::<B, 2>::from_floats(vals, &dev)),
        3 => BurnTensor::F32x3(Tensor::<B, 3>::from_floats(vals, &dev)),
        4 => BurnTensor::F32x4(Tensor::<B, 4>::from_floats(vals, &dev)),
        _ => panic!("Unsupported rank {}", shape.len()),
    }
}

fn make_tensor_from_bytes(data: Vec<u8>, shape: Vec<usize>, dtype: String) -> BurnTensor {
    match dtype.as_str() {
        "f32" => {
            let vals: Vec<f32> = data
                .chunks_exact(4)
                .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
                .collect();
            make_f32_tensor(&vals, &shape)
        }
        other => panic!("Unsupported dtype: {}", other),
    }
}

fn build_resource(t: BurnTensor, shape: Vec<usize>, dtype: String) -> ResourceArc<TensorResource> {
    ResourceArc::new(TensorResource {
        tensor: t,
        shape,
        dtype,
    })
}

// ═══════════════════════════════════════════════════════════════════
// NIF Functions
// ═══════════════════════════════════════════════════════════════════

// ── Tensor Creation ───────────────────────────────────────────────

#[inline(never)]
#[rustler::nif]
fn nif_new_tensor(data: Vec<u8>, shape: Vec<usize>, dtype: String) -> ResourceArc<TensorResource> {
    let t = make_tensor_from_bytes(data, shape.clone(), dtype.clone());
    build_resource(t, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_pow_tensor(a: ResourceArc<TensorResource>, exp: f32) -> ResourceArc<TensorResource> {
    // Create a scalar tensor with the exponent value, then use powf
    let dev = device();
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => {
            let exp_tensor = Tensor::<B, 1>::from_floats([exp], &dev);
            BurnTensor::F32x1(t.clone().powf(exp_tensor))
        }
        BurnTensor::F32x2(t) => {
            let exp_tensor = Tensor::<B, 2>::from_floats([exp], &dev);
            BurnTensor::F32x2(t.clone().powf(exp_tensor))
        }
        _ => panic!("pow_tensor only supports f32 1D/2D tensors"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_empty_tensor(shape: Vec<usize>, dtype: String) -> ResourceArc<TensorResource> {
    let numel: usize = shape.iter().product();
    let data = vec![0u8; numel * 4];
    let t = make_tensor_from_bytes(data, shape.clone(), dtype.clone());
    build_resource(t, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_zeros_tensor(shape: Vec<usize>, dtype: String) -> ResourceArc<TensorResource> {
    let numel: usize = shape.iter().product();
    let data = vec![0u8; numel * 4];
    let t = make_tensor_from_bytes(data, shape.clone(), dtype.clone());
    build_resource(t, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_ones_tensor(shape: Vec<usize>, dtype: String) -> ResourceArc<TensorResource> {
    let numel: usize = shape.iter().product();
    let vals = vec![1.0f32; numel];
    let data: Vec<u8> = vals.iter().flat_map(|v| v.to_le_bytes()).collect();
    let t = make_tensor_from_bytes(data, shape.clone(), dtype.clone());
    build_resource(t, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_eye_tensor(size: usize, _type: String) -> ResourceArc<TensorResource> {
    let dev = device();
    let t = Tensor::<B, 2>::eye(size, &dev);
    build_resource(BurnTensor::F32x2(t), vec![size, size], "f32".into())
}

#[inline(never)]
#[rustler::nif]
fn nif_iota_tensor(shape: Vec<usize>, axis: usize, _type: String) -> ResourceArc<TensorResource> {
    let dev = device();
    let n = shape.get(axis).copied().unwrap_or(1);
    let vals: Vec<f32> = (0..n).map(|i| i as f32).collect();
    let t = Tensor::<B, 1>::from_floats(vals.as_slice(), &dev);
    build_resource(BurnTensor::F32x1(t), vec![n], "f32".into())
}

// ── Tensor Inspection ─────────────────────────────────────────────

#[inline(never)]
#[rustler::nif]
fn nif_tensor_shape(tensor: ResourceArc<TensorResource>) -> Vec<usize> {
    tensor.shape.clone()
}

#[inline(never)]
#[rustler::nif]
fn nif_tensor_dtype(tensor: ResourceArc<TensorResource>) -> String {
    tensor.dtype.clone()
}

#[inline(never)]
#[rustler::nif]
fn nif_tensor_to_binary(tensor: ResourceArc<TensorResource>) -> Vec<u8> {
    let (_, _, bytes) = tensor_to_bytes(&tensor.tensor);
    bytes
}

#[inline(never)]
#[rustler::nif]
fn nif_tensor_numel(tensor: ResourceArc<TensorResource>) -> usize {
    tensor.shape.iter().product()
}

// ── Element-wise Arithmetic ───────────────────────────────────────

#[inline(never)]
#[rustler::nif]
fn nif_add_tensor(
    a: ResourceArc<TensorResource>,
    b: ResourceArc<TensorResource>,
) -> ResourceArc<TensorResource> {
    let result = match (&a.tensor, &b.tensor) {
        (BurnTensor::F32x1(t1), BurnTensor::F32x1(t2)) => {
            BurnTensor::F32x1(t1.clone() + t2.clone())
        }
        (BurnTensor::F32x2(t1), BurnTensor::F32x2(t2)) => {
            BurnTensor::F32x2(t1.clone() + t2.clone())
        }
        _ => panic!("Shape/dtype mismatch in add"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_sub_tensor(
    a: ResourceArc<TensorResource>,
    b: ResourceArc<TensorResource>,
) -> ResourceArc<TensorResource> {
    let result = match (&a.tensor, &b.tensor) {
        (BurnTensor::F32x1(t1), BurnTensor::F32x1(t2)) => {
            BurnTensor::F32x1(t1.clone() - t2.clone())
        }
        (BurnTensor::F32x2(t1), BurnTensor::F32x2(t2)) => {
            BurnTensor::F32x2(t1.clone() - t2.clone())
        }
        _ => panic!("Shape/dtype mismatch in sub"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_mul_tensor(
    a: ResourceArc<TensorResource>,
    b: ResourceArc<TensorResource>,
) -> ResourceArc<TensorResource> {
    let result = match (&a.tensor, &b.tensor) {
        (BurnTensor::F32x1(t1), BurnTensor::F32x1(t2)) => {
            BurnTensor::F32x1(t1.clone() * t2.clone())
        }
        (BurnTensor::F32x2(t1), BurnTensor::F32x2(t2)) => {
            BurnTensor::F32x2(t1.clone() * t2.clone())
        }
        _ => panic!("Shape/dtype mismatch in mul"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_div_tensor(
    a: ResourceArc<TensorResource>,
    b: ResourceArc<TensorResource>,
) -> ResourceArc<TensorResource> {
    let result = match (&a.tensor, &b.tensor) {
        (BurnTensor::F32x1(t1), BurnTensor::F32x1(t2)) => {
            BurnTensor::F32x1(t1.clone() / t2.clone())
        }
        (BurnTensor::F32x2(t1), BurnTensor::F32x2(t2)) => {
            BurnTensor::F32x2(t1.clone() / t2.clone())
        }
        _ => panic!("Shape/dtype mismatch in div"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_neg_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => BurnTensor::F32x1(-t.clone()),
        BurnTensor::F32x2(t) => BurnTensor::F32x2(-t.clone()),
        _ => panic!("Unsupported tensor type for neg"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_abs_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => BurnTensor::F32x1(t.clone().abs()),
        BurnTensor::F32x2(t) => BurnTensor::F32x2(t.clone().abs()),
        _ => panic!("Unsupported tensor type for abs"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_exp_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => BurnTensor::F32x1(t.clone().exp()),
        _ => panic!("Unsupported tensor type for exp"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_log_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => BurnTensor::F32x1(t.clone().log()),
        _ => panic!("Unsupported tensor type for log"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_sqrt_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => BurnTensor::F32x1(t.clone().sqrt()),
        _ => panic!("Unsupported tensor type for sqrt"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn sigmoid_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => {
            let one = Tensor::<B, 1>::ones(t.shape(), &device());
            BurnTensor::F32x1(one.clone() / (one + (-t.clone()).exp()))
        }
        _ => panic!("Unsupported tensor type for sigmoid"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_tanh_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => BurnTensor::F32x1(t.clone().tanh()),
        _ => panic!("Unsupported tensor type for tanh"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_relu_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => {
            let z = Tensor::<B, 1>::zeros(t.shape(), &device());
            BurnTensor::F32x1(t.clone().max_pair(z))
        }
        _ => panic!("Unsupported tensor type for relu"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

// ── Reductions ────────────────────────────────────────────────────

#[inline(never)]
#[rustler::nif]
fn nif_sum_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => BurnTensor::F32x1(t.clone().sum()),
        BurnTensor::F32x2(t) => BurnTensor::F32x1(t.clone().sum()),
        _ => panic!("Unsupported tensor type for sum"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_mean_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => BurnTensor::F32x1(t.clone().mean()),
        _ => panic!("Unsupported tensor type for mean"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_max_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => BurnTensor::F32x1(t.clone().max()),
        _ => panic!("Unsupported tensor type for max"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_min_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => BurnTensor::F32x1(t.clone().min()),
        _ => panic!("Unsupported tensor type for min"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

// ── Linear Algebra ────────────────────────────────────────────────

#[inline(never)]
#[rustler::nif]
fn nif_matmul_tensor(
    a: ResourceArc<TensorResource>,
    b: ResourceArc<TensorResource>,
) -> ResourceArc<TensorResource> {
    let result = match (&a.tensor, &b.tensor) {
        (BurnTensor::F32x2(t1), BurnTensor::F32x2(t2)) => {
            BurnTensor::F32x2(t1.clone().matmul(t2.clone()))
        }
        _ => panic!("Matmul requires 2D tensors"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_transpose_tensor(a: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    let result = match &a.tensor {
        BurnTensor::F32x2(t) => BurnTensor::F32x2(t.clone().transpose()),
        _ => panic!("Transpose requires 2D tensor"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_dot_tensor(
    a: ResourceArc<TensorResource>,
    b: ResourceArc<TensorResource>,
) -> ResourceArc<TensorResource> {
    let result = match (&a.tensor, &b.tensor) {
        (BurnTensor::F32x1(t1), BurnTensor::F32x1(t2)) => {
            BurnTensor::F32x1(t1.clone().dot(t2.clone()))
        }
        _ => panic!("dot requires 1D f32 tensors"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

// ── Shape Manipulation ────────────────────────────────────────────

#[inline(never)]
#[rustler::nif]
fn nif_reshape_tensor(
    a: ResourceArc<TensorResource>,
    new_shape: Vec<usize>,
) -> ResourceArc<TensorResource> {
    let new_numel: usize = new_shape.iter().product();
    let old_numel: usize = a.shape.iter().product();
    if new_numel != old_numel {
        panic!("Cannot reshape {} elements into {:?}", old_numel, new_shape);
    }
    let result = match (&a.tensor, new_shape.as_slice()) {
        (BurnTensor::F32x1(t), [d1]) => BurnTensor::F32x1(t.clone().reshape([*d1])),
        (BurnTensor::F32x1(t), [d1, d2]) => BurnTensor::F32x2(t.clone().reshape([*d1, *d2])),
        (BurnTensor::F32x1(t), [d1, d2, d3]) => {
            BurnTensor::F32x3(t.clone().reshape([*d1, *d2, *d3]))
        }
        (BurnTensor::F32x1(t), [d1, d2, d3, d4]) => {
            BurnTensor::F32x4(t.clone().reshape([*d1, *d2, *d3, *d4]))
        }
        (BurnTensor::F32x2(t), [d1]) => BurnTensor::F32x1(t.clone().reshape([*d1])),
        (BurnTensor::F32x2(t), [d1, d2]) => BurnTensor::F32x2(t.clone().reshape([*d1, *d2])),
        _ => panic!("reshape: unsupported"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_broadcast_tensor(
    a: ResourceArc<TensorResource>,
    target_shape: Vec<usize>,
) -> ResourceArc<TensorResource> {
    let result = match (&a.tensor, target_shape.as_slice()) {
        (BurnTensor::F32x1(t), [d1]) => BurnTensor::F32x1(t.clone().reshape([*d1])),
        (BurnTensor::F32x1(t), [d1, d2]) => BurnTensor::F32x2(t.clone().reshape([*d1, *d2])),
        (BurnTensor::F32x2(t), [d1, d2]) => BurnTensor::F32x2(t.clone().reshape([*d1, *d2])),
        _ => panic!("broadcast: unsupported"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_concat_tensor(
    a: ResourceArc<TensorResource>,
    b: ResourceArc<TensorResource>,
) -> ResourceArc<TensorResource> {
    let total_len = a.shape[0] + b.shape[0];
    let dev = device();
    let mut all_vals: Vec<f32> = Vec::with_capacity(total_len);
    if let BurnTensor::F32x1(inner) = &a.tensor {
        let vals: Vec<f32> = inner.to_data().into_vec().unwrap_or_default();
        all_vals.extend(vals);
    }
    if let BurnTensor::F32x1(inner) = &b.tensor {
        let vals: Vec<f32> = inner.to_data().into_vec().unwrap_or_default();
        all_vals.extend(vals);
    }
    let result = BurnTensor::F32x1(Tensor::<B, 1>::from_floats(all_vals.as_slice(), &dev));
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

// ── Device Management ─────────────────────────────────────────────

#[inline(never)]
#[rustler::nif]
fn nif_gpu_available() -> bool {
    gpu_available_cached()
}

#[inline(never)]
#[rustler::nif]
fn nif_device_name() -> String {
    if gpu_available_cached() {
        #[cfg(feature = "cuda")]
        return "CUDA (NVIDIA GPU)".into();
        #[cfg(feature = "metal")]
        return "Metal (Apple GPU)".into();
        #[cfg(feature = "vulkan")]
        return "Vulkan (GPU)".into();
        #[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
        return "NdArray (CPU)".into();
    } else {
        "NdArray (CPU)".into()
    }
}

#[inline(never)]
#[rustler::nif]
fn nif_to_gpu(tensor: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    // With a GPU backend compiled in, tensors are already on the GPU device.
    // This function is a no-op when the backend is GPU, but still forces
    // evaluation (synchronization) so the data is materialized.
    let (_, _, bytes) = tensor_to_bytes(&tensor.tensor);
    let t = make_tensor_from_bytes(bytes, tensor.shape.clone(), tensor.dtype.clone());
    build_resource(t, tensor.shape.clone(), tensor.dtype.clone())
}

#[inline(never)]
#[rustler::nif]
fn nif_to_cpu(tensor: ResourceArc<TensorResource>) -> ResourceArc<TensorResource> {
    // Materialize the tensor data (forces GPU→CPU transfer if on GPU),
    // then rebuild on the CPU device.
    let (_, _, bytes) = tensor_to_bytes(&tensor.tensor);

    let t = {
        // Read data back to host, then rebuild using the B backend device
        // (which may be GPU or CPU depending on compilation).
        make_tensor_from_bytes(bytes, tensor.shape.clone(), tensor.dtype.clone())
    };

    build_resource(t, tensor.shape.clone(), tensor.dtype.clone())
}

// ── Memory Management ─────────────────────────────────────────────

#[inline(never)]
#[rustler::nif]
fn nif_free_tensor(_tensor: ResourceArc<TensorResource>) -> rustler::Atom {
    rustler::types::atom::ok()
}

// ── Neural Network Operations ─────────────────────────────────────

#[inline(never)]
#[rustler::nif]
fn nif_softmax_tensor(a: ResourceArc<TensorResource>, dim: i64) -> ResourceArc<TensorResource> {
    let dim = if dim < 0 { a.shape.len() } else { dim as usize };
    let result = match &a.tensor {
        BurnTensor::F32x1(t) => {
            let max_val = t.clone().max();
            let shifted = t.clone() - max_val;
            let exp_vals = shifted.clone().exp();
            let sum_exp = exp_vals.clone().sum();
            BurnTensor::F32x1(exp_vals / sum_exp)
        }
        BurnTensor::F32x2(t) => {
            let max_val = t.clone().max_dim(dim);
            let shifted = t.clone() - max_val;
            let exp_vals = shifted.clone().exp();
            let sum_exp = exp_vals.clone().sum_dim(dim);
            BurnTensor::F32x2(exp_vals / sum_exp)
        }
        _ => panic!("softmax only supports f32 tensors"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_layer_norm_tensor(
    a: ResourceArc<TensorResource>,
    dim: i64,
    eps: f64,
) -> ResourceArc<TensorResource> {
    let eps = eps as f32;
    let result = match &a.tensor {
        BurnTensor::F32x2(t) => {
            let norm_dim = if dim < 0 { 1_usize } else { dim as usize };
            let mean = t.clone().mean_dim(norm_dim);
            let shifted = t.clone() - mean;
            let two = Tensor::<B, 2>::from_floats([2.0_f32], &device());
            let var = shifted.clone().powf(two).mean_dim(norm_dim);
            let normalized = shifted / (var + eps).sqrt();
            BurnTensor::F32x2(normalized)
        }
        BurnTensor::F32x1(t) => {
            let mean = t.clone().mean();
            let shifted = t.clone() - mean;
            let two = Tensor::<B, 1>::from_floats([2.0_f32], &device());
            let var = shifted.clone().powf(two).mean();
            let normalized = shifted / (var + eps).sqrt();
            BurnTensor::F32x1(normalized)
        }
        _ => panic!("layer_norm only supports 1D/2D f32 tensors"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

// ═══════════════════════════════════════════════════════════════════
// NIF init
// ═══════════════════════════════════════════════════════════════════

type GradientsType = <B as AutodiffBackend>::Gradients;

thread_local! {
    static LAST_GRADIENTS: RefCell<Option<GradientsType>> = const { RefCell::new(None) };
}

// ── Autodiff / Backward ────────────────────────────────────────-

/// Triggers backward pass on a scalar tensor and stores gradients
/// in a thread-local for subsequent grad() calls.
#[inline(never)]
#[rustler::nif]
fn nif_backward_tensor(a: ResourceArc<TensorResource>) -> rustler::Atom {
    if let BurnTensor::F32x1(t) = &a.tensor {
        if a.shape.iter().product::<usize>() == 1 {
            let grads = t.clone().backward();
            LAST_GRADIENTS.with(|g| {
                *g.borrow_mut() = Some(grads);
            });
        }
    }

    rustler::types::atom::ok()
}

#[inline(never)]
#[rustler::nif]
fn nif_grad_tensor(
    tensor: ResourceArc<TensorResource>,
    _var: ResourceArc<TensorResource>,
) -> ResourceArc<TensorResource> {
    let grad_tensor = LAST_GRADIENTS.with(|grads_cell| {
        let grads_ref = grads_cell.borrow();
        let grads = grads_ref.as_ref();

        match &tensor.tensor {
            BurnTensor::F32x1(t) => match grads {
                Some(g) => match t.grad(g) {
                    Some(grad) => {
                        let data = grad.to_data();
                        let vals: Vec<f32> = data.into_vec().unwrap_or_default();
                        let dev = device();
                        BurnTensor::F32x1(Tensor::<B, 1>::from_floats(vals.as_slice(), &dev))
                    }
                    None => {
                        let dev = device();
                        let n: usize = tensor.shape.iter().product();
                        BurnTensor::F32x1(Tensor::<B, 1>::zeros([n], &dev))
                    }
                },
                None => {
                    let dev = device();
                    let n: usize = tensor.shape.iter().product();
                    BurnTensor::F32x1(Tensor::<B, 1>::zeros([n], &dev))
                }
            },
            BurnTensor::F32x2(t) => match grads {
                Some(g) => match t.grad(g) {
                    Some(grad) => {
                        let data = grad.to_data();
                        let vals: Vec<f32> = data.into_vec().unwrap_or_default();
                        let dev = device();
                        BurnTensor::F32x2(Tensor::<B, 2>::from_floats(vals.as_slice(), &dev))
                    }
                    None => {
                        let dev = device();
                        let n: usize = tensor.shape.iter().product();
                        BurnTensor::F32x1(Tensor::<B, 1>::zeros([n], &dev))
                    }
                },
                None => {
                    let dev = device();
                    let n: usize = tensor.shape.iter().product();
                    BurnTensor::F32x1(Tensor::<B, 1>::zeros([n], &dev))
                }
            },
            _ => panic!("grad_tensor only supports 1D/2D f32 tensors"),
        }
    });
    let (shape, dtype, _) = tensor_to_bytes(&grad_tensor);
    build_resource(grad_tensor, shape, dtype)
}

// ═══════════════════════════════════════════════════════════════════
// NIF init
// ═══════════════════════════════════════════════════════════════════

rustler::init!("Elixir.ExBurn.Nif");

// ── Loss Functions ──────────────────────────────────────────────

#[inline(never)]
#[rustler::nif]
fn nif_cross_entropy_loss(
    pred: ResourceArc<TensorResource>,
    target: ResourceArc<TensorResource>,
) -> ResourceArc<TensorResource> {
    let result = match (&pred.tensor, &target.tensor) {
        (BurnTensor::F32x2(logits), BurnTensor::F32x2(targets)) => {
            let max_val = logits.clone().max_dim(1);
            let shifted = logits.clone() - max_val;
            let exp_vals = shifted.clone().exp();
            let sum_exp = exp_vals.clone().sum_dim(1);
            let log_sum_exp = sum_exp.log();
            let log_probs = shifted - log_sum_exp;
            let batch_size = log_probs.shape().dims::<2>()[0];
            let nll = -(targets.clone() * log_probs).sum() / (batch_size as f32);
            let dev = device();
            BurnTensor::F32x1(Tensor::<B, 1>::from_floats(
                [nll.into_scalar().elem::<f32>()],
                &dev,
            ))
        }
        (BurnTensor::F32x1(logits), BurnTensor::F32x1(targets)) => {
            let max_val = logits.clone().max();
            let shifted = logits.clone() - max_val;
            let exp_vals = shifted.clone().exp();
            let sum_exp = exp_vals.clone().sum();
            let log_sum_exp = sum_exp.log();
            let log_probs = shifted - log_sum_exp;
            let nll = -(targets.clone() * log_probs).sum();
            let dev = device();
            BurnTensor::F32x1(Tensor::<B, 1>::from_floats(
                [nll.into_scalar().elem::<f32>()],
                &dev,
            ))
        }
        _ => panic!("cross_entropy_loss: unsupported tensor shapes"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

#[inline(never)]
#[rustler::nif]
fn nif_mse_loss(
    pred: ResourceArc<TensorResource>,
    target: ResourceArc<TensorResource>,
) -> ResourceArc<TensorResource> {
    let result = match (&pred.tensor, &target.tensor) {
        (BurnTensor::F32x1(p), BurnTensor::F32x1(t)) => {
            let diff = p.clone() - t.clone();
            let squared = diff.clone() * diff;
            let mse = squared.mean();
            let dev = device();
            BurnTensor::F32x1(Tensor::<B, 1>::from_floats(
                [mse.into_scalar().elem::<f32>()],
                &dev,
            ))
        }
        (BurnTensor::F32x2(p), BurnTensor::F32x2(t)) => {
            let diff = p.clone() - t.clone();
            let squared = diff.clone() * diff;
            let mse = squared.mean();
            let dev = device();
            BurnTensor::F32x1(Tensor::<B, 1>::from_floats(
                [mse.into_scalar().elem::<f32>()],
                &dev,
            ))
        }
        _ => panic!("mse_loss: unsupported tensor shapes"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

// ── Regularization ──────────────────────────────────────────────

#[inline(never)]
#[rustler::nif]
fn nif_dropout(tensor: ResourceArc<TensorResource>, prob: f64) -> ResourceArc<TensorResource> {
    let prob = prob as f32;
    let scale = 1.0 / (1.0 - prob);
    let dev = device();

    let result = match &tensor.tensor {
        BurnTensor::F32x2(t) => {
            let dims = t.shape().dims::<2>();
            let total = dims[0] * dims[1];
            let vals: Vec<f32> = (0..total)
                .map(|_| {
                    let r: f32 = rand::random();
                    if r > prob {
                        scale
                    } else {
                        0.0
                    }
                })
                .collect();
            let mask = Tensor::<B, 2>::from_floats(vals.as_slice(), &dev);
            BurnTensor::F32x2(t.clone() * mask)
        }
        BurnTensor::F32x1(t) => {
            let dims = t.shape().dims::<1>();
            let total = dims[0];
            let vals: Vec<f32> = (0..total)
                .map(|_| {
                    let r: f32 = rand::random();
                    if r > prob {
                        scale
                    } else {
                        0.0
                    }
                })
                .collect();
            let mask = Tensor::<B, 1>::from_floats(vals.as_slice(), &dev);
            BurnTensor::F32x1(t.clone() * mask)
        }
        _ => panic!("dropout: unsupported tensor type"),
    };
    let (shape, dtype, _) = tensor_to_bytes(&result);
    build_resource(result, shape, dtype)
}

pub extern "C" fn debug_nif_count() -> i32 {
    rustler::codegen_runtime::inventory::iter::<rustler::Nif>().count() as i32
}