Skip to main content

native/erltoken_nif/src/lib.rs

use rustler::{atoms, Binary, Encoder, Env, NewBinary, Term};
use tiktoken::CoreBpe;

atoms! {
    ok,
    error,
    invalid_utf8,
    unknown_encoding_or_model,
    unknown_model
}

fn binary_term<'a>(env: Env<'a>, bytes: &[u8]) -> Term<'a> {
    let mut bin = NewBinary::new(env, bytes.len());
    bin.as_mut_slice().copy_from_slice(bytes);
    bin.into()
}

fn binary_str_term<'a>(env: Env<'a>, value: &str) -> Term<'a> {
    binary_term(env, value.as_bytes())
}

fn binary_to_str<'a>(env: Env<'a>, bin: Binary<'a>) -> Result<&'a str, Term<'a>> {
    std::str::from_utf8(bin.as_slice()).map_err(|_| invalid_utf8().encode(env))
}

fn resolve_encoding(name: &str) -> Option<&'static CoreBpe> {
    tiktoken::get_encoding(name)
        .or_else(|| tiktoken::model_to_encoding(name).and_then(tiktoken::get_encoding))
}

#[rustler::nif]
fn list_encodings_nif<'a>(env: Env<'a>) -> Term<'a> {
    let items: Vec<Term<'a>> = tiktoken::list_encodings()
        .iter()
        .map(|name| binary_str_term(env, name))
        .collect();
    items.encode(env)
}

#[rustler::nif]
fn encoding_for_model_nif<'a>(env: Env<'a>, model: Binary<'a>) -> Term<'a> {
    let model = match binary_to_str(env, model) {
        Ok(model) => model,
        Err(reason) => return (error(), reason).encode(env),
    };

    match tiktoken::model_to_encoding(model) {
        Some(name) => (ok(), binary_str_term(env, name)).encode(env),
        None => (error(), unknown_model()).encode(env),
    }
}

#[rustler::nif]
fn count_nif<'a>(env: Env<'a>, name: Binary<'a>, text: Binary<'a>) -> Term<'a> {
    let name = match binary_to_str(env, name) {
        Ok(name) => name,
        Err(reason) => return (error(), reason).encode(env),
    };
    let text = match binary_to_str(env, text) {
        Ok(text) => text,
        Err(reason) => return (error(), reason).encode(env),
    };

    match resolve_encoding(name) {
        Some(encoding) => (ok(), encoding.count(text)).encode(env),
        None => (error(), unknown_encoding_or_model()).encode(env),
    }
}

#[rustler::nif]
fn encode_nif<'a>(env: Env<'a>, name: Binary<'a>, text: Binary<'a>) -> Term<'a> {
    encode_impl(env, name, text, false)
}

#[rustler::nif]
fn encode_with_special_tokens_nif<'a>(env: Env<'a>, name: Binary<'a>, text: Binary<'a>) -> Term<'a> {
    encode_impl(env, name, text, true)
}

fn encode_impl<'a>(env: Env<'a>, name: Binary<'a>, text: Binary<'a>, special: bool) -> Term<'a> {
    let name = match binary_to_str(env, name) {
        Ok(name) => name,
        Err(reason) => return (error(), reason).encode(env),
    };
    let text = match binary_to_str(env, text) {
        Ok(text) => text,
        Err(reason) => return (error(), reason).encode(env),
    };

    match resolve_encoding(name) {
        Some(encoding) => {
            let tokens = if special {
                encoding.encode_with_special_tokens(text)
            } else {
                encoding.encode(text)
            };
            (ok(), tokens).encode(env)
        }
        None => (error(), unknown_encoding_or_model()).encode(env),
    }
}

#[rustler::nif]
fn decode_nif<'a>(env: Env<'a>, name: Binary<'a>, tokens: Vec<u32>) -> Term<'a> {
    let name = match binary_to_str(env, name) {
        Ok(name) => name,
        Err(reason) => return (error(), reason).encode(env),
    };

    match resolve_encoding(name) {
        Some(encoding) => (ok(), binary_term(env, &encoding.decode(&tokens))).encode(env),
        None => (error(), unknown_encoding_or_model()).encode(env),
    }
}

#[rustler::nif]
fn estimate_cost_usd_micro_nif<'a>(
    env: Env<'a>,
    model: Binary<'a>,
    input_tokens: u64,
    output_tokens: u64,
) -> Term<'a> {
    let model = match binary_to_str(env, model) {
        Ok(model) => model,
        Err(reason) => return (error(), reason).encode(env),
    };

    match tiktoken::pricing::estimate_cost(model, input_tokens, output_tokens) {
        Some(cost) => {
            let micros = (cost * 1_000_000.0).round() as u64;
            (ok(), micros).encode(env)
        }
        None => (error(), unknown_model()).encode(env),
    }
}

rustler::init!("erltoken");