]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
d94c75917b548fbe7176a95d2c96817f4f9b5626
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22 /// $CompilerFlags: execute("mysql_config --include" "MYSQL_CXXFLAGS")
23 /// $LinkerFlags: execute("mysql_config --libs_r" "MYSQL_LDFLAGS" "-lmysqlclient")
24
25 /// $PackageInfo: require_system("arch") mariadb-libs
26 /// $PackageInfo: require_system("centos" "6.0" "6.99") mysql-devel
27 /// $PackageInfo: require_system("centos" "7.0") mariadb-devel
28 /// $PackageInfo: require_system("darwin") mysql-connector-c
29 /// $PackageInfo: require_system("debian") libmysqlclient-dev
30 /// $PackageInfo: require_system("ubuntu") libmysqlclient-dev
31
32 #ifdef __GNUC__
33 # pragma GCC diagnostic push
34 #endif
35
36 // Fix warnings about the use of `long long` on C++03.
37 #if defined __clang__
38 # pragma clang diagnostic ignored "-Wc++11-long-long"
39 #elif defined __GNUC__
40 # pragma GCC diagnostic ignored "-Wlong-long"
41 #endif
42
43 #include "inspircd.h"
44 #include <mysql.h>
45 #include "modules/sql.h"
46
47 #ifdef __GNUC__
48 # pragma GCC diagnostic pop
49 #endif
50
51 #ifdef _WIN32
52 # pragma comment(lib, "libmysql.lib")
53 #endif
54
55 /* VERSION 3 API: With nonblocking (threaded) requests */
56
57 /* THE NONBLOCKING MYSQL API!
58  *
59  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
60  * that instead, you should thread your program. This is what i've done here to allow for
61  * asyncronous SQL requests via mysql. The way this works is as follows:
62  *
63  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
64  * using a queue with priorities. There is a mutex on either end which prevents two threads
65  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
66  * worker thread wakes up, and checks if there is a request at the head of its queue.
67  * If there is, it processes this request, blocking the worker thread but leaving the ircd
68  * thread to go about its business as usual. During this period, the ircd thread is able
69  * to insert futher pending requests into the queue.
70  *
71  * Once the processing of a request is complete, it is removed from the incoming queue to
72  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
73  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
74  * connection ID through the connection.
75  *
76  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
77  * of the queue, and sends it on its way to the original calling module.
78  *
79  * XXX: You might be asking "why doesnt it just send the response from within the worker thread?"
80  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
81  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
82  * however, if we were to call a module's OnRequest even from within a thread which was not the
83  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
84  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
85  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
86  * gauranteed threadsafe!)
87  */
88
89 class SQLConnection;
90 class MySQLresult;
91 class DispatcherThread;
92
93 struct QueryQueueItem
94 {
95         // An SQL database which this query is executed on.
96         SQLConnection* connection;
97
98         // An object which handles the result of the query.
99         SQL::Query* query;
100
101         // The SQL query which is to be executed.
102         std::string querystr;
103
104         QueryQueueItem(SQL::Query* q, const std::string& s, SQLConnection* c)
105                 : connection(c)
106                 , query(q)
107                 , querystr(s)
108         {
109         }
110 };
111
112 struct ResultQueueItem
113 {
114         // An object which handles the result of the query.
115         SQL::Query* query;
116
117         // The result returned from executing the MySQL query.
118         MySQLresult* result;
119
120         ResultQueueItem(SQL::Query* q, MySQLresult* r)
121                 : query(q)
122                 , result(r)
123         {
124         }
125 };
126
127 typedef insp::flat_map<std::string, SQLConnection*> ConnMap;
128 typedef std::deque<QueryQueueItem> QueryQueue;
129 typedef std::deque<ResultQueueItem> ResultQueue;
130
131 /** MySQL module
132  *  */
133 class ModuleSQL : public Module
134 {
135  public:
136         DispatcherThread* Dispatcher;
137         QueryQueue qq;       // MUST HOLD MUTEX
138         ResultQueue rq;      // MUST HOLD MUTEX
139         ConnMap connections; // main thread only
140
141         ModuleSQL();
142         void init() CXX11_OVERRIDE;
143         ~ModuleSQL();
144         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE;
145         void OnUnloadModule(Module* mod) CXX11_OVERRIDE;
146         Version GetVersion() CXX11_OVERRIDE;
147 };
148
149 class DispatcherThread : public SocketThread
150 {
151  private:
152         ModuleSQL* const Parent;
153  public:
154         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
155         ~DispatcherThread() { }
156         void Run() CXX11_OVERRIDE;
157         void OnNotify() CXX11_OVERRIDE;
158 };
159
160 /** Represents a mysql result set
161  */
162 class MySQLresult : public SQL::Result
163 {
164  public:
165         SQL::Error err;
166         int currentrow;
167         int rows;
168         std::vector<std::string> colnames;
169         std::vector<SQL::Row> fieldlists;
170
171         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL::SUCCESS), currentrow(0), rows(0)
172         {
173                 if (affected_rows >= 1)
174                 {
175                         rows = affected_rows;
176                         fieldlists.resize(rows);
177                 }
178                 unsigned int field_count = 0;
179                 if (res)
180                 {
181                         MYSQL_ROW row;
182                         int n = 0;
183                         while ((row = mysql_fetch_row(res)))
184                         {
185                                 if (fieldlists.size() < (unsigned int)rows+1)
186                                 {
187                                         fieldlists.resize(fieldlists.size()+1);
188                                 }
189                                 field_count = 0;
190                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
191                                 if(mysql_num_fields(res) == 0)
192                                         break;
193                                 if (fields && mysql_num_fields(res))
194                                 {
195                                         colnames.clear();
196                                         while (field_count < mysql_num_fields(res))
197                                         {
198                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
199                                                 if (row[field_count])
200                                                         fieldlists[n].push_back(SQL::Field(row[field_count]));
201                                                 else
202                                                         fieldlists[n].push_back(SQL::Field());
203                                                 colnames.push_back(a);
204                                                 field_count++;
205                                         }
206                                         n++;
207                                 }
208                                 rows++;
209                         }
210                         mysql_free_result(res);
211                 }
212         }
213
214         MySQLresult(SQL::Error& e) : err(e)
215         {
216
217         }
218
219         int Rows() CXX11_OVERRIDE
220         {
221                 return rows;
222         }
223
224         void GetCols(std::vector<std::string>& result) CXX11_OVERRIDE
225         {
226                 result.assign(colnames.begin(), colnames.end());
227         }
228
229         bool HasColumn(const std::string& column, size_t& index) CXX11_OVERRIDE
230         {
231                 for (size_t i = 0; i < colnames.size(); ++i)
232                 {
233                         if (colnames[i] == column)
234                         {
235                                 index = i;
236                                 return true;
237                         }
238                 }
239                 return false;
240         }
241
242         SQL::Field GetValue(int row, int column)
243         {
244                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
245                 {
246                         return fieldlists[row][column];
247                 }
248                 return SQL::Field();
249         }
250
251         bool GetRow(SQL::Row& result) CXX11_OVERRIDE
252         {
253                 if (currentrow < rows)
254                 {
255                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
256                         currentrow++;
257                         return true;
258                 }
259                 else
260                 {
261                         result.clear();
262                         return false;
263                 }
264         }
265 };
266
267 /** Represents a connection to a mysql database
268  */
269 class SQLConnection : public SQL::Provider
270 {
271  private:
272         bool EscapeString(SQL::Query* query, const std::string& in, std::string& out)
273         {
274                 // In the worst case each character may need to be encoded as using two bytes and one
275                 // byte is the NUL terminator.
276                 std::vector<char> buffer(in.length() * 2 + 1);
277
278                 // The return value of mysql_escape_string() is either an error or the length of the
279                 // encoded string not including the NUL terminator.
280                 //
281                 // Unfortunately, someone genius decided that mysql_escape_string should return an
282                 // unsigned type even though -1 is returned on error so checking whether an error
283                 // happened is a bit cursed.
284                 unsigned long escapedsize = mysql_escape_string(&buffer[0], in.c_str(), in.length());
285                 if (escapedsize == static_cast<unsigned long>(-1))
286                 {
287                         SQL::Error err(SQL::QSEND_FAIL, InspIRCd::Format("%u: %s", mysql_errno(connection), mysql_error(connection)));
288                         query->OnError(err);
289                         return false;
290                 }
291
292                 out.append(&buffer[0], escapedsize);
293                 return true;
294         }
295
296  public:
297         reference<ConfigTag> config;
298         MYSQL *connection;
299         Mutex lock;
300
301         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
302         SQLConnection(Module* p, ConfigTag* tag)
303                 : SQL::Provider(p, tag->getString("id"))
304                 , config(tag)
305                 , connection(NULL)
306         {
307         }
308
309         ~SQLConnection()
310         {
311                 Close();
312         }
313
314         // This method connects to the database using the credentials supplied to the constructor, and returns
315         // true upon success.
316         bool Connect()
317         {
318                 connection = mysql_init(connection);
319
320                 // Set the connection timeout.
321                 unsigned int timeout = config->getDuration("timeout", 5, 1, 30);
322                 mysql_options(connection, MYSQL_OPT_CONNECT_TIMEOUT, &timeout);
323
324                 // Attempt to connect to the database.
325                 const std::string host = config->getString("host");
326                 const std::string user = config->getString("user");
327                 const std::string pass = config->getString("pass");
328                 const std::string dbname = config->getString("name");
329                 unsigned int port = config->getUInt("port", 3306, 1, 65535);
330                 if (!mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, CLIENT_IGNORE_SIGPIPE))
331                 {
332                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Unable to connect to the %s MySQL server: %s",
333                                 GetId().c_str(), mysql_error(connection));
334                         return false;
335                 }
336
337                 // Set the default character set.
338                 const std::string charset = config->getString("charset");
339                 if (!charset.empty() && mysql_set_character_set(connection, charset.c_str()))
340                 {
341                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Could not set character set for %s to \"%s\": %s",
342                                 GetId().c_str(), charset.c_str(), mysql_error(connection));
343                         return false;
344                 }
345
346                 // Execute the initial SQL query.
347                 const std::string initialquery = config->getString("initialquery");
348                 if (!initialquery.empty() && mysql_real_query(connection, initialquery.data(), initialquery.length()))
349                 {
350                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Could not execute initial query \"%s\" for %s: %s",
351                                 initialquery.c_str(), name.c_str(), mysql_error(connection));
352                         return false;
353                 }
354
355                 return true;
356         }
357
358         ModuleSQL* Parent()
359         {
360                 return (ModuleSQL*)(Module*)creator;
361         }
362
363         MySQLresult* DoBlockingQuery(const std::string& query)
364         {
365
366                 /* Parse the command string and dispatch it to mysql */
367                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
368                 {
369                         /* Successfull query */
370                         MYSQL_RES* res = mysql_use_result(connection);
371                         unsigned long rows = mysql_affected_rows(connection);
372                         return new MySQLresult(res, rows);
373                 }
374                 else
375                 {
376                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
377                          * possible error numbers and error messages */
378                         SQL::Error e(SQL::QREPLY_FAIL, InspIRCd::Format("%u: %s", mysql_errno(connection), mysql_error(connection)));
379                         return new MySQLresult(e);
380                 }
381         }
382
383         bool CheckConnection()
384         {
385                 if (!connection || mysql_ping(connection) != 0)
386                         return Connect();
387                 return true;
388         }
389
390         void Close()
391         {
392                 mysql_close(connection);
393         }
394
395         void Submit(SQL::Query* q, const std::string& qs) CXX11_OVERRIDE
396         {
397                 ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Executing MySQL query: " + qs);
398                 Parent()->Dispatcher->LockQueue();
399                 Parent()->qq.push_back(QueryQueueItem(q, qs, this));
400                 Parent()->Dispatcher->UnlockQueueWakeup();
401         }
402
403         void Submit(SQL::Query* call, const std::string& q, const SQL::ParamList& p) CXX11_OVERRIDE
404         {
405                 std::string res;
406                 unsigned int param = 0;
407                 for(std::string::size_type i = 0; i < q.length(); i++)
408                 {
409                         if (q[i] != '?')
410                                 res.push_back(q[i]);
411                         else if (param < p.size() && !EscapeString(call, p[param++], res))
412                                 return;
413                 }
414                 Submit(call, res);
415         }
416
417         void Submit(SQL::Query* call, const std::string& q, const SQL::ParamMap& p) CXX11_OVERRIDE
418         {
419                 std::string res;
420                 for(std::string::size_type i = 0; i < q.length(); i++)
421                 {
422                         if (q[i] != '$')
423                                 res.push_back(q[i]);
424                         else
425                         {
426                                 std::string field;
427                                 i++;
428                                 while (i < q.length() && isalnum(q[i]))
429                                         field.push_back(q[i++]);
430                                 i--;
431
432                                 SQL::ParamMap::const_iterator it = p.find(field);
433                                 if (it != p.end() && !EscapeString(call, it->second, res))
434                                         return;
435                         }
436                 }
437                 Submit(call, res);
438         }
439 };
440
441 ModuleSQL::ModuleSQL()
442 {
443         Dispatcher = NULL;
444 }
445
446 void ModuleSQL::init()
447 {
448         if (mysql_library_init(0, NULL, NULL))
449                 throw ModuleException("Unable to initialise the MySQL library!");
450
451         Dispatcher = new DispatcherThread(this);
452         ServerInstance->Threads.Start(Dispatcher);
453 }
454
455 ModuleSQL::~ModuleSQL()
456 {
457         if (Dispatcher)
458         {
459                 Dispatcher->join();
460                 Dispatcher->OnNotify();
461                 delete Dispatcher;
462         }
463
464         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
465         {
466                 delete i->second;
467         }
468
469         mysql_library_end();
470 }
471
472 void ModuleSQL::ReadConfig(ConfigStatus& status)
473 {
474         ConnMap conns;
475         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
476         for(ConfigIter i = tags.first; i != tags.second; i++)
477         {
478                 if (!stdalgo::string::equalsci(i->second->getString("module"), "mysql"))
479                         continue;
480                 std::string id = i->second->getString("id");
481                 ConnMap::iterator curr = connections.find(id);
482                 if (curr == connections.end())
483                 {
484                         SQLConnection* conn = new SQLConnection(this, i->second);
485                         conns.insert(std::make_pair(id, conn));
486                         ServerInstance->Modules->AddService(*conn);
487                 }
488                 else
489                 {
490                         conns.insert(*curr);
491                         connections.erase(curr);
492                 }
493         }
494
495         // now clean up the deleted databases
496         Dispatcher->LockQueue();
497         SQL::Error err(SQL::BAD_DBID);
498         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
499         {
500                 ServerInstance->Modules->DelService(*i->second);
501                 // it might be running a query on this database. Wait for that to complete
502                 i->second->lock.Lock();
503                 i->second->lock.Unlock();
504                 // now remove all active queries to this DB
505                 for (size_t j = qq.size(); j > 0; j--)
506                 {
507                         size_t k = j - 1;
508                         if (qq[k].connection == i->second)
509                         {
510                                 qq[k].query->OnError(err);
511                                 delete qq[k].query;
512                                 qq.erase(qq.begin() + k);
513                         }
514                 }
515                 // finally, nuke the connection
516                 delete i->second;
517         }
518         Dispatcher->UnlockQueue();
519         connections.swap(conns);
520 }
521
522 void ModuleSQL::OnUnloadModule(Module* mod)
523 {
524         SQL::Error err(SQL::BAD_DBID);
525         Dispatcher->LockQueue();
526         unsigned int i = qq.size();
527         while (i > 0)
528         {
529                 i--;
530                 if (qq[i].query->creator == mod)
531                 {
532                         if (i == 0)
533                         {
534                                 // need to wait until the query is done
535                                 // (the result will be discarded)
536                                 qq[i].connection->lock.Lock();
537                                 qq[i].connection->lock.Unlock();
538                         }
539                         qq[i].query->OnError(err);
540                         delete qq[i].query;
541                         qq.erase(qq.begin() + i);
542                 }
543         }
544         Dispatcher->UnlockQueue();
545         // clean up any result queue entries
546         Dispatcher->OnNotify();
547 }
548
549 Version ModuleSQL::GetVersion()
550 {
551         return Version("Provides MySQL support", VF_VENDOR);
552 }
553
554 void DispatcherThread::Run()
555 {
556         this->LockQueue();
557         while (!this->GetExitFlag())
558         {
559                 if (!Parent->qq.empty())
560                 {
561                         QueryQueueItem i = Parent->qq.front();
562                         i.connection->lock.Lock();
563                         this->UnlockQueue();
564                         MySQLresult* res = i.connection->DoBlockingQuery(i.querystr);
565                         i.connection->lock.Unlock();
566
567                         /*
568                          * At this point, the main thread could be working on:
569                          *  Rehash - delete i.connection out from under us. We don't care about that.
570                          *  UnloadModule - delete i.query and the qq item. Need to avoid reporting results.
571                          */
572
573                         this->LockQueue();
574                         if (!Parent->qq.empty() && Parent->qq.front().query == i.query)
575                         {
576                                 Parent->qq.pop_front();
577                                 Parent->rq.push_back(ResultQueueItem(i.query, res));
578                                 NotifyParent();
579                         }
580                         else
581                         {
582                                 // UnloadModule ate the query
583                                 delete res;
584                         }
585                 }
586                 else
587                 {
588                         /* We know the queue is empty, we can safely hang this thread until
589                          * something happens
590                          */
591                         this->WaitForQueue();
592                 }
593         }
594         this->UnlockQueue();
595 }
596
597 void DispatcherThread::OnNotify()
598 {
599         // this could unlock during the dispatch, but OnResult isn't expected to take that long
600         this->LockQueue();
601         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
602         {
603                 MySQLresult* res = i->result;
604                 if (res->err.code == SQL::SUCCESS)
605                         i->query->OnResult(*res);
606                 else
607                         i->query->OnError(res->err);
608                 delete i->query;
609                 delete i->result;
610         }
611         Parent->rq.clear();
612         this->UnlockQueue();
613 }
614
615 MODULE_INIT(ModuleSQL)