X-Git-Url: https://git.netwichtig.de/gitweb/?a=blobdiff_plain;f=src%2Fmodules%2Fextra%2Fm_ssl_gnutls.cpp;h=1ebec075bd7400bebfff7c729951e6808de7e747;hb=b5965b08c23e3e89404b481386f2e56ce7cb7ce2;hp=27c466573c88d7c7a1d02c6b3d683067cf45a0d9;hpb=e2af2347fc035d702e45f12e772223a8d578410d;p=user%2Fhenk%2Fcode%2Finspircd.git diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index 27c466573..1ebec075b 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -2,7 +2,7 @@ * | Inspire Internet Relay Chat Daemon | * +------------------------------------+ * - * InspIRCd: (C) 2002-2009 InspIRCd Development Team + * InspIRCd: (C) 2002-2010 InspIRCd Development Team * See: http://wiki.inspircd.org/Credits * * This program is free but copyrighted software; see @@ -14,7 +14,7 @@ #include "inspircd.h" #include #include -#include "transport.h" +#include "ssl.h" #include "m_cap.h" #ifdef WINDOWS @@ -24,9 +24,8 @@ /* $ModDesc: Provides SSL support for clients */ /* $CompileFlags: pkgconfincludes("gnutls","/gnutls/gnutls.h","") */ /* $LinkerFlags: rpath("pkg-config --libs gnutls") pkgconflibs("gnutls","/libgnutls.so","-lgnutls") */ -/* $ModDep: transport.h */ -/* $CopyInstall: conf/key.pem $(CONPATH) */ -/* $CopyInstall: conf/cert.pem $(CONPATH) */ +/* $CopyInstall: conf/key.pem $(CONPATH) -m 0400 -o $(INSTUID) */ +/* $CopyInstall: conf/cert.pem $(CONPATH) -m 0444 */ enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED }; @@ -44,29 +43,54 @@ static int cert_callback (gnutls_session_t session, const gnutls_datum_t * req_c return 0; } -/** Represents an SSL user's extra data - */ -class issl_session : public classbase +static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t user_wrap, void* buffer, size_t size) { -public: - issl_session() + StreamSocket* user = reinterpret_cast(user_wrap); + if (user->GetEventMask() & FD_READ_WILL_BLOCK) { - sess = NULL; + errno = EAGAIN; + return -1; } + int rv = recv(user->GetFd(), buffer, size, 0); + if (rv < (int)size) + ServerInstance->SE->ChangeEventMask(user, FD_READ_WILL_BLOCK); + return rv; +} +static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t user_wrap, const void* buffer, size_t size) +{ + StreamSocket* user = reinterpret_cast(user_wrap); + if (user->GetEventMask() & FD_WRITE_WILL_BLOCK) + { + errno = EAGAIN; + return -1; + } + int rv = send(user->GetFd(), buffer, size, 0); + if (rv < (int)size) + ServerInstance->SE->ChangeEventMask(user, FD_WRITE_WILL_BLOCK); + return rv; +} + +/** Represents an SSL user's extra data + */ +class issl_session +{ +public: gnutls_session_t sess; issl_status status; + reference cert; + issl_session() : sess(NULL) {} }; -class CommandStartTLS : public Command +class CommandStartTLS : public SplitCommand { public: - CommandStartTLS (Module* mod) : Command(mod, "STARTTLS") + CommandStartTLS (Module* mod) : SplitCommand(mod, "STARTTLS") { works_before_reg = true; } - CmdResult Handle (const std::vector ¶meters, User *user) + CmdResult HandleLocal(const std::vector ¶meters, LocalUser *user) { /* changed from == REG_ALL to catch clients sending STARTTLS * after NICK and USER but before OnUserConnect completes and @@ -78,11 +102,11 @@ class CommandStartTLS : public Command } else { - if (!user->GetIOHook()) + if (!user->eh.GetIOHook()) { user->WriteNumeric(670, "%s :STARTTLS successful, go ahead with TLS handshake", user->nick.c_str()); - user->AddIOHook(creator); - creator->OnStreamSocketAccept(user, NULL, NULL); + user->eh.AddIOHook(creator); + creator->OnStreamSocketAccept(&user->eh, NULL, NULL); } else user->WriteNumeric(691, "%s :STARTTLS failure", user->nick.c_str()); @@ -94,8 +118,6 @@ class CommandStartTLS : public Command class ModuleSSLGnuTLS : public Module { - std::set listenports; - issl_session* sessions; gnutls_certificate_credentials x509_cred; @@ -103,6 +125,7 @@ class ModuleSSLGnuTLS : public Module std::string keyfile; std::string certfile; + std::string cafile; std::string crlfile; std::string sslports; @@ -113,13 +136,12 @@ class ModuleSSLGnuTLS : public Module CommandStartTLS starttls; GenericCap capHandler; + ServiceProvider iohook; public: - ModuleSSLGnuTLS(InspIRCd* Me) - : Module(Me), starttls(this), capHandler(this, "tls") + ModuleSSLGnuTLS() + : starttls(this), capHandler(this, "tls"), iohook(this, "ssl/gnutls", SERVICE_IOHOOK) { - ServerInstance->Modules->PublishInterface("BufferedSocketHook", this); - sessions = new issl_session[ServerInstance->SE->GetMaxFds()]; gnutls_global_init(); // This must be called once in the program @@ -127,37 +149,39 @@ class ModuleSSLGnuTLS : public Module gnutls_x509_privkey_init(&x509_key); cred_alloc = false; + } + + void init() + { // Needs the flag as it ignores a plain /rehash OnModuleRehash(NULL,"ssl"); // Void return, guess we assume success gnutls_certificate_set_dh_params(x509_cred, dh_params); - Implementation eventlist[] = { I_On005Numeric, I_OnRequest, I_OnRehash, I_OnModuleRehash, I_OnPostConnect, + Implementation eventlist[] = { I_On005Numeric, I_OnRehash, I_OnModuleRehash, I_OnUserConnect, I_OnEvent, I_OnHookIO }; ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); + ServerInstance->Modules->AddService(iohook); ServerInstance->AddCommand(&starttls); } void OnRehash(User* user) { - ConfigReader Conf(ServerInstance); + ConfigReader Conf; - listenports.clear(); sslports.clear(); for (size_t i = 0; i < ServerInstance->ports.size(); i++) { - ListenSocketBase* port = ServerInstance->ports[i]; - std::string desc = port->GetDescription(); - if (desc != "gnutls") + ListenSocket* port = ServerInstance->ports[i]; + if (port->bind_tag->getString("ssl") != "gnutls") continue; - listenports.insert(port); - std::string portid = port->GetBindDesc(); - + const std::string& portid = port->bind_desc; ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %s", portid.c_str()); - if (port->GetIP() != "127.0.0.1") + + if (port->bind_tag->getString("type", "clients") == "clients" && port->bind_addr != "127.0.0.1") sslports.append(portid).append(";"); } @@ -172,11 +196,7 @@ class ModuleSSLGnuTLS : public Module OnRehash(user); - ConfigReader Conf(ServerInstance); - - std::string confdir(ServerInstance->ConfigFileName); - // +1 so we the path ends with a / - confdir = confdir.substr(0, confdir.find_last_of('/') + 1); + ConfigReader Conf; cafile = Conf.ReadValue("gnutls", "cafile", 0); crlfile = Conf.ReadValue("gnutls", "crlfile", 0); @@ -186,33 +206,20 @@ class ModuleSSLGnuTLS : public Module // Set all the default values needed. if (cafile.empty()) - cafile = "ca.pem"; + cafile = "conf/ca.pem"; if (crlfile.empty()) - crlfile = "crl.pem"; + crlfile = "conf/crl.pem"; if (certfile.empty()) - certfile = "cert.pem"; + certfile = "conf/cert.pem"; if (keyfile.empty()) - keyfile = "key.pem"; + keyfile = "conf/key.pem"; if((dh_bits != 768) && (dh_bits != 1024) && (dh_bits != 2048) && (dh_bits != 3072) && (dh_bits != 4096)) dh_bits = 1024; - // Prepend relative paths with the path to the config directory. - if ((cafile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(cafile))) - cafile = confdir + cafile; - - if ((crlfile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(crlfile))) - crlfile = confdir + crlfile; - - if ((certfile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(certfile))) - certfile = confdir + certfile; - - if ((keyfile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(keyfile))) - keyfile = confdir + keyfile; - int ret; if (cred_alloc) @@ -233,7 +240,7 @@ class ModuleSSLGnuTLS : public Module if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0) ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret)); - FileReader reader(ServerInstance); + FileReader reader; reader.LoadFile(certfile); std::string cert_string = reader.Contents(); @@ -279,10 +286,12 @@ class ModuleSSLGnuTLS : public Module { gnutls_x509_crt_deinit(x509_cert); gnutls_x509_privkey_deinit(x509_key); - gnutls_dh_params_deinit(dh_params); - gnutls_certificate_free_credentials(x509_cred); + if (cred_alloc) + { + gnutls_dh_params_deinit(dh_params); + gnutls_certificate_free_credentials(x509_cred); + } gnutls_global_deinit(); - ServerInstance->Modules->UnpublishInterface("BufferedSocketHook", this); delete[] sessions; } @@ -290,21 +299,20 @@ class ModuleSSLGnuTLS : public Module { if(target_type == TYPE_USER) { - User* user = static_cast(item); + LocalUser* user = IS_LOCAL(static_cast(item)); - if (user->GetIOHook() == this) + if (user && user->eh.GetIOHook() == this) { // User is using SSL, they're a local user, and they're using one of *our* SSL ports. // Potentially there could be multiple SSL modules loaded at once on different ports. ServerInstance->Users->QuitUser(user, "SSL module unloading"); - user->DelIOHook(); } } } Version GetVersion() { - return Version("$Id$", VF_VENDOR, API_VERSION); + return Version("Provides SSL support for clients", VF_VENDOR); } @@ -315,64 +323,27 @@ class ModuleSSLGnuTLS : public Module output.append(" STARTTLS"); } - void OnHookIO(StreamSocket* user, ListenSocketBase* lsb) + void OnHookIO(StreamSocket* user, ListenSocket* lsb) { - if (!user->GetIOHook() && listenports.find(lsb) != listenports.end()) + if (!user->GetIOHook() && lsb->bind_tag->getString("ssl") == "gnutls") { /* Hook the user with our module */ user->AddIOHook(this); } } - const char* OnRequest(Request* request) + void OnRequest(Request& request) { - ISHRequest* ISR = static_cast(request); - if (strcmp("IS_NAME", request->GetId()) == 0) + if (strcmp("GET_SSL_CERT", request.id) == 0) { - return "gnutls"; - } - else if (strcmp("IS_HOOK", request->GetId()) == 0) - { - ISR->Sock->AddIOHook(this); - return "OK"; - } - else if (strcmp("IS_UNHOOK", request->GetId()) == 0) - { - ISR->Sock->DelIOHook(); - return "OK"; - } - else if (strcmp("IS_HSDONE", request->GetId()) == 0) - { - if (ISR->Sock->GetFd() < 0) - return "OK"; + SocketCertificateRequest& req = static_cast(request); + int fd = req.sock->GetFd(); + issl_session* session = &sessions[fd]; - issl_session* session = &sessions[ISR->Sock->GetFd()]; - return (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) ? NULL : "OK"; + req.cert = session->cert; } - else if (strcmp("IS_ATTACH", request->GetId()) == 0) - { - if (ISR->Sock->GetFd() > -1) - { - issl_session* session = &sessions[ISR->Sock->GetFd()]; - if (session->sess) - { - if (static_cast(ServerInstance->SE->GetRef(ISR->Sock->GetFd())) == static_cast(ISR->Sock)) - { - return "OK"; - } - } - } - } - else if (strcmp("GET_CERT", request->GetId()) == 0) - { - Module* sslinfo = ServerInstance->Modules->Find("m_sslinfo.so"); - if (sslinfo) - return sslinfo->OnRequest(request); - } - return NULL; } - void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) { int fd = user->GetFd(); @@ -388,7 +359,9 @@ class ModuleSSLGnuTLS : public Module gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred); gnutls_dh_set_prime_bits(session->sess, dh_bits); - gnutls_transport_set_ptr(session->sess, reinterpret_cast(fd)); // Give gnutls the fd for the socket. + gnutls_transport_set_ptr(session->sess, reinterpret_cast(user)); + gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper); + gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper); gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any. @@ -404,7 +377,9 @@ class ModuleSSLGnuTLS : public Module gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate. gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred); gnutls_dh_set_prime_bits(session->sess, dh_bits); - gnutls_transport_set_ptr(session->sess, reinterpret_cast(user->GetFd())); + gnutls_transport_set_ptr(session->sess, reinterpret_cast(user)); + gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper); + gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper); Handshake(session, user); } @@ -425,7 +400,7 @@ class ModuleSSLGnuTLS : public Module return -1; } - if (session->status == ISSL_HANDSHAKING_READ) + if (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) { // The handshake isn't finished, try to finish it. @@ -433,15 +408,9 @@ class ModuleSSLGnuTLS : public Module { if (session->status != ISSL_CLOSING) return 0; - user->SetError("Handshake Failed"); return -1; } } - else if (session->status == ISSL_HANDSHAKING_WRITE) - { - MakePollWrite(user); - return 0; - } // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN. @@ -449,35 +418,27 @@ class ModuleSSLGnuTLS : public Module { char* buffer = ServerInstance->GetReadBuffer(); size_t bufsiz = ServerInstance->Config->NetBufferSize; - size_t len = 0; - while (len < bufsiz) + int ret = gnutls_record_recv(session->sess, buffer, bufsiz); + if (ret > 0) { - int ret = gnutls_record_recv(session->sess, buffer + len, bufsiz - len); - if (ret > 0) - { - len += ret; - } - else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) - { - break; - } - else if (ret == 0) - { - user->SetError("SSL Connection closed"); - CloseSession(session); - return -1; - } - else - { - user->SetError(gnutls_strerror(ret)); - CloseSession(session); - return -1; - } + recvq.append(buffer, ret); + return 1; } - if (len) + else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) { - recvq.append(buffer, len); - return 1; + return 0; + } + else if (ret == 0) + { + user->SetError("SSL Connection closed"); + CloseSession(session); + return -1; + } + else + { + user->SetError(gnutls_strerror(ret)); + CloseSession(session); + return -1; } } else if (session->status == ISSL_CLOSING) @@ -503,7 +464,6 @@ class ModuleSSLGnuTLS : public Module Handshake(session, user); if (session->status != ISSL_CLOSING) return 0; - user->SetError("Handshake Failed"); return -1; } @@ -515,17 +475,18 @@ class ModuleSSLGnuTLS : public Module if (ret == (int)sendq.length()) { + ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_WRITE); return 1; } else if (ret > 0) { sendq = sendq.substr(ret); - MakePollWrite(user); + ServerInstance->SE->ChangeEventMask(user, FD_WANT_SINGLE_WRITE); return 0; } else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) { - MakePollWrite(user); + ServerInstance->SE->ChangeEventMask(user, FD_WANT_SINGLE_WRITE); return 0; } else if (ret == 0) @@ -545,7 +506,7 @@ class ModuleSSLGnuTLS : public Module return 0; } - bool Handshake(issl_session* session, EventHandler* user) + bool Handshake(issl_session* session, StreamSocket* user) { int ret = gnutls_handshake(session->sess); @@ -559,16 +520,18 @@ class ModuleSSLGnuTLS : public Module { // gnutls_handshake() wants to read() again. session->status = ISSL_HANDSHAKING_READ; + ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); } else { // gnutls_handshake() wants to write() again. session->status = ISSL_HANDSHAKING_WRITE; - MakePollWrite(user); + ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); } } else { + user->SetError(std::string("Handshake Failed - ") + gnutls_strerror(ret)); CloseSession(session); session->status = ISSL_CLOSING; } @@ -583,20 +546,19 @@ class ModuleSSLGnuTLS : public Module VerifyCertificate(session,user); // Finish writing, if any left - MakePollWrite(user); + ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE); return true; } } - void OnPostConnect(User* user) + void OnUserConnect(LocalUser* user) { - // This occurs AFTER OnUserConnect so we can be sure the - // protocol module has propagated the NICK message. - if (user->GetIOHook() == this && (IS_LOCAL(user))) + if (user->eh.GetIOHook() == this) { if (sessions[user->GetFd()].sess) { + SSLCertSubmission(user, this, ServerInstance->Modules->Find("m_sslinfo.so"), sessions[user->GetFd()].cert); std::string cipher = gnutls_kx_get_name(gnutls_kx_get(sessions[user->GetFd()].sess)); cipher.append("-").append(gnutls_cipher_get_name(gnutls_cipher_get(sessions[user->GetFd()].sess))).append("-"); cipher.append(gnutls_mac_get_name(gnutls_mac_get(sessions[user->GetFd()].sess))); @@ -605,30 +567,21 @@ class ModuleSSLGnuTLS : public Module } } - void MakePollWrite(EventHandler* eh) - { - ServerInstance->SE->WantWrite(eh); - } - void CloseSession(issl_session* session) { - if(session->sess) + if (session->sess) { gnutls_bye(session->sess, GNUTLS_SHUT_WR); gnutls_deinit(session->sess); } - session->sess = NULL; + session->cert = NULL; session->status = ISSL_NONE; } - void VerifyCertificate(issl_session* session, Extensible* user) + void VerifyCertificate(issl_session* session, StreamSocket* user) { - if (!session->sess || !user) - return; - - Module* sslinfo = ServerInstance->Modules->Find("m_sslinfo.so"); - if (!sslinfo) + if (!session->sess || !user || session->cert) return; unsigned int status; @@ -641,6 +594,7 @@ class ModuleSSLGnuTLS : public Module size_t digest_size = sizeof(digest); size_t name_size = sizeof(name); ssl_cert* certinfo = new ssl_cert; + session->cert = certinfo; /* This verification function uses the trusted CAs in the credentials * structure. So you must have installed one or more CA certificates. @@ -650,7 +604,7 @@ class ModuleSSLGnuTLS : public Module if (ret < 0) { certinfo->error = std::string(gnutls_strerror(ret)); - goto info_done; + return; } certinfo->invalid = (status & GNUTLS_CERT_INVALID); @@ -665,14 +619,14 @@ class ModuleSSLGnuTLS : public Module if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509) { certinfo->error = "No X509 keys sent"; - goto info_done; + return; } ret = gnutls_x509_crt_init(&cert); if (ret < 0) { certinfo->error = gnutls_strerror(ret); - goto info_done; + return; } cert_list_size = 0; @@ -718,11 +672,9 @@ class ModuleSSLGnuTLS : public Module info_done_dealloc: gnutls_x509_crt_deinit(cert); -info_done: - BufferedSocketFingerprintSubmission(user, this, sslinfo, certinfo).Send(); } - void OnEvent(Event* ev) + void OnEvent(Event& ev) { capHandler.HandleEvent(ev); }