]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
Merge pull request #1185 from SaberUK/master+lockserv
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22 /// $CompilerFlags: execute("mysql_config --include" "MYSQL_CXXFLAGS")
23 /// $LinkerFlags: execute("mysql_config --libs_r" "MYSQL_LDFLAGS" "-lmysqlclient")
24
25 /// $PackageInfo: require_system("centos" "6.0" "6.99") mysql-devel
26 /// $PackageInfo: require_system("centos" "7.0") mariadb-devel
27 /// $PackageInfo: require_system("darwin") mysql-connector-c
28 /// $PackageInfo: require_system("ubuntu") libmysqlclient-dev
29
30
31 // Fix warnings about the use of `long long` on C++03.
32 #if defined __clang__
33 # pragma clang diagnostic ignored "-Wc++11-long-long"
34 #elif defined __GNUC__
35 # pragma GCC diagnostic ignored "-Wlong-long"
36 #endif
37
38 #include "inspircd.h"
39 #include <mysql.h>
40 #include "modules/sql.h"
41
42 #ifdef _WIN32
43 # pragma comment(lib, "libmysql.lib")
44 #endif
45
46 /* VERSION 3 API: With nonblocking (threaded) requests */
47
48 /* THE NONBLOCKING MYSQL API!
49  *
50  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
51  * that instead, you should thread your program. This is what i've done here to allow for
52  * asyncronous SQL requests via mysql. The way this works is as follows:
53  *
54  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
55  * using a queue with priorities. There is a mutex on either end which prevents two threads
56  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
57  * worker thread wakes up, and checks if there is a request at the head of its queue.
58  * If there is, it processes this request, blocking the worker thread but leaving the ircd
59  * thread to go about its business as usual. During this period, the ircd thread is able
60  * to insert futher pending requests into the queue.
61  *
62  * Once the processing of a request is complete, it is removed from the incoming queue to
63  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
64  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
65  * connection ID through the connection.
66  *
67  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
68  * of the queue, and sends it on its way to the original calling module.
69  *
70  * XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
71  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
72  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
73  * however, if we were to call a module's OnRequest even from within a thread which was not the
74  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
75  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
76  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
77  * gauranteed threadsafe!)
78  *
79  * For a diagram of this system please see http://wiki.inspircd.org/Mysql2
80  */
81
82 class SQLConnection;
83 class MySQLresult;
84 class DispatcherThread;
85
86 struct QQueueItem
87 {
88         SQLQuery* q;
89         std::string query;
90         SQLConnection* c;
91         QQueueItem(SQLQuery* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
92 };
93
94 struct RQueueItem
95 {
96         SQLQuery* q;
97         MySQLresult* r;
98         RQueueItem(SQLQuery* Q, MySQLresult* R) : q(Q), r(R) {}
99 };
100
101 typedef insp::flat_map<std::string, SQLConnection*> ConnMap;
102 typedef std::deque<QQueueItem> QueryQueue;
103 typedef std::deque<RQueueItem> ResultQueue;
104
105 /** MySQL module
106  *  */
107 class ModuleSQL : public Module
108 {
109  public:
110         DispatcherThread* Dispatcher;
111         QueryQueue qq;       // MUST HOLD MUTEX
112         ResultQueue rq;      // MUST HOLD MUTEX
113         ConnMap connections; // main thread only
114
115         ModuleSQL();
116         void init() CXX11_OVERRIDE;
117         ~ModuleSQL();
118         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE;
119         void OnUnloadModule(Module* mod) CXX11_OVERRIDE;
120         Version GetVersion() CXX11_OVERRIDE;
121 };
122
123 class DispatcherThread : public SocketThread
124 {
125  private:
126         ModuleSQL* const Parent;
127  public:
128         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
129         ~DispatcherThread() { }
130         void Run();
131         void OnNotify();
132 };
133
134 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
135 #define mysql_field_count mysql_num_fields
136 #endif
137
138 /** Represents a mysql result set
139  */
140 class MySQLresult : public SQLResult
141 {
142  public:
143         SQLerror err;
144         int currentrow;
145         int rows;
146         std::vector<std::string> colnames;
147         std::vector<SQLEntries> fieldlists;
148
149         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL_NO_ERROR), currentrow(0), rows(0)
150         {
151                 if (affected_rows >= 1)
152                 {
153                         rows = affected_rows;
154                         fieldlists.resize(rows);
155                 }
156                 unsigned int field_count = 0;
157                 if (res)
158                 {
159                         MYSQL_ROW row;
160                         int n = 0;
161                         while ((row = mysql_fetch_row(res)))
162                         {
163                                 if (fieldlists.size() < (unsigned int)rows+1)
164                                 {
165                                         fieldlists.resize(fieldlists.size()+1);
166                                 }
167                                 field_count = 0;
168                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
169                                 if(mysql_num_fields(res) == 0)
170                                         break;
171                                 if (fields && mysql_num_fields(res))
172                                 {
173                                         colnames.clear();
174                                         while (field_count < mysql_num_fields(res))
175                                         {
176                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
177                                                 if (row[field_count])
178                                                         fieldlists[n].push_back(SQLEntry(row[field_count]));
179                                                 else
180                                                         fieldlists[n].push_back(SQLEntry());
181                                                 colnames.push_back(a);
182                                                 field_count++;
183                                         }
184                                         n++;
185                                 }
186                                 rows++;
187                         }
188                         mysql_free_result(res);
189                 }
190         }
191
192         MySQLresult(SQLerror& e) : err(e)
193         {
194
195         }
196
197         int Rows()
198         {
199                 return rows;
200         }
201
202         void GetCols(std::vector<std::string>& result)
203         {
204                 result.assign(colnames.begin(), colnames.end());
205         }
206
207         SQLEntry GetValue(int row, int column)
208         {
209                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
210                 {
211                         return fieldlists[row][column];
212                 }
213                 return SQLEntry();
214         }
215
216         bool GetRow(SQLEntries& result)
217         {
218                 if (currentrow < rows)
219                 {
220                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
221                         currentrow++;
222                         return true;
223                 }
224                 else
225                 {
226                         result.clear();
227                         return false;
228                 }
229         }
230 };
231
232 /** Represents a connection to a mysql database
233  */
234 class SQLConnection : public SQLProvider
235 {
236  public:
237         reference<ConfigTag> config;
238         MYSQL *connection;
239         Mutex lock;
240
241         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
242         SQLConnection(Module* p, ConfigTag* tag) : SQLProvider(p, "SQL/" + tag->getString("id")),
243                 config(tag), connection(NULL)
244         {
245         }
246
247         ~SQLConnection()
248         {
249                 Close();
250         }
251
252         // This method connects to the database using the credentials supplied to the constructor, and returns
253         // true upon success.
254         bool Connect()
255         {
256                 unsigned int timeout = 1;
257                 connection = mysql_init(connection);
258                 mysql_options(connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
259                 std::string host = config->getString("host");
260                 std::string user = config->getString("user");
261                 std::string pass = config->getString("pass");
262                 std::string dbname = config->getString("name");
263                 int port = config->getInt("port");
264                 bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
265                 if (!rv)
266                         return rv;
267
268                 // Enable character set settings
269                 std::string charset = config->getString("charset");
270                 if ((!charset.empty()) && (mysql_set_character_set(connection, charset.c_str())))
271                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not set character set to \"%s\"", charset.c_str());
272
273                 std::string initquery;
274                 if (config->readString("initialquery", initquery))
275                 {
276                         mysql_query(connection,initquery.c_str());
277                 }
278                 return true;
279         }
280
281         ModuleSQL* Parent()
282         {
283                 return (ModuleSQL*)(Module*)creator;
284         }
285
286         MySQLresult* DoBlockingQuery(const std::string& query)
287         {
288
289                 /* Parse the command string and dispatch it to mysql */
290                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
291                 {
292                         /* Successfull query */
293                         MYSQL_RES* res = mysql_use_result(connection);
294                         unsigned long rows = mysql_affected_rows(connection);
295                         return new MySQLresult(res, rows);
296                 }
297                 else
298                 {
299                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
300                          * possible error numbers and error messages */
301                         SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + ": " + mysql_error(connection));
302                         return new MySQLresult(e);
303                 }
304         }
305
306         bool CheckConnection()
307         {
308                 if (!connection || mysql_ping(connection) != 0)
309                         return Connect();
310                 return true;
311         }
312
313         std::string GetError()
314         {
315                 return mysql_error(connection);
316         }
317
318         void Close()
319         {
320                 mysql_close(connection);
321         }
322
323         void submit(SQLQuery* q, const std::string& qs)
324         {
325                 Parent()->Dispatcher->LockQueue();
326                 Parent()->qq.push_back(QQueueItem(q, qs, this));
327                 Parent()->Dispatcher->UnlockQueueWakeup();
328         }
329
330         void submit(SQLQuery* call, const std::string& q, const ParamL& p)
331         {
332                 std::string res;
333                 unsigned int param = 0;
334                 for(std::string::size_type i = 0; i < q.length(); i++)
335                 {
336                         if (q[i] != '?')
337                                 res.push_back(q[i]);
338                         else
339                         {
340                                 if (param < p.size())
341                                 {
342                                         std::string parm = p[param++];
343                                         // In the worst case, each character may need to be encoded as using two bytes,
344                                         // and one byte is the terminating null
345                                         std::vector<char> buffer(parm.length() * 2 + 1);
346
347                                         // The return value of mysql_escape_string() is the length of the encoded string,
348                                         // not including the terminating null
349                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
350 //                                      mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
351                                         res.append(&buffer[0], escapedsize);
352                                 }
353                         }
354                 }
355                 submit(call, res);
356         }
357
358         void submit(SQLQuery* call, const std::string& q, const ParamM& p)
359         {
360                 std::string res;
361                 for(std::string::size_type i = 0; i < q.length(); i++)
362                 {
363                         if (q[i] != '$')
364                                 res.push_back(q[i]);
365                         else
366                         {
367                                 std::string field;
368                                 i++;
369                                 while (i < q.length() && isalnum(q[i]))
370                                         field.push_back(q[i++]);
371                                 i--;
372
373                                 ParamM::const_iterator it = p.find(field);
374                                 if (it != p.end())
375                                 {
376                                         std::string parm = it->second;
377                                         // NOTE: See above
378                                         std::vector<char> buffer(parm.length() * 2 + 1);
379                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
380                                         res.append(&buffer[0], escapedsize);
381                                 }
382                         }
383                 }
384                 submit(call, res);
385         }
386 };
387
388 ModuleSQL::ModuleSQL()
389 {
390         Dispatcher = NULL;
391 }
392
393 void ModuleSQL::init()
394 {
395         Dispatcher = new DispatcherThread(this);
396         ServerInstance->Threads.Start(Dispatcher);
397 }
398
399 ModuleSQL::~ModuleSQL()
400 {
401         if (Dispatcher)
402         {
403                 Dispatcher->join();
404                 Dispatcher->OnNotify();
405                 delete Dispatcher;
406         }
407         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
408         {
409                 delete i->second;
410         }
411 }
412
413 void ModuleSQL::ReadConfig(ConfigStatus& status)
414 {
415         ConnMap conns;
416         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
417         for(ConfigIter i = tags.first; i != tags.second; i++)
418         {
419                 if (i->second->getString("module", "mysql") != "mysql")
420                         continue;
421                 std::string id = i->second->getString("id");
422                 ConnMap::iterator curr = connections.find(id);
423                 if (curr == connections.end())
424                 {
425                         SQLConnection* conn = new SQLConnection(this, i->second);
426                         conns.insert(std::make_pair(id, conn));
427                         ServerInstance->Modules->AddService(*conn);
428                 }
429                 else
430                 {
431                         conns.insert(*curr);
432                         connections.erase(curr);
433                 }
434         }
435
436         // now clean up the deleted databases
437         Dispatcher->LockQueue();
438         SQLerror err(SQL_BAD_DBID);
439         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
440         {
441                 ServerInstance->Modules->DelService(*i->second);
442                 // it might be running a query on this database. Wait for that to complete
443                 i->second->lock.Lock();
444                 i->second->lock.Unlock();
445                 // now remove all active queries to this DB
446                 for (size_t j = qq.size(); j > 0; j--)
447                 {
448                         size_t k = j - 1;
449                         if (qq[k].c == i->second)
450                         {
451                                 qq[k].q->OnError(err);
452                                 delete qq[k].q;
453                                 qq.erase(qq.begin() + k);
454                         }
455                 }
456                 // finally, nuke the connection
457                 delete i->second;
458         }
459         Dispatcher->UnlockQueue();
460         connections.swap(conns);
461 }
462
463 void ModuleSQL::OnUnloadModule(Module* mod)
464 {
465         SQLerror err(SQL_BAD_DBID);
466         Dispatcher->LockQueue();
467         unsigned int i = qq.size();
468         while (i > 0)
469         {
470                 i--;
471                 if (qq[i].q->creator == mod)
472                 {
473                         if (i == 0)
474                         {
475                                 // need to wait until the query is done
476                                 // (the result will be discarded)
477                                 qq[i].c->lock.Lock();
478                                 qq[i].c->lock.Unlock();
479                         }
480                         qq[i].q->OnError(err);
481                         delete qq[i].q;
482                         qq.erase(qq.begin() + i);
483                 }
484         }
485         Dispatcher->UnlockQueue();
486         // clean up any result queue entries
487         Dispatcher->OnNotify();
488 }
489
490 Version ModuleSQL::GetVersion()
491 {
492         return Version("MySQL support", VF_VENDOR);
493 }
494
495 void DispatcherThread::Run()
496 {
497         this->LockQueue();
498         while (!this->GetExitFlag())
499         {
500                 if (!Parent->qq.empty())
501                 {
502                         QQueueItem i = Parent->qq.front();
503                         i.c->lock.Lock();
504                         this->UnlockQueue();
505                         MySQLresult* res = i.c->DoBlockingQuery(i.query);
506                         i.c->lock.Unlock();
507
508                         /*
509                          * At this point, the main thread could be working on:
510                          *  Rehash - delete i.c out from under us. We don't care about that.
511                          *  UnloadModule - delete i.q and the qq item. Need to avoid reporting results.
512                          */
513
514                         this->LockQueue();
515                         if (!Parent->qq.empty() && Parent->qq.front().q == i.q)
516                         {
517                                 Parent->qq.pop_front();
518                                 Parent->rq.push_back(RQueueItem(i.q, res));
519                                 NotifyParent();
520                         }
521                         else
522                         {
523                                 // UnloadModule ate the query
524                                 delete res;
525                         }
526                 }
527                 else
528                 {
529                         /* We know the queue is empty, we can safely hang this thread until
530                          * something happens
531                          */
532                         this->WaitForQueue();
533                 }
534         }
535         this->UnlockQueue();
536 }
537
538 void DispatcherThread::OnNotify()
539 {
540         // this could unlock during the dispatch, but OnResult isn't expected to take that long
541         this->LockQueue();
542         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
543         {
544                 MySQLresult* res = i->r;
545                 if (res->err.id == SQL_NO_ERROR)
546                         i->q->OnResult(*res);
547                 else
548                         i->q->OnError(res->err);
549                 delete i->q;
550                 delete i->r;
551         }
552         Parent->rq.clear();
553         this->UnlockQueue();
554 }
555
556 MODULE_INIT(ModuleSQL)