lib/adt.ex

defmodule ADT do
  @moduledoc """
  Pseudo-ADT definition generator
  """

  @doc """
  Use it like this:

    > ADT.define foo(a: "default") | bar(b: "value")
    > %Foo{a:"value"}
    %Foo{a: "value"}
  """
  defmacro define(parts) do
    parts |> format_parts(__CALLER__) |> generate_code
  end

  defp generate_code(values) do
    variants_definition = quote do
      Module.register_attribute __MODULE__, :variants, accumulate: true
    end
    variants_reader = quote do
      def variants do
        @variants
      end
    end
    modules = values |> Enum.map(&generate_defmodule/1)
    Enum.concat([[variants_definition], modules, [variants_reader, case_macro(values)]])
  end

  defp case_macro(values) do
    quote do
      defmacro case(a, statements \\ []) do
        possible_variants = unquote(values) |> Enum.map(
          fn {variant, _} ->
            ADT._shorten(variant)
          end
        ) |> Enum.sort
        given_variants = statements |> Enum.map(fn {k, v} -> { to_string(k), v } end) |> Enum.sort
        given_variant_names = given_variants |> Enum.map(fn {k, _} -> k end)
        catch_all = statements |> Enum.map(fn {k, _} -> to_string(k) end) |> ADT._has_catch_all
        if possible_variants != given_variant_names and not catch_all do
          raise ADT._exhaustive_error(possible_variants, given_variant_names)
        end
        rules = Enum.flat_map(given_variants, &ADT._compile_clause(a, &1))
        quote do
          cond do: unquote(rules)
        end
      end
    end
  end

  # Maps a module like Foo.Bar.Baz into a short string "Baz"
  def _shorten(name) do
    Regex.named_captures(~r/\.(?<short>[^.{]+)($|{)/, inspect(name), include_captures: true) |> Map.fetch!("short")
  end

  def _compile_clause(a, {"_", v}) do
    quote do
      true -> unquote(v).(unquote(a))
    end
  end
  def _compile_clause(a, {k, v}) do
    quote do
      unquote(k) == ADT._shorten(unquote(a)) -> unquote(v).(unquote(a))
    end
  end

  def _exhaustive_error(possible_variants, []) do
    "case macro needs to handle these cases: #{inspect(possible_variants)}."
  end
  def _exhaustive_error(possible_variants, given_variants) do
    set_possible = MapSet.new(possible_variants)
    set_given = MapSet.new(given_variants)
    unhandled_cases = Enum.into(MapSet.difference(set_possible, set_given), [])
    if MapSet.subset?(set_given, set_possible) do
      "case macro not exhaustive.\nUnhandled cases: #{inspect(unhandled_cases)}."
    else
      extra_cases = Enum.into(MapSet.difference(set_given, set_possible), [])
      unhandled_message = if unhandled_cases == [], do: "", else: "\nUnhandled cases: #{inspect(unhandled_cases)}."
      "case macro not exhaustive.#{unhandled_message}\nExtra cases: #{inspect(extra_cases)}."
    end
  end

  def _has_catch_all(names) do
    count = Enum.count(names, fn n -> n == "_" end)
    case {count, Enum.reverse(names)} do
      {_, []} ->
        false
      {0, _} ->
        false
      {1, ["_"|a]} ->
        true
      {1, _} ->
        raise "case macro only accepts a catch all at the last position."
      {count, _} when count > 1 ->
        raise "case macro contains #{count} catch all clauses."
    end
  end

  # Flatten "one | two | three" ("one | (two | three)" in the AST
  # to "[one, two, three]"
  defp format_parts({:|, _, [elem, rest]}, caller) do
    format_parts(elem, caller) ++ format_parts(rest, caller)
  end
  defp format_parts({name, _, [content]}, caller) do
    [{create_full_name(name, caller), content}]
  end
  defp format_parts({name, _, _}, caller) do
    [{create_full_name(name, caller), []}]
  end

  # Generates an AST for the module definition
  # based on the ADT alternative, like:
  #
  #  foo(a: "default")
  defp generate_defmodule({name, fields}) do
    quote do
      @variants unquote(name)
      defmodule unquote(name) do
        defstruct unquote(fields)
      end
    end
  end

  defp create_full_name(name, caller) do
    # name is lowercase atom, need to capitalize + re-atom
    # fields is [name: val, name: val]
    module_name = name |> to_string |> format_module_name
    # then, generate a module name from the string
    Module.concat([caller.module, module_name])
  end

  # Helper code that really shouldn't be here.
  # Transforms "foo_bar" to "FooBar"
  defp format_module_name(str) do
    Regex.replace(~r/_([a-z])/, String.capitalize(str),
      fn _, l -> String.upcase(l) end)
  end
end