Skip to main content

native/rustq_nif/src/template.rs

use std::collections::HashMap;

use rustler::NifMap;
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::visit_mut::{self, VisitMut};
use syn::{
    Arm, Expr, ExprMatch, ExprStruct, Field, FieldValue, Fields, File, FnArg, ImplItem, Item,
    Lifetime, Pat, Signature, Stmt, Type,
};

#[derive(NifMap)]
pub(crate) struct ErrorInfo {
    r#type: String,
    context: String,
    message: String,
    name: Option<String>,
    fragment: Option<String>,
}

struct Context {
    bindings: HashMap<String, String>,
    splices: HashMap<String, Vec<String>>,
}

pub(crate) fn template_error(error: syn::Error) -> ErrorInfo {
    ErrorInfo {
        r#type: "invalid_template".to_string(),
        context: "template".to_string(),
        message: error.to_string(),
        name: None,
        fragment: None,
    }
}

fn splice_error(context: &str, name: &str, fragment: &str, error: syn::Error) -> ErrorInfo {
    ErrorInfo {
        r#type: "invalid_splice".to_string(),
        context: context.to_string(),
        message: error.to_string(),
        name: Some(name.to_string()),
        fragment: Some(fragment.to_string()),
    }
}

fn binding_error(context: &str, name: &str, value: &str, error: syn::Error) -> ErrorInfo {
    ErrorInfo {
        r#type: "invalid_binding".to_string(),
        context: context.to_string(),
        message: error.to_string(),
        name: Some(name.to_string()),
        fragment: Some(value.to_string()),
    }
}

pub(crate) fn render_source(
    source: &str,
    bindings: Vec<(String, String)>,
    splices: Vec<(String, Vec<String>)>,
) -> Result<String, Vec<ErrorInfo>> {
    let mut file = syn::parse_file(source).map_err(|error| vec![template_error(error)])?;
    let context = Context {
        bindings: bindings.into_iter().collect(),
        splices: splices.into_iter().collect(),
    };

    splice_file(&mut file, &context)?;

    let mut binder = Binder::new(&context.bindings);
    binder.visit_file_mut(&mut file);
    binder.finish()?;

    Ok(prettyplease::unparse(&file))
}

fn splice_file(file: &mut File, context: &Context) -> Result<(), Vec<ErrorInfo>> {
    splice_items(&mut file.items, context)
}

fn splice_items(items: &mut Vec<Item>, context: &Context) -> Result<(), Vec<ErrorInfo>> {
    let mut next = Vec::new();

    for mut item in std::mem::take(items) {
        if let Some(name) = item_splice_name(&item) {
            next.extend(parse_items(&name, context)?);
        } else {
            splice_item(&mut item, context)?;
            next.push(item);
        }
    }

    *items = next;
    Ok(())
}

fn splice_item(item: &mut Item, context: &Context) -> Result<(), Vec<ErrorInfo>> {
    match item {
        Item::Impl(item_impl) => splice_impl_items(&mut item_impl.items, context),
        Item::Mod(item_mod) => {
            if let Some((_, items)) = &mut item_mod.content {
                splice_items(items, context)?;
            }
            Ok(())
        }
        Item::Struct(item_struct) => splice_fields(&mut item_struct.fields, context),
        Item::Fn(item_fn) => {
            splice_signature_inputs(&mut item_fn.sig, context)?;
            splice_stmts(&mut item_fn.block.stmts, context)
        }
        _ => Ok(()),
    }
}

fn splice_impl_items(items: &mut Vec<ImplItem>, context: &Context) -> Result<(), Vec<ErrorInfo>> {
    let mut next = Vec::new();

    for mut item in std::mem::take(items) {
        if let Some(name) = impl_item_splice_name(&item) {
            next.extend(parse_impl_items(&name, context)?);
        } else {
            if let ImplItem::Fn(item_fn) = &mut item {
                splice_signature_inputs(&mut item_fn.sig, context)?;
                splice_stmts(&mut item_fn.block.stmts, context)?;
            }
            next.push(item);
        }
    }

    *items = next;
    Ok(())
}

fn splice_signature_inputs(
    signature: &mut Signature,
    context: &Context,
) -> Result<(), Vec<ErrorInfo>> {
    let mut next = Punctuated::<FnArg, Comma>::new();

    for input in std::mem::take(&mut signature.inputs) {
        if let Some(name) = arg_splice_name(&input) {
            for parsed in parse_args(&name, context)? {
                next.push(parsed);
            }
        } else {
            next.push(input);
        }
    }

    signature.inputs = next;
    Ok(())
}

fn splice_fields(fields: &mut Fields, context: &Context) -> Result<(), Vec<ErrorInfo>> {
    let Fields::Named(fields_named) = fields else {
        return Ok(());
    };

    let mut next = Punctuated::<Field, Comma>::new();

    for field in std::mem::take(&mut fields_named.named) {
        if let Some(name) = field_splice_name(&field) {
            for parsed in parse_fields(&name, context)? {
                next.push(parsed);
            }
        } else {
            next.push(field);
        }
    }

    fields_named.named = next;
    Ok(())
}

fn splice_stmts(stmts: &mut Vec<Stmt>, context: &Context) -> Result<(), Vec<ErrorInfo>> {
    let mut next = Vec::new();

    for stmt in std::mem::take(stmts) {
        if let Some(name) = stmt_splice_name(&stmt) {
            next.extend(parse_stmts(&name, context)?);
        } else {
            let mut stmt = stmt;
            let mut splicer = Splicer::new(context);
            splicer.visit_stmt_mut(&mut stmt);
            splicer.finish()?;
            next.push(stmt);
        }
    }

    *stmts = next;
    Ok(())
}

fn item_splice_name(item: &Item) -> Option<String> {
    let Item::Macro(item_macro) = item else {
        return None;
    };
    splice_name(&item_macro.mac.path)
}

fn impl_item_splice_name(item: &ImplItem) -> Option<String> {
    let ImplItem::Macro(item_macro) = item else {
        return None;
    };
    splice_name(&item_macro.mac.path)
}

fn arg_splice_name(arg: &FnArg) -> Option<String> {
    let FnArg::Typed(pat_type) = arg else {
        return None;
    };

    let Pat::Ident(pat_ident) = pat_type.pat.as_ref() else {
        return None;
    };

    pat_ident
        .ident
        .to_string()
        .strip_prefix("__rq_")
        .map(str::to_string)
}

fn stmt_splice_name(stmt: &Stmt) -> Option<String> {
    let Stmt::Macro(stmt_macro) = stmt else {
        return None;
    };
    splice_name(&stmt_macro.mac.path)
}

fn field_splice_name(field: &Field) -> Option<String> {
    let ident = field.ident.as_ref()?;
    ident.to_string().strip_prefix("__rq_").map(str::to_string)
}

fn splice_name(path: &syn::Path) -> Option<String> {
    let ident = path.get_ident()?;
    ident.to_string().strip_prefix("__rq_").map(str::to_string)
}

fn parse_items(name: &str, context: &Context) -> Result<Vec<Item>, Vec<ErrorInfo>> {
    parse_many_fragments("item", name, context, |source| {
        Ok(syn::parse_str::<syn::File>(source)?.items)
    })
}

fn parse_impl_items(name: &str, context: &Context) -> Result<Vec<ImplItem>, Vec<ErrorInfo>> {
    parse_many_fragments("impl_item", name, context, |source| {
        let wrapped = format!("impl __RustQ {{ {source} }}");
        let item = syn::parse_str::<syn::ItemImpl>(&wrapped)?;
        Ok(item.items)
    })
}

fn parse_fields(name: &str, context: &Context) -> Result<Vec<Field>, Vec<ErrorInfo>> {
    parse_fragments("field", name, context, parse_field)
}

fn parse_args(name: &str, context: &Context) -> Result<Vec<FnArg>, Vec<ErrorInfo>> {
    parse_fragments("arg", name, context, syn::parse_str::<FnArg>)
}

fn parse_field_values(name: &str, context: &Context) -> Result<Vec<FieldValue>, Vec<ErrorInfo>> {
    parse_fragments("field_value", name, context, parse_field_value)
}

fn parse_stmts(name: &str, context: &Context) -> Result<Vec<Stmt>, Vec<ErrorInfo>> {
    parse_fragments("stmt", name, context, syn::parse_str::<Stmt>)
}

fn parse_arms(name: &str, context: &Context) -> Result<Vec<Arm>, Vec<ErrorInfo>> {
    parse_fragments("arm", name, context, syn::parse_str::<Arm>)
}

fn parse_field(source: &str) -> syn::Result<Field> {
    let wrapped = format!("struct __RustQ {{ {source} }}");
    let item = syn::parse_str::<syn::ItemStruct>(&wrapped)?;
    item.fields
        .into_iter()
        .next()
        .ok_or_else(|| syn::Error::new(proc_macro2::Span::call_site(), "expected field"))
}

fn parse_field_value(source: &str) -> syn::Result<FieldValue> {
    let wrapped = format!("__RustQ {{ {source} }}");
    let expr = syn::parse_str::<ExprStruct>(&wrapped)?;
    expr.fields
        .into_iter()
        .next()
        .ok_or_else(|| syn::Error::new(proc_macro2::Span::call_site(), "expected field value"))
}

fn parse_fragments<T, F>(
    fragment_context: &str,
    name: &str,
    context: &Context,
    mut parse: F,
) -> Result<Vec<T>, Vec<ErrorInfo>>
where
    F: FnMut(&str) -> syn::Result<T>,
{
    parse_many_fragments(fragment_context, name, context, |source| {
        parse(source).map(|value| vec![value])
    })
}

fn parse_many_fragments<T, F>(
    fragment_context: &str,
    name: &str,
    context: &Context,
    mut parse: F,
) -> Result<Vec<T>, Vec<ErrorInfo>>
where
    F: FnMut(&str) -> syn::Result<Vec<T>>,
{
    let Some(fragments) = context.splices.get(name) else {
        return Ok(Vec::new());
    };

    let mut parsed = Vec::new();
    let mut errors = Vec::new();

    for fragment in fragments {
        match parse(fragment) {
            Ok(values) => parsed.extend(values),
            Err(error) => errors.push(splice_error(fragment_context, name, fragment, error)),
        }
    }

    if errors.is_empty() {
        Ok(parsed)
    } else {
        Err(errors)
    }
}

struct Splicer<'a> {
    context: &'a Context,
    errors: Vec<ErrorInfo>,
}

impl<'a> Splicer<'a> {
    fn new(context: &'a Context) -> Self {
        Self {
            context,
            errors: Vec::new(),
        }
    }

    fn finish(self) -> Result<(), Vec<ErrorInfo>> {
        if self.errors.is_empty() {
            Ok(())
        } else {
            Err(self.errors)
        }
    }
}

impl VisitMut for Splicer<'_> {
    fn visit_expr_struct_mut(&mut self, expr_struct: &mut ExprStruct) {
        let mut next = Punctuated::<FieldValue, Comma>::new();

        for mut field in std::mem::take(&mut expr_struct.fields) {
            if let Some(name) = field_value_splice_name(&field) {
                match parse_field_values(&name, self.context) {
                    Ok(fields) => {
                        for field in fields {
                            next.push(field);
                        }
                    }
                    Err(errors) => self.errors.extend(errors),
                }
            } else {
                visit_mut::visit_field_value_mut(self, &mut field);
                next.push(field);
            }
        }

        expr_struct.fields = next;
    }

    fn visit_expr_match_mut(&mut self, expr_match: &mut ExprMatch) {
        let mut next = Vec::new();

        for mut arm in std::mem::take(&mut expr_match.arms) {
            if let Some(name) = arm_splice_name(&arm) {
                match parse_arms(&name, self.context) {
                    Ok(arms) => next.extend(arms),
                    Err(errors) => self.errors.extend(errors),
                }
            } else {
                visit_mut::visit_arm_mut(self, &mut arm);
                next.push(arm);
            }
        }

        expr_match.arms = next;
        visit_mut::visit_expr_mut(self, &mut expr_match.expr);
    }
}

fn field_value_splice_name(field: &FieldValue) -> Option<String> {
    let syn::Member::Named(ident) = &field.member else {
        return None;
    };

    ident.to_string().strip_prefix("__rq_").map(str::to_string)
}

fn arm_splice_name(arm: &Arm) -> Option<String> {
    let Pat::Ident(pat_ident) = &arm.pat else {
        return None;
    };

    pat_ident
        .ident
        .to_string()
        .strip_prefix("__rq_")
        .map(str::to_string)
}

struct Binder<'a> {
    bindings: &'a HashMap<String, String>,
    errors: Vec<ErrorInfo>,
}

impl<'a> Binder<'a> {
    fn new(bindings: &'a HashMap<String, String>) -> Self {
        Self {
            bindings,
            errors: Vec::new(),
        }
    }

    fn finish(self) -> Result<(), Vec<ErrorInfo>> {
        if self.errors.is_empty() {
            Ok(())
        } else {
            Err(self.errors)
        }
    }

    fn binding_for_ident(&self, ident: &syn::Ident, prefix: &str) -> Option<&'a str> {
        ident
            .to_string()
            .strip_prefix(prefix)
            .and_then(|name| self.bindings.get(name))
            .map(String::as_str)
    }

    fn binding_for_macro_path(&self, path: &syn::Path, prefix: &str) -> Option<(String, &'a str)> {
        let ident = path.get_ident()?;
        let ident = ident.to_string();
        let name = ident.strip_prefix(prefix)?.to_string();
        let value = self.bindings.get(&name)?;
        Some((name, value.as_str()))
    }
}

impl VisitMut for Binder<'_> {
    fn visit_ident_mut(&mut self, ident: &mut syn::Ident) {
        if let Some(value) = self.binding_for_ident(ident, "__rq_") {
            match syn::parse_str::<syn::Ident>(value) {
                Ok(parsed) => *ident = parsed,
                Err(error) => self.errors.push(binding_error(
                    "ident_binding",
                    &ident.to_string(),
                    value,
                    error,
                )),
            }
        }
    }

    fn visit_lifetime_mut(&mut self, lifetime: &mut Lifetime) {
        let name = lifetime.ident.to_string();

        if let Some(value) = name
            .strip_prefix("__rq_")
            .and_then(|name| self.bindings.get(name))
        {
            let value = value.trim_start_matches('\'');
            *lifetime = Lifetime::new(&format!("'{value}"), lifetime.apostrophe);
        }
    }

    fn visit_expr_mut(&mut self, expr: &mut Expr) {
        if let Expr::Macro(expr_macro) = expr {
            if let Some((name, value)) = self.binding_for_macro_path(&expr_macro.mac.path, "__rq_")
            {
                match syn::parse_str::<Expr>(value) {
                    Ok(parsed) => *expr = parsed,
                    Err(error) => {
                        self.errors
                            .push(binding_error("expr_binding", &name, value, error))
                    }
                }
                return;
            }
        }

        visit_mut::visit_expr_mut(self, expr);
    }

    fn visit_type_mut(&mut self, ty: &mut Type) {
        if let Type::Macro(type_macro) = ty {
            if let Some((name, value)) = self.binding_for_macro_path(&type_macro.mac.path, "__rq_")
            {
                match syn::parse_str::<Type>(value) {
                    Ok(parsed) => *ty = parsed,
                    Err(error) => {
                        self.errors
                            .push(binding_error("type_binding", &name, value, error))
                    }
                }
                return;
            }
        }

        visit_mut::visit_type_mut(self, ty);
    }

    fn visit_arm_mut(&mut self, arm: &mut Arm) {
        visit_mut::visit_arm_mut(self, arm);
    }
}