]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
02d2a08a9699ba42c2b5f5e90ee16adf933bca26
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22 /// $CompilerFlags: execute("mysql_config --include" "MYSQL_CXXFLAGS")
23 /// $LinkerFlags: execute("mysql_config --libs_r" "MYSQL_LDFLAGS" "-lmysqlclient")
24
25 /// $PackageInfo: require_system("centos" "6.0" "6.99") mysql-devel
26 /// $PackageInfo: require_system("centos" "7.0") mariadb-devel
27 /// $PackageInfo: require_system("darwin") mysql-connector-c
28 /// $PackageInfo: require_system("debian") libmysqlclient-dev
29 /// $PackageInfo: require_system("ubuntu") libmysqlclient-dev
30
31 #ifdef __GNUC__
32 # pragma GCC diagnostic push
33 #endif
34
35 // Fix warnings about the use of `long long` on C++03.
36 #if defined __clang__
37 # pragma clang diagnostic ignored "-Wc++11-long-long"
38 #elif defined __GNUC__
39 # pragma GCC diagnostic ignored "-Wlong-long"
40 #endif
41
42 #include "inspircd.h"
43 #include <mysql.h>
44 #include "modules/sql.h"
45
46 #ifdef __GNUC__
47 # pragma GCC diagnostic pop
48 #endif
49
50 #ifdef _WIN32
51 # pragma comment(lib, "libmysql.lib")
52 #endif
53
54 /* VERSION 3 API: With nonblocking (threaded) requests */
55
56 /* THE NONBLOCKING MYSQL API!
57  *
58  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
59  * that instead, you should thread your program. This is what i've done here to allow for
60  * asyncronous SQL requests via mysql. The way this works is as follows:
61  *
62  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
63  * using a queue with priorities. There is a mutex on either end which prevents two threads
64  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
65  * worker thread wakes up, and checks if there is a request at the head of its queue.
66  * If there is, it processes this request, blocking the worker thread but leaving the ircd
67  * thread to go about its business as usual. During this period, the ircd thread is able
68  * to insert futher pending requests into the queue.
69  *
70  * Once the processing of a request is complete, it is removed from the incoming queue to
71  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
72  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
73  * connection ID through the connection.
74  *
75  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
76  * of the queue, and sends it on its way to the original calling module.
77  *
78  * XXX: You might be asking "why doesnt it just send the response from within the worker thread?"
79  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
80  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
81  * however, if we were to call a module's OnRequest even from within a thread which was not the
82  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
83  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
84  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
85  * gauranteed threadsafe!)
86  */
87
88 class SQLConnection;
89 class MySQLresult;
90 class DispatcherThread;
91
92 struct QQueueItem
93 {
94         SQL::Query* q;
95         std::string query;
96         SQLConnection* c;
97         QQueueItem(SQL::Query* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
98 };
99
100 struct RQueueItem
101 {
102         SQL::Query* q;
103         MySQLresult* r;
104         RQueueItem(SQL::Query* Q, MySQLresult* R) : q(Q), r(R) {}
105 };
106
107 typedef insp::flat_map<std::string, SQLConnection*> ConnMap;
108 typedef std::deque<QQueueItem> QueryQueue;
109 typedef std::deque<RQueueItem> ResultQueue;
110
111 /** MySQL module
112  *  */
113 class ModuleSQL : public Module
114 {
115  public:
116         DispatcherThread* Dispatcher;
117         QueryQueue qq;       // MUST HOLD MUTEX
118         ResultQueue rq;      // MUST HOLD MUTEX
119         ConnMap connections; // main thread only
120
121         ModuleSQL();
122         void init() CXX11_OVERRIDE;
123         ~ModuleSQL();
124         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE;
125         void OnUnloadModule(Module* mod) CXX11_OVERRIDE;
126         Version GetVersion() CXX11_OVERRIDE;
127 };
128
129 class DispatcherThread : public SocketThread
130 {
131  private:
132         ModuleSQL* const Parent;
133  public:
134         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
135         ~DispatcherThread() { }
136         void Run() CXX11_OVERRIDE;
137         void OnNotify() CXX11_OVERRIDE;
138 };
139
140 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
141 #define mysql_field_count mysql_num_fields
142 #endif
143
144 /** Represents a mysql result set
145  */
146 class MySQLresult : public SQL::Result
147 {
148  public:
149         SQL::Error err;
150         int currentrow;
151         int rows;
152         std::vector<std::string> colnames;
153         std::vector<SQL::Row> fieldlists;
154
155         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL::SUCCESS), currentrow(0), rows(0)
156         {
157                 if (affected_rows >= 1)
158                 {
159                         rows = affected_rows;
160                         fieldlists.resize(rows);
161                 }
162                 unsigned int field_count = 0;
163                 if (res)
164                 {
165                         MYSQL_ROW row;
166                         int n = 0;
167                         while ((row = mysql_fetch_row(res)))
168                         {
169                                 if (fieldlists.size() < (unsigned int)rows+1)
170                                 {
171                                         fieldlists.resize(fieldlists.size()+1);
172                                 }
173                                 field_count = 0;
174                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
175                                 if(mysql_num_fields(res) == 0)
176                                         break;
177                                 if (fields && mysql_num_fields(res))
178                                 {
179                                         colnames.clear();
180                                         while (field_count < mysql_num_fields(res))
181                                         {
182                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
183                                                 if (row[field_count])
184                                                         fieldlists[n].push_back(SQL::Field(row[field_count]));
185                                                 else
186                                                         fieldlists[n].push_back(SQL::Field());
187                                                 colnames.push_back(a);
188                                                 field_count++;
189                                         }
190                                         n++;
191                                 }
192                                 rows++;
193                         }
194                         mysql_free_result(res);
195                 }
196         }
197
198         MySQLresult(SQL::Error& e) : err(e)
199         {
200
201         }
202
203         int Rows() CXX11_OVERRIDE
204         {
205                 return rows;
206         }
207
208         void GetCols(std::vector<std::string>& result) CXX11_OVERRIDE
209         {
210                 result.assign(colnames.begin(), colnames.end());
211         }
212
213         bool HasColumn(const std::string& column, size_t& index) CXX11_OVERRIDE
214         {
215                 for (size_t i = 0; i < colnames.size(); ++i)
216                 {
217                         if (colnames[i] == column)
218                         {
219                                 index = i;
220                                 return true;
221                         }
222                 }
223                 return false;
224         }
225
226         SQL::Field GetValue(int row, int column)
227         {
228                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
229                 {
230                         return fieldlists[row][column];
231                 }
232                 return SQL::Field();
233         }
234
235         bool GetRow(SQL::Row& result) CXX11_OVERRIDE
236         {
237                 if (currentrow < rows)
238                 {
239                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
240                         currentrow++;
241                         return true;
242                 }
243                 else
244                 {
245                         result.clear();
246                         return false;
247                 }
248         }
249 };
250
251 /** Represents a connection to a mysql database
252  */
253 class SQLConnection : public SQL::Provider
254 {
255  public:
256         reference<ConfigTag> config;
257         MYSQL *connection;
258         Mutex lock;
259
260         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
261         SQLConnection(Module* p, ConfigTag* tag) : SQL::Provider(p, "SQL/" + tag->getString("id")),
262                 config(tag), connection(NULL)
263         {
264         }
265
266         ~SQLConnection()
267         {
268                 Close();
269         }
270
271         // This method connects to the database using the credentials supplied to the constructor, and returns
272         // true upon success.
273         bool Connect()
274         {
275                 unsigned int timeout = 1;
276                 connection = mysql_init(connection);
277                 mysql_options(connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
278                 std::string host = config->getString("host");
279                 std::string user = config->getString("user");
280                 std::string pass = config->getString("pass");
281                 std::string dbname = config->getString("name");
282                 unsigned int port = config->getUInt("port", 3306);
283                 bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
284                 if (!rv)
285                         return rv;
286
287                 // Enable character set settings
288                 std::string charset = config->getString("charset");
289                 if ((!charset.empty()) && (mysql_set_character_set(connection, charset.c_str())))
290                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not set character set to \"%s\"", charset.c_str());
291
292                 std::string initquery;
293                 if (config->readString("initialquery", initquery))
294                 {
295                         mysql_query(connection,initquery.c_str());
296                 }
297                 return true;
298         }
299
300         ModuleSQL* Parent()
301         {
302                 return (ModuleSQL*)(Module*)creator;
303         }
304
305         MySQLresult* DoBlockingQuery(const std::string& query)
306         {
307
308                 /* Parse the command string and dispatch it to mysql */
309                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
310                 {
311                         /* Successfull query */
312                         MYSQL_RES* res = mysql_use_result(connection);
313                         unsigned long rows = mysql_affected_rows(connection);
314                         return new MySQLresult(res, rows);
315                 }
316                 else
317                 {
318                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
319                          * possible error numbers and error messages */
320                         SQL::Error e(SQL::QREPLY_FAIL, InspIRCd::Format("%u: %s", mysql_errno(connection), mysql_error(connection)));
321                         return new MySQLresult(e);
322                 }
323         }
324
325         bool CheckConnection()
326         {
327                 if (!connection || mysql_ping(connection) != 0)
328                         return Connect();
329                 return true;
330         }
331
332         std::string GetError()
333         {
334                 return mysql_error(connection);
335         }
336
337         void Close()
338         {
339                 mysql_close(connection);
340         }
341
342         void Submit(SQL::Query* q, const std::string& qs) CXX11_OVERRIDE
343         {
344                 ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Executing MySQL query: " + qs);
345                 Parent()->Dispatcher->LockQueue();
346                 Parent()->qq.push_back(QQueueItem(q, qs, this));
347                 Parent()->Dispatcher->UnlockQueueWakeup();
348         }
349
350         void Submit(SQL::Query* call, const std::string& q, const SQL::ParamList& p) CXX11_OVERRIDE
351         {
352                 std::string res;
353                 unsigned int param = 0;
354                 for(std::string::size_type i = 0; i < q.length(); i++)
355                 {
356                         if (q[i] != '?')
357                                 res.push_back(q[i]);
358                         else
359                         {
360                                 if (param < p.size())
361                                 {
362                                         std::string parm = p[param++];
363                                         // In the worst case, each character may need to be encoded as using two bytes,
364                                         // and one byte is the terminating null
365                                         std::vector<char> buffer(parm.length() * 2 + 1);
366
367                                         // The return value of mysql_real_escape_string() is the length of the encoded string,
368                                         // not including the terminating null
369                                         unsigned long escapedsize = mysql_real_escape_string(connection, &buffer[0], parm.c_str(), parm.length());
370                                         res.append(&buffer[0], escapedsize);
371                                 }
372                         }
373                 }
374                 Submit(call, res);
375         }
376
377         void Submit(SQL::Query* call, const std::string& q, const SQL::ParamMap& p) CXX11_OVERRIDE
378         {
379                 std::string res;
380                 for(std::string::size_type i = 0; i < q.length(); i++)
381                 {
382                         if (q[i] != '$')
383                                 res.push_back(q[i]);
384                         else
385                         {
386                                 std::string field;
387                                 i++;
388                                 while (i < q.length() && isalnum(q[i]))
389                                         field.push_back(q[i++]);
390                                 i--;
391
392                                 SQL::ParamMap::const_iterator it = p.find(field);
393                                 if (it != p.end())
394                                 {
395                                         std::string parm = it->second;
396                                         // NOTE: See above
397                                         std::vector<char> buffer(parm.length() * 2 + 1);
398                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
399                                         res.append(&buffer[0], escapedsize);
400                                 }
401                         }
402                 }
403                 Submit(call, res);
404         }
405 };
406
407 ModuleSQL::ModuleSQL()
408 {
409         Dispatcher = NULL;
410 }
411
412 void ModuleSQL::init()
413 {
414         Dispatcher = new DispatcherThread(this);
415         ServerInstance->Threads.Start(Dispatcher);
416 }
417
418 ModuleSQL::~ModuleSQL()
419 {
420         if (Dispatcher)
421         {
422                 Dispatcher->join();
423                 Dispatcher->OnNotify();
424                 delete Dispatcher;
425         }
426         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
427         {
428                 delete i->second;
429         }
430 }
431
432 void ModuleSQL::ReadConfig(ConfigStatus& status)
433 {
434         ConnMap conns;
435         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
436         for(ConfigIter i = tags.first; i != tags.second; i++)
437         {
438                 if (!stdalgo::string::equalsci(i->second->getString("module"), "mysql"))
439                         continue;
440                 std::string id = i->second->getString("id");
441                 ConnMap::iterator curr = connections.find(id);
442                 if (curr == connections.end())
443                 {
444                         SQLConnection* conn = new SQLConnection(this, i->second);
445                         conns.insert(std::make_pair(id, conn));
446                         ServerInstance->Modules->AddService(*conn);
447                 }
448                 else
449                 {
450                         conns.insert(*curr);
451                         connections.erase(curr);
452                 }
453         }
454
455         // now clean up the deleted databases
456         Dispatcher->LockQueue();
457         SQL::Error err(SQL::BAD_DBID);
458         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
459         {
460                 ServerInstance->Modules->DelService(*i->second);
461                 // it might be running a query on this database. Wait for that to complete
462                 i->second->lock.Lock();
463                 i->second->lock.Unlock();
464                 // now remove all active queries to this DB
465                 for (size_t j = qq.size(); j > 0; j--)
466                 {
467                         size_t k = j - 1;
468                         if (qq[k].c == i->second)
469                         {
470                                 qq[k].q->OnError(err);
471                                 delete qq[k].q;
472                                 qq.erase(qq.begin() + k);
473                         }
474                 }
475                 // finally, nuke the connection
476                 delete i->second;
477         }
478         Dispatcher->UnlockQueue();
479         connections.swap(conns);
480 }
481
482 void ModuleSQL::OnUnloadModule(Module* mod)
483 {
484         SQL::Error err(SQL::BAD_DBID);
485         Dispatcher->LockQueue();
486         unsigned int i = qq.size();
487         while (i > 0)
488         {
489                 i--;
490                 if (qq[i].q->creator == mod)
491                 {
492                         if (i == 0)
493                         {
494                                 // need to wait until the query is done
495                                 // (the result will be discarded)
496                                 qq[i].c->lock.Lock();
497                                 qq[i].c->lock.Unlock();
498                         }
499                         qq[i].q->OnError(err);
500                         delete qq[i].q;
501                         qq.erase(qq.begin() + i);
502                 }
503         }
504         Dispatcher->UnlockQueue();
505         // clean up any result queue entries
506         Dispatcher->OnNotify();
507 }
508
509 Version ModuleSQL::GetVersion()
510 {
511         return Version("Provides MySQL support", VF_VENDOR);
512 }
513
514 void DispatcherThread::Run()
515 {
516         this->LockQueue();
517         while (!this->GetExitFlag())
518         {
519                 if (!Parent->qq.empty())
520                 {
521                         QQueueItem i = Parent->qq.front();
522                         i.c->lock.Lock();
523                         this->UnlockQueue();
524                         MySQLresult* res = i.c->DoBlockingQuery(i.query);
525                         i.c->lock.Unlock();
526
527                         /*
528                          * At this point, the main thread could be working on:
529                          *  Rehash - delete i.c out from under us. We don't care about that.
530                          *  UnloadModule - delete i.q and the qq item. Need to avoid reporting results.
531                          */
532
533                         this->LockQueue();
534                         if (!Parent->qq.empty() && Parent->qq.front().q == i.q)
535                         {
536                                 Parent->qq.pop_front();
537                                 Parent->rq.push_back(RQueueItem(i.q, res));
538                                 NotifyParent();
539                         }
540                         else
541                         {
542                                 // UnloadModule ate the query
543                                 delete res;
544                         }
545                 }
546                 else
547                 {
548                         /* We know the queue is empty, we can safely hang this thread until
549                          * something happens
550                          */
551                         this->WaitForQueue();
552                 }
553         }
554         this->UnlockQueue();
555 }
556
557 void DispatcherThread::OnNotify()
558 {
559         // this could unlock during the dispatch, but OnResult isn't expected to take that long
560         this->LockQueue();
561         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
562         {
563                 MySQLresult* res = i->r;
564                 if (res->err.code == SQL::SUCCESS)
565                         i->q->OnResult(*res);
566                 else
567                         i->q->OnError(res->err);
568                 delete i->q;
569                 delete i->r;
570         }
571         Parent->rq.clear();
572         this->UnlockQueue();
573 }
574
575 MODULE_INIT(ModuleSQL)