lib/sat_solver.ex

defmodule Ash.SatSolver do
  @moduledoc """
  Tools for working with the satsolver that drives filter subset checking (for authorization)
  """

  alias Ash.Filter
  alias Ash.Query.{BooleanExpression, Not, Ref}

  @dialyzer {:nowarn_function, overlap?: 2}

  defmacro b(statement) do
    value =
      Macro.prewalk(
        statement,
        fn
          {:and, _, [left, right]} ->
            quote do
              {:and, unquote(left), unquote(right)}
            end

          {:or, _, [left, right]} ->
            quote do
              {:or, unquote(left), unquote(right)}
            end

          {:not, _, [value]} ->
            quote do
              {:not, unquote(value)}
            end

          other ->
            other
        end
      )

    quote do
      unquote(value)
      |> Ash.SatSolver.balance()
    end
  end

  def balance({op, left, right}) do
    left = balance(left)
    right = balance(right)
    [left, right] = Enum.sort([left, right])

    {op, left, right}
  end

  def balance({:not, {:not, right}}) do
    balance(right)
  end

  def balance({:not, statement}) do
    {:not, balance(statement)}
  end

  def balance(other), do: other

  def strict_filter_subset(filter, candidate) do
    case {filter, candidate} do
      {%{expression: nil}, %{expression: nil}} ->
        true

      {%{expression: nil}, _candidate_expr} ->
        true

      {_filter_expr, %{expression: nil}} ->
        false

      {filter, candidate} ->
        do_strict_filter_subset(filter, candidate)
    end
  end

  defp do_strict_filter_subset(filter, candidate) do
    expr = BooleanExpression.new(:and, filter.expression, candidate.expression)

    case transform_and_solve(
           filter.resource,
           expr
         ) do
      {:error, :unsatisfiable} ->
        false

      {:ok, _scenario} ->
        expr = BooleanExpression.new(:and, Not.new(filter.expression), candidate.expression)

        case transform_and_solve(
               filter.resource,
               expr
             ) do
          {:error, :unsatisfiable} ->
            true

          {:ok, _scenario} ->
            :maybe
        end
    end
  end

  defp filter_to_expr(nil), do: nil
  defp filter_to_expr(false), do: false
  defp filter_to_expr(true), do: true
  defp filter_to_expr(%Filter{expression: expression}), do: filter_to_expr(expression)
  defp filter_to_expr(%{__predicate__?: _} = op_or_func), do: op_or_func
  defp filter_to_expr(%Not{expression: expression}), do: b(not filter_to_expr(expression))

  defp filter_to_expr(%BooleanExpression{op: op, left: left, right: right}) do
    {op, filter_to_expr(left), filter_to_expr(right)}
  end

  defp filter_to_expr(expr) do
    raise ArgumentError, message: "Invalid filter expression #{inspect(expr)}"
  end

  def transform(resource, expression) do
    expression
    |> consolidate_relationships(resource)
    |> upgrade_related_filters_to_join_keys(resource)
    |> build_expr_with_predicate_information()
  end

  def transform_and_solve(resource, expression) do
    resource
    |> transform(expression)
    |> solve_expression()
  end

  defp upgrade_related_filters_to_join_keys(
         %BooleanExpression{op: op, left: left, right: right},
         resource
       ) do
    BooleanExpression.new(
      op,
      upgrade_related_filters_to_join_keys(left, resource),
      upgrade_related_filters_to_join_keys(right, resource)
    )
  end

  defp upgrade_related_filters_to_join_keys(%Not{expression: expression}, resource) do
    Not.new(upgrade_related_filters_to_join_keys(expression, resource))
  end

  defp upgrade_related_filters_to_join_keys(
         %{__operator__?: true, left: left, right: right} = op,
         resource
       ) do
    %{op | left: upgrade_ref(left, resource), right: upgrade_ref(right, resource)}
  end

  defp upgrade_related_filters_to_join_keys(
         %{__function__?: true, arguments: arguments} = function,
         resource
       ) do
    %{function | arguments: Enum.map(arguments, &upgrade_ref(&1, resource))}
  end

  defp upgrade_related_filters_to_join_keys(expr, _), do: expr

  defp upgrade_ref({key, ref}, resource) when is_atom(key) do
    {key, upgrade_ref(ref, resource)}
  end

  defp upgrade_ref(
         %Ash.Query.Ref{attribute: attribute, relationship_path: path} = ref,
         resource
       )
       when path != [] do
    with relationship when not is_nil(relationship) <-
           Ash.Resource.Info.relationship(resource, path),
         true <- attribute.name == relationship.destination_field,
         new_attribute when not is_nil(new_attribute) <-
           Ash.Resource.Info.attribute(relationship.source, relationship.source_field) do
      %{
        ref
        | relationship_path: :lists.droplast(path),
          attribute: new_attribute,
          resource: resource
      }
    else
      _ ->
        ref
    end
  end

  defp upgrade_ref(other, _), do: other

  defp consolidate_relationships(expression, resource) do
    {replacements, _all_relationship_paths} =
      expression
      |> Filter.relationship_paths()
      |> Enum.reduce({%{}, []}, fn path, {replacements, kept_paths} ->
        case find_synonymous_relationship_path(resource, kept_paths, path) do
          nil ->
            {replacements, [path | kept_paths]}

          synonymous_path ->
            Map.put(replacements, path, synonymous_path)
        end
      end)

    do_consolidate_relationships(expression, replacements)
  end

  defp do_consolidate_relationships(
         %BooleanExpression{op: op, left: left, right: right},
         replacements
       ) do
    BooleanExpression.new(
      op,
      do_consolidate_relationships(left, replacements),
      do_consolidate_relationships(right, replacements)
    )
  end

  defp do_consolidate_relationships(%Not{expression: expression}, replacements) do
    Not.new(do_consolidate_relationships(expression, replacements))
  end

  defp do_consolidate_relationships(%Ash.Query.Ref{relationship_path: path} = ref, replacements)
       when path != [] do
    case Map.fetch(replacements, path) do
      {:ok, replacement} when not is_nil(replacement) -> %{ref | relationship_path: replacement}
      :error -> ref
    end
  end

  defp do_consolidate_relationships(%{__function__?: true, arguments: args} = func, replacements) do
    %{func | arguments: Enum.map(args, &do_consolidate_relationships(&1, replacements))}
  end

  defp do_consolidate_relationships(
         %{__operator__?: true, left: left, right: right} = op,
         replacements
       ) do
    %{
      op
      | left: do_consolidate_relationships(left, replacements),
        right: do_consolidate_relationships(right, replacements)
    }
  end

  defp do_consolidate_relationships(other, _), do: other

  defp find_synonymous_relationship_path(resource, paths, path) do
    Enum.find_value(paths, fn candidate_path ->
      if synonymous_relationship_paths?(resource, candidate_path, path) do
        candidate_path
      else
        false
      end
    end)
  end

  # def synonymous_relationship_paths?(_, [], []), do: true

  # def synonymous_relationship_paths?(_resource, candidate_path, path)
  #     when length(candidate_path) != length(path),
  #     do: false

  # def synonymous_relationship_paths?(resource, [candidate_first | candidate_rest], [first | rest])
  #     when first == candidate_first do
  #   synonymous_relationship_paths?(
  #     Ash.Resource.Info.relationship(resource, candidate_first).destination,
  #     candidate_rest,
  #     rest
  #   )
  # end

  def synonymous_relationship_paths?(
        left_resource,
        candidate,
        search,
        right_resource \\ nil
      )

  def synonymous_relationship_paths?(_, [], [], _), do: true
  def synonymous_relationship_paths?(_, [], _, _), do: false
  def synonymous_relationship_paths?(_, _, [], _), do: false

  def synonymous_relationship_paths?(
        left_resource,
        [candidate_first | candidate_rest],
        [first | rest],
        right_resource
      ) do
    right_resource = right_resource || left_resource
    relationship = Ash.Resource.Info.relationship(left_resource, first)
    candidate_relationship = Ash.Resource.Info.relationship(right_resource, candidate_first)

    cond do
      !relationship || !candidate_relationship ->
        false

      relationship.type == :many_to_many && candidate_relationship.type == :has_many ->
        synonymous_relationship_paths?(left_resource, [relationship.join_relationship], [
          candidate_first
        ]) &&
          synonymous_relationship_paths?(
            left_resource,
            candidate_rest,
            rest,
            right_resource
          )

      relationship.type == :has_many && candidate_relationship.type == :many_to_many ->
        synonymous_relationship_paths?(left_resource, [relationship.name], [
          candidate_relationship.join_relationship
        ]) &&
          synonymous_relationship_paths?(
            left_resource,
            candidate_rest,
            rest,
            right_resource
          )

      true ->
        comparison_keys = [
          :source_field,
          :destination_field,
          :source_field_on_join_table,
          :destination_field_on_join_table,
          :destination_field,
          :destination
        ]

        Map.take(relationship, comparison_keys) ==
          Map.take(candidate_relationship, comparison_keys) and
          synonymous_relationship_paths?(relationship.destination, candidate_rest, rest)
    end
  end

  defp build_expr_with_predicate_information(expression) do
    expression = fully_simplify(expression)

    all_predicates =
      expression
      |> Filter.list_predicates()
      |> Enum.uniq()

    comparison_expressions =
      all_predicates
      |> Enum.filter(fn %module{} ->
        :erlang.function_exported(module, :compare, 2)
      end)
      |> Enum.reduce([], fn predicate, new_expressions ->
        all_predicates
        |> Enum.reject(&Kernel.==(&1, predicate))
        |> Enum.filter(&shares_ref?(&1, predicate))
        |> Enum.reduce(new_expressions, fn other_predicate, new_expressions ->
          # With predicate as a and other_predicate as b
          case Ash.Filter.Predicate.compare(predicate, other_predicate) do
            :right_includes_left ->
              # b || !a

              [b(other_predicate or not predicate) | new_expressions]

            :left_includes_right ->
              # a || ! b
              [b(predicate or not other_predicate) | new_expressions]

            :mutually_inclusive ->
              # (a && b) || (! a && ! b)
              [
                b((predicate and other_predicate) or (not predicate and not other_predicate))
                | new_expressions
              ]

            :mutually_exclusive ->
              [b(not (other_predicate and predicate)) | new_expressions]

            :mutually_exclusive_and_collectively_exhaustive ->
              [
                b(
                  not (other_predicate and predicate) and
                    not (not other_predicate and not predicate)
                )
                | new_expressions
              ]

            _other ->
              # If we can't tell, we assume that both could be true
              new_expressions
          end
        end)
      end)
      |> Enum.uniq()

    expression = filter_to_expr(expression)

    expression_with_comparisons =
      Enum.reduce(comparison_expressions, expression, fn comparison_expression, expression ->
        b(comparison_expression and expression)
      end)

    all_predicates
    |> Enum.map(& &1.__struct__)
    |> Enum.uniq()
    |> Enum.flat_map(fn struct ->
      if :erlang.function_exported(struct, :bulk_compare, 1) do
        struct.bulk_compare(all_predicates)
      else
        []
      end
    end)
    |> Enum.reduce(expression_with_comparisons, fn comparison_expression, expression ->
      b(comparison_expression and expression)
    end)
  end

  def fully_simplify(expression) do
    expression
    |> do_fully_simplify()
    |> lift_equals_out_of_in()
    |> do_fully_simplify()
  end

  defp do_fully_simplify(expression) do
    expression
    |> simplify()
    |> case do
      ^expression ->
        expression

      simplified ->
        fully_simplify(simplified)
    end
  end

  def lift_equals_out_of_in(expression) do
    case find_non_equal_overlap(expression) do
      nil ->
        expression

      non_equal_overlap ->
        expression
        |> split_in_expressions(non_equal_overlap)
        |> lift_equals_out_of_in()
    end
  end

  def find_non_equal_overlap(expression) do
    Ash.Filter.find(expression, fn sub_expr ->
      Ash.Filter.find(expression, fn sub_expr2 ->
        # if has_call_or_expression?(sub_expr) || has_call_or_expression?(sub_expr2) do
        #   false
        # else
        overlap?(sub_expr, sub_expr2)
        # end
      end)
    end)
  end

  defp new_in(base, right) do
    case MapSet.size(right) do
      1 ->
        %Ash.Query.Operator.Eq{left: base.left, right: Enum.at(right, 0)}

      _ ->
        %Ash.Query.Operator.In{left: base.left, right: right}
    end
  end

  def split_in_expressions(
        %Ash.Query.Operator.In{right: right} = sub_expr,
        %Ash.Query.Operator.Eq{right: value} = non_equal_overlap
      ) do
    if overlap?(non_equal_overlap, sub_expr) do
      Ash.Query.BooleanExpression.new(
        :or,
        new_in(sub_expr, MapSet.delete(right, value)),
        non_equal_overlap
      )
    else
      sub_expr
    end
  end

  def split_in_expressions(
        %Ash.Query.Operator.In{} = sub_expr,
        %Ash.Query.Operator.In{right: right} = non_equal_overlap
      ) do
    if overlap?(sub_expr, non_equal_overlap) do
      diff = MapSet.difference(sub_expr.right, right)

      if MapSet.size(diff) == 0 do
        Enum.reduce(sub_expr.right, nil, fn var, acc ->
          BooleanExpression.new(:or, %Ash.Query.Operator.Eq{left: sub_expr.left, right: var}, acc)
        end)
      else
        new_right = new_in(sub_expr, MapSet.intersection(sub_expr.right, right))

        Ash.Query.BooleanExpression.new(
          :or,
          new_in(sub_expr, diff),
          new_right
        )
      end
    else
      sub_expr
    end
  end

  def split_in_expressions(nil, _), do: nil

  def split_in_expressions(%Ash.Filter{expression: expression} = filter, non_equal_overlap),
    do: %{filter | expression: split_in_expressions(expression, non_equal_overlap)}

  def split_in_expressions(%Not{expression: expression} = not_expr, non_equal_overlap),
    do: %{not_expr | expression: split_in_expressions(expression, non_equal_overlap)}

  def split_in_expressions(
        %BooleanExpression{left: left, right: right} = expr,
        non_equal_overlap
      ),
      do: %{
        expr
        | left: split_in_expressions(left, non_equal_overlap),
          right: split_in_expressions(right, non_equal_overlap)
      }

  def split_in_expressions(other, _), do: other

  def overlap?(
        %Ash.Query.Operator.In{left: left, right: %{__struct__: MapSet} = left_right},
        %Ash.Query.Operator.In{left: left, right: %{__struct__: MapSet} = right_right}
      ) do
    if MapSet.equal?(left_right, right_right) do
      false
    else
      overlap? =
        left_right
        |> MapSet.intersection(right_right)
        |> MapSet.size()
        |> Kernel.>(0)

      if overlap? do
        true
      else
        false
      end
    end
  end

  def overlap?(_, %Ash.Query.Operator.Eq{right: %Ref{}}),
    do: false

  def overlap?(%Ash.Query.Operator.Eq{right: %Ref{}}, _),
    do: false

  def overlap?(
        %Ash.Query.Operator.Eq{left: left, right: left_right},
        %Ash.Query.Operator.In{left: left, right: %{__struct__: MapSet} = right_right}
      ) do
    MapSet.member?(right_right, left_right)
  end

  def overlap?(_left, _right) do
    false
  end

  def mutually_exclusive(predicates, acc \\ [])
  def mutually_exclusive([], acc), do: acc

  def mutually_exclusive([predicate | rest], acc) do
    new_acc =
      Enum.reduce(rest, acc, fn other_predicate, acc ->
        [b(not (predicate and other_predicate)) | acc]
      end)

    mutually_exclusive(rest, new_acc)
  end

  def mutually_exclusive_and_collectively_exhaustive([]), do: []

  def mutually_exclusive_and_collectively_exhaustive([_]), do: []

  def mutually_exclusive_and_collectively_exhaustive(predicates) do
    mutually_exclusive(predicates) ++
      Enum.flat_map(predicates, fn predicate ->
        other_predicates = Enum.reject(predicates, &(&1 == predicate))

        other_predicates_union =
          Enum.reduce(other_predicates, nil, fn other_predicate, expr ->
            if expr do
              b(expr or other_predicate)
            else
              other_predicate
            end
          end)

        b(
          not (predicate and other_predicates_union) and
            not (not predicate and not other_predicates_union)
        )
      end)
  end

  def left_excludes_right(left, right) do
    b(not (left and right))
  end

  def right_excludes_left(left, right) do
    b(not (right and left))
  end

  def mutually_inclusive(predicates, acc \\ [])
  def mutually_inclusive([], acc), do: acc

  def mutually_inclusive([predicate | rest], acc) do
    new_acc =
      Enum.reduce(rest, acc, fn other_predicate, acc ->
        [b((predicate and other_predicate) or (not predicate and not other_predicate)) | acc]
      end)

    mutually_exclusive(rest, new_acc)
  end

  def right_implies_left(left, right) do
    b(not (right and not left))
  end

  def left_implies_right(left, right) do
    b(not (left and not right))
  end

  defp shares_ref?(left, right) do
    any_refs_in_common?(refs(left), refs(right))
  end

  defp any_refs_in_common?(left_refs, right_refs) do
    Enum.any?(left_refs, &(&1 in right_refs))
  end

  defp refs(%{__operator__?: true, left: left, right: right}) do
    Enum.filter([left, right], &match?(%Ref{}, &1))
  end

  defp refs(%{__function__?: true, arguments: arguments}) do
    Enum.filter(arguments, &match?(%Ref{}, &1))
  end

  defp refs(_), do: []

  defp simplify(%BooleanExpression{op: op, left: left, right: right}) do
    BooleanExpression.new(op, simplify(left), simplify(right))
  end

  defp simplify(%Not{expression: expression}) do
    Not.new(simplify(expression))
  end

  defp simplify(%mod{__predicate__?: true} = predicate) do
    if :erlang.function_exported(mod, :simplify, 1) do
      predicate
      |> mod.simplify()
      |> Kernel.||(predicate)
    else
      predicate
    end
  end

  defp simplify(other), do: other

  def solve_expression(expression) do
    expression_with_constants = b(true and not false and expression)

    {bindings, expression} = extract_bindings(expression_with_constants)

    expression
    |> to_conjunctive_normal_form()
    |> lift_clauses()
    |> negations_to_negative_numbers()
    |> Enum.uniq()
    |> Picosat.solve()
    |> solutions_to_predicate_values(bindings)
  end

  defp solutions_to_predicate_values({:ok, solution}, bindings) do
    scenario =
      Enum.reduce(solution, %{true: [], false: []}, fn var, state ->
        fact = Map.get(bindings, abs(var))

        Map.put(state, fact, var > 0)
      end)

    {:ok, scenario}
  end

  defp solutions_to_predicate_values({:error, error}, _), do: {:error, error}

  defp extract_bindings(expr, bindings \\ %{current: 1})

  defp extract_bindings({operator, left, right}, bindings) do
    {bindings, left_extracted} = extract_bindings(left, bindings)
    {bindings, right_extracted} = extract_bindings(right, bindings)

    {bindings, {operator, left_extracted, right_extracted}}
  end

  defp extract_bindings({:not, value}, bindings) do
    {bindings, extracted} = extract_bindings(value, bindings)

    {bindings, b(not extracted)}
  end

  defp extract_bindings(value, %{current: current} = bindings) do
    current_binding =
      Enum.find(bindings, fn {key, binding_value} ->
        key != :current && binding_value == value
      end)

    case current_binding do
      nil ->
        new_bindings =
          bindings
          |> Map.put(:current, current + 1)
          |> Map.put(current, value)

        {new_bindings, current}

      {binding, _} ->
        {bindings, binding}
    end
  end

  # A helper function for formatting to the same output we'd give to picosat
  @doc false
  def to_picosat(clauses, variable_count) do
    clause_count = Enum.count(clauses)

    formatted_input =
      Enum.map_join(clauses, "\n", fn clause ->
        format_clause(clause) <> " 0"
      end)

    "p cnf #{variable_count} #{clause_count}\n" <> formatted_input
  end

  defp negations_to_negative_numbers(clauses) do
    Enum.map(
      clauses,
      fn
        {:not, var} when is_integer(var) ->
          [negate_var(var)]

        var when is_integer(var) ->
          [var]

        clause ->
          Enum.map(clause, fn
            {:not, var} -> negate_var(var)
            var -> var
          end)
      end
    )
  end

  defp negate_var(var, multiplier \\ -1)

  defp negate_var({:not, value}, multiplier) do
    negate_var(value, multiplier * -1)
  end

  defp negate_var(value, multiplier), do: value * multiplier

  defp format_clause(clause) do
    Enum.map_join(clause, " ", fn
      {:not, var} -> "-#{var}"
      var -> "#{var}"
    end)
  end

  defp lift_clauses({:and, left, right}) do
    lift_clauses(left) ++ lift_clauses(right)
  end

  defp lift_clauses({:or, left, right}) do
    [lift_or_clauses(left) ++ lift_or_clauses(right)]
  end

  defp lift_clauses(value), do: [[value]]

  defp lift_or_clauses({:or, left, right}) do
    lift_or_clauses(left) ++ lift_or_clauses(right)
  end

  defp lift_or_clauses(value), do: [value]

  defp to_conjunctive_normal_form(expression) do
    expression
    |> demorgans_law()
    |> distributive_law()
  end

  defp distributive_law(expression) do
    distributive_law_applied = apply_distributive_law(expression)

    if expression == distributive_law_applied do
      expression
    else
      distributive_law(distributive_law_applied)
    end
  end

  defp apply_distributive_law({:or, left, {:and, right1, right2}}) do
    left_distributed = apply_distributive_law(left)

    {:and, {:or, left_distributed, apply_distributive_law(right1)},
     {:or, left_distributed, apply_distributive_law(right2)}}
  end

  defp apply_distributive_law({:or, {:and, left1, left2}, right}) do
    right_distributed = apply_distributive_law(right)

    {:and, {:or, apply_distributive_law(left1), right_distributed},
     {:or, apply_distributive_law(left2), right_distributed}}
  end

  defp apply_distributive_law({:not, expression}) do
    {:not, apply_distributive_law(expression)}
  end

  defp apply_distributive_law({operator, left, right}) when operator in [:and, :or] do
    {operator, apply_distributive_law(left), apply_distributive_law(right)}
  end

  defp apply_distributive_law(var) when is_integer(var) do
    var
  end

  defp demorgans_law(expression) do
    demorgans_law_applied = apply_demorgans_law(expression)

    if expression == demorgans_law_applied do
      expression
    else
      demorgans_law(demorgans_law_applied)
    end
  end

  defp apply_demorgans_law({:not, {:and, left, right}}) do
    {:or, {:not, apply_demorgans_law(left)}, {:not, apply_demorgans_law(right)}}
  end

  defp apply_demorgans_law({:not, {:or, left, right}}) do
    {:and, {:not, left}, {:not, right}}
  end

  defp apply_demorgans_law({operator, left, right}) when operator in [:or, :and] do
    {operator, apply_demorgans_law(left), apply_demorgans_law(right)}
  end

  defp apply_demorgans_law({:not, expression}) do
    {:not, apply_demorgans_law(expression)}
  end

  defp apply_demorgans_law(var) when is_integer(var) do
    var
  end
end