]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/extra/m_sqlauth.cpp
pgsql should now work thx to added posibility to force a fd out of the socketengine...
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlauth.cpp
index 8e2ea34f54c09c17b142dcf71d21601dfd5ed801..7130439e0537ad93e4e3d5a21486b6bb68c425fd 100644 (file)
@@ -2,13 +2,9 @@
  *       | Inspire Internet Relay Chat Daemon |
  *       +------------------------------------+
  *
- *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
- *                       E-mail:
- *                <brain@chatspike.net>
- *               <Craig@chatspike.net>
- *               <omster@gmail.com>
- *     
- * Written by Craig Edwards, Craig McLure, and others.
+ *  InspIRCd: (C) 2002-2007 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
  * This program is free but copyrighted software; see
  *            the file COPYING for details.
  *
  */
 
 #include <string>
-
 #include "users.h"
 #include "channels.h"
 #include "modules.h"
 #include "inspircd.h"
-#include "helperfuncs.h"
+
 #include "m_sqlv2.h"
 #include "m_sqlutils.h"
 
 /* $ModDesc: Allow/Deny connections based upon an arbitary SQL table */
+/* $ModDep: m_sqlv2.h m_sqlutils.h */
 
 class ModuleSQLAuth : public Module
 {
-       Server* Srv;
+       InspIRCd* Srv;
        Module* SQLutils;
+       Module* SQLprovider;
 
        std::string usertable;
        std::string userfield;
@@ -43,22 +40,23 @@ class ModuleSQLAuth : public Module
        bool verbose;
        
 public:
-       ModuleSQLAuth(Server* Me)
+       ModuleSQLAuth(InspIRCd* Me)
        : Module::Module(Me), Srv(Me)
        {
-               SQLutils = Srv->FindFeature("SQLutils");
-               
-               if(SQLutils)
-               {
-                       log(DEBUG, "Successfully got SQLutils pointer");
-               }
-               else
-               {
-                       log(DEFAULT, "ERROR: This module requires a module offering the 'SQLutils' feature (usually m_sqlutils.so). Please load it and try again.");
-                       throw ModuleException("This module requires a module offering the 'SQLutils' feature (usually m_sqlutils.so). Please load it and try again.");
-               }
-                               
-               OnRehash("");
+               ServerInstance->UseInterface("SQLutils");
+               ServerInstance->UseInterface("SQL");
+
+               SQLutils = ServerInstance->FindModule("m_sqlutils.so");
+               if (!SQLutils)
+                       throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so.");
+
+               OnRehash(NULL,"");
+       }
+
+       virtual ~ModuleSQLAuth()
+       {
+               ServerInstance->DoneWithInterface("SQL");
+               ServerInstance->DoneWithInterface("SQLutils");
        }
 
        void Implements(char* List)
@@ -66,9 +64,9 @@ public:
                List[I_OnUserDisconnect] = List[I_OnCheckReady] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = 1;
        }
 
-       virtual void OnRehash(const std::string &parameter)
+       virtual void OnRehash(userrec* user, const std::string &parameter)
        {
-               ConfigReader Conf;
+               ConfigReader Conf(Srv);
                
                usertable       = Conf.ReadValue("sqlauth", "usertable", 0);    /* User table name */
                databaseid      = Conf.ReadValue("sqlauth", "dbid", 0);                 /* Database ID, given to the SQL service provider */
@@ -87,23 +85,26 @@ public:
                }
        }       
 
-       virtual void OnUserRegister(userrec* user)
+       virtual int OnUserRegister(userrec* user)
        {
                if ((allowpattern != "") && (Srv->MatchText(user->nick,allowpattern)))
-                       return;
+               {
+                       user->Extend("sqlauthed");
+                       return 0;
+               }
                
                if (!CheckCredentials(user))
                {
-                       Srv->QuitUser(user,killreason);
+                       userrec::QuitUser(Srv,user,killreason);
+                       return 1;
                }
+               return 0;
        }
 
        bool CheckCredentials(userrec* user)
        {
-               bool found;
                Module* target;
                
-               found = false;
                target = Srv->FindFeature("SQL");
                
                if(target)
@@ -115,13 +116,11 @@ public:
                                /* When we get the query response from the service provider we will be given an ID to play with,
                                 * just an ID number which is unique to this query. We need a way of associating that ID with a userrec
                                 * so we insert it into a map mapping the IDs to users.
-                                * This isn't quite enough though, as if the user quit while the query was in progress then when the result
-                                * came to be processed we'd get an invalid userrec* out of the map. Now we *could* solve this by watching
-                                * OnUserDisconnect() and iterating the map every time someone quits to make sure they didn't have any queries
-                                * in progress, but that would be relatively slow and inefficient. Instead (thanks to w00t ;p) we attach a list
-                                * of query IDs associated with it to the userrec, so in OnUserDisconnect() we can remove it immediately.
+                                * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the
+                                * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling
+                                * us to discard the query.
                                 */
-                               log(DEBUG, "Sent query, got given ID %lu", req.id);
+                               ServerInstance->Log(DEBUG, "Sent query, got given ID %lu", req.id);
                                
                                AssociateUser(this, SQLutils, req.id, user).Send();
                                        
@@ -129,30 +128,30 @@ public:
                        }
                        else
                        {
-                               log(DEBUG, "SQLrequest failed: %s", req.error.Str());
+                               ServerInstance->Log(DEBUG, "SQLrequest failed: %s", req.error.Str());
                        
                                if (verbose)
-                                       WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, req.error.Str());
+                                       Srv->WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, req.error.Str());
                        
                                return false;
                        }
                }
                else
                {
-                       log(SPARSE, "WARNING: Couldn't find SQL provider module. NOBODY will be allowed to connect until it comes back unless they match an exception");
+                       ServerInstance->Log(SPARSE, "WARNING: Couldn't find SQL provider module. NOBODY will be allowed to connect until it comes back unless they match an exception");
                        return false;
                }
        }
        
        virtual char* OnRequest(Request* request)
        {
-               if(strcmp(SQLRESID, request->GetData()) == 0)
+               if(strcmp(SQLRESID, request->GetId()) == 0)
                {
                        SQLresult* res;
                
                        res = static_cast<SQLresult*>(request);
                        
-                       log(DEBUG, "Got SQL result (%s) with ID %lu", res->GetData(), res->id);
+                       ServerInstance->Log(DEBUG, "Got SQL result (%s) with ID %lu", res->GetId(), res->id);
                        
                        userrec* user = GetAssocUser(this, SQLutils, res->id).S().user;
                        UnAssociate(this, SQLutils, res->id).S();
@@ -161,8 +160,8 @@ public:
                        {
                                if(res->error.Id() == NO_ERROR)
                                {                               
-                                       log(DEBUG, "Associated query ID %lu with user %s", res->id, user->nick);                        
-                                       log(DEBUG, "Got result with %d rows and %d columns", res->Rows(), res->Cols());
+                                       ServerInstance->Log(DEBUG, "Associated query ID %lu with user %s", res->id, user->nick);                        
+                                       ServerInstance->Log(DEBUG, "Got result with %d rows and %d columns", res->Rows(), res->Cols());
                        
                                        if(res->Rows())
                                        {
@@ -172,26 +171,31 @@ public:
                                        else if (verbose)
                                        {
                                                /* No rows in result, this means there was no record matching the user */
-                                               WriteOpers("Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick, user->ident, user->host);
+                                               Srv->WriteOpers("Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick, user->ident, user->host);
                                                user->Extend("sqlauth_failed");
                                        }
                                }
                                else if (verbose)
                                {
-                                       log(DEBUG, "Query failed: %s", res->error.Str());
-                                       WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, res->error.Str());
+                                       ServerInstance->Log(DEBUG, "Query failed: %s", res->error.Str());
+                                       Srv->WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, res->error.Str());
                                        user->Extend("sqlauth_failed");
                                }
                        }
                        else
                        {
-                               log(DEBUG, "Got query with unknown ID, this probably means the user quit while the query was in progress");
+                               ServerInstance->Log(DEBUG, "Got query with unknown ID, this probably means the user quit while the query was in progress");
+                               return NULL;
+                       }
+
+                       if (!user->GetExt("sqlauthed"))
+                       {
+                               userrec::QuitUser(Srv,user,killreason);
                        }
-               
                        return SQLSUCCESS;
                }
                
-               log(DEBUG, "Got unsupported API version string: %s", request->GetData());
+               ServerInstance->Log(DEBUG, "Got unsupported API version string: %s", request->GetId());
                
                return NULL;
        }
@@ -204,22 +208,12 @@ public:
        
        virtual bool OnCheckReady(userrec* user)
        {
-               if(user->GetExt("sqlauth_failed"))
-               {
-                       Srv->QuitUser(user,killreason);
-                       return false;
-               }
-               
                return user->GetExt("sqlauthed");
        }
 
-       virtual ~ModuleSQLAuth()
-       {
-       }
-       
        virtual Version GetVersion()
        {
-               return Version(1,0,1,0,VF_VENDOR);
+               return Version(1,1,1,0,VF_VENDOR,API_VERSION);
        }
        
 };
@@ -235,7 +229,7 @@ class ModuleSQLAuthFactory : public ModuleFactory
        {
        }
        
-       virtual Module * CreateModule(Server* Me)
+       virtual Module * CreateModule(InspIRCd* Me)
        {
                return new ModuleSQLAuth(Me);
        }