lib/solver/domain/bitvector_domain.ex

defmodule CPSolver.BitVectorDomain do
  import Bitwise

  @failure_value (1 <<< 64) - 1
  @max64_value 1 <<< 64

  def new([]) do
    fail()
  end

  def new(value) when is_integer(value) do
    new([value])
  end

  def new(domain) when is_integer(domain) do
    new([domain])
  end

  def new({{:bit_vector, _ref} = _bitmap, _offset} = domain) do
    domain
  end

  def new(domain) do
    offset = -Enum.min(domain)
    domain_size = Enum.max(domain) + offset + 1
    bv = :bit_vector.new(domain_size)
    Enum.each(domain, fn idx -> :bit_vector.set(bv, idx + offset) end)

    PackedMinMax.set_min(0, 0)
    |> PackedMinMax.set_max(Enum.max(domain) + offset)
    |> then(fn min_max -> set_min_max(bv, min_max) end)

    {bv, offset}
  end

  def copy({{:bit_vector, ref} = bit_vector, offset} = _domain) do
    %{
      min_addr: %{block: current_min_block},
      max_addr: %{block: current_max_block}
    } = get_bound_addrs(bit_vector)

    new_atomics_size = current_max_block + 1
    new_atomics_ref = :atomics.new(new_atomics_size, [{:signed, false}])

    Enum.each(
      current_min_block..current_max_block,
      fn block_idx ->
        block_val = :atomics.get(ref, block_idx)
        :atomics.put(new_atomics_ref, block_idx, block_val)
      end
    )

    new_bit_vector = {:bit_vector, new_atomics_ref}
    set_min_max(new_bit_vector, get_min_max_impl(bit_vector) |> elem(1))
    {new_bit_vector, offset}
  end

  def map(domain, mapper_fun) when is_function(mapper_fun) do
    to_list(domain, mapper_fun)
  end

  ## Reduce over domain values
  def reduce( {{:bit_vector, ref} = bit_vector, offset} = domain, value_mapper_fun, reduce_fun \\ &MapSet.union/2, acc_init \\ MapSet.new()) do
    %{
      min_addr: %{block: current_min_block, offset: _min_offset},
      max_addr: %{block: current_max_block, offset: _max_offset}
    } = get_bound_addrs(bit_vector)

    mapped_lb = value_mapper_fun.(min(domain))
    mapped_ub = value_mapper_fun.(max(domain))

    {lb, ub} = (mapped_lb <= mapped_ub && {mapped_lb, mapped_ub}) || {mapped_ub, mapped_lb}

    Enum.reduce(current_min_block..current_max_block, acc_init, fn idx, acc ->
      n = :atomics.get(ref, idx)

      if n == 0 do
        acc
      else
        reduce_fun.(
          acc,
          bit_positions(n, fn val -> {lb, ub, value_mapper_fun.(val + 64 * (idx - 1) - offset)} end)
        )
      end
    end)
  end

  def to_list(
        domain, value_mapper_fun \\ &Function.identity/1
      ) do
        reduce(domain, value_mapper_fun, &MapSet.union/2, MapSet.new())
  end

  def fixed?({bit_vector, _offset} = _domain) do
    {current_min_max, _min_max_idx, current_min, current_max} = get_min_max(bit_vector)
    current_max == current_min && current_min_max != @failure_value
  end

  def failed?({:bit_vector, _ref} = bit_vector) do
    failed?(elem(get_min_max_impl(bit_vector), 1))
  end

  def failed?({bit_vector, _offset} = _domain) do
    failed?(bit_vector)
  end

  def failed?(min_max_value) when is_integer(min_max_value) do
    min_max_value == @failure_value
  end

  def min({bit_vector, offset} = _domain) do
    get_min(bit_vector) - offset
  end

  def max({bit_vector, offset} = _domain) do
    get_max(bit_vector) - offset
  end

  def size({{:bit_vector, ref} = bit_vector, _offset}) do
    %{
      min_addr: %{block: current_min_block, offset: min_offset},
      max_addr: %{block: current_max_block, offset: max_offset}
    } = get_bound_addrs(bit_vector)

    Enum.reduce(current_min_block..current_max_block, 0, fn idx, acc ->
      n = :atomics.get(ref, idx)

      if n == 0 do
        acc
      else
        n1 = (idx == current_min_block && n >>> min_offset) || n
        n2 = (idx == current_max_block && ((1 <<< (max_offset + 1)) - 1 &&& n1)) || n1
        acc + bit_count(n2)
      end
    end)
  end

  def contains?({{:bit_vector, _ref} = bit_vector, offset}, value) do
    {_current_min_max, _min_max_idx, min_value, max_value} = get_min_max(bit_vector)
    vector_value = value + offset
    contains?(bit_vector, vector_value, min_value, max_value)
  end

  def contains?(bit_vector, vector_value, min_value, max_value) do
    vector_value >= min_value && vector_value <= max_value &&
      :bit_vector.get(bit_vector, vector_value) == 1
  end

  def fix({bit_vector, offset} = _domain, value) do
    min_max_info =
      {_current_min_max, _min_max_idx, min_value, max_value} = get_min_max(bit_vector)

    vector_value = value + offset

    if contains?(bit_vector, vector_value, min_value, max_value) do
      set_fixed(bit_vector, value + offset, min_max_info)
    else
      fail(bit_vector)
    end
  end

  def remove({bit_vector, offset} = domain, value) do
    {_current_min_max, _min_max_idx, min_value, max_value} = get_min_max(bit_vector)
    vector_value = value + offset

    cond do
      ## No value in the domain, do nothing
      contains?(bit_vector, vector_value, min_value, max_value) ->
        domain_change =
          cond do
            min_value == max_value && vector_value == min_value ->
              ## Fixed value: fail on removing attempt
              fail(bit_vector)

            min_value == vector_value ->
              tighten_min(bit_vector, min_value, max_value)

            max_value == vector_value ->
              tighten_max(bit_vector, max_value, min_value)

            true ->
              :domain_change
          end

        {domain_change, domain}
        |> tap(fn _ -> :bit_vector.clear(bit_vector, vector_value) end)

      true ->
        :no_change
    end
  end

  def removeAbove({bit_vector, offset} = domain, value) do
    {_current_min_max, _min_max_idx, min_value, max_value} = get_min_max(bit_vector)
    vector_value = value + offset

    cond do
      vector_value >= max_value ->
        :no_change

      vector_value < min_value ->
        fail(bit_vector)

      true ->
        ## The value is strictly less than max
        domain_change = tighten_max(bit_vector, vector_value + 1, min_value)

        {domain_change, domain}
    end
  end

  def removeBelow({bit_vector, offset} = domain, value) do
    {_current_min_max, _min_max_idx, min_value, max_value} = get_min_max(bit_vector)
    vector_value = value + offset

    cond do
      vector_value <= min_value ->
        :no_change

      vector_value > max_value ->
        fail(bit_vector)

      true ->
        ## The value is strictly greater than min
        domain_change = tighten_min(bit_vector, vector_value - 1, max_value)

        {domain_change, domain}
    end
  end

  def raw({{:bit_vector, ref} = _bit_vector, offset} = _domain) do
    %{
      offset: offset,
      content: Enum.map(1..:atomics.info(ref).size, fn i -> :atomics.get(ref, i) end)
    }
  end

  ## Last 2 bytes of bit_vector are min and max
  def last_index({:bit_vector, ref} = _bit_vector) do
    :atomics.info(ref).size - 1
  end

  defp set_min_max({:bit_vector, ref} = bit_vector, min_max) do
    bit_vector
    |> min_max_index()
    |> tap(fn idx ->
      :atomics.put(ref, idx, min_max)
    end)
  end

  def get_min(bit_vector) do
    get_min_max(bit_vector) |> elem(2)
  end

  def get_max(bit_vector) do
    get_min_max(bit_vector) |> elem(3)
  end

  defp min_max_index(bit_vector) do
    last_index(bit_vector) + 1
  end

  def get_min_max(bit_vector) do
    get_min_max_impl(bit_vector)
    |> then(fn {min_max_index, min_max} ->
      min_max == @failure_value && fail(bit_vector)
      {min_max, min_max_index, PackedMinMax.get_min(min_max), PackedMinMax.get_max(min_max)}
    end)
  end

  defp get_min_max_impl({:bit_vector, ref} = bit_vector) do
    min_max_index = min_max_index(bit_vector)
    {min_max_index, :atomics.get(ref, min_max_index)}
  end

  def set_min(bit_vector, new_min) do
    set_min(bit_vector, new_min, get_min_max(bit_vector))
  end

  def set_min({:bit_vector, ref} = bit_vector, new_min, min_max_info) do
    {current_min_max, min_max_idx, current_min, current_max} = min_max_info

    cond do
      new_min > current_max ->
        ## Inconsistency
        fail(bit_vector)

      new_min != current_min && current_min == current_max ->
        ## Attempt to re-fix
        fail(bit_vector)

      true ->
        ## Min change
        min_max_value = PackedMinMax.set_min(current_min_max, new_min)

        case :atomics.compare_exchange(ref, min_max_idx, current_min_max, min_max_value) do
          :ok ->
            cond do
              new_min == current_max -> :fixed
              new_min <= current_min -> :no_change
              true -> :min_change
            end

          changed_by_other_thread ->
            min2 = PackedMinMax.get_min(changed_by_other_thread)
            max2 = PackedMinMax.get_max(changed_by_other_thread)
            set_min(bit_vector, new_min, {changed_by_other_thread, min_max_idx, min2, max2})
        end
    end
  end

  def set_max(bit_vector, new_max) do
    set_max(bit_vector, new_max, get_min_max(bit_vector))
  end

  def set_max({:bit_vector, ref} = bit_vector, new_max, min_max_info) do
    {current_min_max, min_max_idx, current_min, current_max} = min_max_info

    cond do
      new_max < current_min ->
        ## Inconsistency
        fail(bit_vector)

      new_max != current_max && current_min == current_max ->
        ## Attempt to re-fix
        fail(bit_vector)

      true ->
        ## Max change
        min_max_value = PackedMinMax.set_max(current_min_max, new_max)

        case :atomics.compare_exchange(ref, min_max_idx, current_min_max, min_max_value) do
          :ok ->
            cond do
              new_max == current_min -> :fixed
              new_max >= current_max -> :no_change
              true -> :max_change
            end

          changed_by_other_thread ->
            min2 = PackedMinMax.get_min(changed_by_other_thread)
            max2 = PackedMinMax.get_max(changed_by_other_thread)
            set_max(bit_vector, new_max, {changed_by_other_thread, min_max_idx, min2, max2})
        end
    end
  end

  def set_fixed({:bit_vector, ref} = bit_vector, fixed_value, min_max_info) do
    {current_min_max, min_max_idx, current_min, current_max} = min_max_info

    if fixed_value != current_max && current_min == current_max do
      ## Attempt to re-fix
      fail(bit_vector)
    else
      min_max_value = PackedMinMax.set_min(0, fixed_value) |> PackedMinMax.set_max(fixed_value)

      case :atomics.compare_exchange(ref, min_max_idx, current_min_max, min_max_value) do
        :ok ->
          :fixed

        changed_by_other_thread ->
          min2 = PackedMinMax.get_min(changed_by_other_thread)
          max2 = PackedMinMax.get_max(changed_by_other_thread)
          set_fixed(bit_vector, fixed_value, {changed_by_other_thread, min_max_idx, min2, max2})
      end
    end
  end

  ## Update (cached) min, if necessary
  defp tighten_min(
         {:bit_vector, atomics_ref} = bit_vector,
         starting_at,
         max_value
       ) do
    {current_max_block, _} = vector_address(max_value)
    {rightmost_block, position_in_block} = vector_address(starting_at + 1)
    ## Find a new min (on the right of the current one)
    min_value =
      Enum.reduce_while(rightmost_block..current_max_block, false, fn idx, min_block_empty? ->
        case :atomics.get(atomics_ref, idx) do
          0 ->
            {:cont, min_block_empty?}

          non_zero_block ->
            block_lsb =
              if min_block_empty? do
                lsb(non_zero_block)
              else
                ## Reset all bits in the block to the left of the position
                shift = position_in_block
                lsb(non_zero_block >>> shift <<< shift)
              end

            (block_lsb &&
               {:halt, (idx - 1) * 64 + block_lsb}) || {:cont, true}
        end
      end)

    (is_integer(min_value) && set_min(bit_vector, min_value)) || fail(bit_vector)
  end

  ## Update (cached) max
  defp tighten_max(
         {:bit_vector, atomics_ref} = bit_vector,
         starting_at,
         min_value
       ) do
    {current_min_block_idx, _} = vector_address(min_value)
    {leftmost_block_idx, position_in_block} = vector_address(starting_at - 1)
    ## Find a new max (on the left of the current one)
    ##

    max_value =
      Enum.reduce_while(
        leftmost_block_idx..current_min_block_idx,
        false,
        fn idx, max_block_empty? ->
          case :atomics.get(atomics_ref, idx) do
            0 ->
              {:cont, max_block_empty?}

            non_zero_block ->
              block_msb =
                if max_block_empty? do
                  msb(non_zero_block)
                else
                  ## Reset all bits in the block to the right of the position
                  mask = (1 <<< (position_in_block + 1)) - 1
                  msb(non_zero_block &&& mask)
                end

              (block_msb &&
                 {:halt, (idx - 1) * 64 + block_msb}) || {:cont, true}
          end
        end
      )

    (is_integer(max_value) && set_max(bit_vector, max_value)) || fail(bit_vector)
  end

  defp fail(bit_vector \\ nil) do
    bit_vector && set_min_max(bit_vector, @failure_value)
    throw(:fail)
  end

  def get_bound_addrs(bit_vector) do
    {_, _, current_min, current_max} = get_min_max(bit_vector)
    {current_min_block, current_min_offset} = vector_address(current_min)
    {current_max_block, current_max_offset} = vector_address(current_max)

    %{
      min_addr: %{block: current_min_block, offset: current_min_offset},
      max_addr: %{block: current_max_block, offset: current_max_offset}
    }
  end

  ## Find the index of atomics where the n-value resides
  defp block_index(n) do
    div(n, 64) + 1
  end

  defp vector_address(n) do
    {block_index(n), rem(n, 64)}
  end

  ## Find least significant bit
  defp lsb(0) do
    nil
  end

  defp lsb(n) do
    lsb(n, 0)
  end

  defp lsb(1, idx) do
    idx
  end

  defp lsb(n, idx) do
    ((n &&& 1) == 1 && idx) ||
      lsb(n >>> 1, idx + 1)
  end

  defp msb(0) do
    nil
  end

  defp msb(n) do
    msb = floor(:math.log2(n))
    ## Check if there is no precision loss.
    ## We really want to throw away the fraction part even if it may
    ## get very close to 1.
    if floor(:math.pow(2, msb)) > n do
      msb - 1
    else
      msb
    end
  end

  def bit_count_iter(n) do
    for <<bit::1 <- :binary.encode_unsigned(n)>>, reduce: 0 do
      acc -> acc + bit
    end
  end

  def bit_count(0) do
    0
  end

  def bit_count(n) do
    n = (n &&& 0x5555555555555555) + (n >>> 1 &&& 0x5555555555555555)
    n = (n &&& 0x3333333333333333) + (n >>> 2 &&& 0x3333333333333333)
    n = (n &&& 0x0F0F0F0F0F0F0F0F) + (n >>> 4 &&& 0x0F0F0F0F0F0F0F0F)
    n = (n &&& 0x00FF00FF00FF00FF) + (n >>> 8 &&& 0x00FF00FF00FF00FF)
    n = (n &&& 0x0000FFFF0000FFFF) + (n >>> 16 &&& 0x0000FFFF0000FFFF)
    (n &&& 0x00000000FFFFFFFF) + (n >>> 32 &&& 0x00000000FFFFFFFF)
  end

  def bit_positions(n, mapper) do
    bit_positions(n, 1, 0, mapper, MapSet.new())
  end

  def bit_positions(_n, @max64_value, _iteration, _mapper, positions) do
    positions
  end

  def bit_positions(n, shift, iteration, mapper, positions) do
    acc =
      ((n &&& shift) > 0 &&
         (
           {lb, ub, new_value} = mapper.(iteration)

           (new_value >= lb && new_value <= ub &&
              MapSet.put(positions, new_value)) || positions
         )) ||
        positions

    bit_positions(n, shift <<< 1, iteration + 1, mapper, acc)
  end
end