]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/extra/m_ssl_mbedtls.cpp
m_ssl_mbedtls Apply dummy CA workaround for client context
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_mbedtls.cpp
index 7efcce72db88371fd4731fe65cebcda0649e6846..ffe0a71b8fd3540fea3e91df55a65b6a230fba63 100644 (file)
@@ -257,7 +257,6 @@ namespace mbedTLS
                        mbedtls_debug_set_threshold(INT_MAX);
                        mbedtls_ssl_conf_dbg(&conf, DebugLogFunc, NULL);
 #endif
-                       mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
 
                        // TODO: check ret of mbedtls_ssl_config_defaults
                        mbedtls_ssl_config_defaults(&conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
@@ -308,6 +307,11 @@ namespace mbedTLS
                        mbedtls_ssl_conf_ca_chain(&conf, certs.get(), crl.get());
                }
 
+               void SetOptionalVerifyCert()
+               {
+                       mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
+               }
+
                const mbedtls_ssl_config* GetConf() const { return &conf; }
        };
 
@@ -376,7 +380,8 @@ namespace mbedTLS
                                const std::string& castr, const std::string& crlstr,
                                unsigned int recsize,
                                CTRDRBG& ctrdrbg,
-                               int minver, int maxver
+                               int minver, int maxver,
+                               bool requestclientcert
                                )
                        : name(profilename)
                        , x509cred(certstr, keystr)
@@ -414,7 +419,14 @@ namespace mbedTLS
                                serverctx.SetDHParams(dhparams);
                        }
 
-                       serverctx.SetCA(cacerts, crl);
+                       clientctx.SetOptionalVerifyCert();
+                       clientctx.SetCA(cacerts, crl);                  
+                       // The default for servers is to not request a client certificate from the peer
+                       if (requestclientcert)
+                       {
+                               serverctx.SetOptionalVerifyCert();
+                               serverctx.SetCA(cacerts, crl);
+                       }
                }
 
                static std::string ReadFile(const std::string& filename)
@@ -451,7 +463,8 @@ namespace mbedTLS
                        int minver = tag->getInt("minver");
                        int maxver = tag->getInt("maxver");
                        unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
-                       return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver);
+                       const bool requestclientcert = tag->getBool("requestclientcert", true);
+                       return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver, requestclientcert);
                }
 
                /** Set up the given session with the settings in this profile
@@ -894,7 +907,7 @@ class ModuleSSLmbedTLS : public Module
                        return;
 
                LocalUser* user = IS_LOCAL(static_cast<User*>(item));
-               if ((user) && (user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
+               if ((user) && (user->eh.GetModHook(this)))
                {
                        // User is using SSL, they're a local user, and they're using our IOHook.
                        // Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -904,13 +917,9 @@ class ModuleSSLmbedTLS : public Module
 
        ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
        {
-               if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
-               {
-                       mbedTLSIOHook* iohook = static_cast<mbedTLSIOHook*>(user->eh.GetIOHook());
-                       if (!iohook->IsHandshakeDone())
-                               return MOD_RES_DENY;
-               }
-
+               const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this));
+               if ((iohook) && (!iohook->IsHandshakeDone()))
+                       return MOD_RES_DENY;
                return MOD_RES_PASSTHRU;
        }