]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
469d5e1532f8a43553bc72f12a5ed12196ea5e32
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2007 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *            the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include <string>
15 #include <deque>
16 #include <map>
17 #include <sqlite3.h>
18
19 #include "users.h"
20 #include "channels.h"
21 #include "modules.h"
22 #include "inspircd.h"
23 #include "configreader.h"
24
25 #include "m_sqlv2.h"
26
27 /* $ModDesc: sqlite3 provider */
28 /* $CompileFlags: pkgconfincludes("sqlite3","/sqlite3.h","") */
29 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
30 /* $ModDep: m_sqlv2.h */
31
32
33 class SQLConn;
34 class SQLite3Result;
35 class ResultNotifier;
36
37 typedef std::map<std::string, SQLConn*> ConnMap;
38 typedef std::deque<classbase*> paramlist;
39 typedef std::deque<SQLite3Result*> ResultQueue;
40
41 ResultNotifier* resultnotify = NULL;
42
43
44 class ResultNotifier : public InspSocket
45 {
46         Module* mod;
47         insp_sockaddr sock_us;
48         socklen_t uslen;
49
50  public:
51         /* Create a socket on a random port. Let the tcp stack allocate us an available port */
52 #ifdef IPV6
53         ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "::1", 0, true, 3000), mod(m)
54 #else
55         ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "127.0.0.1", 0, true, 3000), mod(m)
56 #endif
57         {
58                 uslen = sizeof(sock_us);
59                 if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
60                 {
61                         throw ModuleException("Could not create random listening port on localhost");
62                 }
63         }
64
65         ResultNotifier(InspIRCd* SI, Module* m, int newfd, char* ip) : InspSocket(SI, newfd, ip), mod(m)
66         {
67                 Instance->Log(DEBUG,"Constructor of new socket");
68         }
69
70         /* Using getsockname and ntohs, we can determine which port number we were allocated */
71         int GetPort()
72         {
73 #ifdef IPV6
74                 return ntohs(sock_us.sin6_port);
75 #else
76                 return ntohs(sock_us.sin_port);
77 #endif
78         }
79
80         virtual int OnIncomingConnection(int newsock, char* ip)
81         {
82                 Instance->Log(DEBUG,"Inbound connection on fd %d!",newsock);
83                 Dispatch();
84                 return false;
85         }
86
87         void Dispatch();
88 };
89
90
91 class SQLite3Result : public SQLresult
92 {
93   private:
94         int currentrow;
95         int rows;
96         int cols;
97
98         std::vector<std::string> colnames;
99         std::vector<SQLfieldList> fieldlists;
100         SQLfieldList emptyfieldlist;
101
102         SQLfieldList* fieldlist;
103         SQLfieldMap* fieldmap;
104
105   public:
106         SQLite3Result(Module* self, Module* to, unsigned int id)
107         : SQLresult(self, to, id), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
108         {
109         }
110
111         ~SQLite3Result()
112         {
113         }
114
115         void AddRow(int colsnum, char **data, char **colname)
116         {
117                 colnames.clear();
118                 cols = colsnum;
119                 for (int i = 0; i < colsnum; i++)
120                 {
121                         fieldlists.resize(fieldlists.size()+1);
122                         colnames.push_back(colname[i]);
123                         SQLfield sf(data[i] ? data[i] : "", data[i] ? false : true);
124                         fieldlists[rows].push_back(sf);
125                 }
126                 rows++;
127         }
128
129         virtual int Rows()
130         {
131                 return rows;
132         }
133
134         virtual int Cols()
135         {
136                 return cols;
137         }
138
139         virtual std::string ColName(int column)
140         {
141                 if (column < (int)colnames.size())
142                 {
143                         return colnames[column];
144                 }
145                 else
146                 {
147                         throw SQLbadColName();
148                 }
149                 return "";
150         }
151
152         virtual int ColNum(const std::string &column)
153         {
154                 for (unsigned int i = 0; i < colnames.size(); i++)
155                 {
156                         if (column == colnames[i])
157                                 return i;
158                 }
159                 throw SQLbadColName();
160                 return 0;
161         }
162
163         virtual SQLfield GetValue(int row, int column)
164         {
165                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
166                 {
167                         return fieldlists[row][column];
168                 }
169
170                 throw SQLbadColName();
171
172                 /* XXX: We never actually get here because of the throw */
173                 return SQLfield("",true);
174         }
175
176         virtual SQLfieldList& GetRow()
177         {
178                 if (currentrow < rows)
179                         return fieldlists[currentrow];
180                 else
181                         return emptyfieldlist;
182         }
183
184         virtual SQLfieldMap& GetRowMap()
185         {
186                 /* In an effort to reduce overhead we don't actually allocate the map
187                  * until the first time it's needed...so...
188                  */
189                 if(fieldmap)
190                 {
191                         fieldmap->clear();
192                 }
193                 else
194                 {
195                         fieldmap = new SQLfieldMap;
196                 }
197
198                 if (currentrow < rows)
199                 {
200                         for (int i = 0; i < Cols(); i++)
201                         {
202                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
203                         }
204                         currentrow++;
205                 }
206
207                 return *fieldmap;
208         }
209
210         virtual SQLfieldList* GetRowPtr()
211         {
212                 fieldlist = new SQLfieldList();
213
214                 if (currentrow < rows)
215                 {
216                         for (int i = 0; i < Rows(); i++)
217                         {
218                                 fieldlist->push_back(fieldlists[currentrow][i]);
219                         }
220                         currentrow++;
221                 }
222                 return fieldlist;
223         }
224
225         virtual SQLfieldMap* GetRowMapPtr()
226         {
227                 fieldmap = new SQLfieldMap();
228
229                 if (currentrow < rows)
230                 {
231                         for (int i = 0; i < Cols(); i++)
232                         {
233                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
234                         }
235                         currentrow++;
236                 }
237
238                 return fieldmap;
239         }
240
241         virtual void Free(SQLfieldMap* fm)
242         {
243                 delete fm;
244         }
245
246         virtual void Free(SQLfieldList* fl)
247         {
248                 delete fl;
249         }
250
251
252 };
253
254 class SQLConn : public classbase
255 {
256   private:
257         ResultQueue results;
258         InspIRCd* Instance;
259         Module* mod;
260         SQLhost host;
261         sqlite3* conn;
262
263   public:
264         SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
265         : Instance(SI), mod(m), host(hi)
266         {
267                 int result;
268                 if ((result = OpenDB()) == SQLITE_OK)
269                 {
270                         Instance->Log(DEBUG, "Opened sqlite DB: " + host.host);
271                 }
272                 else
273                 {
274                         Instance->Log(DEFAULT, "WARNING: Could not open DB with id: " + host.id);
275                         CloseDB();
276                 }
277         }
278
279         ~SQLConn()
280         {
281                 CloseDB();
282         }
283
284         SQLerror Query(SQLrequest &req)
285         {
286                 /* Pointer to the buffer we screw around with substitution in */
287                 char* query;
288
289                 /* Pointer to the current end of query, where we append new stuff */
290                 char* queryend;
291
292                 /* Total length of the unescaped parameters */
293                 unsigned long paramlen;
294
295                 /* Total length of query, used for binary-safety in mysql_real_query */
296                 unsigned long querylength = 0;
297
298                 paramlen = 0;
299                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
300                 {
301                         paramlen += i->size();
302                 }
303
304                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
305                  * sizeofquery + (totalparamlength*2) + 1
306                  *
307                  * The +1 is for null-terminating the string for mysql_real_escape_string
308                  */
309
310                 query = new char[req.query.q.length() + (paramlen*2) + 1];
311                 queryend = query;
312
313                 for(unsigned long i = 0; i < req.query.q.length(); i++)
314                 {
315                         if(req.query.q[i] == '?')
316                         {
317                                 if(req.query.p.size())
318                                 {
319                                         char* escaped;
320                                         escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
321                                         for (char* n = escaped; *n; n++)
322                                         {
323                                                 *queryend = *n;
324                                                 queryend++;
325                                         }
326                                         sqlite3_free(escaped);
327                                         req.query.p.pop_front();
328                                 }
329                                 else
330                                         break;
331                         }
332                         else
333                         {
334                                 *queryend = req.query.q[i];
335                                 queryend++;
336                         }
337                         querylength++;
338                 }
339                 *queryend = 0;
340                 req.query.q = query;
341
342 //              Instance->Log(DEBUG, "<******> Doing query: " + ConvToStr(req.query.q.data()));
343
344                 SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
345                 res->dbid = host.id;
346                 res->query = req.query.q;
347                 paramlist params;
348                 params.push_back(this);
349                 params.push_back(res);
350
351                 char *errmsg = 0;
352                 if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
353                 {
354                         Instance->Log(DEBUG, "Query failed: " + ConvToStr(errmsg));
355                         sqlite3_free(errmsg);
356                         delete[] query;
357                         delete res;
358                         return SQLerror(QSEND_FAIL, ConvToStr(errmsg));
359                 }
360                 Instance->Log(DEBUG, "Dispatched query successfully. ID: %d resulting rows %d", req.id, res->Rows());
361                 delete[] query;
362
363                 results.push_back(res);
364                 SendNotify();
365                 return SQLerror();
366         }
367
368         static int QueryResult(void *params, int argc, char **argv, char **azColName)
369         {
370                 paramlist* p = (paramlist*)params;
371                 ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
372                 return 0;
373         }
374
375         void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
376         {
377                 res->AddRow(cols, data, colnames);
378         }
379
380         int OpenDB()
381         {
382                 return sqlite3_open(host.host.c_str(), &conn);
383         }
384
385         void CloseDB()
386         {
387                 sqlite3_interrupt(conn);
388                 sqlite3_close(conn);
389                 Instance->Log(DEBUG, "Closed sqlite DB: " + host.host);
390         }
391
392         SQLhost GetConfHost()
393         {
394                 return host;
395         }
396
397         void SendResults()
398         {
399                 while (results.size())
400                 {
401                         SQLite3Result* res = results[0];
402                         if (res->GetDest())
403                         {
404                                 res->Send();
405                         }
406                         else
407                         {
408                                 /* If the client module is unloaded partway through a query then the provider will set
409                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
410                                  * through at some point...and it could get messy if we play with invalid pointers...
411                                  */
412                                 Instance->Log(DEBUG, "Looks like we're handling a zombie query from a module which unloaded before it got a result..fun. ID: " + ConvToStr(res->GetId()));
413                                 delete res;
414                         }
415                         results.pop_front();
416                 }
417         }
418
419         void ClearResults()
420         {
421                 while (results.size())
422                 {
423                         SQLite3Result* res = results[0];
424                         delete res;
425                         results.pop_front();
426                 }
427         }
428
429         void SendNotify()
430         {
431                 int QueueFD;
432                 if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
433                 {
434                         /* crap, we're out of sockets... */
435                         return;
436                 }
437
438                 insp_sockaddr addr;
439
440 #ifdef IPV6
441                 insp_aton("::1", &addr.sin6_addr);
442                 addr.sin6_family = AF_FAMILY;
443                 addr.sin6_port = htons(resultnotify->GetPort());
444 #else
445                 insp_inaddr ia;
446                 insp_aton("127.0.0.1", &ia);
447                 addr.sin_family = AF_FAMILY;
448                 addr.sin_addr = ia;
449                 addr.sin_port = htons(resultnotify->GetPort());
450 #endif
451
452                 if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
453                 {
454                         /* wtf, we cant connect to it, but we just created it! */
455                         return;
456                 }
457         }
458
459 };
460
461
462 class ModuleSQLite3 : public Module
463 {
464   private:
465         ConnMap connections;
466         unsigned long currid;
467
468   public:
469         ModuleSQLite3(InspIRCd* Me)
470         : Module::Module(Me), currid(0)
471         {
472                 ServerInstance->UseInterface("SQLutils");
473
474                 if (!ServerInstance->PublishFeature("SQL", this))
475                 {
476                         throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");
477                 }
478
479                 resultnotify = new ResultNotifier(ServerInstance, this);
480                 ServerInstance->Log(DEBUG,"Bound notifier to 127.0.0.1:%d",resultnotify->GetPort());
481
482                 ReadConf();
483
484                 ServerInstance->PublishInterface("SQL", this);
485         }
486
487         virtual ~ModuleSQLite3()
488         {
489                 ClearQueue();
490                 ClearAllConnections();
491                 resultnotify->SetFd(-1);
492                 resultnotify->state = I_ERROR;
493                 resultnotify->OnError(I_ERR_SOCKET);
494                 resultnotify->ClosePending = true;
495                 delete resultnotify;
496                 ServerInstance->UnpublishInterface("SQL", this);
497                 ServerInstance->UnpublishFeature("SQL");
498                 ServerInstance->DoneWithInterface("SQLutils");
499         }
500
501         void Implements(char* List)
502         {
503                 List[I_OnRequest] = List[I_OnRehash] = 1;
504         }
505
506         void SendQueue()
507         {
508                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
509                 {
510                         iter->second->SendResults();
511                 }
512         }
513
514         void ClearQueue()
515         {
516                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
517                 {
518                         iter->second->ClearResults();
519                 }
520         }
521
522         bool HasHost(const SQLhost &host)
523         {
524                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
525                 {
526                         if (host == iter->second->GetConfHost())
527                                 return true;
528                 }
529                 return false;
530         }
531
532         bool HostInConf(const SQLhost &h)
533         {
534                 ConfigReader conf(ServerInstance);
535                 for(int i = 0; i < conf.Enumerate("database"); i++)
536                 {
537                         SQLhost host;
538                         host.id         = conf.ReadValue("database", "id", i);
539                         host.host       = conf.ReadValue("database", "hostname", i);
540                         host.port       = conf.ReadInteger("database", "port", i, true);
541                         host.name       = conf.ReadValue("database", "name", i);
542                         host.user       = conf.ReadValue("database", "username", i);
543                         host.pass       = conf.ReadValue("database", "password", i);
544                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
545                         if (h == host)
546                                 return true;
547                 }
548                 return false;
549         }
550
551         void ReadConf()
552         {
553                 ClearOldConnections();
554
555                 ConfigReader conf(ServerInstance);
556                 for(int i = 0; i < conf.Enumerate("database"); i++)
557                 {
558                         SQLhost host;
559
560                         host.id         = conf.ReadValue("database", "id", i);
561                         host.host       = conf.ReadValue("database", "hostname", i);
562                         host.port       = conf.ReadInteger("database", "port", i, true);
563                         host.name       = conf.ReadValue("database", "name", i);
564                         host.user       = conf.ReadValue("database", "username", i);
565                         host.pass       = conf.ReadValue("database", "password", i);
566                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
567
568                         if (HasHost(host))
569                                 continue;
570
571                         this->AddConn(host);
572                 }
573         }
574
575         void AddConn(const SQLhost& hi)
576         {
577                 if (HasHost(hi))
578                 {
579                         ServerInstance->Log(DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
580                         return;
581                 }
582
583                 SQLConn* newconn;
584
585                 newconn = new SQLConn(ServerInstance, this, hi);
586
587                 connections.insert(std::make_pair(hi.id, newconn));
588         }
589
590         void ClearOldConnections()
591         {
592                 ConnMap::iterator iter,safei;
593                 for (iter = connections.begin(); iter != connections.end(); iter++)
594                 {
595                         if (!HostInConf(iter->second->GetConfHost()))
596                         {
597                                 DELETE(iter->second);
598                                 safei = iter;
599                                 --iter;
600                                 connections.erase(safei);
601                         }
602                 }
603         }
604
605         void ClearAllConnections()
606         {
607                 ConnMap::iterator i;
608                 while ((i = connections.begin()) != connections.end())
609                 {
610                         connections.erase(i);
611                         DELETE(i->second);
612                 }
613         }
614
615         virtual void OnRehash(userrec* user, const std::string &parameter)
616         {
617                 ReadConf();
618         }
619
620         virtual char* OnRequest(Request* request)
621         {
622                 if(strcmp(SQLREQID, request->GetId()) == 0)
623                 {
624                         SQLrequest* req = (SQLrequest*)request;
625                         ConnMap::iterator iter;
626                         ServerInstance->Log(DEBUG, "Got query: '%s' with %d replacement parameters on id '%s'", req->query.q.c_str(), req->query.p.size(), req->dbid.c_str());
627                         if((iter = connections.find(req->dbid)) != connections.end())
628                         {
629                                 req->id = NewID();
630                                 req->error = iter->second->Query(*req);
631                                 return SQLSUCCESS;
632                         }
633                         else
634                         {
635                                 req->error.Id(BAD_DBID);
636                                 return NULL;
637                         }
638                 }
639                 ServerInstance->Log(DEBUG, "Got unsupported API version string: %s", request->GetId());
640                 return NULL;
641         }
642
643         unsigned long NewID()
644         {
645                 if (currid+1 == 0)
646                         currid++;
647
648                 return ++currid;
649         }
650
651         virtual Version GetVersion()
652         {
653                 return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);
654         }
655
656 };
657
658 void ResultNotifier::Dispatch()
659 {
660         ((ModuleSQLite3*)mod)->SendQueue();
661 }
662
663 class ModuleSQLite3Factory : public ModuleFactory
664 {
665   public:
666         ModuleSQLite3Factory()
667         {
668         }
669
670         ~ModuleSQLite3Factory()
671         {
672         }
673
674         virtual Module * CreateModule(InspIRCd* Me)
675         {
676                 return new ModuleSQLite3(Me);
677         }
678 };
679
680 extern "C" void * init_module( void )
681 {
682         return new ModuleSQLite3Factory;
683 }