X-Git-Url: https://git.netwichtig.de/gitweb/?a=blobdiff_plain;f=src%2Finspsocket.cpp;h=ebbff448ed15bfd791aee08ccb49c3d13b1db6b9;hb=0a1f9bc59494a532a91bc9c8afcecb31ece656ee;hp=f6ca531647b582a769c1ceeee5602024c21bb7e1;hpb=6466e3093a4e5e996f2c9f3c4fd9f6eb1ac0e7b9;p=user%2Fhenk%2Fcode%2Finspircd.git diff --git a/src/inspsocket.cpp b/src/inspsocket.cpp index f6ca53164..ebbff448e 100644 --- a/src/inspsocket.cpp +++ b/src/inspsocket.cpp @@ -25,6 +25,14 @@ #include "inspircd.h" #include "iohook.h" +static IOHook* GetNextHook(IOHook* hook) +{ + IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook); + if (iohm) + return iohm->GetNextHook(); + return NULL; +} + BufferedSocket::BufferedSocket() { Timeout = NULL; @@ -40,7 +48,7 @@ BufferedSocket::BufferedSocket(int newfd) SocketEngine::AddFd(this, FD_WANT_FAST_READ | FD_WANT_EDGE_WRITE); } -void BufferedSocket::DoConnect(const std::string &ipaddr, int aport, unsigned long maxtime, const std::string &connectbindip) +void BufferedSocket::DoConnect(const std::string& ipaddr, int aport, unsigned int maxtime, const std::string& connectbindip) { BufferedSocketError err = BeginConnect(ipaddr, aport, maxtime, connectbindip); if (err != I_ERR_NONE) @@ -51,7 +59,7 @@ void BufferedSocket::DoConnect(const std::string &ipaddr, int aport, unsigned lo } } -BufferedSocketError BufferedSocket::BeginConnect(const std::string &ipaddr, int aport, unsigned long maxtime, const std::string &connectbindip) +BufferedSocketError BufferedSocket::BeginConnect(const std::string& ipaddr, int aport, unsigned int maxtime, const std::string& connectbindip) { irc::sockets::sockaddrs addr, bind; if (!irc::sockets::aptosa(ipaddr, aport, addr)) @@ -72,15 +80,15 @@ BufferedSocketError BufferedSocket::BeginConnect(const std::string &ipaddr, int return BeginConnect(addr, bind, maxtime); } -BufferedSocketError BufferedSocket::BeginConnect(const irc::sockets::sockaddrs& dest, const irc::sockets::sockaddrs& bind, unsigned long timeout) +BufferedSocketError BufferedSocket::BeginConnect(const irc::sockets::sockaddrs& dest, const irc::sockets::sockaddrs& bind, unsigned int timeout) { if (fd < 0) - fd = socket(dest.sa.sa_family, SOCK_STREAM, 0); + fd = socket(dest.family(), SOCK_STREAM, 0); if (fd < 0) return I_ERR_SOCKET; - if (bind.sa.sa_family != 0) + if (bind.family() != 0) { if (SocketEngine::Bind(fd, bind) < 0) return I_ERR_BIND; @@ -88,7 +96,7 @@ BufferedSocketError BufferedSocket::BeginConnect(const irc::sockets::sockaddrs& SocketEngine::NonBlocking(fd); - if (SocketEngine::Connect(this, &dest.sa, dest.sa_size()) == -1) + if (SocketEngine::Connect(this, dest) == -1) { if (errno != EINPROGRESS) return I_ERR_CONNECT; @@ -112,11 +120,15 @@ void StreamSocket::Close() { // final chance, dump as much of the sendq as we can DoWrite(); - if (GetIOHook()) + + IOHook* hook = GetIOHook(); + DelIOHook(); + while (hook) { - GetIOHook()->OnStreamSocketClose(this); - delete iohook; - DelIOHook(); + hook->OnStreamSocketClose(this); + IOHook* const nexthook = GetNextHook(hook); + delete hook; + hook = nexthook; } SocketEngine::Shutdown(this, 2); SocketEngine::Close(this); @@ -139,20 +151,35 @@ bool StreamSocket::GetNextLine(std::string& line, char delim) return true; } -void StreamSocket::DoRead() +int StreamSocket::HookChainRead(IOHook* hook, std::string& rq) { - if (GetIOHook()) + if (!hook) + return ReadToRecvQ(rq); + + IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook); + if (iohm) { - int rv = GetIOHook()->OnStreamSocketRead(this, recvq); - if (rv > 0) - OnDataReady(); - if (rv < 0) - SetError("Read Error"); // will not overwrite a better error message + // Call the next hook to put data into the recvq of the current hook + const int ret = HookChainRead(iohm->GetNextHook(), iohm->GetRecvQ()); + if (ret <= 0) + return ret; } - else + return hook->OnStreamSocketRead(this, rq); +} + +void StreamSocket::DoRead() +{ + const std::string::size_type prevrecvqsize = recvq.size(); + + const int result = HookChainRead(GetIOHook(), recvq); + if (result < 0) { - ReadToRecvQ(recvq); + SetError("Read Error"); // will not overwrite a better error message + return; } + + if (recvq.size() > prevrecvqsize) + OnDataReady(); } int StreamSocket::ReadToRecvQ(std::string& rq) @@ -163,13 +190,11 @@ int StreamSocket::ReadToRecvQ(std::string& rq) { SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ | FD_ADD_TRIAL_READ); rq.append(ReadBuffer, n); - OnDataReady(); } else if (n > 0) { SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ); rq.append(ReadBuffer, n); - OnDataReady(); } else if (n == 0) { @@ -201,7 +226,7 @@ static const int MYIOV_MAX = IOV_MAX < 128 ? IOV_MAX : 128; void StreamSocket::DoWrite() { - if (sendq.empty()) + if (getSendQSize() == 0) return; if (!error.empty() || fd < 0) { @@ -209,19 +234,35 @@ void StreamSocket::DoWrite() return; } - if (GetIOHook()) + SendQueue* psendq = &sendq; + IOHook* hook = GetIOHook(); + while (hook) { - int rv = GetIOHook()->OnStreamSocketWrite(this, sendq); - if (rv < 0) - SetError("Write Error"); // will not overwrite a better error message + int rv = hook->OnStreamSocketWrite(this, *psendq); + psendq = NULL; // rv == 0 means the socket has blocked. Stop trying to send data. // IOHook has requested unblock notification from the socketengine. + if (rv == 0) + break; + + if (rv < 0) + { + SetError("Write Error"); // will not overwrite a better error message + break; + } + + IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook); + hook = NULL; + if (iohm) + { + psendq = &iohm->GetSendQ(); + hook = iohm->GetNextHook(); + } } - else - { - FlushSendQ(sendq); - } + + if (psendq) + FlushSendQ(*psendq); } void StreamSocket::FlushSendQ(SendQueue& sq) @@ -252,7 +293,7 @@ void StreamSocket::FlushSendQ(SendQueue& sq) const SendQueue::Element& elem = *i; iovecs[j].iov_base = const_cast(elem.data()); iovecs[j].iov_len = elem.length(); - rv_max += elem.length(); + rv_max += iovecs[j].iov_len; } rv = SocketEngine::WriteV(this, iovecs, bufcount); } @@ -317,6 +358,11 @@ void StreamSocket::FlushSendQ(SendQueue& sq) } } +bool StreamSocket::OnSetEndPoint(const irc::sockets::sockaddrs& local, const irc::sockets::sockaddrs& remote) +{ + return false; +} + void StreamSocket::WriteData(const std::string &data) { if (fd < 0) @@ -452,10 +498,47 @@ void StreamSocket::CheckError(BufferedSocketError errcode) IOHook* StreamSocket::GetModHook(Module* mod) const { - if (iohook) + for (IOHook* curr = GetIOHook(); curr; curr = GetNextHook(curr)) { - if (iohook->prov->creator == mod) - return iohook; + if (curr->prov->creator == mod) + return curr; } return NULL; } + +void StreamSocket::AddIOHook(IOHook* newhook) +{ + IOHook* curr = GetIOHook(); + if (!curr) + { + iohook = newhook; + return; + } + + IOHookMiddle* lasthook; + while (curr) + { + lasthook = IOHookMiddle::ToMiddleHook(curr); + if (!lasthook) + return; + curr = lasthook->GetNextHook(); + } + + lasthook->SetNextHook(newhook); +} + +size_t StreamSocket::getSendQSize() const +{ + size_t ret = sendq.bytes(); + IOHook* curr = GetIOHook(); + while (curr) + { + const IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(curr); + if (!iohm) + break; + + ret += iohm->GetSendQ().bytes(); + curr = iohm->GetNextHook(); + } + return ret; +}