Skip to main content

lib/pb/wire.ex

defmodule PB.Wire do
  @moduledoc false

  import Bitwise

  @max_field_number 0x1FFFFFFF

  # Mirrors the embedded-message cap in PB.Runtime.Decoder.Scan (@max_decode_depth):
  # nested groups recurse here at the wire layer before the decoder's depth check
  # ever runs, so adversarially nested start-group tags need their own guard.
  @max_group_depth 100

  def wire_type(:TYPE_DOUBLE), do: 1
  def wire_type(:TYPE_FLOAT), do: 5
  def wire_type(:TYPE_INT64), do: 0
  def wire_type(:TYPE_UINT64), do: 0
  def wire_type(:TYPE_INT32), do: 0
  def wire_type(:TYPE_FIXED64), do: 1
  def wire_type(:TYPE_FIXED32), do: 5
  def wire_type(:TYPE_BOOL), do: 0
  def wire_type(:TYPE_STRING), do: 2
  def wire_type(:TYPE_MESSAGE), do: 2
  def wire_type(:TYPE_BYTES), do: 2
  def wire_type(:TYPE_UINT32), do: 0
  def wire_type(:TYPE_ENUM), do: 0
  def wire_type(:TYPE_SFIXED32), do: 5
  def wire_type(:TYPE_SFIXED64), do: 1
  def wire_type(:TYPE_SINT32), do: 0
  def wire_type(:TYPE_SINT64), do: 0

  def wire_type(:TYPE_GROUP), do: 3

  def packable_type?(t),
    do:
      t in [
        :TYPE_DOUBLE,
        :TYPE_FLOAT,
        :TYPE_INT64,
        :TYPE_UINT64,
        :TYPE_INT32,
        :TYPE_UINT32,
        :TYPE_FIXED64,
        :TYPE_FIXED32,
        :TYPE_SFIXED64,
        :TYPE_SFIXED32,
        :TYPE_BOOL,
        :TYPE_ENUM,
        :TYPE_SINT32,
        :TYPE_SINT64
      ]

  def validate_field_number(fnum) when fnum >= 1 and fnum <= @max_field_number, do: :ok
  def validate_field_number(_fnum), do: malformed(:invalid_field_number)

  def validate_wire_type(wtype) when wtype in 0..5, do: :ok
  def validate_wire_type(_wtype), do: malformed(:invalid_wire_type)

  def encode_tag(fnum, wtype), do: encode_varint(fnum <<< 3 ||| wtype)

  def encode_varint(v) when v >= 0 and v < 128, do: <<v::8>>

  def encode_varint(v) when v >= 128 do
    <<1::1, v::7, encode_varint(v >>> 7)::binary>>
  end

  def decode_varint(bin), do: decode_varint_acc(bin, 0, 0)

  def decode_tag(bin) do
    with {:ok, value, rest} <- decode_varint(bin),
         size = byte_size(bin) - byte_size(rest),
         <<raw::binary-size(size), _::binary>> = bin,
         :ok <- validate_canonical_varint(raw, value) do
      {:ok, value, rest}
    end
  end

  defp decode_varint_acc(<<0::1, b::7, rest::binary>>, acc, shift) when shift < 63,
    do: {:ok, acc ||| b <<< shift, rest}

  defp decode_varint_acc(<<0::1, b::7, rest::binary>>, acc, 63) when b <= 1,
    do: {:ok, acc ||| b <<< 63, rest}

  defp decode_varint_acc(<<0::1, _b::7, _rest::binary>>, _acc, 63),
    do: malformed(:varint_overflow)

  defp decode_varint_acc(<<1::1, b::7, rest::binary>>, acc, shift) when shift < 63,
    do: decode_varint_acc(rest, acc ||| b <<< shift, shift + 7)

  defp decode_varint_acc(<<1::1, _b::7, _rest::binary>>, _acc, _shift),
    do: malformed(:varint_overflow)

  defp decode_varint_acc(<<>>, _acc, _shift), do: malformed(:truncated_varint)

  def read_wire(0, bin), do: decode_varint(bin)
  def read_wire(1, <<v::binary-size(8), r::binary>>), do: {:ok, v, r}
  def read_wire(1, _bin), do: malformed(:truncated_fixed64)

  def read_wire(2, bin) do
    with {:ok, len, rest} <- decode_varint(bin) do
      read_length_delimited(len, rest)
    end
  end

  def read_wire(3, _bin),
    do: malformed(:unsupported_group)

  def read_wire(4, _bin),
    do: malformed(:unsupported_group)

  def read_wire(5, <<v::binary-size(4), r::binary>>), do: {:ok, v, r}
  def read_wire(5, _bin), do: malformed(:truncated_fixed32)
  def read_wire(_wtype, _bin), do: malformed(:invalid_wire_type)

  defp validate_canonical_varint(raw, value) do
    if raw == encode_varint(value) do
      :ok
    else
      malformed(:noncanonical_varint)
    end
  end

  def read_unknown_field(field_number, wire_type, bin) do
    read_wire_raw(field_number, wire_type, bin)
  end

  def read_message_field(_field_number, :length_prefixed, bin) do
    with {:ok, raw_body, rest} <- read_wire_raw(2, bin),
         {:ok, body, <<>>} <- read_wire(2, raw_body) do
      {:ok, body, raw_body, rest}
    end
  end

  def read_message_field(field_number, :delimited, bin)
      when is_integer(field_number) and field_number > 0 do
    read_group_field(field_number, bin)
  end

  def encode_message_value(field_number, :delimited, body) do
    [body, encode_tag(field_number, 4)]
  end

  def encode_message_value(_field_number, :length_prefixed, body) do
    [encode_varint(IO.iodata_length(body)), body]
  end

  def read_wire_raw(wtype, bin), do: read_wire_raw(nil, wtype, bin)

  def read_wire_raw(_field_number, 0, bin) do
    with {:ok, _value, rest} <- decode_varint(bin) do
      size = byte_size(bin) - byte_size(rest)
      <<raw::binary-size(size), _::binary>> = bin
      {:ok, raw, rest}
    end
  end

  def read_wire_raw(_field_number, 1, <<raw::binary-size(8), rest::binary>>),
    do: {:ok, raw, rest}

  def read_wire_raw(_field_number, 1, _bin), do: malformed(:truncated_fixed64)

  def read_wire_raw(_field_number, 2, bin) do
    with {:ok, len, rest_after_len} <- decode_varint(bin),
         {:ok, _value, rest} <- read_length_delimited(len, rest_after_len) do
      size = byte_size(bin) - byte_size(rest)
      <<raw::binary-size(size), _::binary>> = bin
      {:ok, raw, rest}
    end
  end

  def read_wire_raw(field_number, 3, bin) when is_integer(field_number) do
    with {:ok, _body, raw, rest} <- read_group_field(field_number, bin) do
      {:ok, raw, rest}
    end
  end

  def read_wire_raw(_field_number, 3, _bin),
    do: malformed(:unsupported_group)

  def read_wire_raw(_field_number, 4, _bin),
    do: malformed(:unexpected_end_group)

  def read_wire_raw(_field_number, 5, <<raw::binary-size(4), rest::binary>>),
    do: {:ok, raw, rest}

  def read_wire_raw(_field_number, 5, _bin), do: malformed(:truncated_fixed32)
  def read_wire_raw(_field_number, _wtype, _bin), do: malformed(:invalid_wire_type)

  def skip_wire(0, bin) do
    with {:ok, _value, rest} <- decode_varint(bin), do: {:ok, rest}
  end

  def skip_wire(1, <<_::64, r::binary>>), do: {:ok, r}
  def skip_wire(1, _bin), do: malformed(:truncated_fixed64)

  def skip_wire(2, bin) do
    with {:ok, len, rest} <- decode_varint(bin),
         {:ok, _value, rest} <- read_length_delimited(len, rest) do
      {:ok, rest}
    end
  end

  def skip_wire(3, _bin),
    do: malformed(:unsupported_group)

  def skip_wire(4, _bin),
    do: malformed(:unsupported_group)

  def skip_wire(5, <<_::32, r::binary>>), do: {:ok, r}
  def skip_wire(5, _bin), do: malformed(:truncated_fixed32)
  def skip_wire(_wtype, _bin), do: malformed(:invalid_wire_type)

  defp read_group_field(field_number, bin), do: read_group_field(field_number, bin, 0)

  defp read_group_field(_field_number, _bin, depth) when depth >= @max_group_depth do
    malformed(:max_group_depth_exceeded)
  end

  defp read_group_field(field_number, bin, depth) do
    read_group_body(field_number, bin, [], depth)
  end

  defp read_group_body(_field_number, <<>>, _acc, _depth), do: malformed(:truncated_group)

  defp read_group_body(field_number, bin, acc, depth) do
    with {:ok, tag_val, rest_after_tag} <- decode_tag(bin),
         fnum = tag_val >>> 3,
         wtype = tag_val &&& 0x07,
         :ok <- validate_field_number(fnum),
         :ok <- validate_wire_type(wtype) do
      tag_size = byte_size(bin) - byte_size(rest_after_tag)
      <<tag_raw::binary-size(tag_size), _::binary>> = bin

      cond do
        wtype == 4 and fnum == field_number ->
          body = IO.iodata_to_binary(Enum.reverse(acc))
          {:ok, body, IO.iodata_to_binary([body, tag_raw]), rest_after_tag}

        wtype == 4 ->
          malformed(:unexpected_end_group, %{
            expected_field_number: field_number,
            got_field_number: fnum
          })

        wtype == 3 ->
          with {:ok, _body, raw, rest} <- read_group_field(fnum, rest_after_tag, depth + 1) do
            read_group_body(field_number, rest, [raw, tag_raw | acc], depth)
          end

        true ->
          with {:ok, raw, rest} <- read_wire_raw(fnum, wtype, rest_after_tag) do
            read_group_body(field_number, rest, [raw, tag_raw | acc], depth)
          end
      end
    end
  end

  def read_length_delimited(len, rest) when len <= byte_size(rest) do
    <<value::binary-size(len), rest::binary>> = rest
    {:ok, value, rest}
  end

  def read_length_delimited(_len, _rest), do: malformed(:truncated_length_delimited)

  def encode_float64(:infinity), do: <<0, 0, 0, 0, 0, 0, 240, 127>>
  def encode_float64(:negative_infinity), do: <<0, 0, 0, 0, 0, 0, 240, 255>>
  def encode_float64(:nan), do: <<0, 0, 0, 0, 0, 0, 248, 127>>
  def encode_float64(v), do: <<v::little-float-64>>

  def encode_float32(:infinity), do: <<0, 0, 128, 127>>
  def encode_float32(:negative_infinity), do: <<0, 0, 128, 255>>
  def encode_float32(:nan), do: <<0, 0, 192, 127>>
  def encode_float32(v), do: <<v::little-float-32>>

  def decode_float64(<<0, 0, 0, 0, 0, 0, 240, 127>>), do: :infinity
  def decode_float64(<<0, 0, 0, 0, 0, 0, 240, 255>>), do: :negative_infinity

  def decode_float64(<<_::binary-size(8)>> = raw) do
    <<i::little-unsigned-64>> = raw
    exp = i >>> 52 &&& 0x7FF
    frac = i &&& 0xFFFFFFFFFFFFF

    if exp == 0x7FF and frac != 0 do
      :nan
    else
      <<v::little-float-64>> = raw
      v
    end
  end

  def decode_float32(<<0, 0, 128, 127>>), do: :infinity
  def decode_float32(<<0, 0, 128, 255>>), do: :negative_infinity

  def decode_float32(<<_::binary-size(4)>> = raw) do
    <<i::little-unsigned-32>> = raw
    exp = i >>> 23 &&& 0xFF
    frac = i &&& 0x7FFFFF

    if exp == 0xFF and frac != 0 do
      :nan
    else
      <<v::little-float-32>> = raw
      v
    end
  end

  def signed32(raw) do
    <<v::signed-32>> = <<raw::32>>
    v
  end

  def signed64(raw) do
    <<v::signed-64>> = <<raw::64>>
    v
  end

  def fixed32_u(<<v::little-unsigned-32>>), do: v
  def fixed32_s(<<v::little-signed-32>>), do: v
  def fixed64_u(<<v::little-unsigned-64>>), do: v
  def fixed64_s(<<v::little-signed-64>>), do: v

  def uint32(raw), do: raw &&& 0xFFFFFFFF
  def to_unsigned64(v) when v >= 0, do: v
  def to_unsigned64(v), do: v &&& 0xFFFFFFFFFFFFFFFF
  def zigzag(raw), do: bxor(raw >>> 1, -(raw &&& 1))
  def zigzag32(raw), do: raw |> uint32() |> zigzag()

  def malformed(reason, details \\ %{}), do: {:error, {:malformed, reason, details}}
end