]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
Seems to work to a point (dont use it, it wont actually execute a query yet)
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                     E-mail:
7  *              <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *          the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17 using namespace std;
18
19 #include <stdio.h>
20 #include <string>
21 #include <mysql.h>
22 #include <pthread.h>
23 #include "users.h"
24 #include "channels.h"
25 #include "modules.h"
26 #include "helperfuncs.h"
27 #include "m_sqlv2.h"
28
29 /* VERSION 2 API: With nonblocking (threaded) requests */
30
31 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
32 /* $CompileFlags: `mysql_config --include` */
33 /* $LinkerFlags: `mysql_config --libs_r` `perl ../mysql_rpath.pl` */
34
35
36 class SQLConnection;
37
38 extern InspIRCd* ServerInstance;
39 typedef std::map<std::string, SQLConnection*> ConnMap;
40 bool giveup = false;
41
42
43 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
44 #define mysql_field_count mysql_num_fields
45 #endif
46
47 class QueryQueue : public classbase
48 {
49 private:
50         typedef std::deque<SQLrequest> ReqDeque;
51
52         ReqDeque priority;      /* The priority queue */
53         ReqDeque normal;        /* The 'normal' queue */
54         enum { PRI, NOR, NON } which;   /* Which queue the currently active element is at the front of */
55
56 public:
57         QueryQueue()
58         : which(NON)
59         {
60         }
61
62         void push(const SQLrequest &q)
63         {
64                 log(DEBUG, "QueryQueue::push(): Adding %s query to queue: %s", ((q.pri) ? "priority" : "non-priority"), q.query.q.c_str());
65
66                 if(q.pri)
67                         priority.push_back(q);
68                 else
69                         normal.push_back(q);
70         }
71
72         void pop()
73         {
74                 if((which == PRI) && priority.size())
75                 {
76                         priority.pop_front();
77                 }
78                 else if((which == NOR) && normal.size())
79                 {
80                         normal.pop_front();
81                 }
82
83                 /* Reset this */
84                 which = NON;
85
86                 /* Silently do nothing if there was no element to pop() */
87         }
88
89         SQLrequest& front()
90         {
91                 switch(which)
92                 {
93                         case PRI:
94                                 return priority.front();
95                         case NOR:
96                                 return normal.front();
97                         default:
98                                 if(priority.size())
99                                 {
100                                         which = PRI;
101                                         return priority.front();
102                                 }
103
104                                 if(normal.size())
105                                 {
106                                         which = NOR;
107                                         return normal.front();
108                                 }
109
110                                 /* This will probably result in a segfault,
111                                  * but the caller should have checked totalsize()
112                                  * first so..meh - moron :p
113                                  */
114
115                                 return priority.front();
116                 }
117         }
118
119         std::pair<int, int> size()
120         {
121                 return std::make_pair(priority.size(), normal.size());
122         }
123
124         int totalsize()
125         {
126                 return priority.size() + normal.size();
127         }
128
129         void PurgeModule(Module* mod)
130         {
131                 DoPurgeModule(mod, priority);
132                 DoPurgeModule(mod, normal);
133         }
134
135 private:
136         void DoPurgeModule(Module* mod, ReqDeque& q)
137         {
138                 for(ReqDeque::iterator iter = q.begin(); iter != q.end(); iter++)
139                 {
140                         if(iter->GetSource() == mod)
141                         {
142                                 if(iter->id == front().id)
143                                 {
144                                         /* It's the currently active query.. :x */
145                                         iter->SetSource(NULL);
146                                 }
147                                 else
148                                 {
149                                         /* It hasn't been executed yet..just remove it */
150                                         iter = q.erase(iter);
151                                 }
152                         }
153                 }
154         }
155 };
156
157 /* A mutex to wrap around queue accesses */
158 pthread_mutex_t queue_mutex = PTHREAD_MUTEX_INITIALIZER;
159
160 class SQLConnection : public classbase
161 {
162  protected:
163
164         MYSQL connection;
165         MYSQL_RES *res;
166         MYSQL_ROW row;
167         std::string host;
168         std::string user;
169         std::string pass;
170         std::string db;
171         std::map<std::string,std::string> thisrow;
172         bool Enabled;
173         long id;
174
175  public:
176
177         QueryQueue queue;
178
179         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
180         // MYSQL struct, but does not connect yet.
181         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
182         {
183                 this->Enabled = true;
184                 this->host = thishost;
185                 this->user = thisuser;
186                 this->pass = thispass;
187                 this->db = thisdb;
188                 this->id = myid;
189         }
190
191         // This method connects to the database using the credentials supplied to the constructor, and returns
192         // true upon success.
193         bool Connect()
194         {
195                 unsigned int timeout = 1;
196                 mysql_init(&connection);
197                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
198                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
199         }
200
201         void DoLeadingQuery()
202         {
203                 SQLrequest& query = queue.front();
204                 log(DEBUG,"DO QUERY: %s",query.query.q.c_str());
205         }
206
207         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
208         // multiple rows.
209         bool QueryResult(std::string query)
210         {
211                 if (!CheckConnection()) return false;
212                 
213                 int r = mysql_query(&connection, query.c_str());
214                 if (!r)
215                 {
216                         res = mysql_use_result(&connection);
217                 }
218                 return (!r);
219         }
220
221         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
222         // the number of effected rows is returned in the return value.
223         long QueryCount(std::string query)
224         {
225                 /* If the connection is down, we return a negative value - New to 1.1 */
226                 if (!CheckConnection()) return -1;
227
228                 int r = mysql_query(&connection, query.c_str());
229                 if (!r)
230                 {
231                         res = mysql_store_result(&connection);
232                         unsigned long rows = mysql_affected_rows(&connection);
233                         mysql_free_result(res);
234                         return rows;
235                 }
236                 return 0;
237         }
238
239         // This method fetches a row, if available from the database. You must issue a query
240         // using QueryResult() first! The row's values are returned as a map of std::string
241         // where each item is keyed by the column name.
242         std::map<std::string,std::string> GetRow()
243         {
244                 thisrow.clear();
245                 if (res)
246                 {
247                         row = mysql_fetch_row(res);
248                         if (row)
249                         {
250                                 unsigned int field_count = 0;
251                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
252                                 if(mysql_field_count(&connection) == 0)
253                                         return thisrow;
254                                 if (fields && mysql_field_count(&connection))
255                                 {
256                                         while (field_count < mysql_field_count(&connection))
257                                         {
258                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
259                                                 std::string b = (row[field_count] ? row[field_count] : "");
260                                                 thisrow[a] = b;
261                                                 field_count++;
262                                         }
263                                         return thisrow;
264                                 }
265                         }
266                 }
267                 return thisrow;
268         }
269
270         bool QueryDone()
271         {
272                 if (res)
273                 {
274                         mysql_free_result(res);
275                         res = NULL;
276                         return true;
277                 }
278                 else return false;
279         }
280
281         bool ConnectionLost()
282         {
283                 if (&connection) {
284                         return (mysql_ping(&connection) != 0);
285                 }
286                 else return false;
287         }
288
289         bool CheckConnection()
290         {
291                 if (ConnectionLost()) {
292                         return Connect();
293                 }
294                 else return true;
295         }
296
297         std::string GetError()
298         {
299                 return mysql_error(&connection);
300         }
301
302         long GetID()
303         {
304                 return id;
305         }
306
307         std::string GetHost()
308         {
309                 return host;
310         }
311
312         void Enable()
313         {
314                 Enabled = true;
315         }
316
317         void Disable()
318         {
319                 Enabled = false;
320         }
321
322         bool IsEnabled()
323         {
324                 return Enabled;
325         }
326
327 };
328
329 ConnMap Connections;
330
331 void ConnectDatabases(Server* Srv)
332 {
333         for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
334         {
335                 i->second->Enable();
336                 if (i->second->Connect())
337                 {
338                         Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->second->GetHost());
339                 }
340                 else
341                 {
342                         Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->second->GetHost()+": Error: "+i->second->GetError());
343                         i->second->Disable();
344                 }
345         }
346 }
347
348
349 void LoadDatabases(ConfigReader* ThisConf, Server* Srv)
350 {
351         Srv->Log(DEFAULT,"SQL: Loading database settings");
352         Connections.clear();
353         Srv->Log(DEBUG,"Cleared connections");
354         for (int j =0; j < ThisConf->Enumerate("database"); j++)
355         {
356                 std::string db = ThisConf->ReadValue("database","name",j);
357                 std::string user = ThisConf->ReadValue("database","username",j);
358                 std::string pass = ThisConf->ReadValue("database","password",j);
359                 std::string host = ThisConf->ReadValue("database","hostname",j);
360                 std::string id = ThisConf->ReadValue("database","id",j);
361                 Srv->Log(DEBUG,"Read database settings");
362                 if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
363                 {
364                         SQLConnection* ThisSQL = new SQLConnection(host,user,pass,db,atoi(id.c_str()));
365                         Srv->Log(DEFAULT,"Loaded database: "+ThisSQL->GetHost());
366                         Connections[id] = ThisSQL;
367                         Srv->Log(DEBUG,"Pushed back connection");
368                 }
369         }
370         ConnectDatabases(Srv);
371 }
372
373 void* DispatcherThread(void* arg);
374
375 class ModuleSQL : public Module
376 {
377  public:
378         Server *Srv;
379         ConfigReader *Conf;
380         pthread_t Dispatcher;
381         int currid;
382
383         void Implements(char* List)
384         {
385                 List[I_OnRehash] = List[I_OnRequest] = 1;
386         }
387
388         unsigned long NewID()
389         {
390                 if (currid+1 == 0)
391                         currid++;
392                 return ++currid;
393         }
394
395         char* OnRequest(Request* request)
396         {
397                 if(strcmp(SQLREQID, request->GetData()) == 0)
398                 {
399                         SQLrequest* req = (SQLrequest*)request;
400
401                         /* XXX: Lock */
402                         pthread_mutex_lock(&queue_mutex);
403
404                         ConnMap::iterator iter;
405
406                         char* returnval = NULL;
407
408                         log(DEBUG, "Got query: '%s' with %d replacement parameters on id '%s'", req->query.q.c_str(), req->query.p.size(), req->dbid.c_str());
409
410                         if((iter = Connections.find(req->dbid)) != Connections.end())
411                         {
412                                 iter->second->queue.push(*req);
413                                 req->id = NewID();
414                                 returnval = SQLSUCCESS;
415                         }
416                         else
417                         {
418                                 req->error.Id(BAD_DBID);
419                         }
420
421                         pthread_mutex_unlock(&queue_mutex);
422                         /* XXX: Unlock */
423
424                         return returnval;
425                 }
426
427                 log(DEBUG, "Got unsupported API version string: %s", request->GetData());
428
429                 return NULL;
430         }
431
432         ModuleSQL(Server* Me)
433                 : Module::Module(Me)
434         {
435                 Srv = Me;
436                 Conf = new ConfigReader();
437                 currid = 0;
438                 pthread_attr_t attribs;
439                 pthread_attr_init(&attribs);
440                 pthread_attr_setdetachstate(&attribs, PTHREAD_CREATE_DETACHED);
441                 if (pthread_create(&this->Dispatcher, &attribs, DispatcherThread, (void *)this) != 0)
442                 {
443                         throw ModuleException("m_mysql: Failed to create dispatcher thread: " + std::string(strerror(errno)));
444                 }
445                 Srv->PublishFeature("SQL", this);
446                 Srv->PublishFeature("MySQL", this);
447         }
448         
449         virtual ~ModuleSQL()
450         {
451                 DELETE(Conf);
452         }
453         
454         virtual void OnRehash(const std::string &parameter)
455         {
456                 /* TODO: set rehash bool here, which makes the dispatcher thread rehash at next opportunity */
457         }
458         
459         virtual Version GetVersion()
460         {
461                 return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
462         }
463         
464 };
465
466 void* DispatcherThread(void* arg)
467 {
468         ModuleSQL* thismodule = (ModuleSQL*)arg;
469         LoadDatabases(thismodule->Conf, thismodule->Srv);
470
471         while (!giveup)
472         {
473                 SQLConnection* conn = NULL;
474                 /* XXX: Lock here for safety */
475                 pthread_mutex_lock(&queue_mutex);
476                 for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
477                 {
478                         if (i->second->queue.totalsize())
479                         {
480                                 conn = i->second;
481                                 break;
482                         }
483                 }
484                 pthread_mutex_unlock(&queue_mutex);
485                 /* XXX: Unlock */
486
487                 /* Theres an item! */
488                 if (conn)
489                 {
490                         conn->DoLeadingQuery();
491
492                         /* XXX: Lock */
493                         pthread_mutex_lock(&queue_mutex);
494                         conn->queue.pop();
495                         pthread_mutex_unlock(&queue_mutex);
496                         /* XXX: Unlock */
497                 }
498
499                 usleep(50);
500         }
501
502         return NULL;
503 }
504
505
506 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
507
508 class ModuleSQLFactory : public ModuleFactory
509 {
510  public:
511         ModuleSQLFactory()
512         {
513         }
514         
515         ~ModuleSQLFactory()
516         {
517         }
518         
519         virtual Module * CreateModule(Server* Me)
520         {
521                 return new ModuleSQL(Me);
522         }
523         
524 };
525
526
527 extern "C" void * init_module( void )
528 {
529         return new ModuleSQLFactory;
530 }
531