defmodule DotPrompt.Parser.Validator do
@moduledoc """
Walks the AST checking types, bounds, nesting depth, and params.
"""
@max_nesting 3
def validate(ast) do
case collect_errors(ast.body, 0) do
[] ->
with :ok <- validate_params_declared(ast),
:ok <- validate_fragments_declared(ast),
:ok <- validate_fragments(ast) do
{:ok, []}
end
errors ->
{:error, Enum.join(errors, "; ")}
end
end
defp validate_fragments(%{init: nil}), do: :ok
defp validate_fragments(ast) do
fragments = parse_fragment_declarations(ast.init)
declarations = parse_param_declarations(ast.init)
errors =
Enum.reduce(fragments, [], fn {_name, spec}, acc ->
from = Map.get(spec, :from)
matchre = Map.get(spec, :matchre) || Map.get(spec, :matchRe)
acc =
if from && String.ends_with?(from, "/") do
["invalid_fragment_path: trailing slashes not allowed in '#{from}'" | acc]
else
acc
end
if matchre && String.starts_with?(matchre, "@") do
var_name = matchre
param_spec = Map.get(declarations, var_name)
cond do
is_nil(param_spec) ->
["unknown_variable: #{var_name} referenced in matchRe but not declared" | acc]
param_spec.type != :enum ->
[
"invalid_matchre_type: matchRe requires enum variable, but #{var_name} is #{param_spec.type}"
| acc
]
true ->
acc
end
else
acc
end
end)
if errors == [], do: :ok, else: {:error, Enum.join(Enum.reverse(errors), "; ")}
end
defp collect_errors(_nodes, depth) when depth > @max_nesting do
["nesting_exceeded: depth #{depth} exceeds maximum of #{@max_nesting}"]
end
defp collect_errors(nodes, depth) when is_list(nodes) do
Enum.flat_map(nodes, fn node ->
case node do
{:if, _, _, then_nodes, elifs, else_node} ->
branch_results =
[{nil, else_node} | elifs]
|> Enum.flat_map(fn {_, branch_nodes} ->
collect_errors(branch_nodes || [], depth + 1)
end)
collect_errors(then_nodes || [], depth + 1) ++ branch_results
{:case, _, branches} ->
Enum.flat_map(branches, fn
{_id, _label, nodes} -> collect_errors(nodes || [], depth + 1)
{:if, _, _, _, _, _} = n -> collect_errors([n], depth + 1)
_ -> []
end)
{:vary, nil, _branches} ->
["invalid_vary: vary requires an enum variable"]
{:vary, _var, branches} ->
Enum.flat_map(branches, fn
{_id, _label, nodes} -> collect_errors(nodes || [], depth + 1)
{:if, _, _, _, _, _} = n -> collect_errors([n], depth + 1)
_ -> []
end)
_ ->
[]
end
end)
end
defp collect_errors(_, _), do: []
def get_warnings(ast) do
case validate(ast) do
{:ok, warnings} -> warnings
_ -> []
end
end
def validate_params(params, declarations) do
case validate_params_present(params, declarations) do
:ok -> validate_params_types(params, declarations)
error -> error
end
end
defp validate_fragments_declared(ast) do
declared_fragments = parse_fragment_declarations(ast.init) |> Map.keys() |> MapSet.new()
{static_fragments, _dynamic_fragments} =
extract_fragments_from_body(ast.body, [])
|> Enum.reduce({[], []}, fn fragment, {static, dynamic} ->
case fragment do
{:static, path} -> {[path | static], dynamic}
{:dynamic, path} -> {static, [path | dynamic]}
end
end)
unknown_static =
Enum.reject(static_fragments, fn raw ->
name = raw |> String.trim_leading("{") |> String.trim_trailing("}")
MapSet.member?(declared_fragments, name)
end)
if unknown_static == [] do
:ok
else
{:error,
"unknown_fragment: #{hd(unknown_static)} referenced but not declared in init block. Inline fragment declarations are no longer supported."}
end
end
defp extract_fragments_from_body(nodes, acc) when is_list(nodes) do
Enum.reduce(nodes, acc, fn node, current_acc ->
case node do
{:fragment_static, path} ->
[{:static, path} | current_acc]
{:fragment_dynamic, path} ->
[{:dynamic, path} | current_acc]
{:if, _var, _cond, then_nodes, elifs, else_node} ->
branches = [then_nodes | [else_node | Enum.map(elifs, &elem(&1, 1))]]
extract_fragments_from_body(branches, current_acc)
{:case, _var, branches} ->
extract_fragments_from_body(Enum.map(branches, &elem(&1, 2)), current_acc)
{:vary, _var, branches} ->
extract_fragments_from_body(Enum.map(branches, &elem(&1, 2)), current_acc)
_ ->
current_acc
end
end)
|> Enum.uniq()
end
defp validate_params_declared(ast) do
declarations = parse_param_declarations(ast.init)
if declarations == %{} do
:ok
else
case extract_vars_from_body(ast.body, []) do
{:ok, used_vars} ->
declared_set =
declarations
|> Map.keys()
|> Enum.map(fn k -> String.trim_leading(k, "@") end)
|> MapSet.new()
unknown_vars = Enum.reject(used_vars, &MapSet.member?(declared_set, &1))
if unknown_vars == [] do
:ok
else
{:error, "unknown_variable: #{hd(unknown_vars)} referenced but not declared"}
end
{:error, _} = error ->
error
end
end
end
defp extract_vars_from_body(nodes, acc) when is_list(nodes) do
result =
Enum.reduce_while(nodes, {:ok, acc}, fn node, {:ok, current_acc} ->
case node do
{:text, t} ->
vars = Regex.scan(~r/@(\w+)/, t, capture: :all_but_first) |> List.flatten()
{:cont, {:ok, current_acc ++ vars}}
{:if, var, cond, then_nodes, elifs, else_node} ->
var_name = String.trim_leading(var, "@")
cond_vars =
Regex.scan(~r/@(\w+)/, cond || "", capture: :all_but_first) |> List.flatten()
branch_acc = [var_name | cond_vars] ++ current_acc
all_branch_nodes = [then_nodes | [else_node | Enum.map(elifs, &elem(&1, 1))]]
case extract_vars_from_body(all_branch_nodes, branch_acc) do
{:ok, vars} -> {:cont, {:ok, vars}}
error -> {:halt, error}
end
{:case, var, branches} ->
var_name = String.trim_leading(var, "@")
branch_nodes = Enum.map(branches, &elem(&1, 2))
case extract_vars_from_body(branch_nodes, [var_name | current_acc]) do
{:ok, vars} -> {:cont, {:ok, vars}}
error -> {:halt, error}
end
{:vary, _, branches} ->
branch_nodes = Enum.map(branches, &elem(&1, 2))
case extract_vars_from_body(branch_nodes, current_acc) do
{:ok, vars} -> {:cont, {:ok, vars}}
error -> {:halt, error}
end
_ ->
{:cont, {:ok, current_acc}}
end
end)
case result do
{:ok, vars} -> {:ok, Enum.uniq(vars)}
error -> error
end
end
def parse_param_declarations(nil), do: %{}
def parse_param_declarations(%{params: params}) do
Enum.into(params, %{}, fn {name, info} ->
{name, parse_param_info(name, info)}
end)
end
def parse_param_declarations(_), do: %{}
defp parse_param_info(_name, info) when is_map(info) do
raw_spec = Map.get(info, :type, "str")
doc = Map.get(info, :doc)
# Handle default value after : or =
# We look for the last : or = that isn't inside brackets []
{type_spec, default_val} = split_type_and_default(raw_spec)
{type, constraints} = parse_type_spec(type_spec)
# Int with range is compile-time (selection), basic int is runtime
life =
constraints[:lifecycle] ||
if type == :int and constraints[:range], do: :compile, else: lifecycle(type)
# Cast default value based on type and ensure booleans default to false
final_default =
case {type, default_val} do
{:bool, nil} ->
false
{:bool, str} when str in ["true", "TRUE", "1"] ->
true
{:bool, str} when str in ["false", "FALSE", "0"] ->
false
{:int, str} when is_binary(str) ->
case Integer.parse(str) do
{n, ""} -> n
_ -> str
end
{:list, str} when is_binary(str) ->
str = str |> String.trim_leading("[") |> String.trim_trailing("]")
if str == "" do
[]
else
str |> String.split(",") |> Enum.map(&String.trim/1) |> Enum.reject(&(&1 == ""))
end
_ ->
default_val
end
%{
type: type,
raw: raw_spec,
doc: doc || "",
lifecycle: life,
default: final_default,
values: constraints[:values],
range: constraints[:range]
}
end
defp split_type_and_default(raw) do
case find_top_level_separator(raw) do
nil ->
{String.trim(raw), nil}
idx ->
type_part = String.slice(raw, 0, idx) |> String.trim()
def_part = String.slice(raw, idx + 1, String.length(raw)) |> String.trim()
# Strip optional quotes
def_part = def_part |> String.trim("\"") |> String.trim("'")
{type_part, def_part}
end
end
defp find_top_level_separator(raw) do
chars = String.to_charlist(raw)
do_find_separator(chars, 0, 0)
end
defp do_find_separator([], _idx, _depth), do: nil
defp do_find_separator([?[ | rest], idx, depth), do: do_find_separator(rest, idx + 1, depth + 1)
defp do_find_separator([?] | rest], idx, depth), do: do_find_separator(rest, idx + 1, depth - 1)
defp do_find_separator([c | _rest], idx, 0) when c == ?=, do: idx
defp do_find_separator([_ | rest], idx, depth), do: do_find_separator(rest, idx + 1, depth)
defp parse_type_spec(spec) do
spec = String.trim(spec)
cond do
Regex.match?(~r/^enum\s*\[(.*)\]$/, spec) ->
[_, vals] = Regex.run(~r/^enum\s*\[(.*)\]$/, spec)
values = Enum.map(String.split(vals, ","), &String.trim/1)
{:enum, %{values: values}}
Regex.match?(~r/^int\s*\[(\d+)\.\.(\d+)\]$/, spec) ->
[_, min_s, max_s] = Regex.run(~r/^int\s*\[(\d+)\.\.(\d+)\]$/, spec)
{:int, %{range: [String.to_integer(min_s), String.to_integer(max_s)]}}
spec == "int" ->
{:int, %{}}
spec == "str" or spec == "string" ->
{:str, %{}}
spec == "bool" or spec == "boolean" ->
{:bool, %{}}
Regex.match?(~r/^list\s*\[(.*)\]$/, spec) ->
[_, vals] = Regex.run(~r/^list\s*\[(.*)\]$/, spec)
values = Enum.map(String.split(vals, ","), &String.trim/1)
{:list, %{values: values}}
true ->
{:str, %{}}
end
end
defp lifecycle(type) do
case type do
:str -> :runtime
:int -> :runtime
:list -> :compile
:bool -> :compile
_ -> :compile
end
end
defp validate_params_present(params, declarations) do
compile_params =
declarations
|> Enum.filter(fn {_, spec} -> Map.get(spec, :lifecycle) == :compile end)
|> Enum.map(fn {k, _} -> k end)
missing =
Enum.filter(compile_params, fn name ->
clean_name = String.trim_leading(name, "@")
atom_name = to_existing_or_nil(clean_name)
spec = Map.get(declarations, name)
is_provided =
(!is_nil(atom_name) and Map.has_key?(params, atom_name)) or
Map.has_key?(params, clean_name) or
Map.has_key?(params, name) or
spec.type == :enum or
spec.type == :list
not is_provided
end)
if missing == [] do
:ok
else
{:error, "missing_param: #{hd(missing)} required but not provided"}
end
end
defp validate_params_types(params, declarations) do
errors =
Enum.reduce(declarations, [], fn {name, spec}, acc ->
clean_name = String.trim_leading(name, "@")
atom_name = to_existing_or_nil(clean_name)
value =
cond do
not is_nil(atom_name) -> Map.get(params, atom_name)
true -> Map.get(params, clean_name) || Map.get(params, name)
end
case validate_value(value, spec, name) do
:ok -> acc
{:error, reason} -> [reason | acc]
end
end)
if errors == [] do
:ok
else
{:error, Enum.join(errors, "; ")}
end
end
defp validate_value(nil, %{lifecycle: :compile}, _name), do: :ok
defp validate_value(nil, %{lifecycle: :runtime}, _name), do: :ok
defp validate_value(_value, %{type: :str}, _name), do: :ok
defp validate_value(value, %{type: :int, range: [min, max]}, name) do
case value do
n when is_integer(n) ->
if n >= min and n <= max,
do: :ok,
else: {:error, "out_of_range: #{name} value #{n} out of range int[#{min}..#{max}]"}
s when is_binary(s) ->
case Integer.parse(s) do
{n, ""} when n >= min and n <= max -> :ok
_ -> {:error, "out_of_range: #{name} value #{s} out of range int[#{min}..#{max}]"}
end
nil ->
:ok
_ ->
{:error, "invalid_type: #{name} expected int[#{min}..#{max}], got #{inspect(value)}"}
end
end
defp validate_value(value, %{type: :int}, name) do
cond do
is_integer(value) -> :ok
is_binary(value) and match?({_, ""}, Integer.parse(value)) -> :ok
true -> {:error, "invalid_type: #{name} expected int, got #{inspect(value)}"}
end
end
defp validate_value(value, %{type: :bool}, name) do
if is_boolean(value) do
:ok
else
{:error, "invalid_type: #{name} expected bool, got #{inspect(value)}"}
end
end
defp validate_value(value, %{type: :enum, values: values}, name) do
str_value = to_string(value)
if str_value in values do
:ok
else
{:error, "invalid_enum: #{name} value #{str_value} not in enum[#{Enum.join(values, ", ")}]"}
end
end
defp validate_value(value, %{type: :list, values: enum_values}, name) do
if is_list(value) do
invalid = Enum.filter(value, fn item -> to_string(item) not in enum_values end)
if invalid == [] do
:ok
else
{:error,
"invalid_enum: #{name} value(s) #{Enum.join(invalid, ", ")} not in list[#{Enum.join(enum_values, ", ")}]"}
end
else
{:error, "invalid_type: #{name} expected list, got #{inspect(value)}"}
end
end
defp validate_value(value, %{type: :list}, name) do
if is_list(value) do
:ok
else
{:error, "invalid_type: #{name} expected list, got #{inspect(value)}"}
end
end
defp maybe_put(map, _key, nil), do: map
defp maybe_put(map, key, value), do: Map.put(map, key, value)
def parse_param_declarations_for_schema(nil), do: %{}
def parse_param_declarations_for_schema(init_nodes) do
parse_param_declarations(init_nodes)
end
def parse_def_block(nil), do: %{}
def parse_def_block(%{def: def_map}) when is_map(def_map) do
Enum.into(def_map, %{}, fn
{:version, v} when is_integer(v) ->
{:version, v}
{:version, v} when is_binary(v) ->
case Integer.parse(v) do
{val, ""} -> {:version, val}
_ -> {:version, v}
end
{:major, v} when is_integer(v) ->
{:major, v}
{:major, v} when is_binary(v) ->
case Integer.parse(v) do
{val, ""} -> {:major, val}
_ -> {:major, v}
end
{k, v} ->
{k, if(is_binary(v), do: v, else: to_string(v))}
end)
end
def parse_def_block(_), do: %{}
def parse_docs_block(nil), do: nil
def parse_docs_block(%{docs: docs_text}), do: docs_text
def parse_docs_block(_), do: nil
def parse_fragment_declarations(nil), do: %{}
def parse_fragment_declarations(%{fragments: fragments}) when is_map(fragments) do
Enum.into(fragments, %{}, fn {name, info} ->
raw_type = if is_map(info), do: Map.get(info, :type, "dynamic"), else: "dynamic"
doc = if is_map(info), do: Map.get(info, :doc, ""), else: ""
# Clean name from braces for consistent lookup
clean_name =
name
|> to_string()
|> String.trim_leading("{")
|> String.trim_leading("{")
|> String.trim_trailing("}")
|> String.trim_trailing("}")
# Type can include "from: path" and assembly rules
{type, from, rules} = parse_fragment_type_and_rules(raw_type)
# If "from" is specified, it's a static fragment reference (compile-time inline)
# Otherwise default to "dynamic" (runtime interpolation)
fragment_type =
cond do
from != nil and from != "" -> "static"
type == "" or type == nil -> "dynamic"
true -> type
end
source_path = if from, do: from, else: nil
{clean_name,
%{type: fragment_type, doc: doc}
|> maybe_put(:from, source_path)
|> maybe_put(:match, if(is_map(info), do: info[:match], else: nil))
|> maybe_put(:matchRe, if(is_map(info), do: info[:matchRe] || info[:matchre], else: nil))
|> maybe_put(:limit, if(is_map(info), do: info[:limit], else: nil))
|> maybe_put(:order, if(is_map(info), do: info[:order], else: nil))
|> Map.merge(rules)}
end)
end
def parse_fragment_declarations(_), do: %{}
defp parse_fragment_type_and_rules(raw) do
# Split by lines to handle rules indented under the type
lines = String.split(raw, ["\n", "\r\n"])
[first | rest] = Enum.map(lines, &String.trim/1)
{type, from} =
case String.split(first, "from:", parts: 2) do
[t, f] -> {String.trim(t), String.trim(f)}
[t] -> {String.trim(t), nil}
end
rules =
Enum.reduce(rest, %{}, fn line, acc ->
case String.split(line, ":", parts: 2) do
[rule, val] ->
rule_val = String.trim(val)
case to_existing_or_nil(String.trim(rule)) do
nil -> acc
rule_name -> Map.put(acc, rule_name, rule_val)
end
_ ->
acc
end
end)
{type, from, rules}
end
defp to_existing_or_nil(""), do: nil
defp to_existing_or_nil(binary) when is_binary(binary) do
try do
String.to_existing_atom(binary)
rescue
ArgumentError -> nil
end
end
defp to_existing_or_nil(_), do: nil
end