]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
902355a2340718fe8708a0a3280ae2e92d6d72c6
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*               +------------------------------------+
2  *               | Inspire Internet Relay Chat Daemon |
3  *               +------------------------------------+
4  *
5  *      InspIRCd: (C) 2002-2010 InspIRCd Development Team
6  * See: http://wiki.inspircd.org/Credits
7  *
8  * This program is free but copyrighted software; see
9  *                        the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include "inspircd.h"
15 #include <sqlite3.h>
16 #include "m_sqlv2.h"
17
18 /* $ModDesc: sqlite3 provider */
19 /* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */
20 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
21 /* $ModDep: m_sqlv2.h */
22 /* $NoPedantic */
23
24 class SQLConn;
25 class SQLite3Result;
26 class ResultNotifier;
27 class SQLiteListener;
28 class ModuleSQLite3;
29
30 typedef std::map<std::string, SQLConn*> ConnMap;
31 typedef std::deque<classbase*> paramlist;
32 typedef std::deque<SQLite3Result*> ResultQueue;
33
34 static unsigned long count(const char * const str, char a)
35 {
36         unsigned long n = 0;
37         for (const char *p = str; *p; ++p)
38         {
39                 if (*p == '?')
40                         ++n;
41         }
42         return n;
43 }
44
45 class SQLite3Result : public SQLresult
46 {
47  private:
48         int currentrow;
49         int rows;
50         int cols;
51
52         std::vector<std::string> colnames;
53         std::vector<SQLfieldList> fieldlists;
54         SQLfieldList emptyfieldlist;
55
56         SQLfieldList* fieldlist;
57         SQLfieldMap* fieldmap;
58
59  public:
60         SQLite3Result(Module* self, Module* to, unsigned int rid)
61         : SQLresult(self, to, rid), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
62         {
63         }
64
65         ~SQLite3Result()
66         {
67         }
68
69         void AddRow(int colsnum, char **dat, char **colname)
70         {
71                 colnames.clear();
72                 cols = colsnum;
73                 for (int i = 0; i < colsnum; i++)
74                 {
75                         fieldlists.resize(fieldlists.size()+1);
76                         colnames.push_back(colname[i]);
77                         SQLfield sf(dat[i] ? dat[i] : "", dat[i] ? false : true);
78                         fieldlists[rows].push_back(sf);
79                 }
80                 rows++;
81         }
82
83         void UpdateAffectedCount()
84         {
85                 rows++;
86         }
87
88         virtual int Rows()
89         {
90                 return rows;
91         }
92
93         virtual int Cols()
94         {
95                 return cols;
96         }
97
98         virtual std::string ColName(int column)
99         {
100                 if (column < (int)colnames.size())
101                 {
102                         return colnames[column];
103                 }
104                 else
105                 {
106                         throw SQLbadColName();
107                 }
108                 return "";
109         }
110
111         virtual int ColNum(const std::string &column)
112         {
113                 for (unsigned int i = 0; i < colnames.size(); i++)
114                 {
115                         if (column == colnames[i])
116                                 return i;
117                 }
118                 throw SQLbadColName();
119                 return 0;
120         }
121
122         virtual SQLfield GetValue(int row, int column)
123         {
124                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
125                 {
126                         return fieldlists[row][column];
127                 }
128
129                 throw SQLbadColName();
130
131                 /* XXX: We never actually get here because of the throw */
132                 return SQLfield("",true);
133         }
134
135         virtual SQLfieldList& GetRow()
136         {
137                 if (currentrow < rows)
138                         return fieldlists[currentrow];
139                 else
140                         return emptyfieldlist;
141         }
142
143         virtual SQLfieldMap& GetRowMap()
144         {
145                 /* In an effort to reduce overhead we don't actually allocate the map
146                  * until the first time it's needed...so...
147                  */
148                 if(fieldmap)
149                 {
150                         fieldmap->clear();
151                 }
152                 else
153                 {
154                         fieldmap = new SQLfieldMap;
155                 }
156
157                 if (currentrow < rows)
158                 {
159                         for (int i = 0; i < Cols(); i++)
160                         {
161                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
162                         }
163                         currentrow++;
164                 }
165
166                 return *fieldmap;
167         }
168
169         virtual SQLfieldList* GetRowPtr()
170         {
171                 fieldlist = new SQLfieldList();
172
173                 if (currentrow < rows)
174                 {
175                         for (int i = 0; i < Rows(); i++)
176                         {
177                                 fieldlist->push_back(fieldlists[currentrow][i]);
178                         }
179                         currentrow++;
180                 }
181                 return fieldlist;
182         }
183
184         virtual SQLfieldMap* GetRowMapPtr()
185         {
186                 fieldmap = new SQLfieldMap();
187
188                 if (currentrow < rows)
189                 {
190                         for (int i = 0; i < Cols(); i++)
191                         {
192                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
193                         }
194                         currentrow++;
195                 }
196
197                 return fieldmap;
198         }
199
200         virtual void Free(SQLfieldMap* fm)
201         {
202                 delete fm;
203         }
204
205         virtual void Free(SQLfieldList* fl)
206         {
207                 delete fl;
208         }
209
210
211 };
212
213 class SQLConn : public classbase
214 {
215  private:
216         ResultQueue results;
217         Module* mod;
218         SQLhost host;
219         sqlite3* conn;
220
221  public:
222         SQLConn(Module* m, const SQLhost& hi)
223         : mod(m), host(hi)
224         {
225                 if (OpenDB() != SQLITE_OK)
226                 {
227                         ServerInstance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: Could not open DB with id: " + host.id);
228                         CloseDB();
229                 }
230         }
231
232         ~SQLConn()
233         {
234                 CloseDB();
235         }
236
237         SQLerror Query(SQLrequest &req)
238         {
239                 /* Pointer to the buffer we screw around with substitution in */
240                 char* query;
241
242                 /* Pointer to the current end of query, where we append new stuff */
243                 char* queryend;
244
245                 /* Total length of the unescaped parameters */
246                 unsigned long maxparamlen, paramcount;
247
248                 /* The length of the longest parameter */
249                 maxparamlen = 0;
250
251                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
252                 {
253                         if (i->size() > maxparamlen)
254                                 maxparamlen = i->size();
255                 }
256
257                 /* How many params are there in the query? */
258                 paramcount = count(req.query.q.c_str(), '?');
259
260                 /* This stores copy of params to be inserted with using numbered params 1;3B*/
261                 ParamL paramscopy(req.query.p);
262
263                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
264                  * sizeofquery + (maxtotalparamlength*2) + 1
265                  *
266                  * The +1 is for null-terminating the string
267                  */
268
269                 query = new char[req.query.q.length() + (maxparamlen*paramcount*2) + 1];
270                 queryend = query;
271
272                 for(unsigned long i = 0; i < req.query.q.length(); i++)
273                 {
274                         if(req.query.q[i] == '?')
275                         {
276                                 /* We found a place to substitute..what fun.
277                                  * use sqlite calls to escape and write the
278                                  * escaped string onto the end of our query buffer,
279                                  * then we "just" need to make sure queryend is
280                                  * pointing at the right place.
281                                  */
282
283                                 /* Is it numbered parameter?
284                                  */
285
286                                 bool numbered;
287                                 numbered = false;
288
289                                 /* Numbered parameter number :|
290                                  */
291                                 unsigned int paramnum;
292                                 paramnum = 0;
293
294                                 /* Let's check if it's a numbered param. And also calculate it's number.
295                                  */
296
297                                 while ((i < req.query.q.length() - 1) && (req.query.q[i+1] >= '0') && (req.query.q[i+1] <= '9'))
298                                 {
299                                         numbered = true;
300                                         ++i;
301                                         paramnum = paramnum * 10 + req.query.q[i] - '0';
302                                 }
303
304                                 if (paramnum > paramscopy.size() - 1)
305                                 {
306                                         /* index is out of range!
307                                          */
308                                         numbered = false;
309                                 }
310
311
312                                 if (numbered)
313                                 {
314                                         char* escaped;
315                                         escaped = sqlite3_mprintf("%q", paramscopy[paramnum].c_str());
316                                         for (char* n = escaped; *n; n++)
317                                         {
318                                                 *queryend = *n;
319                                                 queryend++;
320                                         }
321                                         sqlite3_free(escaped);
322                                 }
323                                 else if (req.query.p.size())
324                                 {
325                                         char* escaped;
326                                         escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
327                                         for (char* n = escaped; *n; n++)
328                                         {
329                                                 *queryend = *n;
330                                                 queryend++;
331                                         }
332                                         sqlite3_free(escaped);
333                                         req.query.p.pop_front();
334                                 }
335                                 else
336                                         break;
337                         }
338                         else
339                         {
340                                 *queryend = req.query.q[i];
341                                 queryend++;
342                         }
343                 }
344                 *queryend = 0;
345                 req.query.q = query;
346
347                 SQLite3Result* res = new SQLite3Result(mod, req.source, req.id);
348                 res->dbid = host.id;
349                 res->query = req.query.q;
350                 paramlist params;
351                 params.push_back(this);
352                 params.push_back(res);
353
354                 char *errmsg = 0;
355                 sqlite3_update_hook(conn, QueryUpdateHook, &params);
356                 if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
357                 {
358                         std::string error(errmsg);
359                         sqlite3_free(errmsg);
360                         delete[] query;
361                         delete res;
362                         return SQLerror(SQL_QSEND_FAIL, error);
363                 }
364                 delete[] query;
365
366                 results.push_back(res);
367                 SendResults();
368                 return SQLerror();
369         }
370
371         static int QueryResult(void *params, int argc, char **argv, char **azColName)
372         {
373                 paramlist* p = (paramlist*)params;
374                 ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
375                 return 0;
376         }
377
378         static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid)
379         {
380                 paramlist* p = (paramlist*)params;
381                 ((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1]));
382         }
383
384         void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
385         {
386                 res->AddRow(cols, data, colnames);
387         }
388
389         void AffectedReady(SQLite3Result *res)
390         {
391                 res->UpdateAffectedCount();
392         }
393
394         int OpenDB()
395         {
396                 return sqlite3_open_v2(host.host.c_str(), &conn, SQLITE_OPEN_READWRITE, 0);
397         }
398
399         void CloseDB()
400         {
401                 sqlite3_interrupt(conn);
402                 sqlite3_close(conn);
403         }
404
405         SQLhost GetConfHost()
406         {
407                 return host;
408         }
409
410         void SendResults()
411         {
412                 while (results.size())
413                 {
414                         SQLite3Result* res = results[0];
415                         if (res->dest)
416                         {
417                                 res->Send();
418                         }
419                         else
420                         {
421                                 /* If the client module is unloaded partway through a query then the provider will set
422                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
423                                  * through at some point...and it could get messy if we play with invalid pointers...
424                                  */
425                                 delete res;
426                         }
427                         results.pop_front();
428                 }
429         }
430
431         void ClearResults()
432         {
433                 while (results.size())
434                 {
435                         SQLite3Result* res = results[0];
436                         delete res;
437                         results.pop_front();
438                 }
439         }
440
441 };
442
443
444 class ModuleSQLite3 : public Module
445 {
446  private:
447         ConnMap connections;
448         unsigned long currid;
449         ServiceProvider sqlserv;
450
451  public:
452         ModuleSQLite3()
453         : currid(0), sqlserv(this, "SQL/sqlite", SERVICE_DATA)
454         {
455         }
456
457         void init()
458         {
459                 ServerInstance->Modules->AddService(sqlserv);
460
461                 ReadConf();
462
463                 Implementation eventlist[] = { I_OnRehash };
464                 ServerInstance->Modules->Attach(eventlist, this, 1);
465         }
466
467         virtual ~ModuleSQLite3()
468         {
469                 ClearQueue();
470                 ClearAllConnections();
471         }
472
473         void ClearQueue()
474         {
475                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
476                 {
477                         iter->second->ClearResults();
478                 }
479         }
480
481         bool HasHost(const SQLhost &host)
482         {
483                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
484                 {
485                         if (host == iter->second->GetConfHost())
486                                 return true;
487                 }
488                 return false;
489         }
490
491         bool HostInConf(const SQLhost &h)
492         {
493                 ConfigReader conf;
494                 for(int i = 0; i < conf.Enumerate("database"); i++)
495                 {
496                         SQLhost host;
497                         host.id         = conf.ReadValue("database", "id", i);
498                         host.host       = conf.ReadValue("database", "hostname", i);
499                         host.port       = conf.ReadInteger("database", "port", i, true);
500                         host.name       = conf.ReadValue("database", "name", i);
501                         host.user       = conf.ReadValue("database", "username", i);
502                         host.pass       = conf.ReadValue("database", "password", i);
503                         if (h == host)
504                                 return true;
505                 }
506                 return false;
507         }
508
509         void ReadConf()
510         {
511                 ClearOldConnections();
512
513                 ConfigReader conf;
514                 for(int i = 0; i < conf.Enumerate("database"); i++)
515                 {
516                         SQLhost host;
517
518                         host.id         = conf.ReadValue("database", "id", i);
519                         host.host       = conf.ReadValue("database", "hostname", i);
520                         host.port       = conf.ReadInteger("database", "port", i, true);
521                         host.name       = conf.ReadValue("database", "name", i);
522                         host.user       = conf.ReadValue("database", "username", i);
523                         host.pass       = conf.ReadValue("database", "password", i);
524
525                         if (HasHost(host))
526                                 continue;
527
528                         this->AddConn(host);
529                 }
530         }
531
532         void AddConn(const SQLhost& hi)
533         {
534                 if (HasHost(hi))
535                 {
536                         ServerInstance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
537                         return;
538                 }
539
540                 SQLConn* newconn;
541
542                 newconn = new SQLConn(this, hi);
543
544                 connections.insert(std::make_pair(hi.id, newconn));
545         }
546
547         void ClearOldConnections()
548         {
549                 ConnMap::iterator iter,safei;
550                 for (iter = connections.begin(); iter != connections.end(); iter++)
551                 {
552                         if (!HostInConf(iter->second->GetConfHost()))
553                         {
554                                 delete iter->second;
555                                 safei = iter;
556                                 --iter;
557                                 connections.erase(safei);
558                         }
559                 }
560         }
561
562         void ClearAllConnections()
563         {
564                 ConnMap::iterator i;
565                 while ((i = connections.begin()) != connections.end())
566                 {
567                         connections.erase(i);
568                         delete i->second;
569                 }
570         }
571
572         virtual void OnRehash(User* user)
573         {
574                 ReadConf();
575         }
576
577         void OnRequest(Request& request)
578         {
579                 if(strcmp(SQLREQID, request.id) == 0)
580                 {
581                         SQLrequest* req = (SQLrequest*)&request;
582                         ConnMap::iterator iter;
583                         if((iter = connections.find(req->dbid)) != connections.end())
584                         {
585                                 req->id = NewID();
586                                 req->error = iter->second->Query(*req);
587                         }
588                         else
589                         {
590                                 req->error.Id(SQL_BAD_DBID);
591                         }
592                 }
593         }
594
595         unsigned long NewID()
596         {
597                 if (currid+1 == 0)
598                         currid++;
599
600                 return ++currid;
601         }
602
603         virtual Version GetVersion()
604         {
605                 return Version("sqlite3 provider", VF_VENDOR);
606         }
607
608 };
609
610 MODULE_INIT(ModuleSQLite3)