lib/spark/dsl.ex

defmodule Spark.Dsl do
  @using_schema [
    single_extension_kinds: [
      type: {:list, :atom},
      default: [],
      doc:
        "The extension kinds that are allowed to have a single value. For example: `[:data_layer]`"
    ],
    many_extension_kinds: [
      type: {:list, :atom},
      default: [],
      doc:
        "The extension kinds that can have multiple values. e.g `[notifiers: [Notifier1, Notifier2]]`"
    ],
    untyped_extensions?: [
      type: :boolean,
      default: true,
      doc: "Whether or not to support an `extensions` key which contains untyped extensions"
    ],
    default_extensions: [
      type: :keyword_list,
      default: [],
      doc: """
      The extensions that are included by default. e.g `[data_layer: Default, notifiers: [Notifier1]]`
      Default values for single extension kinds are overwritten if specified by the implementor, while many extension
      kinds are appended to if specified by the implementor.
      """
    ],
    opt_schema: [
      type: :keyword_list,
      default: [],
      doc: """
      A schema for additional options to accept when calling `use YourSpark`
      """
    ]
  ]

  @type entity :: %Spark.Dsl.Entity{}

  @type section :: %Spark.Dsl.Section{}

  @moduledoc """
  The primary entry point for adding a DSL to a module.

  To add a DSL to a module, add `use Spark.Dsl, ...options`. The options supported with `use Spark.Dsl` are:

  #{Spark.OptionsHelpers.docs(@using_schema)}

  See the callbacks defined in this module to augment the behavior/compilation of the module getting a Dsl.

  ## Schemas/Data Types

  Spark DSLs use a superset of `NimbleOptions` for the `schema` that makes up sections/entities of the DSL.
  For more information, see `Spark.OptionsHelpers`.
  """

  @type opts :: Keyword.t()
  @type t :: map()

  @doc """
  Validate/add options. Those options will be passed to `handle_opts` and `handle_before_compile`
  """
  @callback init(opts) :: {:ok, opts} | {:error, String.t() | term}

  @doc """
  Validate/add options. Those options will be passed to `handle_opts` and `handle_before_compile`
  """
  @callback explain(t(), opts) :: String.t() | nil

  @doc """
  Handle options in the context of the module. Must return a `quote` block.

  If you want to persist anything in the DSL persistence layer,
  use `@persist {:key, value}`. It can be called multiple times to
  persist multiple times.
  """
  @callback handle_opts(Keyword.t()) :: Macro.t()
  @doc """
  Handle options in the context of the module, after all extensions have been processed. Must return a `quote` block.
  """
  @callback handle_before_compile(Keyword.t()) :: Macro.t()

  defmacro __using__(opts) do
    opts = Spark.OptionsHelpers.validate!(opts, @using_schema)

    their_opt_schema =
      Enum.map(opts[:single_extension_kinds], fn extension_kind ->
        {extension_kind, type: :atom, default: opts[:default_extensions][extension_kind]}
      end) ++
        Enum.map(opts[:many_extension_kinds], fn extension_kind ->
          {extension_kind, type: {:list, :atom}, default: []}
        end)

    their_opt_schema =
      if opts[:untyped_extensions?] do
        Keyword.put(their_opt_schema, :extensions, type: {:list, :atom})
      else
        their_opt_schema
      end

    their_opt_schema =
      Keyword.merge(their_opt_schema,
        otp_app: [type: :atom],
        fragments: [type: {:list, :module}]
      )

    their_opt_schema = Keyword.merge(opts[:opt_schema] || [], their_opt_schema)

    quote bind_quoted: [
            their_opt_schema: their_opt_schema,
            parent_opts: opts,
            parent: __CALLER__.module
          ],
          generated: true do
      require Spark.Dsl.Extension
      @dialyzer {:nowarn_function, handle_opts: 1, handle_before_compile: 1}
      Module.register_attribute(__MODULE__, :spark_dsl, persist: true)
      Module.register_attribute(__MODULE__, :spark_default_extensions, persist: true)
      Module.register_attribute(__MODULE__, :spark_extension_kinds, persist: true)
      @spark_dsl true
      @spark_default_extensions parent_opts[:default_extensions]
                                |> Keyword.values()
                                |> List.flatten()
      @spark_extension_kinds List.wrap(parent_opts[:many_extension_kinds]) ++
                               List.wrap(parent_opts[:single_extension_kinds])

      @behaviour Spark.Dsl

      def init(opts), do: {:ok, opts}

      def explain(_, _), do: nil

      def default_extensions, do: @spark_default_extensions
      def default_extension_kinds, do: List.wrap(unquote(parent_opts[:default_extensions]))
      def many_extension_kinds, do: List.wrap(unquote(parent_opts[:many_extension_kinds]))
      def single_extension_kinds, do: List.wrap(unquote(parent_opts[:single_extension_kinds]))

      def handle_opts(opts) do
        quote do
        end
      end

      def handle_before_compile(opts) do
        quote do
        end
      end

      defmacro __using__(opts) do
        parent = unquote(parent)
        parent_opts = unquote(parent_opts)
        their_opt_schema = unquote(their_opt_schema)
        require Spark.Dsl.Extension

        fragments =
          opts[:fragments]
          |> List.wrap()
          |> Enum.map(&Spark.Dsl.Extension.do_expand(&1, __CALLER__))

        {opts, extensions} =
          parent_opts[:default_extensions]
          |> Enum.reduce(opts, fn {key, defaults}, opts ->
            Keyword.update(opts, key, defaults, fn current_value ->
              cond do
                key in parent_opts[:single_extension_kinds] ->
                  fragments_set =
                    Enum.filter(fragments, fn fragment ->
                      fragment.opts
                      |> Spark.Dsl.expand_modules(parent_opts, __CALLER__)
                      |> elem(0)
                      |> Keyword.get(key)
                    end)

                  cond do
                    current_value && !Enum.empty?(fragments_set) ->
                      raise "#{key} is being set as an option, but is also set in fragments: #{Enum.map_join(fragments_set, ", ", &inspect/1)}"

                    Enum.count(fragments_set) > 1 ->
                      raise "#{key} is being set by multiple fragments: #{Enum.map_join(fragments_set, ", ", &inspect/1)}"

                    true ->
                      current_value || List.first(fragments_set) || defaults
                  end

                key in parent_opts[:many_extension_kinds] || key == :extensions ->
                  List.wrap(current_value) ++ List.wrap(defaults)

                true ->
                  current_value
              end
            end)
          end)
          |> Spark.Dsl.expand_modules(parent_opts, __CALLER__)

        fragment_extensions = Enum.flat_map(fragments, & &1.extensions())

        opts =
          Enum.reduce(fragments, opts, fn fragment, opts ->
            fragment.opts()
            |> Spark.Dsl.expand_modules(parent_opts, __CALLER__)
            |> elem(0)
            |> Enum.reduce(opts, fn {key, value}, opts ->
              cond do
                key in parent_opts[:single_extension_kinds] ->
                  Keyword.put(opts, key, value)

                key in parent_opts[:many_extension_kinds] ->
                  Keyword.update(opts, key, value, fn current_extensions ->
                    Enum.uniq(current_extensions ++ List.wrap(value))
                  end)

                true ->
                  opts
              end
            end)
          end)

        extensions = Enum.uniq(extensions ++ fragment_extensions)

        Module.put_attribute(__CALLER__.module, :extensions, extensions)

        body =
          quote generated: true do
            opts =
              unquote(opts)
              |> Spark.OptionsHelpers.validate!(unquote(their_opt_schema))
              |> unquote(__MODULE__).init()
              |> Spark.Dsl.unwrap()

            parent = unquote(parent)
            parent_opts = unquote(parent_opts)
            their_opt_schema = unquote(their_opt_schema)

            @opts opts
            @before_compile Spark.Dsl
            @after_compile __MODULE__

            Module.register_attribute(__MODULE__, :spark_is, persist: true)

            @spark_is parent
            @spark_parent parent

            def spark_is, do: @spark_is

            defmacro __after_compile__(_, _) do
              quote do
                # This is here because dialyzer complains
                # this is really dumb but it works so 🤷‍♂️
                if @after_compile_transformers do
                  transformers_to_run =
                    @extensions
                    |> Enum.flat_map(& &1.transformers())
                    |> Spark.Dsl.Transformer.sort()
                    |> Enum.filter(& &1.after_compile?())

                  @extensions
                  |> Enum.flat_map(& &1.verifiers())
                  |> Enum.each(fn verifier ->
                    case verifier.verify(@spark_dsl_config) do
                      :ok ->
                        :ok

                      {:warn, warnings} ->
                        warnings
                        |> List.wrap()
                        |> Enum.each(&IO.warn(&1, Macro.Env.stacktrace(__ENV__)))

                      {:error, error} ->
                        if Exception.exception?(error) do
                          raise error
                        else
                          raise "Verification error from #{inspect(verifier)}: #{inspect(error)}"
                        end
                    end
                  end)

                  __MODULE__
                  |> Spark.Dsl.Extension.run_transformers(
                    transformers_to_run,
                    @spark_dsl_config,
                    __ENV__
                  )
                end
              end
            end

            Module.register_attribute(__MODULE__, :persist, accumulate: true)

            opts
            |> @spark_parent.handle_opts()
            |> Code.eval_quoted([], __ENV__)

            if opts[:otp_app] do
              @persist {:otp_app, opts[:otp_app]}
            end

            @persist {:module, __MODULE__}
            @persist {:file, __ENV__.file}

            for single_extension_kind <- parent_opts[:single_extension_kinds] do
              @persist {single_extension_kind, opts[single_extension_kind]}
              Module.put_attribute(__MODULE__, single_extension_kind, opts[single_extension_kind])
            end

            for many_extension_kind <- parent_opts[:many_extension_kinds] do
              @persist {many_extension_kind, opts[many_extension_kind] || []}
              Module.put_attribute(
                __MODULE__,
                many_extension_kind,
                opts[many_extension_kind] || []
              )
            end
          end

        preparations = Spark.Dsl.Extension.prepare(extensions)
        [body | preparations]
      end

      defoverridable init: 1, handle_opts: 1, handle_before_compile: 1, explain: 2, __using__: 1
    end
  end

  @doc false
  def unwrap({:ok, value}), do: value
  def unwrap({:error, error}), do: raise(error)

  @doc false
  def expand_modules(opts, their_opt_schema, env) do
    Enum.reduce(opts, {[], []}, fn {key, value}, {opts, extensions} ->
      cond do
        key in their_opt_schema[:single_extension_kinds] ->
          mod = Macro.expand(value, env)

          extensions =
            if Spark.implements_behaviour?(mod, Spark.Dsl.Extension) do
              [mod | extensions]
            else
              extensions
            end

          {Keyword.put(opts, key, mod), extensions}

        key in their_opt_schema[:many_extension_kinds] || key == :extensions ->
          mods =
            value
            |> List.wrap()
            |> Enum.map(&Macro.expand(&1, env))

          extensions =
            extensions ++
              Enum.filter(mods, &Spark.implements_behaviour?(&1, Spark.Dsl.Extension))

          {Keyword.put(opts, key, mods), extensions}

        true ->
          {Keyword.put(opts, key, value), extensions}
      end
    end)
  end

  @after_verify_supported Version.match?(System.version(), ">= 1.14.0")

  defmacro __before_compile__(env) do
    parent = Module.get_attribute(env.module, :spark_parent)
    opts = Module.get_attribute(env.module, :opts)
    parent_code = parent.handle_before_compile(opts)

    # This `to_string() == "true"` bit is to appease dialyzer
    # It is very dumb but it works so 🤷‍♂️
    verify_code =
      if to_string(@after_verify_supported) == "true" do
        quote generated: true do
          @after_compile_transformers false
          @after_verify {__MODULE__, :__verify_ash_dsl__}

          @doc false
          def __verify_ash_dsl__(_module) do
            transformers_to_run =
              @extensions
              |> Enum.flat_map(& &1.transformers())
              |> Spark.Dsl.Transformer.sort()
              |> Enum.filter(& &1.after_compile?())

            @extensions
            |> Enum.flat_map(& &1.verifiers())
            |> Enum.each(fn verifier ->
              case verifier.verify(@spark_dsl_config) do
                :ok ->
                  :ok

                {:warn, warnings} ->
                  warnings
                  |> List.wrap()
                  |> Enum.each(&IO.warn(&1, Macro.Env.stacktrace(__ENV__)))

                {:error, error} ->
                  if Exception.exception?(error) do
                    raise error
                  else
                    raise "Verification error from #{inspect(verifier)}: #{inspect(error)}"
                  end
              end
            end)

            __MODULE__
            |> Spark.Dsl.Extension.run_transformers(
              transformers_to_run,
              @spark_dsl_config,
              __ENV__
            )
          end
        end
      else
        quote do
          @after_compile_transformers true
        end
      end

    code =
      quote generated: true,
            bind_quoted: [dsl: __MODULE__, parent: parent, fragments: opts[:fragments]] do
        require Spark.Dsl.Extension

        Module.register_attribute(__MODULE__, :spark_is, persist: true)
        Module.put_attribute(__MODULE__, :spark_is, @spark_is)

        Spark.Dsl.Extension.set_state(@persist)
        @spark_dsl_config Spark.Dsl.handle_fragments(@spark_dsl_config, fragments)

        for {block, bindings} <- Enum.reverse(@spark_dsl_config[:eval] || []) do
          Code.eval_quoted(block, bindings, __ENV__)
        end

        def __spark_placeholder__, do: nil

        def spark_dsl_config do
          @spark_dsl_config
        end

        cond do
          @moduledoc == false ->
            :ok

          @moduledoc ->
            @moduledoc """
            #{@moduledoc}

            #{Spark.Dsl.Extension.explain(parent, @opts, @extensions, @spark_dsl_config)}
            """

          true ->
            @moduledoc Spark.Dsl.Extension.explain(parent, @opts, @extensions, @spark_dsl_config)
        end
      end

    [code, parent_code, verify_code]
  end

  def is?(module, type) when is_atom(module) do
    module.spark_is() == type
  rescue
    _ -> false
  end

  def is?(_module, _type), do: false

  def handle_fragments(dsl_config, fragments) do
    fragments
    |> List.wrap()
    |> Enum.reduce(dsl_config, fn fragment, acc ->
      config = Map.delete(fragment.spark_dsl_config(), :persist)

      Map.merge(acc, config, fn
        key, %{entities: entities, opts: opts}, %{entities: right_entities, opts: right_opts} ->
          %{
            entities: entities ++ right_entities,
            opts: merge_with_warning(opts, right_opts, key, "fragment: #{fragment}")
          }
      end)
    end)
  end

  @doc false
  def merge_with_warning(left, right, path, overwriting_by \\ nil) do
    Keyword.merge(left, right, fn key, left, right ->
      by =
        if overwriting_by do
          " by #{overwriting_by}"
        else
          ""
        end

      IO.warn(
        "#{Enum.join(path ++ [key], ".")} is being overwritten from #{inspect(left)} to #{inspect(right)}#{by}"
      )

      right
    end)
  end
end