]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
d0ccebc478550037a05a91f7d59463fe53ae3dd2
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22 /// $CompilerFlags: execute("mysql_config --include" "MYSQL_CXXFLAGS")
23 /// $LinkerFlags: execute("mysql_config --libs_r" "MYSQL_LDFLAGS" "-lmysqlclient")
24
25 /// $PackageInfo: require_system("centos" "6.0" "6.99") mysql-devel
26 /// $PackageInfo: require_system("centos" "7.0") mariadb-devel
27 /// $PackageInfo: require_system("darwin") mysql-connector-c
28 /// $PackageInfo: require_system("debian") libmysqlclient-dev
29 /// $PackageInfo: require_system("ubuntu") libmysqlclient-dev
30
31
32 // Fix warnings about the use of `long long` on C++03.
33 #if defined __clang__
34 # pragma clang diagnostic ignored "-Wc++11-long-long"
35 #elif defined __GNUC__
36 # pragma GCC diagnostic ignored "-Wlong-long"
37 #endif
38
39 #include "inspircd.h"
40 #include <mysql.h>
41 #include "modules/sql.h"
42
43 #ifdef _WIN32
44 # pragma comment(lib, "libmysql.lib")
45 #endif
46
47 /* VERSION 3 API: With nonblocking (threaded) requests */
48
49 /* THE NONBLOCKING MYSQL API!
50  *
51  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
52  * that instead, you should thread your program. This is what i've done here to allow for
53  * asyncronous SQL requests via mysql. The way this works is as follows:
54  *
55  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
56  * using a queue with priorities. There is a mutex on either end which prevents two threads
57  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
58  * worker thread wakes up, and checks if there is a request at the head of its queue.
59  * If there is, it processes this request, blocking the worker thread but leaving the ircd
60  * thread to go about its business as usual. During this period, the ircd thread is able
61  * to insert futher pending requests into the queue.
62  *
63  * Once the processing of a request is complete, it is removed from the incoming queue to
64  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
65  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
66  * connection ID through the connection.
67  *
68  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
69  * of the queue, and sends it on its way to the original calling module.
70  *
71  * XXX: You might be asking "why doesnt it just send the response from within the worker thread?"
72  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
73  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
74  * however, if we were to call a module's OnRequest even from within a thread which was not the
75  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
76  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
77  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
78  * gauranteed threadsafe!)
79  */
80
81 class SQLConnection;
82 class MySQLresult;
83 class DispatcherThread;
84
85 struct QQueueItem
86 {
87         SQL::Query* q;
88         std::string query;
89         SQLConnection* c;
90         QQueueItem(SQL::Query* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
91 };
92
93 struct RQueueItem
94 {
95         SQL::Query* q;
96         MySQLresult* r;
97         RQueueItem(SQL::Query* Q, MySQLresult* R) : q(Q), r(R) {}
98 };
99
100 typedef insp::flat_map<std::string, SQLConnection*> ConnMap;
101 typedef std::deque<QQueueItem> QueryQueue;
102 typedef std::deque<RQueueItem> ResultQueue;
103
104 /** MySQL module
105  *  */
106 class ModuleSQL : public Module
107 {
108  public:
109         DispatcherThread* Dispatcher;
110         QueryQueue qq;       // MUST HOLD MUTEX
111         ResultQueue rq;      // MUST HOLD MUTEX
112         ConnMap connections; // main thread only
113
114         ModuleSQL();
115         void init() CXX11_OVERRIDE;
116         ~ModuleSQL();
117         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE;
118         void OnUnloadModule(Module* mod) CXX11_OVERRIDE;
119         Version GetVersion() CXX11_OVERRIDE;
120 };
121
122 class DispatcherThread : public SocketThread
123 {
124  private:
125         ModuleSQL* const Parent;
126  public:
127         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
128         ~DispatcherThread() { }
129         void Run() CXX11_OVERRIDE;
130         void OnNotify() CXX11_OVERRIDE;
131 };
132
133 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
134 #define mysql_field_count mysql_num_fields
135 #endif
136
137 /** Represents a mysql result set
138  */
139 class MySQLresult : public SQL::Result
140 {
141  public:
142         SQL::Error err;
143         int currentrow;
144         int rows;
145         std::vector<std::string> colnames;
146         std::vector<SQL::Row> fieldlists;
147
148         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL::SUCCESS), currentrow(0), rows(0)
149         {
150                 if (affected_rows >= 1)
151                 {
152                         rows = affected_rows;
153                         fieldlists.resize(rows);
154                 }
155                 unsigned int field_count = 0;
156                 if (res)
157                 {
158                         MYSQL_ROW row;
159                         int n = 0;
160                         while ((row = mysql_fetch_row(res)))
161                         {
162                                 if (fieldlists.size() < (unsigned int)rows+1)
163                                 {
164                                         fieldlists.resize(fieldlists.size()+1);
165                                 }
166                                 field_count = 0;
167                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
168                                 if(mysql_num_fields(res) == 0)
169                                         break;
170                                 if (fields && mysql_num_fields(res))
171                                 {
172                                         colnames.clear();
173                                         while (field_count < mysql_num_fields(res))
174                                         {
175                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
176                                                 if (row[field_count])
177                                                         fieldlists[n].push_back(SQL::Field(row[field_count]));
178                                                 else
179                                                         fieldlists[n].push_back(SQL::Field());
180                                                 colnames.push_back(a);
181                                                 field_count++;
182                                         }
183                                         n++;
184                                 }
185                                 rows++;
186                         }
187                         mysql_free_result(res);
188                 }
189         }
190
191         MySQLresult(SQL::Error& e) : err(e)
192         {
193
194         }
195
196         int Rows() CXX11_OVERRIDE
197         {
198                 return rows;
199         }
200
201         void GetCols(std::vector<std::string>& result) CXX11_OVERRIDE
202         {
203                 result.assign(colnames.begin(), colnames.end());
204         }
205
206         bool HasColumn(const std::string& column, size_t& index) CXX11_OVERRIDE
207         {
208                 for (size_t i = 0; i < colnames.size(); ++i)
209                 {
210                         if (colnames[i] == column)
211                         {
212                                 index = i;
213                                 return true;
214                         }
215                 }
216                 return false;
217         }
218
219         SQL::Field GetValue(int row, int column)
220         {
221                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
222                 {
223                         return fieldlists[row][column];
224                 }
225                 return SQL::Field();
226         }
227
228         bool GetRow(SQL::Row& result) CXX11_OVERRIDE
229         {
230                 if (currentrow < rows)
231                 {
232                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
233                         currentrow++;
234                         return true;
235                 }
236                 else
237                 {
238                         result.clear();
239                         return false;
240                 }
241         }
242 };
243
244 /** Represents a connection to a mysql database
245  */
246 class SQLConnection : public SQL::Provider
247 {
248  public:
249         reference<ConfigTag> config;
250         MYSQL *connection;
251         Mutex lock;
252
253         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
254         SQLConnection(Module* p, ConfigTag* tag) : SQL::Provider(p, "SQL/" + tag->getString("id")),
255                 config(tag), connection(NULL)
256         {
257         }
258
259         ~SQLConnection()
260         {
261                 Close();
262         }
263
264         // This method connects to the database using the credentials supplied to the constructor, and returns
265         // true upon success.
266         bool Connect()
267         {
268                 unsigned int timeout = 1;
269                 connection = mysql_init(connection);
270                 mysql_options(connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
271                 std::string host = config->getString("host");
272                 std::string user = config->getString("user");
273                 std::string pass = config->getString("pass");
274                 std::string dbname = config->getString("name");
275                 unsigned int port = config->getUInt("port", 3306);
276                 bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
277                 if (!rv)
278                         return rv;
279
280                 // Enable character set settings
281                 std::string charset = config->getString("charset");
282                 if ((!charset.empty()) && (mysql_set_character_set(connection, charset.c_str())))
283                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not set character set to \"%s\"", charset.c_str());
284
285                 std::string initquery;
286                 if (config->readString("initialquery", initquery))
287                 {
288                         mysql_query(connection,initquery.c_str());
289                 }
290                 return true;
291         }
292
293         ModuleSQL* Parent()
294         {
295                 return (ModuleSQL*)(Module*)creator;
296         }
297
298         MySQLresult* DoBlockingQuery(const std::string& query)
299         {
300
301                 /* Parse the command string and dispatch it to mysql */
302                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
303                 {
304                         /* Successfull query */
305                         MYSQL_RES* res = mysql_use_result(connection);
306                         unsigned long rows = mysql_affected_rows(connection);
307                         return new MySQLresult(res, rows);
308                 }
309                 else
310                 {
311                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
312                          * possible error numbers and error messages */
313                         SQL::Error e(SQL::QREPLY_FAIL, InspIRCd::Format("%u: %s", mysql_errno(connection), mysql_error(connection)));
314                         return new MySQLresult(e);
315                 }
316         }
317
318         bool CheckConnection()
319         {
320                 if (!connection || mysql_ping(connection) != 0)
321                         return Connect();
322                 return true;
323         }
324
325         std::string GetError()
326         {
327                 return mysql_error(connection);
328         }
329
330         void Close()
331         {
332                 mysql_close(connection);
333         }
334
335         void Submit(SQL::Query* q, const std::string& qs) CXX11_OVERRIDE
336         {
337                 ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Executing MySQL query: " + qs);
338                 Parent()->Dispatcher->LockQueue();
339                 Parent()->qq.push_back(QQueueItem(q, qs, this));
340                 Parent()->Dispatcher->UnlockQueueWakeup();
341         }
342
343         void Submit(SQL::Query* call, const std::string& q, const SQL::ParamList& p) CXX11_OVERRIDE
344         {
345                 std::string res;
346                 unsigned int param = 0;
347                 for(std::string::size_type i = 0; i < q.length(); i++)
348                 {
349                         if (q[i] != '?')
350                                 res.push_back(q[i]);
351                         else
352                         {
353                                 if (param < p.size())
354                                 {
355                                         std::string parm = p[param++];
356                                         // In the worst case, each character may need to be encoded as using two bytes,
357                                         // and one byte is the terminating null
358                                         std::vector<char> buffer(parm.length() * 2 + 1);
359
360                                         // The return value of mysql_real_escape_string() is the length of the encoded string,
361                                         // not including the terminating null
362                                         unsigned long escapedsize = mysql_real_escape_string(connection, &buffer[0], parm.c_str(), parm.length());
363                                         res.append(&buffer[0], escapedsize);
364                                 }
365                         }
366                 }
367                 Submit(call, res);
368         }
369
370         void Submit(SQL::Query* call, const std::string& q, const SQL::ParamMap& p) CXX11_OVERRIDE
371         {
372                 std::string res;
373                 for(std::string::size_type i = 0; i < q.length(); i++)
374                 {
375                         if (q[i] != '$')
376                                 res.push_back(q[i]);
377                         else
378                         {
379                                 std::string field;
380                                 i++;
381                                 while (i < q.length() && isalnum(q[i]))
382                                         field.push_back(q[i++]);
383                                 i--;
384
385                                 SQL::ParamMap::const_iterator it = p.find(field);
386                                 if (it != p.end())
387                                 {
388                                         std::string parm = it->second;
389                                         // NOTE: See above
390                                         std::vector<char> buffer(parm.length() * 2 + 1);
391                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
392                                         res.append(&buffer[0], escapedsize);
393                                 }
394                         }
395                 }
396                 Submit(call, res);
397         }
398 };
399
400 ModuleSQL::ModuleSQL()
401 {
402         Dispatcher = NULL;
403 }
404
405 void ModuleSQL::init()
406 {
407         Dispatcher = new DispatcherThread(this);
408         ServerInstance->Threads.Start(Dispatcher);
409 }
410
411 ModuleSQL::~ModuleSQL()
412 {
413         if (Dispatcher)
414         {
415                 Dispatcher->join();
416                 Dispatcher->OnNotify();
417                 delete Dispatcher;
418         }
419         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
420         {
421                 delete i->second;
422         }
423 }
424
425 void ModuleSQL::ReadConfig(ConfigStatus& status)
426 {
427         ConnMap conns;
428         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
429         for(ConfigIter i = tags.first; i != tags.second; i++)
430         {
431                 if (!stdalgo::string::equalsci(i->second->getString("module"), "mysql"))
432                         continue;
433                 std::string id = i->second->getString("id");
434                 ConnMap::iterator curr = connections.find(id);
435                 if (curr == connections.end())
436                 {
437                         SQLConnection* conn = new SQLConnection(this, i->second);
438                         conns.insert(std::make_pair(id, conn));
439                         ServerInstance->Modules->AddService(*conn);
440                 }
441                 else
442                 {
443                         conns.insert(*curr);
444                         connections.erase(curr);
445                 }
446         }
447
448         // now clean up the deleted databases
449         Dispatcher->LockQueue();
450         SQL::Error err(SQL::BAD_DBID);
451         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
452         {
453                 ServerInstance->Modules->DelService(*i->second);
454                 // it might be running a query on this database. Wait for that to complete
455                 i->second->lock.Lock();
456                 i->second->lock.Unlock();
457                 // now remove all active queries to this DB
458                 for (size_t j = qq.size(); j > 0; j--)
459                 {
460                         size_t k = j - 1;
461                         if (qq[k].c == i->second)
462                         {
463                                 qq[k].q->OnError(err);
464                                 delete qq[k].q;
465                                 qq.erase(qq.begin() + k);
466                         }
467                 }
468                 // finally, nuke the connection
469                 delete i->second;
470         }
471         Dispatcher->UnlockQueue();
472         connections.swap(conns);
473 }
474
475 void ModuleSQL::OnUnloadModule(Module* mod)
476 {
477         SQL::Error err(SQL::BAD_DBID);
478         Dispatcher->LockQueue();
479         unsigned int i = qq.size();
480         while (i > 0)
481         {
482                 i--;
483                 if (qq[i].q->creator == mod)
484                 {
485                         if (i == 0)
486                         {
487                                 // need to wait until the query is done
488                                 // (the result will be discarded)
489                                 qq[i].c->lock.Lock();
490                                 qq[i].c->lock.Unlock();
491                         }
492                         qq[i].q->OnError(err);
493                         delete qq[i].q;
494                         qq.erase(qq.begin() + i);
495                 }
496         }
497         Dispatcher->UnlockQueue();
498         // clean up any result queue entries
499         Dispatcher->OnNotify();
500 }
501
502 Version ModuleSQL::GetVersion()
503 {
504         return Version("Provides MySQL support", VF_VENDOR);
505 }
506
507 void DispatcherThread::Run()
508 {
509         this->LockQueue();
510         while (!this->GetExitFlag())
511         {
512                 if (!Parent->qq.empty())
513                 {
514                         QQueueItem i = Parent->qq.front();
515                         i.c->lock.Lock();
516                         this->UnlockQueue();
517                         MySQLresult* res = i.c->DoBlockingQuery(i.query);
518                         i.c->lock.Unlock();
519
520                         /*
521                          * At this point, the main thread could be working on:
522                          *  Rehash - delete i.c out from under us. We don't care about that.
523                          *  UnloadModule - delete i.q and the qq item. Need to avoid reporting results.
524                          */
525
526                         this->LockQueue();
527                         if (!Parent->qq.empty() && Parent->qq.front().q == i.q)
528                         {
529                                 Parent->qq.pop_front();
530                                 Parent->rq.push_back(RQueueItem(i.q, res));
531                                 NotifyParent();
532                         }
533                         else
534                         {
535                                 // UnloadModule ate the query
536                                 delete res;
537                         }
538                 }
539                 else
540                 {
541                         /* We know the queue is empty, we can safely hang this thread until
542                          * something happens
543                          */
544                         this->WaitForQueue();
545                 }
546         }
547         this->UnlockQueue();
548 }
549
550 void DispatcherThread::OnNotify()
551 {
552         // this could unlock during the dispatch, but OnResult isn't expected to take that long
553         this->LockQueue();
554         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
555         {
556                 MySQLresult* res = i->r;
557                 if (res->err.code == SQL::SUCCESS)
558                         i->q->OnResult(*res);
559                 else
560                         i->q->OnError(res->err);
561                 delete i->q;
562                 delete i->r;
563         }
564         Parent->rq.clear();
565         this->UnlockQueue();
566 }
567
568 MODULE_INIT(ModuleSQL)