public:
ConfigReader *Conf;
- InspIRCd* PublicServerInstance;
int currid;
bool rehashing;
DispatcherThread* Dispatcher;
Mutex LoggingMutex;
Mutex ConnMutex;
- ModuleSQL(InspIRCd* Me);
+ ModuleSQL();
~ModuleSQL();
unsigned long NewID();
- const char* OnRequest(Request* request);
+ void OnRequest(Request& request);
void OnRehash(User* user);
Version GetVersion();
};
std::map<std::string,std::string> thisrow;
bool Enabled;
ModuleSQL* Parent;
+ std::string initquery;
public:
if (!CheckConnection())
return;
+ if( !initquery.empty() )
+ mysql_query(connection,initquery.c_str());
+
/* Parse the command string and dispatch it to mysql */
- SQLrequest& req = queue.front();
+ SQLrequest* req = queue.front();
/* Pointer to the buffer we screw around with substitution in */
char* query;
/* The length of the longest parameter */
maxparamlen = 0;
- for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
+ for(ParamL::iterator i = req->query.p.begin(); i != req->query.p.end(); i++)
{
if (i->size() > maxparamlen)
maxparamlen = i->size();
}
/* How many params are there in the query? */
- paramcount = count(req.query.q.c_str(), '?');
+ paramcount = count(req->query.q.c_str(), '?');
/* This stores copy of params to be inserted with using numbered params 1;3B*/
- ParamL paramscopy(req.query.p);
+ ParamL paramscopy(req->query.p);
/* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
* sizeofquery + (maxtotalparamlength*2) + 1
* The +1 is for null-terminating the string for mysql_real_escape_string
*/
- query = new char[req.query.q.length() + (maxparamlen*paramcount*2) + 1];
+ query = new char[req->query.q.length() + (maxparamlen*paramcount*2) + 1];
queryend = query;
/* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting
* the parameters into it...
*/
- for(unsigned long i = 0; i < req.query.q.length(); i++)
+ for(unsigned long i = 0; i < req->query.q.length(); i++)
{
- if(req.query.q[i] == '?')
+ if(req->query.q[i] == '?')
{
/* We found a place to substitute..what fun.
* use mysql calls to escape and write the
/* Let's check if it's a numbered param. And also calculate it's number.
*/
- while ((i < req.query.q.length() - 1) && (req.query.q[i+1] >= '0') && (req.query.q[i+1] <= '9'))
+ while ((i < req->query.q.length() - 1) && (req->query.q[i+1] >= '0') && (req->query.q[i+1] <= '9'))
{
numbered = true;
++i;
- paramnum = paramnum * 10 + req.query.q[i] - '0';
+ paramnum = paramnum * 10 + req->query.q[i] - '0';
}
if (paramnum > paramscopy.size() - 1)
queryend += len;
}
- else if (req.query.p.size())
+ else if (req->query.p.size())
{
- unsigned long len = mysql_real_escape_string(connection, queryend, req.query.p.front().c_str(), req.query.p.front().length());
+ unsigned long len = mysql_real_escape_string(connection, queryend, req->query.p.front().c_str(), req->query.p.front().length());
queryend += len;
- req.query.p.pop_front();
+ req->query.p.pop_front();
}
else
break;
}
else
{
- *queryend = req.query.q[i];
+ *queryend = req->query.q[i];
queryend++;
}
}
*queryend = 0;
- req.query.q = query;
+ req->query.q = query;
- if (!mysql_real_query(connection, req.query.q.data(), req.query.q.length()))
+ if (!mysql_real_query(connection, req->query.q.data(), req->query.q.length()))
{
/* Successfull query */
res = mysql_use_result(connection);
unsigned long rows = mysql_affected_rows(connection);
- MySQLresult* r = new MySQLresult(Parent, req.GetSource(), res, rows, req.id);
+ MySQLresult* r = new MySQLresult(Parent, req->source, res, rows, req->id);
r->dbid = this->GetID();
- r->query = req.query.q;
+ r->query = req->query.q;
/* Put this new result onto the results queue.
* XXX: Remember to mutex the queue!
*/
/* XXX: See /usr/include/mysql/mysqld_error.h for a list of
* possible error numbers and error messages */
SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(connection)) + std::string(": ") + mysql_error(connection));
- MySQLresult* r = new MySQLresult(Parent, req.GetSource(), e, req.id);
+ MySQLresult* r = new MySQLresult(Parent, req->source, e, req->id);
r->dbid = this->GetID();
- r->query = req.query.q;
+ r->query = req->query.q;
Parent->ResultsMutex.Lock();
rq.push_back(r);
return host.host;
}
+ void setInitialQuery(std::string init)
+ {
+ initquery = init;
+ }
+
void SetEnable(bool Enable)
{
Enabled = Enable;
}
}
-void ConnectDatabases(InspIRCd* ServerInstance, ModuleSQL* Parent)
+void ConnectDatabases(ModuleSQL* Parent)
{
for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
{
}
}
-void LoadDatabases(ConfigReader* conf, InspIRCd* ServerInstance, ModuleSQL* Parent)
+void LoadDatabases(ConfigReader* conf, ModuleSQL* Parent)
{
Parent->ConnMutex.Lock();
ClearOldConnections(conf);
host.user = conf->ReadValue("database", "username", j);
host.pass = conf->ReadValue("database", "password", j);
host.ssl = conf->ReadFlag("database", "ssl", j);
+ std::string initquery = conf->ReadValue("database", "initialquery", j);
if (HasHost(host))
continue;
{
SQLConnection* ThisSQL = new SQLConnection(host, Parent);
Connections[host.id] = ThisSQL;
+
+ ThisSQL->setInitialQuery(initquery);
}
}
- ConnectDatabases(ServerInstance, Parent);
+ ConnectDatabases(Parent);
Parent->ConnMutex.Unlock();
}
class DispatcherThread : public SocketThread
{
private:
- ModuleSQL* Parent;
- InspIRCd* ServerInstance;
+ ModuleSQL* const Parent;
public:
- DispatcherThread(InspIRCd* Instance, ModuleSQL* CreatorModule) : SocketThread(Instance), Parent(CreatorModule), ServerInstance(Instance) { }
+ DispatcherThread(ModuleSQL* CreatorModule) : Parent(CreatorModule) { }
~DispatcherThread() { }
virtual void Run();
virtual void OnNotify();
};
-ModuleSQL::ModuleSQL(InspIRCd* Me) : Module(Me), rehashing(false)
+ModuleSQL::ModuleSQL() : rehashing(false)
{
ServerInstance->Modules->UseInterface("SQLutils");
- Conf = new ConfigReader(ServerInstance);
- PublicServerInstance = ServerInstance;
+ Conf = new ConfigReader;
currid = 0;
- Dispatcher = new DispatcherThread(ServerInstance, this);
+ Dispatcher = new DispatcherThread(this);
ServerInstance->Threads->Start(Dispatcher);
if (!ServerInstance->Modules->PublishFeature("SQL", this))
}
ServerInstance->Modules->PublishInterface("SQL", this);
- Implementation eventlist[] = { I_OnRehash, I_OnRequest };
- ServerInstance->Modules->Attach(eventlist, this, 2);
+ Implementation eventlist[] = { I_OnRehash };
+ ServerInstance->Modules->Attach(eventlist, this, 1);
}
ModuleSQL::~ModuleSQL()
return ++currid;
}
-const char* ModuleSQL::OnRequest(Request* request)
+void ModuleSQL::OnRequest(Request& request)
{
- if(strcmp(SQLREQID, request->GetId()) == 0)
+ if(strcmp(SQLREQID, request.id) == 0)
{
- SQLrequest* req = (SQLrequest*)request;
+ SQLrequest* req = (SQLrequest*)&request;
ConnMap::iterator iter;
- const char* returnval = NULL;
-
Dispatcher->LockQueue();
ConnMutex.Lock();
if((iter = Connections.find(req->dbid)) != Connections.end())
{
req->id = NewID();
- iter->second->queue.push(*req);
- returnval = SQLSUCCESS;
+ iter->second->queue.push(new SQLrequest(*req));
}
else
{
/* Yes, it's possible this will generate a spurious wakeup.
* That's fine, it'll just get ignored.
*/
-
- return returnval;
}
-
- return NULL;
}
void ModuleSQL::OnRehash(User* user)
Version ModuleSQL::GetVersion()
{
- return Version("$Id$", VF_VENDOR | VF_SERVICEPROVIDER, API_VERSION);
+ return Version("SQL Service Provider module for all other m_sql* modules", VF_VENDOR | VF_SERVICEPROVIDER);
}
void DispatcherThread::Run()
{
- LoadDatabases(Parent->Conf, Parent->PublicServerInstance, Parent);
+ LoadDatabases(Parent->Conf, Parent);
SQLConnection* conn = NULL;
if (Parent->rehashing)
{
Parent->rehashing = false;
- LoadDatabases(Parent->Conf, Parent->PublicServerInstance, Parent);
+ LoadDatabases(Parent->Conf, Parent);
}
conn = NULL;