]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_ssl_gnutls.cpp
Revert earlier time() -> SI->Time() diff for now, this causes problems with dns.cpp...
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_gnutls.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2008 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *          the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include "inspircd.h"
15 #include <gnutls/gnutls.h>
16 #include <gnutls/x509.h>
17 #include "transport.h"
18 #include "m_cap.h"
19
20 #ifdef WINDOWS
21 #pragma comment(lib, "libgnutls-13.lib")
22 #endif
23
24 /* $ModDesc: Provides SSL support for clients */
25 /* $CompileFlags: exec("libgnutls-config --cflags") */
26 /* $LinkerFlags: rpath("libgnutls-config --libs") exec("libgnutls-config --libs") */
27 /* $ModDep: transport.h */
28 /* $CopyInstall: conf/key.pem $(CONPATH) */
29 /* $CopyInstall: conf/cert.pem $(CONPATH) */
30
31 enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED };
32
33 bool isin(const std::string &host, int port, const std::vector<std::string> &portlist)
34 {
35         if (std::find(portlist.begin(), portlist.end(), "*:" + ConvToStr(port)) != portlist.end())
36                 return true;
37
38         if (std::find(portlist.begin(), portlist.end(), ":" + ConvToStr(port)) != portlist.end())
39                 return true;
40
41         return std::find(portlist.begin(), portlist.end(), host + ":" + ConvToStr(port)) != portlist.end();
42 }
43
44 /** Represents an SSL user's extra data
45  */
46 class issl_session : public classbase
47 {
48 public:
49         issl_session()
50         {
51                 sess = NULL;
52         }
53
54         gnutls_session_t sess;
55         issl_status status;
56         std::string outbuf;
57         int inbufoffset;
58         char* inbuf;
59         int fd;
60 };
61
62 class CommandStartTLS : public Command
63 {
64         Module* Caller;
65  public:
66         /* Command 'dalinfo', takes no parameters and needs no special modes */
67         CommandStartTLS (InspIRCd* Instance, Module* mod) : Command(Instance,"STARTTLS", 0, 0, true), Caller(mod)
68         {
69                 this->source = "m_ssl_gnutls.so";
70         }
71
72         CmdResult Handle (const std::vector<std::string> &parameters, User *user)
73         {
74                 if (user->registered == REG_ALL)
75                 {
76                         ServerInstance->Users->QuitUser(user, "STARTTLS not allowed after client registration");
77                 }
78                 else
79                 {
80                         if (!user->GetIOHook())
81                         {
82                                 user->WriteNumeric(670, "%s :STARTTLS successful, go ahead with TLS handshake", user->nick.c_str());
83                                 user->AddIOHook(Caller);
84                                 Caller->OnRawSocketAccept(user->GetFd(), user->GetIPString(), user->GetPort());
85                         }
86                         else
87                                 user->WriteNumeric(671, "%s :STARTTLS failure", user->nick.c_str());
88                 }
89
90                 return CMD_FAILURE;
91         }
92 };
93
94 class ModuleSSLGnuTLS : public Module
95 {
96
97         ConfigReader* Conf;
98
99         char* dummy;
100
101         std::vector<std::string> listenports;
102
103         int inbufsize;
104         issl_session* sessions;
105
106         gnutls_certificate_credentials x509_cred;
107         gnutls_dh_params dh_params;
108
109         std::string keyfile;
110         std::string certfile;
111         std::string cafile;
112         std::string crlfile;
113         std::string sslports;
114         int dh_bits;
115
116         int clientactive;
117         bool cred_alloc;
118
119         CommandStartTLS* starttls;
120
121  public:
122
123         ModuleSSLGnuTLS(InspIRCd* Me)
124                 : Module(Me)
125         {
126                 ServerInstance->Modules->PublishInterface("BufferedSocketHook", this);
127
128                 sessions = new issl_session[ServerInstance->SE->GetMaxFds()];
129
130                 // Not rehashable...because I cba to reduce all the sizes of existing buffers.
131                 inbufsize = ServerInstance->Config->NetBufferSize;
132
133                 gnutls_global_init(); // This must be called once in the program
134
135                 cred_alloc = false;
136                 // Needs the flag as it ignores a plain /rehash
137                 OnRehash(NULL,"ssl");
138
139                 // Void return, guess we assume success
140                 gnutls_certificate_set_dh_params(x509_cred, dh_params);
141                 Implementation eventlist[] = { I_On005Numeric, I_OnRawSocketConnect, I_OnRawSocketAccept, I_OnRawSocketClose, I_OnRawSocketRead, I_OnRawSocketWrite, I_OnCleanup,
142                         I_OnBufferFlushed, I_OnRequest, I_OnSyncUserMetaData, I_OnDecodeMetaData, I_OnUnloadModule, I_OnRehash, I_OnWhois, I_OnPostConnect, I_OnEvent, I_OnHookUserIO };
143                 ServerInstance->Modules->Attach(eventlist, this, 17);
144
145                 starttls = new CommandStartTLS(ServerInstance, this);
146                 ServerInstance->AddCommand(starttls);
147         }
148
149         virtual void OnRehash(User* user, const std::string &param)
150         {
151                 Conf = new ConfigReader(ServerInstance);
152
153                 listenports.clear();
154                 clientactive = 0;
155                 sslports.clear();
156
157                 for(int index = 0; index < Conf->Enumerate("bind"); index++)
158                 {
159                         // For each <bind> tag
160                         std::string x = Conf->ReadValue("bind", "type", index);
161                         if(((x.empty()) || (x == "clients")) && (Conf->ReadValue("bind", "ssl", index) == "gnutls"))
162                         {
163                                 // Get the port we're meant to be listening on with SSL
164                                 std::string port = Conf->ReadValue("bind", "port", index);
165                                 std::string addr = Conf->ReadValue("bind", "address", index);
166
167                                 irc::portparser portrange(port, false);
168                                 long portno = -1;
169                                 while ((portno = portrange.GetToken()))
170                                 {
171                                         clientactive++;
172                                         try
173                                         {
174                                                 listenports.push_back(addr + ":" + ConvToStr(portno));
175
176                                                 for (size_t i = 0; i < ServerInstance->Config->ports.size(); i++)
177                                                         if ((ServerInstance->Config->ports[i]->GetPort() == portno) && (ServerInstance->Config->ports[i]->GetIP() == addr))
178                                                                 ServerInstance->Config->ports[i]->SetDescription("ssl");
179                                                 ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %ld", portno);
180
181                                                 sslports.append((addr.empty() ? "*" : addr)).append(":").append(ConvToStr(portno)).append(";");
182                                         }
183                                         catch (ModuleException &e)
184                                         {
185                                                 ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: FAILED to enable SSL on port %ld: %s. Maybe it's already hooked by the same port on a different IP, or you have an other SSL or similar module loaded?", portno, e.GetReason());
186                                         }
187                                 }
188                         }
189                 }
190
191                 if (!sslports.empty())
192                         sslports.erase(sslports.end() - 1);
193
194                 if(param != "ssl")
195                 {
196                         delete Conf;
197                         return;
198                 }
199
200                 std::string confdir(ServerInstance->ConfigFileName);
201                 // +1 so we the path ends with a /
202                 confdir = confdir.substr(0, confdir.find_last_of('/') + 1);
203
204                 cafile  = Conf->ReadValue("gnutls", "cafile", 0);
205                 crlfile = Conf->ReadValue("gnutls", "crlfile", 0);
206                 certfile        = Conf->ReadValue("gnutls", "certfile", 0);
207                 keyfile = Conf->ReadValue("gnutls", "keyfile", 0);
208                 dh_bits = Conf->ReadInteger("gnutls", "dhbits", 0, false);
209
210                 // Set all the default values needed.
211                 if (cafile.empty())
212                         cafile = "ca.pem";
213
214                 if (crlfile.empty())
215                         crlfile = "crl.pem";
216
217                 if (certfile.empty())
218                         certfile = "cert.pem";
219
220                 if (keyfile.empty())
221                         keyfile = "key.pem";
222
223                 if((dh_bits != 768) && (dh_bits != 1024) && (dh_bits != 2048) && (dh_bits != 3072) && (dh_bits != 4096))
224                         dh_bits = 1024;
225
226                 // Prepend relative paths with the path to the config directory.
227                 if ((cafile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(cafile)))
228                         cafile = confdir + cafile;
229
230                 if ((crlfile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(crlfile)))
231                         crlfile = confdir + crlfile;
232
233                 if ((certfile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(certfile)))
234                         certfile = confdir + certfile;
235
236                 if ((keyfile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(keyfile)))
237                         keyfile = confdir + keyfile;
238
239                 int ret;
240                 
241                 if (cred_alloc)
242                 {
243                         // Deallocate the old credentials
244                         gnutls_dh_params_deinit(dh_params);
245                         gnutls_certificate_free_credentials(x509_cred);
246                 }
247                 else
248                         cred_alloc = true;
249                 
250                 if((ret = gnutls_certificate_allocate_credentials(&x509_cred)) < 0)
251                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to allocate certificate credentials: %s", gnutls_strerror(ret));
252                 
253                 if((ret = gnutls_dh_params_init(&dh_params)) < 0)
254                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to initialise DH parameters: %s", gnutls_strerror(ret));
255                 
256                 if((ret =gnutls_certificate_set_x509_trust_file(x509_cred, cafile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
257                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret));
258
259                 if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
260                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret));
261
262                 if((ret = gnutls_certificate_set_x509_key_file (x509_cred, certfile.c_str(), keyfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
263                 {
264                         // If this fails, no SSL port will work. At all. So, do the smart thing - throw a ModuleException
265                         throw ModuleException("Unable to load GnuTLS server certificate: " + std::string(gnutls_strerror(ret)));
266                 }
267
268                 // This may be on a large (once a day or week) timer eventually.
269                 GenerateDHParams();
270
271                 delete Conf;
272         }
273
274         void GenerateDHParams()
275         {
276                 // Generate Diffie Hellman parameters - for use with DHE
277                 // kx algorithms. These should be discarded and regenerated
278                 // once a day, once a week or once a month. Depending on the
279                 // security requirements.
280
281                 int ret;
282
283                 if((ret = gnutls_dh_params_generate2(dh_params, dh_bits)) < 0)
284                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to generate DH parameters (%d bits): %s", dh_bits, gnutls_strerror(ret));
285         }
286
287         virtual ~ModuleSSLGnuTLS()
288         {
289                 gnutls_dh_params_deinit(dh_params);
290                 gnutls_certificate_free_credentials(x509_cred);
291                 gnutls_global_deinit();
292                 ServerInstance->Modules->UnpublishInterface("BufferedSocketHook", this);
293                 delete[] sessions;
294         }
295
296         virtual void OnCleanup(int target_type, void* item)
297         {
298                 if(target_type == TYPE_USER)
299                 {
300                         User* user = (User*)item;
301
302                         if (user->GetIOHook() == this)
303                         {
304                                 // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
305                                 // Potentially there could be multiple SSL modules loaded at once on different ports.
306                                 ServerInstance->Users->QuitUser(user, "SSL module unloading");
307                                 user->DelIOHook();
308                         }
309                         if (user->GetExt("ssl_cert", dummy))
310                         {
311                                 ssl_cert* tofree;
312                                 user->GetExt("ssl_cert", tofree);
313                                 delete tofree;
314                                 user->Shrink("ssl_cert");
315                         }
316                 }
317         }
318
319         virtual void OnUnloadModule(Module* mod, const std::string &name)
320         {
321                 if(mod == this)
322                 {
323                         for(unsigned int i = 0; i < listenports.size(); i++)
324                         {
325                                 for (size_t j = 0; j < ServerInstance->Config->ports.size(); j++)
326                                         if (listenports[i] == (ServerInstance->Config->ports[j]->GetIP()+":"+ConvToStr(ServerInstance->Config->ports[j]->GetPort())))
327                                                 ServerInstance->Config->ports[j]->SetDescription("plaintext");
328                         }
329                 }
330         }
331
332         virtual Version GetVersion()
333         {
334                 return Version("$Id$", VF_VENDOR, API_VERSION);
335         }
336
337
338         virtual void On005Numeric(std::string &output)
339         {
340                 output.append(" SSL=" + sslports);
341         }
342
343         virtual void OnHookUserIO(User* user, const std::string &targetip)
344         {
345                 if (!user->GetIOHook() && isin(targetip,user->GetPort(),listenports))
346                 {
347                         /* Hook the user with our module */
348                         user->AddIOHook(this);
349                 }
350         }
351
352         virtual const char* OnRequest(Request* request)
353         {
354                 ISHRequest* ISR = (ISHRequest*)request;
355                 if (strcmp("IS_NAME", request->GetId()) == 0)
356                 {
357                         return "gnutls";
358                 }
359                 else if (strcmp("IS_HOOK", request->GetId()) == 0)
360                 {
361                         const char* ret = "OK";
362                         try
363                         {
364                                 ret = ISR->Sock->AddIOHook((Module*)this) ? "OK" : NULL;
365                         }
366                         catch (ModuleException &e)
367                         {
368                                 return NULL;
369                         }
370                         return ret;
371                 }
372                 else if (strcmp("IS_UNHOOK", request->GetId()) == 0)
373                 {
374                         return ISR->Sock->DelIOHook() ? "OK" : NULL;
375                 }
376                 else if (strcmp("IS_HSDONE", request->GetId()) == 0)
377                 {
378                         if (ISR->Sock->GetFd() < 0)
379                                 return "OK";
380
381                         issl_session* session = &sessions[ISR->Sock->GetFd()];
382                         return (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) ? NULL : "OK";
383                 }
384                 else if (strcmp("IS_ATTACH", request->GetId()) == 0)
385                 {
386                         if (ISR->Sock->GetFd() > -1)
387                         {
388                                 issl_session* session = &sessions[ISR->Sock->GetFd()];
389                                 if (session->sess)
390                                 {
391                                         if ((Extensible*)ServerInstance->FindDescriptor(ISR->Sock->GetFd()) == (Extensible*)(ISR->Sock))
392                                         {
393                                                 VerifyCertificate(session, (BufferedSocket*)ISR->Sock);
394                                                 return "OK";
395                                         }
396                                 }
397                         }
398                 }
399                 return NULL;
400         }
401
402
403         virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport)
404         {
405                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
406                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
407                         return;
408
409                 issl_session* session = &sessions[fd];
410
411                 /* For STARTTLS: Don't try and init a session on a socket that already has a session */
412                 if (session->sess)
413                         return;
414
415                 session->fd = fd;
416                 session->inbuf = new char[inbufsize];
417                 session->inbufoffset = 0;
418
419                 gnutls_init(&session->sess, GNUTLS_SERVER);
420
421                 gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate.
422                 gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
423                 gnutls_dh_set_prime_bits(session->sess, dh_bits);
424
425                 /* This is an experimental change to avoid a warning on 64bit systems about casting between integer and pointer of different sizes
426                  * This needs testing, but it's easy enough to rollback if need be
427                  * Old: gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
428                  * New: gnutls_transport_set_ptr(session->sess, &fd); // Give gnutls the fd for the socket.
429                  *
430                  * With testing this seems to...not work :/
431                  */
432
433                 gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
434
435                 gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any.
436
437                 Handshake(session);
438         }
439
440         virtual void OnRawSocketConnect(int fd)
441         {
442                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
443                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
444                         return;
445
446                 issl_session* session = &sessions[fd];
447
448                 session->fd = fd;
449                 session->inbuf = new char[inbufsize];
450                 session->inbufoffset = 0;
451
452                 gnutls_init(&session->sess, GNUTLS_CLIENT);
453
454                 gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate.
455                 gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
456                 gnutls_dh_set_prime_bits(session->sess, dh_bits);
457                 gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
458
459                 Handshake(session);
460         }
461
462         virtual void OnRawSocketClose(int fd)
463         {
464                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
465                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds()))
466                         return;
467
468                 CloseSession(&sessions[fd]);
469
470                 EventHandler* user = ServerInstance->SE->GetRef(fd);
471
472                 if ((user) && (user->GetExt("ssl_cert", dummy)))
473                 {
474                         ssl_cert* tofree;
475                         user->GetExt("ssl_cert", tofree);
476                         delete tofree;
477                         user->Shrink("ssl_cert");
478                 }
479         }
480
481         virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult)
482         {
483                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
484                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
485                         return 0;
486
487                 issl_session* session = &sessions[fd];
488
489                 if (!session->sess)
490                 {
491                         readresult = 0;
492                         CloseSession(session);
493                         return 1;
494                 }
495
496                 if (session->status == ISSL_HANDSHAKING_READ)
497                 {
498                         // The handshake isn't finished, try to finish it.
499
500                         if(!Handshake(session))
501                         {
502                                 // Couldn't resume handshake.
503                                 return -1;
504                         }
505                 }
506                 else if (session->status == ISSL_HANDSHAKING_WRITE)
507                 {
508                         errno = EAGAIN;
509                         MakePollWrite(session);
510                         return -1;
511                 }
512
513                 // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN.
514
515                 if (session->status == ISSL_HANDSHAKEN)
516                 {
517                         // Is this right? Not sure if the unencrypted data is garaunteed to be the same length.
518                         // Read into the inbuffer, offset from the beginning by the amount of data we have that insp hasn't taken yet.
519                         int ret = gnutls_record_recv(session->sess, session->inbuf + session->inbufoffset, inbufsize - session->inbufoffset);
520
521                         if (ret == 0)
522                         {
523                                 // Client closed connection.
524                                 readresult = 0;
525                                 CloseSession(session);
526                                 return 1;
527                         }
528                         else if (ret < 0)
529                         {
530                                 if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
531                                 {
532                                         errno = EAGAIN;
533                                         return -1;
534                                 }
535                                 else
536                                 {
537                                         readresult = 0;
538                                         CloseSession(session);
539                                 }
540                         }
541                         else
542                         {
543                                 // Read successfully 'ret' bytes into inbuf + inbufoffset
544                                 // There are 'ret' + 'inbufoffset' bytes of data in 'inbuf'
545                                 // 'buffer' is 'count' long
546
547                                 unsigned int length = ret + session->inbufoffset;
548
549                                 if(count <= length)
550                                 {
551                                         memcpy(buffer, session->inbuf, count);
552                                         // Move the stuff left in inbuf to the beginning of it
553                                         memmove(session->inbuf, session->inbuf + count, (length - count));
554                                         // Now we need to set session->inbufoffset to the amount of data still waiting to be handed to insp.
555                                         session->inbufoffset = length - count;
556                                         // Insp uses readresult as the count of how much data there is in buffer, so:
557                                         readresult = count;
558                                 }
559                                 else
560                                 {
561                                         // There's not as much in the inbuf as there is space in the buffer, so just copy the whole thing.
562                                         memcpy(buffer, session->inbuf, length);
563                                         // Zero the offset, as there's nothing there..
564                                         session->inbufoffset = 0;
565                                         // As above
566                                         readresult = length;
567                                 }
568                         }
569                 }
570                 else if(session->status == ISSL_CLOSING)
571                         readresult = 0;
572
573                 return 1;
574         }
575
576         virtual int OnRawSocketWrite(int fd, const char* buffer, int count)
577         {
578                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
579                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
580                         return 0;
581
582                 issl_session* session = &sessions[fd];
583                 const char* sendbuffer = buffer;
584
585                 if (!session->sess)
586                 {
587                         CloseSession(session);
588                         return 1;
589                 }
590
591                 session->outbuf.append(sendbuffer, count);
592                 sendbuffer = session->outbuf.c_str();
593                 count = session->outbuf.size();
594
595                 if (session->status == ISSL_HANDSHAKING_WRITE)
596                 {
597                         // The handshake isn't finished, try to finish it.
598                         Handshake(session);
599                         errno = EAGAIN;
600                         return -1;
601                 }
602
603                 int ret = 0;
604
605                 if (session->status == ISSL_HANDSHAKEN)
606                 {
607                         ret = gnutls_record_send(session->sess, sendbuffer, count);
608
609                         if (ret == 0)
610                         {
611                                 CloseSession(session);
612                         }
613                         else if (ret < 0)
614                         {
615                                 if(ret != GNUTLS_E_AGAIN && ret != GNUTLS_E_INTERRUPTED)
616                                 {
617                                         CloseSession(session);
618                                 }
619                                 else
620                                 {
621                                         errno = EAGAIN;
622                                 }
623                         }
624                         else
625                         {
626                                 session->outbuf = session->outbuf.substr(ret);
627                         }
628                 }
629
630                 MakePollWrite(session);
631
632                 /* Who's smart idea was it to return 1 when we havent written anything?
633                  * This fucks the buffer up in BufferedSocket :p
634                  */
635                 return ret < 1 ? 0 : ret;
636         }
637
638         // :kenny.chatspike.net 320 Om Epy|AFK :is a Secure Connection
639         virtual void OnWhois(User* source, User* dest)
640         {
641                 if (!clientactive)
642                         return;
643
644                 // Bugfix, only send this numeric for *our* SSL users
645                 if (dest->GetExt("ssl", dummy) || ((IS_LOCAL(dest) && (dest->GetIOHook() == this))))
646                 {
647                         ServerInstance->SendWhoisLine(source, dest, 320, "%s %s :is using a secure connection", source->nick.c_str(), dest->nick.c_str());
648                 }
649         }
650
651         virtual void OnSyncUserMetaData(User* user, Module* proto, void* opaque, const std::string &extname, bool displayable)
652         {
653                 // check if the linking module wants to know about OUR metadata
654                 if(extname == "ssl")
655                 {
656                         // check if this user has an swhois field to send
657                         if(user->GetExt(extname, dummy))
658                         {
659                                 // call this function in the linking module, let it format the data how it
660                                 // sees fit, and send it on its way. We dont need or want to know how.
661                                 proto->ProtoSendMetaData(opaque, TYPE_USER, user, extname, displayable ? "Enabled" : "ON");
662                         }
663                 }
664         }
665
666         virtual void OnDecodeMetaData(int target_type, void* target, const std::string &extname, const std::string &extdata)
667         {
668                 // check if its our metadata key, and its associated with a user
669                 if ((target_type == TYPE_USER) && (extname == "ssl"))
670                 {
671                         User* dest = (User*)target;
672                         // if they dont already have an ssl flag, accept the remote server's
673                         if (!dest->GetExt(extname, dummy))
674                         {
675                                 dest->Extend(extname, "ON");
676                         }
677                 }
678         }
679
680         bool Handshake(issl_session* session)
681         {
682                 int ret = gnutls_handshake(session->sess);
683
684                 if (ret < 0)
685                 {
686                         if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
687                         {
688                                 // Handshake needs resuming later, read() or write() would have blocked.
689
690                                 if(gnutls_record_get_direction(session->sess) == 0)
691                                 {
692                                         // gnutls_handshake() wants to read() again.
693                                         session->status = ISSL_HANDSHAKING_READ;
694                                 }
695                                 else
696                                 {
697                                         // gnutls_handshake() wants to write() again.
698                                         session->status = ISSL_HANDSHAKING_WRITE;
699                                         MakePollWrite(session);
700                                 }
701                         }
702                         else
703                         {
704                                 // Handshake failed.
705                                 CloseSession(session);
706                                 session->status = ISSL_CLOSING;
707                         }
708
709                         return false;
710                 }
711                 else
712                 {
713                         // Handshake complete.
714                         // This will do for setting the ssl flag...it could be done earlier if it's needed. But this seems neater.
715                         User* extendme = ServerInstance->FindDescriptor(session->fd);
716                         if (extendme)
717                         {
718                                 if (!extendme->GetExt("ssl", dummy))
719                                         extendme->Extend("ssl", "ON");
720                         }
721
722                         // Change the seesion state
723                         session->status = ISSL_HANDSHAKEN;
724
725                         // Finish writing, if any left
726                         MakePollWrite(session);
727
728                         return true;
729                 }
730         }
731
732         virtual void OnPostConnect(User* user)
733         {
734                 // This occurs AFTER OnUserConnect so we can be sure the
735                 // protocol module has propagated the NICK message.
736                 if ((user->GetExt("ssl", dummy)) && (IS_LOCAL(user)))
737                 {
738                         // Tell whatever protocol module we're using that we need to inform other servers of this metadata NOW.
739                         ServerInstance->PI->SendMetaData(user, TYPE_USER, "SSL", "on");
740
741                         VerifyCertificate(&sessions[user->GetFd()],user);
742                         if (sessions[user->GetFd()].sess)
743                         {
744                                 std::string cipher = gnutls_kx_get_name(gnutls_kx_get(sessions[user->GetFd()].sess));
745                                 cipher.append("-").append(gnutls_cipher_get_name(gnutls_cipher_get(sessions[user->GetFd()].sess))).append("-");
746                                 cipher.append(gnutls_mac_get_name(gnutls_mac_get(sessions[user->GetFd()].sess)));
747                                 user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick.c_str(), cipher.c_str());
748                         }
749                 }
750         }
751
752         void MakePollWrite(issl_session* session)
753         {
754                 //OnRawSocketWrite(session->fd, NULL, 0);
755                 EventHandler* eh = ServerInstance->FindDescriptor(session->fd);
756                 if (eh)
757                         ServerInstance->SE->WantWrite(eh);
758         }
759
760         virtual void OnBufferFlushed(User* user)
761         {
762                 if (user->GetExt("ssl"))
763                 {
764                         issl_session* session = &sessions[user->GetFd()];
765                         if (session && session->outbuf.size())
766                                 OnRawSocketWrite(user->GetFd(), NULL, 0);
767                 }
768         }
769
770         void CloseSession(issl_session* session)
771         {
772                 if(session->sess)
773                 {
774                         gnutls_bye(session->sess, GNUTLS_SHUT_WR);
775                         gnutls_deinit(session->sess);
776                 }
777
778                 if(session->inbuf)
779                 {
780                         delete[] session->inbuf;
781                 }
782
783                 session->outbuf.clear();
784                 session->inbuf = NULL;
785                 session->sess = NULL;
786                 session->status = ISSL_NONE;
787         }
788
789         void VerifyCertificate(issl_session* session, Extensible* user)
790         {
791                 if (!session->sess || !user)
792                         return;
793
794                 unsigned int status;
795                 const gnutls_datum_t* cert_list;
796                 int ret;
797                 unsigned int cert_list_size;
798                 gnutls_x509_crt_t cert;
799                 char name[MAXBUF];
800                 unsigned char digest[MAXBUF];
801                 size_t digest_size = sizeof(digest);
802                 size_t name_size = sizeof(name);
803                 ssl_cert* certinfo = new ssl_cert;
804
805                 user->Extend("ssl_cert",certinfo);
806
807                 /* This verification function uses the trusted CAs in the credentials
808                  * structure. So you must have installed one or more CA certificates.
809                  */
810                 ret = gnutls_certificate_verify_peers2(session->sess, &status);
811
812                 if (ret < 0)
813                 {
814                         certinfo->data.insert(std::make_pair("error",std::string(gnutls_strerror(ret))));
815                         return;
816                 }
817
818                 if (status & GNUTLS_CERT_INVALID)
819                 {
820                         certinfo->data.insert(std::make_pair("invalid",ConvToStr(1)));
821                 }
822                 else
823                 {
824                         certinfo->data.insert(std::make_pair("invalid",ConvToStr(0)));
825                 }
826                 if (status & GNUTLS_CERT_SIGNER_NOT_FOUND)
827                 {
828                         certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(1)));
829                 }
830                 else
831                 {
832                         certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(0)));
833                 }
834                 if (status & GNUTLS_CERT_REVOKED)
835                 {
836                         certinfo->data.insert(std::make_pair("revoked",ConvToStr(1)));
837                 }
838                 else
839                 {
840                         certinfo->data.insert(std::make_pair("revoked",ConvToStr(0)));
841                 }
842                 if (status & GNUTLS_CERT_SIGNER_NOT_CA)
843                 {
844                         certinfo->data.insert(std::make_pair("trusted",ConvToStr(0)));
845                 }
846                 else
847                 {
848                         certinfo->data.insert(std::make_pair("trusted",ConvToStr(1)));
849                 }
850
851                 /* Up to here the process is the same for X.509 certificates and
852                  * OpenPGP keys. From now on X.509 certificates are assumed. This can
853                  * be easily extended to work with openpgp keys as well.
854                  */
855                 if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509)
856                 {
857                         certinfo->data.insert(std::make_pair("error","No X509 keys sent"));
858                         return;
859                 }
860
861                 ret = gnutls_x509_crt_init(&cert);
862                 if (ret < 0)
863                 {
864                         certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
865                         return;
866                 }
867
868                 cert_list_size = 0;
869                 cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size);
870                 if (cert_list == NULL)
871                 {
872                         certinfo->data.insert(std::make_pair("error","No certificate was found"));
873                         return;
874                 }
875
876                 /* This is not a real world example, since we only check the first
877                  * certificate in the given chain.
878                  */
879
880                 ret = gnutls_x509_crt_import(cert, &cert_list[0], GNUTLS_X509_FMT_DER);
881                 if (ret < 0)
882                 {
883                         certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
884                         return;
885                 }
886
887                 gnutls_x509_crt_get_dn(cert, name, &name_size);
888
889                 certinfo->data.insert(std::make_pair("dn",name));
890
891                 gnutls_x509_crt_get_issuer_dn(cert, name, &name_size);
892
893                 certinfo->data.insert(std::make_pair("issuer",name));
894
895                 if ((ret = gnutls_x509_crt_get_fingerprint(cert, GNUTLS_DIG_MD5, digest, &digest_size)) < 0)
896                 {
897                         certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
898                 }
899                 else
900                 {
901                         certinfo->data.insert(std::make_pair("fingerprint",irc::hex(digest, digest_size)));
902                 }
903
904                 /* Beware here we do not check for errors.
905                  */
906                 if ((gnutls_x509_crt_get_expiration_time(cert) < time(0)) || (gnutls_x509_crt_get_activation_time(cert) > time(0)))
907                 {
908                         certinfo->data.insert(std::make_pair("error","Not activated, or expired certificate"));
909                 }
910
911                 gnutls_x509_crt_deinit(cert);
912
913                 return;
914         }
915
916         void OnEvent(Event* ev)
917         {
918                 GenericCapHandler(ev, "tls", "tls");
919         }
920
921         void Prioritize()
922         {
923                 Module* server = ServerInstance->Modules->Find("m_spanningtree.so");
924                 ServerInstance->Modules->SetPriority(this, I_OnPostConnect, PRIO_AFTER, &server);
925         }
926 };
927
928 MODULE_INIT(ModuleSSLGnuTLS)