/*
* Copyright 2010-2013 Ning, Inc.
*
* Ning licenses this file to you under the Apache License, version 2.0
* (the "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.ning.billing.util.security.shiro.realm;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.Attribute;
import javax.naming.directory.SearchControls;
import javax.naming.directory.SearchResult;
import javax.naming.ldap.LdapContext;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.config.Ini;
import org.apache.shiro.config.Ini.Section;
import org.apache.shiro.realm.ldap.JndiLdapContextFactory;
import org.apache.shiro.realm.ldap.JndiLdapRealm;
import org.apache.shiro.realm.ldap.LdapContextFactory;
import org.apache.shiro.realm.ldap.LdapUtils;
import org.apache.shiro.subject.PrincipalCollection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.ning.billing.util.config.SecurityConfig;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterators;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.inject.Inject;
public class KillBillJndiLdapRealm extends JndiLdapRealm {
private static final Logger log = LoggerFactory.getLogger(KillBillJndiLdapRealm.class);
private static final String USERDN_SUBSTITUTION_TOKEN = "{0}";
private static final SearchControls SUBTREE_SCOPE = new SearchControls();
static {
SUBTREE_SCOPE.setSearchScope(SearchControls.SUBTREE_SCOPE);
}
private static final Splitter SPLITTER = Splitter.on(',').omitEmptyStrings().trimResults();
private final String searchBase;
private final String groupSearchFilter;
private final String groupNameId;
private final Map<String, Collection<String>> permissionsByGroup = Maps.newLinkedHashMap();
@Inject
public KillBillJndiLdapRealm(final SecurityConfig securityConfig) {
super();
if (securityConfig.getShiroLDAPUserDnTemplate() != null) {
setUserDnTemplate(securityConfig.getShiroLDAPUserDnTemplate());
}
final JndiLdapContextFactory contextFactory = (JndiLdapContextFactory) getContextFactory();
if (securityConfig.disableShiroLDAPSSLCheck()) {
contextFactory.getEnvironment().put("java.naming.ldap.factory.socket", SkipSSLCheckSocketFactory.class.getName());
}
if (securityConfig.getShiroLDAPUrl() != null) {
contextFactory.setUrl(securityConfig.getShiroLDAPUrl());
}
if (securityConfig.getShiroLDAPSystemUsername() != null) {
contextFactory.setSystemUsername(securityConfig.getShiroLDAPSystemUsername());
}
if (securityConfig.getShiroLDAPSystemPassword() != null) {
contextFactory.setSystemPassword(securityConfig.getShiroLDAPSystemPassword());
}
if (securityConfig.getShiroLDAPAuthenticationMechanism() != null) {
contextFactory.setAuthenticationMechanism(securityConfig.getShiroLDAPAuthenticationMechanism());
}
setContextFactory(contextFactory);
searchBase = securityConfig.getShiroLDAPSearchBase();
groupSearchFilter = securityConfig.getShiroLDAPGroupSearchFilter();
groupNameId = securityConfig.getShiroLDAPGroupNameID();
if (securityConfig.getShiroLDAPPermissionsByGroup() != null) {
final Ini ini = new Ini();
// When passing properties on the command line, \n can be escaped
ini.load(securityConfig.getShiroLDAPPermissionsByGroup().replace("\\n", "\n"));
for (final Section section : ini.getSections()) {
for (final String role : section.keySet()) {
final Collection<String> permissions = ImmutableList.<String>copyOf(SPLITTER.split(section.get(role)));
permissionsByGroup.put(role, permissions);
}
}
}
}
@Override
protected AuthorizationInfo queryForAuthorizationInfo(final PrincipalCollection principals, final LdapContextFactory ldapContextFactory) throws NamingException {
final Set<String> userGroups = findLDAPGroupsForUser(principals, ldapContextFactory);
final SimpleAuthorizationInfo simpleAuthorizationInfo = new SimpleAuthorizationInfo(userGroups);
final Set<String> stringPermissions = groupsPermissions(userGroups);
simpleAuthorizationInfo.setStringPermissions(stringPermissions);
return simpleAuthorizationInfo;
}
private Set<String> findLDAPGroupsForUser(final PrincipalCollection principals, final LdapContextFactory ldapContextFactory) throws NamingException {
final String username = (String) getAvailablePrincipal(principals);
LdapContext systemLdapCtx = null;
try {
systemLdapCtx = ldapContextFactory.getSystemLdapContext();
return findLDAPGroupsForUser(username, systemLdapCtx);
} catch (AuthenticationException ex) {
log.info("LDAP authentication exception: " + ex.getLocalizedMessage());
return ImmutableSet.<String>of();
} finally {
LdapUtils.closeContext(systemLdapCtx);
}
}
private Set<String> findLDAPGroupsForUser(final String userName, final LdapContext ldapCtx) throws NamingException {
final NamingEnumeration<SearchResult> foundGroups = ldapCtx.search(searchBase,
groupSearchFilter.replace(USERDN_SUBSTITUTION_TOKEN, userName),
SUBTREE_SCOPE);
// Extract the name of all the groups
final Iterator<SearchResult> groupsIterator = Iterators.<SearchResult>forEnumeration(foundGroups);
final Iterator<String> groupsNameIterator = Iterators.<SearchResult, String>transform(groupsIterator,
new Function<SearchResult, String>() {
@Override
public String apply(final SearchResult groupEntry) {
return extractGroupNameFromSearchResult(groupEntry);
}
});
final Iterator<String> finalGroupsNameIterator = Iterators.<String>filter(groupsNameIterator, Predicates.notNull());
return Sets.newHashSet(finalGroupsNameIterator);
}
private String extractGroupNameFromSearchResult(final SearchResult searchResult) {
// Get all attributes for that group
final Iterator<? extends Attribute> attributesIterator = Iterators.forEnumeration(searchResult.getAttributes().getAll());
// Find the attribute representing the group name
final Iterator<? extends Attribute> groupNameAttributesIterator = Iterators.filter(attributesIterator,
new Predicate<Attribute>() {
@Override
public boolean apply(final Attribute attribute) {
return groupNameId.equalsIgnoreCase(attribute.getID());
}
});
// Extract the group name from the attribute
// Note: at this point, groupNameAttributesIterator should really contain a single element
final Iterator<String> groupNamesIterator = Iterators.transform(groupNameAttributesIterator,
new Function<Attribute, String>() {
@Override
public String apply(final Attribute groupNameAttribute) {
try {
final NamingEnumeration<?> enumeration = groupNameAttribute.getAll();
if (enumeration.hasMore()) {
return enumeration.next().toString();
} else {
return null;
}
} catch (NamingException namingException) {
log.warn("Unable to read group name", namingException);
return null;
}
}
});
final Iterator<String> finalGroupNamesIterator = Iterators.<String>filter(groupNamesIterator, Predicates.notNull());
if (finalGroupNamesIterator.hasNext()) {
return finalGroupNamesIterator.next();
} else {
log.warn("Unable to find an attribute matching {}", groupNameId);
return null;
}
}
private Set<String> groupsPermissions(final Set<String> groups) {
final Set<String> permissions = new HashSet<String>();
for (final String group : groups) {
final Collection<String> permissionsForGroup = permissionsByGroup.get(group);
if (permissionsForGroup != null) {
permissions.addAll(permissionsForGroup);
}
}
return permissions;
}
@VisibleForTesting
public Map<String, Collection<String>> getPermissionsByGroup() {
return permissionsByGroup;
}
}