Skip to main content

native/kameleoon_elixir_bridge/src/utils_ex.rs

use std::sync::Arc;

use kameleoon_core::error::KameleoonError;
use rustler::env::SavedTerm;
use rustler::Env;
use rustler::{types::LocalPid, Atom, Decoder, Encoder, Error as NifError, NifResult, Reference, Term};
use tokio::runtime::Handle;

use crate::client_ex::{error, kameleoon_native, nil, ok};
use crate::error_ex::ErrorEx;

// MARK: - AsyncReply
pub(crate) struct AsyncReply {
    owner: LocalPid,
    env: rustler::OwnedEnv,
    saved_ref: SavedTerm,
    runtime_handle: Handle,
}

impl AsyncReply {
    pub(crate) fn new(owner: LocalPid, reply_ref: Reference, runtime_handle: Handle) -> Self {
        let env = rustler::OwnedEnv::new();
        let saved_ref = env.save(reply_ref);
        Self {
            owner,
            env,
            saved_ref,
            runtime_handle,
        }
    }

    pub(crate) fn send<T>(self, result: Result<T, KameleoonError>)
    where
        T: Encoder + Send + 'static,
    {
        if rustler::thread::is_scheduler_thread() {
            let runtime_handle = self.runtime_handle.clone();
            runtime_handle.spawn(async move {
                self.send_from_rust_thread(result);
            });
        } else {
            self.send_from_rust_thread(result);
        }
    }

    fn send_from_rust_thread<T>(self, result: Result<T, KameleoonError>)
    where
        T: Encoder,
    {
        let Self {
            owner,
            mut env,
            saved_ref,
            runtime_handle: _,
        } = self;

        let _ = env.send_and_clear(&owner, |env| {
            let reply_ref = saved_ref.load(env);
            let result = match result {
                Ok(value) => (ok(), value).encode(env),
                Err(err) => (error(), ErrorEx::from(&err)).encode(env),
            };
            (kameleoon_native(), reply_ref, result)
        });
    }
}

// MARK: - Converters

pub(crate) fn map_get<'a>(term: Term<'a>, key: Atom) -> Option<Term<'a>> {
    term.map_get(key).ok()
}

pub(crate) fn ensure_map(term: Term) -> NifResult<()> {
    if term.is_map() {
        Ok(())
    } else {
        Err(NifError::BadArg)
    }
}

pub(crate) fn is_nil(term: Term) -> bool {
    nil() == term
}

pub(crate) fn optional_field<'a, T>(term: Term<'a>, key: Atom) -> NifResult<Option<T>>
where
    T: Decoder<'a>,
{
    match map_get(term, key) {
        Some(value) if !is_nil(value) => T::decode(value).map(Some),
        _ => Ok(None),
    }
}

pub(crate) fn required_field<'a, T>(term: Term<'a>, key: Atom) -> NifResult<T>
where
    T: Decoder<'a>,
{
    map_get(term, key).ok_or(NifError::BadArg).and_then(T::decode)
}

pub(crate) fn optional_field_or<'a, T>(term: Term<'a>, key: Atom, default: T) -> NifResult<T>
where
    T: Decoder<'a>,
{
    Ok(optional_field(term, key)?.unwrap_or(default))
}

pub(crate) fn string_or_atom(term: Term) -> NifResult<String> {
    String::decode(term).or_else(|_| term.atom_to_string())
}

// MARK: - ArcStr

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct ArcStr(Arc<str>);

impl From<Arc<str>> for ArcStr {
    fn from(value: Arc<str>) -> Self {
        Self(value)
    }
}

impl From<String> for ArcStr {
    fn from(value: String) -> Self {
        Self(Arc::from(value))
    }
}

impl From<&str> for ArcStr {
    fn from(value: &str) -> Self {
        Self(Arc::from(value))
    }
}

impl Encoder for ArcStr {
    fn encode<'a>(&self, env: Env<'a>) -> Term<'a> {
        self.0.as_ref().encode(env)
    }
}