Skip to main content

native/ex_cubecl_nif/src/lib.rs

use rustler::{Encoder, Env, Error, NifResult, ResourceArc, Term};

pub mod ffi;

// NOTE: CubeCL GPU integration placeholder.
// When the cubecl crate is available, uncomment the feature flag in Cargo.toml
// and add cubecl::wgpu::WgpuRuntime backend support here.
// For now, all operations run on CPU with optimized integer-aware paths.

// ── DType ─────────────────────────────────────────────────────

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DType {
    F32,
    F64,
    S32,
    S64,
    U32,
    U8,
}

impl DType {
    pub fn from_str(s: &str) -> Option<Self> {
        match s {
            "f32" => Some(DType::F32),
            "f64" => Some(DType::F64),
            "s32" => Some(DType::S32),
            "s64" => Some(DType::S64),
            "u32" => Some(DType::U32),
            "u8" => Some(DType::U8),
            _ => None,
        }
    }
    pub fn size_in_bytes(self) -> usize {
        match self {
            DType::F32 => 4,
            DType::F64 => 8,
            DType::S32 => 4,
            DType::S64 => 8,
            DType::U32 => 4,
            DType::U8 => 1,
        }
    }
}

// ── BufResource ──────────────────────────────────────────────

#[derive(Debug, Clone)]
pub struct BufResource {
    pub data: Vec<u8>,
    pub shape: Vec<usize>,
    pub dtype: DType,
}

impl BufResource {
    pub fn num_elements(&self) -> usize {
        self.shape.iter().product()
    }
}

// ── Atoms ────────────────────────────────────────────────────

mod atoms {
    rustler::atoms! { ok, error }
}

// ── Helpers ──────────────────────────────────────────────────

fn decode_dtype(s: &str) -> NifResult<DType> {
    DType::from_str(s.trim())
        .ok_or_else(|| Error::RaiseTerm(Box::new(format!("unknown dtype: {}", s))))
}

fn strides_for(shape: &[usize]) -> Vec<usize> {
    let mut s = vec![1usize; shape.len()];
    for i in (0..shape.len().saturating_sub(1)).rev() {
        s[i] = s[i + 1] * shape[i + 1];
    }
    s
}

fn to_f64(data: &[u8], dtype: DType) -> Vec<f64> {
    match dtype {
        DType::F32 => bytemuck::cast_slice::<u8, f32>(data)
            .iter()
            .map(|v| *v as f64)
            .collect(),
        DType::F64 => bytemuck::cast_slice::<u8, f64>(data).to_vec(),
        DType::S32 => bytemuck::cast_slice::<u8, i32>(data)
            .iter()
            .map(|v| *v as f64)
            .collect(),
        DType::S64 => bytemuck::cast_slice::<u8, i64>(data)
            .iter()
            .map(|v| *v as f64)
            .collect(),
        DType::U32 => bytemuck::cast_slice::<u8, u32>(data)
            .iter()
            .map(|v| *v as f64)
            .collect(),
        DType::U8 => data.iter().map(|v| *v as f64).collect(),
    }
}

fn from_f64(vals: Vec<f64>, dtype: DType) -> Vec<u8> {
    match dtype {
        DType::F32 => {
            let v: Vec<f32> = vals.iter().map(|x| *x as f32).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::F64 => bytemuck::cast_slice(&vals).to_vec(),
        DType::S32 => {
            let v: Vec<i32> = vals.iter().map(|x| *x as i32).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::S64 => {
            let v: Vec<i64> = vals.iter().map(|x| *x as i64).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::U32 => {
            let v: Vec<u32> = vals.iter().map(|x| *x as u32).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::U8 => vals.iter().map(|x| *x as u8).collect(),
    }
}

// Integer-aware binary ops (no f64 roundtrip for integer types)
fn binary_op_int(
    data_a: &[u8],
    dtype_a: DType,
    data_b: &[u8],
    dtype_b: DType,
    op_i64: impl Fn(i64, i64) -> i64,
    op_f64: impl Fn(f64, f64) -> f64,
) -> Vec<u8> {
    match (dtype_a, dtype_b) {
        (DType::S32, DType::S32) => {
            let va: &[i32] = bytemuck::cast_slice(data_a);
            let vb: &[i32] = bytemuck::cast_slice(data_b);
            let v: Vec<i32> = va
                .iter()
                .zip(vb.iter())
                .map(|(a, b)| op_i64(*a as i64, *b as i64) as i32)
                .collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        (DType::S64, DType::S64) => {
            let va: &[i64] = bytemuck::cast_slice(data_a);
            let vb: &[i64] = bytemuck::cast_slice(data_b);
            let v: Vec<i64> = va
                .iter()
                .zip(vb.iter())
                .map(|(a, b)| op_i64(*a, *b))
                .collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        (DType::U32, DType::U32) => {
            let va: &[u32] = bytemuck::cast_slice(data_a);
            let vb: &[u32] = bytemuck::cast_slice(data_b);
            let v: Vec<u32> = va
                .iter()
                .zip(vb.iter())
                .map(|(a, b)| op_i64(*a as i64, *b as i64) as u32)
                .collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        (DType::U8, DType::U8) => data_a
            .iter()
            .zip(data_b.iter())
            .map(|(a, b)| op_i64(*a as i64, *b as i64) as u8)
            .collect(),
        _ => {
            let va = to_f64(data_a, dtype_a);
            let vb = to_f64(data_b, dtype_b);
            from_f64(
                va.iter()
                    .zip(vb.iter())
                    .map(|(a, b)| op_f64(*a, *b))
                    .collect(),
                dtype_a,
            )
        }
    }
}

fn unary_op_int(
    data: &[u8],
    dtype: DType,
    op_i64: impl Fn(i64) -> i64,
    op_f64: impl Fn(f64) -> f64,
) -> Vec<u8> {
    match dtype {
        DType::S32 => {
            let va: &[i32] = bytemuck::cast_slice(data);
            let v: Vec<i32> = va.iter().map(|a| op_i64(*a as i64) as i32).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::S64 => {
            let va: &[i64] = bytemuck::cast_slice(data);
            let v: Vec<i64> = va.iter().map(|a| op_i64(*a)).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::U32 => {
            let va: &[u32] = bytemuck::cast_slice(data);
            let v: Vec<u32> = va.iter().map(|a| op_i64(*a as i64) as u32).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::U8 => data.iter().map(|a| op_i64(*a as i64) as u8).collect(),
        _ => {
            let v = to_f64(data, dtype);
            from_f64(v.iter().map(|x| op_f64(*x)).collect(), dtype)
        }
    }
}

fn unary_op(data: &[u8], dtype: DType, op: impl Fn(f64) -> f64 + Copy) -> Vec<u8> {
    unary_op_int(data, dtype, |x| op(x as f64) as i64, op)
}

fn binary_op(
    data_a: &[u8],
    dtype_a: DType,
    data_b: &[u8],
    dtype_b: DType,
    op: impl Fn(f64, f64) -> f64 + Copy,
) -> Vec<u8> {
    binary_op_int(
        data_a,
        dtype_a,
        data_b,
        dtype_b,
        |a, b| op(a as f64, b as f64) as i64,
        op,
    )
}

fn unary_op_bool(data: &[u8], dtype: DType, op: impl Fn(f64) -> bool) -> Vec<u8> {
    let v = to_f64(data, dtype);
    v.iter().map(|x| if op(*x) { 1u8 } else { 0u8 }).collect()
}

fn binary_op_bool(
    data_a: &[u8],
    dtype_a: DType,
    data_b: &[u8],
    dtype_b: DType,
    op: impl Fn(f64, f64) -> bool,
) -> Vec<u8> {
    let va = to_f64(data_a, dtype_a);
    let vb = to_f64(data_b, dtype_b);
    va.iter()
        .zip(vb.iter())
        .map(|(a, b)| if op(*a, *b) { 1u8 } else { 0u8 })
        .collect()
}

// ── NIF init ─────────────────────────────────────────────────

#[allow(non_local_definitions)]
fn on_load(env: Env, _info: Term) -> bool {
    let _ = rustler::resource!(BufResource, env);
    true
}

rustler::init!("Elixir.ExCubecl.NIF", load = on_load);

// ═══════════════════════════════════════════════════════════════
// Buffer lifecycle
// ═══════════════════════════════════════════════════════════════

#[rustler::nif(schedule = "DirtyCpu")]
fn new_tensor<'a>(
    env: Env<'a>,
    data: rustler::Binary<'a>,
    shape: Vec<usize>,
    dtype_str: String,
) -> NifResult<Term<'a>> {
    let dtype = decode_dtype(&dtype_str)?;
    let expected = shape.iter().product::<usize>() * dtype.size_in_bytes();
    if data.len() != expected {
        return Err(Error::RaiseTerm(Box::new(format!(
            "size mismatch: got {} expected {}",
            data.len(),
            expected
        ))));
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: data.as_slice().to_vec(),
            shape,
            dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif]
fn read_tensor<'a>(env: Env<'a>, buf: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    let mut nb = rustler::types::binary::NewBinary::new(env, buf.data.len());
    nb.as_mut_slice().copy_from_slice(&buf.data);
    Ok((atoms::ok(), rustler::Binary::from(nb)).encode(env))
}

#[rustler::nif]
fn deallocate_tensor(env: Env, _buf: ResourceArc<BufResource>) -> NifResult<Term> {
    Ok(atoms::ok().encode(env))
}

#[rustler::nif]
fn tensor_shape(env: Env, buf: ResourceArc<BufResource>) -> NifResult<Term> {
    Ok((atoms::ok(), buf.shape.clone()).encode(env))
}

#[rustler::nif]
fn tensor_dtype(env: Env, buf: ResourceArc<BufResource>) -> NifResult<Term> {
    Ok((atoms::ok(), buf.dtype as usize).encode(env))
}

// ═══════════════════════════════════════════════════════════════
// Binary ops
// ═══════════════════════════════════════════════════════════════

#[rustler::nif(schedule = "DirtyCpu")]
fn add<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op(&a.data, a.dtype, &b.data, b.dtype, |x, y| x + y),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn subtract<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op(&a.data, a.dtype, &b.data, b.dtype, |x, y| x - y),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn multiply<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op(&a.data, a.dtype, &b.data, b.dtype, |x, y| x * y),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn divide<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op(&a.data, a.dtype, &b.data, b.dtype, |x, y| {
                if y == 0.0 {
                    f64::NAN
                } else {
                    x / y
                }
            }),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn pow<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op(&a.data, a.dtype, &b.data, b.dtype, |x, y| x.powf(y)),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn remainder<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op(&a.data, a.dtype, &b.data, b.dtype, |x, y| x % y),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn atan2<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op(&a.data, a.dtype, &b.data, b.dtype, |x, y| x.atan2(y)),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn min_tensor<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op(&a.data, a.dtype, &b.data, b.dtype, |x, y| x.min(y)),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn max_tensor<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op(&a.data, a.dtype, &b.data, b.dtype, |x, y| x.max(y)),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn quotient<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op(&a.data, a.dtype, &b.data, b.dtype, |x, y| {
                if y == 0.0 {
                    0.0
                } else {
                    (x / y).trunc()
                }
            }),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

// ── Bitwise ops (integer-efficient) ──────────────────────────

#[rustler::nif(schedule = "DirtyCpu")]
fn bitwise_and<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_int(
                &a.data,
                a.dtype,
                &b.data,
                b.dtype,
                |x, y| x & y,
                |x, y| (x as i64 & y as i64) as f64,
            ),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn bitwise_or<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_int(
                &a.data,
                a.dtype,
                &b.data,
                b.dtype,
                |x, y| x | y,
                |x, y| (x as i64 | y as i64) as f64,
            ),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn bitwise_xor<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_int(
                &a.data,
                a.dtype,
                &b.data,
                b.dtype,
                |x, y| x ^ y,
                |x, y| (x as i64 ^ y as i64) as f64,
            ),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn left_shift<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_int(
                &a.data,
                a.dtype,
                &b.data,
                b.dtype,
                |x, y| x << y,
                |x, y| ((x as i64) << (y as i64)) as f64,
            ),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn right_shift<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_int(
                &a.data,
                a.dtype,
                &b.data,
                b.dtype,
                |x, y| x >> y,
                |x, y| ((x as i64) >> (y as i64)) as f64,
            ),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

// ═══════════════════════════════════════════════════════════════
// Comparison ops
// ═══════════════════════════════════════════════════════════════

#[rustler::nif(schedule = "DirtyCpu")]
fn equal<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_bool(&a.data, a.dtype, &b.data, b.dtype, |x, y| {
                (x - y).abs() < f64::EPSILON
            }),
            shape: a.shape.clone(),
            dtype: DType::U8,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn not_equal<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_bool(&a.data, a.dtype, &b.data, b.dtype, |x, y| {
                (x - y).abs() >= f64::EPSILON
            }),
            shape: a.shape.clone(),
            dtype: DType::U8,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn greater<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_bool(&a.data, a.dtype, &b.data, b.dtype, |x, y| x > y),
            shape: a.shape.clone(),
            dtype: DType::U8,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn less<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_bool(&a.data, a.dtype, &b.data, b.dtype, |x, y| x < y),
            shape: a.shape.clone(),
            dtype: DType::U8,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn greater_equal<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_bool(&a.data, a.dtype, &b.data, b.dtype, |x, y| x >= y),
            shape: a.shape.clone(),
            dtype: DType::U8,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn less_equal<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_bool(&a.data, a.dtype, &b.data, b.dtype, |x, y| x <= y),
            shape: a.shape.clone(),
            dtype: DType::U8,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn logical_and<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_bool(&a.data, a.dtype, &b.data, b.dtype, |x, y| {
                x != 0.0 && y != 0.0
            }),
            shape: a.shape.clone(),
            dtype: DType::U8,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn logical_or<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_bool(&a.data, a.dtype, &b.data, b.dtype, |x, y| {
                x != 0.0 || y != 0.0
            }),
            shape: a.shape.clone(),
            dtype: DType::U8,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn logical_xor<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    b: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: binary_op_bool(&a.data, a.dtype, &b.data, b.dtype, |x, y| {
                (x != 0.0) ^ (y != 0.0)
            }),
            shape: a.shape.clone(),
            dtype: DType::U8,
        }),
    )
        .encode(env))
}

// ═══════════════════════════════════════════════════════════════
// Unary ops
// ═══════════════════════════════════════════════════════════════

macro_rules! unary_nif {
    ($name:ident, $op:expr) => {
        #[rustler::nif(schedule = "DirtyCpu")]
        fn $name<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
            Ok((
                atoms::ok(),
                ResourceArc::new(BufResource {
                    data: unary_op(&a.data, a.dtype, $op),
                    shape: a.shape.clone(),
                    dtype: a.dtype,
                }),
            )
                .encode(env))
        }
    };
}

unary_nif!(negate, |x| -x);
unary_nif!(abs_tensor, |x| x.abs());
unary_nif!(exp, |x| x.exp());
unary_nif!(log, |x| if x > 0.0 { x.ln() } else { f64::NAN });
unary_nif!(sqrt, |x| x.sqrt());
unary_nif!(sin, |x| x.sin());
unary_nif!(cos, |x| x.cos());
unary_nif!(tan, |x| x.tan());
unary_nif!(sigmoid, |x| 1.0 / (1.0 + (-x).exp()));
unary_nif!(relu, |x| if x > 0.0 { x } else { 0.0 });
unary_nif!(expm1, |x| x.exp_m1());
unary_nif!(log1p, |x| (1.0 + x).ln());
unary_nif!(cosh, |x| x.cosh());
unary_nif!(sinh, |x| x.sinh());
unary_nif!(tanh, |x| x.tanh());
unary_nif!(acos, |x| x.acos());
unary_nif!(asin, |x| x.asin());
unary_nif!(atan, |x| x.atan());
unary_nif!(acosh, |x| x.acosh());
unary_nif!(asinh, |x| x.asinh());
unary_nif!(atanh, |x| x.atanh());
unary_nif!(rsqrt, |x| 1.0 / x.sqrt());
unary_nif!(cbrt, |x| x.cbrt());
unary_nif!(ceil_tensor, |x| x.ceil());
unary_nif!(floor_tensor, |x| x.floor());
unary_nif!(round_tensor, |x| x.round());

#[rustler::nif(schedule = "DirtyCpu")]
fn sign_tensor<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: unary_op(&a.data, a.dtype, |x| {
                if x > 0.0 {
                    1.0
                } else if x < 0.0 {
                    -1.0
                } else {
                    0.0
                }
            }),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

fn erf_approx(x: f64) -> f64 {
    let s = if x < 0.0 { -1.0 } else { 1.0 };
    let ax = x.abs();
    let t = 1.0 / (1.0 + 0.3275911 * ax);
    let p = 1.061405429 * t - 1.453152027;
    let p = p * t + 1.421413741;
    let p = p * t - 0.284496736;
    let p = p * t + 0.254829592;
    s * (1.0 - p * t * (-ax * ax).exp())
}

#[rustler::nif(schedule = "DirtyCpu")]
fn erf<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: unary_op(&a.data, a.dtype, erf_approx),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn erfc<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: unary_op(&a.data, a.dtype, |x| 1.0 - erf_approx(x)),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn erf_inv<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: unary_op(&a.data, a.dtype, |x| {
                let mut y = 0.0f64;
                for _ in 0..50 {
                    let e = erf_approx(y) - x;
                    let d = 2.0 / std::f64::consts::PI.sqrt() * (-y * y).exp();
                    if d.abs() < 1e-15 {
                        break;
                    }
                    y -= e / d;
                }
                y
            }),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn bitwise_not<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: unary_op_int(&a.data, a.dtype, |x| !x, |x| (!(x as i64)) as f64),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn conjugate<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: a.data.clone(),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn count_leading_zeros<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    let data = match a.dtype {
        DType::S32 => {
            let va: &[i32] = bytemuck::cast_slice(&a.data);
            let v: Vec<i32> = va.iter().map(|x| x.leading_zeros() as i32).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::S64 => {
            let va: &[i64] = bytemuck::cast_slice(&a.data);
            let v: Vec<i64> = va.iter().map(|x| x.leading_zeros() as i64).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::U32 => {
            let va: &[u32] = bytemuck::cast_slice(&a.data);
            let v: Vec<u32> = va.iter().map(|x| x.leading_zeros() as u32).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::U8 => {
            let v: Vec<u8> = a.data.iter().map(|x| x.leading_zeros() as u8).collect();
            v
        }
        _ => {
            let v = to_f64(&a.data, a.dtype);
            from_f64(
                v.iter()
                    .map(|x| (*x as u64).leading_zeros() as f64)
                    .collect(),
                a.dtype,
            )
        }
    };
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data,
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn population_count<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    let data = match a.dtype {
        DType::S32 => {
            let va: &[i32] = bytemuck::cast_slice(&a.data);
            let v: Vec<i32> = va.iter().map(|x| x.count_ones() as i32).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::S64 => {
            let va: &[i64] = bytemuck::cast_slice(&a.data);
            let v: Vec<i64> = va.iter().map(|x| x.count_ones() as i64).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::U32 => {
            let va: &[u32] = bytemuck::cast_slice(&a.data);
            let v: Vec<u32> = va.iter().map(|x| x.count_ones() as u32).collect();
            bytemuck::cast_slice(&v).to_vec()
        }
        DType::U8 => {
            let v: Vec<u8> = a.data.iter().map(|x| x.count_ones() as u8).collect();
            v
        }
        _ => {
            let v = to_f64(&a.data, a.dtype);
            from_f64(
                v.iter().map(|x| (*x as u64).count_ones() as f64).collect(),
                a.dtype,
            )
        }
    };
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data,
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn real<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: a.data.clone(),
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn imag<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    let n = a.num_elements();
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: vec![0u8; n * a.dtype.size_in_bytes()],
            shape: a.shape.clone(),
            dtype: a.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn is_nan<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: unary_op_bool(&a.data, a.dtype, |x| x.is_nan()),
            shape: a.shape.clone(),
            dtype: DType::U8,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn is_infinity<'a>(env: Env<'a>, a: ResourceArc<BufResource>) -> NifResult<Term<'a>> {
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: unary_op_bool(&a.data, a.dtype, |x| x.is_infinite()),
            shape: a.shape.clone(),
            dtype: DType::U8,
        }),
    )
        .encode(env))
}

// ═══════════════════════════════════════════════════════════════
// Shape operations
// ═══════════════════════════════════════════════════════════════

#[rustler::nif(schedule = "DirtyCpu")]
fn reshape_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    new_shape: Vec<usize>,
) -> NifResult<Term<'a>> {
    let old_n: usize = buf.shape.iter().product();
    let new_n: usize = new_shape.iter().product();
    if old_n != new_n {
        return Err(Error::RaiseTerm(Box::new(format!(
            "reshape: {} vs {} elements",
            old_n, new_n
        ))));
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: buf.data.clone(),
            shape: new_shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn squeeze_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    axes: Vec<i64>,
) -> NifResult<Term<'a>> {
    let rank = buf.shape.len() as i64;
    let mut ns = buf.shape.clone();
    let mut sorted = axes.clone();
    sorted.sort_by(|a, b| b.cmp(a));
    for ax in sorted {
        let idx = if ax < 0 { rank + ax } else { ax };
        if idx < 0 || idx >= rank {
            return Err(Error::RaiseTerm(Box::new("squeeze: axis out of range")));
        }
        let idx = idx as usize;
        if ns[idx] == 1 {
            ns.remove(idx);
        }
    }
    if ns.is_empty() {
        ns.push(1);
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: buf.data.clone(),
            shape: ns,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn broadcast_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    target_shape: Vec<usize>,
    axes: Vec<usize>,
) -> NifResult<Term<'a>> {
    let in_rank = buf.shape.len();
    let out_rank = target_shape.len();
    if axes.len() != in_rank {
        return Err(Error::RaiseTerm(Box::new(
            "broadcast: axes length mismatch",
        )));
    }
    let total_out: usize = target_shape.iter().product();
    let out_strides = strides_for(&target_shape);
    let in_strides = strides_for(&buf.shape);
    let in_f64 = to_f64(&buf.data, buf.dtype);
    let mut out_vals = vec![0.0f64; total_out];
    for i in 0..total_out {
        let mut coords = vec![0usize; out_rank];
        let mut rem = i;
        for d in 0..out_rank {
            coords[d] = rem / out_strides[d];
            rem %= out_strides[d];
        }
        let mut in_coords = vec![0usize; in_rank];
        for (di, &ax) in axes.iter().enumerate() {
            in_coords[di] = if ax < out_rank && buf.shape[di] == target_shape[ax] {
                coords[ax]
            } else {
                0
            };
        }
        let flat_in: usize = in_coords
            .iter()
            .zip(in_strides.iter())
            .map(|(c, s)| c * s)
            .sum();
        out_vals[i] = in_f64[flat_in];
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_vals, buf.dtype),
            shape: target_shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn transpose_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    axes: Vec<usize>,
) -> NifResult<Term<'a>> {
    let rank = buf.shape.len();
    if axes.len() != rank {
        return Err(Error::RaiseTerm(Box::new(
            "transpose: axes length mismatch",
        )));
    }
    let new_shape: Vec<usize> = axes.iter().map(|&a| buf.shape[a]).collect();
    let total: usize = buf.shape.iter().product();
    let in_strides = strides_for(&buf.shape);
    let out_strides = strides_for(&new_shape);
    let in_f64 = to_f64(&buf.data, buf.dtype);
    let mut out_vals = vec![0.0f64; total];
    for i in 0..total {
        let mut in_coords = vec![0usize; rank];
        let mut rem = i;
        for d in 0..rank {
            in_coords[axes[d]] = rem / out_strides[d];
            rem %= out_strides[d];
        }
        let flat_in: usize = in_coords
            .iter()
            .zip(in_strides.iter())
            .map(|(c, s)| c * s)
            .sum();
        out_vals[i] = in_f64[flat_in];
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_vals, buf.dtype),
            shape: new_shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn pad_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    pad_ref: ResourceArc<BufResource>,
    padding_config: Vec<(i64, i64, i64)>,
) -> NifResult<Term<'a>> {
    let rank = buf.shape.len();
    if padding_config.len() != rank {
        return Err(Error::RaiseTerm(Box::new("pad: config length mismatch")));
    }
    let pv = to_f64(&pad_ref.data, pad_ref.dtype);
    let pad_value = if pv.is_empty() { 0.0 } else { pv[0] };
    let new_shape: Vec<usize> = buf
        .shape
        .iter()
        .zip(padding_config.iter())
        .map(|(&s, &(lo, hi, interior))| {
            let lo = if lo < 0 { 0 } else { lo as usize };
            let hi = if hi < 0 { 0 } else { hi as usize };
            s + lo + hi + s.saturating_sub(1) * (if interior < 0 { 0 } else { interior as usize })
        })
        .collect();
    let total_out: usize = new_shape.iter().product();
    let out_strides = strides_for(&new_shape);
    let in_strides = strides_for(&buf.shape);
    let in_f64 = to_f64(&buf.data, buf.dtype);
    let mut out_vals = vec![pad_value; total_out];
    let total_in: usize = buf.shape.iter().product();
    for i in 0..total_in {
        let mut in_coords = vec![0usize; rank];
        let mut rem = i;
        for d in 0..rank {
            in_coords[d] = rem / in_strides[d];
            rem %= in_strides[d];
        }
        let mut out_coords = vec![0usize; rank];
        let mut acc = 0usize;
        for d in 0..rank {
            let lo = if padding_config[d].0 < 0 {
                0
            } else {
                padding_config[d].0 as usize
            };
            out_coords[d] = in_coords[d] + lo + acc;
            acc += in_coords[d]
                * (if padding_config[d].2 < 0 {
                    0
                } else {
                    padding_config[d].2 as usize
                });
        }
        let flat_out: usize = out_coords
            .iter()
            .zip(out_strides.iter())
            .map(|(c, s)| c * s)
            .sum();
        if flat_out < total_out {
            out_vals[flat_out] = in_f64[i];
        }
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_vals, buf.dtype),
            shape: new_shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn reverse_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    axes: Vec<usize>,
) -> NifResult<Term<'a>> {
    let rank = buf.shape.len();
    let total: usize = buf.shape.iter().product();
    let in_strides = strides_for(&buf.shape);
    let in_f64 = to_f64(&buf.data, buf.dtype);
    let mut out_vals = vec![0.0f64; total];
    for i in 0..total {
        let mut coords = vec![0usize; rank];
        let mut rem = i;
        for d in 0..rank {
            coords[d] = rem / in_strides[d];
            rem %= in_strides[d];
        }
        let mut in_coords = coords.clone();
        for &ax in &axes {
            if ax < rank {
                in_coords[ax] = buf.shape[ax] - 1 - coords[ax];
            }
        }
        let flat_in: usize = in_coords
            .iter()
            .zip(in_strides.iter())
            .map(|(c, s)| c * s)
            .sum();
        out_vals[i] = in_f64[flat_in];
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_vals, buf.dtype),
            shape: buf.shape.clone(),
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn slice_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    starts: Vec<usize>,
    lengths: Vec<usize>,
    strides: Vec<usize>,
) -> NifResult<Term<'a>> {
    let rank = buf.shape.len();
    if starts.len() != rank || lengths.len() != rank || strides.len() != rank {
        return Err(Error::RaiseTerm(Box::new(
            "slice: parameter length mismatch",
        )));
    }
    let out_shape = lengths.clone();
    let total_out: usize = out_shape.iter().product();
    let out_strides = strides_for(&out_shape);
    let in_strides = strides_for(&buf.shape);
    let in_f64 = to_f64(&buf.data, buf.dtype);
    let mut out_vals = vec![0.0f64; total_out];
    for i in 0..total_out {
        let mut out_coords = vec![0usize; rank];
        let mut rem = i;
        for d in 0..rank {
            out_coords[d] = rem / out_strides[d];
            rem %= out_strides[d];
        }
        let in_coords: Vec<usize> = (0..rank)
            .map(|d| starts[d] + out_coords[d] * strides[d])
            .collect();
        let flat_in: usize = in_coords
            .iter()
            .zip(in_strides.iter())
            .map(|(c, s)| c * s)
            .sum();
        out_vals[i] = in_f64[flat_in];
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_vals, buf.dtype),
            shape: out_shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn concatenate_tensors<'a>(
    env: Env<'a>,
    bufs: Vec<ResourceArc<BufResource>>,
    axis: usize,
) -> NifResult<Term<'a>> {
    if bufs.is_empty() {
        return Err(Error::RaiseTerm(Box::new("concat: empty")));
    }
    let dtype = bufs[0].dtype;
    let rank = bufs[0].shape.len();
    for b in &bufs {
        if b.shape.len() != rank {
            return Err(Error::RaiseTerm(Box::new("concat: rank mismatch")));
        }
    }
    let out_axis: usize = bufs.iter().map(|b| b.shape[axis]).sum();
    let mut out_shape = bufs[0].shape.clone();
    out_shape[axis] = out_axis;
    let total_out: usize = out_shape.iter().product();
    let out_strides = strides_for(&out_shape);
    let mut out_vals = vec![0.0f64; total_out];
    let mut offset = 0usize;
    for buf in &bufs {
        let in_strides = strides_for(&buf.shape);
        let in_f64 = to_f64(&buf.data, buf.dtype);
        let total_in: usize = buf.shape.iter().product();
        for i in 0..total_in {
            let mut in_coords = vec![0usize; rank];
            let mut rem = i;
            for d in 0..rank {
                in_coords[d] = rem / in_strides[d];
                rem %= in_strides[d];
            }
            let mut out_coords = in_coords.clone();
            out_coords[axis] += offset;
            let flat_out: usize = out_coords
                .iter()
                .zip(out_strides.iter())
                .map(|(c, s)| c * s)
                .sum();
            out_vals[flat_out] = in_f64[i];
        }
        offset += buf.shape[axis];
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_vals, dtype),
            shape: out_shape,
            dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn stack_tensors<'a>(
    env: Env<'a>,
    bufs: Vec<ResourceArc<BufResource>>,
    axis: usize,
) -> NifResult<Term<'a>> {
    if bufs.is_empty() {
        return Err(Error::RaiseTerm(Box::new("stack: empty")));
    }
    let dtype = bufs[0].dtype;
    let rank = bufs[0].shape.len();
    for b in &bufs {
        if b.shape != bufs[0].shape {
            return Err(Error::RaiseTerm(Box::new("stack: shape mismatch")));
        }
    }
    let mut out_shape = bufs[0].shape.clone();
    out_shape.insert(axis, bufs.len());
    let out_rank = out_shape.len();
    let total_out: usize = out_shape.iter().product();
    let out_strides = strides_for(&out_shape);
    let mut out_vals = vec![0.0f64; total_out];
    for (si, buf) in bufs.iter().enumerate() {
        let in_strides = strides_for(&buf.shape);
        let in_f64 = to_f64(&buf.data, buf.dtype);
        let total_in: usize = buf.shape.iter().product();
        for i in 0..total_in {
            let mut in_coords = vec![0usize; rank];
            let mut rem = i;
            for d in 0..rank {
                in_coords[d] = rem / in_strides[d];
                rem %= in_strides[d];
            }
            let mut out_coords = vec![0usize; out_rank];
            let mut j = 0;
            for d in 0..out_rank {
                if d == axis {
                    out_coords[d] = si;
                } else {
                    out_coords[d] = in_coords[j];
                    j += 1;
                }
            }
            let flat_out: usize = out_coords
                .iter()
                .zip(out_strides.iter())
                .map(|(c, s)| c * s)
                .sum();
            out_vals[flat_out] = in_f64[i];
        }
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_vals, dtype),
            shape: out_shape,
            dtype,
        }),
    )
        .encode(env))
}

fn broadcast_shape_simple(a: &[usize], b: &[usize], c: &[usize]) -> Vec<usize> {
    let max_len = a.len().max(b.len()).max(c.len());
    let mut result = vec![0usize; max_len];
    for i in 0..max_len {
        let da = if i < a.len() { a[a.len() - 1 - i] } else { 1 };
        let db = if i < b.len() { b[b.len() - 1 - i] } else { 1 };
        let dc = if i < c.len() { c[c.len() - 1 - i] } else { 1 };
        result[max_len - 1 - i] = da.max(db).max(dc);
    }
    result
}

#[rustler::nif(schedule = "DirtyCpu")]
fn select_tensor<'a>(
    env: Env<'a>,
    pred: ResourceArc<BufResource>,
    on_true: ResourceArc<BufResource>,
    on_false: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    let out_shape = broadcast_shape_simple(&pred.shape, &on_true.shape, &on_false.shape);
    let out_dtype = on_true.dtype;
    let total_out: usize = out_shape.iter().product();
    let out_strides = strides_for(&out_shape);
    let rank = out_shape.len();
    let pred_f = to_f64(&pred.data, pred.dtype);
    let true_f = to_f64(&on_true.data, on_true.dtype);
    let false_f = to_f64(&on_false.data, on_false.dtype);
    let pred_s = strides_for(&pred.shape);
    let true_s = strides_for(&on_true.shape);
    let false_s = strides_for(&on_false.shape);
    let mut out_vals = vec![0.0f64; total_out];
    for i in 0..total_out {
        let mut coords = vec![0usize; rank];
        let mut rem = i;
        for d in 0..rank {
            coords[d] = rem / out_strides[d];
            rem %= out_strides[d];
        }
        let pi = if pred.shape.len() == 1 && pred.shape[0] == 1 {
            0
        } else {
            coords
                .iter()
                .zip(pred_s.iter())
                .enumerate()
                .map(|(d, (c, s))| {
                    (if d < pred.shape.len() && pred.shape[d] == 1 {
                        0
                    } else {
                        *c
                    }) * s
                })
                .sum::<usize>()
        };
        let ti = if on_true.shape.len() == 1 && on_true.shape[0] == 1 {
            0
        } else {
            coords
                .iter()
                .zip(true_s.iter())
                .enumerate()
                .map(|(d, (c, s))| {
                    (if d < on_true.shape.len() && on_true.shape[d] == 1 {
                        0
                    } else {
                        *c
                    }) * s
                })
                .sum::<usize>()
        };
        let fi = if on_false.shape.len() == 1 && on_false.shape[0] == 1 {
            0
        } else {
            coords
                .iter()
                .zip(false_s.iter())
                .enumerate()
                .map(|(d, (c, s))| {
                    (if d < on_false.shape.len() && on_false.shape[d] == 1 {
                        0
                    } else {
                        *c
                    }) * s
                })
                .sum::<usize>()
        };
        out_vals[i] = if pred_f[pi] != 0.0 {
            true_f[ti]
        } else {
            false_f[fi]
        };
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_vals, out_dtype),
            shape: out_shape,
            dtype: out_dtype,
        }),
    )
        .encode(env))
}

// ═══════════════════════════════════════════════════════════════
// Reductions
// ═══════════════════════════════════════════════════════════════

fn reduce_along_axes(
    data: &[u8],
    dtype: DType,
    shape: &[usize],
    axes: &[usize],
    keep_dims: bool,
    init: f64,
    op: impl Fn(f64, f64) -> f64,
) -> (Vec<u8>, Vec<usize>) {
    if shape.is_empty() {
        return (data.to_vec(), vec![]);
    }
    let v = to_f64(data, dtype);
    let total_in: usize = shape.iter().product();
    if total_in == 0 {
        return (from_f64(vec![init], dtype), vec![1]);
    }
    let in_strides = strides_for(shape);
    let mut out_shape = Vec::new();
    for (i, &s) in shape.iter().enumerate() {
        if !axes.contains(&i) {
            out_shape.push(s);
        } else if keep_dims {
            out_shape.push(1);
        }
    }
    if out_shape.is_empty() {
        out_shape = vec![1];
    }
    let out_total: usize = out_shape.iter().product();
    let out_strides = strides_for(&out_shape);
    let mut result = vec![init; out_total];
    for idx in 0..total_in {
        let mut remainder = idx;
        let mut out_idx = 0usize;
        for d in 0..shape.len() {
            let coord = remainder / in_strides[d];
            remainder %= in_strides[d];
            if !axes.contains(&d) {
                let out_d = if keep_dims {
                    d
                } else {
                    shape[..d]
                        .iter()
                        .enumerate()
                        .filter(|&(dd, _)| !axes.contains(&dd))
                        .count()
                };
                if out_d < out_strides.len() {
                    out_idx += coord * out_strides[out_d];
                }
            }
        }
        if out_idx < result.len() {
            result[out_idx] = op(result[out_idx], v[idx]);
        }
    }
    (from_f64(result, dtype), out_shape)
}

fn decode_reduction_opts(opts: Term) -> NifResult<(Vec<usize>, bool)> {
    let mut axes = Vec::new();
    let mut keep_dims = false;
    if let Ok(list) = opts.decode::<Vec<Term>>() {
        for item in list {
            if let Ok((ref key, val)) = item.decode::<(String, Term)>() {
                match key.as_str() {
                    "axes" => {
                        axes = val.decode::<Vec<usize>>().unwrap_or_default();
                    }
                    "keep_axes" | "keep_dims" => {
                        keep_dims = val.decode::<bool>().unwrap_or(false);
                    }
                    _ => {}
                }
            }
        }
    }
    Ok((axes, keep_dims))
}

fn decode_axis_opts(opts: Term) -> NifResult<(usize, bool)> {
    let mut axis = 0usize;
    let mut keep_dims = false;
    if let Ok(list) = opts.decode::<Vec<Term>>() {
        for item in list {
            if let Ok((ref key, val)) = item.decode::<(String, Term)>() {
                match key.as_str() {
                    "axis" => {
                        axis = val.decode::<usize>().unwrap_or(0);
                    }
                    "keep_axes" | "keep_dims" => {
                        keep_dims = val.decode::<bool>().unwrap_or(false);
                    }
                    _ => {}
                }
            }
        }
    }
    Ok((axis, keep_dims))
}

fn decode_window_opts(opts: Term) -> NifResult<(Vec<usize>, Vec<usize>, bool)> {
    let mut shape = Vec::new();
    let mut axes = Vec::new();
    let mut keep_dims = false;
    if let Ok(list) = opts.decode::<Vec<Term>>() {
        for item in list {
            if let Ok((ref key, val)) = item.decode::<(String, Term)>() {
                match key.as_str() {
                    "shape" => {
                        shape = val.decode::<Vec<usize>>().unwrap_or_default();
                    }
                    "axes" => {
                        axes = val.decode::<Vec<usize>>().unwrap_or_default();
                    }
                    "keep_axes" | "keep_dims" => {
                        keep_dims = val.decode::<bool>().unwrap_or(false);
                    }
                    _ => {}
                }
            }
        }
    }
    Ok((shape, axes, keep_dims))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn sum_tensor<'a>(env: Env<'a>, buf: ResourceArc<BufResource>, opts: Term) -> NifResult<Term<'a>> {
    let (axes, kd) = decode_reduction_opts(opts)?;
    let (data, shape) =
        reduce_along_axes(&buf.data, buf.dtype, &buf.shape, &axes, kd, 0.0, |a, b| {
            a + b
        });
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data,
            shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn product_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    opts: Term,
) -> NifResult<Term<'a>> {
    let (axes, kd) = decode_reduction_opts(opts)?;
    let (data, shape) =
        reduce_along_axes(&buf.data, buf.dtype, &buf.shape, &axes, kd, 1.0, |a, b| {
            a * b
        });
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data,
            shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn reduce_max<'a>(env: Env<'a>, buf: ResourceArc<BufResource>, opts: Term) -> NifResult<Term<'a>> {
    let (axes, kd) = decode_reduction_opts(opts)?;
    let (data, shape) = reduce_along_axes(
        &buf.data,
        buf.dtype,
        &buf.shape,
        &axes,
        kd,
        f64::MIN,
        |a, b| a.max(b),
    );
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data,
            shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn reduce_min<'a>(env: Env<'a>, buf: ResourceArc<BufResource>, opts: Term) -> NifResult<Term<'a>> {
    let (axes, kd) = decode_reduction_opts(opts)?;
    let (data, shape) = reduce_along_axes(
        &buf.data,
        buf.dtype,
        &buf.shape,
        &axes,
        kd,
        f64::MAX,
        |a, b| a.min(b),
    );
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data,
            shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn all_tensor<'a>(env: Env<'a>, buf: ResourceArc<BufResource>, opts: Term) -> NifResult<Term<'a>> {
    let (axes, kd) = decode_reduction_opts(opts)?;
    let (data, shape) =
        reduce_along_axes(&buf.data, buf.dtype, &buf.shape, &axes, kd, 1.0, |a, b| {
            if a != 0.0 && b != 0.0 {
                1.0
            } else {
                0.0
            }
        });
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data,
            shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn any_tensor<'a>(env: Env<'a>, buf: ResourceArc<BufResource>, opts: Term) -> NifResult<Term<'a>> {
    let (axes, kd) = decode_reduction_opts(opts)?;
    let (data, shape) =
        reduce_along_axes(&buf.data, buf.dtype, &buf.shape, &axes, kd, 0.0, |a, b| {
            if a != 0.0 || b != 0.0 {
                1.0
            } else {
                0.0
            }
        });
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data,
            shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn argmax_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    opts: Term,
) -> NifResult<Term<'a>> {
    let (axis, kd) = decode_axis_opts(opts)?;
    let shape = &buf.shape;
    let rank = shape.len();
    if axis >= rank {
        return Err(Error::RaiseTerm(Box::new("argmax: axis out of range")));
    }
    let axis_size = shape[axis];
    let mut out_shape = Vec::new();
    for (i, &s) in shape.iter().enumerate() {
        if i != axis {
            out_shape.push(s);
        } else if kd {
            out_shape.push(1);
        }
    }
    if out_shape.is_empty() {
        out_shape = vec![1];
    }
    let total_out: usize = out_shape.iter().product();
    let out_strides = strides_for(&out_shape);
    let in_strides = strides_for(shape);
    let in_f64 = to_f64(&buf.data, buf.dtype);
    let mut out_data = vec![0u8; total_out * 4];
    let out_i32: &mut [i32] = bytemuck::cast_slice_mut(&mut out_data);
    for i in 0..total_out {
        let mut out_coords = vec![0usize; out_shape.len()];
        let mut rem = i;
        for d in 0..out_shape.len() {
            out_coords[d] = rem / out_strides[d];
            rem %= out_strides[d];
        }
        let mut in_coords = vec![0usize; rank];
        let mut oi = 0;
        for d in 0..rank {
            if d == axis {
                if kd {
                    oi += 1;
                }
            } else {
                in_coords[d] = out_coords[oi];
                oi += 1;
            }
        }
        let mut best_val = f64::MIN;
        let mut best_idx: i32 = 0;
        for j in 0..axis_size {
            in_coords[axis] = j;
            let flat_in: usize = in_coords
                .iter()
                .zip(in_strides.iter())
                .map(|(c, s)| c * s)
                .sum();
            let val = in_f64[flat_in];
            if val > best_val {
                best_val = val;
                best_idx = j as i32;
            }
        }
        out_i32[i] = best_idx;
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: out_data,
            shape: out_shape,
            dtype: DType::S32,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn argmin_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    opts: Term,
) -> NifResult<Term<'a>> {
    let (axis, kd) = decode_axis_opts(opts)?;
    let shape = &buf.shape;
    let rank = shape.len();
    if axis >= rank {
        return Err(Error::RaiseTerm(Box::new("argmin: axis out of range")));
    }
    let axis_size = shape[axis];
    let mut out_shape = Vec::new();
    for (i, &s) in shape.iter().enumerate() {
        if i != axis {
            out_shape.push(s);
        } else if kd {
            out_shape.push(1);
        }
    }
    if out_shape.is_empty() {
        out_shape = vec![1];
    }
    let total_out: usize = out_shape.iter().product();
    let out_strides = strides_for(&out_shape);
    let in_strides = strides_for(shape);
    let in_f64 = to_f64(&buf.data, buf.dtype);
    let mut out_data = vec![0u8; total_out * 4];
    let out_i32: &mut [i32] = bytemuck::cast_slice_mut(&mut out_data);
    for i in 0..total_out {
        let mut out_coords = vec![0usize; out_shape.len()];
        let mut rem = i;
        for d in 0..out_shape.len() {
            out_coords[d] = rem / out_strides[d];
            rem %= out_strides[d];
        }
        let mut in_coords = vec![0usize; rank];
        let mut oi = 0;
        for d in 0..rank {
            if d == axis {
                if kd {
                    oi += 1;
                }
            } else {
                in_coords[d] = out_coords[oi];
                oi += 1;
            }
        }
        let mut best_val = f64::MAX;
        let mut best_idx: i32 = 0;
        for j in 0..axis_size {
            in_coords[axis] = j;
            let flat_in: usize = in_coords
                .iter()
                .zip(in_strides.iter())
                .map(|(c, s)| c * s)
                .sum();
            let val = in_f64[flat_in];
            if val < best_val {
                best_val = val;
                best_idx = j as i32;
            }
        }
        out_i32[i] = best_idx;
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: out_data,
            shape: out_shape,
            dtype: DType::S32,
        }),
    )
        .encode(env))
}

// ═══════════════════════════════════════════════════════════════
// Window operations
// ═══════════════════════════════════════════════════════════════

fn window_reduce(
    data: &[u8],
    dtype: DType,
    shape: &[usize],
    ws: &[usize],
    axes: &[usize],
    kd: bool,
    init: f64,
    op: impl Fn(f64, f64) -> f64,
) -> (Vec<u8>, Vec<usize>) {
    let rank = shape.len();
    let out_shape: Vec<usize> = (0..rank)
        .map(|d| {
            if axes.contains(&d) {
                shape[d].saturating_sub(ws[d] - 1)
            } else {
                shape[d]
            }
        })
        .collect();
    let out_shape = if kd {
        out_shape
    } else {
        out_shape
            .iter()
            .enumerate()
            .filter(|(i, _)| !axes.contains(i))
            .map(|(_, s)| *s)
            .collect()
    };
    let out_total: usize = out_shape.iter().product();
    let out_strides = strides_for(&out_shape);
    let in_strides = strides_for(shape);
    let in_f64 = to_f64(data, dtype);
    let mut out_vals = vec![init; out_total];
    for i in 0..out_total {
        let mut out_coords = vec![0usize; out_shape.len()];
        let mut rem = i;
        for d in 0..out_shape.len() {
            out_coords[d] = rem / out_strides[d];
            rem %= out_strides[d];
        }
        let mut in_coords = vec![0usize; rank];
        let mut oi = 0;
        for d in 0..rank {
            in_coords[d] = out_coords[oi];
            oi += 1;
        }
        let mut acc = init;
        let n_window: usize = axes.iter().map(|&a| ws[a]).product();
        for w_idx in 0..n_window {
            let mut wc = vec![0usize; axes.len()];
            let mut r = w_idx;
            for k in (0..axes.len()).rev() {
                wc[k] = r % ws[axes[k]];
                r /= ws[axes[k]];
            }
            let mut ic = in_coords.clone();
            for (k, &ax) in axes.iter().enumerate() {
                ic[ax] += wc[k];
            }
            let flat_in: usize = ic.iter().zip(in_strides.iter()).map(|(c, s)| c * s).sum();
            if flat_in < in_f64.len() {
                acc = op(acc, in_f64[flat_in]);
            } else {
                break;
            }
        }
        out_vals[i] = acc;
    }
    (from_f64(out_vals, dtype), out_shape)
}

#[rustler::nif(schedule = "DirtyCpu")]
fn window_sum<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    shape: Vec<usize>,
    opts: Term,
) -> NifResult<Term<'a>> {
    let (ws, axes, kd) = decode_window_opts(opts)?;
    let (data, out_shape) = window_reduce(
        &buf.data,
        buf.dtype,
        &buf.shape,
        &ws,
        &axes,
        kd,
        0.0,
        |a, b| a + b,
    );
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data,
            shape: out_shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn window_max<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    shape: Vec<usize>,
    opts: Term,
) -> NifResult<Term<'a>> {
    let (ws, axes, kd) = decode_window_opts(opts)?;
    let (data, out_shape) = window_reduce(
        &buf.data,
        buf.dtype,
        &buf.shape,
        &ws,
        &axes,
        kd,
        f64::MIN,
        |a, b| a.max(b),
    );
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data,
            shape: out_shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn window_min<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    shape: Vec<usize>,
    opts: Term,
) -> NifResult<Term<'a>> {
    let (ws, axes, kd) = decode_window_opts(opts)?;
    let (data, out_shape) = window_reduce(
        &buf.data,
        buf.dtype,
        &buf.shape,
        &ws,
        &axes,
        kd,
        f64::MAX,
        |a, b| a.min(b),
    );
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data,
            shape: out_shape,
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

// ═══════════════════════════════════════════════════════════════
// LinAlg / Type conversion / Creation
// ═══════════════════════════════════════════════════════════════

#[rustler::nif(schedule = "DirtyCpu")]
fn dot_tensor<'a>(
    env: Env<'a>,
    a: ResourceArc<BufResource>,
    _c1: Vec<usize>,
    _b1: Vec<usize>,
    b: ResourceArc<BufResource>,
    _c2: Vec<usize>,
    _b2: Vec<usize>,
) -> NifResult<Term<'a>> {
    let a_s = &a.shape;
    let b_s = &b.shape;
    if a_s.len() < 2 || b_s.len() < 2 {
        return Err(Error::RaiseTerm(Box::new("dot: need 2D")));
    }
    let m = a_s[a_s.len() - 2];
    let k = a_s[a_s.len() - 1];
    let k2 = b_s[b_s.len() - 2];
    let n = b_s[b_s.len() - 1];
    if k != k2 {
        return Err(Error::RaiseTerm(Box::new(format!(
            "dot: inner mismatch {} vs {}",
            k, k2
        ))));
    }
    let total_out = m * n;
    let out_dtype = a.dtype;
    let a_f = to_f64(&a.data, a.dtype);
    let b_f = to_f64(&b.data, b.dtype);
    let mut out_vals = vec![0.0f64; total_out];
    for i in 0..m {
        for j in 0..n {
            let mut sum = 0.0;
            for l in 0..k {
                sum += a_f[i * k + l] * b_f[l * n + j];
            }
            out_vals[i * n + j] = sum;
        }
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_vals, out_dtype),
            shape: vec![m, n],
            dtype: out_dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn clip_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    min_ref: ResourceArc<BufResource>,
    max_ref: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    let min_v = to_f64(&min_ref.data, min_ref.dtype);
    let max_v = to_f64(&max_ref.data, max_ref.dtype);
    let lo = if min_v.is_empty() { f64::MIN } else { min_v[0] };
    let hi = if max_v.is_empty() { f64::MAX } else { max_v[0] };
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: unary_op(&buf.data, buf.dtype, |x| x.clamp(lo, hi)),
            shape: buf.shape.clone(),
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn as_type_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    dtype_str: String,
) -> NifResult<Term<'a>> {
    let new_dtype = decode_dtype(&dtype_str)?;
    if buf.dtype == new_dtype {
        return Ok((
            atoms::ok(),
            ResourceArc::new(BufResource {
                data: buf.data.clone(),
                shape: buf.shape.clone(),
                dtype: buf.dtype,
            }),
        )
            .encode(env));
    }
    let vals = to_f64(&buf.data, buf.dtype);
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(vals, new_dtype),
            shape: buf.shape.clone(),
            dtype: new_dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn constant_tensor<'a>(
    env: Env<'a>,
    shape: Vec<usize>,
    dtype_str: String,
    value: f64,
) -> NifResult<Term<'a>> {
    let dtype = decode_dtype(&dtype_str)?;
    let n: usize = shape.iter().product();
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(vec![value; n], dtype),
            shape,
            dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn eye_tensor<'a>(env: Env<'a>, shape: Vec<usize>, dtype_str: String) -> NifResult<Term<'a>> {
    let dtype = decode_dtype(&dtype_str)?;
    if shape.len() < 2 {
        return Err(Error::RaiseTerm(Box::new("eye: need >= 2D")));
    }
    let rows = shape[shape.len() - 2];
    let cols = shape[shape.len() - 1];
    let batch: usize = shape[..shape.len() - 2].iter().product();
    let n = batch * rows * cols;
    let mut vals = vec![0.0f64; n];
    for b in 0..batch {
        for i in 0..rows {
            for j in 0..cols {
                let idx = b * rows * cols + i * cols + j;
                vals[idx] = if i == j { 1.0 } else { 0.0 };
            }
        }
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(vals, dtype),
            shape,
            dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn iota_tensor<'a>(
    env: Env<'a>,
    shape: Vec<usize>,
    dtype_str: String,
    axis: usize,
) -> NifResult<Term<'a>> {
    let dtype = decode_dtype(&dtype_str)?;
    let n: usize = shape.iter().product();
    let strides = strides_for(&shape);
    let mut vals = vec![0.0f64; n];
    for i in 0..n {
        let mut coords = vec![0usize; shape.len()];
        let mut rem = i;
        for d in 0..shape.len() {
            coords[d] = rem / strides[d];
            rem %= strides[d];
        }
        vals[i] = coords[axis] as f64;
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(vals, dtype),
            shape,
            dtype,
        }),
    )
        .encode(env))
}

// ═══════════════════════════════════════════════════════════════
// Sorting
// ═══════════════════════════════════════════════════════════════

#[rustler::nif(schedule = "DirtyCpu")]
fn sort_tensor<'a>(env: Env<'a>, buf: ResourceArc<BufResource>, opts: Term) -> NifResult<Term<'a>> {
    let (axis, _) = decode_axis_opts(opts).map(|(a, _kd)| (a, false))?;
    let shape = &buf.shape;
    let rank = shape.len();
    if rank == 0 {
        return Ok((
            atoms::ok(),
            ResourceArc::new(BufResource {
                data: buf.data.clone(),
                shape: buf.shape.clone(),
                dtype: buf.dtype,
            }),
        )
            .encode(env));
    }
    let axis = if axis >= rank { 0 } else { axis };
    let axis_size = shape[axis];
    if axis_size <= 1 {
        return Ok((
            atoms::ok(),
            ResourceArc::new(BufResource {
                data: buf.data.clone(),
                shape: buf.shape.clone(),
                dtype: buf.dtype,
            }),
        )
            .encode(env));
    }
    let in_strides = strides_for(shape);
    let in_f64 = to_f64(&buf.data, buf.dtype);
    let total: usize = shape.iter().product();
    let mut out_vals = vec![0.0f64; total];
    let n_slices = total / axis_size;
    for s in 0..n_slices {
        let mut rem = s;
        let mut base_idx = 0usize;
        for d in 0..rank {
            let coord = rem / in_strides[d];
            rem %= in_strides[d];
            if d != axis {
                base_idx += coord * in_strides[d];
            }
        }
        let mut slice_vals: Vec<f64> = (0..axis_size)
            .map(|j| in_f64[base_idx + j * in_strides[axis]])
            .collect();
        slice_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
        for (j, &val) in slice_vals.iter().enumerate() {
            out_vals[base_idx + j * in_strides[axis]] = val;
        }
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_vals, buf.dtype),
            shape: buf.shape.clone(),
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn argsort_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    opts: Term,
) -> NifResult<Term<'a>> {
    let (axis, _) = decode_axis_opts(opts).map(|(a, _kd)| (a, false))?;
    let shape = &buf.shape;
    let rank = shape.len();
    if rank == 0 {
        return Ok((
            atoms::ok(),
            ResourceArc::new(BufResource {
                data: vec![0u8; 8],
                shape: buf.shape.clone(),
                dtype: DType::S64,
            }),
        )
            .encode(env));
    }
    let axis = if axis >= rank { 0 } else { axis };
    let axis_size = shape[axis];
    let in_strides = strides_for(shape);
    let in_f64 = to_f64(&buf.data, buf.dtype);
    let total: usize = shape.iter().product();
    let n_slices = total / axis_size;
    let mut out_vals = vec![0i64; total];
    for s in 0..n_slices {
        let mut rem = s;
        let mut base_idx = 0usize;
        for d in 0..rank {
            let coord = rem / in_strides[d];
            rem %= in_strides[d];
            if d != axis {
                base_idx += coord * in_strides[d];
            }
        }
        let mut indexed: Vec<(f64, usize)> = (0..axis_size)
            .map(|j| (in_f64[base_idx + j * in_strides[axis]], j))
            .collect();
        indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
        for (j, &(_, orig_idx)) in indexed.iter().enumerate() {
            out_vals[base_idx + j * in_strides[axis]] = orig_idx as i64;
        }
    }
    let out_data = bytemuck::cast_slice(&out_vals).to_vec();
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: out_data,
            shape: buf.shape.clone(),
            dtype: DType::S64,
        }),
    )
        .encode(env))
}

// ═══════════════════════════════════════════════════════════════
// Additional operations (gather, put_slice, bitcast, conv, indexed)
// ═══════════════════════════════════════════════════════════════

#[rustler::nif(schedule = "DirtyCpu")]
fn gather<'a>(
    env: Env<'a>,
    input: ResourceArc<BufResource>,
    indices: ResourceArc<BufResource>,
    opts: Term,
) -> NifResult<Term<'a>> {
    let mut axis = 0usize;
    if let Ok(list) = opts.decode::<Vec<Term>>() {
        for item in list {
            if let Ok((ref key, val)) = item.decode::<(String, Term)>() {
                match key.as_str() {
                    "axes" => {
                        if let Ok(axes) = val.decode::<Vec<usize>>() {
                            axis = axes.first().copied().unwrap_or(0);
                        }
                    }
                    "axis" => {
                        axis = val.decode::<usize>().unwrap_or(0);
                    }
                    _ => {}
                }
            }
        }
    }
    let in_shape = &input.shape;
    let idx_shape = &indices.shape;
    let rank = in_shape.len();
    if axis >= rank {
        return Err(Error::RaiseTerm(Box::new("gather: axis out of range")));
    }
    let axis_size = in_shape[axis];
    let in_strides = strides_for(in_shape);
    let idx_data = to_f64(&indices.data, indices.dtype);
    let in_f64 = to_f64(&input.data, input.dtype);
    let idx_strides = strides_for(idx_shape);
    let num_indices: usize = idx_shape.iter().product();
    // Output shape: indices shape without the last dimension (which corresponds to axes)
    // For simple 1D indices, output shape = indices shape
    // For multi-dimensional indices with last dim = num_axes, output shape = indices shape without last dim
    let num_axes = 1; // We only support single axis for now
    let out_shape: Vec<usize> = if idx_shape.len() > num_axes {
        idx_shape[..idx_shape.len() - num_axes].to_vec()
    } else {
        vec![1]
    };
    let total_out: usize = out_shape.iter().product();
    let out_strides = strides_for(&out_shape);
    let mut out_vals = vec![0.0f64; total_out];
    for i in 0..total_out {
        // Compute output coordinates
        let mut out_coords = vec![0usize; out_shape.len()];
        let mut rem = i;
        for d in 0..out_shape.len() {
            out_coords[d] = rem / out_strides[d];
            rem %= out_strides[d];
        }
        // Compute the flat index into the indices tensor
        let idx_flat = if out_shape.len() == 1 && out_shape[0] == num_indices {
            i
        } else {
            let mut idx_coords = out_coords.clone();
            for _ in 0..num_axes {
                idx_coords.push(0);
            }
            idx_coords
                .iter()
                .zip(idx_strides.iter())
                .map(|(c, s)| c * s)
                .sum()
        };
        let idx_val = if idx_flat < num_indices {
            idx_data[idx_flat] as usize
        } else {
            0
        };
        let idx_val = idx_val.min(axis_size - 1);
        // Compute the flat index into the input tensor
        let flat_in = if rank == 1 {
            idx_val
        } else {
            let mut in_coords = vec![0usize; rank];
            let mut oi = 0;
            for d in 0..rank {
                if d == axis {
                    in_coords[d] = idx_val;
                } else {
                    in_coords[d] = out_coords.get(oi).copied().unwrap_or(0);
                    oi += 1;
                }
            }
            in_coords
                .iter()
                .zip(in_strides.iter())
                .map(|(c, s)| c * s)
                .sum()
        };
        out_vals[i] = in_f64[flat_in];
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_vals, input.dtype),
            shape: out_shape,
            dtype: input.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn put_slice<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    starts: Vec<usize>,
    slice: ResourceArc<BufResource>,
) -> NifResult<Term<'a>> {
    let rank = buf.shape.len();
    let in_strides = strides_for(&buf.shape);
    let slice_strides = strides_for(&slice.shape);
    let mut out_data = to_f64(&buf.data, buf.dtype);
    let slice_f64 = to_f64(&slice.data, slice.dtype);
    let total_slice: usize = slice.shape.iter().product();
    for i in 0..total_slice {
        let mut slice_coords = vec![0usize; rank];
        let mut rem = i;
        for d in 0..rank {
            slice_coords[d] = rem / slice_strides[d];
            rem %= slice_strides[d];
        }
        let mut out_idx = 0usize;
        for d in 0..rank {
            out_idx += (starts[d] + slice_coords[d]) * in_strides[d];
        }
        if out_idx < out_data.len() {
            out_data[out_idx] = slice_f64[i];
        }
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_data, buf.dtype),
            shape: buf.shape.clone(),
            dtype: buf.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn bitcast_tensor<'a>(
    env: Env<'a>,
    buf: ResourceArc<BufResource>,
    dtype_str: String,
) -> NifResult<Term<'a>> {
    let new_dtype = decode_dtype(&dtype_str)?;
    let old_size = buf.dtype.size_in_bytes();
    let new_size = new_dtype.size_in_bytes();
    if old_size == new_size {
        return Ok((
            atoms::ok(),
            ResourceArc::new(BufResource {
                data: buf.data.clone(),
                shape: buf.shape.clone(),
                dtype: new_dtype,
            }),
        )
            .encode(env));
    }
    if old_size > new_size {
        // Truncate: keep only the bytes that fit
        let new_n = buf.data.len() / new_size;
        let mut new_data = vec![0u8; new_n * new_size];
        for i in 0..new_n {
            new_data[i * new_size..(i + 1) * new_size]
                .copy_from_slice(&buf.data[i * old_size..i * old_size + new_size]);
        }
        let mut new_shape = buf.shape.clone();
        let last = new_shape.len() - 1;
        new_shape[last] = new_shape[last] * old_size / new_size;
        return Ok((
            atoms::ok(),
            ResourceArc::new(BufResource {
                data: new_data,
                shape: new_shape,
                dtype: new_dtype,
            }),
        )
            .encode(env));
    }
    // old_size < new_size: expand with zeros
    let repeat = new_size / old_size;
    let n = buf.data.len() / old_size;
    let mut new_data = vec![0u8; n * new_size];
    for i in 0..n {
        new_data[i * new_size..i * new_size + old_size]
            .copy_from_slice(&buf.data[i * old_size..(i + 1) * old_size]);
    }
    let mut new_shape = buf.shape.clone();
    let last = new_shape.len() - 1;
    new_shape[last] = new_shape[last] / repeat;
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: new_data,
            shape: new_shape,
            dtype: new_dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn conv<'a>(
    env: Env<'a>,
    input: ResourceArc<BufResource>,
    kernel: ResourceArc<BufResource>,
    _opts: Term,
) -> NifResult<Term<'a>> {
    // Simple 2D convolution (no padding, stride=1)
    let in_shape = &input.shape;
    let k_shape = &kernel.shape;
    if in_shape.len() < 2 || k_shape.len() < 2 {
        return Err(Error::RaiseTerm(Box::new("conv: need >= 2D")));
    }
    let in_h = in_shape[in_shape.len() - 2];
    let in_w = in_shape[in_shape.len() - 1];
    let k_h = k_shape[k_shape.len() - 2];
    let k_w = k_shape[k_shape.len() - 1];
    let out_h = in_h.saturating_sub(k_h - 1);
    let out_w = in_w.saturating_sub(k_w - 1);
    if out_h == 0 || out_w == 0 {
        return Err(Error::RaiseTerm(Box::new("conv: kernel larger than input")));
    }
    let batch: usize = in_shape[..in_shape.len() - 2].iter().product();
    let in_f = to_f64(&input.data, input.dtype);
    let k_f = to_f64(&kernel.data, kernel.dtype);
    let mut out_shape = in_shape[..in_shape.len() - 2].to_vec();
    out_shape.push(out_h);
    out_shape.push(out_w);
    let total_out = batch * out_h * out_w;
    let mut out_vals = vec![0.0f64; total_out];
    for b in 0..batch {
        for oh in 0..out_h {
            for ow in 0..out_w {
                let mut sum = 0.0;
                for kh in 0..k_h {
                    for kw in 0..k_w {
                        let ih = oh + kh;
                        let iw = ow + kw;
                        let in_idx = b * in_h * in_w + ih * in_w + iw;
                        let k_idx = kh * k_w + kw;
                        sum += in_f[in_idx] * k_f[k_idx];
                    }
                }
                out_vals[b * out_h * out_w + oh * out_w + ow] = sum;
            }
        }
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_vals, input.dtype),
            shape: out_shape,
            dtype: input.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn indexed_add<'a>(
    env: Env<'a>,
    t: ResourceArc<BufResource>,
    indices: ResourceArc<BufResource>,
    updates: ResourceArc<BufResource>,
    _opts: Term,
) -> NifResult<Term<'a>> {
    let mut out_data = to_f64(&t.data, t.dtype);
    let idx_f = to_f64(&indices.data, indices.dtype);
    let upd_f = to_f64(&updates.data, updates.dtype);
    let n = idx_f.len();
    let rank = t.shape.len();
    let in_strides = strides_for(&t.shape);
    for i in 0..n {
        let idx = idx_f[i] as usize;
        if idx < t.shape[0] {
            let offset = idx * in_strides[0];
            let upd_offset = i * in_strides[0];
            for j in 0..in_strides[0] {
                if offset + j < out_data.len() && upd_offset + j < upd_f.len() {
                    out_data[offset + j] += upd_f[upd_offset + j];
                }
            }
        }
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_data, t.dtype),
            shape: t.shape.clone(),
            dtype: t.dtype,
        }),
    )
        .encode(env))
}

#[rustler::nif(schedule = "DirtyCpu")]
fn indexed_put<'a>(
    env: Env<'a>,
    t: ResourceArc<BufResource>,
    indices: ResourceArc<BufResource>,
    updates: ResourceArc<BufResource>,
    _opts: Term,
) -> NifResult<Term<'a>> {
    let mut out_data = to_f64(&t.data, t.dtype);
    let idx_f = to_f64(&indices.data, indices.dtype);
    let upd_f = to_f64(&updates.data, updates.dtype);
    let n = idx_f.len();
    let in_strides = strides_for(&t.shape);
    for i in 0..n {
        let idx = idx_f[i] as usize;
        if idx < t.shape[0] {
            let offset = idx * in_strides[0];
            let upd_offset = i * in_strides[0];
            for j in 0..in_strides[0] {
                if offset + j < out_data.len() && upd_offset + j < upd_f.len() {
                    out_data[offset + j] = upd_f[upd_offset + j];
                }
            }
        }
    }
    Ok((
        atoms::ok(),
        ResourceArc::new(BufResource {
            data: from_f64(out_data, t.dtype),
            shape: t.shape.clone(),
            dtype: t.dtype,
        }),
    )
        .encode(env))
}

// ═══════════════════════════════════════════════════════════════
// Stubs for complex ops (fallback to BinaryBackend on Elixir side)
// ═══════════════════════════════════════════════════════════════

#[rustler::nif(schedule = "DirtyCpu")]
fn triangular_solve<'a>(
    env: Env<'a>,
    _a: ResourceArc<BufResource>,
    _b: ResourceArc<BufResource>,
    _c: Term,
) -> NifResult<Term<'a>> {
    Err(Error::RaiseTerm(Box::new(
        "triangular_solve: not implemented, use BinaryBackend fallback",
    )))
}
#[rustler::nif(schedule = "DirtyCpu")]
fn fft_tensor<'a>(env: Env<'a>, _a: ResourceArc<BufResource>, _b: Term) -> NifResult<Term<'a>> {
    Err(Error::RaiseTerm(Box::new(
        "fft: not implemented, use BinaryBackend fallback",
    )))
}
#[rustler::nif(schedule = "DirtyCpu")]
fn ifft_tensor<'a>(env: Env<'a>, _a: ResourceArc<BufResource>, _b: Term) -> NifResult<Term<'a>> {
    Err(Error::RaiseTerm(Box::new(
        "ifft: not implemented, use BinaryBackend fallback",
    )))
}
#[rustler::nif(schedule = "DirtyCpu")]
fn window_scatter_max<'a>(
    env: Env<'a>,
    _a: ResourceArc<BufResource>,
    _b: ResourceArc<BufResource>,
    _c: ResourceArc<BufResource>,
    _d: Vec<usize>,
    _e: Term,
) -> NifResult<Term<'a>> {
    Err(Error::RaiseTerm(Box::new(
        "window_scatter_max: not implemented",
    )))
}
#[rustler::nif(schedule = "DirtyCpu")]
fn window_scatter_min<'a>(
    env: Env<'a>,
    _a: ResourceArc<BufResource>,
    _b: ResourceArc<BufResource>,
    _c: ResourceArc<BufResource>,
    _d: Vec<usize>,
    _e: Term,
) -> NifResult<Term<'a>> {
    Err(Error::RaiseTerm(Box::new(
        "window_scatter_min: not implemented",
    )))
}
#[rustler::nif(schedule = "DirtyCpu")]
fn reduce<'a>(
    env: Env<'a>,
    _a: ResourceArc<BufResource>,
    _b: ResourceArc<BufResource>,
    _c: Term,
    _d: Term,
) -> NifResult<Term<'a>> {
    Err(Error::RaiseTerm(Box::new(
        "reduce: not implemented, use BinaryBackend fallback",
    )))
}
#[rustler::nif(schedule = "DirtyCpu")]
fn window_reduce<'a>(
    env: Env<'a>,
    _a: ResourceArc<BufResource>,
    _b: ResourceArc<BufResource>,
    _c: Vec<usize>,
    _d: Term,
    _e: Term,
) -> NifResult<Term<'a>> {
    Err(Error::RaiseTerm(Box::new(
        "window_reduce: not implemented, use BinaryBackend fallback",
    )))
}
#[rustler::nif(schedule = "DirtyCpu")]
fn window_product<'a>(
    env: Env<'a>,
    _a: ResourceArc<BufResource>,
    _b: Vec<usize>,
    _c: Term,
) -> NifResult<Term<'a>> {
    Err(Error::RaiseTerm(Box::new(
        "window_product: not implemented, use BinaryBackend fallback",
    )))
}