diff options
Diffstat (limited to 'src/modules/extra')
-rw-r--r-- | src/modules/extra/m_geoip.cpp | 5 | ||||
-rw-r--r-- | src/modules/extra/m_mysql.cpp | 30 | ||||
-rw-r--r-- | src/modules/extra/m_pgsql.cpp | 16 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_gnutls.cpp | 112 |
4 files changed, 99 insertions, 64 deletions
diff --git a/src/modules/extra/m_geoip.cpp b/src/modules/extra/m_geoip.cpp index 7696146e8..ddc4e9a38 100644 --- a/src/modules/extra/m_geoip.cpp +++ b/src/modules/extra/m_geoip.cpp @@ -35,7 +35,7 @@ class ModuleGeoIP : public Module LocalStringExt ext; GeoIP* gi; - void SetExt(LocalUser* user) + std::string* SetExt(LocalUser* user) { const char* c = GeoIP_country_code_by_addr(gi, user->GetIPString().c_str()); if (!c) @@ -43,6 +43,7 @@ class ModuleGeoIP : public Module std::string* cc = new std::string(c); ext.set(user, cc); + return cc; } public: @@ -85,7 +86,7 @@ class ModuleGeoIP : public Module { std::string* cc = ext.get(user); if (!cc) - SetExt(user); + cc = SetExt(user); std::string geoip = myclass->config->getString("geoip"); if (geoip.empty()) diff --git a/src/modules/extra/m_mysql.cpp b/src/modules/extra/m_mysql.cpp index e5d8d379c..bb8f1b573 100644 --- a/src/modules/extra/m_mysql.cpp +++ b/src/modules/extra/m_mysql.cpp @@ -180,7 +180,6 @@ class MySQLresult : public SQLResult rows++; } mysql_free_result(res); - res = NULL; } } @@ -329,10 +328,15 @@ class SQLConnection : public SQLProvider if (param < p.size()) { std::string parm = p[param++]; - char buffer[MAXBUF]; - mysql_escape_string(buffer, parm.c_str(), parm.length()); + // In the worst case, each character may need to be encoded as using two bytes, + // and one byte is the terminating null + std::vector<char> buffer(parm.length() * 2 + 1); + + // The return value of mysql_escape_string() is the length of the encoded string, + // not including the terminating null + unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length()); // mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length()); - res.append(buffer); + res.append(&buffer[0], escapedsize); } } } @@ -358,9 +362,10 @@ class SQLConnection : public SQLProvider if (it != p.end()) { std::string parm = it->second; - char buffer[MAXBUF]; - mysql_escape_string(buffer, parm.c_str(), parm.length()); - res.append(buffer); + // NOTE: See above + std::vector<char> buffer(parm.length() * 2 + 1); + unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length()); + res.append(&buffer[0], escapedsize); } } } @@ -431,13 +436,14 @@ void ModuleSQL::OnRehash(User* user) i->second->lock.Lock(); i->second->lock.Unlock(); // now remove all active queries to this DB - for(unsigned int j = qq.size() - 1; j >= 0; j--) + for (size_t j = qq.size(); j > 0; j--) { - if (qq[j].c == i->second) + size_t k = j - 1; + if (qq[k].c == i->second) { - qq[j].q->OnError(err); - delete qq[j].q; - qq.erase(qq.begin() + j); + qq[k].q->OnError(err); + delete qq[k].q; + qq.erase(qq.begin() + k); } } // finally, nuke the connection diff --git a/src/modules/extra/m_pgsql.cpp b/src/modules/extra/m_pgsql.cpp index 3df6b91bd..61e8d5bbb 100644 --- a/src/modules/extra/m_pgsql.cpp +++ b/src/modules/extra/m_pgsql.cpp @@ -412,16 +412,16 @@ restart: if (param < p.size()) { std::string parm = p[param++]; - char buffer[MAXBUF]; + std::vector<char> buffer(parm.length() * 2 + 1); #ifdef PGSQL_HAS_ESCAPECONN int error; - PQescapeStringConn(sql, buffer, parm.c_str(), parm.length(), &error); + size_t escapedsize = PQescapeStringConn(sql, &buffer[0], parm.data(), parm.length(), &error); if (error) ServerInstance->Logs->Log("m_pgsql", LOG_DEBUG, "BUG: Apparently PQescapeStringConn() failed"); #else - PQescapeString (buffer, parm.c_str(), parm.length()); + size_t escapedsize = PQescapeString(&buffer[0], parm.data(), parm.length()); #endif - res.append(buffer); + res.append(&buffer[0], escapedsize); } } } @@ -447,16 +447,16 @@ restart: if (it != p.end()) { std::string parm = it->second; - char buffer[MAXBUF]; + std::vector<char> buffer(parm.length() * 2 + 1); #ifdef PGSQL_HAS_ESCAPECONN int error; - PQescapeStringConn(sql, buffer, parm.c_str(), parm.length(), &error); + size_t escapedsize = PQescapeStringConn(sql, &buffer[0], parm.data(), parm.length(), &error); if (error) ServerInstance->Logs->Log("m_pgsql", LOG_DEBUG, "BUG: Apparently PQescapeStringConn() failed"); #else - PQescapeString (buffer, parm.c_str(), parm.length()); + size_t escapedsize = PQescapeString(&buffer[0], parm.data(), parm.length()); #endif - res.append(buffer); + res.append(&buffer[0], escapedsize); } } } diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index 0659f631c..8faee2da7 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -77,46 +77,6 @@ static int cert_callback (gnutls_session_t session, const gnutls_datum_t * req_c return 0; } -static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t user_wrap, void* buffer, size_t size) -{ - StreamSocket* user = reinterpret_cast<StreamSocket*>(user_wrap); - if (user->GetEventMask() & FD_READ_WILL_BLOCK) - { - errno = EAGAIN; - return -1; - } - int rv = ServerInstance->SE->Recv(user, reinterpret_cast<char *>(buffer), size, 0); - if (rv < 0) - { - /* On Windows we need to set errno for gnutls */ - if (SocketEngine::IgnoreError()) - errno = EAGAIN; - } - if (rv < (int)size) - ServerInstance->SE->ChangeEventMask(user, FD_READ_WILL_BLOCK); - return rv; -} - -static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t user_wrap, const void* buffer, size_t size) -{ - StreamSocket* user = reinterpret_cast<StreamSocket*>(user_wrap); - if (user->GetEventMask() & FD_WRITE_WILL_BLOCK) - { - errno = EAGAIN; - return -1; - } - int rv = ServerInstance->SE->Send(user, reinterpret_cast<const char *>(buffer), size, 0); - if (rv < 0) - { - /* On Windows we need to set errno for gnutls */ - if (SocketEngine::IgnoreError()) - errno = EAGAIN; - } - if (rv < (int)size) - ServerInstance->SE->ChangeEventMask(user, FD_WRITE_WILL_BLOCK); - return rv; -} - class RandGen : public HandlerBase2<void, char*, size_t> { public: @@ -132,10 +92,12 @@ class RandGen : public HandlerBase2<void, char*, size_t> class issl_session { public: + StreamSocket* socket; gnutls_session_t sess; issl_status status; reference<ssl_cert> cert; - issl_session() : sess(NULL) {} + + issl_session() : socket(NULL), sess(NULL) {} }; class CommandStartTLS : public SplitCommand @@ -213,6 +175,70 @@ class ModuleSSLGnuTLS : public Module return str ? str : "UNKNOWN"; } + static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t session_wrap, void* buffer, size_t size) + { + issl_session* session = reinterpret_cast<issl_session*>(session_wrap); + if (session->socket->GetEventMask() & FD_READ_WILL_BLOCK) + { +#ifdef _WIN32 + gnutls_transport_set_errno(session->sess, EAGAIN); +#else + errno = EAGAIN; +#endif + return -1; + } + + int rv = ServerInstance->SE->Recv(session->socket, reinterpret_cast<char *>(buffer), size, 0); + +#ifdef _WIN32 + if (rv < 0) + { + /* Windows doesn't use errno, but gnutls does, so check SocketEngine::IgnoreError() + * and then set errno appropriately. + * The gnutls library may also have a different errno variable than us, see + * gnutls_transport_set_errno(3). + */ + gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno); + } +#endif + + if (rv < (int)size) + ServerInstance->SE->ChangeEventMask(session->socket, FD_READ_WILL_BLOCK); + return rv; + } + + static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t session_wrap, const void* buffer, size_t size) + { + issl_session* session = reinterpret_cast<issl_session*>(session_wrap); + if (session->socket->GetEventMask() & FD_WRITE_WILL_BLOCK) + { +#ifdef _WIN32 + gnutls_transport_set_errno(session->sess, EAGAIN); +#else + errno = EAGAIN; +#endif + return -1; + } + + int rv = ServerInstance->SE->Send(session->socket, reinterpret_cast<const char *>(buffer), size, 0); + +#ifdef _WIN32 + if (rv < 0) + { + /* Windows doesn't use errno, but gnutls does, so check SocketEngine::IgnoreError() + * and then set errno appropriately. + * The gnutls library may also have a different errno variable than us, see + * gnutls_transport_set_errno(3). + */ + gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno); + } +#endif + + if (rv < (int)size) + ServerInstance->SE->ChangeEventMask(session->socket, FD_WRITE_WILL_BLOCK); + return rv; + } + public: ModuleSSLGnuTLS() : starttls(this), capHandler(this, "tls"), iohook(this, "ssl/gnutls", SERVICE_IOHOOK) @@ -538,13 +564,14 @@ class ModuleSSLGnuTLS : public Module issl_session* session = &sessions[user->GetFd()]; gnutls_init(&session->sess, me_server ? GNUTLS_SERVER : GNUTLS_CLIENT); + session->socket = user; #ifdef GNUTLS_NEW_PRIO_API gnutls_priority_set(session->sess, priority); #endif gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred); gnutls_dh_set_prime_bits(session->sess, dh_bits); - gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(user)); + gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(session)); gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper); gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper); @@ -760,6 +787,7 @@ class ModuleSSLGnuTLS : public Module gnutls_bye(session->sess, GNUTLS_SHUT_WR); gnutls_deinit(session->sess); } + session->socket = NULL; session->sess = NULL; session->cert = NULL; session->status = ISSL_NONE; |