]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
b0a0afbe5787b2f6ed186f3d55e8125051fc2503
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2007 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *            the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include <string>
15 #include <deque>
16 #include <map>
17 #include <sqlite3.h>
18
19 #include "users.h"
20 #include "channels.h"
21 #include "modules.h"
22 #include "inspircd.h"
23 #include "configreader.h"
24
25 #include "m_sqlv2.h"
26
27 /* $ModDesc: sqlite3 provider */
28 /* $CompileFlags: pkgconfincludes("sqlite3","/sqlite3.h","") */
29 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
30 /* $ModDep: m_sqlv2.h */
31
32
33 class SQLConn;
34 class SQLite3Result;
35
36 typedef std::map<std::string, SQLConn*> ConnMap;
37 typedef std::deque<classbase*> paramlist;
38 typedef std::deque<SQLite3Result*> ResultQueue;
39
40 class SQLite3Result : public SQLresult
41 {
42   private:
43         int currentrow;
44         int rows;
45         int cols;
46
47         std::vector<std::string> colnames;
48         std::vector<SQLfieldList> fieldlists;
49         SQLfieldList emptyfieldlist;
50
51         SQLfieldList* fieldlist;
52         SQLfieldMap* fieldmap;
53
54   public:
55         SQLite3Result(Module* self, Module* to, unsigned int id)
56         : SQLresult(self, to, id), currentrow(0), rows(0), cols(0)
57         {
58         }
59
60         ~SQLite3Result()
61         {
62         }
63
64         void AddRow(int colsnum, char **data, char **colname)
65         {
66                 colnames.clear();
67                 cols = colsnum;
68                 for (int i = 0; i < colsnum; i++)
69                 {
70                         fieldlists.resize(fieldlists.size()+1);
71                         colnames.push_back(colname[i]);
72                         SQLfield sf(data[i] ? data[i] : "", data[i] ? false : true);
73                         fieldlists[rows].push_back(sf);\r
74                 }
75                 rows++;
76         }
77
78         virtual int Rows()
79         {
80                 return rows;
81         }
82
83         virtual int Cols()
84         {
85                 return cols;
86         }
87
88         virtual std::string ColName(int column)
89         {
90                 if (column < (int)colnames.size())
91                 {
92                         return colnames[column];
93                 }
94                 else
95                 {
96                         throw SQLbadColName();
97                 }
98                 return "";
99         }
100
101         virtual int ColNum(const std::string &column)
102         {
103                 for (unsigned int i = 0; i < colnames.size(); i++)
104                 {
105                         if (column == colnames[i])
106                                 return i;
107                 }
108                 throw SQLbadColName();
109                 return 0;
110         }
111
112         virtual SQLfield GetValue(int row, int column)
113         {
114                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
115                 {
116                         return fieldlists[row][column];
117                 }
118
119                 throw SQLbadColName();
120
121                 /* XXX: We never actually get here because of the throw */
122                 return SQLfield("",true);
123         }
124
125         virtual SQLfieldList& GetRow()
126         {
127                 if (currentrow < rows)
128                         return fieldlists[currentrow];
129                 else
130                         return emptyfieldlist;
131         }
132
133         virtual SQLfieldMap& GetRowMap()
134         {
135                 /* In an effort to reduce overhead we don't actually allocate the map
136                  * until the first time it's needed...so...
137                  */
138                 if(fieldmap)
139                 {
140                         fieldmap->clear();
141                 }
142                 else
143                 {
144                         fieldmap = new SQLfieldMap;
145                 }
146
147                 if (currentrow < rows)
148                 {
149                         for (int i = 0; i < Cols(); i++)
150                         {
151                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
152                         }
153                         currentrow++;
154                 }
155
156                 return *fieldmap;
157         }
158
159         virtual SQLfieldList* GetRowPtr()
160         {
161                 fieldlist = new SQLfieldList();
162
163                 if (currentrow < rows)
164                 {
165                         for (int i = 0; i < Rows(); i++)
166                         {
167                                 fieldlist->push_back(fieldlists[currentrow][i]);
168                         }
169                         currentrow++;
170                 }
171                 return fieldlist;
172         }
173
174         virtual SQLfieldMap* GetRowMapPtr()
175         {
176                 fieldmap = new SQLfieldMap();
177
178                 if (currentrow < rows)
179                 {
180                         for (int i = 0; i < Cols(); i++)
181                         {
182                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
183                         }
184                         currentrow++;
185                 }
186
187                 return fieldmap;
188         }
189
190         virtual void Free(SQLfieldMap* fm)
191         {
192                 delete fm;
193         }
194
195         virtual void Free(SQLfieldList* fl)
196         {
197                 delete fl;
198         }
199
200
201 };
202
203 class SQLConn : public classbase
204 {
205   private:
206         ResultQueue results;
207         InspIRCd* Instance;
208         Module* mod;
209         SQLhost host;
210         sqlite3* conn;
211
212   public:
213         SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
214         : Instance(SI), mod(m), host(hi)
215         {
216                 int result;
217                 if ((result = OpenDB()) == SQLITE_OK)
218                 {
219                         Instance->Log(DEBUG, "Opened sqlite DB: " + host.host);
220                 }
221                 else
222                 {
223                         Instance->Log(DEFAULT, "WARNING: Could not open DB with id: " + host.id);
224                         CloseDB();
225                 }
226         }
227
228         SQLerror Query(SQLrequest &req)
229         {
230                 /* Pointer to the buffer we screw around with substitution in */
231                 char* query;
232
233                 /* Pointer to the current end of query, where we append new stuff */
234                 char* queryend;
235
236                 /* Total length of the unescaped parameters */
237                 unsigned long paramlen;
238
239                 /* Total length of query, used for binary-safety in mysql_real_query */
240                 unsigned long querylength = 0;
241
242                 paramlen = 0;
243                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
244                 {
245                         paramlen += i->size();
246                 }
247
248                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
249                  * sizeofquery + (totalparamlength*2) + 1
250                  *
251                  * The +1 is for null-terminating the string for mysql_real_escape_string
252                  */
253
254                 query = new char[req.query.q.length() + (paramlen*2) + 1];
255                 queryend = query;
256
257                 for(unsigned long i = 0; i < req.query.q.length(); i++)
258                 {
259                         if(req.query.q[i] == '?')
260                         {
261                                 if(req.query.p.size())
262                                 {
263                                         char* escaped;
264                                         escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
265                                         for (char* n = escaped; *n; n++)
266                                         {
267                                                 *queryend = *n;
268                                                 queryend++;
269                                         }
270                                         sqlite3_free(escaped);
271                                         req.query.p.pop_front();
272                                 }
273                                 else
274                                         break;
275                         }
276                         else
277                         {
278                                 *queryend = req.query.q[i];
279                                 queryend++;
280                         }
281                         querylength++;
282                 }
283                 *queryend = 0;
284                 req.query.q = query;
285
286 //              Instance->Log(DEBUG, "<******> Doing query: " + ConvToStr(req.query.q.data()));
287
288                 SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
289                 res->query = req.query.q;
290                 paramlist params;
291                 params.push_back(this);
292                 params.push_back(res);
293
294                 char *errmsg = 0;
295                 if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
296                 {
297                         Instance->Log(DEBUG, "Query failed: " + ConvToStr(errmsg));
298                         sqlite3_free(errmsg);
299                         delete[] query;
300                         delete res;
301                         return SQLerror(QSEND_FAIL, ConvToStr(errmsg));
302                 }
303                 Instance->Log(DEBUG, "Dispatched query successfully. ID: %d resulting rows %d", req.id, res->Rows());
304                 delete[] query;
305
306                 results.push_back(res);
307
308                 return SQLerror();
309         }
310
311         static int QueryResult(void *params, int argc, char **argv, char **azColName)
312         {
313                 paramlist* p = (paramlist*)params;
314                 ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
315                 return 0;
316         }
317
318         void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
319         {
320                 res->AddRow(cols, data, colnames);
321         }
322
323         void QueryDone(SQLrequest* req, int rows)
324         {
325                 SQLite3Result* r = new SQLite3Result(mod, req->GetSource(), req->id);
326                 r->dbid = host.id;
327                 r->query = req->query.q;
328         }
329
330         int OpenDB()
331         {
332                 return sqlite3_open(host.host.c_str(), &conn);
333         }
334
335         void CloseDB()
336         {
337                 sqlite3_interrupt(conn);
338                 sqlite3_close(conn);
339         }
340
341         SQLhost GetConfHost()
342         {
343                 return host;
344         }
345
346         void SendResults()
347         {
348                 if (results.size())
349                 {
350                         SQLite3Result* res = results[0];
351                         if (res->GetDest())
352                         {
353                                 res->Send();
354                         }
355                         else
356                         {
357                                 /* If the client module is unloaded partway through a query then the provider will set
358                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
359                                  * through at some point...and it could get messy if we play with invalid pointers...
360                                  */
361                                 Instance->Log(DEBUG, "Looks like we're handling a zombie query from a module which unloaded before it got a result..fun. ID: " + ConvToStr(res->GetId()));
362                                 delete res;
363                         }
364                         results.pop_front();
365                 }
366         }
367
368 };
369
370
371 class ModuleSQLite3 : public Module
372 {
373   private:
374         ConnMap connections;
375         unsigned long currid;
376
377   public:
378         ModuleSQLite3(InspIRCd* Me)
379         : Module::Module(Me), currid(0)
380         {
381                 ServerInstance->UseInterface("SQLutils");
382
383                 if (!ServerInstance->PublishFeature("SQL", this))
384                 {
385                         throw ModuleException("m_mysql: Unable to publish feature 'SQL'");
386                 }
387
388                 ReadConf();
389
390                 ServerInstance->PublishInterface("SQL", this);
391         }
392
393         virtual ~ModuleSQLite3()
394         {
395                 ServerInstance->UnpublishInterface("SQL", this);
396                 ServerInstance->UnpublishFeature("SQL");
397                 ServerInstance->DoneWithInterface("SQLutils");
398         }
399
400         void Implements(char* List)
401         {
402                 List[I_OnRequest] = 1;
403         }
404
405         bool HasHost(const SQLhost &host)
406         {
407                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
408                 {
409                         if (host == iter->second->GetConfHost())
410                                 return true;
411                 }
412                 return false;
413         }
414
415         bool HostInConf(const SQLhost &h)
416         {
417                 ConfigReader conf(ServerInstance);
418                 for(int i = 0; i < conf.Enumerate("database"); i++)
419                 {
420                         SQLhost host;
421                         host.id         = conf.ReadValue("database", "id", i);
422                         host.host       = conf.ReadValue("database", "hostname", i);
423                         host.port       = conf.ReadInteger("database", "port", i, true);
424                         host.name       = conf.ReadValue("database", "name", i);
425                         host.user       = conf.ReadValue("database", "username", i);
426                         host.pass       = conf.ReadValue("database", "password", i);
427                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
428                         if (h == host)
429                                 return true;
430                 }
431                 return false;
432         }
433
434         void ReadConf()
435         {
436                 //ClearOldConnections();
437
438                 ConfigReader conf(ServerInstance);
439                 for(int i = 0; i < conf.Enumerate("database"); i++)
440                 {
441                         SQLhost host;
442
443                         host.id         = conf.ReadValue("database", "id", i);
444                         host.host       = conf.ReadValue("database", "hostname", i);
445                         host.port       = conf.ReadInteger("database", "port", i, true);
446                         host.name       = conf.ReadValue("database", "name", i);
447                         host.user       = conf.ReadValue("database", "username", i);
448                         host.pass       = conf.ReadValue("database", "password", i);
449                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
450
451                         if (HasHost(host))
452                                 continue;
453
454                         this->AddConn(host);
455                 }
456         }
457
458         void AddConn(const SQLhost& hi)
459         {
460                 if (HasHost(hi))
461                 {
462                         ServerInstance->Log(DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
463                         return;
464                 }
465
466                 SQLConn* newconn;
467
468                 newconn = new SQLConn(ServerInstance, this, hi);
469
470                 connections.insert(std::make_pair(hi.id, newconn));
471         }
472
473         virtual char* OnRequest(Request* request)
474         {
475                 if(strcmp(SQLREQID, request->GetId()) == 0)
476                 {
477                         SQLrequest* req = (SQLrequest*)request;
478                         ConnMap::iterator iter;
479                         ServerInstance->Log(DEBUG, "Got query: '%s' with %d replacement parameters on id '%s'", req->query.q.c_str(), req->query.p.size(), req->dbid.c_str());
480                         if((iter = connections.find(req->dbid)) != connections.end())
481                         {
482                                 req->id = NewID();
483                                 req->error = iter->second->Query(*req);
484                                 return SQLSUCCESS;
485                         }
486                         else
487                         {
488                                 req->error.Id(BAD_DBID);
489                                 return NULL;
490                         }
491                 }
492                 ServerInstance->Log(DEBUG, "Got unsupported API version string: %s", request->GetId());
493                 return NULL;
494         }
495
496         unsigned long NewID()
497         {
498                 if (currid+1 == 0)
499                         currid++;
500
501                 return ++currid;
502         }
503
504         virtual Version GetVersion()
505         {
506                 return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);
507         }
508
509 };
510
511
512 class ModuleSQLite3Factory : public ModuleFactory
513 {
514   public:
515         ModuleSQLite3Factory()
516         {
517         }
518
519         ~ModuleSQLite3Factory()
520         {
521         }
522
523         virtual Module * CreateModule(InspIRCd* Me)
524         {
525                 return new ModuleSQLite3(Me);
526         }
527 };
528
529 extern "C" void * init_module( void )
530 {
531         return new ModuleSQLite3Factory;
532 }