summaryrefslogtreecommitdiff
path: root/src/modules/extra/m_pgsql.cpp
blob: b3b7d43dfd45444e46fa5001d54413e4ccdf2bfa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
/*
 * InspIRCd -- Internet Relay Chat Daemon
 *
 *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
 *   Copyright (C) 2006-2007, 2009 Dennis Friis <peavey@inspircd.org>
 *   Copyright (C) 2006-2007, 2009 Craig Edwards <craigedwards@brainbox.cc>
 *   Copyright (C) 2008 Robin Burchell <robin+git@viroteck.net>
 *   Copyright (C) 2008 Thomas Stagner <aquanight@inspircd.org>
 *   Copyright (C) 2006 Oliver Lupton <oliverlupton@gmail.com>
 *
 * This file is part of InspIRCd.  InspIRCd is free software: you can
 * redistribute it and/or modify it under the terms of the GNU General Public
 * License as published by the Free Software Foundation, version 2.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */


#include "inspircd.h"
#include <cstdlib>
#include <sstream>
#include <libpq-fe.h>
#include "modules/sql.h"

/* $CompileFlags: -Iexec("pg_config --includedir") eval("my $s = `pg_config --version`;$s =~ /^.*?(\d+)\.(\d+)\.(\d+).*?$/;my $v = hex(sprintf("0x%02x%02x%02x", $1, $2, $3));print "-DPGSQL_HAS_ESCAPECONN" if(($v >= 0x080104) || ($v >= 0x07030F && $v < 0x070400) || ($v >= 0x07040D && $v < 0x080000) || ($v >= 0x080008 && $v < 0x080100));") */
/* $LinkerFlags: -Lexec("pg_config --libdir") -lpq */

/* SQLConn rewritten by peavey to
 * use EventHandler instead of
 * BufferedSocket. This is much neater
 * and gives total control of destroy
 * and delete of resources.
 */

/* Forward declare, so we can have the typedef neatly at the top */
class SQLConn;
class ModulePgSQL;

typedef std::map<std::string, SQLConn*> ConnMap;

/* CREAD,	Connecting and wants read event
 * CWRITE,	Connecting and wants write event
 * WREAD,	Connected/Working and wants read event
 * WWRITE,	Connected/Working and wants write event
 * RREAD,	Resetting and wants read event
 * RWRITE,	Resetting and wants write event
 */
enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE, RREAD, RWRITE };

class ReconnectTimer : public Timer
{
 private:
	ModulePgSQL* mod;
 public:
	ReconnectTimer(ModulePgSQL* m) : Timer(5, ServerInstance->Time(), false), mod(m)
	{
	}
	bool Tick(time_t TIME);
};

struct QueueItem
{
	SQLQuery* c;
	std::string q;
	QueueItem(SQLQuery* C, const std::string& Q) : c(C), q(Q) {}
};

/** PgSQLresult is a subclass of the mostly-pure-virtual class SQLresult.
 * All SQL providers must create their own subclass and define it's methods using that
 * database library's data retriveal functions. The aim is to avoid a slow and inefficient process
 * of converting all data to a common format before it reaches the result structure. This way
 * data is passes to the module nearly as directly as if it was using the API directly itself.
 */

class PgSQLresult : public SQLResult
{
	PGresult* res;
	int currentrow;
	int rows;
 public:
	PgSQLresult(PGresult* result) : res(result), currentrow(0)
	{
		rows = PQntuples(res);
		if (!rows)
			rows = atoi(PQcmdTuples(res));
	}

	~PgSQLresult()
	{
		PQclear(res);
	}

	int Rows()
	{
		return rows;
	}

	void GetCols(std::vector<std::string>& result)
	{
		result.resize(PQnfields(res));
		for(unsigned int i=0; i < result.size(); i++)
		{
			result[i] = PQfname(res, i);
		}
	}

	SQLEntry GetValue(int row, int column)
	{
		char* v = PQgetvalue(res, row, column);
		if (!v || PQgetisnull(res, row, column))
			return SQLEntry();

		return SQLEntry(std::string(v, PQgetlength(res, row, column)));
	}

	bool GetRow(SQLEntries& result)
	{
		if (currentrow >= PQntuples(res))
			return false;
		int ncols = PQnfields(res);

		for(int i = 0; i < ncols; i++)
		{
			result.push_back(GetValue(currentrow, i));
		}
		currentrow++;

		return true;
	}
};

/** SQLConn represents one SQL session.
 */
class SQLConn : public SQLProvider, public EventHandler
{
 public:
	reference<ConfigTag> conf;	/* The <database> entry */
	std::deque<QueueItem> queue;
	PGconn* 		sql;		/* PgSQL database connection handle */
	SQLstatus		status;		/* PgSQL database connection status */
	QueueItem		qinprog;	/* If there is currently a query in progress */

	SQLConn(Module* Creator, ConfigTag* tag)
	: SQLProvider(Creator, "SQL/" + tag->getString("id")), conf(tag), sql(NULL), status(CWRITE), qinprog(NULL, "")
	{
		if (!DoConnect())
		{
			ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: Could not connect to database " + tag->getString("id"));
			DelayReconnect();
		}
	}

	CullResult cull()
	{
		this->SQLProvider::cull();
		ServerInstance->Modules->DelService(*this);
		return this->EventHandler::cull();
	}

	~SQLConn()
	{
		SQLerror err(SQL_BAD_DBID);
		if (qinprog.c)
		{
			qinprog.c->OnError(err);
			delete qinprog.c;
		}
		for(std::deque<QueueItem>::iterator i = queue.begin(); i != queue.end(); i++)
		{
			SQLQuery* q = i->c;
			q->OnError(err);
			delete q;
		}
	}

	void HandleEvent(EventType et, int errornum)
	{
		switch (et)
		{
			case EVENT_READ:
			case EVENT_WRITE:
				DoEvent();
			break;

			case EVENT_ERROR:
				DelayReconnect();
		}
	}

	std::string GetDSN()
	{
		std::ostringstream conninfo("connect_timeout = '5'");
		std::string item;

		if (conf->readString("host", item))
			conninfo << " host = '" << item << "'";

		if (conf->readString("port", item))
			conninfo << " port = '" << item << "'";

		if (conf->readString("name", item))
			conninfo << " dbname = '" << item << "'";

		if (conf->readString("user", item))
			conninfo << " user = '" << item << "'";

		if (conf->readString("pass", item))
			conninfo << " password = '" << item << "'";

		if (conf->getBool("ssl"))
			conninfo << " sslmode = 'require'";
		else
			conninfo << " sslmode = 'disable'";

		return conninfo.str();
	}

	bool DoConnect()
	{
		sql = PQconnectStart(GetDSN().c_str());
		if (!sql)
			return false;

		if(PQstatus(sql) == CONNECTION_BAD)
			return false;

		if(PQsetnonblocking(sql, 1) == -1)
			return false;

		/* OK, we've initalised the connection, now to get it hooked into the socket engine
		* and then start polling it.
		*/
		this->fd = PQsocket(sql);

		if(this->fd <= -1)
			return false;

		if (!ServerInstance->SE->AddFd(this, FD_WANT_NO_WRITE | FD_WANT_NO_READ))
		{
			ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Couldn't add pgsql socket to socket engine");
			return false;
		}

		/* Socket all hooked into the engine, now to tell PgSQL to start connecting */
		return DoPoll();
	}

	bool DoPoll()
	{
		switch(PQconnectPoll(sql))
		{
			case PGRES_POLLING_WRITING:
				ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ);
				status = CWRITE;
				return true;
			case PGRES_POLLING_READING:
				ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
				status = CREAD;
				return true;
			case PGRES_POLLING_FAILED:
				return false;
			case PGRES_POLLING_OK:
				ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
				status = WWRITE;
				DoConnectedPoll();
			default:
				return true;
		}
	}

	void DoConnectedPoll()
	{
restart:
		while (qinprog.q.empty() && !queue.empty())
		{
			/* There's no query currently in progress, and there's queries in the queue. */
			DoQuery(queue.front());
			queue.pop_front();
		}

		if (PQconsumeInput(sql))
		{
			if (PQisBusy(sql))
			{
				/* Nothing happens here */
			}
			else if (qinprog.c)
			{
				/* Fetch the result.. */
				PGresult* result = PQgetResult(sql);

				/* PgSQL would allow a query string to be sent which has multiple
				 * queries in it, this isn't portable across database backends and
				 * we don't want modules doing it. But just in case we make sure we
				 * drain any results there are and just use the last one.
				 * If the module devs are behaving there will only be one result.
				 */
				while (PGresult* temp = PQgetResult(sql))
				{
					PQclear(result);
					result = temp;
				}

				/* ..and the result */
				PgSQLresult reply(result);
				switch(PQresultStatus(result))
				{
					case PGRES_EMPTY_QUERY:
					case PGRES_BAD_RESPONSE:
					case PGRES_FATAL_ERROR:
					{
						SQLerror err(SQL_QREPLY_FAIL, PQresultErrorMessage(result));
						qinprog.c->OnError(err);
						break;
					}
					default:
						/* Other values are not errors */
						qinprog.c->OnResult(reply);
				}

				delete qinprog.c;
				qinprog = QueueItem(NULL, "");
				goto restart;
			}
			else
			{
				qinprog.q.clear();
			}
		}
		else
		{
			/* I think we'll assume this means the server died...it might not,
			 * but I think that any error serious enough we actually get here
			 * deserves to reconnect [/excuse]
			 * Returning true so the core doesn't try and close the connection.
			 */
			DelayReconnect();
		}
	}

	bool DoResetPoll()
	{
		switch(PQresetPoll(sql))
		{
			case PGRES_POLLING_WRITING:
				ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_WRITE | FD_WANT_NO_READ);
				status = CWRITE;
				return DoPoll();
			case PGRES_POLLING_READING:
				ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
				status = CREAD;
				return true;
			case PGRES_POLLING_FAILED:
				return false;
			case PGRES_POLLING_OK:
				ServerInstance->SE->ChangeEventMask(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
				status = WWRITE;
				DoConnectedPoll();
			default:
				return true;
		}
	}

	void DelayReconnect();

	void DoEvent()
	{
		if((status == CREAD) || (status == CWRITE))
		{
			DoPoll();
		}
		else if((status == RREAD) || (status == RWRITE))
		{
			DoResetPoll();
		}
		else
		{
			DoConnectedPoll();
		}
	}

	void submit(SQLQuery *req, const std::string& q)
	{
		if (qinprog.q.empty())
		{
			DoQuery(QueueItem(req,q));
		}
		else
		{
			// wait your turn.
			queue.push_back(QueueItem(req,q));
		}
	}

	void submit(SQLQuery *req, const std::string& q, const ParamL& p)
	{
		std::string res;
		unsigned int param = 0;
		for(std::string::size_type i = 0; i < q.length(); i++)
		{
			if (q[i] != '?')
				res.push_back(q[i]);
			else
			{
				if (param < p.size())
				{
					std::string parm = p[param++];
					std::vector<char> buffer(parm.length() * 2 + 1);
#ifdef PGSQL_HAS_ESCAPECONN
					int error;
					size_t escapedsize = PQescapeStringConn(sql, &buffer[0], parm.data(), parm.length(), &error);
					if (error)
						ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Apparently PQescapeStringConn() failed");
#else
					size_t escapedsize = PQescapeString(&buffer[0], parm.data(), parm.length());
#endif
					res.append(&buffer[0], escapedsize);
				}
			}
		}
		submit(req, res);
	}

	void submit(SQLQuery *req, const std::string& q, const ParamM& p)
	{
		std::string res;
		for(std::string::size_type i = 0; i < q.length(); i++)
		{
			if (q[i] != '$')
				res.push_back(q[i]);
			else
			{
				std::string field;
				i++;
				while (i < q.length() && isalnum(q[i]))
					field.push_back(q[i++]);
				i--;

				ParamM::const_iterator it = p.find(field);
				if (it != p.end())
				{
					std::string parm = it->second;
					std::vector<char> buffer(parm.length() * 2 + 1);
#ifdef PGSQL_HAS_ESCAPECONN
					int error;
					size_t escapedsize = PQescapeStringConn(sql, &buffer[0], parm.data(), parm.length(), &error);
					if (error)
						ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Apparently PQescapeStringConn() failed");
#else
					size_t escapedsize = PQescapeString(&buffer[0], parm.data(), parm.length());
#endif
					res.append(&buffer[0], escapedsize);
				}
			}
		}
		submit(req, res);
	}

	void DoQuery(const QueueItem& req)
	{
		if (status != WREAD && status != WWRITE)
		{
			// whoops, not connected...
			SQLerror err(SQL_BAD_CONN);
			req.c->OnError(err);
			delete req.c;
			return;
		}

		if(PQsendQuery(sql, req.q.c_str()))
		{
			qinprog = req;
		}
		else
		{
			SQLerror err(SQL_QSEND_FAIL, PQerrorMessage(sql));
			req.c->OnError(err);
			delete req.c;
		}
	}

	void Close()
	{
		ServerInstance->SE->DelFd(this);

		if(sql)
		{
			PQfinish(sql);
			sql = NULL;
		}
	}
};

class ModulePgSQL : public Module
{
 public:
	ConnMap connections;
	ReconnectTimer* retimer;

	ModulePgSQL()
		: retimer(NULL)
	{
	}

	~ModulePgSQL()
	{
		delete retimer;
		ClearAllConnections();
	}

	void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
	{
		ReadConf();
	}

	void ReadConf()
	{
		ConnMap conns;
		ConfigTagList tags = ServerInstance->Config->ConfTags("database");
		for(ConfigIter i = tags.first; i != tags.second; i++)
		{
			if (i->second->getString("module", "pgsql") != "pgsql")
				continue;
			std::string id = i->second->getString("id");
			ConnMap::iterator curr = connections.find(id);
			if (curr == connections.end())
			{
				SQLConn* conn = new SQLConn(this, i->second);
				conns.insert(std::make_pair(id, conn));
				ServerInstance->Modules->AddService(*conn);
			}
			else
			{
				conns.insert(*curr);
				connections.erase(curr);
			}
		}
		ClearAllConnections();
		conns.swap(connections);
	}

	void ClearAllConnections()
	{
		for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
		{
			i->second->cull();
			delete i->second;
		}
		connections.clear();
	}

	void OnUnloadModule(Module* mod) CXX11_OVERRIDE
	{
		SQLerror err(SQL_BAD_DBID);
		for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
		{
			SQLConn* conn = i->second;
			if (conn->qinprog.c && conn->qinprog.c->creator == mod)
			{
				conn->qinprog.c->OnError(err);
				delete conn->qinprog.c;
				conn->qinprog.c = NULL;
			}
			std::deque<QueueItem>::iterator j = conn->queue.begin();
			while (j != conn->queue.end())
			{
				SQLQuery* q = j->c;
				if (q->creator == mod)
				{
					q->OnError(err);
					delete q;
					j = conn->queue.erase(j);
				}
				else
					j++;
			}
		}
	}

	Version GetVersion() CXX11_OVERRIDE
	{
		return Version("PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API", VF_VENDOR);
	}
};

bool ReconnectTimer::Tick(time_t time)
{
	mod->retimer = NULL;
	mod->ReadConf();
	delete this;
	return false;
}

void SQLConn::DelayReconnect()
{
	ModulePgSQL* mod = (ModulePgSQL*)(Module*)creator;
	ConnMap::iterator it = mod->connections.find(conf->getString("id"));
	if (it != mod->connections.end())
	{
		mod->connections.erase(it);
		ServerInstance->GlobalCulls.AddItem((EventHandler*)this);
		if (!mod->retimer)
		{
			mod->retimer = new ReconnectTimer(mod);
			ServerInstance->Timers->AddTimer(mod->retimer);
		}
	}
}

MODULE_INIT(ModulePgSQL)