]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
Convert MySQL to SQLv3
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2010 InspIRCd Development Team
6  * See: http://wiki.inspircd.org/Credits
7  *
8  * This program is free but copyrighted software; see
9  *          the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 /* Stop mysql wanting to use long long */
15 #define NO_CLIENT_LONG_LONG
16
17 #include "inspircd.h"
18 #include <mysql.h>
19 #include "sql.h"
20
21 #ifdef WINDOWS
22 #pragma comment(lib, "mysqlclient.lib")
23 #endif
24
25 /* VERSION 3 API: With nonblocking (threaded) requests */
26
27 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
28 /* $CompileFlags: exec("mysql_config --include") */
29 /* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */
30 /* $ModDep: m_sqlv2.h */
31
32 /* THE NONBLOCKING MYSQL API!
33  *
34  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
35  * that instead, you should thread your program. This is what i've done here to allow for
36  * asyncronous SQL requests via mysql. The way this works is as follows:
37  *
38  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
39  * using a queue with priorities. There is a mutex on either end which prevents two threads
40  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
41  * worker thread wakes up, and checks if there is a request at the head of its queue.
42  * If there is, it processes this request, blocking the worker thread but leaving the ircd
43  * thread to go about its business as usual. During this period, the ircd thread is able
44  * to insert futher pending requests into the queue.
45  *
46  * Once the processing of a request is complete, it is removed from the incoming queue to
47  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
48  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
49  * connection ID through the connection.
50  *
51  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
52  * of the queue, and sends it on its way to the original calling module.
53  *
54  * XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
55  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
56  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
57  * however, if we were to call a module's OnRequest even from within a thread which was not the
58  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
59  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
60  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
61  * gauranteed threadsafe!)
62  *
63  * For a diagram of this system please see http://wiki.inspircd.org/Mysql2
64  */
65
66 class SQLConnection;
67 class MySQLresult;
68 class DispatcherThread;
69
70 struct QueueItem
71 {
72         SQLQuery* q;
73         SQLConnection* c;
74         QueueItem(SQLQuery* Q, SQLConnection* C) : q(Q), c(C) {}
75 };
76
77 typedef std::map<std::string, SQLConnection*> ConnMap;
78 typedef std::deque<QueueItem> QueryQueue;
79 typedef std::deque<MySQLresult*> ResultQueue;
80
81 /** MySQL module
82  *  */
83 class ModuleSQL : public Module
84 {
85  public:
86         DispatcherThread* Dispatcher;
87         QueryQueue qq;
88         ResultQueue rq;
89         ConnMap connections;
90
91         ModuleSQL();
92         void init();
93         ~ModuleSQL();
94         void OnRehash(User* user);
95         Version GetVersion();
96 };
97
98 class DispatcherThread : public SocketThread
99 {
100  private:
101         ModuleSQL* const Parent;
102  public:
103         DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
104         ~DispatcherThread() { }
105         virtual void Run();
106         virtual void OnNotify();
107 };
108
109 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
110 #define mysql_field_count mysql_num_fields
111 #endif
112
113 /** Represents a mysql result set
114  */
115 class MySQLresult : public SQLResult
116 {
117  public:
118         SQLQuery* query;
119         SQLerror err;
120         int currentrow;
121         int rows;
122         std::vector<std::string> colnames;
123         std::vector<SQLEntries> fieldlists;
124
125         MySQLresult(SQLQuery* q, MYSQL_RES* res, int affected_rows) : query(q), err(SQL_NO_ERROR), currentrow(0), rows(0)
126         {
127                 if (affected_rows >= 1)
128                 {
129                         rows = affected_rows;
130                         fieldlists.resize(rows);
131                 }
132                 unsigned int field_count = 0;
133                 if (res)
134                 {
135                         MYSQL_ROW row;
136                         int n = 0;
137                         while ((row = mysql_fetch_row(res)))
138                         {
139                                 if (fieldlists.size() < (unsigned int)rows+1)
140                                 {
141                                         fieldlists.resize(fieldlists.size()+1);
142                                 }
143                                 field_count = 0;
144                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
145                                 if(mysql_num_fields(res) == 0)
146                                         break;
147                                 if (fields && mysql_num_fields(res))
148                                 {
149                                         colnames.clear();
150                                         while (field_count < mysql_num_fields(res))
151                                         {
152                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
153                                                 if (row[field_count])
154                                                         fieldlists[n].push_back(SQLEntry(row[field_count]));
155                                                 else
156                                                         fieldlists[n].push_back(SQLEntry());
157                                                 colnames.push_back(a);
158                                                 field_count++;
159                                         }
160                                         n++;
161                                 }
162                                 rows++;
163                         }
164                         mysql_free_result(res);
165                         res = NULL;
166                 }
167         }
168
169         MySQLresult(SQLQuery* q, SQLerror& e) : query(q), err(e)
170         {
171
172         }
173
174         ~MySQLresult()
175         {
176         }
177
178         virtual int Rows()
179         {
180                 return rows;
181         }
182
183         virtual void GetCols(std::vector<std::string>& result)
184         {
185                 result.assign(colnames.begin(), colnames.end());
186         }
187
188         virtual SQLEntry GetValue(int row, int column)
189         {
190                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < (int)fieldlists[row].size()))
191                 {
192                         return fieldlists[row][column];
193                 }
194                 return SQLEntry();
195         }
196
197         virtual bool GetRow(SQLEntries& result)
198         {
199                 if (currentrow < rows)
200                 {
201                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
202                         currentrow++;
203                         return true;
204                 }
205                 else
206                 {
207                         result.clear();
208                         return false;
209                 }
210         }
211 };
212
213 /** Represents a connection to a mysql database
214  */
215 class SQLConnection : public SQLProvider
216 {
217  public:
218         reference<ConfigTag> config;
219         MYSQL *connection;
220         bool active;
221
222         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
223         SQLConnection(Module* p, ConfigTag* tag) : SQLProvider(p, "SQL/" + tag->getString("id")),
224                 config(tag), active(false)
225         {
226         }
227
228         ~SQLConnection()
229         {
230                 Close();
231         }
232
233         // This method connects to the database using the credentials supplied to the constructor, and returns
234         // true upon success.
235         bool Connect()
236         {
237                 unsigned int timeout = 1;
238                 connection = mysql_init(connection);
239                 mysql_options(connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
240                 std::string host = config->getString("host");
241                 std::string user = config->getString("user");
242                 std::string pass = config->getString("pass");
243                 std::string dbname = config->getString("name");
244                 int port = config->getInt("port");
245                 bool rv = mysql_real_connect(connection, host.c_str(), user.c_str(), pass.c_str(), dbname.c_str(), port, NULL, 0);
246                 if (!rv)
247                         return rv;
248                 std::string initquery;
249                 if (config->readString("initialquery", initquery))
250                 {
251                         mysql_query(connection,initquery.c_str());
252                 }
253                 return true;
254         }
255
256         virtual std::string FormatQuery(const std::string& q, const ParamL& p)
257         {
258                 std::string res;
259                 unsigned int param = 0;
260                 for(std::string::size_type i = 0; i < q.length(); i++)
261                 {
262                         if (q[i] != '?')
263                                 res.push_back(q[i]);
264                         else
265                         {
266                                 // TODO numbered parameter support ('?1')
267                                 if (param < p.size())
268                                 {
269                                         std::string parm = p[param++];
270                                         char buffer[MAXBUF];
271                                         mysql_escape_string(buffer, parm.c_str(), parm.length());
272 //                                      mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
273                                         res.append(buffer);
274                                 }
275                         }
276                 }
277                 return res;
278         }
279
280         std::string FormatQuery(const std::string& q, const ParamM& p)
281         {
282                 std::string res;
283                 for(std::string::size_type i = 0; i < q.length(); i++)
284                 {
285                         if (q[i] != '$')
286                                 res.push_back(q[i]);
287                         else
288                         {
289                                 std::string field;
290                                 i++;
291                                 while (i < q.length() && isalpha(q[i]))
292                                         field.push_back(q[i++]);
293                                 i--;
294
295                                 ParamM::const_iterator it = p.find(field);
296                                 if (it != p.end())
297                                 {
298                                         std::string parm = it->second;
299                                         char buffer[MAXBUF];
300                                         mysql_escape_string(buffer, parm.c_str(), parm.length());
301                                         res.append(buffer);
302                                 }
303                         }
304                 }
305                 return res;
306         }
307
308         ModuleSQL* Parent()
309         {
310                 return (ModuleSQL*)(Module*)creator;
311         }
312
313         void DoBlockingQuery(SQLQuery* req)
314         {
315                 /* Parse the command string and dispatch it to mysql */
316                 if (CheckConnection() && !mysql_real_query(connection, req->query.data(), req->query.length()))
317                 {
318                         /* Successfull query */
319                         MYSQL_RES* res = mysql_use_result(connection);
320                         unsigned long rows = mysql_affected_rows(connection);
321                         MySQLresult* r = new MySQLresult(req, res, rows);
322                         Parent()->Dispatcher->LockQueue();
323                         Parent()->rq.push_back(r);
324                         Parent()->Dispatcher->NotifyParent();
325                         Parent()->Dispatcher->UnlockQueue();
326                 }
327                 else
328                 {
329                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
330                          * possible error numbers and error messages */
331                         SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + std::string(": ") + mysql_error(connection));
332                         MySQLresult* r = new MySQLresult(req, e);
333                         Parent()->Dispatcher->LockQueue();
334                         Parent()->rq.push_back(r);
335                         Parent()->Dispatcher->NotifyParent();
336                         Parent()->Dispatcher->UnlockQueue();
337                 }
338         }
339
340         bool CheckConnection()
341         {
342                 if (mysql_ping(connection) != 0)
343                 {
344                         return Connect();
345                 }
346                 else return true;
347         }
348
349         std::string GetError()
350         {
351                 return mysql_error(connection);
352         }
353
354         void Close()
355         {
356                 mysql_close(connection);
357         }
358
359         void submit(SQLQuery* q)
360         {
361                 Parent()->Dispatcher->LockQueue();
362                 Parent()->qq.push_back(QueueItem(q, this));
363                 Parent()->Dispatcher->UnlockQueueWakeup();
364         }
365 };
366
367 ModuleSQL::ModuleSQL()
368 {
369         Dispatcher = NULL;
370 }
371
372 void ModuleSQL::init()
373 {
374         Dispatcher = new DispatcherThread(this);
375         ServerInstance->Threads->Start(Dispatcher);
376
377         Implementation eventlist[] = { I_OnRehash };
378         ServerInstance->Modules->Attach(eventlist, this, 1);
379 }
380
381 ModuleSQL::~ModuleSQL()
382 {
383         if (Dispatcher)
384         {
385                 Dispatcher->join();
386                 Dispatcher->OnNotify();
387                 delete Dispatcher;
388         }
389         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
390         {
391                 delete i->second;
392         }
393 }
394
395 void ModuleSQL::OnRehash(User* user)
396 {
397         Dispatcher->LockQueue();
398         ConnMap conns;
399         ConfigTagList tags = ServerInstance->Config->ConfTags("database");
400         for(ConfigIter i = tags.first; i != tags.second; i++)
401         {
402                 if (i->second->getString("module", "mysql") != "mysql")
403                         continue;
404                 std::string id = i->second->getString("id");
405                 ConnMap::iterator curr = connections.find(id);
406                 if (curr == connections.end())
407                 {
408                         SQLConnection* conn = new SQLConnection(this, i->second);
409                         conns.insert(std::make_pair(id, conn));
410                         ServerInstance->Modules->AddService(*conn);
411                 }
412                 else
413                 {
414                         conns.insert(*curr);
415                         connections.erase(curr);
416                 }
417         }
418         for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
419         {
420                 if (i->second->active)
421                 {
422                         // can't delete it now. Next rehash will try to kill it again
423                         conns.insert(*i);
424                 }
425                 else
426                 {
427                         ServerInstance->Modules->DelService(*i->second);
428                         delete i->second;
429                 }
430         }
431         connections.swap(conns);
432         Dispatcher->UnlockQueue();
433 }
434
435 Version ModuleSQL::GetVersion()
436 {
437         return Version("MySQL support", VF_VENDOR);
438 }
439
440 void DispatcherThread::Run()
441 {
442         this->LockQueue();
443         while (!this->GetExitFlag())
444         {
445                 if (!Parent->qq.empty())
446                 {
447                         QueueItem i = Parent->qq.front();
448                         Parent->qq.pop_front();
449                         i.c->active = true;
450                         this->UnlockQueue();
451                         i.c->DoBlockingQuery(i.q);
452                         this->LockQueue();
453                         i.c->active = false;
454                 }
455                 else
456                 {
457                         /* We know the queue is empty, we can safely hang this thread until
458                          * something happens
459                          */
460                         this->WaitForQueue();
461                 }
462         }
463         this->UnlockQueue();
464 }
465
466 void DispatcherThread::OnNotify()
467 {
468         // this could unlock during the dispatch, but OnResult isn't expected to take that long
469         this->LockQueue();
470         for(ResultQueue::iterator i = Parent->rq.begin(); i != Parent->rq.end(); i++)
471         {
472                 MySQLresult* res = *i;
473                 if (res->err.id == SQL_NO_ERROR)
474                         res->query->OnResult(*res);
475                 else
476                         res->query->OnError(res->err);
477                 delete res->query;
478                 delete res;
479         }
480         Parent->rq.clear();
481         this->UnlockQueue();
482 }
483
484 MODULE_INIT(ModuleSQL)