]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
d51bc16953845bb8647c571fcbbdcc63d0e1254f
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*               +------------------------------------+
2  *               | Inspire Internet Relay Chat Daemon |
3  *               +------------------------------------+
4  *
5  *      InspIRCd: (C) 2002-2008 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *                        the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include "inspircd.h"
15 #include <sqlite3.h>
16 #include "m_sqlv2.h"
17
18 /* $ModDesc: sqlite3 provider */
19 /* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */
20 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
21 /* $ModDep: m_sqlv2.h */
22 /* $NoPedantic */
23
24 class SQLConn;
25 class SQLite3Result;
26 class ResultNotifier;
27 class SQLiteListener;
28 class ModuleSQLite3;
29
30 typedef std::map<std::string, SQLConn*> ConnMap;
31 typedef std::deque<classbase*> paramlist;
32 typedef std::deque<SQLite3Result*> ResultQueue;
33
34 ResultNotifier* notifier = NULL;
35 SQLiteListener* listener = NULL;
36 int QueueFD = -1;
37
38 class ResultNotifier : public BufferedSocket
39 {
40         ModuleSQLite3* mod;
41
42  public:
43         ResultNotifier(ModuleSQLite3* m, InspIRCd* SI, int newfd, char* ip) : BufferedSocket(SI, newfd, ip), mod(m)
44         {
45         }
46
47         virtual bool OnDataReady()
48         {
49                 char data = 0;
50                 if (ServerInstance->SE->Recv(this, &data, 1, 0) > 0)
51                 {
52                         Dispatch();
53                         return true;
54                 }
55                 return false;
56         }
57
58         void Dispatch();
59 };
60
61 class SQLiteListener : public ListenSocketBase
62 {
63         ModuleSQLite3* Parent;
64         irc::sockets::insp_sockaddr sock_us;
65         socklen_t uslen;
66         FileReader* index;
67
68  public:
69         SQLiteListener(ModuleSQLite3* P, InspIRCd* Instance, int port, const std::string &addr) : ListenSocketBase(Instance, port, addr), Parent(P)
70         {
71                 uslen = sizeof(sock_us);
72                 if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
73                 {
74                         throw ModuleException("Could not getsockname() to find out port number for ITC port");
75                 }
76         }
77
78         virtual void OnAcceptReady(const std::string &ipconnectedto, int nfd, const std::string &incomingip)
79         {
80                 new ResultNotifier(this->Parent, this->ServerInstance, nfd, (char *)ipconnectedto.c_str()); // XXX unsafe casts suck
81         }
82
83         /* Using getsockname and ntohs, we can determine which port number we were allocated */
84         int GetPort()
85         {
86 #ifdef IPV6
87                 return ntohs(sock_us.sin6_port);
88 #else
89                 return ntohs(sock_us.sin_port);
90 #endif
91         }
92 };
93
94 class SQLite3Result : public SQLresult
95 {
96  private:
97         int currentrow;
98         int rows;
99         int cols;
100
101         std::vector<std::string> colnames;
102         std::vector<SQLfieldList> fieldlists;
103         SQLfieldList emptyfieldlist;
104
105         SQLfieldList* fieldlist;
106         SQLfieldMap* fieldmap;
107
108  public:
109         SQLite3Result(Module* self, Module* to, unsigned int rid)
110         : SQLresult(self, to, rid), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
111         {
112         }
113
114         ~SQLite3Result()
115         {
116         }
117
118         void AddRow(int colsnum, char **dat, char **colname)
119         {
120                 colnames.clear();
121                 cols = colsnum;
122                 for (int i = 0; i < colsnum; i++)
123                 {
124                         fieldlists.resize(fieldlists.size()+1);
125                         colnames.push_back(colname[i]);
126                         SQLfield sf(dat[i] ? dat[i] : "", dat[i] ? false : true);
127                         fieldlists[rows].push_back(sf);
128                 }
129                 rows++;
130         }
131
132         void UpdateAffectedCount()
133         {
134                 rows++;
135         }
136
137         virtual int Rows()
138         {
139                 return rows;
140         }
141
142         virtual int Cols()
143         {
144                 return cols;
145         }
146
147         virtual std::string ColName(int column)
148         {
149                 if (column < (int)colnames.size())
150                 {
151                         return colnames[column];
152                 }
153                 else
154                 {
155                         throw SQLbadColName();
156                 }
157                 return "";
158         }
159
160         virtual int ColNum(const std::string &column)
161         {
162                 for (unsigned int i = 0; i < colnames.size(); i++)
163                 {
164                         if (column == colnames[i])
165                                 return i;
166                 }
167                 throw SQLbadColName();
168                 return 0;
169         }
170
171         virtual SQLfield GetValue(int row, int column)
172         {
173                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
174                 {
175                         return fieldlists[row][column];
176                 }
177
178                 throw SQLbadColName();
179
180                 /* XXX: We never actually get here because of the throw */
181                 return SQLfield("",true);
182         }
183
184         virtual SQLfieldList& GetRow()
185         {
186                 if (currentrow < rows)
187                         return fieldlists[currentrow];
188                 else
189                         return emptyfieldlist;
190         }
191
192         virtual SQLfieldMap& GetRowMap()
193         {
194                 /* In an effort to reduce overhead we don't actually allocate the map
195                  * until the first time it's needed...so...
196                  */
197                 if(fieldmap)
198                 {
199                         fieldmap->clear();
200                 }
201                 else
202                 {
203                         fieldmap = new SQLfieldMap;
204                 }
205
206                 if (currentrow < rows)
207                 {
208                         for (int i = 0; i < Cols(); i++)
209                         {
210                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
211                         }
212                         currentrow++;
213                 }
214
215                 return *fieldmap;
216         }
217
218         virtual SQLfieldList* GetRowPtr()
219         {
220                 fieldlist = new SQLfieldList();
221
222                 if (currentrow < rows)
223                 {
224                         for (int i = 0; i < Rows(); i++)
225                         {
226                                 fieldlist->push_back(fieldlists[currentrow][i]);
227                         }
228                         currentrow++;
229                 }
230                 return fieldlist;
231         }
232
233         virtual SQLfieldMap* GetRowMapPtr()
234         {
235                 fieldmap = new SQLfieldMap();
236
237                 if (currentrow < rows)
238                 {
239                         for (int i = 0; i < Cols(); i++)
240                         {
241                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
242                         }
243                         currentrow++;
244                 }
245
246                 return fieldmap;
247         }
248
249         virtual void Free(SQLfieldMap* fm)
250         {
251                 delete fm;
252         }
253
254         virtual void Free(SQLfieldList* fl)
255         {
256                 delete fl;
257         }
258
259
260 };
261
262 class SQLConn : public classbase
263 {
264  private:
265         ResultQueue results;
266         InspIRCd* ServerInstance;
267         Module* mod;
268         SQLhost host;
269         sqlite3* conn;
270
271  public:
272         SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
273         : ServerInstance(SI), mod(m), host(hi)
274         {
275                 if (OpenDB() != SQLITE_OK)
276                 {
277                         ServerInstance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: Could not open DB with id: " + host.id);
278                         CloseDB();
279                 }
280         }
281
282         ~SQLConn()
283         {
284                 CloseDB();
285         }
286
287         SQLerror Query(SQLrequest &req)
288         {
289                 /* Pointer to the buffer we screw around with substitution in */
290                 char* query;
291
292                 /* Pointer to the current end of query, where we append new stuff */
293                 char* queryend;
294
295                 /* Total length of the unescaped parameters */
296                 unsigned long paramlen;
297
298                 /* Total length of query, used for binary-safety */
299                 unsigned long querylength = 0;
300
301                 paramlen = 0;
302                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
303                 {
304                         paramlen += i->size();
305                 }
306
307                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
308                  * sizeofquery + (totalparamlength*2) + 1
309                  *
310                  * The +1 is for null-terminating the string
311                  */
312                 query = new char[req.query.q.length() + (paramlen*2) + 1];
313                 queryend = query;
314
315                 for(unsigned long i = 0; i < req.query.q.length(); i++)
316                 {
317                         if(req.query.q[i] == '?')
318                         {
319                                 if(req.query.p.size())
320                                 {
321                                         char* escaped;
322                                         escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
323                                         for (char* n = escaped; *n; n++)
324                                         {
325                                                 *queryend = *n;
326                                                 queryend++;
327                                         }
328                                         sqlite3_free(escaped);
329                                         req.query.p.pop_front();
330                                 }
331                                 else
332                                         break;
333                         }
334                         else
335                         {
336                                 *queryend = req.query.q[i];
337                                 queryend++;
338                         }
339                         querylength++;
340                 }
341                 *queryend = 0;
342                 req.query.q = query;
343
344                 SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
345                 res->dbid = host.id;
346                 res->query = req.query.q;
347                 paramlist params;
348                 params.push_back(this);
349                 params.push_back(res);
350
351                 char *errmsg = 0;
352                 sqlite3_update_hook(conn, QueryUpdateHook, &params);
353                 if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
354                 {
355                         std::string error(errmsg);
356                         sqlite3_free(errmsg);
357                         delete[] query;
358                         delete res;
359                         return SQLerror(SQL_QSEND_FAIL, error);
360                 }
361                 delete[] query;
362
363                 results.push_back(res);
364                 SendNotify();
365                 return SQLerror();
366         }
367
368         static int QueryResult(void *params, int argc, char **argv, char **azColName)
369         {
370                 paramlist* p = (paramlist*)params;
371                 ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
372                 return 0;
373         }
374
375         static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid)
376         {
377                 paramlist* p = (paramlist*)params;
378                 ((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1]));
379         }
380
381         void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
382         {
383                 res->AddRow(cols, data, colnames);
384         }
385
386         void AffectedReady(SQLite3Result *res)
387         {
388                 res->UpdateAffectedCount();
389         }
390
391         int OpenDB()
392         {
393                 return sqlite3_open(host.host.c_str(), &conn);
394         }
395
396         void CloseDB()
397         {
398                 sqlite3_interrupt(conn);
399                 sqlite3_close(conn);
400         }
401
402         SQLhost GetConfHost()
403         {
404                 return host;
405         }
406
407         void SendResults()
408         {
409                 while (results.size())
410                 {
411                         SQLite3Result* res = results[0];
412                         if (res->GetDest())
413                         {
414                                 res->Send();
415                         }
416                         else
417                         {
418                                 /* If the client module is unloaded partway through a query then the provider will set
419                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
420                                  * through at some point...and it could get messy if we play with invalid pointers...
421                                  */
422                                 delete res;
423                         }
424                         results.pop_front();
425                 }
426         }
427
428         void ClearResults()
429         {
430                 while (results.size())
431                 {
432                         SQLite3Result* res = results[0];
433                         delete res;
434                         results.pop_front();
435                 }
436         }
437
438         void SendNotify()
439         {
440                 if (QueueFD < 0)
441                 {
442                         if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
443                         {
444                                 /* crap, we're out of sockets... */
445                                 return;
446                         }
447
448                         irc::sockets::insp_sockaddr addr;
449
450 #ifdef IPV6
451                         irc::sockets::insp_aton("::1", &addr.sin6_addr);
452                         addr.sin6_family = AF_FAMILY;
453                         addr.sin6_port = htons(listener->GetPort());
454 #else
455                         irc::sockets::insp_inaddr ia;
456                         irc::sockets::insp_aton("127.0.0.1", &ia);
457                         addr.sin_family = AF_FAMILY;
458                         addr.sin_addr = ia;
459                         addr.sin_port = htons(listener->GetPort());
460 #endif
461
462                         if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
463                         {
464                                 /* wtf, we cant connect to it, but we just created it! */
465                                 return;
466                         }
467                 }
468                 char id = 0;
469                 send(QueueFD, &id, 1, 0);
470         }
471
472 };
473
474
475 class ModuleSQLite3 : public Module
476 {
477  private:
478         ConnMap connections;
479         unsigned long currid;
480
481  public:
482         ModuleSQLite3(InspIRCd* Me)
483         : Module(Me), currid(0)
484         {
485                 ServerInstance->Modules->UseInterface("SQLutils");
486
487                 if (!ServerInstance->Modules->PublishFeature("SQL", this))
488                 {
489                         throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");
490                 }
491
492                 /* Create a socket on a random port. Let the tcp stack allocate us an available port */
493 #ifdef IPV6
494                 listener = new SQLiteListener(this, ServerInstance, 0, "::1");
495 #else
496                 listener = new SQLiteListener(this, ServerInstance, 0, "127.0.0.1");
497 #endif
498
499                 if (listener->GetFd() == -1)
500                 {
501                         ServerInstance->Modules->DoneWithInterface("SQLutils");
502                         throw ModuleException("m_sqlite3: unable to create ITC pipe");
503                 }
504                 else
505                 {
506                         ServerInstance->Logs->Log("m_sqlite3", DEBUG, "SQLite: Interthread comms port is %d", listener->GetPort());
507                 }
508
509                 ReadConf();
510
511                 ServerInstance->Modules->PublishInterface("SQL", this);
512                 Implementation eventlist[] = { I_OnRequest, I_OnRehash };
513                 ServerInstance->Modules->Attach(eventlist, this, 2);
514         }
515
516         virtual ~ModuleSQLite3()
517         {
518                 ClearQueue();
519                 ClearAllConnections();
520
521                 ServerInstance->SE->DelFd(listener);
522                 ServerInstance->BufferedSocketCull();
523
524                 if (QueueFD >= 0)
525                 {
526                         shutdown(QueueFD, 2);
527                         close(QueueFD);
528                 }
529
530                 if (notifier)
531                 {
532                         ServerInstance->SE->DelFd(notifier);
533                         notifier->Close();
534                         ServerInstance->BufferedSocketCull();
535                 }
536
537                 ServerInstance->Modules->UnpublishInterface("SQL", this);
538                 ServerInstance->Modules->UnpublishFeature("SQL");
539                 ServerInstance->Modules->DoneWithInterface("SQLutils");
540         }
541
542
543         void SendQueue()
544         {
545                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
546                 {
547                         iter->second->SendResults();
548                 }
549         }
550
551         void ClearQueue()
552         {
553                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
554                 {
555                         iter->second->ClearResults();
556                 }
557         }
558
559         bool HasHost(const SQLhost &host)
560         {
561                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
562                 {
563                         if (host == iter->second->GetConfHost())
564                                 return true;
565                 }
566                 return false;
567         }
568
569         bool HostInConf(const SQLhost &h)
570         {
571                 ConfigReader conf(ServerInstance);
572                 for(int i = 0; i < conf.Enumerate("database"); i++)
573                 {
574                         SQLhost host;
575                         host.id         = conf.ReadValue("database", "id", i);
576                         host.host       = conf.ReadValue("database", "hostname", i);
577                         host.port       = conf.ReadInteger("database", "port", i, true);
578                         host.name       = conf.ReadValue("database", "name", i);
579                         host.user       = conf.ReadValue("database", "username", i);
580                         host.pass       = conf.ReadValue("database", "password", i);
581                         if (h == host)
582                                 return true;
583                 }
584                 return false;
585         }
586
587         void ReadConf()
588         {
589                 ClearOldConnections();
590
591                 ConfigReader conf(ServerInstance);
592                 for(int i = 0; i < conf.Enumerate("database"); i++)
593                 {
594                         SQLhost host;
595
596                         host.id         = conf.ReadValue("database", "id", i);
597                         host.host       = conf.ReadValue("database", "hostname", i);
598                         host.port       = conf.ReadInteger("database", "port", i, true);
599                         host.name       = conf.ReadValue("database", "name", i);
600                         host.user       = conf.ReadValue("database", "username", i);
601                         host.pass       = conf.ReadValue("database", "password", i);
602
603                         if (HasHost(host))
604                                 continue;
605
606                         this->AddConn(host);
607                 }
608         }
609
610         void AddConn(const SQLhost& hi)
611         {
612                 if (HasHost(hi))
613                 {
614                         ServerInstance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
615                         return;
616                 }
617
618                 SQLConn* newconn;
619
620                 newconn = new SQLConn(ServerInstance, this, hi);
621
622                 connections.insert(std::make_pair(hi.id, newconn));
623         }
624
625         void ClearOldConnections()
626         {
627                 ConnMap::iterator iter,safei;
628                 for (iter = connections.begin(); iter != connections.end(); iter++)
629                 {
630                         if (!HostInConf(iter->second->GetConfHost()))
631                         {
632                                 delete iter->second;
633                                 safei = iter;
634                                 --iter;
635                                 connections.erase(safei);
636                         }
637                 }
638         }
639
640         void ClearAllConnections()
641         {
642                 ConnMap::iterator i;
643                 while ((i = connections.begin()) != connections.end())
644                 {
645                         connections.erase(i);
646                         delete i->second;
647                 }
648         }
649
650         virtual void OnRehash(User* user, const std::string &parameter)
651         {
652                 ReadConf();
653         }
654
655         virtual const char* OnRequest(Request* request)
656         {
657                 if(strcmp(SQLREQID, request->GetId()) == 0)
658                 {
659                         SQLrequest* req = (SQLrequest*)request;
660                         ConnMap::iterator iter;
661                         if((iter = connections.find(req->dbid)) != connections.end())
662                         {
663                                 req->id = NewID();
664                                 req->error = iter->second->Query(*req);
665                                 return SQLSUCCESS;
666                         }
667                         else
668                         {
669                                 req->error.Id(SQL_BAD_DBID);
670                                 return NULL;
671                         }
672                 }
673                 return NULL;
674         }
675
676         unsigned long NewID()
677         {
678                 if (currid+1 == 0)
679                         currid++;
680
681                 return ++currid;
682         }
683
684         virtual Version GetVersion()
685         {
686                 return Version("$Id$", VF_VENDOR | VF_SERVICEPROVIDER, API_VERSION);
687         }
688
689 };
690
691 void ResultNotifier::Dispatch()
692 {
693         mod->SendQueue();
694 }
695
696 MODULE_INIT(ModuleSQLite3)