Skip to main content

lib/ferric_store/client.ex

defmodule FerricStore.Client do
  @moduledoc """
  Native-protocol client.

  A client process owns one `ferric://` socket. After startup/authentication the
  default mode is request-id multiplexing: callers can issue concurrent
  `GenServer.call/3` requests, the client sends frames immediately, and replies
  are matched by native protocol request id.
  """

  use GenServer

  alias FerricStore.Error
  alias FerricStore.Protocol

  @default_url "ferric://127.0.0.1:6388"
  @default_timeout 5_000

  defstruct [
    :socket,
    :transport,
    :host,
    :port,
    :tls,
    :request_id,
    :timeout,
    multiplex: true,
    pending: %{},
    buffer: <<>>,
    chunks: %{}
  ]

  @type t :: GenServer.server()

  def start_link(opts \\ []) do
    GenServer.start_link(__MODULE__, opts)
  end

  def connect!(opts \\ []) do
    case start_link(opts) do
      {:ok, pid} -> pid
      {:error, reason} -> raise Error, message: "connect failed: #{inspect(reason)}", raw: reason
    end
  end

  def command(client, command, args \\ [], opts \\ []) do
    GenServer.call(client, {:command, command, List.wrap(args), opts}, timeout(opts))
  end

  def pipeline(client, commands, opts \\ []) when is_list(commands) do
    GenServer.call(client, {:pipeline, commands, opts}, timeout(opts))
  end

  def async_pipeline(client, commands, opts \\ []) when is_list(commands) do
    ref = make_ref()
    GenServer.cast(client, {:async_pipeline, self(), ref, commands, opts})
    ref
  end

  def native(client, opcode, payload, opts \\ []) do
    GenServer.call(client, {:native, opcode, payload, opts}, timeout(opts))
  end

  def async_native(client, opcode, payload, opts \\ []) do
    ref = make_ref()
    GenServer.cast(client, {:async_native, self(), ref, opcode, payload, opts})
    ref
  end

  def await(ref, timeout \\ @default_timeout) when is_reference(ref) do
    receive do
      {__MODULE__, ^ref, value} -> value
    after
      timeout -> {:error, %Error{message: "FerricStore async request timed out", raw: :timeout}}
    end
  end

  def yield(ref, timeout \\ 0) when is_reference(ref) do
    receive do
      {__MODULE__, ^ref, value} -> {:ok, value}
    after
      timeout -> nil
    end
  end

  def close(client) do
    GenServer.call(client, :close)
  catch
    :exit, _ -> :ok
  end

  @impl true
  def init(opts) do
    opts = normalize_opts(opts)

    multiplex = Keyword.get(opts, :multiplex, true)

    with {:ok, state} <- connect_socket(opts),
         {:ok, _value, state} <-
           request(state, Protocol.opcode(:startup), startup_payload(opts), 0),
         {:ok, state} <- maybe_auth(state, opts) do
      state
      |> Map.put(:multiplex, multiplex)
      |> activate_socket(multiplex)
    else
      {:error, %Error{} = error, _state} -> {:stop, error}
      {:error, reason, _state} -> {:stop, reason}
      {:error, reason} -> {:stop, reason}
    end
  end

  @impl true
  def handle_call({:command, command, args, _opts}, from, state) do
    payload = Protocol.command_payload(command, args)
    dispatch_call_request(state, from, Protocol.opcode(:command_exec), payload, 1)
  end

  def handle_call({:pipeline, commands, opts}, from, state) do
    payload = Protocol.pipeline_payload(commands, opts)
    dispatch_call_request(state, from, Protocol.opcode(:pipeline), payload, 1)
  end

  def handle_call({:native, opcode, payload, opts}, from, state) do
    lane_id = Keyword.get(opts, :lane_id, 1)
    dispatch_call_request(state, from, opcode, payload, lane_id)
  end

  def handle_call(:close, _from, state) do
    close_socket(state)
    {:stop, :normal, :ok, state}
  end

  @impl true
  def handle_cast({:async_pipeline, caller, ref, commands, opts}, state) do
    payload = Protocol.pipeline_payload(commands, opts)
    dispatch_async_request(state, caller, ref, Protocol.opcode(:pipeline), payload, 1)
  end

  def handle_cast({:async_native, caller, ref, opcode, payload, opts}, state) do
    lane_id = Keyword.get(opts, :lane_id, 1)
    dispatch_async_request(state, caller, ref, opcode, payload, lane_id)
  end

  defp dispatch_call_request(%{multiplex: true} = state, from, opcode, payload, lane_id) do
    request_id = next_request_id(state.request_id)
    frame = Protocol.encode_request(opcode, request_id, payload, lane_id: lane_id)

    case send_data(state, frame) do
      :ok ->
        pending = Map.put(state.pending, request_id, {:call, from, opcode})
        {:noreply, %{state | request_id: request_id, pending: pending}}

      {:error, reason} ->
        {:reply, {:error, to_error(reason)}, state}
    end
  end

  defp dispatch_call_request(state, _from, opcode, payload, lane_id) do
    case request(state, opcode, payload, lane_id) do
      {:ok, value, state} -> {:reply, value, state}
      {:error, error, state} -> {:reply, {:error, error}, state}
    end
  end

  defp dispatch_async_request(%{multiplex: true} = state, caller, ref, opcode, payload, lane_id) do
    request_id = next_request_id(state.request_id)
    frame = Protocol.encode_request(opcode, request_id, payload, lane_id: lane_id)

    case send_data(state, frame) do
      :ok ->
        pending = Map.put(state.pending, request_id, {:message, caller, ref, opcode})
        {:noreply, %{state | request_id: request_id, pending: pending}}

      {:error, reason} ->
        send(caller, {__MODULE__, ref, {:error, to_error(reason)}})
        {:noreply, state}
    end
  end

  defp dispatch_async_request(state, caller, ref, opcode, payload, lane_id) do
    value =
      case request(state, opcode, payload, lane_id) do
        {:ok, value, _state} -> value
        {:error, error, _state} -> {:error, error}
      end

    send(caller, {__MODULE__, ref, value})
    {:noreply, state}
  end

  @impl true
  def handle_info({:tcp, socket, data}, %{transport: :gen_tcp, socket: socket} = state),
    do: handle_socket_data(data, state)

  def handle_info({:ssl, socket, data}, %{transport: :ssl, socket: socket} = state),
    do: handle_socket_data(data, state)

  def handle_info({:tcp_closed, socket}, %{transport: :gen_tcp, socket: socket} = state),
    do: handle_socket_down(:closed, state)

  def handle_info({:ssl_closed, socket}, %{transport: :ssl, socket: socket} = state),
    do: handle_socket_down(:closed, state)

  def handle_info({:tcp_error, socket, reason}, %{transport: :gen_tcp, socket: socket} = state),
    do: handle_socket_down(reason, state)

  def handle_info({:ssl_error, socket, reason}, %{transport: :ssl, socket: socket} = state),
    do: handle_socket_down(reason, state)

  def handle_info(_message, state), do: {:noreply, state}

  @impl true
  def terminate(_reason, state), do: close_socket(state)

  defp normalize_opts(opts) when is_binary(opts), do: normalize_opts(url: opts)

  defp normalize_opts(opts) when is_list(opts) do
    url = Keyword.get(opts, :url, @default_url)
    uri = URI.parse(url)
    scheme = uri.scheme || "ferric"

    unless scheme in ["ferric", "ferrics"] do
      raise ArgumentError, "unsupported FerricStore URL scheme: #{inspect(scheme)}"
    end

    {username, password} = parse_userinfo(uri.userinfo)

    opts
    |> Keyword.put_new(:host, uri.host || "127.0.0.1")
    |> Keyword.put_new(:port, uri.port || default_port(scheme))
    |> Keyword.put_new(:tls, scheme == "ferrics")
    |> Keyword.put_new(:timeout, @default_timeout)
    |> Keyword.put_new(:username, username)
    |> Keyword.put_new(:password, password)
  end

  defp connect_socket(opts) do
    host = Keyword.fetch!(opts, :host)
    port = Keyword.fetch!(opts, :port)
    timeout = Keyword.fetch!(opts, :timeout)
    tls = Keyword.fetch!(opts, :tls)

    if tls do
      :ssl.start()

      case :ssl.connect(String.to_charlist(host), port, ssl_opts(opts), timeout) do
        {:ok, socket} ->
          {:ok,
           %__MODULE__{
             socket: socket,
             transport: :ssl,
             host: host,
             port: port,
             tls: true,
             request_id: 0,
             timeout: timeout
           }}

        {:error, reason} ->
          {:error, reason}
      end
    else
      socket_opts = [:binary, active: false, packet: :raw, nodelay: true]

      case :gen_tcp.connect(String.to_charlist(host), port, socket_opts, timeout) do
        {:ok, socket} ->
          {:ok,
           %__MODULE__{
             socket: socket,
             transport: :gen_tcp,
             host: host,
             port: port,
             tls: false,
             request_id: 0,
             timeout: timeout
           }}

        {:error, reason} ->
          {:error, reason}
      end
    end
  end

  defp request(state, opcode, payload, lane_id) do
    request_id = next_request_id(state.request_id)
    frame = Protocol.encode_request(opcode, request_id, payload, lane_id: lane_id)

    with :ok <- send_data(state, frame),
         {:ok, value} <- receive_response(state, request_id, opcode) do
      {:ok, value, %{state | request_id: request_id}}
    else
      {:error, %Error{} = error} -> {:error, error, state}
      {:error, reason} -> {:error, %Error{message: inspect(reason), raw: reason}, state}
    end
  end

  defp handle_socket_data(data, state) do
    state = %{state | buffer: state.buffer <> data}

    case parse_frames(state) do
      {:ok, state} ->
        case set_active_once(state) do
          :ok -> {:noreply, state}
          {:error, reason} -> handle_socket_down(reason, state)
        end

      {:error, reason, state} ->
        handle_socket_down(reason, state)
    end
  end

  defp parse_frames(%{buffer: buffer} = state) do
    case take_frame(buffer) do
      :incomplete ->
        {:ok, state}

      {:ok, header, body, rest} ->
        state
        |> Map.put(:buffer, rest)
        |> process_frame(header, body)
        |> parse_frames()

      {:error, reason} ->
        {:error, reason, state}
    end
  end

  defp take_frame(buffer) do
    header_size = Protocol.header_size()

    case buffer do
      <<header_binary::binary-size(header_size), rest::binary>> ->
        take_frame_body(header_binary, rest)

      _short ->
        :incomplete
    end
  end

  defp take_frame_body(header_binary, rest) do
    case Protocol.decode_response_header(header_binary) do
      {:ok, header} -> take_frame_body(header, rest, header.body_length)
      {:error, reason} -> {:error, reason}
    end
  end

  defp take_frame_body(_header, rest, body_length) when byte_size(rest) < body_length,
    do: :incomplete

  defp take_frame_body(header, rest, body_length) do
    <<body::binary-size(body_length), remaining::binary>> = rest
    {:ok, header, body, remaining}
  end

  defp process_frame(%{chunks: chunks} = state, %{request_id: 0}, _body),
    do: %{state | chunks: chunks}

  defp process_frame(state, header, body) do
    key = {header.request_id, header.opcode, header.lane_id}

    if Bitwise.band(header.flags, Protocol.flag_more_chunks()) != 0 do
      %{state | chunks: Map.update(state.chunks, key, [body], &[body | &1])}
    else
      {previous_chunks, chunks} = Map.pop(state.chunks, key)

      body =
        case previous_chunks do
          nil -> body
          values -> values |> then(&[body | &1]) |> Enum.reverse() |> IO.iodata_to_binary()
        end

      state
      |> Map.put(:chunks, chunks)
      |> complete_response(header, body)
    end
  end

  defp complete_response(state, header, body) do
    case Map.pop(state.pending, header.request_id) do
      {nil, pending} ->
        %{state | pending: pending}

      {{:call, from, expected_opcode}, pending} ->
        GenServer.reply(from, decode_response_reply(header, expected_opcode, body))
        %{state | pending: pending}

      {{:message, caller, ref, expected_opcode}, pending} ->
        send(caller, {__MODULE__, ref, decode_response_reply(header, expected_opcode, body)})
        %{state | pending: pending}
    end
  end

  defp decode_response_reply(header, expected_opcode, body) do
    case validate_response_header(header, header.request_id, expected_opcode) do
      :ok ->
        case Protocol.decode_response_body(header.flags, header.opcode, body) do
          {:ok, value} ->
            value

          {:error, {status, value}} ->
            {:error, %Error{message: error_message(status, value), status: status, raw: value}}

          {:error, %Error{} = error} ->
            {:error, error}

          {:error, reason} ->
            {:error, to_error(reason)}
        end

      {:error, %Error{} = error} ->
        {:error, error}
    end
  end

  defp handle_socket_down(reason, state) do
    state
    |> fail_pending(reason)
    |> close_socket()

    {:stop, reason, %{state | pending: %{}, chunks: %{}, buffer: <<>>}}
  end

  defp fail_pending(state, reason) do
    error = to_error(reason)

    Enum.each(state.pending, fn
      {_request_id, {:call, from, _opcode}} ->
        GenServer.reply(from, {:error, error})

      {_request_id, {:message, caller, ref, _opcode}} ->
        send(caller, {__MODULE__, ref, {:error, error}})
    end)

    state
  end

  defp to_error(%Error{} = error), do: error
  defp to_error(reason), do: %Error{message: inspect(reason), raw: reason}

  defp receive_response(state, expected_request_id, expected_opcode) do
    with {:ok, header_binary} <- recv_exact(state, Protocol.header_size()),
         {:ok, header} <- Protocol.decode_response_header(header_binary),
         {:ok, body} <- recv_response_body(state, header) do
      handle_response_frame(state, header, body, expected_request_id, expected_opcode)
    else
      {:error, {status, value}} ->
        {:error, %Error{message: error_message(status, value), status: status, raw: value}}

      {:error, reason} ->
        {:error, reason}
    end
  end

  defp handle_response_frame(
         state,
         %{request_id: 0},
         _body,
         expected_request_id,
         expected_opcode
       ),
       do: receive_response(state, expected_request_id, expected_opcode)

  defp handle_response_frame(_state, header, body, expected_request_id, expected_opcode) do
    with :ok <- validate_response_header(header, expected_request_id, expected_opcode) do
      decode_response_result(header, body)
    end
  end

  defp decode_response_result(header, body) do
    case Protocol.decode_response_body(header.flags, header.opcode, body) do
      {:ok, value} ->
        {:ok, value}

      {:error, {status, value}} ->
        {:error, %Error{message: error_message(status, value), status: status, raw: value}}

      {:error, reason} ->
        {:error, reason}
    end
  end

  defp recv_response_body(state, header) do
    with {:ok, first_body} <- recv_exact(state, header.body_length) do
      if Bitwise.band(header.flags, Protocol.flag_more_chunks()) == 0 do
        {:ok, first_body}
      else
        recv_chunks(state, header, [first_body])
      end
    end
  end

  defp recv_chunks(state, previous_header, chunks) do
    with {:ok, header_binary} <- recv_exact(state, Protocol.header_size()),
         {:ok, header} <- Protocol.decode_response_header(header_binary),
         :ok <- validate_chunk_header(header, previous_header),
         {:ok, body} <- recv_exact(state, header.body_length) do
      chunks = [body | chunks]

      if Bitwise.band(header.flags, Protocol.flag_more_chunks()) == 0 do
        {:ok, chunks |> Enum.reverse() |> IO.iodata_to_binary()}
      else
        recv_chunks(state, header, chunks)
      end
    end
  end

  defp send_data(%{transport: :gen_tcp, socket: socket}, data), do: :gen_tcp.send(socket, data)
  defp send_data(%{transport: :ssl, socket: socket}, data), do: :ssl.send(socket, data)

  defp recv_exact(_state, 0), do: {:ok, <<>>}
  defp recv_exact(state, size), do: recv_exact(state, size, [])
  defp recv_exact(_state, 0, acc), do: {:ok, acc |> Enum.reverse() |> IO.iodata_to_binary()}

  defp recv_exact(state, size, acc) do
    case recv(state, size) do
      {:ok, data} when byte_size(data) == size -> recv_exact(state, 0, [data | acc])
      {:ok, data} -> recv_exact(state, size - byte_size(data), [data | acc])
      {:error, reason} -> {:error, reason}
    end
  end

  defp recv(%{transport: :gen_tcp, socket: socket, timeout: timeout}, size),
    do: :gen_tcp.recv(socket, size, timeout)

  defp recv(%{transport: :ssl, socket: socket, timeout: timeout}, size),
    do: :ssl.recv(socket, size, timeout)

  defp set_active_once(%{transport: :gen_tcp, socket: socket}),
    do: :inet.setopts(socket, active: :once)

  defp set_active_once(%{transport: :ssl, socket: socket}),
    do: :ssl.setopts(socket, active: :once)

  defp validate_response_header(%{request_id: request_id, opcode: opcode}, request_id, opcode),
    do: :ok

  defp validate_response_header(header, request_id, opcode) do
    {:error,
     %Error{
       message:
         "protocol response mismatch: expected request #{request_id}/opcode #{opcode}, got #{header.request_id}/#{header.opcode}",
       raw: header
     }}
  end

  defp validate_chunk_header(header, previous) do
    if header.request_id == previous.request_id and header.opcode == previous.opcode and
         header.lane_id == previous.lane_id do
      :ok
    else
      {:error, %Error{message: "invalid protocol chunk continuation", raw: header}}
    end
  end

  defp maybe_auth(state, opts) do
    password = Keyword.get(opts, :password)

    if password in [nil, ""] do
      {:ok, state}
    else
      payload = %{"username" => Keyword.get(opts, :username), "password" => password}

      case request(state, Protocol.opcode(:auth), payload, 0) do
        {:ok, _value, state} -> {:ok, state}
        {:error, error, _state} -> {:error, error}
      end
    end
  end

  defp activate_socket(state, true) do
    case set_active_once(state) do
      :ok -> {:ok, state}
      {:error, reason} -> {:stop, reason}
    end
  end

  defp activate_socket(state, false), do: {:ok, state}

  defp startup_payload(opts) do
    payload = %{"compression" => "none", "compact_flow_responses" => true}

    case Keyword.get(opts, :client_name) do
      nil -> payload
      name -> Map.merge(payload, %{"client_name" => name, "driver_name" => name})
    end
  end

  defp next_request_id(0xFFFF_FFFF_FFFF_FFFF), do: 1
  defp next_request_id(value), do: value + 1

  defp timeout(opts), do: Keyword.get(opts, :timeout, @default_timeout)

  defp parse_userinfo(nil), do: {nil, nil}

  defp parse_userinfo(userinfo) do
    case String.split(userinfo, ":", parts: 2) do
      [username, password] -> {URI.decode(username), URI.decode(password)}
      [username] -> {URI.decode(username), nil}
    end
  end

  defp default_port("ferric"), do: 6388
  defp default_port("ferrics"), do: 6389

  defp ssl_opts(opts) do
    base = [:binary, active: false, packet: :raw]

    if Keyword.get(opts, :verify, true) do
      [{:verify, :verify_peer} | base]
    else
      [{:verify, :verify_none} | base]
    end
  end

  defp close_socket(%{transport: :gen_tcp, socket: socket}) when not is_nil(socket),
    do: :gen_tcp.close(socket)

  defp close_socket(%{transport: :ssl, socket: socket}) when not is_nil(socket),
    do: :ssl.close(socket)

  defp close_socket(_state), do: :ok

  defp error_message(status, value) when is_binary(value),
    do: "FerricStore error #{status}: #{value}"

  defp error_message(status, value), do: "FerricStore error #{status}: #{inspect(value)}"
end