Skip to content

Commit 9133cc2

Browse files
20fpsjzheaux
authored andcommitted
Add Cache to NimbusJwtDecoderJwkSetUriBuilder
PR gh-8332
1 parent b7d3acc commit 9133cc2

2 files changed

Lines changed: 154 additions & 3 deletions

File tree

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import com.nimbusds.jose.JOSEException;
3535
import com.nimbusds.jose.JWSAlgorithm;
3636
import com.nimbusds.jose.RemoteKeySourceException;
37+
import com.nimbusds.jose.jwk.JWKSet;
38+
import com.nimbusds.jose.jwk.source.JWKSetCache;
3739
import com.nimbusds.jose.jwk.source.JWKSource;
3840
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
3941
import com.nimbusds.jose.proc.JWSKeySelector;
@@ -49,6 +51,7 @@
4951
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
5052
import com.nimbusds.jwt.proc.JWTProcessor;
5153

54+
import org.springframework.cache.Cache;
5255
import org.springframework.core.convert.converter.Converter;
5356
import org.springframework.http.HttpHeaders;
5457
import org.springframework.http.HttpMethod;
@@ -68,6 +71,7 @@
6871
*
6972
* @author Josh Cummings
7073
* @author Joe Grandja
74+
* @author Mykyta Bezverkhyi
7175
* @since 5.2
7276
*/
7377
public final class NimbusJwtDecoder implements JwtDecoder {
@@ -215,6 +219,7 @@ public static final class JwkSetUriJwtDecoderBuilder {
215219
private String jwkSetUri;
216220
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
217221
private RestOperations restOperations = new RestTemplate();
222+
private Cache cache;
218223

219224
private JwkSetUriJwtDecoderBuilder(String jwkSetUri) {
220225
Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
@@ -264,6 +269,20 @@ public JwkSetUriJwtDecoderBuilder restOperations(RestOperations restOperations)
264269
return this;
265270
}
266271

272+
/**
273+
* Use the given {@link Cache} to store
274+
* <a href="https://tools.ietf.org/html/rfc7517#section-5">JWK Set</a>.
275+
*
276+
* @param cache the {@link Cache} to be used to store JWK Set
277+
* @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations
278+
* @since 5.4
279+
*/
280+
public JwkSetUriJwtDecoderBuilder cache(Cache cache) {
281+
Assert.notNull(cache, "cache cannot be null");
282+
this.cache = cache;
283+
return this;
284+
}
285+
267286
JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSource) {
268287
if (this.signatureAlgorithms.isEmpty()) {
269288
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
@@ -280,9 +299,17 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
280299
}
281300
}
282301

302+
JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever) {
303+
if (this.cache == null) {
304+
return new RemoteJWKSet<>(toURL(this.jwkSetUri), jwkSetRetriever);
305+
}
306+
ResourceRetriever cachingJwkSetRetriever = new CachingResourceRetriever(this.cache, jwkSetRetriever);
307+
return new RemoteJWKSet<>(toURL(this.jwkSetUri), cachingJwkSetRetriever, new NoOpJwkSetCache());
308+
}
309+
283310
JWTProcessor<SecurityContext> processor() {
284311
ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations);
285-
JWKSource<SecurityContext> jwkSource = new RemoteJWKSet<>(toURL(this.jwkSetUri), jwkSetRetriever);
312+
JWKSource<SecurityContext> jwkSource = jwkSource(jwkSetRetriever);
286313
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
287314
jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource));
288315

@@ -309,6 +336,44 @@ private static URL toURL(String url) {
309336
}
310337
}
311338

339+
private static class NoOpJwkSetCache implements JWKSetCache {
340+
@Override
341+
public void put(JWKSet jwkSet) {
342+
}
343+
344+
@Override
345+
public JWKSet get() {
346+
return null;
347+
}
348+
349+
@Override
350+
public boolean requiresRefresh() {
351+
return true;
352+
}
353+
}
354+
355+
private static class CachingResourceRetriever implements ResourceRetriever {
356+
private final Cache cache;
357+
private final ResourceRetriever resourceRetriever;
358+
359+
CachingResourceRetriever(Cache cache, ResourceRetriever resourceRetriever) {
360+
this.cache = cache;
361+
this.resourceRetriever = resourceRetriever;
362+
}
363+
364+
@Override
365+
public Resource retrieveResource(URL url) throws IOException {
366+
String jwkSet;
367+
try {
368+
jwkSet = cache.get(url.toString(), () -> resourceRetriever.retrieveResource(url).getContent());
369+
} catch (Exception ex) {
370+
throw new IOException(ex);
371+
}
372+
373+
return new Resource(jwkSet, "UTF-8");
374+
}
375+
}
376+
312377
private static class RestOperationsResourceRetriever implements ResourceRetriever {
313378
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
314379
private final RestOperations restOperations;

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import java.util.Date;
3333
import java.util.List;
3434
import java.util.Map;
35+
import java.util.concurrent.Callable;
3536
import javax.crypto.SecretKey;
3637

3738
import com.nimbusds.jose.JWSAlgorithm;
@@ -55,6 +56,8 @@
5556
import org.junit.Test;
5657

5758
import org.mockito.ArgumentCaptor;
59+
import org.springframework.cache.Cache;
60+
import org.springframework.cache.concurrent.ConcurrentMapCache;
5861
import org.springframework.core.convert.converter.Converter;
5962
import org.springframework.http.HttpStatus;
6063
import org.springframework.http.MediaType;
@@ -66,6 +69,7 @@
6669
import org.springframework.security.oauth2.jose.TestKeys;
6770
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
6871
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
72+
import org.springframework.web.client.RestClientException;
6973
import org.springframework.web.client.RestOperations;
7074

7175
import static org.assertj.core.api.Assertions.assertThat;
@@ -75,6 +79,8 @@
7579
import static org.mockito.ArgumentMatchers.eq;
7680
import static org.mockito.Mockito.mock;
7781
import static org.mockito.Mockito.verify;
82+
import static org.mockito.Mockito.verifyNoInteractions;
83+
import static org.mockito.Mockito.verifyNoMoreInteractions;
7884
import static org.mockito.Mockito.when;
7985
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri;
8086
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withPublicKey;
@@ -85,6 +91,7 @@
8591
*
8692
* @author Josh Cummings
8793
* @author Joe Grandja
94+
* @author Mykyta Bezverkhyi
8895
*/
8996
public class NimbusJwtDecoderTests {
9097
private static final String JWK_SET = "{\"keys\":[{\"p\":\"49neceJFs8R6n7WamRGy45F5Tv0YM-R2ODK3eSBUSLOSH2tAqjEVKOkLE5fiNA3ygqq15NcKRadB2pTVf-Yb5ZIBuKzko8bzYIkIqYhSh_FAdEEr0vHF5fq_yWSvc6swsOJGqvBEtuqtJY027u-G2gAQasCQdhyejer68zsTn8M\",\"kty\":\"RSA\",\"q\":\"tWR-ysspjZ73B6p2vVRVyHwP3KQWL5KEQcdgcmMOE_P_cPs98vZJfLhxobXVmvzuEWBpRSiqiuyKlQnpstKt94Cy77iO8m8ISfF3C9VyLWXi9HUGAJb99irWABFl3sNDff5K2ODQ8CmuXLYM25OwN3ikbrhEJozlXg_NJFSGD4E\",\"d\":\"FkZHYZlw5KSoqQ1i2RA2kCUygSUOf1OqMt3uomtXuUmqKBm_bY7PCOhmwbvbn4xZYEeHuTR8Xix-0KpHe3NKyWrtRjkq1T_un49_1LLVUhJ0dL-9_x0xRquVjhl_XrsRXaGMEHs8G9pLTvXQ1uST585gxIfmCe0sxPZLvwoic-bXf64UZ9BGRV3lFexWJQqCZp2S21HfoU7wiz6kfLRNi-K4xiVNB1gswm_8o5lRuY7zB9bRARQ3TS2G4eW7p5sxT3CgsGiQD3_wPugU8iDplqAjgJ5ofNJXZezoj0t6JMB_qOpbrmAM1EnomIPebSLW7Ky9SugEd6KMdL5lW6AuAQ\",\"e\":\"AQAB\",\"use\":\"sig\",\"kid\":\"one\",\"qi\":\"wdkFu_tV2V1l_PWUUimG516Zvhqk2SWDw1F7uNDD-Lvrv_WNRIJVzuffZ8WYiPy8VvYQPJUrT2EXL8P0ocqwlaSTuXctrORcbjwgxDQDLsiZE0C23HYzgi0cofbScsJdhcBg7d07LAf7cdJWG0YVl1FkMCsxUlZ2wTwHfKWf-v4\",\"dp\":\"uwnPxqC-IxG4r33-SIT02kZC1IqC4aY7PWq0nePiDEQMQWpjjNH50rlq9EyLzbtdRdIouo-jyQXB01K15-XXJJ60dwrGLYNVqfsTd0eGqD1scYJGHUWG9IDgCsxyEnuG3s0AwbW2UolWVSsU2xMZGb9PurIUZECeD1XDZwMp2s0\",\"dq\":\"hra786AunB8TF35h8PpROzPoE9VJJMuLrc6Esm8eZXMwopf0yhxfN2FEAvUoTpLJu93-UH6DKenCgi16gnQ0_zt1qNNIVoRfg4rw_rjmsxCYHTVL3-RDeC8X_7TsEySxW0EgFTHh-nr6I6CQrAJjPM88T35KHtdFATZ7BCBB8AE\",\"n\":\"oXJ8OyOv_eRnce4akdanR4KYRfnC2zLV4uYNQpcFn6oHL0dj7D6kxQmsXoYgJV8ZVDn71KGmuLvolxsDncc2UrhyMBY6DVQVgMSVYaPCTgW76iYEKGgzTEw5IBRQL9w3SRJWd3VJTZZQjkXef48Ocz06PGF3lhbz4t5UEZtdF4rIe7u-977QwHuh7yRPBQ3sII-cVoOUMgaXB9SHcGF2iZCtPzL_IffDUcfhLQteGebhW8A6eUHgpD5A1PQ-JCw_G7UOzZAjjDjtNM2eqm8j-Ms_gqnm4MiCZ4E-9pDN77CAAPVN7kuX6ejs9KBXpk01z48i9fORYk9u7rAkh1HuQw\"}]}";
@@ -247,6 +254,21 @@ public void decodeWhenJwkEndpointIsUnresponsiveThenReturnsJwtException() throws
247254
}
248255
}
249256

257+
@Test
258+
public void shouldThrowJwtExceptionWhenJwkSetEndpointHasNotRespondedAndCacheIsConfigured() throws Exception {
259+
try ( MockWebServer server = new MockWebServer() ) {
260+
Cache cache = new ConcurrentMapCache("test-jwk-set-cache");
261+
String jwkSetUri = server.url("/.well-known/jwks.json").toString();
262+
NimbusJwtDecoder jwtDecoder = withJwkSetUri(jwkSetUri).cache(cache).build();
263+
264+
server.shutdown();
265+
assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
266+
.isInstanceOf(JwtException.class)
267+
.isNotInstanceOf(BadJwtException.class)
268+
.hasMessageContaining("An error occurred while attempting to decode the Jwt");
269+
}
270+
}
271+
250272
@Test
251273
public void withJwkSetUriWhenNullOrEmptyThenThrowsException() {
252274
Assertions.assertThatCode(() -> withJwkSetUri(null)).isInstanceOf(IllegalArgumentException.class);
@@ -264,6 +286,12 @@ public void restOperationsWhenNullThenThrowsException() {
264286
Assertions.assertThatCode(() -> builder.restOperations(null)).isInstanceOf(IllegalArgumentException.class);
265287
}
266288

289+
@Test
290+
public void shouldThrowIllegalArgumentExceptionWhenJwkSetCacheIsNull() {
291+
NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = withJwkSetUri(JWK_SET_URI);
292+
Assertions.assertThatCode(() -> builder.cache(null)).isInstanceOf(IllegalArgumentException.class);
293+
}
294+
267295
@Test
268296
public void withPublicKeyWhenNullThenThrowsException() {
269297
assertThatThrownBy(() -> withPublicKey(null))
@@ -425,7 +453,7 @@ public void decodeWhenJwkSetRequestedThenAcceptHeaderJsonAndJwkSetJson() {
425453
RestOperations restOperations = mock(RestOperations.class);
426454
when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
427455
.thenReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
428-
JWTProcessor<SecurityContext> processor = withJwkSetUri("https://issuer/.well-known/jwks.json")
456+
JWTProcessor<SecurityContext> processor = withJwkSetUri(JWK_SET_URI)
429457
.restOperations(restOperations)
430458
.processor();
431459
NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(processor);
@@ -436,6 +464,64 @@ public void decodeWhenJwkSetRequestedThenAcceptHeaderJsonAndJwkSetJson() {
436464
assertThat(acceptHeader).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON);
437465
}
438466

467+
@Test
468+
public void shouldStoreRetrievedJwkSetToCache() {
469+
// given
470+
Cache cache = new ConcurrentMapCache("test-jwk-set-cache");
471+
RestOperations restOperations = mock(RestOperations.class);
472+
when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
473+
.thenReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
474+
NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI)
475+
.restOperations(restOperations)
476+
.cache(cache)
477+
.build();
478+
// when
479+
jwtDecoder.decode(SIGNED_JWT);
480+
// then
481+
assertThat(cache.get(JWK_SET_URI, String.class)).isEqualTo(JWK_SET);
482+
ArgumentCaptor<RequestEntity> requestEntityCaptor = ArgumentCaptor.forClass(RequestEntity.class);
483+
verify(restOperations).exchange(requestEntityCaptor.capture(), eq(String.class));
484+
verifyNoMoreInteractions(restOperations);
485+
List<MediaType> acceptHeader = requestEntityCaptor.getValue().getHeaders().getAccept();
486+
assertThat(acceptHeader).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON);
487+
}
488+
489+
@Test
490+
public void shouldDecodeJwtUsingJwkSetCache() {
491+
// given
492+
RestOperations restOperations = mock(RestOperations.class);
493+
Cache cache = mock(Cache.class);
494+
when(cache.get(eq(JWK_SET_URI), any(Callable.class))).thenReturn(JWK_SET);
495+
NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI)
496+
.cache(cache)
497+
.restOperations(restOperations)
498+
.build();
499+
// when
500+
jwtDecoder.decode(SIGNED_JWT);
501+
// then
502+
verify(cache).get(eq(JWK_SET_URI), any(Callable.class));
503+
verifyNoMoreInteractions(cache);
504+
verifyNoInteractions(restOperations);
505+
}
506+
507+
@Test
508+
public void shouldThrowJwtExceptionWhenExceptionOccurredWhileRetrievingJwkSetInsideCachingRetriever() {
509+
// given
510+
Cache cache = new ConcurrentMapCache("test-jwk-set-cache");
511+
RestOperations restOperations = mock(RestOperations.class);
512+
when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
513+
.thenThrow(new RestClientException("Cannot retrieve JWK Set"));
514+
NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI)
515+
.restOperations(restOperations)
516+
.cache(cache)
517+
.build();
518+
// then
519+
assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
520+
.isInstanceOf(JwtException.class)
521+
.isNotInstanceOf(BadJwtException.class)
522+
.hasMessageContaining("An error occurred while attempting to decode the Jwt");
523+
}
524+
439525
private RSAPublicKey key() throws InvalidKeySpecException {
440526
byte[] decoded = Base64.getDecoder().decode(VERIFY_KEY.getBytes());
441527
EncodedKeySpec spec = new X509EncodedKeySpec(decoded);
@@ -466,7 +552,7 @@ private static JWTProcessor<SecurityContext> withSigning(String jwkResponse) {
466552
RestOperations restOperations = mock(RestOperations.class);
467553
when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
468554
.thenReturn(new ResponseEntity<>(jwkResponse, HttpStatus.OK));
469-
return withJwkSetUri("https://issuer/.well-known/jwks.json")
555+
return withJwkSetUri(JWK_SET_URI)
470556
.restOperations(restOperations)
471557
.processor();
472558
}

0 commit comments

Comments
 (0)