]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/extra/m_sqlauth.cpp
fixed some indentation and spacing in modules
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlauth.cpp
index 17381731ba5897ed709a931c661f928fe7bafe38..dcc314af99064249f8404fa1b49e446f6b2dfaa0 100644 (file)
  *       | Inspire Internet Relay Chat Daemon |
  *       +------------------------------------+
  *
- *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
- *                       E-mail:
- *                <brain@chatspike.net>
- *               <Craig@chatspike.net>
- *               <omster@gmail.com>
- *     
- * Written by Craig Edwards, Craig McLure, and others.
+ *  InspIRCd: (C) 2002-2008 InspIRCd Development Team
+ * See: http://www.inspircd.org/wiki/index.php/Credits
+ *
  * This program is free but copyrighted software; see
  *            the file COPYING for details.
  *
  * ---------------------------------------------------
  */
 
-#include <string>
-#include <map>
-
+#include "inspircd.h"
 #include "users.h"
 #include "channels.h"
 #include "modules.h"
-#include "inspircd.h"
-#include "helperfuncs.h"
 #include "m_sqlv2.h"
+#include "m_sqlutils.h"
+#include "m_hash.h"
 
 /* $ModDesc: Allow/Deny connections based upon an arbitary SQL table */
-
-typedef std::map<unsigned int, userrec*> QueryUserMap;
+/* $ModDep: m_sqlv2.h m_sqlutils.h m_hash.h */
 
 class ModuleSQLAuth : public Module
 {
-       Server* Srv;
+       Module* SQLutils;
+       Module* SQLprovider;
 
-       std::string usertable;
-       std::string userfield;
-       std::string passfield;
-       std::string encryption;
+       std::string freeformquery;
        std::string killreason;
        std::string allowpattern;
        std::string databaseid;
-       
+
        bool verbose;
-       
-       QueryUserMap qumap;
-       
+
 public:
-       ModuleSQLAuth(Server* Me)
+       ModuleSQLAuth(InspIRCd* Me)
        : Module::Module(Me)
        {
-               Srv = Me;
-               OnRehash("");
+               ServerInstance->Modules->UseInterface("SQLutils");
+               ServerInstance->Modules->UseInterface("SQL");
+
+               SQLutils = ServerInstance->Modules->Find("m_sqlutils.so");
+               if (!SQLutils)
+                       throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so.");
+
+               SQLprovider = ServerInstance->Modules->FindFeature("SQL");
+               if (!SQLprovider)
+                       throw ModuleException("Can't find an SQL provider module. Please load one before attempting to load m_sqlauth.");
+
+               OnRehash(NULL,"");
+               Implementation eventlist[] = { I_OnUserDisconnect, I_OnCheckReady, I_OnRequest, I_OnRehash, I_OnUserRegister };
+               ServerInstance->Modules->Attach(eventlist, this, 5);
        }
 
-       void Implements(char* List)
+       virtual ~ModuleSQLAuth()
        {
-               List[I_OnUserDisconnect] = List[I_OnCheckReady] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = 1;
+               ServerInstance->Modules->DoneWithInterface("SQL");
+               ServerInstance->Modules->DoneWithInterface("SQLutils");
        }
 
-       virtual void OnRehash(const std::string &parameter)
+
+       virtual void OnRehash(User* user, const std::string &parameter)
        {
-               ConfigReader Conf;
-               
-               usertable       = Conf.ReadValue("sqlauth", "usertable", 0);    /* User table name */
+               ConfigReader Conf(ServerInstance);
+
                databaseid      = Conf.ReadValue("sqlauth", "dbid", 0);                 /* Database ID, given to the SQL service provider */
-               userfield       = Conf.ReadValue("sqlauth", "userfield", 0);    /* Field name where username can be found */
-               passfield       = Conf.ReadValue("sqlauth", "passfield", 0);    /* Field name where password can be found */
+               freeformquery   = Conf.ReadValue("sqlauth", "query", 0);        /* Field name where username can be found */
                killreason      = Conf.ReadValue("sqlauth", "killreason", 0);   /* Reason to give when access is denied to a user (put your reg details here) */
-               allowpattern= Conf.ReadValue("sqlauth", "allowpattern",0 );     /* Allow nicks matching this pattern without requiring auth */
-               encryption      = Conf.ReadValue("sqlauth", "encryption", 0);   /* Name of sql function used to encrypt password, e.g. "md5" or "passwd".
-                                                                                                                                        * define, but leave blank if no encryption is to be used.
-                                                                                                                                        */
+               allowpattern    = Conf.ReadValue("sqlauth", "allowpattern",0 ); /* Allow nicks matching this pattern without requiring auth */
                verbose         = Conf.ReadFlag("sqlauth", "verbose", 0);               /* Set to true if failed connects should be reported to operators */
-               
-               if (encryption.find("(") == std::string::npos)
+       }
+
+       virtual int OnUserRegister(User* user)
+       {
+               if ((!allowpattern.empty()) && (ServerInstance->MatchText(user->nick,allowpattern)))
                {
-                       encryption.append("(");
+                       user->Extend("sqlauthed");
+                       return 0;
                }
-       }       
 
-       virtual void OnUserRegister(userrec* user)
-       {
-               if ((allowpattern != "") && (Srv->MatchText(user->nick,allowpattern)))
-                       return;
-               
                if (!CheckCredentials(user))
                {
-                       if (verbose)
-                               WriteOpers("Forbidden connection from %s!%s@%s (invalid login/password)",user->nick,user->ident,user->host);
-                       Srv->QuitUser(user,killreason);
+                       ServerInstance->Users->QuitUser(user, killreason);
+                       return 1;
                }
+               return 0;
        }
 
-       bool CheckCredentials(userrec* user)
+       void SearchAndReplace(std::string& newline, const std::string &find, const std::string &replace)
        {
-               bool found;
-               Module* target;
-               
-               found = false;
-               target = Srv->FindFeature("SQL");
-               
-               if(target)
+               std::string::size_type x = newline.find(find);
+               while (x != std::string::npos)
                {
-                       SQLrequest req = SQLreq(this, target, databaseid, "SELECT ? FROM ? WHERE ? = '?' AND ? = ?'?')", userfield, usertable, userfield, user->nick, passfield, encryption, user->password);
-                       
-                       if(req.Send())
-                       {
-                               /* When we get the query response from the service provider we will be given an ID to play with,
-                                * just an ID number which is unique to this query. We need a way of associating that ID with a userrec
-                                * so we insert it into a map mapping the IDs to users.
-                                * This isn't quite enough though, as if the user quit while the query was in progress then when the result
-                                * came to be processed we'd get an invalid userrec* out of the map. Now we *could* solve this by watching
-                                * OnUserDisconnect() and iterating the map every time someone quits to make sure they didn't have any queries
-                                * in progress, but that would be relatively slow and inefficient. Instead (thanks to w00t ;p) we attach a list
-                                * of query IDs associated with it to the userrec, so in OnUserDisconnect() we can remove it immediately.
-                                */
-                               log(DEBUG, "Sent query, got given ID %lu", req.id);
-                               qumap.insert(std::make_pair(req.id, user));
-                               
-                               if(!user->Extend("sqlauth_queryid", new unsigned long(req.id)))
-                               {
-                                       log(DEBUG, "BUG: user being sqlauth'd already extended with 'sqlauth_queryid' :/");
-                               }
-                               
-                               return true;
-                       }
-                       else
-                       {
-                               log(DEBUG, "SQLrequest failed: %s", req.error.Str());
-                               
-                               if (verbose)
-                                       WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, req.error.Str());
-                               
-                               return false;
-                       }
+                       newline.erase(x, find.length());
+                       if (!replace.empty())
+                               newline.insert(x, replace);
+                       x = newline.find(find);
+               }
+       }
+
+       bool CheckCredentials(User* user)
+       {
+               std::string thisquery = freeformquery;
+               std::string safepass = user->password;
+
+               /* Search and replace the escaped nick and escaped pass into the query */
+
+               SearchAndReplace(safepass, "\"", "");
+
+               SearchAndReplace(thisquery, "$nick", user->nick);
+               SearchAndReplace(thisquery, "$pass", safepass);
+               SearchAndReplace(thisquery, "$host", user->host);
+               SearchAndReplace(thisquery, "$ip", user->GetIPString());
+
+               Module* HashMod = ServerInstance->Modules->Find("m_md5.so");
+
+               if (HashMod)
+               {
+                       HashResetRequest(this, HashMod).Send();
+                       SearchAndReplace(thisquery, "$md5pass", HashSumRequest(this, HashMod, user->password).Send());
+               }
+
+               HashMod = ServerInstance->Modules->Find("m_sha256.so");
+
+               if (HashMod)
+               {
+                       HashResetRequest(this, HashMod).Send();
+                       SearchAndReplace(thisquery, "$sha256pass", HashSumRequest(this, HashMod, user->password).Send());
+               }
+
+               /* Build the query */
+               SQLrequest req = SQLrequest(this, SQLprovider, databaseid, SQLquery(thisquery));
+
+               if(req.Send())
+               {
+                       /* When we get the query response from the service provider we will be given an ID to play with,
+                        * just an ID number which is unique to this query. We need a way of associating that ID with a User
+                        * so we insert it into a map mapping the IDs to users.
+                        * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the
+                        * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling
+                        * us to discard the query.
+                        */
+                       AssociateUser(this, SQLutils, req.id, user).Send();
+
+                       return true;
                }
                else
                {
-                       log(SPARSE, "WARNING: Couldn't find SQL provider module. NOBODY will be allowed to connect until it comes back unless they match an exception");
+                       if (verbose)
+                               ServerInstance->SNO->WriteToSnoMask('A', "Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick.c_str(), user->ident.c_str(), user->host.c_str(), req.error.Str());
                        return false;
                }
        }
-       
-       virtual char* OnRequest(Request* request)
+
+       virtual const char* OnRequest(Request* request)
        {
-               if(strcmp(SQLRESID, request->GetData()) == 0)
+               if(strcmp(SQLRESID, request->GetId()) == 0)
                {
-                       SQLresult* res;
-                       QueryUserMap::iterator iter;
-               
-                       res = static_cast<SQLresult*>(request);
-                       
-                       log(DEBUG, "Got SQL result (%s) with ID %lu", res->GetData(), res->id);
-                       
-                       iter = qumap.find(res->id);
-                       
-                       if(iter != qumap.end())
+                       SQLresult* res = static_cast<SQLresult*>(request);
+
+                       User* user = GetAssocUser(this, SQLutils, res->id).S().user;
+                       UnAssociate(this, SQLutils, res->id).S();
+
+                       if(user)
                        {
-                               userrec* user;
-                               unsigned long* id;
-                               
-                               user = iter->second;
-                               
-                               log(DEBUG, "Associated query ID %lu with user %s", res->id, user->nick);
-                               
-                               log(DEBUG, "Got result with %d rows and %d columns", res->Rows(), res->Cols());
-                       
-                               if(res->Rows())
+                               if(res->error.Id() == NO_ERROR)
                                {
-                                       /* We got a row in the result, this is enough really */
-                                       user->Extend("sqlauthed");
+                                       if(res->Rows())
+                                       {
+                                               /* We got a row in the result, this is enough really */
+                                               user->Extend("sqlauthed");
+                                       }
+                                       else if (verbose)
+                                       {
+                                               /* No rows in result, this means there was no record matching the user */
+                                               ServerInstance->SNO->WriteToSnoMask('A', "Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick.c_str(), user->ident.c_str(), user->host.c_str());
+                                               user->Extend("sqlauth_failed");
+                                       }
                                }
                                else if (verbose)
                                {
-                                       /* No rows in result, this means there was no record matching the user */
-                                       WriteOpers("Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick, user->ident, user->host);
-                               }
-                               
-                               /* Remove our ID from the lookup table to keep it as small and neat as possible */
-                               qumap.erase(iter);
-                               
-                               /* Cleanup the userrec, no point leaving this here */
-                               if(user->GetExt("sqlauth_queryid", id))
-                               {
-                                       user->Shrink("sqlauth_queryid");
-                                       delete id;
+                                       ServerInstance->SNO->WriteToSnoMask('A', "Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick.c_str(), user->ident.c_str(), user->host.c_str(), res->error.Str());
+                                       user->Extend("sqlauth_failed");
                                }
                        }
                        else
                        {
-                               log(DEBUG, "Got query with unknown ID, this probably means the user quit while the query was in progress");
+                               return NULL;
+                       }
+
+                       if (!user->GetExt("sqlauthed"))
+                       {
+                               ServerInstance->Users->QuitUser(user, killreason);
                        }
-               
                        return SQLSUCCESS;
                }
-               
-               log(DEBUG, "Got unsupported API version string: %s", request->GetData());
-               
                return NULL;
        }
-       
-       virtual void OnUserDisconnect(userrec* user)
+
+       virtual void OnUserDisconnect(User* user)
        {
-               unsigned long* id;
-               
-               if(user->GetExt("sqlauth_queryid", id))
-               {
-                       QueryUserMap::iterator iter;
-                       
-                       iter = qumap.find(*id);
-                       
-                       if(iter != qumap.end())
-                       {
-                               if(iter->second == user)
-                               {
-                                       qumap.erase(iter);
-                                       
-                                       log(DEBUG, "Erased query from map associated with quitting user %s", user->nick);
-                               }
-                               else
-                               {
-                                       log(DEBUG, "BUG: ID associated with user %s doesn't have the same userrec* associated with it in the map");
-                               }               
-                       }
-                       else
-                       {
-                               log(DEBUG, "BUG: user %s was extended with sqlauth_queryid but there was nothing matching in the map", user->nick);
-                       }
-                       
-                       user->Shrink("sqlauth_queryid");
-                       delete id;
-               }                       
+               user->Shrink("sqlauthed");
+               user->Shrink("sqlauth_failed");
        }
-       
-       virtual bool OnCheckReady(userrec* user)
+
+       virtual bool OnCheckReady(User* user)
        {
                return user->GetExt("sqlauthed");
        }
 
-       virtual ~ModuleSQLAuth()
-       {
-       }
-       
        virtual Version GetVersion()
        {
-               return Version(1,0,1,0,VF_VENDOR);
+               return Version(1,2,1,0,VF_VENDOR,API_VERSION);
        }
-       
-};
 
-class ModuleSQLAuthFactory : public ModuleFactory
-{
- public:
-       ModuleSQLAuthFactory()
-       {
-       }
-       
-       ~ModuleSQLAuthFactory()
-       {
-       }
-       
-       virtual Module * CreateModule(Server* Me)
-       {
-               return new ModuleSQLAuth(Me);
-       }
-       
 };
 
-
-extern "C" void * init_module( void )
-{
-       return new ModuleSQLAuthFactory;
-}
+MODULE_INIT(ModuleSQLAuth)