diff options
Diffstat (limited to 'src/modules/extra/m_ssl_gnutls.cpp')
-rw-r--r-- | src/modules/extra/m_ssl_gnutls.cpp | 273 |
1 files changed, 121 insertions, 152 deletions
diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index 7c19925dd..2add962fd 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -543,58 +543,41 @@ namespace GnuTLS }; } -/** Represents an SSL user's extra data - */ -class issl_session +class GnuTLSIOHook : public SSLIOHook { -public: - StreamSocket* socket; + private: gnutls_session_t sess; issl_status status; - reference<ssl_cert> cert; reference<GnuTLS::Profile> profile; - issl_session() : socket(NULL), sess(NULL) {} -}; - -class GnuTLSIOHook : public SSLIOHook -{ - private: void InitSession(StreamSocket* user, bool me_server) { - issl_session* session = &sessions[user->GetFd()]; - - gnutls_init(&session->sess, me_server ? GNUTLS_SERVER : GNUTLS_CLIENT); - session->socket = user; + gnutls_init(&sess, me_server ? GNUTLS_SERVER : GNUTLS_CLIENT); - session->profile->SetupSession(session->sess); - gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(session)); - gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper); - gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper); + profile->SetupSession(sess); + gnutls_transport_set_ptr(sess, reinterpret_cast<gnutls_transport_ptr_t>(user)); + gnutls_transport_set_push_function(sess, gnutls_push_wrapper); + gnutls_transport_set_pull_function(sess, gnutls_pull_wrapper); if (me_server) - gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any. - - Handshake(session, user); + gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST); // Request client certificate if any. } - void CloseSession(issl_session* session) + void CloseSession() { - if (session->sess) + if (this->sess) { - gnutls_bye(session->sess, GNUTLS_SHUT_WR); - gnutls_deinit(session->sess); + gnutls_bye(this->sess, GNUTLS_SHUT_WR); + gnutls_deinit(this->sess); } - session->socket = NULL; - session->sess = NULL; - session->cert = NULL; - session->profile = NULL; - session->status = ISSL_NONE; + sess = NULL; + certificate = NULL; + status = ISSL_NONE; } - bool Handshake(issl_session* session, StreamSocket* user) + bool Handshake(StreamSocket* user) { - int ret = gnutls_handshake(session->sess); + int ret = gnutls_handshake(this->sess); if (ret < 0) { @@ -602,24 +585,24 @@ class GnuTLSIOHook : public SSLIOHook { // Handshake needs resuming later, read() or write() would have blocked. - if(gnutls_record_get_direction(session->sess) == 0) + if (gnutls_record_get_direction(this->sess) == 0) { // gnutls_handshake() wants to read() again. - session->status = ISSL_HANDSHAKING_READ; + this->status = ISSL_HANDSHAKING_READ; ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); } else { // gnutls_handshake() wants to write() again. - session->status = ISSL_HANDSHAKING_WRITE; + this->status = ISSL_HANDSHAKING_WRITE; ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); } } else { user->SetError("Handshake Failed - " + std::string(gnutls_strerror(ret))); - CloseSession(session); - session->status = ISSL_CLOSING; + CloseSession(); + this->status = ISSL_CLOSING; } return false; @@ -627,9 +610,9 @@ class GnuTLSIOHook : public SSLIOHook else { // Change the seesion state - session->status = ISSL_HANDSHAKEN; + this->status = ISSL_HANDSHAKEN; - VerifyCertificate(session,user); + VerifyCertificate(); // Finish writing, if any left ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE); @@ -638,12 +621,9 @@ class GnuTLSIOHook : public SSLIOHook } } - void VerifyCertificate(issl_session* session, StreamSocket* user) + void VerifyCertificate() { - if (!session->sess || !user) - return; - - unsigned int status; + unsigned int certstatus; const gnutls_datum_t* cert_list; int ret; unsigned int cert_list_size; @@ -653,12 +633,12 @@ class GnuTLSIOHook : public SSLIOHook size_t digest_size = sizeof(digest); size_t name_size = sizeof(str); ssl_cert* certinfo = new ssl_cert; - session->cert = certinfo; + this->certificate = certinfo; /* This verification function uses the trusted CAs in the credentials * structure. So you must have installed one or more CA certificates. */ - ret = gnutls_certificate_verify_peers2(session->sess, &status); + ret = gnutls_certificate_verify_peers2(this->sess, &certstatus); if (ret < 0) { @@ -666,16 +646,16 @@ class GnuTLSIOHook : public SSLIOHook return; } - certinfo->invalid = (status & GNUTLS_CERT_INVALID); - certinfo->unknownsigner = (status & GNUTLS_CERT_SIGNER_NOT_FOUND); - certinfo->revoked = (status & GNUTLS_CERT_REVOKED); - certinfo->trusted = !(status & GNUTLS_CERT_SIGNER_NOT_CA); + certinfo->invalid = (certstatus & GNUTLS_CERT_INVALID); + certinfo->unknownsigner = (certstatus & GNUTLS_CERT_SIGNER_NOT_FOUND); + certinfo->revoked = (certstatus & GNUTLS_CERT_REVOKED); + certinfo->trusted = !(certstatus & GNUTLS_CERT_SIGNER_NOT_CA); /* Up to here the process is the same for X.509 certificates and * OpenPGP keys. From now on X.509 certificates are assumed. This can * be easily extended to work with openpgp keys as well. */ - if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509) + if (gnutls_certificate_type_get(this->sess) != GNUTLS_CRT_X509) { certinfo->error = "No X509 keys sent"; return; @@ -689,7 +669,7 @@ class GnuTLSIOHook : public SSLIOHook } cert_list_size = 0; - cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size); + cert_list = gnutls_certificate_get_peers(this->sess, &cert_list_size); if (cert_list == NULL) { certinfo->error = "No certificate was found"; @@ -713,7 +693,7 @@ class GnuTLSIOHook : public SSLIOHook gnutls_x509_crt_get_issuer_dn(cert, str, &name_size); certinfo->issuer = str; - if ((ret = gnutls_x509_crt_get_fingerprint(cert, session->profile->GetHash(), digest, &digest_size)) < 0) + if ((ret = gnutls_x509_crt_get_fingerprint(cert, profile->GetHash(), digest, &digest_size)) < 0) { certinfo->error = gnutls_strerror(ret); } @@ -740,8 +720,12 @@ info_done_dealloc: static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t session_wrap, void* buffer, size_t size) { - issl_session* session = reinterpret_cast<issl_session*>(session_wrap); - if (session->socket->GetEventMask() & FD_READ_WILL_BLOCK) + StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap); +#ifdef _WIN32 + GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook()); +#endif + + if (sock->GetEventMask() & FD_READ_WILL_BLOCK) { #ifdef _WIN32 gnutls_transport_set_errno(session->sess, EAGAIN); @@ -751,7 +735,7 @@ info_done_dealloc: return -1; } - int rv = ServerInstance->SE->Recv(session->socket, reinterpret_cast<char *>(buffer), size, 0); + int rv = ServerInstance->SE->Recv(sock, reinterpret_cast<char *>(buffer), size, 0); #ifdef _WIN32 if (rv < 0) @@ -766,14 +750,18 @@ info_done_dealloc: #endif if (rv < (int)size) - ServerInstance->SE->ChangeEventMask(session->socket, FD_READ_WILL_BLOCK); + ServerInstance->SE->ChangeEventMask(sock, FD_READ_WILL_BLOCK); return rv; } static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t session_wrap, const void* buffer, size_t size) { - issl_session* session = reinterpret_cast<issl_session*>(session_wrap); - if (session->socket->GetEventMask() & FD_WRITE_WILL_BLOCK) + StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap); +#ifdef _WIN32 + GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook()); +#endif + + if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK) { #ifdef _WIN32 gnutls_transport_set_errno(session->sess, EAGAIN); @@ -783,7 +771,7 @@ info_done_dealloc: return -1; } - int rv = ServerInstance->SE->Send(session->socket, reinterpret_cast<const char *>(buffer), size, 0); + int rv = ServerInstance->SE->Send(sock, reinterpret_cast<const char *>(buffer), size, 0); #ifdef _WIN32 if (rv < 0) @@ -798,75 +786,55 @@ info_done_dealloc: #endif if (rv < (int)size) - ServerInstance->SE->ChangeEventMask(session->socket, FD_WRITE_WILL_BLOCK); + ServerInstance->SE->ChangeEventMask(sock, FD_WRITE_WILL_BLOCK); return rv; } public: - issl_session* sessions; - - GnuTLSIOHook(Module* parent) - : SSLIOHook(parent, "ssl/gnutls") - { - sessions = new issl_session[ServerInstance->SE->GetMaxFds()]; - } - - ~GnuTLSIOHook() - { - delete[] sessions; - } - - void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE + GnuTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool outbound, const reference<GnuTLS::Profile>& sslprofile) + : SSLIOHook(hookprov) + , sess(NULL) + , status(ISSL_NONE) + , profile(sslprofile) { - issl_session* session = &sessions[user->GetFd()]; - - /* For STARTTLS: Don't try and init a session on a socket that already has a session */ - if (session->sess) - return; - - InitSession(user, true); - } - - void OnStreamSocketConnect(StreamSocket* user) CXX11_OVERRIDE - { - InitSession(user, false); + InitSession(sock, outbound); + sock->AddIOHook(this); + Handshake(sock); } void OnStreamSocketClose(StreamSocket* user) CXX11_OVERRIDE { - CloseSession(&sessions[user->GetFd()]); + CloseSession(); } int OnStreamSocketRead(StreamSocket* user, std::string& recvq) CXX11_OVERRIDE { - issl_session* session = &sessions[user->GetFd()]; - - if (!session->sess) + if (!this->sess) { - CloseSession(session); + CloseSession(); user->SetError("No SSL session"); return -1; } - if (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) + if (this->status == ISSL_HANDSHAKING_READ || this->status == ISSL_HANDSHAKING_WRITE) { // The handshake isn't finished, try to finish it. - if(!Handshake(session, user)) + if (!Handshake(user)) { - if (session->status != ISSL_CLOSING) + if (this->status != ISSL_CLOSING) return 0; return -1; } } - // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN. + // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN. - if (session->status == ISSL_HANDSHAKEN) + if (this->status == ISSL_HANDSHAKEN) { char* buffer = ServerInstance->GetReadBuffer(); size_t bufsiz = ServerInstance->Config->NetBufferSize; - int ret = gnutls_record_recv(session->sess, buffer, bufsiz); + int ret = gnutls_record_recv(this->sess, buffer, bufsiz); if (ret > 0) { recvq.append(buffer, ret); @@ -879,17 +847,17 @@ info_done_dealloc: else if (ret == 0) { user->SetError("Connection closed"); - CloseSession(session); + CloseSession(); return -1; } else { user->SetError(gnutls_strerror(ret)); - CloseSession(session); + CloseSession(); return -1; } } - else if (session->status == ISSL_CLOSING) + else if (this->status == ISSL_CLOSING) return -1; return 0; @@ -897,29 +865,27 @@ info_done_dealloc: int OnStreamSocketWrite(StreamSocket* user, std::string& sendq) CXX11_OVERRIDE { - issl_session* session = &sessions[user->GetFd()]; - - if (!session->sess) + if (!this->sess) { - CloseSession(session); + CloseSession(); user->SetError("No SSL session"); return -1; } - if (session->status == ISSL_HANDSHAKING_WRITE || session->status == ISSL_HANDSHAKING_READ) + if (this->status == ISSL_HANDSHAKING_WRITE || this->status == ISSL_HANDSHAKING_READ) { // The handshake isn't finished, try to finish it. - Handshake(session, user); - if (session->status != ISSL_CLOSING) + Handshake(user); + if (this->status != ISSL_CLOSING) return 0; return -1; } int ret = 0; - if (session->status == ISSL_HANDSHAKEN) + if (this->status == ISSL_HANDSHAKEN) { - ret = gnutls_record_send(session->sess, sendq.data(), sendq.length()); + ret = gnutls_record_send(this->sess, sendq.data(), sendq.length()); if (ret == (int)sendq.length()) { @@ -940,7 +906,7 @@ info_done_dealloc: else // (ret < 0) { user->SetError(gnutls_strerror(ret)); - CloseSession(session); + CloseSession(); return -1; } } @@ -948,16 +914,8 @@ info_done_dealloc: return 0; } - ssl_cert* GetCertificate(StreamSocket* sock) CXX11_OVERRIDE - { - int fd = sock->GetFd(); - issl_session* session = &sessions[fd]; - return session->cert; - } - void TellCiphersAndFingerprint(LocalUser* user) { - const gnutls_session_t& sess = sessions[user->eh.GetFd()].sess; if (sess) { std::string text = "*** You are connected using SSL cipher '"; @@ -966,13 +924,14 @@ info_done_dealloc: text.append("-").append(UnknownIfNULL(gnutls_cipher_get_name(gnutls_cipher_get(sess)))).append("-"); text.append(UnknownIfNULL(gnutls_mac_get_name(gnutls_mac_get(sess)))).append("'"); - ssl_cert* cert = sessions[user->eh.GetFd()].cert; - if (!cert->fingerprint.empty()) - text += " and your SSL fingerprint is " + cert->fingerprint; + if (!certificate->fingerprint.empty()) + text += " and your SSL fingerprint is " + certificate->fingerprint; user->WriteNotice(text); } } + + GnuTLS::Profile* GetProfile() { return profile; } }; int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_datum_t* req_ca_rdn, int nreqs, const gnutls_pk_algorithm_t* sign_algos, int sign_algos_length, cert_cb_last_param_type* st) @@ -983,8 +942,8 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d st->cert_type = GNUTLS_CRT_X509; st->key_type = GNUTLS_PRIVKEY_X509; #endif - issl_session* session = reinterpret_cast<issl_session*>(gnutls_transport_get_ptr(sess)); - GnuTLS::X509Credentials& cred = session->profile->GetX509Credentials(); + StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess)); + GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetIOHook())->GetProfile()->GetX509Credentials(); st->ncerts = cred.certs.size(); st->cert.x509 = cred.certs.raw(); @@ -994,15 +953,41 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d return 0; } +class GnuTLSIOHookProvider : public refcountbase, public IOHookProvider +{ + reference<GnuTLS::Profile> profile; + + public: + GnuTLSIOHookProvider(Module* mod, reference<GnuTLS::Profile>& prof) + : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL) + , profile(prof) + { + ServerInstance->Modules->AddService(*this); + } + + ~GnuTLSIOHookProvider() + { + ServerInstance->Modules->DelService(*this); + } + + void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE + { + new GnuTLSIOHook(this, sock, true, profile); + } + + void OnConnect(StreamSocket* sock) CXX11_OVERRIDE + { + new GnuTLSIOHook(this, sock, false, profile); + } +}; + class ModuleSSLGnuTLS : public Module { - typedef std::vector<reference<GnuTLS::Profile> > ProfileList; + typedef std::vector<reference<GnuTLSIOHookProvider> > ProfileList; // First member of the class, gets constructed first and destructed last GnuTLS::Init libinit; - GnuTLSIOHook iohook; - std::string sslports; RandGen randhandler; @@ -1026,7 +1011,7 @@ class ModuleSSLGnuTLS : public Module try { reference<GnuTLS::Profile> profile(GnuTLS::Profile::Create(defname, tag)); - newprofiles.push_back(profile); + newprofiles.push_back(new GnuTLSIOHookProvider(this, profile)); } catch (CoreException& ex) { @@ -1057,7 +1042,7 @@ class ModuleSSLGnuTLS : public Module throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason()); } - newprofiles.push_back(profile); + newprofiles.push_back(new GnuTLSIOHookProvider(this, profile)); } // New profiles are ok, begin using them @@ -1066,7 +1051,7 @@ class ModuleSSLGnuTLS : public Module } public: - ModuleSSLGnuTLS() : iohook(this) + ModuleSSLGnuTLS() { #ifndef GNUTLS_HAS_RND gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0); @@ -1144,7 +1129,7 @@ class ModuleSSLGnuTLS : public Module { LocalUser* user = IS_LOCAL(static_cast<User*>(item)); - if (user && user->eh.GetIOHook() == &iohook) + if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this) { // User is using SSL, they're a local user, and they're using one of *our* SSL ports. // Potentially there could be multiple SSL modules loaded at once on different ports. @@ -1164,27 +1149,11 @@ class ModuleSSLGnuTLS : public Module tokens["SSL"] = sslports; } - void OnHookIO(StreamSocket* user, ListenSocket* lsb) CXX11_OVERRIDE - { - if (!user->GetIOHook()) - { - std::string profilename = lsb->bind_tag->getString("ssl"); - for (ProfileList::const_iterator i = profiles.begin(); i != profiles.end(); ++i) - { - if ((*i)->GetName() == profilename) - { - iohook.sessions[user->GetFd()].profile = *i; - user->AddIOHook(&iohook); - break; - } - } - } - } - void OnUserConnect(LocalUser* user) CXX11_OVERRIDE { - if (user->eh.GetIOHook() == &iohook) - iohook.TellCiphersAndFingerprint(user); + IOHook* hook = user->eh.GetIOHook(); + if (hook && hook->prov->creator == this) + static_cast<GnuTLSIOHook*>(hook)->TellCiphersAndFingerprint(user); } }; |