#include "inspircd.h"
#include "iohook.h"
+static IOHook* GetNextHook(IOHook* hook)
+{
+ IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+ if (iohm)
+ return iohm->GetNextHook();
+ return NULL;
+}
+
BufferedSocket::BufferedSocket()
{
Timeout = NULL;
SocketEngine::AddFd(this, FD_WANT_FAST_READ | FD_WANT_EDGE_WRITE);
}
-void BufferedSocket::DoConnect(const std::string &ipaddr, int aport, unsigned long maxtime, const std::string &connectbindip)
+void BufferedSocket::DoConnect(const irc::sockets::sockaddrs& dest, const irc::sockets::sockaddrs& bind, unsigned int maxtime)
{
- BufferedSocketError err = BeginConnect(ipaddr, aport, maxtime, connectbindip);
+ BufferedSocketError err = BeginConnect(dest, bind, maxtime);
if (err != I_ERR_NONE)
{
state = I_ERROR;
}
}
-BufferedSocketError BufferedSocket::BeginConnect(const std::string &ipaddr, int aport, unsigned long maxtime, const std::string &connectbindip)
-{
- irc::sockets::sockaddrs addr, bind;
- if (!irc::sockets::aptosa(ipaddr, aport, addr))
- {
- ServerInstance->Logs->Log("SOCKET", LOG_DEBUG, "BUG: Hostname passed to BufferedSocket, rather than an IP address!");
- return I_ERR_CONNECT;
- }
-
- bind.sa.sa_family = 0;
- if (!connectbindip.empty())
- {
- if (!irc::sockets::aptosa(connectbindip, 0, bind))
- {
- return I_ERR_BIND;
- }
- }
-
- return BeginConnect(addr, bind, maxtime);
-}
-
-BufferedSocketError BufferedSocket::BeginConnect(const irc::sockets::sockaddrs& dest, const irc::sockets::sockaddrs& bind, unsigned long timeout)
+BufferedSocketError BufferedSocket::BeginConnect(const irc::sockets::sockaddrs& dest, const irc::sockets::sockaddrs& bind, unsigned int timeout)
{
if (fd < 0)
- fd = socket(dest.sa.sa_family, SOCK_STREAM, 0);
+ fd = socket(dest.family(), SOCK_STREAM, 0);
if (fd < 0)
return I_ERR_SOCKET;
- if (bind.sa.sa_family != 0)
+ if (bind.family() != 0)
{
if (SocketEngine::Bind(fd, bind) < 0)
return I_ERR_BIND;
SocketEngine::NonBlocking(fd);
- if (SocketEngine::Connect(this, &dest.sa, dest.sa_size()) == -1)
+ if (SocketEngine::Connect(this, dest) == -1)
{
if (errno != EINPROGRESS)
return I_ERR_CONNECT;
void StreamSocket::Close()
{
+ if (closing)
+ return;
+
+ closing = true;
if (this->fd > -1)
{
// final chance, dump as much of the sendq as we can
DoWrite();
- if (GetIOHook())
+
+ IOHook* hook = GetIOHook();
+ DelIOHook();
+ while (hook)
{
- GetIOHook()->OnStreamSocketClose(this);
- delete iohook;
- DelIOHook();
+ hook->OnStreamSocketClose(this);
+ IOHook* const nexthook = GetNextHook(hook);
+ delete hook;
+ hook = nexthook;
}
SocketEngine::Shutdown(this, 2);
SocketEngine::Close(this);
}
}
+void StreamSocket::Close(bool writeblock)
+{
+ if (getSendQSize() != 0 && writeblock)
+ closeonempty = true;
+ else
+ Close();
+}
+
CullResult StreamSocket::cull()
{
Close();
return true;
}
-void StreamSocket::DoRead()
+int StreamSocket::HookChainRead(IOHook* hook, std::string& rq)
{
- if (GetIOHook())
+ if (!hook)
+ return ReadToRecvQ(rq);
+
+ IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+ if (iohm)
{
- int rv = GetIOHook()->OnStreamSocketRead(this, recvq);
- if (rv > 0)
- OnDataReady();
- if (rv < 0)
- SetError("Read Error"); // will not overwrite a better error message
+ // Call the next hook to put data into the recvq of the current hook
+ const int ret = HookChainRead(iohm->GetNextHook(), iohm->GetRecvQ());
+ if (ret <= 0)
+ return ret;
}
- else
+ return hook->OnStreamSocketRead(this, rq);
+}
+
+void StreamSocket::DoRead()
+{
+ const std::string::size_type prevrecvqsize = recvq.size();
+
+ const int result = HookChainRead(GetIOHook(), recvq);
+ if (result < 0)
{
+ SetError("Read Error"); // will not overwrite a better error message
+ return;
+ }
+
+ if (recvq.size() > prevrecvqsize)
+ OnDataReady();
+}
+
+int StreamSocket::ReadToRecvQ(std::string& rq)
+{
char* ReadBuffer = ServerInstance->GetReadBuffer();
int n = SocketEngine::Recv(this, ReadBuffer, ServerInstance->Config->NetBufferSize, 0);
if (n == ServerInstance->Config->NetBufferSize)
{
SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ | FD_ADD_TRIAL_READ);
- recvq.append(ReadBuffer, n);
- OnDataReady();
+ rq.append(ReadBuffer, n);
}
else if (n > 0)
{
SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ);
- recvq.append(ReadBuffer, n);
- OnDataReady();
+ rq.append(ReadBuffer, n);
}
else if (n == 0)
{
error = "Connection closed";
SocketEngine::ChangeEventMask(this, FD_WANT_NO_READ | FD_WANT_NO_WRITE);
+ return -1;
}
else if (SocketEngine::IgnoreError())
{
SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ | FD_READ_WILL_BLOCK);
+ return 0;
}
else if (errno == EINTR)
{
SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ | FD_ADD_TRIAL_READ);
+ return 0;
}
else
{
error = SocketEngine::LastError();
SocketEngine::ChangeEventMask(this, FD_WANT_NO_READ | FD_WANT_NO_WRITE);
+ return -1;
}
- }
+ return n;
}
/* Don't try to prepare huge blobs of data to send to a blocked socket */
void StreamSocket::DoWrite()
{
- if (sendq.empty())
+ if (getSendQSize() == 0)
+ {
+ if (closeonempty)
+ Close();
+
return;
+ }
if (!error.empty() || fd < 0)
{
ServerInstance->Logs->Log("SOCKET", LOG_DEBUG, "DoWrite on errored or closed socket");
return;
}
- if (GetIOHook())
+ SendQueue* psendq = &sendq;
+ IOHook* hook = GetIOHook();
+ while (hook)
{
- int rv = GetIOHook()->OnStreamSocketWrite(this, sendq);
- if (rv < 0)
- SetError("Write Error"); // will not overwrite a better error message
+ int rv = hook->OnStreamSocketWrite(this, *psendq);
+ psendq = NULL;
// rv == 0 means the socket has blocked. Stop trying to send data.
// IOHook has requested unblock notification from the socketengine.
+ if (rv == 0)
+ break;
+
+ if (rv < 0)
+ {
+ SetError("Write Error"); // will not overwrite a better error message
+ break;
+ }
+
+ IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+ hook = NULL;
+ if (iohm)
+ {
+ psendq = &iohm->GetSendQ();
+ hook = iohm->GetNextHook();
+ }
}
- else
- {
- FlushSendQ(sendq);
- }
+
+ if (psendq)
+ FlushSendQ(*psendq);
+
+ if (getSendQSize() == 0 && closeonempty)
+ Close();
}
void StreamSocket::FlushSendQ(SendQueue& sq)
const SendQueue::Element& elem = *i;
iovecs[j].iov_base = const_cast<char*>(elem.data());
iovecs[j].iov_len = elem.length();
- rv_max += elem.length();
+ rv_max += iovecs[j].iov_len;
}
rv = SocketEngine::WriteV(this, iovecs, bufcount);
}
}
}
+bool StreamSocket::OnSetEndPoint(const irc::sockets::sockaddrs& local, const irc::sockets::sockaddrs& remote)
+{
+ return false;
+}
+
void StreamSocket::WriteData(const std::string &data)
{
if (fd < 0)
IOHook* StreamSocket::GetModHook(Module* mod) const
{
- if (iohook)
+ for (IOHook* curr = GetIOHook(); curr; curr = GetNextHook(curr))
{
- if (iohook->prov->creator == mod)
- return iohook;
+ if (curr->prov->creator == mod)
+ return curr;
}
return NULL;
}
+
+void StreamSocket::AddIOHook(IOHook* newhook)
+{
+ IOHook* curr = GetIOHook();
+ if (!curr)
+ {
+ iohook = newhook;
+ return;
+ }
+
+ IOHookMiddle* lasthook;
+ while (curr)
+ {
+ lasthook = IOHookMiddle::ToMiddleHook(curr);
+ if (!lasthook)
+ return;
+ curr = lasthook->GetNextHook();
+ }
+
+ lasthook->SetNextHook(newhook);
+}
+
+size_t StreamSocket::getSendQSize() const
+{
+ size_t ret = sendq.bytes();
+ IOHook* curr = GetIOHook();
+ while (curr)
+ {
+ const IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(curr);
+ if (!iohm)
+ break;
+
+ ret += iohm->GetSendQ().bytes();
+ curr = iohm->GetNextHook();
+ }
+ return ret;
+}
+
+void StreamSocket::SwapInternals(StreamSocket& other)
+{
+ if (type != other.type)
+ return;
+
+ EventHandler::SwapInternals(other);
+ std::swap(closeonempty, other.closeonempty);
+ std::swap(closing, other.closing);
+ std::swap(error, other.error);
+ std::swap(iohook, other.iohook);
+ std::swap(recvq, other.recvq);
+ std::swap(sendq, other.sendq);
+}