native/metamorphic_crypto_nif/metamorphic-crypto/src/hybrid.rs

//! Hybrid post-quantum KEM: ML-KEM-768 + X25519.
//!
//! This module implements the exact same hybrid KEM as noble-post-quantum's
//! `ml_kem768_x25519` export, ensuring byte-level compatibility with existing
//! production ciphertext.
//!
//! ## Construction (from noble source)
//!
//! ```text
//! combineKEMS(
//!   seedLen = 32,
//!   outputLen = 32,
//!   expandSeed = SHAKE256(seed, dkLen=96),
//!   combiner = SHA3-256(ss_mlkem || ss_x25519 || ct_x25519 || pk_x25519 || b"\\.//^\\"),
//!   ml_kem768,
//!   ecdhKem(x25519)
//! )
//! ```
//!
//! ## Key layout
//!
//! | Component | Size | Description |
//! |-----------|------|-------------|
//! | Secret key (seed) | 32 bytes | Root seed expanded via SHAKE256 |
//! | Public key | 1216 bytes | ML-KEM-768 ek (1184) ‖ X25519 pk (32) |
//! | Ciphertext | 1120 bytes | ML-KEM-768 ct (1088) ‖ X25519 ephemeral pk (32) |
//! | Shared secret | 32 bytes | SHA3-256 combiner output |
//!
//! ## Sealed-box ciphertext format (Metamorphic v2)
//!
//! ```text
//! 0x02 || hybrid_ciphertext (1120 B) || nonce (24 B) || secretbox_ct
//! ```

use ml_kem::{Decapsulate, MlKem768};
use ml_kem::{DecapsulationKey, EncapsulationKey, KeyExport};
use sha3::Shake256;
use sha3::digest::{ExtendableOutput, Update, XofReader};
use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret as X25519StaticSecret};
use zeroize::Zeroize;

use crypto_secretbox::aead::Aead;
use crypto_secretbox::aead::generic_array::GenericArray;
use crypto_secretbox::{KeyInit, XSalsa20Poly1305};

use crate::CryptoError;
use crate::b64;

// === Constants ===

/// Version tag for hybrid ciphertext.
const VERSION_HYBRID: u8 = 0x02;
/// XSalsa20 nonce length.
const NONCE_LEN: usize = 24;
/// ML-KEM-768 encapsulation key size.
const MLKEM_EK_LEN: usize = 1184;
/// ML-KEM-768 ciphertext size.
const MLKEM_CT_LEN: usize = 1088;
/// X25519 key size.
const X25519_LEN: usize = 32;
/// Combined public key: ML-KEM ek (1184) + X25519 pk (32).
const COMBINED_PK_LEN: usize = MLKEM_EK_LEN + X25519_LEN;
/// Combined ciphertext: ML-KEM ct (1088) + X25519 ephemeral pk (32).
const COMBINED_CT_LEN: usize = MLKEM_CT_LEN + X25519_LEN;
/// Root seed length.
const SEED_LEN: usize = 32;
/// Expanded seed length: ML-KEM seed (64) + X25519 secret (32).
const EXPANDED_SEED_LEN: usize = 96;
/// ML-KEM-768 seed portion.
const MLKEM_SEED_LEN: usize = 64;
/// Noble's domain-separation label.
const LABEL: &[u8] = b"\\.//^\\";
/// Poly1305 MAC.
const MAC_LEN: usize = 16;

// === Types ===

/// A hybrid ML-KEM-768 + X25519 keypair (base64-encoded).
#[derive(Debug, Clone)]
pub struct HybridKeyPair {
    /// Combined public key (1216 bytes): ML-KEM ek ‖ X25519 pk. Base64.
    pub public_key: String,
    /// Root seed (32 bytes). Base64.
    pub secret_key: String,
}

// === Helpers ===

/// Fill buffer with OS random bytes.
#[inline]
fn random_bytes(buf: &mut [u8]) {
    getrandom::getrandom(buf).expect("OS CSPRNG unavailable");
}

/// Expand a 32-byte seed to 96 bytes using SHAKE256 (matching noble's expandSeedXof).
fn expand_seed(seed: &[u8; SEED_LEN]) -> [u8; EXPANDED_SEED_LEN] {
    let mut hasher = Shake256::default();
    hasher.update(seed);
    let mut reader = hasher.finalize_xof();
    let mut out = [0u8; EXPANDED_SEED_LEN];
    reader.read(&mut out);
    out
}

/// SHA3-256 combiner: `SHA3-256(ss_mlkem || ss_x25519 || ct_x25519 || pk_x25519 || label)`
fn combine(
    ss_mlkem: &[u8],
    ss_x25519: &[u8],
    ct_x25519: &[u8; X25519_LEN],
    pk_x25519: &[u8; X25519_LEN],
) -> [u8; 32] {
    use sha3::Digest;
    let mut hasher = sha3::Sha3_256::new();
    Digest::update(&mut hasher, ss_mlkem);
    Digest::update(&mut hasher, ss_x25519);
    Digest::update(&mut hasher, ct_x25519);
    Digest::update(&mut hasher, pk_x25519);
    Digest::update(&mut hasher, LABEL);
    hasher.finalize().into()
}

// === Public API ===

/// Generate a hybrid ML-KEM-768 + X25519 keypair.
///
/// The secret key is the 32-byte root seed. The public key is the concatenation
/// of the ML-KEM-768 encapsulation key (1184 B) and the X25519 public key (32 B).
pub fn generate_hybrid_keypair() -> HybridKeyPair {
    let mut seed = [0u8; SEED_LEN];
    random_bytes(&mut seed);

    let mut expanded = expand_seed(&seed);
    let mlkem_seed: [u8; MLKEM_SEED_LEN] = expanded[..MLKEM_SEED_LEN].try_into().unwrap();
    let x25519_sk_bytes: [u8; X25519_LEN] = expanded[MLKEM_SEED_LEN..].try_into().unwrap();

    // ML-KEM-768 keypair from seed
    let dk = DecapsulationKey::<MlKem768>::from_seed(mlkem_seed.into());
    let ek = dk.encapsulation_key();
    let ek_bytes = ek.to_bytes();

    // X25519 keypair
    let x25519_sk = X25519StaticSecret::from(x25519_sk_bytes);
    let x25519_pk = X25519PublicKey::from(&x25519_sk);

    // Combined public key: ML-KEM ek || X25519 pk
    let mut combined_pk = Vec::with_capacity(COMBINED_PK_LEN);
    combined_pk.extend_from_slice(&ek_bytes);
    combined_pk.extend_from_slice(x25519_pk.as_bytes());

    let pair = HybridKeyPair {
        public_key: b64::encode(&combined_pk),
        secret_key: b64::encode(&seed),
    };

    // Zeroize secrets
    seed.zeroize();
    expanded.zeroize();

    pair
}

/// Seal `plaintext` to a hybrid public key.
///
/// Returns base64: `0x02 || hybrid_ct (1120 B) || nonce (24 B) || secretbox_ct`.
pub fn hybrid_seal(plaintext: &[u8], combined_pk_b64: &str) -> Result<String, CryptoError> {
    let pk_bytes = b64::decode(combined_pk_b64)?;
    if pk_bytes.len() != COMBINED_PK_LEN {
        return Err(CryptoError::InvalidLength {
            expected: COMBINED_PK_LEN,
            got: pk_bytes.len(),
        });
    }

    // Split combined public key
    let mlkem_ek_bytes = &pk_bytes[..MLKEM_EK_LEN];
    let x25519_pk_bytes: [u8; X25519_LEN] = pk_bytes[MLKEM_EK_LEN..].try_into().unwrap();

    // ML-KEM-768 encapsulate
    let ek = EncapsulationKey::<MlKem768>::new(
        mlkem_ek_bytes
            .try_into()
            .map_err(|_| CryptoError::Hybrid("invalid ML-KEM ek".into()))?,
    )
    .map_err(|_| CryptoError::Hybrid("invalid ML-KEM encapsulation key".into()))?;

    let mut mlkem_coins = [0u8; 32];
    random_bytes(&mut mlkem_coins);
    let (mlkem_ct, ss_mlkem) = ek.encapsulate_deterministic(&mlkem_coins.into());
    mlkem_coins.zeroize();

    // X25519 encapsulate (ephemeral DH)
    let mut x25519_eph_bytes = [0u8; X25519_LEN];
    random_bytes(&mut x25519_eph_bytes);
    let x25519_eph_sk = X25519StaticSecret::from(x25519_eph_bytes);
    let x25519_eph_pk = X25519PublicKey::from(&x25519_eph_sk);
    let x25519_recipient_pk = X25519PublicKey::from(x25519_pk_bytes);
    let ss_x25519 = x25519_eph_sk.diffie_hellman(&x25519_recipient_pk);
    x25519_eph_bytes.zeroize();

    // Combine shared secrets
    let ct_x25519: [u8; X25519_LEN] = *x25519_eph_pk.as_bytes();
    let mut shared_secret = combine(
        ss_mlkem.as_slice(),
        ss_x25519.as_bytes(),
        &ct_x25519,
        &x25519_pk_bytes,
    );

    // Encrypt plaintext with the combined shared secret
    let cipher = XSalsa20Poly1305::new(GenericArray::from_slice(&shared_secret));
    let mut nonce_buf = [0u8; NONCE_LEN];
    random_bytes(&mut nonce_buf);
    let nonce = GenericArray::from_slice(&nonce_buf);
    let ct = cipher
        .encrypt(nonce, plaintext)
        .map_err(|_| CryptoError::Hybrid("secretbox encrypt failed".into()))?;

    shared_secret.zeroize();

    // Assemble: version || mlkem_ct || x25519_eph_pk || nonce || secretbox_ct
    let mut out = Vec::with_capacity(1 + COMBINED_CT_LEN + NONCE_LEN + ct.len());
    out.push(VERSION_HYBRID);
    out.extend_from_slice(mlkem_ct.as_slice());
    out.extend_from_slice(&ct_x25519);
    out.extend_from_slice(&nonce_buf);
    out.extend_from_slice(&ct);

    Ok(b64::encode(&out))
}

/// Open a hybrid-sealed ciphertext with the recipient's secret key (32-byte seed).
pub fn hybrid_open(ct_b64: &str, seed_b64: &str) -> Result<Vec<u8>, CryptoError> {
    let combined = b64::decode(ct_b64)?;
    let seed_bytes = b64::decode(seed_b64)?;

    if combined.first() != Some(&VERSION_HYBRID) {
        return Err(CryptoError::Hybrid(
            "not a hybrid ciphertext (bad version tag)".into(),
        ));
    }
    if seed_bytes.len() != SEED_LEN {
        return Err(CryptoError::InvalidLength {
            expected: SEED_LEN,
            got: seed_bytes.len(),
        });
    }
    if combined.len() < 1 + COMBINED_CT_LEN + NONCE_LEN + MAC_LEN {
        return Err(CryptoError::TooShort);
    }

    // Expand seed to recover both secret keys
    let seed: [u8; SEED_LEN] = seed_bytes.try_into().unwrap();
    let mut expanded = expand_seed(&seed);
    let mlkem_seed: [u8; MLKEM_SEED_LEN] = expanded[..MLKEM_SEED_LEN].try_into().unwrap();
    let x25519_sk_bytes: [u8; X25519_LEN] = expanded[MLKEM_SEED_LEN..].try_into().unwrap();
    expanded.zeroize();

    // Parse ciphertext components
    let mlkem_ct = &combined[1..1 + MLKEM_CT_LEN];
    let x25519_eph_pk_bytes: [u8; X25519_LEN] = combined[1 + MLKEM_CT_LEN..1 + COMBINED_CT_LEN]
        .try_into()
        .unwrap();
    let nonce_slice = &combined[1 + COMBINED_CT_LEN..1 + COMBINED_CT_LEN + NONCE_LEN];
    let encrypted = &combined[1 + COMBINED_CT_LEN + NONCE_LEN..];

    // ML-KEM-768 decapsulate
    let dk = DecapsulationKey::<MlKem768>::from_seed(mlkem_seed.into());
    let kem_ct = mlkem_ct
        .try_into()
        .map_err(|_| CryptoError::Hybrid("invalid ML-KEM ciphertext".into()))?;
    let ss_mlkem = dk.decapsulate(kem_ct);

    // X25519 decapsulate (DH with ephemeral public key)
    let x25519_sk = X25519StaticSecret::from(x25519_sk_bytes);
    let x25519_eph_pk = X25519PublicKey::from(x25519_eph_pk_bytes);
    let ss_x25519 = x25519_sk.diffie_hellman(&x25519_eph_pk);

    // Recover recipient's X25519 public key for the combiner
    let x25519_pk = X25519PublicKey::from(&x25519_sk);
    let pk_x25519: [u8; X25519_LEN] = *x25519_pk.as_bytes();

    // Combine shared secrets (same combiner as seal)
    let mut shared_secret = combine(
        ss_mlkem.as_slice(),
        ss_x25519.as_bytes(),
        &x25519_eph_pk_bytes,
        &pk_x25519,
    );

    // Decrypt
    let cipher = XSalsa20Poly1305::new(GenericArray::from_slice(&shared_secret));
    let nonce = GenericArray::from_slice(nonce_slice);
    let result = cipher
        .decrypt(nonce, encrypted)
        .map_err(|_| CryptoError::Decryption);

    shared_secret.zeroize();
    result
}

/// Returns `true` if the base64 blob starts with the hybrid version tag (0x02).
pub fn is_hybrid_ciphertext(ct_b64: &str) -> bool {
    b64::decode(ct_b64)
        .map(|bytes| bytes.first() == Some(&VERSION_HYBRID))
        .unwrap_or(false)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn roundtrip() {
        let kp = generate_hybrid_keypair();
        let pt = b"32-byte symmetric context key!!!";
        let ct = hybrid_seal(pt, &kp.public_key).unwrap();
        assert!(is_hybrid_ciphertext(&ct));
        let opened = hybrid_open(&ct, &kp.secret_key).unwrap();
        assert_eq!(opened, pt);
    }

    #[test]
    fn wrong_key_fails() {
        let kp1 = generate_hybrid_keypair();
        let kp2 = generate_hybrid_keypair();
        let ct = hybrid_seal(b"x", &kp1.public_key).unwrap();
        assert!(hybrid_open(&ct, &kp2.secret_key).is_err());
    }

    #[test]
    fn version_tag() {
        let kp = generate_hybrid_keypair();
        let raw = b64::decode(&hybrid_seal(b"x", &kp.public_key).unwrap()).unwrap();
        assert_eq!(raw[0], VERSION_HYBRID);
    }

    #[test]
    fn nondeterministic() {
        let kp = generate_hybrid_keypair();
        let c1 = hybrid_seal(b"x", &kp.public_key).unwrap();
        let c2 = hybrid_seal(b"x", &kp.public_key).unwrap();
        assert_ne!(c1, c2);
    }

    #[test]
    fn empty_plaintext() {
        let kp = generate_hybrid_keypair();
        let ct = hybrid_seal(b"", &kp.public_key).unwrap();
        assert_eq!(hybrid_open(&ct, &kp.secret_key).unwrap(), b"");
    }

    #[test]
    fn key_sizes() {
        let kp = generate_hybrid_keypair();
        let pk = b64::decode(&kp.public_key).unwrap();
        let sk = b64::decode(&kp.secret_key).unwrap();
        assert_eq!(pk.len(), COMBINED_PK_LEN); // 1216
        assert_eq!(sk.len(), SEED_LEN); // 32
    }

    #[test]
    fn ciphertext_size() {
        let kp = generate_hybrid_keypair();
        let pt = b"exactly 32 bytes of key material";
        let raw = b64::decode(&hybrid_seal(pt, &kp.public_key).unwrap()).unwrap();
        // 1 + 1120 + 24 + 32 + 16 = 1193
        assert_eq!(raw.len(), 1 + COMBINED_CT_LEN + NONCE_LEN + 32 + MAC_LEN);
    }

    #[test]
    fn legacy_not_hybrid() {
        let legacy = b64::encode(&[0x01, 0x02, 0x03]);
        assert!(!is_hybrid_ciphertext(&legacy));
    }

    #[test]
    fn combiner_uses_label() {
        // Verify the combiner output changes if we alter the inputs
        let ss_mlkem = [0xAAu8; 32];
        let ss_x25519 = [0xBBu8; 32];
        let ct_x25519 = [0xCCu8; 32];
        let pk_x25519 = [0xDDu8; 32];

        let result = combine(&ss_mlkem, &ss_x25519, &ct_x25519, &pk_x25519);
        assert_eq!(result.len(), 32);

        // Different input → different output
        let ss_mlkem2 = [0xEEu8; 32];
        let result2 = combine(&ss_mlkem2, &ss_x25519, &ct_x25519, &pk_x25519);
        assert_ne!(result, result2);
    }

    #[test]
    fn seed_expansion_produces_96_bytes() {
        let seed = [0x42u8; SEED_LEN];
        let expanded = expand_seed(&seed);
        assert_eq!(expanded.len(), EXPANDED_SEED_LEN);

        // Deterministic
        let expanded2 = expand_seed(&seed);
        assert_eq!(expanded, expanded2);

        // Different seed → different expansion
        let seed2 = [0x43u8; SEED_LEN];
        let expanded3 = expand_seed(&seed2);
        assert_ne!(expanded, expanded3);
    }
}