]> git.netwichtig.de Git - user/henk/code/inspircd.git/commitdiff
Add support for multiple IOHooks per StreamSocket
authorAttila Molnar <attilamolnar@hush.com>
Mon, 8 Aug 2016 13:02:28 +0000 (15:02 +0200)
committerAttila Molnar <attilamolnar@hush.com>
Mon, 8 Aug 2016 13:02:28 +0000 (15:02 +0200)
include/inspsocket.h
include/iohook.h
src/inspsocket.cpp

index 0c5f3b3af8a140c81e7c292e10ffd5ab9ceab57a..77d79bd955ad7118ed1185f7aec77b15592ff4bf 100644 (file)
@@ -240,6 +240,16 @@ class CoreExport StreamSocket : public EventHandler
         */
        int ReadToRecvQ(std::string& rq);
 
+       /** Read data from a hook chain recursively, starting at 'hook'.
+        * If 'hook' is NULL, the recvq is filled with data from SocketEngine::Recv(), otherwise it is filled with data from the
+        * next hook in the chain.
+        * @param hook Next IOHook in the chain, can be NULL
+        * @param rq Receive queue to put incoming data into
+        * @return < 0 on error or close, 0 if no new data is ready (but the socket is still connected), > 0 if data was read from
+        the socket and put into the recvq
+        */
+       int HookChainRead(IOHook* hook, std::string& rq);
+
  protected:
        std::string recvq;
  public:
@@ -286,7 +296,7 @@ class CoreExport StreamSocket : public EventHandler
         */
        bool GetNextLine(std::string& line, char delim = '\n');
        /** Useful for implementing sendq exceeded */
-       size_t getSendQSize() const { return sendq.size(); }
+       size_t getSendQSize() const;
 
        SendQueue& GetSendQ() { return sendq; }
 
@@ -376,5 +386,4 @@ class CoreExport BufferedSocket : public StreamSocket
 };
 
 inline IOHook* StreamSocket::GetIOHook() const { return iohook; }
-inline void StreamSocket::AddIOHook(IOHook* hook) { iohook = hook; }
 inline void StreamSocket::DelIOHook() { iohook = NULL; }
index 57630796397485934253787ee022741bbff2cf32..f236486586091bd4e903e8518746a4f4f842a834 100644 (file)
@@ -23,6 +23,8 @@ class StreamSocket;
 
 class IOHookProvider : public ServiceProvider
 {
+       const bool middlehook;
+
  public:
        enum Type
        {
@@ -32,8 +34,14 @@ class IOHookProvider : public ServiceProvider
 
        const Type type;
 
-       IOHookProvider(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN)
-               : ServiceProvider(mod, Name, SERVICE_IOHOOK), type(hooktype) { }
+       IOHookProvider(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN, bool middle = false)
+               : ServiceProvider(mod, Name, SERVICE_IOHOOK), middlehook(middle), type(hooktype) { }
+
+       /** Check if the IOHook provided can appear in the non-last position of a hook chain.
+        * That is the case if and only if the IOHook instances created are subclasses of IOHookMiddle.
+        * @return True if the IOHooks provided are subclasses of IOHookMiddle
+        */
+       bool IsMiddle() const { return middlehook; }
 
        /** Called immediately after a connection is accepted. This is intended for raw socket
         * processing (e.g. modules which wrap the tcp connection within another library) and provides
@@ -87,3 +95,67 @@ class IOHook : public classbase
         */
        virtual int OnStreamSocketRead(StreamSocket* sock, std::string& recvq) = 0;
 };
+
+class IOHookMiddle : public IOHook
+{
+       /** Data already processed by the IOHook waiting to go down the chain
+        */
+       StreamSocket::SendQueue sendq;
+
+       /** Data waiting to go up the chain
+        */
+       std::string precvq;
+
+       /** Next IOHook in the chain
+        */
+       IOHook* nexthook;
+
+ protected:
+       /** Get all queued up data which has not yet been passed up the hook chain
+        * @return RecvQ containing the data
+        */
+       std::string& GetRecvQ() { return precvq; }
+
+       /** Get all queued up data which is ready to go down the hook chain
+        * @return SendQueue containing all data waiting to go down the hook chain
+        */
+       StreamSocket::SendQueue& GetSendQ() { return sendq; }
+
+ public:
+       /** Constructor
+        * @param provider IOHookProvider that creates this object
+        */
+       IOHookMiddle(IOHookProvider* provider)
+               : IOHook(provider)
+               , nexthook(NULL)
+       {
+       }
+
+       /** Get all queued up data which is ready to go down the hook chain
+        * @return SendQueue containing all data waiting to go down the hook chain
+        */
+       const StreamSocket::SendQueue& GetSendQ() const { return sendq; }
+
+       /** Get the next IOHook in the chain
+        * @return Next hook in the chain or NULL if this is the last hook
+        */
+       IOHook* GetNextHook() const { return nexthook; }
+
+       /** Set the next hook in the chain
+        * @param hook Hook to set as the next hook in the chain
+        */
+       void SetNextHook(IOHook* hook) { nexthook = hook; }
+
+       /** Check if a hook is capable of being the non-last hook in a hook chain and if so, cast it to an IOHookMiddle object.
+        * @param hook IOHook to check
+        * @return IOHookMiddle referring to the same hook or NULL
+        */
+       static IOHookMiddle* ToMiddleHook(IOHook* hook)
+       {
+               if (hook->prov->IsMiddle())
+                       return static_cast<IOHookMiddle*>(hook);
+               return NULL;
+       }
+
+       friend class StreamSocket;
+};
index 89144dee000b4293324713ba74ea96a15999dc08..9bfc6a73e33104bf0ff64d1df9c4b80000aa98e4 100644 (file)
 #include "inspircd.h"
 #include "iohook.h"
 
+static IOHook* GetNextHook(IOHook* hook)
+{
+       IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+       if (iohm)
+               return iohm->GetNextHook();
+       return NULL;
+}
+
 BufferedSocket::BufferedSocket()
 {
        Timeout = NULL;
@@ -112,11 +120,15 @@ void StreamSocket::Close()
        {
                // final chance, dump as much of the sendq as we can
                DoWrite();
-               if (GetIOHook())
+
+               IOHook* hook = GetIOHook();
+               DelIOHook();
+               while (hook)
                {
-                       GetIOHook()->OnStreamSocketClose(this);
-                       delete iohook;
-                       DelIOHook();
+                       hook->OnStreamSocketClose(this);
+                       IOHook* const nexthook = GetNextHook(hook);
+                       delete hook;
+                       hook = nexthook;
                }
                SocketEngine::Shutdown(this, 2);
                SocketEngine::Close(this);
@@ -139,22 +151,31 @@ bool StreamSocket::GetNextLine(std::string& line, char delim)
        return true;
 }
 
-void StreamSocket::DoRead()
+int StreamSocket::HookChainRead(IOHook* hook, std::string& rq)
 {
-       const std::string::size_type prevrecvqsize = recvq.size();
+       if (!hook)
+               return ReadToRecvQ(rq);
 
-       if (GetIOHook())
+       IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+       if (iohm)
        {
-               int rv = GetIOHook()->OnStreamSocketRead(this, recvq);
-               if (rv < 0)
-               {
-                       SetError("Read Error"); // will not overwrite a better error message
-                       return;
-               }
+               // Call the next hook to put data into the recvq of the current hook
+               const int ret = HookChainRead(iohm->GetNextHook(), iohm->GetRecvQ());
+               if (ret <= 0)
+                       return ret;
        }
-       else
+       return hook->OnStreamSocketRead(this, rq);
+}
+
+void StreamSocket::DoRead()
+{
+       const std::string::size_type prevrecvqsize = recvq.size();
+
+       const int result = HookChainRead(GetIOHook(), recvq);
+       if (result < 0)
        {
-               ReadToRecvQ(recvq);
+               SetError("Read Error"); // will not overwrite a better error message
+               return;
        }
 
        if (recvq.size() > prevrecvqsize)
@@ -205,7 +226,7 @@ static const int MYIOV_MAX = IOV_MAX < 128 ? IOV_MAX : 128;
 
 void StreamSocket::DoWrite()
 {
-       if (sendq.empty())
+       if (getSendQSize() == 0)
                return;
        if (!error.empty() || fd < 0)
        {
@@ -213,19 +234,35 @@ void StreamSocket::DoWrite()
                return;
        }
 
-       if (GetIOHook())
+       SendQueue* psendq = &sendq;
+       IOHook* hook = GetIOHook();
+       while (hook)
        {
-               int rv = GetIOHook()->OnStreamSocketWrite(this, sendq);
-               if (rv < 0)
-                       SetError("Write Error"); // will not overwrite a better error message
+               int rv = hook->OnStreamSocketWrite(this, *psendq);
+               psendq = NULL;
 
                // rv == 0 means the socket has blocked. Stop trying to send data.
                // IOHook has requested unblock notification from the socketengine.
+               if (rv == 0)
+                       break;
+
+               if (rv < 0)
+               {
+                       SetError("Write Error"); // will not overwrite a better error message
+                       break;
+               }
+
+               IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+               hook = NULL;
+               if (iohm)
+               {
+                       psendq = &iohm->GetSendQ();
+                       hook = iohm->GetNextHook();
+               }
        }
-       else
-       {
-               FlushSendQ(sendq);
-       }
+
+       if (psendq)
+               FlushSendQ(*psendq);
 }
 
 void StreamSocket::FlushSendQ(SendQueue& sq)
@@ -456,10 +493,47 @@ void StreamSocket::CheckError(BufferedSocketError errcode)
 
 IOHook* StreamSocket::GetModHook(Module* mod) const
 {
-       if (iohook)
+       for (IOHook* curr = GetIOHook(); curr; curr = GetNextHook(curr))
        {
-               if (iohook->prov->creator == mod)
-                       return iohook;
+               if (curr->prov->creator == mod)
+                       return curr;
        }
        return NULL;
 }
+
+void StreamSocket::AddIOHook(IOHook* newhook)
+{
+       IOHook* curr = GetIOHook();
+       if (!curr)
+       {
+               iohook = newhook;
+               return;
+       }
+
+       IOHookMiddle* lasthook;
+       while (curr)
+       {
+               lasthook = IOHookMiddle::ToMiddleHook(curr);
+               if (!lasthook)
+                       return;
+               curr = lasthook->GetNextHook();
+       }
+
+       lasthook->SetNextHook(newhook);
+}
+
+size_t StreamSocket::getSendQSize() const
+{
+       size_t ret = sendq.bytes();
+       IOHook* curr = GetIOHook();
+       while (curr)
+       {
+               const IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(curr);
+               if (!iohm)
+                       break;
+
+               ret += iohm->GetSendQ().bytes();
+               curr = iohm->GetNextHook();
+       }
+       return ret;
+}