]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
fe9bb4ceccb7f90c202bfd131ad41de8f98e14b9
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22 /// $CompilerFlags: execute("mysql_config --include" "MYSQL_CXXFLAGS")
23 /// $LinkerFlags: execute("mysql_config --libs_r" "MYSQL_LDFLAGS" "-lmysqlclient")
24
25 /// $PackageInfo: require_system("arch") mariadb-libs
26 /// $PackageInfo: require_system("centos" "6.0" "6.99") mysql-devel
27 /// $PackageInfo: require_system("centos" "7.0") mariadb-devel
28 /// $PackageInfo: require_system("darwin") mysql-connector-c
29 /// $PackageInfo: require_system("debian") libmysqlclient-dev
30 /// $PackageInfo: require_system("ubuntu") libmysqlclient-dev
31
32 #ifdef __GNUC__
33 # pragma GCC diagnostic push
34 #endif
35
36 // Fix warnings about the use of `long long` on C++03.
37 #if defined __clang__
38 # pragma clang diagnostic ignored "-Wc++11-long-long"
39 #elif defined __GNUC__
40 # pragma GCC diagnostic ignored "-Wlong-long"
41 #endif
42
43 #include "inspircd.h"
44 #include <mysql.h>
45 #include "modules/sql.h"
46
47 #ifdef __GNUC__
48 # pragma GCC diagnostic pop
49 #endif
50
51 #ifdef _WIN32
52 # pragma comment(lib, "libmysql.lib")
53 #endif
54
55 /* VERSION 3 API: With nonblocking (threaded) requests */
56
57 /* THE NONBLOCKING MYSQL API!
58  *
59  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
60  * that instead, you should thread your program. This is what i've done here to allow for
61  * asyncronous SQL requests via mysql. The way this works is as follows:
62  *
63  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
64  * using a queue with priorities. There is a mutex on either end which prevents two threads
65  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
66  * worker thread wakes up, and checks if there is a request at the head of its queue.
67  * If there is, it processes this request, blocking the worker thread but leaving the ircd
68  * thread to go about its business as usual. During this period, the ircd thread is able
69  * to insert futher pending requests into the queue.
70  *
71  * Once the processing of a request is complete, it is removed from the incoming queue to
72  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
73  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
74  * connection ID through the connection.
75  *
76  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
77  * of the queue, and sends it on its way to the original calling module.
78  *
79  * XXX: You might be asking "why doesnt it just send the response from within the worker thread?"
80  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
81  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
82  * however, if we were to call a module's OnRequest even from within a thread which was not the
83  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
84  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
85  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
86  * gauranteed threadsafe!)
87  */
88
89 class SQLConnection;
90 class MySQLresult;
91 class DispatcherThread;
92
93 struct QQueueItem
94 {
95         SQL::Query* q;
96         std::string query;
97         SQLConnection* c;
98         QQueueItem(SQL::Query* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
99 };
100
101 struct RQueueItem
102 {
103         SQL::Query* q;
104         MySQLresult* r;
105         RQueueItem(SQL::Query* Q, MySQLresult* R) : q(Q), r(R) {}
106 };
107
108 typedef insp::flat_map<std::string, SQLConnection*> ConnMap;
109 typedef std::deque<QQueueItem> QueryQueue;
110 typedef std::deque<RQueueItem> ResultQueue;
111
112 /** MySQL module
113  *  */
114 class ModuleSQL : public Module
115 {
116  public:
117         DispatcherThread* Dispatcher;
118         QueryQueue qq;       // MUST HOLD MUTEX
119         ResultQueue rq;      // MUST HOLD MUTEX
120         ConnMap connections; // main thread only
121
122         ModuleSQL();
123         void init() CXX11_OVERRIDE;
124         ~ModuleSQL();
125         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE;
126         void OnUnloadModule(Module* mod) CXX11_OVERRIDE;
127         Version GetVersion() CXX11_OVERRIDE;
128 };
129
130 class DispatcherThread : public SocketThread
131 {
132  private:
133         ModuleSQL* const Parent;
134  public:
135         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
136         ~DispatcherThread() { }
137         void Run() CXX11_OVERRIDE;
138         void OnNotify() CXX11_OVERRIDE;
139 };
140
141 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
142 #define mysql_field_count mysql_num_fields
143 #endif
144
145 /** Represents a mysql result set
146  */
147 class MySQLresult : public SQL::Result
148 {
149  public:
150         SQL::Error err;
151         int currentrow;
152         int rows;
153         std::vector<std::string> colnames;
154         std::vector<SQL::Row> fieldlists;
155
156         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL::SUCCESS), currentrow(0), rows(0)
157         {
158                 if (affected_rows >= 1)
159                 {
160                         rows = affected_rows;
161                         fieldlists.resize(rows);
162                 }
163                 unsigned int field_count = 0;
164                 if (res)
165                 {
166                         MYSQL_ROW row;
167                         int n = 0;
168                         while ((row = mysql_fetch_row(res)))
169                         {
170                                 if (fieldlists.size() < (unsigned int)rows+1)
171                                 {
172                                         fieldlists.resize(fieldlists.size()+1);
173                                 }
174                                 field_count = 0;
175                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
176                                 if(mysql_num_fields(res) == 0)
177                                         break;
178                                 if (fields && mysql_num_fields(res))
179                                 {
180                                         colnames.clear();
181                                         while (field_count < mysql_num_fields(res))
182                                         {
183                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
184                                                 if (row[field_count])
185                                                         fieldlists[n].push_back(SQL::Field(row[field_count]));
186                                                 else
187                                                         fieldlists[n].push_back(SQL::Field());
188                                                 colnames.push_back(a);
189                                                 field_count++;
190                                         }
191                                         n++;
192                                 }
193                                 rows++;
194                         }
195                         mysql_free_result(res);
196                 }
197         }
198
199         MySQLresult(SQL::Error& e) : err(e)
200         {
201
202         }
203
204         int Rows() CXX11_OVERRIDE
205         {
206                 return rows;
207         }
208
209         void GetCols(std::vector<std::string>& result) CXX11_OVERRIDE
210         {
211                 result.assign(colnames.begin(), colnames.end());
212         }
213
214         bool HasColumn(const std::string& column, size_t& index) CXX11_OVERRIDE
215         {
216                 for (size_t i = 0; i < colnames.size(); ++i)
217                 {
218                         if (colnames[i] == column)
219                         {
220                                 index = i;
221                                 return true;
222                         }
223                 }
224                 return false;
225         }
226
227         SQL::Field GetValue(int row, int column)
228         {
229                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
230                 {
231                         return fieldlists[row][column];
232                 }
233                 return SQL::Field();
234         }
235
236         bool GetRow(SQL::Row& result) CXX11_OVERRIDE
237         {
238                 if (currentrow < rows)
239                 {
240                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
241                         currentrow++;
242                         return true;
243                 }
244                 else
245                 {
246                         result.clear();
247                         return false;
248                 }
249         }
250 };
251
252 /** Represents a connection to a mysql database
253  */
254 class SQLConnection : public SQL::Provider
255 {
256  private:
257         bool EscapeString(SQL::Query* query, const std::string& in, std::string& out)
258         {
259                 // In the worst case each character may need to be encoded as using two bytes and one
260                 // byte is the NUL terminator.
261                 std::vector<char> buffer(in.length() * 2 + 1);
262
263                 // The return value of mysql_escape_string() is either an error or the length of the
264                 // encoded string not including the NUL terminator.
265                 //
266                 // Unfortunately, someone genius decided that mysql_escape_string should return an
267                 // unsigned type even though -1 is returned on error so checking whether an error
268                 // happened is a bit cursed.
269                 unsigned long escapedsize = mysql_escape_string(&buffer[0], in.c_str(), in.length());
270                 if (escapedsize == static_cast<unsigned long>(-1))
271                 {
272                         SQL::Error err(SQL::QSEND_FAIL, InspIRCd::Format("%u: %s", mysql_errno(connection), mysql_error(connection)));
273                         query->OnError(err);
274                         return false;
275                 }
276
277                 out.append(&buffer[0], escapedsize);
278                 return true;
279         }
280
281  public:
282         reference<ConfigTag> config;
283         MYSQL *connection;
284         Mutex lock;
285
286         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
287         SQLConnection(Module* p, ConfigTag* tag)
288                 : SQL::Provider(p, tag->getString("id"))
289                 , config(tag)
290                 , connection(NULL)
291         {
292         }
293
294         ~SQLConnection()
295         {
296                 Close();
297         }
298
299         // This method connects to the database using the credentials supplied to the constructor, and returns
300         // true upon success.
301         bool Connect()
302         {
303                 connection = mysql_init(connection);
304
305                 // Set the connection timeout.
306                 unsigned int timeout = config->getDuration("timeout", 5, 1, 30);
307                 mysql_options(connection, MYSQL_OPT_CONNECT_TIMEOUT, &timeout);
308
309                 // Attempt to connect to the database.
310                 const std::string host = config->getString("host");
311                 const std::string user = config->getString("user");
312                 const std::string pass = config->getString("pass");
313                 const std::string dbname = config->getString("name");
314                 unsigned int port = config->getUInt("port", 3306, 1, 65535);
315                 if (!mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, CLIENT_IGNORE_SIGPIPE))
316                 {
317                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Unable to connect to the %s MySQL server: %s",
318                                 GetId().c_str(), mysql_error(connection));
319                         return false;
320                 }
321
322                 // Set the default character set.
323                 const std::string charset = config->getString("charset");
324                 if (!charset.empty() && mysql_set_character_set(connection, charset.c_str()))
325                 {
326                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Could not set character set for %s to \"%s\": %s",
327                                 GetId().c_str(), charset.c_str(), mysql_error(connection));
328                         return false;
329                 }
330
331                 // Execute the initial SQL query.
332                 const std::string initialquery = config->getString("initialquery");
333                 if (!initialquery.empty() && mysql_real_query(connection, initialquery.data(), initialquery.length()))
334                 {
335                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Could not execute initial query \"%s\" for %s: %s",
336                                 initialquery.c_str(), name.c_str(), mysql_error(connection));
337                         return false;
338                 }
339
340                 return true;
341         }
342
343         ModuleSQL* Parent()
344         {
345                 return (ModuleSQL*)(Module*)creator;
346         }
347
348         MySQLresult* DoBlockingQuery(const std::string& query)
349         {
350
351                 /* Parse the command string and dispatch it to mysql */
352                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
353                 {
354                         /* Successfull query */
355                         MYSQL_RES* res = mysql_use_result(connection);
356                         unsigned long rows = mysql_affected_rows(connection);
357                         return new MySQLresult(res, rows);
358                 }
359                 else
360                 {
361                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
362                          * possible error numbers and error messages */
363                         SQL::Error e(SQL::QREPLY_FAIL, InspIRCd::Format("%u: %s", mysql_errno(connection), mysql_error(connection)));
364                         return new MySQLresult(e);
365                 }
366         }
367
368         bool CheckConnection()
369         {
370                 if (!connection || mysql_ping(connection) != 0)
371                         return Connect();
372                 return true;
373         }
374
375         std::string GetError()
376         {
377                 return mysql_error(connection);
378         }
379
380         void Close()
381         {
382                 mysql_close(connection);
383         }
384
385         void Submit(SQL::Query* q, const std::string& qs) CXX11_OVERRIDE
386         {
387                 ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Executing MySQL query: " + qs);
388                 Parent()->Dispatcher->LockQueue();
389                 Parent()->qq.push_back(QQueueItem(q, qs, this));
390                 Parent()->Dispatcher->UnlockQueueWakeup();
391         }
392
393         void Submit(SQL::Query* call, const std::string& q, const SQL::ParamList& p) CXX11_OVERRIDE
394         {
395                 std::string res;
396                 unsigned int param = 0;
397                 for(std::string::size_type i = 0; i < q.length(); i++)
398                 {
399                         if (q[i] != '?')
400                                 res.push_back(q[i]);
401                         else if (param < p.size() && !EscapeString(call, p[param++], res))
402                                 return;
403                 }
404                 Submit(call, res);
405         }
406
407         void Submit(SQL::Query* call, const std::string& q, const SQL::ParamMap& p) CXX11_OVERRIDE
408         {
409                 std::string res;
410                 for(std::string::size_type i = 0; i < q.length(); i++)
411                 {
412                         if (q[i] != '$')
413                                 res.push_back(q[i]);
414                         else
415                         {
416                                 std::string field;
417                                 i++;
418                                 while (i < q.length() && isalnum(q[i]))
419                                         field.push_back(q[i++]);
420                                 i--;
421
422                                 SQL::ParamMap::const_iterator it = p.find(field);
423                                 if (it != p.end() && !EscapeString(call, it->second, res))
424                                         return;
425                         }
426                 }
427                 Submit(call, res);
428         }
429 };
430
431 ModuleSQL::ModuleSQL()
432 {
433         Dispatcher = NULL;
434 }
435
436 void ModuleSQL::init()
437 {
438         if (mysql_library_init(0, NULL, NULL))
439                 throw ModuleException("Unable to initialise the MySQL library!");
440
441         Dispatcher = new DispatcherThread(this);
442         ServerInstance->Threads.Start(Dispatcher);
443 }
444
445 ModuleSQL::~ModuleSQL()
446 {
447         if (Dispatcher)
448         {
449                 Dispatcher->join();
450                 Dispatcher->OnNotify();
451                 delete Dispatcher;
452         }
453
454         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
455         {
456                 delete i->second;
457         }
458
459         mysql_library_end();
460 }
461
462 void ModuleSQL::ReadConfig(ConfigStatus& status)
463 {
464         ConnMap conns;
465         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
466         for(ConfigIter i = tags.first; i != tags.second; i++)
467         {
468                 if (!stdalgo::string::equalsci(i->second->getString("module"), "mysql"))
469                         continue;
470                 std::string id = i->second->getString("id");
471                 ConnMap::iterator curr = connections.find(id);
472                 if (curr == connections.end())
473                 {
474                         SQLConnection* conn = new SQLConnection(this, i->second);
475                         conns.insert(std::make_pair(id, conn));
476                         ServerInstance->Modules->AddService(*conn);
477                 }
478                 else
479                 {
480                         conns.insert(*curr);
481                         connections.erase(curr);
482                 }
483         }
484
485         // now clean up the deleted databases
486         Dispatcher->LockQueue();
487         SQL::Error err(SQL::BAD_DBID);
488         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
489         {
490                 ServerInstance->Modules->DelService(*i->second);
491                 // it might be running a query on this database. Wait for that to complete
492                 i->second->lock.Lock();
493                 i->second->lock.Unlock();
494                 // now remove all active queries to this DB
495                 for (size_t j = qq.size(); j > 0; j--)
496                 {
497                         size_t k = j - 1;
498                         if (qq[k].c == i->second)
499                         {
500                                 qq[k].q->OnError(err);
501                                 delete qq[k].q;
502                                 qq.erase(qq.begin() + k);
503                         }
504                 }
505                 // finally, nuke the connection
506                 delete i->second;
507         }
508         Dispatcher->UnlockQueue();
509         connections.swap(conns);
510 }
511
512 void ModuleSQL::OnUnloadModule(Module* mod)
513 {
514         SQL::Error err(SQL::BAD_DBID);
515         Dispatcher->LockQueue();
516         unsigned int i = qq.size();
517         while (i > 0)
518         {
519                 i--;
520                 if (qq[i].q->creator == mod)
521                 {
522                         if (i == 0)
523                         {
524                                 // need to wait until the query is done
525                                 // (the result will be discarded)
526                                 qq[i].c->lock.Lock();
527                                 qq[i].c->lock.Unlock();
528                         }
529                         qq[i].q->OnError(err);
530                         delete qq[i].q;
531                         qq.erase(qq.begin() + i);
532                 }
533         }
534         Dispatcher->UnlockQueue();
535         // clean up any result queue entries
536         Dispatcher->OnNotify();
537 }
538
539 Version ModuleSQL::GetVersion()
540 {
541         return Version("Provides MySQL support", VF_VENDOR);
542 }
543
544 void DispatcherThread::Run()
545 {
546         this->LockQueue();
547         while (!this->GetExitFlag())
548         {
549                 if (!Parent->qq.empty())
550                 {
551                         QQueueItem i = Parent->qq.front();
552                         i.c->lock.Lock();
553                         this->UnlockQueue();
554                         MySQLresult* res = i.c->DoBlockingQuery(i.query);
555                         i.c->lock.Unlock();
556
557                         /*
558                          * At this point, the main thread could be working on:
559                          *  Rehash - delete i.c out from under us. We don't care about that.
560                          *  UnloadModule - delete i.q and the qq item. Need to avoid reporting results.
561                          */
562
563                         this->LockQueue();
564                         if (!Parent->qq.empty() && Parent->qq.front().q == i.q)
565                         {
566                                 Parent->qq.pop_front();
567                                 Parent->rq.push_back(RQueueItem(i.q, res));
568                                 NotifyParent();
569                         }
570                         else
571                         {
572                                 // UnloadModule ate the query
573                                 delete res;
574                         }
575                 }
576                 else
577                 {
578                         /* We know the queue is empty, we can safely hang this thread until
579                          * something happens
580                          */
581                         this->WaitForQueue();
582                 }
583         }
584         this->UnlockQueue();
585 }
586
587 void DispatcherThread::OnNotify()
588 {
589         // this could unlock during the dispatch, but OnResult isn't expected to take that long
590         this->LockQueue();
591         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
592         {
593                 MySQLresult* res = i->r;
594                 if (res->err.code == SQL::SUCCESS)
595                         i->q->OnResult(*res);
596                 else
597                         i->q->OnError(res->err);
598                 delete i->q;
599                 delete i->r;
600         }
601         Parent->rq.clear();
602         this->UnlockQueue();
603 }
604
605 MODULE_INIT(ModuleSQL)