Skip to main content

native/rx_rust_nif/src/eval.rs

use crate::owner::OwnerState;
use crate::r_api::{c_string, ParseStatus, RApi, Sexp, CE_UTF8, LGLSXP, STRSXP, VECSXP};
use crate::work::{
    EvalOutput, EvalStringWork, EvalSuccess, EvalWork, InitConfig, InitState, InitWork,
    StructuredError, WorkResult,
};

const EVAL_HELPER_SOURCE: &str = r#"function(exprs, env) {
  messages <- character()
  warnings <- character()
  stdout <- character()
  value <- NULL
  con <- textConnection("stdout", "w", local = TRUE)
  sinking <- FALSE
  closed <- FALSE
  finish <- function() {
    if (sinking) {
      sink(type = "output")
      sinking <<- FALSE
    }
    if (!closed) {
      close(con)
      closed <<- TRUE
    }
  }
  sink(con, type = "output")
  sinking <- TRUE
  tryCatch(
    {
      result <- withCallingHandlers({
        for (expr in exprs) value <- eval(expr, envir = env)
        list(TRUE, value, NULL, NULL, NULL, stdout, messages, warnings)
      },
      message = function(m) {
        messages <<- c(messages, conditionMessage(m))
        invokeRestart("muffleMessage")
      },
      warning = function(w) {
        warnings <<- c(warnings, conditionMessage(w))
        invokeRestart("muffleWarning")
      })
      finish()
      result[[6]] <- stdout
      result[[7]] <- messages
      result[[8]] <- warnings
      result
    },
    error = function(e) {
      finish()
      call <- conditionCall(e)
      call_text <- if (is.null(call)) NULL else paste(deparse(call), collapse = "\n")
      list(FALSE, NULL, conditionMessage(e), class(e), call_text, stdout, messages, warnings)
    }
  )
}"#;

pub fn do_init(state: &mut OwnerState, work: &InitWork) -> WorkResult {
    let requested = InitConfig::from(work);
    state.init_attempt_count += 1;
    state.attempted_init_config = Some(requested.clone());

    match state.init_state {
        InitState::Initialized => {
            if state.init_config.as_ref() == Some(&requested) {
                return WorkResult::Ok;
            }

            let current = state
                .init_config
                .clone()
                .expect("initialized native state without init config");
            let mut mismatches = Vec::new();

            if current.r_home != requested.r_home {
                mismatches.push("r_home");
            }
            if current.lib_r_path != requested.lib_r_path {
                mismatches.push("lib_r_path");
            }
            if current.lib_paths != requested.lib_paths {
                mismatches.push("lib_paths");
            }

            state.init_mismatch_count += 1;

            return WorkResult::NativeInitMismatch(crate::work::InitMismatch {
                message: "embedded native R is already initialized with different native init options; restart the BEAM to change native init options".to_owned(),
                mismatches,
                current,
                requested,
            });
        }
        InitState::Failed => {
            if let Some(error) = state.last_init_error.as_ref() {
                if !error.retryable {
                    return WorkResult::NativeInitFailed(error.clone());
                }
            }
        }
        InitState::Uninitialized => {}
    }

    std::env::set_var("R_HOME", &work.r_home);
    std::env::set_var("R_ENABLE_JIT", "0");
    std::env::set_var("R_DEFAULT_PACKAGES", "NULL");

    let api = unsafe {
        match RApi::load(&work.lib_r_path) {
            Ok(api) => api,
            Err(message) => {
                let failure =
                    record_init_failure(state, retryable_loader_stage(&message), message, true);
                state.r = None;
                return WorkResult::NativeInitFailed(failure);
            }
        }
    };

    let argv = [
        "R",
        "--silent",
        "--no-save",
        "--no-restore",
        "--no-readline",
    ];

    let mut c_argv = match argv
        .iter()
        .map(|value| c_string(value))
        .collect::<Result<Vec<_>, _>>()
    {
        Ok(values) => values,
        Err(error) => {
            let failure = record_init_failure(
                state,
                "resolve_boot_symbols",
                format!("failed to build R argv: {error}"),
                true,
            );
            return WorkResult::NativeInitFailed(failure);
        }
    };

    let mut argv_ptrs = c_argv
        .iter_mut()
        .map(|value| value.as_ptr() as *mut _)
        .collect::<Vec<_>>();

    unsafe {
        api.configure_c_stack();
        let init_status = (api.rf_initialize_r)(argv_ptrs.len() as i32, argv_ptrs.as_mut_ptr());

        if init_status < 0 {
            let failure =
                record_init_failure(state, "rf_initialize_r", "Rf_initialize_R failed", false);
            state.r = Some(api);
            return WorkResult::NativeInitFailed(failure);
        }

        if init_fault_stage_enabled("after_rf_initialize") {
            let failure = record_init_failure(
                state,
                "rf_initialize_r",
                "fault injected after Rf_initialize_R",
                false,
            );
            state.r = Some(api);
            return WorkResult::NativeInitFailed(failure);
        }

        api.configure_c_stack();
        (api.setup_rmainloop)();

        // The BEAM owns a process-wide SIGCHLD handler whose signal-dispatcher
        // thread consumes child-process exits. Embedded R's libc
        // system()/popen() then block forever in wait4() waiting for a child
        // whose exit the BEAM already reaped, hanging every R call that shells
        // out (system/system2/pipe and therefore requireNamespace/package
        // loading). Restore the default SIGCHLD disposition so R's own
        // system()/popen() reap their children. Safe for the BEAM: OTP reaps
        // Port children through erl_child_setup, not through this handler.
        // Mirrors reset_sigchld_to_default() in c_src/rx_nif.c.
        libc::signal(libc::SIGCHLD, libc::SIG_DFL);
    }

    state.r = Some(api);

    if let Some(api) = state.r.as_ref() {
        match apply_lib_paths(api, &work.lib_paths) {
            WorkResult::Ok => {
                state.init_state = InitState::Initialized;
                state.init_config = Some(requested);
                state.last_init_error = None;
                WorkResult::Ok
            }
            WorkResult::Error(message) => {
                let failure = record_init_failure(state, "apply_lib_paths", message, false);
                WorkResult::NativeInitFailed(failure)
            }
            other => WorkResult::Error(format!("unexpected apply_lib_paths result: {other:?}")),
        }
    } else {
        WorkResult::Error("embedded R API was not stored after init".to_owned())
    }
}

fn retryable_loader_stage(message: &str) -> &'static str {
    if message.contains("libR is missing") {
        "resolve_boot_symbols"
    } else {
        "dlopen_lib_r"
    }
}

fn crash_repro_enabled() -> bool {
    std::env::var("RX_CRASH_REPRO").ok().as_deref() == Some("1")
}

fn init_fault_stage_enabled(stage: &str) -> bool {
    crash_repro_enabled()
        && std::env::var("RX_NATIVE_FAULT_INIT_STAGE").ok().as_deref() == Some(stage)
}

fn record_init_failure(
    state: &mut OwnerState,
    stage: &'static str,
    message: impl Into<String>,
    retryable: bool,
) -> crate::work::InitFailure {
    let failure = crate::work::InitFailure {
        stage,
        message: message.into(),
        retryable,
        restart_required: !retryable,
    };
    state.init_state = InitState::Failed;
    state.last_init_error = Some(failure.clone());
    failure
}

fn apply_lib_paths(api: &RApi, lib_paths: &[String]) -> WorkResult {
    if lib_paths.is_empty() {
        return WorkResult::Ok;
    }

    unsafe {
        let mut protect_count = 0;
        let paths = (api.protect)((api.alloc_vector)(STRSXP, lib_paths.len() as isize));
        protect_count += 1;

        for (index, path) in lib_paths.iter().enumerate() {
            let path = match c_string(path) {
                Ok(path) => path,
                Err(error) => {
                    (api.unprotect)(protect_count);
                    return WorkResult::Error(format!("invalid native R library path: {error}"));
                }
            };

            let path_sexp = (api.protect)((api.mk_char_len_ce)(
                path.as_ptr(),
                path.as_bytes().len() as i32,
                CE_UTF8,
            ));
            protect_count += 1;
            (api.set_string_elt)(paths, index as isize, path_sexp);
        }

        let lib_paths_name = c_string(".libPaths").expect(".libPaths contains no NUL");
        let current_call = (api.protect)((api.lang1)((api.install)(lib_paths_name.as_ptr())));
        protect_count += 1;

        let mut error_occurred = 0;
        let current = (api.protect)((api.try_eval)(
            current_call,
            *api.global_env,
            &mut error_occurred,
        ));
        protect_count += 1;

        if error_occurred != 0 {
            (api.unprotect)(protect_count);
            return WorkResult::Error("failed to read native R library paths".to_owned());
        }

        let current_len = (api.xlength)(current);
        let total_len = lib_paths.len() as isize + current_len;
        let combined = (api.protect)((api.alloc_vector)(STRSXP, total_len));
        protect_count += 1;

        for index in 0..lib_paths.len() {
            (api.set_string_elt)(
                combined,
                index as isize,
                (api.string_elt)(paths, index as isize),
            );
        }

        for index in 0..current_len {
            (api.set_string_elt)(
                combined,
                lib_paths.len() as isize + index,
                (api.string_elt)(current, index),
            );
        }

        let set_call = (api.protect)((api.lang2)(
            (api.install)(lib_paths_name.as_ptr()),
            combined,
        ));
        protect_count += 1;

        error_occurred = 0;
        (api.try_eval)(set_call, *api.global_env, &mut error_occurred);

        if error_occurred != 0 {
            (api.unprotect)(protect_count);
            return WorkResult::Error("failed to apply native R library paths".to_owned());
        }

        (api.unprotect)(protect_count);
    }

    WorkResult::Ok
}

pub fn do_eval_string(state: &mut OwnerState, work: &EvalStringWork) -> WorkResult {
    if let Some(failure) = state.terminal_init_failure() {
        return WorkResult::NativeInitFailed(failure);
    }

    if !state.is_initialized() {
        return WorkResult::Error("embedded R runtime is not initialized".to_owned());
    }

    let Some(api) = state.r.as_ref() else {
        return WorkResult::Error("embedded R API is not initialized".to_owned());
    };

    unsafe {
        let source = match c_string(&work.source) {
            Ok(source) => source,
            Err(error) => return WorkResult::Error(format!("source contains NUL byte: {error}")),
        };

        let mut protect_count = 0;
        let source_sexp = (api.protect)((api.mk_string)(source.as_ptr()));
        protect_count += 1;

        let mut parse_status = ParseStatus::Null;
        let exprs = (api.protect)((api.parse_vector)(
            source_sexp,
            -1,
            &mut parse_status,
            *api.nil_value,
        ));
        protect_count += 1;

        if parse_status != ParseStatus::Ok {
            (api.unprotect)(protect_count);
            return WorkResult::TaggedError("parse_error", "R parse failed".to_owned());
        }

        let expr_count = (api.xlength)(exprs);
        let mut result = *api.nil_value;

        for index in 0..expr_count {
            let mut error_occurred = 0;
            result = (api.try_eval)(
                (api.vector_elt)(exprs, index),
                *api.global_env,
                &mut error_occurred,
            );

            if error_occurred != 0 {
                (api.unprotect)(protect_count);
                return WorkResult::TaggedError("r_error", "R evaluation failed".to_owned());
            }
        }

        let decoded = crate::codec::decode_simple_value(api, result);
        (api.unprotect)(protect_count);

        match decoded {
            Ok(value) => WorkResult::EvalString(value),
            Err(message) => WorkResult::TaggedError("unsupported_r_value", message),
        }
    }
}

pub fn do_eval(state: &mut OwnerState, work: &EvalWork) -> WorkResult {
    if let Some(failure) = state.terminal_init_failure() {
        return WorkResult::NativeInitFailed(failure);
    }

    if !state.is_initialized() {
        return WorkResult::Error("embedded R runtime is not initialized".to_owned());
    }

    let Some(api) = state.r.as_ref() else {
        return WorkResult::Error("embedded R API is not initialized".to_owned());
    };

    unsafe { do_eval_inner(api, work) }
}

unsafe fn do_eval_inner(api: &RApi, work: &EvalWork) -> WorkResult {
    let source = match c_string(&work.source) {
        Ok(source) => source,
        Err(error) => return WorkResult::Error(format!("source contains NUL byte: {error}")),
    };

    let mut protect_count = 0;
    let env = (api.protect)((api.new_env)(*api.global_env, 1, 29));
    protect_count += 1;

    for (name, resource) in &work.globals {
        let name = match c_string(name) {
            Ok(name) => name,
            Err(error) => {
                (api.unprotect)(protect_count);
                return WorkResult::Error(format!("global name contains NUL byte: {error}"));
            }
        };

        let value = match crate::codec::live_resource_sexp(resource) {
            Ok(value) => value,
            Err(message) => {
                (api.unprotect)(protect_count);
                return WorkResult::Error(message);
            }
        };

        (api.define_var)((api.install)(name.as_ptr()), value, env);
    }

    let source_sexp = (api.protect)((api.mk_string)(source.as_ptr()));
    protect_count += 1;

    let mut parse_status = ParseStatus::Null;
    let exprs = (api.protect)((api.parse_vector)(
        source_sexp,
        -1,
        &mut parse_status,
        *api.nil_value,
    ));
    protect_count += 1;

    if parse_status != ParseStatus::Ok && parse_status != ParseStatus::Null {
        (api.unprotect)(protect_count);
        return WorkResult::StructuredError(parse_error());
    }

    let expr_count = (api.xlength)(exprs);

    if expr_count == 0 || parse_status == ParseStatus::Null {
        let globals = match collect_eval_globals(api, env) {
            Ok(globals) => globals,
            Err(message) => {
                (api.unprotect)(protect_count);
                return WorkResult::Error(message);
            }
        };

        (api.unprotect)(protect_count);
        return WorkResult::EvalSuccess(EvalSuccess {
            result: None,
            globals,
            output: EvalOutput::empty(),
        });
    }

    let helper = match build_eval_helper(api) {
        Ok(helper) => helper,
        Err(message) => {
            (api.unprotect)(protect_count);
            return WorkResult::Error(message);
        }
    };

    let call = (api.protect)((api.lang3)(helper, exprs, env));
    protect_count += 1;
    (api.release_object)(helper);

    let mut error_occurred = 0;
    let helper_result = (api.protect)((api.try_eval)(call, *api.global_env, &mut error_occurred));
    protect_count += 1;

    if error_occurred != 0
        || (api.type_of)(helper_result) != VECSXP
        || (api.xlength)(helper_result) < 8
    {
        (api.unprotect)(protect_count);
        return WorkResult::Error("R evaluation failed".to_owned());
    }

    if !helper_result_ok(api, helper_result) {
        let error = match structured_error_from_helper_result(api, helper_result) {
            Ok(error) => error,
            Err(message) => {
                (api.unprotect)(protect_count);
                return WorkResult::Error(message);
            }
        };

        (api.unprotect)(protect_count);
        return WorkResult::StructuredError(error);
    }

    let output = match output_from_helper_result(api, helper_result) {
        Ok(output) => output,
        Err(message) => {
            (api.unprotect)(protect_count);
            return WorkResult::Error(message);
        }
    };

    let result = crate::codec::preserved_resource(
        api,
        (api.vector_elt)(helper_result, 1),
        crate::resource::ResourceKind::Generic,
    );

    let globals = match collect_eval_globals(api, env) {
        Ok(globals) => globals,
        Err(message) => {
            (api.unprotect)(protect_count);
            return WorkResult::Error(message);
        }
    };

    (api.unprotect)(protect_count);
    WorkResult::EvalSuccess(EvalSuccess {
        result: Some(result),
        globals,
        output,
    })
}

pub(crate) unsafe fn build_eval_helper(api: &RApi) -> Result<Sexp, String> {
    let source = c_string(EVAL_HELPER_SOURCE).expect("eval helper source contains no NUL bytes");
    let mut protect_count = 0;
    let source_sexp = (api.protect)((api.mk_string)(source.as_ptr()));
    protect_count += 1;

    let mut parse_status = ParseStatus::Null;
    let exprs = (api.protect)((api.parse_vector)(
        source_sexp,
        -1,
        &mut parse_status,
        *api.nil_value,
    ));
    protect_count += 1;

    if parse_status != ParseStatus::Ok || (api.xlength)(exprs) < 1 {
        (api.unprotect)(protect_count);
        return Err("failed to parse native eval helper".to_owned());
    }

    let mut error_occurred = 0;
    let helper = (api.try_eval)(
        (api.vector_elt)(exprs, 0),
        *api.global_env,
        &mut error_occurred,
    );

    if error_occurred != 0 {
        (api.unprotect)(protect_count);
        return Err("failed to initialize native eval helper".to_owned());
    }

    (api.preserve_object)(helper);
    (api.unprotect)(protect_count);
    Ok(helper)
}

pub(crate) unsafe fn helper_result_ok(api: &RApi, helper_result: Sexp) -> bool {
    let ok_value = (api.vector_elt)(helper_result, 0);
    (api.type_of)(ok_value) == LGLSXP
        && (api.xlength)(ok_value) >= 1
        && (api.logical_elt)(ok_value, 0) == 1
}

unsafe fn output_from_helper_result(api: &RApi, helper_result: Sexp) -> Result<EvalOutput, String> {
    Ok(EvalOutput {
        stdout: copy_character_vector_with_newlines(api, (api.vector_elt)(helper_result, 5))?,
        messages: copy_character_vector_with_newlines(api, (api.vector_elt)(helper_result, 6))?,
        warnings: copy_character_vector_with_newlines(api, (api.vector_elt)(helper_result, 7))?,
    })
}

fn parse_error() -> StructuredError {
    StructuredError {
        message: b"parse error: R parse failed".to_vec(),
        r_class: vec![b"parseError".to_vec()],
        call: None,
        output: EvalOutput::empty(),
    }
}

unsafe fn structured_error_from_helper_result(
    api: &RApi,
    helper_result: Sexp,
) -> Result<StructuredError, String> {
    let mut message = copy_optional_character(api, (api.vector_elt)(helper_result, 2))?
        .unwrap_or_else(|| b"R evaluation failed".to_vec());

    if message.is_empty() {
        message = b"R evaluation failed".to_vec();
    }

    Ok(StructuredError {
        message,
        r_class: copy_character_vector(api, (api.vector_elt)(helper_result, 3))?,
        call: copy_optional_character(api, (api.vector_elt)(helper_result, 4))?,
        output: output_from_helper_result(api, helper_result)?,
    })
}

pub(crate) unsafe fn copy_optional_character(
    api: &RApi,
    value: Sexp,
) -> Result<Option<Vec<u8>>, String> {
    if value == *api.nil_value {
        return Ok(None);
    }

    if (api.type_of)(value) != STRSXP || (api.xlength)(value) < 1 {
        return Err("native eval helper returned invalid character scalar".to_owned());
    }

    let bytes = crate::codec::copy_string(api, (api.string_elt)(value, 0))?;
    if bytes.is_empty() {
        Ok(None)
    } else {
        Ok(Some(bytes))
    }
}

pub(crate) unsafe fn copy_character_vector(
    api: &RApi,
    value: Sexp,
) -> Result<Vec<Vec<u8>>, String> {
    if value == *api.nil_value {
        return Ok(Vec::new());
    }

    if (api.type_of)(value) != STRSXP {
        return Err("native eval helper returned invalid character vector".to_owned());
    }

    let len = (api.xlength)(value);
    let capacity =
        usize::try_from(len).map_err(|_| "native character vector is too large".to_owned())?;
    let mut values = Vec::with_capacity(capacity);

    for index in 0..len {
        values.push(crate::codec::copy_string(
            api,
            (api.string_elt)(value, index),
        )?);
    }

    Ok(values)
}

unsafe fn copy_character_vector_with_newlines(api: &RApi, value: Sexp) -> Result<Vec<u8>, String> {
    let values = copy_character_vector(api, value)?;
    let mut output = Vec::new();

    for value in values {
        output.extend_from_slice(&value);
        output.push(b'\n');
    }

    Ok(output)
}

unsafe fn collect_eval_globals(
    api: &RApi,
    env: Sexp,
) -> Result<Vec<(Vec<u8>, rustler::ResourceArc<crate::resource::RxResource>)>, String> {
    let names = (api.protect)((api.ls_internal)(env, 1));
    let len = (api.xlength)(names);
    let capacity = match usize::try_from(len) {
        Ok(capacity) => capacity,
        Err(_error) => {
            (api.unprotect)(1);
            return Err("native eval produced too many globals".to_owned());
        }
    };
    let mut globals = Vec::with_capacity(capacity);

    for index in 0..len {
        let name = crate::codec::copy_string(api, (api.string_elt)(names, index))?;
        let symbol = match std::ffi::CString::new(name.as_slice()) {
            Ok(symbol) => symbol,
            Err(error) => {
                (api.unprotect)(1);
                return Err(format!(
                    "native eval global name contains NUL byte: {error}"
                ));
            }
        };

        let value = (api.find_var_in_frame)(env, (api.install)(symbol.as_ptr()));
        globals.push((
            name,
            crate::codec::preserved_resource(api, value, crate::resource::ResourceKind::Generic),
        ));
    }

    (api.unprotect)(1);
    Ok(globals)
}