]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
d1329151c20d203034fe3d9d3f1d919373a35297
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                     E-mail:
7  *              <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *          the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17 using namespace std;
18
19 #include <stdio.h>
20 #include <string>
21 #include <mysql.h>
22 #include <pthread.h>
23 #include "users.h"
24 #include "channels.h"
25 #include "modules.h"
26 #include "helperfuncs.h"
27 #include "m_sqlv2.h"
28
29 /* VERSION 2 API: With nonblocking (threaded) requests */
30
31 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
32 /* $CompileFlags: `mysql_config --include` */
33 /* $LinkerFlags: `mysql_config --libs_r` `perl ../mysql_rpath.pl` */
34
35
36 class SQLConnection;
37
38 extern InspIRCd* ServerInstance;
39 typedef std::map<std::string, SQLConnection*> ConnMap;
40
41
42 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
43 #define mysql_field_count mysql_num_fields
44 #endif
45
46 class QueryQueue : public classbase
47 {
48 private:
49         typedef std::deque<SQLrequest> ReqDeque;
50
51         ReqDeque priority;      /* The priority queue */
52         ReqDeque normal;        /* The 'normal' queue */
53         enum { PRI, NOR, NON } which;   /* Which queue the currently active element is at the front of */
54
55 public:
56         QueryQueue()
57         : which(NON)
58         {
59         }
60
61         void push(const SQLrequest &q)
62         {
63                 log(DEBUG, "QueryQueue::push(): Adding %s query to queue: %s", ((q.pri) ? "priority" : "non-priority"), q.query.q.c_str());
64
65                 if(q.pri)
66                         priority.push_back(q);
67                 else
68                         normal.push_back(q);
69         }
70
71         void pop()
72         {
73                 if((which == PRI) && priority.size())
74                 {
75                         priority.pop_front();
76                 }
77                 else if((which == NOR) && normal.size())
78                 {
79                         normal.pop_front();
80                 }
81
82                 /* Reset this */
83                 which = NON;
84
85                 /* Silently do nothing if there was no element to pop() */
86         }
87
88         SQLrequest& front()
89         {
90                 switch(which)
91                 {
92                         case PRI:
93                                 return priority.front();
94                         case NOR:
95                                 return normal.front();
96                         default:
97                                 if(priority.size())
98                                 {
99                                         which = PRI;
100                                         return priority.front();
101                                 }
102
103                                 if(normal.size())
104                                 {
105                                         which = NOR;
106                                         return normal.front();
107                                 }
108
109                                 /* This will probably result in a segfault,
110                                  * but the caller should have checked totalsize()
111                                  * first so..meh - moron :p
112                                  */
113
114                                 return priority.front();
115                 }
116         }
117
118         std::pair<int, int> size()
119         {
120                 return std::make_pair(priority.size(), normal.size());
121         }
122
123         int totalsize()
124         {
125                 return priority.size() + normal.size();
126         }
127
128         void PurgeModule(Module* mod)
129         {
130                 DoPurgeModule(mod, priority);
131                 DoPurgeModule(mod, normal);
132         }
133
134 private:
135         void DoPurgeModule(Module* mod, ReqDeque& q)
136         {
137                 for(ReqDeque::iterator iter = q.begin(); iter != q.end(); iter++)
138                 {
139                         if(iter->GetSource() == mod)
140                         {
141                                 if(iter->id == front().id)
142                                 {
143                                         /* It's the currently active query.. :x */
144                                         iter->SetSource(NULL);
145                                 }
146                                 else
147                                 {
148                                         /* It hasn't been executed yet..just remove it */
149                                         iter = q.erase(iter);
150                                 }
151                         }
152                 }
153         }
154 };
155
156
157
158 class SQLConnection : public classbase
159 {
160  protected:
161
162         MYSQL connection;
163         MYSQL_RES *res;
164         MYSQL_ROW row;
165         std::string host;
166         std::string user;
167         std::string pass;
168         std::string db;
169         std::map<std::string,std::string> thisrow;
170         bool Enabled;
171         long id;
172
173  public:
174
175         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
176         // MYSQL struct, but does not connect yet.
177         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
178         {
179                 this->Enabled = true;
180                 this->host = thishost;
181                 this->user = thisuser;
182                 this->pass = thispass;
183                 this->db = thisdb;
184                 this->id = myid;
185         }
186
187         // This method connects to the database using the credentials supplied to the constructor, and returns
188         // true upon success.
189         bool Connect()
190         {
191                 unsigned int timeout = 1;
192                 mysql_init(&connection);
193                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
194                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
195         }
196
197         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
198         // multiple rows.
199         bool QueryResult(std::string query)
200         {
201                 if (!CheckConnection()) return false;
202                 
203                 int r = mysql_query(&connection, query.c_str());
204                 if (!r)
205                 {
206                         res = mysql_use_result(&connection);
207                 }
208                 return (!r);
209         }
210
211         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
212         // the number of effected rows is returned in the return value.
213         long QueryCount(std::string query)
214         {
215                 /* If the connection is down, we return a negative value - New to 1.1 */
216                 if (!CheckConnection()) return -1;
217
218                 int r = mysql_query(&connection, query.c_str());
219                 if (!r)
220                 {
221                         res = mysql_store_result(&connection);
222                         unsigned long rows = mysql_affected_rows(&connection);
223                         mysql_free_result(res);
224                         return rows;
225                 }
226                 return 0;
227         }
228
229         // This method fetches a row, if available from the database. You must issue a query
230         // using QueryResult() first! The row's values are returned as a map of std::string
231         // where each item is keyed by the column name.
232         std::map<std::string,std::string> GetRow()
233         {
234                 thisrow.clear();
235                 if (res)
236                 {
237                         row = mysql_fetch_row(res);
238                         if (row)
239                         {
240                                 unsigned int field_count = 0;
241                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
242                                 if(mysql_field_count(&connection) == 0)
243                                         return thisrow;
244                                 if (fields && mysql_field_count(&connection))
245                                 {
246                                         while (field_count < mysql_field_count(&connection))
247                                         {
248                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
249                                                 std::string b = (row[field_count] ? row[field_count] : "");
250                                                 thisrow[a] = b;
251                                                 field_count++;
252                                         }
253                                         return thisrow;
254                                 }
255                         }
256                 }
257                 return thisrow;
258         }
259
260         bool QueryDone()
261         {
262                 if (res)
263                 {
264                         mysql_free_result(res);
265                         res = NULL;
266                         return true;
267                 }
268                 else return false;
269         }
270
271         bool ConnectionLost()
272         {
273                 if (&connection) {
274                         return (mysql_ping(&connection) != 0);
275                 }
276                 else return false;
277         }
278
279         bool CheckConnection()
280         {
281                 if (ConnectionLost()) {
282                         return Connect();
283                 }
284                 else return true;
285         }
286
287         std::string GetError()
288         {
289                 return mysql_error(&connection);
290         }
291
292         long GetID()
293         {
294                 return id;
295         }
296
297         std::string GetHost()
298         {
299                 return host;
300         }
301
302         void Enable()
303         {
304                 Enabled = true;
305         }
306
307         void Disable()
308         {
309                 Enabled = false;
310         }
311
312         bool IsEnabled()
313         {
314                 return Enabled;
315         }
316
317 };
318
319 ConnMap Connections;
320
321 void ConnectDatabases(Server* Srv)
322 {
323         for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
324         {
325                 i->second->Enable();
326                 if (i->second->Connect())
327                 {
328                         Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->second->GetHost());
329                 }
330                 else
331                 {
332                         Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->second->GetHost()+": Error: "+i->second->GetError());
333                         i->second->Disable();
334                 }
335         }
336 }
337
338
339 void LoadDatabases(ConfigReader* ThisConf, Server* Srv)
340 {
341         Srv->Log(DEFAULT,"SQL: Loading database settings");
342         Connections.clear();
343         Srv->Log(DEBUG,"Cleared connections");
344         for (int j =0; j < ThisConf->Enumerate("database"); j++)
345         {
346                 std::string db = ThisConf->ReadValue("database","name",j);
347                 std::string user = ThisConf->ReadValue("database","username",j);
348                 std::string pass = ThisConf->ReadValue("database","password",j);
349                 std::string host = ThisConf->ReadValue("database","hostname",j);
350                 std::string id = ThisConf->ReadValue("database","id",j);
351                 Srv->Log(DEBUG,"Read database settings");
352                 if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
353                 {
354                         SQLConnection* ThisSQL = new SQLConnection(host,user,pass,db,atoi(id.c_str()));
355                         Srv->Log(DEFAULT,"Loaded database: "+ThisSQL->GetHost());
356                         Connections[id] = ThisSQL;
357                         Srv->Log(DEBUG,"Pushed back connection");
358                 }
359         }
360         ConnectDatabases(Srv);
361 }
362
363 void* DispatcherThread(void* arg);
364
365 class ModuleSQL : public Module
366 {
367  public:
368         Server *Srv;
369         ConfigReader *Conf;
370         pthread_t Dispatcher;
371
372         void Implements(char* List)
373         {
374                 List[I_OnRehash] = List[I_OnRequest] = 1;
375         }
376
377         char* OnRequest(Request* request)
378         {
379                 return NULL;
380         }
381
382         ModuleSQL(Server* Me)
383                 : Module::Module(Me)
384         {
385                 Srv = Me;
386                 Conf = new ConfigReader();
387                 pthread_attr_t attribs;
388                 pthread_attr_init(&attribs);
389                 pthread_attr_setdetachstate(&attribs, PTHREAD_CREATE_DETACHED);
390                 if (pthread_create(&this->Dispatcher, &attribs, DispatcherThread, (void *)this) != 0)
391                 {
392                         log(DEBUG,"m_mysql: Failed to create dispatcher thread: %s", strerror(errno));
393                 }
394         }
395         
396         virtual ~ModuleSQL()
397         {
398                 DELETE(Conf);
399         }
400         
401         virtual void OnRehash(const std::string &parameter)
402         {
403                 /* TODO: set rehash bool here, which makes the dispatcher thread rehash at next opportunity */
404         }
405         
406         virtual Version GetVersion()
407         {
408                 return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
409         }
410         
411 };
412
413 void* DispatcherThread(void* arg)
414 {
415         ModuleSQL* thismodule = (ModuleSQL*)arg;
416         LoadDatabases(thismodule->Conf, thismodule->Srv);
417
418         return NULL;
419 }
420
421
422 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
423
424 class ModuleSQLFactory : public ModuleFactory
425 {
426  public:
427         ModuleSQLFactory()
428         {
429         }
430         
431         ~ModuleSQLFactory()
432         {
433         }
434         
435         virtual Module * CreateModule(Server* Me)
436         {
437                 return new ModuleSQL(Me);
438         }
439         
440 };
441
442
443 extern "C" void * init_module( void )
444 {
445         return new ModuleSQLFactory;
446 }
447