/*
* Copyright 2010-2013 Ning, Inc.
*
* Ning licenses this file to you under the Apache License, version 2.0
* (the "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package org.killbill.billing.util.security.shiro.realm;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.Attribute;
import javax.naming.directory.SearchControls;
import javax.naming.directory.SearchResult;
import javax.naming.ldap.LdapContext;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.config.Ini;
import org.apache.shiro.config.Ini.Section;
import org.apache.shiro.realm.ldap.JndiLdapContextFactory;
import org.apache.shiro.realm.ldap.JndiLdapRealm;
import org.apache.shiro.realm.ldap.LdapContextFactory;
import org.apache.shiro.realm.ldap.LdapUtils;
import org.apache.shiro.subject.PrincipalCollection;
import org.killbill.billing.util.config.definition.SecurityConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Functions;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.base.Splitter;
import com.google.common.collect.Collections2;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterators;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.inject.Inject;
public class KillBillJndiLdapRealm extends JndiLdapRealm {
private static final Logger log = LoggerFactory.getLogger(KillBillJndiLdapRealm.class);
private static final String USERDN_SUBSTITUTION_TOKEN = "{0}";
private static final SearchControls SUBTREE_SCOPE = new SearchControls();
static {
SUBTREE_SCOPE.setSearchScope(SearchControls.SUBTREE_SCOPE);
}
private static final Splitter SPLITTER = Splitter.on(',').omitEmptyStrings().trimResults();
private final String searchBase;
private final String groupSearchFilter;
private final String groupNameId;
private final Map<String, Collection<String>> permissionsByGroup = Maps.newLinkedHashMap();
private final String dnSearchFilter;
@Inject
public KillBillJndiLdapRealm(final SecurityConfig securityConfig) {
super();
if (securityConfig.getShiroLDAPUserDnTemplate() != null) {
setUserDnTemplate(securityConfig.getShiroLDAPUserDnTemplate());
}
final JndiLdapContextFactory contextFactory = (JndiLdapContextFactory) getContextFactory();
if (securityConfig.disableShiroLDAPSSLCheck()) {
contextFactory.getEnvironment().put("java.naming.ldap.factory.socket", SkipSSLCheckSocketFactory.class.getName());
}
contextFactory.getEnvironment().put("java.naming.referral", securityConfig.followShiroLDAPReferrals() ? "follow" : "ignore");
if (securityConfig.getShiroLDAPUrl() != null) {
contextFactory.setUrl(securityConfig.getShiroLDAPUrl());
}
if (securityConfig.getShiroLDAPSystemUsername() != null) {
contextFactory.setSystemUsername(securityConfig.getShiroLDAPSystemUsername());
}
if (securityConfig.getShiroLDAPSystemPassword() != null) {
contextFactory.setSystemPassword(securityConfig.getShiroLDAPSystemPassword());
}
if (securityConfig.getShiroLDAPAuthenticationMechanism() != null) {
contextFactory.setAuthenticationMechanism(securityConfig.getShiroLDAPAuthenticationMechanism());
}
setContextFactory(contextFactory);
dnSearchFilter = securityConfig.getShiroLDAPDnSearchTemplate();
searchBase = securityConfig.getShiroLDAPSearchBase();
groupSearchFilter = securityConfig.getShiroLDAPGroupSearchFilter();
groupNameId = securityConfig.getShiroLDAPGroupNameID();
if (securityConfig.getShiroLDAPPermissionsByGroup() != null) {
final Ini ini = new Ini();
// When passing properties on the command line, \n can be escaped
ini.load(securityConfig.getShiroLDAPPermissionsByGroup().replace("\\n", "\n"));
for (final Section section : ini.getSections()) {
for (final String rawRole : section.keySet()) {
// Un-escape manually = (required if the role name is a DN)
final Collection<String> permissions = ImmutableList.<String>copyOf(SPLITTER.split(section.get(rawRole)));
final String role = rawRole.replace("\\=", "=");
permissionsByGroup.put(role, permissions);
}
}
}
}
@Override
protected String getUserDn(final String principal) throws IllegalArgumentException, IllegalStateException {
if (dnSearchFilter != null) {
return findUserDN(principal, getContextFactory());
} else {
// Use template
return super.getUserDn(principal);
}
}
private String findUserDN(final String userName, final LdapContextFactory ldapContextFactory) {
LdapContext systemLdapCtx = null;
try {
systemLdapCtx = ldapContextFactory.getSystemLdapContext();
final NamingEnumeration<SearchResult> usersFound = systemLdapCtx.search(searchBase,
dnSearchFilter.replace(USERDN_SUBSTITUTION_TOKEN, userName),
SUBTREE_SCOPE);
return usersFound.hasMore() ? usersFound.next().getNameInNamespace() : null;
} catch (final AuthenticationException ex) {
log.info("LDAP authentication exception='{}'", ex.getLocalizedMessage());
throw new IllegalArgumentException(ex);
} catch (final NamingException e) {
log.info("LDAP exception='{}'", e.getLocalizedMessage());
throw new IllegalArgumentException(e);
} finally {
LdapUtils.closeContext(systemLdapCtx);
}
}
@Override
protected AuthorizationInfo queryForAuthorizationInfo(final PrincipalCollection principals, final LdapContextFactory ldapContextFactory) throws NamingException {
final Set<String> userGroups = findLDAPGroupsForUser(principals, ldapContextFactory);
final SimpleAuthorizationInfo simpleAuthorizationInfo = new SimpleAuthorizationInfo(userGroups);
final Set<String> stringPermissions = groupsPermissions(userGroups);
simpleAuthorizationInfo.setStringPermissions(stringPermissions);
return simpleAuthorizationInfo;
}
private Set<String> findLDAPGroupsForUser(final PrincipalCollection principals, final LdapContextFactory ldapContextFactory) throws NamingException {
final String username = (String) getAvailablePrincipal(principals);
LdapContext systemLdapCtx = null;
try {
systemLdapCtx = ldapContextFactory.getSystemLdapContext();
return findLDAPGroupsForUser(username, systemLdapCtx);
} catch (final AuthenticationException ex) {
log.info("LDAP authentication exception='{}'", ex.getLocalizedMessage());
return ImmutableSet.<String>of();
} finally {
LdapUtils.closeContext(systemLdapCtx);
}
}
private Set<String> findLDAPGroupsForUser(final String userName, final LdapContext ldapCtx) throws NamingException {
final NamingEnumeration<SearchResult> foundGroups = ldapCtx.search(searchBase,
groupSearchFilter.replace(USERDN_SUBSTITUTION_TOKEN, userName),
SUBTREE_SCOPE);
if (!foundGroups.hasMoreElements()) {
return ImmutableSet.<String>of();
}
// There should really only be one entry
final SearchResult result = foundGroups.next();
// Extract the name of all the groups
final Collection<String> finalGroupsNames = Collections2.<String>filter(extractGroupNamesFromSearchResult(result), Predicates.notNull());
return Sets.newHashSet(finalGroupsNames);
}
private Collection<String> extractGroupNamesFromSearchResult(final SearchResult searchResult) {
// Get all attributes for that group
final Iterator<? extends Attribute> attributesIterator = Iterators.forEnumeration(searchResult.getAttributes().getAll());
// Find the attribute representing the group name
final Iterator<? extends Attribute> groupNameAttributesIterator = Iterators.filter(attributesIterator,
new Predicate<Attribute>() {
@Override
public boolean apply(final Attribute attribute) {
return groupNameId.equalsIgnoreCase(attribute.getID());
}
});
// Extract the group name from the attribute
// Note: at this point, groupNameAttributesIterator should really contain a single element
final Iterator<Iterator<?>> groupNamesIterator = Iterators.transform(groupNameAttributesIterator,
new Function<Attribute, Iterator<?>>() {
@Override
public Iterator<?> apply(final Attribute groupNameAttribute) {
try {
final NamingEnumeration<?> enumeration = groupNameAttribute.getAll();
return Iterators.forEnumeration(enumeration);
} catch (final NamingException namingException) {
log.warn("Unable to read group name(s)", namingException);
return null;
}
}
});
final Iterator<?> finalGroupNamesIterator = Iterators.filter(Iterators.concat(groupNamesIterator), Predicates.notNull());
return ImmutableList.<String>copyOf(Iterators.transform(finalGroupNamesIterator, Functions.toStringFunction()));
}
private Set<String> groupsPermissions(final Set<String> groups) {
final Set<String> permissions = new HashSet<String>();
for (final String group : groups) {
final Collection<String> permissionsForGroup = permissionsByGroup.get(group);
if (permissionsForGroup != null) {
permissions.addAll(permissionsForGroup);
}
}
return permissions;
}
@VisibleForTesting
public Map<String, Collection<String>> getPermissionsByGroup() {
return permissionsByGroup;
}
}