VertxJwksHandler.java
/*
* Copyright (C) 2025 njt
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package uk.co.spudsoft.jwtvalidatorvertx.vertx;
import com.google.common.cache.Cache;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.http.HttpServer;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;
import java.io.IOException;
import java.net.ServerSocket;
import java.security.PublicKey;
import java.util.Base64;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.co.spudsoft.jwtvalidatorvertx.AlgorithmAndKeyPair;
import uk.co.spudsoft.jwtvalidatorvertx.JsonWebAlgorithm;
import uk.co.spudsoft.jwtvalidatorvertx.JwkBuilder;
import uk.co.spudsoft.jwtvalidatorvertx.JwksHandler;
import uk.co.spudsoft.jwtvalidatorvertx.TokenBuilder;
import uk.co.spudsoft.jwtvalidatorvertx.jdk.JdkJwksHandler;
import uk.co.spudsoft.jwtvalidatorvertx.jdk.JdkTokenBuilder;
/**
* An implementation of JwksHandler as a Vertx {@link Handler}<{@link RoutingContext}>.
*
* @author njt
*/
public class VertxJwksHandler implements Handler<RoutingContext>, JwksHandler {
@SuppressWarnings("constantname")
private static final Logger logger = LoggerFactory.getLogger(JdkJwksHandler.class);
private static final Base64.Encoder BASE64 = Base64.getUrlEncoder().withoutPadding();
private final Vertx vertx;
private final HttpServer httpServer;
private final String host;
private final int port;
private final String basePath;
private final String issuer;
private final String configUrl;
private final String jwksUrl;
private final String tokenUrl;
private final boolean withTokenBuilder;
private Cache<String, AlgorithmAndKeyPair> keyCache;
private TokenBuilder tokenBuilder = null;
@Override
public void setKeyCache(Cache<String, AlgorithmAndKeyPair> keyCache) {
this.keyCache = keyCache;
if (withTokenBuilder) {
tokenBuilder = new JdkTokenBuilder(keyCache);
}
}
@Override
public String getBaseUrl() {
return "http://localhost:" + port + basePath;
}
/**
* Get the port that the handler is listening on.
* @return the port that the handler is listening on.
*/
public int getPort() {
return port;
}
/**
* Factory method to create a new VertxJwksHandler.
*
* Note that this differs from the constructor in that it creates a dedicated Vertx instance and {@link HttpServer}.
* If this factory method is used then {@link #start()} must be called, if the constructor is called directly then
* the caller may choose to control the lifetime of the HttpServer manually.
*
* The URL generated from the host, post and basePath is the issuer that will be used in tokens.
*
* @param host The hostname to use - typically this will just be localhost.
* @param port The port to use - may be 0 to choose a random available port.
* @param basePath The path to use - should being with a slash but not end with one.
* @param withTokenBuilder If true provide an endpoint for creating test tokens.
* @return A newly created (but not yet active) VertxJwksHandler.
* @throws IOException if port <= 0 and unable to find an available port.
*/
public static VertxJwksHandler create(String host, int port, String basePath, boolean withTokenBuilder) throws IOException {
if (port <= 0) {
try (ServerSocket s = new ServerSocket(0)) {
port = s.getLocalPort();
}
}
Vertx vertx = Vertx.vertx();
HttpServer httpServer = vertx.createHttpServer();
Router router = Router.router(vertx);
VertxJwksHandler handler = new VertxJwksHandler(vertx, httpServer, host, port, basePath, withTokenBuilder);
router.route(basePath + "/*").handler(handler);
httpServer.requestHandler(router);
return handler;
}
static String checkPath(String path) {
if (!path.startsWith("/")) {
path = "/" + path;
}
if (path.endsWith("/")) {
path = path.substring(0, path.length() - 1);
}
return path;
}
/**
* Constructor.
*
* The URL generated from the host, post and basePath is the issuer that will be used in tokens.
*
* @param vertx The Vertx instance to be owned by the handler.
* Pass in null if the lifetime of the Vertx instance is not managed by this handler.
* @param httpServer The Vertx server to be owned by the handler.
* Pass in null if the lifetime of the Vertx server is not managed by this handler.
* @param host The hostname to use in URLs generated and provided to clients.
* @param port The port to use in URLs generated and provided to clients.
* Also the port to listen on if the httpServer is managed by this handler.
* This must not be zero.
* @param basePath The path to use in URLs generated and provided to clients.
* @param withTokenBuilder If true provide an endpoint for creating test tokens.
*
*/
public VertxJwksHandler(Vertx vertx, HttpServer httpServer, String host, int port, String basePath, boolean withTokenBuilder) {
this.vertx = vertx;
this.httpServer = httpServer;
this.host = host;
this.port = port;
this.basePath = checkPath(basePath);
this.issuer = "http://" + host + ":" + port + basePath;
this.configUrl = "http://" + host + ":" + port + basePath + "/.well-known/openid-configuration";
this.jwksUrl = "http://" + host + ":" + port + basePath + "/jwks";
this.tokenUrl = withTokenBuilder ? "http://" + host + ":" + port + basePath + "/token" : "";
this.withTokenBuilder = withTokenBuilder;
}
@Override
public void start() {
if (httpServer != null) {
httpServer.listen(port);
}
}
@Override
public void close() throws IOException {
if (httpServer != null) {
httpServer.close();
vertx.close();
}
}
private void sendResponse(RoutingContext exchange, int responseCode, String contentType, Buffer body) {
exchange.response()
.setStatusCode(responseCode)
.putHeader("Content-Type", contentType)
.end(body);
}
@Override
public void handle(RoutingContext exchange) {
HttpServerRequest request = exchange.request();
String url = request.absoluteURI();
logger.debug("handle {} {}", request.method(), url);
if (request.method() == HttpMethod.GET) {
if (configUrl.equals(url)) {
handleConfigRequest(exchange);
} else if (jwksUrl.equals(url)) {
handleJwksRequest(exchange);
} else {
exchange.next();
}
} else if (request.method() == HttpMethod.PUT) {
if (tokenUrl.equals(url)) {
handleTokenRequest(exchange);
} else {
exchange.next();
}
} else {
exchange.next();
}
}
private void handleConfigRequest(RoutingContext exchange) {
JsonObject config = new JsonObject();
config.put("jwks_uri", jwksUrl);
sendResponse(exchange, 200, "application/json", config.toBuffer());
}
private void handleJwksRequest(RoutingContext exchange) {
JsonObject jwkSet = new JsonObject();
JsonArray jwks = new JsonArray();
jwkSet.put("keys", jwks);
synchronized (keyCache) {
keyCache.asMap().forEach((kid, akp) -> {
PublicKey key = akp.getKeyPair().getPublic();
try {
JsonObject json = JwkBuilder.get(key).toJson(kid, akp.getAlgorithm().getName(), key);
jwks.add(json);
} catch (Exception ex) {
logger.warn("Failed to add key {} to JWKS: ", kid, ex);
}
});
}
exchange.response().putHeader("cache-control", "max-age=100");
sendResponse(exchange, 200, "application/json", jwkSet.toBuffer());
}
private void handleTokenRequest(RoutingContext exchange) {
HttpServerRequest request = exchange.request();
request.body()
.andThen(ar -> {
if (ar.succeeded()) {
try {
JsonObject body = ar.result().length() > 0 ? ar.result().toJsonObject() : new JsonObject();
long nbf = (System.currentTimeMillis() / 1000);
long exp = nbf + 60 * 60;
String token = tokenBuilder.buildToken(JsonWebAlgorithm.RS512, "test", this.issuer, null, null, nbf, exp, body.getMap());
sendResponse(exchange, 200, "text/plain", Buffer.buffer(token));
} catch (Throwable ex) {
logger.debug("Failed to get request body as json: ", ar.cause());
sendResponse(exchange, 400, "text/plain", Buffer.buffer("Failed to read request body"));
}
} else {
logger.debug("Failed to get request body: ", ar.cause());
sendResponse(exchange, 500, "text/plain", Buffer.buffer("Failed to get request body"));
}
});
}
}