Skip to main content

src/babble.gleam

//// A small Markov chain text generator: train it on example text, then generate
//// new sentences that sound _almost_ like the source. It offers incremental
//// training, sentence-aware generation that stops at a natural full stop, and a
//// pluggable sampler that puts you in control of _how_ each next word is chosen.
////
//// ```gleam
//// let model =
////   babble.new(order: 2, tokenization: babble.Words)
////   |> babble.train("the cat sat on the mat.")
////   |> babble.train("the dog sat on the log.")
////
//// let assert Ok(sentence) = babble.generate(model, babble.weighted, max_tokens: 200)
//// ```
////
//// Generation is driven by a [`Sampler`](#Sampler) — see that type along with
//// [`weighted`](#weighted) and [`most_likely`](#most_likely).

import gleam/dict
import gleam/int
import gleam/list
import gleam/option
import gleam/order
import gleam/string

/// A single unit in the chain: a sentence-start marker, a sentence-end marker,
/// or a concrete word/grapheme. Internal to the model.
type Token {
  Start
  End
  Word(String)
}

/// Whether base tokens are whole words or single characters.
pub type Tokenization {
  Words
  Characters
}

/// Tunable parameters for a model, fixed at construction.
///
/// `order` is the n-gram context length (clamped to >= 1) and `tokenization`
/// selects word- vs character-level base tokens. The generation length cap is a
/// [`generate`](#generate) argument, not a model setting.
pub type Config {
  Config(order: Int, tokenization: Tokenization)
}

/// Why a generation request produced no output.
pub type GenerateError {
  /// The model has no learned transitions to start from (it was never trained,
  /// or only ever saw empty/whitespace messages).
  EmptyModel
}

/// The next step in a sentence: emit a word, or stop. A [`Sampler`](#Sampler)
/// chooses one of these from the weighted candidates at each point in the walk.
pub type Step {
  /// Continue the sentence by emitting this word/grapheme.
  Continue(String)
  /// End the sentence here (a learned sentence boundary).
  Stop
}

/// A strategy for choosing the next [`Step`](#Step) during generation.
///
/// At each point in the walk, babble hands the sampler every successor it has
/// seen for the current context: a non-empty list of `#(step, count)` pairs,
/// where `count` is how many times that step followed the context in training.
/// The sampler returns one of those steps — [`Continue`](#Step) emits a word and
/// the walk goes on, [`Stop`](#Step) ends the sentence. (A `Stop` candidate is
/// present whenever the model saw a sentence end after this context.)
///
/// babble ships two: [`weighted`](#weighted) (random, proportional to the
/// counts) and [`most_likely`](#most_likely) (deterministic). Everything else —
/// temperature, top-k, blocklists, biasing — you write yourself, because a
/// sampler is just an ordinary function:
///
/// ```gleam
/// // Ignore the counts and pick a successor uniformly at random.
/// fn uniform(candidates: List(#(babble.Step, Int))) -> babble.Step {
///   case list.drop(candidates, int.random(list.length(candidates))) {
///     [#(step, _), ..] -> step
///     [] -> babble.Stop
///   }
/// }
/// ```
///
/// A sampler is called once per word and keeps no state between calls, so for
/// _reproducible_ output reach for [`most_likely`](#most_likely) rather than
/// trying to seed randomness here.
pub type Sampler =
  fn(List(#(Step, Int))) -> Step

/// An opaque Markov model: its `config`, the learned `transitions` (each context
/// of `order` tokens mapped to a multiset of successor tokens), and how many
/// messages it has seen.
pub opaque type Model {
  Model(
    config: Config,
    transitions: dict.Dict(List(Token), dict.Dict(Token, Int)),
    messages: Int,
  )
}

/// A new empty model, ready to [`train`](#train).
///
/// Both settings are fixed at construction (changing them would invalidate the
/// learned counts, so there are no setters):
/// - `order` — the n-gram context length: how many previous tokens to condition
///   on when picking the next. Clamped to >= 1. Higher = more coherent but more
///   verbatim; 2 is a good default.
/// - `tokenization` — `Words` or `Characters`.
///
/// The generation length cap is passed to [`generate`](#generate), not set here.
///
/// ## Examples
///
/// ```gleam
/// let model = babble.new(order: 2, tokenization: babble.Words)
/// assert babble.is_empty(model)
/// ```
pub fn new(order order: Int, tokenization tokenization: Tokenization) -> Model {
  Model(
    config: Config(order: int.max(1, order), tokenization: tokenization),
    transitions: dict.new(),
    messages: 0,
  )
}

/// The (clamped) configuration this model was built with.
pub fn config(model: Model) -> Config {
  model.config
}

/// True when the model has learned no transitions yet.
pub fn is_empty(model: Model) -> Bool {
  dict.is_empty(model.transitions)
}

/// How many non-empty messages have been folded into the model.
pub fn message_count(model: Model) -> Int {
  model.messages
}

/// Fold a single message into the model, returning a new model.
///
/// Each sentence is tokenised, padded with `order` `Start` markers and a
/// trailing `End`, and every `order`-length context -> next transition is
/// counted. The message counter bumps once if the message held a non-empty
/// sentence. It is cheap and never rebuilds, so you can keep folding in new text.
///
/// ## Examples
///
/// ```gleam
/// let model =
///   babble.new(order: 2, tokenization: babble.Words)
///   |> babble.train("the cat sat.")
///
/// assert babble.message_count(model) == 1
/// ```
pub fn train(model: Model, message: String) -> Model {
  let sentence_tokens = {
    use sentence <- list.filter_map(sentences(message))
    case tokenize(sentence, model.config.tokenization) {
      [] -> Error(Nil)
      base -> Ok(pad(base, model.config.order))
    }
  }

  Model(
    ..model,
    transitions: {
      use acc, padded <- list.fold(sentence_tokens, model.transitions)
      count_window(acc, padded, model.config.order)
    },
    messages: case sentence_tokens {
      [] -> model.messages
      _ -> model.messages + 1
    },
  )
}

/// Fold many messages into the model, in order.
pub fn train_many(model: Model, messages: List(String)) -> Model {
  list.fold(messages, model, train)
}

/// Pad a sentence's base tokens with `order` `Start`s and a trailing `End`.
fn pad(base: List(String), order: Int) -> List(Token) {
  list.flatten([list.repeat(Start, order), list.map(base, Word), [End]])
}

/// Slide an `order`-length window across `padded`, counting each
/// context -> next transition. Stops when no token follows the context.
fn count_window(
  transitions: dict.Dict(List(Token), dict.Dict(Token, Int)),
  padded: List(Token),
  order: Int,
) -> dict.Dict(List(Token), dict.Dict(Token, Int)) {
  use acc, window <- list.fold(list.window(padded, order + 1), transitions)
  case list.last(window) {
    Ok(next) -> bump(acc, list.take(window, order), next)
    Error(Nil) -> acc
  }
}

/// Increment the count of `next` following `context`, creating the successor
/// table if absent.
fn bump(
  transitions: dict.Dict(List(Token), dict.Dict(Token, Int)),
  context: List(Token),
  next: Token,
) -> dict.Dict(List(Token), dict.Dict(Token, Int)) {
  use existing <- dict.upsert(transitions, context)
  existing
  |> option.unwrap(dict.new())
  |> dict.upsert(next, fn(count) { option.unwrap(count, 0) + 1 })
}

/// Generate one sentence, choosing each next word with `sampler` and emitting at
/// most `max_tokens` of them.
///
/// Walks the chain from the start of a sentence, asking `sampler` for the next
/// step at each point, until it stops at a learned sentence end or reaches
/// `max_tokens` (clamped to >= 1). Returns `Error(EmptyModel)` if the model has
/// never been trained.
///
/// Pass [`weighted`](#weighted) for varied, corpus-like output or
/// [`most_likely`](#most_likely) for deterministic output. See [`Sampler`](#Sampler)
/// to write your own.
///
/// ## Examples
///
/// ```gleam
/// // Varied output — a different sentence each call:
/// let assert Ok(sentence) = babble.generate(model, babble.weighted, max_tokens: 200)
///
/// // No data yet:
/// let empty = babble.new(order: 2, tokenization: babble.Words)
/// assert babble.generate(empty, babble.weighted, max_tokens: 50) == Error(babble.EmptyModel)
/// ```
pub fn generate(
  model: Model,
  sampler: Sampler,
  max_tokens max_tokens: Int,
) -> Result(String, GenerateError) {
  case startable(model) {
    False -> Error(EmptyModel)
    True ->
      Ok(generate_sentence(model, start_context(model), [], sampler, max_tokens))
  }
}

/// Generate `sentences` sentences (at least 1) with `sampler`, each capped at
/// `max_tokens`, joined by spaces.
pub fn generate_paragraph(
  model: Model,
  sentences: Int,
  sampler: Sampler,
  max_tokens max_tokens: Int,
) -> Result(String, GenerateError) {
  case startable(model) {
    False -> Error(EmptyModel)
    True ->
      list.repeat(Nil, int.max(1, sentences))
      |> list.map(fn(_) {
        generate_sentence(model, start_context(model), [], sampler, max_tokens)
      })
      |> string.join(" ")
      |> Ok
  }
}

/// Generate a sentence that begins with `prefix`, choosing with `sampler` and
/// emitting at most `max_tokens` words beyond the prefix.
///
/// The continuation seeds from the last `order` prefix words (left-padded with
/// `Start`); an unknown prefix falls back to the start context, but the prefix
/// words are always kept at the front. Empty models return `Error(EmptyModel)`.
pub fn generate_starting_with(
  model: Model,
  prefix: String,
  sampler: Sampler,
  max_tokens max_tokens: Int,
) -> Result(String, GenerateError) {
  case is_empty(model) {
    True -> Error(EmptyModel)
    False -> {
      let base = tokenize(prefix, model.config.tokenization)
      let context = seed_context(list.map(base, Word), model.config.order)
      let start = case dict.has_key(model.transitions, context) {
        True -> context
        False -> start_context(model)
      }
      Ok(generate_sentence(
        model,
        start,
        list.reverse(base),
        sampler,
        max_tokens,
      ))
    }
  }
}

/// A [`Sampler`](#Sampler) that picks a successor at random, with probability
/// proportional to how often it followed the context in training.
///
/// This is the natural "talk like the corpus" behaviour and the one you'll want
/// most of the time. It uses the platform RNG, so output varies between calls —
/// pass it straight to [`generate`](#generate); you rarely call it yourself.
///
/// ## Examples
///
/// ```gleam
/// let assert Ok(sentence) = babble.generate(model, babble.weighted)
/// ```
pub fn weighted(candidates: List(#(Step, Int))) -> Step {
  let total = list.fold(candidates, 0, fn(sum, candidate) { sum + candidate.1 })
  pick(candidates, int.random(int.max(1, total)))
}

fn pick(candidates: List(#(Step, Int)), r: Int) -> Step {
  case candidates {
    [] -> Stop
    [#(step, _)] -> step
    [#(step, weight), ..rest] ->
      case r < weight {
        True -> step
        False -> pick(rest, r - weight)
      }
  }
}

/// A [`Sampler`](#Sampler) that always picks the most frequent successor, with
/// ties broken deterministically so the result never depends on internal map
/// ordering.
///
/// Generation with this sampler is fully reproducible: a given model always
/// produces the same sentence. That makes it ideal for tests and snapshots, or a
/// fixed "house style" output. Because it always takes the single most-travelled
/// path, its output tends to reproduce whole training sentences verbatim.
///
/// ## Examples
///
/// ```gleam
/// let model =
///   babble.new(order: 2, tokenization: babble.Words)
///   |> babble.train("the cat sat.")
///
/// assert babble.generate(model, babble.most_likely, max_tokens: 50) == Ok("the cat sat.")
/// ```
pub fn most_likely(candidates: List(#(Step, Int))) -> Step {
  let ranked =
    list.sort(candidates, fn(a, b) {
      case int.compare(b.1, a.1) {
        order.Eq -> string.compare(step_key(a.0), step_key(b.0))
        ordering -> ordering
      }
    })
  case ranked {
    [#(step, _), ..] -> step
    [] -> Stop
  }
}

/// A deterministic sort key for tie-breaking: `Stop` before any `Continue`,
/// words alphabetically.
fn step_key(step: Step) -> String {
  case step {
    Stop -> "0"
    Continue(word) -> "1" <> word
  }
}

/// Walk from `context` to a sentence end with `sampler`, emitting at most
/// `max_tokens` words, then join. `acc` holds already-emitted prefix words,
/// newest-first.
fn generate_sentence(
  model: Model,
  context: List(Token),
  acc: List(String),
  sampler: Sampler,
  max_tokens: Int,
) -> String {
  let emitted =
    gen_loop(model, context, acc, 0, sampler, int.max(1, max_tokens))
  join(model.config.tokenization, list.reverse(emitted))
}

fn gen_loop(
  model: Model,
  context: List(Token),
  emitted: List(String),
  count: Int,
  sampler: Sampler,
  max_tokens: Int,
) -> List(String) {
  case count >= max_tokens, dict.get(model.transitions, context) {
    False, Ok(counts) ->
      case sampler(candidates(counts)) {
        Continue(word) ->
          gen_loop(
            model,
            list.append(list.drop(context, 1), [Word(word)]),
            [word, ..emitted],
            count + 1,
            sampler,
            max_tokens,
          )
        Stop -> emitted
      }
    // Hit the cap, or a context with no successors: stop.
    _, _ -> emitted
  }
}

/// The successor table as sampler-facing weighted candidates. `End` is the only
/// non-`Word` token reachable as a successor, so it becomes `Stop`.
fn candidates(counts: dict.Dict(Token, Int)) -> List(#(Step, Int)) {
  use #(token, count) <- list.map(dict.to_list(counts))
  let step = case token {
    Word(word) -> Continue(word)
    Start | End -> Stop
  }
  #(step, count)
}

/// Whether generation can begin: the all-`Start` context has transitions.
fn startable(model: Model) -> Bool {
  dict.has_key(model.transitions, start_context(model))
}

/// The all-`Start` context generation begins from.
fn start_context(model: Model) -> List(Token) {
  list.repeat(Start, model.config.order)
}

/// Build a seed context from prefix words: the last `order` words, left-padded
/// with `Start` when there are fewer than `order` of them.
fn seed_context(words: List(Token), order: Int) -> List(Token) {
  let count = list.length(words)
  case count >= order {
    True -> list.drop(words, count - order)
    False -> list.append(list.repeat(Start, order - count), words)
  }
}

/// Join base tokens (already in final order) into a string under the given
/// tokenization.
fn join(tokenization: Tokenization, tokens: List(String)) -> String {
  case tokenization {
    Words -> string.join(tokens, " ")
    Characters -> string.concat(tokens)
  }
}

// --- Tokenization (internal; `pub` only so the test suite can reach it) ---

/// Split a message into sentences on terminal punctuation . ! ? — cutting only
/// when the punctuation ends the string or is followed by whitespace (so "3.14"
/// stays whole). Punctuation stays attached; trimmed; blanks dropped.
@internal
pub fn sentences(message: String) -> List(String) {
  message
  |> string.to_graphemes
  |> segment("", [])
  |> list.reverse
  |> list.map(string.trim)
  |> list.filter(fn(s) { s != "" })
}

/// Split a sentence into base tokens: whitespace-separated words, or graphemes.
@internal
pub fn tokenize(sentence: String, tokenization: Tokenization) -> List(String) {
  case tokenization {
    Words ->
      sentence
      |> string.replace("\n", " ")
      |> string.replace("\t", " ")
      |> string.replace("\r", " ")
      |> string.split(" ")
      |> list.filter(fn(s) { s != "" })
    Characters -> string.to_graphemes(sentence)
  }
}

fn segment(
  graphemes: List(String),
  buffer: String,
  acc: List(String),
) -> List(String) {
  case graphemes {
    [] -> [buffer, ..acc]
    [grapheme, ..rest] -> {
      let buffer = buffer <> grapheme
      let terminal = grapheme == "." || grapheme == "!" || grapheme == "?"
      let boundary = case rest {
        [] -> True
        [next, ..] ->
          next == " " || next == "\n" || next == "\t" || next == "\r"
      }
      case terminal && boundary {
        True -> segment(rest, "", [buffer, ..acc])
        False -> segment(rest, buffer, acc)
      }
    }
  }
}