lib/cmp/comparable.ex

defprotocol Cmp.Comparable do
  @moduledoc """
  A protocol to define how elements of a type can be compared.

  The simplest way to define it for a custom struct is to use @derive
  based on semantical comparison of fields, in the given order:

      defmodule MyStruct do
        @derive {Cmp.Comparable, using: [:date, :id]}

        defstruct [:id, :date]
      end

      Cmp.sort([
        %MyStruct{date: ~D[2020-03-02], id: 100},
        %MyStruct{date: ~D[2020-03-02], id: 101},
        %MyStruct{date: ~D[2019-06-06], id: 300}
      ])
      [
        %MyStruct{date: ~D[2019-06-06], id: 300},
        %MyStruct{date: ~D[2020-03-02], id: 100},
        %MyStruct{date: ~D[2020-03-02], id: 101}
      ]

  For existing structs that are already implementing a `compare/2`
  function returning `:eq | :lt | :gt`, the protocol can simply be
  done by passing `using: :compare`.

      require Protocol
      Protocol.derive(Cmp.Comparable, MyStruct, using: :compare)

  This is already done for the following modules:
  - `Date`
  - `Time`
  - `DateTime`
  - `NaiveDateTime`
  - `Version`
  - `Decimal` (if available)

  """

  # @type t :: term()
  @typedoc """
  Type of an element implementing the `Comparable` protocol.
  """

  @doc """
  Defines how to semantically compare two elements of a given type.

  This should return:
  - `:eq` if `left == right`
  - `:lt` if `left < right`
  - `:gt` if `left > right`

  """
  @spec compare(t(), t()) :: :eq | :lt | :gt
  def compare(left, right)
end

defimpl Cmp.Comparable, for: Integer do
  def compare(left, right) when is_number(right) do
    Cmp.Util.compare_terms(left, right)
  end

  def compare(left, right), do: raise(Cmp.TypeError, left: left, right: right)
end

defimpl Cmp.Comparable, for: Float do
  def compare(left, right) when is_number(right) do
    Cmp.Util.compare_terms(left, right)
  end

  def compare(left, right), do: raise(Cmp.TypeError, left: left, right: right)
end

defimpl Cmp.Comparable, for: BitString do
  def compare(left, right) when is_binary(right) do
    Cmp.Util.compare_terms(left, right)
  end

  def compare(left, right), do: raise(Cmp.TypeError, left: left, right: right)
end

defimpl Cmp.Comparable, for: Any do
  defmacro __deriving__(module, _struct, using: :compare) do
    quote do
      defimpl Cmp.Comparable, for: unquote(module) do
        @compile {:inline, compare: 2}

        def compare(left, right) when is_struct(right, unquote(module)) do
          case unquote(module).compare(left, right) do
            result when result in [:lt, :gt, :eq] -> result
          end
        end

        def compare(left, right), do: raise(Cmp.TypeError, left: left, right: right)
      end
    end
  end

  defmacro __deriving__(module, _struct, using: fields) when is_list(fields) do
    vars = Map.new([:left, :right], &{&1, Macro.var(&1, __MODULE__)})

    quote do
      defimpl Cmp.Comparable, for: unquote(module) do
        def compare(unquote(vars.left), unquote(vars.right))
            when is_struct(unquote(vars.right), unquote(module)) do
          unquote(generate_fields_comparison(fields, vars))
        end

        def compare(left, right), do: raise(Cmp.TypeError, left: left, right: right)
      end
    end
  end

  defmacro __deriving__(module, _struct, _opts) do
    raise ArgumentError, """
    Deriving Cmp.Comparable for #{inspect(module)} needs to pass a :using option which is either:

    - `using: :compare` if you plan to define a compare/2 function
    - `using: [:field1, :field2] if you want to use a list of fields

    """
  end

  def compare(left, _right) do
    raise Protocol.UndefinedError,
      protocol: @protocol,
      value: left
  end

  defp generate_fields_comparison([field], vars) when is_atom(field) do
    quote do
      Cmp.compare(
        Map.fetch!(unquote(vars.left), unquote(field)),
        Map.fetch!(unquote(vars.right), unquote(field))
      )
    end
  end

  defp generate_fields_comparison([field | fields], vars) when is_atom(field) do
    comparison = generate_fields_comparison([field], vars)
    continue = generate_fields_comparison(fields, vars)

    quote do
      case unquote(comparison) do
        :eq -> unquote(continue)
        result when result in [:lt, :gt] -> result
      end
    end
  end
end

for struct <- Cmp.Util.comparable_structs() do
  require Protocol
  Protocol.derive(Cmp.Comparable, struct, using: :compare)
end

defimpl Cmp.Comparable, for: Tuple do
  def compare(left, right) when is_tuple(right) and tuple_size(left) == tuple_size(right) do
    compare_tuples(left, right, 1, tuple_size(left))
  end

  def compare(left, right), do: raise(Cmp.TypeError, left: left, right: right)

  defp compare_tuples(_left, _right, i, n) when i > n, do: :eq

  defp compare_tuples(left, right, i, n) do
    l = :erlang.element(i, left)
    r = :erlang.element(i, right)

    case Cmp.compare(l, r) do
      :eq ->
        compare_tuples(left, right, i + 1, n)

      result ->
        # keep checking types even if no impact on result
        keep_checking(left, right, i + 1, n)
        result
    end
  end

  defp keep_checking(_left, _right, i, n) when i > n, do: :ok

  defp keep_checking(left, right, i, n) do
    l = :erlang.element(i, left)
    r = :erlang.element(i, right)
    Cmp.compare(l, r)
    keep_checking(left, right, i + 1, n)
  end
end