lib/huggingface_client/hub/evaluation/evaluate.ex

defmodule HuggingfaceClient.Hub.Evaluate do
  @moduledoc """
  HuggingFace Evaluate — metrics computation API.

  Evaluate provides standardized metrics for NLP, vision, and other ML tasks.
  Metrics are hosted on the Hub and can be loaded by name.

  See: https://huggingface.co/docs/evaluate

  ## Available metric categories

  - **Text**: BLEU, ROUGE, BERTScore, SacreBLEU, WER, CER, METEOR, TER
  - **Classification**: Accuracy, F1, Precision, Recall, AUC, MCC
  - **Regression**: MSE, MAE, RMSE, R²
  - **Code**: CodeBLEU, pass@k
  - **Translation**: chrF, chrf++, TER

  ## Example

      # Compute BLEU score
      {:ok, result} = HuggingfaceClient.compute_metric("bleu",
        predictions: ["the cat is on the mat"],
        references: [["the cat sat on the mat"]]
      )
      IO.puts("BLEU: \#{result["bleu"]}")

      # Compute accuracy
      {:ok, result} = HuggingfaceClient.compute_metric("accuracy",
        predictions: [0, 1, 0, 1],
        references: [0, 1, 1, 1]
      )
      IO.puts("Accuracy: \#{result["accuracy"]}")

      # Multiple metrics at once
      {:ok, results} = HuggingfaceClient.evaluate_model(
        model: "bert-base-uncased",
        dataset: "glue",
        subset: "mrpc",
        split: "validation",
        metrics: ["accuracy", "f1"],
        access_token: "hf_..."
      )
  """

  alias HuggingfaceClient.Error.InputError
  alias HuggingfaceClient.Hub.Client

  @evaluate_api "https://api.huggingface.co/evaluate"

  # ── Metric info ───────────────────────────────────────────────────────────────

  @doc """
  Lists all available metrics on the Hub.

  ## Options

  - `:search` — filter by name
  - `:type` — filter by type: `"metric"`, `"comparison"`, `"measurement"`
  - `:access_token`

  ## Example

      {:ok, metrics} = HuggingfaceClient.list_metrics()
      Enum.each(metrics, fn m -> IO.puts("\#{m["id"]}: \#{m["description"]}") end)
  """
  @spec list_metrics(keyword()) :: {:ok, [map()]} | {:error, Exception.t()}
  def list_metrics(opts \\ []) do
    token = opts[:access_token]
    search = opts[:search]
    type = opts[:type]

    params = []
    params = if search, do: [{"search", search} | params], else: params
    params = if type, do: [{"type", type} | params], else: params

    url =
      "#{Client.hub_url()}/api/metrics" <>
        if(params == [], do: "", else: "?" <> URI.encode_query(params))

    case Client.get(url, token, opts) do
      {:ok, body} when is_list(body) -> {:ok, body}
      {:ok, other} -> {:ok, List.wrap(other)}
      err -> err
    end
  end

  @doc """
  Gets detailed information about a specific metric.

  ## Example

      {:ok, info} = HuggingfaceClient.metric_info("bleu")
      IO.puts("Description: \#{info["description"]}")
      IO.puts("Reference: \#{info["reference_urls"]}")
  """
  @spec metric_info(String.t(), keyword()) :: {:ok, map()} | {:error, Exception.t()}
  def metric_info(metric_name, opts \\ []) do
    token = opts[:access_token]
    url = "#{Client.hub_url()}/api/metrics/#{metric_name}"
    Client.get(url, token, opts)
  end

  # ── Compute metrics ───────────────────────────────────────────────────────────

  @doc """
  Computes a metric given predictions and references.

  This calls the HuggingFace Evaluate API to compute the metric server-side.

  ## Options

  - `:metric` — metric name (required), e.g. `"bleu"`, `"rouge"`, `"accuracy"`
  - `:predictions` — list of model predictions (required)
  - `:references` — list of ground truth references (required)
  - `:kwargs` — additional metric-specific parameters
  - `:access_token`

  ## Examples

      # ROUGE score
      {:ok, result} = HuggingfaceClient.compute_metric("rouge",
        predictions: ["The cat sat on the mat"],
        references: ["The cat is on the mat"]
      )
      IO.inspect(result["rouge1"])

      # BLEU
      {:ok, result} = HuggingfaceClient.compute_metric("bleu",
        predictions: ["hello world"],
        references: [["hello world", "hi world"]]
      )

      # WER (word error rate) for ASR
      {:ok, result} = HuggingfaceClient.compute_metric("wer",
        predictions: ["it is raining"],
        references: ["it is raining"]
      )

      # F1 for classification
      {:ok, result} = HuggingfaceClient.compute_metric("f1",
        predictions: [0, 1, 0, 1],
        references: [0, 1, 1, 1],
        kwargs: %{"average" => "macro"}
      )
  """
  @spec compute(String.t(), keyword()) :: {:ok, map()} | {:error, Exception.t()}
  def compute(metric, opts) do
    predictions = opts[:predictions] || raise InputError, ":predictions is required"
    references = opts[:references] || raise InputError, ":references is required"
    token = opts[:access_token]

    body = %{
      "metric" => metric,
      "predictions" => predictions,
      "references" => references
    }

    body = if opts[:kwargs], do: Map.put(body, "kwargs", opts[:kwargs]), else: body

    url = "#{@evaluate_api}/compute"

    case Client.post(url, body, token, opts) do
      {:ok, resp} -> {:ok, resp}
      err -> err
    end
  end

  @doc """
  Computes multiple metrics at once.

  ## Options

  - `:metrics` — list of metric names (required)
  - `:predictions` — list of model predictions (required)
  - `:references` — list of ground truth references (required)
  - `:access_token`

  ## Example

      {:ok, results} = HuggingfaceClient.compute_metrics(
        metrics: ["rouge", "bleu"],
        predictions: ["The cat sat"],
        references: ["The cat is on the mat"]
      )
      IO.inspect(results)
  """
  @spec compute_multiple(keyword()) :: {:ok, map()} | {:error, Exception.t()}
  def compute_multiple(opts) do
    metrics = opts[:metrics] || raise InputError, ":metrics is required"
    predictions = opts[:predictions] || raise InputError, ":predictions is required"
    references = opts[:references] || raise InputError, ":references is required"

    results =
      metrics
      |> Task.async_stream(
        fn metric ->
          case compute(metric, Keyword.merge(opts, predictions: predictions, references: references)) do
            {:ok, result} -> {metric, result}
            {:error, _} -> {metric, nil}
          end
        end,
        max_concurrency: 5,
        timeout: 60_000
      )
      |> Enum.reduce(%{}, fn
        {:ok, {metric, result}}, acc -> Map.put(acc, metric, result)
        _, acc -> acc
      end)

    {:ok, results}
  end

  @doc """
  Evaluates a model on a dataset using specified metrics.

  Runs inference + metric computation server-side via HF API.

  ## Options

  - `:model` — HF model ID (required)
  - `:dataset` — HF dataset ID (required)
  - `:subset` — dataset configuration/subset
  - `:split` — dataset split (default: `"test"`)
  - `:metrics` — list of metric names (required)
  - `:task` — task type (e.g. `"text-classification"`)
  - `:access_token`

  ## Example

      {:ok, results} = HuggingfaceClient.evaluate_model(
        model: "distilbert-base-uncased-finetuned-sst-2-english",
        dataset: "glue",
        subset: "sst2",
        split: "validation",
        metrics: ["accuracy", "f1"],
        task: "text-classification",
        access_token: "hf_..."
      )
      IO.puts("Accuracy: \#{results["accuracy"]}")
  """
  @spec evaluate_model(keyword()) :: {:ok, map()} | {:error, Exception.t()}
  def evaluate_model(opts) do
    model = opts[:model] || raise InputError, ":model is required"
    dataset = opts[:dataset] || raise InputError, ":dataset is required"
    metrics = opts[:metrics] || raise InputError, ":metrics is required"
    token = opts[:access_token]

    body = %{
      "model" => model,
      "dataset" => dataset,
      "metrics" => metrics,
      "split" => opts[:split] || "test"
    }

    body =
      body
      |> put_opt("subset", opts[:subset])
      |> put_opt("task", opts[:task])

    url = "#{@evaluate_api}/evaluate"

    case Client.post(url, body, token, opts) do
      {:ok, resp} -> {:ok, resp}
      err -> err
    end
  end

  @doc """
  Gets a list of standard benchmark results for a model.

  ## Example

      {:ok, benchmarks} = HuggingfaceClient.model_benchmarks("meta-llama/Llama-3.1-8B-Instruct")
      Enum.each(benchmarks, fn b ->
        IO.puts("\#{b["benchmark"]}: \#{b["score"]}")
      end)
  """
  @spec model_benchmarks(String.t(), keyword()) :: {:ok, [map()]} | {:error, Exception.t()}
  def model_benchmarks(model_id, opts \\ []) do
    token = opts[:access_token]
    url = "#{Client.hub_url()}/api/models/#{model_id}/results"

    case Client.get(url, token, opts) do
      {:ok, body} when is_list(body) -> {:ok, body}
      {:ok, other} -> {:ok, List.wrap(other)}
      err -> err
    end
  end

  # ── Private ───────────────────────────────────────────────────────────────────

  defp put_opt(map, _, nil), do: map
  defp put_opt(map, key, val), do: Map.put(map, key, val)
end