src/oidcc_jwt_util.erl

%%%-------------------------------------------------------------------
%% @doc JWT Utilities
%% @end
%%%-------------------------------------------------------------------
-module(oidcc_jwt_util).

-include_lib("jose/include/jose_jwk.hrl").
-include_lib("jose/include/jose_jws.hrl").
-include_lib("jose/include/jose_jwt.hrl").

-export([client_secret_oct_keys/2]).
-export([refresh_jwks_fun/1]).
-export([verify_claims/2]).
-export([verify_signature/3]).

-export_type([claims/0]).
-export_type([error/0]).
-export_type([refresh_jwks_for_unknown_kid_fun/0]).

-type refresh_jwks_for_unknown_kid_fun() ::
    fun((Jwks :: jose_jwk:key(), Kid :: binary()) -> {ok, jose_jwk:key()} | {error, term()}).

-type error() :: no_matching_key | invalid_jwt_token | {no_matching_key_with_kid, Kid :: binary()}.

-type claims() :: #{binary() => term()}.

%% Function to decide if the jwks should be reladed to find a matching key for `Kid'
%%
%% A default function is provided in {@link oidcc:retrieve_token/5}
%% and {@link oidcc:retrieve_userinfo/5}.
%%
%% The default implementation does not implement any rate limiting.

%% @private
%% Checking of jwk sets is a bit wonky because of partial support
%% in jose. see: https://github.com/potatosalad/erlang-jose/issues/28
-spec verify_signature(Token, AllowAlgorithms, Jwks) ->
    {ok, {Jwt, Jws}}
    | {error, error()}
when
    Token :: binary(),
    AllowAlgorithms :: [binary()],
    Jwks :: jose_jwk:key(),
    Jwt :: #jose_jwt{},
    Jws :: #jose_jws{}.
verify_signature(Token, AllowAlgorithms, #jose_jwk{keys = {jose_jwk_set, Keys}}) ->
    lists:foldl(
        fun
            (_Key, {ok, _Res} = Acc) ->
                Acc;
            (Key, Acc) ->
                case {verify_signature(Token, AllowAlgorithms, Key), Acc} of
                    {{ok, Res}, _Acc} ->
                        {ok, Res};
                    {_Res, {error, {no_matching_key_with_kid, Kid}}} ->
                        {error, {no_matching_key_with_kid, Kid}};
                    {Res, _Acc} ->
                        Res
                end
        end,
        {error, no_matching_key},
        Keys
    );
verify_signature(Token, AllowAlgorithms, #jose_jwk{} = Jwks) ->
    try
        Kid =
            case jose_jwt:peek_protected(Token) of
                #jose_jws{fields = #{<<"kid">> := IntKid}} ->
                    IntKid;
                #jose_jws{} ->
                    none
            end,

        case Jwks of
            #jose_jwk{fields = #{<<"kid">> := CmpKid}} when CmpKid =/= Kid, Kid =/= none ->
                {error, {no_matching_key_with_kid, Kid}};
            #jose_jwk{} ->
                case jose_jwt:verify_strict(Jwks, AllowAlgorithms, Token) of
                    {true, Jwt, Jws} ->
                        {ok, {Jwt, Jws}};
                    {false, _Jwt, _Jws} ->
                        {error, no_matching_key}
                end
        end
    catch
        error:{badarg, [_Token]} ->
            {error, invalid_jwt_token}
    end.

%% @private
-spec verify_claims(Claims, ExpClaims) -> ok | {error, {missing_claim, ExpClaim, Claims}} when
    Claims :: claims(),
    ExpClaim :: {binary(), term()},
    ExpClaims :: [ExpClaim].
verify_claims(Claims, ExpClaims) ->
    CheckExpectedClaims =
        fun({Key, Value}) ->
            case maps:get(Key, Claims, none) of
                Value ->
                    false;
                _Other ->
                    true
            end
        end,
    case lists:filter(CheckExpectedClaims, ExpClaims) of
        [] ->
            ok;
        [Claim | _Rest] ->
            {error, {missing_claim, Claim, Claims}}
    end.

%% @private
-spec client_secret_oct_keys(AllowedAlgorithms, ClientSecret) -> jose_jwk:key() | none when
    AllowedAlgorithms :: [binary()],
    ClientSecret :: binary().
client_secret_oct_keys(AllowedAlgorithms, ClientSecret) ->
    case
        lists:member(<<"HS256">>, AllowedAlgorithms) or
            lists:member(<<"HS384">>, AllowedAlgorithms) or
            lists:member(<<"HS512">>, AllowedAlgorithms)
    of
        true ->
            jose_jwk:from_oct(ClientSecret);
        false ->
            none
    end.

%% @private
-spec refresh_jwks_fun(ProviderConfigurationWorkerName) ->
    refresh_jwks_for_unknown_kid_fun()
when
    ProviderConfigurationWorkerName :: gen_server:server_ref().
refresh_jwks_fun(ProviderConfigurationWorkerName) ->
    fun(_Jwks, Kid) ->
        oidcc_provider_configuration_worker:refresh_jwks_for_unknown_kid(
            ProviderConfigurationWorkerName,
            Kid
        ),
        {ok, oidcc_provider_configuration_worker:get_jwks(ProviderConfigurationWorkerName)}
    end.