]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_pgsql.cpp
Merge pull request #109 from Justasic/insp20
[user/henk/code/inspircd.git] / src / modules / extra / m_pgsql.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
6  *   Copyright (C) 2006-2007, 2009 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
8  *   Copyright (C) 2008 Thomas Stagner <aquanight@inspircd.org>
9  *   Copyright (C) 2006 Oliver Lupton <oliverlupton@gmail.com>
10  *
11  * This file is part of InspIRCd.  InspIRCd is free software: you can
12  * redistribute it and/or modify it under the terms of the GNU General Public
13  * License as published by the Free Software Foundation, version 2.
14  *
15  * This program is distributed in the hope that it will be useful, but WITHOUT
16  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
17  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
18  * details.
19  *
20  * You should have received a copy of the GNU General Public License
21  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
22  */
23
24
25 #include "inspircd.h"
26 #include <cstdlib>
27 #include <sstream>
28 #include <libpq-fe.h>
29 #include "sql.h"
30
31 /* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */
32 /* $CompileFlags: -Iexec("pg_config --includedir") eval("my $s = `pg_config --version`;$s =~ /^.*?(\d+)\.(\d+)\.(\d+).*?$/;my $v = hex(sprintf("0x%02x%02x%02x", $1, $2, $3));print "-DPGSQL_HAS_ESCAPECONN" if(($v >= 0x080104) || ($v >= 0x07030F && $v < 0x070400) || ($v >= 0x07040D && $v < 0x080000) || ($v >= 0x080008 && $v < 0x080100));") */
33 /* $LinkerFlags: -Lexec("pg_config --libdir") -lpq */
34
35 /* SQLConn rewritten by peavey to
36  * use EventHandler instead of
37  * BufferedSocket. This is much neater
38  * and gives total control of destroy
39  * and delete of resources.
40  */
41
42 /* Forward declare, so we can have the typedef neatly at the top */
43 class SQLConn;
44 class ModulePgSQL;
45
46 typedef std::map<std::string, SQLConn*> ConnMap;
47
48 /* CREAD,       Connecting and wants read event
49  * CWRITE,      Connecting and wants write event
50  * WREAD,       Connected/Working and wants read event
51  * WWRITE,      Connected/Working and wants write event
52  * RREAD,       Resetting and wants read event
53  * RWRITE,      Resetting and wants write event
54  */
55 enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE, RREAD, RWRITE };
56
57 class ReconnectTimer : public Timer
58 {
59  private:
60         ModulePgSQL* mod;
61  public:
62         ReconnectTimer(ModulePgSQL* m) : Timer(5, ServerInstance->Time(), false), mod(m)
63         {
64         }
65         virtual void Tick(time_t TIME);
66 };
67
68 struct QueueItem
69 {
70         SQLQuery* c;
71         std::string q;
72         QueueItem(SQLQuery* C, const std::string& Q) : c(C), q(Q) {}
73 };
74
75 /** PgSQLresult is a subclass of the mostly-pure-virtual class SQLresult.
76  * All SQL providers must create their own subclass and define it's methods using that
77  * database library's data retriveal functions. The aim is to avoid a slow and inefficient process
78  * of converting all data to a common format before it reaches the result structure. This way
79  * data is passes to the module nearly as directly as if it was using the API directly itself.
80  */
81
82 class PgSQLresult : public SQLResult
83 {
84         PGresult* res;
85         int currentrow;
86         int rows;
87  public:
88         PgSQLresult(PGresult* result) : res(result), currentrow(0)
89         {
90                 rows = PQntuples(res);
91                 if (!rows)
92                         rows = atoi(PQcmdTuples(res));
93         }
94
95         ~PgSQLresult()
96         {
97                 PQclear(res);
98         }
99
100         virtual int Rows()
101         {
102                 return rows;
103         }
104
105         virtual void GetCols(std::vector<std::string>& result)
106         {
107                 result.resize(PQnfields(res));
108                 for(unsigned int i=0; i < result.size(); i++)
109                 {
110                         result[i] = PQfname(res, i);
111                 }
112         }
113
114         virtual SQLEntry GetValue(int row, int column)
115         {
116                 char* v = PQgetvalue(res, row, column);
117                 if (!v || PQgetisnull(res, row, column))
118                         return SQLEntry();
119
120                 return SQLEntry(std::string(v, PQgetlength(res, row, column)));
121         }
122
123         virtual bool GetRow(SQLEntries& result)
124         {
125                 if (currentrow >= PQntuples(res))
126                         return false;
127                 int ncols = PQnfields(res);
128
129                 for(int i = 0; i < ncols; i++)
130                 {
131                         result.push_back(GetValue(currentrow, i));
132                 }
133                 currentrow++;
134
135                 return true;
136         }
137 };
138
139 /** SQLConn represents one SQL session.
140  */
141 class SQLConn : public SQLProvider, public EventHandler
142 {
143  public:
144         reference<ConfigTag> conf;      /* The <database> entry */
145         std::deque<QueueItem> queue;
146         PGconn*                 sql;            /* PgSQL database connection handle */
147         SQLstatus               status;         /* PgSQL database connection status */
148         QueueItem               qinprog;        /* If there is currently a query in progress */
149
150         SQLConn(Module* Creator, ConfigTag* tag)
151         : SQLProvider(Creator, "SQL/" + tag->getString("id")), conf(tag), sql(NULL), status(CWRITE), qinprog(NULL, "")
152         {
153                 if (!DoConnect())
154                 {
155                         ServerInstance->Logs->Log("m_pgsql",DEFAULT, "WARNING: Could not connect to database " + tag->getString("id")); 
156                         DelayReconnect();
157                 }
158         }
159
160         CullResult cull()
161         {
162                 this->SQLProvider::cull();
163                 ServerInstance->Modules->DelService(*this);
164                 return this->EventHandler::cull();
165         }
166
167         ~SQLConn()
168         {
169                 SQLerror err(SQL_BAD_DBID);
170                 if (qinprog.c)
171                 {
172                         qinprog.c->OnError(err);
173                         delete qinprog.c;
174                 }
175                 for(std::deque<QueueItem>::iterator i = queue.begin(); i != queue.end(); i++)
176                 {
177                         SQLQuery* q = i->c;
178                         q->OnError(err);
179                         delete q;
180                 }
181         }
182
183         virtual void HandleEvent(EventType et, int errornum)
184         {
185                 switch (et)
186                 {
187                         case EVENT_READ:
188                         case EVENT_WRITE:
189                                 DoEvent();
190                         break;
191
192                         case EVENT_ERROR:
193                                 DelayReconnect();
194                 }
195         }
196
197         std::string GetDSN()
198         {
199                 std::ostringstream conninfo("connect_timeout = '5'");
200                 std::string item;
201
202                 if (conf->readString("host", item))
203                         conninfo << " host = '" << item << "'";
204
205                 if (conf->readString("port", item))
206                         conninfo << " port = '" << item << "'";
207
208                 if (conf->readString("name", item))
209                         conninfo << " dbname = '" << item << "'";
210
211                 if (conf->readString("user", item))
212                         conninfo << " user = '" << item << "'";
213
214                 if (conf->readString("pass", item))
215                         conninfo << " password = '" << item << "'";
216
217                 if (conf->getBool("ssl"))
218                         conninfo << " sslmode = 'require'";
219                 else
220                         conninfo << " sslmode = 'disable'";
221
222                 return conninfo.str();
223         }
224
225         bool DoConnect()
226         {
227                 sql = PQconnectStart(GetDSN().c_str());
228                 if (!sql)
229                         return false;
230
231                 if(PQstatus(sql) == CONNECTION_BAD)
232                         return false;
233
234                 if(PQsetnonblocking(sql, 1) == -1)
235                         return false;
236
237                 /* OK, we've initalised the connection, now to get it hooked into the socket engine
238                 * and then start polling it.
239                 */
240                 this->fd = PQsocket(sql);
241
242                 if(this->fd <= -1)
243                         return false;
244
245                 if (!ServerInstance->SE->AddFd(this, FD_WANT_NO_WRITE | FD_WANT_NO_READ))
246                 {
247                         ServerInstance->Logs->Log("m_pgsql",DEBUG, "BUG: Couldn't add pgsql socket to socket engine");
248                         return false;
249                 }
250
251                 /* Socket all hooked into the engine, now to tell PgSQL to start connecting */
252                 return DoPoll();
253         }
254
255         bool DoPoll()
256         {
257                 switch(PQconnectPoll(sql))
258                 {
259                         case PGRES_POLLING_WRITING:
260                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ);
261                                 status = CWRITE;
262                                 return true;
263                         case PGRES_POLLING_READING:
264                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
265                                 status = CREAD;
266                                 return true;
267                         case PGRES_POLLING_FAILED:
268                                 return false;
269                         case PGRES_POLLING_OK:
270                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
271                                 status = WWRITE;
272                                 DoConnectedPoll();
273                         default:
274                                 return true;
275                 }
276         }
277
278         void DoConnectedPoll()
279         {
280 restart:
281                 while (qinprog.q.empty() && !queue.empty())
282                 {
283                         /* There's no query currently in progress, and there's queries in the queue. */
284                         DoQuery(queue.front());
285                         queue.pop_front();
286                 }
287
288                 if (PQconsumeInput(sql))
289                 {
290                         if (PQisBusy(sql))
291                         {
292                                 /* Nothing happens here */
293                         }
294                         else if (qinprog.c)
295                         {
296                                 /* Fetch the result.. */
297                                 PGresult* result = PQgetResult(sql);
298
299                                 /* PgSQL would allow a query string to be sent which has multiple
300                                  * queries in it, this isn't portable across database backends and
301                                  * we don't want modules doing it. But just in case we make sure we
302                                  * drain any results there are and just use the last one.
303                                  * If the module devs are behaving there will only be one result.
304                                  */
305                                 while (PGresult* temp = PQgetResult(sql))
306                                 {
307                                         PQclear(result);
308                                         result = temp;
309                                 }
310
311                                 /* ..and the result */
312                                 PgSQLresult reply(result);
313                                 switch(PQresultStatus(result))
314                                 {
315                                         case PGRES_EMPTY_QUERY:
316                                         case PGRES_BAD_RESPONSE:
317                                         case PGRES_FATAL_ERROR:
318                                         {
319                                                 SQLerror err(SQL_QREPLY_FAIL, PQresultErrorMessage(result));
320                                                 qinprog.c->OnError(err);
321                                                 break;
322                                         }
323                                         default:
324                                                 /* Other values are not errors */
325                                                 qinprog.c->OnResult(reply);
326                                 }
327
328                                 delete qinprog.c;
329                                 qinprog = QueueItem(NULL, "");
330                                 goto restart;
331                         }
332                         else
333                         {
334                                 qinprog.q = "";
335                         }
336                 }
337                 else
338                 {
339                         /* I think we'll assume this means the server died...it might not,
340                          * but I think that any error serious enough we actually get here
341                          * deserves to reconnect [/excuse]
342                          * Returning true so the core doesn't try and close the connection.
343                          */
344                         DelayReconnect();
345                 }
346         }
347
348         bool DoResetPoll()
349         {
350                 switch(PQresetPoll(sql))
351                 {
352                         case PGRES_POLLING_WRITING:
353                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ);
354                                 status = CWRITE;
355                                 return DoPoll();
356                         case PGRES_POLLING_READING:
357                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
358                                 status = CREAD;
359                                 return true;
360                         case PGRES_POLLING_FAILED:
361                                 return false;
362                         case PGRES_POLLING_OK:
363                                 ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
364                                 status = WWRITE;
365                                 DoConnectedPoll();
366                         default:
367                                 return true;
368                 }
369         }
370
371         void DelayReconnect();
372
373         void DoEvent()
374         {
375                 if((status == CREAD) || (status == CWRITE))
376                 {
377                         DoPoll();
378                 }
379                 else if((status == RREAD) || (status == RWRITE))
380                 {
381                         DoResetPoll();
382                 }
383                 else
384                 {
385                         DoConnectedPoll();
386                 }
387         }
388
389         void submit(SQLQuery *req, const std::string& q)
390         {
391                 if (qinprog.q.empty())
392                 {
393                         DoQuery(QueueItem(req,q));
394                 }
395                 else
396                 {
397                         // wait your turn.
398                         queue.push_back(QueueItem(req,q));
399                 }
400         }
401
402         void submit(SQLQuery *req, const std::string& q, const ParamL& p)
403         {
404                 std::string res;
405                 unsigned int param = 0;
406                 for(std::string::size_type i = 0; i < q.length(); i++)
407                 {
408                         if (q[i] != '?')
409                                 res.push_back(q[i]);
410                         else
411                         {
412                                 if (param < p.size())
413                                 {
414                                         std::string parm = p[param++];
415                                         char buffer[MAXBUF];
416 #ifdef PGSQL_HAS_ESCAPECONN
417                                         int error;
418                                         PQescapeStringConn(sql, buffer, parm.c_str(), parm.length(), &error);
419                                         if (error)
420                                                 ServerInstance->Logs->Log("m_pgsql", DEBUG, "BUG: Apparently PQescapeStringConn() failed");
421 #else
422                                         PQescapeString         (buffer, parm.c_str(), parm.length());
423 #endif
424                                         res.append(buffer);
425                                 }
426                         }
427                 }
428                 submit(req, res);
429         }
430
431         void submit(SQLQuery *req, const std::string& q, const ParamM& p)
432         {
433                 std::string res;
434                 for(std::string::size_type i = 0; i < q.length(); i++)
435                 {
436                         if (q[i] != '$')
437                                 res.push_back(q[i]);
438                         else
439                         {
440                                 std::string field;
441                                 i++;
442                                 while (i < q.length() && isalnum(q[i]))
443                                         field.push_back(q[i++]);
444                                 i--;
445
446                                 ParamM::const_iterator it = p.find(field);
447                                 if (it != p.end())
448                                 {
449                                         std::string parm = it->second;
450                                         char buffer[MAXBUF];
451 #ifdef PGSQL_HAS_ESCAPECONN
452                                         int error;
453                                         PQescapeStringConn(sql, buffer, parm.c_str(), parm.length(), &error);
454                                         if (error)
455                                                 ServerInstance->Logs->Log("m_pgsql", DEBUG, "BUG: Apparently PQescapeStringConn() failed");
456 #else
457                                         PQescapeString         (buffer, parm.c_str(), parm.length());
458 #endif
459                                         res.append(buffer);
460                                 }
461                         }
462                 }
463                 submit(req, res);
464         }
465
466         void DoQuery(const QueueItem& req)
467         {
468                 if (status != WREAD && status != WWRITE)
469                 {
470                         // whoops, not connected...
471                         SQLerror err(SQL_BAD_CONN);
472                         req.c->OnError(err);
473                         delete req.c;
474                         return;
475                 }
476
477                 if(PQsendQuery(sql, req.q.c_str()))
478                 {
479                         qinprog = req;
480                 }
481                 else
482                 {
483                         SQLerror err(SQL_QSEND_FAIL, PQerrorMessage(sql));
484                         req.c->OnError(err);
485                         delete req.c;
486                 }
487         }
488
489         void Close()
490         {
491                 ServerInstance->SE->DelFd(this);
492
493                 if(sql)
494                 {
495                         PQfinish(sql);
496                         sql = NULL;
497                 }
498         }
499 };
500
501 class ModulePgSQL : public Module
502 {
503  public:
504         ConnMap connections;
505         ReconnectTimer* retimer;
506
507         ModulePgSQL()
508         {
509         }
510
511         void init()
512         {
513                 ReadConf();
514
515                 Implementation eventlist[] = { I_OnUnloadModule, I_OnRehash };
516                 ServerInstance->Modules->Attach(eventlist, this, 2);
517         }
518
519         virtual ~ModulePgSQL()
520         {
521                 if (retimer)
522                         ServerInstance->Timers->DelTimer(retimer);
523                 ClearAllConnections();
524         }
525
526         virtual void OnRehash(User* user)
527         {
528                 ReadConf();
529         }
530
531         void ReadConf()
532         {
533                 ConnMap conns;
534                 ConfigTagList tags = ServerInstance->Config->ConfTags("database");
535                 for(ConfigIter i = tags.first; i != tags.second; i++)
536                 {
537                         if (i->second->getString("module", "pgsql") != "pgsql")
538                                 continue;
539                         std::string id = i->second->getString("id");
540                         ConnMap::iterator curr = connections.find(id);
541                         if (curr == connections.end())
542                         {
543                                 SQLConn* conn = new SQLConn(this, i->second);
544                                 conns.insert(std::make_pair(id, conn));
545                                 ServerInstance->Modules->AddService(*conn);
546                         }
547                         else
548                         {
549                                 conns.insert(*curr);
550                                 connections.erase(curr);
551                         }
552                 }
553                 ClearAllConnections();
554                 conns.swap(connections);
555         }
556
557         void ClearAllConnections()
558         {
559                 for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
560                 {
561                         i->second->cull();
562                         delete i->second;
563                 }
564                 connections.clear();
565         }
566
567         void OnUnloadModule(Module* mod)
568         {
569                 SQLerror err(SQL_BAD_DBID);
570                 for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
571                 {
572                         SQLConn* conn = i->second;
573                         if (conn->qinprog.c && conn->qinprog.c->creator == mod)
574                         {
575                                 conn->qinprog.c->OnError(err);
576                                 delete conn->qinprog.c;
577                                 conn->qinprog.c = NULL;
578                         }
579                         std::deque<QueueItem>::iterator j = conn->queue.begin();
580                         while (j != conn->queue.end())
581                         {
582                                 SQLQuery* q = j->c;
583                                 if (q->creator == mod)
584                                 {
585                                         q->OnError(err);
586                                         delete q;
587                                         j = conn->queue.erase(j);
588                                 }
589                                 else
590                                         j++;
591                         }
592                 }
593         }
594
595         Version GetVersion()
596         {
597                 return Version("PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API", VF_VENDOR);
598         }
599 };
600
601 void ReconnectTimer::Tick(time_t time)
602 {
603         mod->retimer = NULL;
604         mod->ReadConf();
605 }
606
607 void SQLConn::DelayReconnect()
608 {
609         ModulePgSQL* mod = (ModulePgSQL*)(Module*)creator;
610         ConnMap::iterator it = mod->connections.find(conf->getString("id"));
611         if (it != mod->connections.end())
612         {
613                 mod->connections.erase(it);
614                 ServerInstance->GlobalCulls.AddItem((EventHandler*)this);
615                 if (!mod->retimer)
616                 {
617                         mod->retimer = new ReconnectTimer(mod);
618                         ServerInstance->Timers->AddTimer(mod->retimer);
619                 }
620         }
621 }
622
623 MODULE_INIT(ModulePgSQL)