]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlauth.cpp
9ca4473b82fb38475d2f89a35a0c8bc99b5f2c04
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlauth.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2008 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *            the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include "inspircd.h"
15 #include "users.h"
16 #include "channels.h"
17 #include "modules.h"
18 #include "m_sqlv2.h"
19 #include "m_sqlutils.h"
20 #include "m_hash.h"
21
22 /* $ModDesc: Allow/Deny connections based upon an arbitary SQL table */
23 /* $ModDep: m_sqlv2.h m_sqlutils.h m_hash.h */
24
25 class ModuleSQLAuth : public Module
26 {
27         Module* SQLutils;
28         Module* SQLprovider;
29
30         std::string freeformquery;
31         std::string killreason;
32         std::string allowpattern;
33         std::string databaseid;
34         
35         bool verbose;
36         
37 public:
38         ModuleSQLAuth(InspIRCd* Me)
39         : Module::Module(Me)
40         {
41                 ServerInstance->Modules->UseInterface("SQLutils");
42                 ServerInstance->Modules->UseInterface("SQL");
43
44                 SQLutils = ServerInstance->Modules->Find("m_sqlutils.so");
45                 if (!SQLutils)
46                         throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so.");
47
48                 SQLprovider = ServerInstance->Modules->FindFeature("SQL");
49                 if (!SQLprovider)
50                         throw ModuleException("Can't find an SQL provider module. Please load one before attempting to load m_sqlauth.");
51
52                 OnRehash(NULL,"");
53                 Implementation eventlist[] = { I_OnUserDisconnect, I_OnCheckReady, I_OnRequest, I_OnRehash, I_OnUserRegister };
54                 ServerInstance->Modules->Attach(eventlist, this, 5);
55         }
56
57         virtual ~ModuleSQLAuth()
58         {
59                 ServerInstance->Modules->DoneWithInterface("SQL");
60                 ServerInstance->Modules->DoneWithInterface("SQLutils");
61         }
62
63
64         virtual void OnRehash(User* user, const std::string &parameter)
65         {
66                 ConfigReader Conf(ServerInstance);
67
68                 databaseid      = Conf.ReadValue("sqlauth", "dbid", 0);                 /* Database ID, given to the SQL service provider */
69                 freeformquery   = Conf.ReadValue("sqlauth", "query", 0);        /* Field name where username can be found */
70                 killreason      = Conf.ReadValue("sqlauth", "killreason", 0);   /* Reason to give when access is denied to a user (put your reg details here) */
71                 allowpattern    = Conf.ReadValue("sqlauth", "allowpattern",0 ); /* Allow nicks matching this pattern without requiring auth */
72                 verbose         = Conf.ReadFlag("sqlauth", "verbose", 0);               /* Set to true if failed connects should be reported to operators */
73         }
74
75         virtual int OnUserRegister(User* user)
76         {
77                 if ((!allowpattern.empty()) && (ServerInstance->MatchText(user->nick,allowpattern)))
78                 {
79                         user->Extend("sqlauthed");
80                         return 0;
81                 }
82                 
83                 if (!CheckCredentials(user))
84                 {
85                         ServerInstance->Users->QuitUser(user, killreason);
86                         return 1;
87                 }
88                 return 0;
89         }
90
91         void SearchAndReplace(std::string& newline, const std::string &find, const std::string &replace)
92         {
93                 std::string::size_type x = newline.find(find);
94                 while (x != std::string::npos)
95                 {
96                         newline.erase(x, find.length());
97                         if (!replace.empty())
98                                 newline.insert(x, replace);
99                         x = newline.find(find);
100                 }
101         }
102
103         bool CheckCredentials(User* user)
104         {
105                 std::string thisquery = freeformquery;
106                 std::string safepass = user->password;
107                 
108                 /* Search and replace the escaped nick and escaped pass into the query */
109
110                 SearchAndReplace(safepass, "\"", "");
111
112                 SearchAndReplace(thisquery, "$nick", user->nick);
113                 SearchAndReplace(thisquery, "$pass", safepass);
114                 SearchAndReplace(thisquery, "$host", user->host);
115                 SearchAndReplace(thisquery, "$ip", user->GetIPString());
116
117                 Module* HashMod = ServerInstance->Modules->Find("m_md5.so");
118
119                 if (HashMod)
120                 {
121                         HashResetRequest(this, HashMod).Send();
122                         SearchAndReplace(thisquery, "$md5pass", HashSumRequest(this, HashMod, user->password).Send());
123                 }
124
125                 HashMod = ServerInstance->Modules->Find("m_sha256.so");
126
127                 if (HashMod)
128                 {
129                         HashResetRequest(this, HashMod).Send();
130                         SearchAndReplace(thisquery, "$sha256pass", HashSumRequest(this, HashMod, user->password).Send());
131                 }
132
133                 /* Build the query */
134                 SQLrequest req = SQLrequest(this, SQLprovider, databaseid, SQLquery(thisquery));
135                         
136                 if(req.Send())
137                 {
138                         /* When we get the query response from the service provider we will be given an ID to play with,
139                          * just an ID number which is unique to this query. We need a way of associating that ID with a User
140                          * so we insert it into a map mapping the IDs to users.
141                          * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the
142                          * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling
143                          * us to discard the query.
144                          */
145                         AssociateUser(this, SQLutils, req.id, user).Send();
146                                 
147                         return true;
148                 }
149                 else
150                 {
151                         if (verbose)
152                                 ServerInstance->SNO->WriteToSnoMask('A', "Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick.c_str(), user->ident.c_str(), user->host, req.error.Str());
153                         return false;
154                 }
155         }
156         
157         virtual const char* OnRequest(Request* request)
158         {
159                 if(strcmp(SQLRESID, request->GetId()) == 0)
160                 {
161                         SQLresult* res = static_cast<SQLresult*>(request);
162
163                         User* user = GetAssocUser(this, SQLutils, res->id).S().user;
164                         UnAssociate(this, SQLutils, res->id).S();
165                         
166                         if(user)
167                         {
168                                 if(res->error.Id() == NO_ERROR)
169                                 {
170                                         if(res->Rows())
171                                         {
172                                                 /* We got a row in the result, this is enough really */
173                                                 user->Extend("sqlauthed");
174                                         }
175                                         else if (verbose)
176                                         {
177                                                 /* No rows in result, this means there was no record matching the user */
178                                                 ServerInstance->SNO->WriteToSnoMask('A', "Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick.c_str(), user->ident.c_str(), user->host);
179                                                 user->Extend("sqlauth_failed");
180                                         }
181                                 }
182                                 else if (verbose)
183                                 {
184                                         ServerInstance->SNO->WriteToSnoMask('A', "Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick.c_str(), user->ident.c_str(), user->host, res->error.Str());
185                                         user->Extend("sqlauth_failed");
186                                 }
187                         }
188                         else
189                         {
190                                 return NULL;
191                         }
192
193                         if (!user->GetExt("sqlauthed"))
194                         {
195                                 ServerInstance->Users->QuitUser(user, killreason);
196                         }
197                         return SQLSUCCESS;
198                 }               
199                 return NULL;
200         }
201         
202         virtual void OnUserDisconnect(User* user)
203         {
204                 user->Shrink("sqlauthed");
205                 user->Shrink("sqlauth_failed");         
206         }
207         
208         virtual bool OnCheckReady(User* user)
209         {
210                 return user->GetExt("sqlauthed");
211         }
212
213         virtual Version GetVersion()
214         {
215                 return Version(1,2,1,0,VF_VENDOR,API_VERSION);
216         }
217         
218 };
219
220 MODULE_INIT(ModuleSQLAuth)