Skip to main content

lib/pi/tool/dispatch.ex

defmodule Pi.Tool.Dispatch do
  @moduledoc "Bridge-native tool dispatch for stdio transports."

  require Pi.Features

  alias Pi.Protocol.Tool.AST.ReplaceRequest
  alias Pi.Protocol.Tool.AST.SearchRequest
  alias Pi.Protocol.Tool.EvalRequest
  alias Pi.Protocol.Tool.OutputPart

  def dispatch("project_eval", %{"mode" => "sandbox"} = args), do: eval(args, structured?: false)
  def dispatch("project_eval", args), do: eval(args, structured?: false)
  def dispatch("project_eval_structured", args), do: eval(args, structured?: true)

  def dispatch("project_eval_sandbox", args),
    do: eval(Map.put(args, "mode", "sandbox"), structured?: false)

  def dispatch("ex_ast_search", args), do: safe_ast_dispatch(fn -> ast_search(args) end)

  def dispatch("ex_ast_replace", args), do: safe_ast_dispatch(fn -> ast_replace(args) end)

  def dispatch("pi_session_cancel", %{"id" => id}) when is_binary(id) do
    Pi.Features.gate :sessions do
      with {:ok, pid} <- Pi.Session.lookup(id), :ok <- Pi.Session.cancel(pid) do
        {:ok, "ok"}
      else
        {:error, reason} -> {:error, inspect(reason)}
      end
    end
  end

  def dispatch("pi_session_rerun", %{"id" => id} = args) when is_binary(id) do
    Pi.Features.gate :sessions do
      timeout = Map.get(args, "timeout", 60_000)

      case Pi.Session.lookup(id) do
        {:ok, pid} -> Pi.Session.rerun(pid, timeout: timeout)
        {:error, reason} -> {:error, inspect(reason)}
      end
    end
  end

  def dispatch("pi_session_snapshots", _args) do
    if Pi.Features.sessions?() do
      {:ok, encode_payload(%{sessions: Pi.Session.snapshots()})}
    else
      {:ok, encode_payload(%{sessions: []})}
    end
  end

  def dispatch(name, _args), do: {:error, "Unknown tool: #{name}"}

  defp ast_search(args) do
    with {:ok, request} <- decode_request(SearchRequest, args, "pattern or patterns") do
      run_ast_search(request)
    end
  end

  defp run_ast_search(%SearchRequest{patterns: patterns} = request)
       when is_map(patterns) and map_size(patterns) > 0 do
    patterns
    |> Pi.AST.search_many(ast_opts(request))
    |> encode_result()
  end

  defp run_ast_search(%SearchRequest{pattern: pattern} = request) when is_binary(pattern) do
    pattern
    |> Pi.AST.search(ast_opts(request))
    |> encode_result()
  end

  defp run_ast_search(_request), do: {:error, "Missing required parameter: pattern or patterns"}

  defp ast_replace(args) do
    with {:ok, request} <- decode_request(ReplaceRequest, args, "pattern and replacement") do
      request.pattern
      |> Pi.AST.replace(
        request.replacement,
        Keyword.put(ast_opts(request), :dry_run, request.dry_run)
      )
      |> encode_result()
    end
  end

  defp safe_ast_dispatch(fun) do
    fun.()
  rescue
    exception in [File.Error, ArgumentError, RuntimeError, SyntaxError] ->
      {:error, Exception.message(exception)}
  end

  defp eval(args, opts) do
    with {:ok, request} <- decode_request(EvalRequest, args, "code") do
      timeout = request.timeout || eval_timeout(request.mode)
      run_eval(request, timeout, Keyword.fetch!(opts, :structured?))
    end
  end

  defp run_eval(%EvalRequest{mode: :sandbox, code: code}, timeout, true) do
    code |> Pi.Eval.sandbox(timeout: timeout) |> sandbox_payload() |> encode_result()
  end

  defp run_eval(%EvalRequest{mode: :sandbox, code: code}, timeout, false) do
    code |> Pi.Eval.sandbox(timeout: timeout) |> sandbox_result()
  end

  defp run_eval(%EvalRequest{code: code} = request, timeout, true) do
    code |> Pi.Eval.run_structured(eval_opts(request, timeout)) |> encode_result()
  end

  defp run_eval(%EvalRequest{code: code} = request, timeout, false) do
    Pi.Eval.run(code, eval_opts(request, timeout))
  end

  defp eval_opts(request, timeout) do
    [timeout: timeout, reload: request.reload]
    |> maybe_put(:session_id, request.session_id)
    |> maybe_put(:state_path, request.state_path)
    |> maybe_put(:restore_path, request.restore_path)
  end

  defp eval_timeout(:sandbox), do: 5_000
  defp eval_timeout(:trusted), do: 30_000

  defp decode_request(module, args, missing) do
    case module.from_map(args) do
      {:ok, request} -> {:ok, request}
      {:error, _reason} -> {:error, missing_parameters(missing)}
    end
  end

  defp missing_parameters(missing) do
    label = if String.contains?(missing, " and "), do: "parameters", else: "parameter"
    "Missing required #{label}: #{missing}"
  end

  defp sandbox_payload({:ok, %{stdio: stdio, inspected: inspected}}) do
    parts =
      []
      |> maybe_sandbox_io_part(stdio)
      |> Kernel.++([OutputPart.inspect(inspected, language: :elixir)])

    {:ok,
     %Pi.Protocol.Tool.Eval{
       io: stdio,
       result: inspected,
       text: sandbox_text(stdio, inspected),
       parts: parts,
       display: %Pi.Protocol.UI.Display{blocks: Enum.map(parts, &sandbox_part_block/1)}
     }}
  end

  defp sandbox_payload({:error, :unavailable}), do: {:error, "Dune sandbox is not available"}
  defp sandbox_payload({:error, message}), do: {:error, message}

  defp maybe_sandbox_io_part(parts, ""), do: parts

  defp maybe_sandbox_io_part(parts, stdio) do
    parts ++ [OutputPart.text(stdio)]
  end

  defp sandbox_part_block(%OutputPart{
         kind: kind,
         body: body,
         language: language
       }) do
    %Pi.Protocol.UI.Block{type: block_type(kind), text: body, language: language}
  end

  defp block_type(:code), do: :source
  defp block_type(kind), do: kind

  defp sandbox_text("", inspected), do: inspected
  defp sandbox_text(stdio, inspected), do: "IO:\n\n#{stdio}\n\nResult:\n\n#{inspected}"

  defp sandbox_result({:ok, %{stdio: "", inspected: inspected}}), do: {:ok, inspected}

  defp sandbox_result({:ok, %{stdio: stdio, inspected: inspected}}) do
    {:ok, sandbox_text(stdio, inspected)}
  end

  defp sandbox_result({:error, :unavailable}), do: {:error, "Dune sandbox is not available"}
  defp sandbox_result({:error, message}), do: {:error, message}

  defp ast_opts(%{
         path: path,
         inside: inside,
         not_inside: not_inside,
         allow_broad: allow_broad,
         limit: limit
       }) do
    []
    |> maybe_put(:path, path)
    |> maybe_put(:inside, inside)
    |> maybe_put(:not_inside, not_inside)
    |> maybe_put(:allow_broad, allow_broad)
    |> maybe_put(:limit, limit)
  end

  defp maybe_put(opts, _key, nil), do: opts
  defp maybe_put(opts, _key, false), do: opts
  defp maybe_put(opts, key, value), do: Keyword.put(opts, key, value)

  defp encode_result({:ok, payload}), do: {:ok, encode_payload(payload)}

  defp encode_result({:error, payload}) when is_struct(payload),
    do: {:error, encode_payload(payload)}

  defp encode_result({:error, message}), do: {:error, message}

  defp encode_payload(payload) when is_map(payload) do
    payload
    |> JSONCodec.dump()
    |> normalize()
    |> Jason.encode!()
  end

  defp normalize(map) when is_map(map) do
    Map.new(map, fn {key, value} -> {key, normalize_value(value)} end)
  end

  defp normalize_value(%_module{} = value) do
    value |> JSONCodec.dump() |> normalize()
  end

  defp normalize_value(value) when is_list(value), do: Enum.map(value, &normalize_value/1)
  defp normalize_value(value) when is_map(value), do: normalize(value)

  defp normalize_value(value) when is_atom(value) and not is_boolean(value) and not is_nil(value),
    do: Atom.to_string(value)

  defp normalize_value(value), do: value
end