]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/extra/m_ssl_gnutls.cpp
m_ssl_gnutls Replace ISSL_HANDSHAKING_READ/WRITE with a single state
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_gnutls.cpp
index e4c3128f5f324588c2be35337ff2ff12c1a9d645..a684e59168d95bc7b0215176039c76a2b0e4a7d1 100644 (file)
 #include <gnutls/gnutls.h>
 #include <gnutls/x509.h>
 
-#if ((GNUTLS_VERSION_MAJOR > 2) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR > 9) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR == 9 && GNUTLS_VERSION_PATCH >= 8))
+#ifndef GNUTLS_VERSION_NUMBER
+#define GNUTLS_VERSION_NUMBER LIBGNUTLS_VERSION_NUMBER
+#endif
+
+// Check if the GnuTLS library is at least version major.minor.patch
+#define INSPIRCD_GNUTLS_HAS_VERSION(major, minor, patch) (GNUTLS_VERSION_NUMBER >= ((major << 16) | (minor << 8) | patch))
+
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 9, 8)
 #define GNUTLS_HAS_MAC_GET_ID
 #include <gnutls/crypto.h>
 #endif
 
-#if (GNUTLS_VERSION_MAJOR > 2 || GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR > 12)
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 12, 0)
 # define GNUTLS_HAS_RND
 #else
 # include <gcrypt.h>
 /* $CompileFlags: pkgconfincludes("gnutls","/gnutls/gnutls.h","") eval("print `libgcrypt-config --cflags | tr -d \r` if `pkg-config --modversion gnutls 2>/dev/null | tr -d \r` lt '2.12'") */
 /* $LinkerFlags: rpath("pkg-config --libs gnutls") pkgconflibs("gnutls","/libgnutls.so","-lgnutls") eval("print `libgcrypt-config --libs | tr -d \r` if `pkg-config --modversion gnutls 2>/dev/null | tr -d \r` lt '2.12'") */
 
-#ifndef GNUTLS_VERSION_MAJOR
-#define GNUTLS_VERSION_MAJOR LIBGNUTLS_VERSION_MAJOR
-#define GNUTLS_VERSION_MINOR LIBGNUTLS_VERSION_MINOR
-#define GNUTLS_VERSION_PATCH LIBGNUTLS_VERSION_PATCH
-#endif
-
 // These don't exist in older GnuTLS versions
-#if ((GNUTLS_VERSION_MAJOR > 2) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR > 1) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR == 1 && GNUTLS_VERSION_PATCH >= 7))
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 1, 7)
 #define GNUTLS_NEW_PRIO_API
 #endif
 
-#if(GNUTLS_VERSION_MAJOR < 2)
+#if (!INSPIRCD_GNUTLS_HAS_VERSION(2, 0, 0))
 typedef gnutls_certificate_credentials_t gnutls_certificate_credentials;
 typedef gnutls_dh_params_t gnutls_dh_params;
 #endif
 
-enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED };
+enum issl_status { ISSL_NONE, ISSL_HANDSHAKING, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED };
 
-#if (GNUTLS_VERSION_MAJOR > 2 || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR >= 12))
+#if INSPIRCD_GNUTLS_HAS_VERSION(2, 12, 0)
 #define GNUTLS_NEW_CERT_CALLBACK_API
 typedef gnutls_retr2_st cert_cb_last_param_type;
 #else
 typedef gnutls_retr_st cert_cb_last_param_type;
 #endif
 
+#if INSPIRCD_GNUTLS_HAS_VERSION(3, 3, 5)
+#define INSPIRCD_GNUTLS_HAS_RECV_PACKET
+#endif
+
 class RandGen : public HandlerBase2<void, char*, size_t>
 {
  public:
@@ -450,6 +455,51 @@ namespace GnuTLS
                }
        };
 
+       class DataReader
+       {
+               int retval;
+#ifdef INSPIRCD_GNUTLS_HAS_RECV_PACKET
+               gnutls_packet_t packet;
+
+        public:
+               DataReader(gnutls_session_t sess)
+               {
+                       // Using the packet API avoids the final copy of the data which GnuTLS does if we supply
+                       // our own buffer. Instead, we get the buffer containing the data from GnuTLS and copy it
+                       // to the recvq directly from there in appendto().
+                       retval = gnutls_record_recv_packet(sess, &packet);
+               }
+
+               void appendto(std::string& recvq)
+               {
+                       // Copy data from GnuTLS buffers to recvq
+                       gnutls_datum_t datum;
+                       gnutls_packet_get(packet, &datum, NULL);
+                       recvq.append(reinterpret_cast<const char*>(datum.data), datum.size);
+
+                       gnutls_packet_deinit(packet);
+               }
+#else
+               char* const buffer;
+
+        public:
+               DataReader(gnutls_session_t sess)
+                       : buffer(ServerInstance->GetReadBuffer())
+               {
+                       // Read data from GnuTLS buffers into ReadBuffer
+                       retval = gnutls_record_recv(sess, buffer, ServerInstance->Config->NetBufferSize);
+               }
+
+               void appendto(std::string& recvq)
+               {
+                       // Copy data from ReadBuffer to recvq
+                       recvq.append(buffer, retval);
+               }
+#endif
+
+               int ret() const { return retval; }
+       };
+
        class Profile : public refcountbase
        {
                /** Name of this profile
@@ -587,17 +637,16 @@ class GnuTLSIOHook : public SSLIOHook
                        if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
                        {
                                // Handshake needs resuming later, read() or write() would have blocked.
+                               this->status = ISSL_HANDSHAKING;
 
                                if (gnutls_record_get_direction(this->sess) == 0)
                                {
                                        // gnutls_handshake() wants to read() again.
-                                       this->status = ISSL_HANDSHAKING_READ;
                                        SocketEngine::ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
                                }
                                else
                                {
                                        // gnutls_handshake() wants to write() again.
-                                       this->status = ISSL_HANDSHAKING_WRITE;
                                        SocketEngine::ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
                                }
                        }
@@ -831,7 +880,7 @@ info_done_dealloc:
                        return -1;
                }
 
-               if (this->status == ISSL_HANDSHAKING_READ || this->status == ISSL_HANDSHAKING_WRITE)
+               if (this->status == ISSL_HANDSHAKING)
                {
                        // The handshake isn't finished, try to finish it.
 
@@ -847,12 +896,11 @@ info_done_dealloc:
 
                if (this->status == ISSL_HANDSHAKEN)
                {
-                       char* buffer = ServerInstance->GetReadBuffer();
-                       size_t bufsiz = ServerInstance->Config->NetBufferSize;
-                       int ret = gnutls_record_recv(this->sess, buffer, bufsiz);
+                       GnuTLS::DataReader reader(sess);
+                       int ret = reader.ret();
                        if (ret > 0)
                        {
-                               recvq.append(buffer, ret);
+                               reader.appendto(recvq);
                                return 1;
                        }
                        else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
@@ -887,7 +935,7 @@ info_done_dealloc:
                        return -1;
                }
 
-               if (this->status == ISSL_HANDSHAKING_WRITE || this->status == ISSL_HANDSHAKING_READ)
+               if (this->status == ISSL_HANDSHAKING)
                {
                        // The handshake isn't finished, try to finish it.
                        Handshake(user);