]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
Now with SQLite3 support. Fully functional and (hopefully) working.
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2007 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *            the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include <string>
15 #include <deque>
16 #include <map>
17 #include <sqlite3.h>
18
19 #include "users.h"
20 #include "channels.h"
21 #include "modules.h"
22 #include "inspircd.h"
23 #include "configreader.h"
24
25 #include "m_sqlv2.h"
26
27 /* $ModDesc: sqlite3 provider */
28 /* $CompileFlags: pkgconfincludes("sqlite3","/sqlite3.h","") */
29 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
30 /* $ModDep: m_sqlv2.h */
31
32
33 class SQLConn;
34 class SQLite3Result;
35 class ResultNotifier;
36
37 typedef std::map<std::string, SQLConn*> ConnMap;
38 typedef std::deque<classbase*> paramlist;
39 typedef std::deque<SQLite3Result*> ResultQueue;
40
41 ResultNotifier* resultnotify = NULL;
42
43
44 class ResultNotifier : public InspSocket
45 {
46         Module* mod;
47         insp_sockaddr sock_us;
48         socklen_t uslen;
49
50  public:
51         /* Create a socket on a random port. Let the tcp stack allocate us an available port */
52 #ifdef IPV6
53         ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "::1", 0, true, 3000), mod(m)
54 #else
55         ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "127.0.0.1", 0, true, 3000), mod(m)
56 #endif
57         {
58                 uslen = sizeof(sock_us);
59                 if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
60                 {
61                         throw ModuleException("Could not create random listening port on localhost");
62                 }
63         }
64
65         ResultNotifier(InspIRCd* SI, Module* m, int newfd, char* ip) : InspSocket(SI, newfd, ip), mod(m)
66         {
67                 Instance->Log(DEBUG,"Constructor of new socket");
68         }
69
70         /* Using getsockname and ntohs, we can determine which port number we were allocated */
71         int GetPort()
72         {
73 #ifdef IPV6
74                 return ntohs(sock_us.sin6_port);
75 #else
76                 return ntohs(sock_us.sin_port);
77 #endif
78         }
79
80         virtual int OnIncomingConnection(int newsock, char* ip)
81         {
82                 Instance->Log(DEBUG,"Inbound connection on fd %d!",newsock);
83                 Dispatch();
84                 return false;
85         }
86
87         void Dispatch();
88 };
89
90
91 class SQLite3Result : public SQLresult
92 {
93   private:
94         int currentrow;
95         int rows;
96         int cols;
97
98         std::vector<std::string> colnames;
99         std::vector<SQLfieldList> fieldlists;
100         SQLfieldList emptyfieldlist;
101
102         SQLfieldList* fieldlist;
103         SQLfieldMap* fieldmap;
104
105   public:
106         SQLite3Result(Module* self, Module* to, unsigned int id)
107         : SQLresult(self, to, id), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
108         {
109         }
110
111         ~SQLite3Result()
112         {
113         }
114
115         void AddRow(int colsnum, char **data, char **colname)
116         {
117                 colnames.clear();
118                 cols = colsnum;
119                 for (int i = 0; i < colsnum; i++)
120                 {
121                         fieldlists.resize(fieldlists.size()+1);
122                         colnames.push_back(colname[i]);
123                         SQLfield sf(data[i] ? data[i] : "", data[i] ? false : true);
124                         fieldlists[rows].push_back(sf);\r
125                 }
126                 rows++;
127         }
128
129         virtual int Rows()
130         {
131                 return rows;
132         }
133
134         virtual int Cols()
135         {
136                 return cols;
137         }
138
139         virtual std::string ColName(int column)
140         {
141                 if (column < (int)colnames.size())
142                 {
143                         return colnames[column];
144                 }
145                 else
146                 {
147                         throw SQLbadColName();
148                 }
149                 return "";
150         }
151
152         virtual int ColNum(const std::string &column)
153         {
154                 for (unsigned int i = 0; i < colnames.size(); i++)
155                 {
156                         if (column == colnames[i])
157                                 return i;
158                 }
159                 throw SQLbadColName();
160                 return 0;
161         }
162
163         virtual SQLfield GetValue(int row, int column)
164         {
165                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
166                 {
167                         return fieldlists[row][column];
168                 }
169
170                 throw SQLbadColName();
171
172                 /* XXX: We never actually get here because of the throw */
173                 return SQLfield("",true);
174         }
175
176         virtual SQLfieldList& GetRow()
177         {
178                 if (currentrow < rows)
179                         return fieldlists[currentrow];
180                 else
181                         return emptyfieldlist;
182         }
183
184         virtual SQLfieldMap& GetRowMap()
185         {
186                 /* In an effort to reduce overhead we don't actually allocate the map
187                  * until the first time it's needed...so...
188                  */
189                 if(fieldmap)
190                 {
191                         fieldmap->clear();
192                 }
193                 else
194                 {
195                         fieldmap = new SQLfieldMap;
196                 }
197
198                 if (currentrow < rows)
199                 {
200                         for (int i = 0; i < Cols(); i++)
201                         {
202                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
203                         }
204                         currentrow++;
205                 }
206
207                 return *fieldmap;
208         }
209
210         virtual SQLfieldList* GetRowPtr()
211         {
212                 fieldlist = new SQLfieldList();
213
214                 if (currentrow < rows)
215                 {
216                         for (int i = 0; i < Rows(); i++)
217                         {
218                                 fieldlist->push_back(fieldlists[currentrow][i]);
219                         }
220                         currentrow++;
221                 }
222                 return fieldlist;
223         }
224
225         virtual SQLfieldMap* GetRowMapPtr()
226         {
227                 fieldmap = new SQLfieldMap();
228
229                 if (currentrow < rows)
230                 {
231                         for (int i = 0; i < Cols(); i++)
232                         {
233                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
234                         }
235                         currentrow++;
236                 }
237
238                 return fieldmap;
239         }
240
241         virtual void Free(SQLfieldMap* fm)
242         {
243                 delete fm;
244         }
245
246         virtual void Free(SQLfieldList* fl)
247         {
248                 delete fl;
249         }
250
251
252 };
253
254 class SQLConn : public classbase
255 {
256   private:
257         ResultQueue results;
258         InspIRCd* Instance;
259         Module* mod;
260         SQLhost host;
261         sqlite3* conn;
262
263   public:
264         SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
265         : Instance(SI), mod(m), host(hi)
266         {
267                 int result;
268                 if ((result = OpenDB()) == SQLITE_OK)
269                 {
270                         Instance->Log(DEBUG, "Opened sqlite DB: " + host.host);
271                 }
272                 else
273                 {
274                         Instance->Log(DEFAULT, "WARNING: Could not open DB with id: " + host.id);
275                         CloseDB();
276                 }
277         }
278
279         ~SQLConn()
280         {
281                 CloseDB();
282         }
283
284         SQLerror Query(SQLrequest &req)
285         {
286                 /* Pointer to the buffer we screw around with substitution in */
287                 char* query;
288
289                 /* Pointer to the current end of query, where we append new stuff */
290                 char* queryend;
291
292                 /* Total length of the unescaped parameters */
293                 unsigned long paramlen;
294
295                 /* Total length of query, used for binary-safety in mysql_real_query */
296                 unsigned long querylength = 0;
297
298                 paramlen = 0;
299                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
300                 {
301                         paramlen += i->size();
302                 }
303
304                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
305                  * sizeofquery + (totalparamlength*2) + 1
306                  *
307                  * The +1 is for null-terminating the string for mysql_real_escape_string
308                  */
309
310                 query = new char[req.query.q.length() + (paramlen*2) + 1];
311                 queryend = query;
312
313                 for(unsigned long i = 0; i < req.query.q.length(); i++)
314                 {
315                         if(req.query.q[i] == '?')
316                         {
317                                 if(req.query.p.size())
318                                 {
319                                         char* escaped;
320                                         escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
321                                         for (char* n = escaped; *n; n++)
322                                         {
323                                                 *queryend = *n;
324                                                 queryend++;
325                                         }
326                                         sqlite3_free(escaped);
327                                         req.query.p.pop_front();
328                                 }
329                                 else
330                                         break;
331                         }
332                         else
333                         {
334                                 *queryend = req.query.q[i];
335                                 queryend++;
336                         }
337                         querylength++;
338                 }
339                 *queryend = 0;
340                 req.query.q = query;
341
342 //              Instance->Log(DEBUG, "<******> Doing query: " + ConvToStr(req.query.q.data()));
343
344                 SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
345                 res->dbid = host.id;
346                 res->query = req.query.q;
347                 paramlist params;
348                 params.push_back(this);
349                 params.push_back(res);
350
351                 char *errmsg = 0;
352                 if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
353                 {
354                         Instance->Log(DEBUG, "Query failed: " + ConvToStr(errmsg));
355                         sqlite3_free(errmsg);
356                         delete[] query;
357                         delete res;
358                         return SQLerror(QSEND_FAIL, ConvToStr(errmsg));
359                 }
360                 Instance->Log(DEBUG, "Dispatched query successfully. ID: %d resulting rows %d", req.id, res->Rows());
361                 delete[] query;
362
363                 results.push_back(res);
364                 SendNotify();
365                 return SQLerror();
366         }
367
368         static int QueryResult(void *params, int argc, char **argv, char **azColName)
369         {
370                 paramlist* p = (paramlist*)params;
371                 ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
372                 return 0;
373         }
374
375         void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
376         {
377                 res->AddRow(cols, data, colnames);
378         }
379
380         int OpenDB()
381         {
382                 return sqlite3_open(host.host.c_str(), &conn);
383         }
384
385         void CloseDB()
386         {
387                 sqlite3_interrupt(conn);
388                 sqlite3_close(conn);
389         }
390
391         SQLhost GetConfHost()
392         {
393                 return host;
394         }
395
396         void SendResults()
397         {
398                 while (results.size())
399                 {
400                         SQLite3Result* res = results[0];
401                         if (res->GetDest())
402                         {
403                                 res->Send();
404                         }
405                         else
406                         {
407                                 /* If the client module is unloaded partway through a query then the provider will set
408                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
409                                  * through at some point...and it could get messy if we play with invalid pointers...
410                                  */
411                                 Instance->Log(DEBUG, "Looks like we're handling a zombie query from a module which unloaded before it got a result..fun. ID: " + ConvToStr(res->GetId()));
412                                 delete res;
413                         }
414                         results.pop_front();
415                 }
416         }
417
418         void ClearResults()
419         {
420                 while (results.size())
421                 {
422                         SQLite3Result* res = results[0];
423                         delete res;
424                         results.pop_front();
425                 }
426         }
427
428         void SendNotify()
429         {
430                 int QueueFD;
431                 if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
432                 {
433                         /* crap, we're out of sockets... */
434                         return;
435                 }
436
437                 insp_sockaddr addr;
438
439 #ifdef IPV6
440                 insp_aton("::1", &addr.sin6_addr);
441                 addr.sin6_family = AF_FAMILY;
442                 addr.sin6_port = htons(resultnotify->GetPort());
443 #else
444                 insp_inaddr ia;
445                 insp_aton("127.0.0.1", &ia);
446                 addr.sin_family = AF_FAMILY;
447                 addr.sin_addr = ia;
448                 addr.sin_port = htons(resultnotify->GetPort());
449 #endif
450
451                 if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
452                 {
453                         /* wtf, we cant connect to it, but we just created it! */
454                         return;
455                 }
456                 send(QueueFD, "\n", 2, 0);
457         }
458
459 };
460
461
462 class ModuleSQLite3 : public Module
463 {
464   private:
465         ConnMap connections;
466         unsigned long currid;
467
468   public:
469         ModuleSQLite3(InspIRCd* Me)
470         : Module::Module(Me), currid(0)
471         {
472                 ServerInstance->UseInterface("SQLutils");
473
474                 if (!ServerInstance->PublishFeature("SQL", this))
475                 {
476                         throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");
477                 }
478
479                 resultnotify = new ResultNotifier(ServerInstance, this);
480                 ServerInstance->Log(DEBUG,"Bound notifier to 127.0.0.1:%d",resultnotify->GetPort());
481
482                 ReadConf();
483
484                 ServerInstance->PublishInterface("SQL", this);
485         }
486
487         virtual ~ModuleSQLite3()
488         {
489                 ClearQueue();
490                 ClearAllConnections();
491                 resultnotify->SetFd(-1);
492                 resultnotify->state = I_ERROR;
493                 resultnotify->OnError(I_ERR_SOCKET);
494                 resultnotify->ClosePending = true;
495                 if (!ServerInstance->SE->DelFd(resultnotify))
496                 {
497                         ServerInstance->Log(DEBUG, "m_sqlite3: unable to remove notifier from socket engine!");
498                 }
499                 delete resultnotify;
500                 ServerInstance->UnpublishInterface("SQL", this);
501                 ServerInstance->UnpublishFeature("SQL");
502                 ServerInstance->DoneWithInterface("SQLutils");
503         }
504
505         void Implements(char* List)
506         {
507                 List[I_OnRequest] = List[I_OnRequest] = 1;
508         }
509
510         void SendQueue()
511         {
512                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
513                 {
514                         iter->second->SendResults();
515                 }
516         }
517
518         void ClearQueue()
519         {
520                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
521                 {
522                         iter->second->ClearResults();
523                 }
524         }
525
526         bool HasHost(const SQLhost &host)
527         {
528                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
529                 {
530                         if (host == iter->second->GetConfHost())
531                                 return true;
532                 }
533                 return false;
534         }
535
536         bool HostInConf(const SQLhost &h)
537         {
538                 ConfigReader conf(ServerInstance);
539                 for(int i = 0; i < conf.Enumerate("database"); i++)
540                 {
541                         SQLhost host;
542                         host.id         = conf.ReadValue("database", "id", i);
543                         host.host       = conf.ReadValue("database", "hostname", i);
544                         host.port       = conf.ReadInteger("database", "port", i, true);
545                         host.name       = conf.ReadValue("database", "name", i);
546                         host.user       = conf.ReadValue("database", "username", i);
547                         host.pass       = conf.ReadValue("database", "password", i);
548                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
549                         if (h == host)
550                                 return true;
551                 }
552                 return false;
553         }
554
555         void ReadConf()
556         {
557                 ClearOldConnections();
558
559                 ConfigReader conf(ServerInstance);
560                 for(int i = 0; i < conf.Enumerate("database"); i++)
561                 {
562                         SQLhost host;
563
564                         host.id         = conf.ReadValue("database", "id", i);
565                         host.host       = conf.ReadValue("database", "hostname", i);
566                         host.port       = conf.ReadInteger("database", "port", i, true);
567                         host.name       = conf.ReadValue("database", "name", i);
568                         host.user       = conf.ReadValue("database", "username", i);
569                         host.pass       = conf.ReadValue("database", "password", i);
570                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
571
572                         if (HasHost(host))
573                                 continue;
574
575                         this->AddConn(host);
576                 }
577         }
578
579         void AddConn(const SQLhost& hi)
580         {
581                 if (HasHost(hi))
582                 {
583                         ServerInstance->Log(DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
584                         return;
585                 }
586
587                 SQLConn* newconn;
588
589                 newconn = new SQLConn(ServerInstance, this, hi);
590
591                 connections.insert(std::make_pair(hi.id, newconn));
592         }
593
594         void ClearOldConnections()
595         {
596                 ConnMap::iterator iter,safei;
597                 for (iter = connections.begin(); iter != connections.end(); iter++)
598                 {
599                         if (!HostInConf(iter->second->GetConfHost()))
600                         {
601                                 DELETE(iter->second);
602                                 safei = iter;
603                                 --iter;
604                                 connections.erase(safei);
605                         }
606                 }
607         }
608
609         void ClearAllConnections()
610         {
611                 ConnMap::iterator i;
612                 while ((i = connections.begin()) != connections.end())
613                 {
614                         connections.erase(i);
615                         DELETE(i->second);
616                 }
617         }
618
619         virtual void OnRehash(userrec* user, const std::string &parameter)
620         {
621                 ReadConf();
622         }
623
624         virtual char* OnRequest(Request* request)
625         {
626                 if(strcmp(SQLREQID, request->GetId()) == 0)
627                 {
628                         SQLrequest* req = (SQLrequest*)request;
629                         ConnMap::iterator iter;
630                         ServerInstance->Log(DEBUG, "Got query: '%s' with %d replacement parameters on id '%s'", req->query.q.c_str(), req->query.p.size(), req->dbid.c_str());
631                         if((iter = connections.find(req->dbid)) != connections.end())
632                         {
633                                 req->id = NewID();
634                                 req->error = iter->second->Query(*req);
635                                 return SQLSUCCESS;
636                         }
637                         else
638                         {
639                                 req->error.Id(BAD_DBID);
640                                 return NULL;
641                         }
642                 }
643                 ServerInstance->Log(DEBUG, "Got unsupported API version string: %s", request->GetId());
644                 return NULL;
645         }
646
647         unsigned long NewID()
648         {
649                 if (currid+1 == 0)
650                         currid++;
651
652                 return ++currid;
653         }
654
655         virtual Version GetVersion()
656         {
657                 return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);
658         }
659
660 };
661
662 void ResultNotifier::Dispatch()
663 {
664         ((ModuleSQLite3*)mod)->SendQueue();
665 }
666
667 class ModuleSQLite3Factory : public ModuleFactory
668 {
669   public:
670         ModuleSQLite3Factory()
671         {
672         }
673
674         ~ModuleSQLite3Factory()
675         {
676         }
677
678         virtual Module * CreateModule(InspIRCd* Me)
679         {
680                 return new ModuleSQLite3(Me);
681         }
682 };
683
684 extern "C" void * init_module( void )
685 {
686         return new ModuleSQLite3Factory;
687 }