Skip to main content

src/barrel_mcp_session.erl

%%%-------------------------------------------------------------------
%%% @author Benoit Chesneau
%%% @copyright 2024-2026 Benoit Chesneau
%%% @doc MCP Session Management.
%%%
%%% Provides ETS-based session management for MCP Streamable HTTP transport.
%%% Sessions track client connections, protocol versions, and activity.
%%%
%%% @end
%%%-------------------------------------------------------------------
-module(barrel_mcp_session).

-behaviour(gen_server).

%% API
-export([
    start_link/0,
    create/1,
    get/1,
    update_activity/1,
    delete/1,
    generate_id/0,
    list/0,
    cleanup_expired/1,
    %% Capability tracking (set during MCP `initialize').
    set_client_capabilities/2,
    has_sampling/1,
    list_sampling_capable/0,
    has_elicitation/1,
    list_elicitation_capable/0,
    has_roots/1,
    list_roots_capable/0,
    %% Per-session log level (`logging/setLevel').
    set_log_level/2,
    get_log_level/1,
    log_level_priority/1,
    %% Negotiated protocol version (after `initialize').
    set_protocol_version/2,
    get_protocol_version/1,
    %% sse_pid management.
    set_sse_pid/2,
    get_sse_pid/1,
    %% Resource subscription tracking (server-side, used to emit
    %% notifications/resources/updated when an exposed resource changes).
    subscribe_resource/2,
    unsubscribe_resource/2,
    subscribers_for/1,
    %% Server -> client request via the session's SSE channel.
    sampling_create_message/3,
    elicit_create/3,
    roots_list/2,
    deliver_response/2,
    %% Server -> client notifications.
    broadcast_list_changed/1,
    notify_progress/4,
    %% In-flight tool tracking (used by `notifications/cancelled').
    record_in_flight/4,
    cancel_in_flight/2,
    clear_in_flight/2,
    %% SSE replay (Last-Event-ID).
    record_sse_event/3,
    events_since/2,
    set_sse_buffer_max/2
]).

%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2]).

-include("barrel_mcp.hrl").

-define(SESSION_TABLE, barrel_mcp_sessions).
-define(SUBSCRIPTIONS_TABLE, barrel_mcp_resource_subs).
-define(PENDING_TABLE, barrel_mcp_pending_requests).
%% In-flight tool calls per session: {{SessionId, RequestId} => #in_flight{}}
-define(INFLIGHT_TABLE, barrel_mcp_inflight).
%% 1 minute
-define(CLEANUP_INTERVAL, 60000).
-define(DEFAULT_SAMPLING_TIMEOUT, 30000).

-record(mcp_session, {
    id :: binary(),
    created_at :: integer(),
    last_activity :: integer(),
    client_info :: map(),
    client_capabilities :: map(),
    protocol_version :: binary(),
    %% Process handling SSE stream
    sse_pid :: pid() | undefined,
    %% Recent SSE events (newest first) for `Last-Event-ID' replay.
    sse_buffer = [] :: [{binary(), map()}],
    sse_buffer_max = 256 :: pos_integer(),
    %% Per-session log level set by `logging/setLevel'. Default
    %% `info' per the MCP spec. Filters `notifications/message' on
    %% emit.
    log_level = info :: log_level()
}).

-type log_level() ::
    debug
    | info
    | notice
    | warning
    | error
    | critical
    | alert
    | emergency.

%% Async tool call in-flight tracking.
-record(in_flight, {
    session_id :: binary(),
    request_id :: integer() | binary(),
    worker_pid :: pid(),
    waiter_pid :: pid()
}).

-record(pending, {
    id :: binary(),
    session_id :: binary(),
    caller :: pid(),
    caller_ref :: reference(),
    expires_at :: integer(),
    tag = sampling_response :: atom()
}).

%%====================================================================
%% API
%%====================================================================

%% @doc Start the session manager.
-spec start_link() -> {ok, pid()} | {error, term()}.
start_link() ->
    gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).

%% @doc Create a new session.
-spec create(Opts) -> {ok, binary()} when
    Opts :: #{
        client_info => map(),
        protocol_version => binary()
    }.
create(Opts) ->
    gen_server:call(?MODULE, {create, Opts}).

%% @doc Get a session by ID.
-spec get(binary()) -> {ok, map()} | {error, not_found}.
get(SessionId) ->
    case ets:lookup(?SESSION_TABLE, SessionId) of
        [{_, Session}] ->
            {ok, session_to_map(Session)};
        [] ->
            {error, not_found}
    end.

%% @doc Update last activity timestamp.
-spec update_activity(binary()) -> ok | {error, not_found}.
update_activity(SessionId) ->
    gen_server:call(?MODULE, {update_activity, SessionId}).

%% @doc Delete a session.
-spec delete(binary()) -> ok.
delete(SessionId) ->
    gen_server:call(?MODULE, {delete, SessionId}).

%% @doc Generate a unique session ID.
-spec generate_id() -> binary().
generate_id() ->
    Rand = crypto:strong_rand_bytes(16),
    Hex = binary:encode_hex(Rand, lowercase),
    <<"mcp_", Hex/binary>>.

%% @doc List all sessions.
-spec list() -> [map()].
list() ->
    ets:foldl(
        fun({_, Session}, Acc) ->
            [session_to_map(Session) | Acc]
        end,
        [],
        ?SESSION_TABLE
    ).

%% @doc Set the client_capabilities map for a session. Called from the
%% protocol handler after parsing the `initialize' request.
-spec set_client_capabilities(binary(), map()) -> ok | {error, not_found}.
set_client_capabilities(SessionId, Capabilities) when is_map(Capabilities) ->
    gen_server:call(?MODULE, {set_client_capabilities, SessionId, Capabilities}).

%% @doc Record the negotiated protocol version on a session. Called
%% by the HTTP transport after a successful `initialize' so later
%% requests on the same session can fall back to it when the client
%% omits the `MCP-Protocol-Version' header.
-spec set_protocol_version(binary(), binary()) -> ok | {error, not_found}.
set_protocol_version(SessionId, Version) when is_binary(Version) ->
    gen_server:call(?MODULE, {set_protocol_version, SessionId, Version}).

%% @doc Look up the negotiated protocol version for a session.
-spec get_protocol_version(binary()) -> {ok, binary()} | {error, not_found}.
get_protocol_version(SessionId) ->
    case ets:lookup(?SESSION_TABLE, SessionId) of
        [{_, #mcp_session{protocol_version = V}}] when is_binary(V) ->
            {ok, V};
        [{_, _}] ->
            {ok, ?MCP_PROTOCOL_VERSION};
        [] ->
            {error, not_found}
    end.

%% @doc Whether a session declared sampling capability in its initialize
%% request.
-spec has_sampling(binary()) -> boolean().
has_sampling(SessionId) ->
    case ets:lookup(?SESSION_TABLE, SessionId) of
        [{_, #mcp_session{client_capabilities = Caps}}] ->
            maps:is_key(<<"sampling">>, Caps);
        [] ->
            false
    end.

%% @doc List session ids whose client declared sampling capability.
-spec list_sampling_capable() -> [binary()].
list_sampling_capable() ->
    ets:foldl(
        fun({Id, #mcp_session{client_capabilities = Caps}}, Acc) ->
            case maps:is_key(<<"sampling">>, Caps) of
                true -> [Id | Acc];
                false -> Acc
            end
        end,
        [],
        ?SESSION_TABLE
    ).

%% @doc Whether a session declared elicitation capability in its
%% initialize request.
-spec has_elicitation(binary()) -> boolean().
has_elicitation(SessionId) ->
    case ets:lookup(?SESSION_TABLE, SessionId) of
        [{_, #mcp_session{client_capabilities = Caps}}] ->
            maps:is_key(<<"elicitation">>, Caps);
        [] ->
            false
    end.

%% @doc List session ids whose client declared elicitation capability.
-spec list_elicitation_capable() -> [binary()].
list_elicitation_capable() ->
    ets:foldl(
        fun({Id, #mcp_session{client_capabilities = Caps}}, Acc) ->
            case maps:is_key(<<"elicitation">>, Caps) of
                true -> [Id | Acc];
                false -> Acc
            end
        end,
        [],
        ?SESSION_TABLE
    ).

%% @doc Whether a session declared roots capability in its initialize
%% request.
-spec has_roots(binary()) -> boolean().
has_roots(SessionId) ->
    case ets:lookup(?SESSION_TABLE, SessionId) of
        [{_, #mcp_session{client_capabilities = Caps}}] ->
            maps:is_key(<<"roots">>, Caps);
        [] ->
            false
    end.

%% @doc List session ids whose client declared roots capability.
-spec list_roots_capable() -> [binary()].
list_roots_capable() ->
    ets:foldl(
        fun({Id, #mcp_session{client_capabilities = Caps}}, Acc) ->
            case maps:is_key(<<"roots">>, Caps) of
                true -> [Id | Acc];
                false -> Acc
            end
        end,
        [],
        ?SESSION_TABLE
    ).

%% @doc Set the per-session log level (driven by `logging/setLevel').
%% `Level' is one of the eight RFC 5424 levels accepted by the MCP
%% spec; rejects anything else with `{error, invalid_level}'.
-spec set_log_level(binary(), log_level() | binary()) ->
    ok | {error, not_found | invalid_level}.
set_log_level(SessionId, Level) ->
    case parse_level(Level) of
        {ok, L} ->
            gen_server:call(?MODULE, {set_log_level, SessionId, L});
        error ->
            {error, invalid_level}
    end.

%% @doc Read the current log level for a session. Defaults to `info'
%% before any `logging/setLevel' is received.
-spec get_log_level(binary()) -> {ok, log_level()} | {error, not_found}.
get_log_level(SessionId) ->
    case ets:lookup(?SESSION_TABLE, SessionId) of
        [{_, #mcp_session{log_level = L}}] -> {ok, L};
        [] -> {error, not_found}
    end.

%% @doc Numeric priority for the eight RFC 5424 levels (debug=0,
%% emergency=7). Higher = more severe. Used for filtering: a
%% notification at priority N is delivered iff N >= configured level.
-spec log_level_priority(log_level() | binary()) -> 0..7 | error.
log_level_priority(Level) ->
    case parse_level(Level) of
        {ok, L} -> level_priority(L);
        error -> error
    end.

level_priority(debug) -> 0;
level_priority(info) -> 1;
level_priority(notice) -> 2;
level_priority(warning) -> 3;
level_priority(error) -> 4;
level_priority(critical) -> 5;
level_priority(alert) -> 6;
level_priority(emergency) -> 7.

parse_level(L) when is_atom(L) ->
    case
        lists:member(L, [
            debug,
            info,
            notice,
            warning,
            error,
            critical,
            alert,
            emergency
        ])
    of
        true -> {ok, L};
        false -> error
    end;
parse_level(<<"debug">>) ->
    {ok, debug};
parse_level(<<"info">>) ->
    {ok, info};
parse_level(<<"notice">>) ->
    {ok, notice};
parse_level(<<"warning">>) ->
    {ok, warning};
parse_level(<<"error">>) ->
    {ok, error};
parse_level(<<"critical">>) ->
    {ok, critical};
parse_level(<<"alert">>) ->
    {ok, alert};
parse_level(<<"emergency">>) ->
    {ok, emergency};
parse_level(_) ->
    error.

%% @doc Set the SSE process pid for a session.
-spec set_sse_pid(binary(), pid() | undefined) -> ok | {error, not_found}.
set_sse_pid(SessionId, Pid) ->
    gen_server:call(?MODULE, {set_sse_pid, SessionId, Pid}).

-spec get_sse_pid(binary()) -> {ok, pid()} | {error, not_found | no_sse}.
get_sse_pid(SessionId) ->
    case ets:lookup(?SESSION_TABLE, SessionId) of
        [{_, #mcp_session{sse_pid = undefined}}] -> {error, no_sse};
        [{_, #mcp_session{sse_pid = Pid}}] when is_pid(Pid) -> {ok, Pid};
        [] -> {error, not_found}
    end.

%% @doc Subscribe a session to resource updates for a given URI.
-spec subscribe_resource(binary(), binary()) -> ok.
subscribe_resource(SessionId, Uri) when
    is_binary(SessionId), is_binary(Uri)
->
    gen_server:call(?MODULE, {subscribe_resource, SessionId, Uri}).

-spec unsubscribe_resource(binary(), binary()) -> ok.
unsubscribe_resource(SessionId, Uri) ->
    gen_server:call(?MODULE, {unsubscribe_resource, SessionId, Uri}).

%% @doc Return all session ids that subscribed to a URI.
-spec subscribers_for(binary()) -> [binary()].
subscribers_for(Uri) when is_binary(Uri) ->
    _ = ensure_subs_table(),
    %% match-spec to find all {SessionId, Uri} for the given Uri
    Pattern = {{'$1', Uri}},
    Match = [{Pattern, [], ['$1']}],
    ets:select(?SUBSCRIPTIONS_TABLE, Match).

%% @doc Send `sampling/createMessage' to the client behind a session and
%% wait for the response. The session must (a) exist, (b) have an active
%% sse_pid, and (c) have declared sampling capability in initialize.
-spec sampling_create_message(binary(), map(), map()) ->
    {ok, map(), map()}
    | {error, timeout | not_supported | no_sse | not_found | term()}.
sampling_create_message(SessionId, Params, Opts) ->
    case has_sampling(SessionId) of
        false ->
            {error, not_supported};
        true ->
            case get_sse_pid(SessionId) of
                {error, _} = E -> E;
                {ok, Pid} -> do_sampling(SessionId, Pid, Params, Opts)
            end
    end.

%% @doc Send `elicitation/create' to the client behind a session and wait
%% for the response. The session must (a) exist, (b) have an active
%% sse_pid, and (c) have declared elicitation capability in initialize.
-spec elicit_create(binary(), map(), map()) ->
    {ok, map()}
    | {error, timeout | not_supported | no_sse | not_found | term()}.
elicit_create(SessionId, Params, Opts) ->
    case has_elicitation(SessionId) of
        false ->
            {error, not_supported};
        true ->
            case get_sse_pid(SessionId) of
                {error, _} = E -> E;
                {ok, Pid} -> do_elicit(SessionId, Pid, Params, Opts)
            end
    end.

%% @doc Send `roots/list' to the client behind a session and wait for
%% the response. The session must (a) exist, (b) have an active sse_pid,
%% and (c) have declared roots capability in initialize.
-spec roots_list(binary(), map()) ->
    {ok, [map()]}
    | {error, timeout | not_supported | no_sse | not_found | term()}.
roots_list(SessionId, Opts) ->
    case has_roots(SessionId) of
        false ->
            {error, not_supported};
        true ->
            case get_sse_pid(SessionId) of
                {error, _} = E -> E;
                {ok, Pid} -> do_roots_list(SessionId, Pid, Opts)
            end
    end.

%% @doc Deliver a JSON-RPC response from the client back to the waiting
%% caller. Called by the HTTP handler when an inbound POST contains a
%% `result' or `error' for a server-initiated id.
-spec deliver_response(binary() | integer(), map()) -> ok | {error, unknown_id}.
deliver_response(Id, Response) ->
    gen_server:call(?MODULE, {deliver_response, id_to_binary(Id), Response}).

%% @doc Push a `notifications/<kind>/list_changed' envelope to every
%% session that has an active SSE channel. Tolerates a missing
%% session manager (e.g. during stdio-only operation).
-spec broadcast_list_changed(handler_type()) -> ok.
broadcast_list_changed(Kind) ->
    case {whereis(?MODULE), list_changed_method(Kind)} of
        {undefined, _} ->
            ok;
        %% kind has no list_changed notification
        {_, undefined} ->
            ok;
        {_, Method} ->
            Notif = #{
                <<"jsonrpc">> => <<"2.0">>,
                <<"method">> => Method,
                <<"params">> => #{}
            },
            broadcast_to_sse_sessions(Notif)
    end.

list_changed_method(tool) -> <<"notifications/tools/list_changed">>;
list_changed_method(resource) -> <<"notifications/resources/list_changed">>;
list_changed_method(resource_template) -> <<"notifications/resources/list_changed">>;
list_changed_method(prompt) -> <<"notifications/prompts/list_changed">>;
list_changed_method(completion) -> undefined.

broadcast_to_sse_sessions(Notification) ->
    %% Reads from a `protected' ETS via direct ets:foldl/3 work fine
    %% from any process. We only need the gen_server when we mutate
    %% the table.
    case ets:whereis(?SESSION_TABLE) of
        undefined ->
            ok;
        _ ->
            ets:foldl(
                fun
                    ({_Id, #mcp_session{sse_pid = Pid}}, Acc) when is_pid(Pid) ->
                        Pid ! {sse_send_message, Notification},
                        Acc;
                    (_, Acc) ->
                        Acc
                end,
                ok,
                ?SESSION_TABLE
            )
    end.

%% @doc Record an in-flight tool call so a later
%% `notifications/cancelled' can find the worker and waiter.
-spec record_in_flight(binary(), integer() | binary(), pid(), pid()) -> ok.
record_in_flight(SessionId, RequestId, WorkerPid, WaiterPid) ->
    gen_server:call(
        ?MODULE,
        {record_in_flight, SessionId, RequestId, WorkerPid, WaiterPid}
    ).

%% @doc Cancel an in-flight tool call. Sends `{cancel, RequestId}'
%% to the worker and `{cancelled, RequestId}' to the waiter, then
%% drops the entry. Idempotent: a missing entry returns `ok'.
-spec cancel_in_flight(binary(), integer() | binary()) -> ok.
cancel_in_flight(SessionId, RequestId) ->
    gen_server:call(?MODULE, {cancel_in_flight, SessionId, RequestId}).

%% @doc Drop an in-flight entry (called by the waiter after a normal
%% completion).
-spec clear_in_flight(binary(), integer() | binary()) -> ok.
clear_in_flight(SessionId, RequestId) ->
    gen_server:call(?MODULE, {clear_in_flight, SessionId, RequestId}).

%% @doc Append an SSE event to the session's ring buffer for later
%% replay via `Last-Event-ID'.
-spec record_sse_event(binary(), binary(), map()) -> ok.
record_sse_event(SessionId, EventId, Payload) ->
    gen_server:call(
        ?MODULE,
        {record_sse_event, SessionId, EventId, Payload}
    ).

%% @doc Return SSE events newer than `LastId' (oldest first), or
%% `truncated' when `LastId' is older than the oldest buffered event.
-spec events_since(binary(), binary()) ->
    {ok, [{binary(), map()}]} | truncated | {error, not_found}.
events_since(SessionId, LastId) ->
    case ets:lookup(?SESSION_TABLE, SessionId) of
        [{_, #mcp_session{sse_buffer = Buf}}] ->
            collect_after(Buf, LastId);
        [] ->
            {error, not_found}
    end.

%% Buffer is newest-first. Return events after `LastId' in
%% chronological order (oldest first), or `truncated' if LastId
%% isn't in the window.
collect_after(Buf, LastId) ->
    case lists:splitwith(fun({Id, _}) -> Id =/= LastId end, Buf) of
        {_, []} ->
            %% LastId not found — buffer rolled over.
            truncated;
        {Newer, [_ | _]} ->
            {ok, lists:reverse(Newer)}
    end.

%% @doc Configure the maximum number of SSE events buffered per
%% session for replay.
-spec set_sse_buffer_max(binary(), pos_integer()) -> ok | {error, not_found}.
set_sse_buffer_max(SessionId, Max) when is_integer(Max), Max > 0 ->
    gen_server:call(?MODULE, {set_sse_buffer_max, SessionId, Max}).

%% @doc Push a `notifications/progress' envelope to a specific
%% session over its SSE channel. `Token' is the progressToken the
%% client supplied on the originating request.
-spec notify_progress(binary(), term(), number(), number() | undefined) -> ok.
notify_progress(SessionId, Token, Progress, Total) ->
    case get_sse_pid(SessionId) of
        {ok, Pid} ->
            Params0 = #{
                <<"progressToken">> => Token,
                <<"progress">> => Progress
            },
            Params =
                case Total of
                    undefined -> Params0;
                    _ -> Params0#{<<"total">> => Total}
                end,
            Pid !
                {sse_send_message, #{
                    <<"jsonrpc">> => <<"2.0">>,
                    <<"method">> => <<"notifications/progress">>,
                    <<"params">> => Params
                }},
            ok;
        _ ->
            ok
    end.

%% @doc Cleanup sessions older than TTL milliseconds. Routes through
%% the gen_server (the table owner under the new `protected'
%% visibility); the handler deletes expired entries inline.
-spec cleanup_expired(pos_integer()) -> non_neg_integer().
cleanup_expired(TTL) ->
    gen_server:call(?MODULE, {cleanup_expired, TTL}).

%% Trim a newest-first list to at most `Max' entries.
trim(List, Max) when length(List) =< Max -> List;
trim(List, Max) -> lists:sublist(List, Max).

%% Inline session delete, only called from inside the gen_server.
delete_inline(SessionId) ->
    case ets:lookup(?SESSION_TABLE, SessionId) of
        [{_, #mcp_session{sse_pid = Pid}}] when is_pid(Pid) ->
            Pid ! session_terminated;
        _ ->
            ok
    end,
    true = ets:delete(?SESSION_TABLE, SessionId),
    ok.

%%====================================================================
%% gen_server callbacks
%%====================================================================

init([]) ->
    %% Create ETS tables if they don't exist
    _ = ensure_session_table(),
    _ = ensure_subs_table(),
    _ = ensure_pending_table(),
    _ = ensure_inflight_table(),
    %% Schedule periodic cleanup
    _ = erlang:send_after(?CLEANUP_INTERVAL, self(), cleanup),
    {ok, #{}}.

handle_call({create, Opts}, _From, State) ->
    SessionId = generate_id(),
    Now = erlang:system_time(millisecond),
    Session = #mcp_session{
        id = SessionId,
        created_at = Now,
        last_activity = Now,
        client_info = maps:get(client_info, Opts, #{}),
        client_capabilities = maps:get(client_capabilities, Opts, #{}),
        protocol_version = maps:get(protocol_version, Opts, <<"2025-03-26">>),
        sse_pid = undefined
    },
    true = ets:insert(?SESSION_TABLE, {SessionId, Session}),
    {reply, {ok, SessionId}, State};
handle_call({update_activity, SessionId}, _From, State) ->
    Reply =
        case ets:lookup(?SESSION_TABLE, SessionId) of
            [{_, Session}] ->
                Now = erlang:system_time(millisecond),
                Updated = Session#mcp_session{last_activity = Now},
                true = ets:insert(?SESSION_TABLE, {SessionId, Updated}),
                ok;
            [] ->
                {error, not_found}
        end,
    {reply, Reply, State};
handle_call({delete, SessionId}, _From, State) ->
    case ets:lookup(?SESSION_TABLE, SessionId) of
        [{_, #mcp_session{sse_pid = Pid}}] when is_pid(Pid) ->
            Pid ! session_terminated;
        _ ->
            ok
    end,
    true = ets:delete(?SESSION_TABLE, SessionId),
    {reply, ok, State};
handle_call({set_client_capabilities, SessionId, Caps}, _From, State) ->
    Reply =
        case ets:lookup(?SESSION_TABLE, SessionId) of
            [{_, Session}] ->
                Updated = Session#mcp_session{client_capabilities = Caps},
                true = ets:insert(?SESSION_TABLE, {SessionId, Updated}),
                ok;
            [] ->
                {error, not_found}
        end,
    {reply, Reply, State};
handle_call({set_log_level, SessionId, Level}, _From, State) ->
    Reply =
        case ets:lookup(?SESSION_TABLE, SessionId) of
            [{_, Session}] ->
                Updated = Session#mcp_session{log_level = Level},
                true = ets:insert(?SESSION_TABLE, {SessionId, Updated}),
                ok;
            [] ->
                {error, not_found}
        end,
    {reply, Reply, State};
handle_call({set_protocol_version, SessionId, Version}, _From, State) ->
    Reply =
        case ets:lookup(?SESSION_TABLE, SessionId) of
            [{_, Session}] ->
                Updated = Session#mcp_session{protocol_version = Version},
                true = ets:insert(?SESSION_TABLE, {SessionId, Updated}),
                ok;
            [] ->
                {error, not_found}
        end,
    {reply, Reply, State};
handle_call({set_sse_pid, SessionId, Pid}, _From, State) ->
    Reply =
        case ets:lookup(?SESSION_TABLE, SessionId) of
            [{_, Session}] ->
                Updated = Session#mcp_session{sse_pid = Pid},
                true = ets:insert(?SESSION_TABLE, {SessionId, Updated}),
                ok;
            [] ->
                {error, not_found}
        end,
    {reply, Reply, State};
handle_call({subscribe_resource, SessionId, Uri}, _From, State) ->
    true = ets:insert(?SUBSCRIPTIONS_TABLE, {{SessionId, Uri}}),
    {reply, ok, State};
handle_call({unsubscribe_resource, SessionId, Uri}, _From, State) ->
    true = ets:delete(?SUBSCRIPTIONS_TABLE, {SessionId, Uri}),
    {reply, ok, State};
handle_call({register_pending, RequestId, Pending}, _From, State) ->
    true = ets:insert(?PENDING_TABLE, {RequestId, Pending}),
    {reply, ok, State};
handle_call({discard_pending, RequestId}, _From, State) ->
    true = ets:delete(?PENDING_TABLE, RequestId),
    {reply, ok, State};
handle_call({deliver_response, Key, Response}, _From, State) ->
    Reply =
        case ets:lookup(?PENDING_TABLE, Key) of
            [{_, #pending{caller = Caller, caller_ref = Ref, tag = Tag}}] ->
                true = ets:delete(?PENDING_TABLE, Key),
                Caller ! {Tag, Ref, Response},
                ok;
            [] ->
                {error, unknown_id}
        end,
    {reply, Reply, State};
handle_call(
    {record_in_flight, SessionId, RequestId, Worker, Waiter},
    _From,
    State
) ->
    InFlight = #in_flight{
        session_id = SessionId,
        request_id = RequestId,
        worker_pid = Worker,
        waiter_pid = Waiter
    },
    true = ets:insert(?INFLIGHT_TABLE, {{SessionId, RequestId}, InFlight}),
    {reply, ok, State};
handle_call({cancel_in_flight, SessionId, RequestId}, _From, State) ->
    case ets:lookup(?INFLIGHT_TABLE, {SessionId, RequestId}) of
        [{_, #in_flight{worker_pid = W, waiter_pid = Wt}}] ->
            _ =
                (try
                    W ! {cancel, RequestId}
                catch
                    _:_ -> ok
                end),
            _ =
                (try
                    Wt ! {cancelled, RequestId}
                catch
                    _:_ -> ok
                end),
            true = ets:delete(?INFLIGHT_TABLE, {SessionId, RequestId});
        [] ->
            ok
    end,
    {reply, ok, State};
handle_call({clear_in_flight, SessionId, RequestId}, _From, State) ->
    true = ets:delete(?INFLIGHT_TABLE, {SessionId, RequestId}),
    {reply, ok, State};
handle_call({record_sse_event, SessionId, EventId, Payload}, _From, State) ->
    case ets:lookup(?SESSION_TABLE, SessionId) of
        [{_, #mcp_session{sse_buffer = Buf, sse_buffer_max = Max} = S}] ->
            NewBuf = trim([{EventId, Payload} | Buf], Max),
            true = ets:insert(
                ?SESSION_TABLE,
                {SessionId, S#mcp_session{sse_buffer = NewBuf}}
            ),
            ok;
        [] ->
            ok
    end,
    {reply, ok, State};
handle_call({set_sse_buffer_max, SessionId, Max}, _From, State) ->
    Reply =
        case ets:lookup(?SESSION_TABLE, SessionId) of
            [{_, S}] ->
                true = ets:insert(
                    ?SESSION_TABLE,
                    {SessionId, S#mcp_session{sse_buffer_max = Max}}
                ),
                ok;
            [] ->
                {error, not_found}
        end,
    {reply, Reply, State};
handle_call({cleanup_expired, TTL}, _From, State) ->
    Now = erlang:system_time(millisecond),
    Cutoff = Now - TTL,
    Expired = ets:foldl(
        fun
            ({Id, #mcp_session{last_activity = LA}}, Acc) when
                LA < Cutoff
            ->
                [Id | Acc];
            (_, Acc) ->
                Acc
        end,
        [],
        ?SESSION_TABLE
    ),
    lists:foreach(fun delete_inline/1, Expired),
    {reply, length(Expired), State};
handle_call(_Request, _From, State) ->
    {reply, {error, unknown_request}, State}.

handle_cast(_Msg, State) ->
    {noreply, State}.

handle_info(cleanup, State) ->
    %% Inline the cleanup. We can't call the public `cleanup_expired/1'
    %% because it goes through `gen_server:call(?MODULE, …)' — a
    %% self-call that would deadlock.
    TTL = application:get_env(barrel_mcp, session_ttl, 1800000),
    Now = erlang:system_time(millisecond),
    Cutoff = Now - TTL,
    Expired = ets:foldl(
        fun
            ({Id, #mcp_session{last_activity = LA}}, Acc) when
                LA < Cutoff
            ->
                [Id | Acc];
            (_, Acc) ->
                Acc
        end,
        [],
        ?SESSION_TABLE
    ),
    lists:foreach(fun delete_inline/1, Expired),
    case Expired of
        [] ->
            ok;
        _ ->
            logger:debug(
                "Cleaned up ~p expired MCP sessions",
                [length(Expired)]
            )
    end,
    erlang:send_after(?CLEANUP_INTERVAL, self(), cleanup),
    {noreply, State};
handle_info(_Info, State) ->
    {noreply, State}.

terminate(_Reason, _State) ->
    ok.

%%====================================================================
%% Internal functions
%%====================================================================

session_to_map(#mcp_session{
    id = Id,
    created_at = CreatedAt,
    last_activity = LastActivity,
    client_info = ClientInfo,
    client_capabilities = Caps,
    protocol_version = ProtocolVersion,
    sse_pid = SsePid
}) ->
    #{
        id => Id,
        created_at => CreatedAt,
        last_activity => LastActivity,
        client_info => ClientInfo,
        client_capabilities => Caps,
        protocol_version => ProtocolVersion,
        sse_pid => SsePid
    }.

%% ============================================================================
%% Internal helpers (table init + sampling implementation)
%% ============================================================================

ensure_session_table() ->
    case ets:whereis(?SESSION_TABLE) of
        undefined ->
            ets:new(?SESSION_TABLE, [
                named_table,
                protected,
                set,
                {read_concurrency, true},
                {write_concurrency, true}
            ]);
        _ ->
            ok
    end.

ensure_subs_table() ->
    case ets:whereis(?SUBSCRIPTIONS_TABLE) of
        undefined ->
            ets:new(?SUBSCRIPTIONS_TABLE, [
                named_table,
                protected,
                set,
                {read_concurrency, true}
            ]);
        _ ->
            ok
    end.

ensure_pending_table() ->
    case ets:whereis(?PENDING_TABLE) of
        undefined ->
            ets:new(?PENDING_TABLE, [
                named_table,
                protected,
                set,
                {read_concurrency, true}
            ]);
        _ ->
            ok
    end.

ensure_inflight_table() ->
    case ets:whereis(?INFLIGHT_TABLE) of
        undefined ->
            ets:new(?INFLIGHT_TABLE, [
                named_table,
                protected,
                set,
                {read_concurrency, true}
            ]);
        _ ->
            ok
    end.

do_sampling(SessionId, SsePid, Params, Opts) ->
    Timeout = maps:get(timeout_ms, Opts, ?DEFAULT_SAMPLING_TIMEOUT),
    RequestId = generate_request_id(<<"sampling-">>),
    Ref = make_ref(),
    ok = gen_server:call(
        ?MODULE,
        {register_pending, RequestId, #pending{
            id = RequestId,
            session_id = SessionId,
            caller = self(),
            caller_ref = Ref,
            expires_at = erlang:system_time(millisecond) + Timeout,
            tag = sampling_response
        }}
    ),
    Request = #{
        <<"jsonrpc">> => <<"2.0">>,
        <<"id">> => RequestId,
        <<"method">> => <<"sampling/createMessage">>,
        <<"params">> => Params
    },
    SsePid ! {sse_send_message, Request},
    receive
        {sampling_response, Ref, #{<<"result">> := Result} = R} ->
            Usage = maps:get(<<"usage">>, Result, maps:get(usage, R, #{})),
            {ok, Result, Usage};
        {sampling_response, Ref, #{<<"error">> := Err}} ->
            {error, {client_error, Err}}
    after Timeout ->
        _ = gen_server:call(?MODULE, {discard_pending, RequestId}),
        {error, timeout}
    end.

do_elicit(SessionId, SsePid, Params, Opts) ->
    Timeout = maps:get(timeout_ms, Opts, ?DEFAULT_SAMPLING_TIMEOUT),
    RequestId = generate_request_id(<<"elicit-">>),
    Ref = make_ref(),
    ok = gen_server:call(
        ?MODULE,
        {register_pending, RequestId, #pending{
            id = RequestId,
            session_id = SessionId,
            caller = self(),
            caller_ref = Ref,
            expires_at = erlang:system_time(millisecond) + Timeout,
            tag = elicitation_response
        }}
    ),
    Request = #{
        <<"jsonrpc">> => <<"2.0">>,
        <<"id">> => RequestId,
        <<"method">> => <<"elicitation/create">>,
        <<"params">> => Params
    },
    SsePid ! {sse_send_message, Request},
    receive
        {elicitation_response, Ref, #{<<"result">> := Result}} ->
            {ok, Result};
        {elicitation_response, Ref, #{<<"error">> := Err}} ->
            {error, {client_error, Err}}
    after Timeout ->
        _ = gen_server:call(?MODULE, {discard_pending, RequestId}),
        {error, timeout}
    end.

do_roots_list(SessionId, SsePid, Opts) ->
    Timeout = maps:get(timeout_ms, Opts, ?DEFAULT_SAMPLING_TIMEOUT),
    RequestId = generate_request_id(<<"roots-">>),
    Ref = make_ref(),
    ok = gen_server:call(
        ?MODULE,
        {register_pending, RequestId, #pending{
            id = RequestId,
            session_id = SessionId,
            caller = self(),
            caller_ref = Ref,
            expires_at = erlang:system_time(millisecond) + Timeout,
            tag = roots_response
        }}
    ),
    Request = #{
        <<"jsonrpc">> => <<"2.0">>,
        <<"id">> => RequestId,
        <<"method">> => <<"roots/list">>,
        <<"params">> => #{}
    },
    SsePid ! {sse_send_message, Request},
    receive
        {roots_response, Ref, #{<<"result">> := Result}} ->
            Roots = maps:get(<<"roots">>, Result, []),
            {ok, Roots};
        {roots_response, Ref, #{<<"error">> := Err}} ->
            {error, {client_error, Err}}
    after Timeout ->
        _ = gen_server:call(?MODULE, {discard_pending, RequestId}),
        {error, timeout}
    end.

generate_request_id(Prefix) ->
    <<Prefix/binary, (integer_to_binary(erlang:unique_integer([positive])))/binary>>.

id_to_binary(Id) when is_binary(Id) -> Id;
id_to_binary(Id) when is_integer(Id) -> integer_to_binary(Id);
id_to_binary(Id) -> iolist_to_binary(io_lib:format("~p", [Id])).