diff --git a/util/src/main/java/org/killbill/billing/util/config/definition/SecurityConfig.java b/util/src/main/java/org/killbill/billing/util/config/definition/SecurityConfig.java
index 24d3c2d..f0e5c99 100644
--- a/util/src/main/java/org/killbill/billing/util/config/definition/SecurityConfig.java
+++ b/util/src/main/java/org/killbill/billing/util/config/definition/SecurityConfig.java
@@ -41,6 +41,11 @@ public interface SecurityConfig extends KillbillConfig {
@Description("LDAP server's User DN format (e.g. uid={0},ou=users,dc=mycompany,dc=com)")
public String getShiroLDAPUserDnTemplate();
+ @Config("org.killbill.security.ldap.dnSearchTemplate")
+ @DefaultNull
+ @Description("LDAP server's DN search template (e.g. sAMAccountName={0}) for search-then-bind authentication (in case a static DN format template isn't enough)")
+ public String getShiroLDAPDnSearchTemplate();
+
@Config("org.killbill.security.ldap.searchBase")
@DefaultNull
@Description("LDAP search base to use")
@@ -88,6 +93,11 @@ public interface SecurityConfig extends KillbillConfig {
@Description("Whether to ignore SSL certificates checks")
public boolean disableShiroLDAPSSLCheck();
+ @Config("org.killbill.security.ldap.followReferrals")
+ @Default("false")
+ @Description("Whether to follow referrals")
+ public boolean followShiroLDAPReferrals();
+
// Okta realm
@Config("org.killbill.security.okta.url")
diff --git a/util/src/main/java/org/killbill/billing/util/security/shiro/realm/KillBillJndiLdapRealm.java b/util/src/main/java/org/killbill/billing/util/security/shiro/realm/KillBillJndiLdapRealm.java
index 7779670..ac15666 100644
--- a/util/src/main/java/org/killbill/billing/util/security/shiro/realm/KillBillJndiLdapRealm.java
+++ b/util/src/main/java/org/killbill/billing/util/security/shiro/realm/KillBillJndiLdapRealm.java
@@ -39,16 +39,17 @@ import org.apache.shiro.realm.ldap.JndiLdapRealm;
import org.apache.shiro.realm.ldap.LdapContextFactory;
import org.apache.shiro.realm.ldap.LdapUtils;
import org.apache.shiro.subject.PrincipalCollection;
+import org.killbill.billing.util.config.definition.SecurityConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.killbill.billing.util.config.definition.SecurityConfig;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
+import com.google.common.base.Functions;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.base.Splitter;
+import com.google.common.collect.Collections2;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterators;
@@ -74,6 +75,7 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
private final String groupSearchFilter;
private final String groupNameId;
private final Map<String, Collection<String>> permissionsByGroup = Maps.newLinkedHashMap();
+ private final String dnSearchFilter;
@Inject
public KillBillJndiLdapRealm(final SecurityConfig securityConfig) {
@@ -87,6 +89,7 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
if (securityConfig.disableShiroLDAPSSLCheck()) {
contextFactory.getEnvironment().put("java.naming.ldap.factory.socket", SkipSSLCheckSocketFactory.class.getName());
}
+ contextFactory.getEnvironment().put("java.naming.referral", securityConfig.followShiroLDAPReferrals() ? "follow" : "ignore");
if (securityConfig.getShiroLDAPUrl() != null) {
contextFactory.setUrl(securityConfig.getShiroLDAPUrl());
}
@@ -101,6 +104,8 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
}
setContextFactory(contextFactory);
+ dnSearchFilter = securityConfig.getShiroLDAPDnSearchTemplate();
+
searchBase = securityConfig.getShiroLDAPSearchBase();
groupSearchFilter = securityConfig.getShiroLDAPGroupSearchFilter();
groupNameId = securityConfig.getShiroLDAPGroupNameID();
@@ -110,8 +115,10 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
// When passing properties on the command line, \n can be escaped
ini.load(securityConfig.getShiroLDAPPermissionsByGroup().replace("\\n", "\n"));
for (final Section section : ini.getSections()) {
- for (final String role : section.keySet()) {
- final Collection<String> permissions = ImmutableList.<String>copyOf(SPLITTER.split(section.get(role)));
+ for (final String rawRole : section.keySet()) {
+ // Un-escape manually = (required if the role name is a DN)
+ final Collection<String> permissions = ImmutableList.<String>copyOf(SPLITTER.split(section.get(rawRole)));
+ final String role = rawRole.replace("\\=", "=");
permissionsByGroup.put(role, permissions);
}
}
@@ -119,6 +126,35 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
}
@Override
+ protected String getUserDn(final String principal) throws IllegalArgumentException, IllegalStateException {
+ if (dnSearchFilter != null) {
+ return findUserDN(principal, getContextFactory());
+ } else {
+ // Use template
+ return super.getUserDn(principal);
+ }
+ }
+
+ private String findUserDN(final String userName, final LdapContextFactory ldapContextFactory) {
+ LdapContext systemLdapCtx = null;
+ try {
+ systemLdapCtx = ldapContextFactory.getSystemLdapContext();
+ final NamingEnumeration<SearchResult> usersFound = systemLdapCtx.search(searchBase,
+ dnSearchFilter.replace(USERDN_SUBSTITUTION_TOKEN, userName),
+ SUBTREE_SCOPE);
+ return usersFound.hasMore() ? usersFound.next().getNameInNamespace() : null;
+ } catch (final AuthenticationException ex) {
+ log.info("LDAP authentication exception='{}'", ex.getLocalizedMessage());
+ throw new IllegalArgumentException(ex);
+ } catch (final NamingException e) {
+ log.info("LDAP exception='{}'", e.getLocalizedMessage());
+ throw new IllegalArgumentException(e);
+ } finally {
+ LdapUtils.closeContext(systemLdapCtx);
+ }
+ }
+
+ @Override
protected AuthorizationInfo queryForAuthorizationInfo(final PrincipalCollection principals, final LdapContextFactory ldapContextFactory) throws NamingException {
final Set<String> userGroups = findLDAPGroupsForUser(principals, ldapContextFactory);
@@ -136,7 +172,7 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
try {
systemLdapCtx = ldapContextFactory.getSystemLdapContext();
return findLDAPGroupsForUser(username, systemLdapCtx);
- } catch (AuthenticationException ex) {
+ } catch (final AuthenticationException ex) {
log.info("LDAP authentication exception='{}'", ex.getLocalizedMessage());
return ImmutableSet.<String>of();
} finally {
@@ -149,21 +185,20 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
groupSearchFilter.replace(USERDN_SUBSTITUTION_TOKEN, userName),
SUBTREE_SCOPE);
+ if (!foundGroups.hasMoreElements()) {
+ return ImmutableSet.<String>of();
+ }
+
+ // There should really only be one entry
+ final SearchResult result = foundGroups.next();
+
// Extract the name of all the groups
- final Iterator<SearchResult> groupsIterator = Iterators.<SearchResult>forEnumeration(foundGroups);
- final Iterator<String> groupsNameIterator = Iterators.<SearchResult, String>transform(groupsIterator,
- new Function<SearchResult, String>() {
- @Override
- public String apply(final SearchResult groupEntry) {
- return extractGroupNameFromSearchResult(groupEntry);
- }
- });
- final Iterator<String> finalGroupsNameIterator = Iterators.<String>filter(groupsNameIterator, Predicates.notNull());
-
- return Sets.newHashSet(finalGroupsNameIterator);
+ final Collection<String> finalGroupsNames = Collections2.<String>filter(extractGroupNamesFromSearchResult(result), Predicates.notNull());
+
+ return Sets.newHashSet(finalGroupsNames);
}
- private String extractGroupNameFromSearchResult(final SearchResult searchResult) {
+ private Collection<String> extractGroupNamesFromSearchResult(final SearchResult searchResult) {
// Get all attributes for that group
final Iterator<? extends Attribute> attributesIterator = Iterators.forEnumeration(searchResult.getAttributes().getAll());
@@ -178,31 +213,23 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
// Extract the group name from the attribute
// Note: at this point, groupNameAttributesIterator should really contain a single element
- final Iterator<String> groupNamesIterator = Iterators.transform(groupNameAttributesIterator,
- new Function<Attribute, String>() {
+ final Iterator<Iterator<?>> groupNamesIterator = Iterators.transform(groupNameAttributesIterator,
+ new Function<Attribute, Iterator<?>>() {
@Override
- public String apply(final Attribute groupNameAttribute) {
+ public Iterator<?> apply(final Attribute groupNameAttribute) {
try {
final NamingEnumeration<?> enumeration = groupNameAttribute.getAll();
- if (enumeration.hasMore()) {
- return enumeration.next().toString();
- } else {
- return null;
- }
- } catch (NamingException namingException) {
- log.warn("Unable to read group name", namingException);
+ return Iterators.forEnumeration(enumeration);
+ } catch (final NamingException namingException) {
+ log.warn("Unable to read group name(s)", namingException);
return null;
}
}
});
- final Iterator<String> finalGroupNamesIterator = Iterators.<String>filter(groupNamesIterator, Predicates.notNull());
- if (finalGroupNamesIterator.hasNext()) {
- return finalGroupNamesIterator.next();
- } else {
- log.warn("Unable to find an attribute matching {}", groupNameId);
- return null;
- }
+ final Iterator<?> finalGroupNamesIterator = Iterators.filter(Iterators.concat(groupNamesIterator), Predicates.notNull());
+
+ return ImmutableList.<String>copyOf(Iterators.transform(finalGroupNamesIterator, Functions.toStringFunction()));
}
private Set<String> groupsPermissions(final Set<String> groups) {