]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
bc3586c385b2987ad83ef64c40ce0969c84b4ec0
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                     E-mail:
7  *              <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *          the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17 using namespace std;
18
19 #include <stdio.h>
20 #include <string>
21 #include <mysql.h>
22 #include <pthread.h>
23 #include "users.h"
24 #include "channels.h"
25 #include "modules.h"
26 #include "helperfuncs.h"
27 #include "m_sqlv2.h"
28
29 /* VERSION 2 API: With nonblocking (threaded) requests */
30
31 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
32 /* $CompileFlags: `mysql_config --include` */
33 /* $LinkerFlags: `mysql_config --libs_r` `perl ../mysql_rpath.pl` */
34
35
36 class SQLConnection;
37
38 extern InspIRCd* ServerInstance;
39 typedef std::map<std::string, SQLConnection*> ConnMap;
40 bool giveup = false;
41
42
43 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
44 #define mysql_field_count mysql_num_fields
45 #endif
46
47 class QueryQueue : public classbase
48 {
49 private:
50         typedef std::deque<SQLrequest> ReqDeque;
51
52         ReqDeque priority;      /* The priority queue */
53         ReqDeque normal;        /* The 'normal' queue */
54         enum { PRI, NOR, NON } which;   /* Which queue the currently active element is at the front of */
55
56 public:
57         QueryQueue()
58         : which(NON)
59         {
60         }
61
62         void push(const SQLrequest &q)
63         {
64                 log(DEBUG, "QueryQueue::push(): Adding %s query to queue: %s", ((q.pri) ? "priority" : "non-priority"), q.query.q.c_str());
65
66                 if(q.pri)
67                         priority.push_back(q);
68                 else
69                         normal.push_back(q);
70         }
71
72         void pop()
73         {
74                 if((which == PRI) && priority.size())
75                 {
76                         priority.pop_front();
77                 }
78                 else if((which == NOR) && normal.size())
79                 {
80                         normal.pop_front();
81                 }
82
83                 /* Reset this */
84                 which = NON;
85
86                 /* Silently do nothing if there was no element to pop() */
87         }
88
89         SQLrequest& front()
90         {
91                 switch(which)
92                 {
93                         case PRI:
94                                 return priority.front();
95                         case NOR:
96                                 return normal.front();
97                         default:
98                                 if(priority.size())
99                                 {
100                                         which = PRI;
101                                         return priority.front();
102                                 }
103
104                                 if(normal.size())
105                                 {
106                                         which = NOR;
107                                         return normal.front();
108                                 }
109
110                                 /* This will probably result in a segfault,
111                                  * but the caller should have checked totalsize()
112                                  * first so..meh - moron :p
113                                  */
114
115                                 return priority.front();
116                 }
117         }
118
119         std::pair<int, int> size()
120         {
121                 return std::make_pair(priority.size(), normal.size());
122         }
123
124         int totalsize()
125         {
126                 return priority.size() + normal.size();
127         }
128
129         void PurgeModule(Module* mod)
130         {
131                 DoPurgeModule(mod, priority);
132                 DoPurgeModule(mod, normal);
133         }
134
135 private:
136         void DoPurgeModule(Module* mod, ReqDeque& q)
137         {
138                 for(ReqDeque::iterator iter = q.begin(); iter != q.end(); iter++)
139                 {
140                         if(iter->GetSource() == mod)
141                         {
142                                 if(iter->id == front().id)
143                                 {
144                                         /* It's the currently active query.. :x */
145                                         iter->SetSource(NULL);
146                                 }
147                                 else
148                                 {
149                                         /* It hasn't been executed yet..just remove it */
150                                         iter = q.erase(iter);
151                                 }
152                         }
153                 }
154         }
155 };
156
157 /* A mutex to wrap around queue accesses */
158 pthread_mutex_t queue_mutex = PTHREAD_MUTEX_INITIALIZER;
159
160 class SQLConnection : public classbase
161 {
162  protected:
163
164         MYSQL connection;
165         MYSQL_RES *res;
166         MYSQL_ROW row;
167         std::string host;
168         std::string user;
169         std::string pass;
170         std::string db;
171         std::map<std::string,std::string> thisrow;
172         bool Enabled;
173         long id;
174
175  public:
176
177         QueryQueue queue;
178
179         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
180         // MYSQL struct, but does not connect yet.
181         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
182         {
183                 this->Enabled = true;
184                 this->host = thishost;
185                 this->user = thisuser;
186                 this->pass = thispass;
187                 this->db = thisdb;
188                 this->id = myid;
189         }
190
191         // This method connects to the database using the credentials supplied to the constructor, and returns
192         // true upon success.
193         bool Connect()
194         {
195                 unsigned int timeout = 1;
196                 mysql_init(&connection);
197                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
198                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
199         }
200
201         void DoLeadingQuery()
202         {
203                 /* Parse the command string and dispatch it to mysql */
204                 SQLrequest& req = queue.front();
205                 log(DEBUG,"DO QUERY: %s",req.query.q.c_str());
206
207                 /* Pointer to the buffer we screw around with substitution in */
208                 char* query;
209
210                 /* Pointer to the current end of query, where we append new stuff */
211                 char* queryend;
212
213                 /* Total length of the unescaped parameters */
214                 unsigned long paramlen;
215
216                 paramlen = 0;
217
218                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
219                 {
220                         paramlen += i->size();
221                 }
222
223                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
224                  * sizeofquery + (totalparamlength*2) + 1
225                  *
226                  * The +1 is for null-terminating the string for mysql_real_escape_string
227                  */
228
229                 query = new char[req.query.q.length() + (paramlen*2)];
230                 queryend = query;
231
232                 /* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting
233                  * the parameters into it...
234                  */
235
236                 for(unsigned long i = 0; i < req.query.q.length(); i++)
237                 {
238                         if(req.query.q[i] == '?')
239                         {
240                                 /* We found a place to substitute..what fun.
241                                  * use mysql calls to escape and write the
242                                  * escaped string onto the end of our query buffer,
243                                  * then we "just" need to make sure queryend is
244                                  * pointing at the right place.
245                                  */
246                                 if(req.query.p.size())
247                                 {
248                                         unsigned long len = mysql_real_escape_string(&connection, queryend, req.query.p.front().c_str(), req.query.p.front().length());
249
250                                         queryend += len;
251                                         req.query.p.pop_front();
252                                 }
253                                 else
254                                 {
255                                         log(DEBUG, "Found a substitution location but no parameter to substitute :|");
256                                         break;
257                                 }
258                         }
259                         else
260                         {
261                                 *queryend = req.query.q[i];
262                                 queryend++;
263                         }
264                 }
265
266                 *queryend = 0;
267
268                 log(DEBUG, "Attempting to dispatch query: %s", query);
269
270                 pthread_mutex_lock(&queue_mutex);
271                 req.query.q = query;
272                 pthread_mutex_unlock(&queue_mutex);
273
274                 /* TODO: Do the mysql_real_query here */
275         }
276
277         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
278         // multiple rows.
279         bool QueryResult(std::string query)
280         {
281                 if (!CheckConnection()) return false;
282                 
283                 int r = mysql_query(&connection, query.c_str());
284                 if (!r)
285                 {
286                         res = mysql_use_result(&connection);
287                 }
288                 return (!r);
289         }
290
291         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
292         // the number of effected rows is returned in the return value.
293         long QueryCount(std::string query)
294         {
295                 /* If the connection is down, we return a negative value - New to 1.1 */
296                 if (!CheckConnection()) return -1;
297
298                 int r = mysql_query(&connection, query.c_str());
299                 if (!r)
300                 {
301                         res = mysql_store_result(&connection);
302                         unsigned long rows = mysql_affected_rows(&connection);
303                         mysql_free_result(res);
304                         return rows;
305                 }
306                 return 0;
307         }
308
309         // This method fetches a row, if available from the database. You must issue a query
310         // using QueryResult() first! The row's values are returned as a map of std::string
311         // where each item is keyed by the column name.
312         std::map<std::string,std::string> GetRow()
313         {
314                 thisrow.clear();
315                 if (res)
316                 {
317                         row = mysql_fetch_row(res);
318                         if (row)
319                         {
320                                 unsigned int field_count = 0;
321                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
322                                 if(mysql_field_count(&connection) == 0)
323                                         return thisrow;
324                                 if (fields && mysql_field_count(&connection))
325                                 {
326                                         while (field_count < mysql_field_count(&connection))
327                                         {
328                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
329                                                 std::string b = (row[field_count] ? row[field_count] : "");
330                                                 thisrow[a] = b;
331                                                 field_count++;
332                                         }
333                                         return thisrow;
334                                 }
335                         }
336                 }
337                 return thisrow;
338         }
339
340         bool QueryDone()
341         {
342                 if (res)
343                 {
344                         mysql_free_result(res);
345                         res = NULL;
346                         return true;
347                 }
348                 else return false;
349         }
350
351         bool ConnectionLost()
352         {
353                 if (&connection) {
354                         return (mysql_ping(&connection) != 0);
355                 }
356                 else return false;
357         }
358
359         bool CheckConnection()
360         {
361                 if (ConnectionLost()) {
362                         return Connect();
363                 }
364                 else return true;
365         }
366
367         std::string GetError()
368         {
369                 return mysql_error(&connection);
370         }
371
372         long GetID()
373         {
374                 return id;
375         }
376
377         std::string GetHost()
378         {
379                 return host;
380         }
381
382         void Enable()
383         {
384                 Enabled = true;
385         }
386
387         void Disable()
388         {
389                 Enabled = false;
390         }
391
392         bool IsEnabled()
393         {
394                 return Enabled;
395         }
396
397 };
398
399 ConnMap Connections;
400
401 void ConnectDatabases(Server* Srv)
402 {
403         for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
404         {
405                 i->second->Enable();
406                 if (i->second->Connect())
407                 {
408                         Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->second->GetHost());
409                 }
410                 else
411                 {
412                         Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->second->GetHost()+": Error: "+i->second->GetError());
413                         i->second->Disable();
414                 }
415         }
416 }
417
418
419 void LoadDatabases(ConfigReader* ThisConf, Server* Srv)
420 {
421         Srv->Log(DEFAULT,"SQL: Loading database settings");
422         Connections.clear();
423         Srv->Log(DEBUG,"Cleared connections");
424         for (int j =0; j < ThisConf->Enumerate("database"); j++)
425         {
426                 std::string db = ThisConf->ReadValue("database","name",j);
427                 std::string user = ThisConf->ReadValue("database","username",j);
428                 std::string pass = ThisConf->ReadValue("database","password",j);
429                 std::string host = ThisConf->ReadValue("database","hostname",j);
430                 std::string id = ThisConf->ReadValue("database","id",j);
431                 Srv->Log(DEBUG,"Read database settings");
432                 if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
433                 {
434                         SQLConnection* ThisSQL = new SQLConnection(host,user,pass,db,atoi(id.c_str()));
435                         Srv->Log(DEFAULT,"Loaded database: "+ThisSQL->GetHost());
436                         Connections[id] = ThisSQL;
437                         Srv->Log(DEBUG,"Pushed back connection");
438                 }
439         }
440         ConnectDatabases(Srv);
441 }
442
443 void* DispatcherThread(void* arg);
444
445 class ModuleSQL : public Module
446 {
447  public:
448         Server *Srv;
449         ConfigReader *Conf;
450         pthread_t Dispatcher;
451         int currid;
452
453         void Implements(char* List)
454         {
455                 List[I_OnRehash] = List[I_OnRequest] = 1;
456         }
457
458         unsigned long NewID()
459         {
460                 if (currid+1 == 0)
461                         currid++;
462                 return ++currid;
463         }
464
465         char* OnRequest(Request* request)
466         {
467                 if(strcmp(SQLREQID, request->GetData()) == 0)
468                 {
469                         SQLrequest* req = (SQLrequest*)request;
470
471                         /* XXX: Lock */
472                         pthread_mutex_lock(&queue_mutex);
473
474                         ConnMap::iterator iter;
475
476                         char* returnval = NULL;
477
478                         log(DEBUG, "Got query: '%s' with %d replacement parameters on id '%s'", req->query.q.c_str(), req->query.p.size(), req->dbid.c_str());
479
480                         if((iter = Connections.find(req->dbid)) != Connections.end())
481                         {
482                                 iter->second->queue.push(*req);
483                                 req->id = NewID();
484                                 returnval = SQLSUCCESS;
485                         }
486                         else
487                         {
488                                 req->error.Id(BAD_DBID);
489                         }
490
491                         pthread_mutex_unlock(&queue_mutex);
492                         /* XXX: Unlock */
493
494                         return returnval;
495                 }
496
497                 log(DEBUG, "Got unsupported API version string: %s", request->GetData());
498
499                 return NULL;
500         }
501
502         ModuleSQL(Server* Me)
503                 : Module::Module(Me)
504         {
505                 Srv = Me;
506                 Conf = new ConfigReader();
507                 currid = 0;
508                 pthread_attr_t attribs;
509                 pthread_attr_init(&attribs);
510                 pthread_attr_setdetachstate(&attribs, PTHREAD_CREATE_DETACHED);
511                 if (pthread_create(&this->Dispatcher, &attribs, DispatcherThread, (void *)this) != 0)
512                 {
513                         throw ModuleException("m_mysql: Failed to create dispatcher thread: " + std::string(strerror(errno)));
514                 }
515                 Srv->PublishFeature("SQL", this);
516                 Srv->PublishFeature("MySQL", this);
517         }
518         
519         virtual ~ModuleSQL()
520         {
521                 DELETE(Conf);
522         }
523         
524         virtual void OnRehash(const std::string &parameter)
525         {
526                 /* TODO: set rehash bool here, which makes the dispatcher thread rehash at next opportunity */
527         }
528         
529         virtual Version GetVersion()
530         {
531                 return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
532         }
533         
534 };
535
536 void* DispatcherThread(void* arg)
537 {
538         ModuleSQL* thismodule = (ModuleSQL*)arg;
539         LoadDatabases(thismodule->Conf, thismodule->Srv);
540
541         while (!giveup)
542         {
543                 SQLConnection* conn = NULL;
544                 /* XXX: Lock here for safety */
545                 pthread_mutex_lock(&queue_mutex);
546                 for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
547                 {
548                         if (i->second->queue.totalsize())
549                         {
550                                 conn = i->second;
551                                 break;
552                         }
553                 }
554                 pthread_mutex_unlock(&queue_mutex);
555                 /* XXX: Unlock */
556
557                 /* Theres an item! */
558                 if (conn)
559                 {
560                         conn->DoLeadingQuery();
561
562                         /* XXX: Lock */
563                         pthread_mutex_lock(&queue_mutex);
564                         conn->queue.pop();
565                         pthread_mutex_unlock(&queue_mutex);
566                         /* XXX: Unlock */
567                 }
568
569                 usleep(50);
570         }
571
572         return NULL;
573 }
574
575
576 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
577
578 class ModuleSQLFactory : public ModuleFactory
579 {
580  public:
581         ModuleSQLFactory()
582         {
583         }
584         
585         ~ModuleSQLFactory()
586         {
587         }
588         
589         virtual Module * CreateModule(Server* Me)
590         {
591                 return new ModuleSQL(Me);
592         }
593         
594 };
595
596
597 extern "C" void * init_module( void )
598 {
599         return new ModuleSQLFactory;
600 }
601