]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
Add the override keyword in places that it is missing.
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22 /// $CompilerFlags: execute("mysql_config --include" "MYSQL_CXXFLAGS")
23 /// $LinkerFlags: execute("mysql_config --libs_r" "MYSQL_LDFLAGS" "-lmysqlclient")
24
25 /// $PackageInfo: require_system("centos" "6.0" "6.99") mysql-devel
26 /// $PackageInfo: require_system("centos" "7.0") mariadb-devel
27 /// $PackageInfo: require_system("darwin") mysql-connector-c
28 /// $PackageInfo: require_system("debian") libmysqlclient-dev
29 /// $PackageInfo: require_system("ubuntu") libmysqlclient-dev
30
31
32 // Fix warnings about the use of `long long` on C++03.
33 #if defined __clang__
34 # pragma clang diagnostic ignored "-Wc++11-long-long"
35 #elif defined __GNUC__
36 # pragma GCC diagnostic ignored "-Wlong-long"
37 #endif
38
39 #include "inspircd.h"
40 #include <mysql.h>
41 #include "modules/sql.h"
42
43 #ifdef _WIN32
44 # pragma comment(lib, "libmysql.lib")
45 #endif
46
47 /* VERSION 3 API: With nonblocking (threaded) requests */
48
49 /* THE NONBLOCKING MYSQL API!
50  *
51  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
52  * that instead, you should thread your program. This is what i've done here to allow for
53  * asyncronous SQL requests via mysql. The way this works is as follows:
54  *
55  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
56  * using a queue with priorities. There is a mutex on either end which prevents two threads
57  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
58  * worker thread wakes up, and checks if there is a request at the head of its queue.
59  * If there is, it processes this request, blocking the worker thread but leaving the ircd
60  * thread to go about its business as usual. During this period, the ircd thread is able
61  * to insert futher pending requests into the queue.
62  *
63  * Once the processing of a request is complete, it is removed from the incoming queue to
64  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
65  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
66  * connection ID through the connection.
67  *
68  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
69  * of the queue, and sends it on its way to the original calling module.
70  *
71  * XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
72  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
73  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
74  * however, if we were to call a module's OnRequest even from within a thread which was not the
75  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
76  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
77  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
78  * gauranteed threadsafe!)
79  */
80
81 class SQLConnection;
82 class MySQLresult;
83 class DispatcherThread;
84
85 struct QQueueItem
86 {
87         SQLQuery* q;
88         std::string query;
89         SQLConnection* c;
90         QQueueItem(SQLQuery* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
91 };
92
93 struct RQueueItem
94 {
95         SQLQuery* q;
96         MySQLresult* r;
97         RQueueItem(SQLQuery* Q, MySQLresult* R) : q(Q), r(R) {}
98 };
99
100 typedef insp::flat_map<std::string, SQLConnection*> ConnMap;
101 typedef std::deque<QQueueItem> QueryQueue;
102 typedef std::deque<RQueueItem> ResultQueue;
103
104 /** MySQL module
105  *  */
106 class ModuleSQL : public Module
107 {
108  public:
109         DispatcherThread* Dispatcher;
110         QueryQueue qq;       // MUST HOLD MUTEX
111         ResultQueue rq;      // MUST HOLD MUTEX
112         ConnMap connections; // main thread only
113
114         ModuleSQL();
115         void init() CXX11_OVERRIDE;
116         ~ModuleSQL();
117         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE;
118         void OnUnloadModule(Module* mod) CXX11_OVERRIDE;
119         Version GetVersion() CXX11_OVERRIDE;
120 };
121
122 class DispatcherThread : public SocketThread
123 {
124  private:
125         ModuleSQL* const Parent;
126  public:
127         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
128         ~DispatcherThread() { }
129         void Run() CXX11_OVERRIDE;
130         void OnNotify() CXX11_OVERRIDE;
131 };
132
133 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
134 #define mysql_field_count mysql_num_fields
135 #endif
136
137 /** Represents a mysql result set
138  */
139 class MySQLresult : public SQLResult
140 {
141  public:
142         SQLerror err;
143         int currentrow;
144         int rows;
145         std::vector<std::string> colnames;
146         std::vector<SQLEntries> fieldlists;
147
148         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL_NO_ERROR), currentrow(0), rows(0)
149         {
150                 if (affected_rows >= 1)
151                 {
152                         rows = affected_rows;
153                         fieldlists.resize(rows);
154                 }
155                 unsigned int field_count = 0;
156                 if (res)
157                 {
158                         MYSQL_ROW row;
159                         int n = 0;
160                         while ((row = mysql_fetch_row(res)))
161                         {
162                                 if (fieldlists.size() < (unsigned int)rows+1)
163                                 {
164                                         fieldlists.resize(fieldlists.size()+1);
165                                 }
166                                 field_count = 0;
167                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
168                                 if(mysql_num_fields(res) == 0)
169                                         break;
170                                 if (fields && mysql_num_fields(res))
171                                 {
172                                         colnames.clear();
173                                         while (field_count < mysql_num_fields(res))
174                                         {
175                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
176                                                 if (row[field_count])
177                                                         fieldlists[n].push_back(SQLEntry(row[field_count]));
178                                                 else
179                                                         fieldlists[n].push_back(SQLEntry());
180                                                 colnames.push_back(a);
181                                                 field_count++;
182                                         }
183                                         n++;
184                                 }
185                                 rows++;
186                         }
187                         mysql_free_result(res);
188                 }
189         }
190
191         MySQLresult(SQLerror& e) : err(e)
192         {
193
194         }
195
196         int Rows() CXX11_OVERRIDE
197         {
198                 return rows;
199         }
200
201         void GetCols(std::vector<std::string>& result) CXX11_OVERRIDE
202         {
203                 result.assign(colnames.begin(), colnames.end());
204         }
205
206         SQLEntry GetValue(int row, int column)
207         {
208                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
209                 {
210                         return fieldlists[row][column];
211                 }
212                 return SQLEntry();
213         }
214
215         bool GetRow(SQLEntries& result) CXX11_OVERRIDE
216         {
217                 if (currentrow < rows)
218                 {
219                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
220                         currentrow++;
221                         return true;
222                 }
223                 else
224                 {
225                         result.clear();
226                         return false;
227                 }
228         }
229 };
230
231 /** Represents a connection to a mysql database
232  */
233 class SQLConnection : public SQLProvider
234 {
235  public:
236         reference<ConfigTag> config;
237         MYSQL *connection;
238         Mutex lock;
239
240         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
241         SQLConnection(Module* p, ConfigTag* tag) : SQLProvider(p, "SQL/" + tag->getString("id")),
242                 config(tag), connection(NULL)
243         {
244         }
245
246         ~SQLConnection()
247         {
248                 Close();
249         }
250
251         // This method connects to the database using the credentials supplied to the constructor, and returns
252         // true upon success.
253         bool Connect()
254         {
255                 unsigned int timeout = 1;
256                 connection = mysql_init(connection);
257                 mysql_options(connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
258                 std::string host = config->getString("host");
259                 std::string user = config->getString("user");
260                 std::string pass = config->getString("pass");
261                 std::string dbname = config->getString("name");
262                 int port = config->getInt("port");
263                 bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
264                 if (!rv)
265                         return rv;
266
267                 // Enable character set settings
268                 std::string charset = config->getString("charset");
269                 if ((!charset.empty()) && (mysql_set_character_set(connection, charset.c_str())))
270                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not set character set to \"%s\"", charset.c_str());
271
272                 std::string initquery;
273                 if (config->readString("initialquery", initquery))
274                 {
275                         mysql_query(connection,initquery.c_str());
276                 }
277                 return true;
278         }
279
280         ModuleSQL* Parent()
281         {
282                 return (ModuleSQL*)(Module*)creator;
283         }
284
285         MySQLresult* DoBlockingQuery(const std::string& query)
286         {
287
288                 /* Parse the command string and dispatch it to mysql */
289                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
290                 {
291                         /* Successfull query */
292                         MYSQL_RES* res = mysql_use_result(connection);
293                         unsigned long rows = mysql_affected_rows(connection);
294                         return new MySQLresult(res, rows);
295                 }
296                 else
297                 {
298                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
299                          * possible error numbers and error messages */
300                         SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + ": " + mysql_error(connection));
301                         return new MySQLresult(e);
302                 }
303         }
304
305         bool CheckConnection()
306         {
307                 if (!connection || mysql_ping(connection) != 0)
308                         return Connect();
309                 return true;
310         }
311
312         std::string GetError()
313         {
314                 return mysql_error(connection);
315         }
316
317         void Close()
318         {
319                 mysql_close(connection);
320         }
321
322         void submit(SQLQuery* q, const std::string& qs) CXX11_OVERRIDE
323         {
324                 Parent()->Dispatcher->LockQueue();
325                 Parent()->qq.push_back(QQueueItem(q, qs, this));
326                 Parent()->Dispatcher->UnlockQueueWakeup();
327         }
328
329         void submit(SQLQuery* call, const std::string& q, const ParamL& p) CXX11_OVERRIDE
330         {
331                 std::string res;
332                 unsigned int param = 0;
333                 for(std::string::size_type i = 0; i < q.length(); i++)
334                 {
335                         if (q[i] != '?')
336                                 res.push_back(q[i]);
337                         else
338                         {
339                                 if (param < p.size())
340                                 {
341                                         std::string parm = p[param++];
342                                         // In the worst case, each character may need to be encoded as using two bytes,
343                                         // and one byte is the terminating null
344                                         std::vector<char> buffer(parm.length() * 2 + 1);
345
346                                         // The return value of mysql_escape_string() is the length of the encoded string,
347                                         // not including the terminating null
348                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
349 //                                      mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
350                                         res.append(&buffer[0], escapedsize);
351                                 }
352                         }
353                 }
354                 submit(call, res);
355         }
356
357         void submit(SQLQuery* call, const std::string& q, const ParamM& p) CXX11_OVERRIDE
358         {
359                 std::string res;
360                 for(std::string::size_type i = 0; i < q.length(); i++)
361                 {
362                         if (q[i] != '$')
363                                 res.push_back(q[i]);
364                         else
365                         {
366                                 std::string field;
367                                 i++;
368                                 while (i < q.length() && isalnum(q[i]))
369                                         field.push_back(q[i++]);
370                                 i--;
371
372                                 ParamM::const_iterator it = p.find(field);
373                                 if (it != p.end())
374                                 {
375                                         std::string parm = it->second;
376                                         // NOTE: See above
377                                         std::vector<char> buffer(parm.length() * 2 + 1);
378                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
379                                         res.append(&buffer[0], escapedsize);
380                                 }
381                         }
382                 }
383                 submit(call, res);
384         }
385 };
386
387 ModuleSQL::ModuleSQL()
388 {
389         Dispatcher = NULL;
390 }
391
392 void ModuleSQL::init()
393 {
394         Dispatcher = new DispatcherThread(this);
395         ServerInstance->Threads.Start(Dispatcher);
396 }
397
398 ModuleSQL::~ModuleSQL()
399 {
400         if (Dispatcher)
401         {
402                 Dispatcher->join();
403                 Dispatcher->OnNotify();
404                 delete Dispatcher;
405         }
406         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
407         {
408                 delete i->second;
409         }
410 }
411
412 void ModuleSQL::ReadConfig(ConfigStatus& status)
413 {
414         ConnMap conns;
415         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
416         for(ConfigIter i = tags.first; i != tags.second; i++)
417         {
418                 if (i->second->getString("module", "mysql") != "mysql")
419                         continue;
420                 std::string id = i->second->getString("id");
421                 ConnMap::iterator curr = connections.find(id);
422                 if (curr == connections.end())
423                 {
424                         SQLConnection* conn = new SQLConnection(this, i->second);
425                         conns.insert(std::make_pair(id, conn));
426                         ServerInstance->Modules->AddService(*conn);
427                 }
428                 else
429                 {
430                         conns.insert(*curr);
431                         connections.erase(curr);
432                 }
433         }
434
435         // now clean up the deleted databases
436         Dispatcher->LockQueue();
437         SQLerror err(SQL_BAD_DBID);
438         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
439         {
440                 ServerInstance->Modules->DelService(*i->second);
441                 // it might be running a query on this database. Wait for that to complete
442                 i->second->lock.Lock();
443                 i->second->lock.Unlock();
444                 // now remove all active queries to this DB
445                 for (size_t j = qq.size(); j > 0; j--)
446                 {
447                         size_t k = j - 1;
448                         if (qq[k].c == i->second)
449                         {
450                                 qq[k].q->OnError(err);
451                                 delete qq[k].q;
452                                 qq.erase(qq.begin() + k);
453                         }
454                 }
455                 // finally, nuke the connection
456                 delete i->second;
457         }
458         Dispatcher->UnlockQueue();
459         connections.swap(conns);
460 }
461
462 void ModuleSQL::OnUnloadModule(Module* mod)
463 {
464         SQLerror err(SQL_BAD_DBID);
465         Dispatcher->LockQueue();
466         unsigned int i = qq.size();
467         while (i > 0)
468         {
469                 i--;
470                 if (qq[i].q->creator == mod)
471                 {
472                         if (i == 0)
473                         {
474                                 // need to wait until the query is done
475                                 // (the result will be discarded)
476                                 qq[i].c->lock.Lock();
477                                 qq[i].c->lock.Unlock();
478                         }
479                         qq[i].q->OnError(err);
480                         delete qq[i].q;
481                         qq.erase(qq.begin() + i);
482                 }
483         }
484         Dispatcher->UnlockQueue();
485         // clean up any result queue entries
486         Dispatcher->OnNotify();
487 }
488
489 Version ModuleSQL::GetVersion()
490 {
491         return Version("MySQL support", VF_VENDOR);
492 }
493
494 void DispatcherThread::Run()
495 {
496         this->LockQueue();
497         while (!this->GetExitFlag())
498         {
499                 if (!Parent->qq.empty())
500                 {
501                         QQueueItem i = Parent->qq.front();
502                         i.c->lock.Lock();
503                         this->UnlockQueue();
504                         MySQLresult* res = i.c->DoBlockingQuery(i.query);
505                         i.c->lock.Unlock();
506
507                         /*
508                          * At this point, the main thread could be working on:
509                          *  Rehash - delete i.c out from under us. We don't care about that.
510                          *  UnloadModule - delete i.q and the qq item. Need to avoid reporting results.
511                          */
512
513                         this->LockQueue();
514                         if (!Parent->qq.empty() && Parent->qq.front().q == i.q)
515                         {
516                                 Parent->qq.pop_front();
517                                 Parent->rq.push_back(RQueueItem(i.q, res));
518                                 NotifyParent();
519                         }
520                         else
521                         {
522                                 // UnloadModule ate the query
523                                 delete res;
524                         }
525                 }
526                 else
527                 {
528                         /* We know the queue is empty, we can safely hang this thread until
529                          * something happens
530                          */
531                         this->WaitForQueue();
532                 }
533         }
534         this->UnlockQueue();
535 }
536
537 void DispatcherThread::OnNotify()
538 {
539         // this could unlock during the dispatch, but OnResult isn't expected to take that long
540         this->LockQueue();
541         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
542         {
543                 MySQLresult* res = i->r;
544                 if (res->err.id == SQL_NO_ERROR)
545                         i->q->OnResult(*res);
546                 else
547                         i->q->OnError(res->err);
548                 delete i->q;
549                 delete i->r;
550         }
551         Parent->rq.clear();
552         this->UnlockQueue();
553 }
554
555 MODULE_INIT(ModuleSQL)