1 /* +------------------------------------+
2 * | Inspire Internet Relay Chat Daemon |
3 * +------------------------------------+
5 * InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
7 * <brain@chatspike.net>
8 * <Craig@chatspike.net>
10 * Written by Craig Edwards, Craig McLure, and others.
11 * This program is free but copyrighted software; see
12 * the file COPYING for details.
14 * ---------------------------------------------------
26 #include "helperfuncs.h"
29 /* VERSION 2 API: With nonblocking (threaded) requests */
31 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
32 /* $CompileFlags: `mysql_config --include` */
33 /* $LinkerFlags: `mysql_config --libs_r` `perl ../mysql_rpath.pl` */
38 extern InspIRCd* ServerInstance;
39 typedef std::map<std::string, SQLConnection*> ConnMap;
43 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
44 #define mysql_field_count mysql_num_fields
47 class QueryQueue : public classbase
50 typedef std::deque<SQLrequest> ReqDeque;
52 ReqDeque priority; /* The priority queue */
53 ReqDeque normal; /* The 'normal' queue */
54 enum { PRI, NOR, NON } which; /* Which queue the currently active element is at the front of */
62 void push(const SQLrequest &q)
64 log(DEBUG, "QueryQueue::push(): Adding %s query to queue: %s", ((q.pri) ? "priority" : "non-priority"), q.query.q.c_str());
67 priority.push_back(q);
74 if((which == PRI) && priority.size())
78 else if((which == NOR) && normal.size())
86 /* Silently do nothing if there was no element to pop() */
94 return priority.front();
96 return normal.front();
101 return priority.front();
107 return normal.front();
110 /* This will probably result in a segfault,
111 * but the caller should have checked totalsize()
112 * first so..meh - moron :p
115 return priority.front();
119 std::pair<int, int> size()
121 return std::make_pair(priority.size(), normal.size());
126 return priority.size() + normal.size();
129 void PurgeModule(Module* mod)
131 DoPurgeModule(mod, priority);
132 DoPurgeModule(mod, normal);
136 void DoPurgeModule(Module* mod, ReqDeque& q)
138 for(ReqDeque::iterator iter = q.begin(); iter != q.end(); iter++)
140 if(iter->GetSource() == mod)
142 if(iter->id == front().id)
144 /* It's the currently active query.. :x */
145 iter->SetSource(NULL);
149 /* It hasn't been executed yet..just remove it */
150 iter = q.erase(iter);
157 /* A mutex to wrap around queue accesses */
158 pthread_mutex_t queue_mutex = PTHREAD_MUTEX_INITIALIZER;
160 class SQLConnection : public classbase
171 std::map<std::string,std::string> thisrow;
179 // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
180 // MYSQL struct, but does not connect yet.
181 SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
183 this->Enabled = true;
184 this->host = thishost;
185 this->user = thisuser;
186 this->pass = thispass;
191 // This method connects to the database using the credentials supplied to the constructor, and returns
192 // true upon success.
195 unsigned int timeout = 1;
196 mysql_init(&connection);
197 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
198 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
201 void DoLeadingQuery()
203 /* Parse the command string and dispatch it to mysql */
204 SQLrequest& req = queue.front();
205 log(DEBUG,"DO QUERY: %s",req.query.q.c_str());
207 /* Pointer to the buffer we screw around with substitution in */
210 /* Pointer to the current end of query, where we append new stuff */
213 /* Total length of the unescaped parameters */
214 unsigned long paramlen;
218 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
220 paramlen += i->size();
223 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
224 * sizeofquery + (totalparamlength*2) + 1
226 * The +1 is for null-terminating the string for mysql_real_escape_string
229 query = new char[req.query.q.length() + (paramlen*2)];
232 /* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting
233 * the parameters into it...
236 for(unsigned long i = 0; i < req.query.q.length(); i++)
238 if(req.query.q[i] == '?')
240 /* We found a place to substitute..what fun.
241 * use mysql calls to escape and write the
242 * escaped string onto the end of our query buffer,
243 * then we "just" need to make sure queryend is
244 * pointing at the right place.
246 if(req.query.p.size())
248 unsigned long len = mysql_real_escape_string(&connection, queryend, req.query.p.front().c_str(), req.query.p.front().length());
251 req.query.p.pop_front();
255 log(DEBUG, "Found a substitution location but no parameter to substitute :|");
261 *queryend = req.query.q[i];
268 log(DEBUG, "Attempting to dispatch query: %s", query);
270 pthread_mutex_lock(&queue_mutex);
272 pthread_mutex_unlock(&queue_mutex);
274 /* TODO: Do the mysql_real_query here */
277 // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
279 bool QueryResult(std::string query)
281 if (!CheckConnection()) return false;
283 int r = mysql_query(&connection, query.c_str());
286 res = mysql_use_result(&connection);
291 // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
292 // the number of effected rows is returned in the return value.
293 long QueryCount(std::string query)
295 /* If the connection is down, we return a negative value - New to 1.1 */
296 if (!CheckConnection()) return -1;
298 int r = mysql_query(&connection, query.c_str());
301 res = mysql_store_result(&connection);
302 unsigned long rows = mysql_affected_rows(&connection);
303 mysql_free_result(res);
309 // This method fetches a row, if available from the database. You must issue a query
310 // using QueryResult() first! The row's values are returned as a map of std::string
311 // where each item is keyed by the column name.
312 std::map<std::string,std::string> GetRow()
317 row = mysql_fetch_row(res);
320 unsigned int field_count = 0;
321 MYSQL_FIELD *fields = mysql_fetch_fields(res);
322 if(mysql_field_count(&connection) == 0)
324 if (fields && mysql_field_count(&connection))
326 while (field_count < mysql_field_count(&connection))
328 std::string a = (fields[field_count].name ? fields[field_count].name : "");
329 std::string b = (row[field_count] ? row[field_count] : "");
344 mysql_free_result(res);
351 bool ConnectionLost()
354 return (mysql_ping(&connection) != 0);
359 bool CheckConnection()
361 if (ConnectionLost()) {
367 std::string GetError()
369 return mysql_error(&connection);
377 std::string GetHost()
401 void ConnectDatabases(Server* Srv)
403 for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
406 if (i->second->Connect())
408 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->second->GetHost());
412 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->second->GetHost()+": Error: "+i->second->GetError());
413 i->second->Disable();
419 void LoadDatabases(ConfigReader* ThisConf, Server* Srv)
421 Srv->Log(DEFAULT,"SQL: Loading database settings");
423 Srv->Log(DEBUG,"Cleared connections");
424 for (int j =0; j < ThisConf->Enumerate("database"); j++)
426 std::string db = ThisConf->ReadValue("database","name",j);
427 std::string user = ThisConf->ReadValue("database","username",j);
428 std::string pass = ThisConf->ReadValue("database","password",j);
429 std::string host = ThisConf->ReadValue("database","hostname",j);
430 std::string id = ThisConf->ReadValue("database","id",j);
431 Srv->Log(DEBUG,"Read database settings");
432 if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
434 SQLConnection* ThisSQL = new SQLConnection(host,user,pass,db,atoi(id.c_str()));
435 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL->GetHost());
436 Connections[id] = ThisSQL;
437 Srv->Log(DEBUG,"Pushed back connection");
440 ConnectDatabases(Srv);
443 void* DispatcherThread(void* arg);
445 class ModuleSQL : public Module
450 pthread_t Dispatcher;
453 void Implements(char* List)
455 List[I_OnRehash] = List[I_OnRequest] = 1;
458 unsigned long NewID()
465 char* OnRequest(Request* request)
467 if(strcmp(SQLREQID, request->GetData()) == 0)
469 SQLrequest* req = (SQLrequest*)request;
472 pthread_mutex_lock(&queue_mutex);
474 ConnMap::iterator iter;
476 char* returnval = NULL;
478 log(DEBUG, "Got query: '%s' with %d replacement parameters on id '%s'", req->query.q.c_str(), req->query.p.size(), req->dbid.c_str());
480 if((iter = Connections.find(req->dbid)) != Connections.end())
482 iter->second->queue.push(*req);
484 returnval = SQLSUCCESS;
488 req->error.Id(BAD_DBID);
491 pthread_mutex_unlock(&queue_mutex);
497 log(DEBUG, "Got unsupported API version string: %s", request->GetData());
502 ModuleSQL(Server* Me)
506 Conf = new ConfigReader();
508 pthread_attr_t attribs;
509 pthread_attr_init(&attribs);
510 pthread_attr_setdetachstate(&attribs, PTHREAD_CREATE_DETACHED);
511 if (pthread_create(&this->Dispatcher, &attribs, DispatcherThread, (void *)this) != 0)
513 throw ModuleException("m_mysql: Failed to create dispatcher thread: " + std::string(strerror(errno)));
515 Srv->PublishFeature("SQL", this);
516 Srv->PublishFeature("MySQL", this);
524 virtual void OnRehash(const std::string ¶meter)
526 /* TODO: set rehash bool here, which makes the dispatcher thread rehash at next opportunity */
529 virtual Version GetVersion()
531 return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
536 void* DispatcherThread(void* arg)
538 ModuleSQL* thismodule = (ModuleSQL*)arg;
539 LoadDatabases(thismodule->Conf, thismodule->Srv);
543 SQLConnection* conn = NULL;
544 /* XXX: Lock here for safety */
545 pthread_mutex_lock(&queue_mutex);
546 for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
548 if (i->second->queue.totalsize())
554 pthread_mutex_unlock(&queue_mutex);
557 /* Theres an item! */
560 conn->DoLeadingQuery();
563 pthread_mutex_lock(&queue_mutex);
565 pthread_mutex_unlock(&queue_mutex);
576 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
578 class ModuleSQLFactory : public ModuleFactory
589 virtual Module * CreateModule(Server* Me)
591 return new ModuleSQL(Me);
597 extern "C" void * init_module( void )
599 return new ModuleSQLFactory;