Skip to main content

native/rx_rust_nif/src/plot.rs

use crate::owner::OwnerState;
use crate::r_api::{c_string, ParseStatus, RApi, Sexp, INTSXP, RAWSXP, STRSXP, VECSXP};
use crate::work::{EvalOutput, PlotSuccess, PlotWork, StructuredError, WorkResult};
use libc::c_void;

const PLOT_HELPER_SOURCE: &str = r#"function(exprs, env, options) {
  load_packages <- function() {
    for (pkg in c("grDevices", "graphics", "stats", "utils", "datasets")) {
      if (!paste0("package:", pkg) %in% search()) {
        suppressPackageStartupMessages(library(pkg, character.only = TRUE, pos = length(search())))
      }
    }
  }

  is_ggplot <- function(value) {
    inherits(value, "ggplot") || inherits(value, "ggplot2::ggplot")
  }

  read_pages <- function(tmpdir, max_pages, max_bytes) {
    files <- sort(list.files(tmpdir, pattern = "^page-[0-9]+[.]png$", full.names = TRUE))
    if (length(files) > max_pages) stop("plot produced too many pages", call. = FALSE)
    if (length(files) == 0L) return(list())
    info <- file.info(files)
    if (any(is.na(info$size) | info$size <= 0L)) {
      stop("plot device produced an empty PNG file", call. = FALSE)
    }
    if (sum(info$size) > max_bytes) {
      stop("plot PNG output exceeds byte limit", call. = FALSE)
    }
    sig <- as.raw(c(0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a))
    lapply(seq_along(files), function(i) {
      bytes <- readBin(files[[i]], what = "raw", n = info$size[[i]])
      if (length(bytes) < 8L || !identical(bytes[1:8], sig)) {
        stop("plot device did not produce a valid PNG file", call. = FALSE)
      }
      bytes
    })
  }

  messages <- character()
  warnings <- character()
  plot_error <- NULL
  close_error <- NULL
  read_error <- NULL
  opts <- NULL
  plot_dev <- NULL
  old_devices <- integer()
  old_dev <- 1L
  old_options <- NULL
  devices_snapshot <- FALSE
  device_sentinel <- function(...) {
    stop("Rx plot capture device is not available", call. = FALSE)
  }

  close_plot_device <- function() {
    if (!devices_snapshot) return()

    open_devices <- grDevices::dev.list()
    if (!is.null(open_devices)) {
      opened_devices <- setdiff(unname(as.integer(open_devices)), old_devices)
      for (device in rev(sort(opened_devices))) {
        tryCatch(
          invisible(grDevices::dev.off(which = device)),
          error = function(e) { if (is.null(close_error)) close_error <<- e }
        )
      }
    }

    open_devices <- grDevices::dev.list()
    if (!is.null(plot_dev) && !is.null(open_devices) &&
        plot_dev %in% unname(as.integer(open_devices))) {
      tryCatch(
        invisible(grDevices::dev.off(which = plot_dev)),
        error = function(e) { if (is.null(close_error)) close_error <<- e }
      )
    }

    open_devices <- grDevices::dev.list()
    if (!is.null(open_devices) && old_dev != 1L &&
        old_dev %in% unname(as.integer(open_devices))) {
      try(invisible(grDevices::dev.set(which = old_dev)), silent = TRUE)
    }
  }

  tmpdir <- tempfile("rx-native-plot-")
  dir.create(tmpdir, mode = "0700")
  on.exit(unlink(tmpdir, recursive = TRUE, force = TRUE), add = TRUE)

  stdout <- utils::capture.output({
    invisible(withCallingHandlers(
      tryCatch(
        {
          load_packages()
          opts <- list(
            width = as.integer(options$width),
            height = as.integer(options$height),
            res = as.integer(options$res),
            pointsize = as.integer(options$pointsize),
            max_pages = as.integer(options$max_pages),
            max_bytes = as.numeric(options$max_bytes)
          )

          old_devices <- grDevices::dev.list()
          old_devices <- if (is.null(old_devices)) integer() else unname(as.integer(old_devices))
          old_dev <- unname(as.integer(grDevices::dev.cur()))
          devices_snapshot <- TRUE

          old_options <- base::options(device = device_sentinel)

          file_pattern <- file.path(tmpdir, "page-%06d.png")
          png_args <- list(
            filename = file_pattern,
            width = opts$width,
            height = opts$height,
            units = "px",
            pointsize = opts$pointsize,
            bg = "white",
            res = opts$res
          )
          if (isTRUE(capabilities("cairo"))) {
            png_args$type <- "cairo"
          }

          do.call(grDevices::png, png_args)
          plot_dev <- unname(as.integer(grDevices::dev.cur()))

          for (expr in exprs) {
            evaluated <- withVisible(eval(expr, envir = env))
            if (isTRUE(evaluated$visible) && is_ggplot(evaluated$value)) {
              print(evaluated$value)
            }
          }

          if (!identical(base::getOption("device"), device_sentinel)) {
            stop("Rx plot capture does not allow changing options(device=...)", call. = FALSE)
          }

          NULL
        },
        error = function(e) { plot_error <<- e; NULL },
        finally = {
          close_plot_device()
          if (!is.null(old_options)) base::options(old_options)
        }
      ),
      message = function(m) {
        messages <<- c(messages, conditionMessage(m))
        invokeRestart("muffleMessage")
      },
      warning = function(w) {
        warnings <<- c(warnings, conditionMessage(w))
        invokeRestart("muffleWarning")
      }
    ))
  })

  output_stdout <- paste(stdout, collapse = "\n")
  output_messages <- paste(messages, collapse = "")
  output_warnings <- paste(warnings, collapse = "")

  error_result <- function(e) {
    call <- conditionCall(e)
    call_text <- if (is.null(call)) NULL else paste(deparse(call), collapse = "\n")
    list(FALSE, NULL, conditionMessage(e), class(e), call_text,
         output_stdout, output_messages, output_warnings)
  }

  if (!is.null(plot_error)) return(error_result(plot_error))
  if (!is.null(close_error)) return(error_result(close_error))

  pages <- tryCatch(
    read_pages(tmpdir, opts$max_pages, opts$max_bytes),
    error = function(e) { read_error <<- e; NULL }
  )

  if (!is.null(read_error)) return(error_result(read_error))
  if (length(pages) == 0L) return(error_result(simpleError("R code produced no plot")))

  payload <- list(width = opts$width, height = opts$height, pages = pages)
  list(TRUE, payload, NULL, NULL, NULL, output_stdout, output_messages, output_warnings)
}"#;

pub fn do_plot(state: &mut OwnerState, work: &PlotWork) -> WorkResult {
    if let Some(failure) = state.terminal_init_failure() {
        return WorkResult::NativeInitFailed(failure);
    }

    if !state.is_initialized() {
        return WorkResult::Error("embedded R runtime is not initialized".to_owned());
    }

    let Some(api) = state.r.as_ref() else {
        return WorkResult::Error("embedded R API is not initialized".to_owned());
    };

    let mut ctx = PlotExecContext {
        api: api as *const RApi,
        work: work as *const PlotWork,
        result: None,
    };

    let ok = unsafe {
        (api.toplevel_exec)(
            plot_toplevel_exec,
            &mut ctx as *mut PlotExecContext as *mut c_void,
        )
    };

    if ok == 0 {
        return WorkResult::Error("R non-local error during plot".to_owned());
    }

    ctx.result
        .unwrap_or_else(|| WorkResult::Error("R plot failed".to_owned()))
}

struct PlotExecContext {
    api: *const RApi,
    work: *const PlotWork,
    result: Option<WorkResult>,
}

unsafe extern "C" fn plot_toplevel_exec(data: *mut c_void) {
    let ctx = &mut *(data as *mut PlotExecContext);
    ctx.result = Some(do_plot_inner(&*ctx.api, &*ctx.work));
}

unsafe fn do_plot_inner(api: &RApi, work: &PlotWork) -> WorkResult {
    let source = match c_string(&work.source) {
        Ok(source) => source,
        Err(error) => return WorkResult::Error(format!("source contains NUL byte: {error}")),
    };

    let mut protect_count = 0;
    let env = (api.protect)((api.new_env)(*api.global_env, 1, 29));
    protect_count += 1;

    for (name, resource) in &work.globals {
        let name = match c_string(name) {
            Ok(name) => name,
            Err(error) => {
                (api.unprotect)(protect_count);
                return WorkResult::Error(format!("global name contains NUL byte: {error}"));
            }
        };

        let value = match crate::codec::live_resource_sexp(resource) {
            Ok(value) => value,
            Err(message) => {
                (api.unprotect)(protect_count);
                return WorkResult::Error(message);
            }
        };

        (api.define_var)((api.install)(name.as_ptr()), value, env);
    }

    let source_sexp = (api.protect)((api.mk_string)(source.as_ptr()));
    protect_count += 1;

    let mut parse_status = ParseStatus::Null;
    let exprs = (api.protect)((api.parse_vector)(
        source_sexp,
        -1,
        &mut parse_status,
        *api.nil_value,
    ));
    protect_count += 1;

    if parse_status != ParseStatus::Ok && parse_status != ParseStatus::Null {
        (api.unprotect)(protect_count);
        return WorkResult::StructuredError(parse_error());
    }

    let helper = match build_plot_helper(api) {
        Ok(helper) => helper,
        Err(message) => {
            (api.unprotect)(protect_count);
            return WorkResult::Error(message);
        }
    };

    let options = match make_plot_options(api, work, &mut protect_count) {
        Ok(options) => options,
        Err(message) => {
            (api.release_object)(helper);
            (api.unprotect)(protect_count);
            return WorkResult::Error(message);
        }
    };

    let call = (api.protect)((api.lang4)(helper, exprs, env, options));
    protect_count += 1;
    (api.release_object)(helper);

    let mut error_occurred = 0;
    let helper_result = (api.protect)((api.try_eval)(call, *api.global_env, &mut error_occurred));
    protect_count += 1;

    if error_occurred != 0
        || (api.type_of)(helper_result) != VECSXP
        || (api.xlength)(helper_result) < 8
    {
        (api.unprotect)(protect_count);
        return WorkResult::Error("R plot failed".to_owned());
    }

    if !crate::eval::helper_result_ok(api, helper_result) {
        let error = match structured_plot_error_from_helper_result(api, helper_result) {
            Ok(error) => error,
            Err(message) => {
                (api.unprotect)(protect_count);
                return WorkResult::Error(message);
            }
        };

        (api.unprotect)(protect_count);
        return WorkResult::StructuredError(error);
    }

    let success = match plot_success_from_helper_result(api, helper_result, work) {
        Ok(success) => success,
        Err(message) => {
            (api.unprotect)(protect_count);
            return WorkResult::Error(message);
        }
    };

    (api.unprotect)(protect_count);
    WorkResult::PlotSuccess(success)
}

unsafe fn build_plot_helper(api: &RApi) -> Result<Sexp, String> {
    let source = c_string(PLOT_HELPER_SOURCE).expect("plot helper source contains no NUL bytes");
    let mut protect_count = 0;
    let source_sexp = (api.protect)((api.mk_string)(source.as_ptr()));
    protect_count += 1;

    let mut parse_status = ParseStatus::Null;
    let exprs = (api.protect)((api.parse_vector)(
        source_sexp,
        -1,
        &mut parse_status,
        *api.nil_value,
    ));
    protect_count += 1;

    if parse_status != ParseStatus::Ok || (api.xlength)(exprs) < 1 {
        (api.unprotect)(protect_count);
        return Err("failed to parse native plot helper".to_owned());
    }

    let mut error_occurred = 0;
    let helper = (api.try_eval)(
        (api.vector_elt)(exprs, 0),
        *api.global_env,
        &mut error_occurred,
    );

    if error_occurred != 0 {
        (api.unprotect)(protect_count);
        return Err("failed to initialize native plot helper".to_owned());
    }

    (api.preserve_object)(helper);
    (api.unprotect)(protect_count);
    Ok(helper)
}

unsafe fn make_plot_options(
    api: &RApi,
    work: &PlotWork,
    protect_count: &mut i32,
) -> Result<Sexp, String> {
    let options = (api.protect)((api.alloc_vector)(VECSXP, 6));
    *protect_count += 1;
    let names = (api.protect)((api.alloc_vector)(STRSXP, 6));
    *protect_count += 1;

    let option_names = [
        "width",
        "height",
        "res",
        "pointsize",
        "max_pages",
        "max_bytes",
    ];
    let option_values = [
        work.width,
        work.height,
        work.res,
        work.pointsize,
        work.max_pages,
        work.max_bytes,
    ];

    for (index, (name, value)) in option_names.iter().zip(option_values.iter()).enumerate() {
        let name = c_string(name).expect("plot option name contains no NUL bytes");
        let name_sexp = (api.protect)((api.mk_char_len_ce)(
            name.as_ptr(),
            name.as_bytes().len() as i32,
            crate::r_api::CE_UTF8,
        ));
        *protect_count += 1;

        (api.set_string_elt)(names, index as isize, name_sexp);
        (api.set_vector_elt)(options, index as isize, (api.scalar_integer)(*value));
    }

    (api.set_attrib)(options, *api.names_symbol, names);
    Ok(options)
}

unsafe fn plot_success_from_helper_result(
    api: &RApi,
    helper_result: Sexp,
    work: &PlotWork,
) -> Result<PlotSuccess, String> {
    let payload = (api.vector_elt)(helper_result, 1);
    if (api.type_of)(payload) != VECSXP || (api.xlength)(payload) < 3 {
        return Err("native plot helper returned malformed payload".to_owned());
    }

    let width = (api.vector_elt)(payload, 0);
    let height = (api.vector_elt)(payload, 1);
    let pages = (api.vector_elt)(payload, 2);

    if (api.type_of)(width) != INTSXP
        || (api.xlength)(width) < 1
        || (api.type_of)(height) != INTSXP
        || (api.xlength)(height) < 1
        || (api.type_of)(pages) != VECSXP
    {
        return Err("native plot helper returned invalid payload fields".to_owned());
    }

    let page_count = (api.xlength)(pages);
    let max_pages =
        usize::try_from(work.max_pages).map_err(|_| "plot produced too many pages".to_owned())?;

    if page_count <= 0 || page_count > work.max_pages as isize || page_count > 1000 {
        return Err("plot produced too many pages".to_owned());
    }

    let capacity = usize::try_from(page_count)
        .map_err(|_| "native plot page count is too large".to_owned())?;
    if capacity > max_pages {
        return Err("plot produced too many pages".to_owned());
    }

    let max_bytes = usize::try_from(work.max_bytes)
        .map_err(|_| "plot PNG output exceeds byte limit".to_owned())?;
    let mut copied_pages = Vec::with_capacity(capacity);
    let mut total_bytes: usize = 0;

    for index in 0..page_count {
        let page = (api.vector_elt)(pages, index);
        if (api.type_of)(page) != RAWSXP {
            return Err("native plot helper returned non-raw page".to_owned());
        }

        let len = usize::try_from((api.xlength)(page))
            .map_err(|_| "native plot page is too large".to_owned())?;
        if len == 0 {
            return Err("plot device produced an empty PNG file".to_owned());
        }

        total_bytes = total_bytes
            .checked_add(len)
            .ok_or_else(|| "plot PNG output exceeds byte limit".to_owned())?;
        if total_bytes > max_bytes {
            return Err("plot PNG output exceeds byte limit".to_owned());
        }

        let raw_ptr = (api.raw)(page);
        if raw_ptr.is_null() {
            return Err("native plot helper returned invalid raw page".to_owned());
        }

        let raw = std::slice::from_raw_parts(raw_ptr, len);
        if !is_png_bytes(raw) {
            return Err("plot device did not produce a valid PNG file".to_owned());
        }

        copied_pages.push(raw.to_vec());
    }

    Ok(PlotSuccess {
        width: (api.integer_elt)(width, 0),
        height: (api.integer_elt)(height, 0),
        pages: copied_pages,
        output: output_from_plot_helper_result(api, helper_result)?,
    })
}

unsafe fn structured_plot_error_from_helper_result(
    api: &RApi,
    helper_result: Sexp,
) -> Result<StructuredError, String> {
    let mut message =
        crate::eval::copy_optional_character(api, (api.vector_elt)(helper_result, 2))?
            .unwrap_or_else(|| b"R plot failed".to_vec());

    if message.is_empty() {
        message = b"R plot failed".to_vec();
    }

    Ok(StructuredError {
        message,
        r_class: crate::eval::copy_character_vector(api, (api.vector_elt)(helper_result, 3))?,
        call: crate::eval::copy_optional_character(api, (api.vector_elt)(helper_result, 4))?,
        output: output_from_plot_helper_result(api, helper_result)?,
    })
}

unsafe fn output_from_plot_helper_result(
    api: &RApi,
    helper_result: Sexp,
) -> Result<EvalOutput, String> {
    Ok(EvalOutput {
        stdout: copy_character_vector_join(api, (api.vector_elt)(helper_result, 5), b"\n")?,
        messages: copy_character_vector_join(api, (api.vector_elt)(helper_result, 6), b"")?,
        warnings: copy_character_vector_join(api, (api.vector_elt)(helper_result, 7), b"")?,
    })
}

unsafe fn copy_character_vector_join(
    api: &RApi,
    value: Sexp,
    separator: &[u8],
) -> Result<Vec<u8>, String> {
    if value == *api.nil_value {
        return Ok(Vec::new());
    }

    if (api.type_of)(value) != STRSXP {
        return Err("native plot helper returned invalid character vector".to_owned());
    }

    let len = (api.xlength)(value);
    let mut output = Vec::new();

    for index in 0..len {
        if (!output.is_empty() || index > 0) && !separator.is_empty() {
            output.extend_from_slice(separator);
        }

        let string = crate::codec::copy_string(api, (api.string_elt)(value, index))?;
        output.extend_from_slice(&string);
    }

    Ok(output)
}

fn is_png_bytes(bytes: &[u8]) -> bool {
    const PNG_SIGNATURE: &[u8; 8] = b"\x89PNG\r\n\x1a\n";
    bytes.len() >= PNG_SIGNATURE.len() && &bytes[..PNG_SIGNATURE.len()] == PNG_SIGNATURE
}

fn parse_error() -> StructuredError {
    StructuredError {
        message: b"parse error: R parse failed".to_vec(),
        r_class: vec![b"parseError".to_vec()],
        call: None,
        output: EvalOutput::empty(),
    }
}