]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/extra/m_ssl_gnutls.cpp
Fix m_mysql warning about use of C++11 features on C++03.
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_gnutls.cpp
index a1c989163ff0fea98012ac1913e3a89ce0964695..e5cb8ee9000ab3c1040db3d0a670ebce9f0c0035 100644 (file)
@@ -101,6 +101,8 @@ typedef gnutls_connection_end_t inspircd_gnutls_session_init_flags_t;
 #define INSPIRCD_GNUTLS_HAS_CORK
 #endif
 
+static Module* thismod;
+
 class RandGen : public HandlerBase2<void, char*, size_t>
 {
  public:
@@ -581,16 +583,21 @@ namespace GnuTLS
                 */
                const unsigned int outrecsize;
 
+               /** True to request a client certificate as a server
+                */
+               const bool requestclientcert;
+
                Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr,
                                std::auto_ptr<DHParams>& DH, unsigned int mindh, const std::string& hashstr,
                                const std::string& priostr, std::auto_ptr<X509CertList>& CA, std::auto_ptr<X509CRL>& CRL,
-                               unsigned int recsize)
+                               unsigned int recsize, bool Requestclientcert)
                        : name(profilename)
                        , x509cred(certstr, keystr)
                        , min_dh_bits(mindh)
                        , hash(hashstr)
                        , priority(priostr)
                        , outrecsize(recsize)
+                       , requestclientcert(Requestclientcert)
                {
                        x509cred.SetDH(DH);
                        x509cred.SetCA(CA, CRL);
@@ -661,7 +668,10 @@ namespace GnuTLS
 #else
                        unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
 #endif
-                       return new Profile(profilename, certstr, keystr, dh, mindh, hashstr, priostr, ca, crl, outrecsize);
+
+                       const bool requestclientcert = tag->getBool("requestclientcert", true);
+
+                       return new Profile(profilename, certstr, keystr, dh, mindh, hashstr, priostr, ca, crl, outrecsize, requestclientcert);
                }
 
                /** Set up the given session with the settings in this profile
@@ -672,8 +682,9 @@ namespace GnuTLS
                        x509cred.SetupSession(sess);
                        gnutls_dh_set_prime_bits(sess, min_dh_bits);
 
-                       // Request client certificate if we are a server, no-op if we're a client
-                       gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST);
+                       // Request client certificate if enabled and we are a server, no-op if we're a client
+                       if (requestclientcert)
+                               gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST);
                }
 
                const std::string& GetName() const { return name; }
@@ -917,7 +928,7 @@ info_done_dealloc:
        {
                StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
 #ifdef _WIN32
-               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
 #endif
 
                if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
@@ -954,7 +965,7 @@ info_done_dealloc:
        {
                StreamSocket* sock = reinterpret_cast<StreamSocket*>(transportptr);
 #ifdef _WIN32
-               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
 #endif
 
                if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
@@ -989,7 +1000,7 @@ info_done_dealloc:
        {
                StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
 #ifdef _WIN32
-               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
 #endif
 
                if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
@@ -1065,6 +1076,9 @@ info_done_dealloc:
                        if (ret > 0)
                        {
                                reader.appendto(recvq);
+                               // Schedule a read if there is still data in the GnuTLS buffer
+                               if (gnutls_record_check_pending(sess) > 0)
+                                       SocketEngine::ChangeEventMask(user, FD_ADD_TRIAL_READ);
                                return 1;
                        }
                        else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
@@ -1086,7 +1100,7 @@ info_done_dealloc:
                }
        }
 
-       int OnStreamSocketWrite(StreamSocket* user) CXX11_OVERRIDE
+       int OnStreamSocketWrite(StreamSocket* user, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
        {
                // Finish handshake if needed
                int prepret = PrepareIO(user);
@@ -1094,7 +1108,6 @@ info_done_dealloc:
                        return prepret;
 
                // Session is ready for transferring application data
-               StreamSocket::SendQueue& sendq = user->GetSendQ();
 
 #ifdef INSPIRCD_GNUTLS_HAS_CORK
                while (true)
@@ -1173,7 +1186,7 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d
        st->key_type = GNUTLS_PRIVKEY_X509;
 #endif
        StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess));
-       GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetIOHook())->GetProfile()->GetX509Credentials();
+       GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile()->GetX509Credentials();
 
        st->ncerts = cred.certs.size();
        st->cert.x509 = cred.certs.raw();
@@ -1283,6 +1296,7 @@ class ModuleSSLGnuTLS : public Module
 #ifndef GNUTLS_HAS_RND
                gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0);
 #endif
+               thismod = this;
        }
 
        void init() CXX11_OVERRIDE
@@ -1318,7 +1332,7 @@ class ModuleSSLGnuTLS : public Module
                {
                        LocalUser* user = IS_LOCAL(static_cast<User*>(item));
 
-                       if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this)
+                       if ((user) && (user->eh.GetModHook(this)))
                        {
                                // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
                                // Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -1334,13 +1348,9 @@ class ModuleSSLGnuTLS : public Module
 
        ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
        {
-               if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
-               {
-                       GnuTLSIOHook* iohook = static_cast<GnuTLSIOHook*>(user->eh.GetIOHook());
-                       if (!iohook->IsHandshakeDone())
-                               return MOD_RES_DENY;
-               }
-
+               const GnuTLSIOHook* const iohook = static_cast<GnuTLSIOHook*>(user->eh.GetModHook(this));
+               if ((iohook) && (!iohook->IsHandshakeDone()))
+                       return MOD_RES_DENY;
                return MOD_RES_PASSTHRU;
        }
 };