]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_pgsql.cpp
Implement OnUnloadModule for postgres
[user/henk/code/inspircd.git] / src / modules / extra / m_pgsql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2010 InspIRCd Development Team
6  * See: http://wiki.inspircd.org/Credits
7  *
8  * This program is free but copyrighted software; see
9  *            the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include "inspircd.h"
15 #include <cstdlib>
16 #include <sstream>
17 #include <libpq-fe.h>
18 #include "sql.h"
19
20 /* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */
21 /* $CompileFlags: -Iexec("pg_config --includedir") eval("my $s = `pg_config --version`;$s =~ /^.*?(\d+)\.(\d+)\.(\d+).*?$/;my $v = hex(sprintf("0x%02x%02x%02x", $1, $2, $3));print "-DPGSQL_HAS_ESCAPECONN" if(($v >= 0x080104) || ($v >= 0x07030F && $v < 0x070400) || ($v >= 0x07040D && $v < 0x080000) || ($v >= 0x080008 && $v < 0x080100));") */
22 /* $LinkerFlags: -Lexec("pg_config --libdir") -lpq */
23 /* $ModDep: m_sqlv2.h */
24
25 /* SQLConn rewritten by peavey to
26  * use EventHandler instead of
27  * BufferedSocket. This is much neater
28  * and gives total control of destroy
29  * and delete of resources.
30  */
31
32 /* Forward declare, so we can have the typedef neatly at the top */
33 class SQLConn;
34 class ModulePgSQL;
35
36 typedef std::map<std::string, SQLConn*> ConnMap;
37
38 /* CREAD,       Connecting and wants read event
39  * CWRITE,      Connecting and wants write event
40  * WREAD,       Connected/Working and wants read event
41  * WWRITE,      Connected/Working and wants write event
42  * RREAD,       Resetting and wants read event
43  * RWRITE,      Resetting and wants write event
44  */
45 enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE, RREAD, RWRITE };
46
47 class ReconnectTimer : public Timer
48 {
49  private:
50         ModulePgSQL* mod;
51  public:
52         ReconnectTimer(ModulePgSQL* m) : Timer(5, ServerInstance->Time(), false), mod(m)
53         {
54         }
55         virtual void Tick(time_t TIME);
56 };
57
58
59 /** PgSQLresult is a subclass of the mostly-pure-virtual class SQLresult.
60  * All SQL providers must create their own subclass and define it's methods using that
61  * database library's data retriveal functions. The aim is to avoid a slow and inefficient process
62  * of converting all data to a common format before it reaches the result structure. This way
63  * data is passes to the module nearly as directly as if it was using the API directly itself.
64  */
65
66 class PgSQLresult : public SQLResult
67 {
68         PGresult* res;
69         int currentrow;
70         int rows;
71  public:
72         PgSQLresult(PGresult* result) : res(result), currentrow(0)
73         {
74                 rows = PQntuples(res);
75                 if (!rows)
76                         rows = atoi(PQcmdTuples(res));
77         }
78
79         ~PgSQLresult()
80         {
81                 PQclear(res);
82         }
83
84         virtual int Rows()
85         {
86                 return rows;
87         }
88
89         virtual void GetCols(std::vector<std::string>& result)
90         {
91                 result.resize(PQnfields(res));
92                 for(unsigned int i=0; i < result.size(); i++)
93                 {
94                         result[i] = PQfname(res, i);
95                 }
96         }
97
98         virtual SQLEntry GetValue(int row, int column)
99         {
100                 char* v = PQgetvalue(res, row, column);
101                 if (!v || PQgetisnull(res, row, column))
102                         return SQLEntry();
103
104                 return SQLEntry(std::string(v, PQgetlength(res, row, column)));
105         }
106
107         virtual bool GetRow(SQLEntries& result)
108         {
109                 if (currentrow >= PQntuples(res))
110                         return false;
111                 int ncols = PQnfields(res);
112
113                 for(int i = 0; i < ncols; i++)
114                 {
115                         result.push_back(GetValue(currentrow, i));
116                 }
117                 currentrow++;
118
119                 return true;
120         }
121 };
122
123 /** SQLConn represents one SQL session.
124  */
125 class SQLConn : public SQLProvider, public EventHandler
126 {
127  public:
128         reference<ConfigTag> conf;      /* The <database> entry */
129         std::deque<SQLQuery*> queue;
130         PGconn*                 sql;            /* PgSQL database connection handle */
131         SQLstatus               status;         /* PgSQL database connection status */
132         SQLQuery*               qinprog;        /* If there is currently a query in progress */
133
134         SQLConn(Module* Creator, ConfigTag* tag)
135         : SQLProvider(Creator, "SQL/" + tag->getString("id")), conf(tag), sql(NULL), status(CWRITE), qinprog(NULL)
136         {
137                 if (!DoConnect())
138                 {
139                         ServerInstance->Logs->Log("m_pgsql",DEFAULT, "WARNING: Could not connect to database " + tag->getString("id")); 
140                         DelayReconnect();
141                 }
142         }
143
144         CullResult cull()
145         {
146                 this->SQLProvider::cull();
147                 ServerInstance->Modules->DelService(*this);
148                 return this->EventHandler::cull();
149         }
150
151         ~SQLConn()
152         {
153                 SQLerror err(SQL_BAD_DBID);
154                 if (qinprog)
155                 {
156                         qinprog->OnError(err);
157                         delete qinprog;
158                 }
159                 for(std::deque<SQLQuery*>::iterator i = queue.begin(); i != queue.end(); i++)
160                 {
161                         SQLQuery* q = *i;
162                         q->OnError(err);
163                         delete q;
164                 }
165         }
166
167         virtual void HandleEvent(EventType et, int errornum)
168         {
169                 switch (et)
170                 {
171                         case EVENT_READ:
172                         case EVENT_WRITE:
173                                 DoEvent();
174                         break;
175
176                         case EVENT_ERROR:
177                                 DelayReconnect();
178                 }
179         }
180
181         std::string GetDSN()
182         {
183                 std::ostringstream conninfo("connect_timeout = '5'");
184                 std::string item;
185
186                 if (conf->readString("host", item))
187                         conninfo << " host = '" << item << "'";
188
189                 if (conf->readString("port", item))
190                         conninfo << " port = '" << item << "'";
191
192                 if (conf->readString("name", item))
193                         conninfo << " dbname = '" << item << "'";
194
195                 if (conf->readString("user", item))
196                         conninfo << " user = '" << item << "'";
197
198                 if (conf->readString("pass", item))
199                         conninfo << " password = '" << item << "'";
200
201                 if (conf->getBool("ssl"))
202                         conninfo << " sslmode = 'require'";
203                 else
204                         conninfo << " sslmode = 'disable'";
205
206                 return conninfo.str();
207         }
208
209         bool DoConnect()
210         {
211                 sql = PQconnectStart(GetDSN().c_str());
212                 if (!sql)
213                         return false;
214
215                 if(PQstatus(sql) == CONNECTION_BAD)
216                         return false;
217
218                 if(PQsetnonblocking(sql, 1) == -1)
219                         return false;
220
221                 /* OK, we've initalised the connection, now to get it hooked into the socket engine
222                 * and then start polling it.
223                 */
224                 this->fd = PQsocket(sql);
225
226                 if(this->fd <= -1)
227                         return false;
228
229                 if (!ServerInstance->SE->AddFd(this, FD_WANT_NO_WRITE | FD_WANT_NO_READ))
230                 {
231                         ServerInstance->Logs->Log("m_pgsql",DEBUG, "BUG: Couldn't add pgsql socket to socket engine");
232                         return false;
233                 }
234
235                 /* Socket all hooked into the engine, now to tell PgSQL to start connecting */
236                 return DoPoll();
237         }
238
239         bool DoPoll()
240         {
241                 switch(PQconnectPoll(sql))
242                 {
243                         case PGRES_POLLING_WRITING:
244                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ);
245                                 status = CWRITE;
246                                 return true;
247                         case PGRES_POLLING_READING:
248                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
249                                 status = CREAD;
250                                 return true;
251                         case PGRES_POLLING_FAILED:
252                                 return false;
253                         case PGRES_POLLING_OK:
254                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
255                                 status = WWRITE;
256                                 DoConnectedPoll();
257                         default:
258                                 return true;
259                 }
260         }
261
262         void DoConnectedPoll()
263         {
264 restart:
265                 while (!qinprog && !queue.empty())
266                 {
267                         /* There's no query currently in progress, and there's queries in the queue. */
268                         DoQuery(queue.front());
269                         queue.pop_front();
270                 }
271
272                 if (PQconsumeInput(sql))
273                 {
274                         if (PQisBusy(sql))
275                         {
276                                 /* Nothing happens here */
277                         }
278                         else if (qinprog)
279                         {
280                                 /* Fetch the result.. */
281                                 PGresult* result = PQgetResult(sql);
282
283                                 /* PgSQL would allow a query string to be sent which has multiple
284                                  * queries in it, this isn't portable across database backends and
285                                  * we don't want modules doing it. But just in case we make sure we
286                                  * drain any results there are and just use the last one.
287                                  * If the module devs are behaving there will only be one result.
288                                  */
289                                 while (PGresult* temp = PQgetResult(sql))
290                                 {
291                                         PQclear(result);
292                                         result = temp;
293                                 }
294
295                                 /* ..and the result */
296                                 PgSQLresult reply(result);
297                                 switch(PQresultStatus(result))
298                                 {
299                                         case PGRES_EMPTY_QUERY:
300                                         case PGRES_BAD_RESPONSE:
301                                         case PGRES_FATAL_ERROR:
302                                         {
303                                                 SQLerror err(SQL_QREPLY_FAIL, PQresultErrorMessage(result));
304                                                 qinprog->OnError(err);
305                                                 break;
306                                         }
307                                         default:
308                                                 /* Other values are not errors */
309                                                 qinprog->OnResult(reply);
310                                 }
311
312                                 delete qinprog;
313                                 qinprog = NULL;
314                                 goto restart;
315                         }
316                 }
317                 else
318                 {
319                         /* I think we'll assume this means the server died...it might not,
320                          * but I think that any error serious enough we actually get here
321                          * deserves to reconnect [/excuse]
322                          * Returning true so the core doesn't try and close the connection.
323                          */
324                         DelayReconnect();
325                 }
326         }
327
328         bool DoResetPoll()
329         {
330                 switch(PQresetPoll(sql))
331                 {
332                         case PGRES_POLLING_WRITING:
333                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ);
334                                 status = CWRITE;
335                                 return DoPoll();
336                         case PGRES_POLLING_READING:
337                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
338                                 status = CREAD;
339                                 return true;
340                         case PGRES_POLLING_FAILED:
341                                 return false;
342                         case PGRES_POLLING_OK:
343                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
344                                 status = WWRITE;
345                                 DoConnectedPoll();
346                         default:
347                                 return true;
348                 }
349         }
350
351         void DelayReconnect();
352
353         void DoEvent()
354         {
355                 if((status == CREAD) || (status == CWRITE))
356                 {
357                         DoPoll();
358                 }
359                 else if((status == RREAD) || (status == RWRITE))
360                 {
361                         DoResetPoll();
362                 }
363                 else
364                 {
365                         DoConnectedPoll();
366                 }
367         }
368
369         virtual std::string FormatQuery(const std::string& q, const ParamL& p)
370         {
371                 std::string res;
372                 unsigned int param = 0;
373                 for(std::string::size_type i = 0; i < q.length(); i++)
374                 {
375                         if (q[i] != '?')
376                                 res.push_back(q[i]);
377                         else
378                         {
379                                 // TODO numbered parameter support ('?1')
380                                 if (param < p.size())
381                                 {
382                                         std::string parm = p[param++];
383                                         char buffer[MAXBUF];
384 #ifdef PGSQL_HAS_ESCAPECONN
385                                         int error;
386                                         PQescapeStringConn(sql, buffer, parm.c_str(), parm.length(), &error);
387                                         if (error)
388                                                 ServerInstance->Logs->Log("m_pgsql", DEBUG, "BUG: Apparently PQescapeStringConn() failed");
389 #else
390                                         PQescapeString         (buffer, parm.c_str(), parm.length());
391 #endif
392                                         res.append(buffer);
393                                 }
394                         }
395                 }
396                 return res;
397         }
398
399         std::string FormatQuery(const std::string& q, const ParamM& p)
400         {
401                 std::string res;
402                 for(std::string::size_type i = 0; i < q.length(); i++)
403                 {
404                         if (q[i] != '$')
405                                 res.push_back(q[i]);
406                         else
407                         {
408                                 std::string field;
409                                 i++;
410                                 while (i < q.length() && isalpha(q[i]))
411                                         field.push_back(q[i++]);
412                                 i--;
413
414                                 ParamM::const_iterator it = p.find(field);
415                                 if (it != p.end())
416                                 {
417                                         std::string parm = it->second;
418                                         char buffer[MAXBUF];
419 #ifdef PGSQL_HAS_ESCAPECONN
420                                         int error;
421                                         PQescapeStringConn(sql, buffer, parm.c_str(), parm.length(), &error);
422                                         if (error)
423                                                 ServerInstance->Logs->Log("m_pgsql", DEBUG, "BUG: Apparently PQescapeStringConn() failed");
424 #else
425                                         PQescapeString         (buffer, parm.c_str(), parm.length());
426 #endif
427                                         res.append(buffer);
428                                 }
429                         }
430                 }
431                 return res;
432         }
433
434         virtual void submit(SQLQuery *req)
435         {
436                 if (qinprog)
437                 {
438                         // wait your turn.
439                         queue.push_back(req);
440                 }
441                 else
442                 {
443                         DoQuery(req);
444                 }
445         }
446
447         void DoQuery(SQLQuery* req)
448         {
449                 if (status != WREAD && status != WWRITE)
450                 {
451                         // whoops, not connected...
452                         SQLerror err(SQL_BAD_CONN);
453                         req->OnError(err);
454                         delete req;
455                         return;
456                 }
457
458                 if(PQsendQuery(sql, req->query.c_str()))
459                 {
460                         qinprog = req;
461                 }
462                 else
463                 {
464                         SQLerror err(SQL_QSEND_FAIL, PQerrorMessage(sql));
465                         req->OnError(err);
466                         delete req;
467                 }
468         }
469
470         void Close()
471         {
472                 ServerInstance->SE->DelFd(this);
473
474                 if(sql)
475                 {
476                         PQfinish(sql);
477                         sql = NULL;
478                 }
479         }
480 };
481
482 class DummyQuery : public SQLQuery
483 {
484  public:
485         DummyQuery(Module* me) : SQLQuery(me, "") {}
486         void OnResult(SQLResult& result) {}
487 };
488
489 class ModulePgSQL : public Module
490 {
491  public:
492         ConnMap connections;
493         ReconnectTimer* retimer;
494
495         ModulePgSQL()
496         {
497         }
498
499         void init()
500         {
501                 ReadConf();
502
503                 Implementation eventlist[] = { I_OnUnloadModule, I_OnRehash };
504                 ServerInstance->Modules->Attach(eventlist, this, 2);
505         }
506
507         virtual ~ModulePgSQL()
508         {
509                 if (retimer)
510                         ServerInstance->Timers->DelTimer(retimer);
511                 ClearAllConnections();
512         }
513
514         virtual void OnRehash(User* user)
515         {
516                 ReadConf();
517         }
518
519         void ReadConf()
520         {
521                 ConnMap conns;
522                 ConfigTagList tags = ServerInstance->Config->ConfTags("database");
523                 for(ConfigIter i = tags.first; i != tags.second; i++)
524                 {
525                         if (i->second->getString("module", "pgsql") != "pgsql")
526                                 continue;
527                         std::string id = i->second->getString("id");
528                         ConnMap::iterator curr = connections.find(id);
529                         if (curr == connections.end())
530                         {
531                                 SQLConn* conn = new SQLConn(this, i->second);
532                                 conns.insert(std::make_pair(id, conn));
533                                 ServerInstance->Modules->AddService(*conn);
534                         }
535                         else
536                         {
537                                 conns.insert(*curr);
538                                 connections.erase(curr);
539                         }
540                 }
541                 ClearAllConnections();
542                 conns.swap(connections);
543         }
544
545         void ClearAllConnections()
546         {
547                 for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
548                 {
549                         i->second->cull();
550                         delete i->second;
551                 }
552                 connections.clear();
553         }
554
555         void OnUnloadModule(Module* mod)
556         {
557                 SQLerror err(SQL_BAD_DBID);
558                 for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
559                 {
560                         SQLConn* conn = i->second;
561                         if (conn->qinprog && conn->qinprog->creator == mod)
562                         {
563                                 conn->qinprog->OnError(err);
564                                 delete conn->qinprog;
565                                 conn->qinprog = new DummyQuery(this);
566                         }
567                         std::deque<SQLQuery*>::iterator j = conn->queue.begin();
568                         while (j != conn->queue.end())
569                         {
570                                 SQLQuery* q = *j;
571                                 if (q->creator == mod)
572                                 {
573                                         q->OnError(err);
574                                         delete q;
575                                         j = conn->queue.erase(j);
576                                 }
577                                 else
578                                         j++;
579                         }
580                 }
581         }
582
583         Version GetVersion()
584         {
585                 return Version("PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API", VF_VENDOR);
586         }
587 };
588
589 void ReconnectTimer::Tick(time_t time)
590 {
591         mod->retimer = NULL;
592         mod->ReadConf();
593 }
594
595 void SQLConn::DelayReconnect()
596 {
597         ModulePgSQL* mod = (ModulePgSQL*)(Module*)creator;
598         ConnMap::iterator it = mod->connections.find(conf->getString("id"));
599         if (it != mod->connections.end())
600         {
601                 mod->connections.erase(it);
602                 ServerInstance->GlobalCulls.AddItem((EventHandler*)this);
603                 if (!mod->retimer)
604                 {
605                         mod->retimer = new ReconnectTimer(mod);
606                         ServerInstance->Timers->AddTimer(mod->retimer);
607                 }
608         }
609 }
610
611 MODULE_INIT(ModulePgSQL)