]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
4a46108ef5fb6063a790f72ba93f73e0941a60a2
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2007 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *            the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include <string>
15 #include <deque>
16 #include <map>
17 #include <sqlite3.h>
18
19 #include "users.h"
20 #include "channels.h"
21 #include "modules.h"
22 #include "inspircd.h"
23 #include "configreader.h"
24
25 #include "m_sqlv2.h"
26
27 /* $ModDesc: sqlite3 provider */
28 /* $CompileFlags: pkgconfincludes("sqlite3","/sqlite3.h","") */
29 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
30 /* $ModDep: m_sqlv2.h */
31
32
33 class SQLConn;
34 class SQLite3Result;
35 class ResultNotifier;
36
37 typedef std::map<std::string, SQLConn*> ConnMap;
38 typedef std::deque<classbase*> paramlist;
39 typedef std::deque<SQLite3Result*> ResultQueue;
40
41 ResultNotifier* resultnotify = NULL;
42
43
44 class ResultNotifier : public InspSocket
45 {
46         Module* mod;
47         insp_sockaddr sock_us;
48         socklen_t uslen;
49
50  public:
51         /* Create a socket on a random port. Let the tcp stack allocate us an available port */
52 #ifdef IPV6
53         ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "::1", 0, true, 3000), mod(m)
54 #else
55         ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "127.0.0.1", 0, true, 3000), mod(m)
56 #endif
57         {
58                 uslen = sizeof(sock_us);
59                 if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
60                 {
61                         throw ModuleException("Could not create random listening port on localhost");
62                 }
63         }
64
65         ResultNotifier(InspIRCd* SI, Module* m, int newfd, char* ip) : InspSocket(SI, newfd, ip), mod(m)
66         {
67         }
68
69         /* Using getsockname and ntohs, we can determine which port number we were allocated */
70         int GetPort()
71         {
72 #ifdef IPV6
73                 return ntohs(sock_us.sin6_port);
74 #else
75                 return ntohs(sock_us.sin_port);
76 #endif
77         }
78
79         virtual int OnIncomingConnection(int newsock, char* ip)
80         {
81                 Dispatch();
82                 return false;
83         }
84
85         void Dispatch();
86 };
87
88
89 class SQLite3Result : public SQLresult
90 {
91   private:
92         int currentrow;
93         int rows;
94         int cols;
95
96         std::vector<std::string> colnames;
97         std::vector<SQLfieldList> fieldlists;
98         SQLfieldList emptyfieldlist;
99
100         SQLfieldList* fieldlist;
101         SQLfieldMap* fieldmap;
102
103   public:
104         SQLite3Result(Module* self, Module* to, unsigned int id)
105         : SQLresult(self, to, id), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
106         {
107         }
108
109         ~SQLite3Result()
110         {
111         }
112
113         void AddRow(int colsnum, char **data, char **colname)
114         {
115                 colnames.clear();
116                 cols = colsnum;
117                 for (int i = 0; i < colsnum; i++)
118                 {
119                         fieldlists.resize(fieldlists.size()+1);
120                         colnames.push_back(colname[i]);
121                         SQLfield sf(data[i] ? data[i] : "", data[i] ? false : true);
122                         fieldlists[rows].push_back(sf);
123                 }
124                 rows++;
125         }
126
127         void UpdateAffectedCount()
128         {
129                 rows++;
130         }
131
132         virtual int Rows()
133         {
134                 return rows;
135         }
136
137         virtual int Cols()
138         {
139                 return cols;
140         }
141
142         virtual std::string ColName(int column)
143         {
144                 if (column < (int)colnames.size())
145                 {
146                         return colnames[column];
147                 }
148                 else
149                 {
150                         throw SQLbadColName();
151                 }
152                 return "";
153         }
154
155         virtual int ColNum(const std::string &column)
156         {
157                 for (unsigned int i = 0; i < colnames.size(); i++)
158                 {
159                         if (column == colnames[i])
160                                 return i;
161                 }
162                 throw SQLbadColName();
163                 return 0;
164         }
165
166         virtual SQLfield GetValue(int row, int column)
167         {
168                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
169                 {
170                         return fieldlists[row][column];
171                 }
172
173                 throw SQLbadColName();
174
175                 /* XXX: We never actually get here because of the throw */
176                 return SQLfield("",true);
177         }
178
179         virtual SQLfieldList& GetRow()
180         {
181                 if (currentrow < rows)
182                         return fieldlists[currentrow];
183                 else
184                         return emptyfieldlist;
185         }
186
187         virtual SQLfieldMap& GetRowMap()
188         {
189                 /* In an effort to reduce overhead we don't actually allocate the map
190                  * until the first time it's needed...so...
191                  */
192                 if(fieldmap)
193                 {
194                         fieldmap->clear();
195                 }
196                 else
197                 {
198                         fieldmap = new SQLfieldMap;
199                 }
200
201                 if (currentrow < rows)
202                 {
203                         for (int i = 0; i < Cols(); i++)
204                         {
205                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
206                         }
207                         currentrow++;
208                 }
209
210                 return *fieldmap;
211         }
212
213         virtual SQLfieldList* GetRowPtr()
214         {
215                 fieldlist = new SQLfieldList();
216
217                 if (currentrow < rows)
218                 {
219                         for (int i = 0; i < Rows(); i++)
220                         {
221                                 fieldlist->push_back(fieldlists[currentrow][i]);
222                         }
223                         currentrow++;
224                 }
225                 return fieldlist;
226         }
227
228         virtual SQLfieldMap* GetRowMapPtr()
229         {
230                 fieldmap = new SQLfieldMap();
231
232                 if (currentrow < rows)
233                 {
234                         for (int i = 0; i < Cols(); i++)
235                         {
236                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
237                         }
238                         currentrow++;
239                 }
240
241                 return fieldmap;
242         }
243
244         virtual void Free(SQLfieldMap* fm)
245         {
246                 delete fm;
247         }
248
249         virtual void Free(SQLfieldList* fl)
250         {
251                 delete fl;
252         }
253
254
255 };
256
257 class SQLConn : public classbase
258 {
259   private:
260         ResultQueue results;
261         InspIRCd* Instance;
262         Module* mod;
263         SQLhost host;
264         sqlite3* conn;
265
266   public:
267         SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
268         : Instance(SI), mod(m), host(hi)
269         {
270                 if (OpenDB() != SQLITE_OK)
271                 {
272                         Instance->Log(DEFAULT, "WARNING: Could not open DB with id: " + host.id);
273                         CloseDB();
274                 }
275         }
276
277         ~SQLConn()
278         {
279                 CloseDB();
280         }
281
282         SQLerror Query(SQLrequest &req)
283         {
284                 /* Pointer to the buffer we screw around with substitution in */
285                 char* query;
286
287                 /* Pointer to the current end of query, where we append new stuff */
288                 char* queryend;
289
290                 /* Total length of the unescaped parameters */
291                 unsigned long paramlen;
292
293                 /* Total length of query, used for binary-safety in mysql_real_query */
294                 unsigned long querylength = 0;
295
296                 paramlen = 0;
297                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
298                 {
299                         paramlen += i->size();
300                 }
301
302                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
303                  * sizeofquery + (totalparamlength*2) + 1
304                  *
305                  * The +1 is for null-terminating the string for mysql_real_escape_string
306                  */
307                 query = new char[req.query.q.length() + (paramlen*2) + 1];
308                 queryend = query;
309
310                 for(unsigned long i = 0; i < req.query.q.length(); i++)
311                 {
312                         if(req.query.q[i] == '?')
313                         {
314                                 if(req.query.p.size())
315                                 {
316                                         char* escaped;
317                                         escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
318                                         for (char* n = escaped; *n; n++)
319                                         {
320                                                 *queryend = *n;
321                                                 queryend++;
322                                         }
323                                         sqlite3_free(escaped);
324                                         req.query.p.pop_front();
325                                 }
326                                 else
327                                         break;
328                         }
329                         else
330                         {
331                                 *queryend = req.query.q[i];
332                                 queryend++;
333                         }
334                         querylength++;
335                 }
336                 *queryend = 0;
337                 req.query.q = query;
338
339                 SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
340                 res->dbid = host.id;
341                 res->query = req.query.q;
342                 paramlist params;
343                 params.push_back(this);
344                 params.push_back(res);
345
346                 char *errmsg = 0;
347                 sqlite3_update_hook(conn, QueryUpdateHook, &params);
348                 if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
349                 {
350                         std::string error(errmsg);
351                         sqlite3_free(errmsg);
352                         delete[] query;
353                         delete res;
354                         return SQLerror(QSEND_FAIL, error);
355                 }
356                 delete[] query;
357
358                 results.push_back(res);
359                 SendNotify();
360                 return SQLerror();
361         }
362
363         static int QueryResult(void *params, int argc, char **argv, char **azColName)
364         {
365                 paramlist* p = (paramlist*)params;
366                 ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
367                 return 0;
368         }
369
370         static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid)
371         {
372                 paramlist* p = (paramlist*)params;
373                 ((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1]));
374         }\r
375
376         void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
377         {
378                 res->AddRow(cols, data, colnames);
379         }
380
381         void AffectedReady(SQLite3Result *res)
382         {
383                 res->UpdateAffectedCount();
384         }
385
386         int OpenDB()
387         {
388                 return sqlite3_open(host.host.c_str(), &conn);
389         }
390
391         void CloseDB()
392         {
393                 sqlite3_interrupt(conn);
394                 sqlite3_close(conn);
395         }
396
397         SQLhost GetConfHost()
398         {
399                 return host;
400         }
401
402         void SendResults()
403         {
404                 while (results.size())
405                 {
406                         SQLite3Result* res = results[0];
407                         if (res->GetDest())
408                         {
409                                 res->Send();
410                         }
411                         else
412                         {
413                                 /* If the client module is unloaded partway through a query then the provider will set
414                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
415                                  * through at some point...and it could get messy if we play with invalid pointers...
416                                  */
417                                 delete res;
418                         }
419                         results.pop_front();
420                 }
421         }
422
423         void ClearResults()
424         {
425                 while (results.size())
426                 {
427                         SQLite3Result* res = results[0];
428                         delete res;
429                         results.pop_front();
430                 }
431         }
432
433         void SendNotify()
434         {
435                 int QueueFD;
436                 if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
437                 {
438                         /* crap, we're out of sockets... */
439                         return;
440                 }
441
442                 insp_sockaddr addr;
443
444 #ifdef IPV6
445                 insp_aton("::1", &addr.sin6_addr);
446                 addr.sin6_family = AF_FAMILY;
447                 addr.sin6_port = htons(resultnotify->GetPort());
448 #else
449                 insp_inaddr ia;
450                 insp_aton("127.0.0.1", &ia);
451                 addr.sin_family = AF_FAMILY;
452                 addr.sin_addr = ia;
453                 addr.sin_port = htons(resultnotify->GetPort());
454 #endif
455
456                 if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
457                 {
458                         /* wtf, we cant connect to it, but we just created it! */
459                         return;
460                 }
461         }
462
463 };
464
465
466 class ModuleSQLite3 : public Module
467 {
468   private:
469         ConnMap connections;
470         unsigned long currid;
471
472   public:
473         ModuleSQLite3(InspIRCd* Me)
474         : Module::Module(Me), currid(0)
475         {
476                 ServerInstance->UseInterface("SQLutils");
477
478                 if (!ServerInstance->PublishFeature("SQL", this))
479                 {
480                         throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");
481                 }
482
483                 resultnotify = new ResultNotifier(ServerInstance, this);
484
485                 ReadConf();
486
487                 ServerInstance->PublishInterface("SQL", this);
488         }
489
490         virtual ~ModuleSQLite3()
491         {
492                 ClearQueue();
493                 ClearAllConnections();
494                 resultnotify->SetFd(-1);
495                 resultnotify->state = I_ERROR;
496                 resultnotify->OnError(I_ERR_SOCKET);
497                 resultnotify->ClosePending = true;
498                 delete resultnotify;
499                 ServerInstance->UnpublishInterface("SQL", this);
500                 ServerInstance->UnpublishFeature("SQL");
501                 ServerInstance->DoneWithInterface("SQLutils");
502         }
503
504         void Implements(char* List)
505         {
506                 List[I_OnRequest] = List[I_OnRehash] = 1;
507         }
508
509         void SendQueue()
510         {
511                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
512                 {
513                         iter->second->SendResults();
514                 }
515         }
516
517         void ClearQueue()
518         {
519                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
520                 {
521                         iter->second->ClearResults();
522                 }
523         }
524
525         bool HasHost(const SQLhost &host)
526         {
527                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
528                 {
529                         if (host == iter->second->GetConfHost())
530                                 return true;
531                 }
532                 return false;
533         }
534
535         bool HostInConf(const SQLhost &h)
536         {
537                 ConfigReader conf(ServerInstance);
538                 for(int i = 0; i < conf.Enumerate("database"); i++)
539                 {
540                         SQLhost host;
541                         host.id         = conf.ReadValue("database", "id", i);
542                         host.host       = conf.ReadValue("database", "hostname", i);
543                         host.port       = conf.ReadInteger("database", "port", i, true);
544                         host.name       = conf.ReadValue("database", "name", i);
545                         host.user       = conf.ReadValue("database", "username", i);
546                         host.pass       = conf.ReadValue("database", "password", i);
547                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
548                         if (h == host)
549                                 return true;
550                 }
551                 return false;
552         }
553
554         void ReadConf()
555         {
556                 ClearOldConnections();
557
558                 ConfigReader conf(ServerInstance);
559                 for(int i = 0; i < conf.Enumerate("database"); i++)
560                 {
561                         SQLhost host;
562
563                         host.id         = conf.ReadValue("database", "id", i);
564                         host.host       = conf.ReadValue("database", "hostname", i);
565                         host.port       = conf.ReadInteger("database", "port", i, true);
566                         host.name       = conf.ReadValue("database", "name", i);
567                         host.user       = conf.ReadValue("database", "username", i);
568                         host.pass       = conf.ReadValue("database", "password", i);
569                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
570
571                         if (HasHost(host))
572                                 continue;
573
574                         this->AddConn(host);
575                 }
576         }
577
578         void AddConn(const SQLhost& hi)
579         {
580                 if (HasHost(hi))
581                 {
582                         ServerInstance->Log(DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
583                         return;
584                 }
585
586                 SQLConn* newconn;
587
588                 newconn = new SQLConn(ServerInstance, this, hi);
589
590                 connections.insert(std::make_pair(hi.id, newconn));
591         }
592
593         void ClearOldConnections()
594         {
595                 ConnMap::iterator iter,safei;
596                 for (iter = connections.begin(); iter != connections.end(); iter++)
597                 {
598                         if (!HostInConf(iter->second->GetConfHost()))
599                         {
600                                 DELETE(iter->second);
601                                 safei = iter;
602                                 --iter;
603                                 connections.erase(safei);
604                         }
605                 }
606         }
607
608         void ClearAllConnections()
609         {
610                 ConnMap::iterator i;
611                 while ((i = connections.begin()) != connections.end())
612                 {
613                         connections.erase(i);
614                         DELETE(i->second);
615                 }
616         }
617
618         virtual void OnRehash(userrec* user, const std::string &parameter)
619         {
620                 ReadConf();
621         }
622
623         virtual char* OnRequest(Request* request)
624         {
625                 if(strcmp(SQLREQID, request->GetId()) == 0)
626                 {
627                         SQLrequest* req = (SQLrequest*)request;
628                         ConnMap::iterator iter;
629                         if((iter = connections.find(req->dbid)) != connections.end())
630                         {
631                                 req->id = NewID();
632                                 req->error = iter->second->Query(*req);
633                                 return SQLSUCCESS;
634                         }
635                         else
636                         {
637                                 req->error.Id(BAD_DBID);
638                                 return NULL;
639                         }
640                 }
641                 return NULL;
642         }
643
644         unsigned long NewID()
645         {
646                 if (currid+1 == 0)
647                         currid++;
648
649                 return ++currid;
650         }
651
652         virtual Version GetVersion()
653         {
654                 return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);
655         }
656
657 };
658
659 void ResultNotifier::Dispatch()
660 {
661         ((ModuleSQLite3*)mod)->SendQueue();
662 }
663
664 class ModuleSQLite3Factory : public ModuleFactory
665 {
666   public:
667         ModuleSQLite3Factory()
668         {
669         }
670
671         ~ModuleSQLite3Factory()
672         {
673         }
674
675         virtual Module * CreateModule(InspIRCd* Me)
676         {
677                 return new ModuleSQLite3(Me);
678         }
679 };
680
681 extern "C" void * init_module( void )
682 {
683         return new ModuleSQLite3Factory;
684 }