typedef std::map<std::string, SQLConnection*> ConnMap;
typedef std::deque<SQLresult*> ResultQueue;
-unsigned long count(const char * const str, char a)
+static unsigned long count(const char * const str, char a)
{
unsigned long n = 0;
- for (const char *p = reinterpret_cast<const char *>(str); *p; ++p)
+ for (const char *p = str; *p; ++p)
{
if (*p == '?')
++n;
class ModuleSQL : public Module
{
public:
-
- ConfigReader *Conf;
- InspIRCd* PublicServerInstance;
int currid;
bool rehashing;
DispatcherThread* Dispatcher;
Mutex ResultsMutex;
Mutex LoggingMutex;
Mutex ConnMutex;
+ ServiceProvider sqlserv;
- ModuleSQL(InspIRCd* Me);
+ ModuleSQL();
~ModuleSQL();
unsigned long NewID();
- const char* OnRequest(Request* request);
- void OnRehash(User* user, const std::string ¶meter);
+ void OnRequest(Request& request);
+ void OnRehash(User* user);
Version GetVersion();
};
std::map<std::string,std::string> thisrow;
bool Enabled;
ModuleSQL* Parent;
+ std::string initquery;
public:
if (!CheckConnection())
return;
+ if( !initquery.empty() )
+ mysql_query(connection,initquery.c_str());
+
/* Parse the command string and dispatch it to mysql */
- SQLrequest& req = queue.front();
+ SQLrequest* req = queue.front();
/* Pointer to the buffer we screw around with substitution in */
char* query;
/* The length of the longest parameter */
maxparamlen = 0;
- for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
+ for(ParamL::iterator i = req->query.p.begin(); i != req->query.p.end(); i++)
{
if (i->size() > maxparamlen)
maxparamlen = i->size();
}
/* How many params are there in the query? */
- paramcount = count(req.query.q.c_str(), '?');
+ paramcount = count(req->query.q.c_str(), '?');
/* This stores copy of params to be inserted with using numbered params 1;3B*/
- ParamL paramscopy(req.query.p);
+ ParamL paramscopy(req->query.p);
/* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
* sizeofquery + (maxtotalparamlength*2) + 1
* The +1 is for null-terminating the string for mysql_real_escape_string
*/
- query = new char[req.query.q.length() + (maxparamlen*paramcount*2) + 1];
+ query = new char[req->query.q.length() + (maxparamlen*paramcount*2) + 1];
queryend = query;
/* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting
* the parameters into it...
*/
- for(unsigned long i = 0; i < req.query.q.length(); i++)
+ for(unsigned long i = 0; i < req->query.q.length(); i++)
{
- if(req.query.q[i] == '?')
+ if(req->query.q[i] == '?')
{
/* We found a place to substitute..what fun.
* use mysql calls to escape and write the
/* Let's check if it's a numbered param. And also calculate it's number.
*/
- while ((i < req.query.q.length() - 1) && (req.query.q[i+1] >= '0') && (req.query.q[i+1] <= '9'))
+ while ((i < req->query.q.length() - 1) && (req->query.q[i+1] >= '0') && (req->query.q[i+1] <= '9'))
{
numbered = true;
++i;
- paramnum = paramnum * 10 + req.query.q[i] - '0';
+ paramnum = paramnum * 10 + req->query.q[i] - '0';
}
if (paramnum > paramscopy.size() - 1)
queryend += len;
}
- else if (req.query.p.size())
+ else if (req->query.p.size())
{
- unsigned long len = mysql_real_escape_string(connection, queryend, req.query.p.front().c_str(), req.query.p.front().length());
+ unsigned long len = mysql_real_escape_string(connection, queryend, req->query.p.front().c_str(), req->query.p.front().length());
queryend += len;
- req.query.p.pop_front();
+ req->query.p.pop_front();
}
else
break;
}
else
{
- *queryend = req.query.q[i];
+ *queryend = req->query.q[i];
queryend++;
}
}
*queryend = 0;
- req.query.q = query;
+ req->query.q = query;
- if (!mysql_real_query(connection, req.query.q.data(), req.query.q.length()))
+ if (!mysql_real_query(connection, req->query.q.data(), req->query.q.length()))
{
/* Successfull query */
res = mysql_use_result(connection);
unsigned long rows = mysql_affected_rows(connection);
- MySQLresult* r = new MySQLresult(Parent, req.GetSource(), res, rows, req.id);
+ MySQLresult* r = new MySQLresult(Parent, req->source, res, rows, req->id);
r->dbid = this->GetID();
- r->query = req.query.q;
+ r->query = req->query.q;
/* Put this new result onto the results queue.
* XXX: Remember to mutex the queue!
*/
/* XXX: See /usr/include/mysql/mysqld_error.h for a list of
* possible error numbers and error messages */
SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + std::string(": ") + mysql_error(connection));
- MySQLresult* r = new MySQLresult(Parent, req.GetSource(), e, req.id);
+ MySQLresult* r = new MySQLresult(Parent, req->source, e, req->id);
r->dbid = this->GetID();
- r->query = req.query.q;
+ r->query = req->query.q;
Parent->ResultsMutex.Lock();
rq.push_back(r);
return host.host;
}
+ void setInitialQuery(std::string init)
+ {
+ initquery = init;
+ }
+
void SetEnable(bool Enable)
{
Enabled = Enable;
return false;
}
-bool HostInConf(ConfigReader* conf, const SQLhost &h)
+bool HostInConf(const SQLhost &h)
{
- for(int i = 0; i < conf->Enumerate("database"); i++)
+ ConfigReader conf;
+ for(int i = 0; i < conf.Enumerate("database"); i++)
{
SQLhost host;
- host.id = conf->ReadValue("database", "id", i);
- host.host = conf->ReadValue("database", "hostname", i);
- host.port = conf->ReadInteger("database", "port", i, true);
- host.name = conf->ReadValue("database", "name", i);
- host.user = conf->ReadValue("database", "username", i);
- host.pass = conf->ReadValue("database", "password", i);
- host.ssl = conf->ReadFlag("database", "ssl", i);
+ host.id = conf.ReadValue("database", "id", i);
+ host.host = conf.ReadValue("database", "hostname", i);
+ host.port = conf.ReadInteger("database", "port", i, true);
+ host.name = conf.ReadValue("database", "name", i);
+ host.user = conf.ReadValue("database", "username", i);
+ host.pass = conf.ReadValue("database", "password", i);
+ host.ssl = conf.ReadFlag("database", "ssl", i);
if (h == host)
return true;
}
return false;
}
-void ClearOldConnections(ConfigReader* conf)
+void ClearOldConnections()
{
ConnMap::iterator i,safei;
for (i = Connections.begin(); i != Connections.end(); i++)
{
- if (!HostInConf(conf, i->second->GetConfHost()))
+ if (!HostInConf(i->second->GetConfHost()))
{
delete i->second;
safei = i;
}
}
-void ConnectDatabases(InspIRCd* ServerInstance, ModuleSQL* Parent)
+void ConnectDatabases(ModuleSQL* Parent)
{
for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
{
}
}
-void LoadDatabases(ConfigReader* conf, InspIRCd* ServerInstance, ModuleSQL* Parent)
+void LoadDatabases(ModuleSQL* Parent)
{
+ ConfigReader conf;
Parent->ConnMutex.Lock();
- ClearOldConnections(conf);
- for (int j =0; j < conf->Enumerate("database"); j++)
+ ClearOldConnections();
+ for (int j =0; j < conf.Enumerate("database"); j++)
{
SQLhost host;
- host.id = conf->ReadValue("database", "id", j);
- host.host = conf->ReadValue("database", "hostname", j);
- host.port = conf->ReadInteger("database", "port", j, true);
- host.name = conf->ReadValue("database", "name", j);
- host.user = conf->ReadValue("database", "username", j);
- host.pass = conf->ReadValue("database", "password", j);
- host.ssl = conf->ReadFlag("database", "ssl", j);
+ host.id = conf.ReadValue("database", "id", j);
+ host.host = conf.ReadValue("database", "hostname", j);
+ host.port = conf.ReadInteger("database", "port", j, true);
+ host.name = conf.ReadValue("database", "name", j);
+ host.user = conf.ReadValue("database", "username", j);
+ host.pass = conf.ReadValue("database", "password", j);
+ host.ssl = conf.ReadFlag("database", "ssl", j);
+ std::string initquery = conf.ReadValue("database", "initialquery", j);
if (HasHost(host))
continue;
{
SQLConnection* ThisSQL = new SQLConnection(host, Parent);
Connections[host.id] = ThisSQL;
+
+ ThisSQL->setInitialQuery(initquery);
}
}
- ConnectDatabases(ServerInstance, Parent);
+ ConnectDatabases(Parent);
Parent->ConnMutex.Unlock();
}
return Connections.end();
}
-class ModuleSQL;
-
class DispatcherThread : public SocketThread
{
private:
- ModuleSQL* Parent;
- InspIRCd* ServerInstance;
+ ModuleSQL* const Parent;
public:
- DispatcherThread(InspIRCd* Instance, ModuleSQL* CreatorModule) : SocketThread(Instance), Parent(CreatorModule), ServerInstance(Instance) { }
+ DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
~DispatcherThread() { }
virtual void Run();
virtual void OnNotify();
};
-ModuleSQL::ModuleSQL(InspIRCd* Me) : Module(Me), rehashing(false)
+ModuleSQL::ModuleSQL() : rehashing(false), sqlserv(this, "SQL/mysql", SERVICE_DATA)
{
- ServerInstance->Modules->UseInterface("SQLutils");
-
- Conf = new ConfigReader(ServerInstance);
- PublicServerInstance = ServerInstance;
currid = 0;
- Dispatcher = new DispatcherThread(ServerInstance, this);
+ Dispatcher = new DispatcherThread(this);
ServerInstance->Threads->Start(Dispatcher);
- if (!ServerInstance->Modules->PublishFeature("SQL", this))
- {
- /* Tell worker thread to exit NOW,
- * Automatically joins */
- delete Dispatcher;
- ServerInstance->Modules->DoneWithInterface("SQLutils");
- throw ModuleException("m_mysql: Unable to publish feature 'SQL'");
- }
-
- ServerInstance->Modules->PublishInterface("SQL", this);
- Implementation eventlist[] = { I_OnRehash, I_OnRequest };
- ServerInstance->Modules->Attach(eventlist, this, 2);
+ Implementation eventlist[] = { I_OnRehash };
+ ServerInstance->Modules->Attach(eventlist, this, 1);
}
ModuleSQL::~ModuleSQL()
{
delete Dispatcher;
ClearAllConnections();
- delete Conf;
- ServerInstance->Modules->UnpublishInterface("SQL", this);
- ServerInstance->Modules->UnpublishFeature("SQL");
- ServerInstance->Modules->DoneWithInterface("SQLutils");
}
unsigned long ModuleSQL::NewID()
return ++currid;
}
-const char* ModuleSQL::OnRequest(Request* request)
+void ModuleSQL::OnRequest(Request& request)
{
- if(strcmp(SQLREQID, request->GetId()) == 0)
+ if(strcmp(SQLREQID, request.id) == 0)
{
- SQLrequest* req = (SQLrequest*)request;
+ SQLrequest* req = (SQLrequest*)&request;
ConnMap::iterator iter;
- const char* returnval = NULL;
-
Dispatcher->LockQueue();
ConnMutex.Lock();
if((iter = Connections.find(req->dbid)) != Connections.end())
{
req->id = NewID();
- iter->second->queue.push(*req);
- returnval = SQLSUCCESS;
+ iter->second->queue.push(new SQLrequest(*req));
}
else
{
/* Yes, it's possible this will generate a spurious wakeup.
* That's fine, it'll just get ignored.
*/
-
- return returnval;
}
-
- return NULL;
}
-void ModuleSQL::OnRehash(User* user, const std::string ¶meter)
+void ModuleSQL::OnRehash(User* user)
{
Dispatcher->LockQueue();
rehashing = true;
Version ModuleSQL::GetVersion()
{
- return Version("$Id$", VF_VENDOR | VF_SERVICEPROVIDER, API_VERSION);
+ return Version("SQL Service Provider module for all other m_sql* modules", VF_VENDOR);
}
void DispatcherThread::Run()
{
- LoadDatabases(Parent->Conf, Parent->PublicServerInstance, Parent);
+ LoadDatabases(Parent);
SQLConnection* conn = NULL;
if (Parent->rehashing)
{
Parent->rehashing = false;
- LoadDatabases(Parent->Conf, Parent->PublicServerInstance, Parent);
+ LoadDatabases(Parent);
}
+ conn = NULL;
Parent->ConnMutex.Lock();
for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
{
void DispatcherThread::OnNotify()
{
+ SQLConnection* conn;
while (1)
{
- SQLConnection* conn = NULL;
+ conn = NULL;
Parent->ConnMutex.Lock();
for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); iter++)
{