lib/huggingface_client.ex

defmodule HuggingfaceClient do
  @moduledoc """
  Top-level public API for the HuggingFace Elixir client.

  All inference functions accept either a `HuggingfaceClient.Client` struct (for
  shared defaults across calls) or a plain keyword list / map of per-call options.

  ## Quick start

      # Build a reusable client
      client = HuggingfaceClient.client("hf_your_token")

      # Chat completion
      {:ok, resp} = HuggingfaceClient.chat_completion(client, %{
        model:    "meta-llama/Llama-3-8B-Instruct",
        messages: [%{role: "user", content: "Tell me a joke."}]
      })

      # Streaming
      {:ok, stream} = HuggingfaceClient.chat_completion_stream(client, %{
        model:    "meta-llama/Llama-3-8B-Instruct",
        messages: [%{role: "user", content: "Count to five."}]
      })
      HuggingfaceClient.StreamHelpers.each_content(stream, &IO.write/1)

      # Text-to-image
      {:ok, png_bytes} = HuggingfaceClient.text_to_image(client, %{
        model:  "stabilityai/stable-diffusion-xl-base-1.0",
        inputs: "A cat wearing a space suit"
      })

  """

  alias HuggingfaceClient.Client
  alias HuggingfaceClient.Error.InputError
  alias HuggingfaceClient.Error.ProviderApiError
  alias HuggingfaceClient.Inference.Config, as: InferenceConfig
  alias HuggingfaceClient.Inference.ResponseHandler
  alias HuggingfaceClient.ProviderMapping
  alias HuggingfaceClient.ProviderRegistry
  alias HuggingfaceClient.SSE

  # ── Client constructors ────────────────────────────────────────────────────────

  @doc """
  Creates a reusable `Client` struct.

  ## Options

  - `:provider` — inference provider, e.g. `"groq"`, `"together"`. Defaults to auto-routing.
  - `:bill_to` — org slug billed for router requests.
  - `:endpoint_url` — custom inference endpoint URL.
  - `:retry_on_503` — whether to retry once on 503 (default `true`).
  - `:req_opts` — extra keyword opts forwarded to `Req`.

  Raises `InputError` on unknown or mistyped options.
  """
  @spec client(String.t() | nil, keyword()) :: Client.t()
  def client(token, opts \\ []) do
    Client.new(token, opts)
  end

  @doc """
  Creates a `Client` bound to a specific inference `endpoint_url`.
  """
  @spec endpoint_client(String.t() | nil, String.t(), keyword()) :: Client.t()
  def endpoint_client(token, url, opts \\ []) do
    Client.endpoint(token, url, opts)
  end

  # ── Inference tasks ────────────────────────────────────────────────────────────

  @doc """
  Sends a chat-completion request.  Returns `{:ok, response_map}` or `{:error, exception}`.
  """
  @spec chat_completion(Client.t() | map(), map()) ::
          {:ok, map()} | {:error, Exception.t()}
  def chat_completion(client, args) do
    run_task(client, args, "conversational", stream: false)
  end

  @doc """
  Sends a streaming chat-completion request.

  Returns `{:ok, stream}` where `stream` is a lazy enumerable of decoded
  `chat.completion.chunk` maps.  Use `HuggingfaceClient.StreamHelpers` to consume it.
  """
  @spec chat_completion_stream(Client.t() | map(), map()) ::
          {:ok, Enumerable.t()} | {:error, Exception.t()}
  def chat_completion_stream(client, args) do
    run_task(client, Map.put(args, :stream, true), "conversational", stream: true)
  end

  @doc "Generates an image from a text prompt.  Returns `{:ok, binary}` (PNG/JPEG bytes)."
  @spec text_to_image(Client.t() | map(), map()) ::
          {:ok, binary()} | {:error, Exception.t()}
  def text_to_image(client, args) do
    run_task(client, args, "text-to-image", stream: false)
  end

  @doc "Extracts feature vectors / embeddings."
  @spec feature_extraction(Client.t() | map(), map()) ::
          {:ok, list()} | {:error, Exception.t()}
  def feature_extraction(client, args) do
    run_task(client, args, "feature-extraction", stream: false)
  end

  @doc "Transcribes audio.  Pass raw audio bytes as `:inputs`."
  @spec automatic_speech_recognition(Client.t() | map(), map()) ::
          {:ok, map()} | {:error, Exception.t()}
  def automatic_speech_recognition(client, args) do
    run_task(client, args, "automatic-speech-recognition", stream: false)
  end

  @doc "Text-generation (non-chat)."
  @spec text_generation(Client.t() | map(), map()) ::
          {:ok, map()} | {:error, Exception.t()}
  def text_generation(client, args) do
    run_task(client, args, "text-generation", stream: false)
  end

  @doc "Classifies text into provided `candidate_labels`."
  @spec zero_shot_classification(Client.t() | map(), map()) ::
          {:ok, list()} | {:error, Exception.t()}
  def zero_shot_classification(client, args) do
    run_task(client, args, "zero-shot-classification", stream: false)
  end

  @doc "Classifies images."
  @spec image_classification(Client.t() | map(), map()) ::
          {:ok, list()} | {:error, Exception.t()}
  def image_classification(client, args) do
    run_task(client, args, "image-classification", stream: false)
  end

  @doc "Summarises text."
  @spec summarization(Client.t() | map(), map()) ::
          {:ok, map()} | {:error, Exception.t()}
  def summarization(client, args) do
    run_task(client, args, "summarization", stream: false)
  end

  @doc "Translates text."
  @spec translation(Client.t() | map(), map()) ::
          {:ok, map()} | {:error, Exception.t()}
  def translation(client, args) do
    run_task(client, args, "translation", stream: false)
  end

  @doc "Sentence similarity / bi-encoder scoring."
  @spec sentence_similarity(Client.t() | map(), map()) ::
          {:ok, list()} | {:error, Exception.t()}
  def sentence_similarity(client, args) do
    run_task(client, args, "sentence-similarity", stream: false)
  end

  # ── Core dispatch ──────────────────────────────────────────────────────────────

  defp run_task(%Client{} = client, args, default_task, run_opts) do
    opts = Client.merge_opts(client, normalize_args(args))
    do_run(opts, default_task, run_opts)
  end

  defp run_task(_non_client, _args, _default_task, _run_opts) do
    {:error, InputError.exception("First argument must be a HuggingfaceClient.Client struct.")}
  end

  defp do_run(opts, default_task, run_opts) do
    streaming? = Keyword.get(run_opts, :stream, false)
    provider_opt = opts[:provider]
    model = opts[:model]
    endpoint_url = opts[:endpoint_url]
    task = opts[:task] || default_task

    with {:ok, provider} <- ProviderMapping.resolve_provider(provider_opt, model, endpoint_url),
         {:ok, resolved_provider, resolved_task} <- maybe_fetch_mapping(provider, model, task, opts),
         {:ok, provider_mod} <-
           ProviderRegistry.get(effective_provider(resolved_provider, provider_opt), resolved_task) do
      params = build_params(opts, model, resolved_task, resolved_provider, provider_mod)

      if streaming? do
        do_stream(provider_mod, params, opts)
      else
        do_request(provider_mod, params, opts)
      end
    end
  end

  # When provider is "auto" and we have a model, try to fetch the mapping from Hub.
  # Falls back to "hf-inference" if the Hub is unreachable or no mapping is found.
  defp maybe_fetch_mapping("auto", model, task, opts) when is_binary(model) do
    hub_url = HuggingfaceClient.Hub.Client.hub_url()
    token = opts[:access_token] || HuggingfaceClient.Config.token()
    req_opts = Keyword.get(opts, :req_opts, [])

    url = "#{hub_url}/api/models/#{model}"
    headers = if token, do: %{"authorization" => "Bearer #{token}"}, else: %{}

    case Req.get(url, [finch: HuggingfaceClient.Finch, headers: headers, receive_timeout: 10_000] ++ req_opts) do
      {:ok, %{status: 200, body: %{"inferenceProviderMapping" => raw_mapping}}} ->
        mappings = ProviderMapping.normalise_mapping(raw_mapping, model)
        ProviderMapping.store_mappings(mappings)

        # Pick first live mapping that matches the requested task (or any if nil task)
        best =
          Enum.find(mappings, fn m ->
            m["status"] == "live" and (is_nil(task) or m["task"] == task)
          end) || List.first(mappings)

        if best do
          {:ok, best["provider"], best["task"] || task}
        else
          {:ok, "hf-inference", task}
        end

      _ ->
        {:ok, "hf-inference", task}
    end
  end

  defp maybe_fetch_mapping(provider, _model, task, _opts), do: {:ok, provider, task}

  # When provider resolved to "auto", use "hf-inference" for registry lookup
  defp effective_provider("auto", _explicit), do: "hf-inference"
  defp effective_provider(resolved, _), do: resolved

  defp build_params(opts, model, task, _provider, provider_mod) do
    token = opts[:access_token] || HuggingfaceClient.Config.token()

    auth_method =
      cond do
        is_binary(token) and String.starts_with?(token, "hf_") -> :hf_token
        is_binary(token) -> :provider_key
        true -> :none
      end

    %{
      args: normalize_map(opts),
      model: model,
      task: task,
      access_token: token,
      auth_method: auth_method,
      endpoint_url: opts[:endpoint_url],
      bill_to: opts[:bill_to],
      retry_on_503: Keyword.get(opts, :retry_on_503, true),
      req_opts: Keyword.get(opts, :req_opts, []),
      url_transform: nil,
      mapping: nil,
      output_type: nil
    }
  end

  defp do_request(provider_mod, params, opts) do
    url = provider_mod.make_url(params)
    payload = provider_mod.prepare_payload(params)
    binary? = is_binary(payload)
    headers = provider_mod.prepare_headers(params, binary?)

    :telemetry.span(
      [:huggingface_client, :request],
      %{provider: provider_mod.provider_id(), task: params.task, model: params.model},
      fn ->
        result = execute_request(url, payload, headers, params, provider_mod)

        status =
          case result do
            {:ok, _} -> 200
            {:error, %{http_response: %{status: s}}} -> s
            _ -> 0
          end

        {result,
         %{provider: provider_mod.provider_id(), task: params.task, model: params.model, status: status}}
      end
    )
  end

  defp execute_request(url, payload, headers, params, provider_mod) do
    req_opts = params.req_opts
    retry? = params.retry_on_503

    req_base =
      [
        finch: HuggingfaceClient.Finch,
        headers: headers,
        receive_timeout: 120_000,
        retry: :never,
        user_agent: InferenceConfig.user_agent()
      ] ++ req_opts

    result =
      if is_binary(payload) do
        Req.post(url, req_base ++ [body: payload])
      else
        Req.post(url, req_base ++ [json: payload])
      end

    case result do
      {:ok, %{status: 503}} when retry? ->
        # Single retry on 503
        if is_binary(payload) do
          Req.post(url, req_base ++ [body: payload])
        else
          Req.post(url, req_base ++ [json: payload])
        end
        |> handle_response(url, payload, headers, params, provider_mod)

      other ->
        handle_response(other, url, payload, headers, params, provider_mod)
    end
  end

  defp handle_response({:ok, %{status: 200, body: body}}, _url, _payload, _headers, params, provider_mod) do
    provider_mod.get_response(body, params)
    |> ResponseHandler.resolve(params)
  end

  defp handle_response({:ok, %{status: status, body: body}}, url, payload, headers, _params, _provider_mod) do
    redacted = redact_auth(headers)

    {:error,
     ProviderApiError.exception({
       "Provider returned HTTP #{status}",
       %{url: url, method: "POST", headers: redacted, body: sanitize_body(payload)},
       %{request_id: nil, status: status, body: body}
     })}
  end

  defp handle_response({:error, reason}, url, payload, headers, _params, _provider_mod) do
    redacted = redact_auth(headers)

    {:error,
     ProviderApiError.exception({
       "HTTP request failed: #{inspect(reason)}",
       %{url: url, method: "POST", headers: redacted, body: sanitize_body(payload)},
       %{request_id: nil, status: nil, body: nil}
     })}
  end

  defp do_stream(provider_mod, params, _opts) do
    url = provider_mod.make_url(params)
    payload = provider_mod.prepare_payload(params)
    binary? = is_binary(payload)
    headers = provider_mod.prepare_headers(params, binary?)

    req_opts = params.req_opts

    req_base =
      [
        finch: HuggingfaceClient.Finch,
        headers: headers,
        receive_timeout: 300_000,
        retry: :never,
        user_agent: InferenceConfig.user_agent()
      ] ++ req_opts

    ref = make_ref()
    caller = self()

    task =
      Task.async(fn ->
        into_fun = fn {:data, chunk}, acc ->
          send(caller, {ref, {:chunk, chunk}})
          {:cont, acc}
        end

        Req.post(url, req_base ++ if(binary?, do: [body: payload], else: [json: payload]) ++ [into: into_fun])
        send(caller, {ref, :done})
      end)

    stream =
      Stream.resource(
        fn -> {ref, task} end,
        fn {r, t} = state ->
          receive do
            {^r, {:chunk, chunk}} -> {[chunk], state}
            {^r, :done} -> {:halt, state}
          after
            120_000 ->
              Task.shutdown(t, :brutal_kill)
              {:halt, state}
          end
        end,
        fn {_r, t} -> Task.shutdown(t, :brutal_kill) end
      )
      |> SSE.parse_stream_json()

    {:ok, stream}
  end

  # ── Helpers ────────────────────────────────────────────────────────────────────

  defp normalize_args(args) when is_map(args), do: args
  defp normalize_args(args) when is_list(args), do: Map.new(args)

  defp normalize_map(keyword) when is_list(keyword), do: Map.new(keyword)
  defp normalize_map(map) when is_map(map), do: map

  defp redact_auth(headers) when is_map(headers) do
    case Map.fetch(headers, "authorization") do
      {:ok, _} -> Map.put(headers, "authorization", "Bearer [redacted]")
      :error -> headers
    end
  end

  defp redact_auth(headers), do: headers

  defp sanitize_body(body) when is_binary(body) and byte_size(body) > 512,
    do: "<<binary, #{byte_size(body)} bytes>>"

  defp sanitize_body(body), do: body
end