diff options
-rw-r--r-- | src/modules/extra/m_mysql.cpp | 95 |
1 files changed, 81 insertions, 14 deletions
diff --git a/src/modules/extra/m_mysql.cpp b/src/modules/extra/m_mysql.cpp index ea49de5ac..af9d55f6f 100644 --- a/src/modules/extra/m_mysql.cpp +++ b/src/modules/extra/m_mysql.cpp @@ -403,7 +403,7 @@ class SQLConnection : public classbase ResultQueue rq; // This constructor creates an SQLConnection object with the given credentials, but does not connect yet. - SQLConnection(const SQLhost &hi) : host(hi), Enabled(true) + SQLConnection(const SQLhost &hi) : host(hi), Enabled(false) { } @@ -580,27 +580,59 @@ class SQLConnection : public classbase mysql_close(&connection); } + const SQLhost& GetConfHost() + { + return host; + } + }; ConnMap Connections; -void ConnectDatabases(InspIRCd* ServerInstance) +bool HasHost(const SQLhost &host) { - for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++) + for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); iter++) { - i->second->SetEnable(true); - if (!i->second->Connect()) + if (host == iter->second->GetConfHost()) + return true; + } + return false; +} + +bool HostInConf(ConfigReader* conf, const SQLhost &h) +{ + for(int i = 0; i < conf->Enumerate("database"); i++) + { + SQLhost host; + host.id = conf->ReadValue("database", "id", i); + host.host = conf->ReadValue("database", "hostname", i); + host.port = conf->ReadInteger("database", "port", i, true); + host.name = conf->ReadValue("database", "name", i); + host.user = conf->ReadValue("database", "username", i); + host.pass = conf->ReadValue("database", "password", i); + host.ssl = conf->ReadFlag("database", "ssl", i); + if (h == host) + return true; + } + return false; +} + +void ClearOldConnections(ConfigReader* conf) +{ + ConnMap::iterator i,safei; + for (i = Connections.begin(); i != Connections.end(); i++) + { + if (!HostInConf(conf, i->second->GetConfHost())) { - /* XXX: MUTEX */ - pthread_mutex_lock(&logging_mutex); - ServerInstance->Log(DEFAULT,"SQL: Failed to connect database "+i->second->GetHost()+": Error: "+i->second->GetError()); - i->second->SetEnable(false); - pthread_mutex_unlock(&logging_mutex); + DELETE(i->second); + safei = i; + --i; + Connections.erase(safei); } } } -void ClearDatabases() +void ClearAllConnections() { ConnMap::iterator i; while ((i = Connections.begin()) != Connections.end()) @@ -610,9 +642,28 @@ void ClearDatabases() } } +void ConnectDatabases(InspIRCd* ServerInstance) +{ + for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++) + { + if (i->second->IsEnabled()) + continue; + + i->second->SetEnable(true); + if (!i->second->Connect()) + { + /* XXX: MUTEX */ + pthread_mutex_lock(&logging_mutex); + ServerInstance->Log(DEFAULT,"SQL: Failed to connect database "+i->second->GetHost()+": Error: "+i->second->GetError()); + i->second->SetEnable(false); + pthread_mutex_unlock(&logging_mutex); + } + } +} + void LoadDatabases(ConfigReader* conf, InspIRCd* ServerInstance) { - ClearDatabases(); + ClearOldConnections(conf); for (int j =0; j < conf->Enumerate("database"); j++) { SQLhost host; @@ -622,6 +673,10 @@ void LoadDatabases(ConfigReader* conf, InspIRCd* ServerInstance) host.name = conf->ReadValue("database", "name", j); host.user = conf->ReadValue("database", "username", j); host.pass = conf->ReadValue("database", "password", j); + host.ssl = conf->ReadFlag("database", "ssl", j); + + if (HasHost(host)) + continue; if (!host.id.empty() && !host.host.empty() && !host.name.empty() && !host.user.empty() && !host.pass.empty()) { @@ -733,9 +788,10 @@ class ModuleSQL : public Module InspIRCd* PublicServerInstance; pthread_t Dispatcher; int currid; + bool rehashing; ModuleSQL(InspIRCd* Me) - : Module::Module(Me) + : Module::Module(Me), rehashing(false) { ServerInstance->UseInterface("SQLutils"); @@ -768,7 +824,7 @@ class ModuleSQL : public Module virtual ~ModuleSQL() { giveup = true; - ClearDatabases(); + ClearAllConnections(); DELETE(Conf); ServerInstance->UnpublishInterface("SQL", this); ServerInstance->UnpublishFeature("SQL"); @@ -828,6 +884,7 @@ class ModuleSQL : public Module virtual void OnRehash(userrec* user, const std::string ¶meter) { /* TODO: set rehash bool here, which makes the dispatcher thread rehash at next opportunity */ + rehashing = true; } virtual Version GetVersion() @@ -872,6 +929,16 @@ void* DispatcherThread(void* arg) while (!giveup) { + if (thismodule->rehashing) + { + /* XXX: Lock */ + pthread_mutex_lock(&queue_mutex); + thismodule->rehashing = false; + LoadDatabases(thismodule->Conf, thismodule->PublicServerInstance); + pthread_mutex_unlock(&queue_mutex); + /* XXX: Unlock */ + } + SQLConnection* conn = NULL; /* XXX: Lock here for safety */ pthread_mutex_lock(&queue_mutex); |