X-Git-Url: https://git.netwichtig.de/gitweb/?a=blobdiff_plain;f=src%2Fcoremods%2Fcore_dns.cpp;h=1e8bb753389272725e5aa5bd80c4c0dd615beca9;hb=dbe5a1fc6f9e18765863f332a3e79d7c918d3e65;hp=f4559c08fe672cca508bb43f9305aaf7e55ada39;hpb=8950aeda9a42ec795ce4edf87b40135047f55f6d;p=user%2Fhenk%2Fcode%2Finspircd.git diff --git a/src/coremods/core_dns.cpp b/src/coremods/core_dns.cpp index f4559c08f..1e8bb7533 100644 --- a/src/coremods/core_dns.cpp +++ b/src/coremods/core_dns.cpp @@ -27,12 +27,24 @@ #pragma comment(lib, "Iphlpapi.lib") #endif +namespace DNS +{ + /** Maximum value of a dns request id, 16 bits wide, 0xFFFF. + */ + const unsigned int MAX_REQUEST_ID = 0xFFFF; +} + using namespace DNS; /** A full packet sent or recieved to/from the nameserver */ class Packet : public Query { + static bool IsValidName(const std::string& name) + { + return (name.find_first_not_of("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.-") == std::string::npos); + } + void PackName(unsigned char* output, unsigned short output_size, unsigned short& pos, const std::string& name) { if (pos + name.length() + 2 > output_size) @@ -116,20 +128,20 @@ class Packet : public Query Question UnpackQuestion(const unsigned char* input, unsigned short input_size, unsigned short& pos) { - Question question; + Question q; - question.name = this->UnpackName(input, input_size, pos); + q.name = this->UnpackName(input, input_size, pos); if (pos + 4 > input_size) throw Exception("Unable to unpack question"); - question.type = static_cast(input[pos] << 8 | input[pos + 1]); + q.type = static_cast(input[pos] << 8 | input[pos + 1]); pos += 2; - question.qclass = input[pos] << 8 | input[pos + 1]; + // Skip over query class code pos += 2; - return question; + return q; } ResourceRecord UnpackResourceRecord(const unsigned char* input, unsigned short input_size, unsigned short& pos) @@ -183,6 +195,9 @@ class Packet : public Query case QUERY_PTR: { record.rdata = this->UnpackName(input, input_size, pos); + if (!IsValidName(record.rdata)) + throw Exception("Invalid name"); // XXX: Causes the request to time out + break; } default: @@ -201,7 +216,7 @@ class Packet : public Query static const int HEADER_LENGTH = 12; /* ID for this packet */ - unsigned short id; + RequestId id; /* Flags on the packet */ unsigned short flags; @@ -219,9 +234,6 @@ class Packet : public Query this->id = (input[packet_pos] << 8) | input[packet_pos + 1]; packet_pos += 2; - if (this->id >= MAX_REQUEST_ID) - throw Exception("Query ID too large?"); - this->flags = (input[packet_pos] << 8) | input[packet_pos + 1]; packet_pos += 2; @@ -239,8 +251,10 @@ class Packet : public Query ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "qdcount: " + ConvToStr(qdcount) + " ancount: " + ConvToStr(ancount) + " nscount: " + ConvToStr(nscount) + " arcount: " + ConvToStr(arcount)); - for (unsigned i = 0; i < qdcount; ++i) - this->questions.push_back(this->UnpackQuestion(input, len, packet_pos)); + if (qdcount != 1) + throw Exception("Question count != 1 in incoming packet"); + + this->question = this->UnpackQuestion(input, len, packet_pos); for (unsigned i = 0; i < ancount; ++i) this->answers.push_back(this->UnpackResourceRecord(input, len, packet_pos)); @@ -257,18 +271,17 @@ class Packet : public Query output[pos++] = this->id & 0xFF; output[pos++] = this->flags >> 8; output[pos++] = this->flags & 0xFF; - output[pos++] = this->questions.size() >> 8; - output[pos++] = this->questions.size() & 0xFF; - output[pos++] = this->answers.size() >> 8; - output[pos++] = this->answers.size() & 0xFF; + output[pos++] = 0; // Question count, high byte + output[pos++] = 1; // Question count, low byte + output[pos++] = 0; // Answer count, high byte + output[pos++] = 0; // Answer count, low byte output[pos++] = 0; output[pos++] = 0; output[pos++] = 0; output[pos++] = 0; - for (unsigned i = 0; i < this->questions.size(); ++i) { - Question& q = this->questions[i]; + Question& q = this->question; if (q.type == QUERY_PTR) { @@ -310,84 +323,9 @@ class Packet : public Query memcpy(&output[pos], &s, 2); pos += 2; - s = htons(q.qclass); - memcpy(&output[pos], &s, 2); - pos += 2; - } - - for (unsigned int i = 0; i < answers.size(); i++) - { - ResourceRecord& rr = answers[i]; - - this->PackName(output, output_size, pos, rr.name); - - if (pos + 8 >= output_size) - throw Exception("Unable to pack packet"); - - short s = htons(rr.type); - memcpy(&output[pos], &s, 2); - pos += 2; - - s = htons(rr.qclass); - memcpy(&output[pos], &s, 2); - pos += 2; - - long l = htonl(rr.ttl); - memcpy(&output[pos], &l, 4); - pos += 4; - - switch (rr.type) - { - case QUERY_A: - { - if (pos + 6 > output_size) - throw Exception("Unable to pack packet"); - - irc::sockets::sockaddrs a; - irc::sockets::aptosa(rr.rdata, 0, a); - - s = htons(4); - memcpy(&output[pos], &s, 2); - pos += 2; - - memcpy(&output[pos], &a.in4.sin_addr, 4); - pos += 4; - break; - } - case QUERY_AAAA: - { - if (pos + 18 > output_size) - throw Exception("Unable to pack packet"); - - irc::sockets::sockaddrs a; - irc::sockets::aptosa(rr.rdata, 0, a); - - s = htons(16); - memcpy(&output[pos], &s, 2); - pos += 2; - - memcpy(&output[pos], &a.in6.sin6_addr, 16); - pos += 16; - break; - } - case QUERY_CNAME: - case QUERY_PTR: - { - if (pos + 2 >= output_size) - throw Exception("Unable to pack packet"); - - unsigned short packet_pos_save = pos; - pos += 2; - - this->PackName(output, output_size, pos, rr.rdata); - - s = htons(pos - packet_pos_save - 2); - memcpy(&output[packet_pos_save], &s, 2); - break; - } - default: - break; - } + // Query class, always IN + output[pos++] = 0; + output[pos++] = 1; } return pos; @@ -401,6 +339,10 @@ class MyManager : public Manager, public Timer, public EventHandler irc::sockets::sockaddrs myserver; + /** Maximum number of entries in cache + */ + static const unsigned int MAX_CACHE_SIZE = 1000; + static bool IsExpired(const Query& record, time_t now = ServerInstance->Time()) { const ResourceRecord& req = record.answers[0]; @@ -436,24 +378,39 @@ class MyManager : public Manager, public Timer, public EventHandler */ void AddCache(Query& r) { - const ResourceRecord& rr = r.answers[0]; + if (cache.size() >= MAX_CACHE_SIZE) + cache.clear(); + + // Determine the lowest TTL value and use that as the TTL of the cache entry + unsigned int cachettl = UINT_MAX; + for (std::vector::const_iterator i = r.answers.begin(); i != r.answers.end(); ++i) + { + const ResourceRecord& rr = *i; + if (rr.ttl < cachettl) + cachettl = rr.ttl; + } + + cachettl = std::min(cachettl, (unsigned int)5*60); + ResourceRecord& rr = r.answers.front(); + // Set TTL to what we've determined to be the lowest + rr.ttl = cachettl; ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "cache: added cache for " + rr.name + " -> " + rr.rdata + " ttl: " + ConvToStr(rr.ttl)); - this->cache[r.questions[0]] = r; + this->cache[r.question] = r; } public: - DNS::Request* requests[MAX_REQUEST_ID]; + DNS::Request* requests[MAX_REQUEST_ID+1]; - MyManager(Module* c) : Manager(c), Timer(3600, true) + MyManager(Module* c) : Manager(c), Timer(5*60, true) { - for (int i = 0; i < MAX_REQUEST_ID; ++i) + for (unsigned int i = 0; i <= MAX_REQUEST_ID; ++i) requests[i] = NULL; ServerInstance->Timers.AddTimer(this); } ~MyManager() { - for (int i = 0; i < MAX_REQUEST_ID; ++i) + for (unsigned int i = 0; i <= MAX_REQUEST_ID; ++i) { DNS::Request* request = requests[i]; if (!request) @@ -476,14 +433,14 @@ class MyManager : public Manager, public Timer, public EventHandler int id; do { - id = ServerInstance->GenRandomInt(DNS::MAX_REQUEST_ID); + id = ServerInstance->GenRandomInt(DNS::MAX_REQUEST_ID+1); if (++tries == DNS::MAX_REQUEST_ID*5) { // If we couldn't find an empty slot this many times, do a sequential scan as a last // resort. If an empty slot is found that way, go on, otherwise throw an exception id = -1; - for (unsigned int i = 0; i < DNS::MAX_REQUEST_ID; i++) + for (unsigned int i = 0; i <= DNS::MAX_REQUEST_ID; i++) { if (!this->requests[i]) { @@ -506,28 +463,35 @@ class MyManager : public Manager, public Timer, public EventHandler Packet p; p.flags = QUERYFLAGS_RD; p.id = req->id; - p.questions.push_back(*req); + p.question = *req; unsigned char buffer[524]; unsigned short len = p.Pack(buffer, sizeof(buffer)); - /* Note that calling Pack() above can actually change the contents of p.questions[0].name, if the query is a PTR, + /* Note that calling Pack() above can actually change the contents of p.question.name, if the query is a PTR, * to contain the value that would be in the DNS cache, which is why this is here. */ - if (req->use_cache && this->CheckCache(req, p.questions[0])) + if (req->use_cache && this->CheckCache(req, p.question)) { ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Using cached result"); delete req; return; } + // Update name in the original request so question checking works for PTR queries + req->name = p.question.name; + if (SocketEngine::SendTo(this, buffer, len, 0, &this->myserver.sa, this->myserver.sa_size()) != len) throw Exception("DNS: Unable to send query"); + + // Add timer for timeout + ServerInstance->Timers.AddTimer(req); } void RemoveRequest(DNS::Request* req) { - this->requests[req->id] = NULL; + if (requests[req->id] == req) + requests[req->id] = NULL; } std::string GetErrorStr(Error e) @@ -541,6 +505,7 @@ class MyManager : public Manager, public Timer, public EventHandler case ERROR_NOT_AN_ANSWER: case ERROR_NONSTANDARD_QUERY: case ERROR_FORMAT_ERROR: + case ERROR_MALFORMED: return "Malformed answer"; case ERROR_SERVER_FAILURE: case ERROR_NOT_IMPLEMENTED: @@ -583,17 +548,19 @@ class MyManager : public Manager, public Timer, public EventHandler } Packet recv_packet; + bool valid = false; try { recv_packet.Fill(buffer, length); + valid = true; } catch (Exception& ex) { ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, ex.GetReason()); - return; } + // recv_packet.id must be filled in here DNS::Request* request = this->requests[recv_packet.id]; if (request == NULL) { @@ -601,14 +568,27 @@ class MyManager : public Manager, public Timer, public EventHandler return; } - if (recv_packet.flags & QUERYFLAGS_OPCODE) + if (static_cast(*request) != recv_packet.question) + { + // This can happen under high latency, drop it silently, do not fail the request + ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Received an answer that isn't for a question we asked"); + return; + } + + if (!valid) + { + ServerInstance->stats.DnsBad++; + recv_packet.error = ERROR_MALFORMED; + request->OnError(&recv_packet); + } + else if (recv_packet.flags & QUERYFLAGS_OPCODE) { ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Received a nonstandard query"); ServerInstance->stats.DnsBad++; recv_packet.error = ERROR_NONSTANDARD_QUERY; request->OnError(&recv_packet); } - else if (recv_packet.flags & QUERYFLAGS_RCODE) + else if (!(recv_packet.flags & QUERYFLAGS_QR) || (recv_packet.flags & QUERYFLAGS_RCODE)) { Error error = ERROR_UNKNOWN; @@ -642,7 +622,7 @@ class MyManager : public Manager, public Timer, public EventHandler recv_packet.error = error; request->OnError(&recv_packet); } - else if (recv_packet.questions.empty() || recv_packet.answers.empty()) + else if (recv_packet.answers.empty()) { ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "No resource records returned"); ServerInstance->stats.DnsBad++; @@ -776,7 +756,7 @@ class ModuleDNS : public Module if (DNSServer == "nameserver") { resolv >> DNSServer; - if (DNSServer.find_first_not_of("0123456789.") == std::string::npos) + if (DNSServer.find_first_not_of("0123456789.") == std::string::npos || DNSServer.find_first_not_of("0123456789ABCDEFabcdef:") == std::string::npos) { ServerInstance->Logs->Log("CONFIG", LOG_DEFAULT, " set to '%s' as first resolver in /etc/resolv.conf.",DNSServer.c_str()); return; @@ -807,7 +787,7 @@ class ModuleDNS : public Module void OnUnloadModule(Module* mod) { - for (int i = 0; i < MAX_REQUEST_ID; ++i) + for (unsigned int i = 0; i <= MAX_REQUEST_ID; ++i) { DNS::Request* req = this->manager.requests[i]; if (!req)