JwtValidatorVertxImpl.java

package uk.co.spudsoft.jwtvalidatorvertx.impl;

import com.google.common.base.Strings;
import io.vertx.core.Future;
import java.nio.charset.StandardCharsets;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.Base64;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.co.spudsoft.jwtvalidatorvertx.JWK;
import uk.co.spudsoft.jwtvalidatorvertx.JWT;
import uk.co.spudsoft.jwtvalidatorvertx.JsonWebAlgorithm;

import java.util.EnumSet;
import uk.co.spudsoft.jwtvalidatorvertx.JsonWebKeySetHandler;
import uk.co.spudsoft.jwtvalidatorvertx.JwtValidatorVertx;

/**
 * Token validation for vertx - implementation of {@link uk.co.spudsoft.jwtvalidatorvertx.JwtValidatorVertx}.
 * @author Jim Talbut
 */
public class JwtValidatorVertxImpl implements JwtValidatorVertx {

  @SuppressWarnings("constantname")
  private static final Logger logger = LoggerFactory.getLogger(JwtValidatorVertxImpl.class);

  private static final Base64.Decoder DECODER = Base64.getUrlDecoder();

  private static final EnumSet<JsonWebAlgorithm> DEFAULT_PERMITTED_ALGS = EnumSet.of(
          JsonWebAlgorithm.RS256, JsonWebAlgorithm.RS384, JsonWebAlgorithm.RS512
  );
  
  private EnumSet<JsonWebAlgorithm> permittedAlgs;
  
  private boolean requireExp = true;
  private boolean requireNbf = true;
  
  private long timeLeewaySeconds = 0;
  
  private final JsonWebKeySetHandler openIdDiscoveryHandler;
  
  /**
   * Constructor.
   * @param openIdDiscoveryHandler         Handler for obtaining OpenIdDiscovery data and JWKs
   */
  public JwtValidatorVertxImpl(JsonWebKeySetHandler openIdDiscoveryHandler) {
    this.openIdDiscoveryHandler = openIdDiscoveryHandler;
    this.permittedAlgs = EnumSet.copyOf(DEFAULT_PERMITTED_ALGS);
  }

  @Override
  public EnumSet<JsonWebAlgorithm> getPermittedAlgorithms() {
    return EnumSet.copyOf(permittedAlgs);
  }

  @Override
  public JwtValidatorVertx setPermittedAlgorithms(EnumSet<JsonWebAlgorithm> algorithms) {
    this.permittedAlgs = EnumSet.copyOf(algorithms);
    return this;
  }

  @Override
  public JwtValidatorVertx addPermittedAlgorithm(JsonWebAlgorithm algorithm) {
    this.permittedAlgs.add(algorithm);
    return this;
  }
  
  /**
   * Set the maximum amount of time that can pass between the exp and now.
   * @param timeLeewaySeconds the maximum amount of time that can pass between the exp and now.
   */
  @Override
  public JwtValidatorVertx setTimeLeewaySeconds(long timeLeewaySeconds) {
    this.timeLeewaySeconds = timeLeewaySeconds;
    return this;
  }

  /**
   * Set to true if the token is required to have an exp claim.
   * @param requireExp true if the token is required to have an exp claim.
   */
  @Override
  public JwtValidatorVertx setRequireExp(boolean requireExp) {
    this.requireExp = requireExp;
    return this;
  }

  /**
   * Set to true if the token is required to have an nbf claim.
   * @param requireNbf true if the token is required to have an nbf claim.
   */
  @Override
  public JwtValidatorVertx setRequireNbf(boolean requireNbf) {
    this.requireNbf = requireNbf;
    return this;
  }
  
  /**
   * Validate the token and either throw an exception or return it's constituent parts.
   * @param token             The token.
   * @param requiredAudList   List of audiences, all of which must be claimed by the token. If null the defaultRequiredAud is used.
   * @param ignoreRequiredAud Do not check for required audiences.
   * @return The token's parts.
   */
  @Override
  public Future<JWT> validateToken(
          String token
          , List<String> requiredAudList
          , boolean ignoreRequiredAud
  ) {
    
    JWT jwt;
    try {
      jwt = JWT.parseJws(token);
    } catch (Throwable ex) {
      logger.error("Parse of JWT failed: ", ex);
      return Future.failedFuture(new IllegalArgumentException("Parse of signed JWT failed", ex));
    }

    try {
      JsonWebAlgorithm jwa = validateAlgorithm(jwt.getAlgorithm());
      jwt.getKid();

      if (jwt.getPayloadSize() == 0) {
        logger.error("No payload claims found in JWT");
        return Future.failedFuture(new IllegalArgumentException("Parse of signed JWT failed"));
      }
      String issuer = jwt.getIssuer();

      openIdDiscoveryHandler.validateIssuer(issuer);

      return jwt.getJwk(openIdDiscoveryHandler)
              .compose(jwk -> {
                try {
                  verify(jwa, jwk, jwt);

                  long nowSeconds = System.currentTimeMillis() / 1000;
                  validateNbf(jwt, nowSeconds);
                  validateExp(jwt, nowSeconds);
                  validateAud(jwt, requiredAudList, ignoreRequiredAud);
                  validateSub(jwt);

                  return Future.succeededFuture(jwt);
                } catch (Throwable ex) {
                  logger.info("Validation of {} token failed: ", jwt.getAlgorithm(), ex);
                  return Future.failedFuture(new IllegalArgumentException("Validation of " + jwt.getAlgorithm() + " signed JWT failed", ex));
                }
              });
    } catch (Throwable ex) {
      logger.error("Failed to process token: ", ex);
      return Future.failedFuture(ex);
    }
  }

  private void verify(JsonWebAlgorithm jwa, JWK jwk, JWT jwt) throws IllegalArgumentException {

    // empty signature is never allowed
    if (Strings.isNullOrEmpty(jwt.getSignature())) {
      throw new IllegalStateException("No signature in token.");
    }

    // if we only allow secure alg, then none is not a valid option
    if (JsonWebAlgorithm.none == jwa) {
      throw new IllegalStateException("Algorithm \"none\" not allowed");
    }

    byte[] payloadInput = Base64.getUrlDecoder().decode(jwt.getSignature());

    byte[] signingInput = jwt.getSignatureBase().getBytes(StandardCharsets.UTF_8);

    try {
      if (!jwk.verify(jwa, payloadInput, signingInput)) {
        throw new IllegalArgumentException("Signature verification failed");
      }
    } catch (Throwable ex) {
      logger.warn("Signature verification failed: ", ex);
      throw new IllegalArgumentException("Signature verification failed", ex);
    }
  }

  private void validateSub(JWT jwt) throws IllegalArgumentException {
    if (Strings.isNullOrEmpty(jwt.getSubject())) {
      throw new IllegalArgumentException("No subject specified in token");
    }
  }

  private void validateAud(JWT jwt, List<String> requiredAudList, boolean ignoreRequiredAud) throws IllegalArgumentException {
    if ((requiredAudList == null) || (!ignoreRequiredAud && requiredAudList.isEmpty())) {
      throw new IllegalStateException("Required audience not set");
    }
    if (jwt.getAudience() == null) {
      throw new IllegalArgumentException("Token does not include aud claim");
    }
    for (String aud : jwt.getAudience()) {
      for (String requiredAud : requiredAudList) {
        if (requiredAud.equals(aud)) {
          return;
        }
      }
    }
    if (!ignoreRequiredAud) {
      if (requiredAudList.size() == 1) {
        logger.warn("Required audience ({}) not found in token aud claim: {}", requiredAudList.get(0), jwt.getAudience());
      } else {
        logger.warn("None of the required audiences ({}) found in token aud claim: {}", requiredAudList, jwt.getAudience());
      }
      throw new IllegalArgumentException("Required audience not found in token");
    }
  }

  private void validateExp(JWT jwt, long nowSeconds) throws IllegalArgumentException {
    if (jwt.getExpiration() != null) {
      long target = nowSeconds - timeLeewaySeconds;
      if (jwt.getExpiration() < target) {
        logger.warn("Token exp = {} ({}), now = {} ({}), target = {} ({})", jwt.getExpiration(), jwt.getExpirationLocalDateTime(), nowSeconds, LocalDateTime.ofEpochSecond(nowSeconds, 0, ZoneOffset.UTC), target, LocalDateTime.ofEpochSecond(target, 0, ZoneOffset.UTC));
        throw new IllegalArgumentException("Token is not valid after " + jwt.getExpirationLocalDateTime());
      }
    } else if (requireExp) {
      throw new IllegalArgumentException("Token does not specify exp");
    }
  }

  private void validateNbf(JWT jwt, long nowSeconds) throws IllegalArgumentException {
    if (jwt.getNotBefore() != null) {
      long target = nowSeconds + timeLeewaySeconds;
      if (jwt.getNotBefore() > target) {
        throw new IllegalArgumentException("Token is not valid until " + jwt.getNotBeforeLocalDateTime());
      }
    } else if (requireNbf) {
      throw new IllegalArgumentException("Token does not specify exp");
    }
  }

  private JsonWebAlgorithm validateAlgorithm(String algorithm) {
    if (algorithm == null) {
      logger.warn("No signature algorithm in token.");
      throw new IllegalArgumentException("Parse of signed JWT failed");
    }
    JsonWebAlgorithm jwa;
    try {
      jwa = JsonWebAlgorithm.valueOf(algorithm);
    } catch (Throwable ex) {
      logger.warn("Failed to parse algorithm \"{}\"", algorithm);
      throw new IllegalArgumentException("Parse of signed JWT failed");
    }
    if (!permittedAlgs.contains(jwa)) {
      logger.warn("Failed to find algorithm \"{}\" in {}", algorithm, permittedAlgs);
      throw new IllegalArgumentException("Parse of signed JWT failed");
    }
    return jwa;
  }

}