Skip to main content

lib/fastest_tiktoken.ex

defmodule FastestTiktoken do
  @moduledoc """
  Fast Elixir bindings for the pure-Rust `tiktoken` crate.

  `FastestTiktoken` provides OpenAI-compatible tokenization from Elixir while
  keeping runtime calls explicit and safe: every operation returns
  `{:ok, value}` or `{:error, reason}`.

  All tokenization functions require a selector through either `:model` or
  `:encoding`.

      iex> FastestTiktoken.count_tokens("hello world", model: "gpt-4o")
      {:ok, 2}

      iex> FastestTiktoken.encode("hello world", encoding: :cl100k_base)
      {:ok, [15339, 1917]}

  ## Parity

  The public behavior is parity-tested against official OpenAI `tiktoken`
  `0.13.0` for the OpenAI encodings and API surfaces exposed here. That
  includes model mapping, GPT-2/r50k fixtures, regex edge cases, roundtrips,
  special-token behavior, `o200k_harmony`, large inputs, and batch helpers.

  Under the hood, this library wraps the high-performance pure-Rust
  [`tiktoken`](https://crates.io/crates/tiktoken) crate rather than older
  wrappers around `tiktoken-rs`.
  """

  alias FastestTiktoken.Native

  @encoding_aliases %{
    "gpt2" => "r50k_base"
  }

  @official_model_to_encoding %{
    "o1" => "o200k_base",
    "o3" => "o200k_base",
    "o4-mini" => "o200k_base",
    "gpt-5" => "o200k_base",
    "gpt-4.1" => "o200k_base",
    "gpt-4o" => "o200k_base",
    "gpt-4" => "cl100k_base",
    "gpt-3.5-turbo" => "cl100k_base",
    "gpt-3.5" => "cl100k_base",
    "gpt-35-turbo" => "cl100k_base",
    "davinci-002" => "cl100k_base",
    "babbage-002" => "cl100k_base",
    "text-embedding-ada-002" => "cl100k_base",
    "text-embedding-3-small" => "cl100k_base",
    "text-embedding-3-large" => "cl100k_base",
    "text-davinci-003" => "p50k_base",
    "text-davinci-002" => "p50k_base",
    "text-davinci-001" => "r50k_base",
    "text-curie-001" => "r50k_base",
    "text-babbage-001" => "r50k_base",
    "text-ada-001" => "r50k_base",
    "davinci" => "r50k_base",
    "curie" => "r50k_base",
    "babbage" => "r50k_base",
    "ada" => "r50k_base",
    "code-davinci-002" => "p50k_base",
    "code-davinci-001" => "p50k_base",
    "code-cushman-002" => "p50k_base",
    "code-cushman-001" => "p50k_base",
    "davinci-codex" => "p50k_base",
    "cushman-codex" => "p50k_base",
    "text-davinci-edit-001" => "p50k_edit",
    "code-davinci-edit-001" => "p50k_edit",
    "text-similarity-davinci-001" => "r50k_base",
    "text-similarity-curie-001" => "r50k_base",
    "text-similarity-babbage-001" => "r50k_base",
    "text-similarity-ada-001" => "r50k_base",
    "text-search-davinci-doc-001" => "r50k_base",
    "text-search-curie-doc-001" => "r50k_base",
    "text-search-babbage-doc-001" => "r50k_base",
    "text-search-ada-doc-001" => "r50k_base",
    "code-search-babbage-code-001" => "r50k_base",
    "code-search-ada-code-001" => "r50k_base",
    "gpt2" => "gpt2",
    "gpt-2" => "gpt2"
  }

  @official_model_prefix_to_encoding [
    {"o1-", "o200k_base"},
    {"o3-", "o200k_base"},
    {"o4-mini-", "o200k_base"},
    {"gpt-5-", "o200k_base"},
    {"gpt-4.5-", "o200k_base"},
    {"gpt-4.1-", "o200k_base"},
    {"chatgpt-4o-", "o200k_base"},
    {"gpt-4o-", "o200k_base"},
    {"gpt-4-", "cl100k_base"},
    {"gpt-3.5-turbo-", "cl100k_base"},
    {"gpt-35-turbo-", "cl100k_base"},
    {"gpt-oss-", "o200k_harmony"},
    {"ft:gpt-4o", "o200k_base"},
    {"ft:gpt-4", "cl100k_base"},
    {"ft:gpt-3.5-turbo", "cl100k_base"},
    {"ft:davinci-002", "cl100k_base"},
    {"ft:babbage-002", "cl100k_base"}
  ]

  @type encoding :: atom() | String.t()
  @type allowed_special :: :all | [String.t()]
  @type selector_opts :: keyword() | map()
  @type reason ::
          :ambiguous_selector
          | :invalid_batch
          | :invalid_allowed_special
          | :invalid_model
          | :invalid_opts
          | :invalid_text
          | :invalid_token_ids
          | :missing_selector
          | {:decode_failed, String.t()}
          | {:native_error, String.t()}
          | {:unsupported_encoding, String.t()}
          | {:unsupported_model, String.t()}

  @doc """
  Returns all encoding names compiled into the native tokenizer.

  The list includes `gpt2` as an OpenAI-compatible alias for `r50k_base`.
  """
  @spec list_encodings() :: {:ok, [String.t()]} | {:error, reason()}
  def list_encodings do
    with {:ok, encodings} <- Native.list_encodings() do
      {:ok, encodings |> Kernel.++(Map.keys(@encoding_aliases)) |> Enum.uniq() |> Enum.sort()}
    end
  end

  @doc """
  Resolves a model name to the encoding used by the native tokenizer.

  This mirrors the official OpenAI `tiktoken` model mapping for supported
  encodings, including common versioned model prefixes.

      iex> FastestTiktoken.encoding_for_model("gpt-4o-2024-05-13")
      {:ok, "o200k_base"}

      iex> FastestTiktoken.encoding_for_model("text-davinci-003")
      {:ok, "p50k_base"}

      iex> FastestTiktoken.encoding_for_model("gpt-oss-120b")
      {:ok, "o200k_harmony"}
  """
  @spec encoding_for_model(String.t()) :: {:ok, String.t()} | {:error, reason()}
  def encoding_for_model(model) when is_binary(model) and model != "" do
    case official_encoding_name_for_model(model) do
      {:ok, encoding} ->
        if supported_encoding_name?(encoding) do
          {:ok, encoding}
        else
          {:error, {:unsupported_encoding, encoding}}
        end

      :error ->
        case Native.encoding_for_model(model) do
          {:ok, encoding} -> {:ok, encoding}
          {:error, _reason} -> {:error, {:unsupported_model, model}}
        end
    end
  end

  def encoding_for_model(_model), do: {:error, :invalid_model}

  @doc """
  Encodes text into token ids.

      iex> FastestTiktoken.encode("hello world", model: "gpt-4o")
      {:ok, [24912, 2375]}

      iex> FastestTiktoken.encode("hello world", encoding: :cl100k_base)
      {:ok, [15339, 1917]}
  """
  @spec encode(String.t(), selector_opts()) :: {:ok, [non_neg_integer()]} | {:error, reason()}
  def encode(text, opts) when is_binary(text) do
    with {:ok, encoding, mode, allowed_special} <- normalize_call_opts(opts) do
      map_native_error(Native.encode(encoding, text, mode, allowed_special))
    end
  end

  def encode(_text, _opts), do: {:error, :invalid_text}

  @doc """
  Encodes text while treating special token strings as ordinary text.

  This matches the official `encode_ordinary` behavior.

      iex> FastestTiktoken.encode_ordinary("hello <|endoftext|>", encoding: :cl100k_base)
      {:ok, [15339, 83739, 8862, 728, 428, 91, 29]}
  """
  @spec encode_ordinary(String.t(), selector_opts()) ::
          {:ok, [non_neg_integer()]} | {:error, reason()}
  def encode_ordinary(text, opts) when is_binary(text) do
    opts
    |> put_allowed_special([])
    |> then(&encode(text, &1))
  end

  def encode_ordinary(_text, _opts), do: {:error, :invalid_text}

  @doc """
  Decodes token ids into text using the selected encoding.

      iex> FastestTiktoken.decode([24912, 2375], model: "gpt-4o")
      {:ok, "hello world"}
  """
  @spec decode([non_neg_integer()], selector_opts()) :: {:ok, String.t()} | {:error, reason()}
  def decode(ids, opts) do
    with {:ok, ids} <- normalize_token_ids(ids),
         {:ok, encoding} <- resolve_encoding(opts) do
      case Native.decode(encoding, ids) do
        {:ok, text} -> {:ok, text}
        {:error, reason} -> {:error, {:decode_failed, reason}}
      end
    end
  end

  @doc """
  Counts tokens for text using the selected encoding.

  With `allowed_special: []`, this uses the native crate's zero-allocation count path.

      iex> FastestTiktoken.count_tokens("表情符号是\\n🦜🔗", model: "gpt-4o")
      {:ok, 11}
  """
  @spec count_tokens(String.t(), selector_opts()) ::
          {:ok, non_neg_integer()} | {:error, reason()}
  def count_tokens(text, opts) when is_binary(text) do
    with {:ok, encoding, mode, allowed_special} <- normalize_call_opts(opts) do
      map_native_error(Native.count_tokens(encoding, text, mode, allowed_special))
    end
  end

  def count_tokens(_text, _opts), do: {:error, :invalid_text}

  @doc """
  Encodes text and decodes each token id back into a token piece.

  Some valid token ids do not decode to valid UTF-8 in isolation. In that case,
  this function returns `{:error, {:decode_failed, reason}}`.

      iex> FastestTiktoken.split_tokens("hello world", model: "gpt-4o")
      {:ok, ["hello", " world"]}
  """
  @spec split_tokens(String.t(), selector_opts()) :: {:ok, [String.t()]} | {:error, reason()}
  def split_tokens(text, opts) when is_binary(text) do
    with {:ok, encoding, mode, allowed_special} <- normalize_call_opts(opts) do
      case Native.split_tokens(encoding, text, mode, allowed_special) do
        {:ok, pieces} -> {:ok, pieces}
        {:error, reason} -> {:error, {:decode_failed, reason}}
      end
    end
  end

  def split_tokens(_text, _opts), do: {:error, :invalid_text}

  @doc """
  Encodes a batch of texts with the same selector options.

      iex> FastestTiktoken.encode_batch(["hello world"], encoding: :cl100k_base)
      {:ok, [[15339, 1917]]}
  """
  @spec encode_batch([String.t()], selector_opts()) ::
          {:ok, [[non_neg_integer()]]} | {:error, reason()}
  def encode_batch(texts, opts) when is_list(texts) do
    map_batch(texts, &encode(&1, opts))
  end

  def encode_batch(_texts, _opts), do: {:error, :invalid_batch}

  @doc """
  Encodes a batch of texts while treating special token strings as ordinary text.
  """
  @spec encode_ordinary_batch([String.t()], selector_opts()) ::
          {:ok, [[non_neg_integer()]]} | {:error, reason()}
  def encode_ordinary_batch(texts, opts) when is_list(texts) do
    map_batch(texts, &encode_ordinary(&1, opts))
  end

  def encode_ordinary_batch(_texts, _opts), do: {:error, :invalid_batch}

  @doc """
  Decodes a batch of token-id lists with the same selector options.

      iex> FastestTiktoken.decode_batch([[15339, 1917]], encoding: :cl100k_base)
      {:ok, ["hello world"]}
  """
  @spec decode_batch([[non_neg_integer()]], selector_opts()) ::
          {:ok, [String.t()]} | {:error, reason()}
  def decode_batch(batch, opts) when is_list(batch) do
    map_batch(batch, &decode(&1, opts))
  end

  def decode_batch(_batch, _opts), do: {:error, :invalid_batch}

  defp normalize_call_opts(opts) do
    with {:ok, opts} <- normalize_opts(opts),
         {:ok, encoding} <- resolve_encoding_from_opts(opts),
         {:ok, mode, allowed_special} <- resolve_allowed_special_from_opts(opts) do
      {:ok, encoding, mode, allowed_special}
    end
  end

  defp resolve_encoding(opts) do
    with {:ok, opts} <- normalize_opts(opts) do
      resolve_encoding_from_opts(opts)
    end
  end

  defp resolve_encoding_from_opts(opts) do
    model = get_opt(opts, :model)
    encoding = get_opt(opts, :encoding)

    case {model, encoding} do
      {nil, nil} ->
        {:error, :missing_selector}

      {model, nil} when is_binary(model) and model != "" ->
        with {:ok, encoding} <- encoding_for_model(model) do
          resolve_encoding_name(encoding)
        end

      {model, nil} when is_atom(model) ->
        with {:ok, encoding} <- model |> Atom.to_string() |> encoding_for_model() do
          resolve_encoding_name(encoding)
        end

      {nil, encoding} ->
        resolve_encoding_name(encoding)

      {_model, _encoding} ->
        {:error, :ambiguous_selector}
    end
  end

  defp resolve_encoding_name(encoding) when is_atom(encoding),
    do: encoding |> Atom.to_string() |> resolve_encoding_name()

  defp resolve_encoding_name("gpt2"), do: {:ok, "r50k_base"}

  defp resolve_encoding_name(encoding) when is_binary(encoding) and encoding != "" do
    if Native.encoding_exists(encoding) do
      {:ok, encoding}
    else
      {:error, {:unsupported_encoding, encoding}}
    end
  end

  defp resolve_encoding_name(_encoding), do: {:error, {:unsupported_encoding, ""}}

  defp resolve_allowed_special_from_opts(opts) do
    case get_opt(opts, :allowed_special, []) do
      :all -> {:ok, 1, []}
      [] -> {:ok, 0, []}
      allowed when is_list(allowed) -> normalize_allowed_special_list(allowed)
      _other -> {:error, :invalid_allowed_special}
    end
  end

  defp normalize_allowed_special_list(allowed) do
    if Enum.all?(allowed, &is_binary/1) do
      {:ok, 2, allowed}
    else
      {:error, :invalid_allowed_special}
    end
  end

  defp normalize_token_ids(ids) when is_list(ids) do
    if Enum.all?(ids, &valid_token_id?/1) do
      {:ok, ids}
    else
      {:error, :invalid_token_ids}
    end
  end

  defp normalize_token_ids(_ids), do: {:error, :invalid_token_ids}

  defp valid_token_id?(id), do: is_integer(id) and id >= 0 and id <= 4_294_967_295

  defp normalize_opts(opts) when is_list(opts) do
    if Keyword.keyword?(opts), do: {:ok, opts}, else: {:error, :invalid_opts}
  end

  defp normalize_opts(%{} = opts), do: {:ok, opts}
  defp normalize_opts(_opts), do: {:error, :invalid_opts}

  defp get_opt(opts, key, default \\ nil)
  defp get_opt(opts, key, default) when is_list(opts), do: Keyword.get(opts, key, default)
  defp get_opt(opts, key, default) when is_map(opts), do: Map.get(opts, key, default)

  defp put_allowed_special(opts, value) when is_list(opts),
    do: Keyword.put(opts, :allowed_special, value)

  defp put_allowed_special(%{} = opts, value), do: Map.put(opts, :allowed_special, value)
  defp put_allowed_special(opts, _value), do: opts

  defp map_native_error({:ok, value}), do: {:ok, value}
  defp map_native_error({:error, reason}), do: {:error, {:native_error, reason}}

  defp map_batch(values, fun) do
    Enum.reduce_while(values, {:ok, []}, fn value, {:ok, acc} ->
      case fun.(value) do
        {:ok, mapped} -> {:cont, {:ok, [mapped | acc]}}
        {:error, reason} -> {:halt, {:error, reason}}
      end
    end)
    |> case do
      {:ok, values} -> {:ok, Enum.reverse(values)}
      {:error, reason} -> {:error, reason}
    end
  end

  defp official_encoding_name_for_model(model) do
    case Map.fetch(@official_model_to_encoding, model) do
      {:ok, encoding} ->
        {:ok, encoding}

      :error ->
        @official_model_prefix_to_encoding
        |> Enum.find_value(:error, fn {prefix, encoding} ->
          if String.starts_with?(model, prefix), do: {:ok, encoding}
        end)
    end
  end

  defp supported_encoding_name?("gpt2"), do: true

  defp supported_encoding_name?(encoding) when is_binary(encoding),
    do: Native.encoding_exists(encoding)
end