JWKSStaticSetHandlerImpl.java

/*
 * Copyright (C) 2022 jtalbut
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package uk.co.spudsoft.jwtvalidatorvertx.impl;

import com.google.common.collect.ImmutableList;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.auth.impl.jose.JWK;
import io.vertx.ext.web.client.WebClient;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.co.spudsoft.jwtvalidatorvertx.JsonWebKeySetKnownJwksHandler;
import uk.co.spudsoft.jwtvalidatorvertx.impl.AsyncLoadingCache.TimedObject;

/**
 * Implementation of {@link JsonWebKeySetKnownJwksHandler} that stores JWKs in a HashMap.
 * 
 * @author jtalbut
 */
public class JWKSStaticSetHandlerImpl implements JsonWebKeySetKnownJwksHandler {
  
  private static final Logger logger = LoggerFactory.getLogger(JWKSOpenIdDiscoveryHandlerImpl.class);
  
  private final List<String> jwksUrls;
  private final Map<String, TimedObject<JWK>> keys = new HashMap<>();
  private final AtomicReference<Future<Void>> refreshFuture = new AtomicReference<>(null);
  
  private final OpenIdHelper openIdHelper;

  /**
   * Constructor.
   * 
   * With a static map of JWKs the security of the system is not compromised by allowing any issuer, though you should question why this is necessary.
   * 
   * Each JWKs endpoint must use KIDs that are globally unique.
   * 
   * When a KID is requested and cannot be found ALL the configured JWKS URLs will be queried and the single cache will be updated.
   * Entries in the cache will be retained for a duration based on either the Cache-Control max-age header of the response or, 
   * if that is not present, the defaultJwkCacheDuration.
   * Given that only positive responses are cached it is reasonable for the defaultJwkCacheDuration to be 24 hours (or more).
   * 
   * @param webClient Vertx WebClient instance, that will be used for querying the JWKS URLs.
   * @param jwksUrls Static set of URLs that will be used for obtaining JWKs.
   * @param defaultJwkCacheDuration Time to keep JWKs in cache if no cache-control: max-age header is found.
   * 
   * The JWKS URLs must be accessed via https for the environment to offer any security.
   * This is not enforced at the code level.
   * 
   */
  public JWKSStaticSetHandlerImpl(WebClient webClient, Collection<String> jwksUrls, Duration defaultJwkCacheDuration) {
    this.jwksUrls = ImmutableList.copyOf(jwksUrls);
    this.openIdHelper = new OpenIdHelper(webClient, defaultJwkCacheDuration.toSeconds());
  }
  
  @Override
  public void optimize() {
    findJwk(null, "");
  }
  
  private JWK findJwk(String kid) {
    TimedObject<JWK> jwk = keys.get(kid);
    long now = System.currentTimeMillis();
    if (null != jwk) {
      if (jwk.expiredBefore(now)) {
        keys.remove(kid);
      } else {
        return jwk.getValue();
      }
    }
    return null;
  }

  @Override
  public Future<JWK> findJwk(String issuer, String kid) {
    synchronized (keys) {
      JWK jwk = findJwk(kid);
      if (jwk != null) {
        return Future.succeededFuture(jwk);
      }
    
      Promise<Void> refreshPromise = Promise.promise();
      Future<Void> newRefreshFuture = refreshPromise.future();
      Future<Void> result = refreshFuture.compareAndExchange(null, newRefreshFuture);
      if (result == null) {
        result = updateCache()
                .compose(newkeys -> {
                  synchronized (keys) {
                    keys.putAll(newkeys);
                  }
                  refreshPromise.complete();
                  refreshFuture.set(null);
                  return Future.succeededFuture();
                });
      } 
      return result.compose(v -> {
        synchronized (keys) {
          JWK newjwk = findJwk(kid);
          if (newjwk != null) {
            return Future.succeededFuture(newjwk);
          }
          return Future.failedFuture(new IllegalArgumentException("The key \"" + kid + "\" cannot be found."));
        }
      });
    }
  }
  
  private Future<Map<String, TimedObject<JWK>>> updateCache() {
    
    if (jwksUrls.isEmpty()) {
      logger.error("Unable to validate any JWKs because no jwksUrls have been configured");
      IllegalStateException ex = new IllegalStateException("Unable to validate any JWKs because no jwksUrls have been configured");
      return Future.failedFuture(ex);
    }
    
    Map<String, TimedObject<JWK>> result = new HashMap<>();
    List<Future<Void>> futures = new ArrayList<>();
    
    for (String jwksUrl : jwksUrls) {
      futures.add(
              openIdHelper.get(jwksUrl)
                      .compose(tjo -> {
                        return addKeysToCache(jwksUrl, tjo, result);
                      })
      );
    }
    
    return Future.all(futures)
            .compose(cf -> {
              return Future.succeededFuture(result);
            });
  }
  
  private Future<Void> addKeysToCache(String url, TimedObject<JsonObject> data, Map<String, TimedObject<JWK>> result) {
    try {
      Object keysObject = data.getValue().getValue("keys");
      if (keysObject instanceof JsonArray) {
        JsonArray ja = (JsonArray) keysObject;
        for (Iterator<Object> iter = ja.iterator(); iter.hasNext();) {
          Object keyData = iter.next();
          try {
            if (keyData instanceof JsonObject) {
              JsonObject jo = (JsonObject) keyData;
              String keyId = jo.getString("kid");
              JWK jwk = new JWK(jo);
              synchronized (result) {
                result.put(keyId, new TimedObject<>(jwk, data.getExpiryMs()));
              }
            }
          } catch (Throwable ex) {
            logger.warn("Failed to parse {} from {} as a JWK: ", keyData, url, ex);
          }
        }
      } else {
        logger.error("Failed to get JWKS from {} (returned value does not contain a keys array: {}))", url, data.getValue());
        return Future.failedFuture(new IllegalArgumentException("Failed to parse JWKS from " + url));
      }
    } catch (Throwable ex) {
      logger.error("Failed to get process JWKS from {} ({}): ", url, data.getValue(), ex);
      return Future.failedFuture(new IllegalArgumentException("Failed to process JWKS from " + url));
    }
    return Future.succeededFuture();
  }
  
}