]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/extra/m_ldapauth.cpp
m_ldapauth: fix providing username in PASS
[user/henk/code/inspircd.git] / src / modules / extra / m_ldapauth.cpp
index 2102b7492e55f11a511c6b59712a0d7ea15df35f..405bab082bef6df6522de1327408505ec533c9ef 100644 (file)
 #include <ldap.h>
 
 #ifdef _WIN32
-# pragma comment(lib, "ldap.lib")
-# pragma comment(lib, "lber.lib")
+# pragma comment(lib, "libldap.lib")
+# pragma comment(lib, "liblber.lib")
 #endif
 
 /* $ModDesc: Allow/Deny connections based upon answer from LDAP server */
 /* $LinkerFlags: -lldap */
 
+struct RAIILDAPString
+{
+       char *str;
+
+       RAIILDAPString(char *Str)
+               : str(Str)
+       {
+       }
+
+       ~RAIILDAPString()
+       {
+               ldap_memfree(str);
+       }
+
+       operator char*()
+       {
+               return str;
+       }
+
+       operator std::string()
+       {
+               return str;
+       }
+};
+
+struct RAIILDAPMessage
+{
+       RAIILDAPMessage()
+       {
+       }
+
+       ~RAIILDAPMessage()
+       {
+               dealloc();
+       }
+
+       void dealloc()
+       {
+               ldap_msgfree(msg);
+       }
+
+       operator LDAPMessage*()
+       {
+               return msg;
+       }
+
+       LDAPMessage **operator &()
+       {
+               return &msg;
+       }
+
+       LDAPMessage *msg;
+};
+
 class ModuleLDAPAuth : public Module
 {
        LocalIntExt ldapAuthed;
+       LocalStringExt ldapVhost;
        std::string base;
        std::string attribute;
        std::string ldapserver;
@@ -48,6 +103,7 @@ class ModuleLDAPAuth : public Module
        std::string killreason;
        std::string username;
        std::string password;
+       std::string vhost;
        std::vector<std::string> whitelistedcidrs;
        std::vector<std::pair<std::string, std::string> > requiredattributes;
        int searchscope;
@@ -56,15 +112,19 @@ class ModuleLDAPAuth : public Module
        LDAP *conn;
 
 public:
-       ModuleLDAPAuth() : ldapAuthed("ldapauth", this)
+       ModuleLDAPAuth()
+               : ldapAuthed("ldapauth", this)
+               , ldapVhost("ldapauth_vhost", this)
        {
                conn = NULL;
        }
 
        void init()
        {
-               Implementation eventlist[] = { I_OnCheckReady, I_OnRehash, I_OnUserRegister };
-               ServerInstance->Modules->Attach(eventlist, this, 3);
+               ServerInstance->Modules->AddService(ldapAuthed);
+               ServerInstance->Modules->AddService(ldapVhost);
+               Implementation eventlist[] = { I_OnCheckReady, I_OnRehash,I_OnUserRegister, I_OnUserConnect };
+               ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
                OnRehash(NULL);
        }
 
@@ -88,6 +148,7 @@ public:
                std::string scope       = tag->getString("searchscope");
                username                = tag->getString("binddn");
                password                = tag->getString("bindauth");
+               vhost                   = tag->getString("host");
                verbose                 = tag->getBool("verbose");              /* Set to true if failed connects should be reported to operators */
                useusername             = tag->getBool("userfield");
 
@@ -147,6 +208,42 @@ public:
                return true;
        }
 
+       std::string SafeReplace(const std::string &text, std::map<std::string,
+                       std::string> &replacements)
+       {
+               std::string result;
+               result.reserve(MAXBUF);
+
+               for (unsigned int i = 0; i < text.length(); ++i) {
+                       char c = text[i];
+                       if (c == '$') {
+                               // find the first nonalpha
+                               i++;
+                               unsigned int start = i;
+
+                               while (i < text.length() - 1 && isalpha(text[i + 1]))
+                                       ++i;
+
+                               std::string key = text.substr(start, (i - start) + 1);
+                               result.append(replacements[key]);
+                       } else {
+                               result.push_back(c);
+                       }
+               }
+
+          return result;
+       }
+
+       virtual void OnUserConnect(LocalUser *user)
+       {
+               std::string* cc = ldapVhost.get(user);
+               if (cc)
+               {
+                       user->ChangeDisplayedHost(cc->c_str());
+                       ldapVhost.unset(user);
+               }
+       }
+
        ModResult OnUserRegister(LocalUser* user)
        {
                if ((!allowpattern.empty()) && (InspIRCd::Match(user->nick,allowpattern)))
@@ -212,87 +309,102 @@ public:
                        }
                }
 
-               LDAPMessage *msg, *entry;
-               std::string what = (attribute + "=" + (useusername ? user->ident : user->nick));
-               if ((res = ldap_search_ext_s(conn, base.c_str(), searchscope, what.c_str(), NULL, 0, NULL, NULL, NULL, 0, &msg)) != LDAP_SUCCESS)
+               RAIILDAPMessage msg;
+               std::string what;
+               std::string::size_type pos = user->password.find(':');
+               // If a username is provided in PASS, use it, othewrise user their nick or ident
+               if (pos != std::string::npos)
                {
-                       // Do a second search, based on password, if it contains a :
-                       // That is, PASS <user>:<password> will work.
-                       size_t pos = user->password.find(":");
-                       if (pos != std::string::npos)
-                       {
-                               std::string cutpassword = user->password.substr(0, pos);
-                               res = ldap_search_ext_s(conn, base.c_str(), searchscope, cutpassword.c_str(), NULL, 0, NULL, NULL, NULL, 0, &msg);
-
-                               if (res == LDAP_SUCCESS)
-                               {
-                                       // Trim the user: prefix, leaving just 'pass' for later password check
-                                       user->password = user->password.substr(pos + 1);
-                               }
-                       }
+                       what = (attribute + "=" + user->password.substr(0, pos));
 
-                       // It may have found based on user:pass check above.
-                       if (res != LDAP_SUCCESS)
-                       {
-                               if (verbose)
-                                       ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (LDAP search failed: %s)", user->GetFullRealHost().c_str(), ldap_err2string(res));
-                               return false;
-                       }
+                       // Trim the user: prefix, leaving just 'pass' for later password check
+                       user->password = user->password.substr(pos + 1);
+               }
+               else
+               {
+                       what = (attribute + "=" + (useusername ? user->ident : user->nick));
+               }
+               if ((res = ldap_search_ext_s(conn, base.c_str(), searchscope, what.c_str(), NULL, 0, NULL, NULL, NULL, 0, &msg)) != LDAP_SUCCESS)
+               {
+                       if (verbose)
+                               ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (LDAP search failed: %s)", user->GetFullRealHost().c_str(), ldap_err2string(res));
+                       return false;
                }
                if (ldap_count_entries(conn, msg) > 1)
                {
                        if (verbose)
                                ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (LDAP search returned more than one result: %s)", user->GetFullRealHost().c_str(), ldap_err2string(res));
-                       ldap_msgfree(msg);
                        return false;
                }
+
+               LDAPMessage *entry;
                if ((entry = ldap_first_entry(conn, msg)) == NULL)
                {
                        if (verbose)
                                ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (LDAP search returned no results: %s)", user->GetFullRealHost().c_str(), ldap_err2string(res));
-                       ldap_msgfree(msg);
                        return false;
                }
                cred.bv_val = (char*)user->password.data();
                cred.bv_len = user->password.length();
-               if ((res = ldap_sasl_bind_s(conn, ldap_get_dn(conn, entry), LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL)) != LDAP_SUCCESS)
+               RAIILDAPString DN(ldap_get_dn(conn, entry));
+               if ((res = ldap_sasl_bind_s(conn, DN, LDAP_SASL_SIMPLE, &cred, NULL, NULL, NULL)) != LDAP_SUCCESS)
                {
                        if (verbose)
                                ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (%s)", user->GetFullRealHost().c_str(), ldap_err2string(res));
-                       ldap_msgfree(msg);
                        return false;
                }
 
-               if (requiredattributes.empty())
+               if (!requiredattributes.empty())
                {
-                       ldap_msgfree(msg);
-                       ldapAuthed.set(user,1);
-                       return true;
-               }
+                       bool authed = false;
 
-               bool authed = false;
+                       for (std::vector<std::pair<std::string, std::string> >::const_iterator it = requiredattributes.begin(); it != requiredattributes.end(); ++it)
+                       {
+                               const std::string &attr = it->first;
+                               const std::string &val = it->second;
 
-               for (std::vector<std::pair<std::string, std::string> >::const_iterator it = requiredattributes.begin(); it != requiredattributes.end(); ++it)
-               {
-                       const std::string &attr = it->first;
-                       const std::string &val = it->second;
+                               struct berval attr_value;
+                               attr_value.bv_val = const_cast<char*>(val.c_str());
+                               attr_value.bv_len = val.length();
 
-                       struct berval attr_value;
-                       attr_value.bv_val = const_cast<char*>(val.c_str());
-                       attr_value.bv_len = val.length();
+                               ServerInstance->Logs->Log("m_ldapauth", DEBUG, "LDAP compare: %s=%s", attr.c_str(), val.c_str());
 
-                       ServerInstance->Logs->Log("m_ldapauth", DEBUG, "LDAP compare: %s=%s", attr.c_str(), val.c_str());
+                               authed = (ldap_compare_ext_s(conn, DN, attr.c_str(), &attr_value, NULL, NULL) == LDAP_COMPARE_TRUE);
 
-                       authed = (ldap_compare_ext_s(conn, ldap_get_dn(conn, entry), attr.c_str(), &attr_value, NULL, NULL) == LDAP_COMPARE_TRUE);
+                               if (authed)
+                                       break;
+                       }
 
-                       if (authed)
-                               break;
+                       if (!authed)
+                       {
+                               if (verbose)
+                                       ServerInstance->SNO->WriteToSnoMask('c', "Forbidden connection from %s (Lacks required LDAP attributes)", user->GetFullRealHost().c_str());
+                               return false;
+                       }
                }
 
-               ldap_msgfree(msg);
+               if (!vhost.empty())
+               {
+                       irc::commasepstream stream(DN);
 
-               if (!authed)
-                       return false;
+                       // mashed map of key:value parts of the DN
+                       std::map<std::string, std::string> dnParts;
+
+                       std::string dnPart;
+                       while (stream.GetToken(dnPart))
+                       {
+                               pos = dnPart.find('=');
+                               if (pos == std::string::npos) // malformed
+                                       continue;
+
+                               std::string key = dnPart.substr(0, pos);
+                               std::string value = dnPart.substr(pos + 1, dnPart.length() - pos + 1); // +1s to skip the = itself
+                               dnParts[key] = value;
+                       }
+
+                       // change host according to config key
+                       ldapVhost.set(user, SafeReplace(vhost, dnParts));
+               }
 
                ldapAuthed.set(user,1);
                return true;