lib/handshake.ex

defmodule ExWs.Handshake.Errors do
	def build(code, message) do
		data = [
			"HTTP/1.1 ", to_string(code), " ", phrase(code), "\r\n",
			"Error: ", message, "\r\n",
			"Content-Length: 0\r\n",
			"\r\n"
		]
		{:invalid, :erlang.iolist_to_binary(data)}
	end

	defp phrase(400), do: "Bad Request"
	defp phrase(_), do: "Server Error"
end

defmodule ExWs.Handshake do
	require Logger

	alias __MODULE__.Errors

	if Mix.env == :test && System.get_env("AB") != "1" do
		@inet ExWs.GenTcpFake
		@gen_tcp ExWs.GenTcpFake
	else
		@inet :inet
		@gen_tcp :gen_tcp
	end

	@invalid_request_line Errors.build(400, "request_line")
	@invalid_path Errors.build(400, "path")
	@invalid_method Errors.build(400, "method")
	@invalid_proto Errors.build(400, "protocol")
	@invalid_headers Errors.build(400, "headers")

	@invalid_key Errors.build(400, "key")
	@invalid_host Errors.build(400, "host")
	@invalid_version Errors.build(400, "version")
	@invalid_upgrade Errors.build(400, "upgrade")
	@invalid_connection Errors.build(400, "connection")

	if Mix.env == :prod do
		@timeout 5000
	else
		@timeout 100
	end

	def read(socket) do
		@inet.setopts(socket, packet: :line)
		with {:ok, request_line} <- read_request_line(socket),
		     {:ok, path} <- verify_request_line(request_line),
		     {:ok, headers} <- read_headers(socket, %{}),
		     :ok <- validate_headers(headers)
		do
			{:ok, path, headers, socket}
		else
			err -> close(socket, err); :closed
		end
	end

	defp read_request_line(socket) do
		case read_line(socket) do
			{:ok, line} -> {:ok, line}
			_ -> @invalid_request_line
		end
	end

	defp verify_request_line(line) do
		with {:ok, line} <- ensure_method(line),
		     {:ok, path, line} <- extract_path(trim_leading(line)),
		     :ok <- ensure_protocol(trim_leading(line))
		do
			{:ok, path}
		end
	end

	defp ensure_method(<<method::binary-size(3), " ", line::binary>>) do
		case string_compare(method, "get") do
			true -> {:ok, line}
			false -> @invalid_method
		end
	end

	defp ensure_method(_line) do
		@invalid_method
	end

	defp extract_path(line) do
		case :binary.split(line, " ") do
			[path, line] -> {:ok, path, line}
			_ -> @invalid_path
		end
	end

	defp ensure_protocol(line) do
		case string_compare(line, "http/1.1") do
			true -> :ok
			false -> @invalid_proto
		end
	end

	defp read_headers(socket, headers) do
		with {:ok, line} when line != "" <- read_line(socket),
		     [name, value] <- :binary.split(line, ":")
		do
			read_headers(socket, set_header(headers, header_downcase(name, []), value))
		else
			{:ok, ""} -> {:ok, headers} # An empty readline means we're done
			{:error, _} = err -> err
			_ -> @invalid_headers
		end
	end

	for {key, name} <- [host: "host", connection: "connection", upgrade: "upgrade", key: "sec-websocket-key", version: "sec-websocket-version"] do
		defp set_header(headers, unquote(name), value) do
			Map.put(headers, unquote(key), String.trim(value))
		end
	end

	defp set_header(headers, _key, _value), do: headers

	defp validate_headers(headers) do
		key = headers[:key]
		host = headers[:host]
		with true <- key not in [nil, ""] || @invalid_key,
		     true <- host not in [nil, ""] || @invalid_host,
		     true <- headers[:version] == "13" || @invalid_version,
		     true <- header_has_value(headers[:upgrade], "websocket") || @invalid_upgrade,
		     true <- header_has_value(headers[:connection], "upgrade") || @invalid_connection
		do
			:ok
		end
	end

	def accept(socket, headers) do
		key = headers[:key]
		accept_key = :sha
		|> :crypto.hash([key, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"])
		|> Base.encode64()

		data = [
			"HTTP/1.1 101 Switching Protocols\r\n",
			"Upgrade: websocket\r\n",
			"Connection: Upgrade\r\n",
			"Sec-WebSocket-Accept: ", accept_key, "\r\n",
			"\r\n"
		]

		@inet.setopts(socket, send_timeout: @timeout)
		@gen_tcp.send(socket, data)
	end

	def reject(socket, err), do: close(socket, err)

	defp read_line(socket) do
		case @gen_tcp.recv(socket, 0, @timeout) do
			{:ok, data} ->
				 # strip out the trailing \r\n
				length = byte_size(data) - 2
				<<data::bytes-size(length), _, _>> = data
				{:ok, data}
			err -> err
		end
	end

	defp close(socket, {:invalid, message}) do
		@inet.setopts(socket, send_timeout: @timeout)
		@gen_tcp.send(socket, message)
		@gen_tcp.close(socket)
	end

	# Probably an error from :gen_tcp
	defp close(socket, err) do
		Logger.error("handshake: #{inspect(err)}")
		@gen_tcp.close(socket)
	end

	defp trim_leading(<<" ", data::binary>>), do: trim_leading(data)
	defp trim_leading(data), do: data

	defp header_has_value(actual, expected) do
		header_has_value(actual, expected, expected)
	end

	defp header_has_value(<<>>, <<>>, _expected), do: true
	defp header_has_value(<<",", _::binary>>, <<>>, _expected), do: true

	defp header_has_value(<<" ", rest::binary>>, <<>>, expected) do
		header_has_value(rest, <<>>, expected)
	end

	defp header_has_value(<<i, input::binary>>, <<t, target::binary>>, expected) do
		case i == t || i + 32 == t do
			true -> header_has_value(input, target, expected)
			false ->
				case :binary.split(input, ",") do
					[_, next] -> header_has_value(trim_leading(next), expected, expected)
					_ -> false
				end
		end
	end

	defp header_has_value(<<_::binary>>, <<>>, _expected), do: false
	defp header_has_value(<<>>, <<_::binary>>, _expected), do: false
	defp header_has_value(nil, _, _expected), do: false
	defp string_compare(<<i, input::binary>>, <<t, target::binary>>) do
		case i == t || i + 32 == t do
			false -> false
			true -> string_compare(input, target)
		end
	end

	defp string_compare(<<>>, <<>>), do: true
	defp string_compare(<<" ", input::binary>>, <<>>), do: string_compare(input, <<>>)
	defp string_compare(<<_::binary>>, <<>>), do: false
	defp string_compare(<<>>, <<_::binary>>), do: false

	# The HTML headers that we care about have very few legal values
	defp header_downcase(<<>>, acc) do
		acc |> Enum.reverse() |> :erlang.list_to_binary()
	end

	defp header_downcase(<<c, input::binary>>, acc) do
		c = case c >= ?A && c <= ?Z do
			true -> c + 32
			false -> c
		end
		header_downcase(input, [c | acc])
	end
end