]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
fixed some indentation and spacing in modules
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2008 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *            the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include "inspircd.h"
15 #include <sqlite3.h>
16 #include "users.h"
17 #include "channels.h"
18 #include "modules.h"
19 #include "m_sqlv2.h"
20
21 /* $ModDesc: sqlite3 provider */
22 /* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */
23 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
24 /* $ModDep: m_sqlv2.h */
25 /* $NoPedantic */
26
27 class SQLConn;
28 class SQLite3Result;
29 class ResultNotifier;
30
31 typedef std::map<std::string, SQLConn*> ConnMap;
32 typedef std::deque<classbase*> paramlist;
33 typedef std::deque<SQLite3Result*> ResultQueue;
34
35 ResultNotifier* resultnotify = NULL;
36 ResultNotifier* resultdispatch = NULL;
37 int QueueFD = -1;
38
39 class ResultNotifier : public BufferedSocket
40 {
41         Module* mod;
42         insp_sockaddr sock_us;
43         socklen_t uslen;
44
45  public:
46         /* Create a socket on a random port. Let the tcp stack allocate us an available port */
47 #ifdef IPV6
48         ResultNotifier(InspIRCd* SI, Module* m) : BufferedSocket(SI, "::1", 0, true, 3000), mod(m)
49 #else
50         ResultNotifier(InspIRCd* SI, Module* m) : BufferedSocket(SI, "127.0.0.1", 0, true, 3000), mod(m)
51 #endif
52         {
53                 uslen = sizeof(sock_us);
54                 if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
55                 {
56                         throw ModuleException("Could not create random listening port on localhost");
57                 }
58         }
59
60         ResultNotifier(InspIRCd* SI, Module* m, int newfd, char* ip) : BufferedSocket(SI, newfd, ip), mod(m)
61         {
62         }
63
64         /* Using getsockname and ntohs, we can determine which port number we were allocated */
65         int GetPort()
66         {
67 #ifdef IPV6
68                 return ntohs(sock_us.sin6_port);
69 #else
70                 return ntohs(sock_us.sin_port);
71 #endif
72         }
73
74         virtual int OnIncomingConnection(int newsock, char* ip)
75         {
76                 resultdispatch = new ResultNotifier(Instance, mod, newsock, ip);
77                 return true;
78         }
79
80         virtual bool OnDataReady()
81         {
82                 char data = 0;
83                 if (Instance->SE->Recv(this, &data, 1, 0) > 0)
84                 {
85                         Dispatch();
86                         return true;
87                 }
88                 return false;
89         }
90
91         void Dispatch();
92 };
93
94
95 class SQLite3Result : public SQLresult
96 {
97  private:
98         int currentrow;
99         int rows;
100         int cols;
101
102         std::vector<std::string> colnames;
103         std::vector<SQLfieldList> fieldlists;
104         SQLfieldList emptyfieldlist;
105
106         SQLfieldList* fieldlist;
107         SQLfieldMap* fieldmap;
108
109  public:
110         SQLite3Result(Module* self, Module* to, unsigned int rid)
111         : SQLresult(self, to, rid), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
112         {
113         }
114
115         ~SQLite3Result()
116         {
117         }
118
119         void AddRow(int colsnum, char **dat, char **colname)
120         {
121                 colnames.clear();
122                 cols = colsnum;
123                 for (int i = 0; i < colsnum; i++)
124                 {
125                         fieldlists.resize(fieldlists.size()+1);
126                         colnames.push_back(colname[i]);
127                         SQLfield sf(dat[i] ? dat[i] : "", dat[i] ? false : true);
128                         fieldlists[rows].push_back(sf);
129                 }
130                 rows++;
131         }
132
133         void UpdateAffectedCount()
134         {
135                 rows++;
136         }
137
138         virtual int Rows()
139         {
140                 return rows;
141         }
142
143         virtual int Cols()
144         {
145                 return cols;
146         }
147
148         virtual std::string ColName(int column)
149         {
150                 if (column < (int)colnames.size())
151                 {
152                         return colnames[column];
153                 }
154                 else
155                 {
156                         throw SQLbadColName();
157                 }
158                 return "";
159         }
160
161         virtual int ColNum(const std::string &column)
162         {
163                 for (unsigned int i = 0; i < colnames.size(); i++)
164                 {
165                         if (column == colnames[i])
166                                 return i;
167                 }
168                 throw SQLbadColName();
169                 return 0;
170         }
171
172         virtual SQLfield GetValue(int row, int column)
173         {
174                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
175                 {
176                         return fieldlists[row][column];
177                 }
178
179                 throw SQLbadColName();
180
181                 /* XXX: We never actually get here because of the throw */
182                 return SQLfield("",true);
183         }
184
185         virtual SQLfieldList& GetRow()
186         {
187                 if (currentrow < rows)
188                         return fieldlists[currentrow];
189                 else
190                         return emptyfieldlist;
191         }
192
193         virtual SQLfieldMap& GetRowMap()
194         {
195                 /* In an effort to reduce overhead we don't actually allocate the map
196                  * until the first time it's needed...so...
197                  */
198                 if(fieldmap)
199                 {
200                         fieldmap->clear();
201                 }
202                 else
203                 {
204                         fieldmap = new SQLfieldMap;
205                 }
206
207                 if (currentrow < rows)
208                 {
209                         for (int i = 0; i < Cols(); i++)
210                         {
211                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
212                         }
213                         currentrow++;
214                 }
215
216                 return *fieldmap;
217         }
218
219         virtual SQLfieldList* GetRowPtr()
220         {
221                 fieldlist = new SQLfieldList();
222
223                 if (currentrow < rows)
224                 {
225                         for (int i = 0; i < Rows(); i++)
226                         {
227                                 fieldlist->push_back(fieldlists[currentrow][i]);
228                         }
229                         currentrow++;
230                 }
231                 return fieldlist;
232         }
233
234         virtual SQLfieldMap* GetRowMapPtr()
235         {
236                 fieldmap = new SQLfieldMap();
237
238                 if (currentrow < rows)
239                 {
240                         for (int i = 0; i < Cols(); i++)
241                         {
242                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
243                         }
244                         currentrow++;
245                 }
246
247                 return fieldmap;
248         }
249
250         virtual void Free(SQLfieldMap* fm)
251         {
252                 delete fm;
253         }
254
255         virtual void Free(SQLfieldList* fl)
256         {
257                 delete fl;
258         }
259
260
261 };
262
263 class SQLConn : public classbase
264 {
265  private:
266         ResultQueue results;
267         InspIRCd* Instance;
268         Module* mod;
269         SQLhost host;
270         sqlite3* conn;
271
272  public:
273         SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
274         : Instance(SI), mod(m), host(hi)
275         {
276                 if (OpenDB() != SQLITE_OK)
277                 {
278                         Instance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: Could not open DB with id: " + host.id);
279                         CloseDB();
280                 }
281         }
282
283         ~SQLConn()
284         {
285                 CloseDB();
286         }
287
288         SQLerror Query(SQLrequest &req)
289         {
290                 /* Pointer to the buffer we screw around with substitution in */
291                 char* query;
292
293                 /* Pointer to the current end of query, where we append new stuff */
294                 char* queryend;
295
296                 /* Total length of the unescaped parameters */
297                 unsigned long paramlen;
298
299                 /* Total length of query, used for binary-safety in mysql_real_query */
300                 unsigned long querylength = 0;
301
302                 paramlen = 0;
303                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
304                 {
305                         paramlen += i->size();
306                 }
307
308                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
309                  * sizeofquery + (totalparamlength*2) + 1
310                  *
311                  * The +1 is for null-terminating the string for mysql_real_escape_string
312                  */
313                 query = new char[req.query.q.length() + (paramlen*2) + 1];
314                 queryend = query;
315
316                 for(unsigned long i = 0; i < req.query.q.length(); i++)
317                 {
318                         if(req.query.q[i] == '?')
319                         {
320                                 if(req.query.p.size())
321                                 {
322                                         char* escaped;
323                                         escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
324                                         for (char* n = escaped; *n; n++)
325                                         {
326                                                 *queryend = *n;
327                                                 queryend++;
328                                         }
329                                         sqlite3_free(escaped);
330                                         req.query.p.pop_front();
331                                 }
332                                 else
333                                         break;
334                         }
335                         else
336                         {
337                                 *queryend = req.query.q[i];
338                                 queryend++;
339                         }
340                         querylength++;
341                 }
342                 *queryend = 0;
343                 req.query.q = query;
344
345                 SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
346                 res->dbid = host.id;
347                 res->query = req.query.q;
348                 paramlist params;
349                 params.push_back(this);
350                 params.push_back(res);
351
352                 char *errmsg = 0;
353                 sqlite3_update_hook(conn, QueryUpdateHook, &params);
354                 if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
355                 {
356                         std::string error(errmsg);
357                         sqlite3_free(errmsg);
358                         delete[] query;
359                         delete res;
360                         return SQLerror(QSEND_FAIL, error);
361                 }
362                 delete[] query;
363
364                 results.push_back(res);
365                 SendNotify();
366                 return SQLerror();
367         }
368
369         static int QueryResult(void *params, int argc, char **argv, char **azColName)
370         {
371                 paramlist* p = (paramlist*)params;
372                 ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
373                 return 0;
374         }
375
376         static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid)
377         {
378                 paramlist* p = (paramlist*)params;
379                 ((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1]));
380         }
381
382         void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
383         {
384                 res->AddRow(cols, data, colnames);
385         }
386
387         void AffectedReady(SQLite3Result *res)
388         {
389                 res->UpdateAffectedCount();
390         }
391
392         int OpenDB()
393         {
394                 return sqlite3_open(host.host.c_str(), &conn);
395         }
396
397         void CloseDB()
398         {
399                 sqlite3_interrupt(conn);
400                 sqlite3_close(conn);
401         }
402
403         SQLhost GetConfHost()
404         {
405                 return host;
406         }
407
408         void SendResults()
409         {
410                 while (results.size())
411                 {
412                         SQLite3Result* res = results[0];
413                         if (res->GetDest())
414                         {
415                                 res->Send();
416                         }
417                         else
418                         {
419                                 /* If the client module is unloaded partway through a query then the provider will set
420                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
421                                  * through at some point...and it could get messy if we play with invalid pointers...
422                                  */
423                                 delete res;
424                         }
425                         results.pop_front();
426                 }
427         }
428
429         void ClearResults()
430         {
431                 while (results.size())
432                 {
433                         SQLite3Result* res = results[0];
434                         delete res;
435                         results.pop_front();
436                 }
437         }
438
439         void SendNotify()
440         {
441                 if (QueueFD < 0)
442                 {
443                         if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
444                         {
445                                 /* crap, we're out of sockets... */
446                                 return;
447                         }
448
449                         insp_sockaddr addr;
450
451 #ifdef IPV6
452                         insp_aton("::1", &addr.sin6_addr);
453                         addr.sin6_family = AF_FAMILY;
454                         addr.sin6_port = htons(resultnotify->GetPort());
455 #else
456                         insp_inaddr ia;
457                         insp_aton("127.0.0.1", &ia);
458                         addr.sin_family = AF_FAMILY;
459                         addr.sin_addr = ia;
460                         addr.sin_port = htons(resultnotify->GetPort());
461 #endif
462
463                         if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
464                         {
465                                 /* wtf, we cant connect to it, but we just created it! */
466                                 return;
467                         }
468                 }
469                 char id = 0;
470                 send(QueueFD, &id, 1, 0);
471         }
472
473 };
474
475
476 class ModuleSQLite3 : public Module
477 {
478  private:
479         ConnMap connections;
480         unsigned long currid;
481
482  public:
483         ModuleSQLite3(InspIRCd* Me)
484         : Module::Module(Me), currid(0)
485         {
486                 ServerInstance->Modules->UseInterface("SQLutils");
487
488                 if (!ServerInstance->Modules->PublishFeature("SQL", this))
489                 {
490                         throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");
491                 }
492
493                 resultnotify = new ResultNotifier(ServerInstance, this);
494
495                 ReadConf();
496
497                 ServerInstance->Modules->PublishInterface("SQL", this);
498                 Implementation eventlist[] = { I_OnRequest, I_OnRehash };
499                 ServerInstance->Modules->Attach(eventlist, this, 2);
500         }
501
502         virtual ~ModuleSQLite3()
503         {
504                 ClearQueue();
505                 ClearAllConnections();
506
507                 ServerInstance->SE->DelFd(resultnotify);
508                 resultnotify->Close();
509                 ServerInstance->BufferedSocketCull();
510
511                 if (QueueFD >= 0)
512                 {
513                         shutdown(QueueFD, 2);
514                         close(QueueFD);
515                 }
516
517                 if (resultdispatch)
518                 {
519                         ServerInstance->SE->DelFd(resultdispatch);
520                         resultdispatch->Close();
521                         ServerInstance->BufferedSocketCull();
522                 }
523
524                 ServerInstance->Modules->UnpublishInterface("SQL", this);
525                 ServerInstance->Modules->UnpublishFeature("SQL");
526                 ServerInstance->Modules->DoneWithInterface("SQLutils");
527         }
528
529
530         void SendQueue()
531         {
532                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
533                 {
534                         iter->second->SendResults();
535                 }
536         }
537
538         void ClearQueue()
539         {
540                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
541                 {
542                         iter->second->ClearResults();
543                 }
544         }
545
546         bool HasHost(const SQLhost &host)
547         {
548                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
549                 {
550                         if (host == iter->second->GetConfHost())
551                                 return true;
552                 }
553                 return false;
554         }
555
556         bool HostInConf(const SQLhost &h)
557         {
558                 ConfigReader conf(ServerInstance);
559                 for(int i = 0; i < conf.Enumerate("database"); i++)
560                 {
561                         SQLhost host;
562                         host.id         = conf.ReadValue("database", "id", i);
563                         host.host       = conf.ReadValue("database", "hostname", i);
564                         host.port       = conf.ReadInteger("database", "port", i, true);
565                         host.name       = conf.ReadValue("database", "name", i);
566                         host.user       = conf.ReadValue("database", "username", i);
567                         host.pass       = conf.ReadValue("database", "password", i);
568                         if (h == host)
569                                 return true;
570                 }
571                 return false;
572         }
573
574         void ReadConf()
575         {
576                 ClearOldConnections();
577
578                 ConfigReader conf(ServerInstance);
579                 for(int i = 0; i < conf.Enumerate("database"); i++)
580                 {
581                         SQLhost host;
582
583                         host.id         = conf.ReadValue("database", "id", i);
584                         host.host       = conf.ReadValue("database", "hostname", i);
585                         host.port       = conf.ReadInteger("database", "port", i, true);
586                         host.name       = conf.ReadValue("database", "name", i);
587                         host.user       = conf.ReadValue("database", "username", i);
588                         host.pass       = conf.ReadValue("database", "password", i);
589
590                         if (HasHost(host))
591                                 continue;
592
593                         this->AddConn(host);
594                 }
595         }
596
597         void AddConn(const SQLhost& hi)
598         {
599                 if (HasHost(hi))
600                 {
601                         ServerInstance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
602                         return;
603                 }
604
605                 SQLConn* newconn;
606
607                 newconn = new SQLConn(ServerInstance, this, hi);
608
609                 connections.insert(std::make_pair(hi.id, newconn));
610         }
611
612         void ClearOldConnections()
613         {
614                 ConnMap::iterator iter,safei;
615                 for (iter = connections.begin(); iter != connections.end(); iter++)
616                 {
617                         if (!HostInConf(iter->second->GetConfHost()))
618                         {
619                                 delete iter->second;
620                                 safei = iter;
621                                 --iter;
622                                 connections.erase(safei);
623                         }
624                 }
625         }
626
627         void ClearAllConnections()
628         {
629                 ConnMap::iterator i;
630                 while ((i = connections.begin()) != connections.end())
631                 {
632                         connections.erase(i);
633                         delete i->second;
634                 }
635         }
636
637         virtual void OnRehash(User* user, const std::string &parameter)
638         {
639                 ReadConf();
640         }
641
642         virtual const char* OnRequest(Request* request)
643         {
644                 if(strcmp(SQLREQID, request->GetId()) == 0)
645                 {
646                         SQLrequest* req = (SQLrequest*)request;
647                         ConnMap::iterator iter;
648                         if((iter = connections.find(req->dbid)) != connections.end())
649                         {
650                                 req->id = NewID();
651                                 req->error = iter->second->Query(*req);
652                                 return SQLSUCCESS;
653                         }
654                         else
655                         {
656                                 req->error.Id(BAD_DBID);
657                                 return NULL;
658                         }
659                 }
660                 return NULL;
661         }
662
663         unsigned long NewID()
664         {
665                 if (currid+1 == 0)
666                         currid++;
667
668                 return ++currid;
669         }
670
671         virtual Version GetVersion()
672         {
673                 return Version(1,2,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);
674         }
675
676 };
677
678 void ResultNotifier::Dispatch()
679 {
680         ((ModuleSQLite3*)mod)->SendQueue();
681 }
682
683 MODULE_INIT(ModuleSQLite3)