Skip to main content

lib/dot_prompt/helpers.ex

defmodule DotPrompt.Helpers do
  @moduledoc false

  require Logger

  alias DotPrompt.Parser.{Lexer, Parser, Validator}
  alias DotPrompt.Compiler.ResponseCollector

  @doc false
  def prompts_dir(opts \\ []) do
    case Keyword.get(opts, :prompts_dir) || Application.get_env(:anantha_dot_prompt, :prompts_dir) do
      nil ->
        dir = "prompts"
        cwd = File.cwd!()

        cond do
          File.exists?(Path.join(cwd, dir)) ->
            Path.expand(Path.join(cwd, dir))

          File.exists?(Path.join([cwd, "dot_prompt", dir])) ->
            Path.expand(Path.join([cwd, "dot_prompt", dir]))

          File.exists?(Path.join([cwd, "..", "..", dir])) ->
            Path.expand(Path.join([cwd, "..", "..", dir]))

          true ->
            Path.expand(dir)
        end

      dir ->
        Path.expand(dir)
    end
  end

  @doc false
  def load_prompt_file_with_meta(name, major, current_dir \\ "", opts \\ []) do
    prompts_dir = prompts_dir(opts)
    name_str = to_string(name)
    {content, mtime, full_path} = resolve_prompt_path(prompts_dir, name_str, current_dir)

    if major && content do
      check_major_version(
        name,
        major,
        content,
        {content, mtime, full_path},
        prompts_dir,
        name_str
      )
    else
      if is_nil(content),
        do: raise("prompt_not_found: #{name}"),
        else: {content, mtime, full_path}
    end
  end

  defp resolve_prompt_path(prompts_dir, name_str, current_dir) do
    cond do
      String.starts_with?(name_str, "./") and current_dir != "" and current_dir != "." ->
        clean_name = String.slice(name_str, 2..-1//1)
        do_resolve(prompts_dir, Path.join(current_dir, clean_name))

      true ->
        res =
          if current_dir != "" and current_dir != "." do
            do_resolve(prompts_dir, Path.join(current_dir, name_str))
          else
            {nil, nil, nil}
          end

        if res == {nil, nil, nil} do
          do_resolve(prompts_dir, name_str)
        else
          res
        end
    end
  end

  defp do_resolve(prompts_dir, name_str) do
    safe_name =
      name_str
      |> String.trim_leading("/")
      |> String.replace("../", "")
      |> String.replace("./", "")

    path = Path.expand(Path.join(prompts_dir, safe_name))

    if String.starts_with?(path, prompts_dir) do
      cond do
        File.exists?(path) and !File.dir?(path) ->
          {File.read!(path), File.stat!(path).mtime, path}

        File.exists?(path <> ".prompt") ->
          p = path <> ".prompt"
          {File.read!(p), File.stat!(p).mtime, p}

        File.exists?(Path.join(path, "_index.prompt")) ->
          p = Path.join(path, "_index.prompt")
          {File.read!(p), File.stat!(p).mtime, p}

        true ->
          {nil, nil, nil}
      end
    else
      {nil, nil, nil}
    end
  end

  defp check_major_version(name, major, content, path_info, prompts_dir, name_str) do
    tokens = Lexer.tokenize(content)

    case Parser.parse(tokens) do
      {:ok, %{init: %{def: def_map}}} ->
        file_version = def_map[:version]
        file_major = major_from_version(file_version)

        if file_major == major do
          {content, mtime, full_path} = path_info
          {content, mtime, full_path}
        else
          find_archive_or_raise(name, major, prompts_dir, name_str)
        end

      _ ->
        if major == 1 do
          {content, mtime, full_path} = path_info
          {content, mtime, full_path}
        else
          raise "prompt_not_found: #{name} with major version #{major}"
        end
    end
  end

  defp find_archive_or_raise(name, major, prompts_dir, name_str) do
    name_parts = Path.split(name_str)

    archive_path =
      if length(name_parts) > 1 do
        dir = Path.dirname(name_str)
        base = Path.basename(name_str)
        Path.join([prompts_dir, dir, "archive", "#{base}_v#{major}.prompt"])
      else
        Path.join([prompts_dir, "archive", "#{name_str}_v#{major}.prompt"])
      end

    if File.exists?(archive_path) do
      {File.read!(archive_path), File.stat!(archive_path).mtime, archive_path}
    else
      raise "prompt_not_found: #{name} with major version #{major}"
    end
  end

  @doc false
  def cache_key_for_compile(:inline, params, content, annotated) do
    compile_params =
      params
      |> Enum.reject(fn {_, v} -> is_nil(v) end)
      |> Enum.into(%{})

    params_hash = :erlang.phash2(compile_params)
    content_hash = :erlang.phash2(content)
    {"inline", params_hash, content_hash, annotated}
  end

  def cache_key_for_compile(prompt_key, params, content, annotated) do
    compile_params =
      params
      |> Enum.reject(fn {_, v} -> is_nil(v) end)
      |> Enum.into(%{})

    content_hash = :erlang.phash2(content)
    params_hash = :erlang.phash2(compile_params)
    {to_string(prompt_key), params_hash, content_hash, annotated}
  end

  @doc false
  def count_tokens(text) do
    binary = if is_binary(text), do: text, else: IO.iodata_to_binary(text)
    words = binary |> String.trim() |> String.split()
    div(length(words) * 4, 3)
  end

  @doc false
  def strip_annotations(text) do
    text
    |> String.replace(~r/\[\[section:[^\]]+\]\]\n?/, "")
    |> String.replace(~r/\n?\[\[\/section\]\]/, "")
  end

  @doc false
  def maybe_put(map, _key, nil), do: map
  def maybe_put(map, key, value), do: Map.put(map, key, value)

  @doc false
  def extract_response_contract(body) when is_list(body) do
    response_blocks = ResponseCollector.collect_response_blocks(body)

    case response_blocks do
      [] ->
        nil

      [{content, _line} | _] ->
        ResponseCollector.derive_schema(content)
    end
  end

  def extract_response_contract(_), do: nil

  @doc false
  def major_from_version(nil), do: 1
  def major_from_version(version) when is_integer(version), do: version

  def major_from_version(version) when is_binary(version) do
    version = String.replace(version, "v", "")

    case Integer.parse(version) do
      {major, "." <> _} -> major
      {major, _} -> major
      _ -> 1
    end
  end

  @doc false
  def indent_content(content, indent) when is_binary(content) do
    if indent == "" do
      content
    else
      content
      |> String.split("\n")
      |> Enum.map(fn
        "" -> ""
        "[[" <> _ = line -> line
        line -> [indent, line]
      end)
      |> Enum.intersperse("\n")
    end
  end

  def indent_content(content, indent) do
    indent_content(IO.iodata_to_binary(content), indent)
  end

  @doc false
  def validate_params_if_needed(_params, declarations) when declarations == %{}, do: :ok

  def validate_params_if_needed(params, declarations) do
    Validator.validate_params(params, declarations)
  end

  @doc false
  def apply_defaults(params, declarations) do
    declarations
    |> Enum.filter(fn {_name, spec} -> Map.has_key?(spec, :default) and spec.default != nil end)
    |> Enum.reduce(params, fn {name, spec}, acc ->
      clean_name = name |> String.trim_leading("@")
      clean_atom = safe_to_atom(clean_name)

      if Map.has_key?(acc, clean_atom) or Map.has_key?(acc, clean_name) do
        acc
      else
        Map.put(acc, clean_atom, spec.default)
      end
    end)
  end

  @doc false
  def extract_error_with_line(message, tokens) do
    if String.contains?(message, "line") or String.contains?(message, "at line") do
      message
    else
      last_token = List.last(tokens)

      if last_token && last_token.line do
        "#{message} (near line #{last_token.line})"
      else
        message
      end
    end
  end

  @doc false
  def add_line_info_to_validation_error(message, content) when is_binary(content) do
    if String.contains?(message, "line") do
      message
    else
      case Regex.run(
             ~r/(unknown_variable|missing_param|invalid_type|invalid_enum|out_of_range):\s*@?(\w+)/,
             message
           ) do
        [_, _type, var_name] ->
          case find_var_line_number(content, var_name) do
            nil -> message
            line_num -> "#{message} (at line #{line_num})"
          end

        _ ->
          message
      end
    end
  end

  defp find_var_line_number(content, var_name) do
    content
    |> String.split(["\r\n", "\n"], trim: false)
    |> Enum.with_index(1)
    |> Enum.reduce_while(nil, fn {line, line_num}, _acc ->
      if String.contains?(line, "@#{var_name}") do
        {:halt, line_num}
      else
        {:cont, nil}
      end
    end)
  end

  @doc false
  def ensure_atom(k) when is_atom(k), do: k

  def ensure_atom(k) when is_binary(k) do
    try do
      String.to_existing_atom(k)
    rescue
      ArgumentError -> k
    end
  end

  @doc false
  def get_param(params, key_str) do
    case Map.get(params, key_str) do
      nil ->
        Enum.find_value(params, fn
          {k, v} when is_atom(k) ->
            if Atom.to_string(k) == key_str, do: v

          _ ->
            nil
        end)

      value ->
        value
    end
  end

  @doc false
  def infer_type(v) when is_binary(v), do: "string"
  def infer_type(v) when is_integer(v), do: "number"
  def infer_type(v) when is_float(v), do: "number"
  def infer_type(v) when is_boolean(v), do: "boolean"
  def infer_type(v) when is_nil(v), do: "null"
  def infer_type(v) when is_list(v), do: "array"
  def infer_type(v) when is_map(v), do: "object"
  def infer_type(_), do: "unknown"

  @doc false
  def extract_init_vars(init) do
    vars =
      Enum.reduce(init.fragments, MapSet.new(), fn {_name, spec}, acc ->
        matches = Regex.scan(~r/@(\w+)/, spec.type || "") |> Enum.map(fn [_, v] -> v end)
        Enum.reduce(matches, acc, &MapSet.put(&2, &1))
      end)

    Enum.reduce(init.params, vars, fn {_name, spec}, acc ->
      if is_binary(spec.type) do
        matches = Regex.scan(~r/@(\w+)/, spec.type) |> Enum.map(fn [_, v] -> v end)
        Enum.reduce(matches, acc, &MapSet.put(&2, &1))
      else
        acc
      end
    end)
  end

  @doc false
  def type_matches?("string", "string"), do: true
  def type_matches?("string", "null"), do: true
  def type_matches?("number", "number"), do: true
  def type_matches?("boolean", "boolean"), do: true
  def type_matches?("null", "null"), do: true
  def type_matches?("array", "array"), do: true
  def type_matches?("object", "object"), do: true
  def type_matches?(_, _), do: false

  @doc false
  def extract_defaults_from_metadata(%{params: params}) when is_map(params) do
    params
    |> Enum.filter(fn {_name, spec} -> Map.has_key?(spec, :default) and spec.default != nil end)
    |> Enum.into(%{}, fn {name, spec} ->
      clean_atom = safe_to_atom(name)
      {clean_atom, spec.default}
    end)
  end

  def extract_defaults_from_metadata(_), do: %{}

  @doc false
  def validate_response(response, contract, strict) do
    errors =
      for {field, field_spec} <- contract, reduce: [] do
        acc ->
          required = Map.get(field_spec, :required, false)
          expected_type = Map.get(field_spec, :type, "string")

          field_errors =
            with :ok <- check_required_field(required, field, response),
                 :ok <- check_field_type(field, expected_type, response) do
              []
            else
              {:error, msg} -> [msg]
            end

          acc ++ field_errors
      end

    final_errors =
      if strict and Map.keys(response) -- Map.keys(contract) != [] do
        extra_fields = Map.keys(response) -- Map.keys(contract)
        errors ++ ["Unexpected fields: #{Enum.join(extra_fields, ", ")}"]
      else
        errors
      end

    case final_errors do
      [] -> :ok
      errors -> {:error, Enum.join(errors, "; ")}
    end
  end

  defp check_required_field(true, field, response) do
    if Map.has_key?(response, field), do: :ok, else: {:error, "Missing required field: #{field}"}
  end

  defp check_required_field(false, _field, _response), do: :ok

  defp check_field_type(_field, _expected_type, response) when response == %{}, do: :ok

  defp check_field_type(field, expected_type, response) do
    if Map.has_key?(response, field) do
      actual_value = Map.get(response, field)
      actual_type = infer_type(actual_value)

      if type_matches?(expected_type, actual_type) do
        :ok
      else
        {:error, "Field #{field} has type #{actual_type}, expected #{expected_type}"}
      end
    else
      :ok
    end
  end

  defp safe_to_atom(binary) when is_binary(binary) do
    String.to_existing_atom(binary)
  rescue
    ArgumentError -> binary
  end

  defp safe_to_atom(_), do: nil
end