lib/rollex/distribution.ex

defmodule Rollex.Distribution do
  @moduledoc """
  This module provides functions to calculate various outcome distributions
  of sets of dice. This includes minimums, maximums, and histograms.
  """

  # arbitrary nice number
  @max_permutations 12_000_000

  @doc "Calculates the minimum value possible from a roll definition"
  @spec min(Rollex.t()) :: {:ok, number} | {:error, reason :: String.t()}
  def min(%{valid: true, compiled: tokens}), do: limit(tokens, :min)
  def min(%{error: error}), do: {:error, error || "Invalid roll definition"}

  @doc "Calculates the maximum value possible from a roll definition"
  @spec max(Rollex.t()) :: {:ok, number} | {:error, reason :: String.t()}
  def max(%{valid: true, compiled: tokens}), do: limit(tokens, :max)
  def max(%{error: error}), do: {:error, error || "Invalid roll definition"}

  @doc "Calculates the minimum and maximum value possible from a roll definition"
  @spec range(Rollex.t()) ::
          {:ok, min :: number, max :: number} | {:error, reason :: String.t()}
  def range(%{valid: true, compiled: tokens}) do
    with {:ok, min} <- limit(tokens, :min),
         {:ok, max} <- limit(tokens, :max) do
      {:ok, min, max}
    else
      error -> error
    end
  end

  def range(%{error: error}), do: {:error, error || "Invalid roll definition"}

  # The way this algorithm works is:
  # for each and every number that is possible to generate:
  #   1. generate the "most 1s" roll in descending vaule order
  #      so for 9 on a {3, 6} (aka 3d6) that would be [6, 2, 1]
  #   2. iterate all addition possible numbers by "shuffling" numbers right-wards
  #      so [6, 2, 1] becomes [5, 3, 1]. this is done recursively for each sub-set,
  #      so [6, 3, 2, 1] would also have this process done for [3, 2, 1] with 6 prepended
  #      this is all done with care taken to ensure that the newly generated entry is *always*
  #      in descening order (which prevents ever having to Enum.sort them)
  #   3. add these entries into a MapSet to ensure uniqueness
  #   4. for each entry, count the number of times each entry appears (6 appears twice in [6, 6, 1])
  #      and then caluclate the permutations possible: quantity! / (c1! * c2! * ...)
  #   5. sum all of the permutations and divide by the total permutations (sides ^ quantity) and voila!
  #
  #   Note that there is *no* memoization happening. It may be worth looking into that to see if
  #   this could be improved.
  #
  #   It also bails at a hard-coded limit, to preven DoS as the algo is not constant, or even
  #   linear, time complexity.
  @spec histogram(quantity :: pos_integer, sides :: pos_integer, effort_count :: non_neg_integer) ::
          {map, effort_count :: non_neg_integer}
  def histogram(quantity, sides, effort_count \\ 0)

  def histogram(_quantity, _sides, effort_count) when effort_count >= @max_permutations do
    {%{}, effort_count}
  end

  def histogram(quantity, sides, effort_count) do
    min = quantity
    max = quantity * sides
    width = max - min

    # as the histogram is always symmetrical, only calc the 'top' half
    max_to_calc = min + Integer.floor_div(width, 2)

    {half_histogram, final_effort_count} =
      Enum.reduce_while(min..max_to_calc, {%{}, effort_count}, fn target, {acc, count} ->
        {result, count} = histogram_entry(target, quantity, sides, count)

        if count >= @max_permutations do
          {:halt, {%{}, count}}
        else
          {:cont, {Map.put(acc, target, result), count}}
        end
      end)

    # now fill in the bottom half as a reflection of the top half
    full_histogram = fill_top_half(max_to_calc, max, width, half_histogram)

    {full_histogram, final_effort_count}
  end

  @doc "Takes two histograms and combines them into one"
  @spec zip_histograms(left :: map, right :: map) :: map
  def zip_histograms(l, r) do
    Enum.reduce(l, %{}, fn {l, l_odds}, acc ->
      Enum.reduce(r, acc, fn {r, r_odds}, acc ->
        odds = l_odds * r_odds
        Map.update(acc, l + r, odds, fn current -> current + odds end)
      end)
    end)
    |> Enum.reduce(%{}, fn {key, value}, acc -> Map.put(acc, key, Float.round(value / 100, 3)) end)
  end

  @doc """
  Takes histogram, a translation magnitude, and a 3-arity functions to apply to entries in
  the histogram, returning a translated histogram. Useful for performing arithmetic on histograms
  allowing for the representation of things such as "1d8+2"
  """
  @spec translate_histogram(
          h :: map,
          magnitude :: number,
          Rollex.Utilities.merge_operation()
        ) :: map
  def translate_histogram(h, magnitude, merge_op) when is_function(merge_op, 3) do
    Enum.reduce(h, %{}, fn {key, value}, acc ->
      Map.put(acc, merge_op.(:arithmetic, key, magnitude), value)
    end)
  end

  @spec limit(Rollex.tokens(), type :: :min | :max) ::
          {:ok, limit :: integer} | {:error, reason :: String.t()}
  defp limit(tokenized_input, type) do
    rolled_input =
      tokenized_input
      |> set_dice_results(type)

    rolled_input
    |> Rollex.PrattParser.evaluate_expression()
    |> build_evaluator_output(rolled_input)
  end

  defp fill_top_half(_last_calc, _max_calc, _width, histogram) when histogram == %{}, do: %{}

  defp fill_top_half(last_calc, max_calc, width, histogram) do
    reflection_fudge = if(Integer.mod(width, 2) == 0, do: 0, else: 1)

    Enum.reduce((last_calc + 1)..max_calc, histogram, fn n, acc ->
      Map.put(acc, n, Map.get(acc, last_calc - (n - last_calc - reflection_fudge)))
    end)
  end

  defp set_dice_results(input, type, output \\ [])
  defp set_dice_results([], _type, output), do: Enum.reverse(output)

  defp set_dice_results([%{is_dice: true} = token | tail], type, output) do
    case type do
      :min ->
        set_dice_results(tail, type, [
          %Rollex.Tokens.Number{value: Rollex.Dice.min(token)} | output
        ])

      :max ->
        set_dice_results(tail, type, [
          %Rollex.Tokens.Number{value: Rollex.Dice.max(token)} | output
        ])
    end
  end

  defp set_dice_results([token | tail], type, output) do
    set_dice_results(tail, type, [token | output])
  end

  defp build_evaluator_output({:error, errors}, _full_tokenized_input) do
    {:error, errors}
  end

  defp build_evaluator_output(
         {:ok, %{arithmetic: final_total}, [%Rollex.Tokens.End{}]},
         _full_tokenized_input
       ) do
    {:ok, final_total}
  end

  defp build_evaluator_output(
         {:ok, _final_total, [%Rollex.Tokens.RightParenthesis{} | _rest_of_input]},
         _full_tokenized_input
       ) do
    {:error, "Missing opening parenthesis!"}
  end

  defp build_evaluator_output(
         {:ok, _final_total, [%Rollex.Tokens.LeftParenthesis{} | _rest_of_input]},
         _full_tokenized_input
       ) do
    {:error, "Missing closing parenthesis!"}
  end

  defp generate_starting_roll(dice, target, sides, quantity) do
    if quantity == 1 do
      [target | dice]
    else
      if quantity > 1 and (quantity - 1) * sides >= target do
        generate_starting_roll([1 | dice], target - 1, sides, quantity - 1)
      else
        min_roll = target - (quantity - 1) * sides
        generate_starting_roll([min_roll | dice], target - min_roll, sides, quantity - 1)
      end
    end
  end

  defp factorial(1, acc), do: acc
  defp factorial(n, acc), do: factorial(n - 1, n * acc)

  defp factorial(1), do: 1
  defp factorial(n), do: factorial(n - 1, n)

  defp roll_permutations([], _curr, curr_count, acc), do: factorial(curr_count) * acc

  defp roll_permutations([curr | rest], curr, curr_count, acc),
    do: roll_permutations(rest, curr, curr_count + 1, acc)

  defp roll_permutations([new | rest], _curr, curr_count, acc),
    do: roll_permutations(rest, new, 1, factorial(curr_count) * acc)

  defp roll_permutations([new | rest]), do: roll_permutations(rest, new, 1, 1)

  defp generate_next_roll([_last], _acc), do: nil

  defp generate_next_roll([1 | _] = rest, acc) do
    rest ++ acc
  end

  defp generate_next_roll([head, neck | rest], acc) when head > neck + 1 do
    Enum.reverse(rest ++ [neck + 1, head - 1 | acc])
  end

  defp generate_next_roll([head, neck | rest], acc) when head > neck do
    generate_next_roll([neck + 1 | rest], [head - 1 | acc])
  end

  defp generate_next_roll([head | rest], acc) do
    generate_next_roll(rest, [head | acc])
  end

  defp all_unique_rolls(_, _, acc, count) when count >= @max_permutations do
    {acc, count}
  end

  defp all_unique_rolls([_], _, acc, count) do
    {acc, count + 1}
  end

  defp all_unique_rolls([head | rest] = numbers, prefix, acc, count) do
    {acc, count} = all_unique_rolls(rest, prefix ++ [head], acc, count)
    next = generate_next_roll(numbers, [])

    if next != nil and next != numbers do
      # roll = prefix ++ next
      # if roll != Enum.sort(roll, :desc), do: IO.inspect(roll, label: "Sort needed in AUR")
      all_unique_rolls(next, prefix, MapSet.put(acc, prefix ++ next), count)
    else
      {acc, count + 1}
    end
  end

  defp histogram_entry(target, quantity, sides, count) do
    total_permutations = factorial(quantity)

    start = generate_starting_roll([], target, sides, quantity)
    acc = MapSet.put(MapSet.new(), start)
    {result, count} = all_unique_rolls(start, [], acc, count)

    try do
      total =
        result
        |> Enum.reduce(0, fn roll, acc ->
          acc + total_permutations / roll_permutations(roll)
        end)

      {Float.round(total / :math.pow(sides, quantity) * 100, 3), count}
    rescue
      ArithmeticError ->
        {0, @max_permutations}
    end
  end
end