diff options
-rw-r--r-- | include/inspsocket.h | 13 | ||||
-rw-r--r-- | include/iohook.h | 76 | ||||
-rw-r--r-- | src/inspsocket.cpp | 128 |
3 files changed, 186 insertions, 31 deletions
diff --git a/include/inspsocket.h b/include/inspsocket.h index 0c5f3b3af..77d79bd95 100644 --- a/include/inspsocket.h +++ b/include/inspsocket.h @@ -240,6 +240,16 @@ class CoreExport StreamSocket : public EventHandler */ int ReadToRecvQ(std::string& rq); + /** Read data from a hook chain recursively, starting at 'hook'. + * If 'hook' is NULL, the recvq is filled with data from SocketEngine::Recv(), otherwise it is filled with data from the + * next hook in the chain. + * @param hook Next IOHook in the chain, can be NULL + * @param rq Receive queue to put incoming data into + * @return < 0 on error or close, 0 if no new data is ready (but the socket is still connected), > 0 if data was read from + the socket and put into the recvq + */ + int HookChainRead(IOHook* hook, std::string& rq); + protected: std::string recvq; public: @@ -286,7 +296,7 @@ class CoreExport StreamSocket : public EventHandler */ bool GetNextLine(std::string& line, char delim = '\n'); /** Useful for implementing sendq exceeded */ - size_t getSendQSize() const { return sendq.size(); } + size_t getSendQSize() const; SendQueue& GetSendQ() { return sendq; } @@ -376,5 +386,4 @@ class CoreExport BufferedSocket : public StreamSocket }; inline IOHook* StreamSocket::GetIOHook() const { return iohook; } -inline void StreamSocket::AddIOHook(IOHook* hook) { iohook = hook; } inline void StreamSocket::DelIOHook() { iohook = NULL; } diff --git a/include/iohook.h b/include/iohook.h index 576307963..f23648658 100644 --- a/include/iohook.h +++ b/include/iohook.h @@ -23,6 +23,8 @@ class StreamSocket; class IOHookProvider : public ServiceProvider { + const bool middlehook; + public: enum Type { @@ -32,8 +34,14 @@ class IOHookProvider : public ServiceProvider const Type type; - IOHookProvider(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN) - : ServiceProvider(mod, Name, SERVICE_IOHOOK), type(hooktype) { } + IOHookProvider(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN, bool middle = false) + : ServiceProvider(mod, Name, SERVICE_IOHOOK), middlehook(middle), type(hooktype) { } + + /** Check if the IOHook provided can appear in the non-last position of a hook chain. + * That is the case if and only if the IOHook instances created are subclasses of IOHookMiddle. + * @return True if the IOHooks provided are subclasses of IOHookMiddle + */ + bool IsMiddle() const { return middlehook; } /** Called immediately after a connection is accepted. This is intended for raw socket * processing (e.g. modules which wrap the tcp connection within another library) and provides @@ -87,3 +95,67 @@ class IOHook : public classbase */ virtual int OnStreamSocketRead(StreamSocket* sock, std::string& recvq) = 0; }; + +class IOHookMiddle : public IOHook +{ + /** Data already processed by the IOHook waiting to go down the chain + */ + StreamSocket::SendQueue sendq; + + /** Data waiting to go up the chain + */ + std::string precvq; + + /** Next IOHook in the chain + */ + IOHook* nexthook; + + protected: + /** Get all queued up data which has not yet been passed up the hook chain + * @return RecvQ containing the data + */ + std::string& GetRecvQ() { return precvq; } + + /** Get all queued up data which is ready to go down the hook chain + * @return SendQueue containing all data waiting to go down the hook chain + */ + StreamSocket::SendQueue& GetSendQ() { return sendq; } + + public: + /** Constructor + * @param provider IOHookProvider that creates this object + */ + IOHookMiddle(IOHookProvider* provider) + : IOHook(provider) + , nexthook(NULL) + { + } + + /** Get all queued up data which is ready to go down the hook chain + * @return SendQueue containing all data waiting to go down the hook chain + */ + const StreamSocket::SendQueue& GetSendQ() const { return sendq; } + + /** Get the next IOHook in the chain + * @return Next hook in the chain or NULL if this is the last hook + */ + IOHook* GetNextHook() const { return nexthook; } + + /** Set the next hook in the chain + * @param hook Hook to set as the next hook in the chain + */ + void SetNextHook(IOHook* hook) { nexthook = hook; } + + /** Check if a hook is capable of being the non-last hook in a hook chain and if so, cast it to an IOHookMiddle object. + * @param hook IOHook to check + * @return IOHookMiddle referring to the same hook or NULL + */ + static IOHookMiddle* ToMiddleHook(IOHook* hook) + { + if (hook->prov->IsMiddle()) + return static_cast<IOHookMiddle*>(hook); + return NULL; + } + + friend class StreamSocket; +}; diff --git a/src/inspsocket.cpp b/src/inspsocket.cpp index 89144dee0..9bfc6a73e 100644 --- a/src/inspsocket.cpp +++ b/src/inspsocket.cpp @@ -25,6 +25,14 @@ #include "inspircd.h" #include "iohook.h" +static IOHook* GetNextHook(IOHook* hook) +{ + IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook); + if (iohm) + return iohm->GetNextHook(); + return NULL; +} + BufferedSocket::BufferedSocket() { Timeout = NULL; @@ -112,11 +120,15 @@ void StreamSocket::Close() { // final chance, dump as much of the sendq as we can DoWrite(); - if (GetIOHook()) + + IOHook* hook = GetIOHook(); + DelIOHook(); + while (hook) { - GetIOHook()->OnStreamSocketClose(this); - delete iohook; - DelIOHook(); + hook->OnStreamSocketClose(this); + IOHook* const nexthook = GetNextHook(hook); + delete hook; + hook = nexthook; } SocketEngine::Shutdown(this, 2); SocketEngine::Close(this); @@ -139,22 +151,31 @@ bool StreamSocket::GetNextLine(std::string& line, char delim) return true; } -void StreamSocket::DoRead() +int StreamSocket::HookChainRead(IOHook* hook, std::string& rq) { - const std::string::size_type prevrecvqsize = recvq.size(); + if (!hook) + return ReadToRecvQ(rq); - if (GetIOHook()) + IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook); + if (iohm) { - int rv = GetIOHook()->OnStreamSocketRead(this, recvq); - if (rv < 0) - { - SetError("Read Error"); // will not overwrite a better error message - return; - } + // Call the next hook to put data into the recvq of the current hook + const int ret = HookChainRead(iohm->GetNextHook(), iohm->GetRecvQ()); + if (ret <= 0) + return ret; } - else + return hook->OnStreamSocketRead(this, rq); +} + +void StreamSocket::DoRead() +{ + const std::string::size_type prevrecvqsize = recvq.size(); + + const int result = HookChainRead(GetIOHook(), recvq); + if (result < 0) { - ReadToRecvQ(recvq); + SetError("Read Error"); // will not overwrite a better error message + return; } if (recvq.size() > prevrecvqsize) @@ -205,7 +226,7 @@ static const int MYIOV_MAX = IOV_MAX < 128 ? IOV_MAX : 128; void StreamSocket::DoWrite() { - if (sendq.empty()) + if (getSendQSize() == 0) return; if (!error.empty() || fd < 0) { @@ -213,19 +234,35 @@ void StreamSocket::DoWrite() return; } - if (GetIOHook()) + SendQueue* psendq = &sendq; + IOHook* hook = GetIOHook(); + while (hook) { - int rv = GetIOHook()->OnStreamSocketWrite(this, sendq); - if (rv < 0) - SetError("Write Error"); // will not overwrite a better error message + int rv = hook->OnStreamSocketWrite(this, *psendq); + psendq = NULL; // rv == 0 means the socket has blocked. Stop trying to send data. // IOHook has requested unblock notification from the socketengine. + if (rv == 0) + break; + + if (rv < 0) + { + SetError("Write Error"); // will not overwrite a better error message + break; + } + + IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook); + hook = NULL; + if (iohm) + { + psendq = &iohm->GetSendQ(); + hook = iohm->GetNextHook(); + } } - else - { - FlushSendQ(sendq); - } + + if (psendq) + FlushSendQ(*psendq); } void StreamSocket::FlushSendQ(SendQueue& sq) @@ -456,10 +493,47 @@ void StreamSocket::CheckError(BufferedSocketError errcode) IOHook* StreamSocket::GetModHook(Module* mod) const { - if (iohook) + for (IOHook* curr = GetIOHook(); curr; curr = GetNextHook(curr)) { - if (iohook->prov->creator == mod) - return iohook; + if (curr->prov->creator == mod) + return curr; } return NULL; } + +void StreamSocket::AddIOHook(IOHook* newhook) +{ + IOHook* curr = GetIOHook(); + if (!curr) + { + iohook = newhook; + return; + } + + IOHookMiddle* lasthook; + while (curr) + { + lasthook = IOHookMiddle::ToMiddleHook(curr); + if (!lasthook) + return; + curr = lasthook->GetNextHook(); + } + + lasthook->SetNextHook(newhook); +} + +size_t StreamSocket::getSendQSize() const +{ + size_t ret = sendq.bytes(); + IOHook* curr = GetIOHook(); + while (curr) + { + const IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(curr); + if (!iohm) + break; + + ret += iohm->GetSendQ().bytes(); + curr = iohm->GetNextHook(); + } + return ret; +} |