defmodule HuggingfaceClient do
@moduledoc """
Top-level public API for the HuggingFace Elixir client.
All inference functions accept either a `HuggingfaceClient.Client` struct (for
shared defaults across calls) or a plain keyword list / map of per-call options.
## Quick start
# Build a reusable client
client = HuggingfaceClient.client("hf_your_token")
# Chat completion
{:ok, resp} = HuggingfaceClient.chat_completion(client, %{
model: "meta-llama/Llama-3-8B-Instruct",
messages: [%{role: "user", content: "Tell me a joke."}]
})
# Streaming
{:ok, stream} = HuggingfaceClient.chat_completion_stream(client, %{
model: "meta-llama/Llama-3-8B-Instruct",
messages: [%{role: "user", content: "Count to five."}]
})
HuggingfaceClient.StreamHelpers.each_content(stream, &IO.write/1)
# Text-to-image
{:ok, png_bytes} = HuggingfaceClient.text_to_image(client, %{
model: "stabilityai/stable-diffusion-xl-base-1.0",
inputs: "A cat wearing a space suit"
})
"""
alias HuggingfaceClient.Client
alias HuggingfaceClient.Error.InputError
alias HuggingfaceClient.Error.ProviderApiError
alias HuggingfaceClient.Inference.Config, as: InferenceConfig
alias HuggingfaceClient.Inference.ResponseHandler
alias HuggingfaceClient.ProviderMapping
alias HuggingfaceClient.ProviderRegistry
alias HuggingfaceClient.SSE
# ── Client constructors ────────────────────────────────────────────────────────
@doc """
Creates a reusable `Client` struct.
## Options
- `:provider` — inference provider, e.g. `"groq"`, `"together"`. Defaults to auto-routing.
- `:bill_to` — org slug billed for router requests.
- `:endpoint_url` — custom inference endpoint URL.
- `:retry_on_503` — whether to retry once on 503 (default `true`).
- `:req_opts` — extra keyword opts forwarded to `Req`.
Raises `InputError` on unknown or mistyped options.
"""
@spec client(String.t() | nil, keyword()) :: Client.t()
def client(token, opts \\ []) do
Client.new(token, opts)
end
@doc """
Creates a `Client` bound to a specific inference `endpoint_url`.
"""
@spec endpoint_client(String.t() | nil, String.t(), keyword()) :: Client.t()
def endpoint_client(token, url, opts \\ []) do
Client.endpoint(token, url, opts)
end
# ── Inference tasks ────────────────────────────────────────────────────────────
@doc """
Sends a chat-completion request. Returns `{:ok, response_map}` or `{:error, exception}`.
"""
@spec chat_completion(Client.t() | map(), map()) ::
{:ok, map()} | {:error, Exception.t()}
def chat_completion(client, args) do
run_task(client, args, "conversational", stream: false)
end
@doc """
Sends a streaming chat-completion request.
Returns `{:ok, stream}` where `stream` is a lazy enumerable of decoded
`chat.completion.chunk` maps. Use `HuggingfaceClient.StreamHelpers` to consume it.
"""
@spec chat_completion_stream(Client.t() | map(), map()) ::
{:ok, Enumerable.t()} | {:error, Exception.t()}
def chat_completion_stream(client, args) do
run_task(client, Map.put(args, :stream, true), "conversational", stream: true)
end
@doc "Generates an image from a text prompt. Returns `{:ok, binary}` (PNG/JPEG bytes)."
@spec text_to_image(Client.t() | map(), map()) ::
{:ok, binary()} | {:error, Exception.t()}
def text_to_image(client, args) do
run_task(client, args, "text-to-image", stream: false)
end
@doc "Extracts feature vectors / embeddings."
@spec feature_extraction(Client.t() | map(), map()) ::
{:ok, list()} | {:error, Exception.t()}
def feature_extraction(client, args) do
run_task(client, args, "feature-extraction", stream: false)
end
@doc "Transcribes audio. Pass raw audio bytes as `:inputs`."
@spec automatic_speech_recognition(Client.t() | map(), map()) ::
{:ok, map()} | {:error, Exception.t()}
def automatic_speech_recognition(client, args) do
run_task(client, args, "automatic-speech-recognition", stream: false)
end
@doc "Text-generation (non-chat)."
@spec text_generation(Client.t() | map(), map()) ::
{:ok, map()} | {:error, Exception.t()}
def text_generation(client, args) do
run_task(client, args, "text-generation", stream: false)
end
@doc "Classifies text into provided `candidate_labels`."
@spec zero_shot_classification(Client.t() | map(), map()) ::
{:ok, list()} | {:error, Exception.t()}
def zero_shot_classification(client, args) do
run_task(client, args, "zero-shot-classification", stream: false)
end
@doc "Classifies images."
@spec image_classification(Client.t() | map(), map()) ::
{:ok, list()} | {:error, Exception.t()}
def image_classification(client, args) do
run_task(client, args, "image-classification", stream: false)
end
@doc "Summarises text."
@spec summarization(Client.t() | map(), map()) ::
{:ok, map()} | {:error, Exception.t()}
def summarization(client, args) do
run_task(client, args, "summarization", stream: false)
end
@doc "Translates text."
@spec translation(Client.t() | map(), map()) ::
{:ok, map()} | {:error, Exception.t()}
def translation(client, args) do
run_task(client, args, "translation", stream: false)
end
@doc "Sentence similarity / bi-encoder scoring."
@spec sentence_similarity(Client.t() | map(), map()) ::
{:ok, list()} | {:error, Exception.t()}
def sentence_similarity(client, args) do
run_task(client, args, "sentence-similarity", stream: false)
end
# ── Core dispatch ──────────────────────────────────────────────────────────────
defp run_task(%Client{} = client, args, default_task, run_opts) do
opts = Client.merge_opts(client, normalize_args(args))
do_run(opts, default_task, run_opts)
end
defp run_task(_non_client, _args, _default_task, _run_opts) do
{:error, InputError.exception("First argument must be a HuggingfaceClient.Client struct.")}
end
defp do_run(opts, default_task, run_opts) do
streaming? = Keyword.get(run_opts, :stream, false)
provider_opt = opts[:provider]
model = opts[:model]
endpoint_url = opts[:endpoint_url]
task = opts[:task] || default_task
with {:ok, provider} <- ProviderMapping.resolve_provider(provider_opt, model, endpoint_url),
{:ok, resolved_provider, resolved_task} <- maybe_fetch_mapping(provider, model, task, opts),
{:ok, provider_mod} <-
ProviderRegistry.get(effective_provider(resolved_provider, provider_opt), resolved_task) do
params = build_params(opts, model, resolved_task, resolved_provider, provider_mod)
if streaming? do
do_stream(provider_mod, params, opts)
else
do_request(provider_mod, params, opts)
end
end
end
# When provider is "auto" and we have a model, try to fetch the mapping from Hub.
# Falls back to "hf-inference" if the Hub is unreachable or no mapping is found.
defp maybe_fetch_mapping("auto", model, task, opts) when is_binary(model) do
hub_url = HuggingfaceClient.Hub.Client.hub_url()
token = opts[:access_token] || HuggingfaceClient.Config.token()
req_opts = Keyword.get(opts, :req_opts, [])
url = "#{hub_url}/api/models/#{model}"
headers = if token, do: %{"authorization" => "Bearer #{token}"}, else: %{}
case Req.get(url, [finch: HuggingfaceClient.Finch, headers: headers, receive_timeout: 10_000] ++ req_opts) do
{:ok, %{status: 200, body: %{"inferenceProviderMapping" => raw_mapping}}} ->
mappings = ProviderMapping.normalise_mapping(raw_mapping, model)
ProviderMapping.store_mappings(mappings)
# Pick first live mapping that matches the requested task (or any if nil task)
best =
Enum.find(mappings, fn m ->
m["status"] == "live" and (is_nil(task) or m["task"] == task)
end) || List.first(mappings)
if best do
{:ok, best["provider"], best["task"] || task}
else
{:ok, "hf-inference", task}
end
_ ->
{:ok, "hf-inference", task}
end
end
defp maybe_fetch_mapping(provider, _model, task, _opts), do: {:ok, provider, task}
# When provider resolved to "auto", use "hf-inference" for registry lookup
defp effective_provider("auto", _explicit), do: "hf-inference"
defp effective_provider(resolved, _), do: resolved
defp build_params(opts, model, task, _provider, provider_mod) do
token = opts[:access_token] || HuggingfaceClient.Config.token()
auth_method =
cond do
is_binary(token) and String.starts_with?(token, "hf_") -> :hf_token
is_binary(token) -> :provider_key
true -> :none
end
%{
args: normalize_map(opts),
model: model,
task: task,
access_token: token,
auth_method: auth_method,
endpoint_url: opts[:endpoint_url],
bill_to: opts[:bill_to],
retry_on_503: Keyword.get(opts, :retry_on_503, true),
req_opts: Keyword.get(opts, :req_opts, []),
url_transform: nil,
mapping: nil,
output_type: nil
}
end
defp do_request(provider_mod, params, opts) do
url = provider_mod.make_url(params)
payload = provider_mod.prepare_payload(params)
binary? = is_binary(payload)
headers = provider_mod.prepare_headers(params, binary?)
:telemetry.span(
[:huggingface_client, :request],
%{provider: provider_mod.provider_id(), task: params.task, model: params.model},
fn ->
result = execute_request(url, payload, headers, params, provider_mod)
status =
case result do
{:ok, _} -> 200
{:error, %{http_response: %{status: s}}} -> s
_ -> 0
end
{result,
%{provider: provider_mod.provider_id(), task: params.task, model: params.model, status: status}}
end
)
end
defp execute_request(url, payload, headers, params, provider_mod) do
req_opts = params.req_opts
retry? = params.retry_on_503
req_base =
[
finch: HuggingfaceClient.Finch,
headers: headers,
receive_timeout: 120_000,
retry: :never,
user_agent: InferenceConfig.user_agent()
] ++ req_opts
result =
if is_binary(payload) do
Req.post(url, req_base ++ [body: payload])
else
Req.post(url, req_base ++ [json: payload])
end
case result do
{:ok, %{status: 503}} when retry? ->
# Single retry on 503
if is_binary(payload) do
Req.post(url, req_base ++ [body: payload])
else
Req.post(url, req_base ++ [json: payload])
end
|> handle_response(url, payload, headers, params, provider_mod)
other ->
handle_response(other, url, payload, headers, params, provider_mod)
end
end
defp handle_response({:ok, %{status: 200, body: body}}, _url, _payload, _headers, params, provider_mod) do
provider_mod.get_response(body, params)
|> ResponseHandler.resolve(params)
end
defp handle_response({:ok, %{status: status, body: body}}, url, payload, headers, _params, _provider_mod) do
redacted = redact_auth(headers)
{:error,
ProviderApiError.exception({
"Provider returned HTTP #{status}",
%{url: url, method: "POST", headers: redacted, body: sanitize_body(payload)},
%{request_id: nil, status: status, body: body}
})}
end
defp handle_response({:error, reason}, url, payload, headers, _params, _provider_mod) do
redacted = redact_auth(headers)
{:error,
ProviderApiError.exception({
"HTTP request failed: #{inspect(reason)}",
%{url: url, method: "POST", headers: redacted, body: sanitize_body(payload)},
%{request_id: nil, status: nil, body: nil}
})}
end
defp do_stream(provider_mod, params, _opts) do
url = provider_mod.make_url(params)
payload = provider_mod.prepare_payload(params)
binary? = is_binary(payload)
headers = provider_mod.prepare_headers(params, binary?)
req_opts = params.req_opts
req_base =
[
finch: HuggingfaceClient.Finch,
headers: headers,
receive_timeout: 300_000,
retry: :never,
user_agent: InferenceConfig.user_agent()
] ++ req_opts
ref = make_ref()
caller = self()
task =
Task.async(fn ->
into_fun = fn {:data, chunk}, acc ->
send(caller, {ref, {:chunk, chunk}})
{:cont, acc}
end
Req.post(url, req_base ++ if(binary?, do: [body: payload], else: [json: payload]) ++ [into: into_fun])
send(caller, {ref, :done})
end)
stream =
Stream.resource(
fn -> {ref, task} end,
fn {r, t} = state ->
receive do
{^r, {:chunk, chunk}} -> {[chunk], state}
{^r, :done} -> {:halt, state}
after
120_000 ->
Task.shutdown(t, :brutal_kill)
{:halt, state}
end
end,
fn {_r, t} -> Task.shutdown(t, :brutal_kill) end
)
|> SSE.parse_stream_json()
{:ok, stream}
end
# ── Helpers ────────────────────────────────────────────────────────────────────
defp normalize_args(args) when is_map(args), do: args
defp normalize_args(args) when is_list(args), do: Map.new(args)
defp normalize_map(keyword) when is_list(keyword), do: Map.new(keyword)
defp normalize_map(map) when is_map(map), do: map
defp redact_auth(headers) when is_map(headers) do
case Map.fetch(headers, "authorization") do
{:ok, _} -> Map.put(headers, "authorization", "Bearer [redacted]")
:error -> headers
end
end
defp redact_auth(headers), do: headers
defp sanitize_body(body) when is_binary(body) and byte_size(body) > 512,
do: "<<binary, #{byte_size(body)} bytes>>"
defp sanitize_body(body), do: body
end