]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
Hook qlite3_update_hook to queries to also catch affected rows on UPDATE/INSERT/DELETE.
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2007 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *            the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include <string>
15 #include <deque>
16 #include <map>
17 #include <sqlite3.h>
18
19 #include "users.h"
20 #include "channels.h"
21 #include "modules.h"
22 #include "inspircd.h"
23 #include "configreader.h"
24
25 #include "m_sqlv2.h"
26
27 /* $ModDesc: sqlite3 provider */
28 /* $CompileFlags: pkgconfincludes("sqlite3","/sqlite3.h","") */
29 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
30 /* $ModDep: m_sqlv2.h */
31
32
33 class SQLConn;
34 class SQLite3Result;
35 class ResultNotifier;
36
37 typedef std::map<std::string, SQLConn*> ConnMap;
38 typedef std::deque<classbase*> paramlist;
39 typedef std::deque<SQLite3Result*> ResultQueue;
40
41 ResultNotifier* resultnotify = NULL;
42
43
44 class ResultNotifier : public InspSocket
45 {
46         Module* mod;
47         insp_sockaddr sock_us;
48         socklen_t uslen;
49
50  public:
51         /* Create a socket on a random port. Let the tcp stack allocate us an available port */
52 #ifdef IPV6
53         ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "::1", 0, true, 3000), mod(m)
54 #else
55         ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "127.0.0.1", 0, true, 3000), mod(m)
56 #endif
57         {
58                 uslen = sizeof(sock_us);
59                 if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
60                 {
61                         throw ModuleException("Could not create random listening port on localhost");
62                 }
63         }
64
65         ResultNotifier(InspIRCd* SI, Module* m, int newfd, char* ip) : InspSocket(SI, newfd, ip), mod(m)
66         {
67                 Instance->Log(DEBUG,"Constructor of new socket");
68         }
69
70         /* Using getsockname and ntohs, we can determine which port number we were allocated */
71         int GetPort()
72         {
73 #ifdef IPV6
74                 return ntohs(sock_us.sin6_port);
75 #else
76                 return ntohs(sock_us.sin_port);
77 #endif
78         }
79
80         virtual int OnIncomingConnection(int newsock, char* ip)
81         {
82                 Instance->Log(DEBUG,"Inbound connection on fd %d!",newsock);
83                 Dispatch();
84                 return false;
85         }
86
87         void Dispatch();
88 };
89
90
91 class SQLite3Result : public SQLresult
92 {
93   private:
94         int currentrow;
95         int rows;
96         int cols;
97
98         std::vector<std::string> colnames;
99         std::vector<SQLfieldList> fieldlists;
100         SQLfieldList emptyfieldlist;
101
102         SQLfieldList* fieldlist;
103         SQLfieldMap* fieldmap;
104
105   public:
106         SQLite3Result(Module* self, Module* to, unsigned int id)
107         : SQLresult(self, to, id), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
108         {
109         }
110
111         ~SQLite3Result()
112         {
113         }
114
115         void AddRow(int colsnum, char **data, char **colname)
116         {
117                 colnames.clear();
118                 cols = colsnum;
119                 for (int i = 0; i < colsnum; i++)
120                 {
121                         fieldlists.resize(fieldlists.size()+1);
122                         colnames.push_back(colname[i]);
123                         SQLfield sf(data[i] ? data[i] : "", data[i] ? false : true);
124                         fieldlists[rows].push_back(sf);
125                 }
126                 rows++;
127         }
128
129         void UpdateAffectedCount()
130         {
131                 rows++;
132         }
133
134         virtual int Rows()
135         {
136                 return rows;
137         }
138
139         virtual int Cols()
140         {
141                 return cols;
142         }
143
144         virtual std::string ColName(int column)
145         {
146                 if (column < (int)colnames.size())
147                 {
148                         return colnames[column];
149                 }
150                 else
151                 {
152                         throw SQLbadColName();
153                 }
154                 return "";
155         }
156
157         virtual int ColNum(const std::string &column)
158         {
159                 for (unsigned int i = 0; i < colnames.size(); i++)
160                 {
161                         if (column == colnames[i])
162                                 return i;
163                 }
164                 throw SQLbadColName();
165                 return 0;
166         }
167
168         virtual SQLfield GetValue(int row, int column)
169         {
170                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
171                 {
172                         return fieldlists[row][column];
173                 }
174
175                 throw SQLbadColName();
176
177                 /* XXX: We never actually get here because of the throw */
178                 return SQLfield("",true);
179         }
180
181         virtual SQLfieldList& GetRow()
182         {
183                 if (currentrow < rows)
184                         return fieldlists[currentrow];
185                 else
186                         return emptyfieldlist;
187         }
188
189         virtual SQLfieldMap& GetRowMap()
190         {
191                 /* In an effort to reduce overhead we don't actually allocate the map
192                  * until the first time it's needed...so...
193                  */
194                 if(fieldmap)
195                 {
196                         fieldmap->clear();
197                 }
198                 else
199                 {
200                         fieldmap = new SQLfieldMap;
201                 }
202
203                 if (currentrow < rows)
204                 {
205                         for (int i = 0; i < Cols(); i++)
206                         {
207                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
208                         }
209                         currentrow++;
210                 }
211
212                 return *fieldmap;
213         }
214
215         virtual SQLfieldList* GetRowPtr()
216         {
217                 fieldlist = new SQLfieldList();
218
219                 if (currentrow < rows)
220                 {
221                         for (int i = 0; i < Rows(); i++)
222                         {
223                                 fieldlist->push_back(fieldlists[currentrow][i]);
224                         }
225                         currentrow++;
226                 }
227                 return fieldlist;
228         }
229
230         virtual SQLfieldMap* GetRowMapPtr()
231         {
232                 fieldmap = new SQLfieldMap();
233
234                 if (currentrow < rows)
235                 {
236                         for (int i = 0; i < Cols(); i++)
237                         {
238                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
239                         }
240                         currentrow++;
241                 }
242
243                 return fieldmap;
244         }
245
246         virtual void Free(SQLfieldMap* fm)
247         {
248                 delete fm;
249         }
250
251         virtual void Free(SQLfieldList* fl)
252         {
253                 delete fl;
254         }
255
256
257 };
258
259 class SQLConn : public classbase
260 {
261   private:
262         ResultQueue results;
263         InspIRCd* Instance;
264         Module* mod;
265         SQLhost host;
266         sqlite3* conn;
267
268   public:
269         SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
270         : Instance(SI), mod(m), host(hi)
271         {
272                 int result;
273                 if ((result = OpenDB()) == SQLITE_OK)
274                 {
275                         Instance->Log(DEBUG, "Opened sqlite DB: " + host.host);
276                 }
277                 else
278                 {
279                         Instance->Log(DEFAULT, "WARNING: Could not open DB with id: " + host.id);
280                         CloseDB();
281                 }
282         }
283
284         ~SQLConn()
285         {
286                 CloseDB();
287         }
288
289         SQLerror Query(SQLrequest &req)
290         {
291                 /* Pointer to the buffer we screw around with substitution in */
292                 char* query;
293
294                 /* Pointer to the current end of query, where we append new stuff */
295                 char* queryend;
296
297                 /* Total length of the unescaped parameters */
298                 unsigned long paramlen;
299
300                 /* Total length of query, used for binary-safety in mysql_real_query */
301                 unsigned long querylength = 0;
302
303                 paramlen = 0;
304                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
305                 {
306                         paramlen += i->size();
307                 }
308
309                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
310                  * sizeofquery + (totalparamlength*2) + 1
311                  *
312                  * The +1 is for null-terminating the string for mysql_real_escape_string
313                  */
314
315                 query = new char[req.query.q.length() + (paramlen*2) + 1];
316                 queryend = query;
317
318                 for(unsigned long i = 0; i < req.query.q.length(); i++)
319                 {
320                         if(req.query.q[i] == '?')
321                         {
322                                 if(req.query.p.size())
323                                 {
324                                         char* escaped;
325                                         escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
326                                         for (char* n = escaped; *n; n++)
327                                         {
328                                                 *queryend = *n;
329                                                 queryend++;
330                                         }
331                                         sqlite3_free(escaped);
332                                         req.query.p.pop_front();
333                                 }
334                                 else
335                                         break;
336                         }
337                         else
338                         {
339                                 *queryend = req.query.q[i];
340                                 queryend++;
341                         }
342                         querylength++;
343                 }
344                 *queryend = 0;
345                 req.query.q = query;
346
347                 //Instance->Log(DEBUG, "<******> Doing query: " + ConvToStr(req.query.q.data()));
348                 SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
349                 res->dbid = host.id;
350                 res->query = req.query.q;
351                 paramlist params;
352                 params.push_back(this);
353                 params.push_back(res);
354
355                 char *errmsg = 0;
356                 sqlite3_update_hook(conn, QueryUpdateHook, &params);
357                 if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
358                 {
359                         Instance->Log(DEBUG, "Query failed: " + ConvToStr(errmsg));
360                         sqlite3_free(errmsg);
361                         delete[] query;
362                         delete res;
363                         return SQLerror(QSEND_FAIL, ConvToStr(errmsg));
364                 }
365                 Instance->Log(DEBUG, "Dispatched query successfully. ID: %d resulting rows %d", req.id, res->Rows());
366                 delete[] query;
367
368                 results.push_back(res);
369                 SendNotify();
370                 return SQLerror();
371         }
372
373         static int QueryResult(void *params, int argc, char **argv, char **azColName)
374         {
375                 paramlist* p = (paramlist*)params;
376                 ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
377                 return 0;
378         }
379
380         static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid)
381         {
382                 paramlist* p = (paramlist*)params;
383                 ((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1]));
384         }\r
385
386         void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
387         {
388                 res->AddRow(cols, data, colnames);
389         }
390
391         void AffectedReady(SQLite3Result *res)
392         {
393                 res->UpdateAffectedCount();
394         }
395
396         int OpenDB()
397         {
398                 return sqlite3_open(host.host.c_str(), &conn);
399         }
400
401         void CloseDB()
402         {
403                 sqlite3_interrupt(conn);
404                 sqlite3_close(conn);
405                 Instance->Log(DEBUG, "Closed sqlite DB: " + host.host);
406         }
407
408         SQLhost GetConfHost()
409         {
410                 return host;
411         }
412
413         void SendResults()
414         {
415                 while (results.size())
416                 {
417                         SQLite3Result* res = results[0];
418                         if (res->GetDest())
419                         {
420                                 res->Send();
421                         }
422                         else
423                         {
424                                 /* If the client module is unloaded partway through a query then the provider will set
425                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
426                                  * through at some point...and it could get messy if we play with invalid pointers...
427                                  */
428                                 Instance->Log(DEBUG, "Looks like we're handling a zombie query from a module which unloaded before it got a result..fun. ID: " + ConvToStr(res->GetId()));
429                                 delete res;
430                         }
431                         results.pop_front();
432                 }
433         }
434
435         void ClearResults()
436         {
437                 while (results.size())
438                 {
439                         SQLite3Result* res = results[0];
440                         delete res;
441                         results.pop_front();
442                 }
443         }
444
445         void SendNotify()
446         {
447                 int QueueFD;
448                 if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
449                 {
450                         /* crap, we're out of sockets... */
451                         return;
452                 }
453
454                 insp_sockaddr addr;
455
456 #ifdef IPV6
457                 insp_aton("::1", &addr.sin6_addr);
458                 addr.sin6_family = AF_FAMILY;
459                 addr.sin6_port = htons(resultnotify->GetPort());
460 #else
461                 insp_inaddr ia;
462                 insp_aton("127.0.0.1", &ia);
463                 addr.sin_family = AF_FAMILY;
464                 addr.sin_addr = ia;
465                 addr.sin_port = htons(resultnotify->GetPort());
466 #endif
467
468                 if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
469                 {
470                         /* wtf, we cant connect to it, but we just created it! */
471                         return;
472                 }
473         }
474
475 };
476
477
478 class ModuleSQLite3 : public Module
479 {
480   private:
481         ConnMap connections;
482         unsigned long currid;
483
484   public:
485         ModuleSQLite3(InspIRCd* Me)
486         : Module::Module(Me), currid(0)
487         {
488                 ServerInstance->UseInterface("SQLutils");
489
490                 if (!ServerInstance->PublishFeature("SQL", this))
491                 {
492                         throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");
493                 }
494
495                 resultnotify = new ResultNotifier(ServerInstance, this);
496                 ServerInstance->Log(DEBUG,"Bound notifier to 127.0.0.1:%d",resultnotify->GetPort());
497
498                 ReadConf();
499
500                 ServerInstance->PublishInterface("SQL", this);
501         }
502
503         virtual ~ModuleSQLite3()
504         {
505                 ClearQueue();
506                 ClearAllConnections();
507                 resultnotify->SetFd(-1);
508                 resultnotify->state = I_ERROR;
509                 resultnotify->OnError(I_ERR_SOCKET);
510                 resultnotify->ClosePending = true;
511                 delete resultnotify;
512                 ServerInstance->UnpublishInterface("SQL", this);
513                 ServerInstance->UnpublishFeature("SQL");
514                 ServerInstance->DoneWithInterface("SQLutils");
515         }
516
517         void Implements(char* List)
518         {
519                 List[I_OnRequest] = List[I_OnRehash] = 1;
520         }
521
522         void SendQueue()
523         {
524                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
525                 {
526                         iter->second->SendResults();
527                 }
528         }
529
530         void ClearQueue()
531         {
532                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
533                 {
534                         iter->second->ClearResults();
535                 }
536         }
537
538         bool HasHost(const SQLhost &host)
539         {
540                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
541                 {
542                         if (host == iter->second->GetConfHost())
543                                 return true;
544                 }
545                 return false;
546         }
547
548         bool HostInConf(const SQLhost &h)
549         {
550                 ConfigReader conf(ServerInstance);
551                 for(int i = 0; i < conf.Enumerate("database"); i++)
552                 {
553                         SQLhost host;
554                         host.id         = conf.ReadValue("database", "id", i);
555                         host.host       = conf.ReadValue("database", "hostname", i);
556                         host.port       = conf.ReadInteger("database", "port", i, true);
557                         host.name       = conf.ReadValue("database", "name", i);
558                         host.user       = conf.ReadValue("database", "username", i);
559                         host.pass       = conf.ReadValue("database", "password", i);
560                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
561                         if (h == host)
562                                 return true;
563                 }
564                 return false;
565         }
566
567         void ReadConf()
568         {
569                 ClearOldConnections();
570
571                 ConfigReader conf(ServerInstance);
572                 for(int i = 0; i < conf.Enumerate("database"); i++)
573                 {
574                         SQLhost host;
575
576                         host.id         = conf.ReadValue("database", "id", i);
577                         host.host       = conf.ReadValue("database", "hostname", i);
578                         host.port       = conf.ReadInteger("database", "port", i, true);
579                         host.name       = conf.ReadValue("database", "name", i);
580                         host.user       = conf.ReadValue("database", "username", i);
581                         host.pass       = conf.ReadValue("database", "password", i);
582                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
583
584                         if (HasHost(host))
585                                 continue;
586
587                         this->AddConn(host);
588                 }
589         }
590
591         void AddConn(const SQLhost& hi)
592         {
593                 if (HasHost(hi))
594                 {
595                         ServerInstance->Log(DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
596                         return;
597                 }
598
599                 SQLConn* newconn;
600
601                 newconn = new SQLConn(ServerInstance, this, hi);
602
603                 connections.insert(std::make_pair(hi.id, newconn));
604         }
605
606         void ClearOldConnections()
607         {
608                 ConnMap::iterator iter,safei;
609                 for (iter = connections.begin(); iter != connections.end(); iter++)
610                 {
611                         if (!HostInConf(iter->second->GetConfHost()))
612                         {
613                                 DELETE(iter->second);
614                                 safei = iter;
615                                 --iter;
616                                 connections.erase(safei);
617                         }
618                 }
619         }
620
621         void ClearAllConnections()
622         {
623                 ConnMap::iterator i;
624                 while ((i = connections.begin()) != connections.end())
625                 {
626                         connections.erase(i);
627                         DELETE(i->second);
628                 }
629         }
630
631         virtual void OnRehash(userrec* user, const std::string &parameter)
632         {
633                 ReadConf();
634         }
635
636         virtual char* OnRequest(Request* request)
637         {
638                 if(strcmp(SQLREQID, request->GetId()) == 0)
639                 {
640                         SQLrequest* req = (SQLrequest*)request;
641                         ConnMap::iterator iter;
642                         ServerInstance->Log(DEBUG, "Got query: '%s' with %d replacement parameters on id '%s'", req->query.q.c_str(), req->query.p.size(), req->dbid.c_str());
643                         if((iter = connections.find(req->dbid)) != connections.end())
644                         {
645                                 req->id = NewID();
646                                 req->error = iter->second->Query(*req);
647                                 return SQLSUCCESS;
648                         }
649                         else
650                         {
651                                 req->error.Id(BAD_DBID);
652                                 return NULL;
653                         }
654                 }
655                 ServerInstance->Log(DEBUG, "Got unsupported API version string: %s", request->GetId());
656                 return NULL;
657         }
658
659         unsigned long NewID()
660         {
661                 if (currid+1 == 0)
662                         currid++;
663
664                 return ++currid;
665         }
666
667         virtual Version GetVersion()
668         {
669                 return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);
670         }
671
672 };
673
674 void ResultNotifier::Dispatch()
675 {
676         ((ModuleSQLite3*)mod)->SendQueue();
677 }
678
679 class ModuleSQLite3Factory : public ModuleFactory
680 {
681   public:
682         ModuleSQLite3Factory()
683         {
684         }
685
686         ~ModuleSQLite3Factory()
687         {
688         }
689
690         virtual Module * CreateModule(InspIRCd* Me)
691         {
692                 return new ModuleSQLite3(Me);
693         }
694 };
695
696 extern "C" void * init_module( void )
697 {
698         return new ModuleSQLite3Factory;
699 }