Skip to main content

lib/statwise/nonparametric/mann_whitney.ex

defmodule Statwise.Nonparametric.MannWhitney do
  @moduledoc """
  Asymptotic Mann-Whitney U test.
  """

  alias Statwise.Distributions.Normal
  alias Statwise.Sample
  alias Statwise.TestResult

  @reference %{package: :scipy, version: "1.16.0"}

  def test(x, y, opts \\ []) do
    alternative = opts |> Keyword.get(:alternative, :two_sided) |> Sample.validate_alternative!()
    requested_method = Keyword.get(opts, :method, :asymptotic)
    continuity = Keyword.get(opts, :continuity, true)
    nan_policy = opts |> Keyword.get(:nan_policy, :raise) |> Sample.validate_nan_policy!()

    x =
      x
      |> Sample.to_floats(nan_policy: nan_policy)
      |> Sample.ensure_non_empty!("x")

    y =
      y
      |> Sample.to_floats(nan_policy: nan_policy)
      |> Sample.ensure_non_empty!("y")

    n_x = length(x)
    n_y = length(y)
    values = x ++ y

    if nan_policy == :propagate and Sample.has_nan?(values) do
      method = resolve_nan_method!(requested_method, n_x, n_y)

      %TestResult{
        test: :mann_whitney_u,
        statistic: nan(),
        p_value: nan(),
        alternative: alternative,
        method: method,
        n: %{x: n_x, y: n_y},
        reference: @reference,
        effect_size: %{
          common_language: nan(),
          rank_biserial: nan(),
          cliffs_delta: nan()
        },
        metadata: %{
          u1: nan(),
          u2: nan(),
          continuity: continuity,
          requested_method: requested_method
        }
      }
    else
      rank_info = rank_info(values, n_x)
      method = resolve_method!(requested_method, rank_info.has_ties, n_x, n_y)

      test_from_values(
        rank_info,
        n_x,
        n_y,
        alternative,
        method,
        requested_method,
        continuity
      )
    end
  end

  defp test_from_values(rank_info, n_x, n_y, alternative, method, requested_method, continuity) do
    rank_sum_x = rank_info.rank_sum_x
    u1 = rank_sum_x - n_x * (n_x + 1) / 2.0
    u2 = n_x * n_y - u1

    p_value =
      case method do
        :asymptotic ->
          asymptotic_p_value(u1, u2, n_x, n_y, rank_info.tie_term, alternative, continuity)

        :exact ->
          exact_p_value(u1, n_x, n_y, alternative)
      end

    %TestResult{
      test: :mann_whitney_u,
      statistic: u1,
      p_value: p_value,
      alternative: alternative,
      method: method,
      effect_size: effect_size(u1, n_x, n_y),
      n: %{x: n_x, y: n_y},
      reference: @reference,
      metadata: %{u1: u1, u2: u2, continuity: continuity, requested_method: requested_method}
    }
  end

  defp resolve_method!(:asymptotic, _values, _n_x, _n_y), do: :asymptotic

  defp resolve_method!(:exact, _values, _n_x, _n_y), do: :exact

  defp resolve_method!(:auto, has_ties, n_x, n_y) do
    if has_ties or min(n_x, n_y) > 8, do: :asymptotic, else: :exact
  end

  defp resolve_method!(method, _values, _n_x, _n_y) do
    raise ArgumentError, "unsupported Mann-Whitney U method #{inspect(method)}"
  end

  defp resolve_nan_method!(:asymptotic, _n_x, _n_y), do: :asymptotic
  defp resolve_nan_method!(:exact, _n_x, _n_y), do: :exact

  defp resolve_nan_method!(:auto, n_x, n_y),
    do: if(min(n_x, n_y) > 8, do: :asymptotic, else: :exact)

  defp resolve_nan_method!(method, _n_x, _n_y) do
    raise ArgumentError, "unsupported Mann-Whitney U method #{inspect(method)}"
  end

  defp exact_p_value(u1, n_x, n_y, alternative) do
    if n_x == 1 and n_y == 1 and u1 == 0.5 do
      1.0
    else
      exact_distribution_p_value(u1, n_x, n_y, alternative)
    end
  end

  defp exact_distribution_p_value(u1, n_x, n_y, alternative) do
    distribution = exact_distribution(n_x, n_y)
    total = distribution |> Map.values() |> Enum.sum()

    lower =
      distribution
      |> Enum.reduce(0, fn {candidate, count}, acc ->
        if candidate <= u1, do: acc + count, else: acc
      end)

    upper =
      distribution
      |> Enum.reduce(0, fn {candidate, count}, acc ->
        if candidate >= u1, do: acc + count, else: acc
      end)

    case alternative do
      :less -> lower / total
      :greater -> upper / total
      :two_sided -> min(1.0, 2.0 * min(lower, upper) / total)
    end
  end

  defp exact_distribution(n_x, n_y) do
    key = {__MODULE__, :exact_distribution, n_x, n_y}

    case :persistent_term.get(key, nil) do
      nil ->
        distribution = build_exact_distribution(n_x, n_y)
        :persistent_term.put(key, distribution)
        distribution

      distribution ->
        distribution
    end
  end

  defp build_exact_distribution(n_x, n_y) do
    rank_count = n_x + n_y
    rank_offset = div(n_x * (n_x + 1), 2)

    1..rank_count
    |> Enum.reduce(%{{0, 0} => 1}, fn rank, acc ->
      Enum.reduce(acc, acc, fn
        {{chosen, sum}, count}, next when chosen < n_x ->
          Map.update(next, {chosen + 1, sum + rank}, count, &(&1 + count))

        _entry, next ->
          next
      end)
    end)
    |> Enum.reduce(%{}, fn
      {{^n_x, rank_sum}, count}, acc ->
        Map.update(acc, rank_sum - rank_offset, count, &(&1 + count))

      _entry, acc ->
        acc
    end)
  end

  defp asymptotic_p_value(u1, u2, n_x, n_y, tie_term, alternative, continuity) do
    mean = n_x * n_y / 2.0
    sd = :math.sqrt(variance(n_x, n_y, tie_term))

    if sd == 0.0 do
      1.0
    else
      case alternative do
        :greater ->
          z = (u1 - mean - continuity_correction(continuity, u1 - mean)) / sd
          Normal.sf(z)

        :less ->
          z = (u1 - mean + continuity_correction(continuity, mean - u1)) / sd
          Normal.cdf(z)

        :two_sided ->
          u = max(u1, u2)
          z = (u - mean - continuity_correction(continuity, u - mean)) / sd
          min(1.0, 2.0 * Normal.sf(z))
      end
    end
  end

  defp variance(n_x, n_y, tie_term) do
    n = n_x + n_y

    n_x * n_y / 12.0 * (n + 1.0 - tie_term / (n * (n - 1)))
  end

  defp continuity_correction(false, _distance), do: 0.0
  defp continuity_correction(true, _distance), do: 0.5

  defp rank_info(values, n_x) do
    values
    |> Enum.with_index()
    |> Enum.sort_by(fn {value, _index} -> value end)
    |> ranked_groups(n_x, 1, 0.0, 0.0, false)
  end

  defp ranked_groups([], _n_x, _rank, rank_sum_x, tie_term, has_ties) do
    %{rank_sum_x: rank_sum_x, tie_term: tie_term, has_ties: has_ties}
  end

  defp ranked_groups(
         [{value, _index} | _rest] = values,
         n_x,
         rank,
         rank_sum_x,
         tie_term,
         has_ties
       ) do
    {group_size, group_x_count, rest} = consume_rank_group(values, value, n_x, 0, 0)
    average_rank = (rank + rank + group_size - 1) / 2.0
    rank_sum_x = rank_sum_x + group_x_count * average_rank

    tie_term =
      if group_size > 1 do
        tie_term + group_size * group_size * group_size - group_size
      else
        tie_term
      end

    ranked_groups(rest, n_x, rank + group_size, rank_sum_x, tie_term, has_ties or group_size > 1)
  end

  defp consume_rank_group([], _value, _n_x, group_size, group_x_count) do
    {group_size, group_x_count, []}
  end

  defp consume_rank_group([{value, index} | rest], value, n_x, group_size, group_x_count) do
    group_x_count = if index < n_x, do: group_x_count + 1, else: group_x_count
    consume_rank_group(rest, value, n_x, group_size + 1, group_x_count)
  end

  defp consume_rank_group(rest, _value, _n_x, group_size, group_x_count) do
    {group_size, group_x_count, rest}
  end

  defp effect_size(u1, n_x, n_y) do
    common_language = u1 / (n_x * n_y)
    rank_biserial = 2.0 * common_language - 1.0

    %{
      common_language: common_language,
      rank_biserial: rank_biserial,
      cliffs_delta: rank_biserial
    }
  end

  defp nan, do: :nan
end