defmodule Markov.ModelActions do
@moduledoc """
Performs training, generation and probability shifting. Supposed to only ever
be used by `Markov.ModelServer`s.
"""
alias Markov.ModelServer.State
use Amnesia
alias Markov.Database.{Link, Master, Operation, Weight}
import Nx.Defn
@nx_batch_size 1024
# WARNING: match specifications ahead
@doc "tag query to match specification"
@spec tq2ms({term(), [term()]}, Markov.tag_query()) :: :ets.match_spec()
def tq2ms(mf, query), do: [{
{Link, mf, :"$1", :"$2"},
[tq2msc(query)],
[{{:"$1", :"$2"}}]
}]
@doc "tag query to match spec condition"
@spec tq2msc(Markov.tag_query()) :: term()
def tq2msc(true), do: {:==, 1, 1}
def tq2msc({:not, x}), do: {:not, tq2msc(x)}
def tq2msc({x, :or, y}), do: {:orelse, tq2msc(x), tq2msc(y)}
def tq2msc({x, :score, _y}), do: tq2msc(x)
def tq2msc(tag), do: {:==, :"$1", {:const, tag}}
@doc "processes {_, :score, _} tag queries"
def process_scores(mf, rows, {_, :score, queries}) do
to_sets = for {query, score} <- queries do
ms = [{
{Link, mf, :"$1", :"$2"},
[tq2msc(query)],
[:"$2"]
}]
{Link.select(ms) |> Amnesia.Selection.values |> MapSet.new, score}
end
rows_tos = MapSet.new(for {to, _} <- rows, do: to)
Enum.reduce(to_sets, %{}, fn {set, score}, acc ->
MapSet.intersection(rows_tos, set)
|> Enum.reduce(acc, fn to, acc ->
previous = Map.get(acc, to, 0)
Map.put(acc, to, previous + score)
end)
end)
end
def process_scores(_, _, _), do: %{}
@spec train(state :: State.t(), tokens :: [term()], tags :: [term()]) :: :ok
def train(state, tokens, tags) do
order = state.options[:order]
tokens = Enum.map(0..(order - 1), fn _ -> :start end) ++ tokens ++ [:end]
Markov.ListUtil.overlapping_stride(tokens, order + 1)
|> Flow.from_enumerable
|> Flow.map(fn bit ->
from = Enum.slice(bit, 0..-2)
to = Enum.at(bit, -1)
# sanitize tokens
from = if state.options[:sanitize_tokens] do
Enum.map(from, &Markov.TextUtil.sanitize_token/1)
else from end
mf = {state.name, from}
Amnesia.ets do
for tag <- tags do
link = %Link{mod_from: mf, tag: tag, to: to} |> Link.write
:mnesia.dirty_update_counter(Weight, link, 1)
# prev_val = case Weight.read(link, :write) do
# nil -> 0
# %Weight{value: val} -> val
# end
# %Weight{link: link, value: prev_val + 1} |> Weight.write
end
end
end)
|> Flow.run
:ok
end
@spec generate(State.t(), Markov.tag_query()) :: {{:ok, [term()]} | {:error, term()}, State.t()}
def generate(state, tag_query) do
order = state.options[:order]
initial_queue = Enum.map(0..(order - 1), fn _ -> :start end)
walk_chain(state, [], initial_queue, 100, tag_query)
end
@spec walk_chain(State.t(), [term()], [term()], non_neg_integer(), Markov.tag_query())
:: {{:ok, [term()]} | {:error, term()}, State.t()}
def walk_chain(state, acc, queue, limit, tag_query) do
case next_state(state, queue, tag_query) do
_ when limit <= 0 -> {{:ok, acc}, state}
{:ok, :end, state} -> {{:ok, acc}, state}
{:error, err, state} -> {{:error, err}, state}
{:ok, next, state} ->
walk_chain(state, acc ++ [next], Enum.slice(queue, 1..-1) ++ [next], limit - 1, tag_query)
end
end
@spec next_state(State.t(), [term()], Markov.tag_query())
:: {:ok, term(), State.t()} | {:error, term(), State.t()}
def next_state(state, current, tag_query) do
current = if state.options[:sanitize_tokens] do
Enum.map(current, &Markov.TextUtil.sanitize_token/1)
else current end
mf = {state.name, current}
Amnesia.ets do
case Link.select(tq2ms(mf, tag_query)) |> Amnesia.Selection.values do
[] -> {:error, {:no_matches, current}, state}
rows ->
rows = rows |> Enum.map(fn {tag, to} ->
%Weight{value: frequency} = Weight.read(%Link{mod_from: mf, tag: tag, to: to})
{to, frequency}
end)
rows = if state.options[:shift_probabilities], do: apply_shifting(rows), else: rows
scores = process_scores(mf, rows, tag_query)
rows = rows |> Enum.map(fn {to, frequency} ->
score = Map.get(scores, to, 0) + 1
{to, frequency * score}
end)
sum = rows
|> Enum.map(fn {_, frequency} -> frequency end)
|> Enum.sum
result = probabilistic_select(:rand.uniform(sum) - 1, rows, sum)
{:ok, result, state}
end
end
end
@spec probabilistic_select(integer(), list({any(), integer()}), integer(), integer()) :: any()
defp probabilistic_select(number, [{name, add} | tail] = _choices, sum, acc \\ 0) do
if (number >= acc) and (number < acc + add) do
name
else
probabilistic_select(number, tail, sum, acc + add)
end
end
@doc "Adjusts the probability of one connection"
defn adjust_one_prob(param_tensor) do
i = Nx.gather(param_tensor, Nx.tensor([[0]])) |> Nx.squeeze
peak = Nx.gather(param_tensor, Nx.tensor([[1]])) |> Nx.squeeze
peak_prob = Nx.gather(param_tensor, Nx.tensor([[2]])) |> Nx.squeeze
first_prob = Nx.gather(param_tensor, Nx.tensor([[3]])) |> Nx.squeeze
ratio = Nx.gather(param_tensor, Nx.tensor([[4]])) |> Nx.squeeze
len = Nx.gather(param_tensor, Nx.tensor([[5]])) |> Nx.squeeze
# linear approximation
# result = cond do
# i < peak ->
# a = peak_prob / ratio
# k = (peak_prob - a) / peak
# k * i + a
#
# i == peak -> peak_prob
#
# i > peak ->
# last = len - 1
# peak_to_last = last - peak
# k = -Nx.min((ratio - 1) / peak_to_last, peak_prob / peak_to_last)
# a = -k + peak_prob
# k * i + a
#
# # hopefully never reached
# true -> Nx.Constants.nan
# end
power = Nx.tensor(1.7, type: :f32)
# https://www.desmos.com/calculator/mq3qjg8zpm
result = cond do
i < peak ->
offset = (first_prob / (ratio ** power))
coeff = (peak_prob - (peak_prob * ratio / (ratio ** power + 1))) / (peak ** ratio)
(coeff * (i ** ratio)) + offset
i == peak -> peak_prob
i > peak ->
coeff = peak_prob / ((len - peak) ** (1 / ratio))
coeff * ((-i + len - 1) ** (1 / ratio))
# hopefully never reached
true -> Nx.Constants.nan
end
# round off and convert scalar to {1}-shape
Nx.round(result) |> Nx.tile([1])
end
@doc "Adjust the probabilities of a batch of connections"
defn adjust_batch_probs(params) do
results = Nx.iota({@nx_batch_size}, type: :u32) |> Nx.map(fn _ -> -1 end)
{_, _, results} = while {i = 0, params, results}, i < @nx_batch_size do
result = adjust_one_prob(params[i])
|> Nx.as_type(:u32)
i_from_params = Nx.gather(params[i], Nx.tensor([[0]]))
|> Nx.squeeze
|> Nx.as_type(:u32)
{i + 1, params, Nx.put_slice(results, [i_from_params], result)}
end
results
end
@spec apply_shifting([{[term()], non_neg_integer()}])
:: [{[term()], non_neg_integer()}]
defp apply_shifting(rows) do
# sort rows by their probability
rows = Enum.sort(rows, fn {_, foo}, {_, bar} -> foo > bar end)
# choose the peak
min_allowed = if length(rows) == 1, do: 0, else: 1
peak = max(min_allowed, :math.sqrt(length(rows)) * 0.1) |> floor()
{_, peak_prob} = rows |> Enum.at(peak)
# determine by how much the first most probable path is more likely than
# the peak
{_, first_prob} = rows |> Enum.at(0)
ratio = min(first_prob / peak_prob, 5)
jitted = EXLA.jit(&Markov.ModelActions.adjust_batch_probs/1)
constant_params = [peak, peak_prob, first_prob, ratio, length(rows)]
Stream.with_index(rows)
|> Stream.chunk_every(@nx_batch_size)
|> Stream.flat_map(fn batch ->
processed = batch
|> Enum.map(&elem(&1, 1))
|> Enum.map(fn idx -> [idx | constant_params] end)
|> Nx.tensor(type: :f32)
|> jitted.()
|> Nx.to_flat_list
Enum.zip(batch, processed) |> Enum.map(fn {{{k, _}, _}, v} -> {k, v} end)
end)
|> Enum.into([])
end
@spec nuke(name :: term()) :: :ok
def nuke(name) do
# WARNING: matchspec ahead
Amnesia.transaction do
Link.select([{
{Link, {name, :"$1"}, :"$2", :"$3"},
[],
[{{:"$1", :"$2", :"$3"}}]
}])
|> Amnesia.Selection.values
|> Enum.map(fn {from, tag, to} ->
Link.delete({name, from})
Weight.delete(%Link{mod_from: {name, from}, tag: tag, to: to})
end)
Master.delete(name)
Operation.delete(name)
:ok
end
end
end