]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
Remove/fix unused variable warning
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2007 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *            the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include <string>
15 #include <deque>
16 #include <map>
17 #include <sqlite3.h>
18
19 #include "users.h"
20 #include "channels.h"
21 #include "modules.h"
22 #include "inspircd.h"
23 #include "configreader.h"
24
25 #include "m_sqlv2.h"
26
27 /* $ModDesc: sqlite3 provider */
28 /* $CompileFlags: pkgconfincludes("sqlite3","/sqlite3.h","") */
29 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
30 /* $ModDep: m_sqlv2.h */
31
32
33 class SQLConn;
34 class SQLite3Result;
35 class ResultNotifier;
36
37 typedef std::map<std::string, SQLConn*> ConnMap;
38 typedef std::deque<classbase*> paramlist;
39 typedef std::deque<SQLite3Result*> ResultQueue;
40
41 ResultNotifier* resultnotify = NULL;
42
43
44 class ResultNotifier : public InspSocket
45 {
46         Module* mod;
47         insp_sockaddr sock_us;
48         socklen_t uslen;
49
50  public:
51         /* Create a socket on a random port. Let the tcp stack allocate us an available port */
52 #ifdef IPV6
53         ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "::1", 0, true, 3000), mod(m)
54 #else
55         ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "127.0.0.1", 0, true, 3000), mod(m)
56 #endif
57         {
58                 uslen = sizeof(sock_us);
59                 if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
60                 {
61                         throw ModuleException("Could not create random listening port on localhost");
62                 }
63         }
64
65         ResultNotifier(InspIRCd* SI, Module* m, int newfd, char* ip) : InspSocket(SI, newfd, ip), mod(m)
66         {
67                 Instance->Log(DEBUG,"Constructor of new socket");
68         }
69
70         /* Using getsockname and ntohs, we can determine which port number we were allocated */
71         int GetPort()
72         {
73 #ifdef IPV6
74                 return ntohs(sock_us.sin6_port);
75 #else
76                 return ntohs(sock_us.sin_port);
77 #endif
78         }
79
80         virtual int OnIncomingConnection(int newsock, char* ip)
81         {
82                 Instance->Log(DEBUG,"Inbound connection on fd %d!",newsock);
83                 Dispatch();
84                 return false;
85         }
86
87         void Dispatch();
88 };
89
90
91 class SQLite3Result : public SQLresult
92 {
93   private:
94         int currentrow;
95         int rows;
96         int cols;
97
98         std::vector<std::string> colnames;
99         std::vector<SQLfieldList> fieldlists;
100         SQLfieldList emptyfieldlist;
101
102         SQLfieldList* fieldlist;
103         SQLfieldMap* fieldmap;
104
105   public:
106         SQLite3Result(Module* self, Module* to, unsigned int id)
107         : SQLresult(self, to, id), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
108         {
109         }
110
111         ~SQLite3Result()
112         {
113         }
114
115         void AddRow(int colsnum, char **data, char **colname)
116         {
117                 colnames.clear();
118                 cols = colsnum;
119                 for (int i = 0; i < colsnum; i++)
120                 {
121                         fieldlists.resize(fieldlists.size()+1);
122                         colnames.push_back(colname[i]);
123                         SQLfield sf(data[i] ? data[i] : "", data[i] ? false : true);
124                         fieldlists[rows].push_back(sf);
125                 }
126                 rows++;
127         }
128
129         void UpdateAffectedCount()
130         {
131                 rows++;
132         }
133
134         virtual int Rows()
135         {
136                 return rows;
137         }
138
139         virtual int Cols()
140         {
141                 return cols;
142         }
143
144         virtual std::string ColName(int column)
145         {
146                 if (column < (int)colnames.size())
147                 {
148                         return colnames[column];
149                 }
150                 else
151                 {
152                         throw SQLbadColName();
153                 }
154                 return "";
155         }
156
157         virtual int ColNum(const std::string &column)
158         {
159                 for (unsigned int i = 0; i < colnames.size(); i++)
160                 {
161                         if (column == colnames[i])
162                                 return i;
163                 }
164                 throw SQLbadColName();
165                 return 0;
166         }
167
168         virtual SQLfield GetValue(int row, int column)
169         {
170                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
171                 {
172                         return fieldlists[row][column];
173                 }
174
175                 throw SQLbadColName();
176
177                 /* XXX: We never actually get here because of the throw */
178                 return SQLfield("",true);
179         }
180
181         virtual SQLfieldList& GetRow()
182         {
183                 if (currentrow < rows)
184                         return fieldlists[currentrow];
185                 else
186                         return emptyfieldlist;
187         }
188
189         virtual SQLfieldMap& GetRowMap()
190         {
191                 /* In an effort to reduce overhead we don't actually allocate the map
192                  * until the first time it's needed...so...
193                  */
194                 if(fieldmap)
195                 {
196                         fieldmap->clear();
197                 }
198                 else
199                 {
200                         fieldmap = new SQLfieldMap;
201                 }
202
203                 if (currentrow < rows)
204                 {
205                         for (int i = 0; i < Cols(); i++)
206                         {
207                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
208                         }
209                         currentrow++;
210                 }
211
212                 return *fieldmap;
213         }
214
215         virtual SQLfieldList* GetRowPtr()
216         {
217                 fieldlist = new SQLfieldList();
218
219                 if (currentrow < rows)
220                 {
221                         for (int i = 0; i < Rows(); i++)
222                         {
223                                 fieldlist->push_back(fieldlists[currentrow][i]);
224                         }
225                         currentrow++;
226                 }
227                 return fieldlist;
228         }
229
230         virtual SQLfieldMap* GetRowMapPtr()
231         {
232                 fieldmap = new SQLfieldMap();
233
234                 if (currentrow < rows)
235                 {
236                         for (int i = 0; i < Cols(); i++)
237                         {
238                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
239                         }
240                         currentrow++;
241                 }
242
243                 return fieldmap;
244         }
245
246         virtual void Free(SQLfieldMap* fm)
247         {
248                 delete fm;
249         }
250
251         virtual void Free(SQLfieldList* fl)
252         {
253                 delete fl;
254         }
255
256
257 };
258
259 class SQLConn : public classbase
260 {
261   private:
262         ResultQueue results;
263         InspIRCd* Instance;
264         Module* mod;
265         SQLhost host;
266         sqlite3* conn;
267
268   public:
269         SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
270         : Instance(SI), mod(m), host(hi)
271         {
272                 if (OpenDB() == SQLITE_OK)
273                 {
274                         Instance->Log(DEBUG, "Opened sqlite DB: " + host.host);
275                 }
276                 else
277                 {
278                         Instance->Log(DEFAULT, "WARNING: Could not open DB with id: " + host.id);
279                         CloseDB();
280                 }
281         }
282
283         ~SQLConn()
284         {
285                 CloseDB();
286         }
287
288         SQLerror Query(SQLrequest &req)
289         {
290                 /* Pointer to the buffer we screw around with substitution in */
291                 char* query;
292
293                 /* Pointer to the current end of query, where we append new stuff */
294                 char* queryend;
295
296                 /* Total length of the unescaped parameters */
297                 unsigned long paramlen;
298
299                 /* Total length of query, used for binary-safety in mysql_real_query */
300                 unsigned long querylength = 0;
301
302                 paramlen = 0;
303                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
304                 {
305                         paramlen += i->size();
306                 }
307
308                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
309                  * sizeofquery + (totalparamlength*2) + 1
310                  *
311                  * The +1 is for null-terminating the string for mysql_real_escape_string
312                  */
313                 query = new char[req.query.q.length() + (paramlen*2) + 1];
314                 queryend = query;
315
316                 for(unsigned long i = 0; i < req.query.q.length(); i++)
317                 {
318                         if(req.query.q[i] == '?')
319                         {
320                                 if(req.query.p.size())
321                                 {
322                                         char* escaped;
323                                         escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
324                                         for (char* n = escaped; *n; n++)
325                                         {
326                                                 *queryend = *n;
327                                                 queryend++;
328                                         }
329                                         sqlite3_free(escaped);
330                                         req.query.p.pop_front();
331                                 }
332                                 else
333                                         break;
334                         }
335                         else
336                         {
337                                 *queryend = req.query.q[i];
338                                 queryend++;
339                         }
340                         querylength++;
341                 }
342                 *queryend = 0;
343                 req.query.q = query;
344
345                 SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
346                 res->dbid = host.id;
347                 res->query = req.query.q;
348                 paramlist params;
349                 params.push_back(this);
350                 params.push_back(res);
351
352                 char *errmsg = 0;
353                 sqlite3_update_hook(conn, QueryUpdateHook, &params);
354                 if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
355                 {
356                         std::string error(errmsg);
357                         sqlite3_free(errmsg);
358                         Instance->Log(DEBUG, "Query failed: " + ConvToStr(errmsg));
359                         delete[] query;
360                         delete res;
361                         return SQLerror(QSEND_FAIL, error);
362                 }
363                 Instance->Log(DEBUG, "Dispatched query successfully. ID: %d resulting rows %d", req.id, res->Rows());
364                 delete[] query;
365
366                 results.push_back(res);
367                 SendNotify();
368                 return SQLerror();
369         }
370
371         static int QueryResult(void *params, int argc, char **argv, char **azColName)
372         {
373                 paramlist* p = (paramlist*)params;
374                 ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
375                 return 0;
376         }
377
378         static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid)
379         {
380                 paramlist* p = (paramlist*)params;
381                 ((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1]));
382         }\r
383
384         void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
385         {
386                 res->AddRow(cols, data, colnames);
387         }
388
389         void AffectedReady(SQLite3Result *res)
390         {
391                 res->UpdateAffectedCount();
392         }
393
394         int OpenDB()
395         {
396                 return sqlite3_open(host.host.c_str(), &conn);
397         }
398
399         void CloseDB()
400         {
401                 sqlite3_interrupt(conn);
402                 sqlite3_close(conn);
403                 Instance->Log(DEBUG, "Closed sqlite DB: " + host.host);
404         }
405
406         SQLhost GetConfHost()
407         {
408                 return host;
409         }
410
411         void SendResults()
412         {
413                 while (results.size())
414                 {
415                         SQLite3Result* res = results[0];
416                         if (res->GetDest())
417                         {
418                                 res->Send();
419                         }
420                         else
421                         {
422                                 /* If the client module is unloaded partway through a query then the provider will set
423                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
424                                  * through at some point...and it could get messy if we play with invalid pointers...
425                                  */
426                                 Instance->Log(DEBUG, "Looks like we're handling a zombie query from a module which unloaded before it got a result..fun. ID: " + ConvToStr(res->GetId()));
427                                 delete res;
428                         }
429                         results.pop_front();
430                 }
431         }
432
433         void ClearResults()
434         {
435                 while (results.size())
436                 {
437                         SQLite3Result* res = results[0];
438                         delete res;
439                         results.pop_front();
440                 }
441         }
442
443         void SendNotify()
444         {
445                 int QueueFD;
446                 if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
447                 {
448                         /* crap, we're out of sockets... */
449                         return;
450                 }
451
452                 insp_sockaddr addr;
453
454 #ifdef IPV6
455                 insp_aton("::1", &addr.sin6_addr);
456                 addr.sin6_family = AF_FAMILY;
457                 addr.sin6_port = htons(resultnotify->GetPort());
458 #else
459                 insp_inaddr ia;
460                 insp_aton("127.0.0.1", &ia);
461                 addr.sin_family = AF_FAMILY;
462                 addr.sin_addr = ia;
463                 addr.sin_port = htons(resultnotify->GetPort());
464 #endif
465
466                 if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
467                 {
468                         /* wtf, we cant connect to it, but we just created it! */
469                         return;
470                 }
471         }
472
473 };
474
475
476 class ModuleSQLite3 : public Module
477 {
478   private:
479         ConnMap connections;
480         unsigned long currid;
481
482   public:
483         ModuleSQLite3(InspIRCd* Me)
484         : Module::Module(Me), currid(0)
485         {
486                 ServerInstance->UseInterface("SQLutils");
487
488                 if (!ServerInstance->PublishFeature("SQL", this))
489                 {
490                         throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");
491                 }
492
493                 resultnotify = new ResultNotifier(ServerInstance, this);
494                 ServerInstance->Log(DEBUG,"Bound notifier to 127.0.0.1:%d",resultnotify->GetPort());
495
496                 ReadConf();
497
498                 ServerInstance->PublishInterface("SQL", this);
499         }
500
501         virtual ~ModuleSQLite3()
502         {
503                 ClearQueue();
504                 ClearAllConnections();
505                 resultnotify->SetFd(-1);
506                 resultnotify->state = I_ERROR;
507                 resultnotify->OnError(I_ERR_SOCKET);
508                 resultnotify->ClosePending = true;
509                 delete resultnotify;
510                 ServerInstance->UnpublishInterface("SQL", this);
511                 ServerInstance->UnpublishFeature("SQL");
512                 ServerInstance->DoneWithInterface("SQLutils");
513         }
514
515         void Implements(char* List)
516         {
517                 List[I_OnRequest] = List[I_OnRehash] = 1;
518         }
519
520         void SendQueue()
521         {
522                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
523                 {
524                         iter->second->SendResults();
525                 }
526         }
527
528         void ClearQueue()
529         {
530                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
531                 {
532                         iter->second->ClearResults();
533                 }
534         }
535
536         bool HasHost(const SQLhost &host)
537         {
538                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
539                 {
540                         if (host == iter->second->GetConfHost())
541                                 return true;
542                 }
543                 return false;
544         }
545
546         bool HostInConf(const SQLhost &h)
547         {
548                 ConfigReader conf(ServerInstance);
549                 for(int i = 0; i < conf.Enumerate("database"); i++)
550                 {
551                         SQLhost host;
552                         host.id         = conf.ReadValue("database", "id", i);
553                         host.host       = conf.ReadValue("database", "hostname", i);
554                         host.port       = conf.ReadInteger("database", "port", i, true);
555                         host.name       = conf.ReadValue("database", "name", i);
556                         host.user       = conf.ReadValue("database", "username", i);
557                         host.pass       = conf.ReadValue("database", "password", i);
558                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
559                         if (h == host)
560                                 return true;
561                 }
562                 return false;
563         }
564
565         void ReadConf()
566         {
567                 ClearOldConnections();
568
569                 ConfigReader conf(ServerInstance);
570                 for(int i = 0; i < conf.Enumerate("database"); i++)
571                 {
572                         SQLhost host;
573
574                         host.id         = conf.ReadValue("database", "id", i);
575                         host.host       = conf.ReadValue("database", "hostname", i);
576                         host.port       = conf.ReadInteger("database", "port", i, true);
577                         host.name       = conf.ReadValue("database", "name", i);
578                         host.user       = conf.ReadValue("database", "username", i);
579                         host.pass       = conf.ReadValue("database", "password", i);
580                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
581
582                         if (HasHost(host))
583                                 continue;
584
585                         this->AddConn(host);
586                 }
587         }
588
589         void AddConn(const SQLhost& hi)
590         {
591                 if (HasHost(hi))
592                 {
593                         ServerInstance->Log(DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
594                         return;
595                 }
596
597                 SQLConn* newconn;
598
599                 newconn = new SQLConn(ServerInstance, this, hi);
600
601                 connections.insert(std::make_pair(hi.id, newconn));
602         }
603
604         void ClearOldConnections()
605         {
606                 ConnMap::iterator iter,safei;
607                 for (iter = connections.begin(); iter != connections.end(); iter++)
608                 {
609                         if (!HostInConf(iter->second->GetConfHost()))
610                         {
611                                 DELETE(iter->second);
612                                 safei = iter;
613                                 --iter;
614                                 connections.erase(safei);
615                         }
616                 }
617         }
618
619         void ClearAllConnections()
620         {
621                 ConnMap::iterator i;
622                 while ((i = connections.begin()) != connections.end())
623                 {
624                         connections.erase(i);
625                         DELETE(i->second);
626                 }
627         }
628
629         virtual void OnRehash(userrec* user, const std::string &parameter)
630         {
631                 ReadConf();
632         }
633
634         virtual char* OnRequest(Request* request)
635         {
636                 if(strcmp(SQLREQID, request->GetId()) == 0)
637                 {
638                         SQLrequest* req = (SQLrequest*)request;
639                         ConnMap::iterator iter;
640                         ServerInstance->Log(DEBUG, "Got query: '%s' with %d replacement parameters on id '%s'", req->query.q.c_str(), req->query.p.size(), req->dbid.c_str());
641                         if((iter = connections.find(req->dbid)) != connections.end())
642                         {
643                                 req->id = NewID();
644                                 req->error = iter->second->Query(*req);
645                                 return SQLSUCCESS;
646                         }
647                         else
648                         {
649                                 req->error.Id(BAD_DBID);
650                                 return NULL;
651                         }
652                 }
653                 ServerInstance->Log(DEBUG, "Got unsupported API version string: %s", request->GetId());
654                 return NULL;
655         }
656
657         unsigned long NewID()
658         {
659                 if (currid+1 == 0)
660                         currid++;
661
662                 return ++currid;
663         }
664
665         virtual Version GetVersion()
666         {
667                 return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);
668         }
669
670 };
671
672 void ResultNotifier::Dispatch()
673 {
674         ((ModuleSQLite3*)mod)->SendQueue();
675 }
676
677 class ModuleSQLite3Factory : public ModuleFactory
678 {
679   public:
680         ModuleSQLite3Factory()
681         {
682         }
683
684         ~ModuleSQLite3Factory()
685         {
686         }
687
688         virtual Module * CreateModule(InspIRCd* Me)
689         {
690                 return new ModuleSQLite3(Me);
691         }
692 };
693
694 extern "C" void * init_module( void )
695 {
696         return new ModuleSQLite3Factory;
697 }