]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
Pass the ModeHandler to User::HasModePermission()
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22
23 // Fix warnings about the use of `long long` on C++03.
24 #if defined __clang__
25 # pragma clang diagnostic ignored "-Wc++11-long-long"
26 #elif defined __GNUC__
27 # pragma GCC diagnostic ignored "-Wlong-long"
28 #endif
29
30 #include "inspircd.h"
31 #include <mysql.h>
32 #include "modules/sql.h"
33
34 #ifdef _WIN32
35 # pragma comment(lib, "libmysql.lib")
36 #endif
37
38 /* VERSION 3 API: With nonblocking (threaded) requests */
39
40 /* $CompileFlags: exec("mysql_config --include") */
41 /* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */
42
43 /* THE NONBLOCKING MYSQL API!
44  *
45  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
46  * that instead, you should thread your program. This is what i've done here to allow for
47  * asyncronous SQL requests via mysql. The way this works is as follows:
48  *
49  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
50  * using a queue with priorities. There is a mutex on either end which prevents two threads
51  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
52  * worker thread wakes up, and checks if there is a request at the head of its queue.
53  * If there is, it processes this request, blocking the worker thread but leaving the ircd
54  * thread to go about its business as usual. During this period, the ircd thread is able
55  * to insert futher pending requests into the queue.
56  *
57  * Once the processing of a request is complete, it is removed from the incoming queue to
58  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
59  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
60  * connection ID through the connection.
61  *
62  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
63  * of the queue, and sends it on its way to the original calling module.
64  *
65  * XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
66  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
67  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
68  * however, if we were to call a module's OnRequest even from within a thread which was not the
69  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
70  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
71  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
72  * gauranteed threadsafe!)
73  *
74  * For a diagram of this system please see http://wiki.inspircd.org/Mysql2
75  */
76
77 class SQLConnection;
78 class MySQLresult;
79 class DispatcherThread;
80
81 struct QQueueItem
82 {
83         SQLQuery* q;
84         std::string query;
85         SQLConnection* c;
86         QQueueItem(SQLQuery* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
87 };
88
89 struct RQueueItem
90 {
91         SQLQuery* q;
92         MySQLresult* r;
93         RQueueItem(SQLQuery* Q, MySQLresult* R) : q(Q), r(R) {}
94 };
95
96 typedef insp::flat_map<std::string, SQLConnection*> ConnMap;
97 typedef std::deque<QQueueItem> QueryQueue;
98 typedef std::deque<RQueueItem> ResultQueue;
99
100 /** MySQL module
101  *  */
102 class ModuleSQL : public Module
103 {
104  public:
105         DispatcherThread* Dispatcher;
106         QueryQueue qq;       // MUST HOLD MUTEX
107         ResultQueue rq;      // MUST HOLD MUTEX
108         ConnMap connections; // main thread only
109
110         ModuleSQL();
111         void init() CXX11_OVERRIDE;
112         ~ModuleSQL();
113         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE;
114         void OnUnloadModule(Module* mod) CXX11_OVERRIDE;
115         Version GetVersion() CXX11_OVERRIDE;
116 };
117
118 class DispatcherThread : public SocketThread
119 {
120  private:
121         ModuleSQL* const Parent;
122  public:
123         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
124         ~DispatcherThread() { }
125         void Run();
126         void OnNotify();
127 };
128
129 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
130 #define mysql_field_count mysql_num_fields
131 #endif
132
133 /** Represents a mysql result set
134  */
135 class MySQLresult : public SQLResult
136 {
137  public:
138         SQLerror err;
139         int currentrow;
140         int rows;
141         std::vector<std::string> colnames;
142         std::vector<SQLEntries> fieldlists;
143
144         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL_NO_ERROR), currentrow(0), rows(0)
145         {
146                 if (affected_rows >= 1)
147                 {
148                         rows = affected_rows;
149                         fieldlists.resize(rows);
150                 }
151                 unsigned int field_count = 0;
152                 if (res)
153                 {
154                         MYSQL_ROW row;
155                         int n = 0;
156                         while ((row = mysql_fetch_row(res)))
157                         {
158                                 if (fieldlists.size() < (unsigned int)rows+1)
159                                 {
160                                         fieldlists.resize(fieldlists.size()+1);
161                                 }
162                                 field_count = 0;
163                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
164                                 if(mysql_num_fields(res) == 0)
165                                         break;
166                                 if (fields && mysql_num_fields(res))
167                                 {
168                                         colnames.clear();
169                                         while (field_count < mysql_num_fields(res))
170                                         {
171                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
172                                                 if (row[field_count])
173                                                         fieldlists[n].push_back(SQLEntry(row[field_count]));
174                                                 else
175                                                         fieldlists[n].push_back(SQLEntry());
176                                                 colnames.push_back(a);
177                                                 field_count++;
178                                         }
179                                         n++;
180                                 }
181                                 rows++;
182                         }
183                         mysql_free_result(res);
184                 }
185         }
186
187         MySQLresult(SQLerror& e) : err(e)
188         {
189
190         }
191
192         int Rows()
193         {
194                 return rows;
195         }
196
197         void GetCols(std::vector<std::string>& result)
198         {
199                 result.assign(colnames.begin(), colnames.end());
200         }
201
202         SQLEntry GetValue(int row, int column)
203         {
204                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
205                 {
206                         return fieldlists[row][column];
207                 }
208                 return SQLEntry();
209         }
210
211         bool GetRow(SQLEntries& result)
212         {
213                 if (currentrow < rows)
214                 {
215                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
216                         currentrow++;
217                         return true;
218                 }
219                 else
220                 {
221                         result.clear();
222                         return false;
223                 }
224         }
225 };
226
227 /** Represents a connection to a mysql database
228  */
229 class SQLConnection : public SQLProvider
230 {
231  public:
232         reference<ConfigTag> config;
233         MYSQL *connection;
234         Mutex lock;
235
236         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
237         SQLConnection(Module* p, ConfigTag* tag) : SQLProvider(p, "SQL/" + tag->getString("id")),
238                 config(tag), connection(NULL)
239         {
240         }
241
242         ~SQLConnection()
243         {
244                 Close();
245         }
246
247         // This method connects to the database using the credentials supplied to the constructor, and returns
248         // true upon success.
249         bool Connect()
250         {
251                 unsigned int timeout = 1;
252                 connection = mysql_init(connection);
253                 mysql_options(connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
254                 std::string host = config->getString("host");
255                 std::string user = config->getString("user");
256                 std::string pass = config->getString("pass");
257                 std::string dbname = config->getString("name");
258                 int port = config->getInt("port");
259                 bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
260                 if (!rv)
261                         return rv;
262
263                 // Enable character set settings
264                 std::string charset = config->getString("charset");
265                 if ((!charset.empty()) && (mysql_set_character_set(connection, charset.c_str())))
266                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not set character set to \"%s\"", charset.c_str());
267
268                 std::string initquery;
269                 if (config->readString("initialquery", initquery))
270                 {
271                         mysql_query(connection,initquery.c_str());
272                 }
273                 return true;
274         }
275
276         ModuleSQL* Parent()
277         {
278                 return (ModuleSQL*)(Module*)creator;
279         }
280
281         MySQLresult* DoBlockingQuery(const std::string& query)
282         {
283
284                 /* Parse the command string and dispatch it to mysql */
285                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
286                 {
287                         /* Successfull query */
288                         MYSQL_RES* res = mysql_use_result(connection);
289                         unsigned long rows = mysql_affected_rows(connection);
290                         return new MySQLresult(res, rows);
291                 }
292                 else
293                 {
294                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
295                          * possible error numbers and error messages */
296                         SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + ": " + mysql_error(connection));
297                         return new MySQLresult(e);
298                 }
299         }
300
301         bool CheckConnection()
302         {
303                 if (!connection || mysql_ping(connection) != 0)
304                         return Connect();
305                 return true;
306         }
307
308         std::string GetError()
309         {
310                 return mysql_error(connection);
311         }
312
313         void Close()
314         {
315                 mysql_close(connection);
316         }
317
318         void submit(SQLQuery* q, const std::string& qs)
319         {
320                 Parent()->Dispatcher->LockQueue();
321                 Parent()->qq.push_back(QQueueItem(q, qs, this));
322                 Parent()->Dispatcher->UnlockQueueWakeup();
323         }
324
325         void submit(SQLQuery* call, const std::string& q, const ParamL& p)
326         {
327                 std::string res;
328                 unsigned int param = 0;
329                 for(std::string::size_type i = 0; i < q.length(); i++)
330                 {
331                         if (q[i] != '?')
332                                 res.push_back(q[i]);
333                         else
334                         {
335                                 if (param < p.size())
336                                 {
337                                         std::string parm = p[param++];
338                                         // In the worst case, each character may need to be encoded as using two bytes,
339                                         // and one byte is the terminating null
340                                         std::vector<char> buffer(parm.length() * 2 + 1);
341
342                                         // The return value of mysql_escape_string() is the length of the encoded string,
343                                         // not including the terminating null
344                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
345 //                                      mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
346                                         res.append(&buffer[0], escapedsize);
347                                 }
348                         }
349                 }
350                 submit(call, res);
351         }
352
353         void submit(SQLQuery* call, const std::string& q, const ParamM& p)
354         {
355                 std::string res;
356                 for(std::string::size_type i = 0; i < q.length(); i++)
357                 {
358                         if (q[i] != '$')
359                                 res.push_back(q[i]);
360                         else
361                         {
362                                 std::string field;
363                                 i++;
364                                 while (i < q.length() && isalnum(q[i]))
365                                         field.push_back(q[i++]);
366                                 i--;
367
368                                 ParamM::const_iterator it = p.find(field);
369                                 if (it != p.end())
370                                 {
371                                         std::string parm = it->second;
372                                         // NOTE: See above
373                                         std::vector<char> buffer(parm.length() * 2 + 1);
374                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
375                                         res.append(&buffer[0], escapedsize);
376                                 }
377                         }
378                 }
379                 submit(call, res);
380         }
381 };
382
383 ModuleSQL::ModuleSQL()
384 {
385         Dispatcher = NULL;
386 }
387
388 void ModuleSQL::init()
389 {
390         Dispatcher = new DispatcherThread(this);
391         ServerInstance->Threads.Start(Dispatcher);
392 }
393
394 ModuleSQL::~ModuleSQL()
395 {
396         if (Dispatcher)
397         {
398                 Dispatcher->join();
399                 Dispatcher->OnNotify();
400                 delete Dispatcher;
401         }
402         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
403         {
404                 delete i->second;
405         }
406 }
407
408 void ModuleSQL::ReadConfig(ConfigStatus& status)
409 {
410         ConnMap conns;
411         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
412         for(ConfigIter i = tags.first; i != tags.second; i++)
413         {
414                 if (i->second->getString("module", "mysql") != "mysql")
415                         continue;
416                 std::string id = i->second->getString("id");
417                 ConnMap::iterator curr = connections.find(id);
418                 if (curr == connections.end())
419                 {
420                         SQLConnection* conn = new SQLConnection(this, i->second);
421                         conns.insert(std::make_pair(id, conn));
422                         ServerInstance->Modules->AddService(*conn);
423                 }
424                 else
425                 {
426                         conns.insert(*curr);
427                         connections.erase(curr);
428                 }
429         }
430
431         // now clean up the deleted databases
432         Dispatcher->LockQueue();
433         SQLerror err(SQL_BAD_DBID);
434         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
435         {
436                 ServerInstance->Modules->DelService(*i->second);
437                 // it might be running a query on this database. Wait for that to complete
438                 i->second->lock.Lock();
439                 i->second->lock.Unlock();
440                 // now remove all active queries to this DB
441                 for (size_t j = qq.size(); j > 0; j--)
442                 {
443                         size_t k = j - 1;
444                         if (qq[k].c == i->second)
445                         {
446                                 qq[k].q->OnError(err);
447                                 delete qq[k].q;
448                                 qq.erase(qq.begin() + k);
449                         }
450                 }
451                 // finally, nuke the connection
452                 delete i->second;
453         }
454         Dispatcher->UnlockQueue();
455         connections.swap(conns);
456 }
457
458 void ModuleSQL::OnUnloadModule(Module* mod)
459 {
460         SQLerror err(SQL_BAD_DBID);
461         Dispatcher->LockQueue();
462         unsigned int i = qq.size();
463         while (i > 0)
464         {
465                 i--;
466                 if (qq[i].q->creator == mod)
467                 {
468                         if (i == 0)
469                         {
470                                 // need to wait until the query is done
471                                 // (the result will be discarded)
472                                 qq[i].c->lock.Lock();
473                                 qq[i].c->lock.Unlock();
474                         }
475                         qq[i].q->OnError(err);
476                         delete qq[i].q;
477                         qq.erase(qq.begin() + i);
478                 }
479         }
480         Dispatcher->UnlockQueue();
481         // clean up any result queue entries
482         Dispatcher->OnNotify();
483 }
484
485 Version ModuleSQL::GetVersion()
486 {
487         return Version("MySQL support", VF_VENDOR);
488 }
489
490 void DispatcherThread::Run()
491 {
492         this->LockQueue();
493         while (!this->GetExitFlag())
494         {
495                 if (!Parent->qq.empty())
496                 {
497                         QQueueItem i = Parent->qq.front();
498                         i.c->lock.Lock();
499                         this->UnlockQueue();
500                         MySQLresult* res = i.c->DoBlockingQuery(i.query);
501                         i.c->lock.Unlock();
502
503                         /*
504                          * At this point, the main thread could be working on:
505                          *  Rehash - delete i.c out from under us. We don't care about that.
506                          *  UnloadModule - delete i.q and the qq item. Need to avoid reporting results.
507                          */
508
509                         this->LockQueue();
510                         if (!Parent->qq.empty() && Parent->qq.front().q == i.q)
511                         {
512                                 Parent->qq.pop_front();
513                                 Parent->rq.push_back(RQueueItem(i.q, res));
514                                 NotifyParent();
515                         }
516                         else
517                         {
518                                 // UnloadModule ate the query
519                                 delete res;
520                         }
521                 }
522                 else
523                 {
524                         /* We know the queue is empty, we can safely hang this thread until
525                          * something happens
526                          */
527                         this->WaitForQueue();
528                 }
529         }
530         this->UnlockQueue();
531 }
532
533 void DispatcherThread::OnNotify()
534 {
535         // this could unlock during the dispatch, but OnResult isn't expected to take that long
536         this->LockQueue();
537         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
538         {
539                 MySQLresult* res = i->r;
540                 if (res->err.id == SQL_NO_ERROR)
541                         i->q->OnResult(*res);
542                 else
543                         i->q->OnError(res->err);
544                 delete i->q;
545                 delete i->r;
546         }
547         Parent->rq.clear();
548         this->UnlockQueue();
549 }
550
551 MODULE_INIT(ModuleSQL)