src/erllama_scheduler.erl

%% Copyright (c) 2026 Benoit Chesneau. Licensed under the MIT License.
%% See the LICENSE file at the project root.
%%
-module(erllama_scheduler).
-moduledoc """
Memory-pressure-driven cache eviction.

Periodically polls a pluggable pressure source (`erllama_pressure`)
and, when the used/total ratio crosses `high_watermark`, asks the
cache to evict slabs until the ratio would drop below
`low_watermark`. The eviction call is `erllama_cache:evict_bytes/2`
with a target of `(high - low) * Total` bytes; the cache may free
less if no evictable slabs remain.

Tier policy: by default the scheduler evicts only `ram` and
`ram_file` slabs. Disk-tier slabs are left in place — disk is the
cheap tier, and the deployment usually wants to keep as much warm
state as possible there. Disk eviction happens via the cache's own
per-tier quota or via an explicit `erllama_cache:gc/0` call.
Override with `evict_tiers => all` (or a custom list) to include
disk in scheduler-driven eviction.

Disabled by default. Enable via the `erllama` app environment:

```
{erllama, [
  {scheduler, #{
    enabled         => true,
    pressure_source => system,    %% noop | system | nvidia_smi | {module, M}
    interval_ms     => 5000,
    high_watermark  => 0.85,
    low_watermark   => 0.75,
    min_evict_bytes => 1048576,   %% don't bother with sub-MB targets
    evict_tiers     => [ram, ram_file]
  }}
]}
```

The scheduler always starts (so it can be enabled at runtime via
`enable/1`), but its timer only fires when `enabled = true`.
""".
-behaviour(gen_server).

-export([
    start_link/0,
    start_link/1,
    enable/1,
    set_pressure_source/1,
    set_thresholds/2,
    sample/0,
    force_check/0,
    status/0,
    validate_config/1
]).

-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2]).

-define(SERVER, ?MODULE).
-define(DEFAULT_INTERVAL_MS, 5000).
-define(DEFAULT_HIGH, 0.85).
-define(DEFAULT_LOW, 0.75).
-define(DEFAULT_MIN_EVICT, 1024 * 1024).
-define(DEFAULT_EVICT_TIERS, [ram, ram_file]).

-record(state, {
    enabled :: boolean(),
    pressure_source :: erllama_pressure:source(),
    interval_ms :: pos_integer(),
    high_watermark :: float(),
    low_watermark :: float(),
    min_evict_bytes :: non_neg_integer(),
    evict_tiers :: all | [erllama_cache:tier()],
    timer_ref :: reference() | undefined,
    last_used :: non_neg_integer(),
    last_total :: non_neg_integer(),
    last_evicted_bytes :: non_neg_integer(),
    last_evicted_at :: integer() | undefined,
    sampled_at :: integer() | undefined
}).

-type state() :: #state{}.
-type config() :: #{
    enabled => boolean(),
    pressure_source => erllama_pressure:source(),
    interval_ms => pos_integer(),
    high_watermark => float(),
    low_watermark => float(),
    min_evict_bytes => non_neg_integer(),
    evict_tiers => all | [erllama_cache:tier()]
}.

-export_type([config/0]).

%% =============================================================================
%% Public API
%% =============================================================================

-spec start_link() -> {ok, pid()} | {error, term()}.
start_link() ->
    start_link(env_config()).

-spec start_link(config()) -> {ok, pid()} | {error, term()}.
start_link(Config) ->
    gen_server:start_link({local, ?SERVER}, ?MODULE, [Config], []).

-spec enable(boolean()) -> ok.
enable(Bool) when is_boolean(Bool) ->
    gen_server:call(?SERVER, {enable, Bool}).

-spec set_pressure_source(erllama_pressure:source()) -> ok.
set_pressure_source(Source) ->
    gen_server:call(?SERVER, {set_source, Source}).

-spec set_thresholds(float(), float()) -> ok | {error, term()}.
set_thresholds(High, Low) when
    is_number(High),
    is_number(Low),
    High > Low,
    High =< 1.0,
    Low >= 0.0
->
    gen_server:call(?SERVER, {set_thresholds, High, Low});
set_thresholds(_, _) ->
    {error, bad_thresholds}.

%% @doc Take a single pressure sample without acting on it. Returns
%% the most recent reading.
-spec sample() -> erllama_pressure:reading().
sample() ->
    gen_server:call(?SERVER, sample).

-doc """
Force a check now (sample + maybe evict). Returns the eviction
result if one was triggered, `{skipped, Reason}` otherwise.
""".
-spec force_check() ->
    {evicted, non_neg_integer(), non_neg_integer()}
    | {skipped, below_watermark | disabled | nothing_to_evict}.
force_check() ->
    gen_server:call(?SERVER, force_check).

-spec status() -> map().
status() ->
    gen_server:call(?SERVER, status).

%% =============================================================================
%% gen_server callbacks
%% =============================================================================

-spec init([config()]) -> {ok, state()} | {stop, term()}.
init([Config]) ->
    case validate_config(Config) of
        {error, Reason} ->
            {stop, Reason};
        ok ->
            Source = maps:get(pressure_source, Config, noop),
            maybe_start_os_mon(Source),
            S0 = #state{
                enabled = maps:get(enabled, Config, false),
                pressure_source = Source,
                interval_ms = maps:get(interval_ms, Config, ?DEFAULT_INTERVAL_MS),
                high_watermark = maps:get(high_watermark, Config, ?DEFAULT_HIGH),
                low_watermark = maps:get(low_watermark, Config, ?DEFAULT_LOW),
                min_evict_bytes = maps:get(min_evict_bytes, Config, ?DEFAULT_MIN_EVICT),
                evict_tiers = maps:get(evict_tiers, Config, ?DEFAULT_EVICT_TIERS),
                last_used = 0,
                last_total = 1,
                last_evicted_bytes = 0
            },
            {ok, schedule_next(S0)}
    end.

handle_call({enable, Bool}, _From, S) ->
    S1 = S#state{enabled = Bool},
    {reply, ok, schedule_next(cancel_timer(S1))};
handle_call({set_source, Source}, _From, S) ->
    maybe_start_os_mon(Source),
    {reply, ok, S#state{pressure_source = Source}};
handle_call({set_thresholds, High, Low}, _From, S) ->
    {reply, ok, S#state{high_watermark = High, low_watermark = Low}};
handle_call(sample, _From, S) ->
    {Used, Total} = erllama_pressure:sample(S#state.pressure_source),
    S1 = S#state{
        last_used = Used,
        last_total = Total,
        sampled_at = monotonic_ns()
    },
    {reply, {Used, Total}, S1};
handle_call(force_check, _From, S) ->
    {Result, S1} = check_once(S),
    {reply, Result, S1};
handle_call(status, _From, S) ->
    {reply, snapshot(S), S};
handle_call(_Msg, _From, S) ->
    {reply, {error, unknown_call}, S}.

handle_cast(_Msg, S) ->
    {noreply, S}.

handle_info(tick, S) ->
    {_Result, S1} = check_once(S),
    {noreply, schedule_next(S1)};
handle_info(_Msg, S) ->
    {noreply, S}.

terminate(_Reason, _S) ->
    ok.

%% =============================================================================
%% Internal
%% =============================================================================

env_config() ->
    case application:get_env(erllama, scheduler) of
        {ok, M} when is_map(M) -> M;
        _ -> #{}
    end.

%% Validate raw config map before constructing the state record.
%% Defaults pass through unchanged; only user-supplied values are
%% type-checked. We validate before record construction so dialyzer's
%% pos_integer()/float() field types stay accurate.
validate_config(Cfg) ->
    H = maps:get(high_watermark, Cfg, ?DEFAULT_HIGH),
    L = maps:get(low_watermark, Cfg, ?DEFAULT_LOW),
    I = maps:get(interval_ms, Cfg, ?DEFAULT_INTERVAL_MS),
    case watermarks_ok(H, L) of
        false ->
            {error, {invalid_config, {watermarks, "require 0.0 <= low < high <= 1.0"}}};
        true ->
            case is_integer(I) andalso I > 0 of
                true ->
                    ok;
                false ->
                    {error, {invalid_config, {interval_ms, "must be a positive integer"}}}
            end
    end.

watermarks_ok(H, L) when is_number(H), is_number(L), H > L, H =< 1.0, L >= 0.0 ->
    true;
watermarks_ok(_, _) ->
    false.

maybe_start_os_mon(system) ->
    _ = application:ensure_all_started(os_mon),
    ok;
maybe_start_os_mon(_) ->
    ok.

cancel_timer(#state{timer_ref = undefined} = S) ->
    S;
cancel_timer(#state{timer_ref = Ref} = S) ->
    _ = erlang:cancel_timer(Ref),
    S#state{timer_ref = undefined}.

schedule_next(#state{enabled = false} = S) ->
    cancel_timer(S);
schedule_next(#state{interval_ms = Ms} = S) ->
    S1 = cancel_timer(S),
    Ref = erlang:send_after(Ms, self(), tick),
    S1#state{timer_ref = Ref}.

check_once(#state{enabled = false} = S) ->
    {{skipped, disabled}, S};
check_once(#state{pressure_source = Src} = S) ->
    {Used, Total} = erllama_pressure:sample(Src),
    NowNs = monotonic_ns(),
    S1 = S#state{
        last_used = Used,
        last_total = Total,
        sampled_at = NowNs
    },
    case Total of
        0 ->
            {{skipped, below_watermark}, S1};
        _ ->
            maybe_evict(Used, Total, S1, NowNs)
    end.

maybe_evict(Used, Total, S, NowNs) ->
    Ratio = Used / Total,
    case Ratio >= S#state.high_watermark of
        false ->
            {{skipped, below_watermark}, S};
        true ->
            Target = trunc((S#state.high_watermark - S#state.low_watermark) * Total),
            do_evict(max(Target, S#state.min_evict_bytes), S, NowNs)
    end.

do_evict(Target, S, _NowNs) when Target =< 0 ->
    {{skipped, below_watermark}, S};
do_evict(Target, S, NowNs) ->
    case erllama_cache:evict_bytes(Target, S#state.evict_tiers) of
        {evicted, 0, 0} ->
            {{skipped, nothing_to_evict}, S};
        {evicted, _N, Bytes} = R ->
            S1 = S#state{
                last_evicted_bytes = Bytes,
                last_evicted_at = NowNs
            },
            {R, S1}
    end.

snapshot(S) ->
    Total = max(S#state.last_total, 1),
    Ratio = S#state.last_used / Total,
    #{
        enabled => S#state.enabled,
        pressure_source => S#state.pressure_source,
        interval_ms => S#state.interval_ms,
        high_watermark => S#state.high_watermark,
        low_watermark => S#state.low_watermark,
        min_evict_bytes => S#state.min_evict_bytes,
        evict_tiers => S#state.evict_tiers,
        last_used => S#state.last_used,
        last_total => S#state.last_total,
        last_ratio => Ratio,
        last_evicted_bytes => S#state.last_evicted_bytes,
        sampled_at => S#state.sampled_at,
        last_evicted_at => S#state.last_evicted_at
    }.

monotonic_ns() ->
    erlang:monotonic_time(nanosecond).