defmodule CPSolver.BitVectorDomain do
import Bitwise
import CPSolver.BitUtils
@failure_value (1 <<< 64) - 1
def new([]) do
fail()
end
def new(domain) when is_integer(domain) do
new([domain])
end
def new(domain) when is_list(domain) or is_struct(domain, Range) or is_struct(domain, MapSet) do
offset = -Enum.min(domain)
domain_size = Enum.max(domain) + offset + 1
bv = :bit_vector.new(domain_size)
Enum.each(domain, fn idx -> :bit_vector.set(bv, idx + offset) end)
PackedMinMax.set_min(0, 0)
|> PackedMinMax.set_max(Enum.max(domain) + offset)
|> then(fn min_max -> set_min_max(bv, min_max) end)
{bv, offset}
end
def copy({{:bit_vector, ref} = bit_vector, offset} = _domain) do
%{
min_addr: %{block: current_min_block},
max_addr: %{block: current_max_block}
} = get_bound_addrs(bit_vector)
new_atomics_size = current_max_block + 1
new_atomics_ref = :atomics.new(new_atomics_size, [{:signed, false}])
Enum.each(
current_min_block..current_max_block,
fn block_idx ->
block_val = :atomics.get(ref, block_idx)
:atomics.put(new_atomics_ref, block_idx, block_val)
end
)
new_bit_vector = {:bit_vector, new_atomics_ref}
set_min_max(new_bit_vector, get_min_max_impl(bit_vector) |> elem(1))
{new_bit_vector, offset}
end
## 'next' value in the domain
def next({{:bit_vector, ref} = bit_vector, offset} = _domain, value) do
{_current_min_max, _min_max_idx, min_value, max_value} = get_min_max(bit_vector)
cond do
value + offset < min_value -> min_value - offset
value + offset >= max_value -> nil
true ->
{block_index, position} = vector_address(value + offset)
case next(block_index, position, ref, last_index(bit_vector)) do
nil -> nil
positional_value -> positional_value - offset
end
end
end
defp next(block_index, _position, _atomics_ref, last_index) when block_index > last_index do
nil
end
defp next(block_index, position, atomics_ref, last_index) do
block = :atomics.get(atomics_ref, block_index)
## Check if the block has bits on the right of the position of initial value.
shift = position + 1
shifted_left = block >>> shift
case lsb(shifted_left) do
nil -> ## nothing in this block, try next one
next(block_index + 1, -1, atomics_ref, last_index)
next_lsb ->
positional_value(block_index, next_lsb + shift)
end
end
def map(domain, mapper_fun) when is_function(mapper_fun) do
to_list(domain, mapper_fun)
end
## Reduce over domain values
def reduce(
{{:bit_vector, ref} = bit_vector, offset} = domain,
value_mapper_fun,
acc_init \\ MapSet.new(),
reduce_fun \\ &MapSet.union/2
) do
%{
min_addr: %{block: current_min_block, offset: _min_offset},
max_addr: %{block: current_max_block, offset: _max_offset}
} = get_bound_addrs(bit_vector)
lb = min(domain)
ub = max(domain)
Enum.reduce(current_min_block..current_max_block, acc_init, fn block_idx, acc ->
case :atomics.get(ref, block_idx) do
0 ->
acc
block ->
reduce_fun.(
acc,
bit_positions(block, fn val ->
case val + 64 * (block_idx - 1) - offset do
value when value >= lb and value <= ub ->
value_mapper_fun.(value)
_out_of_bounds ->
nil
end
end)
)
end
end)
end
def to_list(
domain,
value_mapper_fun \\ &Function.identity/1
) do
(fixed?(domain) && MapSet.new([value_mapper_fun.(min(domain))])) ||
reduce(domain, value_mapper_fun, MapSet.new(), &MapSet.union/2)
end
def fixed?({bit_vector, _offset} = _domain) do
{current_min_max, _min_max_idx, current_min, current_max} = get_min_max(bit_vector)
current_max == current_min && current_min_max != @failure_value
end
def failed?({:bit_vector, _ref} = bit_vector) do
failed?(elem(get_min_max_impl(bit_vector), 1))
end
def failed?({bit_vector, _offset} = _domain) do
failed?(bit_vector)
end
def failed?(min_max_value) when is_integer(min_max_value) do
min_max_value == @failure_value
end
def min({bit_vector, offset} = _domain) do
get_min(bit_vector) - offset
end
def max({bit_vector, offset} = _domain) do
get_max(bit_vector) - offset
end
def size({{:bit_vector, ref} = bit_vector, _offset}) do
%{
min_addr: %{block: current_min_block, offset: min_offset},
max_addr: %{block: current_max_block, offset: max_offset}
} = get_bound_addrs(bit_vector)
Enum.reduce(current_min_block..current_max_block, 0, fn idx, acc ->
n = :atomics.get(ref, idx)
if n == 0 do
acc
else
n1 = (idx == current_min_block && n >>> min_offset) || n
n2 = (idx == current_max_block && ((1 <<< (max_offset + 1)) - 1 &&& n1)) || n1
acc + bit_count(n2)
end
end)
end
def contains?({{:bit_vector, _ref} = bit_vector, offset}, value) do
{_current_min_max, _min_max_idx, min_value, max_value} = get_min_max(bit_vector)
vector_value = value + offset
contains?(bit_vector, vector_value, min_value, max_value)
end
def contains?(bit_vector, vector_value, min_value, max_value) do
vector_value >= min_value && vector_value <= max_value &&
:bit_vector.get(bit_vector, vector_value) == 1
end
def fix({bit_vector, offset} = _domain, value) do
min_max_info =
{_current_min_max, _min_max_idx, min_value, max_value} = get_min_max(bit_vector)
vector_value = value + offset
if contains?(bit_vector, vector_value, min_value, max_value) do
set_fixed(bit_vector, value + offset, min_max_info)
else
fail(bit_vector)
end
end
def remove({bit_vector, offset} = domain, value) do
{_current_min_max, _min_max_idx, min_value, max_value} = get_min_max(bit_vector)
vector_value = value + offset
cond do
## No value in the domain, do nothing
contains?(bit_vector, vector_value, min_value, max_value) ->
domain_change =
cond do
min_value == max_value && vector_value == min_value ->
## Fixed value: fail on removing attempt
fail(bit_vector)
min_value == vector_value ->
tighten_min(bit_vector, min_value, max_value)
max_value == vector_value ->
tighten_max(bit_vector, max_value, min_value)
true ->
:domain_change
end
{domain_change, domain}
|> tap(fn _ -> :bit_vector.clear(bit_vector, vector_value) end)
true ->
:no_change
end
end
def removeAbove({bit_vector, offset} = domain, value) do
{_current_min_max, _min_max_idx, min_value, max_value} = get_min_max(bit_vector)
vector_value = value + offset
cond do
vector_value >= max_value ->
:no_change
vector_value < min_value ->
fail(bit_vector)
true ->
## The value is strictly less than max
domain_change = tighten_max(bit_vector, vector_value + 1, min_value)
{domain_change, domain}
end
end
def removeBelow({bit_vector, offset} = domain, value) do
{_current_min_max, _min_max_idx, min_value, max_value} = get_min_max(bit_vector)
vector_value = value + offset
cond do
vector_value <= min_value ->
:no_change
vector_value > max_value ->
fail(bit_vector)
true ->
## The value is strictly greater than min
domain_change = tighten_min(bit_vector, vector_value - 1, max_value)
{domain_change, domain}
end
end
def raw({{:bit_vector, ref} = _bit_vector, offset} = _domain) do
%{
offset: offset,
content: Enum.map(1..:atomics.info(ref).size, fn i -> :atomics.get(ref, i) end)
}
end
def bits({{:bit_vector, ref} = _bit_vector, _offset} = _domain) do
Enum.map(1..:atomics.info(ref)[:size] - 1, fn idx -> :atomics.get(ref, idx) |> Integer.to_string(2) |> String.reverse() end)
end
## Last byte of bit_vector contains (packed) min and max
def last_index({:bit_vector, ref} = _bit_vector) do
:atomics.info(ref).size - 1
end
defp set_min_max({:bit_vector, ref} = bit_vector, min_max) do
bit_vector
|> min_max_index()
|> tap(fn idx ->
:atomics.put(ref, idx, min_max)
end)
end
def get_min(bit_vector) do
get_min_max(bit_vector) |> elem(2)
end
def get_max(bit_vector) do
get_min_max(bit_vector) |> elem(3)
end
defp min_max_index(bit_vector) do
last_index(bit_vector) + 1
end
def get_min_max(bit_vector) do
get_min_max_impl(bit_vector)
|> then(fn {min_max_index, min_max} ->
min_max == @failure_value && fail(bit_vector)
{min_max, min_max_index, PackedMinMax.get_min(min_max), PackedMinMax.get_max(min_max)}
end)
end
defp get_min_max_impl({:bit_vector, ref} = bit_vector) do
min_max_index = min_max_index(bit_vector)
{min_max_index, :atomics.get(ref, min_max_index)}
end
def set_min(bit_vector, new_min) do
set_min(bit_vector, new_min, get_min_max(bit_vector))
end
def set_min({:bit_vector, ref} = bit_vector, new_min, min_max_info) do
{current_min_max, min_max_idx, current_min, current_max} = min_max_info
cond do
new_min > current_max ->
## Inconsistency
fail(bit_vector)
new_min != current_min && current_min == current_max ->
## Attempt to re-fix
fail(bit_vector)
true ->
## Min change
min_max_value = PackedMinMax.set_min(current_min_max, new_min)
case :atomics.compare_exchange(ref, min_max_idx, current_min_max, min_max_value) do
:ok ->
cond do
new_min == current_max -> :fixed
new_min <= current_min -> :no_change
true -> :min_change
end
changed_by_other_thread ->
min2 = PackedMinMax.get_min(changed_by_other_thread)
max2 = PackedMinMax.get_max(changed_by_other_thread)
set_min(bit_vector, new_min, {changed_by_other_thread, min_max_idx, min2, max2})
end
end
end
def set_max(bit_vector, new_max) do
set_max(bit_vector, new_max, get_min_max(bit_vector))
end
def set_max({:bit_vector, ref} = bit_vector, new_max, min_max_info) do
{current_min_max, min_max_idx, current_min, current_max} = min_max_info
cond do
new_max < current_min ->
## Inconsistency
fail(bit_vector)
new_max != current_max && current_min == current_max ->
## Attempt to re-fix
fail(bit_vector)
true ->
## Max change
min_max_value = PackedMinMax.set_max(current_min_max, new_max)
case :atomics.compare_exchange(ref, min_max_idx, current_min_max, min_max_value) do
:ok ->
cond do
new_max == current_min -> :fixed
new_max >= current_max -> :no_change
true -> :max_change
end
changed_by_other_thread ->
min2 = PackedMinMax.get_min(changed_by_other_thread)
max2 = PackedMinMax.get_max(changed_by_other_thread)
set_max(bit_vector, new_max, {changed_by_other_thread, min_max_idx, min2, max2})
end
end
end
def set_fixed({:bit_vector, ref} = bit_vector, fixed_value, min_max_info) do
{current_min_max, min_max_idx, current_min, current_max} = min_max_info
if fixed_value != current_max && current_min == current_max do
## Attempt to re-fix
fail(bit_vector)
else
min_max_value = PackedMinMax.set_min(0, fixed_value) |> PackedMinMax.set_max(fixed_value)
case :atomics.compare_exchange(ref, min_max_idx, current_min_max, min_max_value) do
:ok ->
:fixed
changed_by_other_thread ->
min2 = PackedMinMax.get_min(changed_by_other_thread)
max2 = PackedMinMax.get_max(changed_by_other_thread)
set_fixed(bit_vector, fixed_value, {changed_by_other_thread, min_max_idx, min2, max2})
end
end
end
## Update (cached) min, if necessary
defp tighten_min(
{:bit_vector, atomics_ref} = bit_vector,
starting_at,
max_value
) do
{current_max_block, _} = vector_address(max_value)
{rightmost_block, position_in_block} = vector_address(starting_at + 1)
## Find a new min (on the right of the current one)
min_value =
Enum.reduce_while(rightmost_block..current_max_block, false, fn idx, min_block_empty? ->
case :atomics.get(atomics_ref, idx) do
0 ->
{:cont, min_block_empty?}
non_zero_block ->
block_lsb =
if min_block_empty? do
lsb(non_zero_block)
else
## Reset all bits in the block to the left of the position
shift = position_in_block
lsb(non_zero_block >>> shift <<< shift)
end
(block_lsb &&
{:halt, positional_value(idx, block_lsb)}) || {:cont, true}
end
end)
(is_integer(min_value) && set_min(bit_vector, min_value)) || fail(bit_vector)
end
## Update (cached) max
defp tighten_max(
{:bit_vector, atomics_ref} = bit_vector,
starting_at,
min_value
) do
{current_min_block_idx, _} = vector_address(min_value)
{leftmost_block_idx, position_in_block} = vector_address(starting_at - 1)
## Find a new max (on the left of the current one)
##
max_value =
Enum.reduce_while(
leftmost_block_idx..current_min_block_idx,
false,
fn idx, max_block_empty? ->
case :atomics.get(atomics_ref, idx) do
0 ->
{:cont, max_block_empty?}
non_zero_block ->
block_msb =
if max_block_empty? do
msb(non_zero_block)
else
## Reset all bits in the block to the right of the position
mask = (1 <<< (position_in_block + 1)) - 1
msb(non_zero_block &&& mask)
end
(block_msb &&
{:halt, positional_value(idx, block_msb)}) || {:cont, true}
end
end
)
(is_integer(max_value) && set_max(bit_vector, max_value)) || fail(bit_vector)
end
defp fail(bit_vector \\ nil) do
bit_vector && set_min_max(bit_vector, @failure_value)
throw(:fail)
end
def get_bound_addrs(bit_vector) do
{_, _, current_min, current_max} = get_min_max(bit_vector)
{current_min_block, current_min_offset} = vector_address(current_min)
{current_max_block, current_max_offset} = vector_address(current_max)
%{
min_addr: %{block: current_min_block, offset: current_min_offset},
max_addr: %{block: current_max_block, offset: current_max_offset}
}
end
## Find the index of atomics where the n-value resides
defp block_index(n) do
div(n, 64) + 1
end
def vector_address(n) do
{block_index(n), rem(n, 64)}
end
## Domain value, computed off the block index and position within the block
def positional_value(block_index, position) do
(block_index - 1) * 64 + position
end
def bit_positions(0, _mapper) do
MapSet.new()
end
def bit_positions(n, mapper) do
lsb = lsb(n)
msb = msb(n)
initial_set =
Enum.reduce([lsb, msb], MapSet.new(), fn value, acc ->
case mapper.(value) do
nil ->
acc
new_value ->
MapSet.put(acc, new_value)
end
end)
bit_positions(n >>> lsb, 1, lsb, msb, mapper, initial_set)
end
def bit_positions(_n, _shift, iteration, msb, _mapper, positions) when iteration == msb do
positions
end
def bit_positions(n, shift, iteration, msb, mapper, positions) do
acc =
((n &&& shift) > 0 &&
case mapper.(iteration) do
nil -> positions
new_value -> MapSet.put(positions, new_value)
end) ||
positions
bit_positions(n, shift <<< 1, iteration + 1, msb, mapper, acc)
end
end