defmodule Relyra do
@moduledoc """
Public entry points for strict-by-default SAML protocol flows.
"""
alias Relyra.ConnectionResolver
alias Relyra.Error
alias Relyra.Protocol.AuthnRequest
alias Relyra.Protocol.Binding
alias Relyra.Protocol.ValidationPipeline
alias Relyra.RequestStore
alias Relyra.Security.RelayState
@default_request_intent_ttl_seconds 300
@spec start_login(map(), map(), keyword()) :: {:ok, map()} | {:error, Relyra.Error.t()}
def start_login(connection, relay_context, opts \\ []) do
metadata = %{
connection_id: read_field(connection, :connection_id),
organization_id: read_field(connection, :organization_id),
provider_preset: read_field(connection, :provider_preset),
flow: :sp_initiated,
binding: :redirect
}
Relyra.Telemetry.span([:login], metadata, fn ->
result = do_start_login(connection, relay_context, opts)
case result do
{:ok, _} = ok ->
{ok, Map.put(metadata, :outcome, :ok)}
{:error, %Error{} = error} ->
{{:error, error}, Map.merge(metadata, %{outcome: :error, error_code: error.type})}
end
end)
end
defp do_start_login(connection, relay_context, opts) do
metadata = %{
connection_id: read_field(connection, :connection_id),
organization_id: read_field(connection, :organization_id),
provider_preset: read_field(connection, :provider_preset),
flow: :sp_initiated,
binding: :redirect
}
Relyra.Telemetry.span([:authn_request], metadata, fn ->
with {:ok, request_fields} <- AuthnRequest.build(connection, relay_context, opts),
request_id <- Map.fetch!(request_fields, :id),
authn_request_xml <- AuthnRequest.to_xml(request_fields),
{:ok, relay_state} <-
RelayState.issue(Map.put(relay_context, :request_id, request_id), opts),
issued_at <- intent_issued_at(opts),
expires_at <- intent_expires_at(issued_at, opts),
intent <-
build_request_intent(request_id, relay_state, connection, issued_at, expires_at) do
request_store_start = System.monotonic_time()
request_store_result = persist_request_intent(relay_state, intent, opts)
request_store_latency_ms = duration_ms(request_store_start)
case request_store_result do
:ok ->
case Binding.encode_redirect(authn_request_xml, relay_state) do
{:ok, redirect_params} ->
base64_request = Map.get(redirect_params, "SAMLRequest") || ""
{
{:ok,
%{
request_id: request_id,
authn_request: authn_request_xml,
relay_state: relay_state,
redirect_params: redirect_params
}},
Map.merge(metadata, %{
outcome: :ok,
xml_bytes: byte_size(authn_request_xml),
base64_bytes: byte_size(base64_request),
request_store_latency_ms: request_store_latency_ms
})
}
{:error, %Error{} = error} ->
{{:error, error}, Map.merge(metadata, %{outcome: :error, error_code: error.type})}
end
{:error, %Error{} = error} ->
{{:error, error},
Map.merge(metadata, %{
outcome: :error,
error_code: error.type,
request_store_latency_ms: request_store_latency_ms
})}
end
else
{:error, %Error{} = error} ->
{{:error, error}, Map.merge(metadata, %{outcome: :error, error_code: error.type})}
end
end)
end
@spec consume_response(binary(), map() | keyword(), keyword()) ::
{:ok, map()} | {:error, Relyra.Error.t()}
def consume_response(response_payload, request_intent_or_opts, opts \\ []) do
metadata = %{
flow: :sp_initiated
}
Relyra.Telemetry.span([:response, :consume], metadata, fn ->
try do
result = do_consume_response(response_payload, request_intent_or_opts, opts)
case result do
{:ok, login_result} ->
final_metadata =
Map.merge(metadata, %{
outcome: :ok,
connection_id: read_field(login_result, :connection_id)
})
{{:ok, login_result}, final_metadata}
{:error, %Error{} = error} ->
{{:error, error}, Map.merge(metadata, %{outcome: :error, error_code: error.type})}
end
rescue
exception ->
error =
Error.new(
:internal_protocol_error,
"consume_response/3 raised an unexpected exception",
%{stage: :consume_response, reason: Exception.message(exception)}
)
{{:error, error}, Map.merge(metadata, %{outcome: :error, error_code: error.type})}
catch
kind, reason ->
error =
Error.new(
:internal_protocol_error,
"consume_response/3 trapped a non-local exit",
%{stage: :consume_response, kind: kind, reason: inspect(reason)}
)
{{:error, error}, Map.merge(metadata, %{outcome: :error, error_code: error.type})}
end
end)
end
defp do_consume_response(response_payload, request_intent_or_opts, opts) do
now = Keyword.get(opts, :now, DateTime.utc_now())
with {:ok, request_intent, consume_opts} <-
resolve_request_intent(request_intent_or_opts, opts),
:ok <- validate_relay_state_opt(consume_opts, request_intent),
:ok <- validate_request_intent(request_intent, consume_opts),
:ok <- validate_request_intent_expiry(request_intent, now),
{:ok, connection} <- resolve_connection_context(request_intent, consume_opts),
{:ok, result_map} <-
ValidationPipeline.run(response_payload, request_intent, connection, consume_opts),
:ok <- consume_replay_key(result_map, connection, consume_opts),
:ok <- consume_request_intent(request_intent, consume_opts),
{:ok, login_result} <- normalize_consume_result(result_map) do
{:ok, login_result}
else
{:error, %Error{} = error} ->
{:error, error}
end
end
defp validate_request_intent_expiry(%{expires_at: expires_at}, now) do
expires_at = maybe_parse_iso8601(expires_at)
case DateTime.compare(expires_at, now) do
:gt -> :ok
_ -> {:error, Error.new(:request_intent_expired, "Request intent has expired")}
end
end
defp validate_request_intent_expiry(_, _now), do: :ok
defp maybe_parse_iso8601(%DateTime{} = dt), do: dt
defp maybe_parse_iso8601(bin) when is_binary(bin) do
case DateTime.from_iso8601(bin) do
{:ok, dt, _} -> dt
_ -> DateTime.from_unix!(0)
end
end
defp maybe_parse_iso8601(_), do: DateTime.from_unix!(0)
defp normalize_consume_result(result) when is_map(result) do
principal = %Relyra.Principal{
name_id: Map.get(result, :name_id),
name_id_format: Map.get(result, :name_id_format),
session_index: Map.get(result, :session_index),
attributes: Map.get(result, :attributes),
connection_id: Map.get(result, :connection_id)
}
login_result = %Relyra.LoginResult{
principal: principal,
connection: Map.get(result, :connection),
relay_state: Map.get(result, :relay_state),
issuer: Map.get(result, :issuer),
in_response_to: Map.get(result, :in_response_to),
return_to: Map.get(result, :return_to)
}
{:ok, login_result}
end
defp resolve_request_intent(request_intent, opts) when is_map(request_intent) do
if Keyword.get(opts, :relay_state) do
{:ok, request_intent, opts}
else
{:error, Error.new(:relay_state_missing, "RelayState is required in opts")}
end
end
defp resolve_request_intent(opts, []) when is_list(opts) do
relay_state = Keyword.get(opts, :relay_state)
if relay_state do
case RequestStore.fetch_intent(relay_state, opts) do
{:ok, intent} ->
{:ok, intent, opts}
{:error, %Error{type: :adapter_not_configured} = error} ->
{:error, error}
{:error, _} ->
# If not found in store, we might be in an IdP-initiated flow.
# We return nil intent and let the pipeline decide based on connection config.
{:ok, nil, opts}
end
else
# No relay_state and no intent map provided.
# We return nil intent and assume connection will be provided in opts.
{:ok, nil, opts}
end
end
defp validate_request_intent(nil, _opts), do: :ok
defp validate_request_intent(intent, _opts) do
required = [:request_id, :sp_entity_id, :acs_url]
missing = Enum.filter(required, fn key -> is_nil(Map.get(intent, key)) end)
if missing == [] do
:ok
else
{:error,
Error.new(:request_intent_invalid, "Stored request intent is missing required fields", %{
missing: missing
})}
end
end
defp validate_relay_state_opt(opts, request_intent) do
case Keyword.get(opts, :relay_state) do
nil ->
:ok
actual ->
expected = Map.get(request_intent || %{}, :relay_state)
if is_nil(expected) or actual == expected do
:ok
else
{:error,
Error.new(
:relay_state_mismatch,
"Provided relay_state does not match stored intent",
%{
expected: expected,
actual: actual
}
)}
end
end
end
defp resolve_connection_context(request_intent, opts) do
case Keyword.get(opts, :connection) || Keyword.get(opts, :resolved_connection) do
connection when is_map(connection) ->
{:ok, connection}
_ ->
request_context = %{
connection_id: read_field(request_intent, :connection_id),
organization_id: read_field(request_intent, :organization_id)
}
ConnectionResolver.resolve_connection(request_context, opts)
end
end
defp consume_replay_key(login_result, _connection, opts) do
issuer = Map.get(login_result, :issuer)
signed_xml_id = Map.get(login_result, :signed_xml_id)
connection_id = Map.get(login_result, :connection_id)
replay_key = build_replay_key(connection_id, issuer, signed_xml_id)
metadata = %{
connection_id: connection_id,
issuer: issuer,
assertion_id: signed_xml_id
}
Relyra.ReplayStore.consume_replay_key(replay_key, metadata, opts)
end
defp consume_request_intent(nil, _opts), do: :ok
defp consume_request_intent(request_intent, opts) do
relay_state = Map.get(request_intent, :relay_state)
request_id = Map.get(request_intent, :request_id)
RequestStore.consume_intent(relay_state, request_id, opts)
end
defp persist_request_intent(relay_state, intent, opts) do
RequestStore.put_intent(relay_state, intent, opts)
end
defp intent_issued_at(opts) do
Keyword.get(opts, :now, DateTime.utc_now())
end
defp intent_expires_at(issued_at, opts) do
ttl = Keyword.get(opts, :ttl_seconds, @default_request_intent_ttl_seconds)
DateTime.add(issued_at, ttl, :second)
end
defp build_request_intent(request_id, relay_state, connection, issued_at, expires_at) do
%{
request_id: request_id,
relay_state: relay_state,
connection_id: read_field(connection, :connection_id) || read_field(connection, :id),
organization_id: read_field(connection, :organization_id),
sp_entity_id: read_field(connection, :sp_entity_id) || read_field(connection, :issuer),
acs_url: read_field(connection, :acs_url),
issued_at: issued_at,
expires_at: expires_at
}
end
defp build_replay_key(connection_id, issuer, signed_xml_id) do
# Stable deterministic key for replay detection
"#{connection_id}:#{issuer}:#{signed_xml_id}"
end
defp duration_ms(start_time) do
System.convert_time_unit(System.monotonic_time() - start_time, :native, :millisecond)
end
defp read_field(map, key) when is_map(map) and is_atom(key) do
Map.get(map, key) || Map.get(map, Atom.to_string(key))
end
end