]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/extra/m_sqlauth.cpp
Try this as the ssl crash fix
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlauth.cpp
index 8e2ea34f54c09c17b142dcf71d21601dfd5ed801..4d0cc6a761139b572f11118b881d19d7e184aab9 100644 (file)
  *       | Inspire Internet Relay Chat Daemon |
  *       +------------------------------------+
  *
- *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
- *                       E-mail:
- *                <brain@chatspike.net>
- *               <Craig@chatspike.net>
- *               <omster@gmail.com>
- *     
- * Written by Craig Edwards, Craig McLure, and others.
+ *  InspIRCd: (C) 2002-2009 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
  * This program is free but copyrighted software; see
  *            the file COPYING for details.
  *
  * ---------------------------------------------------
  */
 
-#include <string>
-
-#include "users.h"
-#include "channels.h"
-#include "modules.h"
 #include "inspircd.h"
-#include "helperfuncs.h"
 #include "m_sqlv2.h"
 #include "m_sqlutils.h"
+#include "m_hash.h"
 
 /* $ModDesc: Allow/Deny connections based upon an arbitary SQL table */
+/* $ModDep: m_sqlv2.h m_sqlutils.h m_hash.h */
 
 class ModuleSQLAuth : public Module
 {
-       Server* Srv;
        Module* SQLutils;
+       Module* SQLprovider;
 
-       std::string usertable;
-       std::string userfield;
-       std::string passfield;
-       std::string encryption;
+       std::string freeformquery;
        std::string killreason;
        std::string allowpattern;
        std::string databaseid;
-       
+
        bool verbose;
-       
+
 public:
-       ModuleSQLAuth(Server* Me)
-       : Module::Module(Me), Srv(Me)
+       ModuleSQLAuth(InspIRCd* Me)
+       : Module(Me)
        {
-               SQLutils = Srv->FindFeature("SQLutils");
-               
-               if(SQLutils)
-               {
-                       log(DEBUG, "Successfully got SQLutils pointer");
-               }
-               else
-               {
-                       log(DEFAULT, "ERROR: This module requires a module offering the 'SQLutils' feature (usually m_sqlutils.so). Please load it and try again.");
-                       throw ModuleException("This module requires a module offering the 'SQLutils' feature (usually m_sqlutils.so). Please load it and try again.");
-               }
-                               
-               OnRehash("");
+               ServerInstance->Modules->UseInterface("SQLutils");
+               ServerInstance->Modules->UseInterface("SQL");
+
+               SQLutils = ServerInstance->Modules->Find("m_sqlutils.so");
+               if (!SQLutils)
+                       throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so.");
+
+               SQLprovider = ServerInstance->Modules->FindFeature("SQL");
+               if (!SQLprovider)
+                       throw ModuleException("Can't find an SQL provider module. Please load one before attempting to load m_sqlauth.");
+
+               OnRehash(NULL,"");
+               Implementation eventlist[] = { I_OnUserDisconnect, I_OnCheckReady, I_OnRequest, I_OnRehash, I_OnUserRegister };
+               ServerInstance->Modules->Attach(eventlist, this, 5);
        }
 
-       void Implements(char* List)
+       virtual ~ModuleSQLAuth()
        {
-               List[I_OnUserDisconnect] = List[I_OnCheckReady] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = 1;
+               ServerInstance->Modules->DoneWithInterface("SQL");
+               ServerInstance->Modules->DoneWithInterface("SQLutils");
        }
 
-       virtual void OnRehash(const std::string &parameter)
+
+       virtual void OnRehash(User* user, const std::string &parameter)
        {
-               ConfigReader Conf;
-               
-               usertable       = Conf.ReadValue("sqlauth", "usertable", 0);    /* User table name */
+               ConfigReader Conf(ServerInstance);
+
                databaseid      = Conf.ReadValue("sqlauth", "dbid", 0);                 /* Database ID, given to the SQL service provider */
-               userfield       = Conf.ReadValue("sqlauth", "userfield", 0);    /* Field name where username can be found */
-               passfield       = Conf.ReadValue("sqlauth", "passfield", 0);    /* Field name where password can be found */
+               freeformquery   = Conf.ReadValue("sqlauth", "query", 0);        /* Field name where username can be found */
                killreason      = Conf.ReadValue("sqlauth", "killreason", 0);   /* Reason to give when access is denied to a user (put your reg details here) */
-               allowpattern= Conf.ReadValue("sqlauth", "allowpattern",0 );     /* Allow nicks matching this pattern without requiring auth */
-               encryption      = Conf.ReadValue("sqlauth", "encryption", 0);   /* Name of sql function used to encrypt password, e.g. "md5" or "passwd".
-                                                                                                                                        * define, but leave blank if no encryption is to be used.
-                                                                                                                                        */
+               allowpattern    = Conf.ReadValue("sqlauth", "allowpattern",0 ); /* Allow nicks matching this pattern without requiring auth */
                verbose         = Conf.ReadFlag("sqlauth", "verbose", 0);               /* Set to true if failed connects should be reported to operators */
-               
-               if (encryption.find("(") == std::string::npos)
+       }
+
+       virtual int OnUserRegister(User* user)
+       {
+               if ((!allowpattern.empty()) && (InspIRCd::Match(user->nick,allowpattern)))
                {
-                       encryption.append("(");
+                       user->Extend("sqlauthed");
+                       return 0;
                }
-       }       
 
-       virtual void OnUserRegister(userrec* user)
-       {
-               if ((allowpattern != "") && (Srv->MatchText(user->nick,allowpattern)))
-                       return;
-               
                if (!CheckCredentials(user))
                {
-                       Srv->QuitUser(user,killreason);
+                       ServerInstance->Users->QuitUser(user, killreason);
+                       return 1;
                }
+               return 0;
        }
 
-       bool CheckCredentials(userrec* user)
+       bool CheckCredentials(User* user)
        {
-               bool found;
-               Module* target;
-               
-               found = false;
-               target = Srv->FindFeature("SQL");
-               
-               if(target)
+               std::string thisquery = freeformquery;
+               std::string safepass = user->password;
+               std::string safegecos = user->fullname;
+
+               /* Search and replace the escaped nick and escaped pass into the query */
+
+               SearchAndReplace(safepass, std::string("\""), std::string("\\\""));
+               SearchAndReplace(safegecos, std::string("\""), std::string("\\\""));
+
+               SearchAndReplace(thisquery, std::string("$nick"), user->nick);
+               SearchAndReplace(thisquery, std::string("$pass"), safepass);
+               SearchAndReplace(thisquery, std::string("$host"), user->host);
+               SearchAndReplace(thisquery, std::string("$ip"), std::string(user->GetIPString()));
+               SearchAndReplace(thisquery, std::string("$gecos"), safegecos);
+               SearchAndReplace(thisquery, std::string("$ident"), user->ident);
+               SearchAndReplace(thisquery, std::string("$server"), std::string(user->server));
+               SearchAndReplace(thisquery, std::string("$uuid"), user->uuid);
+
+               Module* HashMod = ServerInstance->Modules->Find("m_md5.so");
+
+               if (HashMod)
                {
-                       SQLrequest req = SQLreq(this, target, databaseid, "SELECT ? FROM ? WHERE ? = '?' AND ? = ?'?')", userfield, usertable, userfield, user->nick, passfield, encryption, user->password);
-                       
-                       if(req.Send())
-                       {
-                               /* When we get the query response from the service provider we will be given an ID to play with,
-                                * just an ID number which is unique to this query. We need a way of associating that ID with a userrec
-                                * so we insert it into a map mapping the IDs to users.
-                                * This isn't quite enough though, as if the user quit while the query was in progress then when the result
-                                * came to be processed we'd get an invalid userrec* out of the map. Now we *could* solve this by watching
-                                * OnUserDisconnect() and iterating the map every time someone quits to make sure they didn't have any queries
-                                * in progress, but that would be relatively slow and inefficient. Instead (thanks to w00t ;p) we attach a list
-                                * of query IDs associated with it to the userrec, so in OnUserDisconnect() we can remove it immediately.
-                                */
-                               log(DEBUG, "Sent query, got given ID %lu", req.id);
-                               
-                               AssociateUser(this, SQLutils, req.id, user).Send();
-                                       
-                               return true;
-                       }
-                       else
-                       {
-                               log(DEBUG, "SQLrequest failed: %s", req.error.Str());
-                       
-                               if (verbose)
-                                       WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, req.error.Str());
-                       
-                               return false;
-                       }
+                       HashResetRequest(this, HashMod).Send();
+                       SearchAndReplace(thisquery, std::string("$md5pass"), std::string(HashSumRequest(this, HashMod, user->password).Send()));
+               }
+
+               HashMod = ServerInstance->Modules->Find("m_sha256.so");
+
+               if (HashMod)
+               {
+                       HashResetRequest(this, HashMod).Send();
+                       SearchAndReplace(thisquery, std::string("$sha256pass"), std::string(HashSumRequest(this, HashMod, user->password).Send()));
+               }
+
+               /* Build the query */
+               SQLrequest req = SQLrequest(this, SQLprovider, databaseid, SQLquery(thisquery));
+
+               if(req.Send())
+               {
+                       /* When we get the query response from the service provider we will be given an ID to play with,
+                        * just an ID number which is unique to this query. We need a way of associating that ID with a User
+                        * so we insert it into a map mapping the IDs to users.
+                        * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the
+                        * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling
+                        * us to discard the query.
+                        */
+                       AssociateUser(this, SQLutils, req.id, user).Send();
+
+                       return true;
                }
                else
                {
-                       log(SPARSE, "WARNING: Couldn't find SQL provider module. NOBODY will be allowed to connect until it comes back unless they match an exception");
+                       if (verbose)
+                               ServerInstance->SNO->WriteToSnoMask('A', "Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick.c_str(), user->ident.c_str(), user->host.c_str(), req.error.Str());
                        return false;
                }
        }
-       
-       virtual char* OnRequest(Request* request)
+
+       virtual const char* OnRequest(Request* request)
        {
-               if(strcmp(SQLRESID, request->GetData()) == 0)
+               if(strcmp(SQLRESID, request->GetId()) == 0)
                {
-                       SQLresult* res;
-               
-                       res = static_cast<SQLresult*>(request);
-                       
-                       log(DEBUG, "Got SQL result (%s) with ID %lu", res->GetData(), res->id);
-                       
-                       userrec* user = GetAssocUser(this, SQLutils, res->id).S().user;
+                       SQLresult* res = static_cast<SQLresult*>(request);
+
+                       User* user = GetAssocUser(this, SQLutils, res->id).S().user;
                        UnAssociate(this, SQLutils, res->id).S();
-                       
+
                        if(user)
                        {
-                               if(res->error.Id() == NO_ERROR)
-                               {                               
-                                       log(DEBUG, "Associated query ID %lu with user %s", res->id, user->nick);                        
-                                       log(DEBUG, "Got result with %d rows and %d columns", res->Rows(), res->Cols());
-                       
+                               if(res->error.Id() == SQL_NO_ERROR)
+                               {
                                        if(res->Rows())
                                        {
                                                /* We got a row in the result, this is enough really */
@@ -172,78 +166,46 @@ public:
                                        else if (verbose)
                                        {
                                                /* No rows in result, this means there was no record matching the user */
-                                               WriteOpers("Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick, user->ident, user->host);
+                                               ServerInstance->SNO->WriteToSnoMask('A', "Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick.c_str(), user->ident.c_str(), user->host.c_str());
                                                user->Extend("sqlauth_failed");
                                        }
                                }
                                else if (verbose)
                                {
-                                       log(DEBUG, "Query failed: %s", res->error.Str());
-                                       WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, res->error.Str());
+                                       ServerInstance->SNO->WriteToSnoMask('A', "Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick.c_str(), user->ident.c_str(), user->host.c_str(), res->error.Str());
                                        user->Extend("sqlauth_failed");
                                }
                        }
                        else
                        {
-                               log(DEBUG, "Got query with unknown ID, this probably means the user quit while the query was in progress");
+                               return NULL;
+                       }
+
+                       if (!user->GetExt("sqlauthed"))
+                       {
+                               ServerInstance->Users->QuitUser(user, killreason);
                        }
-               
                        return SQLSUCCESS;
                }
-               
-               log(DEBUG, "Got unsupported API version string: %s", request->GetData());
-               
                return NULL;
        }
-       
-       virtual void OnUserDisconnect(userrec* user)
+
+       virtual void OnUserDisconnect(User* user)
        {
                user->Shrink("sqlauthed");
-               user->Shrink("sqlauth_failed");         
+               user->Shrink("sqlauth_failed");
        }
-       
-       virtual bool OnCheckReady(userrec* user)
+
+       virtual bool OnCheckReady(User* user)
        {
-               if(user->GetExt("sqlauth_failed"))
-               {
-                       Srv->QuitUser(user,killreason);
-                       return false;
-               }
-               
                return user->GetExt("sqlauthed");
        }
 
-       virtual ~ModuleSQLAuth()
-       {
-       }
-       
        virtual Version GetVersion()
        {
-               return Version(1,0,1,0,VF_VENDOR);
+               return Version("$Id$", VF_VENDOR, API_VERSION);
        }
-       
-};
 
-class ModuleSQLAuthFactory : public ModuleFactory
-{
- public:
-       ModuleSQLAuthFactory()
-       {
-       }
-       
-       ~ModuleSQLAuthFactory()
-       {
-       }
-       
-       virtual Module * CreateModule(Server* Me)
-       {
-               return new ModuleSQLAuth(Me);
-       }
-       
 };
 
-
-extern "C" void * init_module( void )
-{
-       return new ModuleSQLAuthFactory;
-}
+MODULE_INIT(ModuleSQLAuth)