lib/markov.ex

defmodule Markov do
  @moduledoc """
  Markov-chain-based trained text generator implementation.
  Next token prediction uses two previous tokens.
  """

  import Nx.Defn
  @nx_batch_size 1000

  defstruct links: %{[:start, :start] => %{end: 1}},
            sanitize_tokens: false,
            shift: false

  # Conditionally sanitizes a token list"
  @spec cond_sanitize_tokens([any()], %Markov{}) :: [any()]
  defp cond_sanitize_tokens(tokens, chain) do
    if chain.sanitize_tokens do
      tokens |> Enum.map(&Markov.TextUtil.sanitize_token/1)
    else tokens end
  end

  @doc "Adjusts the probability of one connection"
  defn adjust_one_prob(param_tensor) do
    i          = Nx.gather(param_tensor, Nx.tensor([[0]])) |> Nx.squeeze
    peak       = Nx.gather(param_tensor, Nx.tensor([[1]])) |> Nx.squeeze
    peak_prob  = Nx.gather(param_tensor, Nx.tensor([[2]])) |> Nx.squeeze
    first_prob = Nx.gather(param_tensor, Nx.tensor([[3]])) |> Nx.squeeze
    ratio      = Nx.gather(param_tensor, Nx.tensor([[4]])) |> Nx.squeeze
    len        = Nx.gather(param_tensor, Nx.tensor([[5]])) |> Nx.squeeze

    # linear approximation
    # result = cond do
    #   i < peak ->
    #     a = peak_prob / ratio
    #     k = (peak_prob - a) / peak
    #     k * i + a
    #
    #   i == peak -> peak_prob
    #
    #   i > peak ->
    #     last = len - 1
    #     peak_to_last = last - peak
    #     k = -Nx.min((ratio - 1) / peak_to_last, peak_prob / peak_to_last)
    #     a = -k + peak_prob
    #     k * i + a
    #
    #   # hopefully never reached
    #   true -> Nx.Constants.nan
    # end

    power = Nx.tensor(1.7, type: :f32)

    # https://www.desmos.com/calculator/mq3qjg8zpm
    result = cond do
      i < peak ->
        offset = (first_prob / (ratio ** power))
        coeff = (peak_prob - (peak_prob * ratio / (ratio ** power + 1))) / (peak ** ratio)
        (coeff * (i ** ratio)) + offset

      i == peak -> peak_prob

      i > peak ->
        coeff = peak_prob / ((len - peak) ** (1 / ratio))
        coeff * ((-i + len - 1) ** (1 / ratio))

      # hopefully never reached
      true -> Nx.Constants.nan
    end

    # round off and convert scalar to {1}-shape
    Nx.round(result) |> Nx.tile([1])
  end

  @doc "Adjust the probabilities of a batch of connections"
  defn adjust_batch_probs(params) do
    results = Nx.iota({@nx_batch_size}, type: :u32) |> Nx.map(fn _ -> -1 end)

    {_, _, results} = while {i = 0, params, results}, i < @nx_batch_size do
      result = adjust_one_prob(params[i])
        |> Nx.as_type(:u32)

      i_from_params = Nx.gather(params[i], Nx.tensor([[0]]))
        |> Nx.squeeze
        |> Nx.as_type(:u32)

      {i + 1, params, Nx.put_slice(results, [i_from_params], result)}
    end

    results
  end

  @doc "Shifts probabilities if the model has a corresponding flag"
  @spec cond_shift_probs(%{[any()] => any()}, %Markov{}) :: %{[any()] => any()}
  def cond_shift_probs(links, %Markov{shift: shift}) when shift and map_size(links) >= 2 do
    # sort links by their probability
    links = links
      |> Enum.into([])
      |> Enum.sort(fn {_, foo}, {_, bar} -> foo > bar end)

    # choose the peak
    peak = max(1, :math.sqrt(length(links)) * 0.1) |> floor()
    {_, peak_prob} = links |> Enum.at(peak)
    # determine by how much the first most probable path
    # is more likely than the peak
    {_, first_prob} = links |> Enum.at(0)
    ratio = min(first_prob / peak_prob, 5)

    jitted = EXLA.jit(&Markov.adjust_batch_probs/1)
    constant_params = [peak, peak_prob, first_prob, ratio, length(links)]

    Stream.with_index(links)
      |> Stream.chunk_every(@nx_batch_size)
      |> Stream.map(fn batch ->
        processed = batch
          |> Enum.map(&elem(&1, 1))
          |> Enum.map(fn idx -> [idx | constant_params] end)
          |> Nx.tensor(type: :f32)
          |> jitted.()
          |> Nx.to_flat_list
        Enum.zip(batch, processed) |> Enum.map(fn {{{k, _}, _}, v} -> {k, v} end)
      end)
      |> Enum.into([])
      |> List.flatten
      |> Enum.into(%{})
  end
  def cond_shift_probs(links, _), do: links

  @doc """
  Trains `chain` using `text` or a list of `tokens`.

  Returns the modified chain.

  ## Example
      chain = %Markov{}
          |> Markov.train("hello, world!")
          |> Markov.train("example string number two")
          |> Markov.train("hello, Elixir!")
          |> Markov.train("fourth string")

      chain = %Markov{}
          |> Markov.train(["individual tokens", :can_be, 'arbitrary terms'])
  """
  @spec train(%Markov{}, String.t() | [any()]) :: %Markov{}
  def train(%Markov{}=chain, text) when is_binary(text) do
    tokens = String.split(text)
    train(chain, tokens)
  end

  def train(%Markov{}=chain, tokens) when is_list(tokens) do
    # add start and end tokens
    tokens = [:start, :start] ++ tokens ++ [:end]

    # adjust link weights
    new_links = Enum.reduce Markov.ListUtil.ttuples(tokens), chain.links, fn {first, second, third}, acc ->
      from = [first, second] |> cond_sanitize_tokens(chain)
      to = third
      links_from = acc[from]
      links_from = if links_from == nil do %{} else links_from end
      if links_from[to] == nil do
        Map.put(acc, from, Map.put(links_from, to, 1))
      else
        Map.put(acc, from, Map.put(links_from, to, links_from[to] + 1))
      end
    end

    # forcefully break the start -> end link
    new_links = Map.put(new_links, [:start, :start], Map.delete(new_links[[:start, :start]], :end))
    chain = %{chain | links: new_links}

    chain
  end

  @doc """
  Removes a `token` from all generation paths `chain` could produce.

  Returns the modifier chain

  ## Example
      iex> %Markov{} |>
      ...> Markov.train("a b c") |>
      ...> Markov.forget_token("b") |>
      ...> Markov.generate_text()
      "a"
  """
  @spec forget_token(%Markov{}, any()) :: %Markov{}
  def forget_token(%Markov{}=chain, token) do
    # sanitize the token
    token = if chain.sanitize_tokens do
      token |> Markov.TextUtil.sanitize_token
    else token end
    # remove links that point to the token
    %{chain | links: chain.links |> Enum.map(fn
      {[_, _]=k, v} ->
        {k, Enum.filter(v, fn {k, _} -> k != token end) |> Enum.into(%{})}
      {k, v} -> {k, v}
    end) |> Enum.into(%{})
    # terminate states that point nowhere
    |> Enum.map(fn
      {k, %{}=map} when map_size(map) == 0 ->
        {k, %{end: 1}}
      {k, v} -> {k, v}
    end) |> Enum.into(%{})}
  end

  @doc """
  Predicts the next state of a `chain` assuming `current` state.

  Note: current state conists of two tokens.

  Returns the next predicted state.

  ## Example
      iex> %Markov{} |> Markov.train("1 2 3 4 5") |> Markov.next_state(["2", "3"])
      "4"

      iex> %Markov{} |> Markov.train("1 2") |> Markov.next_state([:start, :start])
      "1"

      iex> %Markov{} |> Markov.train([:a, :b, :c]) |> Markov.next_state([:a, :b])
      :c
  """
  @spec next_state(%Markov{}, any()) :: any()
  def next_state(%Markov{}=chain, current) do
    # sanitize state
    current = current |> cond_sanitize_tokens(chain)
    # get links from current state
    # (enforce constant order by converting to proplist)
    links = chain.links[current]
      |> cond_shift_probs(chain)
      |> Enum.into([])

    # do the magic
    sum = Enum.unzip(links)
      |> Tuple.to_list
      |> List.last
      |> Enum.sum
    :rand.uniform(sum) - 1 |> probabilistic_select(links, sum)
  end

  @doc """
  Generates a list of tokens using the `chain`

  Optionally prepends `acc` to it and assumes the previous
  two states were `[state1, state2]=state`. The amount of
  the resulting token list is limited by `limit`.

  Returns the generated list.

  ## Example
      iex> %Markov{} |> Markov.train([:a, :b, :c]) |> Markov.generate_tokens()
      [:a, :b, :c]

      iex> %Markov{} |> Markov.train([:a, :b, :c]) |>
      ...> Markov.generate_tokens([], [:a, :b])
      [:c]
  """
  @spec generate_tokens(%Markov{}, [any()], [any()], integer()) :: [any()]
  def generate_tokens(%Markov{}=chain, acc \\ [], state \\ [:start, :start], limit \\ 100) do
    # iterate through states until :end
    new_state = next_state(chain, state)
    if new_state == :end or limit <= 0 do
      acc
    else
      generate_tokens(chain, acc ++ [new_state], [state |> Enum.at(1), new_state], limit - 1)
    end
  end

  @doc """
  Generates a string of text using the `chain`

  Optionally assumes the previous two states were `[state1, state2]=state`.

  Returns the generated text.

  ## Example
      iex> %Markov{} |> Markov.train("hello, world!") |> Markov.generate_text()
      "hello, world!"

      iex> %Markov{} |> Markov.train("hello, world!") |>
      ...> Markov.generate_text([:start, "hello,"])
      "world!"
  """
  @spec generate_text(%Markov{}, [any()]) :: String.t()
  def generate_text(%Markov{}=chain, state \\ [:start, :start]) do
    generate_tokens(chain, [], state) |> Enum.join(" ")
  end

  @spec probabilistic_select(integer(), list({any(), integer()}), integer(), integer()) :: any()
  defp probabilistic_select(number, [{name, add} | tail]=_choices, sum, acc \\ 0) do
    if (number >= acc) and (number < acc + add) do
      name
    else
      probabilistic_select(number, tail, sum, acc + add)
    end
  end

  @doc """
  Enables token sanitization on a `chain`.
  When this mode is enabled, the chain doesn't understand the difference similar textual tokens.
  This mode can't be disabled once it has been enabled.

  Returns the modified chain.
  """
  @spec enable_token_sanitization(%Markov{}) :: %Markov{}
  def enable_token_sanitization(%Markov{}=chain) do
    sanitize = fn t -> t |> Enum.map(&Markov.TextUtil.sanitize_token/1) end

    find_similar_states = fn [_,_]=state ->
      state = state |> sanitize.()
      chain.links |> Map.keys |> Enum.filter(fn s ->
        s |> sanitize.() == state
      end)
    end

    combine_states = fn states ->
      states |> Enum.reduce(%{}, fn state, acc ->
        Map.merge(acc, state, fn _, v1, v2 -> v1 + v2 end)
      end)
    end

    {new_links, _} = chain.links |> Map.keys |> Enum.reduce({%{}, []} , fn k, {map, ignore} ->
      sanitized = sanitize.(k)
      unless sanitized in ignore do
        similar = find_similar_states.(k) # also includes this one
        combined = similar |> Enum.map(fn k -> Map.get(chain.links, k) end) |> combine_states.()
        map = map |> Map.put(sanitized, combined)
        {map, ignore ++ [sanitized]}
      else
        {map, ignore}
      end
    end)

    %{chain | links: new_links, sanitize_tokens: true}
  end
end