]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
Change allocation of InspIRCd::Threads to be physically part of the object containing it
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22
23 /* Stop mysql wanting to use long long */
24 #define NO_CLIENT_LONG_LONG
25
26 #include "inspircd.h"
27 #include <mysql.h>
28 #include "modules/sql.h"
29
30 #ifdef _WIN32
31 # pragma comment(lib, "mysqlclient.lib")
32 # pragma comment(lib, "advapi32.lib")
33 # pragma comment(linker, "/NODEFAULTLIB:LIBCMT")
34 #endif
35
36 /* VERSION 3 API: With nonblocking (threaded) requests */
37
38 /* $CompileFlags: exec("mysql_config --include") */
39 /* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */
40
41 /* THE NONBLOCKING MYSQL API!
42  *
43  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
44  * that instead, you should thread your program. This is what i've done here to allow for
45  * asyncronous SQL requests via mysql. The way this works is as follows:
46  *
47  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
48  * using a queue with priorities. There is a mutex on either end which prevents two threads
49  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
50  * worker thread wakes up, and checks if there is a request at the head of its queue.
51  * If there is, it processes this request, blocking the worker thread but leaving the ircd
52  * thread to go about its business as usual. During this period, the ircd thread is able
53  * to insert futher pending requests into the queue.
54  *
55  * Once the processing of a request is complete, it is removed from the incoming queue to
56  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
57  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
58  * connection ID through the connection.
59  *
60  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
61  * of the queue, and sends it on its way to the original calling module.
62  *
63  * XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
64  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
65  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
66  * however, if we were to call a module's OnRequest even from within a thread which was not the
67  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
68  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
69  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
70  * gauranteed threadsafe!)
71  *
72  * For a diagram of this system please see http://wiki.inspircd.org/Mysql2
73  */
74
75 class SQLConnection;
76 class MySQLresult;
77 class DispatcherThread;
78
79 struct QQueueItem
80 {
81         SQLQuery* q;
82         std::string query;
83         SQLConnection* c;
84         QQueueItem(SQLQuery* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
85 };
86
87 struct RQueueItem
88 {
89         SQLQuery* q;
90         MySQLresult* r;
91         RQueueItem(SQLQuery* Q, MySQLresult* R) : q(Q), r(R) {}
92 };
93
94 typedef std::map<std::string, SQLConnection*> ConnMap;
95 typedef std::deque<QQueueItem> QueryQueue;
96 typedef std::deque<RQueueItem> ResultQueue;
97
98 /** MySQL module
99  *  */
100 class ModuleSQL : public Module
101 {
102  public:
103         DispatcherThread* Dispatcher;
104         QueryQueue qq;       // MUST HOLD MUTEX
105         ResultQueue rq;      // MUST HOLD MUTEX
106         ConnMap connections; // main thread only
107
108         ModuleSQL();
109         void init() CXX11_OVERRIDE;
110         ~ModuleSQL();
111         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE;
112         void OnUnloadModule(Module* mod) CXX11_OVERRIDE;
113         Version GetVersion() CXX11_OVERRIDE;
114 };
115
116 class DispatcherThread : public SocketThread
117 {
118  private:
119         ModuleSQL* const Parent;
120  public:
121         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
122         ~DispatcherThread() { }
123         void Run();
124         void OnNotify();
125 };
126
127 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
128 #define mysql_field_count mysql_num_fields
129 #endif
130
131 /** Represents a mysql result set
132  */
133 class MySQLresult : public SQLResult
134 {
135  public:
136         SQLerror err;
137         int currentrow;
138         int rows;
139         std::vector<std::string> colnames;
140         std::vector<SQLEntries> fieldlists;
141
142         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL_NO_ERROR), currentrow(0), rows(0)
143         {
144                 if (affected_rows >= 1)
145                 {
146                         rows = affected_rows;
147                         fieldlists.resize(rows);
148                 }
149                 unsigned int field_count = 0;
150                 if (res)
151                 {
152                         MYSQL_ROW row;
153                         int n = 0;
154                         while ((row = mysql_fetch_row(res)))
155                         {
156                                 if (fieldlists.size() < (unsigned int)rows+1)
157                                 {
158                                         fieldlists.resize(fieldlists.size()+1);
159                                 }
160                                 field_count = 0;
161                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
162                                 if(mysql_num_fields(res) == 0)
163                                         break;
164                                 if (fields && mysql_num_fields(res))
165                                 {
166                                         colnames.clear();
167                                         while (field_count < mysql_num_fields(res))
168                                         {
169                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
170                                                 if (row[field_count])
171                                                         fieldlists[n].push_back(SQLEntry(row[field_count]));
172                                                 else
173                                                         fieldlists[n].push_back(SQLEntry());
174                                                 colnames.push_back(a);
175                                                 field_count++;
176                                         }
177                                         n++;
178                                 }
179                                 rows++;
180                         }
181                         mysql_free_result(res);
182                 }
183         }
184
185         MySQLresult(SQLerror& e) : err(e)
186         {
187
188         }
189
190         int Rows()
191         {
192                 return rows;
193         }
194
195         void GetCols(std::vector<std::string>& result)
196         {
197                 result.assign(colnames.begin(), colnames.end());
198         }
199
200         SQLEntry GetValue(int row, int column)
201         {
202                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
203                 {
204                         return fieldlists[row][column];
205                 }
206                 return SQLEntry();
207         }
208
209         bool GetRow(SQLEntries& result)
210         {
211                 if (currentrow < rows)
212                 {
213                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
214                         currentrow++;
215                         return true;
216                 }
217                 else
218                 {
219                         result.clear();
220                         return false;
221                 }
222         }
223 };
224
225 /** Represents a connection to a mysql database
226  */
227 class SQLConnection : public SQLProvider
228 {
229  public:
230         reference<ConfigTag> config;
231         MYSQL *connection;
232         Mutex lock;
233
234         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
235         SQLConnection(Module* p, ConfigTag* tag) : SQLProvider(p, "SQL/" + tag->getString("id")),
236                 config(tag), connection(NULL)
237         {
238         }
239
240         ~SQLConnection()
241         {
242                 Close();
243         }
244
245         // This method connects to the database using the credentials supplied to the constructor, and returns
246         // true upon success.
247         bool Connect()
248         {
249                 unsigned int timeout = 1;
250                 connection = mysql_init(connection);
251                 mysql_options(connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
252                 std::string host = config->getString("host");
253                 std::string user = config->getString("user");
254                 std::string pass = config->getString("pass");
255                 std::string dbname = config->getString("name");
256                 int port = config->getInt("port");
257                 bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
258                 if (!rv)
259                         return rv;
260                 std::string initquery;
261                 if (config->readString("initialquery", initquery))
262                 {
263                         mysql_query(connection,initquery.c_str());
264                 }
265                 return true;
266         }
267
268         ModuleSQL* Parent()
269         {
270                 return (ModuleSQL*)(Module*)creator;
271         }
272
273         MySQLresult* DoBlockingQuery(const std::string& query)
274         {
275
276                 /* Parse the command string and dispatch it to mysql */
277                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
278                 {
279                         /* Successfull query */
280                         MYSQL_RES* res = mysql_use_result(connection);
281                         unsigned long rows = mysql_affected_rows(connection);
282                         return new MySQLresult(res, rows);
283                 }
284                 else
285                 {
286                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
287                          * possible error numbers and error messages */
288                         SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + ": " + mysql_error(connection));
289                         return new MySQLresult(e);
290                 }
291         }
292
293         bool CheckConnection()
294         {
295                 if (!connection || mysql_ping(connection) != 0)
296                         return Connect();
297                 return true;
298         }
299
300         std::string GetError()
301         {
302                 return mysql_error(connection);
303         }
304
305         void Close()
306         {
307                 mysql_close(connection);
308         }
309
310         void submit(SQLQuery* q, const std::string& qs)
311         {
312                 Parent()->Dispatcher->LockQueue();
313                 Parent()->qq.push_back(QQueueItem(q, qs, this));
314                 Parent()->Dispatcher->UnlockQueueWakeup();
315         }
316
317         void submit(SQLQuery* call, const std::string& q, const ParamL& p)
318         {
319                 std::string res;
320                 unsigned int param = 0;
321                 for(std::string::size_type i = 0; i < q.length(); i++)
322                 {
323                         if (q[i] != '?')
324                                 res.push_back(q[i]);
325                         else
326                         {
327                                 if (param < p.size())
328                                 {
329                                         std::string parm = p[param++];
330                                         // In the worst case, each character may need to be encoded as using two bytes,
331                                         // and one byte is the terminating null
332                                         std::vector<char> buffer(parm.length() * 2 + 1);
333
334                                         // The return value of mysql_escape_string() is the length of the encoded string,
335                                         // not including the terminating null
336                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
337 //                                      mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
338                                         res.append(&buffer[0], escapedsize);
339                                 }
340                         }
341                 }
342                 submit(call, res);
343         }
344
345         void submit(SQLQuery* call, const std::string& q, const ParamM& p)
346         {
347                 std::string res;
348                 for(std::string::size_type i = 0; i < q.length(); i++)
349                 {
350                         if (q[i] != '$')
351                                 res.push_back(q[i]);
352                         else
353                         {
354                                 std::string field;
355                                 i++;
356                                 while (i < q.length() && isalnum(q[i]))
357                                         field.push_back(q[i++]);
358                                 i--;
359
360                                 ParamM::const_iterator it = p.find(field);
361                                 if (it != p.end())
362                                 {
363                                         std::string parm = it->second;
364                                         // NOTE: See above
365                                         std::vector<char> buffer(parm.length() * 2 + 1);
366                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
367                                         res.append(&buffer[0], escapedsize);
368                                 }
369                         }
370                 }
371                 submit(call, res);
372         }
373 };
374
375 ModuleSQL::ModuleSQL()
376 {
377         Dispatcher = NULL;
378 }
379
380 void ModuleSQL::init()
381 {
382         Dispatcher = new DispatcherThread(this);
383         ServerInstance->Threads.Start(Dispatcher);
384 }
385
386 ModuleSQL::~ModuleSQL()
387 {
388         if (Dispatcher)
389         {
390                 Dispatcher->join();
391                 Dispatcher->OnNotify();
392                 delete Dispatcher;
393         }
394         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
395         {
396                 delete i->second;
397         }
398 }
399
400 void ModuleSQL::ReadConfig(ConfigStatus& status)
401 {
402         ConnMap conns;
403         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
404         for(ConfigIter i = tags.first; i != tags.second; i++)
405         {
406                 if (i->second->getString("module", "mysql") != "mysql")
407                         continue;
408                 std::string id = i->second->getString("id");
409                 ConnMap::iterator curr = connections.find(id);
410                 if (curr == connections.end())
411                 {
412                         SQLConnection* conn = new SQLConnection(this, i->second);
413                         conns.insert(std::make_pair(id, conn));
414                         ServerInstance->Modules->AddService(*conn);
415                 }
416                 else
417                 {
418                         conns.insert(*curr);
419                         connections.erase(curr);
420                 }
421         }
422
423         // now clean up the deleted databases
424         Dispatcher->LockQueue();
425         SQLerror err(SQL_BAD_DBID);
426         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
427         {
428                 ServerInstance->Modules->DelService(*i->second);
429                 // it might be running a query on this database. Wait for that to complete
430                 i->second->lock.Lock();
431                 i->second->lock.Unlock();
432                 // now remove all active queries to this DB
433                 for (size_t j = qq.size(); j > 0; j--)
434                 {
435                         size_t k = j - 1;
436                         if (qq[k].c == i->second)
437                         {
438                                 qq[k].q->OnError(err);
439                                 delete qq[k].q;
440                                 qq.erase(qq.begin() + k);
441                         }
442                 }
443                 // finally, nuke the connection
444                 delete i->second;
445         }
446         Dispatcher->UnlockQueue();
447         connections.swap(conns);
448 }
449
450 void ModuleSQL::OnUnloadModule(Module* mod)
451 {
452         SQLerror err(SQL_BAD_DBID);
453         Dispatcher->LockQueue();
454         unsigned int i = qq.size();
455         while (i > 0)
456         {
457                 i--;
458                 if (qq[i].q->creator == mod)
459                 {
460                         if (i == 0)
461                         {
462                                 // need to wait until the query is done
463                                 // (the result will be discarded)
464                                 qq[i].c->lock.Lock();
465                                 qq[i].c->lock.Unlock();
466                         }
467                         qq[i].q->OnError(err);
468                         delete qq[i].q;
469                         qq.erase(qq.begin() + i);
470                 }
471         }
472         Dispatcher->UnlockQueue();
473         // clean up any result queue entries
474         Dispatcher->OnNotify();
475 }
476
477 Version ModuleSQL::GetVersion()
478 {
479         return Version("MySQL support", VF_VENDOR);
480 }
481
482 void DispatcherThread::Run()
483 {
484         this->LockQueue();
485         while (!this->GetExitFlag())
486         {
487                 if (!Parent->qq.empty())
488                 {
489                         QQueueItem i = Parent->qq.front();
490                         i.c->lock.Lock();
491                         this->UnlockQueue();
492                         MySQLresult* res = i.c->DoBlockingQuery(i.query);
493                         i.c->lock.Unlock();
494
495                         /*
496                          * At this point, the main thread could be working on:
497                          *  Rehash - delete i.c out from under us. We don't care about that.
498                          *  UnloadModule - delete i.q and the qq item. Need to avoid reporting results.
499                          */
500
501                         this->LockQueue();
502                         if (!Parent->qq.empty() && Parent->qq.front().q == i.q)
503                         {
504                                 Parent->qq.pop_front();
505                                 Parent->rq.push_back(RQueueItem(i.q, res));
506                                 NotifyParent();
507                         }
508                         else
509                         {
510                                 // UnloadModule ate the query
511                                 delete res;
512                         }
513                 }
514                 else
515                 {
516                         /* We know the queue is empty, we can safely hang this thread until
517                          * something happens
518                          */
519                         this->WaitForQueue();
520                 }
521         }
522         this->UnlockQueue();
523 }
524
525 void DispatcherThread::OnNotify()
526 {
527         // this could unlock during the dispatch, but OnResult isn't expected to take that long
528         this->LockQueue();
529         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
530         {
531                 MySQLresult* res = i->r;
532                 if (res->err.id == SQL_NO_ERROR)
533                         i->q->OnResult(*res);
534                 else
535                         i->q->OnError(res->err);
536                 delete i->q;
537                 delete i->r;
538         }
539         Parent->rq.clear();
540         this->UnlockQueue();
541 }
542
543 MODULE_INIT(ModuleSQL)