]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
Dont need to send anything on the notifier socket.
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2007 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *            the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include <string>
15 #include <deque>
16 #include <map>
17 #include <sqlite3.h>
18
19 #include "users.h"
20 #include "channels.h"
21 #include "modules.h"
22 #include "inspircd.h"
23 #include "configreader.h"
24
25 #include "m_sqlv2.h"
26
27 /* $ModDesc: sqlite3 provider */
28 /* $CompileFlags: pkgconfincludes("sqlite3","/sqlite3.h","") */
29 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
30 /* $ModDep: m_sqlv2.h */
31
32
33 class SQLConn;
34 class SQLite3Result;
35 class ResultNotifier;
36
37 typedef std::map<std::string, SQLConn*> ConnMap;
38 typedef std::deque<classbase*> paramlist;
39 typedef std::deque<SQLite3Result*> ResultQueue;
40
41 ResultNotifier* resultnotify = NULL;
42
43
44 class ResultNotifier : public InspSocket
45 {
46         Module* mod;
47         insp_sockaddr sock_us;
48         socklen_t uslen;
49
50  public:
51         /* Create a socket on a random port. Let the tcp stack allocate us an available port */
52 #ifdef IPV6
53         ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "::1", 0, true, 3000), mod(m)
54 #else
55         ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "127.0.0.1", 0, true, 3000), mod(m)
56 #endif
57         {
58                 uslen = sizeof(sock_us);
59                 if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
60                 {
61                         throw ModuleException("Could not create random listening port on localhost");
62                 }
63         }
64
65         ResultNotifier(InspIRCd* SI, Module* m, int newfd, char* ip) : InspSocket(SI, newfd, ip), mod(m)
66         {
67                 Instance->Log(DEBUG,"Constructor of new socket");
68         }
69
70         /* Using getsockname and ntohs, we can determine which port number we were allocated */
71         int GetPort()
72         {
73 #ifdef IPV6
74                 return ntohs(sock_us.sin6_port);
75 #else
76                 return ntohs(sock_us.sin_port);
77 #endif
78         }
79
80         virtual int OnIncomingConnection(int newsock, char* ip)
81         {
82                 Instance->Log(DEBUG,"Inbound connection on fd %d!",newsock);
83                 Dispatch();
84                 return false;
85         }
86
87         void Dispatch();
88 };
89
90
91 class SQLite3Result : public SQLresult
92 {
93   private:
94         int currentrow;
95         int rows;
96         int cols;
97
98         std::vector<std::string> colnames;
99         std::vector<SQLfieldList> fieldlists;
100         SQLfieldList emptyfieldlist;
101
102         SQLfieldList* fieldlist;
103         SQLfieldMap* fieldmap;
104
105   public:
106         SQLite3Result(Module* self, Module* to, unsigned int id)
107         : SQLresult(self, to, id), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
108         {
109         }
110
111         ~SQLite3Result()
112         {
113         }
114
115         void AddRow(int colsnum, char **data, char **colname)
116         {
117                 colnames.clear();
118                 cols = colsnum;
119                 for (int i = 0; i < colsnum; i++)
120                 {
121                         fieldlists.resize(fieldlists.size()+1);
122                         colnames.push_back(colname[i]);
123                         SQLfield sf(data[i] ? data[i] : "", data[i] ? false : true);
124                         fieldlists[rows].push_back(sf);
125                 }
126                 rows++;
127         }
128
129         virtual int Rows()
130         {
131                 return rows;
132         }
133
134         virtual int Cols()
135         {
136                 return cols;
137         }
138
139         virtual std::string ColName(int column)
140         {
141                 if (column < (int)colnames.size())
142                 {
143                         return colnames[column];
144                 }
145                 else
146                 {
147                         throw SQLbadColName();
148                 }
149                 return "";
150         }
151
152         virtual int ColNum(const std::string &column)
153         {
154                 for (unsigned int i = 0; i < colnames.size(); i++)
155                 {
156                         if (column == colnames[i])
157                                 return i;
158                 }
159                 throw SQLbadColName();
160                 return 0;
161         }
162
163         virtual SQLfield GetValue(int row, int column)
164         {
165                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
166                 {
167                         return fieldlists[row][column];
168                 }
169
170                 throw SQLbadColName();
171
172                 /* XXX: We never actually get here because of the throw */
173                 return SQLfield("",true);
174         }
175
176         virtual SQLfieldList& GetRow()
177         {
178                 if (currentrow < rows)
179                         return fieldlists[currentrow];
180                 else
181                         return emptyfieldlist;
182         }
183
184         virtual SQLfieldMap& GetRowMap()
185         {
186                 /* In an effort to reduce overhead we don't actually allocate the map
187                  * until the first time it's needed...so...
188                  */
189                 if(fieldmap)
190                 {
191                         fieldmap->clear();
192                 }
193                 else
194                 {
195                         fieldmap = new SQLfieldMap;
196                 }
197
198                 if (currentrow < rows)
199                 {
200                         for (int i = 0; i < Cols(); i++)
201                         {
202                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
203                         }
204                         currentrow++;
205                 }
206
207                 return *fieldmap;
208         }
209
210         virtual SQLfieldList* GetRowPtr()
211         {
212                 fieldlist = new SQLfieldList();
213
214                 if (currentrow < rows)
215                 {
216                         for (int i = 0; i < Rows(); i++)
217                         {
218                                 fieldlist->push_back(fieldlists[currentrow][i]);
219                         }
220                         currentrow++;
221                 }
222                 return fieldlist;
223         }
224
225         virtual SQLfieldMap* GetRowMapPtr()
226         {
227                 fieldmap = new SQLfieldMap();
228
229                 if (currentrow < rows)
230                 {
231                         for (int i = 0; i < Cols(); i++)
232                         {
233                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
234                         }
235                         currentrow++;
236                 }
237
238                 return fieldmap;
239         }
240
241         virtual void Free(SQLfieldMap* fm)
242         {
243                 delete fm;
244         }
245
246         virtual void Free(SQLfieldList* fl)
247         {
248                 delete fl;
249         }
250
251
252 };
253
254 class SQLConn : public classbase
255 {
256   private:
257         ResultQueue results;
258         InspIRCd* Instance;
259         Module* mod;
260         SQLhost host;
261         sqlite3* conn;
262
263   public:
264         SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
265         : Instance(SI), mod(m), host(hi)
266         {
267                 int result;
268                 if ((result = OpenDB()) == SQLITE_OK)
269                 {
270                         Instance->Log(DEBUG, "Opened sqlite DB: " + host.host);
271                 }
272                 else
273                 {
274                         Instance->Log(DEFAULT, "WARNING: Could not open DB with id: " + host.id);
275                         CloseDB();
276                 }
277         }
278
279         ~SQLConn()
280         {
281                 CloseDB();
282         }
283
284         SQLerror Query(SQLrequest &req)
285         {
286                 /* Pointer to the buffer we screw around with substitution in */
287                 char* query;
288
289                 /* Pointer to the current end of query, where we append new stuff */
290                 char* queryend;
291
292                 /* Total length of the unescaped parameters */
293                 unsigned long paramlen;
294
295                 /* Total length of query, used for binary-safety in mysql_real_query */
296                 unsigned long querylength = 0;
297
298                 paramlen = 0;
299                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
300                 {
301                         paramlen += i->size();
302                 }
303
304                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
305                  * sizeofquery + (totalparamlength*2) + 1
306                  *
307                  * The +1 is for null-terminating the string for mysql_real_escape_string
308                  */
309
310                 query = new char[req.query.q.length() + (paramlen*2) + 1];
311                 queryend = query;
312
313                 for(unsigned long i = 0; i < req.query.q.length(); i++)
314                 {
315                         if(req.query.q[i] == '?')
316                         {
317                                 if(req.query.p.size())
318                                 {
319                                         char* escaped;
320                                         escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
321                                         for (char* n = escaped; *n; n++)
322                                         {
323                                                 *queryend = *n;
324                                                 queryend++;
325                                         }
326                                         sqlite3_free(escaped);
327                                         req.query.p.pop_front();
328                                 }
329                                 else
330                                         break;
331                         }
332                         else
333                         {
334                                 *queryend = req.query.q[i];
335                                 queryend++;
336                         }
337                         querylength++;
338                 }
339                 *queryend = 0;
340                 req.query.q = query;
341
342 //              Instance->Log(DEBUG, "<******> Doing query: " + ConvToStr(req.query.q.data()));
343
344                 SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
345                 res->dbid = host.id;
346                 res->query = req.query.q;
347                 paramlist params;
348                 params.push_back(this);
349                 params.push_back(res);
350
351                 char *errmsg = 0;
352                 if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
353                 {
354                         Instance->Log(DEBUG, "Query failed: " + ConvToStr(errmsg));
355                         sqlite3_free(errmsg);
356                         delete[] query;
357                         delete res;
358                         return SQLerror(QSEND_FAIL, ConvToStr(errmsg));
359                 }
360                 Instance->Log(DEBUG, "Dispatched query successfully. ID: %d resulting rows %d", req.id, res->Rows());
361                 delete[] query;
362
363                 results.push_back(res);
364                 SendNotify();
365                 return SQLerror();
366         }
367
368         static int QueryResult(void *params, int argc, char **argv, char **azColName)
369         {
370                 paramlist* p = (paramlist*)params;
371                 ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
372                 return 0;
373         }
374
375         void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
376         {
377                 res->AddRow(cols, data, colnames);
378         }
379
380         int OpenDB()
381         {
382                 return sqlite3_open(host.host.c_str(), &conn);
383         }
384
385         void CloseDB()
386         {
387                 sqlite3_interrupt(conn);
388                 sqlite3_close(conn);
389         }
390
391         SQLhost GetConfHost()
392         {
393                 return host;
394         }
395
396         void SendResults()
397         {
398                 while (results.size())
399                 {
400                         SQLite3Result* res = results[0];
401                         if (res->GetDest())
402                         {
403                                 res->Send();
404                         }
405                         else
406                         {
407                                 /* If the client module is unloaded partway through a query then the provider will set
408                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
409                                  * through at some point...and it could get messy if we play with invalid pointers...
410                                  */
411                                 Instance->Log(DEBUG, "Looks like we're handling a zombie query from a module which unloaded before it got a result..fun. ID: " + ConvToStr(res->GetId()));
412                                 delete res;
413                         }
414                         results.pop_front();
415                 }
416         }
417
418         void ClearResults()
419         {
420                 while (results.size())
421                 {
422                         SQLite3Result* res = results[0];
423                         delete res;
424                         results.pop_front();
425                 }
426         }
427
428         void SendNotify()
429         {
430                 int QueueFD;
431                 if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
432                 {
433                         /* crap, we're out of sockets... */
434                         return;
435                 }
436
437                 insp_sockaddr addr;
438
439 #ifdef IPV6
440                 insp_aton("::1", &addr.sin6_addr);
441                 addr.sin6_family = AF_FAMILY;
442                 addr.sin6_port = htons(resultnotify->GetPort());
443 #else
444                 insp_inaddr ia;
445                 insp_aton("127.0.0.1", &ia);
446                 addr.sin_family = AF_FAMILY;
447                 addr.sin_addr = ia;
448                 addr.sin_port = htons(resultnotify->GetPort());
449 #endif
450
451                 if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
452                 {
453                         /* wtf, we cant connect to it, but we just created it! */
454                         return;
455                 }
456         }
457
458 };
459
460
461 class ModuleSQLite3 : public Module
462 {
463   private:
464         ConnMap connections;
465         unsigned long currid;
466
467   public:
468         ModuleSQLite3(InspIRCd* Me)
469         : Module::Module(Me), currid(0)
470         {
471                 ServerInstance->UseInterface("SQLutils");
472
473                 if (!ServerInstance->PublishFeature("SQL", this))
474                 {
475                         throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");
476                 }
477
478                 resultnotify = new ResultNotifier(ServerInstance, this);
479                 ServerInstance->Log(DEBUG,"Bound notifier to 127.0.0.1:%d",resultnotify->GetPort());
480
481                 ReadConf();
482
483                 ServerInstance->PublishInterface("SQL", this);
484         }
485
486         virtual ~ModuleSQLite3()
487         {
488                 ClearQueue();
489                 ClearAllConnections();
490                 resultnotify->SetFd(-1);
491                 resultnotify->state = I_ERROR;
492                 resultnotify->OnError(I_ERR_SOCKET);
493                 resultnotify->ClosePending = true;
494                 if (!ServerInstance->SE->DelFd(resultnotify))
495                 {
496                         ServerInstance->Log(DEBUG, "m_sqlite3: unable to remove notifier from socket engine!");
497                 }
498                 delete resultnotify;
499                 ServerInstance->UnpublishInterface("SQL", this);
500                 ServerInstance->UnpublishFeature("SQL");
501                 ServerInstance->DoneWithInterface("SQLutils");
502         }
503
504         void Implements(char* List)
505         {
506                 List[I_OnRequest] = List[I_OnRequest] = 1;
507         }
508
509         void SendQueue()
510         {
511                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
512                 {
513                         iter->second->SendResults();
514                 }
515         }
516
517         void ClearQueue()
518         {
519                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
520                 {
521                         iter->second->ClearResults();
522                 }
523         }
524
525         bool HasHost(const SQLhost &host)
526         {
527                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
528                 {
529                         if (host == iter->second->GetConfHost())
530                                 return true;
531                 }
532                 return false;
533         }
534
535         bool HostInConf(const SQLhost &h)
536         {
537                 ConfigReader conf(ServerInstance);
538                 for(int i = 0; i < conf.Enumerate("database"); i++)
539                 {
540                         SQLhost host;
541                         host.id         = conf.ReadValue("database", "id", i);
542                         host.host       = conf.ReadValue("database", "hostname", i);
543                         host.port       = conf.ReadInteger("database", "port", i, true);
544                         host.name       = conf.ReadValue("database", "name", i);
545                         host.user       = conf.ReadValue("database", "username", i);
546                         host.pass       = conf.ReadValue("database", "password", i);
547                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
548                         if (h == host)
549                                 return true;
550                 }
551                 return false;
552         }
553
554         void ReadConf()
555         {
556                 ClearOldConnections();
557
558                 ConfigReader conf(ServerInstance);
559                 for(int i = 0; i < conf.Enumerate("database"); i++)
560                 {
561                         SQLhost host;
562
563                         host.id         = conf.ReadValue("database", "id", i);
564                         host.host       = conf.ReadValue("database", "hostname", i);
565                         host.port       = conf.ReadInteger("database", "port", i, true);
566                         host.name       = conf.ReadValue("database", "name", i);
567                         host.user       = conf.ReadValue("database", "username", i);
568                         host.pass       = conf.ReadValue("database", "password", i);
569                         host.ssl        = conf.ReadFlag("database", "ssl", "0", i);
570
571                         if (HasHost(host))
572                                 continue;
573
574                         this->AddConn(host);
575                 }
576         }
577
578         void AddConn(const SQLhost& hi)
579         {
580                 if (HasHost(hi))
581                 {
582                         ServerInstance->Log(DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
583                         return;
584                 }
585
586                 SQLConn* newconn;
587
588                 newconn = new SQLConn(ServerInstance, this, hi);
589
590                 connections.insert(std::make_pair(hi.id, newconn));
591         }
592
593         void ClearOldConnections()
594         {
595                 ConnMap::iterator iter,safei;
596                 for (iter = connections.begin(); iter != connections.end(); iter++)
597                 {
598                         if (!HostInConf(iter->second->GetConfHost()))
599                         {
600                                 DELETE(iter->second);
601                                 safei = iter;
602                                 --iter;
603                                 connections.erase(safei);
604                         }
605                 }
606         }
607
608         void ClearAllConnections()
609         {
610                 ConnMap::iterator i;
611                 while ((i = connections.begin()) != connections.end())
612                 {
613                         connections.erase(i);
614                         DELETE(i->second);
615                 }
616         }
617
618         virtual void OnRehash(userrec* user, const std::string &parameter)
619         {
620                 ReadConf();
621         }
622
623         virtual char* OnRequest(Request* request)
624         {
625                 if(strcmp(SQLREQID, request->GetId()) == 0)
626                 {
627                         SQLrequest* req = (SQLrequest*)request;
628                         ConnMap::iterator iter;
629                         ServerInstance->Log(DEBUG, "Got query: '%s' with %d replacement parameters on id '%s'", req->query.q.c_str(), req->query.p.size(), req->dbid.c_str());
630                         if((iter = connections.find(req->dbid)) != connections.end())
631                         {
632                                 req->id = NewID();
633                                 req->error = iter->second->Query(*req);
634                                 return SQLSUCCESS;
635                         }
636                         else
637                         {
638                                 req->error.Id(BAD_DBID);
639                                 return NULL;
640                         }
641                 }
642                 ServerInstance->Log(DEBUG, "Got unsupported API version string: %s", request->GetId());
643                 return NULL;
644         }
645
646         unsigned long NewID()
647         {
648                 if (currid+1 == 0)
649                         currid++;
650
651                 return ++currid;
652         }
653
654         virtual Version GetVersion()
655         {
656                 return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);
657         }
658
659 };
660
661 void ResultNotifier::Dispatch()
662 {
663         ((ModuleSQLite3*)mod)->SendQueue();
664 }
665
666 class ModuleSQLite3Factory : public ModuleFactory
667 {
668   public:
669         ModuleSQLite3Factory()
670         {
671         }
672
673         ~ModuleSQLite3Factory()
674         {
675         }
676
677         virtual Module * CreateModule(InspIRCd* Me)
678         {
679                 return new ModuleSQLite3(Me);
680         }
681 };
682
683 extern "C" void * init_module( void )
684 {
685         return new ModuleSQLite3Factory;
686 }