]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_pgsql.cpp
c44c66bc8b3e239625db87e51f42bf9657294ccb
[user/henk/code/inspircd.git] / src / modules / extra / m_pgsql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *                <omster@gmail.com>
10  *     
11  * Written by Craig Edwards, Craig McLure, and others.
12  * This program is free but copyrighted software; see
13  *            the file COPYING for details.
14  *
15  * ---------------------------------------------------
16  */
17
18 #include <sstream>
19 #include <string>
20 #include <deque>
21 #include <map>
22 #include <libpq-fe.h>
23
24 #include "users.h"
25 #include "channels.h"
26 #include "modules.h"
27 #include "helperfuncs.h"
28 #include "inspircd.h"
29 #include "configreader.h"
30
31 #include "m_sqlv2.h"
32
33 /* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */
34 /* $CompileFlags: -I`pg_config --includedir` */
35 /* $LinkerFlags: -L`pg_config --libdir` -lpq */
36
37 /* UGH, UGH, UGH, UGH, UGH, UGH
38  * I'm having trouble seeing how I
39  * can avoid this. The core-defined
40  * constructors for InspSocket just
41  * aren't suitable...and if I'm
42  * reimplementing them I need this so
43  * I can access the socket engine :\
44  */
45 extern InspIRCd* ServerInstance;
46 InspSocket* socket_ref[MAX_DESCRIPTORS];
47
48 /* Forward declare, so we can have the typedef neatly at the top */
49 class SQLConn;
50
51 typedef std::map<std::string, SQLConn*> ConnMap;
52
53 /* CREAD,       Connecting and wants read event
54  * CWRITE,      Connecting and wants write event
55  * WREAD,       Connected/Working and wants read event
56  * WWRITE,      Connected/Working and wants write event
57  */
58 enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE };
59
60 inline std::string pop_front_r(std::deque<std::string> &d)
61 {
62         std::string r = d.front();
63         d.pop_front();
64         return r;
65
66 }
67
68 /** QueryQueue, a queue of queries waiting to be executed.
69  * This maintains two queues internally, one for 'priority'
70  * queries and one for less important ones. Each queue has
71  * new queries appended to it and ones to execute are popped
72  * off the front. This keeps them flowing round nicely and no
73  * query should ever get 'stuck' for too long. If there are
74  * queries in the priority queue they will be executed first,
75  * 'unimportant' queries will only be executed when the
76  * priority queue is empty.
77  */
78
79 class QueryQueue
80 {
81 private:
82         std::deque<std::string> priority;       /* The priority queue */
83         std::deque<std::string> normal;         /* The 'normal' queue */
84
85 public:
86         QueryQueue()
87         {
88         
89         }
90         
91         void push_back(const std::string &q, bool pri = false)
92         {
93                 log(DEBUG, "QueryQueue::push_back(): Adding %s query to queue: %s", ((pri) ? "priority" : "non-priority"), q.c_str());
94                 
95                 if(pri)
96                         priority.push_back(q);
97                 else
98                         normal.push_back(q);
99         }
100         
101         inline std::string pop_front()
102         {
103                 std::string res;
104                 
105                 if(priority.size())
106                 {
107                         return pop_front_r(priority);
108                 }
109                 else if(normal.size())
110                 {
111                         return pop_front_r(normal);
112                 }
113                 else
114                 {
115                         return "";      
116                 }
117         }
118         
119         std::pair<int, int> size()
120         {
121                 return std::make_pair(priority.size(), normal.size());
122         }
123 };
124
125 /** SQLConn represents one SQL session.
126  * Each session has its own persistent connection to the database.
127  * This is a subclass of InspSocket so it can easily recieve read/write events from the core socket
128  * engine, unlike the original MySQL module this module does not block. Ever. It gets a mild stabbing
129  * if it dares to.
130  */
131
132 class SQLConn : public InspSocket
133 {
134 private:
135         Server* Srv;                    /* Server* for..uhm..something, maybe */
136         std::string     dbhost; /* Database server hostname */
137         unsigned int    dbport; /* Database server port */
138         std::string     dbname; /* Database name */
139         std::string     dbuser; /* Database username */
140         std::string     dbpass; /* Database password */
141         bool                    ssl;    /* If we should require SSL */
142         PGconn*                 sql;    /* PgSQL database connection handle */
143         SQLstatus               status; /* PgSQL database connection status */
144
145 public:
146         QueryQueue              queue;  /* Queue of queries waiting to be executed on this connection */
147
148         /* This class should only ever be created inside this module, using this constructor, so we don't have to worry about the default ones */
149
150         SQLConn(Server* srv, const std::string &h, unsigned int p, const std::string &d, const std::string &u, const std::string &pwd, bool s)
151         : InspSocket::InspSocket(), Srv(srv), dbhost(h), dbport(p), dbname(d), dbuser(u), dbpass(pwd), ssl(s), sql(NULL), status(CWRITE)
152         {
153                 log(DEBUG, "Creating new PgSQL connection to database %s on %s:%u (%s/%s)", dbname.c_str(), dbhost.c_str(), dbport, dbuser.c_str(), dbpass.c_str());
154
155                 /* Some of this could be reviewed, unsure if I need to fill 'host' etc...
156                  * just copied this over from the InspSocket constructor.
157                  */
158                 strlcpy(this->host, dbhost.c_str(), MAXBUF);
159                 this->port = dbport;
160                 
161                 this->ClosePending = false;
162                 
163                 if(!inet_aton(this->host, &this->addy))
164                 {
165                         /* Its not an ip, spawn the resolver.
166                          * PgSQL doesn't do nonblocking DNS 
167                          * lookups, so we do it for it.
168                          */
169                         
170                         log(DEBUG,"Attempting to resolve %s", this->host);
171                         
172                         this->dns.SetNS(Srv->GetConfig()->DNSServer);
173                         this->dns.ForwardLookupWithFD(this->host, fd);
174                         
175                         this->state = I_RESOLVING;
176                         socket_ref[this->fd] = this;
177                         
178                         return;
179                 }
180                 else
181                 {
182                         log(DEBUG,"No need to resolve %s", this->host);
183                         strlcpy(this->IP, this->host, MAXBUF);
184                         
185                         if(!this->DoConnect())
186                         {
187                                 throw ModuleException("Connect failed");
188                         }
189                 }
190         }
191         
192         ~SQLConn()
193         {
194                 
195         }
196         
197         bool DoResolve()
198         {       
199                 log(DEBUG, "Checking for DNS lookup result");
200                 
201                 if(this->dns.HasResult())
202                 {
203                         std::string res_ip = dns.GetResultIP();
204                         
205                         if(res_ip.length())
206                         {
207                                 log(DEBUG, "Got result: %s", res_ip.c_str());
208                                 
209                                 strlcpy(this->IP, res_ip.c_str(), MAXBUF);
210                                 dbhost = res_ip;
211                                 
212                                 socket_ref[this->fd] = NULL;
213                                 
214                                 return this->DoConnect();
215                         }
216                         else
217                         {
218                                 log(DEBUG, "DNS lookup failed, dying horribly");
219                                 Close();
220                                 return false;
221                         }
222                 }
223                 else
224                 {
225                         log(DEBUG, "No result for lookup yet!");
226                         return true;
227                 }
228         }
229
230         bool DoConnect()
231         {
232                 log(DEBUG, "SQLConn::DoConnect()");
233                 
234                 if(!(sql = PQconnectStart(MkInfoStr().c_str())))
235                 {
236                         log(DEBUG, "Couldn't allocate PGconn structure, aborting: %s", PQerrorMessage(sql));
237                         Close();
238                         return false;
239                 }
240                 
241                 if(PQstatus(sql) == CONNECTION_BAD)
242                 {
243                         log(DEBUG, "PQconnectStart failed: %s", PQerrorMessage(sql));
244                         Close();
245                         return false;
246                 }
247                 
248                 ShowStatus();
249                 
250                 if(PQsetnonblocking(sql, 1) == -1)
251                 {
252                         log(DEBUG, "Couldn't set connection nonblocking: %s", PQerrorMessage(sql));
253                         Close();
254                         return false;
255                 }
256                 
257                 /* OK, we've initalised the connection, now to get it hooked into the socket engine
258                  * and then start polling it.
259                  */
260                 
261                 log(DEBUG, "Old DNS socket: %d", this->fd);
262                 this->fd = PQsocket(sql);
263                 log(DEBUG, "New SQL socket: %d", this->fd);
264                 
265                 if(this->fd <= -1)
266                 {
267                         log(DEBUG, "PQsocket says we have an invalid FD: %d", this->fd);
268                         Close();
269                         return false;
270                 }
271                 
272                 this->state = I_CONNECTING;
273                 ServerInstance->SE->AddFd(this->fd,false,X_ESTAB_MODULE);
274                 socket_ref[this->fd] = this;
275                 
276                 /* Socket all hooked into the engine, now to tell PgSQL to start connecting */
277                 
278                 return DoPoll();
279         }
280         
281         virtual void Close()
282         {
283                 this->fd = -1;
284                 this->state = I_ERROR;
285                 this->OnError(I_ERR_SOCKET);
286                 this->ClosePending = true;
287                 log(DEBUG,"SQLConn::Close");
288                 
289                 if(sql)
290                 {
291                         PQfinish(sql);
292                         sql = NULL;
293                 }
294                 
295                 return;
296         }
297         
298         bool DoPoll()
299         {
300                 switch(PQconnectPoll(sql))
301                 {
302                         case PGRES_POLLING_WRITING:
303                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_WRITING");
304                                 WantWrite();
305                                 status = CWRITE;
306                                 DoPoll();
307                                 break;
308                         case PGRES_POLLING_READING:
309                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_READING");
310                                 status = CREAD;
311                                 break;
312                         case PGRES_POLLING_FAILED:
313                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_FAILED: %s", PQerrorMessage(sql));
314                                 Close();
315                                 return false;
316                         case PGRES_POLLING_OK:
317                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_OK");
318                                 status = WWRITE;
319                                 break;
320                         default:
321                                 log(DEBUG, "PGconnectPoll: wtf?");
322                                 break;
323                 }
324                 
325                 return true;
326         }
327         
328         bool ProcessData()
329         {
330                 if(PQconsumeInput(sql))
331                 {
332                         log(DEBUG, "PQconsumeInput succeeded");
333                                 
334                         if(PQisBusy(sql))
335                         {
336                                 log(DEBUG, "Still busy processing command though");
337                         }
338                         else
339                         {
340                                 log(DEBUG, "Looks like we have a result to process!");
341                                 
342                                 while(PGresult* result = PQgetResult(sql))
343                                 {
344                                         int cols = PQnfields(result);
345                                         
346                                         log(DEBUG, "Got result! :D");
347                                         log(DEBUG, "%d rows, %d columns checking now what the column names are", PQntuples(result), cols);
348                                                 
349                                         for(int i = 0; i < cols; i++)
350                                         {
351                                                 log(DEBUG, "Column name: %s (%d)", PQfname(result, i));
352                                         }
353                                                 
354                                         PQclear(result);
355                                 }
356                         }
357                         
358                         return true;
359                 }
360                 
361                 log(DEBUG, "PQconsumeInput failed: %s", PQerrorMessage(sql));
362                 return false;
363         }
364         
365         void ShowStatus()
366         {
367                 switch(PQstatus(sql))
368                 {
369                         case CONNECTION_STARTED:
370                                 log(DEBUG, "PQstatus: CONNECTION_STARTED: Waiting for connection to be made.");
371                                 break;
372  
373                         case CONNECTION_MADE:
374                                 log(DEBUG, "PQstatus: CONNECTION_MADE: Connection OK; waiting to send.");
375                                 break;
376                         
377                         case CONNECTION_AWAITING_RESPONSE:
378                                 log(DEBUG, "PQstatus: CONNECTION_AWAITING_RESPONSE: Waiting for a response from the server.");
379                                 break;
380                         
381                         case CONNECTION_AUTH_OK:
382                                 log(DEBUG, "PQstatus: CONNECTION_AUTH_OK: Received authentication; waiting for backend start-up to finish.");
383                                 break;
384                         
385                         case CONNECTION_SSL_STARTUP:
386                                 log(DEBUG, "PQstatus: CONNECTION_SSL_STARTUP: Negotiating SSL encryption.");
387                                 break;
388                         
389                         case CONNECTION_SETENV:
390                                 log(DEBUG, "PQstatus: CONNECTION_SETENV: Negotiating environment-driven parameter settings.");
391                                 break;
392                         
393                         default:
394                                 log(DEBUG, "PQstatus: ???");
395                 }
396         }
397         
398         virtual bool OnDataReady()
399         {
400                 /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */
401                 log(DEBUG, "OnDataReady(): status = %s", StatusStr());
402                 
403                 return DoEvent();
404         }
405
406         virtual bool OnWriteReady()
407         {
408                 /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */
409                 log(DEBUG, "OnWriteReady(): status = %s", StatusStr());
410                 
411                 return DoEvent();
412         }
413         
414         virtual bool OnConnected()
415         {
416                 log(DEBUG, "OnConnected(): status = %s", StatusStr());
417                 
418                 return DoEvent();
419         }
420         
421         bool DoEvent()
422         {
423                 if((status == CREAD) || (status == CWRITE))
424                 {
425                         DoPoll();
426                 }
427                 else
428                 {
429                         ProcessData();
430                 }
431                 
432                 switch(PQflush(sql))
433                 {
434                         case -1:
435                                 log(DEBUG, "Error flushing write queue: %s", PQerrorMessage(sql));
436                                 break;
437                         case 0:
438                                 log(DEBUG, "Successfully flushed write queue (or there was nothing to write)");
439                                 break;
440                         case 1:
441                                 log(DEBUG, "Not all of the write queue written, triggering write event so we can have another go");
442                                 WantWrite();
443                                 break;
444                 }
445
446                 return true;
447         }
448         
449         std::string MkInfoStr()
450         {                       
451                 /* XXX - This needs nonblocking DNS lookups */
452                 
453                 std::ostringstream conninfo("connect_timeout = '2'");
454                 
455                 if(dbhost.length())
456                         conninfo << " hostaddr = '" << dbhost << "'";
457                 
458                 if(dbport)
459                         conninfo << " port = '" << dbport << "'";
460                 
461                 if(dbname.length())
462                         conninfo << " dbname = '" << dbname << "'";
463                 
464                 if(dbuser.length())
465                         conninfo << " user = '" << dbuser << "'";
466                 
467                 if(dbpass.length())
468                         conninfo << " password = '" << dbpass << "'";
469                 
470                 if(ssl)
471                         conninfo << " sslmode = 'require'";
472                 
473                 return conninfo.str();
474         }
475         
476         const char* StatusStr()
477         {
478                 if(status == CREAD) return "CREAD";
479                 if(status == CWRITE) return "CWRITE";
480                 if(status == WREAD) return "WREAD";
481                 if(status == WWRITE) return "WWRITE";
482                 return "Err...what, erm..BUG!";
483         }
484         
485         bool Query(const std::string &query)
486         {
487                 if((status == WREAD) || (status == WWRITE))
488                 {
489                         if(PQsendQuery(sql, query.c_str()))
490                         {
491                                 log(DEBUG, "Dispatched query: %s", query.c_str());
492                                 return true;
493                         }
494                         else
495                         {
496                                 log(DEBUG, "Failed to dispatch query: %s", PQerrorMessage(sql));
497                                 return false;
498                         }
499                 }
500
501                 log(DEBUG, "Can't query until connection is complete");
502                 return false;
503         }
504
505         virtual void OnClose()
506         {
507                 /* Close PgSQL connection */
508         }
509
510         virtual void OnError(InspSocketError e)
511         {
512                 /* Unsure if we need this, we should be reading/writing via the PgSQL API rather than the insp one... */
513         }
514         
515         virtual void OnTimeout()
516         {
517                 /* Unused, I think */
518         }
519 };
520
521 class ModulePgSQL : public Module
522 {
523 private:
524         Server* Srv;
525         ConnMap connections;
526
527 public:
528         ModulePgSQL(Server* Me)
529         : Module::Module(Me), Srv(Me)
530         {
531                 log(DEBUG, "%s 'SQL' feature", Srv->PublishFeature("SQL", this) ? "Published" : "Couldn't publish");
532                 log(DEBUG, "%s 'PgSQL' feature", Srv->PublishFeature("PgSQL", this) ? "Published" : "Couldn't publish");
533
534                 OnRehash("");
535         }
536
537         void Implements(char* List)
538         {
539                 List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = List[I_OnCheckReady] = List[I_OnUserDisconnect] = 1;
540         }
541
542         virtual void OnRehash(const std::string &parameter)
543         {
544                 ConfigReader conf;
545                 
546                 /* Delete all the SQLConn objects in the connection lists,
547                  * this will call their destructors where they can handle
548                  * closing connections and such.
549                  */
550                 for(ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
551                 {
552                         DELETE(iter->second);
553                 }
554                 
555                 /* Empty out our list of connections */
556                 connections.clear();
557
558                 for(int i = 0; i < conf.Enumerate("database"); i++)
559                 {
560                         std::string id;
561                         SQLConn* newconn;
562                         
563                         id = conf.ReadValue("database", "id", i);
564                         newconn = new SQLConn(Srv,      conf.ReadValue("database", "hostname", i),
565                                                                                 conf.ReadInteger("database", "port", i, true),
566                                                                                 conf.ReadValue("database", "name", i),
567                                                                                 conf.ReadValue("database", "username", i),
568                                                                                 conf.ReadValue("database", "password", i),
569                                                                                 conf.ReadFlag("database", "ssl", i));
570                         
571                         connections.insert(std::make_pair(id, newconn));
572                 }       
573         }
574         
575         virtual char* OnRequest(Request* request)
576         {
577                 if(strcmp(SQLREQID, request->GetData()) == 0)
578                 {
579                         SQLrequest* req = (SQLrequest*)request;
580                         ConnMap::iterator iter;
581                 
582                         log(DEBUG, "Got query: '%s' on id '%s'", req->query.c_str(), req->dbid.c_str());
583
584                         if((iter = connections.find(req->dbid)) != connections.end())
585                         {
586                                 /* Execute query */
587                                 iter->second->queue.push_back(req->query, req->pri);
588                                 
589                                 return SQLSUCCESS;
590                         }
591                         else
592                         {
593                                 req->error.Id(BAD_DBID);
594                                 return NULL;
595                         }
596                 }
597
598                 log(DEBUG, "Got unsupported API version string: %s", request->GetData());
599                 
600                 return NULL;
601         }
602                 
603         virtual Version GetVersion()
604         {
605                 return Version(1, 0, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER);
606         }
607         
608         virtual ~ModulePgSQL()
609         {
610         }       
611 };
612
613 class ModulePgSQLFactory : public ModuleFactory
614 {
615  public:
616         ModulePgSQLFactory()
617         {
618         }
619         
620         ~ModulePgSQLFactory()
621         {
622         }
623         
624         virtual Module * CreateModule(Server* Me)
625         {
626                 return new ModulePgSQL(Me);
627         }
628 };
629
630
631 extern "C" void * init_module( void )
632 {
633         return new ModulePgSQLFactory;
634 }