src/oidcc_jwt_util.erl

%%%-------------------------------------------------------------------
%% @doc JWT Utilities
%% @end
%%%-------------------------------------------------------------------
-module(oidcc_jwt_util).

-include_lib("jose/include/jose_jwk.hrl").
-include_lib("jose/include/jose_jws.hrl").
-include_lib("jose/include/jose_jwt.hrl").

-export([client_secret_oct_keys/2]).
-export([encrypt/4]).
-export([evaluate_for_all_keys/2]).
-export([merge_jwks/2]).
-export([refresh_jwks_fun/1]).
-export([sign/3]).
-export([verify_claims/2]).
-export([verify_signature/3]).

-export_type([claims/0]).
-export_type([error/0]).
-export_type([refresh_jwks_for_unknown_kid_fun/0]).

-type refresh_jwks_for_unknown_kid_fun() ::
    fun((Jwks :: jose_jwk:key(), Kid :: binary()) -> {ok, jose_jwk:key()} | {error, term()}).

-type error() :: no_matching_key | invalid_jwt_token | {no_matching_key_with_kid, Kid :: binary()}.

-type claims() :: #{binary() => term()}.

%% Function to decide if the jwks should be reladed to find a matching key for `Kid'
%%
%% A default function is provided in {@link oidcc:retrieve_token/5}
%% and {@link oidcc:retrieve_userinfo/5}.
%%
%% The default implementation does not implement any rate limiting.

%% @private
%% Checking of jwk sets is a bit wonky because of partial support
%% in jose. see: https://github.com/potatosalad/erlang-jose/issues/28
-spec verify_signature(Token, AllowAlgorithms, Jwks) ->
    {ok, {Jwt, Jws}}
    | {error, error()}
when
    Token :: binary(),
    AllowAlgorithms :: [binary()],
    Jwks :: jose_jwk:key(),
    Jwt :: #jose_jwt{},
    Jws :: #jose_jws{}.
verify_signature(Token, AllowAlgorithms, #jose_jwk{keys = {jose_jwk_set, Keys}}) ->
    lists:foldl(
        fun
            (_Key, {ok, _Res} = Acc) ->
                Acc;
            (Key, Acc) ->
                case {verify_signature(Token, AllowAlgorithms, Key), Acc} of
                    {{ok, Res}, _Acc} ->
                        {ok, Res};
                    {_Res, {error, {no_matching_key_with_kid, Kid}}} ->
                        {error, {no_matching_key_with_kid, Kid}};
                    {Res, _Acc} ->
                        Res
                end
        end,
        {error, no_matching_key},
        Keys
    );
verify_signature(Token, AllowAlgorithms, #jose_jwk{} = Jwks) ->
    try
        Kid =
            case jose_jwt:peek_protected(Token) of
                #jose_jws{fields = #{<<"kid">> := IntKid}} ->
                    IntKid;
                #jose_jws{} ->
                    none
            end,

        case Jwks of
            #jose_jwk{fields = #{<<"kid">> := CmpKid}} when CmpKid =/= Kid, Kid =/= none ->
                {error, {no_matching_key_with_kid, Kid}};
            #jose_jwk{} ->
                case jose_jwt:verify_strict(Jwks, AllowAlgorithms, Token) of
                    {true, Jwt, Jws} ->
                        {ok, {Jwt, Jws}};
                    {false, _Jwt, _Jws} ->
                        {error, no_matching_key}
                end
        end
    catch
        error:{badarg, [_Token]} ->
            {error, invalid_jwt_token};
        %% Some Keys crash if a non matching alg is provided
        error:function_clause ->
            {error, invalid_jwt_token}
    end.

%% @private
-spec verify_claims(Claims, ExpClaims) -> ok | {error, {missing_claim, ExpClaim, Claims}} when
    Claims :: claims(),
    ExpClaim :: {binary(), term()},
    ExpClaims :: [ExpClaim].
verify_claims(Claims, ExpClaims) ->
    CheckExpectedClaims =
        fun({Key, Value}) ->
            case maps:get(Key, Claims, none) of
                Value ->
                    false;
                _Other ->
                    true
            end
        end,
    case lists:filter(CheckExpectedClaims, ExpClaims) of
        [] ->
            ok;
        [Claim | _Rest] ->
            {error, {missing_claim, Claim, Claims}}
    end.

%% @private
-spec client_secret_oct_keys(AllowedAlgorithms, ClientSecret) -> jose_jwk:key() | none when
    AllowedAlgorithms :: [binary()] | undefined,
    ClientSecret :: binary().
client_secret_oct_keys(undefined, _ClientSecret) ->
    none;
client_secret_oct_keys(AllowedAlgorithms, ClientSecret) ->
    case
        lists:member(<<"HS256">>, AllowedAlgorithms) or
            lists:member(<<"HS384">>, AllowedAlgorithms) or
            lists:member(<<"HS512">>, AllowedAlgorithms)
    of
        true ->
            Jwk = jose_jwk:from_oct(ClientSecret),
            Jwk#jose_jwk{fields = maps:merge(Jwk#jose_jwk.fields, #{<<"use">> => <<"sig">>})};
        false ->
            none
    end.

%% @private
-spec refresh_jwks_fun(ProviderConfigurationWorkerName) ->
    refresh_jwks_for_unknown_kid_fun()
when
    ProviderConfigurationWorkerName :: gen_server:server_ref().
refresh_jwks_fun(ProviderConfigurationWorkerName) ->
    fun(_Jwks, Kid) ->
        oidcc_provider_configuration_worker:refresh_jwks_for_unknown_kid(
            ProviderConfigurationWorkerName,
            Kid
        ),
        {ok, oidcc_provider_configuration_worker:get_jwks(ProviderConfigurationWorkerName)}
    end.

%% @private
-spec merge_jwks(Left :: jose_jwk:key(), Right :: jose_jwk:key()) -> jose_jwk:key().
merge_jwks(#jose_jwk{keys = {jose_jwk_set, LeftKeys}, fields = LeftFields}, #jose_jwk{
    keys = {jose_jwk_set, RightKeys}, fields = RightFields
}) ->
    #jose_jwk{
        keys = {jose_jwk_set, LeftKeys ++ RightKeys}, fields = maps:merge(LeftFields, RightFields)
    };
merge_jwks(#jose_jwk{} = Left, #jose_jwk{keys = {jose_jwk_set, _RightKeys}} = Right) ->
    merge_jwks(#jose_jwk{keys = {jose_jwk_set, [Left]}}, Right);
merge_jwks(Left, Right) ->
    merge_jwks(Left, #jose_jwk{keys = {jose_jwk_set, [Right]}}).

%% @private
-spec sign(Jwt :: #jose_jwt{}, Jwk :: jose_jwk:key(), SupportedAlgorithms :: [binary()]) ->
    {ok, binary()} | {error, no_supported_alg_or_key}.
sign(_Jwt, _Jwk, []) ->
    {error, no_supported_alg_or_key};
sign(Jwt, Jwk, [Algorithm | RestAlgorithms]) ->
    Jws = jose_jws:from_map(#{<<"alg">> => Algorithm}),
    SigningCallback = fun
        (#jose_jwk{fields = #{<<"use">> := <<"sig">>}} = Key) ->
            try
                {_Jws, Token} = jose_jws:compact(jose_jwt:sign(Key, Jws, Jwt)),
                {ok, Token}
            catch
                error:not_supported -> error;
                error:{not_supported, _Alg} -> error;
                %% Some Keys crash if a public key is provided
                error:function_clause -> error
            end;
        (#jose_jwk{} = Key) when Algorithm == <<"none">> ->
            {_Jws, Token} = jose_jws:compact(jose_jwt:sign(Key, Jws, Jwt)),
            {ok, Token};
        (_Key) ->
            error
    end,
    case evaluate_for_all_keys(Jwk, SigningCallback) of
        {ok, Token} -> {ok, Token};
        error -> sign(Jwt, Jwk, RestAlgorithms)
    end.

%% @private
-spec encrypt(
    Jwt :: binary(),
    Jwk :: jose_jwk:key(),
    SupportedAlgorithms :: [binary()] | undefined,
    SupportedEncValues :: [binary()] | undefined
) ->
    {ok, binary()} | {error, no_supported_alg_or_key}.
encrypt(_Jwt, _Jwk, undefined, _SupportedEncValues) ->
    {error, no_supported_alg_or_key};
encrypt(_Jwt, _Jwk, _SupportedAlgorithms, undefined) ->
    {error, no_supported_alg_or_key};
encrypt(Jwt, Jwk, SupportedAlgorithms, SupportedEncValues) ->
    encrypt(Jwt, Jwk, SupportedAlgorithms, SupportedEncValues, SupportedEncValues).

-spec encrypt(
    Jwt :: binary(),
    Jwk :: jose_jwk:key(),
    SupportedAlgorithms :: [binary()],
    SupportedEncValues :: [binary()],
    AccEncValues :: [binary()]
) ->
    {ok, binary()} | {error, no_supported_alg_or_key}.
encrypt(_Jwt, _Jwk, [], _SupportedEncValues, _AccEncValues) ->
    {error, no_supported_alg_or_key};
encrypt(Jwt, Jwk, [_Algorithm | RestAlgorithms], SupportedEncValues, []) ->
    encrypt(Jwt, Jwk, RestAlgorithms, SupportedEncValues, SupportedEncValues);
encrypt(Jwt, Jwk, [Algorithm | _RestAlgorithms] = SupportedAlgorithms, SupportedEncValues, [
    EncValue | RestEncValues
]) ->
    EncryptionCallback = fun
        (#jose_jwk{fields = #{<<"use">> := <<"enc">>} = Fields} = Key) ->
            try
                JweParams0 = #{<<"alg">> => Algorithm, <<"enc">> => EncValue},
                JweParams =
                    case maps:get(<<"kid">>, Fields, undefined) of
                        undefined -> JweParams0;
                        Kid -> maps:put(<<"kid">>, Kid, JweParams0)
                    end,
                Jwe = jose_jwe:from_map(JweParams),
                {_Jws, Token} = jose_jwe:compact(jose_jwk:block_encrypt(Jwt, Jwe, Key)),
                {ok, Token}
            catch
                error:{not_supported, _Alg} -> error
            end;
        (_Key) ->
            error
    end,
    case evaluate_for_all_keys(Jwk, EncryptionCallback) of
        {ok, Token} -> {ok, Token};
        error -> encrypt(Jwt, Jwk, SupportedAlgorithms, SupportedEncValues, RestEncValues)
    end.

%% @private
-spec evaluate_for_all_keys(Jwk :: jose_jwk:key(), fun((jose_jwk:key()) -> {ok, Result} | error)) ->
    {ok, Result} | error
when
    Result :: term().
evaluate_for_all_keys(#jose_jwk{keys = {jose_jwk_set, Keys}}, Callback) ->
    lists:foldl(
        fun
            (_Key, {ok, Result}) ->
                {ok, Result};
            (Key, error) ->
                evaluate_for_all_keys(Key, Callback)
        end,
        error,
        Keys
    );
evaluate_for_all_keys(#jose_jwk{} = Jwk, Callback) ->
    Callback(Jwk).