lib/tokenizers/tokenizer.ex

defmodule Tokenizers.Tokenizer do
  @moduledoc """
  The struct and associated functions for a tokenizer.

  A `Tokenizers.Tokenizer.t()` is a container that holds the constituent parts of the tokenization pipeline.

  When you call `Tokenizers.Tokenizer.encode/3`, the input text goes through the following pipeline:

  - normalization
  - pre-tokenization
  - model
  - post-processing

  This returns a `Tokenizers.Encoding.t()`, which can then give you the token ids for each token in the input text. These token ids are usually used as the input for natural language processing machine learning models.
  """

  @type t :: %__MODULE__{resource: binary(), reference: reference()}
  defstruct resource: nil, reference: nil

  alias Tokenizers.Model
  alias Tokenizers.Native
  alias Tokenizers.Shared

  @typedoc """
  An input being a subject to tokenization.

  Can be either a single sequence, or a pair of sequences.
  """
  @type encode_input :: String.t() | {String.t(), String.t()}

  @doc """
  Instantiate a new tokenizer from an existing file on the Hugging Face Hub.

  This is going to download a tokenizer file, save it to disk and load that file.

  ## Options

    * `:http_client` - A tuple with a module and options. This module should implement
      the `request/1` function, accepting a keyword list with the options for a request.
      This is inspired by `Req.request/1`: https://hexdocs.pm/req/Req.html#request/1

      The default HTTP client config is: `{Tokenizers.HTTPClient, []}`.
      Since it's inspired by `Req`, it's possible to use that client without any adjustments.

      When making request, the options `:url` and `:method` are going to be overridden.
      `:headers` contains the "user-agent" set by default.

    * `:revision` - The revision name that should be used for fetching the tokenizers
      from Hugging Face.

    * `:use_cache` - Tells if it should read from cache when the file already exists.
      Defaults to `true`.

    * `:cache_dir` - The directory where cache is saved. Files are written to cache
      even if `:use_cache` is false. By default it uses `:filename.basedir/3` to get
      a cache dir based in the "tokenizers_elixir" application name.

    * `:additional_special_tokens` - A list of special tokens to append to the tokenizer.
      Defaults to `[]`.
  """
  @spec from_pretrained(String.t(), Keyword.t()) :: {:ok, Tokenizer.t()} | {:error, term()}
  def from_pretrained(identifier, opts \\ []) do
    opts =
      Keyword.validate!(opts,
        revision: "main",
        use_cache: true,
        cache_dir: :filename.basedir(:user_cache, "tokenizers_elixir"),
        http_client: {Tokenizers.HTTPClient, []},
        additional_special_tokens: []
      )

    {http_client, http_opts} = opts[:http_client]

    {:ok, app_version} = :application.get_key(:tokenizers, :vsn)
    app_version = List.to_string(app_version)

    headers = [{"user-agent", "tokenizers-elixir/#{app_version}"}]
    url = "/#{identifier}/resolve/#{opts[:revision]}/tokenizer.json"

    http_opts =
      http_opts
      |> Keyword.put_new(:base_url, "https://huggingface.co")
      |> Keyword.put(:url, url)
      |> Keyword.put(:method, :get)
      |> Keyword.update(:headers, headers, fn existing -> existing ++ headers end)

    cache_dir = opts[:cache_dir]

    file_path_fun = fn etag ->
      Path.join(cache_dir, entry_filename(url, etag))
    end

    tokenizer_opts = Keyword.take(opts, [:additional_special_tokens])

    if opts[:use_cache] do
      with {:ok, response} <- request(http_client, Keyword.put(http_opts, :method, :head)) do
        etag = fetch_etag(response.headers)
        file_path = file_path_fun.(etag)

        if File.exists?(file_path) do
          from_file(file_path, tokenizer_opts)
        else
          with {:ok, response} <- request(http_client, http_opts) do
            File.mkdir_p!(cache_dir)
            File.write!(file_path, response.body)

            from_file(file_path, tokenizer_opts)
          end
        end
      end
    else
      with {:ok, response} <- request(http_client, http_opts) do
        etag = fetch_etag(response.headers)
        file_path = file_path_fun.(etag)

        File.mkdir_p!(cache_dir)
        File.write!(file_path, response.body)

        from_file(file_path, tokenizer_opts)
      end
    end
  end

  defp fetch_etag(headers) do
    {_, etag} = List.keyfind!(headers, "etag", 0)

    etag
  end

  defp request(http_client, http_opts) do
    case http_client.request(http_opts) do
      {:ok, response} ->
        case response.status do
          status when status in 200..299 ->
            {:ok, response}

          404 ->
            {:error, :not_found}

          other ->
            {:error,
             "download of pretrained file failed with status #{other}. Response: #{inspect(response.body)}"}
        end

      {:error, _} = error ->
        error
    end
  end

  defp entry_filename(url, etag) do
    encode_url(url) <> "." <> encode_etag(etag)
  end

  defp encode_url(url) do
    url |> :erlang.md5() |> Base.encode32(case: :lower, padding: false)
  end

  defp encode_etag(etag) do
    Base.encode32(etag, case: :lower, padding: false)
  end

  @doc """
  Instantiate a new tokenizer from the file at the given path.

  ## Options

    * `:additional_special_tokens` - A list of special tokens to append to the tokenizer.
      Defaults to `[]`.
  """
  @spec from_file(String.t(), Keyword.t()) :: {:ok, Tokenizer.t()} | {:error, term()}
  def from_file(path, opts \\ []) do
    opts = Keyword.validate!(opts, additional_special_tokens: [])
    Native.from_file(path, opts[:additional_special_tokens])
  end

  @doc """
  Save the tokenizer to the provided path.
  """
  @spec save(Tokenizer.t(), String.t()) :: {:ok, String.t()} | {:error, term()}
  def save(tokenizer, path) do
    case Native.save(tokenizer, path, true) do
      {:ok, _} -> {:ok, path}
      {:error, reason} -> {:error, reason}
    end
  end

  @doc """
  Encode the given sequence or batch of sequences to a `Tokenizers.Encoding.t()`.

  ## Options

    * `:add_special_tokens` - whether to add special tokens to the encoding. Defaults to `true`.

  """
  @spec encode(Tokenizer.t(), encode_input() | [encode_input()], Keyword.t()) ::
          {:ok, Encoding.t() | [Encoding.t()]} | {:error, term()}
  def encode(tokenizer, input, opts \\ []) do
    add_special_tokens = Keyword.get(opts, :add_special_tokens, true)
    do_encode(tokenizer, input, add_special_tokens)
  end

  defp do_encode(tokenizer, input, add_special_tokens) when is_list(input) do
    Native.encode_batch(tokenizer, input, add_special_tokens)
  end

  defp do_encode(tokenizer, input, add_special_tokens) do
    Native.encode(tokenizer, input, add_special_tokens)
  end

  @doc """
  Decode the given list of ids or list of lists of ids back to strings.

  ## Options

    * `:skip_special_tokens` - whether the special tokens should be removed from the decoded string. Defaults to `true`.
  """
  @spec decode(Tokenizer.t(), non_neg_integer() | [non_neg_integer()], Keyword.t()) ::
          {:ok, String.t() | [String.t()]} | {:error, term()}
  def decode(tokenizer, ids, opts \\ []) do
    skip_special_tokens = Keyword.get(opts, :skip_special_tokens, true)
    do_decode(tokenizer, ids, skip_special_tokens)
  end

  defp do_decode(tokenizer, [first | _] = ids, skip_special_tokens) when is_integer(first),
    do: Native.decode(tokenizer, ids, skip_special_tokens)

  defp do_decode(tokenizer, [first | _] = ids, skip_special_tokens) when is_list(first),
    do: Native.decode_batch(tokenizer, ids, skip_special_tokens)

  @doc """
  Get the tokenizer's vocabulary as a map of token to id.
  """
  @spec get_vocab(Tokenizer.t()) :: %{binary() => integer()}
  def get_vocab(tokenizer), do: tokenizer |> Native.get_vocab(false) |> Shared.unwrap()

  @doc """
  Get the number of tokens in the vocabulary.
  """
  @spec get_vocab_size(Tokenizer.t()) :: non_neg_integer()
  def get_vocab_size(tokenizer), do: tokenizer |> Native.get_vocab_size(true) |> Shared.unwrap()

  @doc """
  Convert a given id to its token.
  """
  @spec id_to_token(Tokenizer.t(), integer()) :: String.t()
  def id_to_token(tokenizer, id), do: tokenizer |> Native.id_to_token(id) |> Shared.unwrap()

  @doc """
  Convert a given token to its id.
  """
  @spec token_to_id(Tokenizer.t(), binary()) :: non_neg_integer()
  def token_to_id(tokenizer, token), do: tokenizer |> Native.token_to_id(token) |> Shared.unwrap()

  @doc """
  Get the `Tokenizer`'s `Model`.
  """
  @spec get_model(Tokenizer.t()) :: Model.t()
  def get_model(tokenizer), do: tokenizer |> Native.get_model() |> Shared.unwrap()
end

defimpl Inspect, for: Tokenizers.Tokenizer do
  import Inspect.Algebra

  alias Tokenizers.Model
  alias Tokenizers.Tokenizer

  def inspect(tokenizer, opts) do
    model_details =
      tokenizer
      |> Tokenizer.get_model()
      |> Model.get_model_details()
      |> Keyword.new(fn {k, v} -> {String.to_atom(k), v} end)

    attrs =
      Keyword.merge(
        [
          vocab_size: Tokenizer.get_vocab_size(tokenizer)
        ],
        model_details
      )

    concat(["#Tokenizers.Tokenizer<", to_doc(attrs, opts), ">"])
  end
end