killbill-uncached

util: various LDAP improvements * Implement search-then-bind

8/28/2017 2:30:21 AM

Details

diff --git a/util/src/main/java/org/killbill/billing/util/config/definition/SecurityConfig.java b/util/src/main/java/org/killbill/billing/util/config/definition/SecurityConfig.java
index 24d3c2d..f0e5c99 100644
--- a/util/src/main/java/org/killbill/billing/util/config/definition/SecurityConfig.java
+++ b/util/src/main/java/org/killbill/billing/util/config/definition/SecurityConfig.java
@@ -41,6 +41,11 @@ public interface SecurityConfig extends KillbillConfig {
     @Description("LDAP server's User DN format (e.g. uid={0},ou=users,dc=mycompany,dc=com)")
     public String getShiroLDAPUserDnTemplate();
 
+    @Config("org.killbill.security.ldap.dnSearchTemplate")
+    @DefaultNull
+    @Description("LDAP server's DN search template (e.g. sAMAccountName={0}) for search-then-bind authentication (in case a static DN format template isn't enough)")
+    public String getShiroLDAPDnSearchTemplate();
+
     @Config("org.killbill.security.ldap.searchBase")
     @DefaultNull
     @Description("LDAP search base to use")
@@ -88,6 +93,11 @@ public interface SecurityConfig extends KillbillConfig {
     @Description("Whether to ignore SSL certificates checks")
     public boolean disableShiroLDAPSSLCheck();
 
+    @Config("org.killbill.security.ldap.followReferrals")
+    @Default("false")
+    @Description("Whether to follow referrals")
+    public boolean followShiroLDAPReferrals();
+
     // Okta realm
 
     @Config("org.killbill.security.okta.url")
diff --git a/util/src/main/java/org/killbill/billing/util/security/shiro/realm/KillBillJndiLdapRealm.java b/util/src/main/java/org/killbill/billing/util/security/shiro/realm/KillBillJndiLdapRealm.java
index 7779670..ac15666 100644
--- a/util/src/main/java/org/killbill/billing/util/security/shiro/realm/KillBillJndiLdapRealm.java
+++ b/util/src/main/java/org/killbill/billing/util/security/shiro/realm/KillBillJndiLdapRealm.java
@@ -39,16 +39,17 @@ import org.apache.shiro.realm.ldap.JndiLdapRealm;
 import org.apache.shiro.realm.ldap.LdapContextFactory;
 import org.apache.shiro.realm.ldap.LdapUtils;
 import org.apache.shiro.subject.PrincipalCollection;
+import org.killbill.billing.util.config.definition.SecurityConfig;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.killbill.billing.util.config.definition.SecurityConfig;
-
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Function;
+import com.google.common.base.Functions;
 import com.google.common.base.Predicate;
 import com.google.common.base.Predicates;
 import com.google.common.base.Splitter;
+import com.google.common.collect.Collections2;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterators;
@@ -74,6 +75,7 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
     private final String groupSearchFilter;
     private final String groupNameId;
     private final Map<String, Collection<String>> permissionsByGroup = Maps.newLinkedHashMap();
+    private final String dnSearchFilter;
 
     @Inject
     public KillBillJndiLdapRealm(final SecurityConfig securityConfig) {
@@ -87,6 +89,7 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
         if (securityConfig.disableShiroLDAPSSLCheck()) {
             contextFactory.getEnvironment().put("java.naming.ldap.factory.socket", SkipSSLCheckSocketFactory.class.getName());
         }
+        contextFactory.getEnvironment().put("java.naming.referral", securityConfig.followShiroLDAPReferrals() ? "follow" : "ignore");
         if (securityConfig.getShiroLDAPUrl() != null) {
             contextFactory.setUrl(securityConfig.getShiroLDAPUrl());
         }
@@ -101,6 +104,8 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
         }
         setContextFactory(contextFactory);
 
+        dnSearchFilter = securityConfig.getShiroLDAPDnSearchTemplate();
+
         searchBase = securityConfig.getShiroLDAPSearchBase();
         groupSearchFilter = securityConfig.getShiroLDAPGroupSearchFilter();
         groupNameId = securityConfig.getShiroLDAPGroupNameID();
@@ -110,8 +115,10 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
             // When passing properties on the command line, \n can be escaped
             ini.load(securityConfig.getShiroLDAPPermissionsByGroup().replace("\\n", "\n"));
             for (final Section section : ini.getSections()) {
-                for (final String role : section.keySet()) {
-                    final Collection<String> permissions = ImmutableList.<String>copyOf(SPLITTER.split(section.get(role)));
+                for (final String rawRole : section.keySet()) {
+                    // Un-escape manually = (required if the role name is a DN)
+                    final Collection<String> permissions = ImmutableList.<String>copyOf(SPLITTER.split(section.get(rawRole)));
+                    final String role = rawRole.replace("\\=", "=");
                     permissionsByGroup.put(role, permissions);
                 }
             }
@@ -119,6 +126,35 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
     }
 
     @Override
+    protected String getUserDn(final String principal) throws IllegalArgumentException, IllegalStateException {
+        if (dnSearchFilter != null) {
+            return findUserDN(principal, getContextFactory());
+        } else {
+            // Use template
+            return super.getUserDn(principal);
+        }
+    }
+
+    private String findUserDN(final String userName, final LdapContextFactory ldapContextFactory) {
+        LdapContext systemLdapCtx = null;
+        try {
+            systemLdapCtx = ldapContextFactory.getSystemLdapContext();
+            final NamingEnumeration<SearchResult> usersFound = systemLdapCtx.search(searchBase,
+                                                                                    dnSearchFilter.replace(USERDN_SUBSTITUTION_TOKEN, userName),
+                                                                                    SUBTREE_SCOPE);
+            return usersFound.hasMore() ? usersFound.next().getNameInNamespace() : null;
+        } catch (final AuthenticationException ex) {
+            log.info("LDAP authentication exception='{}'", ex.getLocalizedMessage());
+            throw new IllegalArgumentException(ex);
+        } catch (final NamingException e) {
+            log.info("LDAP exception='{}'", e.getLocalizedMessage());
+            throw new IllegalArgumentException(e);
+        } finally {
+            LdapUtils.closeContext(systemLdapCtx);
+        }
+    }
+
+    @Override
     protected AuthorizationInfo queryForAuthorizationInfo(final PrincipalCollection principals, final LdapContextFactory ldapContextFactory) throws NamingException {
         final Set<String> userGroups = findLDAPGroupsForUser(principals, ldapContextFactory);
 
@@ -136,7 +172,7 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
         try {
             systemLdapCtx = ldapContextFactory.getSystemLdapContext();
             return findLDAPGroupsForUser(username, systemLdapCtx);
-        } catch (AuthenticationException ex) {
+        } catch (final AuthenticationException ex) {
             log.info("LDAP authentication exception='{}'", ex.getLocalizedMessage());
             return ImmutableSet.<String>of();
         } finally {
@@ -149,21 +185,20 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
                                                                            groupSearchFilter.replace(USERDN_SUBSTITUTION_TOKEN, userName),
                                                                            SUBTREE_SCOPE);
 
+        if (!foundGroups.hasMoreElements()) {
+            return ImmutableSet.<String>of();
+        }
+
+        // There should really only be one entry
+        final SearchResult result = foundGroups.next();
+
         // Extract the name of all the groups
-        final Iterator<SearchResult> groupsIterator = Iterators.<SearchResult>forEnumeration(foundGroups);
-        final Iterator<String> groupsNameIterator = Iterators.<SearchResult, String>transform(groupsIterator,
-                                                                                              new Function<SearchResult, String>() {
-                                                                                                  @Override
-                                                                                                  public String apply(final SearchResult groupEntry) {
-                                                                                                      return extractGroupNameFromSearchResult(groupEntry);
-                                                                                                  }
-                                                                                              });
-        final Iterator<String> finalGroupsNameIterator = Iterators.<String>filter(groupsNameIterator, Predicates.notNull());
-
-        return Sets.newHashSet(finalGroupsNameIterator);
+        final Collection<String> finalGroupsNames = Collections2.<String>filter(extractGroupNamesFromSearchResult(result), Predicates.notNull());
+
+        return Sets.newHashSet(finalGroupsNames);
     }
 
-    private String extractGroupNameFromSearchResult(final SearchResult searchResult) {
+    private Collection<String> extractGroupNamesFromSearchResult(final SearchResult searchResult) {
         // Get all attributes for that group
         final Iterator<? extends Attribute> attributesIterator = Iterators.forEnumeration(searchResult.getAttributes().getAll());
 
@@ -178,31 +213,23 @@ public class KillBillJndiLdapRealm extends JndiLdapRealm {
 
         // Extract the group name from the attribute
         // Note: at this point, groupNameAttributesIterator should really contain a single element
-        final Iterator<String> groupNamesIterator = Iterators.transform(groupNameAttributesIterator,
-                                                                        new Function<Attribute, String>() {
+        final Iterator<Iterator<?>> groupNamesIterator = Iterators.transform(groupNameAttributesIterator,
+                                                                        new Function<Attribute, Iterator<?>>() {
                                                                             @Override
-                                                                            public String apply(final Attribute groupNameAttribute) {
+                                                                            public Iterator<?> apply(final Attribute groupNameAttribute) {
                                                                                 try {
                                                                                     final NamingEnumeration<?> enumeration = groupNameAttribute.getAll();
-                                                                                    if (enumeration.hasMore()) {
-                                                                                        return enumeration.next().toString();
-                                                                                    } else {
-                                                                                        return null;
-                                                                                    }
-                                                                                } catch (NamingException namingException) {
-                                                                                    log.warn("Unable to read group name", namingException);
+                                                                                    return Iterators.forEnumeration(enumeration);
+                                                                                } catch (final NamingException namingException) {
+                                                                                    log.warn("Unable to read group name(s)", namingException);
                                                                                     return null;
                                                                                 }
                                                                             }
                                                                         });
-        final Iterator<String> finalGroupNamesIterator = Iterators.<String>filter(groupNamesIterator, Predicates.notNull());
 
-        if (finalGroupNamesIterator.hasNext()) {
-            return finalGroupNamesIterator.next();
-        } else {
-            log.warn("Unable to find an attribute matching {}", groupNameId);
-            return null;
-        }
+        final Iterator<?> finalGroupNamesIterator = Iterators.filter(Iterators.concat(groupNamesIterator), Predicates.notNull());
+
+        return ImmutableList.<String>copyOf(Iterators.transform(finalGroupNamesIterator, Functions.toStringFunction()));
     }
 
     private Set<String> groupsPermissions(final Set<String> groups) {