]> git.netwichtig.de Git - user/henk/code/inspircd.git/blobdiff - src/inspsocket.cpp
Add support for blocking tag messages with the deaf mode.
[user/henk/code/inspircd.git] / src / inspsocket.cpp
index f6ca531647b582a769c1ceeee5602024c21bb7e1..e00429bbafc0a36ab90e349a1751a62f41155b1a 100644 (file)
@@ -1,12 +1,19 @@
 /*
  * InspIRCd -- Internet Relay Chat Daemon
  *
- *   Copyright (C) 2009 Daniel De Graaf <danieldg@inspircd.org>
- *   Copyright (C) 2007-2009 Robin Burchell <robin+git@viroteck.net>
- *   Copyright (C) 2008 Thomas Stagner <aquanight@inspircd.org>
- *   Copyright (C) 2006-2007 Craig Edwards <craigedwards@brainbox.cc>
+ *   Copyright (C) 2020 Matt Schatz <genius3000@g3k.solutions>
+ *   Copyright (C) 2019 linuxdaemon <linuxdaemon.irc@gmail.com>
+ *   Copyright (C) 2018 Dylan Frank <b00mx0r@aureus.pw>
+ *   Copyright (C) 2013-2016 Attila Molnar <attilamolnar@hush.com>
+ *   Copyright (C) 2013, 2017-2020 Sadie Powell <sadie@witchery.services>
+ *   Copyright (C) 2013 Adam <Adam@anope.org>
+ *   Copyright (C) 2012 Robby <robby@chatbelgie.be>
+ *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
+ *   Copyright (C) 2007-2008 Robin Burchell <robin+git@viroteck.net>
+ *   Copyright (C) 2007 John Brooks <special@inspircd.org>
  *   Copyright (C) 2007 Dennis Friis <peavey@inspircd.org>
- *   Copyright (C) 2006 Oliver Lupton <oliverlupton@gmail.com>
+ *   Copyright (C) 2006-2007 Craig Edwards <brain@inspircd.org>
+ *   Copyright (C) 2006 Oliver Lupton <om@inspircd.org>
  *
  * This file is part of InspIRCd.  InspIRCd is free software: you can
  * redistribute it and/or modify it under the terms of the GNU General Public
 #include "inspircd.h"
 #include "iohook.h"
 
+static IOHook* GetNextHook(IOHook* hook)
+{
+       IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+       if (iohm)
+               return iohm->GetNextHook();
+       return NULL;
+}
+
 BufferedSocket::BufferedSocket()
 {
        Timeout = NULL;
@@ -36,13 +51,13 @@ BufferedSocket::BufferedSocket(int newfd)
        Timeout = NULL;
        this->fd = newfd;
        this->state = I_CONNECTED;
-       if (fd > -1)
+       if (HasFd())
                SocketEngine::AddFd(this, FD_WANT_FAST_READ | FD_WANT_EDGE_WRITE);
 }
 
-void BufferedSocket::DoConnect(const std::string &ipaddr, int aport, unsigned long maxtime, const std::string &connectbindip)
+void BufferedSocket::DoConnect(const irc::sockets::sockaddrs& dest, const irc::sockets::sockaddrs& bind, unsigned int maxtime)
 {
-       BufferedSocketError err = BeginConnect(ipaddr, aport, maxtime, connectbindip);
+       BufferedSocketError err = BeginConnect(dest, bind, maxtime);
        if (err != I_ERR_NONE)
        {
                state = I_ERROR;
@@ -51,36 +66,15 @@ void BufferedSocket::DoConnect(const std::string &ipaddr, int aport, unsigned lo
        }
 }
 
-BufferedSocketError BufferedSocket::BeginConnect(const std::string &ipaddr, int aport, unsigned long maxtime, const std::string &connectbindip)
-{
-       irc::sockets::sockaddrs addr, bind;
-       if (!irc::sockets::aptosa(ipaddr, aport, addr))
-       {
-               ServerInstance->Logs->Log("SOCKET", LOG_DEBUG, "BUG: Hostname passed to BufferedSocket, rather than an IP address!");
-               return I_ERR_CONNECT;
-       }
-
-       bind.sa.sa_family = 0;
-       if (!connectbindip.empty())
-       {
-               if (!irc::sockets::aptosa(connectbindip, 0, bind))
-               {
-                       return I_ERR_BIND;
-               }
-       }
-
-       return BeginConnect(addr, bind, maxtime);
-}
-
-BufferedSocketError BufferedSocket::BeginConnect(const irc::sockets::sockaddrs& dest, const irc::sockets::sockaddrs& bind, unsigned long timeout)
+BufferedSocketError BufferedSocket::BeginConnect(const irc::sockets::sockaddrs& dest, const irc::sockets::sockaddrs& bind, unsigned int timeout)
 {
-       if (fd < 0)
-               fd = socket(dest.sa.sa_family, SOCK_STREAM, 0);
+       if (!HasFd())
+               fd = socket(dest.family(), SOCK_STREAM, 0);
 
-       if (fd < 0)
+       if (!HasFd())
                return I_ERR_SOCKET;
 
-       if (bind.sa.sa_family != 0)
+       if (bind.family() != 0)
        {
                if (SocketEngine::Bind(fd, bind) < 0)
                        return I_ERR_BIND;
@@ -88,7 +82,7 @@ BufferedSocketError BufferedSocket::BeginConnect(const irc::sockets::sockaddrs&
 
        SocketEngine::NonBlocking(fd);
 
-       if (SocketEngine::Connect(this, &dest.sa, dest.sa_size()) == -1)
+       if (SocketEngine::Connect(this, dest) == -1)
        {
                if (errno != EINPROGRESS)
                        return I_ERR_CONNECT;
@@ -108,21 +102,37 @@ BufferedSocketError BufferedSocket::BeginConnect(const irc::sockets::sockaddrs&
 
 void StreamSocket::Close()
 {
-       if (this->fd > -1)
+       if (closing)
+               return;
+
+       closing = true;
+       if (HasFd())
        {
                // final chance, dump as much of the sendq as we can
                DoWrite();
-               if (GetIOHook())
+
+               IOHook* hook = GetIOHook();
+               DelIOHook();
+               while (hook)
                {
-                       GetIOHook()->OnStreamSocketClose(this);
-                       delete iohook;
-                       DelIOHook();
+                       hook->OnStreamSocketClose(this);
+                       IOHook* const nexthook = GetNextHook(hook);
+                       delete hook;
+                       hook = nexthook;
                }
                SocketEngine::Shutdown(this, 2);
                SocketEngine::Close(this);
        }
 }
 
+void StreamSocket::Close(bool writeblock)
+{
+       if (getSendQSize() != 0 && writeblock)
+               closeonempty = true;
+       else
+               Close();
+}
+
 CullResult StreamSocket::cull()
 {
        Close();
@@ -139,20 +149,35 @@ bool StreamSocket::GetNextLine(std::string& line, char delim)
        return true;
 }
 
-void StreamSocket::DoRead()
+int StreamSocket::HookChainRead(IOHook* hook, std::string& rq)
 {
-       if (GetIOHook())
+       if (!hook)
+               return ReadToRecvQ(rq);
+
+       IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+       if (iohm)
        {
-               int rv = GetIOHook()->OnStreamSocketRead(this, recvq);
-               if (rv > 0)
-                       OnDataReady();
-               if (rv < 0)
-                       SetError("Read Error"); // will not overwrite a better error message
+               // Call the next hook to put data into the recvq of the current hook
+               const int ret = HookChainRead(iohm->GetNextHook(), iohm->GetRecvQ());
+               if (ret <= 0)
+                       return ret;
        }
-       else
+       return hook->OnStreamSocketRead(this, rq);
+}
+
+void StreamSocket::DoRead()
+{
+       const std::string::size_type prevrecvqsize = recvq.size();
+
+       const int result = HookChainRead(GetIOHook(), recvq);
+       if (result < 0)
        {
-               ReadToRecvQ(recvq);
+               SetError("Read Error"); // will not overwrite a better error message
+               return;
        }
+
+       if (recvq.size() > prevrecvqsize)
+               OnDataReady();
 }
 
 int StreamSocket::ReadToRecvQ(std::string& rq)
@@ -163,13 +188,11 @@ int StreamSocket::ReadToRecvQ(std::string& rq)
                {
                        SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ | FD_ADD_TRIAL_READ);
                        rq.append(ReadBuffer, n);
-                       OnDataReady();
                }
                else if (n > 0)
                {
                        SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ);
                        rq.append(ReadBuffer, n);
-                       OnDataReady();
                }
                else if (n == 0)
                {
@@ -201,27 +224,51 @@ static const int MYIOV_MAX = IOV_MAX < 128 ? IOV_MAX : 128;
 
 void StreamSocket::DoWrite()
 {
-       if (sendq.empty())
+       if (getSendQSize() == 0)
+       {
+               if (closeonempty)
+                       Close();
+
                return;
-       if (!error.empty() || fd < 0)
+       }
+       if (!error.empty() || !HasFd())
        {
                ServerInstance->Logs->Log("SOCKET", LOG_DEBUG, "DoWrite on errored or closed socket");
                return;
        }
 
-       if (GetIOHook())
+       SendQueue* psendq = &sendq;
+       IOHook* hook = GetIOHook();
+       while (hook)
        {
-               int rv = GetIOHook()->OnStreamSocketWrite(this, sendq);
-               if (rv < 0)
-                       SetError("Write Error"); // will not overwrite a better error message
+               int rv = hook->OnStreamSocketWrite(this, *psendq);
+               psendq = NULL;
 
                // rv == 0 means the socket has blocked. Stop trying to send data.
                // IOHook has requested unblock notification from the socketengine.
+               if (rv == 0)
+                       break;
+
+               if (rv < 0)
+               {
+                       SetError("Write Error"); // will not overwrite a better error message
+                       break;
+               }
+
+               IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+               hook = NULL;
+               if (iohm)
+               {
+                       psendq = &iohm->GetSendQ();
+                       hook = iohm->GetNextHook();
+               }
        }
-       else
-       {
-               FlushSendQ(sendq);
-       }
+
+       if (psendq)
+               FlushSendQ(*psendq);
+
+       if (getSendQSize() == 0 && closeonempty)
+               Close();
 }
 
 void StreamSocket::FlushSendQ(SendQueue& sq)
@@ -252,7 +299,7 @@ void StreamSocket::FlushSendQ(SendQueue& sq)
                                        const SendQueue::Element& elem = *i;
                                        iovecs[j].iov_base = const_cast<char*>(elem.data());
                                        iovecs[j].iov_len = elem.length();
-                                       rv_max += elem.length();
+                                       rv_max += iovecs[j].iov_len;
                                }
                                rv = SocketEngine::WriteV(this, iovecs, bufcount);
                        }
@@ -317,9 +364,14 @@ void StreamSocket::FlushSendQ(SendQueue& sq)
                }
 }
 
+bool StreamSocket::OnSetEndPoint(const irc::sockets::sockaddrs& local, const irc::sockets::sockaddrs& remote)
+{
+       return false;
+}
+
 void StreamSocket::WriteData(const std::string &data)
 {
-       if (fd < 0)
+       if (!HasFd())
        {
                ServerInstance->Logs->Log("SOCKET", LOG_DEBUG, "Attempt to write data to dead socket: %s",
                        data.c_str());
@@ -452,10 +504,73 @@ void StreamSocket::CheckError(BufferedSocketError errcode)
 
 IOHook* StreamSocket::GetModHook(Module* mod) const
 {
-       if (iohook)
+       for (IOHook* curr = GetIOHook(); curr; curr = GetNextHook(curr))
        {
-               if (iohook->prov->creator == mod)
-                       return iohook;
+               if (curr->prov->creator == mod)
+                       return curr;
        }
        return NULL;
 }
+
+IOHook* StreamSocket::GetLastHook() const
+{
+       IOHook* curr = GetIOHook();
+       IOHook* last = curr;
+
+       for (; curr; curr = GetNextHook(curr))
+               last = curr;
+
+       return last;
+}
+
+
+void StreamSocket::AddIOHook(IOHook* newhook)
+{
+       IOHook* curr = GetIOHook();
+       if (!curr)
+       {
+               iohook = newhook;
+               return;
+       }
+
+       IOHookMiddle* lasthook;
+       while (curr)
+       {
+               lasthook = IOHookMiddle::ToMiddleHook(curr);
+               if (!lasthook)
+                       return;
+               curr = lasthook->GetNextHook();
+       }
+
+       lasthook->SetNextHook(newhook);
+}
+
+size_t StreamSocket::getSendQSize() const
+{
+       size_t ret = sendq.bytes();
+       IOHook* curr = GetIOHook();
+       while (curr)
+       {
+               const IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(curr);
+               if (!iohm)
+                       break;
+
+               ret += iohm->GetSendQ().bytes();
+               curr = iohm->GetNextHook();
+       }
+       return ret;
+}
+
+void StreamSocket::SwapInternals(StreamSocket& other)
+{
+       if (type != other.type)
+               return;
+
+       EventHandler::SwapInternals(other);
+       std::swap(closeonempty, other.closeonempty);
+       std::swap(closing, other.closing);
+       std::swap(error, other.error);
+       std::swap(iohook, other.iohook);
+       std::swap(recvq, other.recvq);
+       std::swap(sendq, other.sendq);
+}