Skip to main content

src/livery_multipart.erl

-module(livery_multipart).
-moduledoc """
Streaming `multipart/form-data` parser (RFC 7578).

Sits on the request body reader (`livery_body`) and works over both
streamed (`{stream, _}`) and buffered (`{buffered, _}`) bodies. Pull the
parts one at a time:

```erlang
{ok, MP0} = livery_multipart:new(Req),
case livery_multipart:next_part(MP0, 5000) of
    {part, #{name := Name, filename := File}, MP1} ->
        %% drain this part's body incrementally
        drain(MP1);
    {done, _} -> ok
end.
```

`read_all/1,2` is a convenience that collects every part fully into
memory under the configured limits.

Security: the parser bounds all buffering (`max_header_bytes`,
`max_header_count`, `max_parts`, `max_part_size`, `max_body`) and never
touches the filesystem. A part's `filename` is returned verbatim; a
handler MUST confine/sanitize it before using it as a path.
""".

-include("livery.hrl").

-export([
    new/1,
    new/2,
    next_part/1,
    next_part/2,
    read_part/2,
    read_all/1,
    read_all/2
]).

-export_type([mp/0, part/0, part_full/0, opts/0, reason/0]).

-record(mp, {
    reader :: livery_body:reader() | undefined,
    dash :: binary(),
    crlfdash :: binary(),
    buffer = <<>> :: binary(),
    state = start :: start | after_boundary | body | closed,
    opts :: opts(),
    timeout = 5000 :: timeout(),
    consumed = 0 :: non_neg_integer(),
    parts = 0 :: non_neg_integer()
}).

-opaque mp() :: #mp{}.

-type opts() :: #{
    part_timeout => timeout(),
    max_parts => pos_integer(),
    max_header_bytes => pos_integer(),
    max_header_count => pos_integer(),
    max_part_size => pos_integer(),
    max_body => pos_integer()
}.

-type part() :: #{
    name := binary() | undefined,
    filename := binary() | undefined,
    content_type := binary() | undefined,
    headers := [{binary(), binary()}]
}.

-type part_full() :: #{
    name := binary() | undefined,
    filename := binary() | undefined,
    content_type := binary() | undefined,
    headers := [{binary(), binary()}],
    body := binary()
}.

-type reason() ::
    malformed
    | {client_reset, term()}
    | timeout
    | {limit,
        max_parts
        | max_header_bytes
        | max_header_count
        | max_part_size
        | max_body}.

-define(CRLF, <<"\r\n">>).

%%====================================================================
%% Construction
%%====================================================================

-doc "Build a parser from a request. Reads the boundary from Content-Type.".
-spec new(livery_req:req()) -> {ok, mp()} | {error, not_multipart | no_boundary}.
new(Req) ->
    new(Req, #{}).

-doc "`new/1` with parser options.".
-spec new(livery_req:req(), opts()) ->
    {ok, mp()} | {error, not_multipart | no_boundary}.
new(Req, Opts0) ->
    Opts = maps:merge(default_opts(), Opts0),
    case livery_req:header(<<"content-type">>, Req) of
        undefined ->
            {error, not_multipart};
        Value ->
            case multipart_boundary(Value) of
                {ok, Boundary} -> {ok, init_mp(Req, Boundary, Opts)};
                {error, R} -> {error, R}
            end
    end.

-spec init_mp(livery_req:req(), binary(), opts()) -> mp().
init_mp(Req, Boundary, Opts) ->
    {Buffer, Reader, Consumed} = source(Req),
    #mp{
        reader = Reader,
        dash = <<"--", Boundary/binary>>,
        crlfdash = <<"\r\n--", Boundary/binary>>,
        buffer = Buffer,
        opts = Opts,
        timeout = maps:get(part_timeout, Opts),
        consumed = Consumed
    }.

-spec source(livery_req:req()) ->
    {binary(), livery_body:reader() | undefined, non_neg_integer()}.
source(Req) ->
    case livery_req:body(Req) of
        empty ->
            {<<>>, undefined, 0};
        {buffered, IoData} ->
            Bin = iolist_to_binary(IoData),
            {Bin, undefined, byte_size(Bin)};
        {stream, Reader} ->
            {<<>>, Reader, 0}
    end.

%%====================================================================
%% Streaming API
%%====================================================================

-doc "Advance to the next part, returning its parsed metadata.".
-spec next_part(mp()) -> {part, part(), mp()} | {done, mp()} | {error, reason(), mp()}.
next_part(MP) ->
    next_part(MP, MP#mp.timeout).

-doc "`next_part/1` with an explicit per-chunk timeout.".
-spec next_part(mp(), timeout()) ->
    {part, part(), mp()} | {done, mp()} | {error, reason(), mp()}.
next_part(#mp{} = MP0, Timeout) ->
    MP = MP0#mp{timeout = Timeout},
    case max_body_ok(MP) of
        false -> {error, {limit, max_body}, MP};
        true -> advance(MP)
    end.

-spec advance(mp()) -> {part, part(), mp()} | {done, mp()} | {error, reason(), mp()}.
advance(#mp{state = closed} = MP) ->
    {done, MP};
advance(#mp{state = start} = MP) ->
    case start_scan(MP) of
        {ok, MP1} -> at_boundary(MP1);
        {done, MP1} -> {done, MP1};
        {error, R, MP1} -> {error, R, MP1}
    end;
advance(#mp{state = body} = MP) ->
    case skip_body(MP) of
        {ok, MP1} -> at_boundary(MP1);
        {error, R, MP1} -> {error, R, MP1}
    end;
advance(#mp{state = after_boundary} = MP) ->
    at_boundary(MP).

%% Positioned right after a `--boundary` token: decide closing vs a part.
-spec at_boundary(mp()) ->
    {part, part(), mp()} | {done, mp()} | {error, reason(), mp()}.
at_boundary(MP0) ->
    case ensure(MP0, 2) of
        {ok, MP} ->
            case MP#mp.buffer of
                <<"--", _/binary>> ->
                    {done, MP#mp{state = closed, buffer = <<>>}};
                _ ->
                    start_part(MP)
            end;
        {eof, MP} ->
            %% boundary not followed by anything: tolerate as closing.
            {done, MP#mp{state = closed, buffer = <<>>}};
        {error, R, MP} ->
            {error, R, MP}
    end.

-spec start_part(mp()) -> {part, part(), mp()} | {error, reason(), mp()}.
start_part(MP0) ->
    MaxParts = maps:get(max_parts, MP0#mp.opts),
    case MP0#mp.parts + 1 > MaxParts of
        true ->
            {error, {limit, max_parts}, MP0};
        false ->
            %% consume the rest of the boundary line (transport padding)
            case read_line(MP0, maps:get(max_header_bytes, MP0#mp.opts)) of
                {line, _Padding, MP1} -> read_part_headers(MP1);
                {error, R, MP1} -> {error, R, MP1}
            end
    end.

-spec read_part_headers(mp()) -> {part, part(), mp()} | {error, reason(), mp()}.
read_part_headers(MP0) ->
    Max = maps:get(max_header_bytes, MP0#mp.opts),
    MaxCount = maps:get(max_header_count, MP0#mp.opts),
    case headers_loop(MP0, [], 0, Max, MaxCount) of
        {ok, Headers, MP1} ->
            Part = build_part(Headers),
            {part, Part, MP1#mp{state = body, parts = MP1#mp.parts + 1}};
        {error, R, MP1} ->
            {error, R, MP1}
    end.

-spec headers_loop(mp(), [{binary(), binary()}], non_neg_integer(), pos_integer(), pos_integer()) ->
    {ok, [{binary(), binary()}], mp()} | {error, reason(), mp()}.
headers_loop(MP0, Acc, Bytes, Max, MaxCount) ->
    case read_line(MP0, Max) of
        {line, <<>>, MP1} ->
            {ok, lists:reverse(Acc), MP1};
        {line, Line, MP1} ->
            NewBytes = Bytes + byte_size(Line) + 2,
            check_header_line(MP1, Acc, Line, NewBytes, Max, MaxCount);
        {error, R, MP1} ->
            {error, R, MP1}
    end.

-spec check_header_line(
    mp(), [{binary(), binary()}], binary(), non_neg_integer(), pos_integer(), pos_integer()
) -> {ok, [{binary(), binary()}], mp()} | {error, reason(), mp()}.
check_header_line(MP, _Acc, _Line, NewBytes, Max, _MaxCount) when NewBytes > Max ->
    {error, {limit, max_header_bytes}, MP};
check_header_line(MP, Acc, _Line, _NewBytes, _Max, MaxCount) when
    length(Acc) + 1 > MaxCount
->
    {error, {limit, max_header_count}, MP};
check_header_line(MP, Acc, Line, NewBytes, Max, MaxCount) ->
    case parse_header(Line) of
        {ok, KV} -> headers_loop(MP, [KV | Acc], NewBytes, Max, MaxCount);
        error -> {error, malformed, MP}
    end.

-doc "Read the next chunk of the current part body.".
-spec read_part(mp(), timeout()) ->
    {ok, binary(), mp()} | {done, mp()} | {error, reason(), mp()}.
read_part(#mp{state = body} = MP0, Timeout) ->
    body_chunk(MP0#mp{timeout = Timeout});
read_part(#mp{} = MP, _Timeout) ->
    {done, MP}.

-spec body_chunk(mp()) -> {ok, binary(), mp()} | {done, mp()} | {error, reason(), mp()}.
body_chunk(#mp{buffer = Buf, crlfdash = Needle} = MP) ->
    case binary:match(Buf, Needle) of
        {Pos, Len} ->
            <<Body:Pos/binary, _:Len/binary, Rest/binary>> = Buf,
            MP1 = MP#mp{buffer = Rest, state = after_boundary},
            case Pos of
                0 -> {done, MP1};
                _ -> {ok, Body, MP1}
            end;
        nomatch ->
            HoldBack = byte_size(MP#mp.crlfdash) - 1,
            case byte_size(Buf) > HoldBack of
                true ->
                    Emit = binary:part(Buf, 0, byte_size(Buf) - HoldBack),
                    Keep = binary:part(Buf, byte_size(Buf) - HoldBack, HoldBack),
                    {ok, Emit, MP#mp{buffer = Keep}};
                false ->
                    case pull(MP) of
                        {ok, MP1} -> body_chunk(MP1);
                        {eof, MP1} -> {error, malformed, MP1};
                        {error, R, MP1} -> {error, R, MP1}
                    end
            end
    end.

-spec skip_body(mp()) -> {ok, mp()} | {error, reason(), mp()}.
skip_body(MP0) ->
    case read_part(MP0, MP0#mp.timeout) of
        {ok, _Chunk, MP1} -> skip_body(MP1);
        {done, MP1} -> {ok, MP1};
        {error, R, MP1} -> {error, R, MP1}
    end.

%%====================================================================
%% Buffered convenience
%%====================================================================

-doc "Collect every part fully into memory under the configured limits.".
-spec read_all(livery_req:req()) -> {ok, [part_full()]} | {error, reason()}.
read_all(Req) ->
    read_all(Req, #{}).

-doc "`read_all/1` with parser options.".
-spec read_all(livery_req:req(), opts()) -> {ok, [part_full()]} | {error, reason()}.
read_all(Req, Opts) ->
    case new(Req, Opts) of
        {error, R} -> {error, R};
        {ok, MP} -> collect(MP, [])
    end.

-spec collect(mp(), [part_full()]) -> {ok, [part_full()]} | {error, reason()}.
collect(MP0, Acc) ->
    case next_part(MP0) of
        {done, _MP1} ->
            {ok, lists:reverse(Acc)};
        {error, R, _MP1} ->
            {error, R};
        {part, Part, MP1} ->
            Max = maps:get(max_part_size, MP1#mp.opts),
            case collect_body(MP1, [], 0, Max) of
                {ok, Body, MP2} ->
                    collect(MP2, [Part#{body => Body} | Acc]);
                {error, R, _MP2} ->
                    {error, R}
            end
    end.

-spec collect_body(mp(), [binary()], non_neg_integer(), pos_integer()) ->
    {ok, binary(), mp()} | {error, reason(), mp()}.
collect_body(MP0, Acc, Size, Max) ->
    case read_part(MP0, MP0#mp.timeout) of
        {ok, Chunk, MP1} ->
            Size1 = Size + byte_size(Chunk),
            case Size1 > Max of
                true -> {error, {limit, max_part_size}, MP1};
                false -> collect_body(MP1, [Chunk | Acc], Size1, Max)
            end;
        {done, MP1} ->
            {ok, iolist_to_binary(lists:reverse(Acc)), MP1};
        {error, R, MP1} ->
            {error, R, MP1}
    end.

%%====================================================================
%% Scanning primitives
%%====================================================================

%% Skip the preamble, locate the first `--boundary`, position after it.
-spec start_scan(mp()) -> {ok, mp()} | {done, mp()} | {error, reason(), mp()}.
start_scan(#mp{buffer = Buf, dash = Dash} = MP) ->
    case binary:match(Buf, Dash) of
        {Pos, Len} ->
            Rest = binary:part(Buf, Pos + Len, byte_size(Buf) - Pos - Len),
            {ok, MP#mp{buffer = Rest, state = after_boundary}};
        nomatch ->
            HoldBack = byte_size(Dash) - 1,
            Keep = tail(Buf, HoldBack),
            case pull(MP#mp{buffer = Keep}) of
                {ok, MP1} ->
                    start_scan(MP1);
                {eof, MP1} ->
                    case MP1#mp.consumed of
                        0 -> {done, MP1#mp{state = closed}};
                        _ -> {error, malformed, MP1}
                    end;
                {error, R, MP1} ->
                    {error, R, MP1}
            end
    end.

%% Read one CRLF-terminated line, returning the content before the CRLF.
-spec read_line(mp(), integer()) -> {line, binary(), mp()} | {error, reason(), mp()}.
read_line(#mp{buffer = Buf} = MP, MaxBytes) ->
    case binary:match(Buf, ?CRLF) of
        {Pos, Len} ->
            <<Line:Pos/binary, _:Len/binary, Rest/binary>> = Buf,
            {line, Line, MP#mp{buffer = Rest}};
        nomatch ->
            case byte_size(Buf) > MaxBytes of
                true ->
                    {error, {limit, max_header_bytes}, MP};
                false ->
                    case pull(MP) of
                        {ok, MP1} -> read_line(MP1, MaxBytes);
                        {eof, MP1} -> {error, malformed, MP1};
                        {error, R, MP1} -> {error, R, MP1}
                    end
            end
    end.

%% Ensure at least N bytes are buffered (or EOF/error).
-spec ensure(mp(), non_neg_integer()) -> {ok, mp()} | {eof, mp()} | {error, reason(), mp()}.
ensure(#mp{buffer = Buf} = MP, N) when byte_size(Buf) >= N ->
    {ok, MP};
ensure(MP, N) ->
    case pull(MP) of
        {ok, MP1} -> ensure(MP1, N);
        {eof, MP1} -> {eof, MP1};
        {error, R, MP1} -> {error, R, MP1}
    end.

%% Read one chunk from the source into the buffer.
-spec pull(mp()) -> {ok, mp()} | {eof, mp()} | {error, reason(), mp()}.
pull(#mp{reader = undefined} = MP) ->
    {eof, MP};
pull(#mp{reader = Reader, timeout = Timeout, buffer = Buf, consumed = Consumed} = MP) ->
    case livery_body:read(Reader, Timeout) of
        {ok, Chunk, Reader1} ->
            Bin = iolist_to_binary(Chunk),
            MP1 = MP#mp{
                reader = Reader1,
                buffer = <<Buf/binary, Bin/binary>>,
                consumed = Consumed + byte_size(Bin)
            },
            case max_body_ok(MP1) of
                true -> {ok, MP1};
                false -> {error, {limit, max_body}, MP1}
            end;
        {done, Reader1} ->
            {eof, MP#mp{reader = Reader1}};
        {error, Error, Reader1} ->
            {error, normalize_error(Error), MP#mp{reader = Reader1}}
    end.

-spec max_body_ok(mp()) -> boolean().
max_body_ok(#mp{consumed = C, opts = Opts}) ->
    C =< maps:get(max_body, Opts).

-spec normalize_error(timeout | {client_reset, term()}) -> reason().
normalize_error(timeout) -> timeout;
normalize_error({client_reset, _} = E) -> E.

-spec tail(binary(), non_neg_integer()) -> binary().
tail(Bin, N) when byte_size(Bin) =< N -> Bin;
tail(Bin, N) -> binary:part(Bin, byte_size(Bin) - N, N).

%%====================================================================
%% Header / part parsing
%%====================================================================

-spec parse_header(binary()) -> {ok, {binary(), binary()}} | error.
parse_header(Line) ->
    case has_control(Line) of
        true ->
            error;
        false ->
            case binary:split(Line, <<":">>) of
                [Name, Value] when Name =/= <<>> ->
                    {ok, {downcase(trim(Name)), trim(Value)}};
                _ ->
                    error
            end
    end.

-spec build_part([{binary(), binary()}]) -> part().
build_part(Headers) ->
    {Name, Filename} =
        case lists:keyfind(<<"content-disposition">>, 1, Headers) of
            {_, CD} -> disposition(CD);
            false -> {undefined, undefined}
        end,
    CType =
        case lists:keyfind(<<"content-type">>, 1, Headers) of
            {_, CT} -> CT;
            false -> undefined
        end,
    #{
        name => Name,
        filename => Filename,
        content_type => CType,
        headers => Headers
    }.

-spec disposition(binary()) -> {binary() | undefined, binary() | undefined}.
disposition(Value) ->
    Params = params(Value),
    {param(<<"name">>, Params), param(<<"filename">>, Params)}.

-spec param(binary(), [{binary(), binary()}]) -> binary() | undefined.
param(Key, Params) ->
    case lists:keyfind(Key, 1, Params) of
        {_, V} -> V;
        false -> undefined
    end.

%% Parse `;`-separated parameters after the leading token, unquoting
%% double-quoted values. Keys are lowercased.
-spec params(binary()) -> [{binary(), binary()}].
params(Value) ->
    case binary:split(Value, <<";">>, [global]) of
        [_Type | Rest] -> lists:filtermap(fun param_pair/1, Rest);
        [] -> []
    end.

-spec param_pair(binary()) -> {true, {binary(), binary()}} | false.
param_pair(Part) ->
    case binary:split(trim(Part), <<"=">>) of
        [K, V] when K =/= <<>> -> {true, {downcase(trim(K)), unquote(trim(V))}};
        _ -> false
    end.

-spec unquote(binary()) -> binary().
unquote(<<$", Rest/binary>>) ->
    case byte_size(Rest) of
        0 ->
            Rest;
        N ->
            case binary:at(Rest, N - 1) of
                $" -> binary:part(Rest, 0, N - 1);
                _ -> <<$", Rest/binary>>
            end
    end;
unquote(V) ->
    V.

%%====================================================================
%% Content-Type / boundary
%%====================================================================

-spec multipart_boundary(binary()) ->
    {ok, binary()} | {error, not_multipart | no_boundary}.
multipart_boundary(Value) ->
    case binary:split(Value, <<";">>) of
        [Type | _] ->
            case downcase(trim(Type)) of
                <<"multipart/form-data">> -> boundary_param(Value);
                _ -> {error, not_multipart}
            end;
        [] ->
            {error, not_multipart}
    end.

-spec boundary_param(binary()) -> {ok, binary()} | {error, no_boundary}.
boundary_param(Value) ->
    case param(<<"boundary">>, params(Value)) of
        undefined -> {error, no_boundary};
        <<>> -> {error, no_boundary};
        Boundary -> {ok, Boundary}
    end.

%%====================================================================
%% Helpers
%%====================================================================

-spec default_opts() -> opts().
default_opts() ->
    #{
        part_timeout => 5000,
        max_parts => 1000,
        max_header_bytes => 65536,
        max_header_count => 64,
        max_part_size => 10485760,
        max_body => 104857600
    }.

-spec has_control(binary()) -> boolean().
has_control(<<>>) -> false;
has_control(<<C, _/binary>>) when C < 32, C =/= $\t -> true;
has_control(<<_, Rest/binary>>) -> has_control(Rest).

-spec trim(binary()) -> binary().
trim(Bin) ->
    iolist_to_binary(string:trim(Bin)).

-spec downcase(binary()) -> binary().
downcase(Bin) ->
    iolist_to_binary(string:lowercase(Bin)).