Skip to main content

lib/http/web_socket/handshake.ex

defmodule HTTP.WebSocket.Handshake do
  @moduledoc false

  @guid "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  @max_head_bytes 64 * 1024
  @required_headers ~w(host upgrade connection sec-websocket-key sec-websocket-version)

  @spec build_request(URI.t(), [String.t()], [{String.t(), String.t()}], String.t()) ::
          {:ok, binary()}
  def build_request(%URI{} = uri, protocols, headers, key) when is_binary(key) do
    headers =
      headers
      |> reject_required_headers()
      |> Kernel.++(required_headers(uri, key))
      |> maybe_add_protocols(protocols)

    request =
      [
        "GET ",
        request_target(uri),
        " HTTP/1.1\r\n",
        Enum.map(headers, fn {name, value} -> [name, ": ", value, "\r\n"] end),
        "\r\n"
      ]
      |> IO.iodata_to_binary()

    {:ok, request}
  end

  @spec parse_response(binary()) ::
          {:ok, non_neg_integer(), HTTP.Headers.t(), binary()}
          | {:more, binary()}
          | {:error, term()}
  def parse_response(buffer) when is_binary(buffer) do
    case :binary.match(buffer, "\r\n\r\n") do
      {index, 4} when index <= @max_head_bytes ->
        head = binary_part(buffer, 0, index)
        rest = binary_part(buffer, index + 4, byte_size(buffer) - index - 4)

        with {:ok, status, headers} <- parse_head(head) do
          {:ok, status, headers, rest}
        end

      {index, 4} when index > @max_head_bytes ->
        {:error, :headers_too_large}

      :nomatch when byte_size(buffer) > @max_head_bytes ->
        {:error, :headers_too_large}

      :nomatch ->
        {:more, buffer}
    end
  end

  @spec accept_key(String.t()) :: String.t()
  def accept_key(key) when is_binary(key) do
    :sha
    |> :crypto.hash(key <> @guid)
    |> Base.encode64()
  end

  @spec validate_response(
          non_neg_integer(),
          HTTP.Headers.t() | [{String.t(), String.t()}],
          String.t(),
          [
            String.t()
          ]
        ) ::
          {:ok, %{protocol: String.t(), extensions: String.t()}} | {:error, term()}
  def validate_response(status, headers, key, requested_protocols) do
    headers = to_headers(headers)

    with :ok <- validate_status(status),
         :ok <- validate_header_token(headers, "upgrade", "websocket"),
         :ok <- validate_header_token(headers, "connection", "upgrade"),
         :ok <- validate_accept(headers, key),
         {:ok, protocol} <- validate_protocol(headers, requested_protocols),
         {:ok, extensions} <- validate_extensions(headers) do
      {:ok, %{protocol: protocol, extensions: extensions}}
    end
  end

  defp request_target(%URI{path: path, query: nil}), do: path_or_root(path)
  defp request_target(%URI{path: path, query: query}), do: path_or_root(path) <> "?" <> query

  defp path_or_root(nil), do: "/"
  defp path_or_root(""), do: "/"
  defp path_or_root(path), do: path

  defp required_headers(uri, key) do
    [
      {"Host", host_header(uri)},
      {"Upgrade", "websocket"},
      {"Connection", "Upgrade"},
      {"Sec-WebSocket-Key", key},
      {"Sec-WebSocket-Version", "13"}
    ]
  end

  defp host_header(%URI{host: host, port: nil}), do: host

  defp host_header(%URI{host: host, port: port, scheme: scheme}) do
    if default_port?(scheme, port), do: host, else: host <> ":" <> Integer.to_string(port)
  end

  defp default_port?("ws", 80), do: true
  defp default_port?("wss", 443), do: true
  defp default_port?(_scheme, _port), do: false

  defp reject_required_headers(headers) do
    Enum.reject(headers, fn {name, _value} ->
      String.downcase(name) in @required_headers
    end)
  end

  defp maybe_add_protocols(headers, []), do: headers

  defp maybe_add_protocols(headers, protocols) do
    headers ++ [{"Sec-WebSocket-Protocol", Enum.join(protocols, ", ")}]
  end

  defp parse_head(head) do
    case String.split(head, "\r\n") do
      [status_line | header_lines] ->
        with {:ok, status} <- parse_status_line(status_line),
             {:ok, headers} <- parse_headers(header_lines) do
          {:ok, status, headers}
        end

      _ ->
        {:error, :invalid_http_response}
    end
  end

  defp parse_status_line(line) do
    case String.split(line, " ", parts: 3) do
      ["HTTP/" <> _version, status_code | _reason] ->
        case Integer.parse(status_code) do
          {status, ""} -> {:ok, status}
          _ -> {:error, :invalid_status_line}
        end

      _ ->
        {:error, :invalid_status_line}
    end
  end

  defp parse_headers(lines) do
    lines
    |> Enum.reduce_while({:ok, []}, fn line, {:ok, acc} ->
      case String.split(line, ":", parts: 2) do
        [name, value] -> {:cont, {:ok, [{String.trim(name), String.trim(value)} | acc]}}
        _ -> {:halt, {:error, :invalid_header}}
      end
    end)
    |> case do
      {:ok, headers} -> {:ok, headers |> Enum.reverse() |> HTTP.Headers.new()}
      {:error, _reason} = error -> error
    end
  end

  defp validate_status(101), do: :ok
  defp validate_status(status), do: {:error, {:unexpected_status, status}}

  defp validate_header_token(headers, name, expected) do
    values =
      headers
      |> HTTP.Headers.get_all(name)
      |> Enum.flat_map(&split_header_tokens/1)

    if expected in values do
      :ok
    else
      {:error, {:missing_header_token, name, expected}}
    end
  end

  defp split_header_tokens(value) do
    value
    |> String.split(",")
    |> Enum.map(&String.trim/1)
    |> Enum.map(&String.downcase/1)
  end

  defp validate_accept(headers, key) do
    case HTTP.Headers.get(headers, "sec-websocket-accept") do
      nil -> {:error, :missing_accept}
      value -> if value == accept_key(key), do: :ok, else: {:error, :invalid_accept}
    end
  end

  defp validate_protocol(headers, requested_protocols) do
    selected = HTTP.Headers.get(headers, "sec-websocket-protocol")

    cond do
      is_nil(selected) ->
        {:ok, ""}

      selected in requested_protocols ->
        {:ok, selected}

      true ->
        {:error, {:unexpected_protocol, selected}}
    end
  end

  defp validate_extensions(headers) do
    case HTTP.Headers.get(headers, "sec-websocket-extensions") do
      nil -> {:ok, ""}
      "" -> {:ok, ""}
      value -> {:error, {:unsupported_extensions, value}}
    end
  end

  defp to_headers(%HTTP.Headers{} = headers), do: headers
  defp to_headers(headers) when is_list(headers), do: HTTP.Headers.new(headers)
end