defmodule Jido.Chat.Markdown do
@moduledoc """
Canonical Markdown AST and formatting helpers.
"""
alias Jido.Chat.Markdown.Node
alias Jido.Chat.Wire
@schema Zoi.struct(
__MODULE__,
%{
root: Zoi.struct(Node),
metadata: Zoi.map() |> Zoi.default(%{})
},
coerce: true
)
@type t :: unquote(Zoi.type_spec(@schema))
@type markdown_node :: Node.t()
@type node_input :: markdown_node() | map() | String.t()
@enforce_keys Zoi.Struct.enforce_keys(@schema)
defstruct Zoi.Struct.struct_fields(@schema)
@doc "Returns the schema for Markdown documents."
def schema, do: @schema
@doc "Creates a canonical Markdown document."
@spec new(t() | map() | String.t() | [node_input()]) :: t()
def new(%__MODULE__{} = markdown), do: markdown
def new(value) when is_binary(value), do: parse(value)
def new(value) when is_list(value), do: root(value)
def new(attrs) when is_map(attrs) do
attrs
|> normalize_root()
|> then(&Jido.Chat.Schema.parse!(__MODULE__, @schema, &1))
end
@doc "Builds a Markdown document with a root node."
@spec root([node_input()], keyword() | map()) :: t()
def root(children, opts \\ []) when is_list(children) do
opts = normalize_opts(opts)
new(%{
root: Node.new(%{type: :root, children: children}),
metadata: opts[:metadata] || opts["metadata"] || %{}
})
end
@doc "Builds a paragraph node."
@spec paragraph([node_input()] | String.t()) :: markdown_node()
def paragraph(value), do: Node.new(%{type: :paragraph, children: normalize_children(value)})
@doc "Builds a text node."
@spec text(String.t()) :: markdown_node()
def text(value) when is_binary(value), do: Node.new(%{type: :text, text: value})
@doc "Builds a strong node."
@spec strong([node_input()] | String.t()) :: markdown_node()
def strong(value), do: Node.new(%{type: :strong, children: normalize_children(value)})
@doc "Builds an emphasis node."
@spec emphasis([node_input()] | String.t()) :: markdown_node()
def emphasis(value), do: Node.new(%{type: :emphasis, children: normalize_children(value)})
@doc "Builds a link node."
@spec link([node_input()] | String.t(), String.t(), keyword() | map()) :: markdown_node()
def link(label, url, opts \\ []) when is_binary(url) do
opts = normalize_opts(opts)
Node.new(%{
type: :link,
url: url,
children: normalize_children(label),
metadata: opts[:metadata] || opts["metadata"] || %{}
})
end
@doc "Builds an inline code node."
@spec code(String.t()) :: markdown_node()
def code(value) when is_binary(value), do: Node.new(%{type: :code, text: value})
@doc "Builds a fenced code block node."
@spec code_block(String.t(), String.t() | nil, keyword() | map()) :: markdown_node()
def code_block(value, language \\ nil, opts \\ []) when is_binary(value) do
opts = normalize_opts(opts)
Node.new(%{
type: :code_block,
text: value,
language: language,
metadata: opts[:metadata] || opts["metadata"] || %{}
})
end
@doc "Builds a heading node."
@spec heading(pos_integer(), [node_input()] | String.t()) :: markdown_node()
def heading(level, value) when is_integer(level) and level >= 1 and level <= 6 do
Node.new(%{type: :heading, level: level, children: normalize_children(value)})
end
@doc "Builds a list node."
@spec list([node_input()], keyword() | map()) :: markdown_node()
def list(items, opts \\ []) when is_list(items) do
opts = normalize_opts(opts)
children =
Enum.map(items, fn
%Node{type: :list_item} = item -> item
item -> list_item(item)
end)
Node.new(%{
type: :list,
ordered: opts[:ordered] || opts["ordered"] || false,
start: opts[:start] || opts["start"],
children: children,
metadata: opts[:metadata] || opts["metadata"] || %{}
})
end
@doc "Builds a list item node."
@spec list_item(node_input() | [node_input()]) :: markdown_node()
def list_item(value), do: Node.new(%{type: :list_item, children: normalize_children(value)})
@doc "Builds a blockquote node."
@spec blockquote([node_input()] | String.t()) :: markdown_node()
def blockquote(value), do: Node.new(%{type: :blockquote, children: normalize_children(value)})
@doc "Builds a table node."
@spec table([node_input()], keyword() | map()) :: markdown_node()
def table(rows, opts \\ []) when is_list(rows) do
opts = normalize_opts(opts)
Node.new(%{
type: :table,
children: Enum.map(rows, &normalize_row/1),
metadata: opts[:metadata] || opts["metadata"] || %{}
})
end
@doc "Builds a table row node."
@spec row([node_input()]) :: markdown_node()
def row(cells) when is_list(cells) do
Node.new(%{type: :table_row, children: Enum.map(cells, &normalize_cell/1)})
end
@doc "Builds a table cell node."
@spec cell([node_input()] | String.t()) :: markdown_node()
def cell(value), do: Node.new(%{type: :table_cell, children: normalize_children(value)})
@doc "Builds a divider node."
@spec divider() :: markdown_node()
def divider, do: Node.new(%{type: :divider})
@doc "Parses plain Markdown text into a canonical AST."
@spec parse(String.t()) :: t()
def parse(markdown) when is_binary(markdown) do
markdown
|> String.split("\n", trim: false)
|> parse_lines([])
|> Enum.reverse()
|> root(metadata: %{source: :parse})
end
@doc "Stringifies a Markdown document or node back to Markdown text."
@spec stringify(t() | markdown_node() | [node_input()] | String.t() | nil) :: String.t()
def stringify(nil), do: ""
def stringify(value) when is_binary(value), do: value
def stringify(%__MODULE__{root: root}), do: render_children(root.children, "\n\n")
def stringify(%Node{} = node), do: render_node(node)
def stringify(nodes) when is_list(nodes),
do: nodes |> Enum.map_join("\n\n", &(&1 |> normalize_node() |> render_node()))
@doc "Extracts plain text from a Markdown document or node."
@spec plain_text(t() | markdown_node() | [node_input()] | String.t() | nil) :: String.t()
def plain_text(nil), do: ""
def plain_text(value) when is_binary(value), do: value
def plain_text(%__MODULE__{root: root}), do: render_plain_children(root.children, "\n\n")
def plain_text(%Node{} = node), do: render_plain_node(node)
def plain_text(nodes) when is_list(nodes) do
nodes
|> Enum.map_join("\n\n", &(&1 |> normalize_node() |> render_plain_node()))
end
@doc "Walks and transforms every node in the Markdown AST."
@spec walk(t() | markdown_node(), (markdown_node() -> markdown_node())) ::
t() | markdown_node()
def walk(%__MODULE__{root: root} = markdown, fun) when is_function(fun, 1) do
%{markdown | root: walk_node(root, fun)}
end
def walk(%Node{} = node, fun) when is_function(fun, 1), do: walk_node(node, fun)
@doc "Renders the first table in a Markdown document, or a given table node, as ASCII."
@spec table_to_ascii(t() | markdown_node()) :: String.t()
def table_to_ascii(%__MODULE__{root: root}), do: root |> first_table() |> table_to_ascii()
def table_to_ascii(nil), do: ""
def table_to_ascii(%Node{type: :table, children: rows}) do
rows =
Enum.map(rows, fn %Node{children: cells} ->
Enum.map(cells, &(&1 |> render_plain_node() |> String.trim()))
end)
widths =
rows
|> Enum.zip_with(fn column ->
column
|> Enum.map(&String.length/1)
|> Enum.max(fn -> 0 end)
end)
case rows do
[] ->
""
[header | body] ->
divider =
widths
|> Enum.map_join("-+-", &String.duplicate("-", max(&1, 1)))
[
render_ascii_row(header, widths),
divider
| Enum.map(body, &render_ascii_row(&1, widths))
]
|> Enum.join("\n")
end
end
def table_to_ascii(%Node{} = node), do: node |> first_table() |> table_to_ascii()
@doc "Serializes Markdown into a plain map with type markers."
@spec to_map(t()) :: map()
def to_map(%__MODULE__{} = markdown) do
markdown
|> Map.from_struct()
|> Map.update!(:root, &Node.to_map/1)
|> Wire.to_plain()
|> Map.put("__type__", "markdown")
end
@doc "Builds Markdown from serialized map data."
@spec from_map(map()) :: t()
def from_map(map) when is_map(map), do: map |> Map.drop(["__type__", :__type__]) |> new()
defp normalize_root(attrs) do
root =
case attrs[:root] || attrs["root"] do
%Node{} = root ->
root
%{} = root ->
Node.new(root)
nil ->
children = attrs[:nodes] || attrs["nodes"] || []
Node.new(%{type: :root, children: children})
end
attrs
|> Map.delete("root")
|> Map.put(:root, root)
end
defp parse_lines([], acc), do: acc
defp parse_lines(["" | rest], acc), do: parse_lines(rest, acc)
defp parse_lines([line | rest], acc) do
cond do
String.starts_with?(line, "```") ->
{node, remainder} = parse_code_block(rest, String.trim_leading(line, "```"))
parse_lines(remainder, [node | acc])
heading_line?(line) ->
{level, text} = parse_heading(line)
parse_lines(rest, [heading(level, text) | acc])
table_header?(line, rest) ->
{node, remainder} = parse_table([line | rest])
parse_lines(remainder, [node | acc])
list_line?(line) ->
{node, remainder} = parse_list([line | rest])
parse_lines(remainder, [node | acc])
String.starts_with?(String.trim_leading(line), "> ") ->
{node, remainder} = parse_blockquote([line | rest])
parse_lines(remainder, [node | acc])
String.trim(line) == "---" ->
parse_lines(rest, [divider() | acc])
true ->
{node, remainder} = parse_paragraph([line | rest])
parse_lines(remainder, [node | acc])
end
end
defp parse_code_block(lines, language) do
{body, remainder} = Enum.split_while(lines, &(not String.starts_with?(&1, "```")))
remainder = if remainder == [], do: [], else: tl(remainder)
{code_block(Enum.join(body, "\n"), blank_to_nil(String.trim(language))), remainder}
end
defp parse_heading(line) do
trimmed = String.trim_leading(line)
marks = trimmed |> String.graphemes() |> Enum.take_while(&(&1 == "#")) |> length()
{marks, trimmed |> String.trim_leading("#") |> String.trim()}
end
defp parse_table([header, _separator | rest]) do
{rows, remainder} =
rest
|> Enum.split_while(fn line -> String.contains?(line, "|") and String.trim(line) != "" end)
header_row = header |> split_table_row() |> row()
body_rows = Enum.map(rows, &(&1 |> split_table_row() |> row()))
{table([header_row | body_rows]), remainder}
end
defp parse_list(lines) do
{items, remainder} = Enum.split_while(lines, &list_line?/1)
ordered? =
case items do
[first | _] -> Regex.match?(~r/^\s*\d+\.\s+/, first)
_ -> false
end
items =
Enum.map(items, fn line ->
line
|> String.trim()
|> String.replace(~r/^([-*]|\d+\.)\s+/, "")
|> list_item()
end)
{list(items, ordered: ordered?), remainder}
end
defp parse_blockquote(lines) do
{quoted, remainder} =
Enum.split_while(lines, fn line ->
trimmed = String.trim_leading(line)
String.starts_with?(trimmed, "> ")
end)
text =
quoted
|> Enum.map(fn line -> line |> String.trim_leading() |> String.trim_leading("> ") end)
|> Enum.join("\n")
{blockquote([paragraph(text)]), remainder}
end
defp parse_paragraph(lines) do
{paragraph_lines, remainder} =
Enum.split_while(lines, fn line ->
trimmed = String.trim(line)
trimmed != "" and not heading_line?(line) and not list_line?(line) and
not String.starts_with?(trimmed, "> ") and not String.starts_with?(trimmed, "```") and
not table_header?(line, remainder_preview(lines, line))
end)
text =
paragraph_lines
|> Enum.map(&String.trim/1)
|> Enum.join(" ")
{paragraph(text), remainder}
end
defp render_node(%Node{type: :root, children: children}), do: render_children(children, "\n\n")
defp render_node(%Node{type: :paragraph, children: children}), do: render_children(children, "")
defp render_node(%Node{type: :text, text: text}), do: text || ""
defp render_node(%Node{type: :strong, children: children}),
do: "**#{render_children(children, "")}**"
defp render_node(%Node{type: :emphasis, children: children}),
do: "_#{render_children(children, "")}_"
defp render_node(%Node{type: :link, url: url, children: children}),
do: "[#{render_children(children, "")}](#{url})"
defp render_node(%Node{type: :code, text: text}), do: "`#{text || ""}`"
defp render_node(%Node{type: :code_block, text: text, language: language}) do
"```#{language || ""}\n#{text || ""}\n```"
end
defp render_node(%Node{type: :heading, level: level, children: children}) do
"#{String.duplicate("#", level || 1)} #{render_children(children, "")}"
end
defp render_node(%Node{type: :list, ordered: ordered?, start: start, children: children}) do
start = start || 1
children
|> Enum.with_index(start)
|> Enum.map_join("\n", fn {%Node{} = child, index} ->
marker = if ordered?, do: "#{index}. ", else: "- "
marker <> render_list_item(child)
end)
end
defp render_node(%Node{type: :list_item} = node), do: render_list_item(node)
defp render_node(%Node{type: :blockquote, children: children}) do
children
|> render_children("\n")
|> String.split("\n")
|> Enum.map_join("\n", &("> " <> &1))
end
defp render_node(%Node{type: :table} = node), do: render_markdown_table(node)
defp render_node(%Node{type: :table_row, children: children}) do
children
|> Enum.map_join(" | ", &render_plain_node/1)
end
defp render_node(%Node{type: :table_cell, children: children}),
do: render_children(children, "")
defp render_node(%Node{type: :divider}), do: "---"
defp render_plain_node(%Node{type: :root, children: children}),
do: render_plain_children(children, "\n\n")
defp render_plain_node(%Node{type: :paragraph, children: children}),
do: render_plain_children(children, "")
defp render_plain_node(%Node{type: :text, text: text}), do: text || ""
defp render_plain_node(%Node{type: :strong, children: children}),
do: render_plain_children(children, "")
defp render_plain_node(%Node{type: :emphasis, children: children}),
do: render_plain_children(children, "")
defp render_plain_node(%Node{type: :link, children: children}),
do: render_plain_children(children, "")
defp render_plain_node(%Node{type: :code, text: text}), do: text || ""
defp render_plain_node(%Node{type: :code_block, text: text}), do: text || ""
defp render_plain_node(%Node{type: :heading, children: children}),
do: render_plain_children(children, "")
defp render_plain_node(%Node{type: :list, children: children}) do
children
|> Enum.map_join("\n", &render_plain_node/1)
end
defp render_plain_node(%Node{type: :list_item, children: children}),
do: render_plain_children(children, " ")
defp render_plain_node(%Node{type: :blockquote, children: children}),
do: render_plain_children(children, "\n")
defp render_plain_node(%Node{type: :table} = node), do: table_to_ascii(node)
defp render_plain_node(%Node{type: :table_row, children: children}) do
children
|> Enum.map_join(" | ", &render_plain_node/1)
end
defp render_plain_node(%Node{type: :table_cell, children: children}),
do: render_plain_children(children, "")
defp render_plain_node(%Node{type: :divider}), do: "---"
defp render_children(children, separator),
do: children |> Enum.map_join(separator, &(&1 |> normalize_node() |> render_node()))
defp render_plain_children(children, separator) do
children
|> Enum.map_join(separator, &(&1 |> normalize_node() |> render_plain_node()))
|> String.trim()
end
defp render_list_item(%Node{children: children}) do
children
|> Enum.map_join(" ", fn child -> child |> normalize_node() |> render_node() end)
|> String.replace("\n", "\n ")
end
defp render_markdown_table(%Node{children: []}), do: ""
defp render_markdown_table(%Node{children: [%Node{} = header | body]}) do
header_cells = Enum.map(header.children, &render_plain_node/1)
divider = Enum.map_join(header_cells, " | ", fn _ -> "---" end)
body_rows =
Enum.map_join(body, "\n", fn %Node{children: cells} ->
"| " <> Enum.map_join(cells, " | ", &render_plain_node/1) <> " |"
end)
[
"| " <> Enum.join(header_cells, " | ") <> " |",
"| " <> divider <> " |",
body_rows
]
|> Enum.reject(&(&1 == ""))
|> Enum.join("\n")
end
defp walk_node(%Node{} = node, fun) do
children = Enum.map(node.children, &walk_node(normalize_node(&1), fun))
node |> Map.put(:children, children) |> fun.()
end
defp first_table(%Node{type: :table} = node), do: node
defp first_table(%Node{children: children}) do
Enum.find_value(children, &first_table(normalize_node(&1)))
end
defp normalize_node(%Node{} = node), do: node
defp normalize_node(node), do: Node.normalize(node)
defp normalize_children(value) when is_binary(value), do: [text(value)]
defp normalize_children(value) when is_list(value), do: Enum.map(value, &normalize_node/1)
defp normalize_children(value), do: [normalize_node(value)]
defp normalize_row(%Node{type: :table_row} = row), do: row
defp normalize_row(values) when is_list(values), do: row(values)
defp normalize_row(value), do: row([value])
defp normalize_cell(%Node{type: :table_cell} = cell), do: cell
defp normalize_cell(value), do: cell(value)
defp split_table_row(line) do
line
|> String.trim()
|> String.trim("|")
|> String.split("|")
|> Enum.map(&String.trim/1)
end
defp render_ascii_row(values, widths) do
values
|> Enum.zip(widths)
|> Enum.map_join(" | ", fn {value, width} -> String.pad_trailing(value, width) end)
end
defp heading_line?(line), do: Regex.match?(~r/^\s{0,3}\#{1,6}\s+.+$/, line)
defp list_line?(line) do
Regex.match?(~r/^\s*([-*]|\d+\.)\s+.+$/, line)
end
defp table_header?(line, [separator | _rest]) do
String.contains?(line, "|") and Regex.match?(~r/^\s*\|?[\s:-|]+\|?\s*$/, separator)
end
defp table_header?(_line, _rest), do: false
defp remainder_preview(lines, current) do
case Enum.drop_while(lines, &(&1 != current)) do
[_current | rest] -> rest
_ -> []
end
end
defp blank_to_nil(""), do: nil
defp blank_to_nil(value), do: value
defp normalize_opts(opts) when is_list(opts), do: Map.new(opts)
defp normalize_opts(opts) when is_map(opts), do: opts
end