defmodule Ash.SatSolver do
@moduledoc """
Tools for working with the satsolver that drives filter subset checking (for authorization)
This is public as a very low level toolkit for writing authorizers, but you almost certainly
do not need to look at this module.
If you are looking for information about how authorization works, see the [policy guide](/documentation/topics/security/policies.md)
"""
alias Ash.Filter
alias Ash.Query.{BooleanExpression, Not, Ref}
@dialyzer {:nowarn_function, overlap?: 2}
@typep boolean_expr ::
{:and, boolean_expr, boolean_expr}
| {:or, boolean_expr, boolean_expr}
| {:not, boolean_expr}
| Ash.Expr.t()
@doc """
Creates tuples of a boolean statement.
i.e `b(1 and 2) #=> {:and, 1, 2}`
"""
defmacro b(statement) do
value =
Macro.prewalk(
statement,
fn
{:and, _, [left, right]} ->
quote do
{:and, unquote(left), unquote(right)}
end
{:or, _, [left, right]} ->
quote do
{:or, unquote(left), unquote(right)}
end
{:not, _, [value]} ->
quote do
{:not, unquote(value)}
end
other ->
other
end
)
quote do
unquote(value)
|> Ash.SatSolver.balance()
end
end
@doc "Returns true if the candidate filter returns the same or less data than the filter"
@spec strict_filter_subset(Ash.Filter.t(), Ash.Filter.t()) :: boolean | :maybe
def strict_filter_subset(filter, candidate) do
case {filter, candidate} do
{%{expression: nil}, %{expression: nil}} ->
true
{%{expression: nil}, _candidate_expr} ->
true
{_filter_expr, %{expression: nil}} ->
false
{filter, candidate} ->
do_strict_filter_subset(filter, candidate)
end
end
@doc "Prepares a filter for comparison"
@spec transform(Ash.Resource.t(), Ash.Expr.t()) :: boolean_expr()
def transform(resource, expression) do
expression
|> consolidate_relationships(resource)
|> upgrade_related_filters_to_join_keys(resource)
|> build_expr_with_predicate_information()
end
@doc "Calls `transform/2` and solves the expression"
@spec transform_and_solve(Ash.Resource.t(), Ash.Expr.t()) ::
{:ok, [integer()]} | {:error, :unsatisfiable}
def transform_and_solve(resource, expression) do
resource
|> transform(expression)
|> to_cnf()
|> elem(0)
|> Ash.SatSolver.Implementation.solve_expression()
end
@doc false
def balance({op, left, right}) do
left = balance(left)
right = balance(right)
[left, right] = Enum.sort([left, right])
{op, left, right}
end
def balance({:not, {:not, right}}) do
balance(right)
end
def balance({:not, statement}) do
{:not, balance(statement)}
end
def balance(other), do: other
@doc "Returns `true` if the relationship paths are synonymous from a data perspective"
@spec synonymous_relationship_paths?(
Ash.Resource.t(),
[atom()],
[atom()],
Ash.Resource.t()
) :: boolean
def synonymous_relationship_paths?(
left_resource,
candidate,
search,
right_resource \\ nil
)
def synonymous_relationship_paths?(_, [], [], _), do: true
def synonymous_relationship_paths?(_, [], _, _), do: false
def synonymous_relationship_paths?(_, _, [], _), do: false
def synonymous_relationship_paths?(
left_resource,
[candidate_first | candidate_rest],
[first | rest],
right_resource
) do
right_resource = right_resource || left_resource
relationship = Ash.Resource.Info.relationship(left_resource, first)
candidate_relationship = Ash.Resource.Info.relationship(right_resource, candidate_first)
cond do
!relationship || !candidate_relationship ->
false
relationship.type == :many_to_many && candidate_relationship.type == :has_many ->
synonymous_relationship_paths?(left_resource, [relationship.join_relationship], [
candidate_first
]) && !Enum.empty?(candidate_rest) &&
synonymous_relationship_paths?(
left_resource,
candidate_rest,
rest,
right_resource
)
relationship.type == :has_many && candidate_relationship.type == :many_to_many ->
synonymous_relationship_paths?(left_resource, [relationship.name], [
candidate_relationship.join_relationship
]) && !Enum.empty?(rest) &&
synonymous_relationship_paths?(
left_resource,
candidate_rest,
rest,
right_resource
)
true ->
comparison_keys = [
:source_attribute,
:destination_attribute,
:source_attribute_on_join_resource,
:destination_attribute_on_join_resource,
:destination_attribute,
:destination,
:manual,
:sort,
:filter
]
Map.take(relationship, comparison_keys) ==
Map.take(candidate_relationship, comparison_keys) and
synonymous_relationship_paths?(relationship.destination, candidate_rest, rest)
end
end
defp do_strict_filter_subset(filter, candidate) do
filter =
Filter.map(filter, fn
%Ref{} = ref ->
%{ref | input?: false}
other ->
other
end)
candidate =
Filter.map(candidate, fn
%Ref{} = ref ->
%{ref | input?: false}
other ->
other
end)
expr = BooleanExpression.new(:and, filter.expression, candidate.expression)
case transform_and_solve(
filter.resource,
expr
) do
{:error, :unsatisfiable} ->
false
{:ok, _scenario} ->
expr = BooleanExpression.new(:and, Not.new(filter.expression), candidate.expression)
case transform_and_solve(
filter.resource,
expr
) do
{:error, :unsatisfiable} ->
true
{:ok, _scenario} ->
:maybe
end
end
end
@doc "Returns a statement expressing that the predicates are mutually exclusive."
@spec mutually_exclusive([Ash.Expr.t()]) :: boolean_expr()
def mutually_exclusive(predicates, acc \\ [])
def mutually_exclusive([], acc), do: acc
def mutually_exclusive([predicate | rest], acc) do
new_acc =
Enum.reduce(rest, acc, fn other_predicate, acc ->
[b(not (predicate and other_predicate)) | acc]
end)
mutually_exclusive(rest, new_acc)
end
@doc "Returns a statement expressing that the predicates are mutually exclusive and collectively exhaustive."
@spec mutually_exclusive_and_collectively_exhaustive([Ash.Expr.t()]) :: boolean_expr()
def mutually_exclusive_and_collectively_exhaustive([]), do: []
def mutually_exclusive_and_collectively_exhaustive([_]), do: []
def mutually_exclusive_and_collectively_exhaustive(predicates) do
mutually_exclusive(predicates) ++
Enum.flat_map(predicates, fn predicate ->
other_predicates = Enum.reject(predicates, &(&1 == predicate))
other_predicates_union =
Enum.reduce(other_predicates, nil, fn other_predicate, expr ->
if expr do
b(expr or other_predicate)
else
other_predicate
end
end)
b(
not (predicate and other_predicates_union) and
not (not predicate and not other_predicates_union)
)
end)
end
@doc "Returns `b(not (left and right))`"
@spec left_excludes_right(Ash.Expr.t(), Ash.Expr.t()) :: boolean_expr()
def left_excludes_right(left, right) do
b(not (left and right))
end
@doc "Returns `b(not (right and left))`"
@spec right_excludes_left(Ash.Expr.t(), Ash.Expr.t()) :: boolean_expr()
def right_excludes_left(left, right) do
b(not (right and left))
end
@doc "Returns a statement expressing that the predicates are mutually inclusive"
@spec mutually_inclusive([Ash.Expr.t()]) :: boolean_expr()
def mutually_inclusive(predicates, acc \\ [])
def mutually_inclusive([], acc), do: acc
def mutually_inclusive([predicate | rest], acc) do
new_acc =
Enum.reduce(rest, acc, fn other_predicate, acc ->
[b((predicate and other_predicate) or (not predicate and not other_predicate)) | acc]
end)
mutually_exclusive(rest, new_acc)
end
@doc "Returns `b(not (right and not left))`"
@spec right_implies_left(Ash.Expr.t(), Ash.Expr.t()) :: boolean_expr()
def right_implies_left(left, right) do
b(not (right and not left))
end
@doc "Returns `b(not (left and not right))`"
def left_implies_right(left, right) do
b(not (left and not right))
end
defp shares_ref?(left, right) do
any_refs_in_common?(refs(left), refs(right))
end
defp any_refs_in_common?(left_refs, right_refs) do
Enum.any?(left_refs, &(&1 in right_refs))
end
defp refs(%{__operator__?: true, left: left, right: right}) do
Enum.filter([left, right], &match?(%Ref{}, &1)) |> Enum.map(&Map.put(&1, :input?, false))
end
defp refs(%{__function__?: true, arguments: arguments}) do
Enum.filter(arguments, &match?(%Ref{}, &1)) |> Enum.map(&Map.put(&1, :input?, false))
end
defp refs(_), do: []
defp simplify(%BooleanExpression{op: op, left: left, right: right}) do
BooleanExpression.new(op, simplify(left), simplify(right))
end
defp simplify(%Not{expression: expression}) do
Not.new(simplify(expression))
end
defp simplify(%Ash.Query.Exists{expr: expr} = exists) do
%{exists | expr: simplify(expr)}
end
defp simplify(%mod{__predicate__?: true} = predicate) do
if :erlang.function_exported(mod, :simplify, 1) do
predicate
|> mod.simplify()
|> Kernel.||(predicate)
else
predicate
end
end
defp simplify(other), do: other
@doc """
Transforms a statement to Conjunctive Normal Form(CNF), as lists of lists of integers.
"""
def to_cnf(expression) do
expression_with_constants = b(true and not false and expression)
{bindings, expression} = extract_bindings(expression_with_constants)
expression
|> to_conjunctive_normal_form()
|> lift_clauses()
|> negations_to_negative_numbers()
|> Enum.map(fn scenario ->
Enum.sort_by(scenario, fn item ->
{abs(item), item}
end)
end)
|> group_predicates(bindings)
|> rebind()
|> unique_clauses()
end
defp unique_clauses({clauses, bindings}) do
{Enum.uniq(clauses), bindings}
end
defp group_predicates(expression, bindings) do
case expression do
[_] ->
{expression, bindings}
scenarios ->
Enum.reduce(scenarios, {[], bindings}, fn scenario, {new_scenarios, bindings} ->
{scenario, bindings} = group_scenario_predicates(scenario, scenarios, bindings)
{[scenario | new_scenarios], bindings}
end)
end
end
defp group_scenario_predicates(scenario, all_scenarios, bindings) do
scenario
|> Ash.SatSolver.Utils.ordered_sublists()
|> Enum.filter(&can_be_used_as_group?(&1, all_scenarios, bindings))
|> Enum.sort_by(&length/1)
|> remove_overlapping()
|> Enum.reduce({scenario, bindings}, fn group, {scenario, bindings} ->
bindings = add_group_binding(bindings, group)
{Ash.SatSolver.Utils.replace_ordered_sublist(scenario, group, bindings[:groups][group]),
bindings}
end)
end
defp remove_overlapping([]), do: []
defp remove_overlapping([item | rest]) do
if Enum.any?(item, fn n ->
Enum.any?(rest, &(n in &1 or -n in &1))
end) do
remove_overlapping(rest)
else
[item | remove_overlapping(rest)]
end
end
@doc false
@spec unbind([[integer()]], map()) :: {[[integer()]], map()}
def unbind(expression, %{temp_bindings: temp_bindings, old_bindings: old_bindings}) do
expression =
Enum.flat_map(expression, fn statement ->
Enum.flat_map(statement, fn var ->
neg? = var < 0
old_binding = temp_bindings[abs(var)]
case old_bindings[:reverse_groups][old_binding] do
nil ->
if neg? do
[-old_binding]
else
[old_binding]
end
group ->
if neg? do
Enum.map(group, &(-&1))
else
[{:expand, group}]
end
end
end)
|> expand_groups()
end)
{expression, old_bindings}
end
defp expand_groups(expression) do
do_expand_groups(expression)
end
defp do_expand_groups([]), do: [[]]
defp do_expand_groups([{:expand, group} | rest]) do
Enum.flat_map(group, fn var ->
Enum.map(do_expand_groups(rest), fn future ->
[var | future]
end)
end)
end
defp do_expand_groups([var | rest]) do
Enum.map(do_expand_groups(rest), fn future ->
[var | future]
end)
end
defp rebind({expression, bindings}) do
{expression, temp_bindings} =
Enum.reduce(expression, {[], %{current: 0}}, fn statement, {statements, acc} ->
{statement, acc} =
Enum.reduce(statement, {[], acc}, fn var, {statement, acc} ->
case acc[:reverse][abs(var)] do
nil ->
binding = acc.current + 1
value =
if var < 0 do
-binding
else
binding
end
{[value | statement],
acc
|> Map.put(:current, binding)
|> Map.update(:reverse, %{abs(var) => binding}, &Map.put(&1, abs(var), binding))
|> Map.put(binding, abs(var))}
value ->
value =
if var < 0 do
-value
else
value
end
{[value | statement], acc}
end
end)
{[Enum.reverse(statement) | statements], acc}
end)
bindings_with_old_bindings = %{temp_bindings: temp_bindings, old_bindings: bindings}
{expression, bindings_with_old_bindings}
end
defp can_be_used_as_group?(group, scenarios, bindings) do
Map.has_key?(bindings[:groups] || %{}, group) ||
Enum.all?(scenarios, fn scenario ->
has_no_overlap?(scenario, group) || group_in_scenario?(scenario, group)
end)
end
defp has_no_overlap?(scenario, group) do
not Enum.any?(group, fn group_predicate ->
Enum.any?(scenario, fn scenario_predicate ->
abs(group_predicate) == abs(scenario_predicate)
end)
end)
end
defp group_in_scenario?(scenario, group) do
Ash.SatSolver.Utils.is_ordered_sublist_of?(group, scenario)
end
defp add_group_binding(bindings, group) do
if bindings[:groups][group] do
bindings
else
binding = bindings[:current]
bindings
|> Map.put_new(:reverse_groups, %{})
|> Map.update!(:reverse_groups, &Map.put(&1, binding, group))
|> Map.put_new(:groups, %{})
|> Map.update!(:groups, &Map.put(&1, group, binding))
|> Map.put(:current, binding + 1)
end
end
@doc false
@spec solutions_to_predicate_values([integer()], map()) :: map()
def solutions_to_predicate_values(solution, bindings) do
Enum.reduce(solution, %{true: [], false: []}, fn var, state ->
fact = Map.get(bindings, abs(var))
if is_nil(fact) do
raise Ash.Error.Framework.AssumptionFailed.exception(
message: """
A fact from the SAT solver had no corresponding bound fact:
Bindings:
#{inspect(bindings)}
Missing:
#{inspect(var)}
"""
)
end
Map.put(state, fact, var > 0)
end)
end
defp extract_bindings(expr, bindings \\ %{current: 1})
defp extract_bindings({operator, left, right}, bindings) do
{bindings, left_extracted} = extract_bindings(left, bindings)
{bindings, right_extracted} = extract_bindings(right, bindings)
{bindings, {operator, left_extracted, right_extracted}}
end
defp extract_bindings({:not, value}, bindings) do
{bindings, extracted} = extract_bindings(value, bindings)
{bindings, b(not extracted)}
end
defp extract_bindings(value, %{current: current} = bindings) do
current_binding =
Enum.find(bindings, fn {key, binding_value} ->
key != :current && binding_value == value
end)
case current_binding do
nil ->
new_bindings =
bindings
|> Map.put(:current, current + 1)
|> Map.put(current, value)
{new_bindings, current}
{binding, _} ->
{bindings, binding}
end
end
# A helper function for formatting to the same output we'd give to picosat
@doc false
def to_picosat(clauses, variable_count) do
clause_count = Enum.count(clauses)
formatted_input =
Enum.map_join(clauses, "\n", fn clause ->
format_clause(clause) <> " 0"
end)
"p cnf #{variable_count} #{clause_count}\n" <> formatted_input
end
defp negations_to_negative_numbers(clauses) do
Enum.map(
clauses,
fn
{:not, var} when is_integer(var) ->
[negate_var(var)]
var when is_integer(var) ->
[var]
clause ->
Enum.map(clause, fn
{:not, var} -> negate_var(var)
var -> var
end)
end
)
end
defp negate_var(var, multiplier \\ -1)
defp negate_var({:not, value}, multiplier) do
negate_var(value, multiplier * -1)
end
defp negate_var(value, multiplier), do: value * multiplier
defp format_clause(clause) do
Enum.map_join(clause, " ", fn
{:not, var} -> "-#{var}"
var -> "#{var}"
end)
end
defp lift_clauses({:and, left, right}) do
lift_clauses(left) ++ lift_clauses(right)
end
defp lift_clauses({:or, left, right}) do
[lift_or_clauses(left) ++ lift_or_clauses(right)]
end
defp lift_clauses(value), do: [[value]]
defp lift_or_clauses({:or, left, right}) do
lift_or_clauses(left) ++ lift_or_clauses(right)
end
defp lift_or_clauses(value), do: [value]
defp to_conjunctive_normal_form(expression) do
expression
|> demorgans_law()
|> distributive_law()
end
defp distributive_law(expression) do
distributive_law_applied = apply_distributive_law(expression)
if expression == distributive_law_applied do
expression
else
distributive_law(distributive_law_applied)
end
end
defp apply_distributive_law({:or, left, {:and, right1, right2}}) do
left_distributed = apply_distributive_law(left)
{:and, {:or, left_distributed, apply_distributive_law(right1)},
{:or, left_distributed, apply_distributive_law(right2)}}
end
defp apply_distributive_law({:or, {:and, left1, left2}, right}) do
right_distributed = apply_distributive_law(right)
{:and, {:or, apply_distributive_law(left1), right_distributed},
{:or, apply_distributive_law(left2), right_distributed}}
end
defp apply_distributive_law({:not, expression}) do
{:not, apply_distributive_law(expression)}
end
defp apply_distributive_law({operator, left, right}) when operator in [:and, :or] do
{operator, apply_distributive_law(left), apply_distributive_law(right)}
end
defp apply_distributive_law(var) when is_integer(var) do
var
end
defp demorgans_law(expression) do
demorgans_law_applied = apply_demorgans_law(expression)
if expression == demorgans_law_applied do
expression
else
demorgans_law(demorgans_law_applied)
end
end
defp apply_demorgans_law({:not, {:and, left, right}}) do
{:or, {:not, apply_demorgans_law(left)}, {:not, apply_demorgans_law(right)}}
end
defp apply_demorgans_law({:not, {:or, left, right}}) do
{:and, {:not, left}, {:not, right}}
end
defp apply_demorgans_law({operator, left, right}) when operator in [:or, :and] do
{operator, apply_demorgans_law(left), apply_demorgans_law(right)}
end
defp apply_demorgans_law({:not, expression}) do
{:not, apply_demorgans_law(expression)}
end
defp apply_demorgans_law(var) when is_integer(var) do
var
end
defp filter_to_expr(nil), do: nil
defp filter_to_expr(false), do: false
defp filter_to_expr(true), do: true
defp filter_to_expr(%Filter{expression: expression}), do: filter_to_expr(expression)
defp filter_to_expr(%{__predicate__?: _} = op_or_func), do: op_or_func
defp filter_to_expr(%Ash.Query.Exists{} = exists), do: exists
defp filter_to_expr(%Ash.Query.Parent{} = parent), do: parent
defp filter_to_expr(%Ash.CustomExpression{expression: expression}), do: expression
defp filter_to_expr(%Not{expression: expression}), do: b(not filter_to_expr(expression))
defp filter_to_expr(%BooleanExpression{op: op, left: left, right: right}) do
{op, filter_to_expr(left), filter_to_expr(right)}
end
defp filter_to_expr(expr) do
raise ArgumentError, message: "Invalid filter expression #{inspect(expr)}"
end
defp upgrade_related_filters_to_join_keys(
%BooleanExpression{op: op, left: left, right: right},
resource
) do
BooleanExpression.new(
op,
upgrade_related_filters_to_join_keys(left, resource),
upgrade_related_filters_to_join_keys(right, resource)
)
end
defp upgrade_related_filters_to_join_keys(%Not{expression: expression}, resource) do
Not.new(upgrade_related_filters_to_join_keys(expression, resource))
end
defp upgrade_related_filters_to_join_keys(
%Ash.Query.Exists{path: path, expr: expr} = exists,
resource
) do
related = Ash.Resource.Info.related(resource, path)
%{exists | expr: upgrade_related_filters_to_join_keys(expr, related)}
end
defp upgrade_related_filters_to_join_keys(
%{__operator__?: true, left: left, right: right} = op,
resource
) do
%{op | left: upgrade_ref(left, resource), right: upgrade_ref(right, resource)}
end
defp upgrade_related_filters_to_join_keys(
%{__function__?: true, arguments: arguments} = function,
resource
) do
%{function | arguments: Enum.map(arguments, &upgrade_ref(&1, resource))}
end
defp upgrade_related_filters_to_join_keys(expr, _), do: expr
defp upgrade_ref({key, ref}, resource) when is_atom(key) do
{key, upgrade_ref(ref, resource)}
end
defp upgrade_ref(
%Ash.Query.Ref{attribute: attribute, relationship_path: path} = ref,
resource
)
when path != [] do
with relationship when not is_nil(relationship) <-
Ash.Resource.Info.relationship(resource, path),
true <- attribute.name == relationship.destination_attribute,
new_attribute when not is_nil(new_attribute) <-
Ash.Resource.Info.attribute(relationship.source, relationship.source_attribute) do
%{
ref
| relationship_path: :lists.droplast(path),
attribute: new_attribute,
resource: resource
}
else
_ ->
ref
end
end
defp upgrade_ref(other, _), do: other
defp consolidate_relationships(expression, resource) do
{replacements, _all_relationship_paths} =
expression
|> Filter.relationship_paths(true)
|> Enum.uniq()
|> Enum.reduce({%{}, []}, fn path, {replacements, kept_paths} ->
case find_synonymous_relationship_path(resource, kept_paths, path) do
nil ->
{replacements, [path | kept_paths]}
synonymous_path ->
Map.put(replacements, path, synonymous_path)
end
end)
do_consolidate_relationships(expression, replacements, resource)
end
defp do_consolidate_relationships(
%BooleanExpression{op: op, left: left, right: right},
replacements,
resource
) do
BooleanExpression.new(
op,
do_consolidate_relationships(left, replacements, resource),
do_consolidate_relationships(right, replacements, resource)
)
end
defp do_consolidate_relationships(%Not{expression: expression}, replacements, resource) do
Not.new(do_consolidate_relationships(expression, replacements, resource))
end
defp do_consolidate_relationships(
%Ash.Query.Exists{at_path: at_path, path: path, expr: expr} = exists,
replacements,
resource
) do
exists =
case Map.fetch(replacements, at_path) do
{:ok, replacement} when not is_nil(replacement) ->
%{exists | at_path: replacement}
:error ->
exists
end
related = Ash.Resource.Info.related(resource, at_path)
{replacements, _all_relationship_paths} =
expr
|> Filter.relationship_paths(true)
|> Enum.uniq()
|> Enum.reduce({%{}, []}, fn path, {replacements, kept_paths} ->
case find_synonymous_relationship_path(related, kept_paths, path) do
nil ->
{replacements, [path | kept_paths]}
synonymous_path ->
Map.put(replacements, path, synonymous_path)
end
end)
exists =
case Map.fetch(replacements, path) do
{:ok, replacement} when not is_nil(replacement) ->
%{exists | path: replacement}
:error ->
exists
end
full_related = Ash.Resource.Info.related(related, path)
%{exists | expr: consolidate_relationships(expr, full_related)}
end
defp do_consolidate_relationships(
%Ash.Query.Ref{relationship_path: path} = ref,
replacements,
_resource
)
when path != [] do
case Map.fetch(replacements, path) do
{:ok, replacement} when not is_nil(replacement) -> %{ref | relationship_path: replacement}
:error -> ref
end
end
defp do_consolidate_relationships(
%{__function__?: true, arguments: args} = func,
replacements,
resource
) do
%{func | arguments: Enum.map(args, &do_consolidate_relationships(&1, replacements, resource))}
end
defp do_consolidate_relationships(
%{__operator__?: true, left: left, right: right} = op,
replacements,
resource
) do
%{
op
| left: do_consolidate_relationships(left, replacements, resource),
right: do_consolidate_relationships(right, replacements, resource)
}
end
defp do_consolidate_relationships(other, _, _), do: other
defp find_synonymous_relationship_path(resource, paths, path) do
Enum.find_value(paths, fn candidate_path ->
if synonymous_relationship_paths?(resource, candidate_path, path) do
candidate_path
else
false
end
end)
end
defp build_expr_with_predicate_information(expression) do
expression = fully_simplify(expression)
all_predicates =
expression
|> Filter.list_predicates()
|> Enum.uniq()
comparison_expressions =
all_predicates
|> Enum.filter(fn %module{} ->
:erlang.function_exported(module, :compare, 2)
end)
|> Enum.reduce([], fn predicate, new_expressions ->
all_predicates
|> Enum.reject(&Kernel.==(&1, predicate))
|> Enum.filter(&shares_ref?(&1, predicate))
|> Enum.reduce(new_expressions, fn other_predicate, new_expressions ->
# With predicate as a and other_predicate as b
case Ash.Filter.Predicate.compare(predicate, other_predicate) do
:right_includes_left ->
# b || !a
[b(other_predicate or not predicate) | new_expressions]
:left_includes_right ->
# a || ! b
[b(predicate or not other_predicate) | new_expressions]
:mutually_inclusive ->
# (a && b) || (! a && ! b)
[
b((predicate and other_predicate) or (not predicate and not other_predicate))
| new_expressions
]
:mutually_exclusive ->
[b(not (other_predicate and predicate)) | new_expressions]
:mutually_exclusive_and_collectively_exhaustive ->
[
b(
not (other_predicate and predicate) and
not (not other_predicate and not predicate)
)
| new_expressions
]
_other ->
# If we can't tell, we assume that both could be true
new_expressions
end
end)
end)
|> Enum.uniq()
expression = filter_to_expr(expression)
expression_with_comparisons =
Enum.reduce(comparison_expressions, expression, fn comparison_expression, expression ->
b(comparison_expression and expression)
end)
all_predicates
|> Enum.map(& &1.__struct__)
|> Enum.uniq()
|> Enum.flat_map(fn struct ->
if :erlang.function_exported(struct, :bulk_compare, 1) do
struct.bulk_compare(all_predicates)
else
[]
end
end)
|> Enum.reduce(expression_with_comparisons, fn comparison_expression, expression ->
b(comparison_expression and expression)
end)
end
defp fully_simplify(expression) do
expression
|> do_fully_simplify()
|> lift_equals_out_of_in()
|> do_fully_simplify()
end
defp do_fully_simplify(expression) do
expression
|> simplify()
|> case do
^expression ->
expression
simplified ->
fully_simplify(simplified)
end
end
defp lift_equals_out_of_in(expression) do
case find_non_equal_overlap(expression) do
nil ->
expression
non_equal_overlap ->
expression
|> split_in_expressions(non_equal_overlap)
|> lift_equals_out_of_in()
end
end
defp find_non_equal_overlap(expression) do
Ash.Filter.find(expression, fn sub_expr ->
Ash.Filter.find(expression, fn sub_expr2 ->
# if has_call_or_expression?(sub_expr) || has_call_or_expression?(sub_expr2) do
# false
# else
overlap?(sub_expr, sub_expr2)
# end
end)
end)
end
defp new_in(base, right) do
case MapSet.size(right) do
1 ->
%Ash.Query.Operator.Eq{left: base.left, right: Enum.at(right, 0)}
_ ->
%Ash.Query.Operator.In{left: base.left, right: right}
end
end
defp split_in_expressions(
%Ash.Query.Operator.In{right: right} = sub_expr,
%Ash.Query.Operator.Eq{right: value} = non_equal_overlap
) do
if overlap?(non_equal_overlap, sub_expr) do
Ash.Query.BooleanExpression.new(
:or,
new_in(sub_expr, MapSet.delete(right, value)),
non_equal_overlap
)
else
sub_expr
end
end
defp split_in_expressions(
%Ash.Query.Operator.In{} = sub_expr,
%Ash.Query.Operator.In{right: right} = non_equal_overlap
) do
if overlap?(sub_expr, non_equal_overlap) do
diff = MapSet.difference(sub_expr.right, right)
if MapSet.size(diff) == 0 do
Enum.reduce(sub_expr.right, nil, fn var, acc ->
BooleanExpression.new(:or, %Ash.Query.Operator.Eq{left: sub_expr.left, right: var}, acc)
end)
else
new_right = new_in(sub_expr, MapSet.intersection(sub_expr.right, right))
Ash.Query.BooleanExpression.new(
:or,
new_in(sub_expr, diff),
new_right
)
end
else
sub_expr
end
end
defp split_in_expressions(nil, _), do: nil
defp split_in_expressions(%Ash.Filter{expression: expression} = filter, non_equal_overlap),
do: %{filter | expression: split_in_expressions(expression, non_equal_overlap)}
defp split_in_expressions(%Not{expression: expression} = not_expr, non_equal_overlap),
do: %{not_expr | expression: split_in_expressions(expression, non_equal_overlap)}
defp split_in_expressions(
%BooleanExpression{left: left, right: right} = expr,
non_equal_overlap
),
do: %{
expr
| left: split_in_expressions(left, non_equal_overlap),
right: split_in_expressions(right, non_equal_overlap)
}
defp split_in_expressions(other, _), do: other
defp overlap?(
%Ash.Query.Operator.In{left: left, right: %MapSet{} = left_right},
%Ash.Query.Operator.In{left: left, right: %MapSet{} = right_right}
) do
if MapSet.equal?(left_right, right_right) do
false
else
overlap? =
left_right
|> MapSet.intersection(right_right)
|> MapSet.size()
|> Kernel.>(0)
if overlap? do
true
else
false
end
end
end
defp overlap?(_, %Ash.Query.Operator.Eq{right: %Ref{}}),
do: false
defp overlap?(%Ash.Query.Operator.Eq{right: %Ref{}}, _),
do: false
defp overlap?(
%Ash.Query.Operator.Eq{left: left, right: left_right},
%Ash.Query.Operator.In{left: left, right: %MapSet{} = right_right}
) do
MapSet.member?(right_right, left_right)
end
defp overlap?(_left, _right) do
false
end
end