src/oidcc_jwt_util.erl

-module(oidcc_jwt_util).

-feature(maybe_expr, enable).

-include("internal/doc.hrl").
?MODULEDOC("JWT Utilities").

-include_lib("jose/include/jose_jwe.hrl").
-include_lib("jose/include/jose_jwk.hrl").
-include_lib("jose/include/jose_jws.hrl").
-include_lib("jose/include/jose_jwt.hrl").

-export([client_secret_oct_keys/2]).
-export([merge_client_secret_oct_keys/3]).
-export([decrypt_and_verify/5]).
-export([encrypt/4]).
-export([evaluate_for_all_keys/2]).
-export([merge_jwks/2]).
-export([peek_payload/1]).
-export([refresh_jwks_fun/1]).
-export([sign/3]).
-export([sign/4]).
-export([sign_dpop/3]).
-export([thumbprint/1]).
-export([verify_claims/2]).
-export([verify_not_none_alg/1]).
-export([verify_signature/3]).

-export_type([claims/0]).
-export_type([error/0]).
-export_type([refresh_jwks_for_unknown_kid_fun/0]).

?DOC(#{since => <<"3.0.0">>}).
-type refresh_jwks_for_unknown_kid_fun() ::
    fun((Jwks :: jose_jwk:key(), Kid :: binary()) -> {ok, jose_jwk:key()} | {error, term()}).

?DOC(#{since => <<"3.0.0">>}).
-type error() ::
    no_matching_key
    | invalid_jwt_token
    | {no_matching_key_with_kid, Kid :: binary()}
    | none_alg_used
    | {none_alg_used, Jwt :: #jose_jwt{}, Jws :: #jose_jws{}}
    | not_encrypted.

?DOC(#{since => <<"3.0.0">>}).
-type claims() :: #{binary() => term()}.

%% Function to decide if the jwks should be reladed to find a matching key for `Kid`
%%
%% A default function is provided in `oidcc:retrieve_token/5`
%% and `oidcc:retrieve_userinfo/5`.
%%
%% The default implementation does not implement any rate limiting.
%%
%% Checking of jwk sets is a bit wonky because of partial support
%% in jose. see: https://github.com/potatosalad/erlang-jose/issues/28
?DOC(false).
-spec verify_signature(Token, AllowAlgorithms, Jwks) ->
    {ok, {Jwt, Jws}}
    | {error, error()}
when
    Token :: binary(),
    AllowAlgorithms :: [binary()],
    Jwks :: jose_jwk:key(),
    Jwt :: #jose_jwt{},
    Jws :: #jose_jws{}.
verify_signature(Token, AllowAlgorithms, #jose_jwk{keys = {jose_jwk_set, Keys}}) ->
    lists:foldl(
        fun
            (_Key, {ok, _Res} = Acc) ->
                Acc;
            (Key, Acc) ->
                case {verify_signature(Token, AllowAlgorithms, Key), Acc} of
                    {{ok, Res}, _Acc} ->
                        {ok, Res};
                    {_Res, {error, {no_matching_key_with_kid, Kid}}} ->
                        {error, {no_matching_key_with_kid, Kid}};
                    {Res, _Acc} ->
                        Res
                end
        end,
        {error, no_matching_key},
        Keys
    );
verify_signature(Token, AllowAlgorithms, #jose_jwk{} = Jwks) ->
    try
        Kid =
            case jose_jwt:peek_protected(Token) of
                #jose_jws{fields = #{<<"kid">> := IntKid}} ->
                    IntKid;
                #jose_jws{} ->
                    none
            end,
        case Jwks of
            #jose_jwk{fields = #{<<"kid">> := CmpKid}} when CmpKid =/= Kid, Kid =/= none ->
                {error, {no_matching_key_with_kid, Kid}};
            #jose_jwk{} ->
                case jose_jwt:verify_strict(Jwks, AllowAlgorithms, Token) of
                    {true, Jwt, Jws} ->
                        {ok, {Jwt, Jws}};
                    {false, Jwt, #jose_jws{alg = {jose_jws_alg_none, none}} = Jws} ->
                        {error, {none_alg_used, Jwt, Jws}};
                    {false, _Jwt, _Jws} ->
                        {error, no_matching_key}
                end
        end
    catch
        error:{badarg, [_Token]} ->
            {error, invalid_jwt_token};
        %% Some Keys crash if a non matching alg is provided
        error:function_clause ->
            {error, invalid_jwt_token}
    end.

?DOC(false).
-spec verify_claims(Claims, ExpClaims) -> ok | {error, {missing_claim, ExpClaim, Claims}} when
    Claims :: claims(),
    ExpClaim :: {binary(), term() | {regex, binary()}},
    ExpClaims :: [ExpClaim].
verify_claims(Claims, ExpClaims) ->
    CheckExpectedClaims =
        fun
            ({Key, {regex, Pattern}}) ->
                case maps:get(Key, Claims, none) of
                    none ->
                        true;
                    Value when is_binary(Value) ->
                        case re:run(Value, Pattern, [{capture, none}]) of
                            match ->
                                false;
                            nomatch ->
                                true
                        end;
                    _Other ->
                        true
                end;
            ({Key, Value}) ->
                case maps:get(Key, Claims, none) of
                    Value ->
                        false;
                    _Other ->
                        true
                end
        end,
    case lists:filter(CheckExpectedClaims, ExpClaims) of
        [] ->
            ok;
        [Claim | _Rest] ->
            {error, {missing_claim, Claim, Claims}}
    end.

?DOC(false).
-spec client_secret_oct_keys(AllowedAlgorithms, ClientSecret) -> jose_jwk:key() | none when
    AllowedAlgorithms :: [binary()] | undefined,
    ClientSecret :: binary() | unauthenticated.
client_secret_oct_keys(undefined, _ClientSecret) ->
    none;
client_secret_oct_keys(_AllowedAlgorithms, unauthenticated) ->
    none;
client_secret_oct_keys(AllowedAlgorithms, ClientSecret) ->
    case
        lists:member(<<"HS256">>, AllowedAlgorithms) or
            lists:member(<<"HS384">>, AllowedAlgorithms) or
            lists:member(<<"HS512">>, AllowedAlgorithms)
    of
        true ->
            jose_jwk:from_oct(ClientSecret);
        false ->
            none
    end.

?DOC(false).
-spec merge_client_secret_oct_keys(Jwks :: jose_jwk:key(), AllowedAlgorithms, ClientSecret) ->
    jose_jwk:key()
when
    AllowedAlgorithms :: [binary()] | undefined,
    ClientSecret :: binary() | unauthenticated.
merge_client_secret_oct_keys(Jwks, AllowedAlgorithms, ClientSecret) ->
    case client_secret_oct_keys(AllowedAlgorithms, ClientSecret) of
        none ->
            Jwks;
        OctKeys ->
            merge_jwks(Jwks, OctKeys)
    end.

?DOC(false).
-spec refresh_jwks_fun(ProviderConfigurationWorkerName) ->
    refresh_jwks_for_unknown_kid_fun()
when
    ProviderConfigurationWorkerName :: gen_server:server_ref().
refresh_jwks_fun(ProviderConfigurationWorkerName) ->
    fun(_Jwks, Kid) ->
        oidcc_provider_configuration_worker:refresh_jwks_for_unknown_kid(
            ProviderConfigurationWorkerName,
            Kid
        ),
        {ok, oidcc_provider_configuration_worker:get_jwks(ProviderConfigurationWorkerName)}
    end.

?DOC(false).
-spec merge_jwks(Left :: jose_jwk:key(), Right :: jose_jwk:key()) -> jose_jwk:key().
merge_jwks(#jose_jwk{keys = {jose_jwk_set, LeftKeys}, fields = LeftFields}, #jose_jwk{
    keys = {jose_jwk_set, RightKeys}, fields = RightFields
}) ->
    #jose_jwk{
        keys = {jose_jwk_set, LeftKeys ++ RightKeys}, fields = maps:merge(LeftFields, RightFields)
    };
merge_jwks(#jose_jwk{} = Left, #jose_jwk{keys = {jose_jwk_set, _RightKeys}} = Right) ->
    merge_jwks(#jose_jwk{keys = {jose_jwk_set, [Left]}}, Right);
merge_jwks(Left, Right) ->
    merge_jwks(Left, #jose_jwk{keys = {jose_jwk_set, [Right]}}).

?DOC(false).
-spec sign(Jwt :: #jose_jwt{}, Jwk :: jose_jwk:key(), SupportedAlgorithms :: [binary()]) ->
    {ok, binary()} | {error, no_supported_alg_or_key}.
sign(Jwt, Jwk, SupportedAlgorithms) ->
    sign(Jwt, Jwk, SupportedAlgorithms, #{}).

?DOC(false).
-spec sign(
    Jwt :: #jose_jwt{}, Jwk :: jose_jwk:key(), SupportedAlgorithms :: [binary()], JwsFields :: map()
) ->
    {ok, binary()} | {error, no_supported_alg_or_key}.
sign(_Jwt, _Jwk, [], _JwsFields) ->
    {error, no_supported_alg_or_key};
sign(Jwt, Jwk, [Algorithm | RestAlgorithms], JwsFields0) ->
    maybe
        #jose_jws{fields = JwsFields} =
            Jws0 ?= jose_jws:from_map(JwsFields0#{<<"alg">> => Algorithm}),
        SigningCallback = fun
            (#jose_jwk{fields = Fields} = Key) when Algorithm =/= <<"none">> ->
                %% add the kid field to the JWS signature if present
                KidField = maps:with([<<"kid">>], Fields),
                Jws = Jws0#jose_jws{fields = maps:merge(KidField, JwsFields)},
                try
                    %% ensure key is either for signatures, or not specified
                    ok =
                        case Fields of
                            #{<<"use">> := <<"sig">>} -> ok;
                            #{<<"use">> := _} -> error;
                            #{} -> ok
                        end,
                    {_Jws, Token} = jose_jws:compact(jose_jwt:sign(Key, Jws, Jwt)),
                    {ok, Token}
                catch
                    error:{badmatch, _} -> error;
                    error:not_supported -> error;
                    error:{not_supported, _Alg} -> error;
                    %% Some Keys crash if a public key is provided
                    error:function_clause -> error
                end;
            (#jose_jwk{} = Key) when Algorithm == <<"none">> ->
                try
                    {_Jws, Token} = jose_jws:compact(jose_jwt:sign(Key, Jws0, Jwt)),
                    {ok, Token}
                catch
                    error:not_supported -> error;
                    error:{not_supported, _Alg} -> error
                end;
            (_Key) ->
                error
        end,
        {ok, Token} ?= evaluate_for_all_keys(Jwk, SigningCallback),
        {ok, Token}
    else
        _ -> sign(Jwt, Jwk, RestAlgorithms, JwsFields0)
    end.

?DOC(false).
-spec decrypt_and_verify(
    Jwt :: binary(),
    Jwks :: jose_jwk:key(),
    SigningAlgs :: [binary()] | undefined,
    EncryptionAlgs :: [binary()] | undefined,
    EncryptionEncs :: [binary()] | undefined
) ->
    {ok, {#jose_jwt{}, #jose_jwe{} | #jose_jws{}}} | {error, error()}.
decrypt_and_verify(Jwt, Jwks, SigningAlgs, EncryptionAlgs, EncryptionEncs) ->
    %% we call jwe_peek_protected/1 before `decrypt/4' so that we can
    %% handle unencrypted tokens in the case where SupportedAlgorithms /
    %% SupportedEncValues are undefined (where `decrypt/4' returns
    %% {error, no_supported_alg_or_key}).
    case jwe_peek_protected(Jwt) of
        {ok, Jwe} ->
            case decrypt(Jwt, Jwks, EncryptionAlgs, EncryptionEncs) of
                {ok, Decrypted} ->
                    verify_decrypted_token(Decrypted, SigningAlgs, Jwe, Jwks);
                {error, Reason} ->
                    {error, Reason}
            end;
        {error, not_encrypted} ->
            %% signed JWT
            verify_signature(Jwt, SigningAlgs, Jwks);
        {error, Reason} ->
            {error, Reason}
    end.

-spec jwe_peek_protected(Jwt :: binary()) ->
    {ok, #jose_jwe{}} | {error, not_encrypted | no_matching_key}.
jwe_peek_protected(Jwt) ->
    %% jose_jwt:peek_protected(Jwt) doesn't work with encrypted tokens
    maybe
        [ProtectedEncoded, _, _, _, _] ?= binary:split(Jwt, <<".">>, [global]),
        Protected = jose_jwa_base64url:decode(ProtectedEncoded),
        #jose_jwe{} = Jwe ?= jose_jwe:from(Protected),
        {ok, Jwe}
    else
        [_, _, _] ->
            {error, not_encrypted};
        _ ->
            {error, no_matching_key}
    end.

-spec decrypt(
    Jwt :: binary(),
    Jwk :: jose_jwk:key(),
    SupportedAlgorithms :: [binary()] | undefined,
    SupportedEncValues :: [binary()] | undefined
) ->
    {ok, binary()} | {error, error()}.
decrypt(_Jwt, _Jwk, undefined, _SupportedEncValues) ->
    {error, no_supported_alg_or_key};
decrypt(_Jwt, _Jwk, _SupportedAlgorithms, undefined) ->
    {error, no_supported_alg_or_key};
decrypt(Jwt, #jose_jwk{keys = {jose_jwk_set, Keys}}, SupportedAlgorithms, SupportedEncValues) ->
    lists:foldl(
        fun
            (_Key, {ok, _Res} = Acc) ->
                Acc;
            (Key, Acc) ->
                case {decrypt(Jwt, Key, SupportedAlgorithms, SupportedEncValues), Acc} of
                    {{ok, Res}, _Acc} ->
                        {ok, Res};
                    {_Res, {error, {no_matching_key_with_kid, Kid}}} ->
                        {error, {no_matching_key_with_kid, Kid}};
                    {Res, _Acc} ->
                        Res
                end
        end,
        {error, no_matching_key},
        Keys
    );
decrypt(Jwt, #jose_jwk{} = Jwk, SupportedAlgorithms, SupportedEncValues) ->
    maybe
        {ok, Jwe} ?= jwe_peek_protected(Jwt),
        {_, #{<<"alg">> := JwtAlg, <<"enc">> := JwtEnc}} = jose_jwe:to_map(Jwe),
        ok ?= verify_in_list(JwtAlg, SupportedAlgorithms),
        ok ?= verify_in_list(JwtEnc, SupportedEncValues),
        Kid =
            case Jwe of
                #jose_jwe{fields = #{<<"kid">> := IntKid}} ->
                    IntKid;
                #jose_jwe{} ->
                    none
            end,
        case Jwk of
            #jose_jwk{fields = #{<<"kid">> := CmpKid}} when CmpKid =/= Kid, Kid =/= none ->
                {error, {no_matching_key_with_kid, Kid}};
            #jose_jwk{fields = #{<<"use">> := NotEnc}} when NotEnc =/= <<"enc">> ->
                {error, no_matching_key};
            _ ->
                try
                    {Token, _Jwe} = jose_jwe:block_decrypt(Jwk, Jwt),
                    {ok, Token}
                catch
                    error:_ when Kid =:= none ->
                        {error, no_matching_key};
                    error:_ ->
                        {error, {no_matching_key_with_kid, Kid}}
                end
        end
    end.

verify_in_list(Value, List) ->
    case lists:member(Value, List) of
        true ->
            ok;
        false ->
            {error, no_matching_key}
    end.

verify_decrypted_token(Jwt, SigningAlgs, Jwe, Jwks) ->
    case verify_signature(Jwt, SigningAlgs, Jwks) of
        {ok, Result} ->
            %% encrypted + signed (nested) JWT
            {ok, Result};
        {error, invalid_jwt_token} ->
            %% encrypted JWT
            try
                {ok, {jose_jwt:from_binary(Jwt), Jwe}}
            catch
                _ -> {error, invalid_jwt_token}
            end;
        {error, Reason} ->
            {error, Reason}
    end.

?DOC(false).
-spec encrypt(
    Jwt :: binary(),
    Jwk :: jose_jwk:key(),
    SupportedAlgorithms :: [binary()] | undefined,
    SupportedEncValues :: [binary()] | undefined
) ->
    {ok, binary()} | {error, no_supported_alg_or_key}.
encrypt(_Jwt, _Jwk, undefined, _SupportedEncValues) ->
    {error, no_supported_alg_or_key};
encrypt(_Jwt, _Jwk, _SupportedAlgorithms, undefined) ->
    {error, no_supported_alg_or_key};
encrypt(Jwt, Jwk, SupportedAlgorithms, SupportedEncValues) ->
    encrypt(Jwt, Jwk, SupportedAlgorithms, SupportedEncValues, SupportedEncValues).

-spec encrypt(
    Jwt :: binary(),
    Jwk :: jose_jwk:key(),
    SupportedAlgorithms :: [binary()],
    SupportedEncValues :: [binary()],
    AccEncValues :: [binary()]
) ->
    {ok, binary()} | {error, no_supported_alg_or_key}.
encrypt(_Jwt, _Jwk, [], _SupportedEncValues, _AccEncValues) ->
    {error, no_supported_alg_or_key};
encrypt(Jwt, Jwk, [_Algorithm | RestAlgorithms], SupportedEncValues, []) ->
    encrypt(Jwt, Jwk, RestAlgorithms, SupportedEncValues, SupportedEncValues);
encrypt(Jwt, Jwk, [Algorithm | _RestAlgorithms] = SupportedAlgorithms, SupportedEncValues, [
    EncValue | RestEncValues
]) ->
    JweParams0 = #{<<"alg">> => Algorithm, <<"enc">> => EncValue},
    EncryptionCallback = fun
        (#jose_jwk{fields = #{<<"use">> := NotEnc}}) when NotEnc =/= <<"enc">> ->
            error;
        (#jose_jwk{fields = Fields} = Key) ->
            try
                JweParams =
                    case maps:get(<<"kid">>, Fields, undefined) of
                        undefined -> JweParams0;
                        Kid -> maps:put(<<"kid">>, Kid, JweParams0)
                    end,
                Jwe = jose_jwe:from_map(JweParams),
                {_Jws, Token} = jose_jwe:compact(jose_jwk:block_encrypt(Jwt, Jwe, Key)),
                {ok, Token}
            catch
                error:undef -> error;
                error:{not_supported, _Alg} -> error
            end;
        (_Key) ->
            error
    end,
    case evaluate_for_all_keys(Jwk, EncryptionCallback) of
        {ok, Token} -> {ok, Token};
        error -> encrypt(Jwt, Jwk, SupportedAlgorithms, SupportedEncValues, RestEncValues)
    end.

?DOC(false).
-spec thumbprint(Jwk :: jose_jwk:key()) -> {ok, binary()} | error.
thumbprint(Jwk) ->
    evaluate_for_all_keys(Jwk, fun
        (#jose_jwk{fields = #{<<"use">> := Use}}) when Use =/= <<"sig">> ->
            error;
        (Key) ->
            {ok, jose_jwk:thumbprint(Key)}
    end).

?DOC(false).
-spec sign_dpop(Jwt :: #jose_jwt{}, Jwk :: jose_jwk:key(), SigningAlgSupported :: [binary()]) ->
    {ok, binary()} | {error, no_supported_alg_or_key}.
sign_dpop(Jwt, Jwk, SigningAlgSupported) ->
    evaluate_for_all_keys(Jwk, fun
        (#jose_jwk{fields = #{<<"use">> := Use}}) when Use =/= <<"sig">> ->
            error;
        (Key) ->
            {_, PublicJwk} = jose_jwk:to_public_map(Key),
            sign(Jwt, Key, SigningAlgSupported, #{
                <<"typ">> => <<"dpop+jwt">>, <<"jwk">> => PublicJwk
            })
    end).

?DOC(false).
-spec evaluate_for_all_keys(Jwk :: jose_jwk:key(), fun((jose_jwk:key()) -> {ok, Result} | error)) ->
    {ok, Result} | error
when
    Result :: term().
evaluate_for_all_keys(#jose_jwk{keys = {jose_jwk_set, Keys}}, Callback) ->
    lists:foldl(
        fun
            (_Key, {ok, Result}) ->
                {ok, Result};
            (Key, error) ->
                evaluate_for_all_keys(Key, Callback)
        end,
        error,
        Keys
    );
evaluate_for_all_keys(#jose_jwk{} = Jwk, Callback) ->
    Callback(Jwk).

?DOC(false).
-spec verify_not_none_alg(#jose_jws{}) -> ok | {error, none_alg_used}.
verify_not_none_alg(#jose_jws{fields = #{<<"alg">> := <<"none">>}}) ->
    {error, none_alg_used};
verify_not_none_alg(#jose_jws{}) ->
    ok.

?DOC(false).
-spec peek_payload(binary()) -> {ok, #jose_jwt{}} | {error, invalid_jwt_token}.
peek_payload(Jwt) ->
    try
        {ok, jose_jwt:peek_payload(Jwt)}
    catch
        error:{badarg, [_Token]} ->
            {error, invalid_jwt_token};
        error:function_clause ->
            {error, invalid_jwt_token}
    end.