Skip to main content

lib/pixir/provider/cache.ex

defmodule Pixir.Provider.Cache do
  @moduledoc """
  Prompt-cache key helpers for the OpenAI Responses dialect.

  A cache key is a routing hint for a stable prefix family. It must be short and safe:
  no raw workspace paths, user text, request ids, timestamps, emails, or secrets. The
  Provider still combines this hint with its own prompt-prefix hash; Pixir treats the key
  as optimization metadata, never as durable state.
  """

  @max_key_bytes 96
  @hash_bytes 8

  # The Prompt Contract version (ADR 0020): the leading key segment names the
  # instructions/input layering this request family was built under. Bump it on
  # every prompt-contract change so an intentional fleet-wide cache break is
  # attributable in provider_usage evidence — never an unexplained hit-rate drop.
  # px1 = workspace path in the first instructions sentence. px2 = byte-stable
  # Layer 0 (discovery rule + checkpoint contract), late developer context.
  # px3 = Skill index rendered as routing-only metadata with when_to_use fields.
  @prompt_contract_version "px3"

  @doc "The current Prompt Contract version segment (leads every cache key)."
  @spec prompt_contract_version() :: String.t()
  def prompt_contract_version, do: @prompt_contract_version

  @doc """
  Build safe prompt-cache metadata for one Provider call.

  Expected fields are `:session_id`, `:model`, `:mode`, `:tools`, and `:skill_index`;
  `:fork_root_session_id` is optional and defaults to `:session_id` (a fork passes its
  fork-tree ROOT so the whole tree shares one cache family — ADR 0020).
  Returns a string-keyed map so it can be copied into a `provider_usage` Event.
  """
  @spec metadata(map()) :: {:ok, map()} | {:error, :invalid_args}
  def metadata(
        %{
          session_id: sid,
          model: model,
          mode: mode,
          tools: tools,
          skill_index: skill_index
        } = input
      ) do
    fork_root = normalize_fork_root(Map.get(input, :fork_root_session_id), sid)

    with {:ok, toolset_hash} <- stable_hash(tools),
         {:ok, skill_index_hash} <- stable_hash(skill_index),
         {:ok, session_family_hash} <- stable_hash(fork_root) do
      key =
        [
          @prompt_contract_version,
          "m_" <> slug(model, 14),
          "r_" <> slug(to_string(mode), 6),
          "s_" <> session_family_hash,
          "t_" <> toolset_hash,
          "k_" <> skill_index_hash
        ]
        |> Enum.join(":")
        |> byte_slice(@max_key_bytes)

      {:ok,
       %{
         "prompt_cache_key" => key,
         "prompt_contract_version" => @prompt_contract_version,
         "toolset_hash" => toolset_hash,
         "skill_index_hash" => skill_index_hash,
         "session_family_hash" => session_family_hash
       }}
    else
      # Never leak a raw exception from stable_hash — the spec promises
      # {:error, :invalid_args}, and callers degrade on that atom (a struct here
      # crashed the Turn task via to_string/1 before this clause existed).
      {:error, _hash_failure} -> {:error, :invalid_args}
    end
  end

  def metadata(_input), do: {:error, :invalid_args}

  defp normalize_fork_root(root, _sid) when is_binary(root) and root != "", do: root
  defp normalize_fork_root(_root, sid), do: sid

  @doc "Return a stable short hash for maps, lists, and scalar values."
  @spec stable_hash(term()) :: {:ok, String.t()} | {:error, term()}
  def stable_hash(term) do
    {:ok, stable_hash_raw(term)}
  rescue
    error -> {:error, error}
  end

  defp stable_hash_raw(term) do
    term
    |> stable_term()
    |> :erlang.term_to_binary()
    |> then(&:crypto.hash(:sha256, &1))
    |> Base.encode16(case: :lower)
    |> binary_part(0, @hash_bytes * 2)
  end

  defp stable_term(map) when is_map(map) do
    map
    |> Enum.map(fn {key, value} -> {to_string(key), stable_term(value)} end)
    |> Enum.sort_by(fn {key, _value} -> key end)
  end

  defp stable_term(list) when is_list(list), do: Enum.map(list, &stable_term/1)
  defp stable_term(value), do: value

  defp slug(value, max) do
    value
    |> to_string()
    |> String.replace(~r/[^A-Za-z0-9_.-]+/, "-")
    |> String.trim("-")
    |> case do
      "" -> "unknown"
      safe -> safe
    end
    |> byte_slice(max)
  end

  defp byte_slice(text, max) when byte_size(text) <= max, do: text

  defp byte_slice(text, max) do
    text
    |> binary_part(0, max)
    |> String.split("", trim: true)
    |> Enum.join()
  rescue
    ArgumentError ->
      text
      |> String.to_charlist()
      |> take_bytes(max, [], 0)
      |> List.to_string()
  end

  defp take_bytes(_chars, max, acc, bytes) when bytes >= max, do: Enum.reverse(acc)
  defp take_bytes([], _max, acc, _bytes), do: Enum.reverse(acc)

  defp take_bytes([char | rest], max, acc, bytes) do
    char_bytes = char |> List.wrap() |> List.to_string() |> byte_size()

    if bytes + char_bytes <= max do
      take_bytes(rest, max, [char | acc], bytes + char_bytes)
    else
      Enum.reverse(acc)
    end
  end
end