lib/solver/constraints/propagators/element_var.ex

defmodule CPSolver.Propagator.ElementVar do
  use CPSolver.Propagator

  @moduledoc """
  The propagator for Element constraint.
  array[index] = value,
  where array is an array of variables.
  """
  def new(var_array, var_index, var_value) do
    new([var_array, var_index, var_value])
  end

  @impl true
  def arguments([var_array, var_index, var_value]) do
    [Arrays.new(var_array, implementation: Aja.Vector), var_index, var_value]
  end

  @impl true
  def bind(%{args: [var_array, var_index, var_value] = _args} = propagator, source, var_field) do
    bound_args =
      [
        Arrays.map(var_array, fn var -> Propagator.bind_to_variable(var, source, var_field) end),
        Propagator.bind_to_variable(var_index, source, var_field),
        Propagator.bind_to_variable(var_value, source, var_field)
      ]

    Map.put(propagator, :args, bound_args)
  end

  @impl true
  def variables([var_array, var_index, var_value]) do
    Enum.map(var_array, fn var ->
      set_propagate_on(var, :fixed)
    end) ++
      [
        set_propagate_on(var_index, :domain_change),
        set_propagate_on(var_value, :domain_change)
      ]
  end

  defp initial_reduction([], _var_index, _var_value, _state, _changes) do
    throw(:fail)
  end

  defp initial_reduction(var_array, var_index, var_value, state, changes) do
    # var_index is an index in array2d,
    # so we trim D(var_index) to the size of array (0-based).
    removeBelow(var_index, 0)
    removeAbove(var_index, Arrays.size(var_array) - 1)
    reduction(var_array, var_index, var_value, state, changes)
  end

  @impl true
  def filter([var_array, var_index, var_value] = args, state, changes) do
    new_state = state || %{var_index_position: Arrays.size(var_array)}

    (state && filter_impl(var_array, var_index, var_value, new_state, changes)) ||
      initial_reduction(var_array, var_index, var_value, new_state, changes)

    (passive?(args) && :passive) || {:state, new_state}
  end

  defp filter_impl(
         var_array,
         var_index,
         var_value,
         %{var_index_position: idx_position} = state,
         changes
       ) do
    ## Run reduction when either of index or value variables are fixed
    map_size(changes) > 0 &&
      (Map.has_key?(changes, idx_position) || Map.has_key?(changes, idx_position + 1)) &&
      reduction(var_array, var_index, var_value, state, changes)
  end

  defp reduction(var_array, var_index, var_value, _state, _changes) do
    index_domain = domain_values(var_index)

    # Step 1
    ## For all variables in var_array, if no values in D(var_value)
    ## present in their domains, then the corresponding index has to be removed.
    value_domain = domain_values(var_value)

    total_value_intersection =
      Enum.reduce(index_domain, MapSet.new(), fn idx, intersection_acc ->
        case Arrays.get(var_array, idx) do
          nil ->
            IO.inspect("Unexpected: no element at #{idx}")
            throw(:unexpected_no_element)

          elem_var ->
            value_elem_intersection = reduce_element_domain(value_domain, elem_var)

            (MapSet.size(value_elem_intersection) == 0 && remove(var_index, idx) &&
               intersection_acc) ||
              MapSet.union(value_elem_intersection, intersection_acc)
        end
      end)

    ## Step 2
    ## `total_value_intersection` has domain values from D(var_value)
    ## such that each of them is present in at least one domain of variables
    ## of `var_array`
    ## Hence, we can remove values that are not in `total_value_intersection` from
    ## D(var_value)

    Enum.each(value_domain, fn val ->
      !MapSet.member?(total_value_intersection, val) && remove(var_value, val)
    end)
  end

  defp reduce_element_domain(value_domain, element_var) do
    element_domain = domain_values(element_var)
    values_to_remove = MapSet.difference(value_domain, element_domain)

    updated_element_domain =
      if MapSet.size(values_to_remove) == 0 do
        element_domain
      else
        Enum.reduce(values_to_remove, element_domain, fn val, domain_acc ->
          remove(element_var, val)
          MapSet.delete(domain_acc, val)
        end)
      end

    MapSet.intersection(updated_element_domain, value_domain)
  end

  defp passive?([var_array, var_index, var_value] = _args) do
    (fixed?(var_index) && fixed?(var_value))
    |> tap(fn fixed? ->
      fixed? && fix(Propagator.arg_at(var_array, min(var_index)), min(var_value))
    end)
  end
end