Skip to main content

src/vpn_crypto.erl

%%%-------------------------------------------------------------------
%% @doc PSK authenticated encryption for VPN frames.
%%%-------------------------------------------------------------------
-module(vpn_crypto).

-export([new/2, encode/2, decode/2]).

-define(KEY_SIZE, 32).
-define(NONCE_SIZE, 12).
-define(TAG_SIZE, 16).

new(Psk, PeerId) when is_binary(Psk), byte_size(Psk) =:= ?KEY_SIZE ->
    #{psk => Psk, peer_id => peer_id_to_binary(PeerId)};
new(Psk, _PeerId) ->
    erlang:error({invalid_psk, Psk}).

encode(Frame, State = #{psk := Psk, peer_id := PeerId}) ->
    Seq = frame_seq(Frame),
    Nonce = nonce(PeerId, Seq),
    {Ciphertext, Tag} =
        crypto:crypto_one_time_aead(chacha20_poly1305,
                                    Psk,
                                    Nonce,
                                    Frame,
                                    <<>>,
                                    true),
    {ok, <<Nonce/binary, Ciphertext/binary, Tag/binary>>, State}.

decode(Packet, State = #{psk := Psk}) when byte_size(Packet) >= ?NONCE_SIZE + ?TAG_SIZE ->
    CipherSize = byte_size(Packet) - ?NONCE_SIZE - ?TAG_SIZE,
    <<Nonce:?NONCE_SIZE/binary, Ciphertext:CipherSize/binary, Tag:?TAG_SIZE/binary>> = Packet,
    case crypto:crypto_one_time_aead(chacha20_poly1305,
                                     Psk,
                                     Nonce,
                                     Ciphertext,
                                     <<>>,
                                     Tag,
                                     false) of
        error ->
            {error, authentication_failed, State};
        Plaintext ->
            {ok, Plaintext, State}
    end;
decode(_Packet, State) ->
    {error, truncated_encrypted_packet, State}.

frame_seq(<<1:8, 1:8, Seq:64/unsigned, _/binary>>) ->
    Seq;
frame_seq(_Frame) ->
    erlang:error(invalid_frame).

nonce(PeerId, Seq) when is_integer(Seq), Seq >= 0 ->
    <<Prefix:32, _/binary>> = crypto:hash(sha256, PeerId),
    <<Prefix:32, Seq:64/unsigned>>.

peer_id_to_binary(PeerId) when is_atom(PeerId) ->
    atom_to_binary(PeerId, utf8);
peer_id_to_binary(PeerId) when is_binary(PeerId) ->
    PeerId;
peer_id_to_binary(PeerId) ->
    erlang:error({invalid_peer_id, PeerId}).