]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_pgsql.cpp
655007ea32c81f5f7a4100c26ccce162b8db0d28
[user/henk/code/inspircd.git] / src / modules / extra / m_pgsql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2010 InspIRCd Development Team
6  * See: http://wiki.inspircd.org/Credits
7  *
8  * This program is free but copyrighted software; see
9  *            the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include "inspircd.h"
15 #include <cstdlib>
16 #include <sstream>
17 #include <libpq-fe.h>
18 #include "sql.h"
19
20 /* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */
21 /* $CompileFlags: -Iexec("pg_config --includedir") eval("my $s = `pg_config --version`;$s =~ /^.*?(\d+)\.(\d+)\.(\d+).*?$/;my $v = hex(sprintf("0x%02x%02x%02x", $1, $2, $3));print "-DPGSQL_HAS_ESCAPECONN" if(($v >= 0x080104) || ($v >= 0x07030F && $v < 0x070400) || ($v >= 0x07040D && $v < 0x080000) || ($v >= 0x080008 && $v < 0x080100));") */
22 /* $LinkerFlags: -Lexec("pg_config --libdir") -lpq */
23 /* $ModDep: m_sqlv2.h */
24
25 /* SQLConn rewritten by peavey to
26  * use EventHandler instead of
27  * BufferedSocket. This is much neater
28  * and gives total control of destroy
29  * and delete of resources.
30  */
31
32 /* Forward declare, so we can have the typedef neatly at the top */
33 class SQLConn;
34 class ModulePgSQL;
35
36 typedef std::map<std::string, SQLConn*> ConnMap;
37
38 /* CREAD,       Connecting and wants read event
39  * CWRITE,      Connecting and wants write event
40  * WREAD,       Connected/Working and wants read event
41  * WWRITE,      Connected/Working and wants write event
42  * RREAD,       Resetting and wants read event
43  * RWRITE,      Resetting and wants write event
44  */
45 enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE, RREAD, RWRITE };
46
47 class ReconnectTimer : public Timer
48 {
49  private:
50         ModulePgSQL* mod;
51  public:
52         ReconnectTimer(ModulePgSQL* m) : Timer(5, ServerInstance->Time(), false), mod(m)
53         {
54         }
55         virtual void Tick(time_t TIME);
56 };
57
58 struct QueueItem
59 {
60         SQLQuery* c;
61         std::string q;
62         QueueItem(SQLQuery* C, const std::string& Q) : c(C), q(Q) {}
63 };
64
65 /** PgSQLresult is a subclass of the mostly-pure-virtual class SQLresult.
66  * All SQL providers must create their own subclass and define it's methods using that
67  * database library's data retriveal functions. The aim is to avoid a slow and inefficient process
68  * of converting all data to a common format before it reaches the result structure. This way
69  * data is passes to the module nearly as directly as if it was using the API directly itself.
70  */
71
72 class PgSQLresult : public SQLResult
73 {
74         PGresult* res;
75         int currentrow;
76         int rows;
77  public:
78         PgSQLresult(PGresult* result) : res(result), currentrow(0)
79         {
80                 rows = PQntuples(res);
81                 if (!rows)
82                         rows = atoi(PQcmdTuples(res));
83         }
84
85         ~PgSQLresult()
86         {
87                 PQclear(res);
88         }
89
90         virtual int Rows()
91         {
92                 return rows;
93         }
94
95         virtual void GetCols(std::vector<std::string>& result)
96         {
97                 result.resize(PQnfields(res));
98                 for(unsigned int i=0; i < result.size(); i++)
99                 {
100                         result[i] = PQfname(res, i);
101                 }
102         }
103
104         virtual SQLEntry GetValue(int row, int column)
105         {
106                 char* v = PQgetvalue(res, row, column);
107                 if (!v || PQgetisnull(res, row, column))
108                         return SQLEntry();
109
110                 return SQLEntry(std::string(v, PQgetlength(res, row, column)));
111         }
112
113         virtual bool GetRow(SQLEntries& result)
114         {
115                 if (currentrow >= PQntuples(res))
116                         return false;
117                 int ncols = PQnfields(res);
118
119                 for(int i = 0; i < ncols; i++)
120                 {
121                         result.push_back(GetValue(currentrow, i));
122                 }
123                 currentrow++;
124
125                 return true;
126         }
127 };
128
129 /** SQLConn represents one SQL session.
130  */
131 class SQLConn : public SQLProvider, public EventHandler
132 {
133  public:
134         reference<ConfigTag> conf;      /* The <database> entry */
135         std::deque<QueueItem> queue;
136         PGconn*                 sql;            /* PgSQL database connection handle */
137         SQLstatus               status;         /* PgSQL database connection status */
138         QueueItem               qinprog;        /* If there is currently a query in progress */
139
140         SQLConn(Module* Creator, ConfigTag* tag)
141         : SQLProvider(Creator, "SQL/" + tag->getString("id")), conf(tag), sql(NULL), status(CWRITE), qinprog(NULL, "")
142         {
143                 if (!DoConnect())
144                 {
145                         ServerInstance->Logs->Log("m_pgsql",DEFAULT, "WARNING: Could not connect to database " + tag->getString("id")); 
146                         DelayReconnect();
147                 }
148         }
149
150         CullResult cull()
151         {
152                 this->SQLProvider::cull();
153                 ServerInstance->Modules->DelService(*this);
154                 return this->EventHandler::cull();
155         }
156
157         ~SQLConn()
158         {
159                 SQLerror err(SQL_BAD_DBID);
160                 if (qinprog.c)
161                 {
162                         qinprog.c->OnError(err);
163                         delete qinprog.c;
164                 }
165                 for(std::deque<QueueItem>::iterator i = queue.begin(); i != queue.end(); i++)
166                 {
167                         SQLQuery* q = i->c;
168                         q->OnError(err);
169                         delete q;
170                 }
171         }
172
173         virtual void HandleEvent(EventType et, int errornum)
174         {
175                 switch (et)
176                 {
177                         case EVENT_READ:
178                         case EVENT_WRITE:
179                                 DoEvent();
180                         break;
181
182                         case EVENT_ERROR:
183                                 DelayReconnect();
184                 }
185         }
186
187         std::string GetDSN()
188         {
189                 std::ostringstream conninfo("connect_timeout = '5'");
190                 std::string item;
191
192                 if (conf->readString("host", item))
193                         conninfo << " host = '" << item << "'";
194
195                 if (conf->readString("port", item))
196                         conninfo << " port = '" << item << "'";
197
198                 if (conf->readString("name", item))
199                         conninfo << " dbname = '" << item << "'";
200
201                 if (conf->readString("user", item))
202                         conninfo << " user = '" << item << "'";
203
204                 if (conf->readString("pass", item))
205                         conninfo << " password = '" << item << "'";
206
207                 if (conf->getBool("ssl"))
208                         conninfo << " sslmode = 'require'";
209                 else
210                         conninfo << " sslmode = 'disable'";
211
212                 return conninfo.str();
213         }
214
215         bool DoConnect()
216         {
217                 sql = PQconnectStart(GetDSN().c_str());
218                 if (!sql)
219                         return false;
220
221                 if(PQstatus(sql) == CONNECTION_BAD)
222                         return false;
223
224                 if(PQsetnonblocking(sql, 1) == -1)
225                         return false;
226
227                 /* OK, we've initalised the connection, now to get it hooked into the socket engine
228                 * and then start polling it.
229                 */
230                 this->fd = PQsocket(sql);
231
232                 if(this->fd <= -1)
233                         return false;
234
235                 if (!ServerInstance->SE->AddFd(this, FD_WANT_NO_WRITE | FD_WANT_NO_READ))
236                 {
237                         ServerInstance->Logs->Log("m_pgsql",DEBUG, "BUG: Couldn't add pgsql socket to socket engine");
238                         return false;
239                 }
240
241                 /* Socket all hooked into the engine, now to tell PgSQL to start connecting */
242                 return DoPoll();
243         }
244
245         bool DoPoll()
246         {
247                 switch(PQconnectPoll(sql))
248                 {
249                         case PGRES_POLLING_WRITING:
250                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ);
251                                 status = CWRITE;
252                                 return true;
253                         case PGRES_POLLING_READING:
254                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
255                                 status = CREAD;
256                                 return true;
257                         case PGRES_POLLING_FAILED:
258                                 return false;
259                         case PGRES_POLLING_OK:
260                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
261                                 status = WWRITE;
262                                 DoConnectedPoll();
263                         default:
264                                 return true;
265                 }
266         }
267
268         void DoConnectedPoll()
269         {
270 restart:
271                 while (qinprog.q.empty() && !queue.empty())
272                 {
273                         /* There's no query currently in progress, and there's queries in the queue. */
274                         DoQuery(queue.front());
275                         queue.pop_front();
276                 }
277
278                 if (PQconsumeInput(sql))
279                 {
280                         if (PQisBusy(sql))
281                         {
282                                 /* Nothing happens here */
283                         }
284                         else if (qinprog.c)
285                         {
286                                 /* Fetch the result.. */
287                                 PGresult* result = PQgetResult(sql);
288
289                                 /* PgSQL would allow a query string to be sent which has multiple
290                                  * queries in it, this isn't portable across database backends and
291                                  * we don't want modules doing it. But just in case we make sure we
292                                  * drain any results there are and just use the last one.
293                                  * If the module devs are behaving there will only be one result.
294                                  */
295                                 while (PGresult* temp = PQgetResult(sql))
296                                 {
297                                         PQclear(result);
298                                         result = temp;
299                                 }
300
301                                 /* ..and the result */
302                                 PgSQLresult reply(result);
303                                 switch(PQresultStatus(result))
304                                 {
305                                         case PGRES_EMPTY_QUERY:
306                                         case PGRES_BAD_RESPONSE:
307                                         case PGRES_FATAL_ERROR:
308                                         {
309                                                 SQLerror err(SQL_QREPLY_FAIL, PQresultErrorMessage(result));
310                                                 qinprog.c->OnError(err);
311                                                 break;
312                                         }
313                                         default:
314                                                 /* Other values are not errors */
315                                                 qinprog.c->OnResult(reply);
316                                 }
317
318                                 delete qinprog.c;
319                                 qinprog = QueueItem(NULL, "");
320                                 goto restart;
321                         }
322                         else
323                         {
324                                 qinprog.q = "";
325                         }
326                 }
327                 else
328                 {
329                         /* I think we'll assume this means the server died...it might not,
330                          * but I think that any error serious enough we actually get here
331                          * deserves to reconnect [/excuse]
332                          * Returning true so the core doesn't try and close the connection.
333                          */
334                         DelayReconnect();
335                 }
336         }
337
338         bool DoResetPoll()
339         {
340                 switch(PQresetPoll(sql))
341                 {
342                         case PGRES_POLLING_WRITING:
343                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ);
344                                 status = CWRITE;
345                                 return DoPoll();
346                         case PGRES_POLLING_READING:
347                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
348                                 status = CREAD;
349                                 return true;
350                         case PGRES_POLLING_FAILED:
351                                 return false;
352                         case PGRES_POLLING_OK:
353                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
354                                 status = WWRITE;
355                                 DoConnectedPoll();
356                         default:
357                                 return true;
358                 }
359         }
360
361         void DelayReconnect();
362
363         void DoEvent()
364         {
365                 if((status == CREAD) || (status == CWRITE))
366                 {
367                         DoPoll();
368                 }
369                 else if((status == RREAD) || (status == RWRITE))
370                 {
371                         DoResetPoll();
372                 }
373                 else
374                 {
375                         DoConnectedPoll();
376                 }
377         }
378
379         void submit(SQLQuery *req, const std::string& q)
380         {
381                 if (qinprog.q.empty())
382                 {
383                         DoQuery(QueueItem(req,q));
384                 }
385                 else
386                 {
387                         // wait your turn.
388                         queue.push_back(QueueItem(req,q));
389                 }
390         }
391
392         void submit(SQLQuery *req, const std::string& q, const ParamL& p)
393         {
394                 std::string res;
395                 unsigned int param = 0;
396                 for(std::string::size_type i = 0; i < q.length(); i++)
397                 {
398                         if (q[i] != '?')
399                                 res.push_back(q[i]);
400                         else
401                         {
402                                 if (param < p.size())
403                                 {
404                                         std::string parm = p[param++];
405                                         char buffer[MAXBUF];
406 #ifdef PGSQL_HAS_ESCAPECONN
407                                         int error;
408                                         PQescapeStringConn(sql, buffer, parm.c_str(), parm.length(), &error);
409                                         if (error)
410                                                 ServerInstance->Logs->Log("m_pgsql", DEBUG, "BUG: Apparently PQescapeStringConn() failed");
411 #else
412                                         PQescapeString         (buffer, parm.c_str(), parm.length());
413 #endif
414                                         res.append(buffer);
415                                 }
416                         }
417                 }
418                 submit(req, res);
419         }
420
421         void submit(SQLQuery *req, const std::string& q, const ParamM& p)
422         {
423                 std::string res;
424                 for(std::string::size_type i = 0; i < q.length(); i++)
425                 {
426                         if (q[i] != '$')
427                                 res.push_back(q[i]);
428                         else
429                         {
430                                 std::string field;
431                                 i++;
432                                 while (i < q.length() && isalpha(q[i]))
433                                         field.push_back(q[i++]);
434                                 i--;
435
436                                 ParamM::const_iterator it = p.find(field);
437                                 if (it != p.end())
438                                 {
439                                         std::string parm = it->second;
440                                         char buffer[MAXBUF];
441 #ifdef PGSQL_HAS_ESCAPECONN
442                                         int error;
443                                         PQescapeStringConn(sql, buffer, parm.c_str(), parm.length(), &error);
444                                         if (error)
445                                                 ServerInstance->Logs->Log("m_pgsql", DEBUG, "BUG: Apparently PQescapeStringConn() failed");
446 #else
447                                         PQescapeString         (buffer, parm.c_str(), parm.length());
448 #endif
449                                         res.append(buffer);
450                                 }
451                         }
452                 }
453                 submit(req, res);
454         }
455
456         void DoQuery(const QueueItem& req)
457         {
458                 if (status != WREAD && status != WWRITE)
459                 {
460                         // whoops, not connected...
461                         SQLerror err(SQL_BAD_CONN);
462                         req.c->OnError(err);
463                         delete req.c;
464                         return;
465                 }
466
467                 if(PQsendQuery(sql, req.q.c_str()))
468                 {
469                         qinprog = req;
470                 }
471                 else
472                 {
473                         SQLerror err(SQL_QSEND_FAIL, PQerrorMessage(sql));
474                         req.c->OnError(err);
475                         delete req.c;
476                 }
477         }
478
479         void Close()
480         {
481                 ServerInstance->SE->DelFd(this);
482
483                 if(sql)
484                 {
485                         PQfinish(sql);
486                         sql = NULL;
487                 }
488         }
489 };
490
491 class ModulePgSQL : public Module
492 {
493  public:
494         ConnMap connections;
495         ReconnectTimer* retimer;
496
497         ModulePgSQL()
498         {
499         }
500
501         void init()
502         {
503                 ReadConf();
504
505                 Implementation eventlist[] = { I_OnUnloadModule, I_OnRehash };
506                 ServerInstance->Modules->Attach(eventlist, this, 2);
507         }
508
509         virtual ~ModulePgSQL()
510         {
511                 if (retimer)
512                         ServerInstance->Timers->DelTimer(retimer);
513                 ClearAllConnections();
514         }
515
516         virtual void OnRehash(User* user)
517         {
518                 ReadConf();
519         }
520
521         void ReadConf()
522         {
523                 ConnMap conns;
524                 ConfigTagList tags = ServerInstance->Config->ConfTags("database");
525                 for(ConfigIter i = tags.first; i != tags.second; i++)
526                 {
527                         if (i->second->getString("module", "pgsql") != "pgsql")
528                                 continue;
529                         std::string id = i->second->getString("id");
530                         ConnMap::iterator curr = connections.find(id);
531                         if (curr == connections.end())
532                         {
533                                 SQLConn* conn = new SQLConn(this, i->second);
534                                 conns.insert(std::make_pair(id, conn));
535                                 ServerInstance->Modules->AddService(*conn);
536                         }
537                         else
538                         {
539                                 conns.insert(*curr);
540                                 connections.erase(curr);
541                         }
542                 }
543                 ClearAllConnections();
544                 conns.swap(connections);
545         }
546
547         void ClearAllConnections()
548         {
549                 for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
550                 {
551                         i->second->cull();
552                         delete i->second;
553                 }
554                 connections.clear();
555         }
556
557         void OnUnloadModule(Module* mod)
558         {
559                 SQLerror err(SQL_BAD_DBID);
560                 for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
561                 {
562                         SQLConn* conn = i->second;
563                         if (conn->qinprog.c && conn->qinprog.c->creator == mod)
564                         {
565                                 conn->qinprog.c->OnError(err);
566                                 delete conn->qinprog.c;
567                                 conn->qinprog.c = NULL;
568                         }
569                         std::deque<QueueItem>::iterator j = conn->queue.begin();
570                         while (j != conn->queue.end())
571                         {
572                                 SQLQuery* q = j->c;
573                                 if (q->creator == mod)
574                                 {
575                                         q->OnError(err);
576                                         delete q;
577                                         j = conn->queue.erase(j);
578                                 }
579                                 else
580                                         j++;
581                         }
582                 }
583         }
584
585         Version GetVersion()
586         {
587                 return Version("PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API", VF_VENDOR);
588         }
589 };
590
591 void ReconnectTimer::Tick(time_t time)
592 {
593         mod->retimer = NULL;
594         mod->ReadConf();
595 }
596
597 void SQLConn::DelayReconnect()
598 {
599         ModulePgSQL* mod = (ModulePgSQL*)(Module*)creator;
600         ConnMap::iterator it = mod->connections.find(conf->getString("id"));
601         if (it != mod->connections.end())
602         {
603                 mod->connections.erase(it);
604                 ServerInstance->GlobalCulls.AddItem((EventHandler*)this);
605                 if (!mod->retimer)
606                 {
607                         mod->retimer = new ReconnectTimer(mod);
608                         ServerInstance->Timers->AddTimer(mod->retimer);
609                 }
610         }
611 }
612
613 MODULE_INIT(ModulePgSQL)