summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/inspsocket.h13
-rw-r--r--include/iohook.h76
-rw-r--r--src/inspsocket.cpp128
3 files changed, 186 insertions, 31 deletions
diff --git a/include/inspsocket.h b/include/inspsocket.h
index 0c5f3b3af..77d79bd95 100644
--- a/include/inspsocket.h
+++ b/include/inspsocket.h
@@ -240,6 +240,16 @@ class CoreExport StreamSocket : public EventHandler
*/
int ReadToRecvQ(std::string& rq);
+ /** Read data from a hook chain recursively, starting at 'hook'.
+ * If 'hook' is NULL, the recvq is filled with data from SocketEngine::Recv(), otherwise it is filled with data from the
+ * next hook in the chain.
+ * @param hook Next IOHook in the chain, can be NULL
+ * @param rq Receive queue to put incoming data into
+ * @return < 0 on error or close, 0 if no new data is ready (but the socket is still connected), > 0 if data was read from
+ the socket and put into the recvq
+ */
+ int HookChainRead(IOHook* hook, std::string& rq);
+
protected:
std::string recvq;
public:
@@ -286,7 +296,7 @@ class CoreExport StreamSocket : public EventHandler
*/
bool GetNextLine(std::string& line, char delim = '\n');
/** Useful for implementing sendq exceeded */
- size_t getSendQSize() const { return sendq.size(); }
+ size_t getSendQSize() const;
SendQueue& GetSendQ() { return sendq; }
@@ -376,5 +386,4 @@ class CoreExport BufferedSocket : public StreamSocket
};
inline IOHook* StreamSocket::GetIOHook() const { return iohook; }
-inline void StreamSocket::AddIOHook(IOHook* hook) { iohook = hook; }
inline void StreamSocket::DelIOHook() { iohook = NULL; }
diff --git a/include/iohook.h b/include/iohook.h
index 576307963..f23648658 100644
--- a/include/iohook.h
+++ b/include/iohook.h
@@ -23,6 +23,8 @@ class StreamSocket;
class IOHookProvider : public ServiceProvider
{
+ const bool middlehook;
+
public:
enum Type
{
@@ -32,8 +34,14 @@ class IOHookProvider : public ServiceProvider
const Type type;
- IOHookProvider(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN)
- : ServiceProvider(mod, Name, SERVICE_IOHOOK), type(hooktype) { }
+ IOHookProvider(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN, bool middle = false)
+ : ServiceProvider(mod, Name, SERVICE_IOHOOK), middlehook(middle), type(hooktype) { }
+
+ /** Check if the IOHook provided can appear in the non-last position of a hook chain.
+ * That is the case if and only if the IOHook instances created are subclasses of IOHookMiddle.
+ * @return True if the IOHooks provided are subclasses of IOHookMiddle
+ */
+ bool IsMiddle() const { return middlehook; }
/** Called immediately after a connection is accepted. This is intended for raw socket
* processing (e.g. modules which wrap the tcp connection within another library) and provides
@@ -87,3 +95,67 @@ class IOHook : public classbase
*/
virtual int OnStreamSocketRead(StreamSocket* sock, std::string& recvq) = 0;
};
+
+class IOHookMiddle : public IOHook
+{
+ /** Data already processed by the IOHook waiting to go down the chain
+ */
+ StreamSocket::SendQueue sendq;
+
+ /** Data waiting to go up the chain
+ */
+ std::string precvq;
+
+ /** Next IOHook in the chain
+ */
+ IOHook* nexthook;
+
+ protected:
+ /** Get all queued up data which has not yet been passed up the hook chain
+ * @return RecvQ containing the data
+ */
+ std::string& GetRecvQ() { return precvq; }
+
+ /** Get all queued up data which is ready to go down the hook chain
+ * @return SendQueue containing all data waiting to go down the hook chain
+ */
+ StreamSocket::SendQueue& GetSendQ() { return sendq; }
+
+ public:
+ /** Constructor
+ * @param provider IOHookProvider that creates this object
+ */
+ IOHookMiddle(IOHookProvider* provider)
+ : IOHook(provider)
+ , nexthook(NULL)
+ {
+ }
+
+ /** Get all queued up data which is ready to go down the hook chain
+ * @return SendQueue containing all data waiting to go down the hook chain
+ */
+ const StreamSocket::SendQueue& GetSendQ() const { return sendq; }
+
+ /** Get the next IOHook in the chain
+ * @return Next hook in the chain or NULL if this is the last hook
+ */
+ IOHook* GetNextHook() const { return nexthook; }
+
+ /** Set the next hook in the chain
+ * @param hook Hook to set as the next hook in the chain
+ */
+ void SetNextHook(IOHook* hook) { nexthook = hook; }
+
+ /** Check if a hook is capable of being the non-last hook in a hook chain and if so, cast it to an IOHookMiddle object.
+ * @param hook IOHook to check
+ * @return IOHookMiddle referring to the same hook or NULL
+ */
+ static IOHookMiddle* ToMiddleHook(IOHook* hook)
+ {
+ if (hook->prov->IsMiddle())
+ return static_cast<IOHookMiddle*>(hook);
+ return NULL;
+ }
+
+ friend class StreamSocket;
+};
diff --git a/src/inspsocket.cpp b/src/inspsocket.cpp
index 89144dee0..9bfc6a73e 100644
--- a/src/inspsocket.cpp
+++ b/src/inspsocket.cpp
@@ -25,6 +25,14 @@
#include "inspircd.h"
#include "iohook.h"
+static IOHook* GetNextHook(IOHook* hook)
+{
+ IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+ if (iohm)
+ return iohm->GetNextHook();
+ return NULL;
+}
+
BufferedSocket::BufferedSocket()
{
Timeout = NULL;
@@ -112,11 +120,15 @@ void StreamSocket::Close()
{
// final chance, dump as much of the sendq as we can
DoWrite();
- if (GetIOHook())
+
+ IOHook* hook = GetIOHook();
+ DelIOHook();
+ while (hook)
{
- GetIOHook()->OnStreamSocketClose(this);
- delete iohook;
- DelIOHook();
+ hook->OnStreamSocketClose(this);
+ IOHook* const nexthook = GetNextHook(hook);
+ delete hook;
+ hook = nexthook;
}
SocketEngine::Shutdown(this, 2);
SocketEngine::Close(this);
@@ -139,22 +151,31 @@ bool StreamSocket::GetNextLine(std::string& line, char delim)
return true;
}
-void StreamSocket::DoRead()
+int StreamSocket::HookChainRead(IOHook* hook, std::string& rq)
{
- const std::string::size_type prevrecvqsize = recvq.size();
+ if (!hook)
+ return ReadToRecvQ(rq);
- if (GetIOHook())
+ IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+ if (iohm)
{
- int rv = GetIOHook()->OnStreamSocketRead(this, recvq);
- if (rv < 0)
- {
- SetError("Read Error"); // will not overwrite a better error message
- return;
- }
+ // Call the next hook to put data into the recvq of the current hook
+ const int ret = HookChainRead(iohm->GetNextHook(), iohm->GetRecvQ());
+ if (ret <= 0)
+ return ret;
}
- else
+ return hook->OnStreamSocketRead(this, rq);
+}
+
+void StreamSocket::DoRead()
+{
+ const std::string::size_type prevrecvqsize = recvq.size();
+
+ const int result = HookChainRead(GetIOHook(), recvq);
+ if (result < 0)
{
- ReadToRecvQ(recvq);
+ SetError("Read Error"); // will not overwrite a better error message
+ return;
}
if (recvq.size() > prevrecvqsize)
@@ -205,7 +226,7 @@ static const int MYIOV_MAX = IOV_MAX < 128 ? IOV_MAX : 128;
void StreamSocket::DoWrite()
{
- if (sendq.empty())
+ if (getSendQSize() == 0)
return;
if (!error.empty() || fd < 0)
{
@@ -213,19 +234,35 @@ void StreamSocket::DoWrite()
return;
}
- if (GetIOHook())
+ SendQueue* psendq = &sendq;
+ IOHook* hook = GetIOHook();
+ while (hook)
{
- int rv = GetIOHook()->OnStreamSocketWrite(this, sendq);
- if (rv < 0)
- SetError("Write Error"); // will not overwrite a better error message
+ int rv = hook->OnStreamSocketWrite(this, *psendq);
+ psendq = NULL;
// rv == 0 means the socket has blocked. Stop trying to send data.
// IOHook has requested unblock notification from the socketengine.
+ if (rv == 0)
+ break;
+
+ if (rv < 0)
+ {
+ SetError("Write Error"); // will not overwrite a better error message
+ break;
+ }
+
+ IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+ hook = NULL;
+ if (iohm)
+ {
+ psendq = &iohm->GetSendQ();
+ hook = iohm->GetNextHook();
+ }
}
- else
- {
- FlushSendQ(sendq);
- }
+
+ if (psendq)
+ FlushSendQ(*psendq);
}
void StreamSocket::FlushSendQ(SendQueue& sq)
@@ -456,10 +493,47 @@ void StreamSocket::CheckError(BufferedSocketError errcode)
IOHook* StreamSocket::GetModHook(Module* mod) const
{
- if (iohook)
+ for (IOHook* curr = GetIOHook(); curr; curr = GetNextHook(curr))
{
- if (iohook->prov->creator == mod)
- return iohook;
+ if (curr->prov->creator == mod)
+ return curr;
}
return NULL;
}
+
+void StreamSocket::AddIOHook(IOHook* newhook)
+{
+ IOHook* curr = GetIOHook();
+ if (!curr)
+ {
+ iohook = newhook;
+ return;
+ }
+
+ IOHookMiddle* lasthook;
+ while (curr)
+ {
+ lasthook = IOHookMiddle::ToMiddleHook(curr);
+ if (!lasthook)
+ return;
+ curr = lasthook->GetNextHook();
+ }
+
+ lasthook->SetNextHook(newhook);
+}
+
+size_t StreamSocket::getSendQSize() const
+{
+ size_t ret = sendq.bytes();
+ IOHook* curr = GetIOHook();
+ while (curr)
+ {
+ const IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(curr);
+ if (!iohm)
+ break;
+
+ ret += iohm->GetSendQ().bytes();
+ curr = iohm->GetNextHook();
+ }
+ return ret;
+}