Skip to main content

lib/solver/search/search.ex

defmodule CPSolver.Search do
  alias CPSolver.DefaultDomain, as: Domain

  alias CPSolver.Search.VariableSelector
  alias CPSolver.Search.Partition
  alias CPSolver.Variable.Interface
  alias CPSolver.Utils.Vector
  alias CPSolver.Variables.UnfixedTracker, as: Tracker

  require Logger

  def default_strategy() do
    CPSolver.Search.DefaultBrancher
  end

  def initialize({variable_choice, value_choice} = _search, space_data) do
    {
      VariableSelector.initialize(variable_choice, space_data),
      Partition.initialize(value_choice, space_data)
    }
  end

  def initialize(brancher_impl, data) when is_atom(brancher_impl) do
    if Code.ensure_loaded(brancher_impl) == {:module, brancher_impl} &&
         function_exported?(brancher_impl, :branch, 2) do
      brancher_impl.initialize(data)
    else
      throw({:unknown_brancher, brancher_impl})
    end
  end

  def initialize(brancher_fun, space_data) when is_function(brancher_fun, 3) do
    brancher_fun.(:init, space_data, nil)
  end

  ### Helpers
  def branch(variables, branching, space_data \\ %{})

  def branch(variables, branching, space_data) do
    variables
    |> filter_fixed_variables(space_data)
    |> then(fn unfixed_vars ->
      unfixed_vars
      |> branch_impl(branching, space_data)
      |> then(fn branching ->
        branching || branch_impl(unfixed_vars, default_strategy(), space_data)
      end)
      |> partitions_impl(space_data)
    end)
  end

  defp branch_impl(variables, brancher_fun, space_data) when is_function(brancher_fun, 3) do
    brancher_fun.(:branch, variables, space_data)
  end

  defp branch_impl(variables, brancher_impl, space_data) when is_atom(brancher_impl) do
    if Code.ensure_loaded(brancher_impl) == {:module, brancher_impl} &&
         function_exported?(brancher_impl, :branch, 2) do
      brancher_impl.branch(variables, space_data)
    else
      throw({:unknown_brancher, brancher_impl})
    end
  end

  defp branch_impl(variables, {variable_choice, partition_strategy}, space_data) do
    branch_impl(variables, variable_choice, partition_strategy, space_data)
  end

  defp branch_impl(variables, variable_choice, partition_strategy, space_data) do
    branch_impl(
      variables,
      fn :branch, variables, space_data ->
        variable_value_choice(variables, variable_choice, partition_strategy, space_data)
      end,
      space_data
    )
  end

  def variable_value_choice(variables, variable_choice, partition_strategy, space_data) do
    case VariableSelector.select_variable(variables, space_data, variable_choice) do
      nil ->
        []

      selected_variable ->
        {:ok, domain_partitions} =
          Partition.partition(selected_variable, partition_strategy)

        domain_partitions
    end
  end

  defp copy_variable(%{domain: domain} = variable) do
    Map.put(variable, :domain, Domain.copy(domain))
  end

  defp filter_fixed_variables(vars, space_data) do
    tracker = space_data[:unfixed_variables_tracker]

    cond do
      is_nil(tracker) ->
        Enum.reject(vars, fn var -> Interface.fixed?(var) end)

      Tracker.empty?(tracker) ->
        throw(:all_vars_fixed)

      true ->
        ## We want to keep the original order in the list of variables
        ## for strategies that depend on it (:input_order, bin packing etc.)
        Tracker.iterate_ordered(tracker, Vector.new([]), fn idx, acc ->
          var = vars[idx - 1]

          if Interface.fixed?(var) do
            Tracker.delete(tracker, idx)
            acc
          else
            Vector.append(acc, var)
          end
        end)
    end
  end

  defp partitions_impl(nil, _space_data) do
    []
  end

  defp partitions_impl(partitions, space_data) when is_list(partitions) do
    Enum.reduce(partitions, [], fn variable_partition, acc ->
      acc ++ variable_partitions_impl(variable_partition, space_data)
    end)
  end

  ## Build partitions for a single variable
  defp variable_partitions_impl(domain_partitions, _space_data) do
    Enum.map(List.wrap(domain_partitions), fn partition ->
      build_reduction(partition)
    end)
  end

  ## Partition is a map %{var_id => reduction}
  ## `reduction is a function that takes a variable
  ## and performs domain reduction.
  ##
  defp build_reduction(partition) do
    fn variables, space_data ->
      var_array = variables ##Vector.new([])

      {_idx, variable_copies, domain_changes} =
        Vector.reduce(variables, {0, var_array, Map.new()}, fn var, {var_idx, variables_acc, changes_acc} ->
          var_copy = copy_variable(var)

          changes_acc =
            case Map.get(partition, var.id) do
              nil -> changes_acc
              reduction -> Map.put(changes_acc, var.id, reduction.(var_copy))
            end

          {
            var_idx + 1,
            Vector.update(variables_acc, var_idx, var_copy),
            changes_acc
          }
        end)

      ## Copy "unfixed variables" tracker
      tracker_copy =
        case space_data[:unfixed_variables_tracker] do
          nil -> nil
          tracker -> Tracker.copy(tracker)
        end

      %{
        variable_copies: variable_copies,
        domain_changes: domain_changes,
        unfixed_variables_tracker: tracker_copy
      }
    end
  end
end