]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_ssl_gnutls.cpp
Fix IO hooking modules to use the new (not old) hooking call
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_gnutls.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2008 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *          the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include "inspircd.h"
15
16 #include <gnutls/gnutls.h>
17 #include <gnutls/x509.h>
18
19 #include "inspircd_config.h"
20 #include "configreader.h"
21 #include "users.h"
22 #include "channels.h"
23 #include "modules.h"
24 #include "socket.h"
25 #include "hashcomp.h"
26 #include "transport.h"
27 #include "m_cap.h"
28
29 #ifdef WINDOWS
30 #pragma comment(lib, "libgnutls-13.lib")
31 #endif
32
33 /* $ModDesc: Provides SSL support for clients */
34 /* $CompileFlags: exec("libgnutls-config --cflags") */
35 /* $LinkerFlags: rpath("libgnutls-config --libs") exec("libgnutls-config --libs") */
36 /* $ModDep: transport.h */
37 /* $CopyInstall: conf/key.pem $(CONPATH) */
38 /* $CopyInstall: conf/cert.pem $(CONPATH) */
39
40 enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED };
41
42 bool isin(const std::string &host, int port, const std::vector<std::string> &portlist)
43 {
44         if (std::find(portlist.begin(), portlist.end(), "*:" + ConvToStr(port)) != portlist.end())
45                 return true;
46
47         if (std::find(portlist.begin(), portlist.end(), ":" + ConvToStr(port)) != portlist.end())
48                 return true;
49
50         return std::find(portlist.begin(), portlist.end(), host + ":" + ConvToStr(port)) != portlist.end();
51 }
52
53 /** Represents an SSL user's extra data
54  */
55 class issl_session : public classbase
56 {
57 public:
58         gnutls_session_t sess;
59         issl_status status;
60         std::string outbuf;
61         int inbufoffset;
62         char* inbuf;
63         int fd;
64 };
65
66 class CommandStartTLS : public Command
67 {
68         Module* Caller;
69  public:
70         /* Command 'dalinfo', takes no parameters and needs no special modes */
71         CommandStartTLS (InspIRCd* Instance, Module* mod) : Command(Instance,"STARTTLS", 0, 0, true), Caller(mod)
72         {
73                 this->source = "m_ssl_gnutls.so";
74         }
75
76         CmdResult Handle (const std::vector<std::string> &parameters, User *user)
77         {
78                 if (user->registered == REG_ALL)
79                 {
80                         ServerInstance->Users->QuitUser(user, "STARTTLS not allowed after client registration");
81                 }
82                 else
83                 {
84                         if (!user->io)
85                         {
86                                 user->WriteNumeric(670, "%s :STARTTLS successful, go ahead with TLS handshake", user->nick.c_str());
87                                 user->io = Caller;
88                                 Caller->OnRawSocketAccept(user->GetFd(), user->GetIPString(), user->GetPort());
89                         }
90                         else
91                                 user->WriteNumeric(671, "%s :STARTTLS failure", user->nick.c_str());
92                 }
93
94                 return CMD_FAILURE;
95         }
96 };
97
98 class ModuleSSLGnuTLS : public Module
99 {
100
101         ConfigReader* Conf;
102
103         char* dummy;
104
105         std::vector<std::string> listenports;
106
107         int inbufsize;
108         issl_session* sessions;
109
110         gnutls_certificate_credentials x509_cred;
111         gnutls_dh_params dh_params;
112
113         std::string keyfile;
114         std::string certfile;
115         std::string cafile;
116         std::string crlfile;
117         std::string sslports;
118         int dh_bits;
119
120         int clientactive;
121         bool cred_alloc;
122
123         CommandStartTLS* starttls;
124
125  public:
126
127         ModuleSSLGnuTLS(InspIRCd* Me)
128                 : Module(Me)
129         {
130                 ServerInstance->Modules->PublishInterface("BufferedSocketHook", this);
131
132                 sessions = new issl_session[ServerInstance->SE->GetMaxFds()];
133
134                 // Not rehashable...because I cba to reduce all the sizes of existing buffers.
135                 inbufsize = ServerInstance->Config->NetBufferSize;
136
137                 gnutls_global_init(); // This must be called once in the program
138
139                 cred_alloc = false;
140                 // Needs the flag as it ignores a plain /rehash
141                 OnRehash(NULL,"ssl");
142
143                 // Void return, guess we assume success
144                 gnutls_certificate_set_dh_params(x509_cred, dh_params);
145                 Implementation eventlist[] = { I_On005Numeric, I_OnRawSocketConnect, I_OnRawSocketAccept, I_OnRawSocketClose, I_OnRawSocketRead, I_OnRawSocketWrite, I_OnCleanup,
146                         I_OnBufferFlushed, I_OnRequest, I_OnSyncUserMetaData, I_OnDecodeMetaData, I_OnUnloadModule, I_OnRehash, I_OnWhois, I_OnPostConnect, I_OnEvent, I_OnHookUserIO };
147                 ServerInstance->Modules->Attach(eventlist, this, 17);
148
149                 starttls = new CommandStartTLS(ServerInstance, this);
150                 ServerInstance->AddCommand(starttls);
151         }
152
153         virtual void OnRehash(User* user, const std::string &param)
154         {
155                 Conf = new ConfigReader(ServerInstance);
156
157                 listenports.clear();
158                 clientactive = 0;
159                 sslports.clear();
160
161                 for(int index = 0; index < Conf->Enumerate("bind"); index++)
162                 {
163                         // For each <bind> tag
164                         std::string x = Conf->ReadValue("bind", "type", index);
165                         if(((x.empty()) || (x == "clients")) && (Conf->ReadValue("bind", "ssl", index) == "gnutls"))
166                         {
167                                 // Get the port we're meant to be listening on with SSL
168                                 std::string port = Conf->ReadValue("bind", "port", index);
169                                 std::string addr = Conf->ReadValue("bind", "address", index);
170
171                                 irc::portparser portrange(port, false);
172                                 long portno = -1;
173                                 while ((portno = portrange.GetToken()))
174                                 {
175                                         clientactive++;
176                                         try
177                                         {
178                                                 listenports.push_back(addr + ":" + ConvToStr(portno));
179
180                                                 for (size_t i = 0; i < ServerInstance->Config->ports.size(); i++)
181                                                         if ((ServerInstance->Config->ports[i]->GetPort() == portno) && (ServerInstance->Config->ports[i]->GetIP() == addr))
182                                                                 ServerInstance->Config->ports[i]->SetDescription("ssl");
183                                                 ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %ld", portno);
184
185                                                 sslports.append((addr.empty() ? "*" : addr)).append(":").append(ConvToStr(portno)).append(";");
186                                         }
187                                         catch (ModuleException &e)
188                                         {
189                                                 ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: FAILED to enable SSL on port %ld: %s. Maybe it's already hooked by the same port on a different IP, or you have an other SSL or similar module loaded?", portno, e.GetReason());
190                                         }
191                                 }
192                         }
193                 }
194
195                 if (!sslports.empty())
196                         sslports.erase(sslports.end() - 1);
197
198                 if(param != "ssl")
199                 {
200                         delete Conf;
201                         return;
202                 }
203
204                 std::string confdir(ServerInstance->ConfigFileName);
205                 // +1 so we the path ends with a /
206                 confdir = confdir.substr(0, confdir.find_last_of('/') + 1);
207
208                 cafile  = Conf->ReadValue("gnutls", "cafile", 0);
209                 crlfile = Conf->ReadValue("gnutls", "crlfile", 0);
210                 certfile        = Conf->ReadValue("gnutls", "certfile", 0);
211                 keyfile = Conf->ReadValue("gnutls", "keyfile", 0);
212                 dh_bits = Conf->ReadInteger("gnutls", "dhbits", 0, false);
213
214                 // Set all the default values needed.
215                 if (cafile.empty())
216                         cafile = "ca.pem";
217
218                 if (crlfile.empty())
219                         crlfile = "crl.pem";
220
221                 if (certfile.empty())
222                         certfile = "cert.pem";
223
224                 if (keyfile.empty())
225                         keyfile = "key.pem";
226
227                 if((dh_bits != 768) && (dh_bits != 1024) && (dh_bits != 2048) && (dh_bits != 3072) && (dh_bits != 4096))
228                         dh_bits = 1024;
229
230                 // Prepend relative paths with the path to the config directory.
231                 if ((cafile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(cafile)))
232                         cafile = confdir + cafile;
233
234                 if ((crlfile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(crlfile)))
235                         crlfile = confdir + crlfile;
236
237                 if ((certfile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(certfile)))
238                         certfile = confdir + certfile;
239
240                 if ((keyfile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(keyfile)))
241                         keyfile = confdir + keyfile;
242
243                 int ret;
244                 
245                 if (cred_alloc)
246                 {
247                         // Deallocate the old credentials
248                         gnutls_dh_params_deinit(dh_params);
249                         gnutls_certificate_free_credentials(x509_cred);
250                 }
251                 else
252                         cred_alloc = true;
253                 
254                 if((ret = gnutls_certificate_allocate_credentials(&x509_cred)) < 0)
255                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to allocate certificate credentials: %s", gnutls_strerror(ret));
256                 
257                 if((ret = gnutls_dh_params_init(&dh_params)) < 0)
258                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to initialise DH parameters: %s", gnutls_strerror(ret));
259                 
260                 if((ret =gnutls_certificate_set_x509_trust_file(x509_cred, cafile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
261                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret));
262
263                 if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
264                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret));
265
266                 if((ret = gnutls_certificate_set_x509_key_file (x509_cred, certfile.c_str(), keyfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
267                 {
268                         // If this fails, no SSL port will work. At all. So, do the smart thing - throw a ModuleException
269                         throw ModuleException("Unable to load GnuTLS server certificate: " + std::string(gnutls_strerror(ret)));
270                 }
271
272                 // This may be on a large (once a day or week) timer eventually.
273                 GenerateDHParams();
274
275                 delete Conf;
276         }
277
278         void GenerateDHParams()
279         {
280                 // Generate Diffie Hellman parameters - for use with DHE
281                 // kx algorithms. These should be discarded and regenerated
282                 // once a day, once a week or once a month. Depending on the
283                 // security requirements.
284
285                 int ret;
286
287                 if((ret = gnutls_dh_params_generate2(dh_params, dh_bits)) < 0)
288                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to generate DH parameters (%d bits): %s", dh_bits, gnutls_strerror(ret));
289         }
290
291         virtual ~ModuleSSLGnuTLS()
292         {
293                 gnutls_dh_params_deinit(dh_params);
294                 gnutls_certificate_free_credentials(x509_cred);
295                 gnutls_global_deinit();
296                 ServerInstance->Modules->UnpublishInterface("BufferedSocketHook", this);
297                 delete[] sessions;
298         }
299
300         virtual void OnCleanup(int target_type, void* item)
301         {
302                 if(target_type == TYPE_USER)
303                 {
304                         User* user = (User*)item;
305
306                         if (user->io == this)
307                         {
308                                 // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
309                                 // Potentially there could be multiple SSL modules loaded at once on different ports.
310                                 ServerInstance->Users->QuitUser(user, "SSL module unloading");
311                                 user->io = NULL;
312                         }
313                         if (user->GetExt("ssl_cert", dummy))
314                         {
315                                 ssl_cert* tofree;
316                                 user->GetExt("ssl_cert", tofree);
317                                 delete tofree;
318                                 user->Shrink("ssl_cert");
319                         }
320                 }
321         }
322
323         virtual void OnUnloadModule(Module* mod, const std::string &name)
324         {
325                 if(mod == this)
326                 {
327                         for(unsigned int i = 0; i < listenports.size(); i++)
328                         {
329                                 for (size_t j = 0; j < ServerInstance->Config->ports.size(); j++)
330                                         if (listenports[i] == (ServerInstance->Config->ports[j]->GetIP()+":"+ConvToStr(ServerInstance->Config->ports[j]->GetPort())))
331                                                 ServerInstance->Config->ports[j]->SetDescription("plaintext");
332                         }
333                 }
334         }
335
336         virtual Version GetVersion()
337         {
338                 return Version("$Id$", VF_VENDOR, API_VERSION);
339         }
340
341
342         virtual void On005Numeric(std::string &output)
343         {
344                 output.append(" SSL=" + sslports);
345         }
346
347         virtual void OnHookUserIO(User* user, const std::string &targetip)
348         {
349                 if (!user->io && isin(targetip,user->GetPort(),listenports))
350                 {
351                         /* Hook the user with our module */
352                         user->io = this;
353                 }
354         }
355
356         virtual const char* OnRequest(Request* request)
357         {
358                 ISHRequest* ISR = (ISHRequest*)request;
359                 if (strcmp("IS_NAME", request->GetId()) == 0)
360                 {
361                         return "gnutls";
362                 }
363                 else if (strcmp("IS_HOOK", request->GetId()) == 0)
364                 {
365                         const char* ret = "OK";
366                         try
367                         {
368                                 ret = ISR->Sock->AddIOHook((Module*)this) ? "OK" : NULL;
369                         }
370                         catch (ModuleException &e)
371                         {
372                                 return NULL;
373                         }
374                         return ret;
375                 }
376                 else if (strcmp("IS_UNHOOK", request->GetId()) == 0)
377                 {
378                         return ISR->Sock->DelIOHook() ? "OK" : NULL;
379                 }
380                 else if (strcmp("IS_HSDONE", request->GetId()) == 0)
381                 {
382                         if (ISR->Sock->GetFd() < 0)
383                                 return "OK";
384
385                         issl_session* session = &sessions[ISR->Sock->GetFd()];
386                         return (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) ? NULL : "OK";
387                 }
388                 else if (strcmp("IS_ATTACH", request->GetId()) == 0)
389                 {
390                         if (ISR->Sock->GetFd() > -1)
391                         {
392                                 issl_session* session = &sessions[ISR->Sock->GetFd()];
393                                 if (session->sess)
394                                 {
395                                         if ((Extensible*)ServerInstance->FindDescriptor(ISR->Sock->GetFd()) == (Extensible*)(ISR->Sock))
396                                         {
397                                                 VerifyCertificate(session, (BufferedSocket*)ISR->Sock);
398                                                 return "OK";
399                                         }
400                                 }
401                         }
402                 }
403                 return NULL;
404         }
405
406
407         virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport)
408         {
409                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
410                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
411                         return;
412
413                 issl_session* session = &sessions[fd];
414
415                 /* For STARTTLS: Don't try and init a session on a socket that already has a session */
416                 if (session->sess)
417                         return;
418
419                 session->fd = fd;
420                 session->inbuf = new char[inbufsize];
421                 session->inbufoffset = 0;
422
423                 gnutls_init(&session->sess, GNUTLS_SERVER);
424
425                 gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate.
426                 gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
427                 gnutls_dh_set_prime_bits(session->sess, dh_bits);
428
429                 /* This is an experimental change to avoid a warning on 64bit systems about casting between integer and pointer of different sizes
430                  * This needs testing, but it's easy enough to rollback if need be
431                  * Old: gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
432                  * New: gnutls_transport_set_ptr(session->sess, &fd); // Give gnutls the fd for the socket.
433                  *
434                  * With testing this seems to...not work :/
435                  */
436
437                 gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
438
439                 gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any.
440
441                 Handshake(session);
442         }
443
444         virtual void OnRawSocketConnect(int fd)
445         {
446                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
447                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
448                         return;
449
450                 issl_session* session = &sessions[fd];
451
452                 session->fd = fd;
453                 session->inbuf = new char[inbufsize];
454                 session->inbufoffset = 0;
455
456                 gnutls_init(&session->sess, GNUTLS_CLIENT);
457
458                 gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate.
459                 gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
460                 gnutls_dh_set_prime_bits(session->sess, dh_bits);
461                 gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
462
463                 Handshake(session);
464         }
465
466         virtual void OnRawSocketClose(int fd)
467         {
468                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
469                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds()))
470                         return;
471
472                 CloseSession(&sessions[fd]);
473
474                 EventHandler* user = ServerInstance->SE->GetRef(fd);
475
476                 if ((user) && (user->GetExt("ssl_cert", dummy)))
477                 {
478                         ssl_cert* tofree;
479                         user->GetExt("ssl_cert", tofree);
480                         delete tofree;
481                         user->Shrink("ssl_cert");
482                 }
483         }
484
485         virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult)
486         {
487                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
488                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
489                         return 0;
490
491                 issl_session* session = &sessions[fd];
492
493                 if (!session->sess)
494                 {
495                         readresult = 0;
496                         CloseSession(session);
497                         return 1;
498                 }
499
500                 if (session->status == ISSL_HANDSHAKING_READ)
501                 {
502                         // The handshake isn't finished, try to finish it.
503
504                         if(!Handshake(session))
505                         {
506                                 // Couldn't resume handshake.
507                                 return -1;
508                         }
509                 }
510                 else if (session->status == ISSL_HANDSHAKING_WRITE)
511                 {
512                         errno = EAGAIN;
513                         MakePollWrite(session);
514                         return -1;
515                 }
516
517                 // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN.
518
519                 if (session->status == ISSL_HANDSHAKEN)
520                 {
521                         // Is this right? Not sure if the unencrypted data is garaunteed to be the same length.
522                         // Read into the inbuffer, offset from the beginning by the amount of data we have that insp hasn't taken yet.
523                         int ret = gnutls_record_recv(session->sess, session->inbuf + session->inbufoffset, inbufsize - session->inbufoffset);
524
525                         if (ret == 0)
526                         {
527                                 // Client closed connection.
528                                 readresult = 0;
529                                 CloseSession(session);
530                                 return 1;
531                         }
532                         else if (ret < 0)
533                         {
534                                 if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
535                                 {
536                                         errno = EAGAIN;
537                                         return -1;
538                                 }
539                                 else
540                                 {
541                                         readresult = 0;
542                                         CloseSession(session);
543                                 }
544                         }
545                         else
546                         {
547                                 // Read successfully 'ret' bytes into inbuf + inbufoffset
548                                 // There are 'ret' + 'inbufoffset' bytes of data in 'inbuf'
549                                 // 'buffer' is 'count' long
550
551                                 unsigned int length = ret + session->inbufoffset;
552
553                                 if(count <= length)
554                                 {
555                                         memcpy(buffer, session->inbuf, count);
556                                         // Move the stuff left in inbuf to the beginning of it
557                                         memmove(session->inbuf, session->inbuf + count, (length - count));
558                                         // Now we need to set session->inbufoffset to the amount of data still waiting to be handed to insp.
559                                         session->inbufoffset = length - count;
560                                         // Insp uses readresult as the count of how much data there is in buffer, so:
561                                         readresult = count;
562                                 }
563                                 else
564                                 {
565                                         // There's not as much in the inbuf as there is space in the buffer, so just copy the whole thing.
566                                         memcpy(buffer, session->inbuf, length);
567                                         // Zero the offset, as there's nothing there..
568                                         session->inbufoffset = 0;
569                                         // As above
570                                         readresult = length;
571                                 }
572                         }
573                 }
574                 else if(session->status == ISSL_CLOSING)
575                         readresult = 0;
576
577                 return 1;
578         }
579
580         virtual int OnRawSocketWrite(int fd, const char* buffer, int count)
581         {
582                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
583                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
584                         return 0;
585
586                 issl_session* session = &sessions[fd];
587                 const char* sendbuffer = buffer;
588
589                 if (!session->sess)
590                 {
591                         CloseSession(session);
592                         return 1;
593                 }
594
595                 session->outbuf.append(sendbuffer, count);
596                 sendbuffer = session->outbuf.c_str();
597                 count = session->outbuf.size();
598
599                 if (session->status == ISSL_HANDSHAKING_WRITE)
600                 {
601                         // The handshake isn't finished, try to finish it.
602                         Handshake(session);
603                         errno = EAGAIN;
604                         return -1;
605                 }
606
607                 int ret = 0;
608
609                 if (session->status == ISSL_HANDSHAKEN)
610                 {
611                         ret = gnutls_record_send(session->sess, sendbuffer, count);
612
613                         if (ret == 0)
614                         {
615                                 CloseSession(session);
616                         }
617                         else if (ret < 0)
618                         {
619                                 if(ret != GNUTLS_E_AGAIN && ret != GNUTLS_E_INTERRUPTED)
620                                 {
621                                         CloseSession(session);
622                                 }
623                                 else
624                                 {
625                                         errno = EAGAIN;
626                                 }
627                         }
628                         else
629                         {
630                                 session->outbuf = session->outbuf.substr(ret);
631                         }
632                 }
633
634                 MakePollWrite(session);
635
636                 /* Who's smart idea was it to return 1 when we havent written anything?
637                  * This fucks the buffer up in BufferedSocket :p
638                  */
639                 return ret < 1 ? 0 : ret;
640         }
641
642         // :kenny.chatspike.net 320 Om Epy|AFK :is a Secure Connection
643         virtual void OnWhois(User* source, User* dest)
644         {
645                 if (!clientactive)
646                         return;
647
648                 // Bugfix, only send this numeric for *our* SSL users
649                 if (dest->GetExt("ssl", dummy) || ((IS_LOCAL(dest) && (dest->io == this))))
650                 {
651                         ServerInstance->SendWhoisLine(source, dest, 320, "%s %s :is using a secure connection", source->nick.c_str(), dest->nick.c_str());
652                 }
653         }
654
655         virtual void OnSyncUserMetaData(User* user, Module* proto, void* opaque, const std::string &extname, bool displayable)
656         {
657                 // check if the linking module wants to know about OUR metadata
658                 if(extname == "ssl")
659                 {
660                         // check if this user has an swhois field to send
661                         if(user->GetExt(extname, dummy))
662                         {
663                                 // call this function in the linking module, let it format the data how it
664                                 // sees fit, and send it on its way. We dont need or want to know how.
665                                 proto->ProtoSendMetaData(opaque, TYPE_USER, user, extname, displayable ? "Enabled" : "ON");
666                         }
667                 }
668         }
669
670         virtual void OnDecodeMetaData(int target_type, void* target, const std::string &extname, const std::string &extdata)
671         {
672                 // check if its our metadata key, and its associated with a user
673                 if ((target_type == TYPE_USER) && (extname == "ssl"))
674                 {
675                         User* dest = (User*)target;
676                         // if they dont already have an ssl flag, accept the remote server's
677                         if (!dest->GetExt(extname, dummy))
678                         {
679                                 dest->Extend(extname, "ON");
680                         }
681                 }
682         }
683
684         bool Handshake(issl_session* session)
685         {
686                 int ret = gnutls_handshake(session->sess);
687
688                 if (ret < 0)
689                 {
690                         if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
691                         {
692                                 // Handshake needs resuming later, read() or write() would have blocked.
693
694                                 if(gnutls_record_get_direction(session->sess) == 0)
695                                 {
696                                         // gnutls_handshake() wants to read() again.
697                                         session->status = ISSL_HANDSHAKING_READ;
698                                 }
699                                 else
700                                 {
701                                         // gnutls_handshake() wants to write() again.
702                                         session->status = ISSL_HANDSHAKING_WRITE;
703                                         MakePollWrite(session);
704                                 }
705                         }
706                         else
707                         {
708                                 // Handshake failed.
709                                 CloseSession(session);
710                                 session->status = ISSL_CLOSING;
711                         }
712
713                         return false;
714                 }
715                 else
716                 {
717                         // Handshake complete.
718                         // This will do for setting the ssl flag...it could be done earlier if it's needed. But this seems neater.
719                         User* extendme = ServerInstance->FindDescriptor(session->fd);
720                         if (extendme)
721                         {
722                                 if (!extendme->GetExt("ssl", dummy))
723                                         extendme->Extend("ssl", "ON");
724                         }
725
726                         // Change the seesion state
727                         session->status = ISSL_HANDSHAKEN;
728
729                         // Finish writing, if any left
730                         MakePollWrite(session);
731
732                         return true;
733                 }
734         }
735
736         virtual void OnPostConnect(User* user)
737         {
738                 // This occurs AFTER OnUserConnect so we can be sure the
739                 // protocol module has propagated the NICK message.
740                 if ((user->GetExt("ssl", dummy)) && (IS_LOCAL(user)))
741                 {
742                         // Tell whatever protocol module we're using that we need to inform other servers of this metadata NOW.
743                         ServerInstance->PI->SendMetaData(user, TYPE_USER, "SSL", "on");
744
745                         VerifyCertificate(&sessions[user->GetFd()],user);
746                         if (sessions[user->GetFd()].sess)
747                         {
748                                 std::string cipher = gnutls_kx_get_name(gnutls_kx_get(sessions[user->GetFd()].sess));
749                                 cipher.append("-").append(gnutls_cipher_get_name(gnutls_cipher_get(sessions[user->GetFd()].sess))).append("-");
750                                 cipher.append(gnutls_mac_get_name(gnutls_mac_get(sessions[user->GetFd()].sess)));
751                                 user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick.c_str(), cipher.c_str());
752                         }
753                 }
754         }
755
756         void MakePollWrite(issl_session* session)
757         {
758                 //OnRawSocketWrite(session->fd, NULL, 0);
759                 EventHandler* eh = ServerInstance->FindDescriptor(session->fd);
760                 if (eh)
761                         ServerInstance->SE->WantWrite(eh);
762         }
763
764         virtual void OnBufferFlushed(User* user)
765         {
766                 if (user->GetExt("ssl"))
767                 {
768                         issl_session* session = &sessions[user->GetFd()];
769                         if (session && session->outbuf.size())
770                                 OnRawSocketWrite(user->GetFd(), NULL, 0);
771                 }
772         }
773
774         void CloseSession(issl_session* session)
775         {
776                 if(session->sess)
777                 {
778                         gnutls_bye(session->sess, GNUTLS_SHUT_WR);
779                         gnutls_deinit(session->sess);
780                 }
781
782                 if(session->inbuf)
783                 {
784                         delete[] session->inbuf;
785                 }
786
787                 session->outbuf.clear();
788                 session->inbuf = NULL;
789                 session->sess = NULL;
790                 session->status = ISSL_NONE;
791         }
792
793         void VerifyCertificate(issl_session* session, Extensible* user)
794         {
795                 if (!session->sess || !user)
796                         return;
797
798                 unsigned int status;
799                 const gnutls_datum_t* cert_list;
800                 int ret;
801                 unsigned int cert_list_size;
802                 gnutls_x509_crt_t cert;
803                 char name[MAXBUF];
804                 unsigned char digest[MAXBUF];
805                 size_t digest_size = sizeof(digest);
806                 size_t name_size = sizeof(name);
807                 ssl_cert* certinfo = new ssl_cert;
808
809                 user->Extend("ssl_cert",certinfo);
810
811                 /* This verification function uses the trusted CAs in the credentials
812                  * structure. So you must have installed one or more CA certificates.
813                  */
814                 ret = gnutls_certificate_verify_peers2(session->sess, &status);
815
816                 if (ret < 0)
817                 {
818                         certinfo->data.insert(std::make_pair("error",std::string(gnutls_strerror(ret))));
819                         return;
820                 }
821
822                 if (status & GNUTLS_CERT_INVALID)
823                 {
824                         certinfo->data.insert(std::make_pair("invalid",ConvToStr(1)));
825                 }
826                 else
827                 {
828                         certinfo->data.insert(std::make_pair("invalid",ConvToStr(0)));
829                 }
830                 if (status & GNUTLS_CERT_SIGNER_NOT_FOUND)
831                 {
832                         certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(1)));
833                 }
834                 else
835                 {
836                         certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(0)));
837                 }
838                 if (status & GNUTLS_CERT_REVOKED)
839                 {
840                         certinfo->data.insert(std::make_pair("revoked",ConvToStr(1)));
841                 }
842                 else
843                 {
844                         certinfo->data.insert(std::make_pair("revoked",ConvToStr(0)));
845                 }
846                 if (status & GNUTLS_CERT_SIGNER_NOT_CA)
847                 {
848                         certinfo->data.insert(std::make_pair("trusted",ConvToStr(0)));
849                 }
850                 else
851                 {
852                         certinfo->data.insert(std::make_pair("trusted",ConvToStr(1)));
853                 }
854
855                 /* Up to here the process is the same for X.509 certificates and
856                  * OpenPGP keys. From now on X.509 certificates are assumed. This can
857                  * be easily extended to work with openpgp keys as well.
858                  */
859                 if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509)
860                 {
861                         certinfo->data.insert(std::make_pair("error","No X509 keys sent"));
862                         return;
863                 }
864
865                 ret = gnutls_x509_crt_init(&cert);
866                 if (ret < 0)
867                 {
868                         certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
869                         return;
870                 }
871
872                 cert_list_size = 0;
873                 cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size);
874                 if (cert_list == NULL)
875                 {
876                         certinfo->data.insert(std::make_pair("error","No certificate was found"));
877                         return;
878                 }
879
880                 /* This is not a real world example, since we only check the first
881                  * certificate in the given chain.
882                  */
883
884                 ret = gnutls_x509_crt_import(cert, &cert_list[0], GNUTLS_X509_FMT_DER);
885                 if (ret < 0)
886                 {
887                         certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
888                         return;
889                 }
890
891                 gnutls_x509_crt_get_dn(cert, name, &name_size);
892
893                 certinfo->data.insert(std::make_pair("dn",name));
894
895                 gnutls_x509_crt_get_issuer_dn(cert, name, &name_size);
896
897                 certinfo->data.insert(std::make_pair("issuer",name));
898
899                 if ((ret = gnutls_x509_crt_get_fingerprint(cert, GNUTLS_DIG_MD5, digest, &digest_size)) < 0)
900                 {
901                         certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
902                 }
903                 else
904                 {
905                         certinfo->data.insert(std::make_pair("fingerprint",irc::hex(digest, digest_size)));
906                 }
907
908                 /* Beware here we do not check for errors.
909                  */
910                 if ((gnutls_x509_crt_get_expiration_time(cert) < time(0)) || (gnutls_x509_crt_get_activation_time(cert) > time(0)))
911                 {
912                         certinfo->data.insert(std::make_pair("error","Not activated, or expired certificate"));
913                 }
914
915                 gnutls_x509_crt_deinit(cert);
916
917                 return;
918         }
919
920         void OnEvent(Event* ev)
921         {
922                 GenericCapHandler(ev, "tls", "tls");
923                 if (ev->GetEventID() == "cap_req")
924                 {
925                         /* GenericCapHandler() Extends("tls") a user if it does
926                          * CAP REQ tls. Check if this was done.
927                          */
928                         CapData *data = (CapData *) ev->GetData();
929                         if (data->user->Shrink("tls"))
930                         {
931                                 /* Not in our spec?!?! */
932                                 data->user->io = this;
933                                 OnRawSocketAccept(data->user->GetFd(), data->user->GetIPString(),
934                                                 data->user->GetPort());
935                         }
936                 }
937         }
938
939         void Prioritize()
940         {
941                 Module* server = ServerInstance->Modules->Find("m_spanningtree.so");
942                 ServerInstance->Modules->SetPriority(this, I_OnPostConnect, PRIO_AFTER, &server);
943         }
944 };
945
946 MODULE_INIT(ModuleSSLGnuTLS)