lib/postgrex/type_module.ex

defmodule Postgrex.TypeModule do
  @moduledoc false

  alias Postgrex.TypeInfo

  def define(module, extensions, opts) do
    opts =
      opts
      |> Keyword.put_new(:decode_binary, :copy)

    config = configure(extensions, opts)
    define_inline(module, config, opts)
  end

  ## Helpers

  defp directives(config, opts) do
    requires =
      for {extension, _} <- config do
        quote do: require(unquote(extension))
      end

    preludes =
      for {extension, {state, _, _}} <- config,
          function_exported?(extension, :prelude, 1),
          do: extension.prelude(state)

    null = Keyword.get(opts, :null)

    moduledoc = Keyword.get(opts, :moduledoc, false)

    quote do
      @moduledoc unquote(moduledoc)
      import Postgrex.BinaryUtils
      require unquote(__MODULE__)
      unquote(requires)
      unquote(preludes)
      unquote(bin_opt_info(opts))
      @compile {:inline, [encode_value: 2]}
      @dialyzer {:no_opaque, [decode_tuple: 5]}
      @null unquote(Macro.escape(null))
    end
  end

  defp bin_opt_info(opts) do
    if Keyword.get(opts, :bin_opt_info) do
      quote do: @compile(:bin_opt_info)
    else
      []
    end
  end

  @anno [generated: true]

  defp find(config) do
    clauses = Enum.flat_map(config, &find_clauses/1)
    clauses = clauses ++ quote do: (_ -> nil)

    quote @anno do
      @doc false
      def find(type_info, formats) do
        case {type_info, formats} do
          unquote(clauses)
        end
      end
    end
  end

  defp find_clauses({extension, {opts, matching, format}}) do
    for {key, value} <- matching do
      [clause] = find_clause(extension, opts, key, value, format)
      clause
    end
  end

  defp find_clause(extension, opts, key, value, :super_binary) do
    quote do
      {%{unquote(key) => unquote(value)} = type_info, formats}
      when formats in [:any, :binary] ->
        oids = unquote(extension).oids(type_info, unquote(opts))
        {:super_binary, unquote(extension), oids}
    end
  end

  defp find_clause(extension, _opts, key, value, format) do
    quote do
      {%{unquote(key) => unquote(value)}, formats}
      when formats in [:any, unquote(format)] ->
        {unquote(format), unquote(extension)}
    end
  end

  defp maybe_rewrite(ast, extension, cases, opts) do
    if Postgrex.Utils.default_extension?(extension) and
         not Keyword.get(opts, :debug_defaults, false) do
      ast
    else
      rewrite(ast, cases)
    end
  end

  defp rewrite(ast, [{:->, clause_meta, _} | _original]) do
    Macro.prewalk(ast, fn
      {kind, meta, [{fun, _, args}, block]} when kind in [:def, :defp] and is_list(args) ->
        {kind, meta, [{fun, clause_meta, args}, block]}

      other ->
        other
    end)
  end

  defp encode(config, define_opts) do
    encodes =
      for {extension, {opts, [_ | _], format}} <- config do
        encode = extension.encode(opts)

        clauses =
          for clause <- encode do
            encode_type(extension, format, clause)
          end

        clauses = [encode_null(extension, format) | clauses]

        quote do
          unquote(encode_value(extension, format))

          unquote(encode_inline(extension, format))

          unquote(clauses |> maybe_rewrite(extension, encode, define_opts))
        end
      end

    quote location: :keep do
      unquote(encodes)

      @doc false
      def encode_params(params, types) do
        encode_params(params, types, [])
      end

      defp encode_params([param | params], [type | types], encoded) do
        encode_params(params, types, [encode_value(param, type) | encoded])
      end

      defp encode_params([], [], encoded), do: Enum.reverse(encoded)
      defp encode_params(params, _, _) when is_list(params), do: :error

      @doc false
      def encode_tuple(tuple, nil, _types) do
        raise DBConnection.EncodeError, """
        cannot encode anonymous tuple #{inspect(tuple)}. \
        Please define a custom Postgrex extension that matches on its underlying type:

            use Postgrex.BinaryExtension, type: "typeinthedb"
        """
      end

      def encode_tuple(tuple, oids, types) do
        encode_tuple(tuple, 1, oids, types, [])
      end

      defp encode_tuple(tuple, n, [oid | oids], [type | types], acc) do
        param = :erlang.element(n, tuple)
        acc = [acc, <<oid::uint32()>> | encode_value(param, type)]
        encode_tuple(tuple, n + 1, oids, types, acc)
      end

      defp encode_tuple(tuple, n, [], [], acc) when tuple_size(tuple) < n do
        acc
      end

      defp encode_tuple(tuple, n, [], [], _) when is_tuple(tuple) do
        raise DBConnection.EncodeError,
              "expected a tuple of size #{n - 1}, got: #{inspect(tuple)}"
      end

      @doc false
      def encode_list(list, type) do
        encode_list(list, type, [])
      end

      defp encode_list([value | rest], type, acc) do
        encode_list(rest, type, [acc | encode_value(value, type)])
      end

      defp encode_list([], _, acc) do
        acc
      end
    end
  end

  defp encode_type(extension, :super_binary, clause) do
    encode_super(extension, clause)
  end

  defp encode_type(extension, _, clause) do
    encode_extension(extension, clause)
  end

  defp encode_extension(extension, clause) do
    case split_extension(clause) do
      {pattern, guard, body} ->
        encode_extension(extension, pattern, guard, body)

      {pattern, body} ->
        encode_extension(extension, pattern, body)
    end
  end

  defp encode_extension(extension, pattern, guard, body) do
    quote do
      defp unquote(extension)(unquote(pattern)) when unquote(guard) do
        unquote(body)
      end
    end
  end

  defp encode_extension(extension, pattern, body) do
    quote do
      defp unquote(extension)(unquote(pattern)) do
        unquote(body)
      end
    end
  end

  defp encode_super(extension, clause) do
    case split_super(clause) do
      {pattern, sub_oids, sub_types, guard, body} ->
        encode_super(extension, pattern, sub_oids, sub_types, guard, body)

      {pattern, sub_oids, sub_types, body} ->
        encode_super(extension, pattern, sub_oids, sub_types, body)
    end
  end

  defp encode_super(extension, pattern, sub_oids, sub_types, guard, body) do
    quote do
      defp unquote(extension)(unquote(pattern), unquote(sub_oids), unquote(sub_types))
           when unquote(guard) do
        unquote(body)
      end
    end
  end

  defp encode_super(extension, pattern, sub_oids, sub_types, body) do
    quote do
      defp unquote(extension)(unquote(pattern), unquote(sub_oids), unquote(sub_types)) do
        unquote(body)
      end
    end
  end

  defp encode_inline(extension, :super_binary) do
    quote do
      @compile {:inline, [{unquote(extension), 3}]}
    end
  end

  defp encode_inline(extension, _) do
    quote do
      @compile {:inline, [{unquote(extension), 1}]}
    end
  end

  defp encode_null(extension, :super_binary) do
    quote do
      defp unquote(extension)(@null, _sub_oids, _sub_types), do: <<-1::int32()>>
    end
  end

  defp encode_null(extension, _) do
    quote do
      defp unquote(extension)(@null), do: <<-1::int32()>>
    end
  end

  defp encode_value(extension, :super_binary) do
    quote do
      @doc false
      def encode_value(value, {unquote(extension), sub_oids, sub_types}) do
        unquote(extension)(value, sub_oids, sub_types)
      end
    end
  end

  defp encode_value(extension, _) do
    quote do
      @doc false
      def encode_value(value, unquote(extension)) do
        unquote(extension)(value)
      end
    end
  end

  defp decode(config, define_opts) do
    rest = quote do: rest
    acc = quote do: acc
    rem = quote do: rem
    full = quote do: full
    rows = quote do: rows

    row_dispatch =
      for {extension, {_, [_ | _], format}} <- config do
        decode_row_dispatch(extension, format, rest, acc, rem, full, rows)
      end

    next_dispatch = decode_rows_dispatch(rest, acc, rem, full, rows)
    row_dispatch = row_dispatch ++ next_dispatch

    decodes =
      for {extension, {opts, [_ | _], format}} <- config do
        decode = extension.decode(opts)

        clauses =
          for clause <- decode do
            decode_type(extension, format, clause, row_dispatch, rest, acc, rem, full, rows)
          end

        null_clauses = decode_null(extension, format, row_dispatch, rest, acc, rem, full, rows)

        quote location: :keep do
          unquote(clauses |> maybe_rewrite(extension, decode, define_opts))

          unquote(null_clauses)
        end
      end

    quote location: :keep do
      unquote(decode_rows(row_dispatch, rest, acc, rem, full, rows))

      unquote(decode_simple())

      unquote(decode_list(config))

      unquote(decode_tuple(config))

      unquote(decodes)
    end
  end

  defp decode_rows(dispatch, rest, acc, rem, full, rows) do
    quote location: :keep, generated: true do
      @doc false
      def decode_rows(binary, types, rows) do
        decode_rows(binary, byte_size(binary), types, rows)
      end

      defp decode_rows(
             <<?D, size::int32(), _::int16(), unquote(rest)::binary>>,
             rem,
             unquote(full),
             unquote(rows)
           )
           when rem > size do
        unquote(rem) = rem - (1 + size)
        unquote(acc) = []

        case unquote(full) do
          unquote(dispatch)
        end
      end

      defp decode_rows(<<?D, size::int32(), rest::binary>>, rem, _, rows) do
        more = size + 1 - rem
        {:more, [?D, <<size::int32()>> | rest], rows, more}
      end

      defp decode_rows(<<?D, rest::binary>>, _, _, rows) do
        {:more, [?D | rest], rows, 0}
      end

      defp decode_rows(<<rest::binary-size(0)>>, _, _, rows) do
        {:more, [], rows, 0}
      end

      defp decode_rows(<<rest::binary>>, _, _, rows) do
        {:ok, rows, rest}
      end
    end
  end

  defp decode_row_dispatch(extension, :super_binary, rest, acc, rem, full, rows) do
    [clause] =
      quote do
        [{unquote(extension), sub_oids, sub_types} | types] ->
          unquote(extension)(
            unquote(rest),
            sub_oids,
            sub_types,
            types,
            unquote(acc),
            unquote(rem),
            unquote(full),
            unquote(rows)
          )
      end

    clause
  end

  defp decode_row_dispatch(extension, _, rest, acc, rem, full, rows) do
    [clause] =
      quote do
        [unquote(extension) | types2] ->
          unquote(extension)(
            unquote(rest),
            types2,
            unquote(acc),
            unquote(rem),
            unquote(full),
            unquote(rows)
          )
      end

    clause
  end

  defp decode_rows_dispatch(rest, acc, rem, full, rows) do
    quote do
      [] ->
        rows = [Enum.reverse(unquote(acc)) | unquote(rows)]
        decode_rows(unquote(rest), unquote(rem), unquote(full), rows)
    end
  end

  defp decode_simple() do
    rest = quote do: rest
    acc = quote do: acc

    dispatch = decode_simple_dispatch(Postgrex.Extensions.Raw, rest, acc)

    quote do
      @doc false
      def decode_simple(binary) do
        decode_simple(binary, [])
      end

      defp decode_simple(<<>>, unquote(acc)), do: Enum.reverse(acc)
      defp decode_simple(<<unquote(rest)::binary>>, unquote(acc)), do: unquote(dispatch)
    end
  end

  defp decode_simple_dispatch(extension, rest, acc) do
    quote do
      unquote(extension)(unquote(rest), unquote(acc), &decode_simple/2)
    end
  end

  defp decode_list(config) do
    rest = quote do: rest

    dispatch =
      for {extension, {_, [_ | _], format}} <- config do
        decode_list_dispatch(extension, format, rest)
      end

    quote do
      @doc false
      def decode_list(<<unquote(rest)::binary>>, type) do
        case type do
          unquote(dispatch)
        end
      end
    end
  end

  defp decode_list_dispatch(extension, :super_binary, rest) do
    [clause] =
      quote do
        {unquote(extension), sub_oids, sub_types} ->
          unquote(extension)(unquote(rest), sub_oids, sub_types, [])
      end

    clause
  end

  defp decode_list_dispatch(extension, _, rest) do
    [clause] =
      quote do
        unquote(extension) ->
          unquote(extension)(unquote(rest), [])
      end

    clause
  end

  defp decode_tuple(config) do
    rest = quote do: rest
    oids = quote do: oids
    n = quote do: n
    acc = quote do: acc

    dispatch =
      for {extension, {_, [_ | _], format}} <- config do
        decode_tuple_dispatch(extension, format, rest, oids, n, acc)
      end

    quote generated: true do
      @doc false
      def decode_tuple(<<rest::binary>>, count, types) when is_integer(count) do
        decode_tuple(rest, count, types, 0, [])
      end

      def decode_tuple(<<rest::binary>>, oids, types) do
        decode_tuple(rest, oids, types, 0, [])
      end

      defp decode_tuple(
             <<oid::int32(), unquote(rest)::binary>>,
             [oid | unquote(oids)],
             types,
             unquote(n),
             unquote(acc)
           ) do
        case types do
          unquote(dispatch)
        end
      end

      defp decode_tuple(<<>>, [], [], n, acc) do
        :erlang.make_tuple(n, @null, acc)
      end

      defp decode_tuple(
             <<oid::int32(), unquote(rest)::binary>>,
             rem,
             types,
             unquote(n),
             unquote(acc)
           )
           when rem > 0 do
        case Postgrex.Types.fetch(oid, types) do
          {:ok, {:binary, type}} ->
            unquote(oids) = rem - 1

            case [type | types] do
              unquote(dispatch)
            end

          {:ok, {:text, _}} ->
            msg =
              "oid `#{oid}` was bootstrapped in text format and can not " <>
                "be decoded inside an anonymous record"

            raise RuntimeError, msg

          {:error, %TypeInfo{type: pg_type}, _mod} ->
            msg = "type `#{pg_type}` can not be handled by the configured extensions"
            raise RuntimeError, msg

          {:error, nil, _mod} ->
            msg = "oid `#{oid}` was not bootstrapped and lacks type information"
            raise RuntimeError, msg
        end
      end

      defp decode_tuple(<<>>, 0, _types, n, acc) do
        :erlang.make_tuple(n, @null, acc)
      end
    end
  end

  defp decode_tuple_dispatch(extension, :super_binary, rest, oids, n, acc) do
    [clause] =
      quote do
        [{unquote(extension), sub_oids, sub_types} | types] ->
          unquote(extension)(
            unquote(rest),
            sub_oids,
            sub_types,
            unquote(oids),
            types,
            unquote(n) + 1,
            unquote(acc)
          )
      end

    clause
  end

  defp decode_tuple_dispatch(extension, _, rest, oids, n, acc) do
    [clause] =
      quote do
        [unquote(extension) | types] ->
          unquote(extension)(unquote(rest), unquote(oids), types, unquote(n) + 1, unquote(acc))
      end

    clause
  end

  defp decode_type(extension, :super_binary, clause, dispatch, rest, acc, rem, full, rows) do
    decode_super(extension, clause, dispatch, rest, acc, rem, full, rows)
  end

  defp decode_type(extension, _, clause, dispatch, rest, acc, rem, full, rows) do
    decode_extension(extension, clause, dispatch, rest, acc, rem, full, rows)
  end

  defp decode_null(extension, :super_binary, dispatch, rest, acc, rem, full, rows) do
    decode_super_null(extension, dispatch, rest, acc, rem, full, rows)
  end

  defp decode_null(extension, _, dispatch, rest, acc, rem, full, rows) do
    decode_extension_null(extension, dispatch, rest, acc, rem, full, rows)
  end

  defp decode_extension(extension, clause, dispatch, rest, acc, rem, full, rows) do
    case split_extension(clause) do
      {pattern, guard, body} ->
        decode_extension(extension, pattern, guard, body, dispatch, rest, acc, rem, full, rows)

      {pattern, body} ->
        decode_extension(extension, pattern, body, dispatch, rest, acc, rem, full, rows)
    end
  end

  defp decode_extension(extension, pattern, guard, body, dispatch, rest, acc, rem, full, rows) do
    quote do
      defp unquote(extension)(
             <<unquote(pattern), unquote(rest)::binary>>,
             types,
             acc,
             unquote(rem),
             unquote(full),
             unquote(rows)
           )
           when unquote(guard) do
        unquote(acc) = [unquote(body) | acc]

        case types do
          unquote(dispatch)
        end
      end

      defp unquote(extension)(<<unquote(pattern), rest::binary>>, acc)
           when unquote(guard) do
        unquote(extension)(rest, [unquote(body) | acc])
      end

      defp unquote(extension)(<<unquote(pattern), rest::binary>>, acc, callback)
           when unquote(guard) do
        unquote(extension)(rest, [unquote(body) | acc], callback)
      end

      defp unquote(extension)(<<unquote(pattern), rest::binary>>, oids, types, n, acc)
           when unquote(guard) do
        decode_tuple(rest, oids, types, n, [{n, unquote(body)} | acc])
      end
    end
  end

  defp decode_extension(extension, pattern, body, dispatch, rest, acc, rem, full, rows) do
    quote do
      defp unquote(extension)(
             <<unquote(pattern), unquote(rest)::binary>>,
             types,
             acc,
             unquote(rem),
             unquote(full),
             unquote(rows)
           ) do
        unquote(acc) = [unquote(body) | acc]

        case types do
          unquote(dispatch)
        end
      end

      defp unquote(extension)(<<unquote(pattern), rest::binary>>, acc) do
        decoded = unquote(body)
        unquote(extension)(rest, [decoded | acc])
      end

      defp unquote(extension)(<<unquote(pattern), rest::binary>>, acc, callback) do
        decoded = unquote(body)
        unquote(extension)(rest, [decoded | acc], callback)
      end

      defp unquote(extension)(<<unquote(pattern), rest::binary>>, oids, types, n, acc) do
        decode_tuple(rest, oids, types, n, [{n, unquote(body)} | acc])
      end
    end
  end

  defp decode_extension_null(extension, dispatch, rest, acc, rem, full, rows) do
    quote do
      defp unquote(extension)(
             <<-1::int32(), unquote(rest)::binary>>,
             types,
             acc,
             unquote(rem),
             unquote(full),
             unquote(rows)
           ) do
        unquote(acc) = [@null | acc]

        case types do
          unquote(dispatch)
        end
      end

      defp unquote(extension)(<<-1::int32(), rest::binary>>, acc) do
        unquote(extension)(rest, [@null | acc])
      end

      defp unquote(extension)(<<>>, acc) do
        acc
      end

      defp unquote(extension)(<<-1::int32(), rest::binary>>, acc, callback) do
        unquote(extension)(rest, [@null | acc], callback)
      end

      defp unquote(extension)(<<rest::binary-size(0)>>, acc, callback) do
        callback.(rest, acc)
      end

      defp unquote(extension)(<<-1::int32(), rest::binary>>, oids, types, n, acc) do
        decode_tuple(rest, oids, types, n, acc)
      end
    end
  end

  defp split_extension({:->, _, [head, body]}) do
    case head do
      [{:when, _, [pattern, guard]}] ->
        {pattern, guard, body}

      [pattern] ->
        {pattern, body}
    end
  end

  defp decode_super(extension, clause, dispatch, rest, acc, rem, full, rows) do
    case split_super(clause) do
      {pattern, oids, types, guard, body} ->
        decode_super(
          extension,
          pattern,
          oids,
          types,
          guard,
          body,
          dispatch,
          rest,
          acc,
          rem,
          full,
          rows
        )

      {pattern, oids, types, body} ->
        decode_super(extension, pattern, oids, types, body, dispatch, rest, acc, rem, full, rows)
    end
  end

  defp decode_super(
         extension,
         pattern,
         sub_oids,
         sub_types,
         guard,
         body,
         dispatch,
         rest,
         acc,
         rem,
         full,
         rows
       ) do
    quote do
      defp unquote(extension)(
             <<unquote(pattern), unquote(rest)::binary>>,
             unquote(sub_oids),
             unquote(sub_types),
             types,
             acc,
             unquote(rem),
             unquote(full),
             unquote(rows)
           )
           when unquote(guard) do
        unquote(acc) = [unquote(body) | acc]

        case types do
          unquote(dispatch)
        end
      end

      defp unquote(extension)(
             <<unquote(pattern), rest::binary>>,
             unquote(sub_oids),
             unquote(sub_types),
             acc
           )
           when unquote(guard) do
        acc = [unquote(body) | acc]
        unquote(extension)(rest, unquote(sub_oids), unquote(sub_types), acc)
      end

      defp unquote(extension)(
             <<unquote(pattern), rest::binary>>,
             unquote(sub_oids),
             unquote(sub_types),
             oids,
             types,
             n,
             acc
           )
           when unquote(guard) do
        decode_tuple(rest, oids, types, n, [{n, unquote(body)} | acc])
      end
    end
  end

  defp decode_super(
         extension,
         pattern,
         sub_oids,
         sub_types,
         body,
         dispatch,
         rest,
         acc,
         rem,
         full,
         rows
       ) do
    quote do
      defp unquote(extension)(
             <<unquote(pattern), unquote(rest)::binary>>,
             unquote(sub_oids),
             unquote(sub_types),
             types,
             acc,
             unquote(rem),
             unquote(full),
             unquote(rows)
           ) do
        unquote(acc) = [unquote(body) | acc]

        case types do
          unquote(dispatch)
        end
      end

      defp unquote(extension)(
             <<unquote(pattern), rest::binary>>,
             unquote(sub_oids),
             unquote(sub_types),
             acc
           ) do
        acc = [unquote(body) | acc]
        unquote(extension)(rest, unquote(sub_oids), unquote(sub_types), acc)
      end

      defp unquote(extension)(
             <<unquote(pattern), rest::binary>>,
             unquote(sub_oids),
             unquote(sub_types),
             oids,
             types,
             n,
             acc
           ) do
        acc = [{n, unquote(body)} | acc]
        decode_tuple(rest, oids, types, n, acc)
      end
    end
  end

  defp decode_super_null(extension, dispatch, rest, acc, rem, full, rows) do
    quote do
      defp unquote(extension)(
             <<-1::int32(), unquote(rest)::binary>>,
             _sub_oids,
             _sub_types,
             types,
             acc,
             unquote(rem),
             unquote(full),
             unquote(rows)
           ) do
        unquote(acc) = [@null | acc]

        case types do
          unquote(dispatch)
        end
      end

      defp unquote(extension)(<<-1::int32(), rest::binary>>, sub_oids, sub_types, acc) do
        unquote(extension)(rest, sub_oids, sub_types, [@null | acc])
      end

      defp unquote(extension)(<<>>, _sub_oid, _sub_types, acc) do
        acc
      end

      defp unquote(extension)(
             <<-1::int32(), rest::binary>>,
             _sub_oids,
             _sub_types,
             oids,
             types,
             n,
             acc
           ) do
        decode_tuple(rest, oids, types, n, acc)
      end
    end
  end

  defp split_super({:->, _, [head, body]}) do
    case head do
      [{:when, _, [pattern, sub_oids, sub_types, guard]}] ->
        {pattern, sub_oids, sub_types, guard, body}

      [pattern, sub_oids, sub_types] ->
        {pattern, sub_oids, sub_types, body}
    end
  end

  defp configure(extensions, opts) do
    defaults = Postgrex.Utils.default_extensions(opts)
    Enum.map(extensions ++ defaults, &configure/1)
  end

  defp configure({extension, opts}) do
    state = extension.init(opts)
    matching = extension.matching(state)
    format = extension.format(state)
    {extension, {state, matching, format}}
  end

  defp configure(extension) do
    configure({extension, []})
  end

  defp define_inline(module, config, opts) do
    quoted = [
      directives(config, opts),
      find(config),
      encode(config, opts),
      decode(config, opts)
    ]

    Module.create(module, quoted, Macro.Env.location(__ENV__))
  end
end