]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
Update wiki links to use HTTPS and point to the correct pages.
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22
23 /* Stop mysql wanting to use long long */
24 #define NO_CLIENT_LONG_LONG
25
26 #include "inspircd.h"
27 #include <mysql.h>
28 #include "sql.h"
29
30 #ifdef _WIN32
31 # pragma comment(lib, "libmysql.lib")
32 #endif
33
34 /* VERSION 3 API: With nonblocking (threaded) requests */
35
36 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
37 /* $CompileFlags: exec("mysql_config --include") */
38 /* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */
39
40 /* THE NONBLOCKING MYSQL API!
41  *
42  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
43  * that instead, you should thread your program. This is what i've done here to allow for
44  * asyncronous SQL requests via mysql. The way this works is as follows:
45  *
46  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
47  * using a queue with priorities. There is a mutex on either end which prevents two threads
48  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
49  * worker thread wakes up, and checks if there is a request at the head of its queue.
50  * If there is, it processes this request, blocking the worker thread but leaving the ircd
51  * thread to go about its business as usual. During this period, the ircd thread is able
52  * to insert futher pending requests into the queue.
53  *
54  * Once the processing of a request is complete, it is removed from the incoming queue to
55  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
56  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
57  * connection ID through the connection.
58  *
59  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
60  * of the queue, and sends it on its way to the original calling module.
61  *
62  * XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
63  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
64  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
65  * however, if we were to call a module's OnRequest even from within a thread which was not the
66  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
67  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
68  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
69  * gauranteed threadsafe!)
70  */
71
72 class SQLConnection;
73 class MySQLresult;
74 class DispatcherThread;
75
76 struct QQueueItem
77 {
78         SQLQuery* q;
79         std::string query;
80         SQLConnection* c;
81         QQueueItem(SQLQuery* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
82 };
83
84 struct RQueueItem
85 {
86         SQLQuery* q;
87         MySQLresult* r;
88         RQueueItem(SQLQuery* Q, MySQLresult* R) : q(Q), r(R) {}
89 };
90
91 typedef std::map<std::string, SQLConnection*> ConnMap;
92 typedef std::deque<QQueueItem> QueryQueue;
93 typedef std::deque<RQueueItem> ResultQueue;
94
95 /** MySQL module
96  *  */
97 class ModuleSQL : public Module
98 {
99  public:
100         DispatcherThread* Dispatcher;
101         QueryQueue qq;       // MUST HOLD MUTEX
102         ResultQueue rq;      // MUST HOLD MUTEX
103         ConnMap connections; // main thread only
104
105         ModuleSQL();
106         void init();
107         ~ModuleSQL();
108         void OnRehash(User* user);
109         void OnUnloadModule(Module* mod);
110         Version GetVersion();
111 };
112
113 class DispatcherThread : public SocketThread
114 {
115  private:
116         ModuleSQL* const Parent;
117  public:
118         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
119         ~DispatcherThread() { }
120         virtual void Run();
121         virtual void OnNotify();
122 };
123
124 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
125 #define mysql_field_count mysql_num_fields
126 #endif
127
128 /** Represents a mysql result set
129  */
130 class MySQLresult : public SQLResult
131 {
132  public:
133         SQLerror err;
134         int currentrow;
135         int rows;
136         std::vector<std::string> colnames;
137         std::vector<SQLEntries> fieldlists;
138
139         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL_NO_ERROR), currentrow(0), rows(0)
140         {
141                 if (affected_rows >= 1)
142                 {
143                         rows = affected_rows;
144                         fieldlists.resize(rows);
145                 }
146                 unsigned int field_count = 0;
147                 if (res)
148                 {
149                         MYSQL_ROW row;
150                         int n = 0;
151                         while ((row = mysql_fetch_row(res)))
152                         {
153                                 if (fieldlists.size() < (unsigned int)rows+1)
154                                 {
155                                         fieldlists.resize(fieldlists.size()+1);
156                                 }
157                                 field_count = 0;
158                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
159                                 if(mysql_num_fields(res) == 0)
160                                         break;
161                                 if (fields && mysql_num_fields(res))
162                                 {
163                                         colnames.clear();
164                                         while (field_count < mysql_num_fields(res))
165                                         {
166                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
167                                                 if (row[field_count])
168                                                         fieldlists[n].push_back(SQLEntry(row[field_count]));
169                                                 else
170                                                         fieldlists[n].push_back(SQLEntry());
171                                                 colnames.push_back(a);
172                                                 field_count++;
173                                         }
174                                         n++;
175                                 }
176                                 rows++;
177                         }
178                         mysql_free_result(res);
179                 }
180         }
181
182         MySQLresult(SQLerror& e) : err(e)
183         {
184
185         }
186
187         ~MySQLresult()
188         {
189         }
190
191         virtual int Rows()
192         {
193                 return rows;
194         }
195
196         virtual void GetCols(std::vector<std::string>& result)
197         {
198                 result.assign(colnames.begin(), colnames.end());
199         }
200
201         virtual SQLEntry GetValue(int row, int column)
202         {
203                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
204                 {
205                         return fieldlists[row][column];
206                 }
207                 return SQLEntry();
208         }
209
210         virtual bool GetRow(SQLEntries& result)
211         {
212                 if (currentrow < rows)
213                 {
214                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
215                         currentrow++;
216                         return true;
217                 }
218                 else
219                 {
220                         result.clear();
221                         return false;
222                 }
223         }
224 };
225
226 /** Represents a connection to a mysql database
227  */
228 class SQLConnection : public SQLProvider
229 {
230  public:
231         reference<ConfigTag> config;
232         MYSQL *connection;
233         Mutex lock;
234
235         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
236         SQLConnection(Module* p, ConfigTag* tag) : SQLProvider(p, "SQL/" + tag->getString("id")),
237                 config(tag), connection(NULL)
238         {
239         }
240
241         ~SQLConnection()
242         {
243                 Close();
244         }
245
246         // This method connects to the database using the credentials supplied to the constructor, and returns
247         // true upon success.
248         bool Connect()
249         {
250                 unsigned int timeout = 1;
251                 connection = mysql_init(connection);
252                 mysql_options(connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
253                 std::string host = config->getString("host");
254                 std::string user = config->getString("user");
255                 std::string pass = config->getString("pass");
256                 std::string dbname = config->getString("name");
257                 int port = config->getInt("port");
258                 bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
259                 if (!rv)
260                         return rv;
261                 std::string initquery;
262                 if (config->readString("initialquery", initquery))
263                 {
264                         mysql_query(connection,initquery.c_str());
265                 }
266                 return true;
267         }
268
269         ModuleSQL* Parent()
270         {
271                 return (ModuleSQL*)(Module*)creator;
272         }
273
274         MySQLresult* DoBlockingQuery(const std::string& query)
275         {
276
277                 /* Parse the command string and dispatch it to mysql */
278                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
279                 {
280                         /* Successfull query */
281                         MYSQL_RES* res = mysql_use_result(connection);
282                         unsigned long rows = mysql_affected_rows(connection);
283                         return new MySQLresult(res, rows);
284                 }
285                 else
286                 {
287                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
288                          * possible error numbers and error messages */
289                         SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + ": " + mysql_error(connection));
290                         return new MySQLresult(e);
291                 }
292         }
293
294         bool CheckConnection()
295         {
296                 if (!connection || mysql_ping(connection) != 0)
297                         return Connect();
298                 return true;
299         }
300
301         std::string GetError()
302         {
303                 return mysql_error(connection);
304         }
305
306         void Close()
307         {
308                 mysql_close(connection);
309         }
310
311         void submit(SQLQuery* q, const std::string& qs)
312         {
313                 Parent()->Dispatcher->LockQueue();
314                 Parent()->qq.push_back(QQueueItem(q, qs, this));
315                 Parent()->Dispatcher->UnlockQueueWakeup();
316         }
317
318         void submit(SQLQuery* call, const std::string& q, const ParamL& p)
319         {
320                 std::string res;
321                 unsigned int param = 0;
322                 for(std::string::size_type i = 0; i < q.length(); i++)
323                 {
324                         if (q[i] != '?')
325                                 res.push_back(q[i]);
326                         else
327                         {
328                                 if (param < p.size())
329                                 {
330                                         std::string parm = p[param++];
331                                         // In the worst case, each character may need to be encoded as using two bytes,
332                                         // and one byte is the terminating null
333                                         std::vector<char> buffer(parm.length() * 2 + 1);
334
335                                         // The return value of mysql_escape_string() is the length of the encoded string,
336                                         // not including the terminating null
337                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
338 //                                      mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
339                                         res.append(&buffer[0], escapedsize);
340                                 }
341                         }
342                 }
343                 submit(call, res);
344         }
345
346         void submit(SQLQuery* call, const std::string& q, const ParamM& p)
347         {
348                 std::string res;
349                 for(std::string::size_type i = 0; i < q.length(); i++)
350                 {
351                         if (q[i] != '$')
352                                 res.push_back(q[i]);
353                         else
354                         {
355                                 std::string field;
356                                 i++;
357                                 while (i < q.length() && isalnum(q[i]))
358                                         field.push_back(q[i++]);
359                                 i--;
360
361                                 ParamM::const_iterator it = p.find(field);
362                                 if (it != p.end())
363                                 {
364                                         std::string parm = it->second;
365                                         // NOTE: See above
366                                         std::vector<char> buffer(parm.length() * 2 + 1);
367                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
368                                         res.append(&buffer[0], escapedsize);
369                                 }
370                         }
371                 }
372                 submit(call, res);
373         }
374 };
375
376 ModuleSQL::ModuleSQL()
377 {
378         Dispatcher = NULL;
379 }
380
381 void ModuleSQL::init()
382 {
383         Dispatcher = new DispatcherThread(this);
384         ServerInstance->Threads->Start(Dispatcher);
385
386         Implementation eventlist[] = { I_OnRehash, I_OnUnloadModule };
387         ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
388
389         OnRehash(NULL);
390 }
391
392 ModuleSQL::~ModuleSQL()
393 {
394         if (Dispatcher)
395         {
396                 Dispatcher->join();
397                 Dispatcher->OnNotify();
398                 delete Dispatcher;
399         }
400         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
401         {
402                 delete i->second;
403         }
404 }
405
406 void ModuleSQL::OnRehash(User* user)
407 {
408         ConnMap conns;
409         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
410         for(ConfigIter i = tags.first; i != tags.second; i++)
411         {
412                 if (i->second->getString("module", "mysql") != "mysql")
413                         continue;
414                 std::string id = i->second->getString("id");
415                 ConnMap::iterator curr = connections.find(id);
416                 if (curr == connections.end())
417                 {
418                         SQLConnection* conn = new SQLConnection(this, i->second);
419                         conns.insert(std::make_pair(id, conn));
420                         ServerInstance->Modules->AddService(*conn);
421                 }
422                 else
423                 {
424                         conns.insert(*curr);
425                         connections.erase(curr);
426                 }
427         }
428
429         // now clean up the deleted databases
430         Dispatcher->LockQueue();
431         SQLerror err(SQL_BAD_DBID);
432         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
433         {
434                 ServerInstance->Modules->DelService(*i->second);
435                 // it might be running a query on this database. Wait for that to complete
436                 i->second->lock.Lock();
437                 i->second->lock.Unlock();
438                 // now remove all active queries to this DB
439                 for (size_t j = qq.size(); j > 0; j--)
440                 {
441                         size_t k = j - 1;
442                         if (qq[k].c == i->second)
443                         {
444                                 qq[k].q->OnError(err);
445                                 delete qq[k].q;
446                                 qq.erase(qq.begin() + k);
447                         }
448                 }
449                 // finally, nuke the connection
450                 delete i->second;
451         }
452         Dispatcher->UnlockQueue();
453         connections.swap(conns);
454 }
455
456 void ModuleSQL::OnUnloadModule(Module* mod)
457 {
458         SQLerror err(SQL_BAD_DBID);
459         Dispatcher->LockQueue();
460         unsigned int i = qq.size();
461         while (i > 0)
462         {
463                 i--;
464                 if (qq[i].q->creator == mod)
465                 {
466                         if (i == 0)
467                         {
468                                 // need to wait until the query is done
469                                 // (the result will be discarded)
470                                 qq[i].c->lock.Lock();
471                                 qq[i].c->lock.Unlock();
472                         }
473                         qq[i].q->OnError(err);
474                         delete qq[i].q;
475                         qq.erase(qq.begin() + i);
476                 }
477         }
478         Dispatcher->UnlockQueue();
479         // clean up any result queue entries
480         Dispatcher->OnNotify();
481 }
482
483 Version ModuleSQL::GetVersion()
484 {
485         return Version("MySQL support", VF_VENDOR);
486 }
487
488 void DispatcherThread::Run()
489 {
490         this->LockQueue();
491         while (!this->GetExitFlag())
492         {
493                 if (!Parent->qq.empty())
494                 {
495                         QQueueItem i = Parent->qq.front();
496                         i.c->lock.Lock();
497                         this->UnlockQueue();
498                         MySQLresult* res = i.c->DoBlockingQuery(i.query);
499                         i.c->lock.Unlock();
500
501                         /*
502                          * At this point, the main thread could be working on:
503                          *  Rehash - delete i.c out from under us. We don't care about that.
504                          *  UnloadModule - delete i.q and the qq item. Need to avoid reporting results.
505                          */
506
507                         this->LockQueue();
508                         if (!Parent->qq.empty() && Parent->qq.front().q == i.q)
509                         {
510                                 Parent->qq.pop_front();
511                                 Parent->rq.push_back(RQueueItem(i.q, res));
512                                 NotifyParent();
513                         }
514                         else
515                         {
516                                 // UnloadModule ate the query
517                                 delete res;
518                         }
519                 }
520                 else
521                 {
522                         /* We know the queue is empty, we can safely hang this thread until
523                          * something happens
524                          */
525                         this->WaitForQueue();
526                 }
527         }
528         this->UnlockQueue();
529 }
530
531 void DispatcherThread::OnNotify()
532 {
533         // this could unlock during the dispatch, but OnResult isn't expected to take that long
534         this->LockQueue();
535         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
536         {
537                 MySQLresult* res = i->r;
538                 if (res->err.id == SQL_NO_ERROR)
539                         i->q->OnResult(*res);
540                 else
541                         i->q->OnError(res->err);
542                 delete i->q;
543                 delete i->r;
544         }
545         Parent->rq.clear();
546         this->UnlockQueue();
547 }
548
549 MODULE_INIT(ModuleSQL)