src/hund.erl

-module(hund).

-include_lib("public_key/include/public_key.hrl").
-include_lib("../include/hund.hrl").

-type contact() :: #saml_contact{}.
-type org() :: #saml_org{}.
-type idp_metadata() :: #saml_idp_metadata{}.
-type sp_metadata() :: #saml_sp_metadata{}.
-type subject() :: #saml_subject{}.
-type assertion() :: #saml_assertion{}.
-type authnreq() :: #saml_authnreq{}.
-type authn() :: #saml_authn{}.
-type logout_request() :: #saml_logout_request{}.
-type logout_response() :: #saml_logout_response{}.
-type saml_record() :: contact()
                     | org()
                     | idp_metadata()
                     | sp_metadata()
                     | subject()
                     | assertion()
                     | authn()
                     | authnreq()
                     | logout_request()
                     | logout_response().

-export_type(
  [
    contact/0,
    org/0,
    idp_metadata/0,
    sp_metadata/0,
    saml_record/0,
    authnreq/0,
    assertion/0,
    logout_request/0,
    logout_response/0
  ]
).

-type localized_string() :: string() | [{Locale :: atom(), LocalString :: string()}].
-type name_format() :: email | x509 | windows | krb | persistent | transient | unknown.
-type status_code() :: success
                     | request_error
                     | response_error
                     | bad_version
                     | authn_failed
                     | bad_attr
                     | denied
                     | bad_binding
                     | unknown.
-type version() :: string().
-type datetime() :: string() | binary().
-type condition() :: #saml_condition{}.
-type subject_method() :: bearer | holder_of_key | sender_vouches.
-type authn_class() :: password
                     | password_protected_transport
                     | internet_protocol
                     | internet_protocol_password
                     | mobile_one_factor_contract
                     | mobile_two_factor_contract
                     | previous_session
                     | unspecified.

-export_type(
  [
    localized_string/0,
    name_format/0,
    status_code/0,
    version/0,
    datetime/0,
    condition/0,
    subject_method/0,
    authn_class/0
  ]
).

-export([nameid_map/1, rev_nameid_map/1, status_code_map/1, datetime_to_saml/1]).
-export([threaduntil/2]).
-export([map_if/2, map_if/1]).
-export([rev_subject_method_map/1, rev_map_authn_class/1, rev_status_code_map/1, map_authn_class/1]).
-export([saml_to_datetime/1, date_to_saml/1]).
-export([start_ets/0, check_dupe_ets/2]).
-export(
  [
    convert_fingerprints/1,
    load_private_key/1,
    import_private_key/2,
    load_certificate/1,
    import_certificate/2,
    load_metadata/1,
    load_metadata/2,
    unique_id/0
  ]
).

-spec threaduntil(
  [fun((Acc :: term()) -> {error, term()} | {stop, term()} | term())],
  InitAcc :: term()
) ->
  {error, term()} | {ok, term()}.
threaduntil([], Acc) -> {ok, Acc};

threaduntil([F | Rest], Acc) ->
  case catch F(Acc) of
    {'EXIT', Reason} -> {error, Reason};
    {error, Reason} -> {error, Reason};
    {stop, LastAcc} -> {ok, LastAcc};
    NextAcc -> threaduntil(Rest, NextAcc)
  end.


-spec nameid_map(string()) -> name_format().
nameid_map("urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress") -> email;
nameid_map("urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName") -> x509;
nameid_map("urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName") -> windows;
nameid_map("urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos") -> krb;
nameid_map("urn:oasis:names:tc:SAML:2.0:nameid-format:persistent") -> persistent;
nameid_map("urn:oasis:names:tc:SAML:2.0:nameid-format:transient") -> transient;
nameid_map(S) when is_list(S) -> unknown.

-spec rev_nameid_map(atom()) -> string().
rev_nameid_map(email) -> "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress";
rev_nameid_map(x509) -> "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName";
rev_nameid_map(windows) -> "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName";
rev_nameid_map(krb) -> "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos";
rev_nameid_map(persistent) -> "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos";
rev_nameid_map(transient) -> "urn:oasis:names:tc:SAML:2.0:nameid-format:transient";
rev_nameid_map(_) -> "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified".

-spec status_code_map(string()) -> status_code() | atom().
status_code_map("urn:oasis:names:tc:SAML:2.0:status:Success") -> success;
status_code_map("urn:oasis:names:tc:SAML:2.0:status:VersionMismatch") -> bad_version;
status_code_map("urn:oasis:names:tc:SAML:2.0:status:AuthnFailed") -> authn_failed;
status_code_map("urn:oasis:names:tc:SAML:2.0:status:InvalidAttrNameOrValue") -> bad_attr;
status_code_map("urn:oasis:names:tc:SAML:2.0:status:RequestDenied") -> denied;
status_code_map("urn:oasis:names:tc:SAML:2.0:status:UnsupportedBinding") -> bad_binding;
status_code_map(Urn = "urn:" ++ _) -> list_to_atom(lists:last(string:tokens(Urn, ":")));
status_code_map(S) when is_list(S) -> unknown.

-spec rev_status_code_map(status_code() | atom()) -> string().
rev_status_code_map(success) -> "urn:oasis:names:tc:SAML:2.0:status:Success";
rev_status_code_map(bad_version) -> "urn:oasis:names:tc:SAML:2.0:status:VersionMismatch";
rev_status_code_map(authn_failed) -> "urn:oasis:names:tc:SAML:2.0:status:AuthnFailed";
rev_status_code_map(bad_attr) -> "urn:oasis:names:tc:SAML:2.0:status:InvalidAttrNameOrValue";
rev_status_code_map(denied) -> "urn:oasis:names:tc:SAML:2.0:status:RequestDenied";
rev_status_code_map(bad_binding) -> "urn:oasis:names:tc:SAML:2.0:status:UnsupportedBinding";

rev_status_code_map(Status) when is_atom(Status) ->
  "urn:oasis:names:tc:SAML:2.0:status:" ++ pascal_case(atom_to_list(Status)).

-spec rev_subject_method_map(subject_method()) -> string().
rev_subject_method_map(bearer) -> "urn:oasis:names:tc:SAML:2.0:cm:bearer";
rev_subject_method_map(holder_of_key) -> "urn:oasis:names:tc:SAML:2.0:cm:holder-of-key";
rev_subject_method_map(sender_vouches) -> "urn:oasis:names:tc:SAML:2.0:cm:sender-vouches".

-spec rev_map_authn_class(Context :: atom()) -> string().
rev_map_authn_class(password_protected_transport) ->
  "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport";

rev_map_authn_class(password) -> "urn:oasis:names:tc:SAML:2.0:ac:classes:Password";
rev_map_authn_class(internet_protocol) -> "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocol";

rev_map_authn_class(internet_protocol_password) ->
  "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword";

rev_map_authn_class(mobile_one_factor_contract) ->
  "urn:oasis:names:tc:SAML:2.0:ac:classes:MobileOneFactorContract";

rev_map_authn_class(mobile_two_factor_contract) ->
  "urn:oasis:names:tc:SAML:2.0:ac:classes:MobileTwoFactorContract";

rev_map_authn_class(previous_session) -> "urn:oasis:names:tc:SAML:2.0:ac:classes:PreviousSession";
rev_map_authn_class(_) -> "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified".

-spec map_authn_class(AuthnClass :: string()) -> authn_class().
map_authn_class("urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport") ->
  password_protected_transport;

map_authn_class("urn:oasis:names:tc:SAML:2.0:ac:classes:Password") -> password;
map_authn_class("urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocol") -> internet_protocol;

map_authn_class("urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword") ->
  internet_protocol_password;

map_authn_class("urn:oasis:names:tc:SAML:2.0:ac:classes:MobileOneFactorContract") ->
  mobile_one_factor_contract;

map_authn_class("urn:oasis:names:tc:SAML:2.0:ac:classes:MobileTwoFactorContract") ->
  mobile_two_factor_contract;

map_authn_class("urn:oasis:names:tc:SAML:2.0:ac:classes:PreviousSession") -> previous_session;
map_authn_class(_) -> unspecified.

%% @doc Converts a calendar:datetime() into SAML time string

-spec datetime_to_saml(calendar:datetime()) -> datetime().
datetime_to_saml(Time) when is_tuple(Time) ->
  {{Y, Mo, D}, {H, Mi, S}} = Time,
  lists:flatten(
    io_lib:format("~4.10.0B-~2.10.0B-~2.10.0BT~2.10.0B:~2.10.0B:~2.10.0BZ", [Y, Mo, D, H, Mi, S])
  );

datetime_to_saml(_Time) -> "".


-spec date_to_saml(calendar:date()) -> string() | binary().
date_to_saml(Date) ->
  {Year, Month, Day} = Date,
  list:flatten(io_lib:format("~4.10.0B-~2.10.0B-~2.10.0B", [Year, Month, Day])).

%% @doc Converts a SAML time string into a calendar:datetime()
%%
%% Inverse of datetime_to_saml/1

-spec saml_to_datetime(esaml:datetime()) -> calendar:datetime().
saml_to_datetime(Stamp) ->
  StampBin =
    if
      is_list(Stamp) -> list_to_binary(Stamp);
      true -> Stamp
    end,
  <<
    YBin:4/binary,
    "-",
    MoBin:2/binary,
    "-",
    DBin:2/binary,
    "T",
    HBin:2/binary,
    ":",
    MiBin:2/binary,
    ":",
    SBin:2/binary,
    Rest/binary
  >> = StampBin,
  %% check that time in UTC timezone because we don't handle another timezones properly
  $Z = binary:last(Rest),
  F = fun (B) -> list_to_integer(binary_to_list(B)) end,
  {{F(YBin), F(MoBin), F(DBin)}, {F(HBin), F(MiBin), F(SBin)}}.


-spec map_if(term()) -> [term()].
map_if("") -> [];
map_if(undefined) -> [];
map_if(K) -> [K].

-spec map_if(atom(), term()) -> [term()].
map_if(Key, List = [{K, _} | _]) when is_atom(K) ->
  case proplists:get_value(Key, List) of
    undefined -> [];
    V when is_list(V) -> V;
    Other -> [Other]
  end;

map_if(_, _) -> [].


-spec pascal_case(String :: string()) -> string().
pascal_case(String) -> pascal_case(String, "_").

-spec pascal_case(String :: string(), Sep :: string()) -> string().
pascal_case(String, Sep) ->
  Chunks = string:split(String, Sep),
  Chunks2 = lists:map(fun string:titlecase/1, Chunks),
  string:join(Chunks2, "").

%% @doc Converts various ascii hex/base64 fingerprint formats to binary

-spec convert_fingerprints([string() | binary()]) -> [binary()].
convert_fingerprints(FPs) ->
  FPSources = FPs ++ esaml:config(trusted_fingerprints, []),
  lists:map(
    fun
      (Print) ->
        if
          is_list(Print) ->
            case string:tokens(Print, ":") of
              [Type, Base64] ->
                Hash = base64:decode(Base64),
                case string:to_lower(Type) of
                  "sha" -> {sha, Hash};
                  "sha1" -> {sha, Hash};
                  "md5" -> {md5, Hash};
                  "sha256" -> {sha256, Hash};
                  "sha384" -> {sha384, Hash};
                  "sha512" -> {sha512, Hash}
                end;

              [_] -> error("unknown fingerprint format");
              HexParts -> list_to_binary([list_to_integer(P, 16) || P <- HexParts])
            end;

          is_binary(Print) -> Print;
          true -> error("unknown fingerprint format")
        end
    end,
    FPSources
  ).

%% @private

start_ets() ->
  case erlang:whereis(hund_ets_table_owner) of
    undefined -> create_tables();

    Pid ->
      Pid ! {self(), check_ready},
      receive {Pid, ready} -> {ok, Pid} end
  end.

%% @private

create_tables() ->
  Caller = self(),
  Pid =
    spawn_link(
      fun
        () ->
          register(hund_ets_table_owner, self()),
          ets:new(hund_assertion_seen, [set, public, named_table]),
          ets:new(hund_privkey_cache, [set, public, named_table]),
          ets:new(hund_certbin_cache, [set, public, named_table]),
          ets:new(hund_idp_meta_cache, [set, public, named_table]),
          Caller ! {self(), ready},
          ets_table_owner()
      end
    ),
  receive {Pid, ready} -> ok end,
  {ok, Pid}.

%% @private

ets_table_owner() ->
  receive
    stop -> ok;

    {Caller, check_ready} ->
      Caller ! {self(), ready},
      ets_table_owner();

    _ -> ets_table_owner()
  end.


% @doc Loads a private key from a file on disk (or ETS memory cache)
-spec load_private_key(Path :: string()) -> #'RSAPrivateKey'{}.
load_private_key(Path) ->
  case ets:lookup(hund_privkey_cache, Path) of
    [{_, Key}] -> Key;

    _ ->
      {ok, KeyFile} = file:read_file(Path),
      do_import_private_key(KeyFile, Path)
  end.


-spec import_private_key(EncodedKey :: string(), Identifier :: term()) -> #'RSAPrivateKey'{}.
import_private_key(EncodedKey, Identifier) ->
  case ets:lookup(hund_privkey_cache, Identifier) of
    [{_, Key}] -> Key;
    _ -> do_import_private_key(EncodedKey, Identifier)
  end.


do_import_private_key(EncodedKey, Identifier) ->
  [KeyEntry] = public_key:pem_decode(EncodedKey),
  Key =
    case public_key:pem_entry_decode(KeyEntry) of
      #'PrivateKeyInfo'{privateKey = KeyData} ->
        KeyDataBin =
          if
            is_list(KeyData) -> list_to_binary(KeyData);
            true -> KeyData
          end,
        public_key:der_decode('RSAPrivateKey', KeyDataBin);

      Other -> Other
    end,
  ets:insert(hund_privkey_cache, {Identifier, Key}),
  Key.


-spec load_certificate(Path :: string()) -> binary().
load_certificate(Path) ->
  [CertBin] = load_certificate_chain(Path),
  CertBin.


-spec import_certificate(EncodedCert :: string(), Identifier :: term()) -> binary().
import_certificate(EncodedCert, Identifier) ->
  [CertBin] = import_certificate_chain(EncodedCert, Identifier),
  CertBin.

%% @doc Loads certificate chain from a file on disk (or ETS memory cache)

-spec load_certificate_chain(Path :: string()) -> [binary()].
load_certificate_chain(Path) ->
  case ets:lookup(hund_certbin_cache, Path) of
    [{_, CertChain}] -> CertChain;

    _ ->
      {ok, EncodedCert} = file:read_file(Path),
      do_import_certificate_chain(EncodedCert, Path)
  end.

%% @doc Loads certificate chain from a file on disk (or ETS memory cache)

-spec import_certificate_chain(EncodedCerts :: string(), Identifier :: string()) -> [binary()].
import_certificate_chain(EncodedCerts, Identifier) ->
  case ets:lookup(hund_certbin_cache, Identifier) of
    [{_, CertChain}] -> CertChain;
    _ -> do_import_certificate_chain(EncodedCerts, Identifier)
  end.


do_import_certificate_chain(EncodedCerts, Identifier) ->
  CertChain =
    [CertBin || {'Certificate', CertBin, not_encrypted} <- public_key:pem_decode(EncodedCerts)],
  ets:insert(hund_certbin_cache, {Identifier, CertChain}),
  CertChain.

%% @doc Reads IDP metadata from a URL (or ETS memory cache) and validates the signature

-spec load_metadata(Url :: string(), Fingerprints :: [string() | binary()]) -> esaml:idp_metadata().
load_metadata(Url, FPs) ->
  Fingerprints = convert_fingerprints(FPs),
  case ets:lookup(hund_idp_meta_cache, Url) of
    [{Url, Meta}] -> Meta;

    _ ->
      {ok, {{_Ver, 200, _}, _Headers, Body}} =
        httpc:request(get, {Url, []}, [{autoredirect, true}, {timeout, 3000}], []),
      {Xml, _} = xmerl_scan:string(Body, [{namespace_conformant, true}]),
      case xmerl_dsig:verify(Xml, Fingerprints) of
        ok -> ok;
        Err -> error(Err)
      end,
      {ok, Meta = #saml_sp_metadata{}} = hund_xml:decode_sp_metadata(Xml),
      ets:insert(hund_idp_meta_cache, {Url, Meta}),
      Meta
  end.

%% @doc Reads IDP metadata from a URL (or ETS memory cache)

-spec load_metadata(Url :: string()) -> esaml:idp_metadata().
load_metadata(Url) ->
  case ets:lookup(hund_idp_meta_cache, Url) of
    [{Url, Meta}] -> Meta;

    _ ->
      Timeout = application:get_env(esaml, load_metadata_timeout, 15000),
      {ok, {{_Ver, 200, _}, _Headers, Body}} =
        httpc:request(get, {Url, []}, [{autoredirect, true}, {timeout, Timeout}], []),
      {Xml, _} = xmerl_scan:string(Body, [{namespace_conformant, true}]),
      {ok, Meta = #saml_sp_metadata{}} = hund:decode_sp_metadata(Xml),
      ets:insert(hund_idp_meta_cache, {Url, Meta}),
      Meta
  end.

%% @doc Checks for a duplicate assertion using ETS tables in memory on all available nodes.
%%
%% This is a helper to be used as a DuplicateFun with hund_sp:validate_assertion/3.
%% If you aren't using standard erlang distribution for your app, you probably don't
%% want to use this.

-spec check_dupe_ets(esaml:assertion(), Digest :: binary()) -> ok | {error, duplicate_assertion}.
check_dupe_ets(A, Digest) ->
  Now = erlang:localtime_to_universaltime(erlang:localtime()),
  NowSecs = calendar:datetime_to_gregorian_seconds(Now),
  DeathSecs = esaml:stale_time(A),
  {ResL, _BadNodes} =
    rpc:multicall(
      erlang,
      apply,
      [
        fun
          () ->
            case catch ets:lookup(hund_assertion_seen, Digest) of
              [{Digest, seen} | _] -> seen;
              _ -> ok
            end
        end,
        []
      ]
    ),
  case lists:member(seen, ResL) of
    true -> {error, duplicate_assertion};

    _ ->
      Until = DeathSecs - NowSecs + 1,
      rpc:multicall(
        erlang,
        apply,
        [
          fun
            () ->
              case ets:info(hund_assertion_seen) of
                undefined ->
                  Me = self(),
                  Pid =
                    spawn(
                      fun
                        () ->
                          register(hund_ets_table_owner, self()),
                          ets:new(hund_assertion_seen, [set, public, named_table]),
                          ets:new(hund_privkey_cache, [set, public, named_table]),
                          ets:new(hund_certbin_cache, [set, public, named_table]),
                          ets:insert(hund_assertion_seen, {Digest, seen}),
                          Me ! {self(), ping},
                          ets_table_owner()
                      end
                    ),
                  receive {Pid, ping} -> ok end;

                _ -> ets:insert(hund_assertion_seen, {Digest, seen})
              end,
              {ok, _} =
                timer:apply_after(
                  Until * 1000,
                  erlang,
                  apply,
                  [fun () -> ets:delete(hund_assertion_seen, Digest) end, []]
                )
          end,
          []
        ]
      ),
      ok
  end.


% TODO: switch to uuid_erl hex pkg
unique_id() ->
  "id"
  ++
  integer_to_list(erlang:system_time())
  ++
  integer_to_list(erlang:unique_integer([positive])).