Skip to main content

src/nhttp_compress.erl

-module(nhttp_compress).

-moduledoc """
HTTP compression utilities for nhttp.

Supports gzip and deflate content encoding for request/response bodies.
Compression is configurable and can be disabled.

Configuration options:
- `compression` => boolean() (default: true)
- `compression_level` => 1..9 (default: 6)
- `compression_threshold` => non_neg_integer() (default: 1024)
- `compress_mime_types` => [binary()] (default: text/*, application/json, etc.)
""".

%%%-----------------------------------------------------------------------------
%% COMPRESSION
%%%-----------------------------------------------------------------------------
-export([
    compress/3,
    decompress/2,
    decompress/3
]).

%%%-----------------------------------------------------------------------------
%% NEGOTIATION
%%%-----------------------------------------------------------------------------
-export([
    default_mime_types/0,
    encoding_header/1,
    negotiate_encoding/1,
    should_compress/3,
    should_compress/4
]).

%%%-----------------------------------------------------------------------------
%% TYPES
%%%-----------------------------------------------------------------------------
-export_type([compress_opts/0, encoding/0, zlib_error/0]).

-type encoding() :: gzip | deflate | identity.

-type compress_opts() :: #{
    compression => boolean(),
    compression_level => 1..9,
    compression_threshold => non_neg_integer(),
    compress_mime_types => [binary()]
}.

-type zlib_error() ::
    data_error
    | buf_error
    | stream_error
    | badarg
    | enomem
    | max_output_exceeded
    | {unknown, term()}.

%%%-----------------------------------------------------------------------------
%% MACROS
%%%-----------------------------------------------------------------------------
-define(DEFAULT_COMPRESSION_THRESHOLD, 1024).
-define(ZLIB_GZIP_WINDOW_BITS, 16 + 15).
-define(DEFAULT_MAX_DECOMPRESSED_SIZE, 16 * 1024 * 1024).

%%%-----------------------------------------------------------------------------
%% COMPRESSION
%%%-----------------------------------------------------------------------------
-doc "Compress data using the specified encoding. Returns `{ok, CompressedData}` or `{error, Reason}`.".
-spec compress(Data :: iodata(), Encoding :: gzip | deflate, Level :: 1..9) ->
    {ok, binary()} | {error, zlib_error()}.
compress(Data, gzip, Level) ->
    compress_gzip(Data, Level);
compress(Data, deflate, Level) ->
    compress_deflate(Data, Level);
compress(Data, identity, _Level) ->
    {ok, iolist_to_binary(Data)}.

-doc """
Decompress data using the specified encoding. Caps inflated output at
`?DEFAULT_MAX_DECOMPRESSED_SIZE` (16 MiB) to defend against decompression
bombs. Use `decompress/3` to override.
""".
-spec decompress(Data :: binary(), Encoding :: encoding()) ->
    {ok, binary()} | {error, zlib_error()}.
decompress(Data, Encoding) ->
    decompress(Data, Encoding, ?DEFAULT_MAX_DECOMPRESSED_SIZE).

-doc """
Decompress data with an explicit cap on the inflated output. Returns
`{error, max_output_exceeded}` if decoding would produce more than `Max`
bytes; pass `infinity` to disable the cap.
""".
-spec decompress(Data :: binary(), Encoding :: encoding(), Max) ->
    {ok, binary()} | {error, zlib_error()}
when
    Max :: pos_integer() | infinity.
decompress(Data, gzip, Max) ->
    decompress_gzip(Data, Max);
decompress(Data, deflate, Max) ->
    decompress_deflate(Data, Max);
decompress(Data, identity, Max) ->
    case Max =:= infinity orelse byte_size(Data) =< Max of
        true -> {ok, Data};
        false -> {error, max_output_exceeded}
    end.

%%%-----------------------------------------------------------------------------
%% NEGOTIATION
%%%-----------------------------------------------------------------------------
-doc "Return default MIME types eligible for compression.".
-spec default_mime_types() -> [binary()].
default_mime_types() ->
    [
        <<"text/html">>,
        <<"text/plain">>,
        <<"text/css">>,
        <<"text/javascript">>,
        <<"text/xml">>,
        <<"application/json">>,
        <<"application/javascript">>,
        <<"application/xml">>,
        <<"application/xhtml+xml">>,
        <<"image/svg+xml">>
    ].

-doc "Convert encoding atom to header value.".
-spec encoding_header(encoding()) -> binary().
encoding_header(gzip) -> <<"gzip">>;
encoding_header(deflate) -> <<"deflate">>;
encoding_header(identity) -> <<"identity">>.

-doc "Negotiate encoding based on Accept-Encoding header. Returns the best encoding the client accepts that we support.".
-spec negotiate_encoding(Headers :: [{binary(), binary()}]) -> encoding().
negotiate_encoding(Headers) ->
    case nhttp_headers:get(<<"accept-encoding">>, Headers) of
        undefined ->
            identity;
        AcceptEncoding ->
            Encodings = parse_accept_encoding(AcceptEncoding),
            select_best_encoding(Encodings)
    end.

-doc """
Check if response should be compressed using the default threshold
(`?DEFAULT_COMPRESSION_THRESHOLD`). Equivalent to
`should_compress(ContentType, Size, MimeTypes, ?DEFAULT_COMPRESSION_THRESHOLD)`.
""".
-spec should_compress(
    ContentType :: binary() | undefined,
    Size :: non_neg_integer(),
    MimeTypes :: [binary()]
) -> boolean().
should_compress(ContentType, Size, MimeTypes) ->
    should_compress(ContentType, Size, MimeTypes, ?DEFAULT_COMPRESSION_THRESHOLD).

-doc "Check if response should be compressed. Returns true when body size meets the threshold and Content-Type matches the compressible MIME types.".
-spec should_compress(
    ContentType :: binary() | undefined,
    Size :: non_neg_integer(),
    MimeTypes :: [binary()],
    Threshold :: non_neg_integer()
) -> boolean().
should_compress(undefined, _Size, _MimeTypes, _Threshold) ->
    false;
should_compress(_ContentType, Size, _MimeTypes, Threshold) when Size < Threshold ->
    false;
should_compress(ContentType, _Size, MimeTypes, _Threshold) ->
    is_compressible_type(ContentType, MimeTypes).

%%%-----------------------------------------------------------------------------
%% INTERNAL FUNCTIONS
%%%-----------------------------------------------------------------------------
-spec check_inflate_size(non_neg_integer(), iolist(), pos_integer() | infinity) ->
    ok | {error, max_output_exceeded}.
check_inflate_size(_Total, _Out, infinity) ->
    ok;
check_inflate_size(Total, Out, Max) ->
    case Total + iolist_size(Out) > Max of
        true -> {error, max_output_exceeded};
        false -> ok
    end.

-spec compress_deflate(iodata(), 1..9) -> {ok, binary()} | {error, zlib_error()}.
compress_deflate(Data, Level) ->
    with_zlib_stream(fun(Z) -> do_compress_deflate(Z, Data, Level) end).

-spec compress_gzip(iodata(), 1..9) -> {ok, binary()} | {error, zlib_error()}.
compress_gzip(Data, Level) ->
    with_zlib_stream(fun(Z) -> do_compress_gzip(Z, Data, Level) end).

-spec decompress_deflate(binary(), pos_integer() | infinity) ->
    {ok, binary()} | {error, zlib_error()}.
decompress_deflate(Data, Max) ->
    with_zlib_stream(fun(Z) -> do_decompress_deflate(Z, Data, Max) end).

-spec decompress_gzip(binary(), pos_integer() | infinity) ->
    {ok, binary()} | {error, zlib_error()}.
decompress_gzip(Data, Max) ->
    with_zlib_stream(fun(Z) -> do_decompress_gzip(Z, Data, Max) end).

-spec do_compress_deflate(zlib:zstream(), iodata(), 1..9) ->
    {ok, binary()} | {error, zlib_error()}.
do_compress_deflate(Z, Data, Level) ->
    maybe
        ok ?= zlib_deflate_init(Z, Level),
        {ok, Compressed} ?= zlib_deflate(Z, Data),
        ok ?= zlib_deflate_end(Z),
        {ok, iolist_to_binary(Compressed)}
    end.

-spec do_compress_gzip(zlib:zstream(), iodata(), 1..9) ->
    {ok, binary()} | {error, zlib_error()}.
do_compress_gzip(Z, Data, Level) ->
    maybe
        ok ?= zlib_deflate_init_gzip(Z, Level),
        {ok, Compressed} ?= zlib_deflate(Z, Data),
        ok ?= zlib_deflate_end(Z),
        {ok, iolist_to_binary(Compressed)}
    end.

-spec do_decompress_deflate(zlib:zstream(), binary(), pos_integer() | infinity) ->
    {ok, binary()} | {error, zlib_error()}.
do_decompress_deflate(Z, Data, Max) ->
    maybe
        ok ?= zlib_inflate_init(Z),
        {ok, Decompressed} ?= zlib_safe_inflate(Z, Data, Max),
        ok ?= zlib_inflate_end(Z),
        {ok, iolist_to_binary(Decompressed)}
    end.

-spec do_decompress_gzip(zlib:zstream(), binary(), pos_integer() | infinity) ->
    {ok, binary()} | {error, zlib_error()}.
do_decompress_gzip(Z, Data, Max) ->
    maybe
        ok ?= zlib_inflate_init_gzip(Z),
        {ok, Decompressed} ?= zlib_safe_inflate(Z, Data, Max),
        ok ?= zlib_inflate_end(Z),
        {ok, iolist_to_binary(Decompressed)}
    end.

-spec encoding_to_atom(binary()) -> encoding().
encoding_to_atom(<<"gzip">>) -> gzip;
encoding_to_atom(<<"x-gzip">>) -> gzip;
encoding_to_atom(<<"deflate">>) -> deflate;
encoding_to_atom(<<"*">>) -> gzip;
encoding_to_atom(_) -> identity.

-spec extract_base_mime(binary()) -> binary().
extract_base_mime(ContentType) ->
    case binary:split(ContentType, <<";">>) of
        [BaseMime | _] -> string:trim(BaseMime);
        [] -> ContentType
    end.

-spec is_compressible_type(binary(), [binary()]) -> boolean().
is_compressible_type(ContentType, MimeTypes) ->
    Lower = nhttp_headers:to_lower(extract_base_mime(ContentType)),
    lists:any(fun(Pattern) -> mime_matches(Lower, Pattern) end, MimeTypes).

-spec mime_matches(binary(), binary()) -> boolean().
mime_matches(ContentType, Pattern) ->
    case nhttp_headers:to_lower(Pattern) of
        <<"text/*">> ->
            case ContentType of
                <<"text/", _/binary>> -> true;
                _ -> false
            end;
        <<"application/*">> ->
            case ContentType of
                <<"application/", _/binary>> -> true;
                _ -> false
            end;
        LowerPattern ->
            ContentType =:= LowerPattern
    end.

-spec parse_accept_encoding(binary()) -> [{binary(), float()}].
parse_accept_encoding(AcceptEncoding) ->
    Parts = binary:split(AcceptEncoding, <<",">>, [global, trim_all]),
    lists:filtermap(fun parse_encoding_part/1, Parts).

-spec parse_encoding_part(binary()) -> {true, {binary(), float()}} | false.
parse_encoding_part(Part) ->
    Trimmed = string:trim(Part),
    case binary:split(Trimmed, <<";">>) of
        [Encoding] ->
            {true, {nhttp_headers:to_lower(Encoding), 1.0}};
        [Encoding, QValue] ->
            case parse_quality(QValue) of
                {ok, Q} -> {true, {nhttp_headers:to_lower(Encoding), Q}};
                {error, _} -> false
            end;
        _ ->
            false
    end.

-spec parse_quality(binary()) -> {ok, float()} | {error, badarg}.
parse_quality(QValue) ->
    Trimmed = string:trim(QValue),
    parse_quality_value(Trimmed).

-spec parse_quality_value(binary()) -> {ok, float()} | {error, badarg}.
parse_quality_value(<<"q=", Rest/binary>>) ->
    parse_qvalue(Rest);
parse_quality_value(_) ->
    {error, badarg}.

-spec parse_qvalue(binary()) -> {ok, float()} | {error, badarg}.
parse_qvalue(<<"1">>) ->
    {ok, 1.0};
parse_qvalue(<<"1.0">>) ->
    {ok, 1.0};
parse_qvalue(<<"0">>) ->
    {ok, 0.0};
parse_qvalue(<<"0.0">>) ->
    {ok, 0.0};
parse_qvalue(<<"0.", D1>>) when D1 >= $0, D1 =< $9 ->
    {ok, (D1 - $0) / 10.0};
parse_qvalue(<<"0.", D1, D2>>) when D1 >= $0, D1 =< $9, D2 >= $0, D2 =< $9 ->
    {ok, ((D1 - $0) * 10 + (D2 - $0)) / 100.0};
parse_qvalue(<<"0.", D1, D2, D3>>) when
    D1 >= $0,
    D1 =< $9,
    D2 >= $0,
    D2 =< $9,
    D3 >= $0,
    D3 =< $9
->
    {ok, ((D1 - $0) * 100 + (D2 - $0) * 10 + (D3 - $0)) / 1000.0};
parse_qvalue(_) ->
    {error, badarg}.

-spec select_best_encoding([{binary(), float()}]) -> encoding().
select_best_encoding(Encodings) ->
    Filtered = [{E, Q} || {E, Q} <- Encodings, Q > 0.0],
    Sorted = lists:sort(fun({_, Q1}, {_, Q2}) -> Q1 > Q2 end, Filtered),
    select_first_supported(Sorted).

-spec select_first_supported([{binary(), float()}]) -> encoding().
select_first_supported([]) ->
    identity;
select_first_supported([{Encoding, _Q} | Rest]) ->
    case encoding_to_atom(Encoding) of
        identity -> select_first_supported(Rest);
        Atom -> Atom
    end.

-spec with_zlib_stream(fun((zlib:zstream()) -> {ok, binary()} | {error, zlib_error()})) ->
    {ok, binary()} | {error, zlib_error()}.
with_zlib_stream(Fun) ->
    case zlib_open() of
        {ok, Z} ->
            try
                Fun(Z)
            after
                _ = zlib_close(Z)
            end;
        {error, _} = Error ->
            Error
    end.

-spec zlib_close(zlib:zstream()) -> ok | {error, zlib_error()}.
zlib_close(Z) ->
    try
        zlib:close(Z),
        ok
    catch
        error:Reason -> {error, Reason}
    end.

-spec zlib_deflate(zlib:zstream(), iodata()) -> {ok, iolist()} | {error, zlib_error()}.
zlib_deflate(Z, Data) ->
    try
        {ok, zlib:deflate(Z, Data, finish)}
    catch
        error:Reason -> {error, Reason}
    end.

-spec zlib_deflate_end(zlib:zstream()) -> ok | {error, zlib_error()}.
zlib_deflate_end(Z) ->
    try
        ok = zlib:deflateEnd(Z),
        ok
    catch
        error:Reason -> {error, Reason}
    end.

-spec zlib_deflate_init(zlib:zstream(), 1..9) -> ok | {error, zlib_error()}.
zlib_deflate_init(Z, Level) ->
    try
        ok = zlib:deflateInit(Z, Level),
        ok
    catch
        error:Reason -> {error, Reason}
    end.

-spec zlib_deflate_init_gzip(zlib:zstream(), 1..9) -> ok | {error, zlib_error()}.
zlib_deflate_init_gzip(Z, Level) ->
    try
        ok = zlib:deflateInit(Z, Level, deflated, ?ZLIB_GZIP_WINDOW_BITS, 8, default),
        ok
    catch
        error:Reason -> {error, Reason}
    end.

-spec zlib_inflate_end(zlib:zstream()) -> ok | {error, zlib_error()}.
zlib_inflate_end(Z) ->
    try
        ok = zlib:inflateEnd(Z),
        ok
    catch
        error:Reason -> {error, Reason}
    end.

-spec zlib_inflate_init(zlib:zstream()) -> ok | {error, zlib_error()}.
zlib_inflate_init(Z) ->
    try
        ok = zlib:inflateInit(Z),
        ok
    catch
        error:Reason -> {error, Reason}
    end.

-spec zlib_inflate_init_gzip(zlib:zstream()) -> ok | {error, zlib_error()}.
zlib_inflate_init_gzip(Z) ->
    try
        ok = zlib:inflateInit(Z, ?ZLIB_GZIP_WINDOW_BITS),
        ok
    catch
        error:Reason -> {error, Reason}
    end.

-spec zlib_open() -> {ok, zlib:zstream()} | {error, zlib_error()}.
zlib_open() ->
    try
        {ok, zlib:open()}
    catch
        error:Reason -> {error, Reason}
    end.

-spec zlib_safe_inflate(zlib:zstream(), binary(), pos_integer() | infinity) ->
    {ok, iolist()} | {error, zlib_error()}.
zlib_safe_inflate(Z, Data, Max) ->
    try
        zlib_safe_inflate_loop(Z, {first, Data}, Max, 0, [])
    catch
        error:Reason -> {error, Reason}
    end.

-spec zlib_safe_inflate_loop(
    zlib:zstream(),
    {first, binary()} | drain,
    pos_integer() | infinity,
    non_neg_integer(),
    [iolist()]
) ->
    {ok, iolist()} | {error, zlib_error()}.
zlib_safe_inflate_loop(Z, Step, Max, Total, Acc) ->
    Input =
        case Step of
            {first, Data} -> Data;
            drain -> <<>>
        end,
    case zlib:safeInflate(Z, Input) of
        {finished, Out} ->
            case check_inflate_size(Total, Out, Max) of
                ok -> {ok, lists:reverse([Out | Acc])};
                {error, _} = E -> E
            end;
        {continue, Out} ->
            case check_inflate_size(Total, Out, Max) of
                ok ->
                    NewTotal = Total + iolist_size(Out),
                    zlib_safe_inflate_loop(Z, drain, Max, NewTotal, [Out | Acc]);
                {error, _} = E ->
                    E
            end
    end.