]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
Merge branch 'master+sslconnmsg'
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22
23 /* Stop mysql wanting to use long long */
24 #define NO_CLIENT_LONG_LONG
25
26 #include "inspircd.h"
27 #include <mysql.h>
28 #include "modules/sql.h"
29
30 #ifdef _WIN32
31 # pragma comment(lib, "libmysql.lib")
32 #endif
33
34 /* VERSION 3 API: With nonblocking (threaded) requests */
35
36 /* $CompileFlags: exec("mysql_config --include") */
37 /* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */
38
39 /* THE NONBLOCKING MYSQL API!
40  *
41  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
42  * that instead, you should thread your program. This is what i've done here to allow for
43  * asyncronous SQL requests via mysql. The way this works is as follows:
44  *
45  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
46  * using a queue with priorities. There is a mutex on either end which prevents two threads
47  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
48  * worker thread wakes up, and checks if there is a request at the head of its queue.
49  * If there is, it processes this request, blocking the worker thread but leaving the ircd
50  * thread to go about its business as usual. During this period, the ircd thread is able
51  * to insert futher pending requests into the queue.
52  *
53  * Once the processing of a request is complete, it is removed from the incoming queue to
54  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
55  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
56  * connection ID through the connection.
57  *
58  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
59  * of the queue, and sends it on its way to the original calling module.
60  *
61  * XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
62  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
63  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
64  * however, if we were to call a module's OnRequest even from within a thread which was not the
65  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
66  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
67  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
68  * gauranteed threadsafe!)
69  *
70  * For a diagram of this system please see http://wiki.inspircd.org/Mysql2
71  */
72
73 class SQLConnection;
74 class MySQLresult;
75 class DispatcherThread;
76
77 struct QQueueItem
78 {
79         SQLQuery* q;
80         std::string query;
81         SQLConnection* c;
82         QQueueItem(SQLQuery* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
83 };
84
85 struct RQueueItem
86 {
87         SQLQuery* q;
88         MySQLresult* r;
89         RQueueItem(SQLQuery* Q, MySQLresult* R) : q(Q), r(R) {}
90 };
91
92 typedef insp::flat_map<std::string, SQLConnection*> ConnMap;
93 typedef std::deque<QQueueItem> QueryQueue;
94 typedef std::deque<RQueueItem> ResultQueue;
95
96 /** MySQL module
97  *  */
98 class ModuleSQL : public Module
99 {
100  public:
101         DispatcherThread* Dispatcher;
102         QueryQueue qq;       // MUST HOLD MUTEX
103         ResultQueue rq;      // MUST HOLD MUTEX
104         ConnMap connections; // main thread only
105
106         ModuleSQL();
107         void init() CXX11_OVERRIDE;
108         ~ModuleSQL();
109         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE;
110         void OnUnloadModule(Module* mod) CXX11_OVERRIDE;
111         Version GetVersion() CXX11_OVERRIDE;
112 };
113
114 class DispatcherThread : public SocketThread
115 {
116  private:
117         ModuleSQL* const Parent;
118  public:
119         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
120         ~DispatcherThread() { }
121         void Run();
122         void OnNotify();
123 };
124
125 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
126 #define mysql_field_count mysql_num_fields
127 #endif
128
129 /** Represents a mysql result set
130  */
131 class MySQLresult : public SQLResult
132 {
133  public:
134         SQLerror err;
135         int currentrow;
136         int rows;
137         std::vector<std::string> colnames;
138         std::vector<SQLEntries> fieldlists;
139
140         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL_NO_ERROR), currentrow(0), rows(0)
141         {
142                 if (affected_rows >= 1)
143                 {
144                         rows = affected_rows;
145                         fieldlists.resize(rows);
146                 }
147                 unsigned int field_count = 0;
148                 if (res)
149                 {
150                         MYSQL_ROW row;
151                         int n = 0;
152                         while ((row = mysql_fetch_row(res)))
153                         {
154                                 if (fieldlists.size() < (unsigned int)rows+1)
155                                 {
156                                         fieldlists.resize(fieldlists.size()+1);
157                                 }
158                                 field_count = 0;
159                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
160                                 if(mysql_num_fields(res) == 0)
161                                         break;
162                                 if (fields && mysql_num_fields(res))
163                                 {
164                                         colnames.clear();
165                                         while (field_count < mysql_num_fields(res))
166                                         {
167                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
168                                                 if (row[field_count])
169                                                         fieldlists[n].push_back(SQLEntry(row[field_count]));
170                                                 else
171                                                         fieldlists[n].push_back(SQLEntry());
172                                                 colnames.push_back(a);
173                                                 field_count++;
174                                         }
175                                         n++;
176                                 }
177                                 rows++;
178                         }
179                         mysql_free_result(res);
180                 }
181         }
182
183         MySQLresult(SQLerror& e) : err(e)
184         {
185
186         }
187
188         int Rows()
189         {
190                 return rows;
191         }
192
193         void GetCols(std::vector<std::string>& result)
194         {
195                 result.assign(colnames.begin(), colnames.end());
196         }
197
198         SQLEntry GetValue(int row, int column)
199         {
200                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
201                 {
202                         return fieldlists[row][column];
203                 }
204                 return SQLEntry();
205         }
206
207         bool GetRow(SQLEntries& result)
208         {
209                 if (currentrow < rows)
210                 {
211                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
212                         currentrow++;
213                         return true;
214                 }
215                 else
216                 {
217                         result.clear();
218                         return false;
219                 }
220         }
221 };
222
223 /** Represents a connection to a mysql database
224  */
225 class SQLConnection : public SQLProvider
226 {
227  public:
228         reference<ConfigTag> config;
229         MYSQL *connection;
230         Mutex lock;
231
232         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
233         SQLConnection(Module* p, ConfigTag* tag) : SQLProvider(p, "SQL/" + tag->getString("id")),
234                 config(tag), connection(NULL)
235         {
236         }
237
238         ~SQLConnection()
239         {
240                 Close();
241         }
242
243         // This method connects to the database using the credentials supplied to the constructor, and returns
244         // true upon success.
245         bool Connect()
246         {
247                 unsigned int timeout = 1;
248                 connection = mysql_init(connection);
249                 mysql_options(connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
250                 std::string host = config->getString("host");
251                 std::string user = config->getString("user");
252                 std::string pass = config->getString("pass");
253                 std::string dbname = config->getString("name");
254                 int port = config->getInt("port");
255                 bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
256                 if (!rv)
257                         return rv;
258
259                 // Enable character set settings
260                 std::string charset = config->getString("charset");
261                 if ((!charset.empty()) && (mysql_set_character_set(connection, charset.c_str())))
262                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not set character set to \"%s\"", charset.c_str());
263
264                 std::string initquery;
265                 if (config->readString("initialquery", initquery))
266                 {
267                         mysql_query(connection,initquery.c_str());
268                 }
269                 return true;
270         }
271
272         ModuleSQL* Parent()
273         {
274                 return (ModuleSQL*)(Module*)creator;
275         }
276
277         MySQLresult* DoBlockingQuery(const std::string& query)
278         {
279
280                 /* Parse the command string and dispatch it to mysql */
281                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
282                 {
283                         /* Successfull query */
284                         MYSQL_RES* res = mysql_use_result(connection);
285                         unsigned long rows = mysql_affected_rows(connection);
286                         return new MySQLresult(res, rows);
287                 }
288                 else
289                 {
290                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
291                          * possible error numbers and error messages */
292                         SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + ": " + mysql_error(connection));
293                         return new MySQLresult(e);
294                 }
295         }
296
297         bool CheckConnection()
298         {
299                 if (!connection || mysql_ping(connection) != 0)
300                         return Connect();
301                 return true;
302         }
303
304         std::string GetError()
305         {
306                 return mysql_error(connection);
307         }
308
309         void Close()
310         {
311                 mysql_close(connection);
312         }
313
314         void submit(SQLQuery* q, const std::string& qs)
315         {
316                 Parent()->Dispatcher->LockQueue();
317                 Parent()->qq.push_back(QQueueItem(q, qs, this));
318                 Parent()->Dispatcher->UnlockQueueWakeup();
319         }
320
321         void submit(SQLQuery* call, const std::string& q, const ParamL& p)
322         {
323                 std::string res;
324                 unsigned int param = 0;
325                 for(std::string::size_type i = 0; i < q.length(); i++)
326                 {
327                         if (q[i] != '?')
328                                 res.push_back(q[i]);
329                         else
330                         {
331                                 if (param < p.size())
332                                 {
333                                         std::string parm = p[param++];
334                                         // In the worst case, each character may need to be encoded as using two bytes,
335                                         // and one byte is the terminating null
336                                         std::vector<char> buffer(parm.length() * 2 + 1);
337
338                                         // The return value of mysql_escape_string() is the length of the encoded string,
339                                         // not including the terminating null
340                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
341 //                                      mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
342                                         res.append(&buffer[0], escapedsize);
343                                 }
344                         }
345                 }
346                 submit(call, res);
347         }
348
349         void submit(SQLQuery* call, const std::string& q, const ParamM& p)
350         {
351                 std::string res;
352                 for(std::string::size_type i = 0; i < q.length(); i++)
353                 {
354                         if (q[i] != '$')
355                                 res.push_back(q[i]);
356                         else
357                         {
358                                 std::string field;
359                                 i++;
360                                 while (i < q.length() && isalnum(q[i]))
361                                         field.push_back(q[i++]);
362                                 i--;
363
364                                 ParamM::const_iterator it = p.find(field);
365                                 if (it != p.end())
366                                 {
367                                         std::string parm = it->second;
368                                         // NOTE: See above
369                                         std::vector<char> buffer(parm.length() * 2 + 1);
370                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
371                                         res.append(&buffer[0], escapedsize);
372                                 }
373                         }
374                 }
375                 submit(call, res);
376         }
377 };
378
379 ModuleSQL::ModuleSQL()
380 {
381         Dispatcher = NULL;
382 }
383
384 void ModuleSQL::init()
385 {
386         Dispatcher = new DispatcherThread(this);
387         ServerInstance->Threads.Start(Dispatcher);
388 }
389
390 ModuleSQL::~ModuleSQL()
391 {
392         if (Dispatcher)
393         {
394                 Dispatcher->join();
395                 Dispatcher->OnNotify();
396                 delete Dispatcher;
397         }
398         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
399         {
400                 delete i->second;
401         }
402 }
403
404 void ModuleSQL::ReadConfig(ConfigStatus& status)
405 {
406         ConnMap conns;
407         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
408         for(ConfigIter i = tags.first; i != tags.second; i++)
409         {
410                 if (i->second->getString("module", "mysql") != "mysql")
411                         continue;
412                 std::string id = i->second->getString("id");
413                 ConnMap::iterator curr = connections.find(id);
414                 if (curr == connections.end())
415                 {
416                         SQLConnection* conn = new SQLConnection(this, i->second);
417                         conns.insert(std::make_pair(id, conn));
418                         ServerInstance->Modules->AddService(*conn);
419                 }
420                 else
421                 {
422                         conns.insert(*curr);
423                         connections.erase(curr);
424                 }
425         }
426
427         // now clean up the deleted databases
428         Dispatcher->LockQueue();
429         SQLerror err(SQL_BAD_DBID);
430         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
431         {
432                 ServerInstance->Modules->DelService(*i->second);
433                 // it might be running a query on this database. Wait for that to complete
434                 i->second->lock.Lock();
435                 i->second->lock.Unlock();
436                 // now remove all active queries to this DB
437                 for (size_t j = qq.size(); j > 0; j--)
438                 {
439                         size_t k = j - 1;
440                         if (qq[k].c == i->second)
441                         {
442                                 qq[k].q->OnError(err);
443                                 delete qq[k].q;
444                                 qq.erase(qq.begin() + k);
445                         }
446                 }
447                 // finally, nuke the connection
448                 delete i->second;
449         }
450         Dispatcher->UnlockQueue();
451         connections.swap(conns);
452 }
453
454 void ModuleSQL::OnUnloadModule(Module* mod)
455 {
456         SQLerror err(SQL_BAD_DBID);
457         Dispatcher->LockQueue();
458         unsigned int i = qq.size();
459         while (i > 0)
460         {
461                 i--;
462                 if (qq[i].q->creator == mod)
463                 {
464                         if (i == 0)
465                         {
466                                 // need to wait until the query is done
467                                 // (the result will be discarded)
468                                 qq[i].c->lock.Lock();
469                                 qq[i].c->lock.Unlock();
470                         }
471                         qq[i].q->OnError(err);
472                         delete qq[i].q;
473                         qq.erase(qq.begin() + i);
474                 }
475         }
476         Dispatcher->UnlockQueue();
477         // clean up any result queue entries
478         Dispatcher->OnNotify();
479 }
480
481 Version ModuleSQL::GetVersion()
482 {
483         return Version("MySQL support", VF_VENDOR);
484 }
485
486 void DispatcherThread::Run()
487 {
488         this->LockQueue();
489         while (!this->GetExitFlag())
490         {
491                 if (!Parent->qq.empty())
492                 {
493                         QQueueItem i = Parent->qq.front();
494                         i.c->lock.Lock();
495                         this->UnlockQueue();
496                         MySQLresult* res = i.c->DoBlockingQuery(i.query);
497                         i.c->lock.Unlock();
498
499                         /*
500                          * At this point, the main thread could be working on:
501                          *  Rehash - delete i.c out from under us. We don't care about that.
502                          *  UnloadModule - delete i.q and the qq item. Need to avoid reporting results.
503                          */
504
505                         this->LockQueue();
506                         if (!Parent->qq.empty() && Parent->qq.front().q == i.q)
507                         {
508                                 Parent->qq.pop_front();
509                                 Parent->rq.push_back(RQueueItem(i.q, res));
510                                 NotifyParent();
511                         }
512                         else
513                         {
514                                 // UnloadModule ate the query
515                                 delete res;
516                         }
517                 }
518                 else
519                 {
520                         /* We know the queue is empty, we can safely hang this thread until
521                          * something happens
522                          */
523                         this->WaitForQueue();
524                 }
525         }
526         this->UnlockQueue();
527 }
528
529 void DispatcherThread::OnNotify()
530 {
531         // this could unlock during the dispatch, but OnResult isn't expected to take that long
532         this->LockQueue();
533         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
534         {
535                 MySQLresult* res = i->r;
536                 if (res->err.id == SQL_NO_ERROR)
537                         i->q->OnResult(*res);
538                 else
539                         i->q->OnError(res->err);
540                 delete i->q;
541                 delete i->r;
542         }
543         Parent->rq.clear();
544         this->UnlockQueue();
545 }
546
547 MODULE_INIT(ModuleSQL)