]> git.netwichtig.de Git - user/henk/code/inspircd.git/commitdiff
Create SSLIOHook interface that provides GetCertificate()
authorattilamolnar <attilamolnar@hush.com>
Fri, 24 May 2013 17:34:25 +0000 (19:34 +0200)
committerattilamolnar <attilamolnar@hush.com>
Thu, 6 Jun 2013 23:00:10 +0000 (01:00 +0200)
include/iohook.h
include/modules/ssl.h
src/modules/extra/m_ssl_gnutls.cpp
src/modules/extra/m_ssl_openssl.cpp
src/modules/m_sasl.cpp
src/modules/m_spanningtree/hmac.cpp
src/modules/m_sslinfo.cpp

index 87403681d1f530c4925ac86531064f8236eb5804..7c3a0faeef975aa064064409cecbe238291926fb 100644 (file)
@@ -24,8 +24,16 @@ class StreamSocket;
 class IOHook : public ServiceProvider
 {
  public:
-       IOHook(Module* mod, const std::string& Name)
-               : ServiceProvider(mod, Name, SERVICE_IOHOOK) { }
+       enum Type
+       {
+               IOH_UNKNOWN,
+               IOH_SSL
+       };
+
+       const Type type;
+
+       IOHook(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN)
+               : ServiceProvider(mod, Name, SERVICE_IOHOOK), type(hooktype) { }
 
        /** Called immediately after any connection is accepted. This is intended for raw socket
         * processing (e.g. modules which wrap the tcp connection within another library) and provides
index a4512153761af0f11f8cb0010a39db0cc0466a5b..9830b1ca6b24a8b289d03218616470be6feed47f 100644 (file)
@@ -132,20 +132,67 @@ class ssl_cert : public refcountbase
        }
 };
 
-/** Get certificate from a socket (only useful with an SSL module) */
-struct SocketCertificateRequest : public Request
+class SSLIOHook : public IOHook
 {
-       StreamSocket* const sock;
-       ssl_cert* cert;
+ public:
+       SSLIOHook(Module* mod, const std::string& Name)
+               : IOHook(mod, Name, IOHook::IOH_SSL)
+       {
+       }
+
+       /**
+        * Get the client certificate from a socket
+        * @param sock The socket to get the certificate from, must be using this IOHook
+        * @return The SSL client certificate information
+        */
+       virtual ssl_cert* GetCertificate(StreamSocket* sock) = 0;
 
-       SocketCertificateRequest(StreamSocket* ss, Module* Me)
-               : Request(Me, (ss->GetIOHook() ? (Module*)ss->GetIOHook()->creator : NULL), "GET_SSL_CERT"), sock(ss), cert(NULL)
+       /**
+        * Get the fingerprint of a client certificate from a socket
+        * @param sock The socket to get the certificate fingerprint from, must be using this IOHook
+        * @return The fingerprint of the SSL client certificate sent by the peer,
+        * empty if no cert was sent
+        */
+       std::string GetFingerprint(StreamSocket* sock)
        {
-               Send();
+               ssl_cert* cert = GetCertificate(sock);
+               if (cert)
+                       return cert->GetFingerprint();
+               return "";
        }
+};
 
-       std::string GetFingerprint()
+/** Helper functions for obtaining SSL client certificates and key fingerprints
+ * from StreamSockets
+ */
+class SSLClientCert
+{
+ public:
+       /**
+        * Get the client certificate from a socket
+        * @param sock The socket to get the certificate from, the socket does not have to use SSL
+        * @return The SSL client certificate information, NULL if the peer is not using SSL
+        */
+       static ssl_cert* GetCertificate(StreamSocket* sock)
+       {
+               IOHook* iohook = sock->GetIOHook();
+               if ((!iohook) || (iohook->type != IOHook::IOH_SSL))
+                       return NULL;
+
+               SSLIOHook* ssliohook = static_cast<SSLIOHook*>(iohook);
+               return ssliohook->GetCertificate(sock);
+       }
+
+       /**
+        * Get the fingerprint of a client certificate from a socket
+        * @param sock The socket to get the certificate fingerprint from, the
+        * socket does not have to use SSL
+        * @return The key fingerprint from the SSL certificate sent by the peer,
+        * empty if no cert was sent or the peer is not using SSL
+        */
+       static std::string GetFingerprint(StreamSocket* sock)
        {
+               ssl_cert* cert = SSLClientCert::GetCertificate(sock);
                if (cert)
                        return cert->GetFingerprint();
                return "";
index e051b34e77d9f8581567026c59d948eb8c40c0a1..3c82a5bebd8292dc02197ff90eae2b6e35c6395d 100644 (file)
@@ -100,7 +100,7 @@ public:
        issl_session() : socket(NULL), sess(NULL) {}
 };
 
-class GnuTLSIOHook : public IOHook
+class GnuTLSIOHook : public SSLIOHook
 {
  private:
        void InitSession(StreamSocket* user, bool me_server)
@@ -359,7 +359,7 @@ info_done_dealloc:
        int dh_bits;
 
        GnuTLSIOHook(Module* parent)
-               : IOHook(parent, "ssl/gnutls")
+               : SSLIOHook(parent, "ssl/gnutls")
        {
                sessions = new issl_session[ServerInstance->SE->GetMaxFds()];
        }
@@ -501,6 +501,13 @@ info_done_dealloc:
                return 0;
        }
 
+       ssl_cert* GetCertificate(StreamSocket* sock) CXX11_OVERRIDE
+       {
+               int fd = sock->GetFd();
+               issl_session* session = &sessions[fd];
+               return session->cert;
+       }
+
        void TellCiphersAndFingerprint(LocalUser* user)
        {
                const gnutls_session_t& sess = sessions[user->eh.GetFd()].sess;
@@ -895,18 +902,6 @@ class ModuleSSLGnuTLS : public Module
                }
        }
 
-       void OnRequest(Request& request) CXX11_OVERRIDE
-       {
-               if (strcmp("GET_SSL_CERT", request.id) == 0)
-               {
-                       SocketCertificateRequest& req = static_cast<SocketCertificateRequest&>(request);
-                       int fd = req.sock->GetFd();
-                       issl_session* session = &iohook.sessions[fd];
-
-                       req.cert = session->cert;
-               }
-       }
-
        void OnUserConnect(LocalUser* user) CXX11_OVERRIDE
        {
                if (user->eh.GetIOHook() == &iohook)
index 0c7362e6e09fd268d1748929aba9af0b30530d52..53c0ab8750b4102d57f0b14b4e09d85637bbd6b1 100644 (file)
@@ -101,7 +101,7 @@ static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx)
        return 1;
 }
 
-class OpenSSLIOHook : public IOHook
+class OpenSSLIOHook : public SSLIOHook
 {
  private:
        bool Handshake(StreamSocket* user, issl_session* session)
@@ -229,7 +229,7 @@ class OpenSSLIOHook : public IOHook
        bool use_sha;
 
        OpenSSLIOHook(Module* mod)
-               : IOHook(mod, "ssl/openssl")
+               : SSLIOHook(mod, "ssl/openssl")
        {
                sessions = new issl_session[ServerInstance->SE->GetMaxFds()];
        }
@@ -440,6 +440,13 @@ class OpenSSLIOHook : public IOHook
                return 0;
        }
 
+       ssl_cert* GetCertificate(StreamSocket* sock) CXX11_OVERRIDE
+       {
+               int fd = sock->GetFd();
+               issl_session* session = &sessions[fd];
+               return session->cert;
+       }
+
        void TellCiphersAndFingerprint(LocalUser* user)
        {
                issl_session& s = sessions[user->eh.GetFd()];
@@ -653,18 +660,6 @@ class ModuleSSLOpenSSL : public Module
        {
                return Version("Provides SSL support for clients", VF_VENDOR);
        }
-
-       void OnRequest(Request& request) CXX11_OVERRIDE
-       {
-               if (strcmp("GET_SSL_CERT", request.id) == 0)
-               {
-                       SocketCertificateRequest& req = static_cast<SocketCertificateRequest&>(request);
-                       int fd = req.sock->GetFd();
-                       issl_session* session = &iohook.sessions[fd];
-
-                       req.cert = session->cert;
-               }
-       }
 };
 
 static int error_callback(const char *str, size_t len, void *u)
index 322a726ce2814971f16e1572d96296ff726475a5..45915ab4d1bd9ce588772d459be2b481191f7602 100644 (file)
@@ -63,10 +63,10 @@ class SaslAuthenticator
                params.push_back("S");
                params.push_back(method);
 
-               if (method == "EXTERNAL" && IS_LOCAL(user_))
+               LocalUser* localuser = IS_LOCAL(user);
+               if (method == "EXTERNAL" && localuser)
                {
-                       SocketCertificateRequest req(&((LocalUser*)user_)->eh, ServerInstance->Modules->Find("m_sasl.so"));
-                       std::string fp = req.GetFingerprint();
+                       std::string fp = SSLClientCert::GetFingerprint(&localuser->eh);
 
                        if (fp.size())
                                params.push_back(fp);
index ad632dbc7057c3c4bcf491619cfd71cff12f590f..0b96f9b26fcf88ee0b66ae2561bd2b0c459dfc03 100644 (file)
@@ -69,16 +69,6 @@ bool TreeSocket::ComparePass(const Link& link, const std::string &theirs)
        capab->auth_fingerprint = !link.Fingerprint.empty();
        capab->auth_challenge = !capab->ourchallenge.empty() && !capab->theirchallenge.empty();
 
-       std::string fp;
-       if (GetIOHook())
-       {
-               SocketCertificateRequest req(this, Utils->Creator);
-               if (req.cert)
-               {
-                       fp = req.cert->GetFingerprint();
-               }
-       }
-
        if (capab->auth_challenge)
        {
                std::string our_hmac = MakePass(link.RecvPass, capab->ourchallenge);
@@ -94,6 +84,7 @@ bool TreeSocket::ComparePass(const Link& link, const std::string &theirs)
                        return false;
        }
 
+       std::string fp = SSLClientCert::GetFingerprint(this);
        if (capab->auth_fingerprint)
        {
                /* Require fingerprint to exist and match */
index 8cdaa1cde0c0cad53ae194ae71af5b13edfcef13..5516af7ef04a7a000692e047722b4383fa966b47 100644 (file)
@@ -191,10 +191,9 @@ class ModuleSSLInfo : public Module
 
        void OnUserConnect(LocalUser* user) CXX11_OVERRIDE
        {
-               SocketCertificateRequest req(&user->eh, this);
-               if (!req.cert)
-                       return;
-               cmd.CertExt.set(user, req.cert);
+               ssl_cert* cert = SSLClientCert::GetCertificate(&user->eh);
+               if (cert)
+                       cmd.CertExt.set(user, cert);
        }
 
        void OnPostConnect(User* user) CXX11_OVERRIDE
@@ -214,15 +213,15 @@ class ModuleSSLInfo : public Module
 
        ModResult OnSetConnectClass(LocalUser* user, ConnectClass* myclass) CXX11_OVERRIDE
        {
-               SocketCertificateRequest req(&user->eh, this);
+               ssl_cert* cert = SSLClientCert::GetCertificate(&user->eh);
                bool ok = true;
                if (myclass->config->getString("requiressl") == "trusted")
                {
-                       ok = (req.cert && req.cert->IsCAVerified());
+                       ok = (cert && cert->IsCAVerified());
                }
                else if (myclass->config->getBool("requiressl"))
                {
-                       ok = (req.cert != NULL);
+                       ok = (cert != NULL);
                }
 
                if (!ok)