%% Copyright (c) 2026 Benoit Chesneau. Licensed under the MIT License.
%% See the LICENSE file at the project root.
%%
%% @doc
%% Cache-key derivation.
%%
%% A cache key is the SHA-256 of model_fingerprint || quant_byte ||
%% ctx_params_hash || tokens_le32. Cache hits are token-exact by
%% construction; semantic / approximate matching is not allowed at
%% this layer.
%%
%% The quant byte is a stable cache-internal enumeration of
%% quantisation types. It is intentionally decoupled from llama.cpp's
%% GGUF tensor type IDs so we can add new entries without depending
%% on upstream renumbering.
%% @end
-module(erllama_cache_key).
-export([
make/1,
make/4,
effective_fingerprint/2,
quant_byte/1,
quant_atom/1,
encode_tokens/1,
decode_tokens/1
]).
-export_type([key/0, quant_type/0, components/0]).
-type key() :: <<_:256>>.
-type quant_type() ::
f32
| f16
| bf16
| q4_0
| q4_1
| q5_0
| q5_1
| q8_0
| q2_k
| q3_k_s
| q3_k_m
| q3_k_l
| q4_k_m
| q4_k_s
| q5_k_m
| q5_k_s
| q6_k
| q8_k
| iq1_s
| iq1_m
| iq2_xxs
| iq2_xs
| iq2_s
| iq2_m
| iq3_xxs
| iq3_xs
| iq3_s
| iq3_m
| iq4_nl
| iq4_xs
| atom().
-type components() :: #{
fingerprint := <<_:256>>,
quant_type := quant_type(),
ctx_params_hash := <<_:256>>,
tokens := [non_neg_integer()]
}.
-spec make(components()) -> key().
make(#{
fingerprint := Fp,
quant_type := QT,
ctx_params_hash := CtxHash,
tokens := Tokens
}) when
is_binary(Fp),
byte_size(Fp) =:= 32,
is_binary(CtxHash),
byte_size(CtxHash) =:= 32,
is_list(Tokens)
->
make(Fp, QT, CtxHash, encode_tokens(Tokens)).
-doc """
Variant taking a pre-encoded TokensBin (u32-LE per token, matching
`encode_tokens/1`). Used by the longest-prefix walk so a caller can
encode once and pass `binary:part(AllTokensBin, 0, N*4)`
sub-binaries per probe, avoiding the per-attempt list traversal +
list comprehension allocation. Sub-binaries are O(1) views, so this
turns the per-probe cost into just the SHA-256 work.
""".
-spec make(<<_:256>>, quant_type(), <<_:256>>, binary()) -> key().
make(Fp, QT, CtxHash, TokensBin) when
is_binary(Fp),
byte_size(Fp) =:= 32,
is_binary(CtxHash),
byte_size(CtxHash) =:= 32,
is_binary(TokensBin),
byte_size(TokensBin) rem 4 =:= 0
->
QuantByte = quant_byte(QT),
crypto:hash(sha256, [Fp, <<QuantByte:8>>, CtxHash, TokensBin]).
-doc """
Compute an effective fingerprint from a base model fingerprint and a
list of attached LoRA adapters.
LoRA changes the model's logits, not its inputs, so attached
adapters must enter the cache key. Two requests on the same model
with different adapter sets / scales must never collide or false-hit
each other.
`effective_fp = sha256(model_fp || sorted_pairs)` where
`sorted_pairs` is the byte concatenation of
`(adapter_sha256 || u64_le(scale_q32))` for every attached adapter,
sorted by `adapter_sha256` for determinism. `scale_q32` is the scale
multiplied by `2^32` and rounded to `int64`, so floating-point
representation isn't part of the key.
An empty adapter list returns the base fingerprint unchanged.
""".
-spec effective_fingerprint(<<_:256>>, [{<<_:256>>, float()}]) -> <<_:256>>.
effective_fingerprint(Fp, []) when
is_binary(Fp), byte_size(Fp) =:= 32
->
Fp;
effective_fingerprint(Fp, Adapters) when
is_binary(Fp), byte_size(Fp) =:= 32, is_list(Adapters)
->
Sorted = lists:sort(Adapters),
Pairs = [
<<Sha/binary, (round(Scale * (1 bsl 32))):64/little-signed>>
|| {Sha, Scale} <- Sorted, is_binary(Sha), byte_size(Sha) =:= 32
],
crypto:hash(sha256, [Fp | Pairs]).
-spec quant_byte(quant_type()) -> 0..255.
quant_byte(f32) -> 0;
quant_byte(f16) -> 1;
quant_byte(q4_0) -> 2;
quant_byte(q4_1) -> 3;
quant_byte(q5_0) -> 4;
quant_byte(q5_1) -> 5;
quant_byte(q8_0) -> 6;
quant_byte(q4_k_m) -> 7;
quant_byte(q4_k_s) -> 8;
quant_byte(q5_k_m) -> 9;
quant_byte(q5_k_s) -> 10;
quant_byte(q6_k) -> 11;
quant_byte(q8_k) -> 12;
quant_byte(q2_k) -> 13;
quant_byte(q3_k_s) -> 14;
quant_byte(q3_k_m) -> 15;
quant_byte(q3_k_l) -> 16;
quant_byte(iq2_xxs) -> 17;
quant_byte(iq2_xs) -> 18;
quant_byte(iq2_s) -> 19;
quant_byte(iq2_m) -> 20;
quant_byte(iq3_xxs) -> 21;
quant_byte(iq3_xs) -> 22;
quant_byte(iq3_s) -> 23;
quant_byte(iq3_m) -> 24;
quant_byte(iq1_s) -> 25;
quant_byte(iq1_m) -> 26;
quant_byte(iq4_nl) -> 27;
quant_byte(iq4_xs) -> 28;
quant_byte(bf16) -> 29;
%% Catch-all: any future or unknown quant atom maps to byte 255 so the
%% cache key derivation never crashes. Cache buckets are already
%% differentiated by the model fingerprint, so conflating several
%% unknown labels into byte 255 is harmless.
quant_byte(_) -> 255.
-spec quant_atom(0..255) -> {ok, quant_type()} | {error, unknown_quant}.
quant_atom(0) -> {ok, f32};
quant_atom(1) -> {ok, f16};
quant_atom(2) -> {ok, q4_0};
quant_atom(3) -> {ok, q4_1};
quant_atom(4) -> {ok, q5_0};
quant_atom(5) -> {ok, q5_1};
quant_atom(6) -> {ok, q8_0};
quant_atom(7) -> {ok, q4_k_m};
quant_atom(8) -> {ok, q4_k_s};
quant_atom(9) -> {ok, q5_k_m};
quant_atom(10) -> {ok, q5_k_s};
quant_atom(11) -> {ok, q6_k};
quant_atom(12) -> {ok, q8_k};
quant_atom(13) -> {ok, q2_k};
quant_atom(14) -> {ok, q3_k_s};
quant_atom(15) -> {ok, q3_k_m};
quant_atom(16) -> {ok, q3_k_l};
quant_atom(17) -> {ok, iq2_xxs};
quant_atom(18) -> {ok, iq2_xs};
quant_atom(19) -> {ok, iq2_s};
quant_atom(20) -> {ok, iq2_m};
quant_atom(21) -> {ok, iq3_xxs};
quant_atom(22) -> {ok, iq3_xs};
quant_atom(23) -> {ok, iq3_s};
quant_atom(24) -> {ok, iq3_m};
quant_atom(25) -> {ok, iq1_s};
quant_atom(26) -> {ok, iq1_m};
quant_atom(27) -> {ok, iq4_nl};
quant_atom(28) -> {ok, iq4_xs};
quant_atom(29) -> {ok, bf16};
quant_atom(_) -> {error, unknown_quant}.
-spec encode_tokens([non_neg_integer()]) -> binary().
encode_tokens(Tokens) ->
<<<<T:32/little>> || T <- Tokens>>.
-spec decode_tokens(binary()) -> [non_neg_integer()].
decode_tokens(Bin) when is_binary(Bin), byte_size(Bin) rem 4 =:= 0 ->
[T || <<T:32/little>> <= Bin].