src/erllama_nif.erl

%% Copyright (c) 2026 Benoit Chesneau. Licensed under the MIT License.
%% See the LICENSE file at the project root.
%%
%% @doc
%% Single NIF entry module for erllama.
%%
%% v0.2 surface (post step 2b):
%%
%%   crc32c/1            CRC32C of an iodata, dirty CPU.
%%   fsync_dir/1         dir fsync (dirty IO).
%%   load_model/2        path + opts -> {ok, ModelRef} | {error, _}.
%%   free_model/1        eager release; resource also freed on GC.
%%   new_context/2       model + opts -> {ok, CtxRef} | {error, _}.
%%   free_context/1      eager release; resource also freed on GC.
%%   tokenize/3          model + text + opts -> [token_id()].
%%   kv_pack/3           ctx + tokens + n_tokens -> binary().
%%                       (Tokens/NTokens are informational; the in-
%%                       memory llama API saves whatever is currently
%%                       in the context's seq_id=0 KV cache. The model
%%                       layer prefills exactly the desired prefix
%%                       before calling.)
%%   kv_unpack/3         ctx + binary + seq_id -> ok | {error, _}.
%% @end
-module(erllama_nif).

-export([
    crc32c/1,
    fsync_dir/1,
    load_model/2,
    free_model/1,
    new_context/2,
    free_context/1,
    tokenize/3,
    detokenize/2,
    prefill/2,
    decode_one/1,
    kv_pack/3,
    kv_pack/4,
    kv_unpack/3,
    kv_seq_rm/4,
    apply_chat_template/2,
    embed/2,
    set_grammar/2,
    configure_sampler/2,
    clear_sampler/1,
    adapter_load/2,
    adapter_free/1,
    set_adapters/2,
    sampler_new/2,
    sampler_free/1
]).

-export_type([adapter_ref/0, sampler_ref/0]).

-on_load(init/0).

-export_type([model_ref/0, context_ref/0, token_id/0]).

-type model_ref() :: reference().
-type context_ref() :: reference().
-type adapter_ref() :: reference().
-type sampler_ref() :: reference().
-type token_id() :: integer().

-spec init() -> ok | {error, term()}.
init() ->
    PrivDir =
        case code:priv_dir(erllama) of
            {error, bad_name} ->
                EbinDir = filename:dirname(code:which(?MODULE)),
                filename:join(filename:dirname(EbinDir), "priv");
            Dir ->
                Dir
        end,
    SoName = filename:join(PrivDir, "erllama_nif"),
    erlang:load_nif(SoName, 0).

%% =============================================================================
%% Public API
%% =============================================================================

-spec crc32c(iodata()) -> non_neg_integer().
crc32c(Data) -> nif_crc32c(Data).

-spec fsync_dir(iodata()) -> ok | {error, atom()}.
fsync_dir(Path) -> nif_fsync_dir(Path).

-spec load_model(iodata(), map()) -> {ok, model_ref()} | {error, atom()}.
load_model(Path, Opts) when is_map(Opts) -> nif_load_model(Path, Opts).

-spec free_model(model_ref()) -> ok.
free_model(Model) -> nif_free_model(Model).

-spec new_context(model_ref(), map()) -> {ok, context_ref()} | {error, atom()}.
new_context(Model, Opts) when is_map(Opts) -> nif_new_context(Model, Opts).

-spec free_context(context_ref()) -> ok.
free_context(Ctx) -> nif_free_context(Ctx).

-spec tokenize(model_ref(), iodata(), map()) -> [token_id()] | {error, atom()}.
tokenize(Model, Text, Opts) when is_map(Opts) -> nif_tokenize(Model, Text, Opts).

-spec detokenize(model_ref(), [token_id()]) -> binary() | {error, atom()}.
detokenize(Model, Tokens) -> nif_detokenize(Model, Tokens).

-spec prefill(context_ref(), [token_id()]) -> ok | {error, term()}.
prefill(Ctx, Tokens) -> nif_prefill(Ctx, Tokens).

-spec decode_one(context_ref()) ->
    {ok, token_id()} | {eog, token_id()} | {error, term()}.
decode_one(Ctx) -> nif_decode_one(Ctx).

-spec kv_pack(context_ref(), [token_id()], non_neg_integer()) ->
    binary() | {error, atom()}.
kv_pack(Ctx, Tokens, NTokens) -> nif_kv_pack(Ctx, Tokens, NTokens).

%% Seq-aware kv_pack. Extract the KV state for a specific seq_id.
%% Used by multi-sequence batching (v0.2+); existing v0.1 callers
%% stay on the 3-arity which defaults to seq_id=0.
-spec kv_pack(context_ref(), [token_id()], non_neg_integer(), non_neg_integer()) ->
    binary() | {error, atom()}.
kv_pack(Ctx, Tokens, NTokens, SeqId) when is_integer(SeqId), SeqId >= 0 ->
    nif_kv_pack(Ctx, Tokens, NTokens, SeqId).

-spec kv_unpack(context_ref(), binary(), non_neg_integer()) ->
    ok | {error, atom()}.
kv_unpack(Ctx, Bin, SeqId) -> nif_kv_unpack(Ctx, Bin, SeqId).

%% Remove KV cells in [P0, P1) from sequence SeqId. Use P1 = -1 for
%% "to infinity". Required after kv_unpack to drop the last cell so
%% the corresponding token can be re-prefilled to regenerate logits.
-spec kv_seq_rm(context_ref(), integer(), integer(), integer()) ->
    ok | {error, atom()}.
kv_seq_rm(Ctx, SeqId, P0, P1) -> nif_kv_seq_rm(Ctx, SeqId, P0, P1).

%% Render a normalised chat request through the model's chat template
%% (read from GGUF metadata) and tokenise the result. Request is a map
%% with `messages`, optional `system`, optional `tools`. Returns a
%% list of token ids on success.
-spec apply_chat_template(model_ref(), map()) ->
    {ok, [token_id()]} | {error, atom()}.
apply_chat_template(Model, Request) when is_map(Request) ->
    nif_apply_chat_template(Model, Request).

%% Decode a token list and read the per-sequence pooled embedding
%% vector. The context must have been opened with `embeddings => true`.
-spec embed(context_ref(), [token_id()]) ->
    {ok, [float()]} | {error, atom()}.
embed(Ctx, Tokens) when is_list(Tokens) ->
    nif_embed(Ctx, Tokens).

%% Install a GBNF grammar on the context's sampler. Subsequent
%% `decode_one/1` calls sample only tokens that keep the output on a
%% valid grammar path. Use `clear_sampler/1` to drop the grammar
%% (returns the context to greedy sampling on the next decode).
%%
%% Equivalent to `configure_sampler(Ctx, #{grammar => Grammar})`.
-spec set_grammar(context_ref(), binary()) -> ok | {error, atom()}.
set_grammar(Ctx, Grammar) when is_binary(Grammar) ->
    nif_set_grammar(Ctx, Grammar).

%% Build the sampler chain in one shot from a config map. Recognised
%% keys (all optional):
%%
%%   grammar             :: binary()           %% GBNF source
%%   repetition_penalty  :: float()            %% > 1.0 penalises repeats
%%   top_k               :: non_neg_integer()
%%   top_p               :: float()            %% (0, 1]
%%   min_p               :: float()            %% (0, 1]
%%   temperature         :: float()            %% 0.0 == greedy
%%   seed                :: non_neg_integer()  %% honoured only with temperature > 0
%%
%% Stages are appended in a deterministic order:
%% grammar -> repetition_penalty -> top_k -> top_p -> min_p ->
%% (temperature > 0 ? temp -> dist(seed) : greedy).
%%
%% Replaces any previously configured chain on the context atomically.
-spec configure_sampler(context_ref(), map()) -> ok | {error, atom()}.
configure_sampler(Ctx, Cfg) when is_map(Cfg) ->
    nif_configure_sampler(Ctx, Cfg).

-spec clear_sampler(context_ref()) -> ok.
clear_sampler(Ctx) ->
    nif_clear_sampler(Ctx).

%% Load a LoRA adapter from a GGUF file. Bound to the model: the
%% adapter is freed when the model is, or earlier on
%% `adapter_free/1`. The model is keep-referenced by the adapter
%% resource so `free_model/1` returns `{ok, deferred}` until all
%% attached adapters are dropped.
-spec adapter_load(model_ref(), iodata()) ->
    {ok, adapter_ref()} | {error, atom()}.
adapter_load(Model, Path) ->
    nif_adapter_load(Model, Path).

%% Explicit free. Idempotent: a second call returns
%% `{error, released}`. The implicit destructor handles the case
%% where the user drops the reference without calling free.
-spec adapter_free(adapter_ref()) -> ok | {error, atom()}.
adapter_free(Adapter) ->
    nif_adapter_free(Adapter).

%% Install a list of {adapter_ref(), Scale} pairs on the context.
%% Replaces any previously installed set; passing [] detaches
%% everything.
-spec set_adapters(context_ref(), [{adapter_ref(), float()}]) ->
    ok | {error, atom()}.
set_adapters(Ctx, Adapters) when is_list(Adapters) ->
    nif_set_adapters(Ctx, Adapters).

%% Build a standalone sampler chain from the same config map
%% configure_sampler/2 accepts. Holds a keep-reference on the
%% context so the context stays alive at least as long as the
%% sampler. v0.1 callers don't need this - it's the building block
%% for multi-seq batching coming in v0.2 (one sampler per request).
-spec sampler_new(context_ref(), map()) ->
    {ok, sampler_ref()} | {error, atom()}.
sampler_new(Ctx, Cfg) when is_map(Cfg) ->
    nif_sampler_new(Ctx, Cfg).

%% Explicit free. Idempotent: a second call returns
%% `{error, released}`. The implicit destructor handles unfreed
%% samplers when the resource is garbage-collected.
-spec sampler_free(sampler_ref()) -> ok | {error, atom()}.
sampler_free(Sampler) ->
    nif_sampler_free(Sampler).

%% =============================================================================
%% NIF stubs (replaced at on_load time)
%% =============================================================================

nif_crc32c(_Data) -> erlang:nif_error(nif_not_loaded).
nif_fsync_dir(_Path) -> erlang:nif_error(nif_not_loaded).
nif_load_model(_Path, _Opts) -> erlang:nif_error(nif_not_loaded).
nif_free_model(_Model) -> erlang:nif_error(nif_not_loaded).
nif_new_context(_Model, _Opts) -> erlang:nif_error(nif_not_loaded).
nif_free_context(_Ctx) -> erlang:nif_error(nif_not_loaded).
nif_tokenize(_Model, _Text, _Opts) -> erlang:nif_error(nif_not_loaded).
nif_detokenize(_Model, _Tokens) -> erlang:nif_error(nif_not_loaded).
nif_prefill(_Ctx, _Tokens) -> erlang:nif_error(nif_not_loaded).
nif_decode_one(_Ctx) -> erlang:nif_error(nif_not_loaded).
nif_kv_pack(_Ctx, _Tokens, _NTokens) -> erlang:nif_error(nif_not_loaded).
nif_kv_pack(_Ctx, _Tokens, _NTokens, _SeqId) -> erlang:nif_error(nif_not_loaded).
nif_kv_unpack(_Ctx, _Bin, _SeqId) -> erlang:nif_error(nif_not_loaded).
nif_kv_seq_rm(_Ctx, _SeqId, _P0, _P1) -> erlang:nif_error(nif_not_loaded).
nif_apply_chat_template(_Model, _Request) -> erlang:nif_error(nif_not_loaded).
nif_embed(_Ctx, _Tokens) -> erlang:nif_error(nif_not_loaded).
nif_set_grammar(_Ctx, _Grammar) -> erlang:nif_error(nif_not_loaded).
nif_configure_sampler(_Ctx, _Cfg) -> erlang:nif_error(nif_not_loaded).
nif_clear_sampler(_Ctx) -> erlang:nif_error(nif_not_loaded).
nif_adapter_load(_Model, _Path) -> erlang:nif_error(nif_not_loaded).
nif_adapter_free(_Adapter) -> erlang:nif_error(nif_not_loaded).
nif_set_adapters(_Ctx, _Adapters) -> erlang:nif_error(nif_not_loaded).
nif_sampler_new(_Ctx, _Cfg) -> erlang:nif_error(nif_not_loaded).
nif_sampler_free(_Sampler) -> erlang:nif_error(nif_not_loaded).