src/erllama_cache_policy.erl

%% Copyright (c) 2026 Benoit Chesneau. Licensed under the MIT License.
%% See the LICENSE file at the project root.
%%
%% @doc
%% Pure-Erlang policy decisions for the erllama_cache subsystem.
%%
%% Two responsibilities:
%%
%%   1. Boundary trim: cold saves persist a *trimmed-aligned prefix* of
%%      the prompt rather than the full live token list, so the next
%%      request whose prompt is a textual extension of this one still
%%      lands on the saved cache key after BPE retokenisation. The
%%      algorithm trims a fixed number of tokens off the tail and
%%      aligns the result down to a multiple of a configured chunk.
%%
%%   2. Save-reason gating: cold/continued/finish saves each have a
%%      simple guard (token-count thresholds and intervals). Eviction
%%      and shutdown saves are unconditional and do not pass through
%%      this module.
%%
%% This module has no side effects; everything is testable as plain
%% data transformations.
%% @end
-module(erllama_cache_policy).

-export([
    trim_boundary/3,
    cold_save_split/2,
    should_continued_save/3,
    should_finish_save/2,
    validate_config/1
]).

-export_type([config/0, token/0]).

-type token() :: non_neg_integer().

-type config() :: #{
    min_tokens := non_neg_integer(),
    cold_min_tokens := non_neg_integer(),
    cold_max_tokens := non_neg_integer(),
    continued_interval := pos_integer(),
    boundary_trim_tokens := non_neg_integer(),
    boundary_align_tokens := pos_integer(),
    session_resume_wait_ms => non_neg_integer()
}.

%% =============================================================================
%% Boundary trim
%% =============================================================================

-spec trim_boundary([token()], non_neg_integer(), pos_integer()) ->
    {ok, [token()]} | {skip, too_short}.
trim_boundary(Tokens, Trim, Align) when
    is_list(Tokens), is_integer(Trim), Trim >= 0, is_integer(Align), Align > 0
->
    Len = length(Tokens),
    case trim_count(Len, Trim, Align) of
        {ok, N} -> {ok, lists:sublist(Tokens, N)};
        {skip, Reason} -> {skip, Reason}
    end.

-spec trim_count(non_neg_integer(), non_neg_integer(), pos_integer()) ->
    {ok, non_neg_integer()} | {skip, too_short}.
trim_count(Len, Trim, Align) ->
    AfterTrim = Len - Trim,
    case AfterTrim < Align of
        true -> {skip, too_short};
        false -> {ok, (AfterTrim div Align) * Align}
    end.

%% =============================================================================
%% Save-reason gating
%% =============================================================================

%% Decide whether a cold save fires, and if so, return both the trimmed
%% prefix to pack/save and the remaining tokens still to be prefilled
%% into the live context.
-spec cold_save_split([token()], config()) ->
    {trim, [token()], [token()]} | no_save.
cold_save_split(Tokens, Cfg) ->
    Len = length(Tokens),
    Min = maps:get(cold_min_tokens, Cfg),
    Max = maps:get(cold_max_tokens, Cfg),
    Trim = maps:get(boundary_trim_tokens, Cfg),
    Align = maps:get(boundary_align_tokens, Cfg),
    case Len < Min orelse Len > Max of
        true ->
            no_save;
        false ->
            case trim_count(Len, Trim, Align) of
                {ok, N} ->
                    {Prefix, Rest} = lists:split(N, Tokens),
                    {trim, Prefix, Rest};
                {skip, _} ->
                    no_save
            end
    end.

%% Continued saves fire every `continued_interval` tokens of *new*
%% generation (i.e. live token count minus the count at the last save).
-spec should_continued_save(non_neg_integer(), non_neg_integer(), config()) ->
    boolean().
should_continued_save(LiveCount, LastSavedAtCount, Cfg) when
    is_integer(LiveCount),
    LiveCount >= 0,
    is_integer(LastSavedAtCount),
    LastSavedAtCount >= 0
->
    Interval = maps:get(continued_interval, Cfg),
    Min = maps:get(min_tokens, Cfg),
    LiveCount - LastSavedAtCount >= Interval andalso LiveCount >= Min.

%% Finish saves fire at successful end-of-stream provided the live
%% sequence is at or above the global minimum.
-spec should_finish_save(non_neg_integer(), config()) -> boolean().
should_finish_save(LiveCount, Cfg) when is_integer(LiveCount), LiveCount >= 0 ->
    LiveCount >= maps:get(min_tokens, Cfg).

%% =============================================================================
%% Config validation
%% =============================================================================

-spec validate_config(map()) -> ok | {error, term()}.
validate_config(Cfg) ->
    Required = [
        min_tokens,
        cold_min_tokens,
        cold_max_tokens,
        continued_interval,
        boundary_trim_tokens,
        boundary_align_tokens
    ],
    case [K || K <- Required, not maps:is_key(K, Cfg)] of
        [] -> check_invariants(Cfg);
        Missing -> {error, {missing_keys, Missing}}
    end.

-spec check_invariants(config()) -> ok | {error, term()}.
check_invariants(Cfg) ->
    Min = maps:get(min_tokens, Cfg),
    ColdMin = maps:get(cold_min_tokens, Cfg),
    ColdMax = maps:get(cold_max_tokens, Cfg),
    Interval = maps:get(continued_interval, Cfg),
    Trim = maps:get(boundary_trim_tokens, Cfg),
    Align = maps:get(boundary_align_tokens, Cfg),
    Checks = [
        {is_integer(Min) andalso Min >= 0, {invalid, min_tokens, Min}},
        {is_integer(ColdMin) andalso ColdMin >= Min, {ordering, cold_min_tokens_lt_min_tokens}},
        {
            is_integer(ColdMax) andalso ColdMax >= ColdMin,
            {ordering, cold_max_tokens_lt_cold_min_tokens}
        },
        {is_integer(Interval) andalso Interval > 0, {invalid, continued_interval, Interval}},
        {is_integer(Trim) andalso Trim >= 0, {invalid, boundary_trim_tokens, Trim}},
        {is_integer(Align) andalso Align > 0, {invalid, boundary_align_tokens, Align}}
    ],
    case lists:dropwhile(fun({Pass, _}) -> Pass end, Checks) of
        [] -> ok;
        [{_, Reason} | _] -> {error, Reason}
    end.