lib/solver/constraints/propagators/all_different/all_different_fwc.ex

defmodule CPSolver.Propagator.AllDifferent.FWC do
  use CPSolver.Propagator

  @moduledoc """
  The forward-checking propagator for AllDifferent constraint.
  """

  @impl true
  def arguments(args) do
    Arrays.new(args, implementation: Aja.Vector)
  end

  @impl true
  def variables(args) do
    Enum.map(args, fn x_el -> set_propagate_on(x_el, :fixed) end)
  end

  @impl true
  def filter(all_vars, state, changes) do
    new_fixed = Map.keys(changes) |> MapSet.new()

    {unresolved, fixed} =
      (state &&
         {state[:unresolved] |> MapSet.difference(new_fixed), fixed_values(all_vars, new_fixed)}) ||
        initial_split(all_vars)

    case fwc(all_vars, unresolved, fixed) do
      false -> :passive
      unfixed_updated_set -> {:state, %{unresolved: unfixed_updated_set}}
    end
  end

  defp fixed_values(vars, fixed) do
    Enum.reduce(fixed, MapSet.new(), fn idx, values_acc ->
      val = Propagator.arg_at(vars, idx) |> min()
      (val in values_acc && fail()) || MapSet.put(values_acc, val)
    end)
  end

  defp initial_split(vars) do
    Enum.reduce(0..(Arrays.size(vars) - 1), {MapSet.new(), MapSet.new()}, fn idx,
                                                                             {unfixed_acc,
                                                                              fixed_vals_acc} ->
      var = Propagator.arg_at(vars, idx)

      if fixed?(var) do
        val = min(var)
        (val in fixed_vals_acc && fail()) || {unfixed_acc, MapSet.put(fixed_vals_acc, val)}
      else
        {MapSet.put(unfixed_acc, idx), fixed_vals_acc}
      end
    end)
  end

  defp fwc(vars, unfixed_set, fixed_values) do
    {updated_unfixed, updated_fixed_vals} = remove_values(vars, unfixed_set, fixed_values)

    fv_diff = MapSet.difference(updated_fixed_vals, fixed_values)

    (MapSet.size(fv_diff) > 0 && MapSet.size(updated_unfixed) > 1 &&
       fwc(vars, updated_unfixed, fv_diff)) ||
      updated_unfixed
  end

  ## unfixed_set - set of indices for yet unfixed variables
  ## fixed_values - the set of fixed values we will use to reduce unfixed set.
  defp remove_values(vars, unfixed_set, fixed_values) do
    for idx <- unfixed_set, reduce: {MapSet.new(), fixed_values} do
      {still_unfixed_acc, fixed_vals_acc} ->
        var = Propagator.arg_at(vars, idx)

        case remove_all(var, fixed_vals_acc) do
          false ->
            ## Variable is still unfixed, keep it
            {MapSet.put(still_unfixed_acc, idx), fixed_vals_acc}

          new_fixed_value ->
            fixed_vals_acc = MapSet.put(fixed_vals_acc, new_fixed_value)

            {unfixed_here, fixed_here} =
              remove_values(vars, still_unfixed_acc, MapSet.new([new_fixed_value]))

            {unfixed_here, MapSet.union(fixed_here, fixed_vals_acc)}
        end
    end
  end

  defp remove_all(var, values) do
    Enum.reduce_while(values, false, fn val, acc ->
      if remove(var, val) == :fixed do
        {:halt, :fixed}
      else
        {:cont, acc}
      end
    end)
    |> case do
      false ->
        fixed?(var) && min(var)

      :fixed ->
        min(var)
    end
    |> then(fn new_min -> new_min && ((new_min in values && fail()) || new_min) end)
  end

  defp fail() do
    throw(:fail)
  end
end