]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_pgsql.cpp
More stuff to return empty lists and maps when there are no more rows in the dataset
[user/henk/code/inspircd.git] / src / modules / extra / m_pgsql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *                <omster@gmail.com>
10  *     
11  * Written by Craig Edwards, Craig McLure, and others.
12  * This program is free but copyrighted software; see
13  *            the file COPYING for details.
14  *
15  * ---------------------------------------------------
16  */
17
18 #include <cstdlib>
19 #include <sstream>
20 #include <string>
21 #include <deque>
22 #include <map>
23 #include <libpq-fe.h>
24
25 #include "users.h"
26 #include "channels.h"
27 #include "modules.h"
28 #include "helperfuncs.h"
29 #include "inspircd.h"
30 #include "configreader.h"
31
32 #include "m_sqlv2.h"
33
34 /* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */
35 /* $CompileFlags: -I`pg_config --includedir` `perl extra/pgsql_config.pl` */
36 /* $LinkerFlags: -L`pg_config --libdir` -lpq */
37
38 /* UGH, UGH, UGH, UGH, UGH, UGH
39  * I'm having trouble seeing how I
40  * can avoid this. The core-defined
41  * constructors for InspSocket just
42  * aren't suitable...and if I'm
43  * reimplementing them I need this so
44  * I can access the socket engine :\
45  */
46 extern InspIRCd* ServerInstance;
47 InspSocket* socket_ref[MAX_DESCRIPTORS];
48
49 /* Forward declare, so we can have the typedef neatly at the top */
50 class SQLConn;
51 /* Also needs forward declaration, as it's used inside SQLconn */
52 class ModulePgSQL;
53
54 typedef std::map<std::string, SQLConn*> ConnMap;
55
56 /* CREAD,       Connecting and wants read event
57  * CWRITE,      Connecting and wants write event
58  * WREAD,       Connected/Working and wants read event
59  * WWRITE,      Connected/Working and wants write event
60  */
61 enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE };
62
63 /** QueryQueue, a queue of queries waiting to be executed.
64  * This maintains two queues internally, one for 'priority'
65  * queries and one for less important ones. Each queue has
66  * new queries appended to it and ones to execute are popped
67  * off the front. This keeps them flowing round nicely and no
68  * query should ever get 'stuck' for too long. If there are
69  * queries in the priority queue they will be executed first,
70  * 'unimportant' queries will only be executed when the
71  * priority queue is empty.
72  *
73  * We store lists of SQLrequest's here, by value as we want to avoid storing
74  * any data allocated inside the client module (in case that module is unloaded
75  * while the query is in progress).
76  *
77  * Because we want to work on the current SQLrequest in-situ, we need a way
78  * of accessing the request we are currently processing, QueryQueue::front(),
79  * but that call needs to always return the same request until that request
80  * is removed from the queue, this is what the 'which' variable is. New queries are
81  * always added to the back of one of the two queues, but if when front()
82  * is first called then the priority queue is empty then front() will return
83  * a query from the normal queue, but if a query is then added to the priority
84  * queue then front() must continue to return the front of the *normal* queue
85  * until pop() is called.
86  */
87
88 class QueryQueue : public classbase
89 {
90 private:
91         typedef std::deque<SQLrequest> ReqDeque;        
92
93         ReqDeque priority;      /* The priority queue */
94         ReqDeque normal;        /* The 'normal' queue */
95         enum { PRI, NOR, NON } which;   /* Which queue the currently active element is at the front of */
96
97 public:
98         QueryQueue()
99         : which(NON)
100         {
101         }
102         
103         void push(const SQLrequest &q)
104         {
105                 log(DEBUG, "QueryQueue::push(): Adding %s query to queue: %s", ((q.pri) ? "priority" : "non-priority"), q.query.q.c_str());
106                 
107                 if(q.pri)
108                         priority.push_back(q);
109                 else
110                         normal.push_back(q);
111         }
112         
113         void pop()
114         {
115                 if((which == PRI) && priority.size())
116                 {
117                         priority.pop_front();
118                 }
119                 else if((which == NOR) && normal.size())
120                 {
121                         normal.pop_front();
122                 }
123                 
124                 /* Reset this */
125                 which = NON;
126                 
127                 /* Silently do nothing if there was no element to pop() */
128         }
129         
130         SQLrequest& front()
131         {
132                 switch(which)
133                 {
134                         case PRI:
135                                 return priority.front();
136                         case NOR:
137                                 return normal.front();
138                         default:
139                                 if(priority.size())
140                                 {
141                                         which = PRI;
142                                         return priority.front();
143                                 }
144                                 
145                                 if(normal.size())
146                                 {
147                                         which = NOR;
148                                         return normal.front();
149                                 }
150                                 
151                                 /* This will probably result in a segfault,
152                                  * but the caller should have checked totalsize()
153                                  * first so..meh - moron :p
154                                  */
155                                 
156                                 return priority.front();
157                 }
158         }
159         
160         std::pair<int, int> size()
161         {
162                 return std::make_pair(priority.size(), normal.size());
163         }
164         
165         int totalsize()
166         {
167                 return priority.size() + normal.size();
168         }
169         
170         void PurgeModule(Module* mod)
171         {
172                 DoPurgeModule(mod, priority);
173                 DoPurgeModule(mod, normal);
174         }
175         
176 private:
177         void DoPurgeModule(Module* mod, ReqDeque& q)
178         {
179                 for(ReqDeque::iterator iter = q.begin(); iter != q.end(); iter++)
180                 {
181                         if(iter->GetSource() == mod)
182                         {
183                                 if(iter->id == front().id)
184                                 {
185                                         /* It's the currently active query.. :x */
186                                         iter->SetSource(NULL);
187                                 }
188                                 else
189                                 {
190                                         /* It hasn't been executed yet..just remove it */
191                                         iter = q.erase(iter);
192                                 }
193                         }
194                 }
195         }
196 };
197
198 /** PgSQLresult is a subclass of the mostly-pure-virtual class SQLresult.
199  * All SQL providers must create their own subclass and define it's methods using that
200  * database library's data retriveal functions. The aim is to avoid a slow and inefficient process
201  * of converting all data to a common format before it reaches the result structure. This way
202  * data is passes to the module nearly as directly as if it was using the API directly itself.
203  */
204
205 class PgSQLresult : public SQLresult
206 {
207         PGresult* res;
208         int currentrow;
209         int rows;
210         int cols;
211         
212         SQLfieldList* fieldlist;
213         SQLfieldMap* fieldmap;
214 public:
215         PgSQLresult(Module* self, Module* to, unsigned long id, PGresult* result)
216         : SQLresult(self, to, id), res(result), currentrow(0), fieldlist(NULL), fieldmap(NULL)
217         {
218                 rows = PQntuples(res);
219                 cols = PQnfields(res);
220                 
221                 log(DEBUG, "Created new PgSQL result; %d rows, %d columns, %s affected", rows, cols, PQcmdTuples(res));
222         }
223         
224         ~PgSQLresult()
225         {
226                 PQclear(res);
227         }
228         
229         virtual int Rows()
230         {
231                 if(!cols && !rows)
232                 {
233                         return atoi(PQcmdTuples(res));
234                 }
235                 else
236                 {
237                         return rows;
238                 }
239         }
240         
241         virtual int Cols()
242         {
243                 return PQnfields(res);
244         }
245         
246         virtual std::string ColName(int column)
247         {
248                 char* name = PQfname(res, column);
249                 
250                 return (name) ? name : "";
251         }
252         
253         virtual int ColNum(const std::string &column)
254         {
255                 int n = PQfnumber(res, column.c_str());
256                 
257                 if(n == -1)
258                 {
259                         throw SQLbadColName();
260                 }
261                 else
262                 {
263                         return n;
264                 }
265         }
266         
267         virtual SQLfield GetValue(int row, int column)
268         {
269                 char* v = PQgetvalue(res, row, column);
270                 
271                 if(v)
272                 {
273                         return SQLfield(std::string(v, PQgetlength(res, row, column)), PQgetisnull(res, row, column));
274                 }
275                 else
276                 {
277                         log(DEBUG, "PQgetvalue returned a null pointer..nobody wants to tell us what this means");
278                         throw SQLbadColName();
279                 }
280         }
281         
282         virtual SQLfieldList& GetRow()
283         {
284                 /* In an effort to reduce overhead we don't actually allocate the list
285                  * until the first time it's needed...so...
286                  */
287                 if(fieldlist)
288                 {
289                         fieldlist->clear();
290                 }
291                 else
292                 {
293                         fieldlist = new SQLfieldList;
294                 }
295                 
296                 if(currentrow < PQntuples(res))
297                 {
298                         int cols = PQnfields(res);
299                         
300                         for(int i = 0; i < cols; i++)
301                         {
302                                 fieldlist->push_back(GetValue(currentrow, i));
303                         }
304                         
305                         currentrow++;
306                 }
307                 
308                 return *fieldlist;
309         }
310         
311         virtual SQLfieldMap& GetRowMap()
312         {
313                 /* In an effort to reduce overhead we don't actually allocate the map
314                  * until the first time it's needed...so...
315                  */
316                 if(fieldmap)
317                 {
318                         fieldmap->clear();
319                 }
320                 else
321                 {
322                         fieldmap = new SQLfieldMap;
323                 }
324                 
325                 if(currentrow < PQntuples(res))
326                 {
327                         int cols = PQnfields(res);
328                         
329                         for(int i = 0; i < cols; i++)
330                         {
331                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
332                         }
333                         
334                         currentrow++;
335                 }
336                 
337                 return *fieldmap;
338         }
339         
340         virtual SQLfieldList* GetRowPtr()
341         {
342                 SQLfieldList* fl = new SQLfieldList;
343                 
344                 if(currentrow < PQntuples(res))
345                 {
346                         int cols = PQnfields(res);
347                         
348                         for(int i = 0; i < cols; i++)
349                         {
350                                 fl->push_back(GetValue(currentrow, i));
351                         }
352                         
353                         currentrow++;
354                 }
355                 
356                 return fl;
357         }
358         
359         virtual SQLfieldMap* GetRowMapPtr()
360         {
361                 SQLfieldMap* fm = new SQLfieldMap;
362                 
363                 if(currentrow < PQntuples(res))
364                 {
365                         int cols = PQnfields(res);
366                         
367                         for(int i = 0; i < cols; i++)
368                         {
369                                 fm->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
370                         }
371                         
372                         currentrow++;
373                 }
374                 
375                 return fm;
376         }
377         
378         virtual void Free(SQLfieldMap* fm)
379         {
380                 DELETE(fm);
381         }
382         
383         virtual void Free(SQLfieldList* fl)
384         {
385                 DELETE(fl);
386         }
387 };
388
389 /** SQLConn represents one SQL session.
390  * Each session has its own persistent connection to the database.
391  * This is a subclass of InspSocket so it can easily recieve read/write events from the core socket
392  * engine, unlike the original MySQL module this module does not block. Ever. It gets a mild stabbing
393  * if it dares to.
394  */
395
396 class SQLConn : public InspSocket
397 {
398 private:
399         ModulePgSQL* us;                /* Pointer to the SQL provider itself */
400         Server* Srv;                    /* Server* for..uhm..something, maybe */
401         std::string     dbhost; /* Database server hostname */
402         unsigned int    dbport; /* Database server port */
403         std::string     dbname; /* Database name */
404         std::string     dbuser; /* Database username */
405         std::string     dbpass; /* Database password */
406         bool                    ssl;    /* If we should require SSL */
407         PGconn*                 sql;    /* PgSQL database connection handle */
408         SQLstatus               status; /* PgSQL database connection status */
409         bool                    qinprog;/* If there is currently a query in progress */
410         QueryQueue              queue;  /* Queue of queries waiting to be executed on this connection */
411
412 public:
413
414         /* This class should only ever be created inside this module, using this constructor, so we don't have to worry about the default ones */
415
416         SQLConn(ModulePgSQL* self, Server* srv, const std::string &h, unsigned int p, const std::string &d, const std::string &u, const std::string &pwd, bool s);
417
418         ~SQLConn();
419
420         bool DoResolve();
421
422         bool DoConnect();
423
424         virtual void Close();
425         
426         bool DoPoll();
427         
428         bool DoConnectedPoll();
429         
430         void ShowStatus();      
431         
432         virtual bool OnDataReady();
433
434         virtual bool OnWriteReady();
435         
436         virtual bool OnConnected();
437         
438         bool DoEvent();
439         
440         std::string MkInfoStr();
441         
442         const char* StatusStr();
443         
444         SQLerror DoQuery(SQLrequest &req);
445         
446         SQLerror Query(const SQLrequest &req);
447         
448         void OnUnloadModule(Module* mod);
449 };
450
451 class ModulePgSQL : public Module
452 {
453 private:
454         Server* Srv;
455         ConnMap connections;
456         unsigned long currid;
457         char* sqlsuccess;
458
459 public:
460         ModulePgSQL(Server* Me)
461         : Module::Module(Me), Srv(Me), currid(0)
462         {
463                 log(DEBUG, "%s 'SQL' feature", Srv->PublishFeature("SQL", this) ? "Published" : "Couldn't publish");
464                 
465                 sqlsuccess = new char[strlen(SQLSUCCESS)+1];
466                 
467                 strcpy(sqlsuccess, SQLSUCCESS);
468
469                 OnRehash("");
470         }
471
472         void Implements(char* List)
473         {
474                 List[I_OnUnloadModule] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = List[I_OnCheckReady] = List[I_OnUserDisconnect] = 1;
475         }
476
477         virtual void OnRehash(const std::string &parameter)
478         {
479                 ConfigReader conf;
480                 
481                 /* Delete all the SQLConn objects in the connection lists,
482                  * this will call their destructors where they can handle
483                  * closing connections and such.
484                  */
485                 for(ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
486                 {
487                         DELETE(iter->second);
488                 }
489                 
490                 /* Empty out our list of connections */
491                 connections.clear();
492
493                 for(int i = 0; i < conf.Enumerate("database"); i++)
494                 {
495                         std::string id;
496                         SQLConn* newconn;
497                         
498                         id = conf.ReadValue("database", "id", i);
499                         newconn = new SQLConn(this, Srv,
500                                                                                 conf.ReadValue("database", "hostname", i),
501                                                                                 conf.ReadInteger("database", "port", i, true),
502                                                                                 conf.ReadValue("database", "name", i),
503                                                                                 conf.ReadValue("database", "username", i),
504                                                                                 conf.ReadValue("database", "password", i),
505                                                                                 conf.ReadFlag("database", "ssl", i));
506                         
507                         connections.insert(std::make_pair(id, newconn));
508                 }       
509         }
510         
511         virtual char* OnRequest(Request* request)
512         {
513                 if(strcmp(SQLREQID, request->GetData()) == 0)
514                 {
515                         SQLrequest* req = (SQLrequest*)request;
516                         ConnMap::iterator iter;
517                 
518                         log(DEBUG, "Got query: '%s' with %d replacement parameters on id '%s'", req->query.q.c_str(), req->query.p.size(), req->dbid.c_str());
519
520                         if((iter = connections.find(req->dbid)) != connections.end())
521                         {
522                                 /* Execute query */
523                                 req->id = NewID();
524                                 req->error = iter->second->Query(*req);
525                                 
526                                 return (req->error.Id() == NO_ERROR) ? sqlsuccess : NULL;
527                         }
528                         else
529                         {
530                                 req->error.Id(BAD_DBID);
531                                 return NULL;
532                         }
533                 }
534
535                 log(DEBUG, "Got unsupported API version string: %s", request->GetData());
536                 
537                 return NULL;
538         }
539         
540         virtual void OnUnloadModule(Module* mod, const std::string&     name)
541         {
542                 /* When a module unloads we have to check all the pending queries for all our connections
543                  * and set the Module* specifying where the query came from to NULL. If the query has already
544                  * been dispatched then when it is processed it will be dropped if the pointer is NULL.
545                  *
546                  * If the queries we find are not already being executed then we can simply remove them immediately.
547                  */
548                 for(ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
549                 {
550                         iter->second->OnUnloadModule(mod);
551                 }
552         }
553
554         unsigned long NewID()
555         {
556                 if (currid+1 == 0)
557                         currid++;
558                 
559                 return ++currid;
560         }
561                 
562         virtual Version GetVersion()
563         {
564                 return Version(1, 0, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER);
565         }
566         
567         virtual ~ModulePgSQL()
568         {
569                 DELETE(sqlsuccess);
570         }       
571 };
572
573 SQLConn::SQLConn(ModulePgSQL* self, Server* srv, const std::string &h, unsigned int p, const std::string &d, const std::string &u, const std::string &pwd, bool s)
574 : InspSocket::InspSocket(), us(self), Srv(srv), dbhost(h), dbport(p), dbname(d), dbuser(u), dbpass(pwd), ssl(s), sql(NULL), status(CWRITE), qinprog(false)
575 {
576         log(DEBUG, "Creating new PgSQL connection to database %s on %s:%u (%s/%s)", dbname.c_str(), dbhost.c_str(), dbport, dbuser.c_str(), dbpass.c_str());
577
578         /* Some of this could be reviewed, unsure if I need to fill 'host' etc...
579          * just copied this over from the InspSocket constructor.
580          */
581         strlcpy(this->host, dbhost.c_str(), MAXBUF);
582         this->port = dbport;
583         
584         this->ClosePending = false;
585         
586         if(!inet_aton(this->host, &this->addy))
587         {
588                 /* Its not an ip, spawn the resolver.
589                  * PgSQL doesn't do nonblocking DNS 
590                  * lookups, so we do it for it.
591                  */
592                 
593                 log(DEBUG,"Attempting to resolve %s", this->host);
594                 
595                 this->dns.SetNS(Srv->GetConfig()->DNSServer);
596                 this->dns.ForwardLookupWithFD(this->host, fd);
597                 
598                 this->state = I_RESOLVING;
599                 socket_ref[this->fd] = this;
600                 
601                 return;
602         }
603         else
604         {
605                 log(DEBUG,"No need to resolve %s", this->host);
606                 strlcpy(this->IP, this->host, MAXBUF);
607                 
608                 if(!this->DoConnect())
609                 {
610                         throw ModuleException("Connect failed");
611                 }
612         }
613 }
614
615 SQLConn::~SQLConn()
616 {
617         Close();
618 }
619
620 bool SQLConn::DoResolve()
621 {       
622         log(DEBUG, "Checking for DNS lookup result");
623         
624         if(this->dns.HasResult())
625         {
626                 std::string res_ip = dns.GetResultIP();
627                 
628                 if(res_ip.length())
629                 {
630                         log(DEBUG, "Got result: %s", res_ip.c_str());
631                         
632                         strlcpy(this->IP, res_ip.c_str(), MAXBUF);
633                         dbhost = res_ip;
634                         
635                         socket_ref[this->fd] = NULL;
636                         
637                         return this->DoConnect();
638                 }
639                 else
640                 {
641                         log(DEBUG, "DNS lookup failed, dying horribly");
642                         Close();
643                         return false;
644                 }
645         }
646         else
647         {
648                 log(DEBUG, "No result for lookup yet!");
649                 return true;
650         }
651 }
652
653 bool SQLConn::DoConnect()
654 {
655         log(DEBUG, "SQLConn::DoConnect()");
656         
657         if(!(sql = PQconnectStart(MkInfoStr().c_str())))
658         {
659                 log(DEBUG, "Couldn't allocate PGconn structure, aborting: %s", PQerrorMessage(sql));
660                 Close();
661                 return false;
662         }
663         
664         if(PQstatus(sql) == CONNECTION_BAD)
665         {
666                 log(DEBUG, "PQconnectStart failed: %s", PQerrorMessage(sql));
667                 Close();
668                 return false;
669         }
670         
671         ShowStatus();
672         
673         if(PQsetnonblocking(sql, 1) == -1)
674         {
675                 log(DEBUG, "Couldn't set connection nonblocking: %s", PQerrorMessage(sql));
676                 Close();
677                 return false;
678         }
679         
680         /* OK, we've initalised the connection, now to get it hooked into the socket engine
681          * and then start polling it.
682          */
683         
684         log(DEBUG, "Old DNS socket: %d", this->fd);
685         this->fd = PQsocket(sql);
686         log(DEBUG, "New SQL socket: %d", this->fd);
687         
688         if(this->fd <= -1)
689         {
690                 log(DEBUG, "PQsocket says we have an invalid FD: %d", this->fd);
691                 Close();
692                 return false;
693         }
694         
695         this->state = I_CONNECTING;
696         ServerInstance->SE->AddFd(this->fd,false,X_ESTAB_MODULE);
697         socket_ref[this->fd] = this;
698         
699         /* Socket all hooked into the engine, now to tell PgSQL to start connecting */
700         
701         return DoPoll();
702 }
703
704 void SQLConn::Close()
705 {
706         log(DEBUG,"SQLConn::Close");
707         
708         if(this->fd > 01)
709                 socket_ref[this->fd] = NULL;
710         this->fd = -1;
711         this->state = I_ERROR;
712         this->OnError(I_ERR_SOCKET);
713         this->ClosePending = true;
714         
715         if(sql)
716         {
717                 PQfinish(sql);
718                 sql = NULL;
719         }
720         
721         return;
722 }
723
724 bool SQLConn::DoPoll()
725 {
726         switch(PQconnectPoll(sql))
727         {
728                 case PGRES_POLLING_WRITING:
729                         log(DEBUG, "PGconnectPoll: PGRES_POLLING_WRITING");
730                         WantWrite();
731                         status = CWRITE;
732                         return DoPoll();
733                 case PGRES_POLLING_READING:
734                         log(DEBUG, "PGconnectPoll: PGRES_POLLING_READING");
735                         status = CREAD;
736                         break;
737                 case PGRES_POLLING_FAILED:
738                         log(DEBUG, "PGconnectPoll: PGRES_POLLING_FAILED: %s", PQerrorMessage(sql));
739                         return false;
740                 case PGRES_POLLING_OK:
741                         log(DEBUG, "PGconnectPoll: PGRES_POLLING_OK");
742                         status = WWRITE;
743                         return DoConnectedPoll();
744                 default:
745                         log(DEBUG, "PGconnectPoll: wtf?");
746                         break;
747         }
748         
749         return true;
750 }
751
752 bool SQLConn::DoConnectedPoll()
753 {
754         if(!qinprog && queue.totalsize())
755         {
756                 /* There's no query currently in progress, and there's queries in the queue. */
757                 SQLrequest& query = queue.front();
758                 DoQuery(query);
759         }
760         
761         if(PQconsumeInput(sql))
762         {
763                 log(DEBUG, "PQconsumeInput succeeded");
764                         
765                 if(PQisBusy(sql))
766                 {
767                         log(DEBUG, "Still busy processing command though");
768                 }
769                 else if(qinprog)
770                 {
771                         log(DEBUG, "Looks like we have a result to process!");
772                         
773                         /* Grab the request we're processing */
774                         SQLrequest& query = queue.front();
775                         
776                         log(DEBUG, "ID is %lu", query.id);
777                         
778                         /* Get a pointer to the module we're about to return the result to */
779                         Module* to = query.GetSource();
780                         
781                         /* Fetch the result.. */
782                         PGresult* result = PQgetResult(sql);
783                         
784                         /* PgSQL would allow a query string to be sent which has multiple
785                          * queries in it, this isn't portable across database backends and
786                          * we don't want modules doing it. But just in case we make sure we
787                          * drain any results there are and just use the last one.
788                          * If the module devs are behaving there will only be one result.
789                          */
790                         while (PGresult* temp = PQgetResult(sql))
791                         {
792                                 PQclear(result);
793                                 result = temp;
794                         }
795                         
796                         if(to)
797                         {
798                                 /* ..and the result */
799                                 PgSQLresult reply(us, to, query.id, result);
800                                 
801                                 log(DEBUG, "Got result, status code: %s; error message: %s", PQresStatus(PQresultStatus(result)), PQresultErrorMessage(result));        
802                                 
803                                 switch(PQresultStatus(result))
804                                 {
805                                         case PGRES_EMPTY_QUERY:
806                                         case PGRES_BAD_RESPONSE:
807                                         case PGRES_FATAL_ERROR:
808                                                 reply.error.Id(QREPLY_FAIL);
809                                                 reply.error.Str(PQresultErrorMessage(result));
810                                         default:;
811                                                 /* No action, other values are not errors */
812                                 }
813                                 
814                                 reply.Send();
815                                 
816                                 /* PgSQLresult's destructor will free the PGresult */
817                         }
818                         else
819                         {
820                                 /* If the client module is unloaded partway through a query then the provider will set
821                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
822                                  * through at some point...and it could get messy if we play with invalid pointers...
823                                  */
824                                 log(DEBUG, "Looks like we're handling a zombie query from a module which unloaded before it got a result..fun. ID: %lu", query.id);
825                                 PQclear(result);
826                         }
827                         
828                         qinprog = false;
829                         queue.pop();                            
830                         DoConnectedPoll();
831                 }
832                 
833                 return true;
834         }
835         
836         log(DEBUG, "PQconsumeInput failed: %s", PQerrorMessage(sql));
837         return false;
838 }
839
840 void SQLConn::ShowStatus()
841 {
842         switch(PQstatus(sql))
843         {
844                 case CONNECTION_STARTED:
845                         log(DEBUG, "PQstatus: CONNECTION_STARTED: Waiting for connection to be made.");
846                         break;
847
848                 case CONNECTION_MADE:
849                         log(DEBUG, "PQstatus: CONNECTION_MADE: Connection OK; waiting to send.");
850                         break;
851                 
852                 case CONNECTION_AWAITING_RESPONSE:
853                         log(DEBUG, "PQstatus: CONNECTION_AWAITING_RESPONSE: Waiting for a response from the server.");
854                         break;
855                 
856                 case CONNECTION_AUTH_OK:
857                         log(DEBUG, "PQstatus: CONNECTION_AUTH_OK: Received authentication; waiting for backend start-up to finish.");
858                         break;
859                 
860                 case CONNECTION_SSL_STARTUP:
861                         log(DEBUG, "PQstatus: CONNECTION_SSL_STARTUP: Negotiating SSL encryption.");
862                         break;
863                 
864                 case CONNECTION_SETENV:
865                         log(DEBUG, "PQstatus: CONNECTION_SETENV: Negotiating environment-driven parameter settings.");
866                         break;
867                 
868                 default:
869                         log(DEBUG, "PQstatus: ???");
870         }
871 }
872
873 bool SQLConn::OnDataReady()
874 {
875         /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */
876         log(DEBUG, "OnDataReady(): status = %s", StatusStr());
877         
878         return DoEvent();
879 }
880
881 bool SQLConn::OnWriteReady()
882 {
883         /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */
884         log(DEBUG, "OnWriteReady(): status = %s", StatusStr());
885         
886         return DoEvent();
887 }
888
889 bool SQLConn::OnConnected()
890 {
891         log(DEBUG, "OnConnected(): status = %s", StatusStr());
892         
893         return DoEvent();
894 }
895
896 bool SQLConn::DoEvent()
897 {
898         bool ret;
899         
900         if((status == CREAD) || (status == CWRITE))
901         {
902                 ret = DoPoll();
903         }
904         else
905         {
906                 ret = DoConnectedPoll();
907         }
908         
909         switch(PQflush(sql))
910         {
911                 case -1:
912                         log(DEBUG, "Error flushing write queue: %s", PQerrorMessage(sql));
913                         break;
914                 case 0:
915                         log(DEBUG, "Successfully flushed write queue (or there was nothing to write)");
916                         break;
917                 case 1:
918                         log(DEBUG, "Not all of the write queue written, triggering write event so we can have another go");
919                         WantWrite();
920                         break;
921         }
922
923         return ret;
924 }
925
926 std::string SQLConn::MkInfoStr()
927 {                       
928         std::ostringstream conninfo("connect_timeout = '2'");
929         
930         if(dbhost.length())
931                 conninfo << " hostaddr = '" << dbhost << "'";
932         
933         if(dbport)
934                 conninfo << " port = '" << dbport << "'";
935         
936         if(dbname.length())
937                 conninfo << " dbname = '" << dbname << "'";
938         
939         if(dbuser.length())
940                 conninfo << " user = '" << dbuser << "'";
941         
942         if(dbpass.length())
943                 conninfo << " password = '" << dbpass << "'";
944         
945         if(ssl)
946                 conninfo << " sslmode = 'require'";
947         
948         return conninfo.str();
949 }
950
951 const char* SQLConn::StatusStr()
952 {
953         if(status == CREAD) return "CREAD";
954         if(status == CWRITE) return "CWRITE";
955         if(status == WREAD) return "WREAD";
956         if(status == WWRITE) return "WWRITE";
957         return "Err...what, erm..BUG!";
958 }
959
960 SQLerror SQLConn::DoQuery(SQLrequest &req)
961 {
962         if((status == WREAD) || (status == WWRITE))
963         {
964                 if(!qinprog)
965                 {
966                         /* Parse the command string and dispatch it */
967                         
968                         /* Pointer to the buffer we screw around with substitution in */
969                         char* query;
970                         /* Pointer to the current end of query, where we append new stuff */
971                         char* queryend;
972                         /* Total length of the unescaped parameters */
973                         unsigned int paramlen;
974                         
975                         paramlen = 0;
976                         
977                         for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
978                         {
979                                 paramlen += i->size();
980                         }
981                         
982                         /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
983                          * sizeofquery + (totalparamlength*2) + 1
984                          * 
985                          * The +1 is for null-terminating the string for PQsendQuery()
986                          */
987                         
988                         query = new char[req.query.q.length() + (paramlen*2)];
989                         queryend = query;
990                         
991                         /* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting
992                          * the parameters into it...
993                          */
994                         
995                         for(unsigned int i = 0; i < req.query.q.length(); i++)
996                         {
997                                 if(req.query.q[i] == '?')
998                                 {
999                                         /* We found a place to substitute..what fun.
1000                                          * Use the PgSQL calls to escape and write the
1001                                          * escaped string onto the end of our query buffer,
1002                                          * then we "just" need to make sure queryend is
1003                                          * pointing at the right place.
1004                                          */
1005                                         
1006                                         if(req.query.p.size())
1007                                         {
1008                                                 int error = 0;
1009                                                 size_t len = 0;
1010
1011 #ifdef PGSQL_HAS_ESCAPECONN
1012                                                 len = PQescapeStringConn(sql, queryend, req.query.p.front().c_str(), req.query.p.front().length(), &error);
1013 #else
1014                                                 len = PQescapeStringConn(queryend, req.query.p.front().c_str(), req.query.p.front().length());
1015                                                 error = 0;
1016 #endif
1017                                                 
1018                                                 if(error)
1019                                                 {
1020                                                         log(DEBUG, "Apparently PQescapeStringConn() failed somehow...don't know how or what to do...");
1021                                                 }
1022                                                 
1023                                                 log(DEBUG, "Appended %d bytes of escaped string onto the query", len);
1024                                                 
1025                                                 /* Incremenet queryend to the end of the newly escaped parameter */
1026                                                 queryend += len;
1027                                                 
1028                                                 /* Remove the parameter we just substituted in */
1029                                                 req.query.p.pop_front();
1030                                         }
1031                                         else
1032                                         {
1033                                                 log(DEBUG, "Found a substitution location but no parameter to substitute :|");
1034                                                 break;
1035                                         }
1036                                 }
1037                                 else
1038                                 {
1039                                         *queryend = req.query.q[i];
1040                                         queryend++;
1041                                 }
1042                         }
1043                         
1044                         /* Null-terminate the query */
1045                         *queryend = 0;
1046         
1047                         log(DEBUG, "Attempting to dispatch query: %s", query);
1048                         
1049                         req.query.q = query;
1050
1051                         if(PQsendQuery(sql, query))
1052                         {
1053                                 log(DEBUG, "Dispatched query successfully");
1054                                 qinprog = true;
1055                                 DELETE(query);
1056                                 return SQLerror();
1057                         }
1058                         else
1059                         {
1060                                 log(DEBUG, "Failed to dispatch query: %s", PQerrorMessage(sql));
1061                                 DELETE(query);
1062                                 return SQLerror(QSEND_FAIL, PQerrorMessage(sql));
1063                         }
1064                 }
1065         }
1066
1067         log(DEBUG, "Can't query until connection is complete");
1068         return SQLerror(BAD_CONN, "Can't query until connection is complete");
1069 }
1070
1071 SQLerror SQLConn::Query(const SQLrequest &req)
1072 {
1073         queue.push(req);
1074         
1075         if(!qinprog && queue.totalsize())
1076         {
1077                 /* There's no query currently in progress, and there's queries in the queue. */
1078                 SQLrequest& query = queue.front();
1079                 return DoQuery(query);
1080         }
1081         else
1082         {
1083                 return SQLerror();
1084         }
1085 }
1086
1087 void SQLConn::OnUnloadModule(Module* mod)
1088 {
1089         queue.PurgeModule(mod);
1090 }
1091
1092 class ModulePgSQLFactory : public ModuleFactory
1093 {
1094  public:
1095         ModulePgSQLFactory()
1096         {
1097         }
1098         
1099         ~ModulePgSQLFactory()
1100         {
1101         }
1102         
1103         virtual Module * CreateModule(Server* Me)
1104         {
1105                 return new ModulePgSQL(Me);
1106         }
1107 };
1108
1109
1110 extern "C" void * init_module( void )
1111 {
1112         return new ModulePgSQLFactory;
1113 }