]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/m_sqloper.cpp
Delete modewatchers when unloading modules that use them to keep the server from...
[user/henk/code/inspircd.git] / src / modules / m_sqloper.cpp
index 5f0df4c1064984d568991ebf1fb549d9d3e2d13e..3ee47953b4de4beeee066fad143ebc0a0292d96e 100644 (file)
@@ -27,50 +27,22 @@ class ModuleSQLOper : public Module
        LocalStringExt saved_pass;
        Module* SQLutils;
        std::string databaseid;
-       irc::string hashtype;
-       hashymodules hashers;
-       bool diduseiface;
+       std::string hashtype;
        parameterlist names;
 
 public:
        ModuleSQLOper() : saved_user("sqloper_user", this), saved_pass("sqloper_pass", this)
        {
-               ServerInstance->Modules->UseInterface("SQLutils");
-               ServerInstance->Modules->UseInterface("SQL");
-               ServerInstance->Modules->UseInterface("HashRequest");
-
                OnRehash(NULL);
 
-               diduseiface = false;
-
-               /* Find all modules which implement the interface 'HashRequest' */
-               modulelist* ml = ServerInstance->Modules->FindInterface("HashRequest");
-
-               /* Did we find any modules? */
-               if (ml)
-               {
-                       /* Yes, enumerate them all to find out the hashing algorithm name */
-                       for (modulelist::iterator m = ml->begin(); m != ml->end(); m++)
-                       {
-                               /* Make a request to it for its name, its implementing
-                                * HashRequest so we know its safe to do this
-                                */
-                               std::string name = HashNameRequest(this, *m).Send();
-                               /* Build a map of them */
-                               hashers[name.c_str()] = *m;
-                               names.push_back(name);
-                       }
-                       /* UseInterface doesn't do anything if there are no providers, so we'll have to call it later if a module gets loaded later on. */
-                       diduseiface = true;
-                       ServerInstance->Modules->UseInterface("HashRequest");
-               }
-
                SQLutils = ServerInstance->Modules->Find("m_sqlutils.so");
                if (!SQLutils)
                        throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqloper.so.");
 
                Implementation eventlist[] = { I_OnRehash, I_OnPreCommand, I_OnLoadModule };
-               ServerInstance->Modules->Attach(eventlist, this, 4);
+               ServerInstance->Modules->Attach(eventlist, this, 3);
+               ServerInstance->Modules->AddService(saved_user);
+               ServerInstance->Modules->AddService(saved_pass);
        }
 
        bool OneOfMatches(const char* host, const char* ip, const char* hostlist)
@@ -87,37 +59,12 @@ public:
                return false;
        }
 
-       virtual void OnLoadModule(Module* mod, const std::string& name)
-       {
-               if (ServerInstance->Modules->ModuleHasInterface(mod, "HashRequest"))
-               {
-                       ServerInstance->Logs->Log("m_sqloper",DEBUG, "Post-load registering hasher: %s", name.c_str());
-                       std::string sname = HashNameRequest(this, mod).Send();
-                       hashers[sname.c_str()] = mod;
-                       names.push_back(sname);
-                       if (!diduseiface)
-                       {
-                               ServerInstance->Modules->UseInterface("HashRequest");
-                               diduseiface = true;
-                       }
-               }
-       }
-
-       virtual ~ModuleSQLOper()
-       {
-               ServerInstance->Modules->DoneWithInterface("SQL");
-               ServerInstance->Modules->DoneWithInterface("SQLutils");
-               if (diduseiface)
-                       ServerInstance->Modules->DoneWithInterface("HashRequest");
-       }
-
-
        virtual void OnRehash(User* user)
        {
                ConfigReader Conf;
 
                databaseid = Conf.ReadValue("sqloper", "dbid", 0); /* Database ID of a database configured for the service provider module */
-               hashtype = assign(Conf.ReadValue("sqloper", "hash", 0));
+               hashtype = Conf.ReadValue("sqloper", "hash", 0);
        }
 
        virtual ModResult OnPreCommand(std::string &command, std::vector<std::string> &parameters, User *user, bool validated, const std::string &original_line)
@@ -139,20 +86,14 @@ public:
 
        bool LookupOper(User* user, const std::string &username, const std::string &password)
        {
-               Module* target;
-
-               target = ServerInstance->Modules->FindFeature("SQL");
-
-               if (target)
+               ServiceProvider* prov = ServerInstance->Modules->FindService(SERVICE_DATA, "SQL");
+               if (prov)
                {
-                       hashymodules::iterator x = hashers.find(hashtype);
-                       if (x == hashers.end())
-                               return false;
+                       Module* target = prov->creator;
+                       HashProvider* hash = ServerInstance->Modules->FindDataService<HashProvider>("hash/" + hashtype);
 
-                       /* Reset hash module first back to MD5 standard state */
-                       HashResetRequest(this, x->second).Send();
                        /* Make an MD5 hash of the password for using in the query */
-                       std::string md5_pass_hash = HashSumRequest(this, x->second, password.c_str()).Send();
+                       std::string md5_pass_hash = hash ? hash->hexsum(password) : password;
 
                        /* We generate our own sum here because some database providers (e.g. SQLite) dont have a builtin md5/sha256 function,
                         * also hashing it in the module and only passing a remote query containing a hash is more secure.
@@ -160,26 +101,19 @@ public:
                        SQLrequest req = SQLrequest(this, target, databaseid,
                                        SQLquery("SELECT username, password, hostname, type FROM ircd_opers WHERE username = '?' AND password='?'") % username % md5_pass_hash);
 
-                       if (req.Send())
-                       {
-                               /* When we get the query response from the service provider we will be given an ID to play with,
-                                * just an ID number which is unique to this query. We need a way of associating that ID with a User
-                                * so we insert it into a map mapping the IDs to users.
-                                * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the
-                                * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling
-                                * us to discard the query.
-                                */
-                               AssociateUser(this, SQLutils, req.id, user).Send();
-
-                               saved_user.set(user, username);
-                               saved_pass.set(user, password);
+                       /* When we get the query response from the service provider we will be given an ID to play with,
+                        * just an ID number which is unique to this query. We need a way of associating that ID with a User
+                        * so we insert it into a map mapping the IDs to users.
+                        * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the
+                        * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling
+                        * us to discard the query.
+                        */
+                       AssociateUser(this, SQLutils, req.id, user).Send();
 
-                               return true;
-                       }
-                       else
-                       {
-                               return false;
-                       }
+                       saved_user.set(user, username);
+                       saved_pass.set(user, password);
+
+                       return true;
                }
                else
                {
@@ -188,11 +122,11 @@ public:
                }
        }
 
-       const char* OnRequest(Request* request)
+       void OnRequest(Request& request)
        {
-               if (strcmp(SQLRESID, request->GetId()) == 0)
+               if (strcmp(SQLRESID, request.id) == 0)
                {
-                       SQLresult* res = static_cast<SQLresult*>(request);
+                       SQLresult* res = static_cast<SQLresult*>(&request);
 
                        User* user = GetAssocUser(this, SQLutils, res->id).S().user;
                        UnAssociate(this, SQLutils, res->id).S();
@@ -225,7 +159,6 @@ public:
                                                                /* If/when one of the rows matches, stop checking and return */
                                                                saved_user.unset(user);
                                                                saved_pass.unset(user);
-                                                               return SQLSUCCESS;
                                                        }
                                                        if (tried_user && tried_pass)
                                                        {
@@ -264,11 +197,7 @@ public:
 
                                }
                        }
-
-                       return SQLSUCCESS;
                }
-
-               return NULL;
        }
 
        void LoginFail(User* user, const std::string &username, const std::string &pass)
@@ -290,26 +219,21 @@ public:
 
        bool OperUser(User* user, const std::string &pattern, const std::string &type)
        {
-               ConfigReader Conf;
+               OperIndex::iterator iter = ServerInstance->Config->oper_blocks.find(" " + type);
+               if (iter == ServerInstance->Config->oper_blocks.end())
+                       return false;
+               OperInfo* ifo = iter->second;
 
-               for (int j = 0; j < Conf.Enumerate("type"); j++)
-               {
-                       std::string tname = Conf.ReadValue("type","name",j);
-                       std::string hostname(user->ident);
+               std::string hostname(user->ident);
 
-                       hostname.append("@").append(user->host);
+               hostname.append("@").append(user->host);
 
-                       if ((tname == type) && OneOfMatches(hostname.c_str(), user->GetIPString(), pattern.c_str()))
-                       {
-                               /* Opertype and host match, looks like this is it. */
-                               std::string operhost = Conf.ReadValue("type", "host", j);
-
-                               if (operhost.size())
-                                       user->ChangeDisplayedHost(operhost.c_str());
+               if (OneOfMatches(hostname.c_str(), user->GetIPString(), pattern.c_str()))
+               {
+                       /* Opertype and host match, looks like this is it. */
 
-                               user->Oper(type, tname);
-                               return true;
-                       }
+                       user->Oper(ifo);
+                       return true;
                }
 
                return false;
@@ -317,7 +241,7 @@ public:
 
        Version GetVersion()
        {
-               return Version("Allows storage of oper credentials in an SQL table", VF_VENDOR, API_VERSION);
+               return Version("Allows storage of oper credentials in an SQL table", VF_VENDOR);
        }
 
 };