JwtValidatorVertxImpl.java
package uk.co.spudsoft.jwtvalidatorvertx.impl;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableSet;
import io.vertx.core.Future;
import io.vertx.ext.auth.impl.jose.JWK;
import io.vertx.ext.auth.impl.jose.JWS;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.Base64;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashSet;
import static java.util.Objects.requireNonNull;
import java.util.Set;
import uk.co.spudsoft.jwtvalidatorvertx.IssuerAcceptabilityHandler;
import uk.co.spudsoft.jwtvalidatorvertx.JsonWebKeySetHandler;
import uk.co.spudsoft.jwtvalidatorvertx.Jwt;
import uk.co.spudsoft.jwtvalidatorvertx.JwtValidator;
/**
* Token validation for vertx - implementation of {@link uk.co.spudsoft.jwtvalidatorvertx.JwtValidator}.
* @author Jim Talbut
*/
public class JwtValidatorVertxImpl implements JwtValidator {
@SuppressWarnings("constantname")
private static final Logger logger = LoggerFactory.getLogger(JwtValidatorVertxImpl.class);
private static final Base64.Decoder B64DECODER = Base64.getUrlDecoder();
private static final Set<String> DEFAULT_PERMITTED_ALGS = ImmutableSet.of(
JWS.EdDSA
, JWS.ES256
, JWS.ES384
, JWS.ES512
, JWS.PS256
, JWS.PS384
, JWS.PS512
, JWS.ES256K
, JWS.RS256
, JWS.RS384
, JWS.RS512
);
private Set<String> permittedAlgs;
private boolean requireExp = true;
private boolean requireNbf = true;
private long timeLeewayMilliseconds = 0;
private final JsonWebKeySetHandler jsonWebKeySetHandler;
private final IssuerAcceptabilityHandler issuerAcceptabilityHandler;
/**
* Constructor.
* @param jsonWebKeySetHandler Handler for obtaining JWKs
* @param issuerAcceptabilityHandler Handler for validating issuers found in the JWT.
*/
public JwtValidatorVertxImpl(JsonWebKeySetHandler jsonWebKeySetHandler, IssuerAcceptabilityHandler issuerAcceptabilityHandler) {
this.jsonWebKeySetHandler = jsonWebKeySetHandler;
this.issuerAcceptabilityHandler = issuerAcceptabilityHandler;
this.permittedAlgs = new HashSet<>(DEFAULT_PERMITTED_ALGS);
}
@Override
public Set<String> getPermittedAlgorithms() {
return ImmutableSet.copyOf(permittedAlgs);
}
@Override
public JwtValidator setPermittedAlgorithms(Set<String> algorithms) throws NoSuchAlgorithmException {
Set<String> copy = new HashSet<>();
for (String alg : algorithms) {
if (!DEFAULT_PERMITTED_ALGS.contains(alg)) {
throw new NoSuchAlgorithmException();
} else {
copy.add(alg);
}
}
this.permittedAlgs = copy;
return this;
}
@Override
public JwtValidator addPermittedAlgorithm(String algorithm) throws NoSuchAlgorithmException {
if (!DEFAULT_PERMITTED_ALGS.contains(algorithm)) {
throw new NoSuchAlgorithmException();
} else {
permittedAlgs.add(algorithm);
}
return this;
}
/**
* Set the maximum amount of time that can pass between the exp and now.
* @param timeLeeway the maximum amount of time that can pass between the exp and now.
*/
@Override
public JwtValidator setTimeLeeway(Duration timeLeeway) {
this.timeLeewayMilliseconds = timeLeeway.toMillis();
return this;
}
/**
* Set to true if the token is required to have an exp claim.
* @param requireExp true if the token is required to have an exp claim.
*/
@Override
public JwtValidator setRequireExp(boolean requireExp) {
this.requireExp = requireExp;
return this;
}
/**
* Set to true if the token is required to have an nbf claim.
* @param requireNbf true if the token is required to have an nbf claim.
*/
@Override
public JwtValidator setRequireNbf(boolean requireNbf) {
this.requireNbf = requireNbf;
return this;
}
/**
* Validate the token and either throw an exception or return it's constituent parts.
* @param token The token.
* @param requiredAudList List of audiences, all of which must be claimed by the token. If null the defaultRequiredAud is used.
* @param ignoreRequiredAud Do not check for required audiences.
* @return The token's parts.
*/
@Override
public Future<Jwt> validateToken(
String issuer
, String token
, List<String> requiredAudList
, boolean ignoreRequiredAud
) {
Jwt jwt;
try {
jwt = Jwt.parseJws(token);
} catch (Throwable ex) {
if (logger.isTraceEnabled()) {
logger.error("Parse of JWT ({}) failed: ", token, ex);
} else {
logger.error("Parse of JWT failed: ", ex);
}
return Future.failedFuture(new IllegalArgumentException("Parse of signed JWT failed", ex));
}
try {
validateAlgorithm(jwt.getAlgorithm());
String kid = jwt.getKid();
if (jwt.getPayloadSize() == 0) {
logger.error("No payload claims found in JWT");
return Future.failedFuture(new IllegalArgumentException("Parse of signed JWT failed"));
}
return jsonWebKeySetHandler.findJwk(issuer, kid)
.onFailure(ex -> {
logger.warn("Failed to find JWK for {} ({}): ", kid, issuer, ex);
})
.compose(jwk -> {
try {
verify(jwk, jwt);
long now = System.currentTimeMillis();
validateIssuer(jwt, issuer);
validateNbf(jwt, now);
validateExp(jwt, now);
validateAud(jwt, requiredAudList, ignoreRequiredAud);
validateSub(jwt);
return Future.succeededFuture(jwt);
} catch (Throwable ex) {
logger.info("Validation of {} token failed: ", jwt.getAlgorithm(), ex);
return Future.failedFuture(new IllegalArgumentException("Validation of " + jwt.getAlgorithm() + " signed JWT failed", ex));
}
});
} catch (Throwable ex) {
logger.error("Failed to process token: ", ex);
return Future.failedFuture(ex);
}
}
private void validateIssuer(Jwt jwt, String externalIssuer) {
String tokenIssuer = jwt.getIssuer();
// empty issuer is never allowed
if (Strings.isNullOrEmpty(tokenIssuer)) {
throw new IllegalStateException("No issuer in token.");
}
if (!issuerAcceptabilityHandler.isAcceptable(tokenIssuer)) {
throw new IllegalStateException("Issuer from token (" + tokenIssuer + ") is not acceptable.");
}
if (externalIssuer != null) {
if (!externalIssuer.equals(tokenIssuer)) {
throw new IllegalStateException("Issuer from token (" + tokenIssuer + ") does not match expected issuer (" + externalIssuer + ").");
}
}
}
private void verify(JWK jwk, Jwt jwt) throws IllegalArgumentException {
// empty signature is never allowed
if (Strings.isNullOrEmpty(jwt.getSignature())) {
throw new IllegalStateException("No signature in token.");
}
requireNonNull(jwk, "JWK not set");
// if we only allow secure alg, then none is not a valid option
if ("none".equals(jwk.getAlgorithm())) {
throw new IllegalStateException("Algorithm \"none\" not allowed");
}
byte[] payloadInput = B64DECODER.decode(jwt.getSignature());
byte[] signingInput = jwt.getSignatureBase().getBytes(StandardCharsets.UTF_8);
try {
JWS jws = new JWS(jwk);
if (!jws.verify(payloadInput, signingInput)) {
throw new IllegalArgumentException("Signature verification failed");
}
} catch (Throwable ex) {
logger.warn("Signature verification failed: ", ex);
throw new IllegalArgumentException("Signature verification failed", ex);
}
}
private void validateSub(Jwt jwt) throws IllegalArgumentException {
if (Strings.isNullOrEmpty(jwt.getSubject())) {
throw new IllegalArgumentException("No subject specified in token");
}
}
private void validateAud(Jwt jwt, List<String> requiredAudList, boolean ignoreRequiredAud) throws IllegalArgumentException {
if ((requiredAudList == null) || (!ignoreRequiredAud && requiredAudList.isEmpty())) {
throw new IllegalStateException("Required audience not set");
}
if (jwt.getAudience() == null) {
throw new IllegalArgumentException("Token does not include aud claim");
}
for (String aud : jwt.getAudience()) {
for (String requiredAud : requiredAudList) {
if (requiredAud.equals(aud)) {
return;
}
}
}
if (!ignoreRequiredAud) {
if (requiredAudList.size() == 1) {
logger.warn("Required audience ({}) not found in token aud claim: {}", requiredAudList.get(0), jwt.getAudience());
} else {
logger.warn("None of the required audiences ({}) found in token aud claim: {}", requiredAudList, jwt.getAudience());
}
throw new IllegalArgumentException("Required audience not found in token");
}
}
private void validateExp(Jwt jwt, long now) throws IllegalArgumentException {
if (jwt.getExpiration() != null) {
long targetMs = now - timeLeewayMilliseconds;
if (1000 * jwt.getExpiration() < targetMs) {
logger.warn("Token exp = {} ({}), now = {} ({}), target = {} ({})", jwt.getExpiration(), jwt.getExpirationLocalDateTime(), now, LocalDateTime.ofInstant(Instant.ofEpochMilli(now), ZoneOffset.UTC), targetMs, LocalDateTime.ofInstant(Instant.ofEpochMilli(targetMs), ZoneOffset.UTC));
throw new IllegalArgumentException("Token is not valid after " + jwt.getExpirationLocalDateTime());
}
} else if (requireExp) {
throw new IllegalArgumentException("Token does not specify exp");
}
}
private void validateNbf(Jwt jwt, long now) throws IllegalArgumentException {
if (jwt.getNotBefore() != null) {
long targetMs = now + timeLeewayMilliseconds;
if (1000 * jwt.getNotBefore() > targetMs) {
logger.warn("Token nbf = {} ({}), now = {} ({}), target = {} ({})", jwt.getNotBefore(), jwt.getNotBeforeLocalDateTime(), now, LocalDateTime.ofInstant(Instant.ofEpochMilli(now), ZoneOffset.UTC), targetMs, LocalDateTime.ofInstant(Instant.ofEpochMilli(targetMs), ZoneOffset.UTC));
throw new IllegalArgumentException("Token is not valid until " + jwt.getNotBeforeLocalDateTime());
}
} else if (requireNbf) {
throw new IllegalArgumentException("Token does not specify nbf");
}
}
private void validateAlgorithm(String algorithm) throws IllegalArgumentException {
if (algorithm == null) {
logger.warn("No signature algorithm in token.");
throw new IllegalArgumentException("Parse of signed JWT failed");
}
if (!permittedAlgs.contains(algorithm)) {
logger.warn("Failed to find algorithm \"{}\" in {}", algorithm, permittedAlgs);
throw new IllegalArgumentException("Parse of signed JWT failed");
}
}
}