]> git.netwichtig.de Git - user/henk/code/inspircd.git/commitdiff
Split IOHook into IOHook and IOHookProvider
authorAttila Molnar <attilamolnar@hush.com>
Tue, 24 Sep 2013 18:40:20 +0000 (20:40 +0200)
committerAttila Molnar <attilamolnar@hush.com>
Wed, 22 Jan 2014 18:10:01 +0000 (19:10 +0100)
Create one IOHook instance for each hooked socket which contains all the
hook specific data and read/write/close functions, removing the need for
the "issl_session" array in SSL modules.

Register instances of the IOHookProvider class in the core and use them to
create specialized IOHook instances (OnConnect/OnAccept).

Remove the OnHookIO hook, add a dynamic reference to ListenSocket that
points to the hook provider (if any) to use for incoming connections on
that socket.

For outgoing connections modules still have to find the IOHookProvider
they want to use themselves but instead of calling AddIOHook(hookprov),
now they have to call IOHookProvider::OnConnect() after the connection
has been established.

15 files changed:
include/iohook.h
include/modules.h
include/modules/ssl.h
include/socket.h
src/inspsocket.cpp
src/listensocket.cpp
src/modules.cpp
src/modules/extra/m_ssl_gnutls.cpp
src/modules/extra/m_ssl_openssl.cpp
src/modules/m_httpd.cpp
src/modules/m_spanningtree/main.cpp
src/modules/m_spanningtree/treesocket1.cpp
src/modules/m_starttls.cpp
src/socket.cpp
src/usermanager.cpp

index 7c3a0faeef975aa064064409cecbe238291926fb..ce7ca2a1be305d347af9e185ad134fbffcfdb596 100644 (file)
@@ -21,7 +21,7 @@
 
 class StreamSocket;
 
-class IOHook : public ServiceProvider
+class IOHookProvider : public ServiceProvider
 {
  public:
        enum Type
@@ -32,19 +32,35 @@ class IOHook : public ServiceProvider
 
        const Type type;
 
-       IOHook(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN)
+       IOHookProvider(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN)
                : ServiceProvider(mod, Name, SERVICE_IOHOOK), type(hooktype) { }
 
-       /** Called immediately after any connection is accepted. This is intended for raw socket
+       /** Called immediately after a connection is accepted. This is intended for raw socket
         * processing (e.g. modules which wrap the tcp connection within another library) and provides
         * no information relating to a user record as the connection has not been assigned yet.
-        * There are no return values from this call as all modules get an opportunity if required to
-        * process the connection.
         * @param sock The socket in question
         * @param client The client IP address and port
         * @param server The server IP address and port
         */
-       virtual void OnStreamSocketAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) = 0;
+       virtual void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) = 0;
+
+       /** Called immediately upon connection of an outbound BufferedSocket which has been hooked
+        * by a module.
+        * @param sock The socket in question
+        */
+       virtual void OnConnect(StreamSocket* sock) = 0;
+};
+
+class IOHook : public classbase
+{
+ public:
+       /** The IOHookProvider for this hook, contains information about the hook,
+        * such as the module providing it and the hook type.
+        */
+       IOHookProvider* const prov;
+
+       IOHook(IOHookProvider* provider)
+               : prov(provider) { }
 
        /**
         * Called when a hooked stream has data to write, or when the socket
@@ -62,12 +78,6 @@ class IOHook : public ServiceProvider
         */
        virtual void OnStreamSocketClose(StreamSocket* sock) = 0;
 
-       /** Called immediately upon connection of an outbound BufferedSocket which has been hooked
-        * by a module.
-        * @param sock The socket in question
-        */
-       virtual void OnStreamSocketConnect(StreamSocket* sock) = 0;
-
        /**
         * Called when the stream socket has data to read
         * @param sock The socket that is ready
index 0be1ea294189b490c5aa66764b67ad7ca731e03a..7223f6b9d61d139cb87f0952b16cf2fd6f6a1f1f 100644 (file)
@@ -264,7 +264,7 @@ enum Implementation
        I_OnChangeLocalUserGECOS, I_OnUserRegister, I_OnChannelPreDelete, I_OnChannelDelete,
        I_OnPostOper, I_OnSyncNetwork, I_OnSetAway, I_OnPostCommand, I_OnPostJoin,
        I_OnWhoisLine, I_OnBuildNeighborList, I_OnGarbageCollect, I_OnSetConnectClass,
-       I_OnText, I_OnPassCompare, I_OnRunTestSuite, I_OnNamesListItem, I_OnNumeric, I_OnHookIO,
+       I_OnText, I_OnPassCompare, I_OnRunTestSuite, I_OnNamesListItem, I_OnNumeric,
        I_OnPreRehash, I_OnModuleRehash, I_OnSendWhoLine, I_OnChangeIdent, I_OnSetUserIP,
        I_END
 };
@@ -989,12 +989,6 @@ class CoreExport Module : public classbase, public usecountbase
         */
        virtual void OnPostConnect(User* user);
 
-       /** Called to install an I/O hook on an event handler
-        * @param user The socket to possibly install the I/O hook on
-        * @param via The port that the user connected on
-        */
-       virtual void OnHookIO(StreamSocket* user, ListenSocket* via);
-
        /** Called when a port accepts a connection
         * Return MOD_RES_ACCEPT if you have used the file descriptor.
         * @param fd The file descriptor returned from accept()
index 25076215ac5b4e5e5913e2dfbb3ba52c6c372715..0f58e0b7bfce4bf43d6c3357b052e6e900ae420c 100644 (file)
@@ -133,28 +133,34 @@ class ssl_cert : public refcountbase
 
 class SSLIOHook : public IOHook
 {
+ protected:
+       /** Peer SSL certificate, set by the SSL module
+        */
+       reference<ssl_cert> certificate;
+
  public:
-       SSLIOHook(Module* mod, const std::string& Name)
-               : IOHook(mod, Name, IOHook::IOH_SSL)
+       SSLIOHook(IOHookProvider* hookprov)
+               : IOHook(hookprov)
        {
        }
 
        /**
-        * Get the client certificate from a socket
-        * @param sock The socket to get the certificate from, must be using this IOHook
-        * @return The SSL client certificate information
+        * Get the certificate sent by this peer
+        * @return The SSL certificate sent by the peer, NULL if no cert was sent
         */
-       virtual ssl_cert* GetCertificate(StreamSocket* sock) = 0;
+       ssl_cert* GetCertificate() const
+       {
+               return certificate;
+       }
 
        /**
-        * Get the fingerprint of a client certificate from a socket
-        * @param sock The socket to get the certificate fingerprint from, must be using this IOHook
+        * Get the fingerprint of the peer's certificate
         * @return The fingerprint of the SSL client certificate sent by the peer,
         * empty if no cert was sent
         */
-       std::string GetFingerprint(StreamSocket* sock)
+       std::string GetFingerprint() const
        {
-               ssl_cert* cert = GetCertificate(sock);
+               ssl_cert* cert = GetCertificate();
                if (cert)
                        return cert->GetFingerprint();
                return "";
@@ -175,11 +181,11 @@ class SSLClientCert
        static ssl_cert* GetCertificate(StreamSocket* sock)
        {
                IOHook* iohook = sock->GetIOHook();
-               if ((!iohook) || (iohook->type != IOHook::IOH_SSL))
+               if ((!iohook) || (iohook->prov->type != IOHookProvider::IOH_SSL))
                        return NULL;
 
                SSLIOHook* ssliohook = static_cast<SSLIOHook*>(iohook);
-               return ssliohook->GetCertificate(sock);
+               return ssliohook->GetCertificate();
        }
 
        /**
index c54517a76c9f8eb68a7d9d141ef415d1d62e3b9f..c292b7010b915262d3ad03920f39acdfecf39d81 100644 (file)
@@ -127,6 +127,7 @@ namespace irc
        }
 }
 
+#include "iohook.h"
 #include "socketengine.h"
 /** This class handles incoming connections on client ports.
  * It will create a new User for every valid connection
@@ -140,6 +141,12 @@ class CoreExport ListenSocket : public EventHandler
        int bind_port;
        /** Human-readable bind description */
        std::string bind_desc;
+
+       /** The IOHook provider which handles connections on this socket,
+        * NULL if there is none.
+        */
+       dynamic_reference_nocheck<IOHookProvider> iohookprov;
+
        /** Create a new listening socket
         */
        ListenSocket(ConfigTag* tag, const irc::sockets::sockaddrs& bind_to);
@@ -153,4 +160,10 @@ class CoreExport ListenSocket : public EventHandler
        /** Handles sockets internals crap of a connection, convenience wrapper really
         */
        void AcceptInternal();
+
+       /** Inspects the bind block belonging to this socket to set the name of the IO hook
+        * provider which this socket will use for incoming connections.
+        * @return True if the IO hook provider was found or none was given, false otherwise.
+        */
+       bool ResetIOHookProvider();
 };
index 8822f69f82f5538e3aa2f2579a628bfe464538ee..ea09a8b1d5ca2c6fa819acdf978996770a26f3a3 100644 (file)
@@ -134,6 +134,7 @@ void StreamSocket::Close()
                                ServerInstance->Logs->Log("SOCKET", LOG_DEFAULT, "%s threw an exception: %s",
                                        modexcept.GetSource().c_str(), modexcept.GetReason().c_str());
                        }
+                       delete iohook;
                        DelIOHook();
                }
                ServerInstance->SE->Shutdown(this, 2);
@@ -467,9 +468,7 @@ void BufferedSocket::DoWrite()
        {
                state = I_CONNECTED;
                this->OnConnected();
-               if (GetIOHook())
-                       GetIOHook()->OnStreamSocketConnect(this);
-               else
+               if (!GetIOHook())
                        ServerInstance->SE->ChangeEventMask(this, FD_WANT_FAST_READ | FD_WANT_EDGE_WRITE);
        }
        this->StreamSocket::DoWrite();
index 108466ae34548993b86c2c1ab18db4e78cb38920..01bc36cc5ea6c08d6c7874b6a140c9493ebe6033 100644 (file)
@@ -28,6 +28,7 @@
 
 ListenSocket::ListenSocket(ConfigTag* tag, const irc::sockets::sockaddrs& bind_to)
        : bind_tag(tag)
+       , iohookprov(NULL, std::string())
 {
        irc::sockets::satoap(bind_to, bind_addr, bind_port);
        bind_desc = bind_to.str();
@@ -85,6 +86,8 @@ ListenSocket::ListenSocket(ConfigTag* tag, const irc::sockets::sockaddrs& bind_t
        {
                ServerInstance->SE->NonBlocking(this->fd);
                ServerInstance->SE->AddFd(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
+
+               this->ResetIOHookProvider();
        }
 }
 
@@ -214,3 +217,16 @@ void ListenSocket::HandleEvent(EventType e, int err)
                        break;
        }
 }
+
+bool ListenSocket::ResetIOHookProvider()
+{
+       std::string provname = bind_tag->getString("ssl");
+       if (!provname.empty())
+               provname.insert(0, "ssl/");
+
+       // Set the new provider name, dynref handles the rest
+       iohookprov.SetProvider(provname);
+
+       // Return true if no provider was set, or one was set and it was also found
+       return (provname.empty() || iohookprov);
+}
index 23aceb3e11e442c524e5eb8f15ce44e2ef619eba..c70a99d779de5d6c2dc559ac008bbe71d84646af 100644 (file)
@@ -154,7 +154,6 @@ void                Module::OnText(User*, void*, int, const std::string&, char, CUList&) { De
 void           Module::OnRunTestSuite() { DetachEvent(I_OnRunTestSuite); }
 void           Module::OnNamesListItem(User*, Membership*, std::string&, std::string&) { DetachEvent(I_OnNamesListItem); }
 ModResult      Module::OnNumeric(User*, unsigned int, const std::string&) { DetachEvent(I_OnNumeric); return MOD_RES_PASSTHRU; }
-void           Module::OnHookIO(StreamSocket*, ListenSocket*) { DetachEvent(I_OnHookIO); }
 ModResult   Module::OnAcceptConnection(int, ListenSocket*, irc::sockets::sockaddrs*, irc::sockets::sockaddrs*) { DetachEvent(I_OnAcceptConnection); return MOD_RES_PASSTHRU; }
 void           Module::OnSendWhoLine(User*, const std::vector<std::string>&, User*, std::string&) { DetachEvent(I_OnSendWhoLine); }
 void           Module::OnSetUserIP(LocalUser*) { DetachEvent(I_OnSetUserIP); }
index 7c19925ddabbe1f64d61c6e0268b2e7ac80f6d03..2add962fd05d1157e4a93ed2b6469a8c07f2f51f 100644 (file)
@@ -543,58 +543,41 @@ namespace GnuTLS
        };
 }
 
-/** Represents an SSL user's extra data
- */
-class issl_session
+class GnuTLSIOHook : public SSLIOHook
 {
-public:
-       StreamSocket* socket;
+ private:
        gnutls_session_t sess;
        issl_status status;
-       reference<ssl_cert> cert;
        reference<GnuTLS::Profile> profile;
 
-       issl_session() : socket(NULL), sess(NULL) {}
-};
-
-class GnuTLSIOHook : public SSLIOHook
-{
- private:
        void InitSession(StreamSocket* user, bool me_server)
        {
-               issl_session* session = &sessions[user->GetFd()];
-
-               gnutls_init(&session->sess, me_server ? GNUTLS_SERVER : GNUTLS_CLIENT);
-               session->socket = user;
+               gnutls_init(&sess, me_server ? GNUTLS_SERVER : GNUTLS_CLIENT);
 
-               session->profile->SetupSession(session->sess);
-               gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(session));
-               gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper);
-               gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper);
+               profile->SetupSession(sess);
+               gnutls_transport_set_ptr(sess, reinterpret_cast<gnutls_transport_ptr_t>(user));
+               gnutls_transport_set_push_function(sess, gnutls_push_wrapper);
+               gnutls_transport_set_pull_function(sess, gnutls_pull_wrapper);
 
                if (me_server)
-                       gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any.
-
-               Handshake(session, user);
+                       gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST); // Request client certificate if any.
        }
 
-       void CloseSession(issl_session* session)
+       void CloseSession()
        {
-               if (session->sess)
+               if (this->sess)
                {
-                       gnutls_bye(session->sess, GNUTLS_SHUT_WR);
-                       gnutls_deinit(session->sess);
+                       gnutls_bye(this->sess, GNUTLS_SHUT_WR);
+                       gnutls_deinit(this->sess);
                }
-               session->socket = NULL;
-               session->sess = NULL;
-               session->cert = NULL;
-               session->profile = NULL;
-               session->status = ISSL_NONE;
+               sess = NULL;
+               certificate = NULL;
+               status = ISSL_NONE;
        }
 
-       bool Handshake(issl_session* session, StreamSocket* user)
+       bool Handshake(StreamSocket* user)
        {
-               int ret = gnutls_handshake(session->sess);
+               int ret = gnutls_handshake(this->sess);
 
                if (ret < 0)
                {
@@ -602,24 +585,24 @@ class GnuTLSIOHook : public SSLIOHook
                        {
                                // Handshake needs resuming later, read() or write() would have blocked.
 
-                               if(gnutls_record_get_direction(session->sess) == 0)
+                               if (gnutls_record_get_direction(this->sess) == 0)
                                {
                                        // gnutls_handshake() wants to read() again.
-                                       session->status = ISSL_HANDSHAKING_READ;
+                                       this->status = ISSL_HANDSHAKING_READ;
                                        ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
                                }
                                else
                                {
                                        // gnutls_handshake() wants to write() again.
-                                       session->status = ISSL_HANDSHAKING_WRITE;
+                                       this->status = ISSL_HANDSHAKING_WRITE;
                                        ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
                                }
                        }
                        else
                        {
                                user->SetError("Handshake Failed - " + std::string(gnutls_strerror(ret)));
-                               CloseSession(session);
-                               session->status = ISSL_CLOSING;
+                               CloseSession();
+                               this->status = ISSL_CLOSING;
                        }
 
                        return false;
@@ -627,9 +610,9 @@ class GnuTLSIOHook : public SSLIOHook
                else
                {
                        // Change the seesion state
-                       session->status = ISSL_HANDSHAKEN;
+                       this->status = ISSL_HANDSHAKEN;
 
-                       VerifyCertificate(session,user);
+                       VerifyCertificate();
 
                        // Finish writing, if any left
                        ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
@@ -638,12 +621,9 @@ class GnuTLSIOHook : public SSLIOHook
                }
        }
 
-       void VerifyCertificate(issl_session* session, StreamSocket* user)
+       void VerifyCertificate()
        {
-               if (!session->sess || !user)
-                       return;
-
-               unsigned int status;
+               unsigned int certstatus;
                const gnutls_datum_t* cert_list;
                int ret;
                unsigned int cert_list_size;
@@ -653,12 +633,12 @@ class GnuTLSIOHook : public SSLIOHook
                size_t digest_size = sizeof(digest);
                size_t name_size = sizeof(str);
                ssl_cert* certinfo = new ssl_cert;
-               session->cert = certinfo;
+               this->certificate = certinfo;
 
                /* This verification function uses the trusted CAs in the credentials
                 * structure. So you must have installed one or more CA certificates.
                 */
-               ret = gnutls_certificate_verify_peers2(session->sess, &status);
+               ret = gnutls_certificate_verify_peers2(this->sess, &certstatus);
 
                if (ret < 0)
                {
@@ -666,16 +646,16 @@ class GnuTLSIOHook : public SSLIOHook
                        return;
                }
 
-               certinfo->invalid = (status & GNUTLS_CERT_INVALID);
-               certinfo->unknownsigner = (status & GNUTLS_CERT_SIGNER_NOT_FOUND);
-               certinfo->revoked = (status & GNUTLS_CERT_REVOKED);
-               certinfo->trusted = !(status & GNUTLS_CERT_SIGNER_NOT_CA);
+               certinfo->invalid = (certstatus & GNUTLS_CERT_INVALID);
+               certinfo->unknownsigner = (certstatus & GNUTLS_CERT_SIGNER_NOT_FOUND);
+               certinfo->revoked = (certstatus & GNUTLS_CERT_REVOKED);
+               certinfo->trusted = !(certstatus & GNUTLS_CERT_SIGNER_NOT_CA);
 
                /* Up to here the process is the same for X.509 certificates and
                 * OpenPGP keys. From now on X.509 certificates are assumed. This can
                 * be easily extended to work with openpgp keys as well.
                 */
-               if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509)
+               if (gnutls_certificate_type_get(this->sess) != GNUTLS_CRT_X509)
                {
                        certinfo->error = "No X509 keys sent";
                        return;
@@ -689,7 +669,7 @@ class GnuTLSIOHook : public SSLIOHook
                }
 
                cert_list_size = 0;
-               cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size);
+               cert_list = gnutls_certificate_get_peers(this->sess, &cert_list_size);
                if (cert_list == NULL)
                {
                        certinfo->error = "No certificate was found";
@@ -713,7 +693,7 @@ class GnuTLSIOHook : public SSLIOHook
                gnutls_x509_crt_get_issuer_dn(cert, str, &name_size);
                certinfo->issuer = str;
 
-               if ((ret = gnutls_x509_crt_get_fingerprint(cert, session->profile->GetHash(), digest, &digest_size)) < 0)
+               if ((ret = gnutls_x509_crt_get_fingerprint(cert, profile->GetHash(), digest, &digest_size)) < 0)
                {
                        certinfo->error = gnutls_strerror(ret);
                }
@@ -740,8 +720,12 @@ info_done_dealloc:
 
        static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t session_wrap, void* buffer, size_t size)
        {
-               issl_session* session = reinterpret_cast<issl_session*>(session_wrap);
-               if (session->socket->GetEventMask() & FD_READ_WILL_BLOCK)
+               StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
+#ifdef _WIN32
+               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+#endif
+
+               if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
                {
 #ifdef _WIN32
                        gnutls_transport_set_errno(session->sess, EAGAIN);
@@ -751,7 +735,7 @@ info_done_dealloc:
                        return -1;
                }
 
-               int rv = ServerInstance->SE->Recv(session->socket, reinterpret_cast<char *>(buffer), size, 0);
+               int rv = ServerInstance->SE->Recv(sock, reinterpret_cast<char *>(buffer), size, 0);
 
 #ifdef _WIN32
                if (rv < 0)
@@ -766,14 +750,18 @@ info_done_dealloc:
 #endif
 
                if (rv < (int)size)
-                       ServerInstance->SE->ChangeEventMask(session->socket, FD_READ_WILL_BLOCK);
+                       ServerInstance->SE->ChangeEventMask(sock, FD_READ_WILL_BLOCK);
                return rv;
        }
 
        static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t session_wrap, const void* buffer, size_t size)
        {
-               issl_session* session = reinterpret_cast<issl_session*>(session_wrap);
-               if (session->socket->GetEventMask() & FD_WRITE_WILL_BLOCK)
+               StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
+#ifdef _WIN32
+               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+#endif
+
+               if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
                {
 #ifdef _WIN32
                        gnutls_transport_set_errno(session->sess, EAGAIN);
@@ -783,7 +771,7 @@ info_done_dealloc:
                        return -1;
                }
 
-               int rv = ServerInstance->SE->Send(session->socket, reinterpret_cast<const char *>(buffer), size, 0);
+               int rv = ServerInstance->SE->Send(sock, reinterpret_cast<const char *>(buffer), size, 0);
 
 #ifdef _WIN32
                if (rv < 0)
@@ -798,75 +786,55 @@ info_done_dealloc:
 #endif
 
                if (rv < (int)size)
-                       ServerInstance->SE->ChangeEventMask(session->socket, FD_WRITE_WILL_BLOCK);
+                       ServerInstance->SE->ChangeEventMask(sock, FD_WRITE_WILL_BLOCK);
                return rv;
        }
 
  public:
-       issl_session* sessions;
-
-       GnuTLSIOHook(Module* parent)
-               : SSLIOHook(parent, "ssl/gnutls")
-       {
-               sessions = new issl_session[ServerInstance->SE->GetMaxFds()];
-       }
-
-       ~GnuTLSIOHook()
-       {
-               delete[] sessions;
-       }
-
-       void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
+       GnuTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool outbound, const reference<GnuTLS::Profile>& sslprofile)
+               : SSLIOHook(hookprov)
+               , sess(NULL)
+               , status(ISSL_NONE)
+               , profile(sslprofile)
        {
-               issl_session* session = &sessions[user->GetFd()];
-
-               /* For STARTTLS: Don't try and init a session on a socket that already has a session */
-               if (session->sess)
-                       return;
-
-               InitSession(user, true);
-       }
-
-       void OnStreamSocketConnect(StreamSocket* user) CXX11_OVERRIDE
-       {
-               InitSession(user, false);
+               InitSession(sock, outbound);
+               sock->AddIOHook(this);
+               Handshake(sock);
        }
 
        void OnStreamSocketClose(StreamSocket* user) CXX11_OVERRIDE
        {
-               CloseSession(&sessions[user->GetFd()]);
+               CloseSession();
        }
 
        int OnStreamSocketRead(StreamSocket* user, std::string& recvq) CXX11_OVERRIDE
        {
-               issl_session* session = &sessions[user->GetFd()];
-
-               if (!session->sess)
+               if (!this->sess)
                {
-                       CloseSession(session);
+                       CloseSession();
                        user->SetError("No SSL session");
                        return -1;
                }
 
-               if (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE)
+               if (this->status == ISSL_HANDSHAKING_READ || this->status == ISSL_HANDSHAKING_WRITE)
                {
                        // The handshake isn't finished, try to finish it.
 
-                       if(!Handshake(session, user))
+                       if (!Handshake(user))
                        {
-                               if (session->status != ISSL_CLOSING)
+                               if (this->status != ISSL_CLOSING)
                                        return 0;
                                return -1;
                        }
                }
 
-               // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN.
+               // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN.
 
-               if (session->status == ISSL_HANDSHAKEN)
+               if (this->status == ISSL_HANDSHAKEN)
                {
                        char* buffer = ServerInstance->GetReadBuffer();
                        size_t bufsiz = ServerInstance->Config->NetBufferSize;
-                       int ret = gnutls_record_recv(session->sess, buffer, bufsiz);
+                       int ret = gnutls_record_recv(this->sess, buffer, bufsiz);
                        if (ret > 0)
                        {
                                recvq.append(buffer, ret);
@@ -879,17 +847,17 @@ info_done_dealloc:
                        else if (ret == 0)
                        {
                                user->SetError("Connection closed");
-                               CloseSession(session);
+                               CloseSession();
                                return -1;
                        }
                        else
                        {
                                user->SetError(gnutls_strerror(ret));
-                               CloseSession(session);
+                               CloseSession();
                                return -1;
                        }
                }
-               else if (session->status == ISSL_CLOSING)
+               else if (this->status == ISSL_CLOSING)
                        return -1;
 
                return 0;
@@ -897,29 +865,27 @@ info_done_dealloc:
 
        int OnStreamSocketWrite(StreamSocket* user, std::string& sendq) CXX11_OVERRIDE
        {
-               issl_session* session = &sessions[user->GetFd()];
-
-               if (!session->sess)
+               if (!this->sess)
                {
-                       CloseSession(session);
+                       CloseSession();
                        user->SetError("No SSL session");
                        return -1;
                }
 
-               if (session->status == ISSL_HANDSHAKING_WRITE || session->status == ISSL_HANDSHAKING_READ)
+               if (this->status == ISSL_HANDSHAKING_WRITE || this->status == ISSL_HANDSHAKING_READ)
                {
                        // The handshake isn't finished, try to finish it.
-                       Handshake(session, user);
-                       if (session->status != ISSL_CLOSING)
+                       Handshake(user);
+                       if (this->status != ISSL_CLOSING)
                                return 0;
                        return -1;
                }
 
                int ret = 0;
 
-               if (session->status == ISSL_HANDSHAKEN)
+               if (this->status == ISSL_HANDSHAKEN)
                {
-                       ret = gnutls_record_send(session->sess, sendq.data(), sendq.length());
+                       ret = gnutls_record_send(this->sess, sendq.data(), sendq.length());
 
                        if (ret == (int)sendq.length())
                        {
@@ -940,7 +906,7 @@ info_done_dealloc:
                        else // (ret < 0)
                        {
                                user->SetError(gnutls_strerror(ret));
-                               CloseSession(session);
+                               CloseSession();
                                return -1;
                        }
                }
@@ -948,16 +914,8 @@ info_done_dealloc:
                return 0;
        }
 
-       ssl_cert* GetCertificate(StreamSocket* sock) CXX11_OVERRIDE
-       {
-               int fd = sock->GetFd();
-               issl_session* session = &sessions[fd];
-               return session->cert;
-       }
-
        void TellCiphersAndFingerprint(LocalUser* user)
        {
-               const gnutls_session_t& sess = sessions[user->eh.GetFd()].sess;
                if (sess)
                {
                        std::string text = "*** You are connected using SSL cipher '";
@@ -966,13 +924,14 @@ info_done_dealloc:
                        text.append("-").append(UnknownIfNULL(gnutls_cipher_get_name(gnutls_cipher_get(sess)))).append("-");
                        text.append(UnknownIfNULL(gnutls_mac_get_name(gnutls_mac_get(sess)))).append("'");
 
-                       ssl_cert* cert = sessions[user->eh.GetFd()].cert;
-                       if (!cert->fingerprint.empty())
-                               text += " and your SSL fingerprint is " + cert->fingerprint;
+                       if (!certificate->fingerprint.empty())
+                               text += " and your SSL fingerprint is " + certificate->fingerprint;
 
                        user->WriteNotice(text);
                }
        }
+
+       GnuTLS::Profile* GetProfile() { return profile; }
 };
 
 int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_datum_t* req_ca_rdn, int nreqs, const gnutls_pk_algorithm_t* sign_algos, int sign_algos_length, cert_cb_last_param_type* st)
@@ -983,8 +942,8 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d
        st->cert_type = GNUTLS_CRT_X509;
        st->key_type = GNUTLS_PRIVKEY_X509;
 #endif
-       issl_session* session = reinterpret_cast<issl_session*>(gnutls_transport_get_ptr(sess));
-       GnuTLS::X509Credentials& cred = session->profile->GetX509Credentials();
+       StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess));
+       GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetIOHook())->GetProfile()->GetX509Credentials();
 
        st->ncerts = cred.certs.size();
        st->cert.x509 = cred.certs.raw();
@@ -994,15 +953,41 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d
        return 0;
 }
 
+class GnuTLSIOHookProvider : public refcountbase, public IOHookProvider
+{
+       reference<GnuTLS::Profile> profile;
+
+ public:
+       GnuTLSIOHookProvider(Module* mod, reference<GnuTLS::Profile>& prof)
+               : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL)
+               , profile(prof)
+       {
+               ServerInstance->Modules->AddService(*this);
+       }
+
+       ~GnuTLSIOHookProvider()
+       {
+               ServerInstance->Modules->DelService(*this);
+       }
+
+       void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
+       {
+               new GnuTLSIOHook(this, sock, true, profile);
+       }
+
+       void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
+       {
+               new GnuTLSIOHook(this, sock, false, profile);
+       }
+};
+
 class ModuleSSLGnuTLS : public Module
 {
-       typedef std::vector<reference<GnuTLS::Profile> > ProfileList;
+       typedef std::vector<reference<GnuTLSIOHookProvider> > ProfileList;
 
        // First member of the class, gets constructed first and destructed last
        GnuTLS::Init libinit;
 
-       GnuTLSIOHook iohook;
-
        std::string sslports;
 
        RandGen randhandler;
@@ -1026,7 +1011,7 @@ class ModuleSSLGnuTLS : public Module
                        try
                        {
                                reference<GnuTLS::Profile> profile(GnuTLS::Profile::Create(defname, tag));
-                               newprofiles.push_back(profile);
+                               newprofiles.push_back(new GnuTLSIOHookProvider(this, profile));
                        }
                        catch (CoreException& ex)
                        {
@@ -1057,7 +1042,7 @@ class ModuleSSLGnuTLS : public Module
                                throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
                        }
 
-                       newprofiles.push_back(profile);
+                       newprofiles.push_back(new GnuTLSIOHookProvider(this, profile));
                }
 
                // New profiles are ok, begin using them
@@ -1066,7 +1051,7 @@ class ModuleSSLGnuTLS : public Module
        }
 
  public:
-       ModuleSSLGnuTLS() : iohook(this)
+       ModuleSSLGnuTLS()
        {
 #ifndef GNUTLS_HAS_RND
                gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0);
@@ -1144,7 +1129,7 @@ class ModuleSSLGnuTLS : public Module
                {
                        LocalUser* user = IS_LOCAL(static_cast<User*>(item));
 
-                       if (user && user->eh.GetIOHook() == &iohook)
+                       if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this)
                        {
                                // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
                                // Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -1164,27 +1149,11 @@ class ModuleSSLGnuTLS : public Module
                        tokens["SSL"] = sslports;
        }
 
-       void OnHookIO(StreamSocket* user, ListenSocket* lsb) CXX11_OVERRIDE
-       {
-               if (!user->GetIOHook())
-               {
-                       std::string profilename = lsb->bind_tag->getString("ssl");
-                       for (ProfileList::const_iterator i = profiles.begin(); i != profiles.end(); ++i)
-                       {
-                               if ((*i)->GetName() == profilename)
-                               {
-                                       iohook.sessions[user->GetFd()].profile = *i;
-                                       user->AddIOHook(&iohook);
-                                       break;
-                               }
-                       }
-               }
-       }
-
        void OnUserConnect(LocalUser* user) CXX11_OVERRIDE
        {
-               if (user->eh.GetIOHook() == &iohook)
-                       iohook.TellCiphersAndFingerprint(user);
+               IOHook* hook = user->eh.GetIOHook();
+               if (hook && hook->prov->creator == this)
+                       static_cast<GnuTLSIOHook*>(hook)->TellCiphersAndFingerprint(user);
        }
 };
 
index 11f4a365e64965459d61f8046789f31795087798..962350e1c38cf6adfdbc7b40f76c9276f070df74 100644 (file)
@@ -235,26 +235,6 @@ namespace OpenSSL
        };
 }
 
-/** Represents an SSL user's extra data
- */
-class issl_session
-{
-public:
-       SSL* sess;
-       issl_status status;
-       reference<ssl_cert> cert;
-       reference<OpenSSL::Profile> profile;
-
-       bool outbound;
-       bool data_to_write;
-
-       issl_session()
-       {
-               outbound = false;
-               data_to_write = false;
-       }
-};
-
 static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx)
 {
        /* XXX: This will allow self signed certificates.
@@ -272,34 +252,40 @@ static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx)
 class OpenSSLIOHook : public SSLIOHook
 {
  private:
-       bool Handshake(StreamSocket* user, issl_session* session)
+       SSL* sess;
+       issl_status status;
+       const bool outbound;
+       bool data_to_write;
+       reference<OpenSSL::Profile> profile;
+
+       bool Handshake(StreamSocket* user)
        {
                int ret;
 
-               if (session->outbound)
-                       ret = SSL_connect(session->sess);
+               if (outbound)
+                       ret = SSL_connect(sess);
                else
-                       ret = SSL_accept(session->sess);
+                       ret = SSL_accept(sess);
 
                if (ret < 0)
                {
-                       int err = SSL_get_error(session->sess, ret);
+                       int err = SSL_get_error(sess, ret);
 
                        if (err == SSL_ERROR_WANT_READ)
                        {
                                ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
-                               session->status = ISSL_HANDSHAKING;
+                               this->status = ISSL_HANDSHAKING;
                                return true;
                        }
                        else if (err == SSL_ERROR_WANT_WRITE)
                        {
                                ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
-                               session->status = ISSL_HANDSHAKING;
+                               this->status = ISSL_HANDSHAKING;
                                return true;
                        }
                        else
                        {
-                               CloseSession(session);
+                               CloseSession();
                        }
 
                        return false;
@@ -307,9 +293,9 @@ class OpenSSLIOHook : public SSLIOHook
                else if (ret > 0)
                {
                        // Handshake complete.
-                       VerifyCertificate(session, user);
+                       VerifyCertificate();
 
-                       session->status = ISSL_OPEN;
+                       status = ISSL_OPEN;
 
                        ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
 
@@ -317,38 +303,35 @@ class OpenSSLIOHook : public SSLIOHook
                }
                else if (ret == 0)
                {
-                       CloseSession(session);
+                       CloseSession();
                        return true;
                }
 
                return true;
        }
 
-       void CloseSession(issl_session* session)
+       void CloseSession()
        {
-               if (session->sess)
+               if (sess)
                {
-                       SSL_shutdown(session->sess);
-                       SSL_free(session->sess);
+                       SSL_shutdown(sess);
+                       SSL_free(sess);
                }
-
-               session->sess = NULL;
-               session->status = ISSL_NONE;
+               sess = NULL;
+               certificate = NULL;
+               status = ISSL_NONE;
                errno = EIO;
        }
 
-       void VerifyCertificate(issl_session* session, StreamSocket* user)
+       void VerifyCertificate()
        {
-               if (!session->sess || !user)
-                       return;
-
                X509* cert;
                ssl_cert* certinfo = new ssl_cert;
-               session->cert = certinfo;
+               this->certificate = certinfo;
                unsigned int n;
                unsigned char md[EVP_MAX_MD_SIZE];
 
-               cert = SSL_get_peer_certificate((SSL*)session->sess);
+               cert = SSL_get_peer_certificate(sess);
 
                if (!cert)
                {
@@ -356,7 +339,7 @@ class OpenSSLIOHook : public SSLIOHook
                        return;
                }
 
-               certinfo->invalid = (SSL_get_verify_result(session->sess) != X509_V_OK);
+               certinfo->invalid = (SSL_get_verify_result(sess) != X509_V_OK);
 
                if (!SelfSigned)
                {
@@ -372,7 +355,7 @@ class OpenSSLIOHook : public SSLIOHook
                certinfo->dn = X509_NAME_oneline(X509_get_subject_name(cert),0,0);
                certinfo->issuer = X509_NAME_oneline(X509_get_issuer_name(cert),0,0);
 
-               if (!X509_digest(cert, session->profile->GetDigest(), md, &n))
+               if (!X509_digest(cert, profile->GetDigest(), md, &n))
                {
                        certinfo->error = "Out of memory generating fingerprint";
                }
@@ -390,129 +373,73 @@ class OpenSSLIOHook : public SSLIOHook
        }
 
  public:
-       issl_session* sessions;
-
-       OpenSSLIOHook(Module* mod)
-               : SSLIOHook(mod, "ssl/openssl")
+       OpenSSLIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool is_outbound, SSL* session, const reference<OpenSSL::Profile>& sslprofile)
+               : SSLIOHook(hookprov)
+               , sess(session)
+               , status(ISSL_NONE)
+               , outbound(is_outbound)
+               , data_to_write(false)
+               , profile(sslprofile)
        {
-               sessions = new issl_session[ServerInstance->SE->GetMaxFds()];
-       }
-
-       ~OpenSSLIOHook()
-       {
-               delete[] sessions;
-       }
-
-       void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
-       {
-               int fd = user->GetFd();
-
-               issl_session* session = &sessions[fd];
-
-               session->sess = session->profile->CreateServerSession();
-               session->status = ISSL_NONE;
-               session->outbound = false;
-               session->cert = NULL;
-
-               if (session->sess == NULL)
+               if (sess == NULL)
                        return;
+               if (SSL_set_fd(sess, sock->GetFd()) == 0)
+                       throw ModuleException("Can't set fd with SSL_set_fd: " + ConvToStr(sock->GetFd()));
 
-               if (SSL_set_fd(session->sess, fd) == 0)
-               {
-                       ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Can't set fd with SSL_set_fd: %d", fd);
-                       return;
-               }
-
-               Handshake(user, session);
-       }
-
-       void OnStreamSocketConnect(StreamSocket* user) CXX11_OVERRIDE
-       {
-               int fd = user->GetFd();
-               /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
-               if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() -1))
-                       return;
-
-               issl_session* session = &sessions[fd];
-
-               session->sess = session->profile->CreateClientSession();
-               session->status = ISSL_NONE;
-               session->outbound = true;
-
-               if (session->sess == NULL)
-                       return;
-
-               if (SSL_set_fd(session->sess, fd) == 0)
-               {
-                       ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Can't set fd with SSL_set_fd: %d", fd);
-                       return;
-               }
-
-               Handshake(user, session);
+               sock->AddIOHook(this);
+               Handshake(sock);
        }
 
        void OnStreamSocketClose(StreamSocket* user) CXX11_OVERRIDE
        {
-               int fd = user->GetFd();
-               /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
-               if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
-                       return;
-
-               CloseSession(&sessions[fd]);
+               CloseSession();
        }
 
        int OnStreamSocketRead(StreamSocket* user, std::string& recvq) CXX11_OVERRIDE
        {
-               int fd = user->GetFd();
-               /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
-               if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
-                       return -1;
-
-               issl_session* session = &sessions[fd];
-
-               if (!session->sess)
+               if (!sess)
                {
-                       CloseSession(session);
+                       CloseSession();
                        return -1;
                }
 
-               if (session->status == ISSL_HANDSHAKING)
+               if (status == ISSL_HANDSHAKING)
                {
                        // The handshake isn't finished and it wants to read, try to finish it.
-                       if (!Handshake(user, session))
+                       if (!Handshake(user))
                        {
                                // Couldn't resume handshake.
-                               if (session->status == ISSL_NONE)
+                               if (status == ISSL_NONE)
                                        return -1;
                                return 0;
                        }
                }
 
-               // If we resumed the handshake then session->status will be ISSL_OPEN
+               // If we resumed the handshake then this->status will be ISSL_OPEN
 
-               if (session->status == ISSL_OPEN)
+               if (status == ISSL_OPEN)
                {
                        char* buffer = ServerInstance->GetReadBuffer();
                        size_t bufsiz = ServerInstance->Config->NetBufferSize;
-                       int ret = SSL_read(session->sess, buffer, bufsiz);
+                       int ret = SSL_read(sess, buffer, bufsiz);
 
                        if (ret > 0)
                        {
                                recvq.append(buffer, ret);
-                               if (session->data_to_write)
+                               if (data_to_write)
                                        ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_SINGLE_WRITE);
                                return 1;
                        }
                        else if (ret == 0)
                        {
                                // Client closed connection.
-                               CloseSession(session);
+                               CloseSession();
                                user->SetError("Connection closed");
                                return -1;
                        }
                        else if (ret < 0)
                        {
-                               int err = SSL_get_error(session->sess, ret);
+                               int err = SSL_get_error(sess, ret);
 
                                if (err == SSL_ERROR_WANT_READ)
                                {
@@ -526,7 +453,7 @@ class OpenSSLIOHook : public SSLIOHook
                                }
                                else
                                {
-                                       CloseSession(session);
+                                       CloseSession();
                                        return -1;
                                }
                        }
@@ -537,35 +464,31 @@ class OpenSSLIOHook : public SSLIOHook
 
        int OnStreamSocketWrite(StreamSocket* user, std::string& buffer) CXX11_OVERRIDE
        {
-               int fd = user->GetFd();
-
-               issl_session* session = &sessions[fd];
-
-               if (!session->sess)
+               if (!sess)
                {
-                       CloseSession(session);
+                       CloseSession();
                        return -1;
                }
 
-               session->data_to_write = true;
+               data_to_write = true;
 
-               if (session->status == ISSL_HANDSHAKING)
+               if (status == ISSL_HANDSHAKING)
                {
-                       if (!Handshake(user, session))
+                       if (!Handshake(user))
                        {
                                // Couldn't resume handshake.
-                               if (session->status == ISSL_NONE)
+                               if (status == ISSL_NONE)
                                        return -1;
                                return 0;
                        }
                }
 
-               if (session->status == ISSL_OPEN)
+               if (status == ISSL_OPEN)
                {
-                       int ret = SSL_write(session->sess, buffer.data(), buffer.size());
+                       int ret = SSL_write(sess, buffer.data(), buffer.size());
                        if (ret == (int)buffer.length())
                        {
-                               session->data_to_write = false;
+                               data_to_write = false;
                                ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
                                return 1;
                        }
@@ -577,12 +500,12 @@ class OpenSSLIOHook : public SSLIOHook
                        }
                        else if (ret == 0)
                        {
-                               CloseSession(session);
+                               CloseSession();
                                return -1;
                        }
                        else if (ret < 0)
                        {
-                               int err = SSL_get_error(session->sess, ret);
+                               int err = SSL_get_error(sess, ret);
 
                                if (err == SSL_ERROR_WANT_WRITE)
                                {
@@ -596,7 +519,7 @@ class OpenSSLIOHook : public SSLIOHook
                                }
                                else
                                {
-                                       CloseSession(session);
+                                       CloseSession();
                                        return -1;
                                }
                        }
@@ -604,20 +527,12 @@ class OpenSSLIOHook : public SSLIOHook
                return 0;
        }
 
-       ssl_cert* GetCertificate(StreamSocket* sock) CXX11_OVERRIDE
-       {
-               int fd = sock->GetFd();
-               issl_session* session = &sessions[fd];
-               return session->cert;
-       }
-
        void TellCiphersAndFingerprint(LocalUser* user)
        {
-               issl_session& s = sessions[user->eh.GetFd()];
-               if (s.sess)
+               if (sess)
                {
-                       std::string text = "*** You are connected using SSL cipher '" + std::string(SSL_get_cipher(s.sess)) + "'";
-                       const std::string& fingerprint = s.cert->fingerprint;
+                       std::string text = "*** You are connected using SSL cipher '" + std::string(SSL_get_cipher(sess)) + "'";
+                       const std::string& fingerprint = certificate->fingerprint;
                        if (!fingerprint.empty())
                                text += " and your SSL fingerprint is " + fingerprint;
 
@@ -626,12 +541,39 @@ class OpenSSLIOHook : public SSLIOHook
        }
 };
 
+class OpenSSLIOHookProvider : public refcountbase, public IOHookProvider
+{
+       reference<OpenSSL::Profile> profile;
+
+ public:
+       OpenSSLIOHookProvider(Module* mod, reference<OpenSSL::Profile>& prof)
+               : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL)
+               , profile(prof)
+       {
+               ServerInstance->Modules->AddService(*this);
+       }
+
+       ~OpenSSLIOHookProvider()
+       {
+               ServerInstance->Modules->DelService(*this);
+       }
+
+       void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
+       {
+               new OpenSSLIOHook(this, sock, false, profile->CreateServerSession(), profile);
+       }
+
+       void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
+       {
+               new OpenSSLIOHook(this, sock, true, profile->CreateClientSession(), profile);
+       }
+};
+
 class ModuleSSLOpenSSL : public Module
 {
-       typedef std::vector<reference<OpenSSL::Profile> > ProfileList;
+       typedef std::vector<reference<OpenSSLIOHookProvider> > ProfileList;
 
        std::string sslports;
-       OpenSSLIOHook iohook;
        ProfileList profiles;
 
        void ReadProfiles()
@@ -648,7 +590,7 @@ class ModuleSSLOpenSSL : public Module
                        try
                        {
                                reference<OpenSSL::Profile> profile(new OpenSSL::Profile(defname, tag));
-                               newprofiles.push_back(profile);
+                               newprofiles.push_back(new OpenSSLIOHookProvider(this, profile));
                        }
                        catch (OpenSSL::Exception& ex)
                        {
@@ -679,14 +621,14 @@ class ModuleSSLOpenSSL : public Module
                                throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
                        }
 
-                       newprofiles.push_back(profile);
+                       newprofiles.push_back(new OpenSSLIOHookProvider(this, profile));
                }
 
                profiles.swap(newprofiles);
        }
 
  public:
-       ModuleSSLOpenSSL() : iohook(this)
+       ModuleSSLOpenSSL()
        {
                // Initialize OpenSSL
                SSL_library_init();
@@ -698,24 +640,6 @@ class ModuleSSLOpenSSL : public Module
                ReadProfiles();
        }
 
-       void OnHookIO(StreamSocket* user, ListenSocket* lsb) CXX11_OVERRIDE
-       {
-               if (user->GetIOHook())
-                       return;
-
-               ConfigTag* tag = lsb->bind_tag;
-               std::string profilename = tag->getString("ssl");
-               for (ProfileList::const_iterator i = profiles.begin(); i != profiles.end(); ++i)
-               {
-                       if ((*i)->GetName() == profilename)
-                       {
-                               iohook.sessions[user->GetFd()].profile = *i;
-                               user->AddIOHook(&iohook);
-                               break;
-                       }
-               }
-       }
-
        void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
        {
                sslports.clear();
@@ -778,8 +702,9 @@ class ModuleSSLOpenSSL : public Module
 
        void OnUserConnect(LocalUser* user) CXX11_OVERRIDE
        {
-               if (user->eh.GetIOHook() == &iohook)
-                       iohook.TellCiphersAndFingerprint(user);
+               IOHook* hook = user->eh.GetIOHook();
+               if (hook && hook->prov->creator == this)
+                       static_cast<OpenSSLIOHook*>(hook)->TellCiphersAndFingerprint(user);
        }
 
        void OnCleanup(int target_type, void* item) CXX11_OVERRIDE
@@ -788,7 +713,7 @@ class ModuleSSLOpenSSL : public Module
                {
                        LocalUser* user = IS_LOCAL((User*)item);
 
-                       if (user && user->eh.GetIOHook() == &iohook)
+                       if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this)
                        {
                                // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
                                // Potentially there could be multiple SSL modules loaded at once on different ports.
index 735551dffa3ba6dd184a014fa72200b99560d961..d0291b8cc0f98c251db1cec7ac0f16ebcaa05358 100644 (file)
@@ -65,9 +65,8 @@ class HttpServerSocket : public BufferedSocket
        {
                InternalState = HTTP_SERVE_WAIT_REQUEST;
 
-               FOREACH_MOD(OnHookIO, (this, via));
-               if (GetIOHook())
-                       GetIOHook()->OnStreamSocketAccept(this, client, server);
+               if (via->iohookprov)
+                       via->iohookprov->OnAccept(this, client, server);
        }
 
        ~HttpServerSocket()
index 671e102691af8a0bd7a3f3762cd135a4acf7fad7..1782f7e2a47e54c244ffcf72abdf30c9c9955db7 100644 (file)
@@ -677,7 +677,7 @@ void ModuleSpanningTree::OnUnloadModule(Module* mod)
        for (TreeServer::ChildServers::const_iterator i = list.begin(); i != list.end(); ++i)
        {
                TreeSocket* sock = (*i)->GetSocket();
-               if (sock && sock->GetIOHook() && sock->GetIOHook()->creator == mod)
+               if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod)
                {
                        sock->SendError("SSL module unloaded");
                        sock->Close();
@@ -687,7 +687,7 @@ void ModuleSpanningTree::OnUnloadModule(Module* mod)
        for (SpanningTreeUtilities::TimeoutList::const_iterator i = Utils->timeoutlist.begin(); i != Utils->timeoutlist.end(); ++i)
        {
                TreeSocket* sock = i->first;
-               if (sock->GetIOHook() && sock->GetIOHook()->creator == mod)
+               if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod)
                        sock->Close();
        }
 }
index fa8a94f72bab89b080a8c98b91db50dba35967f1..9c262f1ea469b3aca0881b1ee5ce626692ea8557 100644 (file)
@@ -44,16 +44,7 @@ TreeSocket::TreeSocket(Link* link, Autoconnect* myac, const std::string& ipaddr)
        capab->link = link;
        capab->ac = myac;
        capab->capab_phase = 0;
-       if (!link->Hook.empty())
-       {
-               ServiceProvider* prov = ServerInstance->Modules->FindService(SERVICE_IOHOOK, link->Hook);
-               if (!prov)
-               {
-                       SetError("Could not find hook '" + link->Hook + "' for connection to " + linkID);
-                       return;
-               }
-               AddIOHook(static_cast<IOHook*>(prov));
-       }
+
        DoConnect(ipaddr, link->Port, link->Timeout, link->Bind);
        Utils->timeoutlist[this] = std::pair<std::string, int>(linkID, link->Timeout);
        SendCapabilities(1);
@@ -71,9 +62,8 @@ TreeSocket::TreeSocket(int newfd, ListenSocket* via, irc::sockets::sockaddrs* cl
        capab = new CapabData;
        capab->capab_phase = 0;
 
-       FOREACH_MOD(OnHookIO, (this, via));
-       if (GetIOHook())
-               GetIOHook()->OnStreamSocketAccept(this, client, server);
+       if (via->iohookprov)
+               via->iohookprov->OnAccept(this, client, server);
        SendCapabilities(1);
 
        Utils->timeoutlist[this] = std::pair<std::string, int>(linkID, 30);
@@ -116,6 +106,17 @@ void TreeSocket::OnConnected()
 {
        if (this->LinkState == CONNECTING)
        {
+               if (!capab->link->Hook.empty())
+               {
+                       ServiceProvider* prov = ServerInstance->Modules->FindService(SERVICE_IOHOOK, capab->link->Hook);
+                       if (!prov)
+                       {
+                               SetError("Could not find hook '" + capab->link->Hook + "' for connection to " + linkID);
+                               return;
+                       }
+                       static_cast<IOHookProvider*>(prov)->OnConnect(this);
+               }
+
                ServerInstance->SNO->WriteGlobalSno('l', "Connection to \2%s\2[%s] started.", linkID.c_str(),
                        (capab->link->HiddenFromStats ? "<hidden>" : capab->link->IPAddr.c_str()));
                this->SendCapabilities(1);
index 09c9b4f0fff1b07a33a8949933ecda9fb8d6e96d..d591eed5528264bb5966eaa46358dbc5b76b053d 100644 (file)
@@ -30,10 +30,10 @@ enum
 
 class CommandStartTLS : public SplitCommand
 {
-       dynamic_reference_nocheck<IOHook>& ssl;
+       dynamic_reference_nocheck<IOHookProvider>& ssl;
 
  public:
-       CommandStartTLS(Module* mod, dynamic_reference_nocheck<IOHook>& s)
+       CommandStartTLS(Module* mod, dynamic_reference_nocheck<IOHookProvider>& s)
                : SplitCommand(mod, "STARTTLS")
                , ssl(s)
        {
@@ -71,8 +71,7 @@ class CommandStartTLS : public SplitCommand
                 */
                user->eh.DoWrite();
 
-               user->eh.AddIOHook(*ssl);
-               ssl->OnStreamSocketAccept(&user->eh, NULL, NULL);
+               ssl->OnAccept(&user->eh, NULL, NULL);
 
                return CMD_SUCCESS;
        }
@@ -82,7 +81,7 @@ class ModuleStartTLS : public Module
 {
        CommandStartTLS starttls;
        GenericCap tls;
-       dynamic_reference_nocheck<IOHook> ssl;
+       dynamic_reference_nocheck<IOHookProvider> ssl;
 
  public:
        ModuleStartTLS()
index 4ebed1ccd7117684f61c747729c790ae6899d2a7..c65cd5b27faa870523897695f9b026e685c6f888 100644 (file)
@@ -106,6 +106,8 @@ int InspIRCd::BindPorts(FailedPortList &failed_ports)
                                if ((**n).bind_desc == bind_readable)
                                {
                                        (*n)->bind_tag = tag; // Replace tag, we know addr and port match, but other info (type, ssl) may not
+                                       (*n)->ResetIOHookProvider();
+
                                        skip = true;
                                        old_ports.erase(n);
                                        break;
index 745934fd4467efc78db9c0b325d068aa7f98cc83..29d1f737003c88bf1c8576015a90ae69ae4e3eda 100644 (file)
@@ -62,20 +62,9 @@ void UserManager::AddUser(int socket, ListenSocket* via, irc::sockets::sockaddrs
        }
        UserIOHandler* eh = &New->eh;
 
-       /* Give each of the modules an attempt to hook the user for I/O */
-       FOREACH_MOD(OnHookIO, (eh, via));
-
-       if (eh->GetIOHook())
-       {
-               try
-               {
-                       eh->GetIOHook()->OnStreamSocketAccept(eh, client, server);
-               }
-               catch (CoreException& modexcept)
-               {
-                       ServerInstance->Logs->Log("SOCKET", LOG_DEBUG, "%s threw an exception: %s", modexcept.GetSource().c_str(), modexcept.GetReason().c_str());
-               }
-       }
+       // If this listener has an IO hook provider set then tell it about the connection
+       if (via->iohookprov)
+               via->iohookprov->OnAccept(eh, client, server);
 
        ServerInstance->Logs->Log("USERS", LOG_DEBUG, "New user fd: %d", socket);