native/vodozemac_nif/src/lib.rs

//! NIF entry points for the `olm` Elixir package.
//!
//! Wraps vodozemac primitives. The Elixir-side public API lives in
//! `lib/olm.ex`; each NIF entry point declared in `lib/olm/native.ex`
//! has a corresponding `#[rustler::nif]` here, registered in the
//! `rustler::init!` macro at the bottom.
//!
//! ## Conventions
//!
//! * **Pickle round-trip on every call.** vodozemac's `Account` and
//!   session types aren't `Send + Sync` in a way that makes putting
//!   them in a Rustler `ResourceArc` straightforward, so each NIF
//!   accepts a pickle binary and returns a fresh pickle. The cost is
//!   negligible for the call frequencies we hit (account ops are once
//!   per user per refill, not per message).
//! * **Pickle key** is a fixed zero-key for now. Phase 5 will swap
//!   this for a KDF-derived per-user key supplied by the caller.
//! * **Errors** return `(:error, atom_or_string)` rather than panic.
//!
//! See `apps/olm/README.md` for the phased rollout plan.

use rustler::Binary;
use rustler::Encoder;
use rustler::Env;
use rustler::NifResult;
use rustler::OwnedBinary;
use rustler::Term;
use vodozemac::megolm::GroupSession;
use vodozemac::megolm::GroupSessionPickle;
use vodozemac::megolm::InboundGroupSession;
use vodozemac::megolm::InboundGroupSessionPickle;
use vodozemac::megolm::MegolmMessage;
use vodozemac::megolm::SessionConfig as MegolmSessionConfig;
use vodozemac::megolm::SessionKey;
use vodozemac::olm::Account;
use vodozemac::olm::OlmMessage;
use vodozemac::olm::PreKeyMessage;
use vodozemac::olm::Session;
use vodozemac::olm::SessionConfig as OlmSessionConfig;
use vodozemac::olm::SessionPickle;
use vodozemac::sas::EstablishedSas;
use vodozemac::sas::Mac as SasMac;
use vodozemac::sas::Sas;
use vodozemac::Curve25519PublicKey;
use vodozemac::Curve25519SecretKey;
use vodozemac::Ed25519PublicKey;
use vodozemac::Ed25519Signature;
use rustler::ResourceArc;
use std::sync::Mutex;

mod atoms {
    rustler::atoms! {
        ok,
        error,
        bad_pickle,
        bad_message,
    }
}

const PICKLE_KEY: [u8; 32] = [0u8; 32];

// ── Helpers ────────────────────────────────────────────────────────

fn unpickle(pickle: Binary) -> Result<Account, &'static str> {
    let s = std::str::from_utf8(pickle.as_slice()).map_err(|_| "bad_pickle_utf8")?;
    Account::from_libolm_pickle(s, &PICKLE_KEY).map_err(|_| "bad_pickle")
}

fn pickle_to_binary<'a>(env: Env<'a>, account: &Account) -> Binary<'a> {
    let pickled = account
        .to_libolm_pickle(&PICKLE_KEY)
        .expect("libolm pickle never fails with valid key");
    string_to_binary(env, &pickled)
}

fn string_to_binary<'a>(env: Env<'a>, s: &str) -> Binary<'a> {
    let bytes = s.as_bytes();
    let mut bin = OwnedBinary::new(bytes.len()).expect("alloc");
    bin.as_mut_slice().copy_from_slice(bytes);
    Binary::from_owned(bin, env)
}

// ── Phase 1: Account ───────────────────────────────────────────────

/// Create a fresh Olm account. Returns the libolm-compatible pickle.
#[rustler::nif]
fn account_create<'a>(env: Env<'a>) -> Binary<'a> {
    let account = Account::new();
    pickle_to_binary(env, &account)
}

/// Identity keys from an account pickle: `{curve25519_b64, ed25519_b64}`
/// where both are URL-safe base64 (vodozemac default).
#[rustler::nif]
fn account_identity_keys<'a>(env: Env<'a>, pickle: Binary) -> NifResult<Term<'a>> {
    let account = unpickle(pickle).map_err(rustler_err)?;
    let keys = account.identity_keys();
    Ok((
        atoms::ok(),
        (
            string_to_binary(env, &keys.curve25519.to_base64()),
            string_to_binary(env, &keys.ed25519.to_base64()),
        ),
    )
        .encode(env))
}

/// Generate up to `count` one-time keys. Returns the updated pickle
/// and the unpublished keys as a list of `{key_id_b64, key_b64}`
/// tuples. Caller is responsible for signing the per-key payload
/// (with `account_sign/2`), uploading via `/keys/upload`, then
/// calling `account_mark_published/1` to commit.
#[rustler::nif]
fn account_generate_one_time_keys<'a>(
    env: Env<'a>,
    pickle: Binary,
    count: u32,
) -> NifResult<Term<'a>> {
    let mut account = unpickle(pickle).map_err(rustler_err)?;
    account.generate_one_time_keys(count as usize);

    let keys: Vec<(Binary, Binary)> = account
        .one_time_keys()
        .iter()
        .map(|(id, key)| {
            (
                string_to_binary(env, &id.to_base64()),
                string_to_binary(env, &key.to_base64()),
            )
        })
        .collect();

    Ok((atoms::ok(), pickle_to_binary(env, &account), keys).encode(env))
}

/// The maximum number of one-time keys this account will hold
/// (vodozemac uses 50 to match libolm).
#[rustler::nif]
fn account_max_one_time_keys<'a>(env: Env<'a>, pickle: Binary) -> NifResult<Term<'a>> {
    let account = unpickle(pickle).map_err(rustler_err)?;
    Ok((atoms::ok(), account.max_number_of_one_time_keys()).encode(env))
}

/// Mark all unpublished one-time keys as published (call after a
/// successful `/keys/upload`). Returns the updated pickle.
#[rustler::nif]
fn account_mark_published<'a>(env: Env<'a>, pickle: Binary) -> NifResult<Term<'a>> {
    let mut account = unpickle(pickle).map_err(rustler_err)?;
    account.mark_keys_as_published();
    Ok((atoms::ok(), pickle_to_binary(env, &account)).encode(env))
}

/// Sign `message` bytes with the account's Ed25519 key. Returns the
/// signature as URL-safe base64. Used to sign each one-time key's
/// canonical-JSON payload before upload.
#[rustler::nif]
fn account_sign<'a>(env: Env<'a>, pickle: Binary, message: Binary) -> NifResult<Term<'a>> {
    let account = unpickle(pickle).map_err(rustler_err)?;
    let sig = account.sign(message.as_slice());
    Ok((atoms::ok(), string_to_binary(env, &sig.to_base64())).encode(env))
}

/// Round-trip an account pickle to verify it parses. Returns the
/// same pickle re-serialized (semantically a no-op, but useful for
/// schema migrations and key-rotation flows).
#[rustler::nif]
fn account_unpickle<'a>(env: Env<'a>, pickle: Binary) -> NifResult<Term<'a>> {
    let account = unpickle(pickle).map_err(rustler_err)?;
    Ok((atoms::ok(), pickle_to_binary(env, &account)).encode(env))
}

/// Verify an Ed25519 signature over `message` using the given
/// `signer_ed25519_b64` public key. Useful for verifying inbound
/// device keys before trusting them. Returns `:ok` or
/// `{:error, :bad_signature}`.
#[rustler::nif]
fn verify_ed25519<'a>(
    env: Env<'a>,
    signer_ed25519_b64: Binary,
    message: Binary,
    signature_b64: Binary,
) -> Term<'a> {
    let signer = match std::str::from_utf8(signer_ed25519_b64.as_slice())
        .ok()
        .and_then(|s| Ed25519PublicKey::from_base64(s).ok())
    {
        Some(k) => k,
        None => return (atoms::error(), atoms::bad_message()).encode(env),
    };
    let signature = match std::str::from_utf8(signature_b64.as_slice())
        .ok()
        .and_then(|s| Ed25519Signature::from_base64(s).ok())
    {
        Some(s) => s,
        None => return (atoms::error(), atoms::bad_message()).encode(env),
    };

    match signer.verify(message.as_slice(), &signature) {
        Ok(()) => atoms::ok().encode(env),
        Err(_) => (atoms::error(), atoms::bad_message()).encode(env),
    }
}

fn rustler_err(msg: &'static str) -> rustler::Error {
    rustler::Error::Term(Box::new(msg))
}

// ── Phase 2: Inbound Megolm group sessions ─────────────────────────

fn unpickle_inbound_group(pickle: Binary) -> Result<InboundGroupSession, &'static str> {
    let s = std::str::from_utf8(pickle.as_slice()).map_err(|_| "bad_pickle_utf8")?;
    let p = InboundGroupSessionPickle::from_encrypted(s, &PICKLE_KEY).map_err(|_| "bad_pickle")?;
    Ok(InboundGroupSession::from_pickle(p))
}

fn pickle_inbound_group<'a>(env: Env<'a>, session: &InboundGroupSession) -> Binary<'a> {
    let s = session.pickle().encrypt(&PICKLE_KEY);
    string_to_binary(env, &s)
}

/// Build an InboundGroupSession from a base64 session key (the
/// `session_key` field of an `m.room_key` to-device event). Returns
/// `{:ok, session_id_b64, pickle}`.
#[rustler::nif]
fn inbound_group_session_create<'a>(
    env: Env<'a>,
    session_key_b64: Binary,
) -> NifResult<Term<'a>> {
    let key_str = std::str::from_utf8(session_key_b64.as_slice())
        .map_err(|_| rustler_err("bad_session_key_utf8"))?;
    let key = SessionKey::from_base64(key_str).map_err(|_| rustler_err("bad_session_key"))?;
    let session = InboundGroupSession::new(&key, MegolmSessionConfig::version_1());

    Ok((
        atoms::ok(),
        string_to_binary(env, &session.session_id()),
        pickle_inbound_group(env, &session),
    )
        .encode(env))
}

/// Decrypt a Megolm ciphertext from a timeline `m.room.encrypted`
/// event's `ciphertext` field. Returns `{:ok, plaintext, message_index, new_pickle}`.
/// The pickle is updated because Megolm sessions track which message
/// indices have been seen (replay defence).
#[rustler::nif]
fn inbound_group_session_decrypt<'a>(
    env: Env<'a>,
    pickle: Binary,
    ciphertext_b64: Binary,
) -> NifResult<Term<'a>> {
    let mut session = unpickle_inbound_group(pickle).map_err(rustler_err)?;
    let ciphertext_str = std::str::from_utf8(ciphertext_b64.as_slice())
        .map_err(|_| rustler_err("bad_ciphertext_utf8"))?;
    let msg =
        MegolmMessage::from_base64(ciphertext_str).map_err(|_| rustler_err("bad_megolm_message"))?;

    let decrypted = session
        .decrypt(&msg)
        .map_err(|_| rustler_err("megolm_decrypt_failed"))?;

    let plaintext_bin = {
        let bytes = &decrypted.plaintext;
        let mut bin = OwnedBinary::new(bytes.len()).expect("alloc");
        bin.as_mut_slice().copy_from_slice(bytes);
        Binary::from_owned(bin, env)
    };

    Ok((
        atoms::ok(),
        plaintext_bin,
        decrypted.message_index,
        pickle_inbound_group(env, &session),
    )
        .encode(env))
}

#[rustler::nif]
fn inbound_group_session_id<'a>(env: Env<'a>, pickle: Binary) -> NifResult<Term<'a>> {
    let session = unpickle_inbound_group(pickle).map_err(rustler_err)?;
    Ok((atoms::ok(), string_to_binary(env, &session.session_id())).encode(env))
}

// ── Phase 2: Outbound Megolm group sessions ────────────────────────
//
// Used for round-trip self-tests today; phase 3 wires it into the
// actual outbound send path.

fn unpickle_outbound_group(pickle: Binary) -> Result<GroupSession, &'static str> {
    let s = std::str::from_utf8(pickle.as_slice()).map_err(|_| "bad_pickle_utf8")?;
    let p = GroupSessionPickle::from_encrypted(s, &PICKLE_KEY).map_err(|_| "bad_pickle")?;
    Ok(GroupSession::from_pickle(p))
}

fn pickle_outbound_group<'a>(env: Env<'a>, session: &GroupSession) -> Binary<'a> {
    let s = session.pickle().encrypt(&PICKLE_KEY);
    string_to_binary(env, &s)
}

/// Create a fresh outbound Megolm group session. Returns
/// `{:ok, session_id_b64, session_key_b64, pickle}`. The session_key
/// is what gets wrapped in `m.room_key` to-device events and shared
/// with each recipient's device via Olm.
#[rustler::nif]
fn outbound_group_session_create<'a>(env: Env<'a>) -> Term<'a> {
    let session = GroupSession::new(MegolmSessionConfig::version_1());
    let session_id = session.session_id();
    let session_key = session.session_key();

    (
        atoms::ok(),
        string_to_binary(env, &session_id),
        string_to_binary(env, &session_key.to_base64()),
        pickle_outbound_group(env, &session),
    )
        .encode(env)
}

/// Encrypt plaintext bytes with the given outbound session. Returns
/// `{:ok, ciphertext_b64, new_pickle}`.
#[rustler::nif]
fn outbound_group_session_encrypt<'a>(
    env: Env<'a>,
    pickle: Binary,
    plaintext: Binary,
) -> NifResult<Term<'a>> {
    let mut session = unpickle_outbound_group(pickle).map_err(rustler_err)?;
    let msg = session.encrypt(plaintext.as_slice());
    let ciphertext_b64 = msg.to_base64();

    Ok((
        atoms::ok(),
        string_to_binary(env, &ciphertext_b64),
        pickle_outbound_group(env, &session),
    )
        .encode(env))
}

#[rustler::nif]
fn outbound_group_session_message_index<'a>(
    env: Env<'a>,
    pickle: Binary,
) -> NifResult<Term<'a>> {
    let session = unpickle_outbound_group(pickle).map_err(rustler_err)?;
    Ok((atoms::ok(), session.message_index()).encode(env))
}

/// Extract the current Megolm session_key from a live outbound
/// session pickle. The key reflects the current ratchet state — late
/// recipients keyed from it only decrypt from this index forward,
/// matching the Matrix spec for "joined after history".
#[rustler::nif]
fn outbound_group_session_key<'a>(env: Env<'a>, pickle: Binary) -> NifResult<Term<'a>> {
    let session = unpickle_outbound_group(pickle).map_err(rustler_err)?;
    Ok((atoms::ok(), string_to_binary(env, &session.session_key().to_base64())).encode(env))
}

#[rustler::nif]
fn outbound_group_session_id<'a>(env: Env<'a>, pickle: Binary) -> NifResult<Term<'a>> {
    let session = unpickle_outbound_group(pickle).map_err(rustler_err)?;
    Ok((atoms::ok(), string_to_binary(env, &session.session_id())).encode(env))
}

// ── Phase 2 / 3: Olm pairwise sessions ─────────────────────────────

fn unpickle_olm_session(pickle: Binary) -> Result<Session, &'static str> {
    let s = std::str::from_utf8(pickle.as_slice()).map_err(|_| "bad_pickle_utf8")?;
    let p = SessionPickle::from_encrypted(s, &PICKLE_KEY).map_err(|_| "bad_pickle")?;
    Ok(Session::from_pickle(p))
}

fn pickle_olm_session<'a>(env: Env<'a>, session: &Session) -> Binary<'a> {
    let s = session.pickle().encrypt(&PICKLE_KEY);
    string_to_binary(env, &s)
}

/// Establish an outbound Olm session against a peer device. Requires
/// the peer's Curve25519 identity key and one of its signed one-time
/// keys. Mutates the account (consumes nothing locally, but the
/// caller should keep the returned account pickle).
///
/// Returns `{:ok, session_id_b64, pickle, new_account_pickle}`.
#[rustler::nif]
fn olm_session_create_outbound<'a>(
    env: Env<'a>,
    account_pickle: Binary,
    their_curve25519_b64: Binary,
    their_one_time_key_b64: Binary,
) -> NifResult<Term<'a>> {
    let account = unpickle(account_pickle).map_err(rustler_err)?;
    let their_identity = parse_curve25519(&their_curve25519_b64)?;
    let their_otk = parse_curve25519(&their_one_time_key_b64)?;

    let session = account.create_outbound_session(
        OlmSessionConfig::version_2(),
        their_identity,
        their_otk,
    );

    Ok((
        atoms::ok(),
        string_to_binary(env, &session.session_id()),
        pickle_olm_session(env, &session),
        pickle_to_binary(env, &account),
    )
        .encode(env))
}

/// Decrypt a *pre-key* Olm message (`type: 0`), establishing the
/// inbound session in the process. Used the first time a peer sends
/// us an `m.olm.v1.curve25519-aes-sha2` to-device event.
///
/// Returns `{:ok, plaintext, session_id_b64, session_pickle, new_account_pickle}`.
#[rustler::nif]
fn olm_session_create_inbound<'a>(
    env: Env<'a>,
    account_pickle: Binary,
    their_curve25519_b64: Binary,
    ciphertext_b64: Binary,
) -> NifResult<Term<'a>> {
    let mut account = unpickle(account_pickle).map_err(rustler_err)?;
    let their_identity = parse_curve25519(&their_curve25519_b64)?;

    let ciphertext_str = std::str::from_utf8(ciphertext_b64.as_slice())
        .map_err(|_| rustler_err("bad_ciphertext_utf8"))?;
    let pre_key = PreKeyMessage::from_base64(ciphertext_str)
        .map_err(|_| rustler_err("bad_pre_key_message"))?;

    let result = account
        .create_inbound_session(their_identity, &pre_key)
        .map_err(|_| rustler_err("inbound_session_failed"))?;

    let plaintext_bin = {
        let bytes = &result.plaintext;
        let mut bin = OwnedBinary::new(bytes.len()).expect("alloc");
        bin.as_mut_slice().copy_from_slice(bytes);
        Binary::from_owned(bin, env)
    };

    Ok((
        atoms::ok(),
        plaintext_bin,
        string_to_binary(env, &result.session.session_id()),
        pickle_olm_session(env, &result.session),
        pickle_to_binary(env, &account),
    )
        .encode(env))
}

/// Encrypt plaintext with an existing Olm session. Returns
/// `{:ok, message_type, ciphertext_b64, new_pickle}`. `message_type`
/// is 0 (pre-key) or 1 (normal) per Matrix spec.
#[rustler::nif]
fn olm_session_encrypt<'a>(
    env: Env<'a>,
    pickle: Binary,
    plaintext: Binary,
) -> NifResult<Term<'a>> {
    let mut session = unpickle_olm_session(pickle).map_err(rustler_err)?;
    let msg = session.encrypt(plaintext.as_slice());
    let (msg_type, ciphertext_b64) = match msg {
        OlmMessage::PreKey(m) => (0u32, m.to_base64()),
        OlmMessage::Normal(m) => (1u32, m.to_base64()),
    };

    Ok((
        atoms::ok(),
        msg_type,
        string_to_binary(env, &ciphertext_b64),
        pickle_olm_session(env, &session),
    )
        .encode(env))
}

/// Decrypt with an existing Olm session. `message_type` is 0
/// (pre-key) or 1 (normal). Returns `{:ok, plaintext, new_pickle}`.
#[rustler::nif]
fn olm_session_decrypt<'a>(
    env: Env<'a>,
    pickle: Binary,
    message_type: u32,
    ciphertext_b64: Binary,
) -> NifResult<Term<'a>> {
    let mut session = unpickle_olm_session(pickle).map_err(rustler_err)?;
    let ciphertext_str = std::str::from_utf8(ciphertext_b64.as_slice())
        .map_err(|_| rustler_err("bad_ciphertext_utf8"))?;

    let msg = match message_type {
        0 => OlmMessage::PreKey(
            PreKeyMessage::from_base64(ciphertext_str)
                .map_err(|_| rustler_err("bad_pre_key_message"))?,
        ),
        1 => OlmMessage::Normal(
            vodozemac::olm::Message::from_base64(ciphertext_str)
                .map_err(|_| rustler_err("bad_olm_message"))?,
        ),
        _ => return Err(rustler_err("bad_message_type")),
    };

    let plaintext = session
        .decrypt(&msg)
        .map_err(|_| rustler_err("olm_decrypt_failed"))?;

    let plaintext_bin = {
        let bytes = &plaintext;
        let mut bin = OwnedBinary::new(bytes.len()).expect("alloc");
        bin.as_mut_slice().copy_from_slice(bytes);
        Binary::from_owned(bin, env)
    };

    Ok((
        atoms::ok(),
        plaintext_bin,
        pickle_olm_session(env, &session),
    )
        .encode(env))
}

#[rustler::nif]
fn olm_session_id<'a>(env: Env<'a>, pickle: Binary) -> NifResult<Term<'a>> {
    let session = unpickle_olm_session(pickle).map_err(rustler_err)?;
    Ok((atoms::ok(), string_to_binary(env, &session.session_id())).encode(env))
}

fn parse_curve25519(b: &Binary) -> Result<Curve25519PublicKey, rustler::Error> {
    let s = std::str::from_utf8(b.as_slice()).map_err(|_| rustler_err("bad_key_utf8"))?;
    Curve25519PublicKey::from_base64(s).map_err(|_| rustler_err("bad_curve25519"))
}

// ── Phase 5: Cross-signing (Ed25519 keypairs) ──────────────────────
//
// Matrix cross-signing uses three Ed25519 keypairs per user:
//   * Master Signing Key (MSK)  — root of trust
//   * Self-Signing Key (SSK)    — signs the user's own devices
//   * User-Signing Key (USK)    — signs other users (for trust-through-friends)
//
// vodozemac doesn't expose cross-signing-specific types; we just
// generate raw `Ed25519SecretKey` instances and use them in the
// Matrix-spec-required signing contexts. The signing is plain
// Ed25519 over canonical-JSON; the caller does the canonical-JSON
// formatting.

/// Generate a fresh Ed25519 keypair. Returns
/// `{:ok, secret_b64, public_b64}` — both URL-safe-no-pad base64
/// (vodozemac's default).
#[rustler::nif]
fn ed25519_keypair_new<'a>(env: Env<'a>) -> Term<'a> {
    let secret = vodozemac::Ed25519SecretKey::new();
    let public = secret.public_key();

    (
        atoms::ok(),
        string_to_binary(env, &secret.to_base64()),
        string_to_binary(env, &public.to_base64()),
    )
        .encode(env)
}

/// Sign `message` bytes with the given base64-encoded Ed25519
/// secret key. Used for every cross-signing-related signature
/// (master signing self_signing, self_signing signing devices,
/// etc.). Returns the URL-safe base64 signature.
#[rustler::nif]
fn ed25519_sign<'a>(
    env: Env<'a>,
    secret_b64: Binary,
    message: Binary,
) -> NifResult<Term<'a>> {
    let secret_str = std::str::from_utf8(secret_b64.as_slice())
        .map_err(|_| rustler_err("bad_secret_utf8"))?;
    let secret = vodozemac::Ed25519SecretKey::from_base64(secret_str)
        .map_err(|_| rustler_err("bad_secret"))?;

    let sig = secret.sign(message.as_slice());
    Ok((atoms::ok(), string_to_binary(env, &sig.to_base64())).encode(env))
}

/// Derive the public key from a base64-encoded Ed25519 secret key.
/// Useful when we have the secret stored but want the public form
/// (e.g. on restart, when constructing the `keys` field of a
/// cross-signing payload).
#[rustler::nif]
fn ed25519_public_key<'a>(env: Env<'a>, secret_b64: Binary) -> NifResult<Term<'a>> {
    let secret_str = std::str::from_utf8(secret_b64.as_slice())
        .map_err(|_| rustler_err("bad_secret_utf8"))?;
    let secret = vodozemac::Ed25519SecretKey::from_base64(secret_str)
        .map_err(|_| rustler_err("bad_secret"))?;

    Ok((atoms::ok(), string_to_binary(env, &secret.public_key().to_base64())).encode(env))
}

// ── Phase 5c: Curve25519 (Megolm backup) ───────────────────────────
//
// The server-side key backup (`m.megolm_backup.v1.curve25519-aes-sha2`)
// uses a per-user Curve25519 keypair: the public half encrypts every
// stored Megolm session under an ephemeral-static ECDH wrap; the
// private half is itself stored as an SSSS secret so peers can
// restore the backup with just the recovery key.

/// Generate a fresh Curve25519 keypair. Returns
/// `{:ok, secret_b64, public_b64}`.
#[rustler::nif]
fn curve25519_keypair_new<'a>(env: Env<'a>) -> Term<'a> {
    let secret = Curve25519SecretKey::new();
    let public = Curve25519PublicKey::from(&secret);

    (
        atoms::ok(),
        string_to_binary(env, &Base64Encode::encode_secret(&secret)),
        string_to_binary(env, &public.to_base64()),
    )
        .encode(env)
}

/// Perform an X25519 Diffie-Hellman between our secret key and a
/// peer's public key. Returns the 32-byte shared secret (raw bytes,
/// NOT base64) so the caller can feed it directly to HKDF.
#[rustler::nif]
fn curve25519_ecdh<'a>(
    env: Env<'a>,
    our_secret_b64: Binary,
    their_public_b64: Binary,
) -> NifResult<Term<'a>> {
    let secret_str = std::str::from_utf8(our_secret_b64.as_slice())
        .map_err(|_| rustler_err("bad_secret_utf8"))?;
    let secret_bytes = base64_decode(secret_str)?;
    let secret_arr: [u8; 32] = secret_bytes
        .as_slice()
        .try_into()
        .map_err(|_| rustler_err("bad_secret_length"))?;
    let secret = Curve25519SecretKey::from_slice(&secret_arr);

    let their_public = parse_curve25519(&their_public_b64)?;

    let shared = secret.diffie_hellman(&their_public);
    let bytes = shared.as_bytes();

    let mut bin = OwnedBinary::new(bytes.len()).expect("alloc");
    bin.as_mut_slice().copy_from_slice(bytes);
    Ok((atoms::ok(), Binary::from_owned(bin, env)).encode(env))
}

// vodozemac doesn't ship a Curve25519SecretKey::to_base64 helper
// (probably to discourage callers from persisting it), so we reach
// into the raw bytes ourselves.
struct Base64Encode;
impl Base64Encode {
    fn encode_secret(secret: &Curve25519SecretKey) -> String {
        use base64::engine::general_purpose::STANDARD_NO_PAD;
        use base64::Engine as _;
        STANDARD_NO_PAD.encode(*secret.to_bytes())
    }
}

fn base64_decode(s: &str) -> Result<Vec<u8>, rustler::Error> {
    use base64::engine::general_purpose::{STANDARD, STANDARD_NO_PAD};
    use base64::Engine as _;
    STANDARD_NO_PAD
        .decode(s)
        .or_else(|_| STANDARD.decode(s))
        .map_err(|_| rustler_err("bad_base64"))
}

// ── Phase 4b: SAS verification ─────────────────────────────────────
//
// Short Authentication String — the in-band emoji comparison flow
// users go through when verifying one of their devices from another.
// `Sas` holds an ephemeral curve25519 secret that vodozemac won't
// serialize (correctly so — leaking it would let an attacker who
// captures the to-device transcript later derive the same SAS bytes
// and impersonate the verification). We therefore keep live `Sas`
// state inside the BEAM via Rustler `ResourceArc`s and pass an
// opaque handle back and forth between Elixir and Rust.

pub struct SasResource(Mutex<Option<Sas>>);
pub struct EstablishedSasResource(EstablishedSas);

#[rustler::resource_impl]
impl rustler::Resource for SasResource {}

#[rustler::resource_impl]
impl rustler::Resource for EstablishedSasResource {}

/// Create a fresh SAS state. Returns a handle (ResourceArc) the
/// caller passes back to subsequent SAS NIFs, and our public key
/// (base64) to send to the peer in `m.key.verification.key`.
#[rustler::nif]
fn sas_new<'a>(env: Env<'a>) -> Term<'a> {
    let sas = Sas::new();
    let public = sas.public_key().to_base64();
    let resource = ResourceArc::new(SasResource(Mutex::new(Some(sas))));

    (
        atoms::ok(),
        resource,
        string_to_binary(env, &public),
    )
        .encode(env)
}

/// Consume the SAS state and produce an EstablishedSas given the
/// peer's base64-encoded curve25519 public key. The resulting
/// resource is what `bytes`/`calculate_mac`/`verify_mac` operate on.
#[rustler::nif]
fn sas_diffie_hellman<'a>(
    env: Env<'a>,
    sas_resource: ResourceArc<SasResource>,
    their_public_b64: Binary,
) -> NifResult<Term<'a>> {
    let mut guard = sas_resource
        .0
        .lock()
        .map_err(|_| rustler_err("sas_locked"))?;

    let sas = guard.take().ok_or(rustler_err("sas_consumed"))?;

    let their_pub_str = std::str::from_utf8(their_public_b64.as_slice())
        .map_err(|_| rustler_err("bad_public_utf8"))?;
    let their_pub = Curve25519PublicKey::from_base64(their_pub_str)
        .map_err(|_| rustler_err("bad_public"))?;

    let established = sas
        .diffie_hellman(their_pub)
        .map_err(|_| rustler_err("dh_failed"))?;

    let resource = ResourceArc::new(EstablishedSasResource(established));
    Ok((atoms::ok(), resource).encode(env))
}

/// Compute the SAS bytes for the given `info` string per Matrix
/// spec. Returns `{:ok, {emoji_indices_list, decimals_tuple}}` where
/// `emoji_indices_list` is the 7-element list of 0–63 indices into
/// the Matrix emoji table and `decimals_tuple` is `{u16, u16, u16}`.
#[rustler::nif]
fn sas_bytes<'a>(
    env: Env<'a>,
    sas: ResourceArc<EstablishedSasResource>,
    info: Binary,
) -> NifResult<Term<'a>> {
    let info_str = std::str::from_utf8(info.as_slice())
        .map_err(|_| rustler_err("bad_info_utf8"))?;
    let bytes = sas.0.bytes(info_str);
    let emojis = bytes.emoji_indices();
    let decimals = bytes.decimals();

    let emoji_list: Vec<u8> = emojis.to_vec();
    Ok((
        atoms::ok(),
        emoji_list,
        (decimals.0, decimals.1, decimals.2),
    )
        .encode(env))
}

/// Compute a base64 MAC over `input` with the given `info` string.
/// Used during the m.key.verification.mac stage of the SAS protocol
/// to authenticate which ed25519 keys we're committing to.
#[rustler::nif]
fn sas_calculate_mac<'a>(
    env: Env<'a>,
    sas: ResourceArc<EstablishedSasResource>,
    input: Binary,
    info: Binary,
) -> NifResult<Term<'a>> {
    let input_str = std::str::from_utf8(input.as_slice())
        .map_err(|_| rustler_err("bad_input_utf8"))?;
    let info_str = std::str::from_utf8(info.as_slice())
        .map_err(|_| rustler_err("bad_info_utf8"))?;

    let mac = sas.0.calculate_mac(input_str, info_str);
    Ok((atoms::ok(), string_to_binary(env, &mac.to_base64())).encode(env))
}

/// Verify a peer-computed MAC. Returns `:ok` or
/// `{:error, :bad_mac}` if the MAC doesn't match.
#[rustler::nif]
fn sas_verify_mac<'a>(
    env: Env<'a>,
    sas: ResourceArc<EstablishedSasResource>,
    input: Binary,
    info: Binary,
    tag_b64: Binary,
) -> Term<'a> {
    let result: Result<(), &'static str> = (|| {
        let input_str = std::str::from_utf8(input.as_slice()).map_err(|_| "bad_input_utf8")?;
        let info_str = std::str::from_utf8(info.as_slice()).map_err(|_| "bad_info_utf8")?;
        let tag_str = std::str::from_utf8(tag_b64.as_slice()).map_err(|_| "bad_tag_utf8")?;
        let tag = SasMac::from_base64(tag_str).map_err(|_| "bad_tag")?;
        sas.0.verify_mac(input_str, info_str, &tag).map_err(|_| "bad_mac")
    })();

    match result {
        Ok(()) => atoms::ok().encode(env),
        Err(_) => (atoms::error(), atoms::bad_message()).encode(env),
    }
}

rustler::init!("Elixir.Vodozemac.Native");