1212package com .amazon .dlic .auth .http .jwt .keybyoidc ;
1313
1414import com .google .common .base .Strings ;
15+ import com .nimbusds .jose .Algorithm ;
16+ import com .nimbusds .jose .JOSEException ;
17+ import com .nimbusds .jose .JWSVerifier ;
18+ import com .nimbusds .jose .jwk .JWK ;
19+ import com .nimbusds .jose .jwk .OctetSequenceKey ;
20+ import com .nimbusds .jose .crypto .factories .DefaultJWSVerifierFactory ;
21+ import com .nimbusds .jose .proc .SimpleSecurityContext ;
22+ import com .nimbusds .jwt .JWTClaimsSet ;
23+ import com .nimbusds .jwt .SignedJWT ;
24+ import com .nimbusds .jwt .proc .BadJWTException ;
25+ import com .nimbusds .jwt .proc .DefaultJWTClaimsVerifier ;
1526import org .apache .commons .lang3 .StringEscapeUtils ;
16- import org .apache .cxf .rs .security .jose .jwa .SignatureAlgorithm ;
17- import org .apache .cxf .rs .security .jose .jwk .JsonWebKey ;
18- import org .apache .cxf .rs .security .jose .jwk .KeyType ;
19- import org .apache .cxf .rs .security .jose .jwk .PublicKeyUse ;
20- import org .apache .cxf .rs .security .jose .jws .JwsJwtCompactConsumer ;
21- import org .apache .cxf .rs .security .jose .jws .JwsSignatureVerifier ;
22- import org .apache .cxf .rs .security .jose .jws .JwsUtils ;
23- import org .apache .cxf .rs .security .jose .jwt .JwtClaims ;
24- import org .apache .cxf .rs .security .jose .jwt .JwtException ;
25- import org .apache .cxf .rs .security .jose .jwt .JwtToken ;
26- import org .apache .cxf .rs .security .jose .jwt .JwtUtils ;
2727import org .apache .logging .log4j .LogManager ;
2828import org .apache .logging .log4j .Logger ;
2929
30+ import java .text .ParseException ;
31+ import java .util .Collections ;
32+
3033public class JwtVerifier {
3134
3235 private final static Logger log = LogManager .getLogger (JwtVerifier .class );
@@ -43,31 +46,24 @@ public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, Strin
4346 this .requiredAudience = requiredAudience ;
4447 }
4548
46- public JwtToken getVerifiedJwtToken (String encodedJwt ) throws BadCredentialsException {
49+ public SignedJWT getVerifiedJwtToken (String encodedJwt ) throws BadCredentialsException {
4750 try {
48- JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer (encodedJwt );
49- JwtToken jwt = jwtConsumer .getJwtToken ();
51+ SignedJWT jwt = SignedJWT .parse (encodedJwt );
5052
51- String escapedKid = jwt .getJwsHeaders ().getKeyId ();
53+ String escapedKid = jwt .getHeader ().getKeyID ();
5254 String kid = escapedKid ;
5355 if (!Strings .isNullOrEmpty (kid )) {
5456 kid = StringEscapeUtils .unescapeJava (escapedKid );
5557 }
56- JsonWebKey key = keyProvider .getKey (kid );
57-
58- // Algorithm is not mandatory for the key material, so we set it to the same as the JWT
59- if (key .getAlgorithm () == null && key .getPublicKeyUse () == PublicKeyUse .SIGN && key .getKeyType () == KeyType .RSA ) {
60- key .setAlgorithm (jwt .getJwsHeaders ().getAlgorithm ());
61- }
62-
63- JwsSignatureVerifier signatureVerifier = getInitializedSignatureVerifier (key , jwt );
58+ JWK key = keyProvider .getKey (kid );
6459
65- boolean signatureValid = jwtConsumer .verifySignatureWith (signatureVerifier );
60+ JWSVerifier signatureVerifier = getInitializedSignatureVerifier (key , jwt );
61+ boolean signatureValid = jwt .verify (signatureVerifier );
6662
6763 if (!signatureValid && Strings .isNullOrEmpty (kid )) {
6864 key = keyProvider .getKeyAfterRefresh (null );
6965 signatureVerifier = getInitializedSignatureVerifier (key , jwt );
70- signatureValid = jwtConsumer . verifySignatureWith (signatureVerifier );
66+ signatureValid = jwt . verify (signatureVerifier );
7167 }
7268
7369 if (!signatureValid ) {
@@ -77,18 +73,18 @@ public JwtToken getVerifiedJwtToken(String encodedJwt) throws BadCredentialsExce
7773 validateClaims (jwt );
7874
7975 return jwt ;
80- } catch (JwtException e ) {
76+ } catch (JOSEException | ParseException | BadJWTException e ) {
8177 throw new BadCredentialsException (e .getMessage (), e );
8278 }
8379 }
8480
85- private void validateSignatureAlgorithm (JsonWebKey key , JwtToken jwt ) throws BadCredentialsException {
86- if (Strings . isNullOrEmpty ( key .getAlgorithm ()) ) {
81+ private void validateSignatureAlgorithm (JWK key , SignedJWT jwt ) throws BadCredentialsException {
82+ if (key .getAlgorithm () == null || jwt . getHeader (). getAlgorithm () == null ) {
8783 return ;
8884 }
8985
90- SignatureAlgorithm keyAlgorithm = SignatureAlgorithm . getAlgorithm ( key .getAlgorithm () );
91- SignatureAlgorithm tokenAlgorithm = SignatureAlgorithm . getAlgorithm ( jwt .getJwsHeaders ().getAlgorithm () );
86+ Algorithm keyAlgorithm = key .getAlgorithm ();
87+ Algorithm tokenAlgorithm = jwt .getHeader ().getAlgorithm ();
9288
9389 if (!keyAlgorithm .equals (tokenAlgorithm )) {
9490 throw new BadCredentialsException (
@@ -97,38 +93,48 @@ private void validateSignatureAlgorithm(JsonWebKey key, JwtToken jwt) throws Bad
9793 }
9894 }
9995
100- private JwsSignatureVerifier getInitializedSignatureVerifier (JsonWebKey key , JwtToken jwt ) throws BadCredentialsException ,
101- JwtException {
96+ private JWSVerifier getInitializedSignatureVerifier (JWK key , SignedJWT jwt ) throws BadCredentialsException , JOSEException {
10297
10398 validateSignatureAlgorithm (key , jwt );
104- JwsSignatureVerifier result = JwsUtils .getSignatureVerifier (key , jwt .getJwsHeaders ().getSignatureAlgorithm ());
99+ final JWSVerifier result ;
100+ if (key .getClass () == OctetSequenceKey .class ) {
101+ result = new DefaultJWSVerifierFactory ().createJWSVerifier (jwt .getHeader (), key .toOctetSequenceKey ().toSecretKey ());
102+ } else {
103+ result = new DefaultJWSVerifierFactory ().createJWSVerifier (jwt .getHeader (), key .toRSAKey ().toRSAPublicKey ());
104+ }
105+
105106 if (result == null ) {
106107 throw new BadCredentialsException ("Cannot verify JWT" );
107108 } else {
108109 return result ;
109110 }
110111 }
111112
112- private void validateClaims (JwtToken jwt ) throws JwtException {
113- JwtClaims claims = jwt .getClaims ();
113+ private void validateClaims (SignedJWT jwt ) throws ParseException , BadJWTException {
114+ JWTClaimsSet claims = jwt .getJWTClaimsSet ();
114115
115116 if (claims != null ) {
116- JwtUtils .validateJwtExpiry (claims , clockSkewToleranceSeconds , false );
117- JwtUtils .validateJwtNotBefore (claims , clockSkewToleranceSeconds , false );
117+ DefaultJWTClaimsVerifier <SimpleSecurityContext > claimsVerifier = new DefaultJWTClaimsVerifier <>(
118+ requiredAudience ,
119+ null ,
120+ Collections .emptySet ()
121+ );
122+ claimsVerifier .setMaxClockSkew (clockSkewToleranceSeconds );
123+ claimsVerifier .verify (claims , null );
118124 validateRequiredAudienceAndIssuer (claims );
119125 }
120126 }
121127
122- private void validateRequiredAudienceAndIssuer (JwtClaims claims ) {
123- String audience = claims .getAudience ();
128+ private void validateRequiredAudienceAndIssuer (JWTClaimsSet claims ) throws BadJWTException {
129+ String audience = claims .getAudience (). stream (). findFirst (). orElse ( "" ) ;
124130 String issuer = claims .getIssuer ();
125131
126132 if (!Strings .isNullOrEmpty (requiredAudience ) && !requiredAudience .equals (audience )) {
127- throw new JwtException ("Invalid audience" );
133+ throw new BadJWTException ("Invalid audience" );
128134 }
129135
130136 if (!Strings .isNullOrEmpty (requiredIssuer ) && !requiredIssuer .equals (issuer )) {
131- throw new JwtException ("Invalid issuer" );
137+ throw new BadJWTException ("Invalid issuer" );
132138 }
133139 }
134140}
0 commit comments