Skip to main content

lib/quickbeam/web_socket.ex

defmodule QuickBEAM.WebSocket do
  @moduledoc false
  use GenServer

  alias Mint.HTTP

  # Mint.WebSocket.t() is @opaque — Dialyzer can't prove that new/4 returns
  # {:ok, _, _} so it considers handle_upgrade_success and its callees dead.
  @dialyzer {:nowarn_function,
             handle_response: 2, handle_upgrade_success: 3, response_header: 2, notify_open: 1}

  defstruct [
    :id,
    :owner,
    :owner_ref,
    :url,
    :protocols,
    :conn,
    :request_ref,
    :websocket,
    :upgrade_status,
    :pending_close,
    upgrade_headers: [],
    protocol: "",
    closed?: false,
    close_sent?: false
  ]

  @spec connect(args :: [String.t()], owner :: pid()) :: String.t()
  def connect([url, protocols], owner_pid) do
    id = Integer.to_string(System.unique_integer([:positive]))

    {:ok, pid} =
      GenServer.start_link(__MODULE__, %{
        id: id,
        owner: owner_pid,
        url: url,
        protocols: List.wrap(protocols)
      })

    send(owner_pid, {:websocket_started, id, pid})
    id
  end

  @spec send_frame(args :: [term()], owner :: pid()) :: nil
  def send_frame([id, [kind, payload]], owner_pid) do
    send(owner_pid, {:ws_send, id, kind, payload})
    nil
  end

  @spec close(args :: [term()], owner :: pid()) :: nil
  def close([id, code, reason], owner_pid) do
    send(owner_pid, {:ws_close, id, code, reason})
    nil
  end

  # -- GenServer --

  @impl true
  def init(%{id: id, owner: owner, url: url, protocols: protocols}) do
    owner_ref = Process.monitor(owner)
    send(self(), :connect)

    {:ok,
     %__MODULE__{
       id: id,
       owner: owner,
       owner_ref: owner_ref,
       url: url,
       protocols: protocols
     }}
  end

  @impl true
  def handle_info(:connect, state) do
    case open_connection(state) do
      {:ok, state} -> {:noreply, state}
      {:error, state, reason} -> {:stop, reason, emit_error_and_close(state, reason)}
    end
  end

  def handle_info({:DOWN, ref, :process, _pid, _reason}, %{owner_ref: ref} = state) do
    {:stop, :normal, state}
  end

  def handle_info(_message, %{conn: nil} = state) do
    {:noreply, state}
  end

  def handle_info(message, state) do
    case Mint.WebSocket.stream(state.conn, message) do
      {:ok, conn, responses} ->
        state = %{state | conn: conn}
        handle_responses(state, responses)

      {:error, conn, reason, responses} ->
        state = %{state | conn: conn}

        state =
          case handle_response_list(state, responses) do
            {:ok, state} -> state
            {:stop, state} -> state
          end

        {:stop, reason, emit_error_and_close(state, reason)}

      :unknown ->
        {:noreply, state}
    end
  end

  @impl true
  def handle_cast({:send, _kind, _payload}, %{websocket: nil} = state) do
    {:noreply, state}
  end

  def handle_cast({:send, kind, payload}, state) do
    frame =
      case kind do
        "text" -> {:text, payload}
        "binary" -> {:binary, payload}
      end

    case stream_frame(state, frame) do
      {:ok, state} -> {:noreply, state}
      {:error, state, reason} -> {:stop, reason, emit_error_and_close(state, reason)}
    end
  end

  def handle_cast({:close, _code, _reason}, %{closed?: true} = state) do
    {:noreply, state}
  end

  def handle_cast({:close, code, reason}, %{websocket: nil} = state) do
    {:noreply, %{state | pending_close: {code, reason}}}
  end

  def handle_cast({:close, code, reason}, state) do
    case do_close(state, code, reason) do
      {:ok, state} -> {:noreply, state}
      {:error, state, error} -> {:stop, error, emit_error_and_close(state, error)}
    end
  end

  @impl true
  def terminate(_reason, state) do
    if state.conn do
      try do
        HTTP.close(state.conn)
      catch
        _, _ -> :ok
      end
    end

    :ok
  end

  # -- HTTP upgrade response handling --

  defp handle_responses(state, responses) do
    case handle_response_list(state, responses) do
      {:ok, state} -> {:noreply, state}
      {:stop, state} -> {:stop, :normal, state}
    end
  end

  defp handle_response_list(state, responses) do
    Enum.reduce_while(responses, {:ok, state}, fn response, {:ok, state} ->
      case handle_response(state, response) do
        {:ok, state} -> {:cont, {:ok, state}}
        {:stop, state} -> {:halt, {:stop, state}}
      end
    end)
  end

  defp handle_response(state, {:status, ref, status}) when ref == state.request_ref do
    {:ok, %{state | upgrade_status: status}}
  end

  defp handle_response(state, {:headers, ref, headers}) when ref == state.request_ref do
    {:ok, %{state | upgrade_headers: state.upgrade_headers ++ headers}}
  end

  defp handle_response(state, {:done, ref}) when ref == state.request_ref do
    case Mint.WebSocket.new(state.conn, ref, state.upgrade_status, state.upgrade_headers) do
      {:ok, conn, websocket} ->
        handle_upgrade_success(state, conn, websocket)

      {:error, conn, reason} ->
        {:stop, emit_error_and_close(%{state | conn: conn}, reason)}
    end
  end

  defp handle_response(state, {:data, ref, data})
       when ref == state.request_ref and not is_nil(state.websocket) do
    case Mint.WebSocket.decode(state.websocket, data) do
      {:ok, websocket, frames} ->
        handle_frames(%{state | websocket: websocket}, frames)

      {:error, websocket, reason} ->
        {:stop, emit_error_and_close(%{state | websocket: websocket}, reason)}
    end
  end

  defp handle_response(state, _response), do: {:ok, state}

  defp handle_frames(state, frames) do
    Enum.reduce_while(frames, {:ok, state}, fn frame, {:ok, state} ->
      case handle_frame(state, frame) do
        {:ok, state} -> {:cont, {:ok, state}}
        {:stop, state} -> {:halt, {:stop, state}}
      end
    end)
  end

  defp handle_upgrade_success(state, conn, websocket) do
    state = %{state | conn: conn, websocket: websocket, upgrade_status: nil}

    if state.pending_close do
      {:stop, emit_close(%{state | pending_close: nil, upgrade_headers: []}, 1006, "", false)}
    else
      protocol = response_header(state.upgrade_headers, "sec-websocket-protocol") || ""

      state =
        %{state | protocol: protocol, upgrade_headers: []}
        |> notify_open()

      {:ok, state}
    end
  end

  # -- Frame handling --

  defp handle_frame(state, {:text, text}), do: {:ok, notify_text_message(state, text)}
  defp handle_frame(state, {:binary, data}), do: {:ok, notify_binary_message(state, data)}
  defp handle_frame(state, {:pong, _data}), do: {:ok, state}

  defp handle_frame(state, {:ping, data}) do
    case stream_frame(state, {:pong, data}) do
      {:ok, state} -> {:ok, state}
      {:error, state, error} -> {:stop, emit_error_and_close(state, error)}
    end
  end

  defp handle_frame(state, {:close, code, reason}) do
    state =
      if state.close_sent? do
        state
      else
        case stream_frame(state, :close) do
          {:ok, state} -> state
          {:error, state, _reason} -> state
        end
      end

    {:stop, emit_close(%{state | close_sent?: true}, code || 1005, reason || "", true)}
  end

  # -- Connection setup --

  defp open_connection(state) do
    with {:ok, %{scheme: scheme, host: host, port: port, path: path} = info} <-
           parse_url(state.url),
         {:ok, conn} <- HTTP.connect(http_scheme(scheme), host, port, connect_opts(info)),
         {:ok, conn, ref} <-
           Mint.WebSocket.upgrade(
             websocket_scheme(scheme),
             conn,
             path,
             upgrade_headers(state.protocols)
           ) do
      {:ok, %{state | conn: conn, request_ref: ref}}
    else
      {:error, %Mint.TransportError{} = reason} -> {:error, state, reason}
      {:error, %Mint.HTTPError{} = reason} -> {:error, state, reason}
      {:error, %Mint.WebSocketError{} = reason} -> {:error, state, reason}
      {:error, conn, reason} -> {:error, %{state | conn: conn}, reason}
      {:error, reason} -> {:error, state, reason}
    end
  end

  # -- Close / send --

  defp do_close(state, code, reason) do
    frame = if code == 1000 and reason == "", do: :close, else: {:close, code, reason}

    case stream_frame(state, frame) do
      {:ok, state} -> {:ok, %{state | close_sent?: true}}
      {:error, state, error} -> {:error, state, error}
    end
  end

  defp stream_frame(state, frame) do
    case Mint.WebSocket.encode(state.websocket, frame) do
      {:ok, websocket, data} ->
        case Mint.WebSocket.stream_request_body(state.conn, state.request_ref, data) do
          {:ok, conn} ->
            {:ok, %{state | conn: conn, websocket: websocket}}

          {:error, conn, reason} ->
            {:error, %{state | conn: conn, websocket: websocket}, reason}
        end

      {:error, websocket, reason} ->
        {:error, %{state | websocket: websocket}, reason}
    end
  end

  # -- URL parsing --

  defp parse_url(url) do
    uri = URI.parse(url)

    cond do
      uri.scheme not in ["ws", "wss"] ->
        {:error, ArgumentError.exception("unsupported WebSocket scheme")}

      is_nil(uri.host) or uri.host == "" ->
        {:error, ArgumentError.exception("missing WebSocket host")}

      true ->
        {:ok,
         %{
           scheme: uri.scheme,
           host: uri.host,
           port: uri.port || default_port(uri.scheme),
           path: path_with_query(uri)
         }}
    end
  end

  defp path_with_query(%URI{path: path, query: nil}) when path in [nil, ""], do: "/"
  defp path_with_query(%URI{path: path, query: nil}), do: path
  defp path_with_query(%URI{path: path, query: query}) when path in [nil, ""], do: "/?" <> query
  defp path_with_query(%URI{path: path, query: query}), do: path <> "?" <> query

  defp connect_opts(%{scheme: "ws"}), do: [protocols: [:http1]]

  defp connect_opts(%{scheme: "wss", host: host}) do
    [
      protocols: [:http1],
      transport_opts: [
        verify: :verify_peer,
        cacerts: :public_key.cacerts_get(),
        server_name_indication: String.to_charlist(host),
        customize_hostname_check: [match_fun: :public_key.pkix_verify_hostname_match_fun(:https)]
      ]
    ]
  end

  defp upgrade_headers([]), do: []
  defp upgrade_headers(protocols), do: [{"sec-websocket-protocol", Enum.join(protocols, ", ")}]

  defp http_scheme("ws"), do: :http
  defp http_scheme("wss"), do: :https

  defp websocket_scheme("ws"), do: :ws
  defp websocket_scheme("wss"), do: :wss

  defp default_port("ws"), do: 80
  defp default_port("wss"), do: 443

  # -- Event notifications --

  defp response_header(headers, name) do
    Enum.find_value(headers, fn
      {key, value} when is_binary(key) -> if String.downcase(key) == name, do: value
      _ -> nil
    end)
  end

  defp notify_open(state) do
    send(state.owner, {:websocket_event, ["__ws_open", state.id, state.protocol]})
    state
  end

  defp notify_text_message(state, payload) do
    send(state.owner, {:websocket_event, ["__ws_message", state.id, payload]})
    state
  end

  defp notify_binary_message(state, payload) do
    send(state.owner, {:websocket_event, ["__ws_message", state.id, {:bytes, payload}]})
    state
  end

  defp emit_error_and_close(state, reason) do
    state
    |> notify_error(reason)
    |> emit_close(1006, "", false)
  end

  defp notify_error(state, reason) do
    send(state.owner, {:websocket_event, ["__ws_error", state.id, Exception.message(reason)]})
    state
  end

  defp emit_close(%{closed?: true} = state, _code, _reason, _was_clean), do: state

  defp emit_close(state, code, reason, was_clean) do
    send(state.owner, {:websocket_event, ["__ws_close", state.id, code, reason, was_clean]})
    %{state | closed?: true}
  end
end