]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_pgsql.cpp
e16f38e5990fc5aa508c22d8d3866841ac9b99ae
[user/henk/code/inspircd.git] / src / modules / extra / m_pgsql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *                <omster@gmail.com>
10  *     
11  * Written by Craig Edwards, Craig McLure, and others.
12  * This program is free but copyrighted software; see
13  *            the file COPYING for details.
14  *
15  * ---------------------------------------------------
16  */
17
18 #include <sstream>
19 #include <string>
20 #include <map>
21 #include <libpq-fe.h>
22
23 #include "users.h"
24 #include "channels.h"
25 #include "modules.h"
26 #include "helperfuncs.h"
27 #include "inspircd.h"
28 #include "configreader.h"
29
30 #include "m_sqlv2.h"
31
32 /* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */
33 /* $CompileFlags: -I`pg_config --includedir` */
34 /* $LinkerFlags: -L`pg_config --libdir` -lpq */
35
36 /* UGH, UGH, UGH, UGH, UGH, UGH
37  * I'm having trouble seeing how I
38  * can avoid this. The core-defined
39  * constructors for InspSocket just
40  * aren't suitable...and if I'm
41  * reimplementing them I need this so
42  * I can access the socket engine :\
43  */
44 extern InspIRCd* ServerInstance;
45 InspSocket* socket_ref[MAX_DESCRIPTORS];
46
47 /* Forward declare, so we can have the typedef neatly at the top */
48 class SQLConn;
49
50 typedef std::map<std::string, SQLConn*> ConnMap;
51
52 /* CREAD,       Connecting and wants read event
53  * CWRITE,      Connecting and wants write event
54  * WREAD,       Connected/Working and wants read event
55  * WWRITE,      Connected/Working and wants write event
56  */
57 enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE };
58
59 class SQLerror
60 {
61 public:
62         std::string err;
63
64         SQLerror(const std::string &s)
65         : err(s)
66         {
67         }
68 };
69
70 /** SQLConn represents one SQL session.
71  * Each session has its own persistent connection to the database.
72  * This is a subclass of InspSocket so it can easily recieve read/write events from the core socket
73  * engine, unlike the original MySQL module this module does not block. Ever. It gets a mild stabbing
74  * if it dares to.
75  */
76
77 class SQLConn : public InspSocket
78 {
79 private:
80         Server* Srv;                    /* Server* for..uhm..something, maybe */
81         std::string     dbhost; /* Database server hostname */
82         unsigned int    dbport; /* Database server port */
83         std::string     dbname; /* Database name */
84         std::string     dbuser; /* Database username */
85         std::string     dbpass; /* Database password */
86         bool                    ssl;    /* If we should require SSL */
87         PGconn*                 sql;    /* PgSQL database connection handle */
88         SQLstatus               status; /* PgSQL database connection status */
89
90 public:
91
92         /* This class should only ever be created inside this module, using this constructor, so we don't have to worry about the default ones */
93
94         SQLConn(Server* srv, const std::string &h, unsigned int p, const std::string &d, const std::string &u, const std::string &pwd, bool s)
95         : InspSocket::InspSocket(), Srv(srv), dbhost(h), dbport(p), dbname(d), dbuser(u), dbpass(pwd), ssl(s), sql(NULL), status(CWRITE)
96         {
97                 log(DEBUG, "Creating new PgSQL connection to database %s on %s:%u (%s/%s)", dbname.c_str(), dbhost.c_str(), dbport, dbuser.c_str(), dbpass.c_str());
98
99                 /* Some of this could be reviewed, unsure if I need to fill 'host' etc...
100                  * just copied this over from the InspSocket constructor.
101                  */
102                 strlcpy(this->host, dbhost.c_str(), MAXBUF);
103                 this->port = dbport;
104                 
105                 this->ClosePending = false;
106                 
107                 if(!inet_aton(this->host, &this->addy))
108                 {
109                         /* Its not an ip, spawn the resolver.
110                          * PgSQL doesn't do nonblocking DNS 
111                          * lookups, so we do it for it.
112                          */
113                         
114                         log(DEBUG,"Attempting to resolve %s", this->host);
115                         
116                         this->dns.SetNS(Srv->GetConfig()->DNSServer);
117                         this->dns.ForwardLookupWithFD(this->host, fd);
118                         
119                         this->state = I_RESOLVING;
120                         socket_ref[this->fd] = this;
121                         
122                         return;
123                 }
124                 else
125                 {
126                         log(DEBUG,"No need to resolve %s", this->host);
127                         strlcpy(this->IP, this->host, MAXBUF);
128                         
129                         if(!this->DoConnect())
130                         {
131                                 throw ModuleException("Connect failed");
132                         }
133                 }
134                 
135                 exit(-1);
136         }
137         
138         ~SQLConn()
139         {
140                 
141         }
142         
143         bool DoResolve()
144         {       
145                 log(DEBUG, "Checking for DNS lookup result");
146                 
147                 if(this->dns.HasResult())
148                 {
149                         std::string res_ip = dns.GetResultIP();
150                         
151                         if(res_ip.length())
152                         {
153                                 log(DEBUG, "Got result: %s", res_ip.c_str());
154                                 
155                                 strlcpy(this->IP, res_ip.c_str(), MAXBUF);
156                                 dbhost = res_ip;
157                                 
158                                 socket_ref[this->fd] = NULL;
159                                 
160                                 return this->DoConnect();
161                         }
162                         else
163                         {
164                                 log(DEBUG, "DNS lookup failed, dying horribly");
165                                 Close();
166                                 return false;
167                         }
168                 }
169                 else
170                 {
171                         log(DEBUG, "No result for lookup yet!");
172                         return true;
173                 }
174                 
175                 exit(-1);
176         }
177
178         bool DoConnect()
179         {
180                 log(DEBUG, "SQLConn::DoConnect()");
181                 
182                 if(!(sql = PQconnectStart(MkInfoStr().c_str())))
183                 {
184                         log(DEBUG, "Couldn't allocate PGconn structure, aborting: %s", PQerrorMessage(sql));
185                         Close();
186                         return false;
187                 }
188                 
189                 if(PQstatus(sql) == CONNECTION_BAD)
190                 {
191                         log(DEBUG, "PQconnectStart failed: %s", PQerrorMessage(sql));
192                         Close();
193                         return false;
194                 }
195                 
196                 ShowStatus();
197                 
198                 if(PQsetnonblocking(sql, 1) == -1)
199                 {
200                         log(DEBUG, "Couldn't set connection nonblocking: %s", PQerrorMessage(sql));
201                         Close();
202                         return false;
203                 }
204                 
205                 /* OK, we've initalised the connection, now to get it hooked into the socket engine
206                  * and then start polling it.
207                  */
208                 
209                 log(DEBUG, "Old DNS socket: %d", this->fd);
210                 this->fd = PQsocket(sql);
211                 log(DEBUG, "New SQL socket: %d", this->fd);
212                 
213                 if(this->fd <= -1)
214                 {
215                         log(DEBUG, "PQsocket says we have an invalid FD: %d", this->fd);
216                         Close();
217                         return false;
218                 }
219                 
220                 this->state = I_CONNECTING;
221                 ServerInstance->SE->AddFd(this->fd,false,X_ESTAB_MODULE);
222                 socket_ref[this->fd] = this;
223                 
224                 /* Socket all hooked into the engine, now to tell PgSQL to start connecting */
225                 
226                 return DoPoll();
227         }
228         
229         virtual void Close()
230         {
231                 this->fd = -1;
232                 this->state = I_ERROR;
233                 this->OnError(I_ERR_SOCKET);
234                 this->ClosePending = true;
235                 log(DEBUG,"SQLConn::Close");
236                 
237                 if(sql)
238                 {
239                         PQfinish(sql);
240                         sql = NULL;
241                 }
242                 
243                 return;
244         }
245         
246         bool DoPoll()
247         {
248                 switch(PQconnectPoll(sql))
249                 {
250                         case PGRES_POLLING_WRITING:
251                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_WRITING");
252                                 status = CWRITE;
253                                 DoPoll();
254                                 break;
255                         case PGRES_POLLING_READING:
256                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_READING");
257                                 status = CREAD;
258                                 break;
259                         case PGRES_POLLING_FAILED:
260                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_FAILED: %s", PQerrorMessage(sql));
261                                 Close();
262                                 return false;
263                         case PGRES_POLLING_OK:
264                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_OK");
265                                 status = WWRITE;
266                                 Query("SELECT * FROM rawr");
267                                 break;
268                         default:
269                                 log(DEBUG, "PGconnectPoll: wtf?");
270                                 break;
271                 }
272                 
273                 return true;
274         }
275         
276         void ShowStatus()
277         {
278                 switch(PQstatus(sql))
279                 {
280                         case CONNECTION_STARTED:
281                                 log(DEBUG, "PQstatus: CONNECTION_STARTED: Waiting for connection to be made.");
282                                 break;
283  
284                         case CONNECTION_MADE:
285                                 log(DEBUG, "PQstatus: CONNECTION_MADE: Connection OK; waiting to send.");
286                                 break;
287                         
288                         case CONNECTION_AWAITING_RESPONSE:
289                                 log(DEBUG, "PQstatus: CONNECTION_AWAITING_RESPONSE: Waiting for a response from the server.");
290                                 break;
291                         
292                         case CONNECTION_AUTH_OK:
293                                 log(DEBUG, "PQstatus: CONNECTION_AUTH_OK: Received authentication; waiting for backend start-up to finish.");
294                                 break;
295                         
296                         case CONNECTION_SSL_STARTUP:
297                                 log(DEBUG, "PQstatus: CONNECTION_SSL_STARTUP: Negotiating SSL encryption.");
298                                 break;
299                         
300                         case CONNECTION_SETENV:
301                                 log(DEBUG, "PQstatus: CONNECTION_SETENV: Negotiating environment-driven parameter settings.");
302                                 break;
303                         
304                         default:
305                                 log(DEBUG, "PQstatus: ???");
306                 }
307         }
308         
309         virtual bool OnDataReady()
310         {
311                 /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */
312                 log(DEBUG, "OnDataReady(): status = %s", StatusStr());
313                 
314                 return DoEvent();
315         }
316         
317         virtual bool OnConnected()
318         {
319                 log(DEBUG, "OnConnected(): status = %s", StatusStr());
320                 
321                 return DoEvent();
322         }
323         
324         bool DoEvent()
325         {
326                 if((status == CREAD) || (status == CWRITE))
327                 {
328                         DoPoll();
329                 }
330                 else
331                 {
332                         if(PQconsumeInput(sql))
333                         {
334                                 log(DEBUG, "PQconsumeInput succeeded");
335                                 
336                                 if(PQisBusy(sql))
337                                 {
338                                         log(DEBUG, "Still busy processing command though");
339                                 }
340                                 else
341                                 {
342                                         log(DEBUG, "Looks like we have a result to process!");
343                                         
344                                         while(PGresult* result = PQgetResult(sql))
345                                         {
346                                                 int cols = PQnfields(result);
347                                                 
348                                                 log(DEBUG, "Got result! :D");
349                                                 log(DEBUG, "%d rows, %d columns checking now what the column names are", PQntuples(result), cols);
350                                                 
351                                                 for(int i = 0; i < cols; i++)
352                                                 {
353                                                         log(DEBUG, "Column name: %s (%d)", PQfname(result, i));
354                                                 }
355                                                 
356                                                 PQclear(result);
357                                         }
358                                 }
359                         }
360                         else
361                         {
362                                 log(DEBUG, "PQconsumeInput failed: %s", PQerrorMessage(sql));
363                         }
364                 }
365
366                 return true;
367         }
368         
369         std::string MkInfoStr()
370         {                       
371                 /* XXX - This needs nonblocking DNS lookups */
372                 
373                 std::ostringstream conninfo("connect_timeout = '2'");
374                 
375                 if(dbhost.length())
376                         conninfo << " hostaddr = '" << dbhost << "'";
377                 
378                 if(dbport)
379                         conninfo << " port = '" << dbport << "'";
380                 
381                 if(dbname.length())
382                         conninfo << " dbname = '" << dbname << "'";
383                 
384                 if(dbuser.length())
385                         conninfo << " user = '" << dbuser << "'";
386                 
387                 if(dbpass.length())
388                         conninfo << " password = '" << dbpass << "'";
389                 
390                 if(ssl)
391                         conninfo << " sslmode = 'require'";
392                 
393                 return conninfo.str();
394         }
395         
396         const char* StatusStr()
397         {
398                 if(status == CREAD) return "CREAD";
399                 if(status == CWRITE) return "CWRITE";
400                 if(status == WREAD) return "WREAD";
401                 if(status == WWRITE) return "WWRITE";
402                 return "Err...what, erm..BUG!";
403         }
404         
405         bool Query(const std::string &query)
406         {
407                 if((status == WREAD) || (status == WWRITE))
408                 {
409                         if(PQsendQuery(sql, query.c_str()))
410                         {
411                                 log(DEBUG, "Dispatched query: %s", query.c_str());
412                                 return true;
413                         }
414                         else
415                         {
416                                 log(DEBUG, "Failed to dispatch query: %s", PQerrorMessage(sql));
417                                 return false;
418                         }
419                 }
420
421                 log(DEBUG, "Can't query until connection is complete");
422                 return false;
423         }
424
425         virtual void OnClose()
426         {
427                 /* Close PgSQL connection */
428         }
429
430         virtual void OnError(InspSocketError e)
431         {
432                 /* Unsure if we need this, we should be reading/writing via the PgSQL API rather than the insp one... */
433         }
434         
435         virtual void OnTimeout()
436         {
437                 /* Unused, I think */
438         }
439         
440 };
441
442 class ModulePgSQL : public Module
443 {
444 private:
445         Server* Srv;
446         ConnMap connections;
447
448 public:
449         ModulePgSQL(Server* Me)
450         : Module::Module(Me), Srv(Me)
451         {
452                 log(DEBUG, "%s 'SQL' feature", Srv->PublishFeature("SQL", this) ? "Published" : "Couldn't publish");
453                 log(DEBUG, "%s 'PgSQL' feature", Srv->PublishFeature("PgSQL", this) ? "Published" : "Couldn't publish");
454                 
455                 OnRehash("");
456         }
457
458         void Implements(char* List)
459         {
460                 List[I_OnRehash] = List[I_OnUserRegister] = List[I_OnCheckReady] = List[I_OnUserDisconnect] = 1;
461         }
462
463         virtual void OnRehash(const std::string &parameter)
464         {
465                 ConfigReader conf;
466                 
467                 /* Delete all the SQLConn objects in the connection lists,
468                  * this will call their destructors where they can handle
469                  * closing connections and such.
470                  */
471                 for(ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
472                 {
473                         DELETE(iter->second);
474                 }
475                 
476                 /* Empty out our list of connections */
477                 connections.clear();
478
479                 for(int i = 0; i < conf.Enumerate("database"); i++)
480                 {
481                         std::string id;
482                         SQLConn* newconn;
483                         
484                         id = conf.ReadValue("database", "id", i);
485                         newconn = new SQLConn(Srv,      conf.ReadValue("database", "hostname", i),
486                                                                                 conf.ReadInteger("database", "port", i, true),
487                                                                                 conf.ReadValue("database", "name", i),
488                                                                                 conf.ReadValue("database", "username", i),
489                                                                                 conf.ReadValue("database", "password", i),
490                                                                                 conf.ReadFlag("database", "ssl", i));
491                         
492                         connections.insert(std::make_pair(id, newconn));
493                 }       
494         }
495                 
496         virtual Version GetVersion()
497         {
498                 return Version(1, 0, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER);
499         }
500         
501         virtual ~ModulePgSQL()
502         {
503         }       
504 };
505
506 class ModulePgSQLFactory : public ModuleFactory
507 {
508  public:
509         ModulePgSQLFactory()
510         {
511         }
512         
513         ~ModulePgSQLFactory()
514         {
515         }
516         
517         virtual Module * CreateModule(Server* Me)
518         {
519                 return new ModulePgSQL(Me);
520         }
521 };
522
523
524 extern "C" void * init_module( void )
525 {
526         return new ModulePgSQLFactory;
527 }