lib/kayrock/generate.ex

defmodule Kayrock.Generate do
  @moduledoc """
  Macros for generating modules for the protocol schema
  """

  require Kayrock.Serialize
  require Kayrock.Deserialize

  def generate_schema_metadata(schema_module) do
    all_apis = schema_module.all_apis()

    version_ranges =
      Enum.map(all_apis, fn api ->
        quote do
          def version_range(unquote(api)), do: unquote(schema_module.vsn_range(api))
        end
      end)

    api_keys =
      List.flatten(
        Enum.map(all_apis, fn api ->
          key = schema_module.api_key(api)

          quote do
            def api_key(unquote(api)), do: unquote(key)
            def api_key(unquote(key)), do: unquote(api)
          end
        end)
      )

    request_schemas =
      List.flatten(
        Enum.map(all_apis, fn api ->
          {min_vsn, max_vsn} = schema_module.vsn_range(api)

          Enum.map(min_vsn..max_vsn, fn vsn ->
            quote do
              def request_schema(unquote(api), unquote(vsn)) do
                unquote(schema_module.req(api, vsn))
              end
            end
          end)
        end)
      )

    response_schemas =
      List.flatten(
        Enum.map(all_apis, fn api ->
          {min_vsn, max_vsn} = schema_module.vsn_range(api)

          Enum.map(min_vsn..max_vsn, fn vsn ->
            quote do
              def response_schema(unquote(api), unquote(vsn)) do
                unquote(schema_module.rsp(api, vsn))
              end
            end
          end)
        end)
      )

    min_known_error_code = -1
    max_known_error_code = 71

    error_codes =
      Enum.map(min_known_error_code..max_known_error_code, fn error_code ->
        error = schema_module.ec(error_code)

        quote do
          def error_code_to_error(unquote(error_code)) do
            unquote(error)
          end
        end
      end)

    quote do
      defmodule Kayrock.KafkaSchemaMetadata do
        @moduledoc false
        _ = "THIS CODE IS GENERATED BY KAYROCK"

        def all_apis do
          unquote(all_apis)
        end

        unquote_splicing(version_ranges)

        unquote_splicing(api_keys)

        unquote_splicing(request_schemas)

        unquote_splicing(response_schemas)

        def min_known_error_code, do: unquote(min_known_error_code)
        def max_known_error_code, do: unquote(max_known_error_code)

        unquote_splicing(error_codes)
      end
    end
  end

  def build_all(api, schema_module) do
    modname = Module.concat([Kayrock, Macro.camelize("#{api}")])
    contents = build_modules(api, schema_module, modname)

    quote do
      defmodule unquote(modname) do
        @api unquote(api)
        @moduledoc """
        Kayrock-generated module for the Kafka `#{@api}` API
        """
        _ = " THIS CODE IS GENERATED BY KAYROCK"
        unquote_splicing(contents)
      end
    end
  end

  def build_modules(api, schema_module, modname) do
    {vmin, vmax} = schema_module.vsn_range(api)

    List.flatten(
      [
        quote do
          @vmin unquote(vmin)
          @vmax unquote(vmax)
        end
      ] ++
        Enum.map(vmin..vmax, &make_request_module(api, &1, schema_module)) ++
        [
          quote do
            @doc "Returns a request struct for this API with the given version"
            @spec get_request_struct(integer) :: request_t
          end
        ] ++
        Enum.map(vmin..vmax, &make_request_getter(&1)) ++
        Enum.map(vmin..vmax, &make_response_module(api, &1, schema_module)) ++
        [
          quote do
            @doc "Deserializes raw wire data for this API with the given version"
            @spec deserialize(integer, binary) :: {response_t, binary}
          end
        ] ++
        Enum.map(vmin..vmax, &make_response_deserializer(&1)) ++
        make_types(modname, vmin, vmax) ++
        [
          quote do
            @doc "Returns the minimum version of this API supported by Kayrock (#{@vmin})"
            @spec min_vsn :: integer
            def min_vsn, do: unquote(vmin)
          end,
          quote do
            @doc "Returns the maximum version of this API supported by Kayrock (#{@vmax})"
            @spec max_vsn :: integer
            def max_vsn, do: unquote(vmax)
          end
        ]
    )
  end

  def make_types(root_module, vmin, vmax) do
    request_ts =
      for v <- vmin..vmax do
        quote do
          unquote(Module.concat([root_module, "V#{v}", Request])).t()
        end
      end

    response_ts =
      for v <- vmin..vmax do
        quote do
          unquote(Module.concat([root_module, "V#{v}", Response])).t()
        end
      end

    [
      quote do
        @typedoc "Union type for all request structs for this API"
        @type request_t :: unquote(Enum.reduce(request_ts, &{:|, [], [&1, &2]}))
      end,
      quote do
        @typedoc "Union type for all response structs for this API"
        @type response_t :: unquote(Enum.reduce(response_ts, &{:|, [], [&1, &2]}))
      end
    ]
  end

  def make_request_getter(vsn) do
    quote do
      def get_request_struct(unquote(vsn)) do
        %unquote(modname(vsn, Request)){}
      end
    end
  end

  def make_response_deserializer(vsn) do
    quote do
      def deserialize(unquote(vsn), data) do
        unquote(modname(vsn, Response)).deserialize(data)
      end
    end
  end

  def make_request_module(api, vsn, schema_module) do
    schema = schema_module.req(api, vsn)

    request_module_name = modname(vsn, Request)
    response_module_name = modname(vsn, Response)

    struct = generate_request_struct(api, vsn, request_module_name, response_module_name, schema)
    serializer = generate_serializer(request_module_name)

    List.flatten([struct, serializer])
  end

  def make_response_module(api, vsn, schema_module) do
    schema = schema_module.rsp(api, vsn)

    response_module_name = modname(vsn, Response)

    struct = generate_response_struct(api, vsn, response_module_name, schema)

    List.flatten([struct])
  end

  defp modname(vsn, suffix) do
    Module.concat(["V#{vsn}", suffix])
  end

  def generate_request_struct(api, vsn, modname, response_modname, schema) do
    fields =
      Enum.reduce(schema, [], fn {k, v}, acc ->
        acc ++ [{k, default_val(v)}]
      end) ++ [correlation_id: nil, client_id: nil]

    field_serializers = Enum.map(schema, &field_serializer(&1, :struct))

    struct_types =
      Enum.reduce(schema, [], fn {k, v}, acc ->
        acc ++ [{k, describe_type(api, k, v)}]
      end) ++
        [
          correlation_id:
            quote do
              nil | integer()
            end,
          client_id:
            quote do
              nil | binary()
            end
        ]

    # some of the apis don't really have any fields to serialize.  if we don't
    # do this then we end up with an "unused import" warning
    imports =
      if length(fields) > 2 do
        [
          quote do
            import Elixir.Kayrock.Serialize
          end
        ]
      else
        []
      end

    quote do
      defmodule unquote(modname) do
        @vsn unquote(vsn)
        @api unquote(api)
        @schema unquote(schema)
        @moduledoc """
        Kayrock-generated request struct for Kafka `#{@api}` v#{@vsn} API
        messages

        The schema of this API is
        ```
        #{inspect(@schema, pretty: true)}
        ```
        """
        _ = " THIS CODE IS GENERATED BY KAYROCK"
        defstruct unquote(fields)

        unquote_splicing(imports)

        @typedoc """
        Request struct for the Kafka `#{@api}` API v#{@vsn}
        """
        @type t :: %__MODULE__{unquote_splicing(struct_types)}

        @doc "Returns the Kafka API key for this API"
        @spec api_key :: integer
        def api_key, do: Kayrock.KafkaSchemaMetadata.api_key(unquote(api))

        @doc "Returns the API version (#{@vsn}) implemented by this module"
        @spec api_vsn :: integer
        def api_vsn, do: unquote(vsn)

        @doc """
        Returns a function that can be used to deserialize the wire response from the
        broker for this message type
        """
        @spec response_deserializer :: (binary -> {unquote(response_modname).t(), binary})
        def response_deserializer, do: &unquote(response_modname).deserialize/1

        @doc """
        Returns the schema of this message

        See [above](#).
        """
        @spec schema :: term
        def schema, do: unquote(schema)

        @doc "Serialize a message to binary data for transfer to a Kafka broker"
        @spec serialize(t()) :: iodata
        def serialize(%unquote(modname){} = struct) do
          [
            <<api_key()::16, api_vsn()::16, struct.correlation_id::32,
              byte_size(struct.client_id)::16, struct.client_id::binary>>,
            unquote(field_serializers)
          ]
        end
      end
    end
  end

  def describe_type(:sync_group, :member_assignment, :bytes) do
    quote do
      nil | Kayrock.MemberAssignment.t()
    end
  end

  def describe_type(_, _, :bytes) do
    quote do
      nil | bitstring()
    end
  end

  def describe_type(_, _, :nullable_string) do
    quote do
      nil | binary()
    end
  end

  def describe_type(_, _, :string) do
    quote do
      nil | binary()
    end
  end

  def describe_type(_, _, :records) do
    quote do
      nil | Kayrock.MessageSet.t() | Kayrock.RecordBatch.t()
    end
  end

  def describe_type(api, field, mapspec) when is_list(mapspec) do
    field_types =
      Enum.map(mapspec, fn {k, v} ->
        {k, describe_type(api, {field, k}, v)}
      end)

    quote do
      %{unquote_splicing(field_types)}
    end
  end

  def describe_type(api, field, {:array, arrayspec}) do
    inner_type = describe_type(api, field, arrayspec)

    quote do
      [unquote(inner_type)]
    end
  end

  def describe_type(_, _, t) when t in [:boolean, :int8, :int16, :int32, :int64] do
    quote do
      nil | integer()
    end
  end

  def describe_type(_, _, t) do
    IO.puts("Unhandled type: #{inspect(t)} will be spec'ed as term()")

    quote do
      term()
    end
  end

  def generate_response_struct(api, vsn, modname, schema) do
    fields =
      Enum.reduce(schema, [], fn {k, v}, acc ->
        acc ++ [{k, default_val(v)}]
      end) ++ [correlation_id: nil]

    struct_types =
      Enum.reduce(schema, [], fn {k, v}, acc ->
        acc ++ [{k, describe_type(api, k, v)}]
      end) ++
        [
          correlation_id:
            quote do
              integer()
            end
        ]

    {first_field_name, fields_with_next_field} = build_field_zip(schema)

    field_deserializers =
      List.flatten(
        Enum.map(fields_with_next_field, fn {s, n} -> generate_field_deserializer(:root, s, n) end)
      )

    quote do
      defmodule unquote(modname) do
        @vsn unquote(vsn)
        @api unquote(api)
        @schema unquote(schema)
        @moduledoc """
        Kayrock-generated response struct for Kafka `#{@api}` v#{@vsn} API
        messages

        The schema of this API is
        ```
        #{inspect(@schema, pretty: true)}
        ```
        """
        _ = " THIS CODE IS GENERATED BY KAYROCK"
        defstruct unquote(fields)

        @typedoc """
        Response struct for the Kafka `#{@api}` API v#{@vsn}
        """
        @type t :: %__MODULE__{unquote_splicing(struct_types)}

        import Elixir.Kayrock.Deserialize

        @doc "Returns the Kafka API key for this API"
        @spec api_key :: integer
        def api_key, do: Kayrock.KafkaSchemaMetadata.api_key(unquote(api))

        @doc "Returns the API version (#{@vsn}) implemented by this module"
        @spec api_vsn :: integer
        def api_vsn, do: unquote(vsn)

        @doc """
        Returns the schema of this message

        See [above](#).
        """
        @spec schema :: term
        def schema, do: unquote(schema)

        @doc """
        Deserialize data for this version of this API
        """
        @spec deserialize(binary) :: {t(), binary}
        def deserialize(data) do
          <<correlation_id::32-signed, rest::binary>> = data

          deserialize_field(
            :root,
            unquote(first_field_name),
            %__MODULE__{correlation_id: correlation_id},
            rest
          )
        end

        unquote_splicing(field_deserializers)

        defp deserialize_field(_, nil, acc, rest) do
          {acc, rest}
        end
      end
    end
  end

  # we need 'next field' to build the recursive deserializer
  # we also tack on a nil to signify the final field
  defp build_field_zip(schema) do
    field_names = Keyword.keys(schema)
    [first_field_name | rest_of_fields] = field_names
    fields_with_next_field = Enum.zip(schema, rest_of_fields ++ [nil])
    {first_field_name, fields_with_next_field}
  end

  def generate_field_deserializer(scope, {:member_assignment, :bytes}, next_field_name) do
    quote do
      defp deserialize_field(unquote(scope), :member_assignment, acc, data) do
        {val, rest} = Kayrock.MemberAssignment.deserialize(data)

        deserialize_field(
          unquote(scope),
          unquote(next_field_name),
          Map.put(acc, :member_assignment, val),
          rest
        )
      end
    end
  end

  def generate_field_deserializer(scope, {field_name, {:array, type}}, next_field_name)
      when type in Kayrock.Deserialize.primitive_types() do
    quote do
      defp deserialize_field(unquote(scope), unquote(field_name), acc, data) do
        {val, rest} = deserialize_array(unquote(type), data)

        deserialize_field(
          unquote(scope),
          unquote(next_field_name),
          Map.put(acc, unquote(field_name), Enum.reverse(val)),
          rest
        )
      end
    end
  end

  def generate_field_deserializer(scope, {field_name, {:array, elements_schema}}, next_field_name) do
    {first_field_name, fields_with_next_field} = build_field_zip(elements_schema)

    [
      Enum.map(fields_with_next_field, fn {f, n} ->
        generate_field_deserializer(field_name, f, n)
      end),
      quote do
        defp deserialize_field(unquote(scope), unquote(field_name), acc, data) do
          <<num_elements::32-signed, rest::binary>> = data

          {vals, rest} =
            if num_elements > 0 do
              Enum.reduce(1..num_elements, {[], rest}, fn _ix, {acc, d} ->
                {val, r} =
                  deserialize_field(unquote(field_name), unquote(first_field_name), %{}, d)

                {[val | acc], r}
              end)
            else
              {[], rest}
            end

          deserialize_field(
            unquote(scope),
            unquote(next_field_name),
            Map.put(acc, unquote(field_name), Enum.reverse(vals)),
            rest
          )
        end
      end
    ]
  end

  def generate_field_deserializer(scope, {field_name, type}, next_field_name)
      when type in Kayrock.Deserialize.primitive_types() do
    quote do
      defp deserialize_field(unquote(scope), unquote(field_name), acc, data) do
        {val, rest} = deserialize(unquote(type), data)

        deserialize_field(
          unquote(scope),
          unquote(next_field_name),
          Map.put(acc, unquote(field_name), val),
          rest
        )
      end
    end
  end

  def generate_field_deserializer(scope, {field_name, :records}, next_field_name) do
    quote do
      defp deserialize_field(unquote(scope), unquote(field_name), acc, data) do
        <<msg_set_size::32-signed, msg_set_data::size(msg_set_size)-binary, rest::bits>> = data

        val = Elixir.Kayrock.RecordBatch.deserialize(msg_set_size, msg_set_data)

        deserialize_field(
          unquote(scope),
          unquote(next_field_name),
          Map.put(acc, unquote(field_name), val),
          rest
        )
      end
    end
  end

  def generate_field_deserializer(scope, {field_name, struct_schema}, next_field_name)
      when is_list(struct_schema) do
    {first_field_name, fields_with_next_field} = build_field_zip(struct_schema)

    [
      Enum.map(fields_with_next_field, fn {f, n} ->
        generate_field_deserializer(field_name, f, n)
      end),
      quote do
        defp deserialize_field(unquote(scope), unquote(field_name), acc, data) do
          {val, rest} =
            deserialize_field(unquote(field_name), unquote(first_field_name), %{}, data)

          deserialize_field(
            unquote(scope),
            unquote(next_field_name),
            Map.put(acc, unquote(field_name), val),
            rest
          )
        end
      end
    ]
  end

  def generate_serializer(modname) do
    quote do
      defimpl Elixir.Kayrock.Request, for: unquote(modname) do
        def serialize(%unquote(modname){} = struct) do
          try do
            unquote(modname).serialize(struct)
          rescue
            e ->
              reraise(Kayrock.InvalidRequestError, {e, struct}, __STACKTRACE__)
          end
        end

        def api_vsn(%unquote(modname){}) do
          unquote(modname).api_vsn()
        end

        def response_deserializer(%unquote(modname){}) do
          unquote(modname).response_deserializer()
        end
      end
    end
  end

  ######################################################################
  # SPECIAL CASES
  #

  # protocol metadata for JoinGroup request
  # this is 'bytes' in the spec but it is expected to be the serialization of a
  # ProtocolMetadata message defined in the consumer group API, so we handle
  # both cases here
  defp field_serializer({:protocol_metadata, :bytes}, varname) do
    quote do
      case Map.fetch!(unquote(Macro.var(varname, __MODULE__)), :protocol_metadata) do
        %Kayrock.GroupProtocolMetadata{} = m ->
          Kayrock.Serialize.serialize(
            :iodata_bytes,
            Kayrock.GroupProtocolMetadata.serialize(m)
          )

        b when is_binary(b) ->
          Kayrock.Serialize.serialize(:bytes, b)
      end
    end
  end

  # member assignment for SyncGroup request
  # this is 'bytes' in the spec but it is expected to be the serialization of a
  # MemberAssignment message defined in the consumer group API, so we handle
  # both cases here
  defp field_serializer({:member_assignment, :bytes}, varname) do
    quote do
      case Map.fetch!(unquote(Macro.var(varname, __MODULE__)), :member_assignment) do
        %Kayrock.MemberAssignment{} = m ->
          Kayrock.Serialize.serialize(
            :iodata_bytes,
            Kayrock.MemberAssignment.serialize(m)
          )

        b when is_binary(b) ->
          Kayrock.Serialize.serialize(:bytes, b)
      end
    end
  end

  # END SPECIAL CASES
  ######################################################################

  defp field_serializer({name, type}, varname) when type in Kayrock.Serialize.primitive_types() do
    quote do
      serialize(unquote(type), Map.fetch!(unquote(Macro.var(varname, __MODULE__)), unquote(name)))
    end
  end

  defp field_serializer({name, :records}, varname) do
    quote do
      Elixir.Kayrock.Request.serialize(
        Map.fetch!(unquote(Macro.var(varname, __MODULE__)), unquote(name))
      )
    end
  end

  defp field_serializer({name, {:array, type}}, varname)
       when type in Kayrock.Serialize.primitive_types() do
    quote do
      serialize_array(
        unquote(type),
        Map.fetch!(unquote(Macro.var(varname, __MODULE__)), unquote(name))
      )
    end
  end

  defp field_serializer({name, {:array, el}}, varname) when is_list(el) do
    subfield_serializers = Enum.map(el, &field_serializer(&1, :v))

    quote do
      case Map.fetch!(unquote(Macro.var(varname, __MODULE__)), unquote(name)) do
        nil ->
          <<-1::32-signed>>

        [] ->
          <<0::32-signed>>

        vals when is_list(vals) ->
          [
            <<length(vals)::32-signed>>,
            for v <- vals do
              unquote(subfield_serializers)
            end
          ]
      end
    end
  end

  defp field_serializer({name, {:array, {:array, type}}}, varname)
       when type in Kayrock.Serialize.primitive_types() do
    quote do
      case Map.fetch!(unquote(Macro.var(varname, __MODULE__)), unquote(name)) do
        nil ->
          <<-1::32-signed>>

        [] ->
          <<0::32-signed>>

        vals when is_list(vals) ->
          [
            <<length(vals)::32-signed>>,
            for v <- vals do
              serialize_array(unquote(type), v)
            end
          ]
      end
    end
  end

  defp field_serializer({name, schema}, varname) when is_list(schema) do
    subfield_serializers = Enum.map(schema, &field_serializer(&1, :v))

    quote do
      v = Map.fetch!(unquote(Macro.var(varname, __MODULE__)), unquote(name))
      unquote(subfield_serializers)
    end
  end

  defp default_val({:array, _}), do: []
  defp default_val(_), do: nil
end