defmodule Bandit.HTTP2.Stream do
@moduledoc false
# Carries out state management transitions per RFC9113§5.1. Anything having to do
# with the internal state of a stream is handled in this module. Note that sending
# of frames on behalf of a stream is a bit of a split responsibility: the stream
# itself may update state depending on the value of the end_stream flag (this is
# a stream concern and thus handled here), but the sending of the data over the
# wire is a connection concern as it must be serialized properly & is subject to
# flow control at a connection level
require Integer
require Logger
alias Bandit.HTTP2.{Connection, Errors, FlowControl, StreamTask}
defstruct stream_id: nil,
state: nil,
pid: nil,
recv_window_size: nil,
send_window_size: nil,
pending_content_length: nil,
span: nil
defmodule StreamError, do: defexception([:message])
# credo:disable-for-this-file Credo.Check.Design.AliasUsage
@typedoc "An HTTP/2 stream identifier"
@type stream_id :: non_neg_integer()
@typedoc "An HTTP/2 stream state"
@type state :: :reserved_local | :idle | :open | :local_closed | :remote_closed | :closed
@typedoc "A description of a stream error"
@type error :: {:stream, stream_id(), Errors.error_code(), String.t()}
@typedoc "A single HTTP/2 stream"
@type t :: %__MODULE__{
stream_id: stream_id(),
state: state(),
pid: pid() | nil,
recv_window_size: non_neg_integer(),
send_window_size: non_neg_integer(),
pending_content_length: non_neg_integer() | nil,
span: Bandit.Telemetry.t()
}
@spec recv_headers(
t(),
Bandit.Pipeline.transport_info(),
Plug.Conn.headers(),
boolean,
Bandit.plug(),
keyword()
) :: {:ok, t()} | {:error, Connection.error()} | {:error, error()}
def recv_headers(
%__MODULE__{state: state} = stream,
_transport_info,
trailers,
true,
_plug,
_opts
)
when state in [:open, :local_closed] do
with :ok <- no_pseudo_headers(trailers, stream.stream_id) do
# These are actually trailers, which Plug doesn't support. Log and ignore
Logger.warning("Ignoring trailers on stream #{stream.stream_id}: #{inspect(trailers)}")
{:ok, stream}
end
end
def recv_headers(
%__MODULE__{state: :idle} = stream,
transport_info,
headers,
_end_stream,
plug,
opts
) do
with :ok <- stream_id_is_valid_client(stream.stream_id),
{_, _, peer, connection_span} <- transport_info,
span <- start_span(connection_span, stream.stream_id),
{:ok, content_length} <- get_content_length(headers, stream.stream_id),
content_encoding <- negotiate_content_encoding(headers, opts),
req <- Bandit.HTTP2.Adapter.init(self(), peer, stream.stream_id, content_encoding, opts),
{:ok, pid} <- StreamTask.start_link(req, transport_info, headers, plug, span) do
{:ok,
%{stream | state: :open, pid: pid, pending_content_length: content_length, span: span}}
end
end
def recv_headers(%__MODULE__{}, _transport_info, _headers, _end_stream, _plug, _opts) do
{:error, {:connection, Errors.protocol_error(), "Received HEADERS in unexpected state"}}
end
# RFC9113§5.1.1 - client initiated streams must be odd
defp stream_id_is_valid_client(stream_id) do
if Integer.is_odd(stream_id) do
:ok
else
{:error, {:connection, Errors.protocol_error(), "Received HEADERS with even stream_id"}}
end
end
defp start_span(connection_span, stream_id) do
Bandit.Telemetry.start_span(:request, %{}, %{
connection_telemetry_span_context: connection_span.telemetry_span_context,
stream_id: stream_id
})
end
# RFC9113§8.1.1 - content length must be valid
defp get_content_length(headers, stream_id) do
case Bandit.Headers.get_content_length(headers) do
{:ok, content_length} -> {:ok, content_length}
{:error, reason} -> {:error, {:stream, stream_id, Errors.protocol_error(), reason}}
end
end
defp negotiate_content_encoding(headers, opts) do
Bandit.Compression.negotiate_content_encoding(
Bandit.Headers.get_header(headers, "accept-encoding"),
Keyword.get(opts, :compress, true)
)
end
# RFC9113§8.1 - no pseudo headers
defp no_pseudo_headers(headers, stream_id) do
if Enum.any?(headers, fn {key, _value} -> String.starts_with?(key, ":") end) do
{:error,
{:stream, stream_id, Errors.protocol_error(), "Received trailers with pseudo headers"}}
else
:ok
end
end
@spec recv_data(t(), binary()) :: {:ok, t(), non_neg_integer()} | {:error, Connection.error()}
def recv_data(%__MODULE__{state: state} = stream, data) when state in [:open, :local_closed] do
StreamTask.recv_data(stream.pid, data)
{new_window, increment} =
FlowControl.compute_recv_window(stream.recv_window_size, byte_size(data))
pending_content_length =
case stream.pending_content_length do
nil -> nil
pending_content_length -> pending_content_length - byte_size(data)
end
{:ok,
%{stream | recv_window_size: new_window, pending_content_length: pending_content_length},
increment}
end
def recv_data(%__MODULE__{} = stream, _data) do
{:error, {:connection, Errors.protocol_error(), "Received DATA when in #{stream.state}"}}
end
@spec recv_window_update(t(), non_neg_integer()) ::
{:ok, t()} | {:error, Connection.error()} | {:error, error()}
def recv_window_update(%__MODULE__{state: :idle}, _increment) do
{:error, {:connection, Errors.protocol_error(), "Received WINDOW_UPDATE when in idle"}}
end
def recv_window_update(%__MODULE__{} = stream, increment) do
case FlowControl.update_send_window(stream.send_window_size, increment) do
{:ok, new_window} ->
{:ok, %{stream | send_window_size: new_window}}
{:error, error} ->
{:error, {:stream, stream.stream_id, Errors.flow_control_error(), error}}
end
end
@spec recv_rst_stream(t(), Errors.error_code()) ::
{:ok, t()} | {:error, Connection.error()}
def recv_rst_stream(%__MODULE__{state: :idle}, _error_code) do
{:error, {:connection, Errors.protocol_error(), "Received RST_STREAM when in idle"}}
end
def recv_rst_stream(%__MODULE__{} = stream, error_code) do
if is_pid(stream.pid), do: StreamTask.recv_rst_stream(stream.pid, error_code)
{:ok, %{stream | state: :closed, pid: nil}}
end
@spec recv_end_of_stream(t(), boolean()) ::
{:ok, t()} | {:error, Connection.error()}
def recv_end_of_stream(%__MODULE__{state: :open} = stream, true) do
with :ok <- verify_content_length(stream) do
StreamTask.recv_end_of_stream(stream.pid)
{:ok, %{stream | state: :remote_closed}}
end
end
def recv_end_of_stream(%__MODULE__{state: :local_closed} = stream, true) do
with :ok <- verify_content_length(stream) do
StreamTask.recv_end_of_stream(stream.pid)
{:ok, %{stream | state: :closed, pid: nil}}
end
end
def recv_end_of_stream(%__MODULE__{}, true) do
{:error, {:connection, Errors.protocol_error(), "Received unexpected end_stream"}}
end
def recv_end_of_stream(%__MODULE__{} = stream, false) do
{:ok, stream}
end
defp verify_content_length(%__MODULE__{pending_content_length: nil}), do: :ok
defp verify_content_length(%__MODULE__{pending_content_length: 0}), do: :ok
defp verify_content_length(%__MODULE__{} = stream) do
{:error,
{:stream, stream.stream_id, Errors.protocol_error(),
"Received end of stream with #{stream.pending_content_length} byte(s) pending"}}
end
@spec get_send_window_size(t()) :: non_neg_integer()
def get_send_window_size(%__MODULE__{} = stream), do: stream.send_window_size
@spec send_headers(t()) :: {:ok, t()} | {:error, :invalid_state}
def send_headers(%__MODULE__{state: state} = stream) when state in [:open, :remote_closed] do
{:ok, stream}
end
def send_headers(%__MODULE__{}) do
{:error, :invalid_state}
end
@spec send_data(t(), non_neg_integer()) ::
{:ok, t()} | {:error, :insufficient_window_size} | {:error, :invalid_state}
def send_data(%__MODULE__{state: state} = stream, 0) when state in [:open, :remote_closed] do
{:ok, stream}
end
def send_data(%__MODULE__{state: state} = stream, len) when state in [:open, :remote_closed] do
if len <= stream.send_window_size do
{:ok, %{stream | send_window_size: stream.send_window_size - len}}
else
{:error, :insufficient_window_size}
end
end
def send_data(%__MODULE__{}, _len) do
{:error, :invalid_state}
end
@spec send_end_of_stream(t(), boolean()) :: {:ok, t()} | {:error, :invalid_state}
def send_end_of_stream(%__MODULE__{state: :open} = stream, true) do
{:ok, %{stream | state: :local_closed}}
end
def send_end_of_stream(%__MODULE__{state: :remote_closed} = stream, true) do
{:ok, %{stream | state: :closed, pid: nil}}
end
def send_end_of_stream(%__MODULE__{}, true) do
{:error, :invalid_state}
end
def send_end_of_stream(%__MODULE__{} = stream, false) do
{:ok, stream}
end
@spec terminate_stream(t(), term()) :: :ok
def terminate_stream(%__MODULE__{pid: pid}, reason) when is_pid(pid) do
# Just kill the process; we will receive a call to stream_terminated once the process actually
# dies, at which point we will transition the struct to the expected final state
Process.exit(pid, reason)
:ok
end
def terminate_stream(%__MODULE__{}, _reason) do
:ok
end
@spec stream_terminated(t(), term()) :: {:ok, t(), Errors.error_code() | nil}
def stream_terminated(%__MODULE__{state: :closed} = stream, :normal) do
# In the normal case, stop telemetry is emitted by the stream process to keep the main
# connection process unblocked. In error cases we send from here, however, since there are
# many error cases which never involve the stream process at all
{:ok, %{stream | state: :closed, pid: nil}, nil}
end
def stream_terminated(%__MODULE__{} = stream, {:bandit, reason}) do
Bandit.Telemetry.stop_span(stream.span, %{}, %{error: reason})
Logger.warning("Stream #{stream.stream_id} was killed by bandit (#{reason})")
{:ok, %{stream | state: :closed, pid: nil}, nil}
end
def stream_terminated(%__MODULE__{} = stream, {%StreamError{} = error, _}) do
Bandit.Telemetry.stop_span(stream.span, %{}, %{error: error.message})
Logger.warning("Stream #{stream.stream_id} encountered a stream error (#{inspect(error)})")
{:ok, %{stream | state: :closed, pid: nil}, Errors.protocol_error()}
end
def stream_terminated(%__MODULE__{} = stream, :normal) do
Logger.warning("Stream #{stream.stream_id} completed in unexpected state #{stream.state}")
{:ok, %{stream | state: :closed, pid: nil}, Errors.no_error()}
end
def stream_terminated(%__MODULE__{} = stream, reason) do
case reason do
{exception, stacktrace} ->
Bandit.Telemetry.span_exception(stream.span, :exit, exception, stacktrace)
_ ->
:ok
end
Logger.error("Task for stream #{stream.stream_id} crashed with #{inspect(reason)}")
{:ok, %{stream | state: :closed, pid: nil}, Errors.internal_error()}
end
end