defmodule CPSolver.Propagator.AllDifferent.DC do
use CPSolver.Propagator
alias CPSolver.ValueGraph
alias CPSolver.Propagator.AllDifferent.Utils, as: AllDiffUtils
alias Iter.Iterable
@moduledoc """
The domain-consistent propagator for AllDifferent constraint,
based on:
J.-C. Régin, A filtering algorithm for constraints of difference in CSPs
"""
@impl true
def arguments(args) do
Arrays.new(args, implementation: Aja.Vector)
end
@impl true
def variables(args) do
Enum.map(args, fn x_el -> set_propagate_on(x_el, :domain_change) end)
end
@impl true
def filter(vars, state, changes) do
state = (state &&
state
|> Map.put(:propagator_variables, vars)
|> apply_changes(changes)) || initial_state(vars)
finalize(state)
end
defp finalize(state) do
(entailed?(state) && :passive) ||
{:state, state}
end
defp entailed?(%{sccs: sccs} = _state) do
Enum.empty?(sccs)
end
def apply_changes(
%{
sccs: sccs,
value_graph: value_graph,
propagator_variables: vars
} = state, _changes
) do
state = Map.put(state, :value_graph,
set_neighbor_finder(value_graph, ValueGraph.default_neighbor_finder(vars)))
## Apply changes to affected SCCs
Enum.reduce(sccs, Map.put(state, :sccs, MapSet.new()),
fn component, state_acc ->
%{value_graph: reduced_graph, sccs: derived_sccs} = reduce_component(component, state_acc)
state_acc
|> Map.put(:value_graph, reduced_graph)
|> Map.update!(:sccs, fn existing -> MapSet.union(existing, derived_sccs) end)
end)
end
def initial_state(vars) do
%{value_graph: value_graph, left_partition: variable_vertices, fixed_matching: _fixed_matching} =
ValueGraph.build(vars, check_matching: true)
reduce_component(MapSet.new(variable_vertices, fn {:variable, var_index} -> var_index end),
value_graph, vars)
|> Map.put(:propagator_variables, vars)
end
def reduce_component(component,
%{
propagator_variables: vars,
value_graph: value_graph
} = _state) do
reduce_component(component, value_graph, vars)
end
def reduce_component(component, value_graph, vars) do
reduction(vars, value_graph, MapSet.new(component,
fn component_index -> {:variable, component_index}
end), %{})
end
def reduction(vars, value_graph, variable_vertices, fixed_matching) do
matching = find_matching(value_graph, variable_vertices, fixed_matching)
%{value_graph: _reduced_graph, sccs: _sccs} =
reduce_graph(value_graph, vars, matching)
end
def find_matching(value_graph, variable_vertices, fixed_matching) do
try do
BitGraph.Algorithms.bipartite_matching(
value_graph,
variable_vertices,
fixed_matching: fixed_matching,
required_size: MapSet.size(variable_vertices)
)
|> tap(fn matching -> matching || fail() end)
catch {:error, _} ->
fail()
end
end
def reduce_graph(value_graph, variables, %{free: free_nodes, matching: matching} = _matching_record) do
value_graph
|> build_residual_graph(variables, matching, free_nodes)
|> reduce_residual_graph(variables, matching)
|> then(fn {sccs, reduced_graph} ->
%{
sccs: sccs,
value_graph:
reduced_graph
|> remove_sink_node()
|> set_neighbor_finder(ValueGraph.default_neighbor_finder(variables))
}
end)
end
def build_residual_graph(graph, variables, matching, free_nodes) do
graph
|> add_sink_node(free_nodes)
|> then(fn g ->
set_neighbor_finder(g,
residual_graph_neighbor_finder(g, variables, matching, free_nodes)
)
end)
end
defp set_neighbor_finder(graph, neighbor_finder) do
BitGraph.set_neighbor_finder(graph, neighbor_finder)
end
defp add_sink_node(graph, free_nodes) do
Enum.empty?(free_nodes) && graph ||
BitGraph.add_vertex(graph, :sink)
end
defp remove_sink_node(graph) do
case BitGraph.V.get_vertex_index(graph, :sink) do
nil -> graph
sink_index -> BitGraph.V.delete_vertex(graph, sink_index)
end
end
defp residual_graph_neighbor_finder(value_graph, variables, matching, free_nodes) do
num_variables = ValueGraph.get_variable_count(value_graph)
base_neighbor_finder = ValueGraph.matching_neighbor_finder(value_graph, variables, matching, free_nodes)
free_node_indices = MapSet.new(free_nodes, fn value_vertex -> BitGraph.V.get_vertex_index(value_graph, value_vertex) end)
matching_value_indices = MapSet.new(Map.values(matching), fn value_vertex -> BitGraph.V.get_vertex_index(value_graph, value_vertex) end)
sink_node_index = BitGraph.V.get_vertex_index(value_graph, :sink)
fn _graph, nil, _direction ->
## "Stray" vertex index.
## This could happen if the vertex is not in the graph,
## for instance, as a result of it being removed during graph processing;
## TODO: review
MapSet.new()
graph, vertex_index, direction ->
neighbors = base_neighbor_finder.(graph, vertex_index, direction)
## By construction of value graph, the variable vertices go first,
## followed by value vertices; the last on is 'sink' vertex
cond do
vertex_index == sink_node_index && direction == :out->
matching_value_indices
vertex_index == sink_node_index && direction == :in ->
free_node_indices
vertex_index <= num_variables ->
neighbors
direction == :in && vertex_index in free_node_indices ->
neighbors
direction == :out && vertex_index in free_node_indices ->
MapSet.new([sink_node_index])
direction == :in && vertex_index in matching_value_indices ->
Iterable.append(neighbors, sink_node_index)
direction == :out && vertex_index in matching_value_indices ->
neighbors
true ->
MapSet.new()
end
end
end
def reduce_residual_graph(residual_graph, vars, matching) do
AllDiffUtils.split_to_sccs(residual_graph, Map.keys(matching),
AllDiffUtils.default_remove_edge_fun(vars))
end
defp fail() do
throw(:fail)
end
end