Skip to main content

lib/council_ex/bias_detector.ex

defmodule CouncilEx.BiasDetector do
  @moduledoc """
  Diagnostic detector for demographic-laden disagreement across member outputs.

  Inspired by the `Bias Detector` component in Wu et al. *Council Mode*
  (arXiv:2604.02923). The paper observes that when heterogeneous models
  disagree, the disagreement sometimes correlates with demographic axes
  (gender, ethnicity, religion, age, ability). This module surfaces that
  correlation as a structured report.

  **Diagnostic only — does not mitigate.** Reading the report is on you.

  ## Backends

    * `:lexicon` (default) — substring/regex match against a built-in
      term list per axis. Cheap. False-positive prone (any neutral
      mention of "women" or "Christians" trips it). Ship a default
      lexicon; users can extend or replace it.

  Future backends planned: `:llm_judge` (separate LLM rates the
  disagreement), `:embedding_cluster` (cluster responses, correlate
  cluster membership with demographic phrasing).

  ## Usage

      member_results = %{
        a: %CouncilEx.MemberResult{
          status: :ok,
          response: %CouncilEx.Response{content: "..."}
        },
        b: %CouncilEx.MemberResult{...}
      }

      report = CouncilEx.BiasDetector.analyze(member_results)
      # %{
      #   flagged: true,
      #   axes: [%{axis: :gender, score: 0.5, evidence: [...]}, ...],
      #   baseline_disagreement: 0.4
      # }

  ## Report shape

    * `:flagged` — boolean. True if any axis crosses the threshold
      (default `0.3`).
    * `:axes` — list of `%{axis, score, evidence}` per axis where
      members differ in coverage of demographic terms.
      `:score` is `coverage_variance ∈ [0, 1]`.
      `:evidence` is a list of `{member_id, [matched_terms]}`.
    * `:baseline_disagreement` — content-similarity proxy across all
      members (Jaccard over token sets, 0 = identical, 1 = no overlap).
      High baseline + low axis scores = members disagree on substance,
      not demographics.
  """

  alias CouncilEx.{MemberResult, Response}

  @default_lexicon %{
    gender: ~w(
      woman women female girl mother daughter wife sister feminine she her hers
      man men male boy father son husband brother masculine he him his
      gender nonbinary transgender trans cisgender
    ),
    ethnicity: ~w(
      black white asian hispanic latino latina latinx african european
      indigenous native immigrant minority ethnic race racial
    ),
    religion: ~w(
      christian christianity catholic protestant muslim islam islamic jew jewish
      hindu hinduism buddhist buddhism atheist agnostic religious secular
    ),
    age: ~w(
      young old elderly senior teen teenager child kids adult aging youth
    ),
    ability: ~w(
      disabled disability blind deaf wheelchair handicapped impaired
      able-bodied neurodiverse autistic
    )
  }

  @type axis :: :gender | :ethnicity | :religion | :age | :ability | atom()

  @type report :: %{
          flagged: boolean(),
          axes: [%{axis: axis(), score: float(), evidence: [{atom(), [String.t()]}]}],
          baseline_disagreement: float()
        }

  @doc "Returns the default lexicon. Pass your own via `:lexicon` opt to `analyze/2`."
  @spec default_lexicon() :: %{axis() => [String.t()]}
  def default_lexicon, do: @default_lexicon

  @doc """
  Analyze member results and return a bias report.

  ## Options

    * `:lexicon` — `%{axis => [terms]}` to override the default.
    * `:threshold` — float 0.0..1.0. Axes scoring above this set
      `:flagged => true` for the whole report. Default `0.3`.
    * `:backend` — `:lexicon` (default). Reserved for future backends.
  """
  @spec analyze(%{atom() => MemberResult.t()}, keyword()) :: report()
  def analyze(member_results, opts \\ []) when is_map(member_results) do
    lexicon = Keyword.get(opts, :lexicon, @default_lexicon)
    threshold = Keyword.get(opts, :threshold, 0.3)

    ok_results =
      for {id, %MemberResult{status: :ok, response: %Response{content: c}}} <- member_results,
          is_binary(c) do
        {id, String.downcase(c)}
      end

    axes =
      for {axis, terms} <- lexicon do
        score_axis(axis, terms, ok_results)
      end

    flagged = Enum.any?(axes, fn %{score: s} -> s >= threshold end)

    %{
      flagged: flagged,
      axes: Enum.sort_by(axes, & &1.score, :desc),
      baseline_disagreement: baseline_disagreement(ok_results)
    }
  end

  defp score_axis(axis, terms, ok_results) do
    evidence =
      for {id, content} <- ok_results do
        matched =
          terms
          |> Enum.filter(fn term -> String.contains?(content, String.downcase(term)) end)
          |> Enum.uniq()

        {id, matched}
      end

    counts = Enum.map(evidence, fn {_, matches} -> length(matches) end)

    score =
      case counts do
        [] -> 0.0
        [_] -> 0.0
        _ -> coverage_variance(counts)
      end

    %{
      axis: axis,
      score: Float.round(score, 4),
      evidence: Enum.reject(evidence, fn {_, m} -> m == [] end)
    }
  end

  # Normalize counts to [0, 1] coverage rates per member, then take
  # population variance. Variance ∈ [0, 0.25]; rescale to [0, 1].
  defp coverage_variance(counts) do
    max_count = Enum.max(counts)

    if max_count == 0 do
      0.0
    else
      rates = Enum.map(counts, &(&1 / max_count))
      mean = Enum.sum(rates) / length(rates)
      var = Enum.sum(Enum.map(rates, &:math.pow(&1 - mean, 2))) / length(rates)
      min(var * 4.0, 1.0)
    end
  end

  defp baseline_disagreement(ok_results) when length(ok_results) < 2, do: 0.0

  defp baseline_disagreement(ok_results) do
    token_sets =
      Enum.map(ok_results, fn {_, content} ->
        content
        |> String.split(~r/\W+/, trim: true)
        |> MapSet.new()
      end)

    pairs = for s1 <- token_sets, s2 <- token_sets, s1 != s2, do: {s1, s2}

    case pairs do
      [] ->
        0.0

      _ ->
        avg_jaccard =
          pairs
          |> Enum.map(fn {a, b} ->
            inter = MapSet.intersection(a, b) |> MapSet.size()
            union = MapSet.union(a, b) |> MapSet.size()
            if union == 0, do: 0.0, else: inter / union
          end)
          |> then(&(Enum.sum(&1) / length(&1)))

        Float.round(1.0 - avg_jaccard, 4)
    end
  end
end