lib/assent/jwt_adapter/assent_jwt.ex

defmodule Assent.JWTAdapter.AssentJWT do
  @moduledoc """
  JWT adapter module for parsing JSON Web Tokens natively.

  See `Assent.JWTAdapter` for more.
  """
  alias Assent.{Config, JWTAdapter}

  @behaviour Assent.JWTAdapter

  defmodule Error do
    defexception [:message, :reason, data: nil]
  end

  @impl JWTAdapter
  def sign(claims, alg, secret_or_private_key, opts) do
    with {:ok, header}    <- encode_header(alg, opts),
         {:ok, claims}    <- encode_claims(claims, opts) do
      do_sign(header, claims, alg, secret_or_private_key)
    end
  end

  defp encode_header(alg, opts) do
    header =
      case Keyword.has_key?(opts, :private_key_id) do
        false -> %{"typ" => "JWT", "alg" => alg}
        true -> %{"typ" => "JWT", "alg" => alg, "kid" => Keyword.get(opts, :private_key_id)}
      end

    case encode_json_base64(header, opts) do
      {:ok, encoded_header} -> {:ok, encoded_header}
      {:error, error}       -> {:error, %Error{message: "Failed to encode header", reason: error, data: header}}
    end
  end

  defp encode_json_base64(map, opts) do
    with {:ok, json_library} <- Config.fetch(opts, :json_library),
         {:ok, json} <- json_library.encode(map) do
      {:ok, Base.url_encode64(json, padding: false)}
    end
  end

  defp encode_claims(claims, opts) do
    case encode_json_base64(claims, opts) do
      {:ok, encoded_claims} -> {:ok, encoded_claims}
      {:error, error}       -> {:error, %Error{message: "Failed to encode claims", reason: error, data: claims}}
    end
  end

  defp do_sign(header, claims, alg, secret_or_private_key) do
    message = header <> "." <> claims

    case sign_message(message, alg, secret_or_private_key) do
      {:ok, signature} -> {:ok, "#{message}.#{Base.url_encode64(signature, padding: false)}"}
      {:error, error} -> {:error, %Error{message: "Failed to sign JWT", reason: error, data: {message, alg}}}
    end
  end

  defp sign_message(message, "HS" <> sha_bit_size, secret) do
    with {:ok, sha_alg} <- sha2_alg(sha_bit_size) do
      {:ok, :crypto.mac(:hmac, sha_alg, secret, message)}
    end
  end

  defp sign_message(message, "ES" <> sha_bit_size, private_key) do
    # Per https://tools.ietf.org/html/rfc7515#appendix-A.3.1

    with {:ok, sha_alg} <- sha2_alg(sha_bit_size),
         {:ok, key}     <- decode_pem(private_key) do

      der_signature = :public_key.sign(message, sha_alg, key)
      {:'ECDSA-Sig-Value', r, s} = :public_key.der_decode(:'ECDSA-Sig-Value', der_signature)
      r_bin = sha_bit_pad(int_to_bin(r), sha_bit_size)
      s_bin = sha_bit_pad(int_to_bin(s), sha_bit_size)

      {:ok, r_bin <> s_bin}
    end
  end

  defp sign_message(message, <<_, "S", sha_bit_size :: binary>>, private_key) do
    with {:ok, sha_alg} <- sha2_alg(sha_bit_size),
         {:ok, key}     <- decode_pem(private_key) do

      {:ok, :public_key.sign(message, sha_alg, key)}
    end
  end

  defp sign_message(_message, alg, _jwk), do: {:error, "Unsupported JWT alg #{alg} or invalid JWK"}

  defp sha2_alg("256"), do: {:ok, :sha256}
  defp sha2_alg("384"), do: {:ok, :sha384}
  defp sha2_alg("512"), do: {:ok, :sha512}
  defp sha2_alg(bit_size), do: {:error, "Invalid SHA-2 algorithm bit size: #{bit_size}"}

  defp decode_pem(pem) do
    case :public_key.pem_decode(pem) do
      []      -> {:error, "Invalid private key"}
      [entry] -> {:ok, :public_key.pem_entry_decode(entry)}
      _any    -> {:error, "Private key should only have one entry"}
    end
  end

  # From erlang crypto lib
  defp int_to_bin(x) when x < 0, do: int_to_bin_neg(x, [])
  defp int_to_bin(x), do: int_to_bin_pos(x, [])

  defp int_to_bin_pos(0, [_ | _] = ds), do: :erlang.list_to_binary(ds)
  defp int_to_bin_pos(x, ds), do: int_to_bin_pos(:erlang.bsr(x, 8), [:erlang.band(x, 255) | ds])

  defp int_to_bin_neg(-1, [msb | _] = ds) when msb >= 128, do: :erlang.list_to_binary(ds)
  defp int_to_bin_neg(x, ds), do: int_to_bin_neg(:erlang.bsr(x, 8), [:erlang.band(x, 255) | ds])

  defp sha_bit_pad(binary, "256"), do: lpad_binary(binary, byte_size(binary) - 32)
  defp sha_bit_pad(binary, "384"), do: lpad_binary(binary, byte_size(binary) - 48)
  defp sha_bit_pad(binary, "512"), do: lpad_binary(binary, byte_size(binary) - 66)

  defp lpad_binary(binary, length) when length > 0 do
    :binary.copy(<<0>>, length - byte_size(binary)) <> binary
  end
  defp lpad_binary(binary, _length), do: binary

  @impl JWTAdapter
  def verify(token, secret_or_public_key, opts) do
    with {:ok, encoded_jwt} <- split(token),
         {:ok, alg, header} <- decode_header(encoded_jwt.header, opts),
         {:ok, claims}      <- decode_claims(encoded_jwt.claims, opts),
         {:ok, signature}   <- decode_signature(encoded_jwt.signature),
         {:ok, verified}    <- do_verify(encoded_jwt.header, encoded_jwt.claims, signature, alg, secret_or_public_key) do

      {:ok, %{
        header: header,
        claims: claims,
        signature: signature,
        verified?: verified
      }}
    end
  end

  defp split(token) do
    case String.split(token, ".") do
      [header, claims, signature] -> {:ok, %{header: header, claims: claims, signature: signature}}
      parts                     -> {:error, %Error{message: "JWT must have exactly three parts", reason: :invalid_format, data: parts}}
    end
  end

  defp decode_header(header, opts) do
    with {:ok, json_library} <- Config.fetch(opts, :json_library),
         {:ok, header}       <- decode_base64_url(header),
         {:ok, header}       <- decode_json(header, json_library),
         {:ok, alg}          <- fetch_alg(header) do
      {:ok, alg, header}
    else
      {:error, error} -> {:error, %Error{message: "Failed to decode header", reason: error, data: header}}
    end
  end

  defp decode_base64_url(encoded) do
    case Base.url_decode64(encoded, padding: false) do
      {:ok, decoded} -> {:ok, decoded}
      :error -> {:error, "Invalid Base64URL"}
    end
  end

  defp decode_json(encoded, json_library) do
    case json_library.decode(encoded) do
      {:ok, decoded} -> {:ok, decoded}
      {:error, error} -> {:error, error}
    end
  end

  defp fetch_alg(%{"alg" => alg}), do: {:ok, alg}
  defp fetch_alg(_header), do: {:error, "No \"alg\" found in header"}

  defp decode_claims(claims, opts) do
    with {:ok, json_library} <- Config.fetch(opts, :json_library),
         {:ok, claims}       <- decode_base64_url(claims),
         {:ok, claims}       <- decode_json(claims, json_library) do
      {:ok, claims}
    else
      {:error, error} -> {:error, %Error{message: "Failed to decode claims", reason: error, data: claims}}
    end
  end

  defp decode_signature(signature) do
    case decode_base64_url(signature) do
      {:ok, signature} -> {:ok, signature}
      {:error, error} -> {:error, %Error{message: "Failed to decode signature", reason: error, data: signature}}
    end
  end

  defp do_verify(header, claims, signature, alg, secret_or_public_key) do
    message = "#{header}.#{claims}"

    case verify_message(message, signature, alg, secret_or_public_key) do
      {:ok, verified} -> {:ok, verified}
      {:error, error} -> {:error, %Error{message: "Failed to verify signature", reason: error, data: {message, signature, alg}}}
    end
  end

  defp verify_message(_message, _signature, "none", _secret), do: {:ok, false}

  defp verify_message(_message, _signature, _alg, nil), do: {:ok, false}

  defp verify_message(message, signature_1, "HS" <> _rest = alg, secret) when is_binary(secret) do
    with {:ok, signature_2} <- sign_message(message, alg, secret) do
      {:ok, Assent.constant_time_compare(signature_2, signature_1)}
    end
  end

  defp verify_message(message, signature, "ES" <> sha_bit_size, public_key) do
    with {:ok, sha_alg} <- sha2_alg(sha_bit_size),
         {:ok, pem}     <- decode_key(public_key) do

      # Per https://tools.ietf.org/html/rfc7515#appendix-A.3.1

      size           = :erlang.byte_size(signature)
      {r_bin, s_bin} = :erlang.split_binary(signature, Integer.floor_div(size, 2))
      r              = :crypto.bytes_to_integer(r_bin)
      s              = :crypto.bytes_to_integer(s_bin)
      der_signature  = :public_key.der_encode(:'ECDSA-Sig-Value', {:'ECDSA-Sig-Value', r, s})

      {:ok, :public_key.verify(message, sha_alg, der_signature, pem)}
    end
  end

  defp verify_message(message, signature, <<_, "S", sha_bit_size :: binary>>, public_key) do
    with {:ok, sha_alg} <- sha2_alg(sha_bit_size),
         {:ok, pem}     <- decode_key(public_key) do
      {:ok, :public_key.verify(message, sha_alg, signature, pem)}
    end
  end

  defp decode_key(pem) when is_binary(pem), do: decode_pem(pem)

  defp decode_key(%{"kty" => "RSA", "n" => n, "e" => e}) do
    with {:ok, n} <- decode_base64_url(n),
         {:ok, e} <- decode_base64_url(e) do
      {:ok, {:RSAPublicKey, :crypto.bytes_to_integer(n), :crypto.bytes_to_integer(e)}}
    end
  end

  defp decode_key(jwk) when is_map(jwk), do: {:error, "Unable to decode the JWK"}
end