From eef55acb1dbb2ae6c0202fec54e12506c064f892 Mon Sep 17 00:00:00 2001 From: Attila Molnar Date: Mon, 8 Aug 2016 14:31:49 +0200 Subject: [PATCH] Add StreamSocket::GetModHook() for obtaining the IOHook belonging to a given module Use it to simplify logic in all modules using or providing IOHooks --- include/inspsocket.h | 6 ++++++ src/inspsocket.cpp | 10 ++++++++++ src/modules/extra/m_ssl_gnutls.cpp | 23 +++++++++++------------ src/modules/extra/m_ssl_mbedtls.cpp | 12 ++++-------- src/modules/extra/m_ssl_openssl.cpp | 12 ++++-------- src/modules/m_httpd.cpp | 2 +- src/modules/m_spanningtree/main.cpp | 4 ++-- 7 files changed, 38 insertions(+), 31 deletions(-) diff --git a/include/inspsocket.h b/include/inspsocket.h index 53eca2e91..72fb03d58 100644 --- a/include/inspsocket.h +++ b/include/inspsocket.h @@ -284,6 +284,12 @@ class CoreExport StreamSocket : public EventHandler virtual void Close(); /** This ensures that close is called prior to destructor */ virtual CullResult cull(); + + /** Get the IOHook of a module attached to this socket + * @param mod Module whose IOHook to return + * @return IOHook belonging to the module or NULL if the module haven't attached an IOHook to this socket + */ + IOHook* GetModHook(Module* mod) const; }; /** * BufferedSocket is an extendable socket class which modules diff --git a/src/inspsocket.cpp b/src/inspsocket.cpp index dcc455482..0b0507f7c 100644 --- a/src/inspsocket.cpp +++ b/src/inspsocket.cpp @@ -434,3 +434,13 @@ void StreamSocket::CheckError(BufferedSocketError errcode) OnError(errcode); } } + +IOHook* StreamSocket::GetModHook(Module* mod) const +{ + if (iohook) + { + if (iohook->prov->creator == mod) + return iohook; + } + return NULL; +} diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index 44a49d895..dfd3b47dd 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -101,6 +101,8 @@ typedef gnutls_connection_end_t inspircd_gnutls_session_init_flags_t; #define INSPIRCD_GNUTLS_HAS_CORK #endif +static Module* thismod; + class RandGen : public HandlerBase2 { public: @@ -917,7 +919,7 @@ info_done_dealloc: { StreamSocket* sock = reinterpret_cast(session_wrap); #ifdef _WIN32 - GnuTLSIOHook* session = static_cast(sock->GetIOHook()); + GnuTLSIOHook* session = static_cast(sock->GetModHook(thismod)); #endif if (sock->GetEventMask() & FD_READ_WILL_BLOCK) @@ -954,7 +956,7 @@ info_done_dealloc: { StreamSocket* sock = reinterpret_cast(transportptr); #ifdef _WIN32 - GnuTLSIOHook* session = static_cast(sock->GetIOHook()); + GnuTLSIOHook* session = static_cast(sock->GetModHook(thismod)); #endif if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK) @@ -989,7 +991,7 @@ info_done_dealloc: { StreamSocket* sock = reinterpret_cast(session_wrap); #ifdef _WIN32 - GnuTLSIOHook* session = static_cast(sock->GetIOHook()); + GnuTLSIOHook* session = static_cast(sock->GetModHook(thismod)); #endif if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK) @@ -1172,7 +1174,7 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d st->key_type = GNUTLS_PRIVKEY_X509; #endif StreamSocket* sock = reinterpret_cast(gnutls_transport_get_ptr(sess)); - GnuTLS::X509Credentials& cred = static_cast(sock->GetIOHook())->GetProfile()->GetX509Credentials(); + GnuTLS::X509Credentials& cred = static_cast(sock->GetModHook(thismod))->GetProfile()->GetX509Credentials(); st->ncerts = cred.certs.size(); st->cert.x509 = cred.certs.raw(); @@ -1282,6 +1284,7 @@ class ModuleSSLGnuTLS : public Module #ifndef GNUTLS_HAS_RND gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0); #endif + thismod = this; } void init() CXX11_OVERRIDE @@ -1317,7 +1320,7 @@ class ModuleSSLGnuTLS : public Module { LocalUser* user = IS_LOCAL(static_cast(item)); - if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this) + if ((user) && (user->eh.GetModHook(this))) { // User is using SSL, they're a local user, and they're using one of *our* SSL ports. // Potentially there could be multiple SSL modules loaded at once on different ports. @@ -1333,13 +1336,9 @@ class ModuleSSLGnuTLS : public Module ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE { - if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this)) - { - GnuTLSIOHook* iohook = static_cast(user->eh.GetIOHook()); - if (!iohook->IsHandshakeDone()) - return MOD_RES_DENY; - } - + const GnuTLSIOHook* const iohook = static_cast(user->eh.GetModHook(this)); + if ((iohook) && (!iohook->IsHandshakeDone())) + return MOD_RES_DENY; return MOD_RES_PASSTHRU; } }; diff --git a/src/modules/extra/m_ssl_mbedtls.cpp b/src/modules/extra/m_ssl_mbedtls.cpp index 7efcce72d..845d02aa3 100644 --- a/src/modules/extra/m_ssl_mbedtls.cpp +++ b/src/modules/extra/m_ssl_mbedtls.cpp @@ -894,7 +894,7 @@ class ModuleSSLmbedTLS : public Module return; LocalUser* user = IS_LOCAL(static_cast(item)); - if ((user) && (user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this)) + if ((user) && (user->eh.GetModHook(this))) { // User is using SSL, they're a local user, and they're using our IOHook. // Potentially there could be multiple SSL modules loaded at once on different ports. @@ -904,13 +904,9 @@ class ModuleSSLmbedTLS : public Module ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE { - if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this)) - { - mbedTLSIOHook* iohook = static_cast(user->eh.GetIOHook()); - if (!iohook->IsHandshakeDone()) - return MOD_RES_DENY; - } - + const mbedTLSIOHook* const iohook = static_cast(user->eh.GetModHook(this)); + if ((iohook) && (!iohook->IsHandshakeDone())) + return MOD_RES_DENY; return MOD_RES_PASSTHRU; } diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp index 5587f323a..4ad556438 100644 --- a/src/modules/extra/m_ssl_openssl.cpp +++ b/src/modules/extra/m_ssl_openssl.cpp @@ -909,7 +909,7 @@ class ModuleSSLOpenSSL : public Module { LocalUser* user = IS_LOCAL((User*)item); - if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this) + if ((user) && (user->eh.GetModHook(this))) { // User is using SSL, they're a local user, and they're using one of *our* SSL ports. // Potentially there could be multiple SSL modules loaded at once on different ports. @@ -920,13 +920,9 @@ class ModuleSSLOpenSSL : public Module ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE { - if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this)) - { - OpenSSLIOHook* iohook = static_cast(user->eh.GetIOHook()); - if (!iohook->IsHandshakeDone()) - return MOD_RES_DENY; - } - + const OpenSSLIOHook* const iohook = static_cast(user->eh.GetModHook(this)); + if ((iohook) && (!iohook->IsHandshakeDone())) + return MOD_RES_DENY; return MOD_RES_PASSTHRU; } diff --git a/src/modules/m_httpd.cpp b/src/modules/m_httpd.cpp index 760647d47..0b6b2e32b 100644 --- a/src/modules/m_httpd.cpp +++ b/src/modules/m_httpd.cpp @@ -413,7 +413,7 @@ class ModuleHttpServer : public Module { HttpServerSocket* sock = *i; ++i; - if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod) + if (sock->GetModHook(mod)) { sock->cull(); delete sock; diff --git a/src/modules/m_spanningtree/main.cpp b/src/modules/m_spanningtree/main.cpp index 0b9bb65df..81543b0da 100644 --- a/src/modules/m_spanningtree/main.cpp +++ b/src/modules/m_spanningtree/main.cpp @@ -635,7 +635,7 @@ restart: for (TreeServer::ChildServers::const_iterator i = list.begin(); i != list.end(); ++i) { TreeSocket* sock = (*i)->GetSocket(); - if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod) + if (sock->GetModHook(mod)) { sock->SendError("SSL module unloaded"); sock->Close(); @@ -647,7 +647,7 @@ restart: for (SpanningTreeUtilities::TimeoutList::const_iterator i = Utils->timeoutlist.begin(); i != Utils->timeoutlist.end(); ++i) { TreeSocket* sock = i->first; - if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod) + if (sock->GetModHook(mod)) sock->Close(); } } -- 2.39.5