defmodule Markov.ModelActions do
@moduledoc """
Performs training, generation and probability shifting. Supposed to only ever
be used by `Markov.ModelServer`s.
"""
alias Markov.ModelServer.State
import Nx.Defn
@nx_batch_size 1024
@doc "processes tag scores"
@spec process_scores([{term(), non_neg_integer(), term()}], Markov.tag_query) :: %{term() => non_neg_integer()}
def process_scores(rows, tag_scores) do
tag_set = Map.keys(tag_scores) |> MapSet.new # [:tag]
rows # [{"hello", 1, :tag}, {"world", 1, :tag_two}]
|> Enum.group_by(fn {to, _, _} -> to end) # %{"hello" => [{"hello", 1, :tag}], "world" => [{"world", 1, :tag_two}]}
|> Enum.map(fn {to, list} ->
{to, Enum.map(list, fn {_, tag, _} -> tag end) |> MapSet.new}
end) # %{"hello" => MapSet.new([:tag]), "world" => MapSet.new([:tag_two])}
|> Enum.map(fn {to, tags} ->
considering = MapSet.intersection(tags, tag_set)
score = Enum.reduce(considering, 0,
fn tag, acc -> acc + Map.get(tag_scores, tag) end)
{to, score + 1}
end) |> Enum.into(%{}) # %{"hello" => 1, "world" => 0}
end
@spec train(state :: State.t(), tokens :: [term()], tags :: [term()]) :: :ok
def train(state, tokens, tags) do
order = state.options[:order]
tokens = Enum.map(0..(order - 1), fn _ -> :start end) ++ tokens ++ [:end]
for bit <- Markov.ListUtil.overlapping_stride(tokens, order + 1) do
from = Enum.slice(bit, 0..-2)
to = Enum.at(bit, -1)
# sanitize tokens
from = if state.options[:sanitize_tokens] do
Enum.map(from, &Markov.TextUtil.sanitize_token/1)
else from end
for tag <- tags do
keys = [from, tag, to]
case Sidx.select(state.main_table, keys) do
{:ok, []} -> Sidx.insert(state.main_table, keys, 1)
{:ok, [{[], val}]} -> Sidx.insert(state.main_table, keys, val + 1)
end
end
end
:ok
end
@spec generate(State.t(), Markov.tag_query()) :: {{:ok, [term()]} | {:error, term()}, State.t()}
def generate(state, tag_query) do
order = state.options[:order]
initial_queue = Enum.map(0..(order - 1), fn _ -> :start end)
walk_chain(state, [], initial_queue, 100, tag_query)
end
@spec walk_chain(State.t(), [term()], [term()], non_neg_integer(), Markov.tag_query())
:: {{:ok, [term()]} | {:error, term()}, State.t()}
def walk_chain(state, acc, queue, limit, tag_query) do
case next_state(state, queue, tag_query) do
_ when limit <= 0 -> {{:ok, acc}, state}
{:ok, :end, state} -> {{:ok, acc}, state}
{:error, err, state} -> {{:error, err}, state}
{:ok, next, state} ->
walk_chain(state, acc ++ [next], Enum.slice(queue, 1..-1) ++ [next], limit - 1, tag_query)
end
end
@spec next_state(State.t(), [term()], Markov.tag_query())
:: {:ok, term(), State.t()} | {:error, term(), State.t()}
def next_state(state, current, tag_query) do
current = if state.options[:sanitize_tokens] do
Enum.map(current, &Markov.TextUtil.sanitize_token/1)
else current end
case Sidx.select(state.main_table, [current]) do
{:ok, []} -> {:error, {:no_matches, current}, state}
{:ok, rows} ->
rows = rows |> Enum.map(fn {[to, tag], freq} -> {to, tag, freq} end)
rows = if state.options[:shift_probabilities], do: apply_shifting(rows), else: rows
scores = process_scores(rows, tag_query)
rows = rows |> Enum.map(fn {to, _, frequency} ->
{to, frequency * Map.get(scores, to)}
end)
sum = rows
|> Enum.map(fn {_, frequency} -> frequency end)
|> Enum.sum
result = probabilistic_select(:rand.uniform(sum) - 1, rows, sum)
{:ok, result, state}
end
end
@spec probabilistic_select(integer(), list({any(), integer()}), integer(), integer()) :: any()
defp probabilistic_select(number, [{name, add} | tail] = _choices, sum, acc \\ 0) do
if (number >= acc) and (number < acc + add) do
name
else
probabilistic_select(number, tail, sum, acc + add)
end
end
@doc "Adjusts the probability of one connection"
defn adjust_one_prob(param_tensor) do
i = Nx.gather(param_tensor, Nx.tensor([[0]])) |> Nx.squeeze
peak = Nx.gather(param_tensor, Nx.tensor([[1]])) |> Nx.squeeze
peak_prob = Nx.gather(param_tensor, Nx.tensor([[2]])) |> Nx.squeeze
first_prob = Nx.gather(param_tensor, Nx.tensor([[3]])) |> Nx.squeeze
ratio = Nx.gather(param_tensor, Nx.tensor([[4]])) |> Nx.squeeze
len = Nx.gather(param_tensor, Nx.tensor([[5]])) |> Nx.squeeze
power = Nx.tensor(1.7, type: :f32)
# https://www.desmos.com/calculator/mq3qjg8zpm
result = cond do
i < peak ->
offset = (first_prob / (ratio ** power))
coeff = (peak_prob - (peak_prob * ratio / (ratio ** power + 1))) / (peak ** ratio)
(coeff * (i ** ratio)) + offset
i == peak -> peak_prob
i > peak ->
coeff = peak_prob / ((len - peak) ** (1 / ratio))
coeff * ((-i + len - 1) ** (1 / ratio))
# hopefully never reached
true -> Nx.Constants.nan
end
# round off and convert scalar to {1}-shape
Nx.round(result) |> Nx.tile([1])
end
@doc "Adjust the probabilities of a batch of connections"
defn adjust_batch_probs(params) do
results = Nx.iota({@nx_batch_size}, type: :u32) |> Nx.map(fn _ -> -1 end)
{_, _, results} = while {i = 0, params, results}, i < @nx_batch_size do
result = adjust_one_prob(params[i])
|> Nx.as_type(:u32)
i_from_params = Nx.gather(params[i], Nx.tensor([[0]]))
|> Nx.squeeze
|> Nx.as_type(:u32)
{i + 1, params, Nx.put_slice(results, [i_from_params], result)}
end
results
end
@spec apply_shifting([{[term()], non_neg_integer()}])
:: [{[term()], non_neg_integer()}]
defp apply_shifting(rows) do
# sort rows by their probability
rows = Enum.sort(rows, fn {_, _, foo}, {_, _, bar} -> foo > bar end)
# choose the peak
min_allowed = if length(rows) == 1, do: 0, else: 1
peak = max(min_allowed, :math.sqrt(length(rows)) * 0.1) |> floor()
{_, _, peak_prob} = rows |> Enum.at(peak)
# determine by how much the first most probable path is more likely than
# the peak
{_, _, first_prob} = rows |> Enum.at(0)
ratio = min(first_prob / peak_prob, 5)
jitted = EXLA.jit(&Markov.ModelActions.adjust_batch_probs/1)
constant_params = [peak, peak_prob, first_prob, ratio, length(rows)]
Stream.with_index(rows)
|> Stream.chunk_every(@nx_batch_size)
|> Stream.flat_map(fn batch ->
processed = batch
|> Enum.map(&elem(&1, 1))
|> Enum.map(fn idx -> [idx | constant_params] end)
|> Nx.tensor(type: :f32)
|> jitted.()
|> Nx.to_flat_list
Enum.zip(batch, processed) |> Enum.map(fn {{{to, tag, _}, _}, fq} -> {to, tag, fq} end)
end)
|> Enum.into([])
end
end