]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_ssl_gnutls.cpp
151d10fd401543d1411fd56bc548f74391c422b7
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_gnutls.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd: (C) 2002-2009 InspIRCd Development Team
6  * See: http://www.inspircd.org/wiki/index.php/Credits
7  *
8  * This program is free but copyrighted software; see
9  *          the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include "inspircd.h"
15 #include <gnutls/gnutls.h>
16 #include <gnutls/x509.h>
17 #include "transport.h"
18 #include "m_cap.h"
19
20 #ifdef WINDOWS
21 #pragma comment(lib, "libgnutls-13.lib")
22 #endif
23
24 /* $ModDesc: Provides SSL support for clients */
25 /* $CompileFlags: exec("libgnutls-config --cflags") */
26 /* $LinkerFlags: rpath("libgnutls-config --libs") exec("libgnutls-config --libs") */
27 /* $ModDep: transport.h */
28 /* $CopyInstall: conf/key.pem $(CONPATH) */
29 /* $CopyInstall: conf/cert.pem $(CONPATH) */
30
31 enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED };
32
33 bool isin(const std::string &host, int port, const std::vector<std::string> &portlist)
34 {
35         if (std::find(portlist.begin(), portlist.end(), "*:" + ConvToStr(port)) != portlist.end())
36                 return true;
37
38         if (std::find(portlist.begin(), portlist.end(), ":" + ConvToStr(port)) != portlist.end())
39                 return true;
40
41         return std::find(portlist.begin(), portlist.end(), host + ":" + ConvToStr(port)) != portlist.end();
42 }
43
44 /** Represents an SSL user's extra data
45  */
46 class issl_session : public classbase
47 {
48 public:
49         issl_session()
50         {
51                 sess = NULL;
52         }
53
54         gnutls_session_t sess;
55         issl_status status;
56         std::string outbuf;
57         int inbufoffset;
58         char* inbuf;
59         int fd;
60 };
61
62 class CommandStartTLS : public Command
63 {
64         Module* Caller;
65  public:
66         /* Command 'dalinfo', takes no parameters and needs no special modes */
67         CommandStartTLS (InspIRCd* Instance, Module* mod) : Command(Instance,"STARTTLS", 0, 0, true), Caller(mod)
68         {
69                 this->source = "m_ssl_gnutls.so";
70         }
71
72         CmdResult Handle (const std::vector<std::string> &parameters, User *user)
73         {
74                 /* changed from == REG_ALL to catch clients sending STARTTLS
75                  * after NICK and USER but before OnUserConnect completes and
76                  * give a proper error message (see bug #645) - dz
77                  */
78                 if (user->registered != REG_NONE)
79                 {
80                         ServerInstance->Users->QuitUser(user, "STARTTLS is not permitted after client registration has started");
81                 }
82                 else
83                 {
84                         if (!user->GetIOHook())
85                         {
86                                 user->WriteNumeric(670, "%s :STARTTLS successful, go ahead with TLS handshake", user->nick.c_str());
87                                 user->AddIOHook(Caller);
88                                 Caller->OnRawSocketAccept(user->GetFd(), user->GetIPString(), user->GetPort());
89                         }
90                         else
91                                 user->WriteNumeric(691, "%s :STARTTLS failure", user->nick.c_str());
92                 }
93
94                 return CMD_FAILURE;
95         }
96 };
97
98 class ModuleSSLGnuTLS : public Module
99 {
100
101         ConfigReader* Conf;
102
103         char* dummy;
104
105         std::vector<std::string> listenports;
106
107         int inbufsize;
108         issl_session* sessions;
109
110         gnutls_certificate_credentials x509_cred;
111         gnutls_dh_params dh_params;
112
113         std::string keyfile;
114         std::string certfile;
115         std::string cafile;
116         std::string crlfile;
117         std::string sslports;
118         int dh_bits;
119
120         int clientactive;
121         bool cred_alloc;
122
123         CommandStartTLS* starttls;
124
125  public:
126
127         ModuleSSLGnuTLS(InspIRCd* Me)
128                 : Module(Me)
129         {
130                 ServerInstance->Modules->PublishInterface("BufferedSocketHook", this);
131
132                 sessions = new issl_session[ServerInstance->SE->GetMaxFds()];
133
134                 // Not rehashable...because I cba to reduce all the sizes of existing buffers.
135                 inbufsize = ServerInstance->Config->NetBufferSize;
136
137                 gnutls_global_init(); // This must be called once in the program
138
139                 cred_alloc = false;
140                 // Needs the flag as it ignores a plain /rehash
141                 OnRehash(NULL,"ssl");
142
143                 // Void return, guess we assume success
144                 gnutls_certificate_set_dh_params(x509_cred, dh_params);
145                 Implementation eventlist[] = { I_On005Numeric, I_OnRawSocketConnect, I_OnRawSocketAccept, I_OnRawSocketClose, I_OnRawSocketRead, I_OnRawSocketWrite, I_OnCleanup,
146                         I_OnBufferFlushed, I_OnRequest, I_OnSyncUserMetaData, I_OnDecodeMetaData, I_OnUnloadModule, I_OnRehash, I_OnWhois, I_OnPostConnect, I_OnEvent, I_OnHookUserIO };
147                 ServerInstance->Modules->Attach(eventlist, this, 17);
148
149                 starttls = new CommandStartTLS(ServerInstance, this);
150                 ServerInstance->AddCommand(starttls);
151         }
152
153         virtual void OnRehash(User* user, const std::string &param)
154         {
155                 Conf = new ConfigReader(ServerInstance);
156
157                 listenports.clear();
158                 clientactive = 0;
159                 sslports.clear();
160
161                 for(int index = 0; index < Conf->Enumerate("bind"); index++)
162                 {
163                         // For each <bind> tag
164                         std::string x = Conf->ReadValue("bind", "type", index);
165                         if(((x.empty()) || (x == "clients")) && (Conf->ReadValue("bind", "ssl", index) == "gnutls"))
166                         {
167                                 // Get the port we're meant to be listening on with SSL
168                                 std::string port = Conf->ReadValue("bind", "port", index);
169                                 std::string addr = Conf->ReadValue("bind", "address", index);
170
171                                 irc::portparser portrange(port, false);
172                                 long portno = -1;
173                                 while ((portno = portrange.GetToken()))
174                                 {
175                                         clientactive++;
176                                         try
177                                         {
178                                                 listenports.push_back(addr + ":" + ConvToStr(portno));
179
180                                                 for (size_t i = 0; i < ServerInstance->Config->ports.size(); i++)
181                                                         if ((ServerInstance->Config->ports[i]->GetPort() == portno) && (ServerInstance->Config->ports[i]->GetIP() == addr))
182                                                                 ServerInstance->Config->ports[i]->SetDescription("ssl");
183                                                 ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %ld", portno);
184
185                                                 sslports.append((addr.empty() ? "*" : addr)).append(":").append(ConvToStr(portno)).append(";");
186                                         }
187                                         catch (ModuleException &e)
188                                         {
189                                                 ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: FAILED to enable SSL on port %ld: %s. Maybe it's already hooked by the same port on a different IP, or you have an other SSL or similar module loaded?", portno, e.GetReason());
190                                         }
191                                 }
192                         }
193                 }
194
195                 if (!sslports.empty())
196                         sslports.erase(sslports.end() - 1);
197
198                 if(param != "ssl")
199                 {
200                         delete Conf;
201                         return;
202                 }
203
204                 std::string confdir(ServerInstance->ConfigFileName);
205                 // +1 so we the path ends with a /
206                 confdir = confdir.substr(0, confdir.find_last_of('/') + 1);
207
208                 cafile  = Conf->ReadValue("gnutls", "cafile", 0);
209                 crlfile = Conf->ReadValue("gnutls", "crlfile", 0);
210                 certfile        = Conf->ReadValue("gnutls", "certfile", 0);
211                 keyfile = Conf->ReadValue("gnutls", "keyfile", 0);
212                 dh_bits = Conf->ReadInteger("gnutls", "dhbits", 0, false);
213
214                 // Set all the default values needed.
215                 if (cafile.empty())
216                         cafile = "ca.pem";
217
218                 if (crlfile.empty())
219                         crlfile = "crl.pem";
220
221                 if (certfile.empty())
222                         certfile = "cert.pem";
223
224                 if (keyfile.empty())
225                         keyfile = "key.pem";
226
227                 if((dh_bits != 768) && (dh_bits != 1024) && (dh_bits != 2048) && (dh_bits != 3072) && (dh_bits != 4096))
228                         dh_bits = 1024;
229
230                 // Prepend relative paths with the path to the config directory.
231                 if ((cafile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(cafile)))
232                         cafile = confdir + cafile;
233
234                 if ((crlfile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(crlfile)))
235                         crlfile = confdir + crlfile;
236
237                 if ((certfile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(certfile)))
238                         certfile = confdir + certfile;
239
240                 if ((keyfile[0] != '/') && (!ServerInstance->Config->StartsWithWindowsDriveLetter(keyfile)))
241                         keyfile = confdir + keyfile;
242
243                 int ret;
244                 
245                 if (cred_alloc)
246                 {
247                         // Deallocate the old credentials
248                         gnutls_dh_params_deinit(dh_params);
249                         gnutls_certificate_free_credentials(x509_cred);
250                 }
251                 else
252                         cred_alloc = true;
253                 
254                 if((ret = gnutls_certificate_allocate_credentials(&x509_cred)) < 0)
255                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to allocate certificate credentials: %s", gnutls_strerror(ret));
256                 
257                 if((ret = gnutls_dh_params_init(&dh_params)) < 0)
258                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to initialise DH parameters: %s", gnutls_strerror(ret));
259                 
260                 if((ret =gnutls_certificate_set_x509_trust_file(x509_cred, cafile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
261                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret));
262
263                 if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
264                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret));
265
266                 if((ret = gnutls_certificate_set_x509_key_file (x509_cred, certfile.c_str(), keyfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
267                 {
268                         // If this fails, no SSL port will work. At all. So, do the smart thing - throw a ModuleException
269                         throw ModuleException("Unable to load GnuTLS server certificate (" + certfile + ", key: " + keyfile + "): " + std::string(gnutls_strerror(ret)));
270                 }
271
272                 // This may be on a large (once a day or week) timer eventually.
273                 GenerateDHParams();
274
275                 delete Conf;
276         }
277
278         void GenerateDHParams()
279         {
280                 // Generate Diffie Hellman parameters - for use with DHE
281                 // kx algorithms. These should be discarded and regenerated
282                 // once a day, once a week or once a month. Depending on the
283                 // security requirements.
284
285                 int ret;
286
287                 if((ret = gnutls_dh_params_generate2(dh_params, dh_bits)) < 0)
288                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to generate DH parameters (%d bits): %s", dh_bits, gnutls_strerror(ret));
289         }
290
291         virtual ~ModuleSSLGnuTLS()
292         {
293                 gnutls_dh_params_deinit(dh_params);
294                 gnutls_certificate_free_credentials(x509_cred);
295                 gnutls_global_deinit();
296                 ServerInstance->Modules->UnpublishInterface("BufferedSocketHook", this);
297                 delete[] sessions;
298         }
299
300         virtual void OnCleanup(int target_type, void* item)
301         {
302                 if(target_type == TYPE_USER)
303                 {
304                         User* user = (User*)item;
305
306                         if (user->GetIOHook() == this)
307                         {
308                                 // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
309                                 // Potentially there could be multiple SSL modules loaded at once on different ports.
310                                 ServerInstance->Users->QuitUser(user, "SSL module unloading");
311                                 user->DelIOHook();
312                         }
313                         if (user->GetExt("ssl_cert", dummy))
314                         {
315                                 ssl_cert* tofree;
316                                 user->GetExt("ssl_cert", tofree);
317                                 delete tofree;
318                                 user->Shrink("ssl_cert");
319                         }
320                 }
321         }
322
323         virtual void OnUnloadModule(Module* mod, const std::string &name)
324         {
325                 if(mod == this)
326                 {
327                         for(unsigned int i = 0; i < listenports.size(); i++)
328                         {
329                                 for (size_t j = 0; j < ServerInstance->Config->ports.size(); j++)
330                                         if (listenports[i] == (ServerInstance->Config->ports[j]->GetIP()+":"+ConvToStr(ServerInstance->Config->ports[j]->GetPort())))
331                                                 ServerInstance->Config->ports[j]->SetDescription("plaintext");
332                         }
333                 }
334         }
335
336         virtual Version GetVersion()
337         {
338                 return Version("$Id$", VF_VENDOR, API_VERSION);
339         }
340
341
342         virtual void On005Numeric(std::string &output)
343         {
344                 output.append(" SSL=" + sslports);
345         }
346
347         virtual void OnHookUserIO(User* user, const std::string &targetip)
348         {
349                 if (!user->GetIOHook() && isin(targetip,user->GetPort(),listenports))
350                 {
351                         /* Hook the user with our module */
352                         user->AddIOHook(this);
353                 }
354         }
355
356         virtual const char* OnRequest(Request* request)
357         {
358                 ISHRequest* ISR = (ISHRequest*)request;
359                 if (strcmp("IS_NAME", request->GetId()) == 0)
360                 {
361                         return "gnutls";
362                 }
363                 else if (strcmp("IS_HOOK", request->GetId()) == 0)
364                 {
365                         const char* ret = "OK";
366                         try
367                         {
368                                 ret = ISR->Sock->AddIOHook((Module*)this) ? "OK" : NULL;
369                         }
370                         catch (ModuleException &e)
371                         {
372                                 return NULL;
373                         }
374                         return ret;
375                 }
376                 else if (strcmp("IS_UNHOOK", request->GetId()) == 0)
377                 {
378                         return ISR->Sock->DelIOHook() ? "OK" : NULL;
379                 }
380                 else if (strcmp("IS_HSDONE", request->GetId()) == 0)
381                 {
382                         if (ISR->Sock->GetFd() < 0)
383                                 return "OK";
384
385                         issl_session* session = &sessions[ISR->Sock->GetFd()];
386                         return (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) ? NULL : "OK";
387                 }
388                 else if (strcmp("IS_ATTACH", request->GetId()) == 0)
389                 {
390                         if (ISR->Sock->GetFd() > -1)
391                         {
392                                 issl_session* session = &sessions[ISR->Sock->GetFd()];
393                                 if (session->sess)
394                                 {
395                                         if ((Extensible*)ServerInstance->FindDescriptor(ISR->Sock->GetFd()) == (Extensible*)(ISR->Sock))
396                                         {
397                                                 VerifyCertificate(session, (BufferedSocket*)ISR->Sock);
398                                                 return "OK";
399                                         }
400                                 }
401                         }
402                 }
403                 return NULL;
404         }
405
406
407         virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport)
408         {
409                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
410                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
411                         return;
412
413                 issl_session* session = &sessions[fd];
414
415                 /* For STARTTLS: Don't try and init a session on a socket that already has a session */
416                 if (session->sess)
417                         return;
418
419                 session->fd = fd;
420                 session->inbuf = new char[inbufsize];
421                 session->inbufoffset = 0;
422
423                 gnutls_init(&session->sess, GNUTLS_SERVER);
424
425                 gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate.
426                 gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
427                 gnutls_dh_set_prime_bits(session->sess, dh_bits);
428
429                 /* This is an experimental change to avoid a warning on 64bit systems about casting between integer and pointer of different sizes
430                  * This needs testing, but it's easy enough to rollback if need be
431                  * Old: gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
432                  * New: gnutls_transport_set_ptr(session->sess, &fd); // Give gnutls the fd for the socket.
433                  *
434                  * With testing this seems to...not work :/
435                  */
436
437                 gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
438
439                 gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any.
440
441                 Handshake(session);
442         }
443
444         virtual void OnRawSocketConnect(int fd)
445         {
446                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
447                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
448                         return;
449
450                 issl_session* session = &sessions[fd];
451
452                 session->fd = fd;
453                 session->inbuf = new char[inbufsize];
454                 session->inbufoffset = 0;
455
456                 gnutls_init(&session->sess, GNUTLS_CLIENT);
457
458                 gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate.
459                 gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
460                 gnutls_dh_set_prime_bits(session->sess, dh_bits);
461                 gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
462
463                 Handshake(session);
464         }
465
466         virtual void OnRawSocketClose(int fd)
467         {
468                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
469                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds()))
470                         return;
471
472                 CloseSession(&sessions[fd]);
473
474                 EventHandler* user = ServerInstance->SE->GetRef(fd);
475
476                 if ((user) && (user->GetExt("ssl_cert", dummy)))
477                 {
478                         ssl_cert* tofree;
479                         user->GetExt("ssl_cert", tofree);
480                         delete tofree;
481                         user->Shrink("ssl_cert");
482                 }
483         }
484
485         virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult)
486         {
487                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
488                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
489                         return 0;
490
491                 issl_session* session = &sessions[fd];
492
493                 if (!session->sess)
494                 {
495                         readresult = 0;
496                         CloseSession(session);
497                         return 1;
498                 }
499
500                 if (session->status == ISSL_HANDSHAKING_READ)
501                 {
502                         // The handshake isn't finished, try to finish it.
503
504                         if(!Handshake(session))
505                         {
506                                 // Couldn't resume handshake.
507                                 return -1;
508                         }
509                 }
510                 else if (session->status == ISSL_HANDSHAKING_WRITE)
511                 {
512                         errno = EAGAIN;
513                         MakePollWrite(session);
514                         return -1;
515                 }
516
517                 // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN.
518
519                 if (session->status == ISSL_HANDSHAKEN)
520                 {
521                         // Is this right? Not sure if the unencrypted data is garaunteed to be the same length.
522                         // Read into the inbuffer, offset from the beginning by the amount of data we have that insp hasn't taken yet.
523                         int ret = gnutls_record_recv(session->sess, session->inbuf + session->inbufoffset, inbufsize - session->inbufoffset);
524
525                         if (ret == 0)
526                         {
527                                 // Client closed connection.
528                                 readresult = 0;
529                                 CloseSession(session);
530                                 return 1;
531                         }
532                         else if (ret < 0)
533                         {
534                                 if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
535                                 {
536                                         errno = EAGAIN;
537                                         return -1;
538                                 }
539                                 else
540                                 {
541                                         ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT,
542                                                         "m_ssl_gnutls.so: Error while reading on fd %d: %s",
543                                                         session->fd, gnutls_strerror(ret));
544                                         readresult = 0;
545                                         CloseSession(session);
546                                 }
547                         }
548                         else
549                         {
550                                 // Read successfully 'ret' bytes into inbuf + inbufoffset
551                                 // There are 'ret' + 'inbufoffset' bytes of data in 'inbuf'
552                                 // 'buffer' is 'count' long
553
554                                 unsigned int length = ret + session->inbufoffset;
555
556                                 if(count <= length)
557                                 {
558                                         memcpy(buffer, session->inbuf, count);
559                                         // Move the stuff left in inbuf to the beginning of it
560                                         memmove(session->inbuf, session->inbuf + count, (length - count));
561                                         // Now we need to set session->inbufoffset to the amount of data still waiting to be handed to insp.
562                                         session->inbufoffset = length - count;
563                                         // Insp uses readresult as the count of how much data there is in buffer, so:
564                                         readresult = count;
565                                 }
566                                 else
567                                 {
568                                         // There's not as much in the inbuf as there is space in the buffer, so just copy the whole thing.
569                                         memcpy(buffer, session->inbuf, length);
570                                         // Zero the offset, as there's nothing there..
571                                         session->inbufoffset = 0;
572                                         // As above
573                                         readresult = length;
574                                 }
575                         }
576                 }
577                 else if(session->status == ISSL_CLOSING)
578                         readresult = 0;
579
580                 return 1;
581         }
582
583         virtual int OnRawSocketWrite(int fd, const char* buffer, int count)
584         {
585                 /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */
586                 if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1))
587                         return 0;
588
589                 issl_session* session = &sessions[fd];
590                 const char* sendbuffer = buffer;
591
592                 if (!session->sess)
593                 {
594                         CloseSession(session);
595                         return 1;
596                 }
597
598                 session->outbuf.append(sendbuffer, count);
599                 sendbuffer = session->outbuf.c_str();
600                 count = session->outbuf.size();
601
602                 if (session->status == ISSL_HANDSHAKING_WRITE || session->status == ISSL_HANDSHAKING_READ)
603                 {
604                         // The handshake isn't finished, try to finish it.
605                         Handshake(session);
606                         errno = EAGAIN;
607                         return -1;
608                 }
609
610                 int ret = 0;
611
612                 if (session->status == ISSL_HANDSHAKEN)
613                 {
614                         ret = gnutls_record_send(session->sess, sendbuffer, count);
615
616                         if (ret == 0)
617                         {
618                                 CloseSession(session);
619                         }
620                         else if (ret < 0)
621                         {
622                                 if(ret != GNUTLS_E_AGAIN && ret != GNUTLS_E_INTERRUPTED)
623                                 {
624                                         ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT,
625                                                         "m_ssl_gnutls.so: Error while writing to fd %d: %s",
626                                                         session->fd, gnutls_strerror(ret));
627                                         CloseSession(session);
628                                 }
629                                 else
630                                 {
631                                         errno = EAGAIN;
632                                 }
633                         }
634                         else
635                         {
636                                 session->outbuf = session->outbuf.substr(ret);
637                         }
638                 }
639
640                 MakePollWrite(session);
641
642                 /* Who's smart idea was it to return 1 when we havent written anything?
643                  * This fucks the buffer up in BufferedSocket :p
644                  */
645                 return ret < 1 ? 0 : ret;
646         }
647
648         // :kenny.chatspike.net 320 Om Epy|AFK :is a Secure Connection
649         virtual void OnWhois(User* source, User* dest)
650         {
651                 if (!clientactive)
652                         return;
653
654                 // Bugfix, only send this numeric for *our* SSL users
655                 if (dest->GetExt("ssl", dummy) || ((IS_LOCAL(dest) && (dest->GetIOHook() == this))))
656                 {
657                         ServerInstance->SendWhoisLine(source, dest, 320, "%s %s :is using a secure connection", source->nick.c_str(), dest->nick.c_str());
658                 }
659         }
660
661         virtual void OnSyncUserMetaData(User* user, Module* proto, void* opaque, const std::string &extname, bool displayable)
662         {
663                 // check if the linking module wants to know about OUR metadata
664                 if(extname == "ssl")
665                 {
666                         // check if this user has an swhois field to send
667                         if(user->GetExt(extname, dummy))
668                         {
669                                 // call this function in the linking module, let it format the data how it
670                                 // sees fit, and send it on its way. We dont need or want to know how.
671                                 proto->ProtoSendMetaData(opaque, TYPE_USER, user, extname, displayable ? "Enabled" : "ON");
672                         }
673                 }
674         }
675
676         virtual void OnDecodeMetaData(int target_type, void* target, const std::string &extname, const std::string &extdata)
677         {
678                 // check if its our metadata key, and its associated with a user
679                 if ((target_type == TYPE_USER) && (extname == "ssl"))
680                 {
681                         User* dest = (User*)target;
682                         // if they dont already have an ssl flag, accept the remote server's
683                         if (!dest->GetExt(extname, dummy))
684                         {
685                                 dest->Extend(extname, "ON");
686                         }
687                 }
688         }
689
690         bool Handshake(issl_session* session)
691         {
692                 int ret = gnutls_handshake(session->sess);
693
694                 if (ret < 0)
695                 {
696                         if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
697                         {
698                                 // Handshake needs resuming later, read() or write() would have blocked.
699
700                                 if(gnutls_record_get_direction(session->sess) == 0)
701                                 {
702                                         // gnutls_handshake() wants to read() again.
703                                         session->status = ISSL_HANDSHAKING_READ;
704                                 }
705                                 else
706                                 {
707                                         // gnutls_handshake() wants to write() again.
708                                         session->status = ISSL_HANDSHAKING_WRITE;
709                                         MakePollWrite(session);
710                                 }
711                         }
712                         else
713                         {
714                                 // Handshake failed.
715                                 ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT,
716                                                 "m_ssl_gnutls.so: Handshake failed on fd %d: %s",
717                                                 session->fd, gnutls_strerror(ret));
718                                 CloseSession(session);
719                                 session->status = ISSL_CLOSING;
720                         }
721
722                         return false;
723                 }
724                 else
725                 {
726                         // Handshake complete.
727                         // This will do for setting the ssl flag...it could be done earlier if it's needed. But this seems neater.
728                         User* extendme = ServerInstance->FindDescriptor(session->fd);
729                         if (extendme)
730                         {
731                                 if (!extendme->GetExt("ssl", dummy))
732                                         extendme->Extend("ssl", "ON");
733                         }
734
735                         // Change the seesion state
736                         session->status = ISSL_HANDSHAKEN;
737
738                         // Finish writing, if any left
739                         MakePollWrite(session);
740
741                         return true;
742                 }
743         }
744
745         virtual void OnPostConnect(User* user)
746         {
747                 // This occurs AFTER OnUserConnect so we can be sure the
748                 // protocol module has propagated the NICK message.
749                 if (user->GetIOHook() == this && (IS_LOCAL(user)))
750                 {
751                         // Tell whatever protocol module we're using that we need to inform other servers of this metadata NOW.
752                         ServerInstance->PI->SendMetaData(user, TYPE_USER, "SSL", "on");
753
754                         VerifyCertificate(&sessions[user->GetFd()],user);
755                         if (sessions[user->GetFd()].sess)
756                         {
757                                 std::string cipher = gnutls_kx_get_name(gnutls_kx_get(sessions[user->GetFd()].sess));
758                                 cipher.append("-").append(gnutls_cipher_get_name(gnutls_cipher_get(sessions[user->GetFd()].sess))).append("-");
759                                 cipher.append(gnutls_mac_get_name(gnutls_mac_get(sessions[user->GetFd()].sess)));
760                                 user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick.c_str(), cipher.c_str());
761                         }
762                 }
763         }
764
765         void MakePollWrite(issl_session* session)
766         {
767                 //OnRawSocketWrite(session->fd, NULL, 0);
768                 EventHandler* eh = ServerInstance->FindDescriptor(session->fd);
769                 if (eh)
770                         ServerInstance->SE->WantWrite(eh);
771         }
772
773         virtual void OnBufferFlushed(User* user)
774         {
775                 if (user->GetIOHook() == this)
776                 {
777                         issl_session* session = &sessions[user->GetFd()];
778                         if (session && session->outbuf.size())
779                                 OnRawSocketWrite(user->GetFd(), NULL, 0);
780                 }
781         }
782
783         void CloseSession(issl_session* session)
784         {
785                 if(session->sess)
786                 {
787                         gnutls_bye(session->sess, GNUTLS_SHUT_WR);
788                         gnutls_deinit(session->sess);
789                 }
790
791                 if(session->inbuf)
792                 {
793                         delete[] session->inbuf;
794                 }
795
796                 session->outbuf.clear();
797                 session->inbuf = NULL;
798                 session->sess = NULL;
799                 session->status = ISSL_NONE;
800         }
801
802         void VerifyCertificate(issl_session* session, Extensible* user)
803         {
804                 if (!session->sess || !user)
805                         return;
806
807                 unsigned int status;
808                 const gnutls_datum_t* cert_list;
809                 int ret;
810                 unsigned int cert_list_size;
811                 gnutls_x509_crt_t cert;
812                 char name[MAXBUF];
813                 unsigned char digest[MAXBUF];
814                 size_t digest_size = sizeof(digest);
815                 size_t name_size = sizeof(name);
816                 ssl_cert* certinfo = new ssl_cert;
817
818                 user->Extend("ssl_cert",certinfo);
819
820                 /* This verification function uses the trusted CAs in the credentials
821                  * structure. So you must have installed one or more CA certificates.
822                  */
823                 ret = gnutls_certificate_verify_peers2(session->sess, &status);
824
825                 if (ret < 0)
826                 {
827                         certinfo->data.insert(std::make_pair("error",std::string(gnutls_strerror(ret))));
828                         return;
829                 }
830
831                 if (status & GNUTLS_CERT_INVALID)
832                 {
833                         certinfo->data.insert(std::make_pair("invalid",ConvToStr(1)));
834                 }
835                 else
836                 {
837                         certinfo->data.insert(std::make_pair("invalid",ConvToStr(0)));
838                 }
839                 if (status & GNUTLS_CERT_SIGNER_NOT_FOUND)
840                 {
841                         certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(1)));
842                 }
843                 else
844                 {
845                         certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(0)));
846                 }
847                 if (status & GNUTLS_CERT_REVOKED)
848                 {
849                         certinfo->data.insert(std::make_pair("revoked",ConvToStr(1)));
850                 }
851                 else
852                 {
853                         certinfo->data.insert(std::make_pair("revoked",ConvToStr(0)));
854                 }
855                 if (status & GNUTLS_CERT_SIGNER_NOT_CA)
856                 {
857                         certinfo->data.insert(std::make_pair("trusted",ConvToStr(0)));
858                 }
859                 else
860                 {
861                         certinfo->data.insert(std::make_pair("trusted",ConvToStr(1)));
862                 }
863
864                 /* Up to here the process is the same for X.509 certificates and
865                  * OpenPGP keys. From now on X.509 certificates are assumed. This can
866                  * be easily extended to work with openpgp keys as well.
867                  */
868                 if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509)
869                 {
870                         certinfo->data.insert(std::make_pair("error","No X509 keys sent"));
871                         return;
872                 }
873
874                 ret = gnutls_x509_crt_init(&cert);
875                 if (ret < 0)
876                 {
877                         certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
878                         return;
879                 }
880
881                 cert_list_size = 0;
882                 cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size);
883                 if (cert_list == NULL)
884                 {
885                         certinfo->data.insert(std::make_pair("error","No certificate was found"));
886                         return;
887                 }
888
889                 /* This is not a real world example, since we only check the first
890                  * certificate in the given chain.
891                  */
892
893                 ret = gnutls_x509_crt_import(cert, &cert_list[0], GNUTLS_X509_FMT_DER);
894                 if (ret < 0)
895                 {
896                         certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
897                         return;
898                 }
899
900                 gnutls_x509_crt_get_dn(cert, name, &name_size);
901
902                 certinfo->data.insert(std::make_pair("dn",name));
903
904                 gnutls_x509_crt_get_issuer_dn(cert, name, &name_size);
905
906                 certinfo->data.insert(std::make_pair("issuer",name));
907
908                 if ((ret = gnutls_x509_crt_get_fingerprint(cert, GNUTLS_DIG_MD5, digest, &digest_size)) < 0)
909                 {
910                         certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
911                 }
912                 else
913                 {
914                         certinfo->data.insert(std::make_pair("fingerprint",irc::hex(digest, digest_size)));
915                 }
916
917                 /* Beware here we do not check for errors.
918                  */
919                 if ((gnutls_x509_crt_get_expiration_time(cert) < ServerInstance->Time()) || (gnutls_x509_crt_get_activation_time(cert) > ServerInstance->Time()))
920                 {
921                         certinfo->data.insert(std::make_pair("error","Not activated, or expired certificate"));
922                 }
923
924                 gnutls_x509_crt_deinit(cert);
925
926                 return;
927         }
928
929         void OnEvent(Event* ev)
930         {
931                 GenericCapHandler(ev, "tls", "tls");
932         }
933
934         void Prioritize()
935         {
936                 Module* server = ServerInstance->Modules->Find("m_spanningtree.so");
937                 ServerInstance->Modules->SetPriority(this, I_OnPostConnect, PRIO_AFTER, &server);
938         }
939 };
940
941 MODULE_INIT(ModuleSSLGnuTLS)