]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/extra/m_ssl_gnutls.cpp
m_ssl_gnutls Move logic that reads data from a session into new class GnuTLS::DataReader
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_gnutls.cpp
index be642a79d65512187d49c00379662ffe1318691a..667ad56819d482d7f7925b43a441146264ec0349 100644 (file)
 #include <gnutls/gnutls.h>
 #include <gnutls/x509.h>
 
-#if ((GNUTLS_VERSION_MAJOR > 2) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR > 9) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR == 9 && GNUTLS_VERSION_PATCH >= 8))
+#ifndef GNUTLS_VERSION_NUMBER
+#define GNUTLS_VERSION_NUMBER LIBGNUTLS_VERSION_NUMBER
+#endif
+
+// Check if the GnuTLS library is at least version major.minor.patch
+#define INSPIRCD_GNUTLS_HAS_VERSION(major, minor, patch) (GNUTLS_VERSION_NUMBER >= ((major << 16) | (minor << 8) | patch))
+
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 9, 8)
 #define GNUTLS_HAS_MAC_GET_ID
 #include <gnutls/crypto.h>
 #endif
 
-#if (GNUTLS_VERSION_MAJOR > 2 || GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR > 12)
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 12, 0)
 # define GNUTLS_HAS_RND
 #else
 # include <gcrypt.h>
 /* $CompileFlags: pkgconfincludes("gnutls","/gnutls/gnutls.h","") eval("print `libgcrypt-config --cflags | tr -d \r` if `pkg-config --modversion gnutls 2>/dev/null | tr -d \r` lt '2.12'") */
 /* $LinkerFlags: rpath("pkg-config --libs gnutls") pkgconflibs("gnutls","/libgnutls.so","-lgnutls") eval("print `libgcrypt-config --libs | tr -d \r` if `pkg-config --modversion gnutls 2>/dev/null | tr -d \r` lt '2.12'") */
 
-#ifndef GNUTLS_VERSION_MAJOR
-#define GNUTLS_VERSION_MAJOR LIBGNUTLS_VERSION_MAJOR
-#define GNUTLS_VERSION_MINOR LIBGNUTLS_VERSION_MINOR
-#define GNUTLS_VERSION_PATCH LIBGNUTLS_VERSION_PATCH
-#endif
-
 // These don't exist in older GnuTLS versions
-#if ((GNUTLS_VERSION_MAJOR > 2) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR > 1) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR == 1 && GNUTLS_VERSION_PATCH >= 7))
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 1, 7)
 #define GNUTLS_NEW_PRIO_API
 #endif
 
-#if(GNUTLS_VERSION_MAJOR < 2)
+#if (!INSPIRCD_GNUTLS_HAS_VERSION(2, 0, 0))
 typedef gnutls_certificate_credentials_t gnutls_certificate_credentials;
 typedef gnutls_dh_params_t gnutls_dh_params;
 #endif
 
 enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED };
 
-#if (GNUTLS_VERSION_MAJOR > 2 || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR >= 12))
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 12, 0)
 #define GNUTLS_NEW_CERT_CALLBACK_API
 typedef gnutls_retr2_st cert_cb_last_param_type;
 #else
@@ -450,6 +451,28 @@ namespace GnuTLS
                }
        };
 
+       class DataReader
+       {
+               int retval;
+               char* const buffer;
+
+        public:
+               DataReader(gnutls_session_t sess)
+                       : buffer(ServerInstance->GetReadBuffer())
+               {
+                       // Read data from GnuTLS buffers into ReadBuffer
+                       retval = gnutls_record_recv(sess, buffer, ServerInstance->Config->NetBufferSize);
+               }
+
+               void appendto(std::string& recvq)
+               {
+                       // Copy data from ReadBuffer to recvq
+                       recvq.append(buffer, retval);
+               }
+
+               int ret() const { return retval; }
+       };
+
        class Profile : public refcountbase
        {
                /** Name of this profile
@@ -847,12 +870,11 @@ info_done_dealloc:
 
                if (this->status == ISSL_HANDSHAKEN)
                {
-                       char* buffer = ServerInstance->GetReadBuffer();
-                       size_t bufsiz = ServerInstance->Config->NetBufferSize;
-                       int ret = gnutls_record_recv(this->sess, buffer, bufsiz);
+                       GnuTLS::DataReader reader(sess);
+                       int ret = reader.ret();
                        if (ret > 0)
                        {
-                               recvq.append(buffer, ret);
+                               reader.appendto(recvq);
                                return 1;
                        }
                        else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
@@ -934,11 +956,8 @@ info_done_dealloc:
                if (sess)
                {
                        std::string text = "*** You are connected using SSL cipher '";
-
-                       text += UnknownIfNULL(gnutls_kx_get_name(gnutls_kx_get(sess)));
-                       text.append("-").append(UnknownIfNULL(gnutls_cipher_get_name(gnutls_cipher_get(sess)))).append("-");
-                       text.append(UnknownIfNULL(gnutls_mac_get_name(gnutls_mac_get(sess)))).append("'");
-
+                       GetCiphersuite(text);
+                       text += '\'';
                        if (!certificate->fingerprint.empty())
                                text += " and your SSL certificate fingerprint is " + certificate->fingerprint;
 
@@ -946,6 +965,14 @@ info_done_dealloc:
                }
        }
 
+       void GetCiphersuite(std::string& out) const
+       {
+               out.append(UnknownIfNULL(gnutls_protocol_get_name(gnutls_protocol_get_version(sess)))).push_back('-');
+               out.append(UnknownIfNULL(gnutls_kx_get_name(gnutls_kx_get(sess)))).push_back('-');
+               out.append(UnknownIfNULL(gnutls_cipher_get_name(gnutls_cipher_get(sess)))).push_back('-');
+               out.append(UnknownIfNULL(gnutls_mac_get_name(gnutls_mac_get(sess))));
+       }
+
        GnuTLS::Profile* GetProfile() { return profile; }
 };