diff options
-rw-r--r-- | include/inspircd.h | 4 | ||||
-rw-r--r-- | include/testsuite.h | 2 | ||||
-rw-r--r-- | src/channels.cpp | 16 | ||||
-rw-r--r-- | src/inspsocket.cpp | 2 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_gnutls.cpp | 4 | ||||
-rw-r--r-- | src/modules/m_showwhois.cpp | 3 | ||||
-rw-r--r-- | src/server.cpp | 25 | ||||
-rw-r--r-- | src/testsuite.cpp | 77 |
8 files changed, 107 insertions, 26 deletions
diff --git a/include/inspircd.h b/include/inspircd.h index cabb24aa0..9c9609530 100644 --- a/include/inspircd.h +++ b/include/inspircd.h @@ -256,6 +256,8 @@ DEFINE_HANDLER1(IsSIDHandler, bool, const std::string&); DEFINE_HANDLER1(RehashHandler, void, const std::string&); DEFINE_HANDLER3(OnCheckExemptionHandler, ModResult, User*, Channel*, const std::string&); +class TestSuite; + /** The main class of the irc server. * This class contains instances of all the other classes in this software. * Amongst other things, it contains a ModeParser, a DNS object, a CommandParser @@ -855,6 +857,8 @@ class CoreExport InspIRCd { return this->ReadBuffer; } + + friend class TestSuite; }; ENTRYPOINT; diff --git a/include/testsuite.h b/include/testsuite.h index 618615dc9..f91e508c9 100644 --- a/include/testsuite.h +++ b/include/testsuite.h @@ -21,6 +21,7 @@ class TestSuite { + bool RealGenerateUIDTests(); public: TestSuite(); ~TestSuite(); @@ -29,6 +30,7 @@ class TestSuite bool DoWildTests(); bool DoCommaSepStreamTests(); bool DoSpaceSepStreamTests(); + bool DoGenerateUIDTests(); }; #endif diff --git a/src/channels.cpp b/src/channels.cpp index 5539f4bfe..51fa74064 100644 --- a/src/channels.cpp +++ b/src/channels.cpp @@ -400,11 +400,15 @@ Channel* Channel::ForceChan(Channel* Ptr, User* user, const std::string &privs, Ptr->WriteAllExcept(user, false, 0, except_list, "JOIN :%s", Ptr->name.c_str()); /* Theyre not the first ones in here, make sure everyone else sees the modes we gave the user */ - std::string ms = memb->modes; - for(unsigned int i=0; i < memb->modes.length(); i++) - ms.append(" ").append(user->nick); - if ((Ptr->GetUserCounter() > 1) && (ms.length())) - Ptr->WriteAllExceptSender(user, ServerInstance->Config->CycleHostsFromUser, 0, "MODE %s +%s", Ptr->name.c_str(), ms.c_str()); + if ((Ptr->GetUserCounter() > 1) && (!memb->modes.empty())) + { + std::string ms = memb->modes; + for(unsigned int i=0; i < memb->modes.length(); i++) + ms.append(" ").append(user->nick); + + except_list.insert(user); + Ptr->WriteAllExcept(user, !ServerInstance->Config->CycleHostsFromUser, 0, except_list, "MODE %s +%s", Ptr->name.c_str(), ms.c_str()); + } if (IS_LOCAL(user)) { @@ -655,7 +659,7 @@ void Channel::WriteAllExcept(User* user, bool serversource, char status, CUList if (!text) return; - int offset = snprintf(textbuffer,MAXBUF,":%s ", user->GetFullHost().c_str()); + int offset = snprintf(textbuffer,MAXBUF,":%s ", serversource ? ServerInstance->Config->ServerName.c_str() : user->GetFullHost().c_str()); va_start(argsPtr, text); vsnprintf(textbuffer + offset, MAXBUF - offset, text, argsPtr); diff --git a/src/inspsocket.cpp b/src/inspsocket.cpp index 1254dc58b..3841c6147 100644 --- a/src/inspsocket.cpp +++ b/src/inspsocket.cpp @@ -194,7 +194,7 @@ void StreamSocket::DoRead() else { char* ReadBuffer = ServerInstance->GetReadBuffer(); - int n = recv(fd, ReadBuffer, ServerInstance->Config->NetBufferSize, 0); + int n = ServerInstance->SE->Recv(this, ReadBuffer, ServerInstance->Config->NetBufferSize, 0); if (n == ServerInstance->Config->NetBufferSize) { ServerInstance->SE->ChangeEventMask(this, FD_WANT_FAST_READ | FD_ADD_TRIAL_READ); diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index cc934ff77..22c027cfb 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -56,7 +56,7 @@ static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t user_wrap, void* buffe errno = EAGAIN; return -1; } - int rv = recv(user->GetFd(), reinterpret_cast<char *>(buffer), size, 0); + int rv = ServerInstance->SE->Recv(user, reinterpret_cast<char *>(buffer), size, 0); if (rv < (int)size) ServerInstance->SE->ChangeEventMask(user, FD_READ_WILL_BLOCK); return rv; @@ -70,7 +70,7 @@ static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t user_wrap, const void* errno = EAGAIN; return -1; } - int rv = send(user->GetFd(), reinterpret_cast<const char *>(buffer), size, 0); + int rv = ServerInstance->SE->Send(user, reinterpret_cast<const char *>(buffer), size, 0); if (rv < (int)size) ServerInstance->SE->ChangeEventMask(user, FD_WRITE_WILL_BLOCK); return rv; diff --git a/src/modules/m_showwhois.cpp b/src/modules/m_showwhois.cpp index 691887429..6eec64bd5 100644 --- a/src/modules/m_showwhois.cpp +++ b/src/modules/m_showwhois.cpp @@ -76,6 +76,9 @@ class WhoisNoticeCmd : public Command CmdResult Handle(const std::vector<std::string> ¶meters, User *user) { User* dest = ServerInstance->FindNick(parameters[0]); + if (!dest) + return CMD_FAILURE; + User* source = ServerInstance->FindNick(parameters[1]); if (IS_LOCAL(dest) && source) diff --git a/src/server.cpp b/src/server.cpp index dab920bb6..adaaa7d2c 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -101,20 +101,12 @@ void InspIRCd::IncrementUID(int pos) * A again, in an iterative fashion.. so.. * AAA9 -> AABA, and so on. -- w00t */ - if (pos == 3) + if ((pos == 3) && (current_uid[3] == '9')) { // At pos 3, if we hit '9', we've run out of available UIDs, and need to reset to AAA..AAA. - if (current_uid[pos] == '9') + for (int i = 3; i < UUID_LENGTH-1; i++) { - for (int i = 3; i < (UUID_LENGTH - 1); i++) - { - current_uid[i] = 'A'; - } - } - else - { - // Buf if we haven't, just keep incrementing merrily. - current_uid[pos]++; + current_uid[i] = 'A'; } } else @@ -146,17 +138,18 @@ void InspIRCd::IncrementUID(int pos) */ std::string InspIRCd::GetUID() { - static int curindex = -1; + static bool inited = false; /* - * If -1, we're setting up. Copy SID into the first three digits, 9's to the rest, null term at the end + * If we're setting up, copy SID into the first three digits, 9's to the rest, null term at the end * Why 9? Well, we increment before we find, otherwise we have an unnecessary copy, and I want UID to start at AAA..AA * and not AA..AB. So by initialising to 99999, we force it to rollover to AAAAA on the first IncrementUID call. * Kind of silly, but I like how it looks. * -- w */ - if (curindex == -1) + if (!inited) { + inited = true; current_uid[0] = Config->sid[0]; current_uid[1] = Config->sid[1]; current_uid[2] = Config->sid[2]; @@ -164,8 +157,6 @@ std::string InspIRCd::GetUID() for (int i = 3; i < (UUID_LENGTH - 1); i++) current_uid[i] = '9'; - curindex = UUID_LENGTH - 2; // look at the end of the string now kthx, ignore null - // Null terminator. Important. current_uid[UUID_LENGTH - 1] = '\0'; } @@ -173,7 +164,7 @@ std::string InspIRCd::GetUID() while (1) { // Add one to the last UID - this->IncrementUID(curindex); + this->IncrementUID(UUID_LENGTH - 2); if (this->FindUUID(current_uid)) { diff --git a/src/testsuite.cpp b/src/testsuite.cpp index 064724392..59e6102e7 100644 --- a/src/testsuite.cpp +++ b/src/testsuite.cpp @@ -65,6 +65,7 @@ TestSuite::TestSuite() cout << "(5) Wildcard and CIDR tests\n"; cout << "(6) Comma sepstream tests\n"; cout << "(7) Space sepstream tests\n"; + cout << "(8) UID generation tests\n"; cout << endl << "(X) Exit test suite\n"; @@ -105,6 +106,9 @@ TestSuite::TestSuite() case '7': cout << (DoSpaceSepStreamTests() ? "\nSUCCESS!\n" : "\nFAILURE\n"); break; + case '8': + cout << (DoGenerateUIDTests() ? "\nSUCCESS!\n" : "\nFAILURE\n"); + break; case 'X': return; break; @@ -327,6 +331,79 @@ bool TestSuite::DoThreadTests() return true; } +bool TestSuite::DoGenerateUIDTests() +{ + bool success = RealGenerateUIDTests(); + + // Reset the UID generation state so running the tests multiple times won't mess things up + for (unsigned int i = 0; i < 3; i++) + ServerInstance->current_uid[i] = ServerInstance->Config->sid[i]; + for (unsigned int i = 3; i < UUID_LENGTH-1; i++) + ServerInstance->current_uid[i] = '9'; + + ServerInstance->current_uid[UUID_LENGTH-1] = '\0'; + + return success; +} + +bool TestSuite::RealGenerateUIDTests() +{ + std::string first_uid = ServerInstance->GetUID(); + if (first_uid.length() != UUID_LENGTH-1) + { + cout << "GENERATEUID: Generated UID is " << first_uid.length() << " characters long instead of " << UUID_LENGTH-1 << endl; + return false; + } + + if (ServerInstance->current_uid[UUID_LENGTH-1] != '\0') + { + cout << "GENERATEUID: The null terminator is missing from the end of current_uid" << endl; + return false; + } + + // The correct UID when generating one for the first time is ...AAAAAA + std::string correct_uid = ServerInstance->Config->sid + std::string(UUID_LENGTH - 4, 'A'); + if (first_uid != correct_uid) + { + cout << "GENERATEUID: Generated an invalid first UID: " << first_uid << " instead of " << correct_uid << endl; + return false; + } + + // Set current_uid to be ...Z99999 + ServerInstance->current_uid[3] = 'Z'; + for (unsigned int i = 4; i < UUID_LENGTH-1; i++) + ServerInstance->current_uid[i] = '9'; + + // Store the UID we'll be incrementing so we can display what's wrong later if necessary + std::string before_increment(ServerInstance->current_uid); + std::string generated_uid = ServerInstance->GetUID(); + + // Correct UID after incrementing ...Z99999 is ...0AAAAA + correct_uid = ServerInstance->Config->sid + "0" + std::string(UUID_LENGTH - 5, 'A'); + + if (generated_uid != correct_uid) + { + cout << "GENERATEUID: Generated an invalid UID after incrementing " << before_increment << ": " << generated_uid << " instead of " << correct_uid << endl; + return false; + } + + // Set current_uid to be ...999999 to see if it rolls over correctly + for (unsigned int i = 3; i < UUID_LENGTH-1; i++) + ServerInstance->current_uid[i] = '9'; + + before_increment.assign(ServerInstance->current_uid); + generated_uid = ServerInstance->GetUID(); + + // Correct UID after rolling over is the first UID we've generated (...AAAAAA) + if (generated_uid != first_uid) + { + cout << "GENERATEUID: Generated an invalid UID after incrementing " << before_increment << ": " << generated_uid << " instead of " << first_uid << endl; + return false; + } + + return true; +} + TestSuite::~TestSuite() { cout << "\n\n*** END OF TEST SUITE ***\n"; |