]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
Last cleanup. All trunk extras now builds again.
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*               +------------------------------------+
2  *               | Inspire Internet Relay Chat Daemon |
3  *               +------------------------------------+
4  *
5  *      InspIRCd: (C) 2002-2008 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *                        the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include "inspircd.h"
15 #include <sqlite3.h>
16 #include "users.h"
17 #include "channels.h"
18 #include "modules.h"
19 #include "m_sqlv2.h"
20
21 /* $ModDesc: sqlite3 provider */
22 /* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */
23 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
24 /* $ModDep: m_sqlv2.h */
25 /* $NoPedantic */
26
27 class SQLConn;
28 class SQLite3Result;
29 class ResultNotifier;
30 class SQLiteListener;
31 class ModuleSQLite3;
32
33 typedef std::map<std::string, SQLConn*> ConnMap;
34 typedef std::deque<classbase*> paramlist;
35 typedef std::deque<SQLite3Result*> ResultQueue;
36
37 ResultNotifier* notifier = NULL;
38 SQLiteListener* listener = NULL;
39 int QueueFD = -1;
40
41 class ResultNotifier : public BufferedSocket
42 {
43         ModuleSQLite3* mod;
44
45  public:
46         ResultNotifier(ModuleSQLite3* m, InspIRCd* SI, int newfd, char* ip) : BufferedSocket(SI, newfd, ip), mod(m)
47         {
48         }
49
50         virtual bool OnDataReady()
51         {
52                 char data = 0;
53                 if (Instance->SE->Recv(this, &data, 1, 0) > 0)
54                 {
55                         Dispatch();
56                         return true;
57                 }
58                 return false;
59         }
60
61         void Dispatch();
62 };
63
64 class SQLiteListener : public ListenSocketBase
65 {
66         ModuleSQLite3* Parent;
67         insp_sockaddr sock_us;
68         socklen_t uslen;
69         FileReader* index;
70
71  public:
72         SQLiteListener(ModuleSQLite3* P, InspIRCd* Instance, int port, const std::string &addr) : ListenSocketBase(Instance, port, addr), Parent(P)
73         {
74                 uslen = sizeof(sock_us);
75                 if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
76                 {
77                         throw ModuleException("Could not getsockname() to find out port number for ITC port");
78                 }
79         }
80
81         virtual void OnAcceptReady(const std::string &ipconnectedto, int nfd, const std::string &incomingip)
82         {
83                 new ResultNotifier(this->Parent, this->ServerInstance, nfd, (char *)ipconnectedto.c_str()); // XXX unsafe casts suck
84         }
85
86         /* Using getsockname and ntohs, we can determine which port number we were allocated */
87         int GetPort()
88         {
89 #ifdef IPV6
90                 return ntohs(sock_us.sin6_port);
91 #else
92                 return ntohs(sock_us.sin_port);
93 #endif
94         }
95 };
96
97 class SQLite3Result : public SQLresult
98 {
99  private:
100         int currentrow;
101         int rows;
102         int cols;
103
104         std::vector<std::string> colnames;
105         std::vector<SQLfieldList> fieldlists;
106         SQLfieldList emptyfieldlist;
107
108         SQLfieldList* fieldlist;
109         SQLfieldMap* fieldmap;
110
111  public:
112         SQLite3Result(Module* self, Module* to, unsigned int rid)
113         : SQLresult(self, to, rid), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
114         {
115         }
116
117         ~SQLite3Result()
118         {
119         }
120
121         void AddRow(int colsnum, char **dat, char **colname)
122         {
123                 colnames.clear();
124                 cols = colsnum;
125                 for (int i = 0; i < colsnum; i++)
126                 {
127                         fieldlists.resize(fieldlists.size()+1);
128                         colnames.push_back(colname[i]);
129                         SQLfield sf(dat[i] ? dat[i] : "", dat[i] ? false : true);
130                         fieldlists[rows].push_back(sf);
131                 }
132                 rows++;
133         }
134
135         void UpdateAffectedCount()
136         {
137                 rows++;
138         }
139
140         virtual int Rows()
141         {
142                 return rows;
143         }
144
145         virtual int Cols()
146         {
147                 return cols;
148         }
149
150         virtual std::string ColName(int column)
151         {
152                 if (column < (int)colnames.size())
153                 {
154                         return colnames[column];
155                 }
156                 else
157                 {
158                         throw SQLbadColName();
159                 }
160                 return "";
161         }
162
163         virtual int ColNum(const std::string &column)
164         {
165                 for (unsigned int i = 0; i < colnames.size(); i++)
166                 {
167                         if (column == colnames[i])
168                                 return i;
169                 }
170                 throw SQLbadColName();
171                 return 0;
172         }
173
174         virtual SQLfield GetValue(int row, int column)
175         {
176                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
177                 {
178                         return fieldlists[row][column];
179                 }
180
181                 throw SQLbadColName();
182
183                 /* XXX: We never actually get here because of the throw */
184                 return SQLfield("",true);
185         }
186
187         virtual SQLfieldList& GetRow()
188         {
189                 if (currentrow < rows)
190                         return fieldlists[currentrow];
191                 else
192                         return emptyfieldlist;
193         }
194
195         virtual SQLfieldMap& GetRowMap()
196         {
197                 /* In an effort to reduce overhead we don't actually allocate the map
198                  * until the first time it's needed...so...
199                  */
200                 if(fieldmap)
201                 {
202                         fieldmap->clear();
203                 }
204                 else
205                 {
206                         fieldmap = new SQLfieldMap;
207                 }
208
209                 if (currentrow < rows)
210                 {
211                         for (int i = 0; i < Cols(); i++)
212                         {
213                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
214                         }
215                         currentrow++;
216                 }
217
218                 return *fieldmap;
219         }
220
221         virtual SQLfieldList* GetRowPtr()
222         {
223                 fieldlist = new SQLfieldList();
224
225                 if (currentrow < rows)
226                 {
227                         for (int i = 0; i < Rows(); i++)
228                         {
229                                 fieldlist->push_back(fieldlists[currentrow][i]);
230                         }
231                         currentrow++;
232                 }
233                 return fieldlist;
234         }
235
236         virtual SQLfieldMap* GetRowMapPtr()
237         {
238                 fieldmap = new SQLfieldMap();
239
240                 if (currentrow < rows)
241                 {
242                         for (int i = 0; i < Cols(); i++)
243                         {
244                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
245                         }
246                         currentrow++;
247                 }
248
249                 return fieldmap;
250         }
251
252         virtual void Free(SQLfieldMap* fm)
253         {
254                 delete fm;
255         }
256
257         virtual void Free(SQLfieldList* fl)
258         {
259                 delete fl;
260         }
261
262
263 };
264
265 class SQLConn : public classbase
266 {
267  private:
268         ResultQueue results;
269         InspIRCd* Instance;
270         Module* mod;
271         SQLhost host;
272         sqlite3* conn;
273
274  public:
275         SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
276         : Instance(SI), mod(m), host(hi)
277         {
278                 if (OpenDB() != SQLITE_OK)
279                 {
280                         Instance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: Could not open DB with id: " + host.id);
281                         CloseDB();
282                 }
283         }
284
285         ~SQLConn()
286         {
287                 CloseDB();
288         }
289
290         SQLerror Query(SQLrequest &req)
291         {
292                 /* Pointer to the buffer we screw around with substitution in */
293                 char* query;
294
295                 /* Pointer to the current end of query, where we append new stuff */
296                 char* queryend;
297
298                 /* Total length of the unescaped parameters */
299                 unsigned long paramlen;
300
301                 /* Total length of query, used for binary-safety */
302                 unsigned long querylength = 0;
303
304                 paramlen = 0;
305                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
306                 {
307                         paramlen += i->size();
308                 }
309
310                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
311                  * sizeofquery + (totalparamlength*2) + 1
312                  *
313                  * The +1 is for null-terminating the string
314                  */
315                 query = new char[req.query.q.length() + (paramlen*2) + 1];
316                 queryend = query;
317
318                 for(unsigned long i = 0; i < req.query.q.length(); i++)
319                 {
320                         if(req.query.q[i] == '?')
321                         {
322                                 if(req.query.p.size())
323                                 {
324                                         char* escaped;
325                                         escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
326                                         for (char* n = escaped; *n; n++)
327                                         {
328                                                 *queryend = *n;
329                                                 queryend++;
330                                         }
331                                         sqlite3_free(escaped);
332                                         req.query.p.pop_front();
333                                 }
334                                 else
335                                         break;
336                         }
337                         else
338                         {
339                                 *queryend = req.query.q[i];
340                                 queryend++;
341                         }
342                         querylength++;
343                 }
344                 *queryend = 0;
345                 req.query.q = query;
346
347                 SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
348                 res->dbid = host.id;
349                 res->query = req.query.q;
350                 paramlist params;
351                 params.push_back(this);
352                 params.push_back(res);
353
354                 char *errmsg = 0;
355                 sqlite3_update_hook(conn, QueryUpdateHook, &params);
356                 if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
357                 {
358                         std::string error(errmsg);
359                         sqlite3_free(errmsg);
360                         delete[] query;
361                         delete res;
362                         return SQLerror(SQL_QSEND_FAIL, error);
363                 }
364                 delete[] query;
365
366                 results.push_back(res);
367                 SendNotify();
368                 return SQLerror();
369         }
370
371         static int QueryResult(void *params, int argc, char **argv, char **azColName)
372         {
373                 paramlist* p = (paramlist*)params;
374                 ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
375                 return 0;
376         }
377
378         static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid)
379         {
380                 paramlist* p = (paramlist*)params;
381                 ((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1]));
382         }
383
384         void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
385         {
386                 res->AddRow(cols, data, colnames);
387         }
388
389         void AffectedReady(SQLite3Result *res)
390         {
391                 res->UpdateAffectedCount();
392         }
393
394         int OpenDB()
395         {
396                 return sqlite3_open(host.host.c_str(), &conn);
397         }
398
399         void CloseDB()
400         {
401                 sqlite3_interrupt(conn);
402                 sqlite3_close(conn);
403         }
404
405         SQLhost GetConfHost()
406         {
407                 return host;
408         }
409
410         void SendResults()
411         {
412                 while (results.size())
413                 {
414                         SQLite3Result* res = results[0];
415                         if (res->GetDest())
416                         {
417                                 res->Send();
418                         }
419                         else
420                         {
421                                 /* If the client module is unloaded partway through a query then the provider will set
422                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
423                                  * through at some point...and it could get messy if we play with invalid pointers...
424                                  */
425                                 delete res;
426                         }
427                         results.pop_front();
428                 }
429         }
430
431         void ClearResults()
432         {
433                 while (results.size())
434                 {
435                         SQLite3Result* res = results[0];
436                         delete res;
437                         results.pop_front();
438                 }
439         }
440
441         void SendNotify()
442         {
443                 if (QueueFD < 0)
444                 {
445                         if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
446                         {
447                                 /* crap, we're out of sockets... */
448                                 return;
449                         }
450
451                         insp_sockaddr addr;
452
453 #ifdef IPV6
454                         insp_aton("::1", &addr.sin6_addr);
455                         addr.sin6_family = AF_FAMILY;
456                         addr.sin6_port = htons(listener->GetPort());
457 #else
458                         insp_inaddr ia;
459                         insp_aton("127.0.0.1", &ia);
460                         addr.sin_family = AF_FAMILY;
461                         addr.sin_addr = ia;
462                         addr.sin_port = htons(listener->GetPort());
463 #endif
464
465                         if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
466                         {
467                                 /* wtf, we cant connect to it, but we just created it! */
468                                 return;
469                         }
470                 }
471                 char id = 0;
472                 send(QueueFD, &id, 1, 0);
473         }
474
475 };
476
477
478 class ModuleSQLite3 : public Module
479 {
480  private:
481         ConnMap connections;
482         unsigned long currid;
483
484  public:
485         ModuleSQLite3(InspIRCd* Me)
486         : Module(Me), currid(0)
487         {
488                 ServerInstance->Modules->UseInterface("SQLutils");
489
490                 if (!ServerInstance->Modules->PublishFeature("SQL", this))
491                 {
492                         throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");
493                 }
494
495                 /* Create a socket on a random port. Let the tcp stack allocate us an available port */
496 #ifdef IPV6
497                 listener = new SQLiteListener(this, ServerInstance, 0, "::1");
498 #else
499                 listener = new SQLiteListener(this, ServerInstance, 0, "127.0.0.1");
500 #endif
501
502                 if (listener->GetFd() == -1)
503                 {
504                         ServerInstance->Modules->DoneWithInterface("SQLutils");
505                         throw ModuleException("m_sqlite3: unable to create ITC pipe");
506                 }
507                 else
508                 {
509                         ServerInstance->Logs->Log("m_sqlite3", DEBUG, "SQLite: Interthread comms port is %d", listener->GetPort());
510                 }
511
512                 ReadConf();
513
514                 ServerInstance->Modules->PublishInterface("SQL", this);
515                 Implementation eventlist[] = { I_OnRequest, I_OnRehash };
516                 ServerInstance->Modules->Attach(eventlist, this, 2);
517         }
518
519         virtual ~ModuleSQLite3()
520         {
521                 ClearQueue();
522                 ClearAllConnections();
523
524                 ServerInstance->SE->DelFd(listener);
525                 ServerInstance->BufferedSocketCull();
526
527                 if (QueueFD >= 0)
528                 {
529                         shutdown(QueueFD, 2);
530                         close(QueueFD);
531                 }
532
533                 if (notifier)
534                 {
535                         ServerInstance->SE->DelFd(notifier);
536                         notifier->Close();
537                         ServerInstance->BufferedSocketCull();
538                 }
539
540                 ServerInstance->Modules->UnpublishInterface("SQL", this);
541                 ServerInstance->Modules->UnpublishFeature("SQL");
542                 ServerInstance->Modules->DoneWithInterface("SQLutils");
543         }
544
545
546         void SendQueue()
547         {
548                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
549                 {
550                         iter->second->SendResults();
551                 }
552         }
553
554         void ClearQueue()
555         {
556                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
557                 {
558                         iter->second->ClearResults();
559                 }
560         }
561
562         bool HasHost(const SQLhost &host)
563         {
564                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
565                 {
566                         if (host == iter->second->GetConfHost())
567                                 return true;
568                 }
569                 return false;
570         }
571
572         bool HostInConf(const SQLhost &h)
573         {
574                 ConfigReader conf(ServerInstance);
575                 for(int i = 0; i < conf.Enumerate("database"); i++)
576                 {
577                         SQLhost host;
578                         host.id         = conf.ReadValue("database", "id", i);
579                         host.host       = conf.ReadValue("database", "hostname", i);
580                         host.port       = conf.ReadInteger("database", "port", i, true);
581                         host.name       = conf.ReadValue("database", "name", i);
582                         host.user       = conf.ReadValue("database", "username", i);
583                         host.pass       = conf.ReadValue("database", "password", i);
584                         if (h == host)
585                                 return true;
586                 }
587                 return false;
588         }
589
590         void ReadConf()
591         {
592                 ClearOldConnections();
593
594                 ConfigReader conf(ServerInstance);
595                 for(int i = 0; i < conf.Enumerate("database"); i++)
596                 {
597                         SQLhost host;
598
599                         host.id         = conf.ReadValue("database", "id", i);
600                         host.host       = conf.ReadValue("database", "hostname", i);
601                         host.port       = conf.ReadInteger("database", "port", i, true);
602                         host.name       = conf.ReadValue("database", "name", i);
603                         host.user       = conf.ReadValue("database", "username", i);
604                         host.pass       = conf.ReadValue("database", "password", i);
605
606                         if (HasHost(host))
607                                 continue;
608
609                         this->AddConn(host);
610                 }
611         }
612
613         void AddConn(const SQLhost& hi)
614         {
615                 if (HasHost(hi))
616                 {
617                         ServerInstance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
618                         return;
619                 }
620
621                 SQLConn* newconn;
622
623                 newconn = new SQLConn(ServerInstance, this, hi);
624
625                 connections.insert(std::make_pair(hi.id, newconn));
626         }
627
628         void ClearOldConnections()
629         {
630                 ConnMap::iterator iter,safei;
631                 for (iter = connections.begin(); iter != connections.end(); iter++)
632                 {
633                         if (!HostInConf(iter->second->GetConfHost()))
634                         {
635                                 delete iter->second;
636                                 safei = iter;
637                                 --iter;
638                                 connections.erase(safei);
639                         }
640                 }
641         }
642
643         void ClearAllConnections()
644         {
645                 ConnMap::iterator i;
646                 while ((i = connections.begin()) != connections.end())
647                 {
648                         connections.erase(i);
649                         delete i->second;
650                 }
651         }
652
653         virtual void OnRehash(User* user, const std::string &parameter)
654         {
655                 ReadConf();
656         }
657
658         virtual const char* OnRequest(Request* request)
659         {
660                 if(strcmp(SQLREQID, request->GetId()) == 0)
661                 {
662                         SQLrequest* req = (SQLrequest*)request;
663                         ConnMap::iterator iter;
664                         if((iter = connections.find(req->dbid)) != connections.end())
665                         {
666                                 req->id = NewID();
667                                 req->error = iter->second->Query(*req);
668                                 return SQLSUCCESS;
669                         }
670                         else
671                         {
672                                 req->error.Id(SQL_BAD_DBID);
673                                 return NULL;
674                         }
675                 }
676                 return NULL;
677         }
678
679         unsigned long NewID()
680         {
681                 if (currid+1 == 0)
682                         currid++;
683
684                 return ++currid;
685         }
686
687         virtual Version GetVersion()
688         {
689                 return Version("$Id$", VF_VENDOR | VF_SERVICEPROVIDER, API_VERSION);
690         }
691
692 };
693
694 void ResultNotifier::Dispatch()
695 {
696         mod->SendQueue();
697 }
698
699 MODULE_INIT(ModuleSQLite3)