Skip to main content

src/barrel_mcp_tasks.erl

%%%-------------------------------------------------------------------
%%% @doc Long-running operation registry (MCP tasks).
%%%
%%% Tools registered with `long_running => true' return immediately
%%% with a `taskId' instead of synchronously producing a result. The
%%% worker continues in the background; clients poll via
%%% `tasks/get', enumerate via `tasks/list', and abort via
%%% `tasks/cancel'. State transitions emit
%%% `notifications/tasks/status' on the session's SSE channel.
%%%
%%% Tasks live in a `protected' ETS table keyed by
%%% `{SessionId, TaskId}'. A periodic sweep evicts terminal tasks
%%% (success / error / cancelled) older than `?TASK_TTL'.
%%% @end
%%%-------------------------------------------------------------------
-module(barrel_mcp_tasks).

-behaviour(gen_server).

-export([
    start_link/0,
    create/3,
    get/2,
    list/2,
    cancel/2,
    finish/3,
    fail/3,
    set_worker/3
]).

-export([
    init/1,
    handle_call/3,
    handle_cast/2,
    handle_info/2,
    terminate/2
]).

-define(TABLE, barrel_mcp_tasks_table).
%% 1 hour
-define(TASK_TTL, 3600 * 1000).
%% 1 minute
-define(SWEEP_INTERVAL, 60 * 1000).

-record(task, {
    id :: binary(),
    session_id :: binary() | undefined,
    method :: binary(),
    %% Spec vocabulary (MCP 2025-11-25):
    %%   submitted | working | completed | failed | cancelled
    %% We don't model `submitted' today (workers start immediately),
    %% so the initial state is `working' and terminal states are
    %% `completed', `failed', `cancelled'.
    status :: working | completed | failed | cancelled,
    result :: term(),
    error :: term(),
    created_at :: integer(),
    updated_at :: integer(),
    worker_pid :: pid() | undefined,
    request_id :: integer() | binary() | undefined
}).

%%====================================================================
%% Public API
%%====================================================================

start_link() ->
    gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).

%% @doc Create a new running task. Returns the task id.
-spec create(
    SessionId :: binary() | undefined,
    Method :: binary(),
    Opts :: map()
) -> {ok, binary()}.
create(SessionId, Method, _Opts) ->
    gen_server:call(?MODULE, {create, SessionId, Method}).

-spec get(SessionId :: binary() | undefined, TaskId :: binary()) ->
    {ok, map()} | {error, not_found}.
get(SessionId, TaskId) ->
    case ets:lookup(?TABLE, {SessionId, TaskId}) of
        [{_, Task}] -> {ok, task_to_map(Task)};
        [] -> {error, not_found}
    end.

-spec list(SessionId :: binary() | undefined, map()) -> {ok, [map()]}.
list(SessionId, _Opts) ->
    Tasks = ets:foldl(
        fun
            ({{S, _}, T}, Acc) when S =:= SessionId -> [task_to_map(T) | Acc];
            (_, Acc) -> Acc
        end,
        [],
        ?TABLE
    ),
    {ok, Tasks}.

%% @doc Mark a task as cancelled and notify the client. Sends
%% `{cancel, RequestId}' to the worker pid (if recorded) so
%% cooperative arity-2 handlers can abort.
-spec cancel(binary() | undefined, binary()) -> ok | {error, not_found}.
cancel(SessionId, TaskId) ->
    gen_server:call(?MODULE, {cancel, SessionId, TaskId}).

%% @doc Record the worker pid (and optional originating request id)
%% on a running task so a later `tasks/cancel' can stop it.
-spec set_worker(
    binary() | undefined,
    binary(),
    #{
        worker := pid(),
        request_id => integer() | binary()
    }
) ->
    ok | {error, not_found}.
set_worker(SessionId, TaskId, Info) ->
    gen_server:call(?MODULE, {set_worker, SessionId, TaskId, Info}).

%% @doc Record success: store the result and emit notifications/tasks/status.
-spec finish(binary() | undefined, binary(), term()) -> ok | {error, not_found}.
finish(SessionId, TaskId, Result) ->
    gen_server:call(?MODULE, {finish, SessionId, TaskId, Result}).

%% @doc Record failure: store the error and emit notification.
-spec fail(binary() | undefined, binary(), term()) -> ok | {error, not_found}.
fail(SessionId, TaskId, Reason) ->
    gen_server:call(?MODULE, {fail, SessionId, TaskId, Reason}).

%%====================================================================
%% gen_server
%%====================================================================

init([]) ->
    _ = ensure_table(),
    erlang:send_after(?SWEEP_INTERVAL, self(), sweep),
    {ok, #{}}.

handle_call({create, SessionId, Method}, _From, State) ->
    Now = erlang:system_time(millisecond),
    TaskId = generate_id(),
    Task = #task{
        id = TaskId,
        session_id = SessionId,
        method = Method,
        status = working,
        created_at = Now,
        updated_at = Now
    },
    true = ets:insert(?TABLE, {{SessionId, TaskId}, Task}),
    notify_changed(SessionId, Task),
    {reply, {ok, TaskId}, State};
handle_call({cancel, SessionId, TaskId}, _From, State) ->
    %% Best-effort: send the worker a cooperative cancel signal so
    %% arity-2 tool handlers can short-circuit. Then transition
    %% the stored status. Arity-1 handlers run to completion;
    %% their result is dropped because the task is already in a
    %% terminal state.
    _ =
        case ets:lookup(?TABLE, {SessionId, TaskId}) of
            [{_, #task{worker_pid = Pid, request_id = ReqId}}] when
                is_pid(Pid)
            ->
                try
                    Pid ! {cancel, ReqId}
                catch
                    _:_ -> ok
                end;
            _ ->
                ok
        end,
    Reply = transition(SessionId, TaskId, cancelled, undefined, undefined),
    {reply, Reply, State};
handle_call({set_worker, SessionId, TaskId, Info}, _From, State) ->
    Reply =
        case ets:lookup(?TABLE, {SessionId, TaskId}) of
            [{_, #task{} = Task}] ->
                Updated = Task#task{
                    worker_pid = maps:get(worker, Info),
                    request_id = maps:get(request_id, Info, undefined)
                },
                true = ets:insert(?TABLE, {{SessionId, TaskId}, Updated}),
                ok;
            [] ->
                {error, not_found}
        end,
    {reply, Reply, State};
handle_call({finish, SessionId, TaskId, Result}, _From, State) ->
    Reply = transition(SessionId, TaskId, completed, Result, undefined),
    {reply, Reply, State};
handle_call({fail, SessionId, TaskId, Reason}, _From, State) ->
    Reply = transition(SessionId, TaskId, failed, undefined, Reason),
    {reply, Reply, State};
handle_call(_, _, State) ->
    {reply, {error, unknown_request}, State}.

handle_cast(_Msg, State) -> {noreply, State}.

handle_info(sweep, State) ->
    Now = erlang:system_time(millisecond),
    Cutoff = Now - ?TASK_TTL,
    Drop = ets:foldl(
        fun
            ({_, #task{status = working}}, Acc) -> Acc;
            ({Key, #task{updated_at = U}}, Acc) when U < Cutoff -> [Key | Acc];
            (_, Acc) -> Acc
        end,
        [],
        ?TABLE
    ),
    lists:foreach(fun(K) -> ets:delete(?TABLE, K) end, Drop),
    erlang:send_after(?SWEEP_INTERVAL, self(), sweep),
    {noreply, State};
handle_info(_, State) ->
    {noreply, State}.

terminate(_Reason, _State) -> ok.

%%====================================================================
%% Internal
%%====================================================================

ensure_table() ->
    case ets:whereis(?TABLE) of
        undefined ->
            ets:new(?TABLE, [
                named_table,
                protected,
                set,
                {read_concurrency, true}
            ]);
        _ ->
            ok
    end.

generate_id() ->
    Rand = crypto:strong_rand_bytes(16),
    Hex = binary:encode_hex(Rand, lowercase),
    <<"task_", Hex/binary>>.

transition(SessionId, TaskId, Status, Result, Reason) ->
    case ets:lookup(?TABLE, {SessionId, TaskId}) of
        [{_, #task{status = working} = Task}] ->
            Now = erlang:system_time(millisecond),
            Updated = Task#task{
                status = Status,
                result = Result,
                error = Reason,
                updated_at = Now
            },
            true = ets:insert(?TABLE, {{SessionId, TaskId}, Updated}),
            notify_changed(SessionId, Updated),
            ok;
        [{_, _}] ->
            %% Already terminal — idempotent.
            ok;
        [] ->
            {error, not_found}
    end.

notify_changed(undefined, _) ->
    ok;
notify_changed(SessionId, #task{} = Task) ->
    case barrel_mcp_session:get_sse_pid(SessionId) of
        {ok, Pid} when is_pid(Pid) ->
            Pid !
                {sse_send_message, #{
                    <<"jsonrpc">> => <<"2.0">>,
                    <<"method">> => <<"notifications/tasks/status">>,
                    <<"params">> => task_to_map(Task)
                }},
            ok;
        _ ->
            ok
    end.

task_to_map(#task{
    id = Id,
    session_id = Sid,
    method = M,
    status = St,
    result = R,
    error = E,
    created_at = C,
    updated_at = U
}) ->
    Base = #{
        <<"taskId">> => Id,
        <<"method">> => M,
        <<"status">> => atom_to_binary(St, utf8),
        <<"createdAt">> => to_rfc3339(C),
        <<"lastUpdatedAt">> => to_rfc3339(U),
        %% `ttl' is the requested retention duration in milliseconds,
        %% null when unlimited. We don't yet honour a client-supplied
        %% TTL, so we report null.
        <<"ttl">> => null
    },
    Base1 =
        case Sid of
            undefined -> Base;
            _ -> Base#{<<"sessionId">> => Sid}
        end,
    Base2 =
        case St =:= completed of
            true when R =/= undefined -> Base1#{<<"result">> => R};
            _ -> Base1
        end,
    case St =:= failed of
        true when E =/= undefined ->
            Base2#{<<"error">> => format_error(E)};
        _ ->
            Base2
    end.

to_rfc3339(Ms) when is_integer(Ms) ->
    iolist_to_binary(
        calendar:system_time_to_rfc3339(
            Ms,
            [
                {unit, millisecond},
                {offset, "Z"}
            ]
        )
    ).

format_error(B) when is_binary(B) -> B;
format_error(T) -> iolist_to_binary(io_lib:format("~p", [T])).