OpenIdHelper.java
/*
* Copyright (C) 2023 jtalbut
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package uk.co.spudsoft.jwtvalidatorvertx.impl;
import io.vertx.core.Future;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.client.HttpResponse;
import io.vertx.ext.web.client.WebClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.co.spudsoft.jwtvalidatorvertx.impl.AsyncLoadingCache.TimedObject;
/**
* Helper class for performing OpenID Discovery and JWKS requests.
*
* @author njt
*/
public class OpenIdHelper {
private static final Logger logger = LoggerFactory.getLogger(OpenIdHelper.class);
private final WebClient webClient;
private final long defaultCacheDurationS;
/**
* Constructor.
* @param webClient The Vert.x WebClient to use for making HTTP requests.
* @param defaultCacheDurationS The default time that the caller should cache results.
*/
public OpenIdHelper(WebClient webClient, long defaultCacheDurationS) {
this.webClient = webClient;
this.defaultCacheDurationS = defaultCacheDurationS;
}
private static boolean succeeded(int statusCode) {
return statusCode >= 200 && statusCode < 300;
}
private long calculateExpiry(long requestTimeMsSinceEpoch, HttpResponse<?> response) {
long maxAgeSecondsSinceEpoch = Long.MAX_VALUE;
for (String header : response.headers().getAll(HttpHeaders.CACHE_CONTROL)) {
for (String headerDirective : header.split(",")) {
String[] directiveParts = headerDirective.split("=", 2);
directiveParts[0] = directiveParts[0].trim();
if ("max-age".equals(directiveParts[0])) {
try {
long value = Long.parseLong(directiveParts[1].replaceAll("\"", "").trim().toLowerCase());
if (value > 0 && value < maxAgeSecondsSinceEpoch) {
maxAgeSecondsSinceEpoch = value;
}
} catch (NumberFormatException e) {
logger.warn("Invalid max-age cache-control directive ({}): ", directiveParts[1], e);
}
}
}
}
// If we don't get any other instruction the value gets cached for one minute.
if (maxAgeSecondsSinceEpoch == Long.MAX_VALUE) {
maxAgeSecondsSinceEpoch = defaultCacheDurationS;
}
return requestTimeMsSinceEpoch + maxAgeSecondsSinceEpoch * 1000;
}
/**
* Get a JsonObject from a URL and return it as Future with an expiry time.
* @param url The URL to be got.
* @return A TimedObject containing JSON from the URL and an expiry time based on the Cache-Control max-age header.
*/
public Future<TimedObject<JsonObject>> get(String url) {
long requestTime = System.currentTimeMillis();
try {
return webClient.getAbs(url)
.send()
.map(response -> {
if (succeeded(response.statusCode())) {
String body = response.bodyAsString();
return new TimedObject<>(new JsonObject(body), calculateExpiry(requestTime, response));
} else {
logger.debug("Request to {} returned {}: {}", url, response.statusCode(), response.bodyAsString());
throw new IllegalStateException("Request to " + url + " returned " + response.statusCode());
}
});
} catch (Exception ex) {
logger.error("The JWKS URI ({}) is not a valid URL: ", url, ex);
return Future.failedFuture(new IllegalArgumentException("Parse of signed JWT failed", ex));
}
}
}