#pragma comment(lib, "Iphlpapi.lib")
#endif
+namespace DNS
+{
+ /** Maximum value of a dns request id, 16 bits wide, 0xFFFF.
+ */
+ const unsigned int MAX_REQUEST_ID = 0xFFFF;
+}
+
using namespace DNS;
/** A full packet sent or recieved to/from the nameserver
*/
class Packet : public Query
{
+ static bool IsValidName(const std::string& name)
+ {
+ return (name.find_first_not_of("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.-") == std::string::npos);
+ }
+
void PackName(unsigned char* output, unsigned short output_size, unsigned short& pos, const std::string& name)
{
if (pos + name.length() + 2 > output_size)
Question UnpackQuestion(const unsigned char* input, unsigned short input_size, unsigned short& pos)
{
- Question question;
+ Question q;
- question.name = this->UnpackName(input, input_size, pos);
+ q.name = this->UnpackName(input, input_size, pos);
if (pos + 4 > input_size)
throw Exception("Unable to unpack question");
- question.type = static_cast<QueryType>(input[pos] << 8 | input[pos + 1]);
+ q.type = static_cast<QueryType>(input[pos] << 8 | input[pos + 1]);
pos += 2;
- question.qclass = input[pos] << 8 | input[pos + 1];
+ // Skip over query class code
pos += 2;
- return question;
+ return q;
}
ResourceRecord UnpackResourceRecord(const unsigned char* input, unsigned short input_size, unsigned short& pos)
case QUERY_PTR:
{
record.rdata = this->UnpackName(input, input_size, pos);
+ if (!IsValidName(record.rdata))
+ throw Exception("Invalid name"); // XXX: Causes the request to time out
+
break;
}
default:
ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "qdcount: " + ConvToStr(qdcount) + " ancount: " + ConvToStr(ancount) + " nscount: " + ConvToStr(nscount) + " arcount: " + ConvToStr(arcount));
- for (unsigned i = 0; i < qdcount; ++i)
- this->questions.push_back(this->UnpackQuestion(input, len, packet_pos));
+ if (qdcount != 1)
+ throw Exception("Question count != 1 in incoming packet");
+
+ this->question = this->UnpackQuestion(input, len, packet_pos);
for (unsigned i = 0; i < ancount; ++i)
this->answers.push_back(this->UnpackResourceRecord(input, len, packet_pos));
output[pos++] = this->id & 0xFF;
output[pos++] = this->flags >> 8;
output[pos++] = this->flags & 0xFF;
- output[pos++] = this->questions.size() >> 8;
- output[pos++] = this->questions.size() & 0xFF;
+ output[pos++] = 0; // Question count, high byte
+ output[pos++] = 1; // Question count, low byte
output[pos++] = 0; // Answer count, high byte
output[pos++] = 0; // Answer count, low byte
output[pos++] = 0;
output[pos++] = 0;
output[pos++] = 0;
- for (unsigned i = 0; i < this->questions.size(); ++i)
{
- Question& q = this->questions[i];
+ Question& q = this->question;
if (q.type == QUERY_PTR)
{
memcpy(&output[pos], &s, 2);
pos += 2;
- s = htons(q.qclass);
- memcpy(&output[pos], &s, 2);
- pos += 2;
+ // Query class, always IN
+ output[pos++] = 0;
+ output[pos++] = 1;
}
return pos;
*/
void AddCache(Query& r)
{
- const ResourceRecord& rr = r.answers[0];
+ // Determine the lowest TTL value and use that as the TTL of the cache entry
+ unsigned int cachettl = UINT_MAX;
+ for (std::vector<ResourceRecord>::const_iterator i = r.answers.begin(); i != r.answers.end(); ++i)
+ {
+ const ResourceRecord& rr = *i;
+ if (rr.ttl < cachettl)
+ cachettl = rr.ttl;
+ }
+
+ ResourceRecord& rr = r.answers.front();
+ // Set TTL to what we've determined to be the lowest
+ rr.ttl = cachettl;
ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "cache: added cache for " + rr.name + " -> " + rr.rdata + " ttl: " + ConvToStr(rr.ttl));
- this->cache[r.questions[0]] = r;
+ this->cache[r.question] = r;
}
public:
Packet p;
p.flags = QUERYFLAGS_RD;
p.id = req->id;
- p.questions.push_back(*req);
+ p.question = *req;
unsigned char buffer[524];
unsigned short len = p.Pack(buffer, sizeof(buffer));
- /* Note that calling Pack() above can actually change the contents of p.questions[0].name, if the query is a PTR,
+ /* Note that calling Pack() above can actually change the contents of p.question.name, if the query is a PTR,
* to contain the value that would be in the DNS cache, which is why this is here.
*/
- if (req->use_cache && this->CheckCache(req, p.questions[0]))
+ if (req->use_cache && this->CheckCache(req, p.question))
{
ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Using cached result");
delete req;
return;
}
+ // Update name in the original request so question checking works for PTR queries
+ req->name = p.question.name;
+
if (SocketEngine::SendTo(this, buffer, len, 0, &this->myserver.sa, this->myserver.sa_size()) != len)
throw Exception("DNS: Unable to send query");
+
+ // Add timer for timeout
+ ServerInstance->Timers.AddTimer(req);
}
void RemoveRequest(DNS::Request* req)
{
- this->requests[req->id] = NULL;
+ if (requests[req->id] == req)
+ requests[req->id] = NULL;
}
std::string GetErrorStr(Error e)
case ERROR_NOT_AN_ANSWER:
case ERROR_NONSTANDARD_QUERY:
case ERROR_FORMAT_ERROR:
+ case ERROR_MALFORMED:
return "Malformed answer";
case ERROR_SERVER_FAILURE:
case ERROR_NOT_IMPLEMENTED:
}
Packet recv_packet;
+ bool valid = false;
try
{
recv_packet.Fill(buffer, length);
+ valid = true;
}
catch (Exception& ex)
{
ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, ex.GetReason());
- return;
}
+ // recv_packet.id must be filled in here
DNS::Request* request = this->requests[recv_packet.id];
if (request == NULL)
{
return;
}
- if (recv_packet.flags & QUERYFLAGS_OPCODE)
+ if (static_cast<Question&>(*request) != recv_packet.question)
+ {
+ // This can happen under high latency, drop it silently, do not fail the request
+ ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Received an answer that isn't for a question we asked");
+ return;
+ }
+
+ if (!valid)
+ {
+ ServerInstance->stats.DnsBad++;
+ recv_packet.error = ERROR_MALFORMED;
+ request->OnError(&recv_packet);
+ }
+ else if (recv_packet.flags & QUERYFLAGS_OPCODE)
{
ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Received a nonstandard query");
ServerInstance->stats.DnsBad++;
recv_packet.error = ERROR_NONSTANDARD_QUERY;
request->OnError(&recv_packet);
}
- else if (recv_packet.flags & QUERYFLAGS_RCODE)
+ else if (!(recv_packet.flags & QUERYFLAGS_QR) || (recv_packet.flags & QUERYFLAGS_RCODE))
{
Error error = ERROR_UNKNOWN;
recv_packet.error = error;
request->OnError(&recv_packet);
}
- else if (recv_packet.questions.empty() || recv_packet.answers.empty())
+ else if (recv_packet.answers.empty())
{
ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "No resource records returned");
ServerInstance->stats.DnsBad++;
if (DNSServer == "nameserver")
{
resolv >> DNSServer;
- if (DNSServer.find_first_not_of("0123456789.") == std::string::npos)
+ if (DNSServer.find_first_not_of("0123456789.") == std::string::npos || DNSServer.find_first_not_of("0123456789ABCDEFabcdef:") == std::string::npos)
{
ServerInstance->Logs->Log("CONFIG", LOG_DEFAULT, "<dns:server> set to '%s' as first resolver in /etc/resolv.conf.",DNSServer.c_str());
return;