]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/extra/m_ssl_gnutls.cpp
Read multiple GnuTLS records in one RawSocketRead operation
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_gnutls.cpp
index 4ff5a9062ad8cce90f9a7bf0e0d2f1cbde0c6db3..f05a73227044d7f08e84389bc2b146408218ce90 100644 (file)
@@ -73,7 +73,7 @@ class CommandStartTLS : public Command
                 */
                if (user->registered != REG_NONE)
                {
-                       ServerInstance->Users->QuitUser(user, "STARTTLS is not permitted after client registration has started");
+                       user->WriteNumeric(691, "%s :STARTTLS is not permitted after client registration has started", user->nick.c_str());
                }
                else
                {
@@ -81,7 +81,7 @@ class CommandStartTLS : public Command
                        {
                                user->WriteNumeric(670, "%s :STARTTLS successful, go ahead with TLS handshake", user->nick.c_str());
                                user->AddIOHook(Caller);
-                               Caller->OnRawSocketAccept(user->GetFd(), user->GetIPString(), user->GetPort());
+                               Caller->OnRawSocketAccept(user->GetFd(), NULL, NULL);
                        }
                        else
                                user->WriteNumeric(691, "%s :STARTTLS failure", user->nick.c_str());
@@ -255,16 +255,16 @@ class ModuleSSLGnuTLS : public Module
                        cred_alloc = true;
 
                if((ret = gnutls_certificate_allocate_credentials(&x509_cred)) < 0)
-                       ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to allocate certificate credentials: %s", gnutls_strerror(ret));
+                       ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to allocate certificate credentials: %s", gnutls_strerror(ret));
 
                if((ret = gnutls_dh_params_init(&dh_params)) < 0)
-                       ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to initialise DH parameters: %s", gnutls_strerror(ret));
+                       ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to initialise DH parameters: %s", gnutls_strerror(ret));
 
                if((ret =gnutls_certificate_set_x509_trust_file(x509_cred, cafile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
-                       ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret));
+                       ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret));
 
                if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
-                       ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret));
+                       ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret));
 
                if((ret = gnutls_certificate_set_x509_key_file (x509_cred, certfile.c_str(), keyfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
                {
@@ -347,9 +347,9 @@ class ModuleSSLGnuTLS : public Module
                output.append(" STARTTLS");
        }
 
-       virtual void OnHookUserIO(User* user, const std::string &targetip)
+       virtual void OnHookUserIO(User* user)
        {
-               if (!user->GetIOHook() && isin(targetip,user->GetPort(),listenports))
+               if (!user->GetIOHook() && isin(user->GetServerIP(),user->GetServerPort(),listenports))
                {
                        /* Hook the user with our module */
                        user->AddIOHook(this);
@@ -421,7 +421,7 @@ class ModuleSSLGnuTLS : public Module
        }
 
 
-       virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport)
+       virtual void OnRawSocketAccept(int fd, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
        {
                /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
                if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
@@ -520,34 +520,43 @@ class ModuleSSLGnuTLS : public Module
 
                if (session->status == ISSL_HANDSHAKEN)
                {
-                       int ret = gnutls_record_recv(session->sess, buffer, count);
-
-                       if (ret > 0)
+                       unsigned int len = 0;
+                       while (len < count)
                        {
-                               readresult = ret;
+                               int ret = gnutls_record_recv(session->sess, buffer + len, count - len);
+                               if (ret > 0)
+                               {
+                                       len += ret;
+                               }
+                               else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
+                               {
+                                       break;
+                               }
+                               else
+                               {
+                                       if (ret != 0)
+                                               ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT,
+                                                       "m_ssl_gnutls.so: Error while reading on fd %d: %s",
+                                                       fd, gnutls_strerror(ret));
+
+                                       // if ret == 0, client closed connection.
+                                       readresult = 0;
+                                       CloseSession(session);
+                                       return 1;
+                               }
                        }
-                       else if (ret == 0)
+                       readresult = len;
+                       if (len)
                        {
-                               // Client closed connection.
-                               readresult = 0;
-                               CloseSession(session);
                                return 1;
                        }
-                       else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
+                       else
                        {
                                errno = EAGAIN;
                                return -1;
                        }
-                       else
-                       {
-                               ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT,
-                                               "m_ssl_gnutls.so: Error while reading on fd %d: %s",
-                                               fd, gnutls_strerror(ret));
-                               readresult = 0;
-                               CloseSession(session);
-                       }
                }
-               else if(session->status == ISSL_CLOSING)
+               else if (session->status == ISSL_CLOSING)
                        readresult = 0;
 
                return 1;
@@ -610,7 +619,8 @@ class ModuleSSLGnuTLS : public Module
                        }
                }
 
-               MakePollWrite(fd);
+               if (!session->outbuf.empty())
+                       MakePollWrite(fd);
 
                /* Who's smart idea was it to return 1 when we havent written anything?
                 * This fucks the buffer up in BufferedSocket :p
@@ -679,10 +689,7 @@ class ModuleSSLGnuTLS : public Module
                // protocol module has propagated the NICK message.
                if (user->GetIOHook() == this && (IS_LOCAL(user)))
                {
-                       // Tell whatever protocol module we're using that we need to inform other servers of this metadata NOW.
-                       ServerInstance->PI->SendMetaData(user, TYPE_USER, "ssl", "on");
-
-                       VerifyCertificate(&sessions[user->GetFd()],user);
+                       ssl_cert* certdata = VerifyCertificate(&sessions[user->GetFd()],user);
                        if (sessions[user->GetFd()].sess)
                        {
                                std::string cipher = gnutls_kx_get_name(gnutls_kx_get(sessions[user->GetFd()].sess));
@@ -690,6 +697,10 @@ class ModuleSSLGnuTLS : public Module
                                cipher.append(gnutls_mac_get_name(gnutls_mac_get(sessions[user->GetFd()].sess)));
                                user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick.c_str(), cipher.c_str());
                        }
+
+                       ServerInstance->PI->SendMetaData(user, TYPE_USER, "ssl", "ON");
+                       if (certdata)
+                               ServerInstance->PI->SendMetaData(user, TYPE_USER, "ssl_cert", certdata->GetMetaLine().c_str());
                }
        }
 
@@ -724,10 +735,10 @@ class ModuleSSLGnuTLS : public Module
                session->status = ISSL_NONE;
        }
 
-       void VerifyCertificate(issl_session* session, Extensible* user)
+       ssl_cert* VerifyCertificate(issl_session* session, Extensible* user)
        {
                if (!session->sess || !user)
-                       return;
+                       return NULL;
 
                unsigned int status;
                const gnutls_datum_t* cert_list;
@@ -750,7 +761,7 @@ class ModuleSSLGnuTLS : public Module
                if (ret < 0)
                {
                        certinfo->error = std::string(gnutls_strerror(ret));
-                       return;
+                       return certinfo;
                }
 
                certinfo->invalid = (status & GNUTLS_CERT_INVALID);
@@ -765,14 +776,14 @@ class ModuleSSLGnuTLS : public Module
                if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509)
                {
                        certinfo->error = "No X509 keys sent";
-                       return;
+                       return certinfo;
                }
 
                ret = gnutls_x509_crt_init(&cert);
                if (ret < 0)
                {
                        certinfo->error = gnutls_strerror(ret);
-                       return;
+                       return certinfo;
                }
 
                cert_list_size = 0;
@@ -780,7 +791,7 @@ class ModuleSSLGnuTLS : public Module
                if (cert_list == NULL)
                {
                        certinfo->error = "No certificate was found";
-                       return;
+                       return certinfo;
                }
 
                /* This is not a real world example, since we only check the first
@@ -791,7 +802,7 @@ class ModuleSSLGnuTLS : public Module
                if (ret < 0)
                {
                        certinfo->error = gnutls_strerror(ret);
-                       return;
+                       return certinfo;
                }
 
                gnutls_x509_crt_get_dn(cert, name, &name_size);
@@ -818,7 +829,7 @@ class ModuleSSLGnuTLS : public Module
 
                gnutls_x509_crt_deinit(cert);
 
-               return;
+               return certinfo;
        }
 
        void OnEvent(Event* ev)