src/tflite_beam/tokenizer/tflite_beam_full_tokenizer.erl

%% @doc
%% Runs end-to-end tokenization.
%%
%% Related link: https://github.com/tensorflow/examples/blob/master/lite/examples/bert_qa/ios/BertQACore/Models/Tokenizers/FullTokenizer.swift

-module(tflite_beam_full_tokenizer).
-export([
    tokenize/3,
    convert_to_id/2
]).

%% @doc
%% End-to-end tokenization.
-spec tokenize(binary() | list(), boolean(), map()) -> list(binary()).
tokenize(Text, IsCaseInsensitive, Vocab) when (is_binary(Text) or is_list(Text)) and is_boolean(IsCaseInsensitive) and is_map(Vocab) ->
    lists:flatten(
        lists:map(
            fun(E) ->
                tflite_beam_wordpiece_tokenizer:tokenize(E, Vocab)
            end,
            tflite_beam_basic_tokenizer:tokenize(Text, IsCaseInsensitive)
        )
    ).

%% @doc
%% Convert to ID in the vocab
-spec convert_to_id(list(binary()), map()) -> {ok, list(integer())} | {error, binary()}.
convert_to_id(Tokens, Vocab) ->
    MappedResults = 
        lists:map(
            fun(Token) ->
                case maps:is_key(Token, Vocab) of
                    true ->
                        maps:get(Token, Vocab);
                    false ->
                        Reason = io_lib:format("Cannot found token `~ts` in the given vocabulary map", [Token]),
                        unicode:characters_to_binary(Reason)
                end
            end,
            Tokens
        ),
    FilteredResults = 
        lists:filter(
            fun(R) ->
                is_binary(R)
            end,
            MappedResults
        ),
    if 
        length(FilteredResults) > 0 ->
            Reason = lists:foldl(fun(R, Acc) -> <<Acc/binary, <<"; ">>/binary, R/binary>> end, <<"">>, FilteredResults),
            {error, binary:part(Reason, {2, byte_size(Reason) - 2})};
        true ->
            {ok, MappedResults}
    end.