lib/bandit/websocket/permessage_deflate.ex

defmodule Bandit.WebSocket.PerMessageDeflate do
  @moduledoc false
  # Support for per-message deflate extension, per RFC7692§7

  @typedoc "Encapsulates the state of a WebSocket permessage-deflate context"
  @type t :: %__MODULE__{
          server_no_context_takeover: boolean(),
          client_no_context_takeover: boolean(),
          server_max_window_bits: 8..15,
          client_max_window_bits: 8..15,
          inflate_context: :zlib.zstream(),
          deflate_context: :zlib.zstream()
        }

  defstruct server_no_context_takeover: false,
            client_no_context_takeover: false,
            server_max_window_bits: 15,
            client_max_window_bits: 15,
            inflate_context: nil,
            deflate_context: nil

  @valid_params ~w[server_no_context_takeover client_no_context_takeover server_max_window_bits client_max_window_bits]

  def negotiate(requested_extensions) do
    :proplists.get_all_values("permessage-deflate", requested_extensions)
    |> Enum.find_value(&do_negotiate/1)
    |> case do
      nil -> {nil, []}
      params -> {init(params), "permessage-deflate": params}
    end
  end

  defp do_negotiate(params) do
    with params <- normalize_params(params),
         true <- validate_params(params) do
      resolve_params(params)
    else
      _ -> nil
    end
  end

  defp normalize_params(params) do
    params
    |> Enum.map(fn
      {"server_max_window_bits", true} -> {"server_max_window_bits", true}
      {"server_max_window_bits", value} -> {"server_max_window_bits", parse(value)}
      {"client_max_window_bits", true} -> {"client_max_window_bits", 15}
      {"client_max_window_bits", value} -> {"client_max_window_bits", parse(value)}
      value -> value
    end)
  end

  defp parse(value) do
    case Integer.parse(value) do
      {int_value, ""} -> int_value
      :error -> value
    end
  end

  defp validate_params(params) do
    no_invalid_params = params |> :proplists.split(@valid_params) |> elem(1) == []
    no_repeat_params = params |> :proplists.get_keys() |> length() == length(params)

    no_invalid_values =
      :proplists.get_value("server_no_context_takeover", params) in [:undefined, true] &&
        :proplists.get_value("client_no_context_takeover", params) in [:undefined, true] &&
        :proplists.get_value("server_max_window_bits", params, 15) in 8..15 &&
        :proplists.get_value("client_max_window_bits", params, 15) in 8..15

    no_invalid_params && no_repeat_params && no_invalid_values
  end

  # This is where we finally determine which parameters to accept. Note that we don't convert to
  # atoms until this stage to avoid potential atom exhaustion
  defp resolve_params(params) do
    @valid_params
    |> Enum.flat_map(fn param_name ->
      case :proplists.get_value(param_name, params) do
        :undefined -> []
        param -> [{String.to_existing_atom(param_name), param}]
      end
    end)
  end

  defp init(params) do
    instance = struct(__MODULE__, params)
    inflate_context = :zlib.open()
    :ok = :zlib.inflateInit(inflate_context, fix_bits(-instance.client_max_window_bits))
    deflate_context = :zlib.open()

    :ok =
      :zlib.deflateInit(
        deflate_context,
        :default,
        :deflated,
        fix_bits(-instance.server_max_window_bits),
        8,
        :default
      )

    %{instance | inflate_context: inflate_context, deflate_context: deflate_context}
  end

  # https://www.erlang.org/doc/man/zlib.html#deflateInit-6
  defp fix_bits(-8), do: -9
  defp fix_bits(other), do: other

  # Note that we pass back the context to the caller even though it is unmodified locally

  def inflate(data, %__MODULE__{} = context) do
    inflated_data =
      context.inflate_context
      |> :zlib.inflate(<<data::binary, 0x00, 0x00, 0xFF, 0xFF>>)
      |> IO.iodata_to_binary()

    if context.client_no_context_takeover, do: :zlib.inflateReset(context.inflate_context)
    {:ok, inflated_data, context}
  rescue
    e -> {:error, "Error encountered #{inspect(e)}"}
  end

  def inflate(_data, nil), do: {:error, :no_compress}

  def deflate(data, %__MODULE__{} = context) do
    deflated_data =
      context.deflate_context
      |> :zlib.deflate(data, :sync)
      |> IO.iodata_to_binary()

    deflated_size = byte_size(deflated_data) - 4

    deflated_data =
      case deflated_data do
        <<deflated_data::binary-size(deflated_size), 0x00, 0x00, 0xFF, 0xFF>> -> deflated_data
        deflated -> deflated
      end

    if context.server_no_context_takeover, do: :zlib.deflateReset(context.deflate_context)
    {:ok, deflated_data, context}
  rescue
    e -> {:error, "Error encountered #{inspect(e)}"}
  end

  def deflate(_data, nil), do: {:error, :no_compress}
end