lib/nx/defn/token.ex

defmodule Nx.Defn.Token do
  @moduledoc """
  A `defn` token used by hooks.

  ## Documentation for compilers

  The token has a `hooks` field as a list of maps of the shape:

      %{
        expr: Nx.Tensor.t | Nx.Container.t,
        name: atom(),
        callback: (Nx.Tensor.t | Nx.Container.t -> term()) | nil
      }

  The `hooks` field must only be accessed by `defn` compilers.
  """

  # Hooks are stored with the hooks declared first
  # at the end of the list.
  defstruct hooks: []

  @doc false
  def new do
    %Nx.Defn.Token{}
  end

  @doc false
  def add_hook(%Nx.Defn.Token{} = token, expr, name, callback)
      when is_atom(name) and (is_function(callback) or is_nil(callback)) do
    hook = %{expr: expr, name: name, callback: callback}
    update_in(token.hooks, &[hook | &1])
  end

  defimpl Inspect do
    import Inspect.Algebra

    def inspect(%{hooks: hooks}, opts) do
      concat([
        color("#Nx.Defn.Token<", :map, opts),
        to_doc(Enum.map(hooks, & &1.name), opts),
        color(">", :map, opts)
      ])
    end
  end
end