]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_pgsql.cpp
Change InspSocket's private members to only be protected, I couldn't find any other...
[user/henk/code/inspircd.git] / src / modules / extra / m_pgsql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *                <omster@gmail.com>
10  *     
11  * Written by Craig Edwards, Craig McLure, and others.
12  * This program is free but copyrighted software; see
13  *            the file COPYING for details.
14  *
15  * ---------------------------------------------------
16  */
17
18 #include <sstream>
19 #include <string>
20 #include <map>
21 #include <libpq-fe.h>
22
23 #include "users.h"
24 #include "channels.h"
25 #include "modules.h"
26 #include "helperfuncs.h"
27 #include "inspircd.h"
28 #include "configreader.h"
29
30 #include "m_sqlv2.h"
31
32 /* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */
33 /* $CompileFlags: -I`pg_config --includedir` */
34 /* $LinkerFlags: -L`pg_config --libdir` -lpq */
35
36 /* UGH, UGH, UGH, UGH, UGH, UGH
37  * I'm having trouble seeing how I
38  * can avoid this. The core-defined
39  * constructors for InspSocket just
40  * aren't suitable...and if I'm
41  * reimplementing them I need this so
42  * I can access the socket engine :\
43  */
44 extern InspIRCd* ServerInstance;
45 InspSocket* socket_ref[MAX_DESCRIPTORS];
46
47 /* Forward declare, so we can have the typedef neatly at the top */
48 class SQLConn;
49
50 typedef std::map<std::string, SQLConn*> ConnMap;
51
52 /* CREAD,       Connecting and wants read event
53  * CWRITE,      Connecting and wants write event
54  * WREAD,       Connected/Working and wants read event
55  * WWRITE,      Connected/Working and wants write event
56  */
57 enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE };
58
59 class SQLerror
60 {
61 public:
62         std::string err;
63
64         SQLerror(const std::string &s)
65         : err(s)
66         {
67         }
68 };
69
70 /** SQLConn represents one SQL session.
71  * Each session has its own persistent connection to the database.
72  * This is a subclass of InspSocket so it can easily recieve read/write events from the core socket
73  * engine, unlike the original MySQL module this module does not block. Ever. It gets a mild stabbing
74  * if it dares to.
75  */
76
77 class SQLConn : public InspSocket
78 {
79 private:
80         Server* Srv;                    /* Server* for..uhm..something, maybe */
81         std::string     dbhost; /* Database server hostname */
82         unsigned int    dbport; /* Database server port */
83         std::string     dbname; /* Database name */
84         std::string     dbuser; /* Database username */
85         std::string     dbpass; /* Database password */
86         bool                    ssl;    /* If we should require SSL */
87         PGconn*                 sql;    /* PgSQL database connection handle */
88         SQLstatus               status; /* PgSQL database connection status */
89
90 public:
91
92         /* This class should only ever be created inside this module, using this constructor, so we don't have to worry about the default ones */
93
94         SQLConn(Server* srv, const std::string &h, unsigned int p, const std::string &d, const std::string &u, const std::string &pwd, bool s)
95         : InspSocket::InspSocket(), Srv(srv), dbhost(h), dbport(p), dbname(d), dbuser(u), dbpass(pwd), ssl(s), sql(NULL), status(CWRITE)
96         {
97                 log(DEBUG, "Creating new PgSQL connection to database %s on %s:%u (%s/%s)", dbname.c_str(), dbhost.c_str(), dbport, dbuser.c_str(), dbpass.c_str());
98
99                 /* Some of this could be reviewed, unsure if I need to fill 'host' etc...
100                  * just copied this over from the InspSocket constructor.
101                  */
102                 strlcpy(this->host, dbhost.c_str(), MAXBUF);
103                 this->port = dbport;
104                 
105                 this->ClosePending = false;
106                 
107                 if(!inet_aton(this->host, &this->addy))
108                 {
109                         /* Its not an ip, spawn the resolver.
110                          * PgSQL doesn't do nonblocking DNS 
111                          * lookups, so we do it for it.
112                          */
113                         
114                         log(DEBUG,"Attempting to resolve %s", this->host);
115                         
116                         this->dns.SetNS(Srv->GetConfig()->DNSServer);
117                         this->dns.ForwardLookupWithFD(this->host, fd);
118                         
119                         this->state = I_RESOLVING;
120                         socket_ref[this->fd] = this;
121                         
122                         return;
123                 }
124                 else
125                 {
126                         log(DEBUG,"No need to resolve %s", this->host);
127                         strlcpy(this->IP, this->host, MAXBUF);
128                         
129                         if(!this->DoConnect())
130                         {
131                                 throw ModuleException("Connect failed");
132                         }
133                 }
134                 
135                 exit(-1);
136         }
137         
138         bool DoResolve()
139         {       
140                 log(DEBUG, "Checking for DNS lookup result");
141                 
142                 if(this->dns.HasResult())
143                 {
144                         std::string res_ip = dns.GetResultIP();
145                         
146                         if(res_ip.length())
147                         {
148                                 log(DEBUG, "Got result: %s", res_ip.c_str());
149                                 
150                                 strlcpy(this->IP, res_ip.c_str(), MAXBUF);
151                                 dbhost = res_ip;
152                                 
153                                 socket_ref[this->fd] = NULL;
154                                 
155                                 return this->DoConnect();
156                         }
157                         else
158                         {
159                                 log(DEBUG, "DNS lookup failed, dying horribly");
160                                 DoError();
161                                 return false;
162                         }
163                 }
164                 else
165                 {
166                         log(DEBUG, "No result for lookup yet!");
167                         return true;
168                 }
169                 
170                 exit(-1);
171         }
172
173         bool DoConnect()
174         {
175                 log(DEBUG, "SQLConn::DoConnect()");
176                 
177                 if(!(sql = PQconnectStart(MkInfoStr().c_str())))
178                 {
179                         log(DEBUG, "Couldn't allocate PGconn structure, aborting: %s", PQerrorMessage(sql));
180                         DoError();
181                         return false;
182                 }
183                 
184                 if(PQstatus(sql) == CONNECTION_BAD)
185                 {
186                         log(DEBUG, "PQconnectStart failed: %s", PQerrorMessage(sql));
187                         DoError();
188                         return false;
189                 }
190                 
191                 ShowStatus();
192                 
193                 if(PQsetnonblocking(sql, 1) == -1)
194                 {
195                         log(DEBUG, "Couldn't set connection nonblocking: %s", PQerrorMessage(sql));
196                         DoError();
197                         return false;
198                 }
199                 
200                 /* OK, we've initalised the connection, now to get it hooked into the socket engine
201                  * and then start polling it.
202                  */
203                 
204                 log(DEBUG, "Old DNS socket: %d", this->fd);
205                 this->fd = PQsocket(sql);
206                 log(DEBUG, "New SQL socket: %d", this->fd);
207                 
208                 if(this->fd <= -1)
209                 {
210                         log(DEBUG, "PQsocket says we have an invalid FD: %d", this->fd);
211                         DoError();
212                         return false;
213                 }
214                 
215                 this->state = I_CONNECTING;
216                 ServerInstance->SE->AddFd(this->fd,false,X_ESTAB_MODULE);
217                 socket_ref[this->fd] = this;
218                 
219                 /* Socket all hooked into the engine, now to tell PgSQL to start connecting */
220                 
221                 return DoPoll();
222         }
223         
224         void DoError()
225         {
226                 this->fd = -1;
227                 this->state = I_ERROR;
228                 this->OnError(I_ERR_SOCKET);
229                 this->ClosePending = true;
230                 log(DEBUG,"SQLConn::DoError");
231                 
232                 if(sql)
233                 {
234                         PQfinish(sql);
235                         sql = NULL;
236                 }
237                 
238                 return;
239         }
240         
241         bool DoPoll()
242         {
243                 switch(PQconnectPoll(sql))
244                 {
245                         case PGRES_POLLING_WRITING:
246                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_WRITING");
247                                 status = CWRITE;
248                                 DoPoll();
249                                 break;
250                         case PGRES_POLLING_READING:
251                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_READING");
252                                 status = CREAD;
253                                 break;
254                         case PGRES_POLLING_FAILED:
255                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_FAILED: %s", PQerrorMessage(sql));
256                                 DoError();
257                                 return false;
258                         case PGRES_POLLING_OK:
259                                 log(DEBUG, "PGconnectPoll: PGRES_POLLING_OK");
260                                 status = WWRITE;
261                                 Query("SELECT * FROM rawr");
262                                 break;
263                         default:
264                                 log(DEBUG, "PGconnectPoll: wtf?");
265                                 break;
266                 }
267                 
268                 return true;
269         }
270         
271         void ShowStatus()
272         {
273                 switch(PQstatus(sql))
274                 {
275                         case CONNECTION_STARTED:
276                                 log(DEBUG, "PQstatus: CONNECTION_STARTED: Waiting for connection to be made.");
277                                 break;
278  
279                         case CONNECTION_MADE:
280                                 log(DEBUG, "PQstatus: CONNECTION_MADE: Connection OK; waiting to send.");
281                                 break;
282                         
283                         case CONNECTION_AWAITING_RESPONSE:
284                                 log(DEBUG, "PQstatus: CONNECTION_AWAITING_RESPONSE: Waiting for a response from the server.");
285                                 break;
286                         
287                         case CONNECTION_AUTH_OK:
288                                 log(DEBUG, "PQstatus: CONNECTION_AUTH_OK: Received authentication; waiting for backend start-up to finish.");
289                                 break;
290                         
291                         case CONNECTION_SSL_STARTUP:
292                                 log(DEBUG, "PQstatus: CONNECTION_SSL_STARTUP: Negotiating SSL encryption.");
293                                 break;
294                         
295                         case CONNECTION_SETENV:
296                                 log(DEBUG, "PQstatus: CONNECTION_SETENV: Negotiating environment-driven parameter settings.");
297                                 break;
298                         
299                         default:
300                                 log(DEBUG, "PQstatus: ???");
301                 }
302         }
303         
304         virtual void OnTimeout()
305         {
306                 /* Unused, I think */
307         }
308
309         virtual bool OnDataReady()
310         {
311                 /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */
312                 log(DEBUG, "OnDataReady(): status = %s", StatusStr());
313                 
314                 return DoEvent();
315         }
316         
317         virtual bool OnConnected()
318         {
319                 log(DEBUG, "OnConnected(): status = %s", StatusStr());
320                 
321                 return DoEvent();
322         }
323         
324         bool DoEvent()
325         {
326                 if((status == CREAD) || (status == CWRITE))
327                 {
328                         DoPoll();
329                 }
330                 else
331                 {
332                         if(PQconsumeInput(sql))
333                         {
334                                 log(DEBUG, "PQconsumeInput succeeded");
335                                 
336                                 if(PQisBusy(sql))
337                                 {
338                                         log(DEBUG, "Still busy processing command though");
339                                 }
340                                 else
341                                 {
342                                         log(DEBUG, "Looks like we have a result to process!");
343                                         
344                                         while(PGresult* result = PQgetResult(sql))
345                                         {
346                                                 int cols = PQnfields(result);
347                                                 
348                                                 log(DEBUG, "Got result! :D");
349                                                 log(DEBUG, "%d rows, %d columns checking now what the column names are", PQntuples(result), cols);
350                                                 
351                                                 for(int i = 0; i < cols; i++)
352                                                 {
353                                                         log(DEBUG, "Column name: %s (%d)", PQfname(result, i));
354                                                 }
355                                                 
356                                                 PQclear(result);
357                                         }
358                                 }
359                         }
360                         else
361                         {
362                                 log(DEBUG, "PQconsumeInput failed: %s", PQerrorMessage(sql));
363                         }
364                 }
365
366                 return true;
367         }
368
369         virtual void OnClose()
370         {
371                 /* Close PgSQL connection */
372         }
373
374         virtual void OnError(InspSocketError e)
375         {
376                 /* Unsure if we need this, we should be reading/writing via the PgSQL API rather than the insp one... */
377         }
378         
379         std::string MkInfoStr()
380         {                       
381                 /* XXX - This needs nonblocking DNS lookups */
382                 
383                 std::ostringstream conninfo("connect_timeout = '2'");
384                 
385                 if(dbhost.length())
386                         conninfo << " hostaddr = '" << dbhost << "'";
387                 
388                 if(dbport)
389                         conninfo << " port = '" << dbport << "'";
390                 
391                 if(dbname.length())
392                         conninfo << " dbname = '" << dbname << "'";
393                 
394                 if(dbuser.length())
395                         conninfo << " user = '" << dbuser << "'";
396                 
397                 if(dbpass.length())
398                         conninfo << " password = '" << dbpass << "'";
399                 
400                 if(ssl)
401                         conninfo << " sslmode = 'require'";
402                 
403                 return conninfo.str();
404         }
405         
406         const char* StatusStr()
407         {
408                 if(status == CREAD) return "CREAD";
409                 if(status == CWRITE) return "CWRITE";
410                 if(status == WREAD) return "WREAD";
411                 if(status == WWRITE) return "WWRITE";
412         }
413         
414         bool Query(const std::string &query)
415         {
416                 if((status == WREAD) || (status == WWRITE))
417                 {
418                         if(PQsendQuery(sql, query.c_str()))
419                         {
420                                 log(DEBUG, "Dispatched query: %s", query.c_str());
421                                 return true;
422                         }
423                         else
424                         {
425                                 log(DEBUG, "Failed to dispatch query: %s", PQerrorMessage(sql));
426                                 return false;
427                         }
428                 }
429                 else
430                 {
431                         log(DEBUG, "Can't query until connection is complete");
432                         return false;
433                 }
434         }
435 };
436
437 class ModulePgSQL : public Module
438 {
439 private:
440         Server* Srv;
441         ConnMap connections;
442
443 public:
444         ModulePgSQL(Server* Me)
445         : Module::Module(Me), Srv(Me)
446         {
447                 OnRehash("");
448         }
449
450         void Implements(char* List)
451         {
452                 List[I_OnRehash] = List[I_OnUserRegister] = List[I_OnCheckReady] = List[I_OnUserDisconnect] = 1;
453         }
454
455         virtual void OnRehash(const std::string &parameter)
456         {
457                 ConfigReader conf;
458                 
459                 /* Delete all the SQLConn objects in the connection lists,
460                  * this will call their destructors where they can handle
461                  * closing connections and such.
462                  */
463                 for(ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
464                 {
465                         DELETE(iter->second);
466                 }
467                 
468                 /* Empty out our list of connections */
469                 connections.clear();
470
471                 for(int i = 0; i < conf.Enumerate("database"); i++)
472                 {
473                         std::string id;
474                         SQLConn* newconn;
475                         
476                         id = conf.ReadValue("database", "id", i);
477                         newconn = new SQLConn(Srv,      conf.ReadValue("database", "hostname", i),
478                                                                                 conf.ReadInteger("database", "port", i, true),
479                                                                                 conf.ReadValue("database", "name", i),
480                                                                                 conf.ReadValue("database", "username", i),
481                                                                                 conf.ReadValue("database", "password", i),
482                                                                                 conf.ReadFlag("database", "ssl", i));
483                         
484                         connections.insert(std::make_pair(id, newconn));
485                 }       
486         }
487                 
488         virtual Version GetVersion()
489         {
490                 return Version(1, 0, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER);
491         }
492         
493         virtual ~ModulePgSQL()
494         {
495         }       
496 };
497
498 class ModulePgSQLFactory : public ModuleFactory
499 {
500  public:
501         ModulePgSQLFactory()
502         {
503         }
504         
505         ~ModulePgSQLFactory()
506         {
507         }
508         
509         virtual Module * CreateModule(Server* Me)
510         {
511                 return new ModulePgSQL(Me);
512         }
513 };
514
515
516 extern "C" void * init_module( void )
517 {
518         return new ModulePgSQLFactory;
519 }