]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/extra/m_ssl_mbedtls.cpp
Update copyright headers.
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_mbedtls.cpp
index 8578b8196e593ee5eb8b0f2dfa1e3971015a0334..84c507cf8978ef65af8fe52329cc95bcf3ae7b08 100644 (file)
@@ -1,7 +1,9 @@
 /*
  * InspIRCd -- Internet Relay Chat Daemon
  *
- *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
+ *   Copyright (C) 2019 Matt Schatz <genius3000@g3k.solutions>
+ *   Copyright (C) 2016-2019 Sadie Powell <sadie@witchery.services>
+ *   Copyright (C) 2016-2017 Attila Molnar <attilamolnar@hush.com>
  *
  * This file is part of InspIRCd.  InspIRCd is free software: you can
  * redistribute it and/or modify it under the terms of the GNU General Public
  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
  */
 
+/// $LinkerFlags: -lmbedtls
+
+/// $PackageInfo: require_system("arch") mbedtls
+/// $PackageInfo: require_system("darwin") mbedtls
+/// $PackageInfo: require_system("debian" "9.0") libmbedtls-dev
+/// $PackageInfo: require_system("ubuntu" "16.04") libmbedtls-dev
 
-/* $LinkerFlags: -lmbedtls */
 
 #include "inspircd.h"
 #include "modules/ssl.h"
@@ -257,7 +264,6 @@ namespace mbedTLS
                        mbedtls_debug_set_threshold(INT_MAX);
                        mbedtls_ssl_conf_dbg(&conf, DebugLogFunc, NULL);
 #endif
-                       mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
 
                        // TODO: check ret of mbedtls_ssl_config_defaults
                        mbedtls_ssl_config_defaults(&conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
@@ -308,6 +314,11 @@ namespace mbedTLS
                        mbedtls_ssl_conf_ca_chain(&conf, certs.get(), crl.get());
                }
 
+               void SetOptionalVerifyCert()
+               {
+                       mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
+               }
+
                const mbedtls_ssl_config* GetConf() const { return &conf; }
        };
 
@@ -337,7 +348,7 @@ namespace mbedTLS
                }
        };
 
-       class Profile : public refcountbase
+       class Profile
        {
                /** Name of this profile
                 */
@@ -370,28 +381,71 @@ namespace mbedTLS
                 */
                const unsigned int outrecsize;
 
-               Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr,
-                               const std::string& dhstr, unsigned int mindh, const std::string& hashstr,
-                               const std::string& ciphersuitestr, const std::string& curvestr,
-                               const std::string& castr, const std::string& crlstr,
-                               unsigned int recsize,
-                               CTRDRBG& ctrdrbg,
-                               int minver, int maxver
-                               )
-                       : name(profilename)
-                       , x509cred(certstr, keystr)
-                       , ciphersuites(ciphersuitestr)
-                       , curves(curvestr)
-                       , serverctx(ctrdrbg, MBEDTLS_SSL_IS_SERVER)
-                       , clientctx(ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
-                       , cacerts(castr, true)
-                       , crl(crlstr)
-                       , hash(hashstr)
-                       , outrecsize(recsize)
+        public:
+               struct Config
+               {
+                       const std::string name;
+
+                       CTRDRBG& ctrdrbg;
+
+                       const std::string certstr;
+                       const std::string keystr;
+                       const std::string dhstr;
+
+                       const std::string ciphersuitestr;
+                       const std::string curvestr;
+                       const unsigned int mindh;
+                       const std::string hashstr;
+
+                       std::string crlstr;
+                       std::string castr;
+
+                       const int minver;
+                       const int maxver;
+                       const unsigned int outrecsize;
+                       const bool requestclientcert;
+
+                       Config(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
+                               : name(profilename)
+                               , ctrdrbg(ctr_drbg)
+                               , certstr(ReadFile(tag->getString("certfile", "cert.pem")))
+                               , keystr(ReadFile(tag->getString("keyfile", "key.pem")))
+                               , dhstr(ReadFile(tag->getString("dhfile", "dhparams.pem")))
+                               , ciphersuitestr(tag->getString("ciphersuites"))
+                               , curvestr(tag->getString("curves"))
+                               , mindh(tag->getUInt("mindhbits", 2048))
+                               , hashstr(tag->getString("hash", "sha256"))
+                               , castr(tag->getString("cafile"))
+                               , minver(tag->getUInt("minver", 0))
+                               , maxver(tag->getUInt("maxver", 0))
+                               , outrecsize(tag->getUInt("outrecsize", 2048, 512, 16384))
+                               , requestclientcert(tag->getBool("requestclientcert", true))
+                       {
+                               if (!castr.empty())
+                               {
+                                       castr = ReadFile(castr);
+                                       crlstr = tag->getString("crlfile");
+                                       if (!crlstr.empty())
+                                               crlstr = ReadFile(crlstr);
+                               }
+                       }
+               };
+
+               Profile(Config& config)
+                       : name(config.name)
+                       , x509cred(config.certstr, config.keystr)
+                       , ciphersuites(config.ciphersuitestr)
+                       , curves(config.curvestr)
+                       , serverctx(config.ctrdrbg, MBEDTLS_SSL_IS_SERVER)
+                       , clientctx(config.ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
+                       , cacerts(config.castr, true)
+                       , crl(config.crlstr)
+                       , hash(config.hashstr)
+                       , outrecsize(config.outrecsize)
                {
                        serverctx.SetX509CertAndKey(x509cred);
                        clientctx.SetX509CertAndKey(x509cred);
-                       clientctx.SetMinDHBits(mindh);
+                       clientctx.SetMinDHBits(config.mindh);
 
                        if (!ciphersuites.empty())
                        {
@@ -405,16 +459,23 @@ namespace mbedTLS
                                clientctx.SetCurves(curves);
                        }
 
-                       serverctx.SetVersion(minver, maxver);
-                       clientctx.SetVersion(minver, maxver);
+                       serverctx.SetVersion(config.minver, config.maxver);
+                       clientctx.SetVersion(config.minver, config.maxver);
 
-                       if (!dhstr.empty())
+                       if (!config.dhstr.empty())
                        {
-                               dhparams.set(dhstr);
+                               dhparams.set(config.dhstr);
                                serverctx.SetDHParams(dhparams);
                        }
 
-                       serverctx.SetCA(cacerts, crl);
+                       clientctx.SetOptionalVerifyCert();
+                       clientctx.SetCA(cacerts, crl);
+                       // The default for servers is to not request a client certificate from the peer
+                       if (config.requestclientcert)
+                       {
+                               serverctx.SetOptionalVerifyCert();
+                               serverctx.SetCA(cacerts, crl);
+                       }
                }
 
                static std::string ReadFile(const std::string& filename)
@@ -426,34 +487,6 @@ namespace mbedTLS
                        return ret;
                }
 
-        public:
-               static reference<Profile> Create(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
-               {
-                       const std::string certstr = ReadFile(tag->getString("certfile", "cert.pem"));
-                       const std::string keystr = ReadFile(tag->getString("keyfile", "key.pem"));
-                       const std::string dhstr = ReadFile(tag->getString("dhfile", "dhparams.pem"));
-
-                       const std::string ciphersuitestr = tag->getString("ciphersuites");
-                       const std::string curvestr = tag->getString("curves");
-                       unsigned int mindh = tag->getInt("mindhbits", 2048);
-                       std::string hashstr = tag->getString("hash", "sha256");
-
-                       std::string crlstr;
-                       std::string castr = tag->getString("cafile");
-                       if (!castr.empty())
-                       {
-                               castr = ReadFile(castr);
-                               crlstr = tag->getString("crlfile");
-                               if (!crlstr.empty())
-                                       crlstr = ReadFile(crlstr);
-                       }
-
-                       int minver = tag->getInt("minver");
-                       int maxver = tag->getInt("maxver");
-                       unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
-                       return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver);
-               }
-
                /** Set up the given session with the settings in this profile
                 */
                void SetupClientSession(mbedtls_ssl_context* sess)
@@ -484,7 +517,6 @@ class mbedTLSIOHook : public SSLIOHook
 
        mbedtls_ssl_context sess;
        Status status;
-       reference<mbedTLS::Profile> profile;
 
        void CloseSession()
        {
@@ -558,7 +590,7 @@ class mbedTLSIOHook : public SSLIOHook
                }
 
                // If there is a certificate we can always generate a fingerprint
-               certificate->fingerprint = profile->GetHash().hash(cert->raw.p, cert->raw.len);
+               certificate->fingerprint = GetProfile().GetHash().hash(cert->raw.p, cert->raw.len);
 
                // At this point mbedTLS verified the cert already, we just need to check the results
                const uint32_t flags = mbedtls_ssl_get_verify_result(&sess);
@@ -632,16 +664,15 @@ class mbedTLSIOHook : public SSLIOHook
        }
 
  public:
-       mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver, mbedTLS::Profile* sslprofile)
+       mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver)
                : SSLIOHook(hookprov)
                , status(ISSL_NONE)
-               , profile(sslprofile)
        {
                mbedtls_ssl_init(&sess);
                if (isserver)
-                       profile->SetupServerSession(&sess);
+                       GetProfile().SetupServerSession(&sess);
                else
-                       profile->SetupClientSession(&sess);
+                       GetProfile().SetupClientSession(&sess);
 
                mbedtls_ssl_set_bio(&sess, reinterpret_cast<void*>(sock), Push, Pull, NULL);
 
@@ -698,7 +729,7 @@ class mbedTLSIOHook : public SSLIOHook
                }
        }
 
-       int OnStreamSocketWrite(StreamSocket* sock) CXX11_OVERRIDE
+       int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
        {
                // Finish handshake if needed
                int prepret = PrepareIO(sock);
@@ -706,10 +737,9 @@ class mbedTLSIOHook : public SSLIOHook
                        return prepret;
 
                // Session is ready for transferring application data
-               StreamSocket::SendQueue& sendq = sock->GetSendQ();
                while (!sendq.empty())
                {
-                       FlattenSendQueue(sendq, profile->GetOutgoingRecordSize());
+                       FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
                        const StreamSocket::SendQueue::Element& buffer = sendq.front();
                        int ret = mbedtls_ssl_write(&sess, reinterpret_cast<const unsigned char*>(buffer.data()), buffer.length());
                        if (ret == (int)buffer.length())
@@ -766,17 +796,24 @@ class mbedTLSIOHook : public SSLIOHook
                out.append(ciphersuitestr + skip);
        }
 
+       bool GetServerName(std::string& out) const CXX11_OVERRIDE
+       {
+               // TODO: Implement SNI support.
+               return false;
+       }
+
+       mbedTLS::Profile& GetProfile();
        bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
 };
 
-class mbedTLSIOHookProvider : public refcountbase, public IOHookProvider
+class mbedTLSIOHookProvider : public IOHookProvider
 {
-       reference<mbedTLS::Profile> profile;
+       mbedTLS::Profile profile;
 
  public:
-       mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile* prof)
-               : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL)
-               , profile(prof)
+       mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile::Config& config)
+               : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL)
+               , profile(config)
        {
                ServerInstance->Modules->AddService(*this);
        }
@@ -788,15 +825,23 @@ class mbedTLSIOHookProvider : public refcountbase, public IOHookProvider
 
        void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
        {
-               new mbedTLSIOHook(this, sock, true, profile);
+               new mbedTLSIOHook(this, sock, true);
        }
 
        void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
        {
-               new mbedTLSIOHook(this, sock, false, profile);
+               new mbedTLSIOHook(this, sock, false);
        }
+
+       mbedTLS::Profile& GetProfile() { return profile; }
 };
 
+mbedTLS::Profile& mbedTLSIOHook::GetProfile()
+{
+       IOHookProvider* hookprov = prov;
+       return static_cast<mbedTLSIOHookProvider*>(hookprov)->GetProfile();
+}
+
 class ModuleSSLmbedTLS : public Module
 {
        typedef std::vector<reference<mbedTLSIOHookProvider> > ProfileList;
@@ -822,8 +867,8 @@ class ModuleSSLmbedTLS : public Module
 
                        try
                        {
-                               reference<mbedTLS::Profile> profile(mbedTLS::Profile::Create(defname, tag, ctr_drbg));
-                               newprofiles.push_back(new mbedTLSIOHookProvider(this, profile));
+                               mbedTLS::Profile::Config profileconfig(defname, tag, ctr_drbg);
+                               newprofiles.push_back(new mbedTLSIOHookProvider(this, profileconfig));
                        }
                        catch (CoreException& ex)
                        {
@@ -834,7 +879,7 @@ class ModuleSSLmbedTLS : public Module
                for (ConfigIter i = tags.first; i != tags.second; ++i)
                {
                        ConfigTag* tag = i->second;
-                       if (tag->getString("provider") != "mbedtls")
+                       if (!stdalgo::string::equalsci(tag->getString("provider"), "mbedtls"))
                                continue;
 
                        std::string name = tag->getString("name");
@@ -844,21 +889,28 @@ class ModuleSSLmbedTLS : public Module
                                continue;
                        }
 
-                       reference<mbedTLS::Profile> profile;
+                       reference<mbedTLSIOHookProvider> prov;
                        try
                        {
-                               profile = mbedTLS::Profile::Create(name, tag, ctr_drbg);
+                               mbedTLS::Profile::Config profileconfig(name, tag, ctr_drbg);
+                               prov = new mbedTLSIOHookProvider(this, profileconfig);
                        }
                        catch (CoreException& ex)
                        {
                                throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
                        }
 
-                       newprofiles.push_back(new mbedTLSIOHookProvider(this, profile));
+                       newprofiles.push_back(prov);
                }
 
                // New profiles are ok, begin using them
                // Old profiles are deleted when their refcount drops to zero
+               for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
+               {
+                       mbedTLSIOHookProvider& prov = **i;
+                       ServerInstance->Modules.DelService(prov);
+               }
+
                profiles.swap(newprofiles);
        }
 
@@ -876,12 +928,13 @@ class ModuleSSLmbedTLS : public Module
 
        void OnModuleRehash(User* user, const std::string &param) CXX11_OVERRIDE
        {
-               if (param != "ssl")
+               if (!irc::equals(param, "ssl"))
                        return;
 
                try
                {
                        ReadProfiles();
+                       ServerInstance->SNO->WriteToSnoMask('a', "SSL module %s rehashed.", MODNAME);
                }
                catch (ModuleException& ex)
                {
@@ -889,13 +942,13 @@ class ModuleSSLmbedTLS : public Module
                }
        }
 
-       void OnCleanup(int target_type, void* item) CXX11_OVERRIDE
+       void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
        {
-               if (target_type != TYPE_USER)
+               if (type != ExtensionItem::EXT_USER)
                        return;
 
                LocalUser* user = IS_LOCAL(static_cast<User*>(item));
-               if ((user) && (user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
+               if ((user) && (user->eh.GetModHook(this)))
                {
                        // User is using SSL, they're a local user, and they're using our IOHook.
                        // Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -905,13 +958,9 @@ class ModuleSSLmbedTLS : public Module
 
        ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
        {
-               if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
-               {
-                       mbedTLSIOHook* iohook = static_cast<mbedTLSIOHook*>(user->eh.GetIOHook());
-                       if (!iohook->IsHandshakeDone())
-                               return MOD_RES_DENY;
-               }
-
+               const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this));
+               if ((iohook) && (!iohook->IsHandshakeDone()))
+                       return MOD_RES_DENY;
                return MOD_RES_PASSTHRU;
        }