]> git.netwichtig.de Git - user/henk/code/inspircd.git/commitdiff
Merge branch 'master+websocket'
authorAttila Molnar <attilamolnar@hush.com>
Wed, 10 Aug 2016 13:57:04 +0000 (15:57 +0200)
committerAttila Molnar <attilamolnar@hush.com>
Wed, 10 Aug 2016 13:57:04 +0000 (15:57 +0200)
19 files changed:
docs/conf/inspircd.conf.example
docs/conf/modules.conf.example
include/compat.h
include/inspsocket.h
include/iohook.h
include/socket.h
include/typedefs.h
make/test/compiler.cpp
src/inspsocket.cpp
src/listensocket.cpp
src/modules/extra/m_ssl_gnutls.cpp
src/modules/extra/m_ssl_mbedtls.cpp
src/modules/extra/m_ssl_openssl.cpp
src/modules/m_httpd.cpp
src/modules/m_sha1.cpp [new file with mode: 0644]
src/modules/m_spanningtree/main.cpp
src/modules/m_spanningtree/treesocket1.cpp
src/modules/m_websocket.cpp [new file with mode: 0644]
src/usermanager.cpp

index 33f45535770eeddc07b5dab7a377492da65c9510..16c34cc249824d6d95972cdae10d4d79190a4183 100644 (file)
 
 <bind address="" port="6660-6669" type="clients">
 
+# Listener accepting HTML5 WebSocket connections.
+# Requires the websocket module and SHA-1 hashing support (provided by the sha1
+# module).
+#<bind address="" port="7002" type="clients" hook="websocket">
+
 # When linking servers, the OpenSSL and GnuTLS implementations are completely
 # link-compatible and can be used alongside each other
 # on each end of the link without any significant issues.
index ec58ccfb4b55b60809cbbe60eb34bea597b474f2..b0f9e8d4b377539f4a6e59387b75e79cf81c008e 100644 (file)
 # to a server matching a mask like +b s:server.mask.here from joining.
 #<module name="serverban">
 
+#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#
+# SHA1 module: Allows other modules to generate SHA1 hashes.
+# Required by the WebSocket module.
+#<module name="sha1">
+
 #-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#
 # Showfile: Provides support for showing a text file to users when    #
 # they enter a command.                                               #
 # Set the maximum number of entries on a user's watch list below.
 #<watch maxentries="32">
 
+#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#
+# WebSocket module: Adds HTML5 WebSocket support.
+# Specify hook="websocket" in a <bind> tag to make that port accept
+# WebSocket connections. Compatible with SSL/TLS.
+# Requires SHA-1 hash support available in the sha1 module.
+#<module name="websocket">
+
 #-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#
 # XLine database: Stores all *Lines (G/Z/K/R/any added by other modules)
 # in a file which is re-loaded on restart. This is useful
index e7719bcd7c258f8457fd108934c4d31d55214e62..1e6fc3d45512aecb23373b0d9d34d43d78cf921c 100644 (file)
  */
 #if defined _LIBCPP_VERSION || defined _WIN32
 # define TR1NS std
+# include <array>
 # include <unordered_map>
 # include <type_traits>
 #else
 # define TR1NS std::tr1
+# include <tr1/array>
 # include <tr1/unordered_map>
 # include <tr1/type_traits>
 #endif
index 53eca2e91d4c292c31a396161cae342ae2bc9401..751374fdf1aab978bfeae12e45c3ade7881359d2 100644 (file)
@@ -198,6 +198,13 @@ class CoreExport StreamSocket : public EventHandler
                        nbytes = 0;
                }
 
+               void moveall(SendQueue& other)
+               {
+                       nbytes += other.bytes();
+                       data.insert(data.end(), other.data.begin(), other.data.end());
+                       other.clear();
+               }
+
         private:
                /** Private send queue. Note that individual strings may be shared.
                 */
@@ -228,6 +235,28 @@ class CoreExport StreamSocket : public EventHandler
         */
        void DoRead();
 
+       /** Send as much data contained in a SendQueue object as possible.
+        * All data which successfully sent will be removed from the SendQueue.
+        * @param sq SendQueue to flush
+        */
+       void FlushSendQ(SendQueue& sq);
+
+       /** Read incoming data into a receive queue.
+        * @param rq Receive queue to put incoming data into
+        * @return < 0 on error or close, 0 if no new data is ready (but the socket is still connected), > 0 if data was read from the socket and put into the recvq
+        */
+       int ReadToRecvQ(std::string& rq);
+
+       /** Read data from a hook chain recursively, starting at 'hook'.
+        * If 'hook' is NULL, the recvq is filled with data from SocketEngine::Recv(), otherwise it is filled with data from the
+        * next hook in the chain.
+        * @param hook Next IOHook in the chain, can be NULL
+        * @param rq Receive queue to put incoming data into
+        * @return < 0 on error or close, 0 if no new data is ready (but the socket is still connected), > 0 if data was read from
+        the socket and put into the recvq
+        */
+       int HookChainRead(IOHook* hook, std::string& rq);
+
  protected:
        std::string recvq;
  public:
@@ -274,7 +303,7 @@ class CoreExport StreamSocket : public EventHandler
         */
        bool GetNextLine(std::string& line, char delim = '\n');
        /** Useful for implementing sendq exceeded */
-       size_t getSendQSize() const { return sendq.size(); }
+       size_t getSendQSize() const;
 
        SendQueue& GetSendQ() { return sendq; }
 
@@ -284,6 +313,12 @@ class CoreExport StreamSocket : public EventHandler
        virtual void Close();
        /** This ensures that close is called prior to destructor */
        virtual CullResult cull();
+
+       /** Get the IOHook of a module attached to this socket
+        * @param mod Module whose IOHook to return
+        * @return IOHook belonging to the module or NULL if the module haven't attached an IOHook to this socket
+        */
+       IOHook* GetModHook(Module* mod) const;
 };
 /**
  * BufferedSocket is an extendable socket class which modules
@@ -358,5 +393,4 @@ class CoreExport BufferedSocket : public StreamSocket
 };
 
 inline IOHook* StreamSocket::GetIOHook() const { return iohook; }
-inline void StreamSocket::AddIOHook(IOHook* hook) { iohook = hook; }
 inline void StreamSocket::DelIOHook() { iohook = NULL; }
index cf27fcb0cfc4876b7fdba0312fd3aea0d34d6141..e99316b99dc1110c02183649a69abb653e50e988 100644 (file)
@@ -23,6 +23,8 @@ class StreamSocket;
 
 class IOHookProvider : public ServiceProvider
 {
+       const bool middlehook;
+
  public:
        enum Type
        {
@@ -32,21 +34,31 @@ class IOHookProvider : public ServiceProvider
 
        const Type type;
 
-       IOHookProvider(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN)
-               : ServiceProvider(mod, Name, SERVICE_IOHOOK), type(hooktype) { }
+       /** Constructor
+        * @param mod Module that owns the IOHookProvider
+        * @param Name Name of the provider
+        * @param hooktype One of IOHookProvider::Type
+        * @param middle True if the IOHook instances created by this hook are subclasses of IOHookMiddle, false otherwise
+        */
+       IOHookProvider(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN, bool middle = false)
+               : ServiceProvider(mod, Name, SERVICE_IOHOOK), middlehook(middle), type(hooktype) { }
+
+       /** Check if the IOHook provided can appear in the non-last position of a hook chain.
+        * That is the case if and only if the IOHook instances created are subclasses of IOHookMiddle.
+        * @return True if the IOHooks provided are subclasses of IOHookMiddle
+        */
+       bool IsMiddle() const { return middlehook; }
 
-       /** Called immediately after a connection is accepted. This is intended for raw socket
-        * processing (e.g. modules which wrap the tcp connection within another library) and provides
-        * no information relating to a user record as the connection has not been assigned yet.
-        * @param sock The socket in question
-        * @param client The client IP address and port
-        * @param server The server IP address and port
+       /** Called when the provider should hook an incoming connection and act as being on the server side of the connection.
+        * This occurs when a bind block has a hook configured and the listener accepts a connection.
+        * @param sock Socket to hook
+        * @param client Client IP address and port
+        * @param server Server IP address and port
         */
        virtual void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) = 0;
 
-       /** Called immediately upon connection of an outbound BufferedSocket which has been hooked
-        * by a module.
-        * @param sock The socket in question
+       /** Called when the provider should hook an outgoing connection and act as being on the client side of the connection.
+        * @param sock Socket to hook
         */
        virtual void OnConnect(StreamSocket* sock) = 0;
 };
@@ -59,30 +71,97 @@ class IOHook : public classbase
         */
        IOHookProvider* const prov;
 
+       /** Constructor
+        * @param provider IOHookProvider that creates this object
+        */
        IOHook(IOHookProvider* provider)
                : prov(provider) { }
 
        /**
-        * Called when a hooked stream has data to write, or when the socket
-        * engine returns it as writable
-        * @param sock The socket in question
+        * Called when the hooked socket has data to write, or when the socket engine returns it as writable
+        * @param sock Hooked socket
+        * @param sendq Send queue to send data from
         * @return 1 if the sendq has been completely emptied, 0 if there is
         *  still data to send, and -1 if there was an error
         */
-       virtual int OnStreamSocketWrite(StreamSocket* sock) = 0;
+       virtual int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& sendq) = 0;
 
-       /** Called immediately before any socket is closed. When this event is called, shutdown()
+       /** Called immediately before the hooked socket is closed. When this event is called, shutdown()
         * has not yet been called on the socket.
-        * @param sock The socket in question
+        * @param sock Hooked socket
         */
        virtual void OnStreamSocketClose(StreamSocket* sock) = 0;
 
        /**
-        * Called when the stream socket has data to read
-        * @param sock The socket that is ready
+        * Called when the hooked socket has data to read
+        * @param sock Hooked socket
         * @param recvq The receive queue that new data should be appended to
         * @return 1 if new data has been read, 0 if no new data is ready (but the
         *  socket is still connected), -1 if there was an error or close
         */
        virtual int OnStreamSocketRead(StreamSocket* sock, std::string& recvq) = 0;
 };
+
+class IOHookMiddle : public IOHook
+{
+       /** Data already processed by the IOHook waiting to go down the chain
+        */
+       StreamSocket::SendQueue sendq;
+
+       /** Data waiting to go up the chain
+        */
+       std::string precvq;
+
+       /** Next IOHook in the chain
+        */
+       IOHook* nexthook;
+
+ protected:
+       /** Get all queued up data which has not yet been passed up the hook chain
+        * @return RecvQ containing the data
+        */
+       std::string& GetRecvQ() { return precvq; }
+
+       /** Get all queued up data which is ready to go down the hook chain
+        * @return SendQueue containing all data waiting to go down the hook chain
+        */
+       StreamSocket::SendQueue& GetSendQ() { return sendq; }
+
+ public:
+       /** Constructor
+        * @param provider IOHookProvider that creates this object
+        */
+       IOHookMiddle(IOHookProvider* provider)
+               : IOHook(provider)
+               , nexthook(NULL)
+       {
+       }
+
+       /** Get all queued up data which is ready to go down the hook chain
+        * @return SendQueue containing all data waiting to go down the hook chain
+        */
+       const StreamSocket::SendQueue& GetSendQ() const { return sendq; }
+
+       /** Get the next IOHook in the chain
+        * @return Next hook in the chain or NULL if this is the last hook
+        */
+       IOHook* GetNextHook() const { return nexthook; }
+
+       /** Set the next hook in the chain
+        * @param hook Hook to set as the next hook in the chain
+        */
+       void SetNextHook(IOHook* hook) { nexthook = hook; }
+
+       /** Check if a hook is capable of being the non-last hook in a hook chain and if so, cast it to an IOHookMiddle object.
+        * @param hook IOHook to check
+        * @return IOHookMiddle referring to the same hook or NULL
+        */
+       static IOHookMiddle* ToMiddleHook(IOHook* hook)
+       {
+               if (hook->prov->IsMiddle())
+                       return static_cast<IOHookMiddle*>(hook);
+               return NULL;
+       }
+
+       friend class StreamSocket;
+};
index 9d69b5d22cf8e70267e9764888bcb4e9bb26e035..427ee9fe7e3f7823aad489f5b253ba81523a9149 100644 (file)
@@ -127,7 +127,6 @@ namespace irc
        }
 }
 
-#include "iohook.h"
 #include "socketengine.h"
 /** This class handles incoming connections on client ports.
  * It will create a new User for every valid connection
@@ -142,10 +141,21 @@ class CoreExport ListenSocket : public EventHandler
        /** Human-readable bind description */
        std::string bind_desc;
 
-       /** The IOHook provider which handles connections on this socket,
-        * NULL if there is none.
+       class IOHookProvRef : public dynamic_reference_nocheck<IOHookProvider>
+       {
+        public:
+               IOHookProvRef()
+                       : dynamic_reference_nocheck<IOHookProvider>(NULL, std::string())
+               {
+               }
+       };
+
+       typedef TR1NS::array<IOHookProvRef, 2> IOHookProvList;
+
+       /** IOHook providers for handling connections on this socket,
+        * may be empty.
         */
-       dynamic_reference_nocheck<IOHookProvider> iohookprov;
+       IOHookProvList iohookprovs;
 
        /** Create a new listening socket
         */
@@ -160,7 +170,6 @@ class CoreExport ListenSocket : public EventHandler
 
        /** Inspects the bind block belonging to this socket to set the name of the IO hook
         * provider which this socket will use for incoming connections.
-        * @return True if the IO hook provider was found or none was given, false otherwise.
         */
-       bool ResetIOHookProvider();
+       void ResetIOHookProvider();
 };
index 17c05d704a83b9e570e661388ad08b15c5a46a0c..48842ccf0fb442be842b7ffd8e66ec8abd05ec4b 100644 (file)
@@ -31,6 +31,7 @@ class Extensible;
 class FakeUser;
 class InspIRCd;
 class Invitation;
+class IOHookProvider;
 class LocalUser;
 class Membership;
 class Module;
index edf08b8e363c41fdd460d7094191e00fd8eff101..e2cbd9f64f0cd2ea586bcffe169166d49c61e8c9 100644 (file)
 
 #include <iostream>
 #if defined _LIBCPP_VERSION
+# include <array>
 # include <type_traits>
 # include <unordered_map>
 #else
+# include <tr1/array>
 # include <tr1/type_traits>
 # include <tr1/unordered_map>
 #endif
index 89c3a71a93a691850b70b882f5db0fddb8db8679..9bfc6a73e33104bf0ff64d1df9c4b80000aa98e4 100644 (file)
 #include "inspircd.h"
 #include "iohook.h"
 
+static IOHook* GetNextHook(IOHook* hook)
+{
+       IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+       if (iohm)
+               return iohm->GetNextHook();
+       return NULL;
+}
+
 BufferedSocket::BufferedSocket()
 {
        Timeout = NULL;
@@ -112,11 +120,15 @@ void StreamSocket::Close()
        {
                // final chance, dump as much of the sendq as we can
                DoWrite();
-               if (GetIOHook())
+
+               IOHook* hook = GetIOHook();
+               DelIOHook();
+               while (hook)
                {
-                       GetIOHook()->OnStreamSocketClose(this);
-                       delete iohook;
-                       DelIOHook();
+                       hook->OnStreamSocketClose(this);
+                       IOHook* const nexthook = GetNextHook(hook);
+                       delete hook;
+                       hook = nexthook;
                }
                SocketEngine::Shutdown(this, 2);
                SocketEngine::Close(this);
@@ -139,51 +151,74 @@ bool StreamSocket::GetNextLine(std::string& line, char delim)
        return true;
 }
 
-void StreamSocket::DoRead()
+int StreamSocket::HookChainRead(IOHook* hook, std::string& rq)
 {
-       if (GetIOHook())
+       if (!hook)
+               return ReadToRecvQ(rq);
+
+       IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+       if (iohm)
        {
-               int rv = GetIOHook()->OnStreamSocketRead(this, recvq);
-               if (rv > 0)
-                       OnDataReady();
-               if (rv < 0)
-                       SetError("Read Error"); // will not overwrite a better error message
+               // Call the next hook to put data into the recvq of the current hook
+               const int ret = HookChainRead(iohm->GetNextHook(), iohm->GetRecvQ());
+               if (ret <= 0)
+                       return ret;
        }
-       else
+       return hook->OnStreamSocketRead(this, rq);
+}
+
+void StreamSocket::DoRead()
+{
+       const std::string::size_type prevrecvqsize = recvq.size();
+
+       const int result = HookChainRead(GetIOHook(), recvq);
+       if (result < 0)
        {
+               SetError("Read Error"); // will not overwrite a better error message
+               return;
+       }
+
+       if (recvq.size() > prevrecvqsize)
+               OnDataReady();
+}
+
+int StreamSocket::ReadToRecvQ(std::string& rq)
+{
                char* ReadBuffer = ServerInstance->GetReadBuffer();
                int n = SocketEngine::Recv(this, ReadBuffer, ServerInstance->Config->NetBufferSize, 0);
                if (n == ServerInstance->Config->NetBufferSize)
                {
                        SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ | FD_ADD_TRIAL_READ);
-                       recvq.append(ReadBuffer, n);
-                       OnDataReady();
+                       rq.append(ReadBuffer, n);
                }
                else if (n > 0)
                {
                        SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ);
-                       recvq.append(ReadBuffer, n);
-                       OnDataReady();
+                       rq.append(ReadBuffer, n);
                }
                else if (n == 0)
                {
                        error = "Connection closed";
                        SocketEngine::ChangeEventMask(this, FD_WANT_NO_READ | FD_WANT_NO_WRITE);
+                       return -1;
                }
                else if (SocketEngine::IgnoreError())
                {
                        SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ | FD_READ_WILL_BLOCK);
+                       return 0;
                }
                else if (errno == EINTR)
                {
                        SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ | FD_ADD_TRIAL_READ);
+                       return 0;
                }
                else
                {
                        error = SocketEngine::LastError();
                        SocketEngine::ChangeEventMask(this, FD_WANT_NO_READ | FD_WANT_NO_WRITE);
+                       return -1;
                }
-       }
+       return n;
 }
 
 /* Don't try to prepare huge blobs of data to send to a blocked socket */
@@ -191,7 +226,7 @@ static const int MYIOV_MAX = IOV_MAX < 128 ? IOV_MAX : 128;
 
 void StreamSocket::DoWrite()
 {
-       if (sendq.empty())
+       if (getSendQSize() == 0)
                return;
        if (!error.empty() || fd < 0)
        {
@@ -199,26 +234,48 @@ void StreamSocket::DoWrite()
                return;
        }
 
-       if (GetIOHook())
+       SendQueue* psendq = &sendq;
+       IOHook* hook = GetIOHook();
+       while (hook)
        {
-               int rv = GetIOHook()->OnStreamSocketWrite(this);
-               if (rv < 0)
-                       SetError("Write Error"); // will not overwrite a better error message
+               int rv = hook->OnStreamSocketWrite(this, *psendq);
+               psendq = NULL;
 
                // rv == 0 means the socket has blocked. Stop trying to send data.
                // IOHook has requested unblock notification from the socketengine.
+               if (rv == 0)
+                       break;
+
+               if (rv < 0)
+               {
+                       SetError("Write Error"); // will not overwrite a better error message
+                       break;
+               }
+
+               IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+               hook = NULL;
+               if (iohm)
+               {
+                       psendq = &iohm->GetSendQ();
+                       hook = iohm->GetNextHook();
+               }
        }
-       else
-       {
+
+       if (psendq)
+               FlushSendQ(*psendq);
+}
+
+void StreamSocket::FlushSendQ(SendQueue& sq)
+{
                // don't even try if we are known to be blocking
                if (GetEventMask() & FD_WRITE_WILL_BLOCK)
                        return;
                // start out optimistic - we won't need to write any more
                int eventChange = FD_WANT_EDGE_WRITE;
-               while (error.empty() && !sendq.empty() && eventChange == FD_WANT_EDGE_WRITE)
+               while (error.empty() && !sq.empty() && eventChange == FD_WANT_EDGE_WRITE)
                {
                        // Prepare a writev() call to write all buffers efficiently
-                       int bufcount = sendq.size();
+                       int bufcount = sq.size();
 
                        // cap the number of buffers at MYIOV_MAX
                        if (bufcount > MYIOV_MAX)
@@ -231,7 +288,7 @@ void StreamSocket::DoWrite()
                        {
                                SocketEngine::IOVector iovecs[MYIOV_MAX];
                                size_t j = 0;
-                               for (SendQueue::const_iterator i = sendq.begin(), end = i+bufcount; i != end; ++i, j++)
+                               for (SendQueue::const_iterator i = sq.begin(), end = i+bufcount; i != end; ++i, j++)
                                {
                                        const SendQueue::Element& elem = *i;
                                        iovecs[j].iov_base = const_cast<char*>(elem.data());
@@ -241,11 +298,11 @@ void StreamSocket::DoWrite()
                                rv = SocketEngine::WriteV(this, iovecs, bufcount);
                        }
 
-                       if (rv == (int)sendq.bytes())
+                       if (rv == (int)sq.bytes())
                        {
                                // it's our lucky day, everything got written out. Fast cleanup.
                                // This won't ever happen if the number of buffers got capped.
-                               sendq.clear();
+                               sq.clear();
                        }
                        else if (rv > 0)
                        {
@@ -255,19 +312,19 @@ void StreamSocket::DoWrite()
                                        // it's going to block now
                                        eventChange = FD_WANT_FAST_WRITE | FD_WRITE_WILL_BLOCK;
                                }
-                               while (rv > 0 && !sendq.empty())
+                               while (rv > 0 && !sq.empty())
                                {
-                                       const SendQueue::Element& front = sendq.front();
+                                       const SendQueue::Element& front = sq.front();
                                        if (front.length() <= (size_t)rv)
                                        {
                                                // this string got fully written out
                                                rv -= front.length();
-                                               sendq.pop_front();
+                                               sq.pop_front();
                                        }
                                        else
                                        {
                                                // stopped in the middle of this string
-                                               sendq.erase_front(rv);
+                                               sq.erase_front(rv);
                                                rv = 0;
                                        }
                                }
@@ -299,7 +356,6 @@ void StreamSocket::DoWrite()
                {
                        SocketEngine::ChangeEventMask(this, eventChange);
                }
-       }
 }
 
 void StreamSocket::WriteData(const std::string &data)
@@ -434,3 +490,50 @@ void StreamSocket::CheckError(BufferedSocketError errcode)
                OnError(errcode);
        }
 }
+
+IOHook* StreamSocket::GetModHook(Module* mod) const
+{
+       for (IOHook* curr = GetIOHook(); curr; curr = GetNextHook(curr))
+       {
+               if (curr->prov->creator == mod)
+                       return curr;
+       }
+       return NULL;
+}
+
+void StreamSocket::AddIOHook(IOHook* newhook)
+{
+       IOHook* curr = GetIOHook();
+       if (!curr)
+       {
+               iohook = newhook;
+               return;
+       }
+
+       IOHookMiddle* lasthook;
+       while (curr)
+       {
+               lasthook = IOHookMiddle::ToMiddleHook(curr);
+               if (!lasthook)
+                       return;
+               curr = lasthook->GetNextHook();
+       }
+
+       lasthook->SetNextHook(newhook);
+}
+
+size_t StreamSocket::getSendQSize() const
+{
+       size_t ret = sendq.bytes();
+       IOHook* curr = GetIOHook();
+       while (curr)
+       {
+               const IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(curr);
+               if (!iohm)
+                       break;
+
+               ret += iohm->GetSendQ().bytes();
+               curr = iohm->GetNextHook();
+       }
+       return ret;
+}
index fa43e6827240b8ed113b55003a356c6a51507b84..fb9f2a0eff6e313ef01caced913462fa1033d6e1 100644 (file)
@@ -19,6 +19,7 @@
 
 
 #include "inspircd.h"
+#include "iohook.h"
 
 #ifndef _WIN32
 #include <netinet/tcp.h>
@@ -26,7 +27,6 @@
 
 ListenSocket::ListenSocket(ConfigTag* tag, const irc::sockets::sockaddrs& bind_to)
        : bind_tag(tag)
-       , iohookprov(NULL, std::string())
 {
        irc::sockets::satoap(bind_to, bind_addr, bind_port);
        bind_desc = bind_to.str();
@@ -178,15 +178,23 @@ void ListenSocket::OnEventHandlerRead()
        }
 }
 
-bool ListenSocket::ResetIOHookProvider()
+void ListenSocket::ResetIOHookProvider()
 {
+       iohookprovs[0].SetProvider(bind_tag->getString("hook"));
+
+       // Check that all non-last hooks support being in the middle
+       for (IOHookProvList::iterator i = iohookprovs.begin(); i != iohookprovs.end()-1; ++i)
+       {
+               IOHookProvRef& curr = *i;
+               // Ignore if cannot be in the middle
+               if ((curr) && (!curr->IsMiddle()))
+                       curr.SetProvider(std::string());
+       }
+
        std::string provname = bind_tag->getString("ssl");
        if (!provname.empty())
                provname.insert(0, "ssl/");
 
-       // Set the new provider name, dynref handles the rest
-       iohookprov.SetProvider(provname);
-
-       // Return true if no provider was set, or one was set and it was also found
-       return (provname.empty() || iohookprov);
+       // SSL should be the last
+       iohookprovs.back().SetProvider(provname);
 }
index a1c989163ff0fea98012ac1913e3a89ce0964695..bda4e6a4878410b7c0ce8ee3c0c8c43b2520c4c6 100644 (file)
@@ -101,6 +101,8 @@ typedef gnutls_connection_end_t inspircd_gnutls_session_init_flags_t;
 #define INSPIRCD_GNUTLS_HAS_CORK
 #endif
 
+static Module* thismod;
+
 class RandGen : public HandlerBase2<void, char*, size_t>
 {
  public:
@@ -581,16 +583,21 @@ namespace GnuTLS
                 */
                const unsigned int outrecsize;
 
+               /** True to request a client certificate as a server
+                */
+               const bool requestclientcert;
+
                Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr,
                                std::auto_ptr<DHParams>& DH, unsigned int mindh, const std::string& hashstr,
                                const std::string& priostr, std::auto_ptr<X509CertList>& CA, std::auto_ptr<X509CRL>& CRL,
-                               unsigned int recsize)
+                               unsigned int recsize, bool Requestclientcert)
                        : name(profilename)
                        , x509cred(certstr, keystr)
                        , min_dh_bits(mindh)
                        , hash(hashstr)
                        , priority(priostr)
                        , outrecsize(recsize)
+                       , requestclientcert(Requestclientcert)
                {
                        x509cred.SetDH(DH);
                        x509cred.SetCA(CA, CRL);
@@ -661,7 +668,10 @@ namespace GnuTLS
 #else
                        unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
 #endif
-                       return new Profile(profilename, certstr, keystr, dh, mindh, hashstr, priostr, ca, crl, outrecsize);
+
+                       const bool requestclientcert = tag->getBool("requestclientcert", true);
+
+                       return new Profile(profilename, certstr, keystr, dh, mindh, hashstr, priostr, ca, crl, outrecsize, requestclientcert);
                }
 
                /** Set up the given session with the settings in this profile
@@ -672,8 +682,9 @@ namespace GnuTLS
                        x509cred.SetupSession(sess);
                        gnutls_dh_set_prime_bits(sess, min_dh_bits);
 
-                       // Request client certificate if we are a server, no-op if we're a client
-                       gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST);
+                       // Request client certificate if enabled and we are a server, no-op if we're a client
+                       if (requestclientcert)
+                               gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST);
                }
 
                const std::string& GetName() const { return name; }
@@ -917,7 +928,7 @@ info_done_dealloc:
        {
                StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
 #ifdef _WIN32
-               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
 #endif
 
                if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
@@ -954,7 +965,7 @@ info_done_dealloc:
        {
                StreamSocket* sock = reinterpret_cast<StreamSocket*>(transportptr);
 #ifdef _WIN32
-               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
 #endif
 
                if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
@@ -989,7 +1000,7 @@ info_done_dealloc:
        {
                StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
 #ifdef _WIN32
-               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+               GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
 #endif
 
                if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
@@ -1086,7 +1097,7 @@ info_done_dealloc:
                }
        }
 
-       int OnStreamSocketWrite(StreamSocket* user) CXX11_OVERRIDE
+       int OnStreamSocketWrite(StreamSocket* user, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
        {
                // Finish handshake if needed
                int prepret = PrepareIO(user);
@@ -1094,7 +1105,6 @@ info_done_dealloc:
                        return prepret;
 
                // Session is ready for transferring application data
-               StreamSocket::SendQueue& sendq = user->GetSendQ();
 
 #ifdef INSPIRCD_GNUTLS_HAS_CORK
                while (true)
@@ -1173,7 +1183,7 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d
        st->key_type = GNUTLS_PRIVKEY_X509;
 #endif
        StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess));
-       GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetIOHook())->GetProfile()->GetX509Credentials();
+       GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile()->GetX509Credentials();
 
        st->ncerts = cred.certs.size();
        st->cert.x509 = cred.certs.raw();
@@ -1283,6 +1293,7 @@ class ModuleSSLGnuTLS : public Module
 #ifndef GNUTLS_HAS_RND
                gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0);
 #endif
+               thismod = this;
        }
 
        void init() CXX11_OVERRIDE
@@ -1318,7 +1329,7 @@ class ModuleSSLGnuTLS : public Module
                {
                        LocalUser* user = IS_LOCAL(static_cast<User*>(item));
 
-                       if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this)
+                       if ((user) && (user->eh.GetModHook(this)))
                        {
                                // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
                                // Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -1334,13 +1345,9 @@ class ModuleSSLGnuTLS : public Module
 
        ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
        {
-               if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
-               {
-                       GnuTLSIOHook* iohook = static_cast<GnuTLSIOHook*>(user->eh.GetIOHook());
-                       if (!iohook->IsHandshakeDone())
-                               return MOD_RES_DENY;
-               }
-
+               const GnuTLSIOHook* const iohook = static_cast<GnuTLSIOHook*>(user->eh.GetModHook(this));
+               if ((iohook) && (!iohook->IsHandshakeDone()))
+                       return MOD_RES_DENY;
                return MOD_RES_PASSTHRU;
        }
 };
index 8578b8196e593ee5eb8b0f2dfa1e3971015a0334..a465d06eef1d9c449313ac5ca3ad051413353109 100644 (file)
@@ -257,7 +257,6 @@ namespace mbedTLS
                        mbedtls_debug_set_threshold(INT_MAX);
                        mbedtls_ssl_conf_dbg(&conf, DebugLogFunc, NULL);
 #endif
-                       mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
 
                        // TODO: check ret of mbedtls_ssl_config_defaults
                        mbedtls_ssl_config_defaults(&conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
@@ -308,6 +307,11 @@ namespace mbedTLS
                        mbedtls_ssl_conf_ca_chain(&conf, certs.get(), crl.get());
                }
 
+               void SetOptionalVerifyCert()
+               {
+                       mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
+               }
+
                const mbedtls_ssl_config* GetConf() const { return &conf; }
        };
 
@@ -376,7 +380,8 @@ namespace mbedTLS
                                const std::string& castr, const std::string& crlstr,
                                unsigned int recsize,
                                CTRDRBG& ctrdrbg,
-                               int minver, int maxver
+                               int minver, int maxver,
+                               bool requestclientcert
                                )
                        : name(profilename)
                        , x509cred(certstr, keystr)
@@ -414,7 +419,13 @@ namespace mbedTLS
                                serverctx.SetDHParams(dhparams);
                        }
 
-                       serverctx.SetCA(cacerts, crl);
+                       clientctx.SetOptionalVerifyCert();
+                       // The default for servers is to not request a client certificate from the peer
+                       if (requestclientcert)
+                       {
+                               serverctx.SetOptionalVerifyCert();
+                               serverctx.SetCA(cacerts, crl);
+                       }
                }
 
                static std::string ReadFile(const std::string& filename)
@@ -451,7 +462,8 @@ namespace mbedTLS
                        int minver = tag->getInt("minver");
                        int maxver = tag->getInt("maxver");
                        unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
-                       return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver);
+                       const bool requestclientcert = tag->getBool("requestclientcert", true);
+                       return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver, requestclientcert);
                }
 
                /** Set up the given session with the settings in this profile
@@ -698,7 +710,7 @@ class mbedTLSIOHook : public SSLIOHook
                }
        }
 
-       int OnStreamSocketWrite(StreamSocket* sock) CXX11_OVERRIDE
+       int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
        {
                // Finish handshake if needed
                int prepret = PrepareIO(sock);
@@ -706,7 +718,6 @@ class mbedTLSIOHook : public SSLIOHook
                        return prepret;
 
                // Session is ready for transferring application data
-               StreamSocket::SendQueue& sendq = sock->GetSendQ();
                while (!sendq.empty())
                {
                        FlattenSendQueue(sendq, profile->GetOutgoingRecordSize());
@@ -895,7 +906,7 @@ class ModuleSSLmbedTLS : public Module
                        return;
 
                LocalUser* user = IS_LOCAL(static_cast<User*>(item));
-               if ((user) && (user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
+               if ((user) && (user->eh.GetModHook(this)))
                {
                        // User is using SSL, they're a local user, and they're using our IOHook.
                        // Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -905,13 +916,9 @@ class ModuleSSLmbedTLS : public Module
 
        ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
        {
-               if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
-               {
-                       mbedTLSIOHook* iohook = static_cast<mbedTLSIOHook*>(user->eh.GetIOHook());
-                       if (!iohook->IsHandshakeDone())
-                               return MOD_RES_DENY;
-               }
-
+               const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this));
+               if ((iohook) && (!iohook->IsHandshakeDone()))
+                       return MOD_RES_DENY;
                return MOD_RES_PASSTHRU;
        }
 
index 80c9d93959116a6b21e8476aff48ac3f6c4fb632..4df0d8962e96b98367d1fc3d339bb2109eae9c5a 100644 (file)
@@ -132,7 +132,7 @@ namespace OpenSSL
                        mode |= SSL_MODE_RELEASE_BUFFERS;
 #endif
                        SSL_CTX_set_mode(ctx, mode);
-                       SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify);
+                       SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL);
                        SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF);
                        SSL_CTX_set_info_callback(ctx, StaticSSLInfoCallback);
                }
@@ -206,6 +206,11 @@ namespace OpenSSL
                        return SSL_CTX_clear_options(ctx, clearoptions);
                }
 
+               void SetVerifyCert()
+               {
+                       SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify);
+               }
+
                SSL* CreateServerSession()
                {
                        SSL* sess = SSL_new(ctx);
@@ -345,6 +350,10 @@ namespace OpenSSL
                                ERR_print_errors_cb(error_callback, this);
                                ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Can't read CA list from %s. This is only a problem if you want to verify client certificates, otherwise it's safe to ignore this message. Error: %s", filename.c_str(), lasterr.c_str());
                        }
+
+                       clictx.SetVerifyCert();
+                       if (tag->getBool("requestclientcert", true))
+                               ctx.SetVerifyCert();
                }
 
                const std::string& GetName() const { return name; }
@@ -656,7 +665,7 @@ class OpenSSLIOHook : public SSLIOHook
                }
        }
 
-       int OnStreamSocketWrite(StreamSocket* user) CXX11_OVERRIDE
+       int OnStreamSocketWrite(StreamSocket* user, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
        {
                // Finish handshake if needed
                int prepret = PrepareIO(user);
@@ -666,7 +675,6 @@ class OpenSSLIOHook : public SSLIOHook
                data_to_write = true;
 
                // Session is ready for transferring application data
-               StreamSocket::SendQueue& sendq = user->GetSendQ();
                while (!sendq.empty())
                {
                        ERR_clear_error();
@@ -910,7 +918,7 @@ class ModuleSSLOpenSSL : public Module
                {
                        LocalUser* user = IS_LOCAL((User*)item);
 
-                       if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this)
+                       if ((user) && (user->eh.GetModHook(this)))
                        {
                                // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
                                // Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -921,13 +929,9 @@ class ModuleSSLOpenSSL : public Module
 
        ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
        {
-               if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
-               {
-                       OpenSSLIOHook* iohook = static_cast<OpenSSLIOHook*>(user->eh.GetIOHook());
-                       if (!iohook->IsHandshakeDone())
-                               return MOD_RES_DENY;
-               }
-
+               const OpenSSLIOHook* const iohook = static_cast<OpenSSLIOHook*>(user->eh.GetModHook(this));
+               if ((iohook) && (!iohook->IsHandshakeDone()))
+                       return MOD_RES_DENY;
                return MOD_RES_PASSTHRU;
        }
 
index 760647d47f1ae796636e8de03bedf1045cf3ae6c..64bef70d11ad3959a4323f0a1efedf5de2ab07d9 100644 (file)
@@ -78,8 +78,8 @@ class HttpServerSocket : public BufferedSocket, public Timer, public insp::intru
        {
                ServerInstance->Timers.AddTimer(this);
 
-               if (via->iohookprov)
-                       via->iohookprov->OnAccept(this, client, server);
+               if ((!via->iohookprovs.empty()) && (via->iohookprovs.back()))
+                       via->iohookprovs.back()->OnAccept(this, client, server);
        }
 
        ~HttpServerSocket()
@@ -413,7 +413,7 @@ class ModuleHttpServer : public Module
                {
                        HttpServerSocket* sock = *i;
                        ++i;
-                       if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod)
+                       if (sock->GetModHook(mod))
                        {
                                sock->cull();
                                delete sock;
diff --git a/src/modules/m_sha1.cpp b/src/modules/m_sha1.cpp
new file mode 100644 (file)
index 0000000..5926e49
--- /dev/null
@@ -0,0 +1,199 @@
+/*
+ * InspIRCd -- Internet Relay Chat Daemon
+ *
+ *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
+ *
+ * This file is part of InspIRCd.  InspIRCd is free software: you can
+ * redistribute it and/or modify it under the terms of the GNU General Public
+ * License as published by the Free Software Foundation, version 2.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+ * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
+ * details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+/*
+SHA-1 in C
+By Steve Reid <steve@edmweb.com>
+100% Public Domain
+*/
+
+#include "inspircd.h"
+#include "modules/hash.h"
+
+union CHAR64LONG16
+{
+       unsigned char c[64];
+       uint32_t l[16];
+};
+
+inline static uint32_t rol(uint32_t value, uint32_t bits) { return (value << bits) | (value >> (32 - bits)); }
+
+// blk0() and blk() perform the initial expand.
+// I got the idea of expanding during the round function from SSLeay
+static bool big_endian;
+inline static uint32_t blk0(CHAR64LONG16& block, uint32_t i)
+{
+       if (big_endian)
+               return block.l[i];
+       else
+               return block.l[i] = (rol(block.l[i], 24) & 0xFF00FF00) | (rol(block.l[i], 8) & 0x00FF00FF);
+}
+inline static uint32_t blk(CHAR64LONG16 &block, uint32_t i) { return block.l[i & 15] = rol(block.l[(i + 13) & 15] ^ block.l[(i + 8) & 15] ^ block.l[(i + 2) & 15] ^ block.l[i & 15],1); }
+
+// (R0+R1), R2, R3, R4 are the different operations used in SHA1
+inline static void R0(CHAR64LONG16& block, uint32_t v, uint32_t &w, uint32_t x, uint32_t y, uint32_t &z, uint32_t i) { z += ((w & (x ^ y)) ^ y) + blk0(block, i) + 0x5A827999 + rol(v, 5); w = rol(w, 30); }
+inline static void R1(CHAR64LONG16& block, uint32_t v, uint32_t &w, uint32_t x, uint32_t y, uint32_t &z, uint32_t i) { z += ((w & (x ^ y)) ^ y) + blk(block, i) + 0x5A827999 + rol(v, 5); w = rol(w, 30); }
+inline static void R2(CHAR64LONG16& block, uint32_t v, uint32_t &w, uint32_t x, uint32_t y, uint32_t &z, uint32_t i) { z += (w ^ x ^ y) + blk(block, i) + 0x6ED9EBA1 + rol(v, 5); w = rol(w, 30); }
+inline static void R3(CHAR64LONG16& block, uint32_t v, uint32_t &w, uint32_t x, uint32_t y, uint32_t &z, uint32_t i) { z += (((w | x) & y) | (w & x)) + blk(block, i) + 0x8F1BBCDC + rol(v, 5); w = rol(w, 30); }
+inline static void R4(CHAR64LONG16& block, uint32_t v, uint32_t &w, uint32_t x, uint32_t y, uint32_t &z, uint32_t i) { z += (w ^ x ^ y) + blk(block, i) + 0xCA62C1D6 + rol(v, 5); w = rol(w, 30); }
+
+static const uint32_t sha1_iv[5] =
+{
+       0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0
+};
+
+class SHA1Context
+{
+       uint32_t state[5];
+       uint32_t count[2];
+       unsigned char buffer[64];
+       unsigned char digest[20];
+
+       void Transform(const unsigned char buf[64])
+       {
+               uint32_t a, b, c, d, e;
+
+               CHAR64LONG16 block;
+               memcpy(block.c, buf, 64);
+
+               // Copy state[] to working vars
+               a = this->state[0];
+               b = this->state[1];
+               c = this->state[2];
+               d = this->state[3];
+               e = this->state[4];
+
+               // 4 rounds of 20 operations each. Loop unrolled.
+               R0(block, a, b, c, d, e, 0); R0(block, e, a, b, c, d, 1); R0(block, d, e, a, b, c, 2); R0(block, c, d, e, a, b, 3);
+               R0(block, b, c, d, e, a, 4); R0(block, a, b, c, d, e, 5); R0(block, e, a, b, c, d, 6); R0(block, d, e, a, b, c, 7);
+               R0(block, c, d, e, a, b, 8); R0(block, b, c, d, e, a, 9); R0(block, a, b, c, d, e, 10); R0(block, e, a, b, c, d, 11);
+               R0(block, d, e, a, b, c, 12); R0(block, c, d, e, a, b, 13); R0(block, b, c, d, e, a, 14); R0(block, a, b, c, d, e, 15);
+               R1(block, e, a, b, c, d, 16); R1(block, d, e, a, b, c, 17); R1(block, c, d, e, a, b, 18); R1(block, b, c, d, e, a, 19);
+               R2(block, a, b, c, d, e, 20); R2(block, e, a, b, c, d, 21); R2(block, d, e, a, b, c, 22); R2(block, c, d, e, a, b, 23);
+               R2(block, b, c, d, e, a, 24); R2(block, a, b, c, d, e, 25); R2(block, e, a, b, c, d, 26); R2(block, d, e, a, b, c, 27);
+               R2(block, c, d, e, a, b, 28); R2(block, b, c, d, e, a, 29); R2(block, a, b, c, d, e, 30); R2(block, e, a, b, c, d, 31);
+               R2(block, d, e, a, b, c, 32); R2(block, c, d, e, a, b, 33); R2(block, b, c, d, e, a, 34); R2(block, a, b, c, d, e, 35);
+               R2(block, e, a, b, c, d, 36); R2(block, d, e, a, b, c, 37); R2(block, c, d, e, a, b, 38); R2(block, b, c, d, e, a, 39);
+               R3(block, a, b, c, d, e, 40); R3(block, e, a, b, c, d, 41); R3(block, d, e, a, b, c, 42); R3(block, c, d, e, a, b, 43);
+               R3(block, b, c, d, e, a, 44); R3(block, a, b, c, d, e, 45); R3(block, e, a, b, c, d, 46); R3(block, d, e, a, b, c, 47);
+               R3(block, c, d, e, a, b, 48); R3(block, b, c, d, e, a, 49); R3(block, a, b, c, d, e, 50); R3(block, e, a, b, c, d, 51);
+               R3(block, d, e, a, b, c, 52); R3(block, c, d, e, a, b, 53); R3(block, b, c, d, e, a, 54); R3(block, a, b, c, d, e, 55);
+               R3(block, e, a, b, c, d, 56); R3(block, d, e, a, b, c, 57); R3(block, c, d, e, a, b, 58); R3(block, b, c, d, e, a, 59);
+               R4(block, a, b, c, d, e, 60); R4(block, e, a, b, c, d, 61); R4(block, d, e, a, b, c, 62); R4(block, c, d, e, a, b, 63);
+               R4(block, b, c, d, e, a, 64); R4(block, a, b, c, d, e, 65); R4(block, e, a, b, c, d, 66); R4(block, d, e, a, b, c, 67);
+               R4(block, c, d, e, a, b, 68); R4(block, b, c, d, e, a, 69); R4(block, a, b, c, d, e, 70); R4(block, e, a, b, c, d, 71);
+               R4(block, d, e, a, b, c, 72); R4(block, c, d, e, a, b, 73); R4(block, b, c, d, e, a, 74); R4(block, a, b, c, d, e, 75);
+               R4(block, e, a, b, c, d, 76); R4(block, d, e, a, b, c, 77); R4(block, c, d, e, a, b, 78); R4(block, b, c, d, e, a, 79);
+               // Add the working vars back into state[]
+               this->state[0] += a;
+               this->state[1] += b;
+               this->state[2] += c;
+               this->state[3] += d;
+               this->state[4] += e;
+       }
+
+ public:
+       SHA1Context()
+       {
+               for (int i = 0; i < 5; ++i)
+                       this->state[i] = sha1_iv[i];
+
+               this->count[0] = this->count[1] = 0;
+               memset(this->buffer, 0, sizeof(this->buffer));
+               memset(this->digest, 0, sizeof(this->digest));
+       }
+
+       void Update(const unsigned char* data, size_t len)
+       {
+               uint32_t i, j;
+
+               j = (this->count[0] >> 3) & 63;
+               if ((this->count[0] += len << 3) < (len << 3))
+                       ++this->count[1];
+               this->count[1] += len >> 29;
+               if (j + len > 63)
+               {
+                       memcpy(&this->buffer[j], data, (i = 64 - j));
+                       this->Transform(this->buffer);
+                       for (; i + 63 < len; i += 64)
+                               this->Transform(&data[i]);
+                       j = 0;
+               }
+               else
+                       i = 0;
+               memcpy(&this->buffer[j], &data[i], len - i);
+       }
+
+       void Finalize()
+       {
+               uint32_t i;
+               unsigned char finalcount[8];
+
+               for (i = 0; i < 8; ++i)
+                       finalcount[i] = static_cast<unsigned char>((this->count[i >= 4 ? 0 : 1] >> ((3 - (i & 3)) * 8)) & 255); /* Endian independent */
+               this->Update(reinterpret_cast<const unsigned char *>("\200"), 1);
+               while ((this->count[0] & 504) != 448)
+                       this->Update(reinterpret_cast<const unsigned char *>("\0"), 1);
+               this->Update(finalcount, 8); // Should cause a SHA1Transform()
+               for (i = 0; i < 20; ++i)
+                       this->digest[i] = static_cast<unsigned char>((this->state[i>>2] >> ((3 - (i & 3)) * 8)) & 255);
+
+               this->Transform(this->buffer);
+       }
+
+       std::string GetRaw() const
+       {
+               return std::string((const char*)digest, sizeof(digest));
+       }
+};
+
+class SHA1HashProvider : public HashProvider
+{
+ public:
+       SHA1HashProvider(Module* mod)
+               : HashProvider(mod, "hash/sha1", 20, 64)
+       {
+       }
+
+       std::string GenerateRaw(const std::string& data)
+       {
+               SHA1Context ctx;
+               ctx.Update(reinterpret_cast<const unsigned char*>(data.data()), data.length());
+               ctx.Finalize();
+               return ctx.GetRaw();
+       }
+};
+
+class ModuleSHA1 : public Module
+{
+       SHA1HashProvider sha1;
+
+ public:
+       ModuleSHA1()
+               : sha1(this)
+       {
+               big_endian = (htonl(1337) == 1337);
+       }
+
+       Version GetVersion()
+       {
+               return Version("Implements SHA-1 hashing", VF_VENDOR);
+       }
+};
+
+MODULE_INIT(ModuleSHA1)
index 0b9bb65dff166cb8ecd113bfaad3f7807cc66bcd..81543b0da1693bbaf52bb004bae9e50132f3d075 100644 (file)
@@ -635,7 +635,7 @@ restart:
        for (TreeServer::ChildServers::const_iterator i = list.begin(); i != list.end(); ++i)
        {
                TreeSocket* sock = (*i)->GetSocket();
-               if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod)
+               if (sock->GetModHook(mod))
                {
                        sock->SendError("SSL module unloaded");
                        sock->Close();
@@ -647,7 +647,7 @@ restart:
        for (SpanningTreeUtilities::TimeoutList::const_iterator i = Utils->timeoutlist.begin(); i != Utils->timeoutlist.end(); ++i)
        {
                TreeSocket* sock = i->first;
-               if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod)
+               if (sock->GetModHook(mod))
                        sock->Close();
        }
 }
index 2198d6208c1f6927f2e386eae57dd25543047710..e1642a086041bb54bb9b0eb9d1e62c8287f40864 100644 (file)
@@ -60,8 +60,13 @@ TreeSocket::TreeSocket(int newfd, ListenSocket* via, irc::sockets::sockaddrs* cl
        capab = new CapabData;
        capab->capab_phase = 0;
 
-       if (via->iohookprov)
-               via->iohookprov->OnAccept(this, client, server);
+       for (ListenSocket::IOHookProvList::iterator i = via->iohookprovs.begin(); i != via->iohookprovs.end(); ++i)
+       {
+               ListenSocket::IOHookProvRef& iohookprovref = *i;
+               if (iohookprovref)
+                       iohookprovref->OnAccept(this, client, server);
+       }
+
        SendCapabilities(1);
 
        Utils->timeoutlist[this] = std::pair<std::string, int>(linkID, 30);
diff --git a/src/modules/m_websocket.cpp b/src/modules/m_websocket.cpp
new file mode 100644 (file)
index 0000000..399b0b0
--- /dev/null
@@ -0,0 +1,405 @@
+/*
+ * InspIRCd -- Internet Relay Chat Daemon
+ *
+ *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
+ *
+ * This file is part of InspIRCd.  InspIRCd is free software: you can
+ * redistribute it and/or modify it under the terms of the GNU General Public
+ * License as published by the Free Software Foundation, version 2.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+ * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
+ * details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+
+#include "inspircd.h"
+#include "iohook.h"
+#include "modules/hash.h"
+
+static const char MagicGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+static const char whitespace[] = " \t\r\n";
+static dynamic_reference_nocheck<HashProvider>* sha1;
+
+class WebSocketHookProvider : public IOHookProvider
+{
+ public:
+       WebSocketHookProvider(Module* mod)
+               : IOHookProvider(mod, "websocket", IOHookProvider::IOH_UNKNOWN, true)
+       {
+       }
+
+       void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE;
+
+       void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
+       {
+       }
+};
+
+class WebSocketHook : public IOHookMiddle
+{
+       class HTTPHeaderFinder
+       {
+               std::string::size_type bpos;
+               std::string::size_type len;
+
+        public:
+               bool Find(const std::string& req, const char* header, std::string::size_type headerlen, std::string::size_type maxpos)
+               {
+                       std::string::size_type keybegin = req.find(header);
+                       if ((keybegin == std::string::npos) || (keybegin > maxpos) || (keybegin == 0) || (req[keybegin-1] != '\n'))
+                               return false;
+
+                       keybegin += headerlen;
+
+                       bpos = req.find_first_not_of(whitespace, keybegin, sizeof(whitespace)-1);
+                       if ((bpos == std::string::npos) || (bpos > maxpos))
+                               return false;
+
+                       const std::string::size_type epos = req.find_first_of(whitespace, bpos, sizeof(whitespace)-1);
+                       len = epos - bpos;
+
+                       return true;
+               }
+
+               std::string ExtractValue(const std::string& req) const
+               {
+                       return std::string(req, bpos, len);
+               }
+       };
+
+       enum OpCode
+       {
+               OP_CONTINUATION = 0x00,
+               OP_TEXT = 0x01,
+               OP_BINARY = 0x02,
+               OP_CLOSE = 0x08,
+               OP_PING = 0x09,
+               OP_PONG = 0x0a
+       };
+
+       enum State
+       {
+               STATE_HTTPREQ,
+               STATE_ESTABLISHED
+       };
+
+       static const unsigned char WS_MASKBIT = (1 << 7);
+       static const unsigned char WS_FINBIT = (1 << 7);
+       static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_LARGE = 126;
+       static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_HUGE = 127;
+       static const size_t WS_MAX_PAYLOAD_LENGTH_SMALL = 125;
+       static const size_t WS_MAX_PAYLOAD_LENGTH_LARGE = 65535;
+       static const size_t MAXHEADERSIZE = sizeof(uint64_t) + 2;
+
+       // Clients sending ping or pong frames faster than this are killed
+       static const time_t MINPINGPONGDELAY = 10;
+
+       State state;
+       time_t lastpingpong;
+
+       static size_t FillHeader(unsigned char* outbuf, size_t sendlength, OpCode opcode)
+       {
+               size_t pos = 0;
+               outbuf[pos++] = WS_FINBIT | opcode;
+
+               if (sendlength <= WS_MAX_PAYLOAD_LENGTH_SMALL)
+               {
+                       outbuf[pos++] = sendlength;
+               }
+               else if (sendlength <= WS_MAX_PAYLOAD_LENGTH_LARGE)
+               {
+                       outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_LARGE;
+                       outbuf[pos++] = (sendlength >> 8) & 0xff;
+                       outbuf[pos++] = sendlength & 0xff;
+               }
+               else
+               {
+                       outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_HUGE;
+                       const uint64_t len = sendlength;
+                       for (int i = sizeof(uint64_t)-1; i >= 0; i--)
+                               outbuf[pos++] = ((len >> i*8) & 0xff);
+               }
+
+               return pos;
+       }
+
+       static StreamSocket::SendQueue::Element PrepareSendQElem(size_t size, OpCode opcode)
+       {
+               unsigned char header[MAXHEADERSIZE];
+               const size_t n = FillHeader(header, size, opcode);
+
+               return StreamSocket::SendQueue::Element(reinterpret_cast<const char*>(header), n);
+       }
+
+       int HandleAppData(StreamSocket* sock, std::string& appdataout, bool allowlarge)
+       {
+               std::string& myrecvq = GetRecvQ();
+               // Need 1 byte opcode, minimum 1 byte len, 4 bytes masking key
+               if (myrecvq.length() < 6)
+                       return 0;
+
+               const std::string& cmyrecvq = myrecvq;
+               unsigned char len1 = (unsigned char)cmyrecvq[1];
+               if (!(len1 & WS_MASKBIT))
+               {
+                       sock->SetError("WebSocket protocol violation: unmasked client frame");
+                       return -1;
+               }
+
+               len1 &= ~WS_MASKBIT;
+
+               // Assume the length is a single byte, if not, update values later
+               unsigned int len = len1;
+               unsigned int payloadstartoffset = 6;
+               const unsigned char* maskkey = reinterpret_cast<const unsigned char*>(&cmyrecvq[2]);
+
+               if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE)
+               {
+                       // allowlarge is false for control frames according to the RFC meaning large pings, etc. are not allowed
+                       if (!allowlarge)
+                       {
+                               sock->SetError("WebSocket protocol violation: large control frame");
+                               return -1;
+                       }
+
+                       // Large frame, has 2 bytes len after the magic byte indicating the length
+                       // Need 1 byte opcode, 3 bytes len, 4 bytes masking key
+                       if (myrecvq.length() < 8)
+                               return 0;
+
+                       unsigned char len2 = (unsigned char)cmyrecvq[2];
+                       unsigned char len3 = (unsigned char)cmyrecvq[3];
+                       len = (len2 << 8) | len3;
+
+                       if (len <= WS_MAX_PAYLOAD_LENGTH_SMALL)
+                       {
+                               sock->SetError("WebSocket protocol violation: non-minimal length encoding used");
+                               return -1;
+                       }
+
+                       maskkey += 2;
+                       payloadstartoffset += 2;
+               }
+               else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE)
+               {
+                       sock->SetError("WebSocket: Huge frames are not supported");
+                       return -1;
+               }
+
+               if (myrecvq.length() < payloadstartoffset + len)
+                       return 0;
+
+               unsigned int maskkeypos = 0;
+               const std::string::iterator endit = myrecvq.begin() + payloadstartoffset + len;
+               for (std::string::const_iterator i = myrecvq.begin() + payloadstartoffset; i != endit; ++i)
+               {
+                       const unsigned char c = (unsigned char)*i;
+                       appdataout.push_back(c ^ maskkey[maskkeypos++]);
+                       maskkeypos %= 4;
+               }
+
+               myrecvq.erase(myrecvq.begin(), endit);
+               return 1;
+       }
+
+       int HandlePingPongFrame(StreamSocket* sock, bool isping)
+       {
+               if (lastpingpong + MINPINGPONGDELAY >= ServerInstance->Time())
+               {
+                       sock->SetError("WebSocket: Ping/pong flood");
+                       return -1;
+               }
+
+               lastpingpong = ServerInstance->Time();
+
+               std::string appdata;
+               const int result = HandleAppData(sock, appdata, false);
+               // If it's a pong stop here regardless of the result so we won't generate a reply
+               if ((result <= 0) || (!isping))
+                       return result;
+
+               StreamSocket::SendQueue::Element elem = PrepareSendQElem(appdata.length(), OP_PONG);
+               elem.append(appdata);
+               GetSendQ().push_back(elem);
+
+               SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
+               return 1;
+       }
+
+       int HandleWS(StreamSocket* sock, std::string& destrecvq)
+       {
+               if (GetRecvQ().empty())
+                       return 0;
+
+               unsigned char opcode = (unsigned char)GetRecvQ().c_str()[0];
+               opcode &= ~WS_FINBIT;
+
+               switch (opcode)
+               {
+                       case OP_CONTINUATION:
+                       case OP_TEXT:
+                       case OP_BINARY:
+                       {
+                               return HandleAppData(sock, destrecvq, true);
+                       }
+
+                       case OP_PING:
+                       {
+                               return HandlePingPongFrame(sock, true);
+                       }
+
+                       case OP_PONG:
+                       {
+                               // A pong frame may be sent unsolicited, so we have to handle it.
+                               // It may carry application data which we need to remove from the recvq as well.
+                               return HandlePingPongFrame(sock, false);
+                       }
+
+                       case OP_CLOSE:
+                       {
+                               sock->SetError("Connection closed");
+                               return -1;
+                       }
+
+                       default:
+                       {
+                               sock->SetError("WebSocket: Invalid opcode");
+                               return -1;
+                       }
+               }
+       }
+
+       void FailHandshake(StreamSocket* sock, const char* httpreply, const char* sockerror)
+       {
+               GetSendQ().push_back(StreamSocket::SendQueue::Element(httpreply));
+               sock->DoWrite();
+               sock->SetError(sockerror);
+       }
+
+       int HandleHTTPReq(StreamSocket* sock)
+       {
+               std::string& recvq = GetRecvQ();
+               const std::string::size_type reqend = recvq.find("\r\n\r\n");
+               if (reqend == std::string::npos)
+                       return 0;
+
+               HTTPHeaderFinder keyheader;
+               if (!keyheader.Find(recvq, "Sec-WebSocket-Key:", 18, reqend))
+               {
+                       FailHandshake(sock, "HTTP/1.1 501 Not Implemented\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request which is not a websocket upgrade");
+                       return -1;
+               }
+
+               if (!*sha1)
+               {
+                       FailHandshake(sock, "HTTP/1.1 503 Service Unavailable\r\nConnection: close\r\n\r\n", "WebSocket: SHA-1 provider missing");
+                       return -1;
+               }
+
+               state = STATE_ESTABLISHED;
+
+               std::string key = keyheader.ExtractValue(recvq);
+               key.append(MagicGUID);
+
+               std::string reply = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
+               reply.append(BinToBase64((*sha1)->GenerateRaw(key), NULL, '=')).append("\r\n\r\n");
+               GetSendQ().push_back(StreamSocket::SendQueue::Element(reply));
+
+               SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
+
+               recvq.erase(0, reqend + 4);
+
+               return 1;
+       }
+
+ public:
+       WebSocketHook(IOHookProvider* Prov, StreamSocket* sock)
+               : IOHookMiddle(Prov)
+               , state(STATE_HTTPREQ)
+               , lastpingpong(0)
+       {
+               sock->AddIOHook(this);
+       }
+
+       int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& uppersendq) CXX11_OVERRIDE
+       {
+               StreamSocket::SendQueue& mysendq = GetSendQ();
+
+               // Return 1 to allow sending back an error HTTP response
+               if (state != STATE_ESTABLISHED)
+                       return (mysendq.empty() ? 0 : 1);
+
+               if (!uppersendq.empty())
+               {
+                       StreamSocket::SendQueue::Element elem = PrepareSendQElem(uppersendq.bytes(), OP_BINARY);
+                       mysendq.push_back(elem);
+                       mysendq.moveall(uppersendq);
+               }
+
+               return 1;
+       }
+
+       int OnStreamSocketRead(StreamSocket* sock, std::string& destrecvq) CXX11_OVERRIDE
+       {
+               if (state == STATE_HTTPREQ)
+               {
+                       int httpret = HandleHTTPReq(sock);
+                       if (httpret <= 0)
+                               return httpret;
+               }
+
+               int wsret;
+               do
+               {
+                       wsret = HandleWS(sock, destrecvq);
+               }
+               while ((!GetRecvQ().empty()) && (wsret > 0));
+
+               return wsret;
+       }
+
+       void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
+       {
+       }
+};
+
+void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
+{
+       new WebSocketHook(this, sock);
+}
+
+class ModuleWebSocket : public Module
+{
+       dynamic_reference_nocheck<HashProvider> hash;
+       WebSocketHookProvider hookprov;
+
+ public:
+       ModuleWebSocket()
+               : hash(this, "hash/sha1")
+               , hookprov(this)
+       {
+               sha1 = &hash;
+       }
+
+       void OnCleanup(int target_type, void* item) CXX11_OVERRIDE
+       {
+               if (target_type != TYPE_USER)
+                       return;
+
+               LocalUser* user = IS_LOCAL(static_cast<User*>(item));
+               if ((user) && (user->eh.GetModHook(this)))
+                       ServerInstance->Users.QuitUser(user, "WebSocket module unloading");
+       }
+
+       Version GetVersion() CXX11_OVERRIDE
+       {
+               return Version("Provides RFC 6455 WebSocket support", VF_VENDOR);
+       }
+};
+
+MODULE_INIT(ModuleWebSocket)
index fe052fcfc90f517c477772e8849f8aa15b8acb37..95deca00a9a356f01a301bc9b135f485d2084a83 100644 (file)
@@ -72,8 +72,12 @@ void UserManager::AddUser(int socket, ListenSocket* via, irc::sockets::sockaddrs
        UserIOHandler* eh = &New->eh;
 
        // If this listener has an IO hook provider set then tell it about the connection
-       if (via->iohookprov)
-               via->iohookprov->OnAccept(eh, client, server);
+       for (ListenSocket::IOHookProvList::iterator i = via->iohookprovs.begin(); i != via->iohookprovs.end(); ++i)
+       {
+               ListenSocket::IOHookProvRef& iohookprovref = *i;
+               if (iohookprovref)
+                       iohookprovref->OnAccept(eh, client, server);
+       }
 
        ServerInstance->Logs->Log("USERS", LOG_DEBUG, "New user fd: %d", socket);