]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_pgsql.cpp
6567aa787cc43433773e165ee5825372805e105c
[user/henk/code/inspircd.git] / src / modules / extra / m_pgsql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *                <omster@gmail.com>
10  *     
11  * Written by Craig Edwards, Craig McLure, and others.
12  * This program is free but copyrighted software; see
13  *            the file COPYING for details.
14  *
15  * ---------------------------------------------------
16  */
17
18 #include <sstream>
19 #include <string>
20 #include <deque>
21 #include <map>
22 #include <libpq-fe.h>
23
24 #include "users.h"
25 #include "channels.h"
26 #include "modules.h"
27 #include "helperfuncs.h"
28 #include "inspircd.h"
29 #include "configreader.h"
30
31 #include "m_sqlv2.h"
32
33 /* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */
34 /* $CompileFlags: -I`pg_config --includedir` */
35 /* $LinkerFlags: -L`pg_config --libdir` -lpq */
36
37 /* UGH, UGH, UGH, UGH, UGH, UGH
38  * I'm having trouble seeing how I
39  * can avoid this. The core-defined
40  * constructors for InspSocket just
41  * aren't suitable...and if I'm
42  * reimplementing them I need this so
43  * I can access the socket engine :\
44  */
45 extern InspIRCd* ServerInstance;
46 InspSocket* socket_ref[MAX_DESCRIPTORS];
47
48 /* Forward declare, so we can have the typedef neatly at the top */
49 class SQLConn;
50
51 typedef std::map<std::string, SQLConn*> ConnMap;
52
53 /* CREAD,       Connecting and wants read event
54  * CWRITE,      Connecting and wants write event
55  * WREAD,       Connected/Working and wants read event
56  * WWRITE,      Connected/Working and wants write event
57  */
58 enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE };
59
60 /** QueryQueue, a queue of queries waiting to be executed.
61  * This maintains two queues internally, one for 'priority'
62  * queries and one for less important ones. Each queue has
63  * new queries appended to it and ones to execute are popped
64  * off the front. This keeps them flowing round nicely and no
65  * query should ever get 'stuck' for too long. If there are
66  * queries in the priority queue they will be executed first,
67  * 'unimportant' queries will only be executed when the
68  * priority queue is empty.
69  *
70  * We store lists of SQLrequest's here, by value as we want to avoid storing
71  * any data allocated inside the client module (in case that module is unloaded
72  * while the query is in progress).
73  *
74  * Because we want to work on the current SQLrequest in-situ, we need a way
75  * of accessing the request we are currently processing, QueryQueue::front(),
76  * but that call needs to always return the same request until that request
77  * is removed from the queue, this is what the 'which' variable is. New queries are
78  * always added to the back of one of the two queues, but if when front()
79  * is first called then the priority queue is empty then front() will return
80  * a query from the normal queue, but if a query is then added to the priority
81  * queue then front() must continue to return the front of the *normal* queue
82  * until pop() is called.
83  */
84
85 class QueryQueue : public classbase
86 {
87 private:
88         std::deque<SQLrequest> priority;        /* The priority queue */
89         std::deque<SQLrequest> normal;  /* The 'normal' queue */
90         enum { PRI, NOR, NON } which;   /* Which queue the currently active element is at the front of */
91
92 public:
93         QueryQueue()
94         : which(NON)
95         {
96         }
97         
98         void push(const SQLrequest &q)
99         {
100                 log(DEBUG, "QueryQueue::push(): Adding %s query to queue: %s", ((q.pri) ? "priority" : "non-priority"), q.query.c_str());
101                 
102                 if(q.pri)
103                         priority.push_back(q);
104                 else
105                         normal.push_back(q);
106         }
107         
108         void pop()
109         {
110                 if((which == PRI) && priority.size())
111                 {
112                         priority.pop_front();
113                 }
114                 else if((which == NOR) && normal.size())
115                 {
116                         normal.pop_front();
117                 }
118                 
119                 /* Reset this */
120                 which = NON;
121                 
122                 /* Silently do nothing if there was no element to pop() */
123         }
124         
125         SQLrequest& front()
126         {
127                 switch(which)
128                 {
129                         case PRI:
130                                 return priority.front();
131                         case NOR:
132                                 return normal.front();
133                         default:
134                                 if(priority.size())
135                                 {
136                                         which = PRI;
137                                         return priority.front();
138                                 }
139                                 
140                                 if(normal.size())
141                                 {
142                                         which = NOR;
143                                         return normal.front();
144                                 }
145                                 
146                                 /* This will probably result in a segfault,
147                                  * but the caller should have checked totalsize()
148                                  * first so..meh - moron :p
149                                  */
150                                 
151                                 return priority.front();
152                 }
153         }
154         
155         std::pair<int, int> size()
156         {
157                 return std::make_pair(priority.size(), normal.size());
158         }
159         
160         int totalsize()
161         {
162                 return priority.size() + normal.size();
163         }
164 };
165
166 /** SQLConn represents one SQL session.
167  * Each session has its own persistent connection to the database.
168  * This is a subclass of InspSocket so it can easily recieve read/write events from the core socket
169  * engine, unlike the original MySQL module this module does not block. Ever. It gets a mild stabbing
170  * if it dares to.
171  */
172
173 class SQLConn : public InspSocket
174 {
175 private:
176         Server* Srv;                    /* Server* for..uhm..something, maybe */
177         std::string     dbhost; /* Database server hostname */
178         unsigned int    dbport; /* Database server port */
179         std::string     dbname; /* Database name */
180         std::string     dbuser; /* Database username */
181         std::string     dbpass; /* Database password */
182         bool                    ssl;    /* If we should require SSL */
183         PGconn*                 sql;    /* PgSQL database connection handle */
184         SQLstatus               status; /* PgSQL database connection status */
185         bool                    qinprog;/* If there is currently a query in progress */
186         QueryQueue              queue;  /* Queue of queries waiting to be executed on this connection */
187
188 public:
189
190         /* This class should only ever be created inside this module, using this constructor, so we don't have to worry about the default ones */
191
192         SQLConn(Server* srv, const std::string &h, unsigned int p, const std::string &d, const std::string &u, const std::string &pwd, bool s)
193         : InspSocket::InspSocket(), Srv(srv), dbhost(h), dbport(p), dbname(d), dbuser(u), dbpass(pwd), ssl(s), sql(NULL), status(CWRITE), qinprog(false)
194         {
195                 log(DEBUG, "Creating new PgSQL connection to database %s on %s:%u (%s/%s)", dbname.c_str(), dbhost.c_str(), dbport, dbuser.c_str(), dbpass.c_str());
196
197                 /* Some of this could be reviewed, unsure if I need to fill 'host' etc...
198                  * just copied this over from the InspSocket constructor.
199                  */
200                 strlcpy(this->host, dbhost.c_str(), MAXBUF);
201                 this->port = dbport;
202                 
203                 this->ClosePending = false;
204                 
205                 if(!inet_aton(this->host, &this->addy))
206                 {
207                         /* Its not an ip, spawn the resolver.
208                          * PgSQL doesn't do nonblocking DNS 
209                          * lookups, so we do it for it.
210                          */
211                         
212                         log(DEBUG,"Attempting to resolve %s", this->host);
213                         
214                         this->dns.SetNS(Srv->GetConfig()->DNSServer);
215                         this->dns.ForwardLookupWithFD(this->host, fd);
216                         
217                         this->state = I_RESOLVING;
218                         socket_ref[this->fd] = this;
219                         
220                         return;
221                 }
222                 else
223                 {
224                         log(DEBUG,"No need to resolve %s", this->host);
225                         strlcpy(this->IP, this->host, MAXBUF);
226                         
227                         if(!this->DoConnect())
228                         {
229                                 throw ModuleException("Connect failed");
230                         }
231                 }
232         }
233         
234         ~SQLConn()
235         {
236                 Close();
237         }
238         
239         bool DoResolve()
240         {       
241                 log(DEBUG, "Checking for DNS lookup result");
242                 
243                 if(this->dns.HasResult())
244                 {
245                         std::string res_ip = dns.GetResultIP();
246                         
247                         if(res_ip.length())
248                         {
249                                 log(DEBUG, "Got result: %s", res_ip.c_str());
250                                 
251                                 strlcpy(this->IP, res_ip.c_str(), MAXBUF);
252                                 dbhost = res_ip;
253                                 
254                                 socket_ref[this->fd] = NULL;
255                                 
256                                 return this->DoConnect();
257                         }
258                         else
259                         {
260                                 log(DEBUG, "DNS lookup failed, dying horribly");
261                                 Close();
262                                 return false;
263                         }
264                 }
265                 else
266                 {
267                         log(DEBUG, "No result for lookup yet!");
268                         return true;
269                 }
270         }
271
272         bool DoConnect()
273         {
274                 log(DEBUG, "SQLConn::DoConnect()");
275                 
276                 if(!(sql = PQconnectStart(MkInfoStr().c_str())))
277                 {
278                         log(DEBUG, "Couldn't allocate PGconn structure, aborting: %s", PQerrorMessage(sql));
279                         Close();
280                         return false;
281                 }
282                 
283                 if(PQstatus(sql) == CONNECTION_BAD)
284                 {
285                         log(DEBUG, "PQconnectStart failed: %s", PQerrorMessage(sql));
286                         Close();
287                         return false;
288                 }
289                 
290                 ShowStatus();
291                 
292                 if(PQsetnonblocking(sql, 1) == -1)
293                 {
294                         log(DEBUG, "Couldn't set connection nonblocking: %s", PQerrorMessage(sql));
295                         Close();
296                         return false;
297                 }
298                 
299                 /* OK, we've initalised the connection, now to get it hooked into the socket engine
300                  * and then start polling it.
301                  */
302                 
303                 log(DEBUG, "Old DNS socket: %d", this->fd);
304                 this->fd = PQsocket(sql);
305                 log(DEBUG, "New SQL socket: %d", this->fd);
306                 
307                 if(this->fd <= -1)
308                 {
309                         log(DEBUG, "PQsocket says we have an invalid FD: %d", this->fd);
310                         Close();
311                         return false;
312                 }
313                 
314                 this->state = I_CONNECTING;
315                 ServerInstance->SE->AddFd(this->fd,false,X_ESTAB_MODULE);
316                 socket_ref[this->fd] = this;
317                 
318                 /* Socket all hooked into the engine, now to tell PgSQL to start connecting */
319                 
320                 return DoPoll();
321         }
322         
323         virtual void Close()
324         {
325                 log(DEBUG,"SQLConn::Close");
326                 
327                 if(this->fd > 01)
328                         socket_ref[this->fd] = NULL;
329                 this->fd = -1;
330                 this->state = I_ERROR;
331                 this->OnError(I_ERR_SOCKET);
332                 this->ClosePending = true;
333                 
334                 if(sql)
335                 {
336                         PQfinish(sql);
337                         sql = NULL;
338                 }
339                 
340                 return;
341         }
342         
343         bool DoPoll()
344         {
345                 switch(PQconnectPoll(sql))
346                 {
347                         case PGRES_POLLING_WRITING:
348                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_WRITING");
349                                 WantWrite();
350                                 status = CWRITE;
351                                 return DoPoll();
352                         case PGRES_POLLING_READING:
353                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_READING");
354                                 status = CREAD;
355                                 break;
356                         case PGRES_POLLING_FAILED:
357                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_FAILED: %s", PQerrorMessage(sql));
358                                 return false;
359                         case PGRES_POLLING_OK:
360                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_OK");
361                                 status = WWRITE;
362                                 return DoConnectedPoll();
363                         default:
364                                 log(DEBUG, "PGconnectPoll: wtf?");
365                                 break;
366                 }
367                 
368                 return true;
369         }
370         
371         bool DoConnectedPoll()
372         {
373                 if(!qinprog && queue.totalsize())
374                 {
375                         /* There's no query currently in progress, and there's queries in the queue. */
376                         SQLrequest& query = queue.front();
377                         DoQuery(query);
378                 }
379                 
380                 if(PQconsumeInput(sql))
381                 {
382                         log(DEBUG, "PQconsumeInput succeeded");
383                                 
384                         if(PQisBusy(sql))
385                         {
386                                 log(DEBUG, "Still busy processing command though");
387                         }
388                         else if(qinprog)
389                         {
390                                 log(DEBUG, "Looks like we have a result to process!");
391                                 
392                                 while(PGresult* result = PQgetResult(sql))
393                                 {
394                                         int cols = PQnfields(result);
395                                         
396                                         log(DEBUG, "Got result! :D");
397                                         log(DEBUG, "%d rows, %d columns checking now what the column names are", PQntuples(result), cols);
398                                                 
399                                         for(int i = 0; i < cols; i++)
400                                         {
401                                                 log(DEBUG, "Column name: %s (%d)", PQfname(result, i));
402                                         }
403                                                 
404                                         PQclear(result);
405                                 }
406                                 
407                                 qinprog = false;
408                                 queue.pop();                            
409                                 DoConnectedPoll();
410                         }
411                         
412                         return true;
413                 }
414                 
415                 log(DEBUG, "PQconsumeInput failed: %s", PQerrorMessage(sql));
416                 return false;
417         }
418         
419         void ShowStatus()
420         {
421                 switch(PQstatus(sql))
422                 {
423                         case CONNECTION_STARTED:
424                                 log(DEBUG, "PQstatus: CONNECTION_STARTED: Waiting for connection to be made.");
425                                 break;
426  
427                         case CONNECTION_MADE:
428                                 log(DEBUG, "PQstatus: CONNECTION_MADE: Connection OK; waiting to send.");
429                                 break;
430                         
431                         case CONNECTION_AWAITING_RESPONSE:
432                                 log(DEBUG, "PQstatus: CONNECTION_AWAITING_RESPONSE: Waiting for a response from the server.");
433                                 break;
434                         
435                         case CONNECTION_AUTH_OK:
436                                 log(DEBUG, "PQstatus: CONNECTION_AUTH_OK: Received authentication; waiting for backend start-up to finish.");
437                                 break;
438                         
439                         case CONNECTION_SSL_STARTUP:
440                                 log(DEBUG, "PQstatus: CONNECTION_SSL_STARTUP: Negotiating SSL encryption.");
441                                 break;
442                         
443                         case CONNECTION_SETENV:
444                                 log(DEBUG, "PQstatus: CONNECTION_SETENV: Negotiating environment-driven parameter settings.");
445                                 break;
446                         
447                         default:
448                                 log(DEBUG, "PQstatus: ???");
449                 }
450         }
451         
452         virtual bool OnDataReady()
453         {
454                 /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */
455                 log(DEBUG, "OnDataReady(): status = %s", StatusStr());
456                 
457                 return DoEvent();
458         }
459
460         virtual bool OnWriteReady()
461         {
462                 /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */
463                 log(DEBUG, "OnWriteReady(): status = %s", StatusStr());
464                 
465                 return DoEvent();
466         }
467         
468         virtual bool OnConnected()
469         {
470                 log(DEBUG, "OnConnected(): status = %s", StatusStr());
471                 
472                 return DoEvent();
473         }
474         
475         bool DoEvent()
476         {
477                 bool ret;
478                 
479                 if((status == CREAD) || (status == CWRITE))
480                 {
481                         ret = DoPoll();
482                 }
483                 else
484                 {
485                         ret = DoConnectedPoll();
486                 }
487                 
488                 switch(PQflush(sql))
489                 {
490                         case -1:
491                                 log(DEBUG, "Error flushing write queue: %s", PQerrorMessage(sql));
492                                 break;
493                         case 0:
494                                 log(DEBUG, "Successfully flushed write queue (or there was nothing to write)");
495                                 break;
496                         case 1:
497                                 log(DEBUG, "Not all of the write queue written, triggering write event so we can have another go");
498                                 WantWrite();
499                                 break;
500                 }
501
502                 return ret;
503         }
504         
505         std::string MkInfoStr()
506         {                       
507                 /* XXX - This needs nonblocking DNS lookups */
508                 
509                 std::ostringstream conninfo("connect_timeout = '2'");
510                 
511                 if(dbhost.length())
512                         conninfo << " hostaddr = '" << dbhost << "'";
513                 
514                 if(dbport)
515                         conninfo << " port = '" << dbport << "'";
516                 
517                 if(dbname.length())
518                         conninfo << " dbname = '" << dbname << "'";
519                 
520                 if(dbuser.length())
521                         conninfo << " user = '" << dbuser << "'";
522                 
523                 if(dbpass.length())
524                         conninfo << " password = '" << dbpass << "'";
525                 
526                 if(ssl)
527                         conninfo << " sslmode = 'require'";
528                 
529                 return conninfo.str();
530         }
531         
532         const char* StatusStr()
533         {
534                 if(status == CREAD) return "CREAD";
535                 if(status == CWRITE) return "CWRITE";
536                 if(status == WREAD) return "WREAD";
537                 if(status == WWRITE) return "WWRITE";
538                 return "Err...what, erm..BUG!";
539         }
540         
541         SQLerror DoQuery(const SQLrequest &req)
542         {
543                 if((status == WREAD) || (status == WWRITE))
544                 {
545                         if(!qinprog)
546                         {
547                                 if(PQsendQuery(sql, req.query.c_str()))
548                                 {
549                                         log(DEBUG, "Dispatched query: %s", req.query.c_str());
550                                         qinprog = true;
551                                         return SQLerror();
552                                 }
553                                 else
554                                 {
555                                         log(DEBUG, "Failed to dispatch query: %s", PQerrorMessage(sql));
556                                         return SQLerror(QSEND_FAIL, PQerrorMessage(sql));
557                                 }
558                         }
559                 }
560
561                 log(DEBUG, "Can't query until connection is complete");
562                 return SQLerror(BAD_CONN, "Can't query until connection is complete");
563         }
564         
565         SQLerror Query(const SQLrequest &req)
566         {
567                 queue.push(req);
568                 
569                 if(!qinprog && queue.totalsize())
570                 {
571                         /* There's no query currently in progress, and there's queries in the queue. */
572                         SQLrequest& query = queue.front();
573                         return DoQuery(query);
574                 }
575                 else
576                 {
577                         return SQLerror();
578                 }
579         }
580 };
581
582 class ModulePgSQL : public Module
583 {
584 private:
585         Server* Srv;
586         ConnMap connections;
587
588 public:
589         ModulePgSQL(Server* Me)
590         : Module::Module(Me), Srv(Me)
591         {
592                 log(DEBUG, "%s 'SQL' feature", Srv->PublishFeature("SQL", this) ? "Published" : "Couldn't publish");
593                 log(DEBUG, "%s 'PgSQL' feature", Srv->PublishFeature("PgSQL", this) ? "Published" : "Couldn't publish");
594
595                 OnRehash("");
596         }
597
598         void Implements(char* List)
599         {
600                 List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = List[I_OnCheckReady] = List[I_OnUserDisconnect] = 1;
601         }
602
603         virtual void OnRehash(const std::string &parameter)
604         {
605                 ConfigReader conf;
606                 
607                 /* Delete all the SQLConn objects in the connection lists,
608                  * this will call their destructors where they can handle
609                  * closing connections and such.
610                  */
611                 for(ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
612                 {
613                         DELETE(iter->second);
614                 }
615                 
616                 /* Empty out our list of connections */
617                 connections.clear();
618
619                 for(int i = 0; i < conf.Enumerate("database"); i++)
620                 {
621                         std::string id;
622                         SQLConn* newconn;
623                         
624                         id = conf.ReadValue("database", "id", i);
625                         newconn = new SQLConn(Srv,      conf.ReadValue("database", "hostname", i),
626                                                                                 conf.ReadInteger("database", "port", i, true),
627                                                                                 conf.ReadValue("database", "name", i),
628                                                                                 conf.ReadValue("database", "username", i),
629                                                                                 conf.ReadValue("database", "password", i),
630                                                                                 conf.ReadFlag("database", "ssl", i));
631                         
632                         connections.insert(std::make_pair(id, newconn));
633                 }       
634         }
635         
636         virtual char* OnRequest(Request* request)
637         {
638                 if(strcmp(SQLREQID, request->GetData()) == 0)
639                 {
640                         SQLrequest* req = (SQLrequest*)request;
641                         ConnMap::iterator iter;
642                 
643                         log(DEBUG, "Got query: '%s' on id '%s'", req->query.c_str(), req->dbid.c_str());
644
645                         if((iter = connections.find(req->dbid)) != connections.end())
646                         {
647                                 /* Execute query */
648                                 req->error = iter->second->Query(*req);
649                                 
650                                 return SQLSUCCESS;
651                         }
652                         else
653                         {
654                                 req->error.Id(BAD_DBID);
655                                 return NULL;
656                         }
657                 }
658
659                 log(DEBUG, "Got unsupported API version string: %s", request->GetData());
660                 
661                 return NULL;
662         }
663                 
664         virtual Version GetVersion()
665         {
666                 return Version(1, 0, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER);
667         }
668         
669         virtual ~ModulePgSQL()
670         {
671         }       
672 };
673
674 class ModulePgSQLFactory : public ModuleFactory
675 {
676  public:
677         ModulePgSQLFactory()
678         {
679         }
680         
681         ~ModulePgSQLFactory()
682         {
683         }
684         
685         virtual Module * CreateModule(Server* Me)
686         {
687                 return new ModulePgSQL(Me);
688         }
689 };
690
691
692 extern "C" void * init_module( void )
693 {
694         return new ModulePgSQLFactory;
695 }