use crate::owner::OwnerState;
use crate::r_api::{c_string, ParseStatus, RApi, Sexp, LGLSXP, RAWSXP};
use crate::work::{DecodeArrowWork, EncodeDataframeWork, WorkResult};
const READ_IPC_DATAFRAME_SOURCE: &str = r#"
value <- arrow::read_ipc_stream(.rx_ipc_raw, as_data_frame = TRUE)
if (!inherits(value, "data.frame")) stop("Arrow IPC did not decode to a data.frame")
value
"#;
const INHERITS_DATAFRAME_SOURCE: &str = r#"inherits(.rx_df, "data.frame")"#;
const WRITE_IPC_DATAFRAME_SOURCE: &str = r#"arrow::write_to_raw(.rx_df, format = "stream")"#;
const REQUIRE_ARROW_SOURCE: &str = r#"requireNamespace("arrow", quietly = TRUE)"#;
pub fn do_encode_dataframe(state: &mut OwnerState, work: &EncodeDataframeWork) -> WorkResult {
if let Some(failure) = state.terminal_init_failure() {
return WorkResult::NativeInitFailed(failure);
}
if !state.is_initialized() {
return WorkResult::Error("embedded R runtime is not initialized".to_owned());
}
let Some(api) = state.r.as_ref() else {
return WorkResult::Error("embedded R API is not initialized".to_owned());
};
unsafe { do_encode_dataframe_inner(api, work) }
}
pub fn do_decode_arrow(state: &mut OwnerState, work: &DecodeArrowWork) -> WorkResult {
if let Some(failure) = state.terminal_init_failure() {
return WorkResult::NativeInitFailed(failure);
}
if !state.is_initialized() {
return WorkResult::Error("embedded R runtime is not initialized".to_owned());
}
let Some(api) = state.r.as_ref() else {
return WorkResult::Error("embedded R API is not initialized".to_owned());
};
let (sexp, resource_kind) = {
let state = match work.resource.state.lock() {
Ok(state) => state,
Err(_error) => {
return WorkResult::Error("native R object resource mutex was poisoned".to_owned())
}
};
if work
.resource
.release_enqueued
.load(std::sync::atomic::Ordering::SeqCst)
{
return WorkResult::Error(
"native R object resource has already been released".to_owned(),
);
}
match state.sexp {
Some(sexp) => (sexp, state.kind),
None => {
return WorkResult::Error(
"native R object resource has already been released".to_owned(),
)
}
}
};
unsafe { do_decode_arrow_inner(api, sexp, resource_kind) }
}
unsafe fn do_encode_dataframe_inner(api: &RApi, work: &EncodeDataframeWork) -> WorkResult {
let rx_ipc_raw = c_string(".rx_ipc_raw").expect(".rx_ipc_raw contains no NUL bytes");
let mut protect_count = 0;
let env = (api.protect)((api.new_env)(*api.global_env, 1, 29));
protect_count += 1;
let raw_len = match isize::try_from(work.ipc_bytes.len()) {
Ok(len) => len,
Err(_error) => {
(api.unprotect)(protect_count);
return WorkResult::Error("Arrow IPC input is too large".to_owned());
}
};
let raw = (api.protect)((api.alloc_vector)(RAWSXP, raw_len));
protect_count += 1;
if !work.ipc_bytes.is_empty() {
std::ptr::copy_nonoverlapping(
work.ipc_bytes.as_ptr(),
(api.raw)(raw),
work.ipc_bytes.len(),
);
}
(api.define_var)((api.install)(rx_ipc_raw.as_ptr()), raw, env);
if let Err(result) = require_arrow_in_env(api, env) {
(api.unprotect)(protect_count);
return result;
}
let value = match eval_source_in_env(api, env, READ_IPC_DATAFRAME_SOURCE) {
Ok(value) => value,
Err(result) => {
(api.unprotect)(protect_count);
return result;
}
};
let resource =
crate::codec::preserved_resource(api, value, crate::resource::ResourceKind::Dataframe);
(api.unprotect)(protect_count);
WorkResult::Resource(resource)
}
unsafe fn do_decode_arrow_inner(
api: &RApi,
sexp: Sexp,
resource_kind: crate::resource::ResourceKind,
) -> WorkResult {
let rx_df = c_string(".rx_df").expect(".rx_df contains no NUL bytes");
let mut protect_count = 0;
let env = (api.protect)((api.new_env)(*api.global_env, 1, 29));
protect_count += 1;
(api.define_var)((api.install)(rx_df.as_ptr()), sexp, env);
if let Err(result) = require_arrow_in_env(api, env) {
(api.unprotect)(protect_count);
return result;
}
if resource_kind != crate::resource::ResourceKind::Dataframe {
let inherits = match eval_source_in_env(api, env, INHERITS_DATAFRAME_SOURCE) {
Ok(inherits) => inherits,
Err(result) => {
(api.unprotect)(protect_count);
return result;
}
};
let is_dataframe = (api.type_of)(inherits) == LGLSXP
&& (api.xlength)(inherits) >= 1
&& (api.logical_elt)(inherits, 0) == 1;
if !is_dataframe {
(api.unprotect)(protect_count);
return WorkResult::TaggedError(
"not_dataframe",
"native object is not an Arrow-compatible data frame".to_owned(),
);
}
}
let raw = match eval_source_in_env(api, env, WRITE_IPC_DATAFRAME_SOURCE) {
Ok(raw) => raw,
Err(result) => {
(api.unprotect)(protect_count);
return result;
}
};
if (api.type_of)(raw) != RAWSXP {
(api.unprotect)(protect_count);
return WorkResult::Error("R arrow writer did not return raw bytes".to_owned());
}
let output_len = match usize::try_from((api.xlength)(raw)) {
Ok(len) => len,
Err(_error) => {
(api.unprotect)(protect_count);
return WorkResult::Error("Arrow IPC output is too large".to_owned());
}
};
let bytes = if output_len == 0 {
Vec::new()
} else {
std::slice::from_raw_parts((api.raw)(raw), output_len).to_vec()
};
(api.unprotect)(protect_count);
WorkResult::ArrowBytes(bytes)
}
unsafe fn require_arrow_in_env(api: &RApi, env: Sexp) -> Result<(), WorkResult> {
let value = eval_source_in_env(api, env, REQUIRE_ARROW_SOURCE)?;
if (api.type_of)(value) != LGLSXP || (api.xlength)(value) < 1 {
return Err(WorkResult::Error(
"R expression did not return a logical scalar".to_owned(),
));
}
if (api.logical_elt)(value, 0) == 1 {
Ok(())
} else {
Err(WorkResult::TaggedError(
"missing_r_package",
"R 'arrow' package is not installed".to_owned(),
))
}
}
unsafe fn eval_source_in_env(api: &RApi, env: Sexp, source: &str) -> Result<Sexp, WorkResult> {
let source = c_string(source)
.map_err(|error| WorkResult::Error(format!("R source contains NUL byte: {error}")))?;
let mut protect_count = 0;
let source_sexp = (api.protect)((api.mk_string)(source.as_ptr()));
protect_count += 1;
let mut parse_status = ParseStatus::Null;
let exprs = (api.protect)((api.parse_vector)(
source_sexp,
-1,
&mut parse_status,
*api.nil_value,
));
protect_count += 1;
if parse_status != ParseStatus::Ok && parse_status != ParseStatus::Null {
(api.unprotect)(protect_count);
return Err(WorkResult::Error("R parse failed".to_owned()));
}
let mut result = *api.nil_value;
let expr_count = (api.xlength)(exprs);
for index in 0..expr_count {
let mut error_occurred = 0;
result = (api.try_eval)((api.vector_elt)(exprs, index), env, &mut error_occurred);
if error_occurred != 0 {
(api.unprotect)(protect_count);
return Err(WorkResult::Error("R evaluation failed".to_owned()));
}
}
(api.unprotect)(protect_count);
Ok(result)
}