]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/extra/m_ssl_gnutls.cpp
Merge pull request #1219 from SaberUK/master+directive
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_gnutls.cpp
index a1c989163ff0fea98012ac1913e3a89ce0964695..a42efa1ab121359f6f1587873a4484e280ec7019 100644 (file)
  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
  */
 
+/// $CompilerFlags: find_compiler_flags("gnutls")
+/// $CompilerFlags: require_version("gnutls" "1.0" "2.12") execute("libgcrypt-config --cflags" "LIBGCRYPT_CXXFLAGS")
+
+/// $LinkerFlags: find_linker_flags("gnutls" "-lgnutls")
+/// $LinkerFlags: require_version("gnutls" "1.0" "2.12") execute("libgcrypt-config --libs" "LIBGCRYPT_LDFLAGS")
+
+/// $PackageInfo: require_system("darwin") gnutls pkg-config
+/// $PackageInfo: require_system("ubuntu" "1.0" "13.10") libgcrypt11-dev
+/// $PackageInfo: require_system("ubuntu" "14.04") gnutls-bin libgnutls-dev pkg-config
 
 #include "inspircd.h"
 #include "modules/ssl.h"
@@ -62,9 +71,6 @@
 # pragma comment(lib, "libgnutls-30.lib")
 #endif
 
-/* $CompileFlags: pkgconfincludes("gnutls","/gnutls/gnutls.h","") eval("print `libgcrypt-config --cflags | tr -d \r` if `pkg-config --modversion gnutls 2>/dev/null | tr -d \r` lt '2.12'") */
-/* $LinkerFlags: rpath("pkg-config --libs gnutls") pkgconflibs("gnutls","/libgnutls.so","-lgnutls") eval("print `libgcrypt-config --libs | tr -d \r` if `pkg-config --modversion gnutls 2>/dev/null | tr -d \r` lt '2.12'") */
-
 // These don't exist in older GnuTLS versions
 #if INSPIRCD_GNUTLS_HAS_VERSION(2, 1, 7)
 #define GNUTLS_NEW_PRIO_API
@@ -101,6 +107,8 @@ typedef gnutls_connection_end_t inspircd_gnutls_session_init_flags_t;
 #define INSPIRCD_GNUTLS_HAS_CORK
 #endif
 
+static Module* thismod;
+
 class RandGen : public HandlerBase2<void, char*, size_t>
 {
  public:
@@ -581,16 +589,21 @@ namespace GnuTLS
                 */
                const unsigned int outrecsize;
 
+               /** True to request a client certificate as a server
+                */
+               const bool requestclientcert;
+
                Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr,
                                std::auto_ptr<DHParams>& DH, unsigned int mindh, const std::string& hashstr,
                                const std::string& priostr, std::auto_ptr<X509CertList>& CA, std::auto_ptr<X509CRL>& CRL,
-                               unsigned int recsize)
+                               unsigned int recsize, bool Requestclientcert)
                        : name(profilename)
                        , x509cred(certstr, keystr)
                        , min_dh_bits(mindh)
                        , hash(hashstr)
                        , priority(priostr)
                        , outrecsize(recsize)
+                       , requestclientcert(Requestclientcert)
                {
                        x509cred.SetDH(DH);
                        x509cred.SetCA(CA, CRL);
@@ -661,7 +674,10 @@ namespace GnuTLS
 #else
                        unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
 #endif
-                       return new Profile(profilename, certstr, keystr, dh, mindh, hashstr, priostr, ca, crl, outrecsize);
+
+                       const bool requestclientcert = tag->getBool("requestclientcert", true);
+
+                       return new Profile(profilename, certstr, keystr, dh, mindh, hashstr, priostr, ca, crl, outrecsize, requestclientcert);
                }
 
                /** Set up the given session with the settings in this profile
@@ -672,8 +688,9 @@ namespace GnuTLS
                        x509cred.SetupSession(sess);
                        gnutls_dh_set_prime_bits(sess, min_dh_bits);
 
-                       // Request client certificate if we are a server, no-op if we're a client
-                       gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST);
+                       // Request client certificate if enabled and we are a server, no-op if we're a client
+                       if (requestclientcert)
+                               gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST);
                }
 
                const std::string& GetName() const { return name; }
@@ -917,7 +934,7 @@ info_done_dealloc:
        {
                StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
 #ifdef _WIN32
-               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
 #endif
 
                if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
@@ -954,7 +971,7 @@ info_done_dealloc:
        {
                StreamSocket* sock = reinterpret_cast<StreamSocket*>(transportptr);
 #ifdef _WIN32
-               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
 #endif
 
                if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
@@ -989,7 +1006,7 @@ info_done_dealloc:
        {
                StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
 #ifdef _WIN32
-               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
 #endif
 
                if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
@@ -1065,6 +1082,9 @@ info_done_dealloc:
                        if (ret > 0)
                        {
                                reader.appendto(recvq);
+                               // Schedule a read if there is still data in the GnuTLS buffer
+                               if (gnutls_record_check_pending(sess) > 0)
+                                       SocketEngine::ChangeEventMask(user, FD_ADD_TRIAL_READ);
                                return 1;
                        }
                        else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
@@ -1086,7 +1106,7 @@ info_done_dealloc:
                }
        }
 
-       int OnStreamSocketWrite(StreamSocket* user) CXX11_OVERRIDE
+       int OnStreamSocketWrite(StreamSocket* user, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
        {
                // Finish handshake if needed
                int prepret = PrepareIO(user);
@@ -1094,7 +1114,6 @@ info_done_dealloc:
                        return prepret;
 
                // Session is ready for transferring application data
-               StreamSocket::SendQueue& sendq = user->GetSendQ();
 
 #ifdef INSPIRCD_GNUTLS_HAS_CORK
                while (true)
@@ -1173,7 +1192,7 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d
        st->key_type = GNUTLS_PRIVKEY_X509;
 #endif
        StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess));
-       GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetIOHook())->GetProfile()->GetX509Credentials();
+       GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile()->GetX509Credentials();
 
        st->ncerts = cred.certs.size();
        st->cert.x509 = cred.certs.raw();
@@ -1283,6 +1302,7 @@ class ModuleSSLGnuTLS : public Module
 #ifndef GNUTLS_HAS_RND
                gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0);
 #endif
+               thismod = this;
        }
 
        void init() CXX11_OVERRIDE
@@ -1318,7 +1338,7 @@ class ModuleSSLGnuTLS : public Module
                {
                        LocalUser* user = IS_LOCAL(static_cast<User*>(item));
 
-                       if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this)
+                       if ((user) && (user->eh.GetModHook(this)))
                        {
                                // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
                                // Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -1334,13 +1354,9 @@ class ModuleSSLGnuTLS : public Module
 
        ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
        {
-               if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
-               {
-                       GnuTLSIOHook* iohook = static_cast<GnuTLSIOHook*>(user->eh.GetIOHook());
-                       if (!iohook->IsHandshakeDone())
-                               return MOD_RES_DENY;
-               }
-
+               const GnuTLSIOHook* const iohook = static_cast<GnuTLSIOHook*>(user->eh.GetModHook(this));
+               if ((iohook) && (!iohook->IsHandshakeDone()))
+                       return MOD_RES_DENY;
                return MOD_RES_PASSTHRU;
        }
 };