]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/modules/extra/m_ssl_gnutls.cpp
Add initial query support to m_mysql [patch by Athenon]
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_gnutls.cpp
index c81be9f7bb16f476f4f9399dac0947ccf5f7252a..f5133a1dc7a9a332cf12d42d5aba890fc7e231ce 100644 (file)
 
 enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED };
 
-bool isin(const std::string &host, int port, const std::vector<std::string> &portlist)
-{
-       if (std::find(portlist.begin(), portlist.end(), "*:" + ConvToStr(port)) != portlist.end())
-               return true;
-
-       if (std::find(portlist.begin(), portlist.end(), ":" + ConvToStr(port)) != portlist.end())
-               return true;
-
-       return std::find(portlist.begin(), portlist.end(), host + ":" + ConvToStr(port)) != portlist.end();
-}
-
 /** Represents an SSL user's extra data
  */
 class issl_session : public classbase
@@ -58,11 +47,9 @@ public:
 
 class CommandStartTLS : public Command
 {
-       Module* Caller;
  public:
-       CommandStartTLS (InspIRCd* Instance, Module* mod) : Command(Instance,"STARTTLS", 0, 0, true), Caller(mod)
+       CommandStartTLS (InspIRCd* Instance, Module* mod) : Command(Instance, mod, "STARTTLS", 0, 0, true)
        {
-               this->source = "m_ssl_gnutls.so";
        }
 
        CmdResult Handle (const std::vector<std::string> &parameters, User *user)
@@ -80,8 +67,8 @@ class CommandStartTLS : public Command
                        if (!user->GetIOHook())
                        {
                                user->WriteNumeric(670, "%s :STARTTLS successful, go ahead with TLS handshake", user->nick.c_str());
-                               user->AddIOHook(Caller);
-                               Caller->OnRawSocketAccept(user->GetFd(), user->GetIPString(), user->GetPort());
+                               user->AddIOHook(creator);
+                               creator->OnRawSocketAccept(user->GetFd(), NULL, NULL);
                        }
                        else
                                user->WriteNumeric(691, "%s :STARTTLS failure", user->nick.c_str());
@@ -93,7 +80,7 @@ class CommandStartTLS : public Command
 
 class ModuleSSLGnuTLS : public Module
 {
-       std::vector<std::string> listenports;
+       std::set<ListenSocketBase*> listenports;
 
        issl_session* sessions;
 
@@ -107,15 +94,14 @@ class ModuleSSLGnuTLS : public Module
        std::string sslports;
        int dh_bits;
 
-       int clientactive;
        bool cred_alloc;
 
-       CommandStartTLS* starttls;
+       CommandStartTLS starttls;
 
  public:
 
        ModuleSSLGnuTLS(InspIRCd* Me)
-               : Module(Me)
+               : Module(Me), starttls(Me, this)
        {
                ServerInstance->Modules->PublishInterface("BufferedSocketHook", this);
 
@@ -131,12 +117,11 @@ class ModuleSSLGnuTLS : public Module
                gnutls_certificate_set_dh_params(x509_cred, dh_params);
                Implementation eventlist[] = { I_On005Numeric, I_OnRawSocketConnect, I_OnRawSocketAccept,
                        I_OnRawSocketClose, I_OnRawSocketRead, I_OnRawSocketWrite, I_OnCleanup,
-                       I_OnBufferFlushed, I_OnRequest, I_OnUnloadModule, I_OnRehash, I_OnModuleRehash,
-                       I_OnPostConnect, I_OnEvent, I_OnHookUserIO };
+                       I_OnBufferFlushed, I_OnRequest, I_OnRehash, I_OnModuleRehash, I_OnPostConnect,
+                       I_OnEvent, I_OnHookIO };
                ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
 
-               starttls = new CommandStartTLS(ServerInstance, this);
-               ServerInstance->AddCommand(starttls);
+               ServerInstance->AddCommand(&starttls);
        }
 
        virtual void OnRehash(User* user)
@@ -144,51 +129,21 @@ class ModuleSSLGnuTLS : public Module
                ConfigReader Conf(ServerInstance);
 
                listenports.clear();
-               clientactive = 0;
                sslports.clear();
 
-               for(int index = 0; index < Conf.Enumerate("bind"); index++)
+               for (size_t i = 0; i < ServerInstance->ports.size(); i++)
                {
-                       // For each <bind> tag
-                       std::string x = Conf.ReadValue("bind", "type", index);
-                       if(((x.empty()) || (x == "clients")) && (Conf.ReadValue("bind", "ssl", index) == "gnutls"))
-                       {
-                               // Get the port we're meant to be listening on with SSL
-                               std::string port = Conf.ReadValue("bind", "port", index);
-                               std::string addr = Conf.ReadValue("bind", "address", index);
-
-                               if (!addr.empty())
-                               {
-                                       // normalize address, important for IPv6
-                                       int portint = 0;
-                                       irc::sockets::sockaddrs bin;
-                                       if (irc::sockets::aptosa(addr.c_str(), portint, &bin))
-                                               irc::sockets::satoap(&bin, addr, portint);
-                               }
-
-                               irc::portparser portrange(port, false);
-                               long portno = -1;
-                               while ((portno = portrange.GetToken()))
-                               {
-                                       clientactive++;
-                                       try
-                                       {
-                                               listenports.push_back(addr + ":" + ConvToStr(portno));
+                       ListenSocketBase* port = ServerInstance->ports[i];
+                       std::string desc = port->GetDescription();
+                       if (desc != "gnutls")
+                               continue;
 
-                                               for (size_t i = 0; i < ServerInstance->ports.size(); i++)
-                                                       if ((ServerInstance->ports[i]->GetPort() == portno) && (ServerInstance->ports[i]->GetIP() == addr))
-                                                               ServerInstance->ports[i]->SetDescription("ssl");
-                                               ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %ld", portno);
+                       listenports.insert(port);
+                       std::string portid = port->GetBindDesc();
 
-                                               if (addr != "127.0.0.1")
-                                                       sslports.append((addr.empty() ? "*" : addr)).append(":").append(ConvToStr(portno)).append(";");
-                                       }
-                                       catch (ModuleException &e)
-                                       {
-                                               ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: FAILED to enable SSL on port %ld: %s. Maybe it's already hooked by the same port on a different IP, or you have an other SSL or similar module loaded?", portno, e.GetReason());
-                                       }
-                               }
-                       }
+                       ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %s", portid.c_str());
+                       if (port->GetIP() != "127.0.0.1")
+                               sslports.append(portid).append(";");
                }
 
                if (!sslports.empty())
@@ -255,16 +210,16 @@ class ModuleSSLGnuTLS : public Module
                        cred_alloc = true;
 
                if((ret = gnutls_certificate_allocate_credentials(&x509_cred)) < 0)
-                       ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to allocate certificate credentials: %s", gnutls_strerror(ret));
+                       ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to allocate certificate credentials: %s", gnutls_strerror(ret));
 
                if((ret = gnutls_dh_params_init(&dh_params)) < 0)
-                       ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to initialise DH parameters: %s", gnutls_strerror(ret));
+                       ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to initialise DH parameters: %s", gnutls_strerror(ret));
 
                if((ret =gnutls_certificate_set_x509_trust_file(x509_cred, cafile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
-                       ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret));
+                       ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret));
 
                if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
-                       ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret));
+                       ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret));
 
                if((ret = gnutls_certificate_set_x509_key_file (x509_cred, certfile.c_str(), keyfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
                {
@@ -302,7 +257,7 @@ class ModuleSSLGnuTLS : public Module
        {
                if(target_type == TYPE_USER)
                {
-                       User* user = (User*)item;
+                       User* user = static_cast<User*>(item);
 
                        if (user->GetIOHook() == this)
                        {
@@ -321,19 +276,6 @@ class ModuleSSLGnuTLS : public Module
                }
        }
 
-       virtual void OnUnloadModule(Module* mod, const std::string &name)
-       {
-               if(mod == this)
-               {
-                       for(unsigned int i = 0; i < listenports.size(); i++)
-                       {
-                               for (size_t j = 0; j < ServerInstance->ports.size(); j++)
-                                       if (listenports[i] == (ServerInstance->ports[j]->GetIP()+":"+ConvToStr(ServerInstance->ports[j]->GetPort())))
-                                               ServerInstance->ports[j]->SetDescription("plaintext");
-                       }
-               }
-       }
-
        virtual Version GetVersion()
        {
                return Version("$Id$", VF_VENDOR, API_VERSION);
@@ -347,9 +289,9 @@ class ModuleSSLGnuTLS : public Module
                output.append(" STARTTLS");
        }
 
-       virtual void OnHookUserIO(User* user, const std::string &targetip)
+       virtual void OnHookIO(EventHandler* user, ListenSocketBase* lsb)
        {
-               if (!user->GetIOHook() && isin(targetip,user->GetPort(),listenports))
+               if (!user->GetIOHook() && listenports.find(lsb) != listenports.end())
                {
                        /* Hook the user with our module */
                        user->AddIOHook(this);
@@ -358,7 +300,7 @@ class ModuleSSLGnuTLS : public Module
 
        virtual const char* OnRequest(Request* request)
        {
-               ISHRequest* ISR = (ISHRequest*)request;
+               ISHRequest* ISR = static_cast<ISHRequest*>(request);
                if (strcmp("IS_NAME", request->GetId()) == 0)
                {
                        return "gnutls";
@@ -368,7 +310,7 @@ class ModuleSSLGnuTLS : public Module
                        const char* ret = "OK";
                        try
                        {
-                               ret = ISR->Sock->AddIOHook((Module*)this) ? "OK" : NULL;
+                               ret = ISR->Sock->AddIOHook(this) ? "OK" : NULL;
                        }
                        catch (ModuleException &e)
                        {
@@ -395,9 +337,9 @@ class ModuleSSLGnuTLS : public Module
                                issl_session* session = &sessions[ISR->Sock->GetFd()];
                                if (session->sess)
                                {
-                                       if ((Extensible*)ServerInstance->SE->GetRef(ISR->Sock->GetFd()) == (Extensible*)(ISR->Sock))
+                                       if (static_cast<Extensible*>(ServerInstance->SE->GetRef(ISR->Sock->GetFd())) == static_cast<Extensible*>(ISR->Sock))
                                        {
-                                               VerifyCertificate(session, (BufferedSocket*)ISR->Sock);
+                                               VerifyCertificate(session, ISR->Sock);
                                                return "OK";
                                        }
                                }
@@ -421,7 +363,7 @@ class ModuleSSLGnuTLS : public Module
        }
 
 
-       virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport)
+       virtual void OnRawSocketAccept(int fd, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
        {
                /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
                if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
@@ -439,7 +381,7 @@ class ModuleSSLGnuTLS : public Module
                gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
                gnutls_dh_set_prime_bits(session->sess, dh_bits);
 
-               gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
+               gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(fd)); // Give gnutls the fd for the socket.
 
                gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any.
 
@@ -459,7 +401,7 @@ class ModuleSSLGnuTLS : public Module
                gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate.
                gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
                gnutls_dh_set_prime_bits(session->sess, dh_bits);
-               gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
+               gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(fd)); // Give gnutls the fd for the socket.
 
                Handshake(session, fd);
        }
@@ -520,34 +462,43 @@ class ModuleSSLGnuTLS : public Module
 
                if (session->status == ISSL_HANDSHAKEN)
                {
-                       int ret = gnutls_record_recv(session->sess, buffer, count);
-
-                       if (ret > 0)
+                       unsigned int len = 0;
+                       while (len < count)
                        {
-                               readresult = ret;
+                               int ret = gnutls_record_recv(session->sess, buffer + len, count - len);
+                               if (ret > 0)
+                               {
+                                       len += ret;
+                               }
+                               else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
+                               {
+                                       break;
+                               }
+                               else
+                               {
+                                       if (ret != 0)
+                                               ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT,
+                                                       "m_ssl_gnutls.so: Error while reading on fd %d: %s",
+                                                       fd, gnutls_strerror(ret));
+
+                                       // if ret == 0, client closed connection.
+                                       readresult = 0;
+                                       CloseSession(session);
+                                       return 1;
+                               }
                        }
-                       else if (ret == 0)
+                       readresult = len;
+                       if (len)
                        {
-                               // Client closed connection.
-                               readresult = 0;
-                               CloseSession(session);
                                return 1;
                        }
-                       else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
+                       else
                        {
                                errno = EAGAIN;
                                return -1;
                        }
-                       else
-                       {
-                               ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT,
-                                               "m_ssl_gnutls.so: Error while reading on fd %d: %s",
-                                               fd, gnutls_strerror(ret));
-                               readresult = 0;
-                               CloseSession(session);
-                       }
                }
-               else if(session->status == ISSL_CLOSING)
+               else if (session->status == ISSL_CLOSING)
                        readresult = 0;
 
                return 1;
@@ -610,7 +561,8 @@ class ModuleSSLGnuTLS : public Module
                        }
                }
 
-               MakePollWrite(fd);
+               if (!session->outbuf.empty())
+                       MakePollWrite(fd);
 
                /* Who's smart idea was it to return 1 when we havent written anything?
                 * This fucks the buffer up in BufferedSocket :p
@@ -659,8 +611,7 @@ class ModuleSSLGnuTLS : public Module
                        EventHandler *extendme = ServerInstance->SE->GetRef(fd);
                        if (extendme)
                        {
-                               if (!extendme->GetExt("ssl"))
-                                       extendme->Extend("ssl", "ON");
+                               extendme->Extend("ssl");
                        }
 
                        // Change the seesion state
@@ -688,9 +639,9 @@ class ModuleSSLGnuTLS : public Module
                                user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick.c_str(), cipher.c_str());
                        }
 
-                       ServerInstance->PI->SendMetaData(user, TYPE_USER, "ssl", "ON");
+                       ServerInstance->PI->SendMetaData(user, "ssl", "ON");
                        if (certdata)
-                               ServerInstance->PI->SendMetaData(user, TYPE_USER, "ssl_cert", certdata->GetMetaLine().c_str());
+                               ServerInstance->PI->SendMetaData(user, "ssl_cert", certdata->GetMetaLine().c_str());
                }
        }