]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
Change Windows libraries to be dynamically linked
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *
9  * This file is part of InspIRCd.  InspIRCd is free software: you can
10  * redistribute it and/or modify it under the terms of the GNU General Public
11  * License as published by the Free Software Foundation, version 2.
12  *
13  * This program is distributed in the hope that it will be useful, but WITHOUT
14  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
16  * details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
20  */
21
22
23 /* Stop mysql wanting to use long long */
24 #define NO_CLIENT_LONG_LONG
25
26 #include "inspircd.h"
27 #include <mysql.h>
28 #include "sql.h"
29
30 #ifdef _WIN32
31 # pragma comment(lib, "libmysql.lib")
32 #endif
33
34 /* VERSION 3 API: With nonblocking (threaded) requests */
35
36 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
37 /* $CompileFlags: exec("mysql_config --include") */
38 /* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */
39
40 /* THE NONBLOCKING MYSQL API!
41  *
42  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
43  * that instead, you should thread your program. This is what i've done here to allow for
44  * asyncronous SQL requests via mysql. The way this works is as follows:
45  *
46  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
47  * using a queue with priorities. There is a mutex on either end which prevents two threads
48  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
49  * worker thread wakes up, and checks if there is a request at the head of its queue.
50  * If there is, it processes this request, blocking the worker thread but leaving the ircd
51  * thread to go about its business as usual. During this period, the ircd thread is able
52  * to insert futher pending requests into the queue.
53  *
54  * Once the processing of a request is complete, it is removed from the incoming queue to
55  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
56  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
57  * connection ID through the connection.
58  *
59  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
60  * of the queue, and sends it on its way to the original calling module.
61  *
62  * XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
63  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
64  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
65  * however, if we were to call a module's OnRequest even from within a thread which was not the
66  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
67  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
68  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
69  * gauranteed threadsafe!)
70  *
71  * For a diagram of this system please see http://wiki.inspircd.org/Mysql2
72  */
73
74 class SQLConnection;
75 class MySQLresult;
76 class DispatcherThread;
77
78 struct QQueueItem
79 {
80         SQLQuery* q;
81         std::string query;
82         SQLConnection* c;
83         QQueueItem(SQLQuery* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
84 };
85
86 struct RQueueItem
87 {
88         SQLQuery* q;
89         MySQLresult* r;
90         RQueueItem(SQLQuery* Q, MySQLresult* R) : q(Q), r(R) {}
91 };
92
93 typedef std::map<std::string, SQLConnection*> ConnMap;
94 typedef std::deque<QQueueItem> QueryQueue;
95 typedef std::deque<RQueueItem> ResultQueue;
96
97 /** MySQL module
98  *  */
99 class ModuleSQL : public Module
100 {
101  public:
102         DispatcherThread* Dispatcher;
103         QueryQueue qq;       // MUST HOLD MUTEX
104         ResultQueue rq;      // MUST HOLD MUTEX
105         ConnMap connections; // main thread only
106
107         ModuleSQL();
108         void init();
109         ~ModuleSQL();
110         void OnRehash(User* user);
111         void OnUnloadModule(Module* mod);
112         Version GetVersion();
113 };
114
115 class DispatcherThread : public SocketThread
116 {
117  private:
118         ModuleSQL* const Parent;
119  public:
120         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
121         ~DispatcherThread() { }
122         virtual void Run();
123         virtual void OnNotify();
124 };
125
126 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
127 #define mysql_field_count mysql_num_fields
128 #endif
129
130 /** Represents a mysql result set
131  */
132 class MySQLresult : public SQLResult
133 {
134  public:
135         SQLerror err;
136         int currentrow;
137         int rows;
138         std::vector<std::string> colnames;
139         std::vector<SQLEntries> fieldlists;
140
141         MySQLresult(MYSQL_RES* res, int affected_rows) : err(SQL_NO_ERROR), currentrow(0), rows(0)
142         {
143                 if (affected_rows >= 1)
144                 {
145                         rows = affected_rows;
146                         fieldlists.resize(rows);
147                 }
148                 unsigned int field_count = 0;
149                 if (res)
150                 {
151                         MYSQL_ROW row;
152                         int n = 0;
153                         while ((row = mysql_fetch_row(res)))
154                         {
155                                 if (fieldlists.size() < (unsigned int)rows+1)
156                                 {
157                                         fieldlists.resize(fieldlists.size()+1);
158                                 }
159                                 field_count = 0;
160                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
161                                 if(mysql_num_fields(res) == 0)
162                                         break;
163                                 if (fields && mysql_num_fields(res))
164                                 {
165                                         colnames.clear();
166                                         while (field_count < mysql_num_fields(res))
167                                         {
168                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
169                                                 if (row[field_count])
170                                                         fieldlists[n].push_back(SQLEntry(row[field_count]));
171                                                 else
172                                                         fieldlists[n].push_back(SQLEntry());
173                                                 colnames.push_back(a);
174                                                 field_count++;
175                                         }
176                                         n++;
177                                 }
178                                 rows++;
179                         }
180                         mysql_free_result(res);
181                 }
182         }
183
184         MySQLresult(SQLerror& e) : err(e)
185         {
186
187         }
188
189         ~MySQLresult()
190         {
191         }
192
193         virtual int Rows()
194         {
195                 return rows;
196         }
197
198         virtual void GetCols(std::vector<std::string>& result)
199         {
200                 result.assign(colnames.begin(), colnames.end());
201         }
202
203         virtual SQLEntry GetValue(int row, int column)
204         {
205                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
206                 {
207                         return fieldlists[row][column];
208                 }
209                 return SQLEntry();
210         }
211
212         virtual bool GetRow(SQLEntries& result)
213         {
214                 if (currentrow < rows)
215                 {
216                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
217                         currentrow++;
218                         return true;
219                 }
220                 else
221                 {
222                         result.clear();
223                         return false;
224                 }
225         }
226 };
227
228 /** Represents a connection to a mysql database
229  */
230 class SQLConnection : public SQLProvider
231 {
232  public:
233         reference<ConfigTag> config;
234         MYSQL *connection;
235         Mutex lock;
236
237         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
238         SQLConnection(Module* p, ConfigTag* tag) : SQLProvider(p, "SQL/" + tag->getString("id")),
239                 config(tag), connection(NULL)
240         {
241         }
242
243         ~SQLConnection()
244         {
245                 Close();
246         }
247
248         // This method connects to the database using the credentials supplied to the constructor, and returns
249         // true upon success.
250         bool Connect()
251         {
252                 unsigned int timeout = 1;
253                 connection = mysql_init(connection);
254                 mysql_options(connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
255                 std::string host = config->getString("host");
256                 std::string user = config->getString("user");
257                 std::string pass = config->getString("pass");
258                 std::string dbname = config->getString("name");
259                 int port = config->getInt("port");
260                 bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
261                 if (!rv)
262                         return rv;
263                 std::string initquery;
264                 if (config->readString("initialquery", initquery))
265                 {
266                         mysql_query(connection,initquery.c_str());
267                 }
268                 return true;
269         }
270
271         ModuleSQL* Parent()
272         {
273                 return (ModuleSQL*)(Module*)creator;
274         }
275
276         MySQLresult* DoBlockingQuery(const std::string& query)
277         {
278
279                 /* Parse the command string and dispatch it to mysql */
280                 if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
281                 {
282                         /* Successfull query */
283                         MYSQL_RES* res = mysql_use_result(connection);
284                         unsigned long rows = mysql_affected_rows(connection);
285                         return new MySQLresult(res, rows);
286                 }
287                 else
288                 {
289                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
290                          * possible error numbers and error messages */
291                         SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + ": " + mysql_error(connection));
292                         return new MySQLresult(e);
293                 }
294         }
295
296         bool CheckConnection()
297         {
298                 if (!connection || mysql_ping(connection) != 0)
299                         return Connect();
300                 return true;
301         }
302
303         std::string GetError()
304         {
305                 return mysql_error(connection);
306         }
307
308         void Close()
309         {
310                 mysql_close(connection);
311         }
312
313         void submit(SQLQuery* q, const std::string& qs)
314         {
315                 Parent()->Dispatcher->LockQueue();
316                 Parent()->qq.push_back(QQueueItem(q, qs, this));
317                 Parent()->Dispatcher->UnlockQueueWakeup();
318         }
319
320         void submit(SQLQuery* call, const std::string& q, const ParamL& p)
321         {
322                 std::string res;
323                 unsigned int param = 0;
324                 for(std::string::size_type i = 0; i < q.length(); i++)
325                 {
326                         if (q[i] != '?')
327                                 res.push_back(q[i]);
328                         else
329                         {
330                                 if (param < p.size())
331                                 {
332                                         std::string parm = p[param++];
333                                         // In the worst case, each character may need to be encoded as using two bytes,
334                                         // and one byte is the terminating null
335                                         std::vector<char> buffer(parm.length() * 2 + 1);
336
337                                         // The return value of mysql_escape_string() is the length of the encoded string,
338                                         // not including the terminating null
339                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
340 //                                      mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
341                                         res.append(&buffer[0], escapedsize);
342                                 }
343                         }
344                 }
345                 submit(call, res);
346         }
347
348         void submit(SQLQuery* call, const std::string& q, const ParamM& p)
349         {
350                 std::string res;
351                 for(std::string::size_type i = 0; i < q.length(); i++)
352                 {
353                         if (q[i] != '$')
354                                 res.push_back(q[i]);
355                         else
356                         {
357                                 std::string field;
358                                 i++;
359                                 while (i < q.length() && isalnum(q[i]))
360                                         field.push_back(q[i++]);
361                                 i--;
362
363                                 ParamM::const_iterator it = p.find(field);
364                                 if (it != p.end())
365                                 {
366                                         std::string parm = it->second;
367                                         // NOTE: See above
368                                         std::vector<char> buffer(parm.length() * 2 + 1);
369                                         unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length());
370                                         res.append(&buffer[0], escapedsize);
371                                 }
372                         }
373                 }
374                 submit(call, res);
375         }
376 };
377
378 ModuleSQL::ModuleSQL()
379 {
380         Dispatcher = NULL;
381 }
382
383 void ModuleSQL::init()
384 {
385         Dispatcher = new DispatcherThread(this);
386         ServerInstance->Threads->Start(Dispatcher);
387
388         Implementation eventlist[] = { I_OnRehash, I_OnUnloadModule };
389         ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
390
391         OnRehash(NULL);
392 }
393
394 ModuleSQL::~ModuleSQL()
395 {
396         if (Dispatcher)
397         {
398                 Dispatcher->join();
399                 Dispatcher->OnNotify();
400                 delete Dispatcher;
401         }
402         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
403         {
404                 delete i->second;
405         }
406 }
407
408 void ModuleSQL::OnRehash(User* user)
409 {
410         ConnMap conns;
411         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
412         for(ConfigIter i = tags.first; i != tags.second; i++)
413         {
414                 if (i->second->getString("module", "mysql") != "mysql")
415                         continue;
416                 std::string id = i->second->getString("id");
417                 ConnMap::iterator curr = connections.find(id);
418                 if (curr == connections.end())
419                 {
420                         SQLConnection* conn = new SQLConnection(this, i->second);
421                         conns.insert(std::make_pair(id, conn));
422                         ServerInstance->Modules->AddService(*conn);
423                 }
424                 else
425                 {
426                         conns.insert(*curr);
427                         connections.erase(curr);
428                 }
429         }
430
431         // now clean up the deleted databases
432         Dispatcher->LockQueue();
433         SQLerror err(SQL_BAD_DBID);
434         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
435         {
436                 ServerInstance->Modules->DelService(*i->second);
437                 // it might be running a query on this database. Wait for that to complete
438                 i->second->lock.Lock();
439                 i->second->lock.Unlock();
440                 // now remove all active queries to this DB
441                 for (size_t j = qq.size(); j > 0; j--)
442                 {
443                         size_t k = j - 1;
444                         if (qq[k].c == i->second)
445                         {
446                                 qq[k].q->OnError(err);
447                                 delete qq[k].q;
448                                 qq.erase(qq.begin() + k);
449                         }
450                 }
451                 // finally, nuke the connection
452                 delete i->second;
453         }
454         Dispatcher->UnlockQueue();
455         connections.swap(conns);
456 }
457
458 void ModuleSQL::OnUnloadModule(Module* mod)
459 {
460         SQLerror err(SQL_BAD_DBID);
461         Dispatcher->LockQueue();
462         unsigned int i = qq.size();
463         while (i > 0)
464         {
465                 i--;
466                 if (qq[i].q->creator == mod)
467                 {
468                         if (i == 0)
469                         {
470                                 // need to wait until the query is done
471                                 // (the result will be discarded)
472                                 qq[i].c->lock.Lock();
473                                 qq[i].c->lock.Unlock();
474                         }
475                         qq[i].q->OnError(err);
476                         delete qq[i].q;
477                         qq.erase(qq.begin() + i);
478                 }
479         }
480         Dispatcher->UnlockQueue();
481         // clean up any result queue entries
482         Dispatcher->OnNotify();
483 }
484
485 Version ModuleSQL::GetVersion()
486 {
487         return Version("MySQL support", VF_VENDOR);
488 }
489
490 void DispatcherThread::Run()
491 {
492         this->LockQueue();
493         while (!this->GetExitFlag())
494         {
495                 if (!Parent->qq.empty())
496                 {
497                         QQueueItem i = Parent->qq.front();
498                         i.c->lock.Lock();
499                         this->UnlockQueue();
500                         MySQLresult* res = i.c->DoBlockingQuery(i.query);
501                         i.c->lock.Unlock();
502
503                         /*
504                          * At this point, the main thread could be working on:
505                          *  Rehash - delete i.c out from under us. We don't care about that.
506                          *  UnloadModule - delete i.q and the qq item. Need to avoid reporting results.
507                          */
508
509                         this->LockQueue();
510                         if (!Parent->qq.empty() && Parent->qq.front().q == i.q)
511                         {
512                                 Parent->qq.pop_front();
513                                 Parent->rq.push_back(RQueueItem(i.q, res));
514                                 NotifyParent();
515                         }
516                         else
517                         {
518                                 // UnloadModule ate the query
519                                 delete res;
520                         }
521                 }
522                 else
523                 {
524                         /* We know the queue is empty, we can safely hang this thread until
525                          * something happens
526                          */
527                         this->WaitForQueue();
528                 }
529         }
530         this->UnlockQueue();
531 }
532
533 void DispatcherThread::OnNotify()
534 {
535         // this could unlock during the dispatch, but OnResult isn't expected to take that long
536         this->LockQueue();
537         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
538         {
539                 MySQLresult* res = i->r;
540                 if (res->err.id == SQL_NO_ERROR)
541                         i->q->OnResult(*res);
542                 else
543                         i->q->OnError(res->err);
544                 delete i->q;
545                 delete i->r;
546         }
547         Parent->rq.clear();
548         this->UnlockQueue();
549 }
550
551 MODULE_INIT(ModuleSQL)