Skip to main content

src/auth/livery_auth.erl

-module(livery_auth).
-moduledoc """
JWT verification against a JWK set.

Verifies compact-serialization JSON Web Tokens signed with RS256
or ES256, then validates the registered claims (`exp`, `nbf`,
`iss`, `aud`). Signature verification and key handling use the OTP
`public_key` and `crypto` modules; no third-party crypto is
pulled in.

The JWK set is supplied by the caller. OIDC discovery and live
JWKS rotation over HTTP are a thin layer that can sit on top of
this module (a follow-up); keeping verification network-free makes
it cheap to test and embed.

```erlang
{ok, Claims} = livery_auth:verify(Token, #{
    keys     => JwkList,
    issuer   => <<"https://issuer.example">>,
    audience => <<"my-api">>
}).
```

A JWK is a map with binary keys, e.g. for RSA:
`#{<<"kty">> => <<"RSA">>, <<"kid">> => _, <<"n">> => _, <<"e">> => _}`
and for EC P-256:
`#{<<"kty">> => <<"EC">>, <<"crv">> => <<"P-256">>, <<"x">> => _, <<"y">> => _}`.
""".

-include_lib("public_key/include/public_key.hrl").

-export([verify/2, tls_opts/0]).

-export_type([jwk/0, verify_opts/0, claims/0, error_reason/0]).

-type jwk() :: #{binary() => binary()}.
-type claims() :: #{binary() => term()}.

-type verify_opts() :: #{
    keys := [jwk()],
    issuer => binary() | undefined,
    audience => binary() | [binary()] | undefined,
    now => non_neg_integer(),
    leeway => non_neg_integer()
}.

-type error_reason() ::
    malformed
    | invalid_json
    | {unsupported_alg, binary()}
    | no_matching_key
    | bad_signature
    | expired
    | not_yet_valid
    | {issuer_mismatch, binary()}
    | audience_mismatch.

%%====================================================================
%% Public API
%%====================================================================

-doc """
Verify a JWT and return its validated claims.

Steps: split the compact token, decode the header to pick the
algorithm and key id, find the matching JWK, verify the
signature, then validate `exp`/`nbf`/`iss`/`aud`.
""".
-spec verify(binary(), verify_opts()) ->
    {ok, claims()} | {error, error_reason()}.
verify(Token, Opts) when is_binary(Token) ->
    case split(Token) of
        {ok, HeaderB64, PayloadB64, SigB64, SigningInput} ->
            with_decoded(HeaderB64, PayloadB64, SigB64, SigningInput, Opts);
        error ->
            {error, malformed}
    end.

-doc """
TLS client options for verifying an HTTPS peer's certificate.

Used by the OIDC/JWKS/introspection fetchers so the channel that
discovers signing keys (and thus the identity trust root) is
authenticated: a forged JWK set served by an on-path attacker would
otherwise let them mint tokens this node accepts. Verifies against
the OS trust store with hostname checking.
""".
-spec tls_opts() -> [ssl:tls_client_option()].
tls_opts() ->
    [
        {verify, verify_peer},
        {cacerts, public_key:cacerts_get()},
        {customize_hostname_check, [
            {match_fun, public_key:pkix_verify_hostname_match_fun(https)}
        ]},
        {depth, 99}
    ].

%%====================================================================
%% Internals
%%====================================================================

-spec split(binary()) ->
    {ok, binary(), binary(), binary(), binary()} | error.
split(Token) ->
    case binary:split(Token, <<".">>, [global]) of
        [H, P, S] ->
            {ok, H, P, S, <<H/binary, ".", P/binary>>};
        _ ->
            error
    end.

with_decoded(HeaderB64, PayloadB64, SigB64, SigningInput, Opts) ->
    case {decode_json(HeaderB64), decode_json(PayloadB64), b64url(SigB64)} of
        {{ok, Header}, {ok, Claims}, {ok, Sig}} ->
            verify_decoded(Header, Claims, Sig, SigningInput, Opts);
        _ ->
            {error, invalid_json}
    end.

verify_decoded(Header, Claims, Sig, SigningInput, Opts) ->
    Alg = maps:get(<<"alg">>, Header, undefined),
    Kid = maps:get(<<"kid">>, Header, undefined),
    case find_key(Alg, Kid, maps:get(keys, Opts, [])) of
        {ok, Jwk} ->
            case verify_signature(Alg, SigningInput, Sig, Jwk) of
                true -> validate_claims(Claims, Opts);
                false -> {error, bad_signature}
            end;
        {error, unsupported} ->
            {error, {unsupported_alg, Alg}};
        {error, not_found} ->
            {error, no_matching_key}
    end.

%%====================================================================
%% Key selection
%%====================================================================

find_key(Alg, _Kid, _Keys) when Alg =/= <<"RS256">>, Alg =/= <<"ES256">> ->
    {error, unsupported};
find_key(_Alg, Kid, Keys) ->
    Matching = [K || K <- Keys, key_matches(K, Kid)],
    case Matching of
        [K | _] -> {ok, K};
        [] -> {error, not_found}
    end.

%% When the token carries a kid, require it to match; otherwise fall
%% back to any key (single-key deployments routinely omit kid).
key_matches(_Jwk, undefined) -> true;
key_matches(Jwk, Kid) -> maps:get(<<"kid">>, Jwk, undefined) =:= Kid.

%%====================================================================
%% Signature verification
%%====================================================================

verify_signature(<<"RS256">>, SigningInput, Sig, Jwk) ->
    case rsa_public_key(Jwk) of
        {ok, PubKey} ->
            public_key:verify(SigningInput, sha256, Sig, PubKey);
        error ->
            false
    end;
verify_signature(<<"ES256">>, SigningInput, Sig, Jwk) ->
    case ec_public_key(Jwk) of
        {ok, PubKey} ->
            case raw_to_der_sig(Sig) of
                {ok, DerSig} ->
                    public_key:verify(SigningInput, sha256, DerSig, PubKey);
                error ->
                    false
            end;
        error ->
            false
    end.

-spec rsa_public_key(jwk()) -> {ok, #'RSAPublicKey'{}} | error.
rsa_public_key(#{<<"n">> := N64, <<"e">> := E64}) ->
    case {b64url(N64), b64url(E64)} of
        {{ok, N}, {ok, E}} ->
            {ok, #'RSAPublicKey'{
                modulus = binary:decode_unsigned(N),
                publicExponent = binary:decode_unsigned(E)
            }};
        _ ->
            error
    end;
rsa_public_key(_) ->
    error.

-spec ec_public_key(jwk()) -> {ok, term()} | error.
ec_public_key(#{<<"x">> := X64, <<"y">> := Y64}) ->
    case {b64url(X64), b64url(Y64)} of
        {{ok, X}, {ok, Y}} when byte_size(X) =:= 32, byte_size(Y) =:= 32 ->
            Point = #'ECPoint'{point = <<4, X/binary, Y/binary>>},
            Params = {namedCurve, ?'secp256r1'},
            {ok, {Point, Params}};
        _ ->
            error
    end;
ec_public_key(_) ->
    error.

%% JWS ECDSA signatures are the raw r||s concatenation (RFC 7518
%% ยง3.4). OTP's public_key:verify wants a DER-encoded
%% ECDSA-Sig-Value.
-spec raw_to_der_sig(binary()) -> {ok, binary()} | error.
raw_to_der_sig(<<R:32/binary, S:32/binary>>) ->
    RInt = binary:decode_unsigned(R),
    SInt = binary:decode_unsigned(S),
    {ok,
        public_key:der_encode(
            'ECDSA-Sig-Value',
            #'ECDSA-Sig-Value'{r = RInt, s = SInt}
        )};
raw_to_der_sig(_) ->
    error.

%%====================================================================
%% Claim validation
%%====================================================================

validate_claims(Claims, Opts) ->
    Now = maps:get(now, Opts, os:system_time(second)),
    Leeway = maps:get(leeway, Opts, 0),
    Checks = [
        fun() -> check_exp(Claims, Now, Leeway) end,
        fun() -> check_nbf(Claims, Now, Leeway) end,
        fun() -> check_iss(Claims, maps:get(issuer, Opts, undefined)) end,
        fun() -> check_aud(Claims, maps:get(audience, Opts, undefined)) end
    ],
    run_checks(Checks, Claims).

run_checks([], Claims) ->
    {ok, Claims};
run_checks([Check | Rest], Claims) ->
    case Check() of
        ok -> run_checks(Rest, Claims);
        {error, _} = E -> E
    end.

check_exp(Claims, Now, Leeway) ->
    case maps:get(<<"exp">>, Claims, undefined) of
        undefined -> ok;
        Exp when is_integer(Exp), Now =< Exp + Leeway -> ok;
        _ -> {error, expired}
    end.

check_nbf(Claims, Now, Leeway) ->
    case maps:get(<<"nbf">>, Claims, undefined) of
        undefined -> ok;
        Nbf when is_integer(Nbf), Now + Leeway >= Nbf -> ok;
        _ -> {error, not_yet_valid}
    end.

check_iss(_Claims, undefined) ->
    ok;
check_iss(Claims, Expected) ->
    case maps:get(<<"iss">>, Claims, undefined) of
        Expected -> ok;
        _ -> {error, {issuer_mismatch, Expected}}
    end.

check_aud(_Claims, undefined) ->
    ok;
check_aud(Claims, Expected) ->
    Aud = maps:get(<<"aud">>, Claims, undefined),
    case audience_ok(Aud, Expected) of
        true -> ok;
        false -> {error, audience_mismatch}
    end.

%% `aud` may be a string or an array; `Expected` may be one value or
%% a list of acceptable values. Match if any expected value appears.
audience_ok(undefined, _Expected) ->
    false;
audience_ok(Aud, Expected) when is_binary(Aud) ->
    audience_ok([Aud], Expected);
audience_ok(AudList, Expected) when is_list(AudList), is_binary(Expected) ->
    lists:member(Expected, AudList);
audience_ok(AudList, ExpectedList) when is_list(AudList), is_list(ExpectedList) ->
    lists:any(fun(E) -> lists:member(E, AudList) end, ExpectedList);
audience_ok(_, _) ->
    false.

%%====================================================================
%% base64url + JSON
%%====================================================================

-spec decode_json(binary()) -> {ok, map()} | error.
decode_json(B64) ->
    case b64url(B64) of
        {ok, Bin} ->
            try
                {ok, json:decode(Bin)}
            catch
                _:_ -> error
            end;
        error ->
            error
    end.

-spec b64url(binary()) -> {ok, binary()} | error.
b64url(B64) ->
    try
        {ok, base64:decode(B64, #{mode => urlsafe, padding => false})}
    catch
        _:_ -> error
    end.