lib/nx/heatmap.ex

defmodule Nx.Heatmap do
  @moduledoc """
  Provides a heatmap that is printed using ANSI colors
  in the terminal.
  """

  @doc false
  defstruct [:tensor, opts: []]

  @behaviour Access

  @impl true
  def fetch(%Nx.Heatmap{tensor: tensor} = hm, value) do
    case Access.fetch(tensor, value) do
      {:ok, %Nx.Tensor{shape: {}} = tensor} -> {:ok, tensor}
      {:ok, tensor} -> {:ok, put_in(hm.tensor, tensor)}
      :error -> :error
    end
  end

  @impl true
  def get_and_update(hm, key, fun) do
    {get, tensor} = Access.get_and_update(hm.tensor, key, fun)
    {get, put_in(hm.tensor, tensor)}
  end

  @impl true
  def pop(hm, key) do
    {pop, tensor} = Access.pop(hm.tensor, key)
    {pop, put_in(hm.tensor, tensor)}
  end

  defimpl Inspect do
    import Inspect.Algebra

    @min_mono256 232
    @max_mono256 255
    @mono256 @min_mono256..@max_mono256

    def inspect(%{tensor: tensor, opts: heatmap_opts}, opts) do
      %{shape: shape, names: names, type: type} = tensor

      open = color("[", :list, opts)
      sep = color(",", :list, opts)
      close = color("]", :list, opts)

      data = data(tensor, heatmap_opts, opts, {open, sep, close})
      type = color(Nx.Type.to_string(type), :atom, opts)
      shape = Nx.Shape.to_algebra(shape, names, open, close)

      color("#Nx.Heatmap<", :map, opts)
      |> concat(nest(concat([line(), type, shape, line(), data]), 2))
      |> concat(color("\n>", :map, opts))
    end

    defp data(tensor, heatmap_opts, opts, doc) do
      whitespace = Keyword.get(heatmap_opts, :ansi_whitespace, "\u3000")

      {entry_fun, line_fun, nf_fun} =
        if Keyword.get_lazy(heatmap_opts, :ansi_enabled, &IO.ANSI.enabled?/0) do
          scale = Enum.count(@mono256) - 1

          entry_fun = fn range ->
            index = range |> Kernel.*(scale) |> round()
            color = Enum.fetch!(@mono256, index)
            [IO.ANSI.color_background(color), whitespace]
          end

          {entry_fun, &IO.iodata_to_binary([&1 | IO.ANSI.reset()]),
           &non_finite_ansi(&1, whitespace)}
        else
          {&Integer.to_string(&1 |> Kernel.*(9) |> round()), &IO.iodata_to_binary/1,
           &non_finite_ascii/1}
        end

      render(tensor, opts, doc, entry_fun, line_fun, nf_fun)
    end

    defp non_finite_ascii(:infinity), do: ?+
    defp non_finite_ascii(:neg_infinity), do: ?-
    defp non_finite_ascii(:nan), do: ?x

    defp non_finite_ansi(:infinity, _ws), do: [IO.ANSI.color_background(@max_mono256), "∞"]
    defp non_finite_ansi(:neg_infinity, _ws), do: [IO.ANSI.color_background(@min_mono256), "∞"]
    defp non_finite_ansi(:nan, ws), do: [IO.ANSI.red_background(), ws]

    defp render(%{shape: {size}, type: type} = tensor, _opts, _doc, entry_fun, line_fun, nf_fun) do
      data = Nx.to_flat_list(tensor)
      min = type |> Nx.Constants.min_finite() |> Nx.to_number()
      max = type |> Nx.Constants.max_finite() |> Nx.to_number()
      {data, [], min, max} = take_min_max(data, size, max, min, [])
      base = if max == min, do: 1, else: max - min

      data
      |> Enum.map(fn
        elem when is_number(elem) -> entry_fun.((elem - min) / base)
        elem when is_atom(elem) -> nf_fun.(elem)
      end)
      |> line_fun.()
    end

    defp render(%{shape: shape, type: type} = tensor, opts, doc, entry_fun, line_fun, nf_fun) do
      {dims, [rows, cols]} = shape |> Tuple.to_list() |> Enum.split(-2)

      limit = opts.limit
      list_opts = if limit == :infinity, do: [], else: [limit: rows * cols * limit + 1]
      data = Nx.to_flat_list(tensor, list_opts)

      min = type |> Nx.Constants.min_finite() |> Nx.to_number()
      max = type |> Nx.Constants.max_finite() |> Nx.to_number()
      state = {rows, cols, entry_fun, line_fun, nf_fun, min, max}
      {data, _rest, _limit} = chunk(dims, data, limit, state, doc)
      data
    end

    defp take_min_max(rest, 0, min, max, acc),
      do: {Enum.reverse(acc), rest, min, max}

    defp take_min_max([head | tail], count, min, max, acc) when is_atom(head),
      do: take_min_max(tail, count - 1, min, max, [head | acc])

    defp take_min_max([head | tail], count, min, max, acc),
      do: take_min_max(tail, count - 1, min(min, head), max(max, head), [head | acc])

    defp chunk([], acc, limit, {rows, cols, entry_fun, line_fun, nf_fun, min, max}, _docs) do
      {acc, rest, min, max} = take_min_max(acc, rows * cols, max, min, [])
      base = if max == min, do: 1, else: max - min

      {[], doc} =
        Enum.reduce(1..rows, {acc, empty()}, fn _, {acc, doc} ->
          {line, acc} =
            Enum.map_reduce(1..cols, acc, fn
              _, [elem | acc] when is_number(elem) -> {entry_fun.((elem - min) / base), acc}
              _, [elem | acc] when is_atom(elem) -> {nf_fun.(elem), acc}
            end)

          doc = concat(doc, concat(line(), line_fun.(line)))
          {acc, doc}
        end)

      if limit == :infinity, do: {doc, rest, limit}, else: {doc, rest, limit - 1}
    end

    defp chunk([dim | dims], data, limit, rcw, {open, sep, close} = docs) do
      {acc, rest, limit} =
        chunk_each(dim, data, [], limit, fn chunk, limit ->
          chunk(dims, chunk, limit, rcw, docs)
        end)

      doc =
        if(dims == [], do: open, else: concat(open, line()))
        |> concat(concat(Enum.intersperse(acc, concat(sep, line()))))
        |> nest(2)
        |> concat(line())
        |> concat(close)

      {doc, rest, limit}
    end

    defp chunk_each(0, data, acc, limit, _fun) do
      {Enum.reverse(acc), data, limit}
    end

    defp chunk_each(_dim, data, acc, 0, _fun) do
      {Enum.reverse(["..." | acc]), data, 0}
    end

    defp chunk_each(dim, data, acc, limit, fun) do
      {doc, rest, limit} = fun.(data, limit)
      chunk_each(dim - 1, rest, [doc | acc], limit, fun)
    end
  end
end