]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/coremods/core_dns.cpp
Textual improvements and fixes such as typos, casing, etc. (#1612)
[user/henk/code/inspircd.git] / src / coremods / core_dns.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2013 Adam <Adam@anope.org>
5  *   Copyright (C) 2003-2013 Anope Team <team@anope.org>
6  *
7  * This file is part of InspIRCd.  InspIRCd is free software: you can
8  * redistribute it and/or modify it under the terms of the GNU General Public
9  * License as published by the Free Software Foundation, version 2.
10  *
11  * This program is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
13  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
14  * details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
18  */
19
20 #include "inspircd.h"
21 #include "modules/dns.h"
22 #include <iostream>
23 #include <fstream>
24
25 #ifdef _WIN32
26 #include <Iphlpapi.h>
27 #pragma comment(lib, "Iphlpapi.lib")
28 #endif
29
30 namespace DNS
31 {
32         /** Maximum value of a dns request id, 16 bits wide, 0xFFFF.
33          */
34         const unsigned int MAX_REQUEST_ID = 0xFFFF;
35 }
36
37 using namespace DNS;
38
39 /** A full packet sent or received to/from the nameserver
40  */
41 class Packet : public Query
42 {
43         void PackName(unsigned char* output, unsigned short output_size, unsigned short& pos, const std::string& name)
44         {
45                 if (pos + name.length() + 2 > output_size)
46                         throw Exception("Unable to pack name");
47
48                 ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Packing name " + name);
49
50                 irc::sepstream sep(name, '.');
51                 std::string token;
52
53                 while (sep.GetToken(token))
54                 {
55                         output[pos++] = token.length();
56                         memcpy(&output[pos], token.data(), token.length());
57                         pos += token.length();
58                 }
59
60                 output[pos++] = 0;
61         }
62
63         std::string UnpackName(const unsigned char* input, unsigned short input_size, unsigned short& pos)
64         {
65                 std::string name;
66                 unsigned short pos_ptr = pos, lowest_ptr = input_size;
67                 bool compressed = false;
68
69                 if (pos_ptr >= input_size)
70                         throw Exception("Unable to unpack name - no input");
71
72                 while (input[pos_ptr] > 0)
73                 {
74                         unsigned short offset = input[pos_ptr];
75
76                         if (offset & POINTER)
77                         {
78                                 if ((offset & POINTER) != POINTER)
79                                         throw Exception("Unable to unpack name - bogus compression header");
80                                 if (pos_ptr + 1 >= input_size)
81                                         throw Exception("Unable to unpack name - bogus compression header");
82
83                                 /* Place pos at the second byte of the first (farthest) compression pointer */
84                                 if (compressed == false)
85                                 {
86                                         ++pos;
87                                         compressed = true;
88                                 }
89
90                                 pos_ptr = (offset & LABEL) << 8 | input[pos_ptr + 1];
91
92                                 /* Pointers can only go back */
93                                 if (pos_ptr >= lowest_ptr)
94                                         throw Exception("Unable to unpack name - bogus compression pointer");
95                                 lowest_ptr = pos_ptr;
96                         }
97                         else
98                         {
99                                 if (pos_ptr + offset + 1 >= input_size)
100                                         throw Exception("Unable to unpack name - offset too large");
101                                 if (!name.empty())
102                                         name += ".";
103                                 for (unsigned i = 1; i <= offset; ++i)
104                                         name += input[pos_ptr + i];
105
106                                 pos_ptr += offset + 1;
107                                 if (compressed == false)
108                                         /* Move up pos */
109                                         pos = pos_ptr;
110                         }
111                 }
112
113                 /* +1 pos either to one byte after the compression pointer or one byte after the ending \0 */
114                 ++pos;
115
116                 if (name.empty())
117                         throw Exception("Unable to unpack name - no name");
118
119                 ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Unpack name " + name);
120
121                 return name;
122         }
123
124         Question UnpackQuestion(const unsigned char* input, unsigned short input_size, unsigned short& pos)
125         {
126                 Question q;
127
128                 q.name = this->UnpackName(input, input_size, pos);
129
130                 if (pos + 4 > input_size)
131                         throw Exception("Unable to unpack question");
132
133                 q.type = static_cast<QueryType>(input[pos] << 8 | input[pos + 1]);
134                 pos += 2;
135
136                 // Skip over query class code
137                 pos += 2;
138
139                 return q;
140         }
141
142         ResourceRecord UnpackResourceRecord(const unsigned char* input, unsigned short input_size, unsigned short& pos)
143         {
144                 ResourceRecord record = static_cast<ResourceRecord>(this->UnpackQuestion(input, input_size, pos));
145
146                 if (pos + 6 > input_size)
147                         throw Exception("Unable to unpack resource record");
148
149                 record.ttl = (input[pos] << 24) | (input[pos + 1] << 16) | (input[pos + 2] << 8) | input[pos + 3];
150                 pos += 4;
151
152                 uint16_t rdlength = input[pos] << 8 | input[pos + 1];
153                 pos += 2;
154
155                 switch (record.type)
156                 {
157                         case QUERY_A:
158                         {
159                                 if (pos + 4 > input_size)
160                                         throw Exception("Unable to unpack resource record");
161
162                                 irc::sockets::sockaddrs addrs;
163                                 memset(&addrs, 0, sizeof(addrs));
164
165                                 addrs.in4.sin_family = AF_INET;
166                                 addrs.in4.sin_addr.s_addr = input[pos] | (input[pos + 1] << 8) | (input[pos + 2] << 16)  | (input[pos + 3] << 24);
167                                 pos += 4;
168
169                                 record.rdata = addrs.addr();
170                                 break;
171                         }
172                         case QUERY_AAAA:
173                         {
174                                 if (pos + 16 > input_size)
175                                         throw Exception("Unable to unpack resource record");
176
177                                 irc::sockets::sockaddrs addrs;
178                                 memset(&addrs, 0, sizeof(addrs));
179
180                                 addrs.in6.sin6_family = AF_INET6;
181                                 for (int j = 0; j < 16; ++j)
182                                         addrs.in6.sin6_addr.s6_addr[j] = input[pos + j];
183                                 pos += 16;
184
185                                 record.rdata = addrs.addr();
186
187                                 break;
188                         }
189                         case QUERY_CNAME:
190                         case QUERY_PTR:
191                         {
192                                 record.rdata = this->UnpackName(input, input_size, pos);
193                                 if (!InspIRCd::IsHost(record.rdata))
194                                         throw Exception("Invalid name"); // XXX: Causes the request to time out
195
196                                 break;
197                         }
198                         case QUERY_TXT:
199                         {
200                                 if (pos + rdlength > input_size)
201                                         throw Exception("Unable to unpack txt resource record");
202
203                                 record.rdata = std::string(reinterpret_cast<const char *>(input + pos), rdlength);
204                                 pos += rdlength;
205
206                                 if (record.rdata.find_first_of("\r\n\0", 0, 3) != std::string::npos)
207                                         throw Exception("Invalid character in txt record");
208
209                                 break;
210                         }
211                         default:
212                                 break;
213                 }
214
215                 if (!record.name.empty() && !record.rdata.empty())
216                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, record.name + " -> " + record.rdata);
217
218                 return record;
219         }
220
221  public:
222         static const int POINTER = 0xC0;
223         static const int LABEL = 0x3F;
224         static const int HEADER_LENGTH = 12;
225
226         /* ID for this packet */
227         RequestId id;
228         /* Flags on the packet */
229         unsigned short flags;
230
231         Packet() : id(0), flags(0)
232         {
233         }
234
235         void Fill(const unsigned char* input, const unsigned short len)
236         {
237                 if (len < HEADER_LENGTH)
238                         throw Exception("Unable to fill packet");
239
240                 unsigned short packet_pos = 0;
241
242                 this->id = (input[packet_pos] << 8) | input[packet_pos + 1];
243                 packet_pos += 2;
244
245                 this->flags = (input[packet_pos] << 8) | input[packet_pos + 1];
246                 packet_pos += 2;
247
248                 unsigned short qdcount = (input[packet_pos] << 8) | input[packet_pos + 1];
249                 packet_pos += 2;
250
251                 unsigned short ancount = (input[packet_pos] << 8) | input[packet_pos + 1];
252                 packet_pos += 2;
253
254                 unsigned short nscount = (input[packet_pos] << 8) | input[packet_pos + 1];
255                 packet_pos += 2;
256
257                 unsigned short arcount = (input[packet_pos] << 8) | input[packet_pos + 1];
258                 packet_pos += 2;
259
260                 ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "qdcount: " + ConvToStr(qdcount) + " ancount: " + ConvToStr(ancount) + " nscount: " + ConvToStr(nscount) + " arcount: " + ConvToStr(arcount));
261
262                 if (qdcount != 1)
263                         throw Exception("Question count != 1 in incoming packet");
264
265                 this->question = this->UnpackQuestion(input, len, packet_pos);
266
267                 for (unsigned i = 0; i < ancount; ++i)
268                         this->answers.push_back(this->UnpackResourceRecord(input, len, packet_pos));
269         }
270
271         unsigned short Pack(unsigned char* output, unsigned short output_size)
272         {
273                 if (output_size < HEADER_LENGTH)
274                         throw Exception("Unable to pack packet");
275
276                 unsigned short pos = 0;
277
278                 output[pos++] = this->id >> 8;
279                 output[pos++] = this->id & 0xFF;
280                 output[pos++] = this->flags >> 8;
281                 output[pos++] = this->flags & 0xFF;
282                 output[pos++] = 0; // Question count, high byte
283                 output[pos++] = 1; // Question count, low byte
284                 output[pos++] = 0; // Answer count, high byte
285                 output[pos++] = 0; // Answer count, low byte
286                 output[pos++] = 0;
287                 output[pos++] = 0;
288                 output[pos++] = 0;
289                 output[pos++] = 0;
290
291                 {
292                         Question& q = this->question;
293
294                         if (q.type == QUERY_PTR)
295                         {
296                                 irc::sockets::sockaddrs ip;
297                                 irc::sockets::aptosa(q.name, 0, ip);
298
299                                 if (q.name.find(':') != std::string::npos)
300                                 {
301                                         static const char* const hex = "0123456789abcdef";
302                                         char reverse_ip[128];
303                                         unsigned reverse_ip_count = 0;
304                                         for (int j = 15; j >= 0; --j)
305                                         {
306                                                 reverse_ip[reverse_ip_count++] = hex[ip.in6.sin6_addr.s6_addr[j] & 0xF];
307                                                 reverse_ip[reverse_ip_count++] = '.';
308                                                 reverse_ip[reverse_ip_count++] = hex[ip.in6.sin6_addr.s6_addr[j] >> 4];
309                                                 reverse_ip[reverse_ip_count++] = '.';
310                                         }
311                                         reverse_ip[reverse_ip_count++] = 0;
312
313                                         q.name = reverse_ip;
314                                         q.name += "ip6.arpa";
315                                 }
316                                 else
317                                 {
318                                         unsigned long forward = ip.in4.sin_addr.s_addr;
319                                         ip.in4.sin_addr.s_addr = forward << 24 | (forward & 0xFF00) << 8 | (forward & 0xFF0000) >> 8 | forward >> 24;
320
321                                         q.name = ip.addr() + ".in-addr.arpa";
322                                 }
323                         }
324
325                         this->PackName(output, output_size, pos, q.name);
326
327                         if (pos + 4 >= output_size)
328                                 throw Exception("Unable to pack packet");
329
330                         short s = htons(q.type);
331                         memcpy(&output[pos], &s, 2);
332                         pos += 2;
333
334                         // Query class, always IN
335                         output[pos++] = 0;
336                         output[pos++] = 1;
337                 }
338
339                 return pos;
340         }
341 };
342
343 class MyManager : public Manager, public Timer, public EventHandler
344 {
345         typedef TR1NS::unordered_map<Question, Query, Question::hash> cache_map;
346         cache_map cache;
347
348         irc::sockets::sockaddrs myserver;
349         bool unloading;
350
351         /** Maximum number of entries in cache
352          */
353         static const unsigned int MAX_CACHE_SIZE = 1000;
354
355         static bool IsExpired(const Query& record, time_t now = ServerInstance->Time())
356         {
357                 const ResourceRecord& req = record.answers[0];
358                 return (req.created + static_cast<time_t>(req.ttl) < now);
359         }
360
361         /** Check the DNS cache to see if request can be handled by a cached result
362          * @return true if a cached result was found.
363          */
364         bool CheckCache(DNS::Request* req, const DNS::Question& question)
365         {
366                 ServerInstance->Logs->Log(MODNAME, LOG_SPARSE, "cache: Checking cache for " + question.name);
367
368                 cache_map::iterator it = this->cache.find(question);
369                 if (it == this->cache.end())
370                         return false;
371
372                 Query& record = it->second;
373                 if (IsExpired(record))
374                 {
375                         this->cache.erase(it);
376                         return false;
377                 }
378
379                 ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "cache: Using cached result for " + question.name);
380                 record.cached = true;
381                 req->OnLookupComplete(&record);
382                 return true;
383         }
384
385         /** Add a record to the dns cache
386          * @param r The record
387          */
388         void AddCache(Query& r)
389         {
390                 if (cache.size() >= MAX_CACHE_SIZE)
391                         cache.clear();
392
393                 // Determine the lowest TTL value and use that as the TTL of the cache entry
394                 unsigned int cachettl = UINT_MAX;
395                 for (std::vector<ResourceRecord>::const_iterator i = r.answers.begin(); i != r.answers.end(); ++i)
396                 {
397                         const ResourceRecord& rr = *i;
398                         if (rr.ttl < cachettl)
399                                 cachettl = rr.ttl;
400                 }
401
402                 cachettl = std::min(cachettl, (unsigned int)5*60);
403                 ResourceRecord& rr = r.answers.front();
404                 // Set TTL to what we've determined to be the lowest
405                 rr.ttl = cachettl;
406                 ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "cache: added cache for " + rr.name + " -> " + rr.rdata + " ttl: " + ConvToStr(rr.ttl));
407                 this->cache[r.question] = r;
408         }
409
410  public:
411         DNS::Request* requests[MAX_REQUEST_ID+1];
412
413         MyManager(Module* c) : Manager(c), Timer(5*60, true)
414                 , unloading(false)
415         {
416                 for (unsigned int i = 0; i <= MAX_REQUEST_ID; ++i)
417                         requests[i] = NULL;
418                 ServerInstance->Timers.AddTimer(this);
419         }
420
421         ~MyManager()
422         {
423                 // Ensure Process() will fail for new requests
424                 unloading = true;
425
426                 for (unsigned int i = 0; i <= MAX_REQUEST_ID; ++i)
427                 {
428                         DNS::Request* request = requests[i];
429                         if (!request)
430                                 continue;
431
432                         Query rr(request->question);
433                         rr.error = ERROR_UNKNOWN;
434                         request->OnError(&rr);
435
436                         delete request;
437                 }
438         }
439
440         void Process(DNS::Request* req) CXX11_OVERRIDE
441         {
442                 if ((unloading) || (req->creator->dying))
443                         throw Exception("Module is being unloaded");
444
445                 ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Processing request to lookup " + req->question.name + " of type " + ConvToStr(req->question.type) + " to " + this->myserver.addr());
446
447                 /* Create an id */
448                 unsigned int tries = 0;
449                 int id;
450                 do
451                 {
452                         id = ServerInstance->GenRandomInt(DNS::MAX_REQUEST_ID+1);
453
454                         if (++tries == DNS::MAX_REQUEST_ID*5)
455                         {
456                                 // If we couldn't find an empty slot this many times, do a sequential scan as a last
457                                 // resort. If an empty slot is found that way, go on, otherwise throw an exception
458                                 id = -1;
459                                 for (unsigned int i = 0; i <= DNS::MAX_REQUEST_ID; i++)
460                                 {
461                                         if (!this->requests[i])
462                                         {
463                                                 id = i;
464                                                 break;
465                                         }
466                                 }
467
468                                 if (id == -1)
469                                         throw Exception("DNS: All ids are in use");
470
471                                 break;
472                         }
473                 }
474                 while (this->requests[id]);
475
476                 req->id = id;
477                 this->requests[req->id] = req;
478
479                 Packet p;
480                 p.flags = QUERYFLAGS_RD;
481                 p.id = req->id;
482                 p.question = req->question;
483
484                 unsigned char buffer[524];
485                 unsigned short len = p.Pack(buffer, sizeof(buffer));
486
487                 /* Note that calling Pack() above can actually change the contents of p.question.name, if the query is a PTR,
488                  * to contain the value that would be in the DNS cache, which is why this is here.
489                  */
490                 if (req->use_cache && this->CheckCache(req, p.question))
491                 {
492                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Using cached result");
493                         delete req;
494                         return;
495                 }
496
497                 // Update name in the original request so question checking works for PTR queries
498                 req->question.name = p.question.name;
499
500                 if (SocketEngine::SendTo(this, buffer, len, 0, this->myserver) != len)
501                         throw Exception("DNS: Unable to send query");
502
503                 // Add timer for timeout
504                 ServerInstance->Timers.AddTimer(req);
505         }
506
507         void RemoveRequest(DNS::Request* req) CXX11_OVERRIDE
508         {
509                 if (requests[req->id] == req)
510                         requests[req->id] = NULL;
511         }
512
513         std::string GetErrorStr(Error e) CXX11_OVERRIDE
514         {
515                 switch (e)
516                 {
517                         case ERROR_UNLOADED:
518                                 return "Module is unloading";
519                         case ERROR_TIMEDOUT:
520                                 return "Request timed out";
521                         case ERROR_NOT_AN_ANSWER:
522                         case ERROR_NONSTANDARD_QUERY:
523                         case ERROR_FORMAT_ERROR:
524                         case ERROR_MALFORMED:
525                                 return "Malformed answer";
526                         case ERROR_SERVER_FAILURE:
527                         case ERROR_NOT_IMPLEMENTED:
528                         case ERROR_REFUSED:
529                         case ERROR_INVALIDTYPE:
530                                 return "Nameserver failure";
531                         case ERROR_DOMAIN_NOT_FOUND:
532                         case ERROR_NO_RECORDS:
533                                 return "Domain not found";
534                         case ERROR_NONE:
535                         case ERROR_UNKNOWN:
536                         default:
537                                 return "Unknown error";
538                 }
539         }
540
541         void OnEventHandlerError(int errcode) CXX11_OVERRIDE
542         {
543                 ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "UDP socket got an error event");
544         }
545
546         void OnEventHandlerRead() CXX11_OVERRIDE
547         {
548                 unsigned char buffer[524];
549                 irc::sockets::sockaddrs from;
550                 socklen_t x = sizeof(from);
551
552                 int length = SocketEngine::RecvFrom(this, buffer, sizeof(buffer), 0, &from.sa, &x);
553
554                 if (length < Packet::HEADER_LENGTH)
555                         return;
556
557                 if (myserver != from)
558                 {
559                         std::string server1 = from.str();
560                         std::string server2 = myserver.str();
561                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Got a result from the wrong server! Bad NAT or DNS forging attempt? '%s' != '%s'",
562                                 server1.c_str(), server2.c_str());
563                         return;
564                 }
565
566                 Packet recv_packet;
567                 bool valid = false;
568
569                 try
570                 {
571                         recv_packet.Fill(buffer, length);
572                         valid = true;
573                 }
574                 catch (Exception& ex)
575                 {
576                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, ex.GetReason());
577                 }
578
579                 // recv_packet.id must be filled in here
580                 DNS::Request* request = this->requests[recv_packet.id];
581                 if (request == NULL)
582                 {
583                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Received an answer for something we didn't request");
584                         return;
585                 }
586
587                 if (request->question != recv_packet.question)
588                 {
589                         // This can happen under high latency, drop it silently, do not fail the request
590                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Received an answer that isn't for a question we asked");
591                         return;
592                 }
593
594                 if (!valid)
595                 {
596                         ServerInstance->stats.DnsBad++;
597                         recv_packet.error = ERROR_MALFORMED;
598                         request->OnError(&recv_packet);
599                 }
600                 else if (recv_packet.flags & QUERYFLAGS_OPCODE)
601                 {
602                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Received a nonstandard query");
603                         ServerInstance->stats.DnsBad++;
604                         recv_packet.error = ERROR_NONSTANDARD_QUERY;
605                         request->OnError(&recv_packet);
606                 }
607                 else if (!(recv_packet.flags & QUERYFLAGS_QR) || (recv_packet.flags & QUERYFLAGS_RCODE))
608                 {
609                         Error error = ERROR_UNKNOWN;
610
611                         switch (recv_packet.flags & QUERYFLAGS_RCODE)
612                         {
613                                 case 1:
614                                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "format error");
615                                         error = ERROR_FORMAT_ERROR;
616                                         break;
617                                 case 2:
618                                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "server error");
619                                         error = ERROR_SERVER_FAILURE;
620                                         break;
621                                 case 3:
622                                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "domain not found");
623                                         error = ERROR_DOMAIN_NOT_FOUND;
624                                         break;
625                                 case 4:
626                                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "not implemented");
627                                         error = ERROR_NOT_IMPLEMENTED;
628                                         break;
629                                 case 5:
630                                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "refused");
631                                         error = ERROR_REFUSED;
632                                         break;
633                                 default:
634                                         break;
635                         }
636
637                         ServerInstance->stats.DnsBad++;
638                         recv_packet.error = error;
639                         request->OnError(&recv_packet);
640                 }
641                 else if (recv_packet.answers.empty())
642                 {
643                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "No resource records returned");
644                         ServerInstance->stats.DnsBad++;
645                         recv_packet.error = ERROR_NO_RECORDS;
646                         request->OnError(&recv_packet);
647                 }
648                 else
649                 {
650                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "Lookup complete for " + request->question.name);
651                         ServerInstance->stats.DnsGood++;
652                         request->OnLookupComplete(&recv_packet);
653                         this->AddCache(recv_packet);
654                 }
655
656                 ServerInstance->stats.Dns++;
657
658                 /* Request's destructor removes it from the request map */
659                 delete request;
660         }
661
662         bool Tick(time_t now) CXX11_OVERRIDE
663         {
664                 ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "cache: purging DNS cache");
665
666                 for (cache_map::iterator it = this->cache.begin(); it != this->cache.end(); )
667                 {
668                         const Query& query = it->second;
669                         if (IsExpired(query, now))
670                                 this->cache.erase(it++);
671                         else
672                                 ++it;
673                 }
674                 return true;
675         }
676
677         void Rehash(const std::string& dnsserver, std::string sourceaddr, unsigned int sourceport)
678         {
679                 if (this->GetFd() > -1)
680                 {
681                         SocketEngine::Shutdown(this, 2);
682                         SocketEngine::Close(this);
683
684                         /* Remove expired entries from the cache */
685                         this->Tick(ServerInstance->Time());
686                 }
687
688                 irc::sockets::aptosa(dnsserver, DNS::PORT, myserver);
689
690                 /* Initialize mastersocket */
691                 int s = socket(myserver.family(), SOCK_DGRAM, 0);
692                 this->SetFd(s);
693
694                 /* Have we got a socket? */
695                 if (this->GetFd() != -1)
696                 {
697                         SocketEngine::SetReuse(s);
698                         SocketEngine::NonBlocking(s);
699
700                         irc::sockets::sockaddrs bindto;
701                         if (sourceaddr.empty())
702                         {
703                                 // set a sourceaddr for irc::sockets::aptosa() based on the servers af type
704                                 if (myserver.family() == AF_INET)
705                                         sourceaddr = "0.0.0.0";
706                                 else if (myserver.family() == AF_INET6)
707                                         sourceaddr = "::";
708                         }
709                         irc::sockets::aptosa(sourceaddr, sourceport, bindto);
710
711                         if (SocketEngine::Bind(this->GetFd(), bindto) < 0)
712                         {
713                                 /* Failed to bind */
714                                 ServerInstance->Logs->Log(MODNAME, LOG_SPARSE, "Error binding dns socket - hostnames will NOT resolve");
715                                 SocketEngine::Close(this->GetFd());
716                                 this->SetFd(-1);
717                         }
718                         else if (!SocketEngine::AddFd(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE))
719                         {
720                                 ServerInstance->Logs->Log(MODNAME, LOG_SPARSE, "Internal error starting DNS - hostnames will NOT resolve.");
721                                 SocketEngine::Close(this->GetFd());
722                                 this->SetFd(-1);
723                         }
724
725                         if (bindto.family() != myserver.family())
726                                 ServerInstance->Logs->Log(MODNAME, LOG_SPARSE, "Nameserver address family differs from source address family - hostnames might not resolve");
727                 }
728                 else
729                 {
730                         ServerInstance->Logs->Log(MODNAME, LOG_SPARSE, "Error creating DNS socket - hostnames will NOT resolve");
731                 }
732         }
733 };
734
735 class ModuleDNS : public Module
736 {
737         MyManager manager;
738         std::string DNSServer;
739         std::string SourceIP;
740         unsigned int SourcePort;
741
742         void FindDNSServer()
743         {
744 #ifdef _WIN32
745                 // attempt to look up their nameserver from the system
746                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: <dns:server> not defined, attempting to find a working server in the system settings...");
747
748                 PFIXED_INFO pFixedInfo;
749                 DWORD dwBufferSize = sizeof(FIXED_INFO);
750                 pFixedInfo = (PFIXED_INFO) HeapAlloc(GetProcessHeap(), 0, sizeof(FIXED_INFO));
751
752                 if (pFixedInfo)
753                 {
754                         if (GetNetworkParams(pFixedInfo, &dwBufferSize) == ERROR_BUFFER_OVERFLOW)
755                         {
756                                 HeapFree(GetProcessHeap(), 0, pFixedInfo);
757                                 pFixedInfo = (PFIXED_INFO) HeapAlloc(GetProcessHeap(), 0, dwBufferSize);
758                         }
759
760                         if (pFixedInfo)
761                         {
762                                 if (GetNetworkParams(pFixedInfo, &dwBufferSize) == NO_ERROR)
763                                         DNSServer = pFixedInfo->DnsServerList.IpAddress.String;
764
765                                 HeapFree(GetProcessHeap(), 0, pFixedInfo);
766                         }
767
768                         if (!DNSServer.empty())
769                         {
770                                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "<dns:server> set to '%s' as first active resolver in the system settings.", DNSServer.c_str());
771                                 return;
772                         }
773                 }
774
775                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No viable nameserver found! Defaulting to nameserver '127.0.0.1'!");
776 #else
777                 // attempt to look up their nameserver from /etc/resolv.conf
778                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "WARNING: <dns:server> not defined, attempting to find working server in /etc/resolv.conf...");
779
780                 std::ifstream resolv("/etc/resolv.conf");
781
782                 while (resolv >> DNSServer)
783                 {
784                         if (DNSServer == "nameserver")
785                         {
786                                 resolv >> DNSServer;
787                                 if (DNSServer.find_first_not_of("0123456789.") == std::string::npos || DNSServer.find_first_not_of("0123456789ABCDEFabcdef:") == std::string::npos)
788                                 {
789                                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "<dns:server> set to '%s' as first resolver in /etc/resolv.conf.",DNSServer.c_str());
790                                         return;
791                                 }
792                         }
793                 }
794
795                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "/etc/resolv.conf contains no viable nameserver entries! Defaulting to nameserver '127.0.0.1'!");
796 #endif
797                 DNSServer = "127.0.0.1";
798         }
799
800  public:
801         ModuleDNS() : manager(this)
802                 , SourcePort(0)
803         {
804         }
805
806         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
807         {
808                 std::string oldserver = DNSServer;
809                 const std::string oldip = SourceIP;
810                 const unsigned int oldport = SourcePort;
811
812                 ConfigTag* tag = ServerInstance->Config->ConfValue("dns");
813                 DNSServer = tag->getString("server");
814                 SourceIP = tag->getString("sourceip");
815                 SourcePort = tag->getUInt("sourceport", 0, 0, UINT16_MAX);
816
817                 if (DNSServer.empty())
818                         FindDNSServer();
819
820                 if (oldserver != DNSServer || oldip != SourceIP || oldport != SourcePort)
821                         this->manager.Rehash(DNSServer, SourceIP, SourcePort);
822         }
823
824         void OnUnloadModule(Module* mod) CXX11_OVERRIDE
825         {
826                 for (unsigned int i = 0; i <= MAX_REQUEST_ID; ++i)
827                 {
828                         DNS::Request* req = this->manager.requests[i];
829                         if (!req)
830                                 continue;
831
832                         if (req->creator == mod)
833                         {
834                                 Query rr(req->question);
835                                 rr.error = ERROR_UNLOADED;
836                                 req->OnError(&rr);
837
838                                 delete req;
839                         }
840                 }
841         }
842
843         Version GetVersion() CXX11_OVERRIDE
844         {
845                 return Version("Provides support for DNS lookups", VF_CORE|VF_VENDOR);
846         }
847 };
848
849 MODULE_INIT(ModuleDNS)
850