Skip to main content

native/rx_rust_nif/src/arrow.rs

use crate::owner::OwnerState;
use crate::r_api::{c_string, ParseStatus, RApi, Sexp, LGLSXP, RAWSXP};
use crate::work::{DecodeArrowWork, EncodeDataframeWork, WorkResult};

const READ_IPC_DATAFRAME_SOURCE: &str = r#"
value <- arrow::read_ipc_stream(.rx_ipc_raw, as_data_frame = TRUE)
if (!inherits(value, "data.frame")) stop("Arrow IPC did not decode to a data.frame")
value
"#;

const INHERITS_DATAFRAME_SOURCE: &str = r#"inherits(.rx_df, "data.frame")"#;
const WRITE_IPC_DATAFRAME_SOURCE: &str = r#"arrow::write_to_raw(.rx_df, format = "stream")"#;
const REQUIRE_ARROW_SOURCE: &str = r#"requireNamespace("arrow", quietly = TRUE)"#;

pub fn do_encode_dataframe(state: &mut OwnerState, work: &EncodeDataframeWork) -> WorkResult {
    if let Some(failure) = state.terminal_init_failure() {
        return WorkResult::NativeInitFailed(failure);
    }

    if !state.is_initialized() {
        return WorkResult::Error("embedded R runtime is not initialized".to_owned());
    }

    let Some(api) = state.r.as_ref() else {
        return WorkResult::Error("embedded R API is not initialized".to_owned());
    };

    unsafe { do_encode_dataframe_inner(api, work) }
}

pub fn do_decode_arrow(state: &mut OwnerState, work: &DecodeArrowWork) -> WorkResult {
    if let Some(failure) = state.terminal_init_failure() {
        return WorkResult::NativeInitFailed(failure);
    }

    if !state.is_initialized() {
        return WorkResult::Error("embedded R runtime is not initialized".to_owned());
    }

    let Some(api) = state.r.as_ref() else {
        return WorkResult::Error("embedded R API is not initialized".to_owned());
    };

    let (sexp, resource_kind) = {
        let state = match work.resource.state.lock() {
            Ok(state) => state,
            Err(_error) => {
                return WorkResult::Error("native R object resource mutex was poisoned".to_owned())
            }
        };

        if work
            .resource
            .release_enqueued
            .load(std::sync::atomic::Ordering::SeqCst)
        {
            return WorkResult::Error(
                "native R object resource has already been released".to_owned(),
            );
        }

        match state.sexp {
            Some(sexp) => (sexp, state.kind),
            None => {
                return WorkResult::Error(
                    "native R object resource has already been released".to_owned(),
                )
            }
        }
    };

    unsafe { do_decode_arrow_inner(api, sexp, resource_kind) }
}

unsafe fn do_encode_dataframe_inner(api: &RApi, work: &EncodeDataframeWork) -> WorkResult {
    let rx_ipc_raw = c_string(".rx_ipc_raw").expect(".rx_ipc_raw contains no NUL bytes");
    let mut protect_count = 0;

    let env = (api.protect)((api.new_env)(*api.global_env, 1, 29));
    protect_count += 1;

    let raw_len = match isize::try_from(work.ipc_bytes.len()) {
        Ok(len) => len,
        Err(_error) => {
            (api.unprotect)(protect_count);
            return WorkResult::Error("Arrow IPC input is too large".to_owned());
        }
    };

    let raw = (api.protect)((api.alloc_vector)(RAWSXP, raw_len));
    protect_count += 1;

    if !work.ipc_bytes.is_empty() {
        std::ptr::copy_nonoverlapping(
            work.ipc_bytes.as_ptr(),
            (api.raw)(raw),
            work.ipc_bytes.len(),
        );
    }

    (api.define_var)((api.install)(rx_ipc_raw.as_ptr()), raw, env);

    if let Err(result) = require_arrow_in_env(api, env) {
        (api.unprotect)(protect_count);
        return result;
    }

    let value = match eval_source_in_env(api, env, READ_IPC_DATAFRAME_SOURCE) {
        Ok(value) => value,
        Err(result) => {
            (api.unprotect)(protect_count);
            return result;
        }
    };

    let resource =
        crate::codec::preserved_resource(api, value, crate::resource::ResourceKind::Dataframe);
    (api.unprotect)(protect_count);
    WorkResult::Resource(resource)
}

unsafe fn do_decode_arrow_inner(
    api: &RApi,
    sexp: Sexp,
    resource_kind: crate::resource::ResourceKind,
) -> WorkResult {
    let rx_df = c_string(".rx_df").expect(".rx_df contains no NUL bytes");
    let mut protect_count = 0;

    let env = (api.protect)((api.new_env)(*api.global_env, 1, 29));
    protect_count += 1;
    (api.define_var)((api.install)(rx_df.as_ptr()), sexp, env);

    if let Err(result) = require_arrow_in_env(api, env) {
        (api.unprotect)(protect_count);
        return result;
    }

    if resource_kind != crate::resource::ResourceKind::Dataframe {
        let inherits = match eval_source_in_env(api, env, INHERITS_DATAFRAME_SOURCE) {
            Ok(inherits) => inherits,
            Err(result) => {
                (api.unprotect)(protect_count);
                return result;
            }
        };

        let is_dataframe = (api.type_of)(inherits) == LGLSXP
            && (api.xlength)(inherits) >= 1
            && (api.logical_elt)(inherits, 0) == 1;

        if !is_dataframe {
            (api.unprotect)(protect_count);
            return WorkResult::TaggedError(
                "not_dataframe",
                "native object is not an Arrow-compatible data frame".to_owned(),
            );
        }
    }

    let raw = match eval_source_in_env(api, env, WRITE_IPC_DATAFRAME_SOURCE) {
        Ok(raw) => raw,
        Err(result) => {
            (api.unprotect)(protect_count);
            return result;
        }
    };

    if (api.type_of)(raw) != RAWSXP {
        (api.unprotect)(protect_count);
        return WorkResult::Error("R arrow writer did not return raw bytes".to_owned());
    }

    let output_len = match usize::try_from((api.xlength)(raw)) {
        Ok(len) => len,
        Err(_error) => {
            (api.unprotect)(protect_count);
            return WorkResult::Error("Arrow IPC output is too large".to_owned());
        }
    };

    let bytes = if output_len == 0 {
        Vec::new()
    } else {
        std::slice::from_raw_parts((api.raw)(raw), output_len).to_vec()
    };

    (api.unprotect)(protect_count);
    WorkResult::ArrowBytes(bytes)
}

unsafe fn require_arrow_in_env(api: &RApi, env: Sexp) -> Result<(), WorkResult> {
    let value = eval_source_in_env(api, env, REQUIRE_ARROW_SOURCE)?;

    if (api.type_of)(value) != LGLSXP || (api.xlength)(value) < 1 {
        return Err(WorkResult::Error(
            "R expression did not return a logical scalar".to_owned(),
        ));
    }

    if (api.logical_elt)(value, 0) == 1 {
        Ok(())
    } else {
        Err(WorkResult::TaggedError(
            "missing_r_package",
            "R 'arrow' package is not installed".to_owned(),
        ))
    }
}

unsafe fn eval_source_in_env(api: &RApi, env: Sexp, source: &str) -> Result<Sexp, WorkResult> {
    let source = c_string(source)
        .map_err(|error| WorkResult::Error(format!("R source contains NUL byte: {error}")))?;
    let mut protect_count = 0;

    let source_sexp = (api.protect)((api.mk_string)(source.as_ptr()));
    protect_count += 1;

    let mut parse_status = ParseStatus::Null;
    let exprs = (api.protect)((api.parse_vector)(
        source_sexp,
        -1,
        &mut parse_status,
        *api.nil_value,
    ));
    protect_count += 1;

    if parse_status != ParseStatus::Ok && parse_status != ParseStatus::Null {
        (api.unprotect)(protect_count);
        return Err(WorkResult::Error("R parse failed".to_owned()));
    }

    let mut result = *api.nil_value;
    let expr_count = (api.xlength)(exprs);

    for index in 0..expr_count {
        let mut error_occurred = 0;
        result = (api.try_eval)((api.vector_elt)(exprs, index), env, &mut error_occurred);

        if error_occurred != 0 {
            (api.unprotect)(protect_count);
            return Err(WorkResult::Error("R evaluation failed".to_owned()));
        }
    }

    (api.unprotect)(protect_count);
    Ok(result)
}