Skip to main content

src/nhttp_proxy_protocol.erl

-module(nhttp_proxy_protocol).

-moduledoc """
PROXY protocol v1 and v2 parser.

The PROXY protocol preserves the original client IP/port across a
connection-terminating intermediary (load balancer, proxy). Both v1
(ASCII line) and v2 (binary frame) are supported.

The 12-byte v2 signature `\\r\\n\\r\\n\\x00\\r\\nQUIT\\n` is unambiguous
and shares no prefix with `"PROXY "`, so callers wishing to accept both
versions can probe each shape from the first byte without ambiguity.

## Usage

```erlang
case nhttp_proxy_protocol:parse(Buffer, #{version => both}) of
    {ok, Header, Consumed} ->
        Rest = binary:part(Buffer, Consumed, byte_size(Buffer) - Consumed),
        handle(Header, Rest);
    {more, _MinBytes} ->
        recv_more();
    {error, Reason} ->
        reject(Reason)
end.
```

The returned `t:header/0` map carries the parsed source/destination
addresses for the `proxy` command. The `local` command (v2 only) signals
a health-check-style connection from the proxy itself; callers should
keep the original L4 peer in that case.
""".

%%%-----------------------------------------------------------------------------
%% API
%%%-----------------------------------------------------------------------------
-export([
    parse/1,
    parse/2
]).

%%%-----------------------------------------------------------------------------
%% TYPES
%%%-----------------------------------------------------------------------------
-export_type([
    command/0,
    family/0,
    header/0,
    opts/0,
    parse_error/0,
    parse_result/0,
    transport/0,
    version/0
]).

-type version() :: v1 | v2.

-type command() :: proxy | local.

-type family() :: inet | inet6 | unix | unspec.

-type transport() :: stream | dgram | unspec.

-type accepted() :: v1 | v2 | both.

-type tlv() :: {Type :: 0..255, Value :: binary()}.

-type header() :: #{
    version := version(),
    command := command(),
    family := family(),
    transport := transport(),
    src_addr => inet:ip_address() | binary(),
    dst_addr => inet:ip_address() | binary(),
    src_port => inet:port_number(),
    dst_port => inet:port_number(),
    tlvs => [tlv()]
}.

-type opts() :: #{version => accepted()}.

-type parse_error() ::
    bad_signature
    | bad_v1_header
    | bad_v1_proto
    | bad_v1_addr
    | bad_v1_port
    | bad_v1_too_long
    | bad_v2_version
    | bad_v2_command
    | bad_v2_family
    | bad_v2_transport
    | bad_v2_payload
    | unsupported_version.

-type parse_result() ::
    {ok, header(), Consumed :: pos_integer()}
    | {more, MinBytes :: pos_integer()}
    | {error, parse_error()}.

%%%-----------------------------------------------------------------------------
%% MACROS
%%%-----------------------------------------------------------------------------
-define(V2_SIGNATURE, <<13, 10, 13, 10, 0, 13, 10, "QUIT", 10>>).
-define(V2_SIG_LEN, 12).
-define(V2_HEADER_LEN, 16).
-define(V1_PREFIX, <<"PROXY ">>).
-define(V1_PREFIX_LEN, 6).
-define(V1_MAX_LEN, 107).

%%%-----------------------------------------------------------------------------
%% API
%%%-----------------------------------------------------------------------------
-doc "Equivalent to `parse(Bin, #{version => both})`.".
-spec parse(binary()) -> parse_result().
parse(<<Bin/binary>>) ->
    parse(Bin, #{}).

-doc """
Parse a PROXY protocol header from `Bin`.
`Opts` may carry `version => v1 | v2 | both` (default `both`) to restrict
the accepted protocol versions. When restricted, a header that matches a
disallowed version returns `{error, unsupported_version}`.
Returns `{ok, Header, Consumed}` on success, where `Consumed` is the
number of bytes occupied by the PROXY header (including any v1 CRLF or
v2 length prefix). Use `binary:part/3` on the original buffer to recover
the bytes that follow.
""".
-spec parse(binary(), opts()) -> parse_result().
parse(<<Bin/binary>>, Opts) ->
    Accepted = maps:get(version, Opts, both),
    classify(Bin, Accepted).

%%%-----------------------------------------------------------------------------
%% INTERNAL FUNCTIONS - DISPATCH
%%%-----------------------------------------------------------------------------
-spec allows_v1(accepted()) -> boolean().
allows_v1(v1) -> true;
allows_v1(both) -> true;
allows_v1(_) -> false.

-spec allows_v2(accepted()) -> boolean().
allows_v2(v2) -> true;
allows_v2(both) -> true;
allows_v2(_) -> false.

-spec classify(binary(), accepted()) -> parse_result().
classify(<<>>, _Accepted) ->
    {more, 1};
classify(<<13, _/binary>> = Bin, Accepted) ->
    case allows_v2(Accepted) of
        true -> match_v2(Bin);
        false -> {error, unsupported_version}
    end;
classify(<<$P, _/binary>> = Bin, Accepted) ->
    case allows_v1(Accepted) of
        true -> match_v1(Bin);
        false -> {error, unsupported_version}
    end;
classify(_Bin, _Accepted) ->
    {error, bad_signature}.

%%%-----------------------------------------------------------------------------
%% INTERNAL FUNCTIONS - V2
%%%-----------------------------------------------------------------------------
-spec build_v2_header(command(), family(), transport(), binary()) -> parse_result().
build_v2_header(local, _Family, _Transport, Payload) ->
    Header = #{
        version => v2,
        command => local,
        family => unspec,
        transport => unspec
    },
    {ok, Header, ?V2_HEADER_LEN + byte_size(Payload)};
build_v2_header(proxy, unspec, Transport, Payload) ->
    Header = #{
        version => v2,
        command => proxy,
        family => unspec,
        transport => Transport
    },
    {ok, Header, ?V2_HEADER_LEN + byte_size(Payload)};
build_v2_header(proxy, inet, Transport, Payload) ->
    case Payload of
        <<S1:8, S2:8, S3:8, S4:8, D1:8, D2:8, D3:8, D4:8, SP:16/big, DP:16/big, Tail/binary>> ->
            maybe
                {ok, Tlvs} ?= parse_tlvs(Tail, []),
                Header = #{
                    version => v2,
                    command => proxy,
                    family => inet,
                    transport => Transport,
                    src_addr => {S1, S2, S3, S4},
                    dst_addr => {D1, D2, D3, D4},
                    src_port => SP,
                    dst_port => DP,
                    tlvs => Tlvs
                },
                {ok, Header, ?V2_HEADER_LEN + byte_size(Payload)}
            end;
        _ ->
            {error, bad_v2_payload}
    end;
build_v2_header(proxy, inet6, Transport, Payload) ->
    case Payload of
        <<S1:16/big, S2:16/big, S3:16/big, S4:16/big, S5:16/big, S6:16/big, S7:16/big, S8:16/big,
            D1:16/big, D2:16/big, D3:16/big, D4:16/big, D5:16/big, D6:16/big, D7:16/big, D8:16/big,
            SP:16/big, DP:16/big, Tail/binary>> ->
            maybe
                {ok, Tlvs} ?= parse_tlvs(Tail, []),
                Header = #{
                    version => v2,
                    command => proxy,
                    family => inet6,
                    transport => Transport,
                    src_addr => {S1, S2, S3, S4, S5, S6, S7, S8},
                    dst_addr => {D1, D2, D3, D4, D5, D6, D7, D8},
                    src_port => SP,
                    dst_port => DP,
                    tlvs => Tlvs
                },
                {ok, Header, ?V2_HEADER_LEN + byte_size(Payload)}
            end;
        _ ->
            {error, bad_v2_payload}
    end;
build_v2_header(proxy, unix, Transport, Payload) ->
    case Payload of
        <<Src:108/binary, Dst:108/binary, Tail/binary>> ->
            maybe
                {ok, Tlvs} ?= parse_tlvs(Tail, []),
                Header = #{
                    version => v2,
                    command => proxy,
                    family => unix,
                    transport => Transport,
                    src_addr => trim_nul(Src),
                    dst_addr => trim_nul(Dst),
                    tlvs => Tlvs
                },
                {ok, Header, ?V2_HEADER_LEN + byte_size(Payload)}
            end;
        _ ->
            {error, bad_v2_payload}
    end.

-spec decode_v2(byte(), byte(), binary()) -> parse_result().
decode_v2(VerCmd, FamProto, Payload) ->
    Version = VerCmd bsr 4,
    Command = VerCmd band 16#0F,
    Family = FamProto bsr 4,
    Transport = FamProto band 16#0F,
    case Version of
        2 ->
            maybe
                {ok, CommandAtom} ?= decode_v2_command(Command),
                {ok, FamilyAtom} ?= decode_v2_family(Family),
                {ok, TransportAtom} ?= decode_v2_transport(Transport),
                build_v2_header(CommandAtom, FamilyAtom, TransportAtom, Payload)
            end;
        _ ->
            {error, bad_v2_version}
    end.

-spec decode_v2_command(byte()) -> {ok, command()} | {error, bad_v2_command}.
decode_v2_command(0) -> {ok, local};
decode_v2_command(1) -> {ok, proxy};
decode_v2_command(_) -> {error, bad_v2_command}.

-spec decode_v2_family(byte()) -> {ok, family()} | {error, bad_v2_family}.
decode_v2_family(0) -> {ok, unspec};
decode_v2_family(1) -> {ok, inet};
decode_v2_family(2) -> {ok, inet6};
decode_v2_family(3) -> {ok, unix};
decode_v2_family(_) -> {error, bad_v2_family}.

-spec decode_v2_transport(byte()) -> {ok, transport()} | {error, bad_v2_transport}.
decode_v2_transport(0) -> {ok, unspec};
decode_v2_transport(1) -> {ok, stream};
decode_v2_transport(2) -> {ok, dgram};
decode_v2_transport(_) -> {error, bad_v2_transport}.

-spec match_v2(binary()) -> parse_result().
match_v2(<<13, 10, 13, 10, 0, 13, 10, "QUIT", 10, Rest/binary>>) ->
    parse_v2(Rest);
match_v2(<<Bin/binary>>) when byte_size(Bin) < ?V2_SIG_LEN ->
    Size = byte_size(Bin),
    Prefix = binary:part(?V2_SIGNATURE, 0, Size),
    case Bin of
        Prefix -> {more, ?V2_HEADER_LEN - Size};
        _ -> {error, bad_signature}
    end;
match_v2(_Bin) ->
    {error, bad_signature}.

-spec parse_tlvs(binary(), [tlv()]) -> {ok, [tlv()]} | {error, bad_v2_payload}.
parse_tlvs(<<>>, Acc) ->
    {ok, lists:reverse(Acc)};
parse_tlvs(<<Type:8, Len:16/big, Value:Len/binary, Rest/binary>>, Acc) ->
    parse_tlvs(Rest, [{Type, Value} | Acc]);
parse_tlvs(_Bin, _Acc) ->
    {error, bad_v2_payload}.

-spec parse_v2(binary()) -> parse_result().
parse_v2(<<VerCmd:8, FamProto:8, Len:16/big, Rest/binary>>) ->
    case Rest of
        <<Payload:Len/binary, _/binary>> ->
            decode_v2(VerCmd, FamProto, Payload);
        _ ->
            {more, Len - byte_size(Rest)}
    end;
parse_v2(<<Bin/binary>>) ->
    {more, (?V2_HEADER_LEN - ?V2_SIG_LEN) - byte_size(Bin)}.

-spec trim_nul(binary()) -> binary().
trim_nul(<<Bin/binary>>) ->
    case binary:match(Bin, <<0>>) of
        {Pos, 1} -> binary:part(Bin, 0, Pos);
        nomatch -> Bin
    end.

%%%-----------------------------------------------------------------------------
%% INTERNAL FUNCTIONS - V1
%%%-----------------------------------------------------------------------------
-spec all_digits(binary()) -> boolean().
all_digits(<<>>) -> true;
all_digits(<<C, Rest/binary>>) when C >= $0, C =< $9 -> all_digits(Rest);
all_digits(_) -> false.

-spec decode_v1_addrs(binary(), inet | inet6, pos_integer()) -> parse_result().
decode_v1_addrs(Bin, Family, Consumed) ->
    case binary:split(Bin, <<" ">>, [global]) of
        [SrcAddr, DstAddr, SrcPort, DstPort] ->
            maybe
                {ok, SrcIp} ?= parse_v1_ip(SrcAddr, Family),
                {ok, DstIp} ?= parse_v1_ip(DstAddr, Family),
                {ok, SrcP} ?= parse_v1_port(SrcPort),
                {ok, DstP} ?= parse_v1_port(DstPort),
                Header = #{
                    version => v1,
                    command => proxy,
                    family => Family,
                    transport => stream,
                    src_addr => SrcIp,
                    dst_addr => DstIp,
                    src_port => SrcP,
                    dst_port => DstP
                },
                {ok, Header, Consumed}
            end;
        _ ->
            {error, bad_v1_header}
    end.

-spec decode_v1_line(binary(), pos_integer()) -> parse_result().
decode_v1_line(<<"UNKNOWN">>, Consumed) ->
    {ok, unknown_v1_header(), Consumed};
decode_v1_line(<<"UNKNOWN ", _Tail/binary>>, Consumed) ->
    {ok, unknown_v1_header(), Consumed};
decode_v1_line(<<"TCP4 ", Rest/binary>>, Consumed) ->
    decode_v1_addrs(Rest, inet, Consumed);
decode_v1_line(<<"TCP6 ", Rest/binary>>, Consumed) ->
    decode_v1_addrs(Rest, inet6, Consumed);
decode_v1_line(_Line, _Consumed) ->
    {error, bad_v1_proto}.

-spec match_v1(binary()) -> parse_result().
match_v1(<<"PROXY ", Rest/binary>>) ->
    parse_v1(Rest);
match_v1(<<Bin/binary>>) when byte_size(Bin) < ?V1_PREFIX_LEN ->
    Size = byte_size(Bin),
    Prefix = binary:part(?V1_PREFIX, 0, Size),
    case Bin of
        Prefix -> {more, ?V1_PREFIX_LEN - Size};
        _ -> {error, bad_signature}
    end;
match_v1(_Bin) ->
    {error, bad_signature}.

-spec parse_v1(binary()) -> parse_result().
parse_v1(<<Rest/binary>>) ->
    MaxBodyLen = ?V1_MAX_LEN - ?V1_PREFIX_LEN,
    Available = min(MaxBodyLen, byte_size(Rest)),
    Window = binary:part(Rest, 0, Available),
    case binary:match(Window, <<"\r\n">>) of
        {Pos, 2} ->
            <<Line:Pos/binary, "\r\n", _/binary>> = Rest,
            decode_v1_line(Line, ?V1_PREFIX_LEN + Pos + 2);
        nomatch when byte_size(Rest) >= MaxBodyLen ->
            {error, bad_v1_too_long};
        nomatch ->
            {more, 1}
    end.

-spec parse_v1_ip(binary(), inet | inet6) -> {ok, inet:ip_address()} | {error, bad_v1_addr}.
parse_v1_ip(IpBin, Family) ->
    case inet:parse_strict_address(binary_to_list(IpBin)) of
        {ok, IP} when Family =:= inet, tuple_size(IP) =:= 4 ->
            {ok, IP};
        {ok, IP} when Family =:= inet6, tuple_size(IP) =:= 8 ->
            {ok, IP};
        _ ->
            {error, bad_v1_addr}
    end.

-spec parse_v1_port(binary()) -> {ok, inet:port_number()} | {error, bad_v1_port}.
parse_v1_port(<<>>) ->
    {error, bad_v1_port};
parse_v1_port(PortBin) ->
    case all_digits(PortBin) of
        true ->
            Port = binary_to_integer(PortBin),
            case Port =< 65535 of
                true -> {ok, Port};
                false -> {error, bad_v1_port}
            end;
        false ->
            {error, bad_v1_port}
    end.

-spec unknown_v1_header() -> header().
unknown_v1_header() ->
    #{
        version => v1,
        command => proxy,
        family => unspec,
        transport => unspec
    }.