lib/evaluator.ex

defmodule Dsqlex.Evaluator do
  @moduledoc """
  Evaluates a parsed AST against a context (map of field names to values).

  Example:
      context = %{
        "x" => Decimal.new("100.00"),
        "y" => Decimal.new("20.00"),
        "category" => "B",
        "z" => Decimal.new("5.00")
      }

      {:ok, ast} = Dsqlex.Parser.parse(tokens)
      {:ok, result} = Dsqlex.Evaluator.evaluate(ast, context)
  """

  def evaluate(ast, context, opts \\ []) when is_map(context) do
    try do
      {:ok, do_eval(ast, context, opts)}
    rescue
      e -> {:error, Exception.message(e)}
    end
  end

  # ============================================================
  # (optional) SELECT wrapper
  # ============================================================
  defp do_eval({:select, expr}, context, opts), do: do_eval(expr, context, opts)

  # ============================================================
  # Literals
  # ============================================================
  defp do_eval({:number, n}, _context, _opts), do: Decimal.new(n)
  defp do_eval({:string, s}, _context, _opts), do: s
  defp do_eval({:boolean, b}, _context, _opts), do: b
  defp do_eval({:null}, _context, _opts), do: nil

  # ============================================================
  # Identifier - lookup in context
  # ============================================================
  defp do_eval({:identifier, name}, context, opts) do
    case Map.fetch(context, name) do
      {:ok, value} ->
        value

      :error ->
        # Try dot-path access for nested maps (e.g. "config.pricing.margin_rate")
        if String.contains?(name, ".") do
          resolve_dot_path(name, context)
        else
          resolver = Keyword.get(opts, :resolver)

          if resolver do
            visited = Keyword.get(opts, :visited, MapSet.new())

            if MapSet.member?(visited, name) do
              raise "Circular reference detected: #{name}"
            end

            case resolver.(name, visited) do
              {:ok, value} -> value
              {:error, reason} -> raise reason
            end
          else
            raise "Unknown field: #{name}"
          end
        end
    end
  end

  # ============================================================
  # Binary operations - arithmetic
  # ============================================================
  defp do_eval({:binary_op, :plus, left, right}, context, opts) do
    Decimal.add(to_decimal(do_eval(left, context, opts)), to_decimal(do_eval(right, context, opts)))
  end

  defp do_eval({:binary_op, :minus, left, right}, context, opts) do
    Decimal.sub(to_decimal(do_eval(left, context, opts)), to_decimal(do_eval(right, context, opts)))
  end

  defp do_eval({:binary_op, :multiply, left, right}, context, opts) do
    Decimal.mult(to_decimal(do_eval(left, context, opts)), to_decimal(do_eval(right, context, opts)))
  end

  defp do_eval({:binary_op, :divide, left, right}, context, opts) do
    Decimal.div(to_decimal(do_eval(left, context, opts)), to_decimal(do_eval(right, context, opts)))
  end

  # ============================================================
  # Binary operations - comparison
  # ============================================================
  defp do_eval({:binary_op, :eq, left, right}, context, opts) do
    compare_values(do_eval(left, context, opts), do_eval(right, context, opts)) == :eq
  end

  defp do_eval({:binary_op, :neq, left, right}, context, opts) do
    compare_values(do_eval(left, context, opts), do_eval(right, context, opts)) != :eq
  end

  defp do_eval({:binary_op, :lt, left, right}, context, opts) do
    compare_values(do_eval(left, context, opts), do_eval(right, context, opts)) == :lt
  end

  defp do_eval({:binary_op, :gt, left, right}, context, opts) do
    compare_values(do_eval(left, context, opts), do_eval(right, context, opts)) == :gt
  end

  defp do_eval({:binary_op, :lte, left, right}, context, opts) do
    compare_values(do_eval(left, context, opts), do_eval(right, context, opts)) in [:lt, :eq]
  end

  defp do_eval({:binary_op, :gte, left, right}, context, opts) do
    compare_values(do_eval(left, context, opts), do_eval(right, context, opts)) in [:gt, :eq]
  end

  # ============================================================
  # Binary operations - logical
  # ============================================================
  defp do_eval({:binary_op, :and, left, right}, context, opts) do
    do_eval(left, context, opts) && do_eval(right, context, opts)
  end

  defp do_eval({:binary_op, :or, left, right}, context, opts) do
    do_eval(left, context, opts) || do_eval(right, context, opts)
  end

  # ============================================================
  # IN / NOT IN
  # ============================================================
  defp do_eval({:in, expr, items}, context, opts) do
    value = do_eval(expr, context, opts)
    Enum.any?(items, fn item -> compare_values(value, do_eval(item, context, opts)) == :eq end)
  end

  defp do_eval({:not_in, expr, items}, context, opts) do
    value = do_eval(expr, context, opts)
    not Enum.any?(items, fn item -> compare_values(value, do_eval(item, context, opts)) == :eq end)
  end

  # ============================================================
  # LIKE / NOT LIKE (case-insensitive, MySQL/MariaDB default)
  # ============================================================
  defp do_eval({:like, expr, pattern_expr}, context, opts) do
    value = do_eval(expr, context, opts) |> to_string()
    pattern = do_eval(pattern_expr, context, opts) |> to_string()
    like_match?(value, pattern)
  end

  defp do_eval({:not_like, expr, pattern_expr}, context, opts) do
    value = do_eval(expr, context, opts) |> to_string()
    pattern = do_eval(pattern_expr, context, opts) |> to_string()
    not like_match?(value, pattern)
  end

  # ============================================================
  # CASE expression
  # ============================================================
  defp do_eval({:case_expr, when_clauses, else_clause}, context, opts) do
    eval_when_clauses(when_clauses, else_clause, context, opts)
  end

  # ============================================================
  # Function calls
  # ============================================================
  defp do_eval({:call, :round, [value, precision]}, context, opts) do
    Decimal.round(to_decimal(do_eval(value, context, opts)), do_eval(precision, context, opts) |> Decimal.to_integer())
  end

  defp do_eval({:call, :coalesce, args}, context, opts) do
    Enum.find_value(args, fn arg ->
      result = do_eval(arg, context, opts)
      if result != nil, do: result, else: nil
    end)
  end

  defp do_eval({:call, :upper, [value]}, context, opts) do
    do_eval(value, context, opts) |> to_string() |> String.upcase()
  end

  defp do_eval({:call, :lower, [value]}, context, opts) do
    do_eval(value, context, opts) |> to_string() |> String.downcase()
  end

  defp do_eval({:call, :abs, [value]}, context, opts) do
    Decimal.abs(to_decimal(do_eval(value, context, opts)))
  end

  defp do_eval({:call, :concat, args}, context, opts) do
    args
    |> Enum.map(&do_eval(&1, context, opts))
    |> Enum.map(&to_string/1)
    |> Enum.join()
  end

  # EVENT(type, subtype) — evaluate referenced formula with current context
  defp do_eval({:call, :event, [{:identifier, type}, {:identifier, subtype}]}, context, opts) do
    resolve_event(type, subtype, context, opts)
  end

  # EVENT(type, subtype, context_source) — evaluate referenced formula with sub-entity context
  # If context_source resolves to a list, evaluates per item and sums the results
  defp do_eval({:call, :event, [{:identifier, type}, {:identifier, subtype}, {:identifier, source}]}, context, opts) do
    case Map.fetch(context, source) do
      {:ok, sub_context} when is_list(sub_context) ->
        sub_context
        |> Enum.map(fn item -> resolve_event(type, subtype, item, opts) end)
        |> Enum.reduce(Decimal.new(0), &Decimal.add/2)

      {:ok, sub_context} when is_map(sub_context) ->
        resolve_event(type, subtype, sub_context, opts)

      {:ok, _} ->
        raise "EVENT context source '#{source}' must be a map or list of maps"

      :error ->
        raise "EVENT context source '#{source}' not found in context"
    end
  end

  defp do_eval({:call, :event, _args}, _context, _opts) do
    raise "EVENT requires 2 or 3 arguments: EVENT(type, subtype) or EVENT(type, subtype, context_source)"
  end

  defp resolve_event(type, subtype, eval_context, opts) do
    event_resolver = Keyword.get(opts, :event_resolver)

    unless event_resolver do
      raise "EVENT() calls require an :event_resolver option"
    end

    event_key = "#{type}.#{subtype}"
    visited = Keyword.get(opts, :visited, MapSet.new())

    if MapSet.member?(visited, event_key) do
      raise "Circular reference detected: #{event_key}"
    end

    new_visited = MapSet.put(visited, event_key)
    new_opts = Keyword.put(opts, :visited, new_visited)

    case event_resolver.(type, subtype, eval_context, new_opts) do
      {:ok, result} -> result
      {:error, reason} -> raise reason
    end
  end

  # ============================================================
  # CASE/WHEN helpers
  # ============================================================
  defp eval_when_clauses([], else_clause, context, opts) do
    if else_clause do
      do_eval(else_clause, context, opts)
    else
      nil
    end
  end

  defp eval_when_clauses([{:when, condition, result} | rest], else_clause, context, opts) do
    if do_eval(condition, context, opts) do
      do_eval(result, context, opts)
    else
      eval_when_clauses(rest, else_clause, context, opts)
    end
  end

  # ============================================================
  # Helpers
  # ============================================================
  defp to_decimal(%Decimal{} = d), do: d
  defp to_decimal(n) when is_integer(n), do: Decimal.new(n)
  defp to_decimal(n) when is_float(n), do: Decimal.from_float(n)
  defp to_decimal(s) when is_binary(s), do: Decimal.new(s)

  defp compare_values(nil, nil), do: :eq
  defp compare_values(nil, _), do: :neq
  defp compare_values(_, nil), do: :neq
  defp compare_values(%Decimal{} = a, %Decimal{} = b), do: Decimal.compare(a, b)
  defp compare_values(%Decimal{} = a, b), do: Decimal.compare(a, to_decimal(b))
  defp compare_values(a, %Decimal{} = b), do: Decimal.compare(to_decimal(a), b)
  defp compare_values(a, b) when is_binary(a) and is_binary(b) do
    cond do
      a == b -> :eq
      a < b -> :lt
      a > b -> :gt
    end
  end
  defp compare_values(a, b) when a == b, do: :eq
  defp compare_values(a, b) when a < b, do: :lt
  defp compare_values(a, b) when a > b, do: :gt

  # Convert a SQL LIKE pattern to an Elixir regex and test (case-insensitive)
  # % = any sequence of characters, _ = any single character
  defp like_match?(value, pattern) do
    # 1. Replace LIKE wildcards with placeholders before escaping
    # 2. Escape remaining regex-special chars
    # 3. Replace placeholders with regex equivalents
    regex_str =
      pattern
      |> String.replace("%", "\x00PCT\x00")
      |> String.replace("_", "\x00UND\x00")
      |> Regex.escape()
      |> String.replace("\x00PCT\x00", ".*")
      |> String.replace("\x00UND\x00", ".")

    {:ok, regex} = Regex.compile("^#{regex_str}$", "i")
    Regex.match?(regex, value)
  end

  defp resolve_dot_path(path, context) do
    parts = String.split(path, ".")
    resolve_dot_parts(parts, context, path)
  end

  defp resolve_dot_parts([], acc, _path), do: acc

  defp resolve_dot_parts([key | rest], acc, path) when is_map(acc) do
    case Map.fetch(acc, key) do
      {:ok, value} -> resolve_dot_parts(rest, value, path)
      :error -> raise "Unknown field: #{path} (failed at '#{key}')"
    end
  end

  defp resolve_dot_parts([_key | _] = remaining, acc, path) when is_list(acc) do
    results = Enum.map(acc, fn item -> resolve_dot_parts(remaining, item, path) end)

    if Enum.all?(results, &decimal_like?/1) do
      results
      |> Enum.map(&to_decimal/1)
      |> Enum.reduce(Decimal.new(0), &Decimal.add/2)
    else
      results
    end
  end

  defp resolve_dot_parts([key | _rest], _acc, path) do
    raise "Cannot access '#{key}' on non-map value in path '#{path}'"
  end

  defp decimal_like?(%Decimal{}), do: true
  defp decimal_like?(n) when is_number(n), do: true
  defp decimal_like?(s) when is_binary(s) do
    case Decimal.parse(s) do
      {_, ""} -> true
      _ -> false
    end
  end
  defp decimal_like?(_), do: false
end