Skip to main content

src/barrel_mcp_auth_bearer.erl

%%%-------------------------------------------------------------------
%%% @author Benoit Chesneau
%%% @copyright 2024-2026 Benoit Chesneau
%%% @doc Bearer token authentication provider for barrel_mcp.
%%%
%%% Supports JWT validation (HS256 built-in, RS256/ES256 via custom verifier),
%%% opaque tokens, and standard claims validation (iss, aud, exp, nbf).
%%%
%%% == Configuration Options ==
%%%
%%% <ul>
%%%   <li>`verifier' - Custom verification function for tokens</li>
%%%   <li>`secret' - HMAC secret for HS256 JWT validation</li>
%%%   <li>`issuer' - Expected issuer (iss claim)</li>
%%%   <li>`audience' - Expected audience (aud claim)</li>
%%%   <li>`clock_skew' - Allowed clock skew in seconds (default: 60)</li>
%%%   <li>`scope_claim' - Claim name for scopes (default: scope)</li>
%%%   <li>`realm' - Realm for WWW-Authenticate header</li>
%%%   <li>`resource' - Resource identifier for RFC 8707</li>
%%% </ul>
%%%
%%% @see barrel_mcp_auth
%%% @end
%%%-------------------------------------------------------------------
-module(barrel_mcp_auth_bearer).

-behaviour(barrel_mcp_auth).

%% barrel_mcp_auth callbacks
-export([
    init/1,
    authenticate/2,
    challenge/2
]).

%% JWT utilities (exported for testing)
-export([
    decode_jwt/1,
    verify_hs256/2,
    validate_claims/2
]).

-define(DEFAULT_CLOCK_SKEW, 60).
-define(DEFAULT_SCOPE_CLAIM, <<"scope">>).

%%====================================================================
%% barrel_mcp_auth callbacks
%%====================================================================

%% @doc Initialize the Bearer token provider.
-spec init(map()) -> {ok, map()}.
init(Opts) ->
    State = #{
        verifier => maps:get(verifier, Opts, undefined),
        secret => maps:get(secret, Opts, undefined),
        issuer => maps:get(issuer, Opts, undefined),
        audience => maps:get(audience, Opts, undefined),
        clock_skew => maps:get(clock_skew, Opts, ?DEFAULT_CLOCK_SKEW),
        scope_claim => maps:get(scope_claim, Opts, ?DEFAULT_SCOPE_CLAIM),
        realm => maps:get(realm, Opts, <<"mcp">>),
        resource => maps:get(resource, Opts, undefined)
    },
    {ok, State}.

%% @doc Authenticate a request using Bearer token.
-spec authenticate(map(), map()) ->
    {ok, barrel_mcp_auth:auth_info()} | {error, barrel_mcp_auth:auth_error()}.
authenticate(Request, State) ->
    Headers = maps:get(headers, Request, #{}),
    case barrel_mcp_auth:extract_bearer_token(Headers) of
        {ok, Token} ->
            verify_token(Token, State);
        {error, no_token} ->
            {error, unauthorized}
    end.

%% @doc Generate a WWW-Authenticate challenge.
-spec challenge(barrel_mcp_auth:auth_error(), map()) ->
    {integer(), map(), binary()}.
challenge(Reason, State) ->
    Realm = maps:get(realm, State, <<"mcp">>),
    %% RFC 9728 / MCP auth: emit `resource_metadata="<URL>"' so
    %% the client can discover the AS via PRM. We previously
    %% emitted a non-conformant `resource="..."' parameter (the
    %% audience-claim string from RFC 8707, conflated with the
    %% RFC 9728 metadata URL). Drop it.
    MetaUrl = maps:get(resource_metadata_url, State, undefined),

    {StatusCode, ErrorCode, ErrorDesc} = error_details(Reason),

    %% Build WWW-Authenticate header per RFC 6750 + RFC 9728 /
    %% MCP authorization sub-spec.
    Challenge = build_challenge(Realm, ErrorCode, ErrorDesc, MetaUrl),

    Body = iolist_to_binary(
        json:encode(#{
            <<"error">> => ErrorCode,
            <<"error_description">> => ErrorDesc
        })
    ),

    {StatusCode,
        #{
            <<"www-authenticate">> => Challenge,
            <<"content-type">> => <<"application/json">>
        },
        Body}.

%%====================================================================
%% Token verification
%%====================================================================

verify_token(Token, #{verifier := Verifier} = State) when is_function(Verifier, 1) ->
    %% Custom verifier function
    case Verifier(Token) of
        {ok, Claims} when is_map(Claims) ->
            validate_and_build_auth_info(Claims, State);
        {error, _} = Error ->
            Error
    end;
verify_token(Token, #{secret := Secret} = State) when Secret =/= undefined ->
    %% JWT with HS256
    case decode_jwt(Token) of
        {ok, Header, Claims, _Signature} ->
            case maps:get(<<"alg">>, Header, undefined) of
                <<"HS256">> ->
                    case verify_hs256(Token, Secret) of
                        true ->
                            validate_and_build_auth_info(Claims, State);
                        false ->
                            {error, invalid_token}
                    end;
                Alg ->
                    %% Unsupported algorithm without custom verifier
                    error_logger:warning_msg(
                        "Unsupported JWT algorithm ~p, use custom verifier~n", [Alg]
                    ),
                    {error, invalid_token}
            end;
        {error, _} = Error ->
            Error
    end;
verify_token(_Token, _State) ->
    %% No verifier or secret configured
    {error, {error, no_verifier_configured}}.

validate_and_build_auth_info(Claims, State) ->
    case validate_claims(Claims, State) of
        ok ->
            build_auth_info(Claims, State);
        {error, _} = Error ->
            Error
    end.

%%====================================================================
%% JWT decoding and verification
%%====================================================================

%% @doc Decode a JWT without verification.
%% Returns {ok, Header, Claims, Signature} or {error, Reason}.
-spec decode_jwt(binary()) ->
    {ok, map(), map(), binary()} | {error, term()}.
decode_jwt(Token) ->
    try
        case binary:split(Token, <<".">>, [global]) of
            [HeaderB64, ClaimsB64, SignatureB64] ->
                Header = json:decode(base64url_decode(HeaderB64)),
                Claims = json:decode(base64url_decode(ClaimsB64)),
                Signature = base64url_decode(SignatureB64),
                {ok, Header, Claims, Signature};
            _ ->
                {error, invalid_token}
        end
    catch
        _:_ ->
            {error, invalid_token}
    end.

%% @doc Verify HS256 signature.
-spec verify_hs256(binary(), binary()) -> boolean().
verify_hs256(Token, Secret) ->
    case binary:split(Token, <<".">>, [global]) of
        [HeaderB64, ClaimsB64, SignatureB64] ->
            SigningInput = <<HeaderB64/binary, ".", ClaimsB64/binary>>,
            ExpectedSig = crypto:mac(hmac, sha256, Secret, SigningInput),
            ActualSig = base64url_decode(SignatureB64),
            %% Constant-time comparison
            crypto:hash_equals(ExpectedSig, ActualSig);
        _ ->
            false
    end.

%% @doc Validate JWT claims.
-spec validate_claims(map(), map()) -> ok | {error, term()}.
validate_claims(Claims, State) ->
    Now = erlang:system_time(second),
    ClockSkew = maps:get(clock_skew, State, ?DEFAULT_CLOCK_SKEW),

    Checks = [
        fun() -> check_expiration(Claims, Now, ClockSkew) end,
        fun() -> check_not_before(Claims, Now, ClockSkew) end,
        fun() -> check_issuer(Claims, State) end,
        fun() -> check_audience_claim(Claims, State) end
    ],
    run_checks(Checks).

run_checks([]) ->
    ok;
run_checks([Check | Rest]) ->
    case Check() of
        ok -> run_checks(Rest);
        {error, _} = Error -> Error
    end.

check_expiration(Claims, Now, ClockSkew) ->
    case maps:get(<<"exp">>, Claims, undefined) of
        undefined ->
            ok;
        Exp when is_integer(Exp), Exp + ClockSkew < Now ->
            {error, expired_token};
        _ ->
            ok
    end.

check_not_before(Claims, Now, ClockSkew) ->
    case maps:get(<<"nbf">>, Claims, undefined) of
        undefined ->
            ok;
        Nbf when is_integer(Nbf), Nbf - ClockSkew > Now ->
            {error, invalid_token};
        _ ->
            ok
    end.

check_issuer(Claims, State) ->
    case maps:get(issuer, State, undefined) of
        undefined ->
            ok;
        ExpectedIssuer ->
            case maps:get(<<"iss">>, Claims, undefined) of
                ExpectedIssuer -> ok;
                _ -> {error, invalid_token}
            end
    end.

check_audience_claim(Claims, State) ->
    case maps:get(audience, State, undefined) of
        undefined -> ok;
        ExpectedAud -> check_audience(ExpectedAud, maps:get(<<"aud">>, Claims, undefined))
    end.

check_audience(Expected, Actual) when is_binary(Expected), is_binary(Actual) ->
    case Expected =:= Actual of
        true -> ok;
        false -> {error, invalid_token}
    end;
check_audience(Expected, Actual) when is_binary(Expected), is_list(Actual) ->
    case lists:member(Expected, Actual) of
        true -> ok;
        false -> {error, invalid_token}
    end;
check_audience(ExpectedList, Actual) when is_list(ExpectedList), is_binary(Actual) ->
    case lists:member(Actual, ExpectedList) of
        true -> ok;
        false -> {error, invalid_token}
    end;
check_audience(ExpectedList, ActualList) when is_list(ExpectedList), is_list(ActualList) ->
    case lists:any(fun(E) -> lists:member(E, ActualList) end, ExpectedList) of
        true -> ok;
        false -> {error, invalid_token}
    end;
check_audience(_, undefined) ->
    {error, invalid_token};
check_audience(_, _) ->
    {error, invalid_token}.

%%====================================================================
%% Auth info building
%%====================================================================

build_auth_info(Claims, State) ->
    ScopeClaim = maps:get(scope_claim, State, ?DEFAULT_SCOPE_CLAIM),
    Scopes = extract_scopes(maps:get(ScopeClaim, Claims, <<>>)),

    AuthInfo = #{
        subject => maps:get(<<"sub">>, Claims, undefined),
        issuer => maps:get(<<"iss">>, Claims, undefined),
        audience => maps:get(<<"aud">>, Claims, undefined),
        scopes => Scopes,
        expires_at => maps:get(<<"exp">>, Claims, undefined),
        claims => Claims,
        metadata => #{provider => barrel_mcp_auth_bearer}
    },
    {ok, AuthInfo}.

extract_scopes(ScopeStr) when is_binary(ScopeStr) ->
    %% Scopes as space-separated string
    [S || S <- binary:split(ScopeStr, <<" ">>, [global]), S =/= <<>>];
extract_scopes(Scopes) when is_list(Scopes) ->
    %% Scopes as list
    Scopes;
extract_scopes(_) ->
    [].

%%====================================================================
%% Challenge building
%%====================================================================

error_details(unauthorized) ->
    {401, <<"invalid_request">>, <<"Authorization required">>};
error_details(invalid_token) ->
    {401, <<"invalid_token">>, <<"The access token is invalid">>};
error_details(expired_token) ->
    {401, <<"invalid_token">>, <<"The access token has expired">>};
error_details(insufficient_scope) ->
    {403, <<"insufficient_scope">>, <<"The access token has insufficient scope">>};
error_details({error, Reason}) when is_binary(Reason) ->
    {401, <<"invalid_token">>, Reason};
error_details({error, _}) ->
    {401, <<"invalid_token">>, <<"Token verification failed">>};
error_details(_) ->
    {401, <<"invalid_token">>, <<"Authentication failed">>}.

build_challenge(Realm, ErrorCode, ErrorDesc, ResourceMetadataUrl) ->
    Parts = [<<"Bearer realm=\"", Realm/binary, "\"">>],
    Parts1 =
        case ErrorCode of
            <<"invalid_request">> -> Parts;
            _ -> Parts ++ [<<" error=\"", ErrorCode/binary, "\"">>]
        end,
    Parts2 =
        case ErrorDesc of
            <<>> -> Parts1;
            _ -> Parts1 ++ [<<" error_description=\"", ErrorDesc/binary, "\"">>]
        end,
    Parts3 =
        case ResourceMetadataUrl of
            undefined -> Parts2;
            Url -> Parts2 ++ [<<" resource_metadata=\"", Url/binary, "\"">>]
        end,
    iolist_to_binary(lists:join(<<",">>, Parts3)).

%%====================================================================
%% Base64URL utilities
%%====================================================================

base64url_decode(Data) ->
    %% Add padding if necessary
    Padded =
        case byte_size(Data) rem 4 of
            0 -> Data;
            2 -> <<Data/binary, "==">>;
            3 -> <<Data/binary, "=">>
        end,
    %% Convert URL-safe characters
    Std = binary:replace(
        binary:replace(Padded, <<"-">>, <<"+">>, [global]),
        <<"_">>,
        <<"/">>,
        [global]
    ),
    base64:decode(Std).