lib/nsq/connection/initializer.ex

defmodule NSQ.Connection.Initializer do
  alias NSQ.Connection, as: C
  alias NSQ.Connection.MessageHandling
  alias NSQ.Connection.Buffer
  alias NSQ.ConnInfo
  import NSQ.Protocol

  @socket_opts [as: :binary, mode: :passive, packet: :raw]

  @project ElixirNsq.Mixfile.project()
  @user_agent "#{@project[:app]}/#{@project[:version]}"
  @ssl_versions [:sslv3, :tlsv1, :"tlsv1.1", :"tlsv1.2"] |> Enum.with_index()

  @spec connect(%{nsqd: C.host_with_port()}) :: {:ok, C.state()} | {:error, String.t()}
  def connect(%{nsqd: {host, port}} = state) do
    if should_connect?(state) do
      socket_opts =
        @socket_opts
        |> Keyword.merge(
          send: [{:timeout, state.config.write_timeout}],
          timeout: state.config.dial_timeout
        )

      case Socket.TCP.connect(host, port, socket_opts) do
        {:ok, socket} ->
          state.reader |> Buffer.setup_socket(socket, state.config.read_timeout)
          state.writer |> Buffer.setup_socket(socket, state.config.read_timeout)

          state =
            %{state | socket: socket}
            |> do_handshake!
            |> start_receiving_messages!
            |> reset_connects

          {:ok, %{state | connected: true}}

        {:error, reason} ->
          if length(state.config.nsqlookupds) > 0 do
            NSQ.Logger.warn(
              "(#{inspect(self())}) connect failed; #{reason}; discovery loop should respawn"
            )

            {{:error, reason}, %{state | connect_attempts: state.connect_attempts + 1}}
          else
            if state.config.max_reconnect_attempts > 0 do
              NSQ.Logger.warn(
                "(#{inspect(self())}) connect failed; #{reason}; discovery loop should respawn"
              )

              {{:error, reason}, %{state | connect_attempts: state.connect_attempts + 1}}
            else
              NSQ.Logger.error(
                "(#{inspect(self())}) connect failed; #{reason}; reconnect turned off; terminating connection"
              )

              Process.exit(self(), :connect_failed)
            end
          end
      end
    else
      NSQ.Logger.error("#{inspect(self())}: Failed to connect; terminating connection")
      Process.exit(self(), :connect_failed)
    end
  end

  @doc """
  Immediately after connecting to the NSQ socket, both consumers and producers
  follow this protocol.
  """
  @spec do_handshake(C.state()) :: {:ok, C.state()}
  def do_handshake(conn_state) do
    conn_state |> send_magic_v2

    {:ok, conn_state} = identify(conn_state)

    # Producers don't have a channel, so they won't do this.
    if conn_state.channel do
      subscribe(conn_state)
    end

    {:ok, conn_state}
  end

  def do_handshake!(conn_state) do
    {:ok, conn_state} = do_handshake(conn_state)
    conn_state
  end

  @spec send_magic_v2(C.state()) :: :ok
  defp send_magic_v2(conn_state) do
    NSQ.Logger.debug("(#{inspect(self())}) sending magic v2...")
    conn_state |> Buffer.send!(encode(:magic_v2))
  end

  @spec identify(C.state()) :: {:ok, binary}
  defp identify(conn_state) do
    NSQ.Logger.debug("(#{inspect(self())}) identifying...")
    identify_obj = encode({:identify, identify_props(conn_state)})
    conn_state |> Buffer.send!(identify_obj)
    {:response, json} = recv_nsq_response(conn_state)
    {:ok, _conn_state} = update_from_identify_response(conn_state, json)
  end

  @spec identify_props(C.state()) :: map
  defp identify_props(%{nsqd: {host, port}, config: config} = conn_state) do
    %{
      client_id: "#{host}:#{port} (#{inspect(conn_state.parent)})",
      hostname: to_string(:net_adm.localhost()),
      feature_negotiation: true,
      heartbeat_interval: config.heartbeat_interval,
      output_buffer: config.output_buffer_size,
      output_buffer_timeout: config.output_buffer_timeout,
      tls_v1: config.tls_v1,
      snappy: false,
      deflate: config.deflate,
      deflate_level: config.deflate_level,
      sample_rate: 0,
      user_agent: config.user_agent || @user_agent,
      msg_timeout: config.msg_timeout
    }
  end

  def inflate(data) do
    z = :zlib.open()
    :ok = z |> :zlib.inflateInit(-15)
    inflated = z |> :zlib.inflateChunk(data)
    NSQ.Logger.warn("inflated chunk?")
    NSQ.Logger.warn(inspect(inflated))
    :ok = z |> :zlib.inflateEnd()
    :ok = z |> :zlib.close()
    inflated
  end

  @spec update_from_identify_response(C.state(), binary) :: {:ok, C.state()}
  defp update_from_identify_response(conn_state, json) do
    {:ok, parsed} = Jason.decode(json)

    # respect negotiated max_rdy_count
    if parsed["max_rdy_count"] do
      ConnInfo.update(conn_state, %{max_rdy: parsed["max_rdy_count"]})
    end

    # respect negotiated msg_timeout
    timeout = parsed["msg_timeout"] || conn_state.config.msg_timeout
    conn_state = %{conn_state | msg_timeout: timeout}

    # wrap our socket with SSL if TLS is enabled
    conn_state =
      if parsed["tls_v1"] == true do
        NSQ.Logger.debug("Upgrading to TLS...")

        socket =
          Socket.SSL.connect!(conn_state.socket,
            cacertfile: conn_state.config.tls_cert,
            keyfile: conn_state.config.tls_key,
            versions: ssl_versions(conn_state.config.tls_min_version),
            verify: ssl_verify_atom(conn_state.config)
          )

        conn_state = %{conn_state | socket: socket}
        conn_state.reader |> Buffer.setup_socket(socket, conn_state.config.read_timeout)
        conn_state.writer |> Buffer.setup_socket(socket, conn_state.config.read_timeout)
        conn_state |> wait_for_ok!
        conn_state
      else
        conn_state
      end

    # If compression is enabled, we expect to receive a compressed "OK"
    # immediately.
    conn_state.reader |> Buffer.setup_compression(parsed, conn_state.config)
    conn_state.writer |> Buffer.setup_compression(parsed, conn_state.config)

    if parsed["deflate"] == true || parsed["snappy"] == true do
      conn_state |> wait_for_ok!
    end

    if parsed["auth_required"] == true do
      NSQ.Logger.debug("sending AUTH")
      auth_cmd = encode({:auth, conn_state.config.auth_secret})
      conn_state |> Buffer.send!(auth_cmd)
      {:response, json} = recv_nsq_response(conn_state)
      NSQ.Logger.debug(json)
    end

    {:ok, conn_state}
  end

  defp ssl_verify_atom(config) do
    if config.tls_insecure_skip_verify == true do
      :verify_none
    else
      :verify_peer
    end
  end

  @spec subscribe(C.state()) :: {:ok, binary}
  defp subscribe(%{topic: topic, channel: channel} = conn_state) do
    NSQ.Logger.debug("(#{inspect(self())}) subscribe to #{topic} #{channel}")
    conn_state |> Buffer.send!(encode({:sub, topic, channel}))

    NSQ.Logger.debug("(#{inspect(self())}) wait for subscription acknowledgment")
    conn_state |> wait_for_ok!
  end

  @spec recv_nsq_response(C.state()) :: {:response, binary}
  defp recv_nsq_response(conn_state) do
    <<msg_size::size(32)>> = conn_state |> Buffer.recv!(4)
    raw_msg_data = conn_state |> Buffer.recv!(msg_size)
    {:response, _response} = decode(raw_msg_data)
  end

  defp wait_for_ok!(state) do
    expected = ok_msg()
    ^expected = state |> Buffer.recv!(byte_size(expected))
  end

  @spec ssl_versions(NSQ.Config.t()) :: [atom]
  def ssl_versions(tls_min_version) do
    if tls_min_version do
      min_index = @ssl_versions[tls_min_version]

      @ssl_versions
      |> Enum.drop_while(fn {_, index} -> index < min_index end)
      |> Enum.map(fn {version, _} -> version end)
      |> Enum.reverse()
    else
      @ssl_versions
      |> Enum.map(fn {version, _} -> version end)
      |> Enum.reverse()
    end
  end

  @spec should_connect?(C.state()) :: boolean
  defp should_connect?(state) do
    state.connect_attempts == 0 ||
      state.connect_attempts <= state.config.max_reconnect_attempts
  end

  @spec start_receiving_messages(C.state()) :: {:ok, C.state()}
  defp start_receiving_messages(state) do
    reader_pid = spawn_link(MessageHandling, :recv_nsq_messages, [state, self()])
    state = %{state | reader_pid: reader_pid}
    GenServer.cast(self(), :flush_cmd_queue)
    {:ok, state}
  end

  defp start_receiving_messages!(state) do
    {:ok, state} = start_receiving_messages(state)
    state
  end

  @spec reset_connects(C.state()) :: C.state()
  defp reset_connects(state), do: %{state | connect_attempts: 0}
end