package org.jboss.resteasy.security.doseta;
import java.io.IOException;
import java.io.InputStream;
import java.security.AccessController;
import java.security.KeyFactory;
import java.security.PrivateKey;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.security.PublicKey;
import java.security.spec.X509EncodedKeySpec;
import java.util.Base64;
import java.util.Hashtable;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.naming.directory.Attributes;
import javax.naming.directory.DirContext;
import javax.naming.directory.InitialDirContext;
import javax.ws.rs.core.SecurityContext;
import org.jboss.resteasy.core.ResteasyContext;
import org.jboss.resteasy.security.doseta.i18n.LogMessages;
import org.jboss.resteasy.security.doseta.i18n.Messages;
import org.jboss.resteasy.util.ParameterParser;
public class DosetaKeyRepository implements KeyRepository
{
protected class CacheEntry<T>
{
public long time = System.currentTimeMillis();
public T key;
protected CacheEntry(final T key)
{
this.key = key;
}
public boolean isStale()
{
return time + cacheTimeout >= System.currentTimeMillis();
}
}
protected ConcurrentHashMap<String, CacheEntry<PrivateKey>> privateCache = new ConcurrentHashMap<String, CacheEntry<PrivateKey>>();
protected ConcurrentHashMap<String, CacheEntry<PublicKey>> publicCache = new ConcurrentHashMap<String, CacheEntry<PublicKey>>();
protected KeyStoreKeyRepository keyStore;
protected String defaultPrivateDomain;
protected boolean useDns = false;
protected boolean userPrincipalAsPrivateSelector = false;
protected String dnsUri;
protected long cacheTimeout = 3600000L;
protected String keyStorePath;
protected String keyStoreFile;
protected String keyStorePassword;
public void start()
{
if (keyStore == null)
{
if (keyStoreFile != null)
{
try
{
keyStore = new KeyStoreKeyRepository(keyStoreFile, keyStorePassword);
}
catch (IOException e)
{
throw new RuntimeException(e);
}
}
else
{
if (keyStorePath != null)
{
ClassLoader loader = null;
if (System.getSecurityManager() == null) {
loader = Thread.currentThread().getContextClassLoader();
} else {
try {
loader = AccessController.doPrivileged(new PrivilegedExceptionAction<ClassLoader>() {
@Override
public ClassLoader run() throws Exception {
return Thread.currentThread().getContextClassLoader();
}
});
} catch (PrivilegedActionException pae) {
throw new RuntimeException(pae);
}
}
InputStream is = loader.getResourceAsStream(keyStorePath.trim());
if (is == null) throw new RuntimeException(Messages.MESSAGES.unableToFindKeyStore(keyStorePath));
keyStore = new KeyStoreKeyRepository(is, keyStorePassword);
try
{
is.close();
}
catch (IOException e)
{
throw new RuntimeException(e);
}
}
}
}
}
public String getDefaultPrivateSelector()
{
if (userPrincipalAsPrivateSelector)
{
SecurityContext securityContext = ResteasyContext.getContextData(SecurityContext.class);
if (securityContext != null)
{
return securityContext.getUserPrincipal().getName();
}
}
return null;
}
public String getKeyStorePath()
{
return keyStorePath;
}
public void setKeyStorePath(String keyStorePath)
{
this.keyStorePath = keyStorePath;
}
public String getKeyStoreFile()
{
return keyStoreFile;
}
public void setKeyStoreFile(String keyStoreFile)
{
this.keyStoreFile = keyStoreFile;
}
public String getKeyStorePassword()
{
return keyStorePassword;
}
public void setKeyStorePassword(String keyStorePassword)
{
this.keyStorePassword = keyStorePassword;
}
public KeyStoreKeyRepository getKeyStore()
{
return keyStore;
}
public void setKeyStore(KeyStoreKeyRepository keyStore)
{
this.keyStore = keyStore;
}
public String getDefaultPrivateDomain()
{
return defaultPrivateDomain;
}
public void setDefaultPrivateDomain(String defaultPrivateDomain)
{
this.defaultPrivateDomain = defaultPrivateDomain;
}
public boolean isUseDns()
{
return useDns;
}
public void setUseDns(boolean useDns)
{
this.useDns = useDns;
}
public boolean isUserPrincipalAsPrivateSelector()
{
return userPrincipalAsPrivateSelector;
}
public void setUserPrincipalAsPrivateSelector(boolean userPrincipalAsPrivateSelector)
{
this.userPrincipalAsPrivateSelector = userPrincipalAsPrivateSelector;
}
public String getDnsUri()
{
return dnsUri;
}
public void setDnsUri(String dnsUri)
{
this.dnsUri = dnsUri;
}
public long getCacheTimeout()
{
return cacheTimeout;
}
public void setCacheTimeout(long cacheTimeout)
{
this.cacheTimeout = cacheTimeout;
}
protected void addPrivate(String alias, PrivateKey key)
{
privateCache.put(alias, new CacheEntry<PrivateKey>(key));
}
protected void addPublic(String alias, PublicKey key)
{
publicCache.put(alias, new CacheEntry<PublicKey>(key));
}
protected PrivateKey getPrivateCache(String alias)
{
CacheEntry<PrivateKey> entry = privateCache.get(alias);
if (entry == null) return null;
if (entry.isStale())
{
privateCache.remove(alias, entry);
return null;
}
return entry.key;
}
protected PublicKey getPublicCache(String alias)
{
CacheEntry<PublicKey> entry = publicCache.get(alias);
if (entry == null) return null;
if (entry.isStale())
{
publicCache.remove(alias, entry);
return null;
}
return entry.key;
}
public String getAlias(DKIMSignature header)
{
StringBuffer buf = new StringBuffer();
String selector = header.getSelector();
if (selector != null) buf.append(selector.trim()).append(".");
buf.append("_domainKey.");
String domain = header.getDomain();
if (domain == null)
{
throw new RuntimeException(Messages.MESSAGES.domainAttributeIsRequired());
}
buf.append(domain);
return buf.toString();
}
public PrivateKey findPrivateKey(DKIMSignature header)
{
String alias = getAlias(header);
if (alias == null) return null;
PrivateKey key = getPrivateCache(alias);
if (key != null) return key;
if (keyStore != null)
{
key = keyStore.getPrivateKey(alias);
if (key != null) addPrivate(alias, key);
}
return key;
}
public PublicKey findPublicKey(DKIMSignature header)
{
String alias = getAlias(header);
if (alias == null) return null;
PublicKey key = getPublicCache(alias);
if (key != null) return key;
if (keyStore != null)
{
key = keyStore.getPublicKey(alias);
if (key != null) addPublic(alias, key);
}
if (useDns)
{
key = findFromDns(alias);
addPublic(alias, key);
}
return key;
}
protected PublicKey findFromDns(String alias)
{
if (LogMessages.LOGGER.isDebugEnabled()) LogMessages.LOGGER.debug(Messages.MESSAGES.checkDNS(alias));
PublicKey key;
try
{
Hashtable<String, String> env = new Hashtable<String, String>();
env.put("java.naming.factory.initial", "com.sun.jndi.dns.DnsContextFactory");
if (dnsUri != null)
{
env.put("java.naming.provider.url", dnsUri);
}
DirContext dnsContext = new InitialDirContext(env);
Attributes attrs1 = dnsContext.getAttributes(alias, new String[]{"TXT"});
javax.naming.directory.Attribute txtrecord = attrs1.get("txt");
String record = txtrecord.get().toString();
dnsContext.close();
if (LogMessages.LOGGER.isDebugEnabled()) LogMessages.LOGGER.debug(Messages.MESSAGES.dnsRecordFound(record));
ParameterParser parser = new ParameterParser();
parser.setLowerCaseNames(true);
Map<String, String> keyEntry = parser.parse(record, ';');
String type = keyEntry.get("k");
if (type != null && !type.toLowerCase().equals("rsa"))
throw new RuntimeException(Messages.MESSAGES.unsupportedKeyType(type));
String pem = keyEntry.get("p");
if (pem == null)
{
throw new RuntimeException(Messages.MESSAGES.noPEntry());
}
if (LogMessages.LOGGER.isDebugEnabled()) LogMessages.LOGGER.debug(Messages.MESSAGES.pem(pem));
byte[] der = Base64.getMimeDecoder().decode(pem);
X509EncodedKeySpec spec =
new X509EncodedKeySpec(der);
KeyFactory kf = KeyFactory.getInstance("RSA");
key = kf.generatePublic(spec);
}
catch (Exception e)
{
throw new RuntimeException(Messages.MESSAGES.failedToFindPublicKey(alias), e);
}
return key;
}
}