package org.keycloak.keys.infinispan;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import org.infinispan.Cache;
import org.jboss.logging.Logger;
import org.keycloak.cluster.ClusterProvider;
import org.keycloak.common.util.Time;
import org.keycloak.crypto.KeyWrapper;
import org.keycloak.keys.PublicKeyLoader;
import org.keycloak.keys.PublicKeyStorageProvider;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakTransaction;
public class InfinispanPublicKeyStorageProvider implements PublicKeyStorageProvider {
private static final Logger log = Logger.getLogger(InfinispanPublicKeyStorageProvider.class);
private final KeycloakSession session;
private final Cache<String, PublicKeysEntry> keys;
private final Map<String, FutureTask<PublicKeysEntry>> tasksInProgress;
private final int minTimeBetweenRequests ;
private Set<String> invalidations = new HashSet<>();
private boolean transactionEnlisted = false;
public InfinispanPublicKeyStorageProvider(KeycloakSession session, Cache<String, PublicKeysEntry> keys, Map<String, FutureTask<PublicKeysEntry>> tasksInProgress, int minTimeBetweenRequests) {
this.session = session;
this.keys = keys;
this.tasksInProgress = tasksInProgress;
this.minTimeBetweenRequests = minTimeBetweenRequests;
}
void addInvalidation(String cacheKey) {
if (!transactionEnlisted) {
session.getTransactionManager().enlistAfterCompletion(getAfterTransaction());
transactionEnlisted = true;
}
this.invalidations.add(cacheKey);
}
protected KeycloakTransaction getAfterTransaction() {
return new KeycloakTransaction() {
@Override
public void begin() {
}
@Override
public void commit() {
runInvalidations();
}
@Override
public void rollback() {
runInvalidations();
}
@Override
public void setRollbackOnly() {
}
@Override
public boolean getRollbackOnly() {
return false;
}
@Override
public boolean isActive() {
return true;
}
};
}
protected void runInvalidations() {
ClusterProvider cluster = session.getProvider(ClusterProvider.class);
for (String cacheKey : invalidations) {
keys.remove(cacheKey);
cluster.notify(InfinispanPublicKeyStorageProviderFactory.PUBLIC_KEY_STORAGE_INVALIDATION_EVENT, PublicKeyStorageInvalidationEvent.create(cacheKey), true, ClusterProvider.DCNotify.ALL_DCS);
}
}
@Override
public KeyWrapper getPublicKey(String modelKey, String kid, PublicKeyLoader loader) {
PublicKeysEntry entry = keys.get(modelKey);
if (entry != null) {
KeyWrapper publicKey = getPublicKey(entry.getCurrentKeys(), kid);
if (publicKey != null) {
return publicKey;
}
}
int lastRequestTime = entry==null ? 0 : entry.getLastRequestTime();
int currentTime = Time.currentTime();
if (currentTime > lastRequestTime + minTimeBetweenRequests) {
WrapperCallable wrapperCallable = new WrapperCallable(modelKey, loader);
FutureTask<PublicKeysEntry> task = new FutureTask<>(wrapperCallable);
FutureTask<PublicKeysEntry> existing = tasksInProgress.putIfAbsent(modelKey, task);
if (existing == null) {
task.run();
} else {
task = existing;
}
try {
entry = task.get();
KeyWrapper publicKey = getPublicKey(entry.getCurrentKeys(), kid);
if (publicKey != null) {
return publicKey;
}
} catch (ExecutionException ee) {
throw new RuntimeException("Error when loading public keys", ee);
} catch (InterruptedException ie) {
throw new RuntimeException("Error. Interrupted when loading public keys", ie);
} finally {
if (existing == null) {
tasksInProgress.remove(modelKey);
}
}
} else {
log.warnf("Won't load the keys for model '%s' . Last request time was %d", modelKey, lastRequestTime);
}
Set<String> availableKids = entry==null ? Collections.emptySet() : entry.getCurrentKeys().keySet();
log.warnf("PublicKey wasn't found in the storage. Requested kid: '%s' . Available kids: '%s'", kid, availableKids);
return null;
}
private KeyWrapper getPublicKey(Map<String, KeyWrapper> publicKeys, String kid) {
if (kid == null && !publicKeys.isEmpty()) {
return publicKeys.values().iterator().next();
} else {
return publicKeys.get(kid);
}
}
@Override
public void close() {
}
private class WrapperCallable implements Callable<PublicKeysEntry> {
private final String modelKey;
private final PublicKeyLoader delegate;
public WrapperCallable(String modelKey, PublicKeyLoader delegate) {
this.modelKey = modelKey;
this.delegate = delegate;
}
@Override
public PublicKeysEntry call() throws Exception {
PublicKeysEntry entry = keys.get(modelKey);
int lastRequestTime = entry==null ? 0 : entry.getLastRequestTime();
int currentTime = Time.currentTime();
if (currentTime > lastRequestTime + minTimeBetweenRequests) {
Map<String, KeyWrapper> publicKeys = delegate.loadKeys();
if (log.isDebugEnabled()) {
log.debugf("Public keys retrieved successfully for model %s. New kids: %s", modelKey, publicKeys.keySet().toString());
}
entry = new PublicKeysEntry(currentTime, publicKeys);
keys.put(modelKey, entry);
}
return entry;
}
}
}