]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
22cf5f3f4f1a5e006560abb1968c9db207e08474
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22
23 /* Stop mysql wanting to use long long */
24 #define NO_CLIENT_LONG_LONG
25
26 #include "inspircd.h"
27 #include <mysql.h>
28 #include "sql.h"
29
30 #ifdef _WIN32
31 # pragma comment(lib, "mysqlclient.lib")
32 # pragma comment(lib, "advapi32.lib")
33 # pragma comment(linker, "/NODEFAULTLIB:LIBCMT")
34 #endif
35
36 /* VERSION 3 API: With nonblocking (threaded) requests */
37
38 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
39 /* $CompileFlags: exec("mysql_config --include") */
40 /* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */
41
42 /* THE NONBLOCKING MYSQL API!
43  *
44  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
45  * that instead, you should thread your program. This is what i've done here to allow for
46  * asyncronous SQL requests via mysql. The way this works is as follows:
47  *
48  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
49  * using a queue with priorities. There is a mutex on either end which prevents two threads
50  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
51  * worker thread wakes up, and checks if there is a request at the head of its queue.
52  * If there is, it processes this request, blocking the worker thread but leaving the ircd
53  * thread to go about its business as usual. During this period, the ircd thread is able
54  * to insert futher pending requests into the queue.
55  *
56  * Once the processing of a request is complete, it is removed from the incoming queue to
57  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
58  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
59  * connection ID through the connection.
60  *
61  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
62  * of the queue, and sends it on its way to the original calling module.
63  *
64  * XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
65  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
66  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
67  * however, if we were to call a module's OnRequest even from within a thread which was not the
68  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
69  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
70  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
71  * gauranteed threadsafe!)
72  *
73  * For a diagram of this system please see http://wiki.inspircd.org/Mysql2
74  */
75
76 class SQLConnection;
77 class MySQLresult;
78 class DispatcherThread;
79
80 struct QQueueItem
81 {
82         SQLQuery* q;
83         std::string query;
84         SQLConnection* c;
85         QQueueItem(SQLQuery* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
86 };
87
88 struct RQueueItem
89 {
90         SQLQuery* q;
91         MySQLresult* r;
92         RQueueItem(SQLQuery* Q, MySQLresult* R) : q(Q), r(R) {}
93 };
94
95 typedef std::map<std::string, SQLConnection*> ConnMap;
96 typedef std::deque<QQueueItem> QueryQueue;
97 typedef std::deque<RQueueItem> ResultQueue;
98
99 /** MySQL module
100  *  */
101 class ModuleSQL : public Module
102 {
103  public:
104         DispatcherThread* Dispatcher;
105         QueryQueue qq;       // MUST HOLD MUTEX
106         ResultQueue rq;      // MUST HOLD MUTEX
107         ConnMap connections; // main thread only
108
109         ModuleSQL();
110         void init();
111         ~ModuleSQL();
112         void OnRehash(User* user);
113         void OnUnloadModule(Module* mod);
114         Version GetVersion();
115 };
116
117 class DispatcherThread : public SocketThread
118 {
119  private:
120         ModuleSQL* const Parent;
121  public:
122         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
123         ~DispatcherThread() { }
124         virtual void Run();
125         virtual void OnNotify();
126 };
127
128 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
129 #define mysql_field_count mysql_num_fields
130 #endif
131
132 /** Represents a mysql result set
133  */
134 class MySQLresult : public SQLResult
135 {
136  public:
137         SQLerror err;
138         int currentrow;
139         int rows;
140         std::vector<std::string> colnames;
141         std::vector<SQLEntries> fieldlists;
142
143         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL_NO_ERROR), currentrow(0), rows(0)
144         {
145                 if (affected_rows >= 1)
146                 {
147                         rows = affected_rows;
148                         fieldlists.resize(rows);
149                 }
150                 unsigned int field_count = 0;
151                 if (res)
152                 {
153                         MYSQL_ROW row;
154                         int n = 0;
155                         while ((row = mysql_fetch_row(res)))
156                         {
157                                 if (fieldlists.size() < (unsigned int)rows+1)
158                                 {
159                                         fieldlists.resize(fieldlists.size()+1);
160                                 }
161                                 field_count = 0;
162                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
163                                 if(mysql_num_fields(res) == 0)
164                                         break;
165                                 if (fields && mysql_num_fields(res))
166                                 {
167                                         colnames.clear();
168                                         while (field_count < mysql_num_fields(res))
169                                         {
170                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
171                                                 if (row[field_count])
172                                                         fieldlists[n].push_back(SQLEntry(row[field_count]));
173                                                 else
174                                                         fieldlists[n].push_back(SQLEntry());
175                                                 colnames.push_back(a);
176                                                 field_count++;
177                                         }
178                                         n++;
179                                 }
180                                 rows++;
181                         }
182                         mysql_free_result(res);
183                 }
184         }
185
186         MySQLresult(SQLerror& e) : err(e)
187         {
188
189         }
190
191         ~MySQLresult()
192         {
193         }
194
195         virtual int Rows()
196         {
197                 return rows;
198         }
199
200         virtual void GetCols(std::vector<std::string>& result)
201         {
202                 result.assign(colnames.begin(), colnames.end());
203         }
204
205         virtual SQLEntry GetValue(int row, int column)
206         {
207                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
208                 {
209                         return fieldlists[row][column];
210                 }
211                 return SQLEntry();
212         }
213
214         virtual bool GetRow(SQLEntries& result)
215         {
216                 if (currentrow < rows)
217                 {
218                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
219                         currentrow++;
220                         return true;
221                 }
222                 else
223                 {
224                         result.clear();
225                         return false;
226                 }
227         }
228 };
229
230 /** Represents a connection to a mysql database
231  */
232 class SQLConnection : public SQLProvider
233 {
234  public:
235         reference<ConfigTag> config;
236         MYSQL *connection;
237         Mutex lock;
238
239         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
240         SQLConnection(Module* p, ConfigTag* tag) : SQLProvider(p, "SQL/" + tag->getString("id")),
241                 config(tag), connection(NULL)
242         {
243         }
244
245         ~SQLConnection()
246         {
247                 Close();
248         }
249
250         // This method connects to the database using the credentials supplied to the constructor, and returns
251         // true upon success.
252         bool Connect()
253         {
254                 unsigned int timeout = 1;
255                 connection = mysql_init(connection);
256                 mysql_options(connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
257                 std::string host = config->getString("host");
258                 std::string user = config->getString("user");
259                 std::string pass = config->getString("pass");
260                 std::string dbname = config->getString("name");
261                 int port = config->getInt("port");
262                 bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
263                 if (!rv)
264                         return rv;
265                 std::string initquery;
266                 if (config->readString("initialquery", initquery))
267                 {
268                         mysql_query(connection,initquery.c_str());
269                 }
270                 return true;
271         }
272
273         ModuleSQL* Parent()
274         {
275                 return (ModuleSQL*)(Module*)creator;
276         }
277
278         MySQLresult* DoBlockingQuery(const std::string& query)
279         {
280
281                 /* Parse the command string and dispatch it to mysql */
282                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
283                 {
284                         /* Successfull query */
285                         MYSQL_RES* res = mysql_use_result(connection);
286                         unsigned long rows = mysql_affected_rows(connection);
287                         return new MySQLresult(res, rows);
288                 }
289                 else
290                 {
291                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
292                          * possible error numbers and error messages */
293                         SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + ": " + mysql_error(connection));
294                         return new MySQLresult(e);
295                 }
296         }
297
298         bool CheckConnection()
299         {
300                 if (!connection || mysql_ping(connection) != 0)
301                         return Connect();
302                 return true;
303         }
304
305         std::string GetError()
306         {
307                 return mysql_error(connection);
308         }
309
310         void Close()
311         {
312                 mysql_close(connection);
313         }
314
315         void submit(SQLQuery* q, const std::string& qs)
316         {
317                 Parent()->Dispatcher->LockQueue();
318                 Parent()->qq.push_back(QQueueItem(q, qs, this));
319                 Parent()->Dispatcher->UnlockQueueWakeup();
320         }
321
322         void submit(SQLQuery* call, const std::string& q, const ParamL& p)
323         {
324                 std::string res;
325                 unsigned int param = 0;
326                 for(std::string::size_type i = 0; i < q.length(); i++)
327                 {
328                         if (q[i] != '?')
329                                 res.push_back(q[i]);
330                         else
331                         {
332                                 if (param < p.size())
333                                 {
334                                         std::string parm = p[param++];
335                                         // In the worst case, each character may need to be encoded as using two bytes,
336                                         // and one byte is the terminating null
337                                         std::vector<char> buffer(parm.length() * 2 + 1);
338
339                                         // The return value of mysql_escape_string() is the length of the encoded string,
340                                         // not including the terminating null
341                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
342 //                                      mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
343                                         res.append(&buffer[0], escapedsize);
344                                 }
345                         }
346                 }
347                 submit(call, res);
348         }
349
350         void submit(SQLQuery* call, const std::string& q, const ParamM& p)
351         {
352                 std::string res;
353                 for(std::string::size_type i = 0; i < q.length(); i++)
354                 {
355                         if (q[i] != '$')
356                                 res.push_back(q[i]);
357                         else
358                         {
359                                 std::string field;
360                                 i++;
361                                 while (i < q.length() && isalnum(q[i]))
362                                         field.push_back(q[i++]);
363                                 i--;
364
365                                 ParamM::const_iterator it = p.find(field);
366                                 if (it != p.end())
367                                 {
368                                         std::string parm = it->second;
369                                         // NOTE: See above
370                                         std::vector<char> buffer(parm.length() * 2 + 1);
371                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
372                                         res.append(&buffer[0], escapedsize);
373                                 }
374                         }
375                 }
376                 submit(call, res);
377         }
378 };
379
380 ModuleSQL::ModuleSQL()
381 {
382         Dispatcher = NULL;
383 }
384
385 void ModuleSQL::init()
386 {
387         Dispatcher = new DispatcherThread(this);
388         ServerInstance->Threads->Start(Dispatcher);
389
390         Implementation eventlist[] = { I_OnRehash, I_OnUnloadModule };
391         ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
392
393         OnRehash(NULL);
394 }
395
396 ModuleSQL::~ModuleSQL()
397 {
398         if (Dispatcher)
399         {
400                 Dispatcher->join();
401                 Dispatcher->OnNotify();
402                 delete Dispatcher;
403         }
404         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
405         {
406                 delete i->second;
407         }
408 }
409
410 void ModuleSQL::OnRehash(User* user)
411 {
412         ConnMap conns;
413         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
414         for(ConfigIter i = tags.first; i != tags.second; i++)
415         {
416                 if (i->second->getString("module", "mysql") != "mysql")
417                         continue;
418                 std::string id = i->second->getString("id");
419                 ConnMap::iterator curr = connections.find(id);
420                 if (curr == connections.end())
421                 {
422                         SQLConnection* conn = new SQLConnection(this, i->second);
423                         conns.insert(std::make_pair(id, conn));
424                         ServerInstance->Modules->AddService(*conn);
425                 }
426                 else
427                 {
428                         conns.insert(*curr);
429                         connections.erase(curr);
430                 }
431         }
432
433         // now clean up the deleted databases
434         Dispatcher->LockQueue();
435         SQLerror err(SQL_BAD_DBID);
436         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
437         {
438                 ServerInstance->Modules->DelService(*i->second);
439                 // it might be running a query on this database. Wait for that to complete
440                 i->second->lock.Lock();
441                 i->second->lock.Unlock();
442                 // now remove all active queries to this DB
443                 for (size_t j = qq.size(); j > 0; j--)
444                 {
445                         size_t k = j - 1;
446                         if (qq[k].c == i->second)
447                         {
448                                 qq[k].q->OnError(err);
449                                 delete qq[k].q;
450                                 qq.erase(qq.begin() + k);
451                         }
452                 }
453                 // finally, nuke the connection
454                 delete i->second;
455         }
456         Dispatcher->UnlockQueue();
457         connections.swap(conns);
458 }
459
460 void ModuleSQL::OnUnloadModule(Module* mod)
461 {
462         SQLerror err(SQL_BAD_DBID);
463         Dispatcher->LockQueue();
464         unsigned int i = qq.size();
465         while (i > 0)
466         {
467                 i--;
468                 if (qq[i].q->creator == mod)
469                 {
470                         if (i == 0)
471                         {
472                                 // need to wait until the query is done
473                                 // (the result will be discarded)
474                                 qq[i].c->lock.Lock();
475                                 qq[i].c->lock.Unlock();
476                         }
477                         qq[i].q->OnError(err);
478                         delete qq[i].q;
479                         qq.erase(qq.begin() + i);
480                 }
481         }
482         Dispatcher->UnlockQueue();
483         // clean up any result queue entries
484         Dispatcher->OnNotify();
485 }
486
487 Version ModuleSQL::GetVersion()
488 {
489         return Version("MySQL support", VF_VENDOR);
490 }
491
492 void DispatcherThread::Run()
493 {
494         this->LockQueue();
495         while (!this->GetExitFlag())
496         {
497                 if (!Parent->qq.empty())
498                 {
499                         QQueueItem i = Parent->qq.front();
500                         i.c->lock.Lock();
501                         this->UnlockQueue();
502                         MySQLresult* res = i.c->DoBlockingQuery(i.query);
503                         i.c->lock.Unlock();
504
505                         /*
506                          * At this point, the main thread could be working on:
507                          *  Rehash - delete i.c out from under us. We don't care about that.
508                          *  UnloadModule - delete i.q and the qq item. Need to avoid reporting results.
509                          */
510
511                         this->LockQueue();
512                         if (!Parent->qq.empty() && Parent->qq.front().q == i.q)
513                         {
514                                 Parent->qq.pop_front();
515                                 Parent->rq.push_back(RQueueItem(i.q, res));
516                                 NotifyParent();
517                         }
518                         else
519                         {
520                                 // UnloadModule ate the query
521                                 delete res;
522                         }
523                 }
524                 else
525                 {
526                         /* We know the queue is empty, we can safely hang this thread until
527                          * something happens
528                          */
529                         this->WaitForQueue();
530                 }
531         }
532         this->UnlockQueue();
533 }
534
535 void DispatcherThread::OnNotify()
536 {
537         // this could unlock during the dispatch, but OnResult isn't expected to take that long
538         this->LockQueue();
539         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
540         {
541                 MySQLresult* res = i->r;
542                 if (res->err.id == SQL_NO_ERROR)
543                         i->q->OnResult(*res);
544                 else
545                         i->q->OnError(res->err);
546                 delete i->q;
547                 delete i->r;
548         }
549         Parent->rq.clear();
550         this->UnlockQueue();
551 }
552
553 MODULE_INIT(ModuleSQL)