]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_ssl_gnutls.cpp
cmd_who Hide +i users when listing users on a server and hidewhois is off
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_gnutls.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2009-2010 Daniel De Graaf <danieldg@inspircd.org>
5  *   Copyright (C) 2008 John Brooks <john.brooks@dereferenced.net>
6  *   Copyright (C) 2006-2008 Craig Edwards <craigedwards@brainbox.cc>
7  *   Copyright (C) 2007 Dennis Friis <peavey@inspircd.org>
8  *   Copyright (C) 2006 Oliver Lupton <oliverlupton@gmail.com>
9  *
10  * This file is part of InspIRCd.  InspIRCd is free software: you can
11  * redistribute it and/or modify it under the terms of the GNU General Public
12  * License as published by the Free Software Foundation, version 2.
13  *
14  * This program is distributed in the hope that it will be useful, but WITHOUT
15  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
16  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
17  * details.
18  *
19  * You should have received a copy of the GNU General Public License
20  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
21  */
22
23
24 #include "inspircd.h"
25 #include <gcrypt.h>
26 #include <gnutls/gnutls.h>
27 #include <gnutls/x509.h>
28 #include "ssl.h"
29 #include "m_cap.h"
30
31 #ifdef _WIN32
32 # pragma comment(lib, "libgnutls.lib")
33 # pragma comment(lib, "libgcrypt.lib")
34 # pragma comment(lib, "libgpg-error.lib")
35 # pragma comment(lib, "user32.lib")
36 # pragma comment(lib, "advapi32.lib")
37 # pragma comment(lib, "libgcc.lib")
38 # pragma comment(lib, "libmingwex.lib")
39 # pragma comment(lib, "gdi32.lib")
40 #endif
41
42 /* $ModDesc: Provides SSL support for clients */
43 /* $CompileFlags: pkgconfincludes("gnutls","/gnutls/gnutls.h","") exec("libgcrypt-config --cflags") */
44 /* $LinkerFlags: rpath("pkg-config --libs gnutls") pkgconflibs("gnutls","/libgnutls.so","-lgnutls") exec("libgcrypt-config --libs") */
45 /* $NoPedantic */
46
47 // These don't exist in older GnuTLS versions
48 #if ((GNUTLS_VERSION_MAJOR > 2) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR > 1) || (GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR == 1 && GNUTLS_VERSION_MICRO >= 7))
49 #define GNUTLS_NEW_PRIO_API
50 #endif
51
52 #if(GNUTLS_VERSION_MAJOR < 2)
53 typedef gnutls_certificate_credentials_t gnutls_certificate_credentials;
54 typedef gnutls_dh_params_t gnutls_dh_params;
55 #endif
56
57 enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED };
58
59 static std::vector<gnutls_x509_crt_t> x509_certs;
60 static gnutls_x509_privkey_t x509_key;
61 #if(GNUTLS_VERSION_MAJOR < 2 || ( GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR < 12 ) )
62 static int cert_callback (gnutls_session_t session, const gnutls_datum_t * req_ca_rdn, int nreqs,
63         const gnutls_pk_algorithm_t * sign_algos, int sign_algos_length, gnutls_retr_st * st) {
64
65         st->type = GNUTLS_CRT_X509;
66 #else
67 static int cert_callback (gnutls_session_t session, const gnutls_datum_t * req_ca_rdn, int nreqs,
68         const gnutls_pk_algorithm_t * sign_algos, int sign_algos_length, gnutls_retr2_st * st) {
69         st->cert_type = GNUTLS_CRT_X509;
70         st->key_type = GNUTLS_PRIVKEY_X509;
71 #endif
72         st->ncerts = x509_certs.size();
73         st->cert.x509 = &x509_certs[0];
74         st->key.x509 = x509_key;
75         st->deinit_all = 0;
76
77         return 0;
78 }
79
80 class RandGen : public HandlerBase2<void, char*, size_t>
81 {
82  public:
83         RandGen() {}
84         void Call(char* buffer, size_t len)
85         {
86                 gcry_randomize(buffer, len, GCRY_STRONG_RANDOM);
87         }
88 };
89
90 /** Represents an SSL user's extra data
91  */
92 class issl_session
93 {
94 public:
95         StreamSocket* socket;
96         gnutls_session_t sess;
97         issl_status status;
98         reference<ssl_cert> cert;
99
100         issl_session() : socket(NULL), sess(NULL) {}
101 };
102
103 class CommandStartTLS : public SplitCommand
104 {
105  public:
106         bool enabled;
107         CommandStartTLS (Module* mod) : SplitCommand(mod, "STARTTLS")
108         {
109                 enabled = true;
110                 works_before_reg = true;
111         }
112
113         CmdResult HandleLocal(const std::vector<std::string> &parameters, LocalUser *user)
114         {
115                 if (!enabled)
116                 {
117                         user->WriteNumeric(691, "%s :STARTTLS is not enabled", user->nick.c_str());
118                         return CMD_FAILURE;
119                 }
120
121                 if (user->registered == REG_ALL)
122                 {
123                         user->WriteNumeric(691, "%s :STARTTLS is not permitted after client registration is complete", user->nick.c_str());
124                 }
125                 else
126                 {
127                         if (!user->eh.GetIOHook())
128                         {
129                                 user->WriteNumeric(670, "%s :STARTTLS successful, go ahead with TLS handshake", user->nick.c_str());
130                                 /* We need to flush the write buffer prior to adding the IOHook,
131                                  * otherwise we'll be sending this line inside the SSL session - which
132                                  * won't start its handshake until the client gets this line. Currently,
133                                  * we assume the write will not block here; this is usually safe, as
134                                  * STARTTLS is sent very early on in the registration phase, where the
135                                  * user hasn't built up much sendq. Handling a blocked write here would
136                                  * be very annoying.
137                                  */
138                                 user->eh.DoWrite();
139                                 user->eh.AddIOHook(creator);
140                                 creator->OnStreamSocketAccept(&user->eh, NULL, NULL);
141                         }
142                         else
143                                 user->WriteNumeric(691, "%s :STARTTLS failure", user->nick.c_str());
144                 }
145
146                 return CMD_FAILURE;
147         }
148 };
149
150 class ModuleSSLGnuTLS : public Module
151 {
152         issl_session* sessions;
153
154         gnutls_certificate_credentials_t x509_cred;
155         gnutls_dh_params_t dh_params;
156         gnutls_digest_algorithm_t hash;
157         #ifdef GNUTLS_NEW_PRIO_API
158         gnutls_priority_t priority;
159         #endif
160
161         std::string sslports;
162         int dh_bits;
163
164         bool cred_alloc;
165         bool dh_alloc;
166
167         RandGen randhandler;
168         CommandStartTLS starttls;
169
170         GenericCap capHandler;
171         ServiceProvider iohook;
172
173         inline static const char* UnknownIfNULL(const char* str)
174         {
175                 return str ? str : "UNKNOWN";
176         }
177
178         static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t session_wrap, void* buffer, size_t size)
179         {
180                 issl_session* session = reinterpret_cast<issl_session*>(session_wrap);
181                 if (session->socket->GetEventMask() & FD_READ_WILL_BLOCK)
182                 {
183 #ifdef _WIN32
184                         gnutls_transport_set_errno(session->sess, EAGAIN);
185 #else
186                         errno = EAGAIN;
187 #endif
188                         return -1;
189                 }
190
191                 int rv = ServerInstance->SE->Recv(session->socket, reinterpret_cast<char *>(buffer), size, 0);
192
193 #ifdef _WIN32
194                 if (rv < 0)
195                 {
196                         /* Windows doesn't use errno, but gnutls does, so check SocketEngine::IgnoreError()
197                          * and then set errno appropriately.
198                          * The gnutls library may also have a different errno variable than us, see
199                          * gnutls_transport_set_errno(3).
200                          */
201                         gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno);
202                 }
203 #endif
204
205                 if (rv < (int)size)
206                         ServerInstance->SE->ChangeEventMask(session->socket, FD_READ_WILL_BLOCK);
207                 return rv;
208         }
209
210         static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t session_wrap, const void* buffer, size_t size)
211         {
212                 issl_session* session = reinterpret_cast<issl_session*>(session_wrap);
213                 if (session->socket->GetEventMask() & FD_WRITE_WILL_BLOCK)
214                 {
215 #ifdef _WIN32
216                         gnutls_transport_set_errno(session->sess, EAGAIN);
217 #else
218                         errno = EAGAIN;
219 #endif
220                         return -1;
221                 }
222
223                 int rv = ServerInstance->SE->Send(session->socket, reinterpret_cast<const char *>(buffer), size, 0);
224
225 #ifdef _WIN32
226                 if (rv < 0)
227                 {
228                         /* Windows doesn't use errno, but gnutls does, so check SocketEngine::IgnoreError()
229                          * and then set errno appropriately.
230                          * The gnutls library may also have a different errno variable than us, see
231                          * gnutls_transport_set_errno(3).
232                          */
233                         gnutls_transport_set_errno(session->sess, SocketEngine::IgnoreError() ? EAGAIN : errno);
234                 }
235 #endif
236
237                 if (rv < (int)size)
238                         ServerInstance->SE->ChangeEventMask(session->socket, FD_WRITE_WILL_BLOCK);
239                 return rv;
240         }
241
242  public:
243
244         ModuleSSLGnuTLS()
245                 : starttls(this), capHandler(this, "tls"), iohook(this, "ssl/gnutls", SERVICE_IOHOOK)
246         {
247                 gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0);
248
249                 sessions = new issl_session[ServerInstance->SE->GetMaxFds()];
250
251                 gnutls_global_init(); // This must be called once in the program
252                 gnutls_x509_privkey_init(&x509_key);
253
254                 #ifdef GNUTLS_NEW_PRIO_API
255                 // Init this here so it's always initialized, avoids an extra boolean
256                 gnutls_priority_init(&priority, "NORMAL", NULL);
257                 #endif
258
259                 cred_alloc = false;
260                 dh_alloc = false;
261         }
262
263         void init()
264         {
265                 // Needs the flag as it ignores a plain /rehash
266                 OnModuleRehash(NULL,"ssl");
267
268                 ServerInstance->GenRandom = &randhandler;
269
270                 // Void return, guess we assume success
271                 gnutls_certificate_set_dh_params(x509_cred, dh_params);
272                 Implementation eventlist[] = { I_On005Numeric, I_OnRehash, I_OnModuleRehash, I_OnUserConnect,
273                         I_OnEvent, I_OnHookIO };
274                 ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
275
276                 ServerInstance->Modules->AddService(iohook);
277                 ServerInstance->Modules->AddService(starttls);
278         }
279
280         void OnRehash(User* user)
281         {
282                 sslports.clear();
283
284                 ConfigTag* Conf = ServerInstance->Config->ConfValue("gnutls");
285                 starttls.enabled = Conf->getBool("starttls", true);
286
287                 if (Conf->getBool("showports", true))
288                 {
289                         sslports = Conf->getString("advertisedports");
290                         if (!sslports.empty())
291                                 return;
292
293                         for (size_t i = 0; i < ServerInstance->ports.size(); i++)
294                         {
295                                 ListenSocket* port = ServerInstance->ports[i];
296                                 if (port->bind_tag->getString("ssl") != "gnutls")
297                                         continue;
298
299                                 const std::string& portid = port->bind_desc;
300                                 ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %s", portid.c_str());
301
302                                 if (port->bind_tag->getString("type", "clients") == "clients" && port->bind_addr != "127.0.0.1")
303                                 {
304                                         /*
305                                          * Found an SSL port for clients that is not bound to 127.0.0.1 and handled by us, display
306                                          * the IP:port in ISUPPORT.
307                                          *
308                                          * We used to advertise all ports seperated by a ';' char that matched the above criteria,
309                                          * but this resulted in too long ISUPPORT lines if there were lots of ports to be displayed.
310                                          * To solve this by default we now only display the first IP:port found and let the user
311                                          * configure the exact value for the 005 token, if necessary.
312                                          */
313                                         sslports = portid;
314                                         break;
315                                 }
316                         }
317                 }
318         }
319
320         void OnModuleRehash(User* user, const std::string &param)
321         {
322                 if(param != "ssl")
323                         return;
324
325                 std::string keyfile;
326                 std::string certfile;
327                 std::string cafile;
328                 std::string crlfile;
329                 OnRehash(user);
330
331                 ConfigTag* Conf = ServerInstance->Config->ConfValue("gnutls");
332
333                 cafile = Conf->getString("cafile", CONFIG_PATH "/ca.pem");
334                 crlfile = Conf->getString("crlfile", CONFIG_PATH "/crl.pem");
335                 certfile = Conf->getString("certfile", CONFIG_PATH "/cert.pem");
336                 keyfile = Conf->getString("keyfile", CONFIG_PATH "/key.pem");
337                 dh_bits = Conf->getInt("dhbits");
338                 std::string hashname = Conf->getString("hash", "md5");
339
340                 // The GnuTLS manual states that the gnutls_set_default_priority()
341                 // call we used previously when initializing the session is the same
342                 // as setting the "NORMAL" priority string.
343                 // Thus if the setting below is not in the config we will behave exactly
344                 // the same as before, when the priority setting wasn't available.
345                 std::string priorities = Conf->getString("priority", "NORMAL");
346
347                 if((dh_bits != 768) && (dh_bits != 1024) && (dh_bits != 2048) && (dh_bits != 3072) && (dh_bits != 4096))
348                         dh_bits = 1024;
349
350                 if (hashname == "md5")
351                         hash = GNUTLS_DIG_MD5;
352                 else if (hashname == "sha1")
353                         hash = GNUTLS_DIG_SHA1;
354                 else
355                         throw ModuleException("Unknown hash type " + hashname);
356
357
358                 int ret;
359
360                 if (dh_alloc)
361                 {
362                         gnutls_dh_params_deinit(dh_params);
363                         dh_alloc = false;
364                         dh_params = NULL;
365                 }
366
367                 if (cred_alloc)
368                 {
369                         // Deallocate the old credentials
370                         gnutls_certificate_free_credentials(x509_cred);
371
372                         for(unsigned int i=0; i < x509_certs.size(); i++)
373                                 gnutls_x509_crt_deinit(x509_certs[i]);
374                         x509_certs.clear();
375                 }
376
377                 ret = gnutls_certificate_allocate_credentials(&x509_cred);
378                 cred_alloc = (ret >= 0);
379                 if (!cred_alloc)
380                         ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to allocate certificate credentials: %s", gnutls_strerror(ret));
381
382                 if((ret =gnutls_certificate_set_x509_trust_file(x509_cred, cafile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
383                         ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret));
384
385                 if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
386                         ServerInstance->Logs->Log("m_ssl_gnutls",DEBUG, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret));
387
388                 FileReader reader;
389
390                 reader.LoadFile(certfile);
391                 std::string cert_string = reader.Contents();
392                 gnutls_datum_t cert_datum = { (unsigned char*)cert_string.data(), static_cast<unsigned int>(cert_string.length()) };
393
394                 reader.LoadFile(keyfile);
395                 std::string key_string = reader.Contents();
396                 gnutls_datum_t key_datum = { (unsigned char*)key_string.data(), static_cast<unsigned int>(key_string.length()) };
397
398                 // If this fails, no SSL port will work. At all. So, do the smart thing - throw a ModuleException
399                 unsigned int certcount = 3;
400                 x509_certs.resize(certcount);
401                 ret = gnutls_x509_crt_list_import(&x509_certs[0], &certcount, &cert_datum, GNUTLS_X509_FMT_PEM, GNUTLS_X509_CRT_LIST_IMPORT_FAIL_IF_EXCEED);
402                 if (ret == GNUTLS_E_SHORT_MEMORY_BUFFER)
403                 {
404                         // the buffer wasn't big enough to hold all certs but gnutls updated certcount to the number of available certs, try again with a bigger buffer
405                         x509_certs.resize(certcount);
406                         ret = gnutls_x509_crt_list_import(&x509_certs[0], &certcount, &cert_datum, GNUTLS_X509_FMT_PEM, GNUTLS_X509_CRT_LIST_IMPORT_FAIL_IF_EXCEED);
407                 }
408
409                 if (ret <= 0)
410                 {
411                         // clear the vector so we won't call gnutls_x509_crt_deinit() on the (uninited) certs later
412                         x509_certs.clear();
413                         throw ModuleException("Unable to load GnuTLS server certificate (" + certfile + "): " + ((ret < 0) ? (std::string(gnutls_strerror(ret))) : "No certs could be read"));
414                 }
415                 x509_certs.resize(ret);
416
417                 if((ret = gnutls_x509_privkey_import(x509_key, &key_datum, GNUTLS_X509_FMT_PEM)) < 0)
418                         throw ModuleException("Unable to load GnuTLS server private key (" + keyfile + "): " + std::string(gnutls_strerror(ret)));
419
420                 if((ret = gnutls_certificate_set_x509_key(x509_cred, &x509_certs[0], certcount, x509_key)) < 0)
421                         throw ModuleException("Unable to set GnuTLS cert/key pair: " + std::string(gnutls_strerror(ret)));
422
423                 #ifdef GNUTLS_NEW_PRIO_API
424                 // It's safe to call this every time as we cannot have this uninitialized, see constructor and below.
425                 gnutls_priority_deinit(priority);
426
427                 // Try to set the priorities for ciphers, kex methods etc. to the user supplied string
428                 // If the user did not supply anything then the string is already set to "NORMAL"
429                 const char* priocstr = priorities.c_str();
430                 const char* prioerror;
431
432                 if ((ret = gnutls_priority_init(&priority, priocstr, &prioerror)) < 0)
433                 {
434                         // gnutls did not understand the user supplied string, log and fall back to the default priorities
435                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to set priorities to \"%s\": %s Syntax error at position %u, falling back to default (NORMAL)", priorities.c_str(), gnutls_strerror(ret), (unsigned int) (prioerror - priocstr));
436                         gnutls_priority_init(&priority, "NORMAL", NULL);
437                 }
438
439                 #else
440                 if (priorities != "NORMAL")
441                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: You've set <gnutls:priority> to a value other than the default, but this is only supported with GnuTLS v2.1.7 or newer. Your GnuTLS version is older than that so the option will have no effect.");
442                 #endif
443
444                 #if(GNUTLS_VERSION_MAJOR < 2 || ( GNUTLS_VERSION_MAJOR == 2 && GNUTLS_VERSION_MINOR < 12 ) )
445                 gnutls_certificate_client_set_retrieve_function (x509_cred, cert_callback);
446                 #else
447                 gnutls_certificate_set_retrieve_function (x509_cred, cert_callback);
448                 #endif
449                 ret = gnutls_dh_params_init(&dh_params);
450                 dh_alloc = (ret >= 0);
451                 if (!dh_alloc)
452                 {
453                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to initialise DH parameters: %s", gnutls_strerror(ret));
454                         return;
455                 }
456
457                 std::string dhfile = Conf->getString("dhfile");
458                 if (!dhfile.empty())
459                 {
460                         // Try to load DH params from file
461                         reader.LoadFile(dhfile);
462                         std::string dhstring = reader.Contents();
463                         gnutls_datum_t dh_datum = { (unsigned char*)dhstring.data(), static_cast<unsigned int>(dhstring.length()) };
464
465                         if ((ret = gnutls_dh_params_import_pkcs3(dh_params, &dh_datum, GNUTLS_X509_FMT_PEM)) < 0)
466                         {
467                                 // File unreadable or GnuTLS was unhappy with the contents, generate the DH primes now
468                                 ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, "m_ssl_gnutls.so: Generating DH parameters because I failed to load them from file '%s': %s", dhfile.c_str(), gnutls_strerror(ret));
469                                 GenerateDHParams();
470                         }
471                 }
472                 else
473                 {
474                         GenerateDHParams();
475                 }
476         }
477
478         void GenerateDHParams()
479         {
480                 // Generate Diffie Hellman parameters - for use with DHE
481                 // kx algorithms. These should be discarded and regenerated
482                 // once a day, once a week or once a month. Depending on the
483                 // security requirements.
484
485                 if (!dh_alloc)
486                         return;
487
488                 int ret;
489
490                 if((ret = gnutls_dh_params_generate2(dh_params, dh_bits)) < 0)
491                         ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to generate DH parameters (%d bits): %s", dh_bits, gnutls_strerror(ret));
492         }
493
494         ~ModuleSSLGnuTLS()
495         {
496                 for(unsigned int i=0; i < x509_certs.size(); i++)
497                         gnutls_x509_crt_deinit(x509_certs[i]);
498
499                 gnutls_x509_privkey_deinit(x509_key);
500                 #ifdef GNUTLS_NEW_PRIO_API
501                 gnutls_priority_deinit(priority);
502                 #endif
503
504                 if (dh_alloc)
505                         gnutls_dh_params_deinit(dh_params);
506                 if (cred_alloc)
507                         gnutls_certificate_free_credentials(x509_cred);
508
509                 gnutls_global_deinit();
510                 delete[] sessions;
511                 ServerInstance->GenRandom = &ServerInstance->HandleGenRandom;
512         }
513
514         void OnCleanup(int target_type, void* item)
515         {
516                 if(target_type == TYPE_USER)
517                 {
518                         LocalUser* user = IS_LOCAL(static_cast<User*>(item));
519
520                         if (user && user->eh.GetIOHook() == this)
521                         {
522                                 // User is using SSL, they're a local user, and they're using one of *our* SSL ports.
523                                 // Potentially there could be multiple SSL modules loaded at once on different ports.
524                                 ServerInstance->Users->QuitUser(user, "SSL module unloading");
525                         }
526                 }
527         }
528
529         Version GetVersion()
530         {
531                 return Version("Provides SSL support for clients", VF_VENDOR);
532         }
533
534
535         void On005Numeric(std::string &output)
536         {
537                 if (!sslports.empty())
538                         output.append(" SSL=" + sslports);
539                 if (starttls.enabled)
540                         output.append(" STARTTLS");
541         }
542
543         void OnHookIO(StreamSocket* user, ListenSocket* lsb)
544         {
545                 if (!user->GetIOHook() && lsb->bind_tag->getString("ssl") == "gnutls")
546                 {
547                         /* Hook the user with our module */
548                         user->AddIOHook(this);
549                 }
550         }
551
552         void OnRequest(Request& request)
553         {
554                 if (strcmp("GET_SSL_CERT", request.id) == 0)
555                 {
556                         SocketCertificateRequest& req = static_cast<SocketCertificateRequest&>(request);
557                         int fd = req.sock->GetFd();
558                         issl_session* session = &sessions[fd];
559
560                         req.cert = session->cert;
561                 }
562         }
563
564         void InitSession(StreamSocket* user, bool me_server)
565         {
566                 issl_session* session = &sessions[user->GetFd()];
567
568                 gnutls_init(&session->sess, me_server ? GNUTLS_SERVER : GNUTLS_CLIENT);
569                 session->socket = user;
570
571                 #ifdef GNUTLS_NEW_PRIO_API
572                 gnutls_priority_set(session->sess, priority);
573                 #endif
574                 gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
575                 gnutls_dh_set_prime_bits(session->sess, dh_bits);
576                 gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(session));
577                 gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper);
578                 gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper);
579
580                 if (me_server)
581                         gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any.
582
583                 Handshake(session, user);
584         }
585
586         void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
587         {
588                 issl_session* session = &sessions[user->GetFd()];
589
590                 /* For STARTTLS: Don't try and init a session on a socket that already has a session */
591                 if (session->sess)
592                         return;
593
594                 InitSession(user, true);
595         }
596
597         void OnStreamSocketConnect(StreamSocket* user)
598         {
599                 InitSession(user, false);
600         }
601
602         void OnStreamSocketClose(StreamSocket* user)
603         {
604                 CloseSession(&sessions[user->GetFd()]);
605         }
606
607         int OnStreamSocketRead(StreamSocket* user, std::string& recvq)
608         {
609                 issl_session* session = &sessions[user->GetFd()];
610
611                 if (!session->sess)
612                 {
613                         CloseSession(session);
614                         user->SetError("No SSL session");
615                         return -1;
616                 }
617
618                 if (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE)
619                 {
620                         // The handshake isn't finished, try to finish it.
621
622                         if(!Handshake(session, user))
623                         {
624                                 if (session->status != ISSL_CLOSING)
625                                         return 0;
626                                 return -1;
627                         }
628                 }
629
630                 // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN.
631
632                 if (session->status == ISSL_HANDSHAKEN)
633                 {
634                         char* buffer = ServerInstance->GetReadBuffer();
635                         size_t bufsiz = ServerInstance->Config->NetBufferSize;
636                         int ret = gnutls_record_recv(session->sess, buffer, bufsiz);
637                         if (ret > 0)
638                         {
639                                 recvq.append(buffer, ret);
640                                 return 1;
641                         }
642                         else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
643                         {
644                                 return 0;
645                         }
646                         else if (ret == 0)
647                         {
648                                 user->SetError("Connection closed");
649                                 CloseSession(session);
650                                 return -1;
651                         }
652                         else
653                         {
654                                 user->SetError(gnutls_strerror(ret));
655                                 CloseSession(session);
656                                 return -1;
657                         }
658                 }
659                 else if (session->status == ISSL_CLOSING)
660                         return -1;
661
662                 return 0;
663         }
664
665         int OnStreamSocketWrite(StreamSocket* user, std::string& sendq)
666         {
667                 issl_session* session = &sessions[user->GetFd()];
668
669                 if (!session->sess)
670                 {
671                         CloseSession(session);
672                         user->SetError("No SSL session");
673                         return -1;
674                 }
675
676                 if (session->status == ISSL_HANDSHAKING_WRITE || session->status == ISSL_HANDSHAKING_READ)
677                 {
678                         // The handshake isn't finished, try to finish it.
679                         Handshake(session, user);
680                         if (session->status != ISSL_CLOSING)
681                                 return 0;
682                         return -1;
683                 }
684
685                 int ret = 0;
686
687                 if (session->status == ISSL_HANDSHAKEN)
688                 {
689                         ret = gnutls_record_send(session->sess, sendq.data(), sendq.length());
690
691                         if (ret == (int)sendq.length())
692                         {
693                                 ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_WRITE);
694                                 return 1;
695                         }
696                         else if (ret > 0)
697                         {
698                                 sendq = sendq.substr(ret);
699                                 ServerInstance->SE->ChangeEventMask(user, FD_WANT_SINGLE_WRITE);
700                                 return 0;
701                         }
702                         else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED || ret == 0)
703                         {
704                                 ServerInstance->SE->ChangeEventMask(user, FD_WANT_SINGLE_WRITE);
705                                 return 0;
706                         }
707                         else // (ret < 0)
708                         {
709                                 user->SetError(gnutls_strerror(ret));
710                                 CloseSession(session);
711                                 return -1;
712                         }
713                 }
714
715                 return 0;
716         }
717
718         bool Handshake(issl_session* session, StreamSocket* user)
719         {
720                 int ret = gnutls_handshake(session->sess);
721
722                 if (ret < 0)
723                 {
724                         if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
725                         {
726                                 // Handshake needs resuming later, read() or write() would have blocked.
727
728                                 if(gnutls_record_get_direction(session->sess) == 0)
729                                 {
730                                         // gnutls_handshake() wants to read() again.
731                                         session->status = ISSL_HANDSHAKING_READ;
732                                         ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
733                                 }
734                                 else
735                                 {
736                                         // gnutls_handshake() wants to write() again.
737                                         session->status = ISSL_HANDSHAKING_WRITE;
738                                         ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
739                                 }
740                         }
741                         else
742                         {
743                                 user->SetError("Handshake Failed - " + std::string(gnutls_strerror(ret)));
744                                 CloseSession(session);
745                                 session->status = ISSL_CLOSING;
746                         }
747
748                         return false;
749                 }
750                 else
751                 {
752                         // Change the seesion state
753                         session->status = ISSL_HANDSHAKEN;
754
755                         VerifyCertificate(session,user);
756
757                         // Finish writing, if any left
758                         ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
759
760                         return true;
761                 }
762         }
763
764         void OnUserConnect(LocalUser* user)
765         {
766                 if (user->eh.GetIOHook() == this)
767                 {
768                         if (sessions[user->eh.GetFd()].sess)
769                         {
770                                 const gnutls_session_t& sess = sessions[user->eh.GetFd()].sess;
771                                 std::string cipher = UnknownIfNULL(gnutls_kx_get_name(gnutls_kx_get(sess)));
772                                 cipher.append("-").append(UnknownIfNULL(gnutls_cipher_get_name(gnutls_cipher_get(sess)))).append("-");
773                                 cipher.append(UnknownIfNULL(gnutls_mac_get_name(gnutls_mac_get(sess))));
774
775                                 ssl_cert* cert = sessions[user->eh.GetFd()].cert;
776                                 if (cert->fingerprint.empty())
777                                         user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick.c_str(), cipher.c_str());
778                                 else
779                                         user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\""
780                                                 " and your SSL fingerprint is %s", user->nick.c_str(), cipher.c_str(), cert->fingerprint.c_str());
781                         }
782                 }
783         }
784
785         void CloseSession(issl_session* session)
786         {
787                 if (session->sess)
788                 {
789                         gnutls_bye(session->sess, GNUTLS_SHUT_WR);
790                         gnutls_deinit(session->sess);
791                 }
792                 session->socket = NULL;
793                 session->sess = NULL;
794                 session->cert = NULL;
795                 session->status = ISSL_NONE;
796         }
797
798         void VerifyCertificate(issl_session* session, StreamSocket* user)
799         {
800                 if (!session->sess || !user)
801                         return;
802
803                 unsigned int status;
804                 const gnutls_datum_t* cert_list;
805                 int ret;
806                 unsigned int cert_list_size;
807                 gnutls_x509_crt_t cert;
808                 char name[MAXBUF];
809                 unsigned char digest[MAXBUF];
810                 size_t digest_size = sizeof(digest);
811                 size_t name_size = sizeof(name);
812                 ssl_cert* certinfo = new ssl_cert;
813                 session->cert = certinfo;
814
815                 /* This verification function uses the trusted CAs in the credentials
816                  * structure. So you must have installed one or more CA certificates.
817                  */
818                 ret = gnutls_certificate_verify_peers2(session->sess, &status);
819
820                 if (ret < 0)
821                 {
822                         certinfo->error = std::string(gnutls_strerror(ret));
823                         return;
824                 }
825
826                 certinfo->invalid = (status & GNUTLS_CERT_INVALID);
827                 certinfo->unknownsigner = (status & GNUTLS_CERT_SIGNER_NOT_FOUND);
828                 certinfo->revoked = (status & GNUTLS_CERT_REVOKED);
829                 certinfo->trusted = !(status & GNUTLS_CERT_SIGNER_NOT_CA);
830
831                 /* Up to here the process is the same for X.509 certificates and
832                  * OpenPGP keys. From now on X.509 certificates are assumed. This can
833                  * be easily extended to work with openpgp keys as well.
834                  */
835                 if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509)
836                 {
837                         certinfo->error = "No X509 keys sent";
838                         return;
839                 }
840
841                 ret = gnutls_x509_crt_init(&cert);
842                 if (ret < 0)
843                 {
844                         certinfo->error = gnutls_strerror(ret);
845                         return;
846                 }
847
848                 cert_list_size = 0;
849                 cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size);
850                 if (cert_list == NULL)
851                 {
852                         certinfo->error = "No certificate was found";
853                         goto info_done_dealloc;
854                 }
855
856                 /* This is not a real world example, since we only check the first
857                  * certificate in the given chain.
858                  */
859
860                 ret = gnutls_x509_crt_import(cert, &cert_list[0], GNUTLS_X509_FMT_DER);
861                 if (ret < 0)
862                 {
863                         certinfo->error = gnutls_strerror(ret);
864                         goto info_done_dealloc;
865                 }
866
867                 gnutls_x509_crt_get_dn(cert, name, &name_size);
868                 certinfo->dn = name;
869
870                 gnutls_x509_crt_get_issuer_dn(cert, name, &name_size);
871                 certinfo->issuer = name;
872
873                 if ((ret = gnutls_x509_crt_get_fingerprint(cert, hash, digest, &digest_size)) < 0)
874                 {
875                         certinfo->error = gnutls_strerror(ret);
876                 }
877                 else
878                 {
879                         certinfo->fingerprint = irc::hex(digest, digest_size);
880                 }
881
882                 /* Beware here we do not check for errors.
883                  */
884                 if ((gnutls_x509_crt_get_expiration_time(cert) < ServerInstance->Time()) || (gnutls_x509_crt_get_activation_time(cert) > ServerInstance->Time()))
885                 {
886                         certinfo->error = "Not activated, or expired certificate";
887                 }
888
889 info_done_dealloc:
890                 gnutls_x509_crt_deinit(cert);
891         }
892
893         void OnEvent(Event& ev)
894         {
895                 if (starttls.enabled)
896                         capHandler.HandleEvent(ev);
897         }
898 };
899
900 MODULE_INIT(ModuleSSLGnuTLS)