JWKSAwsElbHandlerImpl.java
/*
* Copyright (C) 2025 jtalbut
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package uk.co.spudsoft.jwtvalidatorvertx.impl;
import com.google.common.collect.ImmutableList;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.buffer.Buffer;
import io.vertx.ext.auth.PubSecKeyOptions;
import io.vertx.ext.auth.impl.jose.JWK;
import io.vertx.ext.web.client.WebClient;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.co.spudsoft.jwtvalidatorvertx.JsonWebKeySetAwsElbHandler;
import uk.co.spudsoft.jwtvalidatorvertx.impl.AsyncLoadingCache.TimedObject;
/**
* Implementation of {@link JsonWebKeySetAwsElbHandler} that stores JWKs in a HashMap.
*
* @author jtalbut
*/
public class JWKSAwsElbHandlerImpl implements JsonWebKeySetAwsElbHandler {
private static final Logger logger = LoggerFactory.getLogger(JWKSOpenIdDiscoveryHandlerImpl.class);
private final List<String> keyBaseUrls;
private final WebClient webClient;
private final long cacheDurationMillis;
private final Map<String, TimedObject<JWK>> keys = new HashMap<>();
/**
* Constructor.
*
* With a static map of JWKs the security of the system is not compromised by allowing any issuer, though you should question
* why this is necessary.
*
* Each JWKs endpoint must use KIDs that are globally unique.
*
* When a KID is requested and cannot be found ALL the configured JWKS URLs will be queried and the single cache will be
* updated. Entries in the cache will be retained for a duration based on either the Cache-Control max-age header of the
* response or, if that is not present, the defaultJwkCacheDuration. Given that only positive responses are cached it is
* reasonable for the defaultJwkCacheDuration to be 24 hours (or more).
*
* @param webClient Vertx WebClient instance, that will be used for querying the JWKS URLs.
* @param keyBaseUrls Static set of base URLs that will be used for constructing the URLs to the AWS keys.
* @param defaultJwkCacheDuration Time to keep JWKs in cache if no cache-control: max-age header is found.
*
* The JWKS URLs must be accessed via https for the environment to offer any security. This is not enforced at the code level.
*
* @see <a href="https://docs.aws.amazon.com/elasticloadbalancing/latest/application/listener-authenticate-users.html#user-claims-encoding">listener-authenticate-users.html#user-claims-encoding</a>
*/
public JWKSAwsElbHandlerImpl(WebClient webClient, Collection<String> keyBaseUrls, Duration defaultJwkCacheDuration) {
this.webClient = webClient;
this.cacheDurationMillis = defaultJwkCacheDuration.toMillis();
this.keyBaseUrls = keyBaseUrls.stream().map(url -> url.endsWith("/") ? url : url + "/").collect(ImmutableList.toImmutableList());
}
@Override
public void optimize() {
}
private JWK findJwk(String kid) {
synchronized (keys) {
TimedObject<JWK> jwk = keys.get(kid);
long now = System.currentTimeMillis();
if (null != jwk) {
if (jwk.expiredBefore(now)) {
keys.remove(kid);
} else {
return jwk.getValue();
}
}
return null;
}
}
@Override
public Future<JWK> findJwk(String issuer, String kid) {
if (!kid.matches("^[A-Za-z0-9._~-]*$")) {
logger.error("The kid \"{}\" is not a valid AWS ELB kid", kid);
throw new IllegalArgumentException("The kid is not a valid AWS ELB kid.");
}
JWK foundJwk = findJwk(kid);
if (foundJwk != null) {
return Future.succeededFuture(foundJwk);
}
Promise<JWK> resultPromise = Promise.promise();
List<Future<Void>> trackingFutures = new ArrayList<>();
for (String baseUrl : this.keyBaseUrls) {
String awsKeyUrl = baseUrl + kid;
Future<Void> future = webClient.getAbs(awsKeyUrl)
.send()
.compose(response -> {
if (response.statusCode() >= 200 && response.statusCode() < 300) {
JWK jwk;
Buffer body = response.body();
try {
jwk = pemToJwk(kid, body);
} catch (Throwable ex) {
logger.warn("From {} failed to parse body ({}) as JWKRequest: ", awsKeyUrl, body, response.body());
return Future.<Void>succeededFuture();
}
synchronized (keys) {
keys.put(kid, new TimedObject<>(jwk, System.currentTimeMillis() + cacheDurationMillis));
}
resultPromise.tryComplete(jwk);
} else {
logger.warn("Request to {} returned {}: {}", awsKeyUrl, response.statusCode(), response.body());
}
return Future.<Void>succeededFuture();
})
.recover(ex -> {
logger.warn("Failed request to {}: ", awsKeyUrl, ex);
return Future.<Void>succeededFuture();
});
trackingFutures.add(future);
}
// After all requests finish, fail the promise if none succeeded
Future.all(trackingFutures).onComplete(ar -> {
if (!resultPromise.future().isComplete()) {
resultPromise.fail("No valid response found");
}
});
return resultPromise.future();
}
private static JWK pemToJwk(String kid, Buffer pem) {
PubSecKeyOptions keyOptions = new PubSecKeyOptions()
.setAlgorithm("ES256")
.setBuffer(pem)
.setId(kid);
return new JWK(keyOptions);
}
}