%%%-------------------------------------------------------------------
%%% @doc MCP client for connecting to external MCP servers.
%%%
%%% A `gen_statem' that owns one connection to one MCP server. Two
%%% transports are supported: stdio (subprocess) and Streamable HTTP
%%% (POST + SSE GET).
%%%
%%% States:
%%% <ul>
%%% <li>`connecting' — transport is opening.</li>
%%% <li>`initializing' — `initialize' request in flight.</li>
%%% <li>`ready' — handshake complete; calls accepted.</li>
%%% <li>`closing' — owner asked to close.</li>
%%% </ul>
%%%
%%% Inbound JSON-RPC envelopes from the transport are routed by
%%% `decode_envelope/1':
%%% <ul>
%%% <li>response/error with `id' — match against the pending-request
%%% table, post the result to the waiting caller.</li>
%%% <li>request with `id' — dispatch to the configured
%%% `barrel_mcp_client_handler' module; reply (sync or async)
%%% goes back over the same transport.</li>
%%% <li>notification (no `id') — dispatch to handler; resource
%%% update notifications are also routed to subscribers.</li>
%%% </ul>
%%%
%%% Server-side host application code never sees the transport
%%% layer; it talks to this module via the API below. Whether to bind
%%% an LLM provider (Anthropic, OpenAI, Hermes-style local model) into
%%% this loop is the host's job — `barrel_mcp' itself stays a pure
%%% MCP library.
%%% @end
%%%-------------------------------------------------------------------
-module(barrel_mcp_client).
-behaviour(gen_statem).
-include("barrel_mcp.hrl").
%% Public API
-export([
start_link/1,
start/1,
close/1,
%% Tools
list_tools/1, list_tools/2,
list_tools_all/1,
call_tool/3, call_tool/4,
%% Resources
list_resources/1, list_resources/2,
list_resources_all/1,
list_resource_templates/1, list_resource_templates/2,
list_resource_templates_all/1,
read_resource/2,
subscribe/2,
unsubscribe/2,
%% Prompts
list_prompts/1, list_prompts/2,
list_prompts_all/1,
get_prompt/3,
%% Tasks (long-running operations, MCP 2025-11-25)
tasks_list/1, tasks_list/2,
tasks_list_all/1,
tasks_get/2,
tasks_cancel/2,
tasks_result/2,
%% Misc
complete/3,
set_log_level/2,
ping/1,
cancel/2,
notify_roots_list_changed/1,
reply_async/3,
%% Introspection
server_info/1,
server_capabilities/1,
protocol_version/1
]).
%% gen_statem callbacks
-export([callback_mode/0, init/1, terminate/3, code_change/4]).
-export([connecting/3, initializing/3, ready/3, closing/3]).
-type connect_spec() ::
#{
transport :=
{http, binary() | string()}
| {stdio, #{command := string(), args => [string()]}},
client_info => #{name => binary(), version => binary()},
capabilities => map(),
handler => {module(), term()},
auth =>
none
| {bearer, binary()}
| {oauth, map()}
| {oauth_client_credentials, map()}
| {oauth_enterprise, map()},
protocol_version => binary(),
request_timeout => pos_integer(),
init_timeout => pos_integer(),
ping_interval => pos_integer() | infinity,
ping_failure_threshold => pos_integer()
}.
-export_type([connect_spec/0]).
-define(DEFAULT_REQUEST_TIMEOUT, 30000).
-define(DEFAULT_INIT_TIMEOUT, 30000).
-define(DEFAULT_PING_TIMEOUT, 5000).
-define(DEFAULT_PING_FAILURE_THRESHOLD, 3).
-record(pending, {
caller :: init | ping | {pid(), term()},
method :: binary(),
deadline :: integer() | infinity,
progress_token :: binary() | undefined
}).
-record(data, {
spec :: connect_spec(),
transport :: {module(), pid()} | undefined,
request_id = 1 :: integer(),
pending = #{} :: #{integer() => #pending{}},
handler_mod :: module(),
handler_state :: term(),
async_replies = #{} :: #{barrel_mcp_client_handler:async_tag() => integer()},
subscriptions = #{} :: #{binary() => [pid()]},
progress = #{} :: #{binary() => pid()},
ping_failures = 0 :: non_neg_integer(),
server_capabilities :: map() | undefined,
server_info :: map() | undefined,
protocol_version :: binary() | undefined
}).
%%====================================================================
%% Public API
%%====================================================================
%% @doc Start a supervised client. Linked to the calling process.
-spec start_link(connect_spec()) -> {ok, pid()} | {error, term()}.
start_link(Spec) ->
gen_statem:start_link(?MODULE, Spec, []).
%% @doc Start an unsupervised client.
-spec start(connect_spec()) -> {ok, pid()} | {error, term()}.
start(Spec) ->
gen_statem:start(?MODULE, Spec, []).
%% @doc Close the connection.
-spec close(pid()) -> ok.
close(Pid) ->
gen_statem:cast(Pid, close).
%% @doc List tools advertised by the server. Returns a single page.
%% Use {@link list_tools/2} with `#{want_cursor => true}' or
%% {@link list_tools_all/1} to walk pagination.
-spec list_tools(pid()) -> {ok, [map()]} | {error, term()}.
list_tools(Pid) ->
list_tools(Pid, #{}).
%% @doc List tools with pagination control.
%%
%% `Opts' may contain:
%% <ul>
%% <li>`{cursor, Cursor}' — start from a previously-returned
%% `nextCursor'.</li>
%% <li>`{want_cursor, true}' — return `{ok, Items, NextCursor}' even
%% on the last page (with `undefined' for `NextCursor').</li>
%% <li>`{timeout, Ms}' — override the per-request timeout.</li>
%% </ul>
-spec list_tools(pid(), map()) ->
{ok, [map()], NextCursor :: binary() | undefined}
| {ok, [map()]}
| {error, term()}.
list_tools(Pid, Opts) ->
paged(Pid, <<"tools/list">>, <<"tools">>, Opts).
%% @doc Walk all `tools/list' pages and return the full list.
-spec list_tools_all(pid()) -> {ok, [map()]} | {error, term()}.
list_tools_all(Pid) ->
walk_all(fun(Cursor) -> list_tools(Pid, page_opts(Cursor)) end).
%% @doc Invoke a tool by name. `Args' is forwarded verbatim as the
%% JSON-RPC `arguments' field. Returns the server's `result' map,
%% which has a `<<"content">>' list of content blocks.
-spec call_tool(pid(), binary(), map()) -> {ok, map()} | {error, term()}.
call_tool(Pid, Name, Args) ->
call_tool(Pid, Name, Args, #{}).
%% @doc Invoke a tool with options.
%%
%% `Opts' may contain:
%% <ul>
%% <li>`{progress_token, Token}' — register the calling process to
%% receive `{mcp_progress, Token, Params}' messages until the
%% request settles.</li>
%% <li>`{timeout, Ms}' — override the per-request timeout
%% (`request_timeout' from the connect spec, default 30000).</li>
%% </ul>
-spec call_tool(pid(), binary(), map(), map()) -> {ok, map()} | {error, term()}.
call_tool(Pid, Name, Args, Opts) ->
Params0 = #{<<"name">> => Name, <<"arguments">> => Args},
Params = maybe_attach_progress_token(Params0, Opts),
request(Pid, <<"tools/call">>, Params, request_timeout(Opts)).
%% @doc List resources advertised by the server. Single page.
-spec list_resources(pid()) -> {ok, [map()]} | {error, term()}.
list_resources(Pid) -> list_resources(Pid, #{}).
%% @doc List resources with pagination control. Same `Opts' shape as
%% {@link list_tools/2}.
-spec list_resources(pid(), map()) ->
{ok, [map()], binary() | undefined}
| {ok, [map()]}
| {error, term()}.
list_resources(Pid, Opts) ->
paged(Pid, <<"resources/list">>, <<"resources">>, Opts).
%% @doc Walk every `resources/list' page and return the union.
-spec list_resources_all(pid()) -> {ok, [map()]} | {error, term()}.
list_resources_all(Pid) ->
walk_all(fun(Cursor) -> list_resources(Pid, page_opts(Cursor)) end).
%% @doc List resource templates advertised by the server. Single
%% page.
-spec list_resource_templates(pid()) -> {ok, [map()]} | {error, term()}.
list_resource_templates(Pid) -> list_resource_templates(Pid, #{}).
%% @doc List resource templates with pagination control. Same `Opts'
%% shape as {@link list_tools/2}.
-spec list_resource_templates(pid(), map()) ->
{ok, [map()], binary() | undefined} | {ok, [map()]} | {error, term()}.
list_resource_templates(Pid, Opts) ->
paged(Pid, <<"resources/templates/list">>, <<"resourceTemplates">>, Opts).
%% @doc Walk every `resources/templates/list' page.
-spec list_resource_templates_all(pid()) -> {ok, [map()]} | {error, term()}.
list_resource_templates_all(Pid) ->
walk_all(fun(Cursor) -> list_resource_templates(Pid, page_opts(Cursor)) end).
%% @doc Read a resource by URI.
-spec read_resource(pid(), binary()) -> {ok, map()} | {error, term()}.
read_resource(Pid, Uri) ->
request(Pid, <<"resources/read">>, #{<<"uri">> => Uri}).
%% @doc Subscribe the calling process to updates for `Uri'. The
%% calling process receives `{mcp_resource_updated, Uri, Params}' on
%% every inbound `notifications/resources/updated' for that URI until
%% it calls {@link unsubscribe/2} or the client closes.
-spec subscribe(pid(), binary()) -> {ok, map()} | {error, term()}.
subscribe(Pid, Uri) ->
case request(Pid, <<"resources/subscribe">>, #{<<"uri">> => Uri}) of
{ok, _} = Ok ->
ok = gen_statem:cast(Pid, {add_subscriber, Uri, self()}),
Ok;
Err ->
Err
end.
%% @doc Stop receiving updates for `Uri' on the calling process.
-spec unsubscribe(pid(), binary()) -> {ok, map()} | {error, term()}.
unsubscribe(Pid, Uri) ->
case request(Pid, <<"resources/unsubscribe">>, #{<<"uri">> => Uri}) of
{ok, _} = Ok ->
ok = gen_statem:cast(Pid, {remove_subscriber, Uri, self()}),
Ok;
Err ->
Err
end.
%% @doc List prompts advertised by the server. Single page.
-spec list_prompts(pid()) -> {ok, [map()]} | {error, term()}.
list_prompts(Pid) -> list_prompts(Pid, #{}).
%% @doc List prompts with pagination control. Same `Opts' shape as
%% {@link list_tools/2}.
-spec list_prompts(pid(), map()) ->
{ok, [map()], binary() | undefined}
| {ok, [map()]}
| {error, term()}.
list_prompts(Pid, Opts) ->
paged(Pid, <<"prompts/list">>, <<"prompts">>, Opts).
%% @doc Walk every `prompts/list' page.
-spec list_prompts_all(pid()) -> {ok, [map()]} | {error, term()}.
list_prompts_all(Pid) ->
walk_all(fun(Cursor) -> list_prompts(Pid, page_opts(Cursor)) end).
%% @doc Render a prompt with the given arguments.
-spec get_prompt(pid(), binary(), map()) -> {ok, map()} | {error, term()}.
get_prompt(Pid, Name, Args) ->
request(Pid, <<"prompts/get">>, #{
<<"name">> => Name,
<<"arguments">> => Args
}).
%% @doc Send `completion/complete' to ask the server to suggest values
%% for a prompt or resource template argument. `Ref' is the JSON-RPC
%% `ref' map (e.g. `#{<<"type">> => <<"ref/prompt">>, <<"name">> => N}')
%% and `Argument' is `#{<<"name">> => Key, <<"value">> => Partial}'.
-spec complete(pid(), map(), map()) -> {ok, map()} | {error, term()}.
complete(Pid, Ref, Argument) ->
request(Pid, <<"completion/complete">>, #{
<<"ref">> => Ref,
<<"argument">> => Argument
}).
%% @doc Send `logging/setLevel'. `Level' is one of `debug', `info',
%% `notice', `warning', `error', `critical', `alert', `emergency' as
%% a binary.
-spec set_log_level(pid(), binary()) -> {ok, map()} | {error, term()}.
set_log_level(Pid, Level) when is_binary(Level) ->
request(Pid, <<"logging/setLevel">>, #{<<"level">> => Level}).
%% @doc List long-running tasks owned by the connected session.
%% Single page; use {@link tasks_list/2} with `#{want_cursor =>
%% true}' or {@link tasks_list_all/1} to walk pagination.
-spec tasks_list(pid()) -> {ok, [map()]} | {error, term()}.
tasks_list(Pid) ->
tasks_list(Pid, #{}).
-spec tasks_list(pid(), map()) ->
{ok, [map()], binary() | undefined}
| {ok, [map()]}
| {error, term()}.
tasks_list(Pid, Opts) ->
paged(Pid, <<"tasks/list">>, <<"tasks">>, Opts).
%% @doc Walk every `tasks/list' page.
-spec tasks_list_all(pid()) -> {ok, [map()]} | {error, term()}.
tasks_list_all(Pid) ->
walk_all(fun(Cursor) -> tasks_list(Pid, page_opts(Cursor)) end).
%% @doc Fetch a single task by id.
-spec tasks_get(pid(), binary()) -> {ok, map()} | {error, term()}.
tasks_get(Pid, TaskId) ->
request(Pid, <<"tasks/get">>, #{<<"taskId">> => TaskId}).
%% @doc Cancel a long-running task by id. Returns `{ok, _}' on
%% acceptance; the task transitions to `cancelled' status, which the
%% server then broadcasts via `notifications/tasks/status'.
-spec tasks_cancel(pid(), binary()) -> {ok, map()} | {error, term()}.
tasks_cancel(Pid, TaskId) ->
request(Pid, <<"tasks/cancel">>, #{<<"taskId">> => TaskId}).
%% @doc Fetch the final result of a completed task. Returns the
%% task's stored `result' map; for `failed' tasks returns
%% `{error, {Code, Message}}'; for tasks still `working' returns
%% `{error, {_, <<"Task not yet complete">>}}'.
-spec tasks_result(pid(), binary()) -> {ok, map()} | {error, term()}.
tasks_result(Pid, TaskId) ->
request(Pid, <<"tasks/result">>, #{<<"taskId">> => TaskId}).
%% @doc Send a `ping' request and wait for the response.
-spec ping(pid()) -> {ok, map()} | {error, term()}.
ping(Pid) ->
request(Pid, <<"ping">>, #{}).
%% @doc Cancel a previously-issued request by id. Sends
%% `notifications/cancelled' to the server and unblocks the caller
%% with `{error, cancelled}'.
-spec cancel(pid(), integer()) -> ok.
cancel(Pid, RequestId) ->
gen_statem:cast(Pid, {cancel, RequestId}).
%% @doc Inform the connected server that the host's roots list has
%% changed. The server may follow up with `roots/list' to fetch
%% the new set. Hosts that mutate their roots after `initialize'
%% (e.g. user opened a new workspace) call this so the server
%% picks up the change without polling.
-spec notify_roots_list_changed(pid()) -> ok.
notify_roots_list_changed(Pid) ->
gen_statem:cast(Pid, notify_roots_list_changed).
%% @doc Deliver a deferred reply for a server-initiated request that
%% the handler answered with `{async, Tag, _}'. `Result' is either a
%% plain term (sent as the JSON-RPC `result') or
%% `{error, Code, Message}'.
-spec reply_async(
pid(),
term(),
term() | {error, integer(), binary()}
) -> ok.
reply_async(Pid, Tag, Result) ->
gen_statem:cast(Pid, {async_reply, Tag, Result}).
%% @doc Return the `serverInfo' map the server reported during
%% `initialize' (with keys like `<<"name">>' and `<<"version">>').
-spec server_info(pid()) -> {ok, map() | undefined}.
server_info(Pid) ->
gen_statem:call(Pid, server_info).
%% @doc Return the server capabilities map negotiated during
%% `initialize'. Useful to gate work on optional features.
-spec server_capabilities(pid()) -> {ok, map() | undefined}.
server_capabilities(Pid) ->
gen_statem:call(Pid, server_capabilities).
%% @doc Return the negotiated protocol version (e.g.
%% `<<"2025-11-25">>' or `<<"2025-03-26">>' if the server downgraded).
-spec protocol_version(pid()) -> {ok, binary() | undefined}.
protocol_version(Pid) ->
gen_statem:call(Pid, protocol_version).
%%====================================================================
%% gen_statem
%%====================================================================
callback_mode() -> state_functions.
init(Spec) ->
process_flag(trap_exit, true),
{HandlerMod, HandlerArgs} =
maps:get(handler, Spec, {barrel_mcp_client_handler_default, []}),
case HandlerMod:init(HandlerArgs) of
{ok, HState} ->
Data = #data{
spec = Spec,
handler_mod = HandlerMod,
handler_state = HState
},
{ok, connecting, Data, [{next_event, internal, open_transport}]};
{error, _} = Err ->
Err
end.
%%-- connecting -------------------------------------------------------
connecting(internal, open_transport, Data) ->
case open_transport(Data) of
{ok, Data1} ->
InitTimeout = maps:get(
init_timeout,
Data#data.spec,
?DEFAULT_INIT_TIMEOUT
),
{Id, Data2} = next_id(Data1),
Params = build_initialize_params(Data2),
send_envelope(
Data2,
barrel_mcp_protocol:encode_request(Id, <<"initialize">>, Params)
),
P = #pending{
caller = init,
method = <<"initialize">>,
deadline = deadline(InitTimeout)
},
Pending1 = (Data2#data.pending)#{Id => P},
{next_state, initializing, Data2#data{pending = Pending1}, [
{state_timeout, InitTimeout, init_timeout}
]};
{error, Reason} ->
{stop, {transport_failed, Reason}}
end;
connecting({call, From}, _Req, _Data) ->
{keep_state_and_data, [{reply, From, {error, not_ready}}]};
connecting(EventType, EventContent, Data) ->
common_handler(EventType, EventContent, Data).
%%-- initializing -----------------------------------------------------
initializing(state_timeout, init_timeout, _Data) ->
{stop, init_timeout};
initializing({call, From}, _Req, _Data) ->
{keep_state_and_data, [{reply, From, {error, not_ready}}]};
initializing(
info,
{mcp_in, Pid, Json},
#data{transport = {_, Pid}} = Data
) ->
handle_inbound(Json, initializing, Data);
initializing(EventType, EventContent, Data) ->
common_handler(EventType, EventContent, Data).
%%-- ready ------------------------------------------------------------
ready({call, From}, server_info, Data) ->
{keep_state_and_data, [{reply, From, {ok, Data#data.server_info}}]};
ready({call, From}, server_capabilities, Data) ->
{keep_state_and_data, [{reply, From, {ok, Data#data.server_capabilities}}]};
ready({call, From}, protocol_version, Data) ->
{keep_state_and_data, [{reply, From, {ok, Data#data.protocol_version}}]};
ready({call, From}, {request, Method, Params, Timeout}, Data) ->
case is_supported(Method, Data) of
false ->
{keep_state_and_data, [{reply, From, {error, {unsupported, Method}}}]};
true ->
{Id, Data1} = next_id(Data),
send_envelope(
Data1,
barrel_mcp_protocol:encode_request(Id, Method, Params)
),
ProgressToken = progress_token_from_params(Params),
{CallerPid, _Tag} = From,
Data2 =
case ProgressToken of
undefined -> Data1;
Tok -> Data1#data{progress = (Data1#data.progress)#{Tok => CallerPid}}
end,
P = #pending{
caller = From,
method = Method,
deadline = deadline(Timeout),
progress_token = ProgressToken
},
Pending = (Data2#data.pending)#{Id => P},
Actions =
case Timeout of
infinity -> [];
T -> [{{timeout, {req, Id}}, T, request_timeout}]
end,
{keep_state, Data2#data{pending = Pending}, Actions}
end;
ready(cast, {cancel, Id}, Data) ->
do_cancel(Id, Data);
ready(cast, notify_roots_list_changed, Data) ->
send_envelope(
Data,
barrel_mcp_protocol:encode_notification(
<<"notifications/roots/list_changed">>, #{}
)
),
{keep_state, Data};
ready(cast, {add_subscriber, Uri, Pid}, Data) ->
{keep_state, add_sub(Uri, Pid, Data)};
ready(cast, {remove_subscriber, Uri, Pid}, Data) ->
{keep_state, del_sub(Uri, Pid, Data)};
ready(cast, {async_reply, Tag, Result}, Data) ->
{keep_state, deliver_async_reply(Tag, Result, Data)};
ready({timeout, {req, Id}}, request_timeout, Data) ->
timeout_pending(Id, Data);
ready(state_timeout, ping_tick, Data) ->
{Data1, Actions} = issue_ping(Data),
{keep_state, Data1, Actions};
ready(
info,
{mcp_in, Pid, Json},
#data{transport = {_, Pid}} = Data
) ->
handle_inbound(Json, ready, Data);
ready(EventType, EventContent, Data) ->
common_handler(EventType, EventContent, Data).
%%-- closing ----------------------------------------------------------
closing({call, From}, _Req, _Data) ->
{keep_state_and_data, [{reply, From, {error, closing}}]};
closing(_E, _C, _Data) ->
keep_state_and_data.
%%====================================================================
%% Common event handling (transport messages, casts, etc.)
%%====================================================================
common_handler(
info,
{mcp_closed, Pid, _Reason},
#data{transport = {_, Pid}}
) ->
{stop, normal};
common_handler(info, {'EXIT', _, _}, _Data) ->
keep_state_and_data;
common_handler(cast, close, Data) ->
case Data#data.transport of
{Mod, Pid} ->
try
Mod:close(Pid)
catch
_:_ -> ok
end;
_ ->
ok
end,
{stop, normal, Data};
common_handler(_E, _C, _D) ->
keep_state_and_data.
%%====================================================================
%% Inbound message routing
%%====================================================================
handle_inbound(Json, State, Data) ->
case decode(Json) of
{request, Id, Method, Params} ->
handle_server_request(Id, Method, Params, State, Data);
{notification, Method, Params} ->
handle_server_notification(Method, Params, State, Data);
{response, Id, Result} ->
handle_response(Id, Result, State, Data);
{error, Id, Code, Message, _Data1} ->
handle_error_response(Id, Code, Message, State, Data);
_ ->
keep_state_and_data
end.
decode(Json) ->
case barrel_mcp_protocol:decode(Json) of
{ok, Map} -> barrel_mcp_protocol:decode_envelope(Map);
Err -> Err
end.
handle_response(Id, Result, initializing, Data) ->
case maps:take(Id, Data#data.pending) of
{#pending{method = <<"initialize">>}, Rest} ->
handle_initialize_result(Result, Data#data{pending = Rest});
_ ->
keep_state_and_data
end;
handle_response(Id, Result, _State, Data) ->
case maps:take(Id, Data#data.pending) of
{#pending{caller = ping} = P, Rest} ->
Data1 = settle_data(P, Data#data{pending = Rest, ping_failures = 0}),
{keep_state, Data1, [drop_req_timeout(Id)]};
{#pending{caller = From} = P, Rest} when From =/= init ->
gen_statem:reply(From, {ok, Result}),
Data1 = settle_data(P, Data#data{pending = Rest}),
{keep_state, Data1, [drop_req_timeout(Id)]};
_ ->
keep_state_and_data
end.
handle_error_response(_Id, Code, Message, initializing, _Data) ->
{stop, {init_failed, Code, Message}};
handle_error_response(Id, Code, Message, _State, Data) ->
case maps:take(Id, Data#data.pending) of
{#pending{caller = ping} = P, Rest} ->
Data1 = settle_data(P, bump_ping_failures(Data#data{pending = Rest})),
maybe_close_on_ping_failures(Data1, Id);
{#pending{caller = From} = P, Rest} when From =/= init ->
gen_statem:reply(From, {error, {Code, Message}}),
Data1 = settle_data(P, Data#data{pending = Rest}),
{keep_state, Data1, [drop_req_timeout(Id)]};
_ ->
keep_state_and_data
end.
settle_data(#pending{progress_token = Tok}, Data) ->
drop_progress(Tok, Data).
drop_progress(undefined, Data) -> Data;
drop_progress(Tok, Data) -> Data#data{progress = maps:remove(Tok, Data#data.progress)}.
drop_req_timeout(Id) ->
{{timeout, {req, Id}}, infinity, request_timeout}.
handle_server_request(
Id,
Method,
Params,
_State,
#data{handler_mod = Mod, handler_state = HS} = Data
) ->
case Mod:handle_request(Method, Params, HS) of
{reply, Result, HS1} ->
send_envelope(Data, barrel_mcp_protocol:encode_response(Id, Result)),
{keep_state, Data#data{handler_state = HS1}};
{error, Code, Msg, HS1} ->
send_envelope(Data, barrel_mcp_protocol:encode_error(Id, Code, Msg)),
{keep_state, Data#data{handler_state = HS1}};
{async, Tag, HS1} ->
Async = (Data#data.async_replies)#{Tag => Id},
{keep_state, Data#data{
handler_state = HS1,
async_replies = Async
}}
end.
deliver_async_reply(Tag, Result, Data) ->
case maps:take(Tag, Data#data.async_replies) of
{Id, Rest} ->
Envelope =
case Result of
{error, Code, Msg} ->
barrel_mcp_protocol:encode_error(Id, Code, Msg);
_ ->
barrel_mcp_protocol:encode_response(Id, Result)
end,
send_envelope(Data, Envelope),
Data#data{async_replies = Rest};
error ->
Data
end.
handle_server_notification(
<<"notifications/resources/updated">> = Method,
Params,
_State,
Data
) ->
Uri = maps:get(<<"uri">>, Params, <<>>),
notify_subscribers(Uri, Params, Data),
dispatch_notification(Method, Params, Data);
handle_server_notification(
<<"notifications/progress">> = Method,
Params,
_State,
Data
) ->
notify_progress(Params, Data),
dispatch_notification(Method, Params, Data);
handle_server_notification(Method, Params, _State, Data) ->
dispatch_notification(Method, Params, Data).
dispatch_notification(
Method,
Params,
#data{handler_mod = Mod, handler_state = HS} = Data
) ->
case Mod:handle_notification(Method, Params, HS) of
{ok, HS1} ->
{keep_state, Data#data{handler_state = HS1}}
end.
notify_subscribers(Uri, Params, Data) ->
case maps:get(Uri, Data#data.subscriptions, []) of
[] ->
ok;
Pids ->
lists:foreach(
fun(P) -> P ! {mcp_resource_updated, Uri, Params} end,
Pids
),
ok
end.
notify_progress(Params, Data) ->
case maps:get(<<"progressToken">>, Params, undefined) of
undefined ->
ok;
Tok ->
case maps:get(Tok, Data#data.progress, undefined) of
undefined ->
ok;
Pid ->
Pid ! {mcp_progress, Tok, Params},
ok
end
end.
%%====================================================================
%% Initialize handling
%%====================================================================
build_initialize_params(#data{spec = Spec}) ->
ClientInfo0 = maps:get(
client_info,
Spec,
#{
<<"name">> => <<"barrel_mcp_client">>,
<<"version">> => <<"2.2.0">>
}
),
ClientInfo = normalize_keys(ClientInfo0),
Caps = capabilities_to_wire(maps:get(capabilities, Spec, #{})),
Version = maps:get(protocol_version, Spec, ?MCP_CLIENT_PROTOCOL_VERSION),
#{
<<"protocolVersion">> => Version,
<<"capabilities">> => Caps,
<<"clientInfo">> => ClientInfo
}.
%% Sugar -> spec-shaped wire form. `true' becomes an empty object;
%% maps are passed through with binary keys.
capabilities_to_wire(Map) when is_map(Map) ->
maps:fold(
fun(K, V, Acc) ->
Acc#{cap_key(K) => cap_value(V)}
end,
#{},
Map
).
cap_key(K) when is_atom(K) -> atom_to_binary(K, utf8);
cap_key(K) when is_binary(K) -> K.
cap_value(true) ->
#{};
cap_value(false) ->
undefined;
cap_value(Map) when is_map(Map) ->
maps:fold(
fun(K, V, Acc) ->
case V of
false -> Acc;
_ -> Acc#{cap_subkey(K) => cap_subvalue(V)}
end
end,
#{},
Map
);
cap_value(_) ->
#{}.
cap_subkey(list_changed) -> <<"listChanged">>;
cap_subkey(K) when is_atom(K) -> atom_to_binary(K, utf8);
cap_subkey(K) when is_binary(K) -> K.
cap_subvalue(true) -> true;
cap_subvalue(V) -> V.
normalize_keys(Map) when is_map(Map) ->
maps:fold(
fun(K, V, Acc) ->
Key =
case K of
A when is_atom(A) -> atom_to_binary(A, utf8);
B when is_binary(B) -> B
end,
Acc#{Key => V}
end,
#{},
Map
).
handle_initialize_result(Result, Data) ->
case maps:get(<<"protocolVersion">>, Result, undefined) of
undefined ->
{stop, {init_failed, missing_protocol_version}};
Version ->
case lists:member(Version, ?MCP_CLIENT_SUPPORTED_VERSIONS) of
false ->
{stop, {protocol_version, Version, ?MCP_CLIENT_SUPPORTED_VERSIONS}};
true ->
finish_initialize(Version, Result, Data)
end
end.
finish_initialize(Version, Result, Data) ->
case Data#data.transport of
{barrel_mcp_client_http, Pid} ->
barrel_mcp_client_http:set_protocol_version(Pid, Version),
barrel_mcp_client_http:open_event_stream(Pid);
_ ->
ok
end,
send_envelope(
Data,
barrel_mcp_protocol:encode_notification(
<<"notifications/initialized">>, #{}
)
),
Data1 = Data#data{
server_capabilities = maps:get(<<"capabilities">>, Result, #{}),
server_info = maps:get(<<"serverInfo">>, Result, #{}),
protocol_version = Version
},
{next_state, ready, Data1, [arm_ping_timer(Data1)]}.
%%====================================================================
%% Transport plumbing
%%====================================================================
open_transport(#data{spec = Spec} = Data) ->
case maps:get(transport, Spec) of
{http, Url} ->
Auth = barrel_mcp_client_auth:new(maps:get(auth, Spec, none)),
case Auth of
{error, _} = Err ->
Err;
_ ->
Opts = #{
url => Url,
auth => Auth,
open_event_stream => true,
headers => maps:get(http_headers, Spec, [])
},
case barrel_mcp_client_http:connect(self(), Opts) of
{ok, Pid} ->
link(Pid),
{ok, Data#data{transport = {barrel_mcp_client_http, Pid}}};
Err ->
Err
end
end;
{stdio, StdioOpts} ->
case barrel_mcp_client_stdio:connect(self(), StdioOpts) of
{ok, Pid} ->
link(Pid),
{ok, Data#data{transport = {barrel_mcp_client_stdio, Pid}}};
Err ->
Err
end
end.
send_envelope(#data{transport = {Mod, Pid}}, Envelope) ->
Json = iolist_to_binary(json:encode(Envelope)),
Mod:send(Pid, Json);
send_envelope(_, _) ->
ok.
%%====================================================================
%% Helpers
%%====================================================================
next_id(#data{request_id = N} = Data) ->
{N, Data#data{request_id = N + 1}}.
deadline(infinity) ->
infinity;
deadline(T) when is_integer(T) ->
erlang:monotonic_time(millisecond) + T.
request(Pid, Method, Params) ->
request(Pid, Method, Params, ?DEFAULT_REQUEST_TIMEOUT).
request(Pid, Method, Params, Timeout) ->
CallTimeout =
case Timeout of
infinity -> infinity;
T when is_integer(T) -> T + 5000
end,
gen_statem:call(Pid, {request, Method, Params, Timeout}, CallTimeout).
walk_all(Fetch) ->
barrel_mcp_pagination:walk(Fetch).
page_opts(undefined) -> #{want_cursor => true};
page_opts(Cursor) -> #{cursor => Cursor, want_cursor => true}.
paged(Pid, Method, ResultKey, Opts) ->
Params =
case maps:get(cursor, Opts, undefined) of
undefined -> #{};
C -> #{<<"cursor">> => C}
end,
case request(Pid, Method, Params, request_timeout(Opts)) of
{ok, Result} ->
Items = maps:get(ResultKey, Result, []),
Next = maps:get(<<"nextCursor">>, Result, undefined),
WantCursor = map_get_default(want_cursor, Opts, false),
case {Next, WantCursor} of
{undefined, false} -> {ok, Items};
_ -> {ok, Items, Next}
end;
Err ->
Err
end.
map_get_default(K, M, D) ->
case maps:find(K, M) of
{ok, V} -> V;
error -> D
end.
request_timeout(Opts) ->
map_get_default(timeout, Opts, ?DEFAULT_REQUEST_TIMEOUT).
maybe_attach_progress_token(Params, Opts) ->
case maps:get(progress_token, Opts, undefined) of
undefined -> Params;
Tok -> Params#{<<"_meta">> => #{<<"progressToken">> => Tok}}
end.
progress_token_from_params(#{<<"_meta">> := #{<<"progressToken">> := Tok}}) ->
Tok;
progress_token_from_params(_) ->
undefined.
%%====================================================================
%% Ping cadence
%%====================================================================
ping_interval(#data{spec = Spec}) ->
case maps:get(ping_interval, Spec, infinity) of
infinity -> infinity;
N when is_integer(N), N > 0 -> N
end.
ping_failure_threshold(#data{spec = Spec}) ->
maps:get(ping_failure_threshold, Spec, ?DEFAULT_PING_FAILURE_THRESHOLD).
bump_ping_failures(#data{ping_failures = N} = Data) ->
Data#data{ping_failures = N + 1}.
maybe_close_on_ping_failures(Data, _Id) ->
case Data#data.ping_failures >= ping_failure_threshold(Data) of
true ->
case Data#data.transport of
{Mod, TPid} ->
try
Mod:close(TPid)
catch
_:_ -> ok
end;
_ ->
ok
end,
{stop, ping_failed};
false ->
{keep_state, Data, [arm_ping_timer(Data)]}
end.
arm_ping_timer(Data) ->
case ping_interval(Data) of
infinity -> {state_timeout, infinity, ping_tick};
N -> {state_timeout, N, ping_tick}
end.
issue_ping(Data) ->
{Id, Data1} = next_id(Data),
send_envelope(
Data1,
barrel_mcp_protocol:encode_request(Id, <<"ping">>, #{})
),
P = #pending{
caller = ping,
method = <<"ping">>,
deadline = deadline(?DEFAULT_PING_TIMEOUT)
},
Pending = (Data1#data.pending)#{Id => P},
{Data1#data{pending = Pending}, [
{{timeout, {req, Id}}, ?DEFAULT_PING_TIMEOUT, request_timeout},
arm_ping_timer(Data1)
]}.
is_supported(<<"initialize">>, _) ->
true;
is_supported(<<"ping">>, _) ->
true;
is_supported(<<"notifications/", _/binary>>, _) ->
true;
is_supported(_, #data{server_capabilities = undefined}) ->
false;
is_supported(<<"tools/", _/binary>>, #data{server_capabilities = Caps}) ->
maps:is_key(<<"tools">>, Caps);
is_supported(<<"resources/", _/binary>>, #data{server_capabilities = Caps}) ->
maps:is_key(<<"resources">>, Caps);
is_supported(<<"prompts/", _/binary>>, #data{server_capabilities = Caps}) ->
maps:is_key(<<"prompts">>, Caps);
is_supported(<<"completion/", _/binary>>, #data{server_capabilities = Caps}) ->
maps:is_key(<<"completions">>, Caps) orelse maps:is_key(<<"completion">>, Caps);
is_supported(<<"logging/", _/binary>>, #data{server_capabilities = Caps}) ->
maps:is_key(<<"logging">>, Caps);
is_supported(<<"tasks/", _/binary>>, #data{server_capabilities = Caps}) ->
maps:is_key(<<"tasks">>, Caps);
is_supported(_, _) ->
true.
do_cancel(Id, #data{pending = Pending} = Data) ->
case maps:take(Id, Pending) of
{#pending{caller = From} = P, Rest} when From =/= init, From =/= ping ->
gen_statem:reply(From, {error, cancelled}),
send_envelope(
Data,
barrel_mcp_protocol:encode_notification(
<<"notifications/cancelled">>,
#{<<"requestId">> => Id, <<"reason">> => <<"cancelled by client">>}
)
),
Data1 = settle_data(P, Data#data{pending = Rest}),
{keep_state, Data1, [drop_req_timeout(Id)]};
_ ->
keep_state_and_data
end.
timeout_pending(Id, #data{pending = Pending} = Data) ->
case maps:take(Id, Pending) of
{#pending{caller = ping} = P, Rest} ->
Data1 = settle_data(P, bump_ping_failures(Data#data{pending = Rest})),
send_envelope(
Data1,
barrel_mcp_protocol:encode_notification(
<<"notifications/cancelled">>,
#{<<"requestId">> => Id, <<"reason">> => <<"timeout">>}
)
),
maybe_close_on_ping_failures(Data1, Id);
{#pending{caller = From} = P, Rest} when From =/= init ->
gen_statem:reply(From, {error, timeout}),
send_envelope(
Data,
barrel_mcp_protocol:encode_notification(
<<"notifications/cancelled">>,
#{<<"requestId">> => Id, <<"reason">> => <<"timeout">>}
)
),
Data1 = settle_data(P, Data#data{pending = Rest}),
{keep_state, Data1};
_ ->
keep_state_and_data
end.
add_sub(Uri, Pid, Data) ->
Subs = Data#data.subscriptions,
Existing = maps:get(Uri, Subs, []),
Data#data{subscriptions = Subs#{Uri => lists:usort([Pid | Existing])}}.
del_sub(Uri, Pid, Data) ->
Subs = Data#data.subscriptions,
case maps:get(Uri, Subs, []) of
[] ->
Data;
L ->
case lists:delete(Pid, L) of
[] -> Data#data{subscriptions = maps:remove(Uri, Subs)};
L1 -> Data#data{subscriptions = Subs#{Uri => L1}}
end
end.
%%====================================================================
%% Termination
%%====================================================================
terminate(
Reason,
_State,
#data{handler_mod = Mod, handler_state = HS, transport = T}
) ->
case T of
{Tmod, Pid} ->
try
Tmod:close(Pid)
catch
_:_ -> ok
end;
_ ->
ok
end,
case erlang:function_exported(Mod, terminate, 2) of
true ->
try
Mod:terminate(Reason, HS)
catch
_:_ -> ok
end;
false ->
ok
end,
ok.
code_change(_OldVsn, State, Data, _Extra) ->
{ok, State, Data}.