]> git.netwichtig.de Git - user/henk/code/inspircd.git/commitdiff
Unite SSL service providers and SSL profile classes
authorAttila Molnar <attilamolnar@hush.com>
Tue, 10 Jan 2017 19:21:57 +0000 (20:21 +0100)
committerPeter Powell <petpow@saberuk.com>
Mon, 13 Nov 2017 16:38:30 +0000 (16:38 +0000)
include/iohook.h
src/modules/extra/m_ssl_gnutls.cpp
src/modules/extra/m_ssl_mbedtls.cpp
src/modules/extra/m_ssl_openssl.cpp
src/modules/m_websocket.cpp

index e99316b99dc1110c02183649a69abb653e50e988..9ca17d77e46bdcdf6aef38e081abc0f9cd7adf29 100644 (file)
@@ -21,7 +21,7 @@
 
 class StreamSocket;
 
-class IOHookProvider : public ServiceProvider
+class IOHookProvider : public refcountbase, public ServiceProvider
 {
        const bool middlehook;
 
@@ -69,7 +69,7 @@ class IOHook : public classbase
        /** The IOHookProvider for this hook, contains information about the hook,
         * such as the module providing it and the hook type.
         */
-       IOHookProvider* const prov;
+       reference<IOHookProvider> prov;
 
        /** Constructor
         * @param provider IOHookProvider that creates this object
index 50c847ee4f7e87c4df3208d0b20379c56a5649f4..534c3abbc773f450ba866b6c77dea4e3e4feb999 100644 (file)
@@ -566,7 +566,7 @@ namespace GnuTLS
                int ret() const { return retval; }
        };
 
-       class Profile : public refcountbase
+       class Profile
        {
                /** Name of this profile
                 */
@@ -596,22 +596,6 @@ namespace GnuTLS
                 */
                const bool requestclientcert;
 
-               Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr,
-                               std::auto_ptr<DHParams>& DH, unsigned int mindh, const std::string& hashstr,
-                               const std::string& priostr, std::auto_ptr<X509CertList>& CA, std::auto_ptr<X509CRL>& CRL,
-                               unsigned int recsize, bool Requestclientcert)
-                       : name(profilename)
-                       , x509cred(certstr, keystr)
-                       , min_dh_bits(mindh)
-                       , hash(hashstr)
-                       , priority(priostr)
-                       , outrecsize(recsize)
-                       , requestclientcert(Requestclientcert)
-               {
-                       x509cred.SetDH(DH);
-                       x509cred.SetCA(CA, CRL);
-               }
-
                static std::string ReadFile(const std::string& filename)
                {
                        FileReader reader(filename);
@@ -647,42 +631,66 @@ namespace GnuTLS
                }
 
         public:
-               static reference<Profile> Create(const std::string& profilename, ConfigTag* tag)
+               struct Config
                {
-                       std::string certstr = ReadFile(tag->getString("certfile", "cert.pem"));
-                       std::string keystr = ReadFile(tag->getString("keyfile", "key.pem"));
+                       std::string name;
 
-                       std::auto_ptr<DHParams> dh = DHParams::Import(ReadFile(tag->getString("dhfile", "dhparams.pem")));
-
-                       std::string priostr = GetPrioStr(profilename, tag);
-                       unsigned int mindh = tag->getInt("mindhbits", 1024);
-                       std::string hashstr = tag->getString("hash", "md5");
-
-                       // Load trusted CA and revocation list, if set
                        std::auto_ptr<X509CertList> ca;
                        std::auto_ptr<X509CRL> crl;
-                       std::string filename = tag->getString("cafile");
-                       if (!filename.empty())
-                       {
-                               ca.reset(new X509CertList(ReadFile(filename)));
 
-                               filename = tag->getString("crlfile");
+                       std::string certstr;
+                       std::string keystr;
+                       std::auto_ptr<DHParams> dh;
+
+                       std::string priostr;
+                       unsigned int mindh;
+                       std::string hashstr;
+
+                       unsigned int outrecsize;
+                       bool requestclientcert;
+
+                       Config(const std::string& profilename, ConfigTag* tag)
+                               : name(profilename)
+                               , certstr(ReadFile(tag->getString("certfile", "cert.pem")))
+                               , keystr(ReadFile(tag->getString("keyfile", "key.pem")))
+                               , dh(DHParams::Import(ReadFile(tag->getString("dhfile", "dhparams.pem"))))
+                               , priostr(GetPrioStr(profilename, tag))
+                               , mindh(tag->getInt("mindhbits", 1024))
+                               , hashstr(tag->getString("hash", "md5"))
+                               , requestclientcert(tag->getBool("requestclientcert", true))
+                       {
+                               // Load trusted CA and revocation list, if set
+                               std::string filename = tag->getString("cafile");
                                if (!filename.empty())
-                                       crl.reset(new X509CRL(ReadFile(filename)));
-                       }
+                               {
+                                       ca.reset(new X509CertList(ReadFile(filename)));
+
+                                       filename = tag->getString("crlfile");
+                                       if (!filename.empty())
+                                               crl.reset(new X509CRL(ReadFile(filename)));
+                               }
 
 #ifdef INSPIRCD_GNUTLS_HAS_CORK
-                       // If cork support is available outrecsize represents the (rough) max amount of data we give GnuTLS while corked
-                       unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512);
+                               // If cork support is available outrecsize represents the (rough) max amount of data we give GnuTLS while corked
+                               outrecsize = tag->getInt("outrecsize", 2048, 512);
 #else
-                       unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
+                               outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
 #endif
+                       }
+               };
 
-                       const bool requestclientcert = tag->getBool("requestclientcert", true);
-
-                       return new Profile(profilename, certstr, keystr, dh, mindh, hashstr, priostr, ca, crl, outrecsize, requestclientcert);
+               Profile(Config& config)
+                       : name(config.name)
+                       , x509cred(config.certstr, config.keystr)
+                       , min_dh_bits(config.mindh)
+                       , hash(config.hashstr)
+                       , priority(config.priostr)
+                       , outrecsize(config.outrecsize)
+                       , requestclientcert(config.requestclientcert)
+               {
+                       x509cred.SetDH(config.dh);
+                       x509cred.SetCA(config.ca, config.crl);
                }
-
                /** Set up the given session with the settings in this profile
                 */
                void SetupSession(gnutls_session_t sess)
@@ -708,7 +716,6 @@ class GnuTLSIOHook : public SSLIOHook
  private:
        gnutls_session_t sess;
        issl_status status;
-       reference<GnuTLS::Profile> profile;
 #ifdef INSPIRCD_GNUTLS_HAS_CORK
        size_t gbuffersize;
 #endif
@@ -855,7 +862,7 @@ class GnuTLSIOHook : public SSLIOHook
                                issuer.clear();
                }
 
-               if ((ret = gnutls_x509_crt_get_fingerprint(cert, profile->GetHash(), digest, &digest_size)) < 0)
+               if ((ret = gnutls_x509_crt_get_fingerprint(cert, GetProfile().GetHash(), digest, &digest_size)) < 0)
                {
                        certinfo->error = gnutls_strerror(ret);
                }
@@ -1043,11 +1050,10 @@ info_done_dealloc:
 #endif // INSPIRCD_GNUTLS_HAS_VECTOR_PUSH
 
  public:
-       GnuTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, inspircd_gnutls_session_init_flags_t flags, const reference<GnuTLS::Profile>& sslprofile)
+       GnuTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, inspircd_gnutls_session_init_flags_t flags)
                : SSLIOHook(hookprov)
                , sess(NULL)
                , status(ISSL_NONE)
-               , profile(sslprofile)
 #ifdef INSPIRCD_GNUTLS_HAS_CORK
                , gbuffersize(0)
 #endif
@@ -1060,7 +1066,7 @@ info_done_dealloc:
                gnutls_transport_set_push_function(sess, gnutls_push_wrapper);
 #endif
                gnutls_transport_set_pull_function(sess, gnutls_pull_wrapper);
-               profile->SetupSession(sess);
+               GetProfile().SetupSession(sess);
 
                sock->AddIOHook(this);
                Handshake(sock);
@@ -1132,7 +1138,7 @@ info_done_dealloc:
 
                        // GnuTLS buffer is empty but sendq is not, begin sending data from the sendq
                        gnutls_record_cork(this->sess);
-                       while ((!sendq.empty()) && (gbuffersize < profile->GetOutgoingRecordSize()))
+                       while ((!sendq.empty()) && (gbuffersize < GetProfile().GetOutgoingRecordSize()))
                        {
                                const StreamSocket::SendQueue::Element& elem = sendq.front();
                                gbuffersize += elem.length();
@@ -1150,7 +1156,7 @@ info_done_dealloc:
 
                while (!sendq.empty())
                {
-                       FlattenSendQueue(sendq, profile->GetOutgoingRecordSize());
+                       FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
                        const StreamSocket::SendQueue::Element& buffer = sendq.front();
                        ret = HandleWriteRet(user, gnutls_record_send(this->sess, buffer.data(), buffer.length()));
 
@@ -1201,7 +1207,7 @@ info_done_dealloc:
                return true;
        }
 
-       GnuTLS::Profile* GetProfile() { return profile; }
+       GnuTLS::Profile& GetProfile();
        bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
 };
 
@@ -1214,7 +1220,7 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d
        st->key_type = GNUTLS_PRIVKEY_X509;
 #endif
        StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess));
-       GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile()->GetX509Credentials();
+       GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile().GetX509Credentials();
 
        st->ncerts = cred.certs.size();
        st->cert.x509 = cred.certs.raw();
@@ -1224,14 +1230,14 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d
        return 0;
 }
 
-class GnuTLSIOHookProvider : public refcountbase, public IOHookProvider
+class GnuTLSIOHookProvider : public IOHookProvider
 {
-       reference<GnuTLS::Profile> profile;
+       GnuTLS::Profile profile;
 
  public:
-       GnuTLSIOHookProvider(Module* mod, reference<GnuTLS::Profile>& prof)
-               : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL)
-               , profile(prof)
+       GnuTLSIOHookProvider(Module* mod, GnuTLS::Profile::Config& config)
+               : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL)
+               , profile(config)
        {
                ServerInstance->Modules->AddService(*this);
        }
@@ -1243,15 +1249,23 @@ class GnuTLSIOHookProvider : public refcountbase, public IOHookProvider
 
        void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
        {
-               new GnuTLSIOHook(this, sock, GNUTLS_SERVER, profile);
+               new GnuTLSIOHook(this, sock, GNUTLS_SERVER);
        }
 
        void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
        {
-               new GnuTLSIOHook(this, sock, GNUTLS_CLIENT, profile);
+               new GnuTLSIOHook(this, sock, GNUTLS_CLIENT);
        }
+
+       GnuTLS::Profile& GetProfile() { return profile; }
 };
 
+GnuTLS::Profile& GnuTLSIOHook::GetProfile()
+{
+       IOHookProvider* hookprov = prov;
+       return static_cast<GnuTLSIOHookProvider*>(hookprov)->GetProfile();
+}
+
 class ModuleSSLGnuTLS : public Module
 {
        typedef std::vector<reference<GnuTLSIOHookProvider> > ProfileList;
@@ -1278,8 +1292,8 @@ class ModuleSSLGnuTLS : public Module
 
                        try
                        {
-                               reference<GnuTLS::Profile> profile(GnuTLS::Profile::Create(defname, tag));
-                               newprofiles.push_back(new GnuTLSIOHookProvider(this, profile));
+                               GnuTLS::Profile::Config profileconfig(defname, tag);
+                               newprofiles.push_back(new GnuTLSIOHookProvider(this, profileconfig));
                        }
                        catch (CoreException& ex)
                        {
@@ -1300,21 +1314,28 @@ class ModuleSSLGnuTLS : public Module
                                continue;
                        }
 
-                       reference<GnuTLS::Profile> profile;
+                       reference<GnuTLSIOHookProvider> prov;
                        try
                        {
-                               profile = GnuTLS::Profile::Create(name, tag);
+                               GnuTLS::Profile::Config profileconfig(name, tag);
+                               prov = new GnuTLSIOHookProvider(this, profileconfig);
                        }
                        catch (CoreException& ex)
                        {
                                throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
                        }
 
-                       newprofiles.push_back(new GnuTLSIOHookProvider(this, profile));
+                       newprofiles.push_back(prov);
                }
 
                // New profiles are ok, begin using them
                // Old profiles are deleted when their refcount drops to zero
+               for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
+               {
+                       GnuTLSIOHookProvider& prov = **i;
+                       ServerInstance->Modules.DelService(prov);
+               }
+
                profiles.swap(newprofiles);
        }
 
index 4e0032fdcae3d17c9845566f4006203d7c75d0a2..8c15342f22820ceedaa95a66a61f76c0c03b418e 100644 (file)
@@ -345,7 +345,7 @@ namespace mbedTLS
                }
        };
 
-       class Profile : public refcountbase
+       class Profile
        {
                /** Name of this profile
                 */
@@ -378,29 +378,71 @@ namespace mbedTLS
                 */
                const unsigned int outrecsize;
 
-               Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr,
-                               const std::string& dhstr, unsigned int mindh, const std::string& hashstr,
-                               const std::string& ciphersuitestr, const std::string& curvestr,
-                               const std::string& castr, const std::string& crlstr,
-                               unsigned int recsize,
-                               CTRDRBG& ctrdrbg,
-                               int minver, int maxver,
-                               bool requestclientcert
-                               )
-                       : name(profilename)
-                       , x509cred(certstr, keystr)
-                       , ciphersuites(ciphersuitestr)
-                       , curves(curvestr)
-                       , serverctx(ctrdrbg, MBEDTLS_SSL_IS_SERVER)
-                       , clientctx(ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
-                       , cacerts(castr, true)
-                       , crl(crlstr)
-                       , hash(hashstr)
-                       , outrecsize(recsize)
+        public:
+               struct Config
+               {
+                       const std::string name;
+
+                       CTRDRBG& ctrdrbg;
+
+                       const std::string certstr;
+                       const std::string keystr;
+                       const std::string dhstr;
+
+                       const std::string ciphersuitestr;
+                       const std::string curvestr;
+                       const unsigned int mindh;
+                       const std::string hashstr;
+
+                       std::string crlstr;
+                       std::string castr;
+
+                       const int minver;
+                       const int maxver;
+                       const unsigned int outrecsize;
+                       const bool requestclientcert;
+
+                       Config(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
+                               : name(profilename)
+                               , ctrdrbg(ctr_drbg)
+                               , certstr(ReadFile(tag->getString("certfile", "cert.pem")))
+                               , keystr(ReadFile(tag->getString("keyfile", "key.pem")))
+                               , dhstr(ReadFile(tag->getString("dhfile", "dhparams.pem")))
+                               , ciphersuitestr(tag->getString("ciphersuites"))
+                               , curvestr(tag->getString("curves"))
+                               , mindh(tag->getInt("mindhbits", 2048))
+                               , hashstr(tag->getString("hash", "sha256"))
+                               , castr(tag->getString("cafile"))
+                               , minver(tag->getInt("minver"))
+                               , maxver(tag->getInt("maxver"))
+                               , outrecsize(tag->getInt("outrecsize", 2048, 512, 16384))
+                               , requestclientcert(tag->getBool("requestclientcert", true))
+                       {
+                               if (!castr.empty())
+                               {
+                                       castr = ReadFile(castr);
+                                       crlstr = tag->getString("crlfile");
+                                       if (!crlstr.empty())
+                                               crlstr = ReadFile(crlstr);
+                               }
+                       }
+               };
+
+               Profile(Config& config)
+                       : name(config.name)
+                       , x509cred(config.certstr, config.keystr)
+                       , ciphersuites(config.ciphersuitestr)
+                       , curves(config.curvestr)
+                       , serverctx(config.ctrdrbg, MBEDTLS_SSL_IS_SERVER)
+                       , clientctx(config.ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
+                       , cacerts(config.castr, true)
+                       , crl(config.crlstr)
+                       , hash(config.hashstr)
+                       , outrecsize(config.outrecsize)
                {
                        serverctx.SetX509CertAndKey(x509cred);
                        clientctx.SetX509CertAndKey(x509cred);
-                       clientctx.SetMinDHBits(mindh);
+                       clientctx.SetMinDHBits(config.mindh);
 
                        if (!ciphersuites.empty())
                        {
@@ -414,19 +456,19 @@ namespace mbedTLS
                                clientctx.SetCurves(curves);
                        }
 
-                       serverctx.SetVersion(minver, maxver);
-                       clientctx.SetVersion(minver, maxver);
+                       serverctx.SetVersion(config.minver, config.maxver);
+                       clientctx.SetVersion(config.minver, config.maxver);
 
-                       if (!dhstr.empty())
+                       if (!config.dhstr.empty())
                        {
-                               dhparams.set(dhstr);
+                               dhparams.set(config.dhstr);
                                serverctx.SetDHParams(dhparams);
                        }
 
                        clientctx.SetOptionalVerifyCert();
                        clientctx.SetCA(cacerts, crl);
                        // The default for servers is to not request a client certificate from the peer
-                       if (requestclientcert)
+                       if (config.requestclientcert)
                        {
                                serverctx.SetOptionalVerifyCert();
                                serverctx.SetCA(cacerts, crl);
@@ -442,35 +484,6 @@ namespace mbedTLS
                        return ret;
                }
 
-        public:
-               static reference<Profile> Create(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
-               {
-                       const std::string certstr = ReadFile(tag->getString("certfile", "cert.pem"));
-                       const std::string keystr = ReadFile(tag->getString("keyfile", "key.pem"));
-                       const std::string dhstr = ReadFile(tag->getString("dhfile", "dhparams.pem"));
-
-                       const std::string ciphersuitestr = tag->getString("ciphersuites");
-                       const std::string curvestr = tag->getString("curves");
-                       unsigned int mindh = tag->getInt("mindhbits", 2048);
-                       std::string hashstr = tag->getString("hash", "sha256");
-
-                       std::string crlstr;
-                       std::string castr = tag->getString("cafile");
-                       if (!castr.empty())
-                       {
-                               castr = ReadFile(castr);
-                               crlstr = tag->getString("crlfile");
-                               if (!crlstr.empty())
-                                       crlstr = ReadFile(crlstr);
-                       }
-
-                       int minver = tag->getInt("minver");
-                       int maxver = tag->getInt("maxver");
-                       unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
-                       const bool requestclientcert = tag->getBool("requestclientcert", true);
-                       return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver, requestclientcert);
-               }
-
                /** Set up the given session with the settings in this profile
                 */
                void SetupClientSession(mbedtls_ssl_context* sess)
@@ -501,7 +514,6 @@ class mbedTLSIOHook : public SSLIOHook
 
        mbedtls_ssl_context sess;
        Status status;
-       reference<mbedTLS::Profile> profile;
 
        void CloseSession()
        {
@@ -575,7 +587,7 @@ class mbedTLSIOHook : public SSLIOHook
                }
 
                // If there is a certificate we can always generate a fingerprint
-               certificate->fingerprint = profile->GetHash().hash(cert->raw.p, cert->raw.len);
+               certificate->fingerprint = GetProfile().GetHash().hash(cert->raw.p, cert->raw.len);
 
                // At this point mbedTLS verified the cert already, we just need to check the results
                const uint32_t flags = mbedtls_ssl_get_verify_result(&sess);
@@ -649,16 +661,15 @@ class mbedTLSIOHook : public SSLIOHook
        }
 
  public:
-       mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver, mbedTLS::Profile* sslprofile)
+       mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver)
                : SSLIOHook(hookprov)
                , status(ISSL_NONE)
-               , profile(sslprofile)
        {
                mbedtls_ssl_init(&sess);
                if (isserver)
-                       profile->SetupServerSession(&sess);
+                       GetProfile().SetupServerSession(&sess);
                else
-                       profile->SetupClientSession(&sess);
+                       GetProfile().SetupClientSession(&sess);
 
                mbedtls_ssl_set_bio(&sess, reinterpret_cast<void*>(sock), Push, Pull, NULL);
 
@@ -725,7 +736,7 @@ class mbedTLSIOHook : public SSLIOHook
                // Session is ready for transferring application data
                while (!sendq.empty())
                {
-                       FlattenSendQueue(sendq, profile->GetOutgoingRecordSize());
+                       FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
                        const StreamSocket::SendQueue::Element& buffer = sendq.front();
                        int ret = mbedtls_ssl_write(&sess, reinterpret_cast<const unsigned char*>(buffer.data()), buffer.length());
                        if (ret == (int)buffer.length())
@@ -788,17 +799,18 @@ class mbedTLSIOHook : public SSLIOHook
                return false;
        }
 
+       mbedTLS::Profile& GetProfile();
        bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
 };
 
-class mbedTLSIOHookProvider : public refcountbase, public IOHookProvider
+class mbedTLSIOHookProvider : public IOHookProvider
 {
-       reference<mbedTLS::Profile> profile;
+       mbedTLS::Profile profile;
 
  public:
-       mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile* prof)
-               : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL)
-               , profile(prof)
+       mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile::Config& config)
+               : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL)
+               , profile(config)
        {
                ServerInstance->Modules->AddService(*this);
        }
@@ -810,15 +822,23 @@ class mbedTLSIOHookProvider : public refcountbase, public IOHookProvider
 
        void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
        {
-               new mbedTLSIOHook(this, sock, true, profile);
+               new mbedTLSIOHook(this, sock, true);
        }
 
        void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
        {
-               new mbedTLSIOHook(this, sock, false, profile);
+               new mbedTLSIOHook(this, sock, false);
        }
+
+       mbedTLS::Profile& GetProfile() { return profile; }
 };
 
+mbedTLS::Profile& mbedTLSIOHook::GetProfile()
+{
+       IOHookProvider* hookprov = prov;
+       return static_cast<mbedTLSIOHookProvider*>(hookprov)->GetProfile();
+}
+
 class ModuleSSLmbedTLS : public Module
 {
        typedef std::vector<reference<mbedTLSIOHookProvider> > ProfileList;
@@ -844,8 +864,8 @@ class ModuleSSLmbedTLS : public Module
 
                        try
                        {
-                               reference<mbedTLS::Profile> profile(mbedTLS::Profile::Create(defname, tag, ctr_drbg));
-                               newprofiles.push_back(new mbedTLSIOHookProvider(this, profile));
+                               mbedTLS::Profile::Config profileconfig(defname, tag, ctr_drbg);
+                               newprofiles.push_back(new mbedTLSIOHookProvider(this, profileconfig));
                        }
                        catch (CoreException& ex)
                        {
@@ -866,21 +886,28 @@ class ModuleSSLmbedTLS : public Module
                                continue;
                        }
 
-                       reference<mbedTLS::Profile> profile;
+                       reference<mbedTLSIOHookProvider> prov;
                        try
                        {
-                               profile = mbedTLS::Profile::Create(name, tag, ctr_drbg);
+                               mbedTLS::Profile::Config profileconfig(name, tag, ctr_drbg);
+                               prov = new mbedTLSIOHookProvider(this, profileconfig);
                        }
                        catch (CoreException& ex)
                        {
                                throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
                        }
 
-                       newprofiles.push_back(new mbedTLSIOHookProvider(this, profile));
+                       newprofiles.push_back(prov);
                }
 
                // New profiles are ok, begin using them
                // Old profiles are deleted when their refcount drops to zero
+               for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
+               {
+                       mbedTLSIOHookProvider& prov = **i;
+                       ServerInstance->Modules.DelService(prov);
+               }
+
                profiles.swap(newprofiles);
        }
 
index 9b7e608a2b32fb4fe7dd55f33365d29dad09d06a..ae5e213b7c119e66b2dd2a611ce38e0c87a3e476 100644 (file)
@@ -240,7 +240,7 @@ namespace OpenSSL
                }
        };
 
-       class Profile : public refcountbase
+       class Profile
        {
                /** Name of this profile
                 */
@@ -459,7 +459,6 @@ class OpenSSLIOHook : public SSLIOHook
        SSL* sess;
        issl_status status;
        bool data_to_write;
-       reference<OpenSSL::Profile> profile;
 
        // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed
        int Handshake(StreamSocket* user)
@@ -559,7 +558,7 @@ class OpenSSLIOHook : public SSLIOHook
                if (certinfo->issuer.find_first_of("\r\n") != std::string::npos)
                        certinfo->issuer.clear();
 
-               if (!X509_digest(cert, profile->GetDigest(), md, &n))
+               if (!X509_digest(cert, GetProfile().GetDigest(), md, &n))
                {
                        certinfo->error = "Out of memory generating fingerprint";
                }
@@ -580,7 +579,7 @@ class OpenSSLIOHook : public SSLIOHook
        {
                if ((where & SSL_CB_HANDSHAKE_START) && (status == ISSL_OPEN))
                {
-                       if (profile->AllowRenegotiation())
+                       if (GetProfile().AllowRenegotiation())
                                return;
 
                        // The other side is trying to renegotiate, kill the connection and change status
@@ -622,12 +621,11 @@ class OpenSSLIOHook : public SSLIOHook
        friend void StaticSSLInfoCallback(const SSL* ssl, int where, int rc);
 
  public:
-       OpenSSLIOHook(IOHookProvider* hookprov, StreamSocket* sock, SSL* session, const reference<OpenSSL::Profile>& sslprofile)
+       OpenSSLIOHook(IOHookProvider* hookprov, StreamSocket* sock, SSL* session)
                : SSLIOHook(hookprov)
                , sess(session)
                , status(ISSL_NONE)
                , data_to_write(false)
-               , profile(sslprofile)
        {
                // Create BIO instance and store a pointer to the socket in it which will be used by the read and write functions
 #ifdef INSPIRCD_OPENSSL_OPAQUE_BIO
@@ -721,7 +719,7 @@ class OpenSSLIOHook : public SSLIOHook
                while (!sendq.empty())
                {
                        ERR_clear_error();
-                       FlattenSendQueue(sendq, profile->GetOutgoingRecordSize());
+                       FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
                        const StreamSocket::SendQueue::Element& buffer = sendq.front();
                        int ret = SSL_write(sess, buffer.data(), buffer.size());
 
@@ -790,6 +788,7 @@ class OpenSSLIOHook : public SSLIOHook
        }
 
        bool IsHandshakeDone() const { return (status == ISSL_OPEN); }
+       OpenSSL::Profile& GetProfile();
 };
 
 static void StaticSSLInfoCallback(const SSL* ssl, int where, int rc)
@@ -844,14 +843,14 @@ static int OpenSSL::BIOMethod::read(BIO* bio, char* buffer, int size)
        return ret;
 }
 
-class OpenSSLIOHookProvider : public refcountbase, public IOHookProvider
+class OpenSSLIOHookProvider : public IOHookProvider
 {
-       reference<OpenSSL::Profile> profile;
+       OpenSSL::Profile profile;
 
  public:
-       OpenSSLIOHookProvider(Module* mod, reference<OpenSSL::Profile>& prof)
-               : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL)
-               , profile(prof)
+       OpenSSLIOHookProvider(Module* mod, const std::string& profilename, ConfigTag* tag)
+               : IOHookProvider(mod, "ssl/" + profilename, IOHookProvider::IOH_SSL)
+               , profile(profilename, tag)
        {
                ServerInstance->Modules->AddService(*this);
        }
@@ -863,15 +862,23 @@ class OpenSSLIOHookProvider : public refcountbase, public IOHookProvider
 
        void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
        {
-               new OpenSSLIOHook(this, sock, profile->CreateServerSession(), profile);
+               new OpenSSLIOHook(this, sock, profile.CreateServerSession());
        }
 
        void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
        {
-               new OpenSSLIOHook(this, sock, profile->CreateClientSession(), profile);
+               new OpenSSLIOHook(this, sock, profile.CreateClientSession());
        }
+
+       OpenSSL::Profile& GetProfile() { return profile; }
 };
 
+OpenSSL::Profile& OpenSSLIOHook::GetProfile()
+{
+       IOHookProvider* hookprov = prov;
+       return static_cast<OpenSSLIOHookProvider*>(hookprov)->GetProfile();
+}
+
 class ModuleSSLOpenSSL : public Module
 {
        typedef std::vector<reference<OpenSSLIOHookProvider> > ProfileList;
@@ -891,8 +898,7 @@ class ModuleSSLOpenSSL : public Module
 
                        try
                        {
-                               reference<OpenSSL::Profile> profile(new OpenSSL::Profile(defname, tag));
-                               newprofiles.push_back(new OpenSSLIOHookProvider(this, profile));
+                               newprofiles.push_back(new OpenSSLIOHookProvider(this, defname, tag));
                        }
                        catch (OpenSSL::Exception& ex)
                        {
@@ -913,17 +919,23 @@ class ModuleSSLOpenSSL : public Module
                                continue;
                        }
 
-                       reference<OpenSSL::Profile> profile;
+                       reference<OpenSSLIOHookProvider> prov;
                        try
                        {
-                               profile = new OpenSSL::Profile(name, tag);
+                               prov = new OpenSSLIOHookProvider(this, name, tag);
                        }
                        catch (CoreException& ex)
                        {
                                throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
                        }
 
-                       newprofiles.push_back(new OpenSSLIOHookProvider(this, profile));
+                       newprofiles.push_back(prov);
+               }
+
+               for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
+               {
+                       OpenSSLIOHookProvider& prov = **i;
+                       ServerInstance->Modules.DelService(prov);
                }
 
                profiles.swap(newprofiles);
index a7457f7886f47460e422ae64dbf2460d815461a9..12102d2151bd818f392ac82162dbb32dfa117c67 100644 (file)
@@ -376,12 +376,12 @@ void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs
 class ModuleWebSocket : public Module
 {
        dynamic_reference_nocheck<HashProvider> hash;
-       WebSocketHookProvider hookprov;
+       reference<WebSocketHookProvider> hookprov;
 
  public:
        ModuleWebSocket()
                : hash(this, "hash/sha1")
-               , hookprov(this)
+               , hookprov(new WebSocketHookProvider(this))
        {
                sha1 = &hash;
        }