Skip to main content

lib/scoria/eval/online_score_sampler.ex

defmodule Scoria.Eval.OnlineScoreSampler do
  @moduledoc false

  alias Scoria.Repo
  alias Scoria.Repo.Trace
  alias Scoria.Workflows.{Run, Step}

  @default_sample_reason "production_sample"
  @default_sampler_version "online-score-sampler@v1"

  def schedule_sample(attrs, opts \\ []) when is_map(attrs) do
    with {:ok, payload} <- normalize_payload(attrs),
         {:ok, %Trace{} = trace} <- fetch_trace(payload.trace_id),
         :ok <- validate_workflow_lineage(payload, trace),
         :ok <- ensure_production_trace(trace) do
      coordinator = Application.get_env(:scoria, :online_scoring_module, Scoria.Eval.OnlineScoring)
      caller = self()

      Task.start(fn ->
        allow_test_processes(opts, caller)
        _ = coordinator.enqueue_sampled_trace(payload, opts)
      end)

      {:ok, %{status: :scheduled, trace_id: payload.trace_id}}
    else
      {:ignored, reason} -> {:ok, %{status: :ignored, reason: reason}}
      {:error, _} = error -> error
    end
  end

  defp normalize_payload(attrs) do
    trace_id = fetch_attr!(attrs, :trace_id)
    tenant_id = fetch_attr!(attrs, :tenant_id)
    workflow_run_id = fetch_attr!(attrs, :workflow_run_id)
    workflow_step_id = fetch_attr!(attrs, :workflow_step_id)
    sample_window = fetch_attr!(attrs, :sample_window)

    scorer =
      attrs
      |> fetch_attr!(:scorer)
      |> normalize_map()

    evidence_refs =
      attrs
      |> fetch_attr(:evidence_refs, %{})
      |> normalize_map()

    promotion_snapshot =
      attrs
      |> fetch_attr(:promotion_snapshot, %{})
      |> normalize_map()

    sample_reason = fetch_attr(attrs, :sample_reason, @default_sample_reason)
    sampler_version = fetch_attr(attrs, :sampler_version, @default_sampler_version)
    dedupe_key = "#{tenant_id}:#{trace_id}:#{sample_window}"

    {:ok,
     %{
       trace_id: trace_id,
       tenant_id: tenant_id,
       workflow_run_id: workflow_run_id,
       workflow_step_id: workflow_step_id,
       dedupe_key: dedupe_key,
       scorer: scorer,
       evidence_refs: evidence_refs,
       promotion_snapshot: promotion_snapshot,
       sampling_metadata: %{
         "sample_reason" => sample_reason,
         "sample_window" => sample_window,
         "sampler_version" => sampler_version,
         "dedupe_key" => dedupe_key
       }
     }}
  rescue
    error in [ArgumentError] -> {:error, error}
  end

  defp fetch_trace(trace_id) do
    case Repo.get(Trace, trace_id) do
      %Trace{} = trace -> {:ok, trace}
      nil -> {:error, ArgumentError.exception("trace_id must reference a persisted trace")}
    end
  end

  defp validate_workflow_lineage(payload, %Trace{} = trace) do
    with %Run{} = run <- Repo.get(Run, payload.workflow_run_id),
         %Step{} = step <- Repo.get(Step, payload.workflow_step_id),
         true <- step.run_id == run.id,
         true <- run.session_id == trace.session_id,
         true <- run.tenant_id == payload.tenant_id do
      :ok
    else
      nil ->
        {:error, ArgumentError.exception("workflow lineage must reference persisted run and step rows")}

      false ->
        {:error,
         ArgumentError.exception(
           "workflow lineage must belong to the sampled trace session and tenant"
         )}
    end
  end

  defp ensure_production_trace(%Trace{} = trace) do
    env =
      trace.attributes
      |> normalize_map()
      |> Map.get("env")

    if env in ["prod", "production"], do: :ok, else: {:ignored, :ineligible_trace}
  end

  defp allow_test_processes(opts, caller) do
    repo = Keyword.get(opts, :sandbox_repo)

    owner =
      Keyword.get_lazy(opts, :sandbox_owner, fn ->
        Keyword.get(opts, :notify, caller)
      end)

    if repo && owner do
      Ecto.Adapters.SQL.Sandbox.allow(repo, owner, self())
    end
  rescue
    _ -> :ok
  end

  defp fetch_attr!(attrs, key) do
    case fetch_attr(attrs, key) do
      nil -> raise ArgumentError, "missing required online scoring attribute #{key}"
      value -> value
    end
  end

  defp fetch_attr(attrs, key, default \\ nil) when is_map(attrs) do
    Map.get(attrs, key) || Map.get(attrs, Atom.to_string(key)) || default
  end

  defp normalize_map(nil), do: %{}

  defp normalize_map(map) when is_map(map) do
    Map.new(map, fn {key, value} ->
      {to_string(key), normalize_value(value)}
    end)
  end

  defp normalize_value(value) when is_map(value), do: normalize_map(value)
  defp normalize_value(value) when is_list(value), do: Enum.map(value, &normalize_value/1)
  defp normalize_value(value), do: value
end