]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
Merge branch 'master+listmode'
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22
23 /* Stop mysql wanting to use long long */
24 #define NO_CLIENT_LONG_LONG
25
26 #include "inspircd.h"
27 #include <mysql.h>
28 #include "modules/sql.h"
29
30 #ifdef _WIN32
31 # pragma comment(lib, "libmysql.lib")
32 #endif
33
34 /* VERSION 3 API: With nonblocking (threaded) requests */
35
36 /* $CompileFlags: exec("mysql_config --include") */
37 /* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */
38
39 /* THE NONBLOCKING MYSQL API!
40  *
41  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
42  * that instead, you should thread your program. This is what i've done here to allow for
43  * asyncronous SQL requests via mysql. The way this works is as follows:
44  *
45  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
46  * using a queue with priorities. There is a mutex on either end which prevents two threads
47  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
48  * worker thread wakes up, and checks if there is a request at the head of its queue.
49  * If there is, it processes this request, blocking the worker thread but leaving the ircd
50  * thread to go about its business as usual. During this period, the ircd thread is able
51  * to insert futher pending requests into the queue.
52  *
53  * Once the processing of a request is complete, it is removed from the incoming queue to
54  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
55  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
56  * connection ID through the connection.
57  *
58  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
59  * of the queue, and sends it on its way to the original calling module.
60  *
61  * XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
62  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
63  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
64  * however, if we were to call a module's OnRequest even from within a thread which was not the
65  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
66  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
67  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
68  * gauranteed threadsafe!)
69  *
70  * For a diagram of this system please see http://wiki.inspircd.org/Mysql2
71  */
72
73 class SQLConnection;
74 class MySQLresult;
75 class DispatcherThread;
76
77 struct QQueueItem
78 {
79         SQLQuery* q;
80         std::string query;
81         SQLConnection* c;
82         QQueueItem(SQLQuery* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
83 };
84
85 struct RQueueItem
86 {
87         SQLQuery* q;
88         MySQLresult* r;
89         RQueueItem(SQLQuery* Q, MySQLresult* R) : q(Q), r(R) {}
90 };
91
92 typedef insp::flat_map<std::string, SQLConnection*> ConnMap;
93 typedef std::deque<QQueueItem> QueryQueue;
94 typedef std::deque<RQueueItem> ResultQueue;
95
96 /** MySQL module
97  *  */
98 class ModuleSQL : public Module
99 {
100  public:
101         DispatcherThread* Dispatcher;
102         QueryQueue qq;       // MUST HOLD MUTEX
103         ResultQueue rq;      // MUST HOLD MUTEX
104         ConnMap connections; // main thread only
105
106         ModuleSQL();
107         void init() CXX11_OVERRIDE;
108         ~ModuleSQL();
109         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE;
110         void OnUnloadModule(Module* mod) CXX11_OVERRIDE;
111         Version GetVersion() CXX11_OVERRIDE;
112 };
113
114 class DispatcherThread : public SocketThread
115 {
116  private:
117         ModuleSQL* const Parent;
118  public:
119         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
120         ~DispatcherThread() { }
121         void Run();
122         void OnNotify();
123 };
124
125 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
126 #define mysql_field_count mysql_num_fields
127 #endif
128
129 /** Represents a mysql result set
130  */
131 class MySQLresult : public SQLResult
132 {
133  public:
134         SQLerror err;
135         int currentrow;
136         int rows;
137         std::vector<std::string> colnames;
138         std::vector<SQLEntries> fieldlists;
139
140         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL_NO_ERROR), currentrow(0), rows(0)
141         {
142                 if (affected_rows >= 1)
143                 {
144                         rows = affected_rows;
145                         fieldlists.resize(rows);
146                 }
147                 unsigned int field_count = 0;
148                 if (res)
149                 {
150                         MYSQL_ROW row;
151                         int n = 0;
152                         while ((row = mysql_fetch_row(res)))
153                         {
154                                 if (fieldlists.size() < (unsigned int)rows+1)
155                                 {
156                                         fieldlists.resize(fieldlists.size()+1);
157                                 }
158                                 field_count = 0;
159                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
160                                 if(mysql_num_fields(res) == 0)
161                                         break;
162                                 if (fields && mysql_num_fields(res))
163                                 {
164                                         colnames.clear();
165                                         while (field_count < mysql_num_fields(res))
166                                         {
167                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
168                                                 if (row[field_count])
169                                                         fieldlists[n].push_back(SQLEntry(row[field_count]));
170                                                 else
171                                                         fieldlists[n].push_back(SQLEntry());
172                                                 colnames.push_back(a);
173                                                 field_count++;
174                                         }
175                                         n++;
176                                 }
177                                 rows++;
178                         }
179                         mysql_free_result(res);
180                 }
181         }
182
183         MySQLresult(SQLerror& e) : err(e)
184         {
185
186         }
187
188         int Rows()
189         {
190                 return rows;
191         }
192
193         void GetCols(std::vector<std::string>& result)
194         {
195                 result.assign(colnames.begin(), colnames.end());
196         }
197
198         SQLEntry GetValue(int row, int column)
199         {
200                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
201                 {
202                         return fieldlists[row][column];
203                 }
204                 return SQLEntry();
205         }
206
207         bool GetRow(SQLEntries& result)
208         {
209                 if (currentrow < rows)
210                 {
211                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
212                         currentrow++;
213                         return true;
214                 }
215                 else
216                 {
217                         result.clear();
218                         return false;
219                 }
220         }
221 };
222
223 /** Represents a connection to a mysql database
224  */
225 class SQLConnection : public SQLProvider
226 {
227  public:
228         reference<ConfigTag> config;
229         MYSQL *connection;
230         Mutex lock;
231
232         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
233         SQLConnection(Module* p, ConfigTag* tag) : SQLProvider(p, "SQL/" + tag->getString("id")),
234                 config(tag), connection(NULL)
235         {
236         }
237
238         ~SQLConnection()
239         {
240                 Close();
241         }
242
243         // This method connects to the database using the credentials supplied to the constructor, and returns
244         // true upon success.
245         bool Connect()
246         {
247                 unsigned int timeout = 1;
248                 connection = mysql_init(connection);
249                 mysql_options(connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
250                 std::string host = config->getString("host");
251                 std::string user = config->getString("user");
252                 std::string pass = config->getString("pass");
253                 std::string dbname = config->getString("name");
254                 int port = config->getInt("port");
255                 bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
256                 if (!rv)
257                         return rv;
258                 std::string initquery;
259                 if (config->readString("initialquery", initquery))
260                 {
261                         mysql_query(connection,initquery.c_str());
262                 }
263                 return true;
264         }
265
266         ModuleSQL* Parent()
267         {
268                 return (ModuleSQL*)(Module*)creator;
269         }
270
271         MySQLresult* DoBlockingQuery(const std::string& query)
272         {
273
274                 /* Parse the command string and dispatch it to mysql */
275                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
276                 {
277                         /* Successfull query */
278                         MYSQL_RES* res = mysql_use_result(connection);
279                         unsigned long rows = mysql_affected_rows(connection);
280                         return new MySQLresult(res, rows);
281                 }
282                 else
283                 {
284                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
285                          * possible error numbers and error messages */
286                         SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + ": " + mysql_error(connection));
287                         return new MySQLresult(e);
288                 }
289         }
290
291         bool CheckConnection()
292         {
293                 if (!connection || mysql_ping(connection) != 0)
294                         return Connect();
295                 return true;
296         }
297
298         std::string GetError()
299         {
300                 return mysql_error(connection);
301         }
302
303         void Close()
304         {
305                 mysql_close(connection);
306         }
307
308         void submit(SQLQuery* q, const std::string& qs)
309         {
310                 Parent()->Dispatcher->LockQueue();
311                 Parent()->qq.push_back(QQueueItem(q, qs, this));
312                 Parent()->Dispatcher->UnlockQueueWakeup();
313         }
314
315         void submit(SQLQuery* call, const std::string& q, const ParamL& p)
316         {
317                 std::string res;
318                 unsigned int param = 0;
319                 for(std::string::size_type i = 0; i < q.length(); i++)
320                 {
321                         if (q[i] != '?')
322                                 res.push_back(q[i]);
323                         else
324                         {
325                                 if (param < p.size())
326                                 {
327                                         std::string parm = p[param++];
328                                         // In the worst case, each character may need to be encoded as using two bytes,
329                                         // and one byte is the terminating null
330                                         std::vector<char> buffer(parm.length() * 2 + 1);
331
332                                         // The return value of mysql_escape_string() is the length of the encoded string,
333                                         // not including the terminating null
334                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
335 //                                      mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
336                                         res.append(&buffer[0], escapedsize);
337                                 }
338                         }
339                 }
340                 submit(call, res);
341         }
342
343         void submit(SQLQuery* call, const std::string& q, const ParamM& p)
344         {
345                 std::string res;
346                 for(std::string::size_type i = 0; i < q.length(); i++)
347                 {
348                         if (q[i] != '$')
349                                 res.push_back(q[i]);
350                         else
351                         {
352                                 std::string field;
353                                 i++;
354                                 while (i < q.length() && isalnum(q[i]))
355                                         field.push_back(q[i++]);
356                                 i--;
357
358                                 ParamM::const_iterator it = p.find(field);
359                                 if (it != p.end())
360                                 {
361                                         std::string parm = it->second;
362                                         // NOTE: See above
363                                         std::vector<char> buffer(parm.length() * 2 + 1);
364                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
365                                         res.append(&buffer[0], escapedsize);
366                                 }
367                         }
368                 }
369                 submit(call, res);
370         }
371 };
372
373 ModuleSQL::ModuleSQL()
374 {
375         Dispatcher = NULL;
376 }
377
378 void ModuleSQL::init()
379 {
380         Dispatcher = new DispatcherThread(this);
381         ServerInstance->Threads.Start(Dispatcher);
382 }
383
384 ModuleSQL::~ModuleSQL()
385 {
386         if (Dispatcher)
387         {
388                 Dispatcher->join();
389                 Dispatcher->OnNotify();
390                 delete Dispatcher;
391         }
392         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
393         {
394                 delete i->second;
395         }
396 }
397
398 void ModuleSQL::ReadConfig(ConfigStatus& status)
399 {
400         ConnMap conns;
401         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
402         for(ConfigIter i = tags.first; i != tags.second; i++)
403         {
404                 if (i->second->getString("module", "mysql") != "mysql")
405                         continue;
406                 std::string id = i->second->getString("id");
407                 ConnMap::iterator curr = connections.find(id);
408                 if (curr == connections.end())
409                 {
410                         SQLConnection* conn = new SQLConnection(this, i->second);
411                         conns.insert(std::make_pair(id, conn));
412                         ServerInstance->Modules->AddService(*conn);
413                 }
414                 else
415                 {
416                         conns.insert(*curr);
417                         connections.erase(curr);
418                 }
419         }
420
421         // now clean up the deleted databases
422         Dispatcher->LockQueue();
423         SQLerror err(SQL_BAD_DBID);
424         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
425         {
426                 ServerInstance->Modules->DelService(*i->second);
427                 // it might be running a query on this database. Wait for that to complete
428                 i->second->lock.Lock();
429                 i->second->lock.Unlock();
430                 // now remove all active queries to this DB
431                 for (size_t j = qq.size(); j > 0; j--)
432                 {
433                         size_t k = j - 1;
434                         if (qq[k].c == i->second)
435                         {
436                                 qq[k].q->OnError(err);
437                                 delete qq[k].q;
438                                 qq.erase(qq.begin() + k);
439                         }
440                 }
441                 // finally, nuke the connection
442                 delete i->second;
443         }
444         Dispatcher->UnlockQueue();
445         connections.swap(conns);
446 }
447
448 void ModuleSQL::OnUnloadModule(Module* mod)
449 {
450         SQLerror err(SQL_BAD_DBID);
451         Dispatcher->LockQueue();
452         unsigned int i = qq.size();
453         while (i > 0)
454         {
455                 i--;
456                 if (qq[i].q->creator == mod)
457                 {
458                         if (i == 0)
459                         {
460                                 // need to wait until the query is done
461                                 // (the result will be discarded)
462                                 qq[i].c->lock.Lock();
463                                 qq[i].c->lock.Unlock();
464                         }
465                         qq[i].q->OnError(err);
466                         delete qq[i].q;
467                         qq.erase(qq.begin() + i);
468                 }
469         }
470         Dispatcher->UnlockQueue();
471         // clean up any result queue entries
472         Dispatcher->OnNotify();
473 }
474
475 Version ModuleSQL::GetVersion()
476 {
477         return Version("MySQL support", VF_VENDOR);
478 }
479
480 void DispatcherThread::Run()
481 {
482         this->LockQueue();
483         while (!this->GetExitFlag())
484         {
485                 if (!Parent->qq.empty())
486                 {
487                         QQueueItem i = Parent->qq.front();
488                         i.c->lock.Lock();
489                         this->UnlockQueue();
490                         MySQLresult* res = i.c->DoBlockingQuery(i.query);
491                         i.c->lock.Unlock();
492
493                         /*
494                          * At this point, the main thread could be working on:
495                          *  Rehash - delete i.c out from under us. We don't care about that.
496                          *  UnloadModule - delete i.q and the qq item. Need to avoid reporting results.
497                          */
498
499                         this->LockQueue();
500                         if (!Parent->qq.empty() && Parent->qq.front().q == i.q)
501                         {
502                                 Parent->qq.pop_front();
503                                 Parent->rq.push_back(RQueueItem(i.q, res));
504                                 NotifyParent();
505                         }
506                         else
507                         {
508                                 // UnloadModule ate the query
509                                 delete res;
510                         }
511                 }
512                 else
513                 {
514                         /* We know the queue is empty, we can safely hang this thread until
515                          * something happens
516                          */
517                         this->WaitForQueue();
518                 }
519         }
520         this->UnlockQueue();
521 }
522
523 void DispatcherThread::OnNotify()
524 {
525         // this could unlock during the dispatch, but OnResult isn't expected to take that long
526         this->LockQueue();
527         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
528         {
529                 MySQLresult* res = i->r;
530                 if (res->err.id == SQL_NO_ERROR)
531                         i->q->OnResult(*res);
532                 else
533                         i->q->OnError(res->err);
534                 delete i->q;
535                 delete i->r;
536         }
537         Parent->rq.clear();
538         this->UnlockQueue();
539 }
540
541 MODULE_INIT(ModuleSQL)