]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_ssl_mbedtls.cpp
Automatically apply +P to all permanent channels.
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_mbedtls.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
5  *
6  * This file is part of InspIRCd.  InspIRCd is free software: you can
7  * redistribute it and/or modify it under the terms of the GNU General Public
8  * License as published by the Free Software Foundation, version 2.
9  *
10  * This program is distributed in the hope that it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
13  * details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19 /// $LinkerFlags: -lmbedtls
20
21 /// $PackageInfo: require_system("darwin") mbedtls
22 /// $PackageInfo: require_system("ubuntu" "16.04") libmbedtls-dev
23
24
25 #include "inspircd.h"
26 #include "modules/ssl.h"
27
28 #include <mbedtls/ctr_drbg.h>
29 #include <mbedtls/dhm.h>
30 #include <mbedtls/ecp.h>
31 #include <mbedtls/entropy.h>
32 #include <mbedtls/error.h>
33 #include <mbedtls/md.h>
34 #include <mbedtls/pk.h>
35 #include <mbedtls/ssl.h>
36 #include <mbedtls/ssl_ciphersuites.h>
37 #include <mbedtls/version.h>
38 #include <mbedtls/x509.h>
39 #include <mbedtls/x509_crt.h>
40 #include <mbedtls/x509_crl.h>
41
42 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
43 #include <mbedtls/debug.h>
44 #endif
45
46 namespace mbedTLS
47 {
48         class Exception : public ModuleException
49         {
50          public:
51                 Exception(const std::string& reason)
52                         : ModuleException(reason) { }
53         };
54
55         std::string ErrorToString(int errcode)
56         {
57                 char buf[256];
58                 mbedtls_strerror(errcode, buf, sizeof(buf));
59                 return buf;
60         }
61
62         void ThrowOnError(int errcode, const char* msg)
63         {
64                 if (errcode != 0)
65                 {
66                         std::string reason = msg;
67                         reason.append(" :").append(ErrorToString(errcode));
68                         throw Exception(reason);
69                 }
70         }
71
72         template <typename T, void (*init)(T*), void (*deinit)(T*)>
73         class RAIIObj
74         {
75                 T obj;
76
77          public:
78                 RAIIObj()
79                 {
80                         init(&obj);
81                 }
82
83                 ~RAIIObj()
84                 {
85                         deinit(&obj);
86                 }
87
88                 T* get() { return &obj; }
89                 const T* get() const { return &obj; }
90         };
91
92         typedef RAIIObj<mbedtls_entropy_context, mbedtls_entropy_init, mbedtls_entropy_free> Entropy;
93
94         class CTRDRBG : private RAIIObj<mbedtls_ctr_drbg_context, mbedtls_ctr_drbg_init, mbedtls_ctr_drbg_free>
95         {
96          public:
97                 bool Seed(Entropy& entropy)
98                 {
99                         return (mbedtls_ctr_drbg_seed(get(), mbedtls_entropy_func, entropy.get(), NULL, 0) == 0);
100                 }
101
102                 void SetupConf(mbedtls_ssl_config* conf)
103                 {
104                         mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, get());
105                 }
106         };
107
108         class DHParams : public RAIIObj<mbedtls_dhm_context, mbedtls_dhm_init, mbedtls_dhm_free>
109         {
110          public:
111                 void set(const std::string& dhstr)
112                 {
113                         // Last parameter is buffer size, must include the terminating null
114                         int ret = mbedtls_dhm_parse_dhm(get(), reinterpret_cast<const unsigned char*>(dhstr.c_str()), dhstr.size()+1);
115                         ThrowOnError(ret, "Unable to import DH params");
116                 }
117         };
118
119         class X509Key : public RAIIObj<mbedtls_pk_context, mbedtls_pk_init, mbedtls_pk_free>
120         {
121          public:
122                 /** Import */
123                 X509Key(const std::string& keystr)
124                 {
125                         int ret = mbedtls_pk_parse_key(get(), reinterpret_cast<const unsigned char*>(keystr.c_str()), keystr.size()+1, NULL, 0);
126                         ThrowOnError(ret, "Unable to import private key");
127                 }
128         };
129
130         class Ciphersuites
131         {
132                 std::vector<int> list;
133
134          public:
135                 Ciphersuites(const std::string& str)
136                 {
137                         // mbedTLS uses the ciphersuite format "TLS-ECDHE-RSA-WITH-AES-128-GCM-SHA256" internally.
138                         // This is a bit verbose, so we make life a bit simpler for admins by not requiring them to supply the static parts.
139                         irc::sepstream ss(str, ':');
140                         for (std::string token; ss.GetToken(token); )
141                         {
142                                 // Prepend "TLS-" if not there
143                                 if (token.compare(0, 4, "TLS-", 4))
144                                         token.insert(0, "TLS-");
145
146                                 const int id = mbedtls_ssl_get_ciphersuite_id(token.c_str());
147                                 if (!id)
148                                         throw Exception("Unknown ciphersuite " + token);
149                                 list.push_back(id);
150                         }
151                         list.push_back(0);
152                 }
153
154                 const int* get() const { return &list.front(); }
155                 bool empty() const { return (list.size() <= 1); }
156         };
157
158         class Curves
159         {
160                 std::vector<mbedtls_ecp_group_id> list;
161
162          public:
163                 Curves(const std::string& str)
164                 {
165                         irc::sepstream ss(str, ':');
166                         for (std::string token; ss.GetToken(token); )
167                         {
168                                 const mbedtls_ecp_curve_info* curve = mbedtls_ecp_curve_info_from_name(token.c_str());
169                                 if (!curve)
170                                         throw Exception("Unknown curve " + token);
171                                 list.push_back(curve->grp_id);
172                         }
173                         list.push_back(MBEDTLS_ECP_DP_NONE);
174                 }
175
176                 const mbedtls_ecp_group_id* get() const { return &list.front(); }
177                 bool empty() const { return (list.size() <= 1); }
178         };
179
180         class X509CertList : public RAIIObj<mbedtls_x509_crt, mbedtls_x509_crt_init, mbedtls_x509_crt_free>
181         {
182          public:
183                 /** Import or create empty */
184                 X509CertList(const std::string& certstr, bool allowempty = false)
185                 {
186                         if ((allowempty) && (certstr.empty()))
187                                 return;
188                         int ret = mbedtls_x509_crt_parse(get(), reinterpret_cast<const unsigned char*>(certstr.c_str()), certstr.size()+1);
189                         ThrowOnError(ret, "Unable to load certificates");
190                 }
191
192                 bool empty() const { return (get()->raw.p != NULL); }
193         };
194
195         class X509CRL : public RAIIObj<mbedtls_x509_crl, mbedtls_x509_crl_init, mbedtls_x509_crl_free>
196         {
197          public:
198                 X509CRL(const std::string& crlstr)
199                 {
200                         if (crlstr.empty())
201                                 return;
202                         int ret = mbedtls_x509_crl_parse(get(), reinterpret_cast<const unsigned char*>(crlstr.c_str()), crlstr.size()+1);
203                         ThrowOnError(ret, "Unable to load CRL");
204                 }
205         };
206
207         class X509Credentials
208         {
209                 /** Private key
210                  */
211                 X509Key key;
212
213                 /** Certificate list, presented to the peer
214                  */
215                 X509CertList certs;
216
217          public:
218                 X509Credentials(const std::string& certstr, const std::string& keystr)
219                         : key(keystr)
220                         , certs(certstr)
221                 {
222                         // Verify that one of the certs match the private key
223                         bool found = false;
224                         for (mbedtls_x509_crt* cert = certs.get(); cert; cert = cert->next)
225                         {
226                                 if (mbedtls_pk_check_pair(&cert->pk, key.get()) == 0)
227                                 {
228                                         found = true;
229                                         break;
230                                 }
231                         }
232                         if (!found)
233                                 throw Exception("Public/private key pair does not match");
234                 }
235
236                 mbedtls_pk_context* getkey() { return key.get(); }
237                 mbedtls_x509_crt* getcerts() { return certs.get(); }
238         };
239
240         class Context
241         {
242                 mbedtls_ssl_config conf;
243
244 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
245                 static void DebugLogFunc(void* userptr, int level, const char* file, int line, const char* msg)
246                 {
247                         // Remove trailing \n
248                         size_t len = strlen(msg);
249                         if ((len > 0) && (msg[len-1] == '\n'))
250                                 len--;
251                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "%s:%d %.*s", file, line, len, msg);
252                 }
253 #endif
254
255          public:
256                 Context(CTRDRBG& ctrdrbg, unsigned int endpoint)
257                 {
258                         mbedtls_ssl_config_init(&conf);
259 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
260                         mbedtls_debug_set_threshold(INT_MAX);
261                         mbedtls_ssl_conf_dbg(&conf, DebugLogFunc, NULL);
262 #endif
263
264                         // TODO: check ret of mbedtls_ssl_config_defaults
265                         mbedtls_ssl_config_defaults(&conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
266                         ctrdrbg.SetupConf(&conf);
267                 }
268
269                 ~Context()
270                 {
271                         mbedtls_ssl_config_free(&conf);
272                 }
273
274                 void SetMinDHBits(unsigned int mindh)
275                 {
276                         mbedtls_ssl_conf_dhm_min_bitlen(&conf, mindh);
277                 }
278
279                 void SetDHParams(DHParams& dh)
280                 {
281                         mbedtls_ssl_conf_dh_param_ctx(&conf, dh.get());
282                 }
283
284                 void SetX509CertAndKey(X509Credentials& x509cred)
285                 {
286                         mbedtls_ssl_conf_own_cert(&conf, x509cred.getcerts(), x509cred.getkey());
287                 }
288
289                 void SetCiphersuites(const Ciphersuites& ciphersuites)
290                 {
291                         mbedtls_ssl_conf_ciphersuites(&conf, ciphersuites.get());
292                 }
293
294                 void SetCurves(const Curves& curves)
295                 {
296                         mbedtls_ssl_conf_curves(&conf, curves.get());
297                 }
298
299                 void SetVersion(int minver, int maxver)
300                 {
301                         // SSL v3 support cannot be enabled
302                         if (minver)
303                                 mbedtls_ssl_conf_min_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, minver);
304                         if (maxver)
305                                 mbedtls_ssl_conf_max_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, maxver);
306                 }
307
308                 void SetCA(X509CertList& certs, X509CRL& crl)
309                 {
310                         mbedtls_ssl_conf_ca_chain(&conf, certs.get(), crl.get());
311                 }
312
313                 void SetOptionalVerifyCert()
314                 {
315                         mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
316                 }
317
318                 const mbedtls_ssl_config* GetConf() const { return &conf; }
319         };
320
321         class Hash
322         {
323                 const mbedtls_md_info_t* md;
324
325                 /** Buffer where cert hashes are written temporarily
326                  */
327                 mutable std::vector<unsigned char> buf;
328
329          public:
330                 Hash(std::string hashstr)
331                 {
332                         std::transform(hashstr.begin(), hashstr.end(), hashstr.begin(), ::toupper);
333                         md = mbedtls_md_info_from_string(hashstr.c_str());
334                         if (!md)
335                                 throw Exception("Unknown hash: " + hashstr);
336
337                         buf.resize(mbedtls_md_get_size(md));
338                 }
339
340                 std::string hash(const unsigned char* input, size_t length) const
341                 {
342                         mbedtls_md(md, input, length, &buf.front());
343                         return BinToHex(&buf.front(), buf.size());
344                 }
345         };
346
347         class Profile : public refcountbase
348         {
349                 /** Name of this profile
350                  */
351                 const std::string name;
352
353                 X509Credentials x509cred;
354
355                 /** Ciphersuites to use
356                  */
357                 Ciphersuites ciphersuites;
358
359                 /** Curves accepted for use in ECDHE and in the peer's end-entity certificate
360                  */
361                 Curves curves;
362
363                 Context serverctx;
364                 Context clientctx;
365
366                 DHParams dhparams;
367
368                 X509CertList cacerts;
369
370                 X509CRL crl;
371
372                 /** Hashing algorithm to use when generating certificate fingerprints
373                  */
374                 Hash hash;
375
376                 /** Rough max size of records to send
377                  */
378                 const unsigned int outrecsize;
379
380                 Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr,
381                                 const std::string& dhstr, unsigned int mindh, const std::string& hashstr,
382                                 const std::string& ciphersuitestr, const std::string& curvestr,
383                                 const std::string& castr, const std::string& crlstr,
384                                 unsigned int recsize,
385                                 CTRDRBG& ctrdrbg,
386                                 int minver, int maxver,
387                                 bool requestclientcert
388                                 )
389                         : name(profilename)
390                         , x509cred(certstr, keystr)
391                         , ciphersuites(ciphersuitestr)
392                         , curves(curvestr)
393                         , serverctx(ctrdrbg, MBEDTLS_SSL_IS_SERVER)
394                         , clientctx(ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
395                         , cacerts(castr, true)
396                         , crl(crlstr)
397                         , hash(hashstr)
398                         , outrecsize(recsize)
399                 {
400                         serverctx.SetX509CertAndKey(x509cred);
401                         clientctx.SetX509CertAndKey(x509cred);
402                         clientctx.SetMinDHBits(mindh);
403
404                         if (!ciphersuites.empty())
405                         {
406                                 serverctx.SetCiphersuites(ciphersuites);
407                                 clientctx.SetCiphersuites(ciphersuites);
408                         }
409
410                         if (!curves.empty())
411                         {
412                                 serverctx.SetCurves(curves);
413                                 clientctx.SetCurves(curves);
414                         }
415
416                         serverctx.SetVersion(minver, maxver);
417                         clientctx.SetVersion(minver, maxver);
418
419                         if (!dhstr.empty())
420                         {
421                                 dhparams.set(dhstr);
422                                 serverctx.SetDHParams(dhparams);
423                         }
424
425                         clientctx.SetOptionalVerifyCert();
426                         clientctx.SetCA(cacerts, crl);
427                         // The default for servers is to not request a client certificate from the peer
428                         if (requestclientcert)
429                         {
430                                 serverctx.SetOptionalVerifyCert();
431                                 serverctx.SetCA(cacerts, crl);
432                         }
433                 }
434
435                 static std::string ReadFile(const std::string& filename)
436                 {
437                         FileReader reader(filename);
438                         std::string ret = reader.GetString();
439                         if (ret.empty())
440                                 throw Exception("Cannot read file " + filename);
441                         return ret;
442                 }
443
444          public:
445                 static reference<Profile> Create(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
446                 {
447                         const std::string certstr = ReadFile(tag->getString("certfile", "cert.pem"));
448                         const std::string keystr = ReadFile(tag->getString("keyfile", "key.pem"));
449                         const std::string dhstr = ReadFile(tag->getString("dhfile", "dhparams.pem"));
450
451                         const std::string ciphersuitestr = tag->getString("ciphersuites");
452                         const std::string curvestr = tag->getString("curves");
453                         unsigned int mindh = tag->getInt("mindhbits", 2048);
454                         std::string hashstr = tag->getString("hash", "sha256");
455
456                         std::string crlstr;
457                         std::string castr = tag->getString("cafile");
458                         if (!castr.empty())
459                         {
460                                 castr = ReadFile(castr);
461                                 crlstr = tag->getString("crlfile");
462                                 if (!crlstr.empty())
463                                         crlstr = ReadFile(crlstr);
464                         }
465
466                         int minver = tag->getInt("minver");
467                         int maxver = tag->getInt("maxver");
468                         unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
469                         const bool requestclientcert = tag->getBool("requestclientcert", true);
470                         return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver, requestclientcert);
471                 }
472
473                 /** Set up the given session with the settings in this profile
474                  */
475                 void SetupClientSession(mbedtls_ssl_context* sess)
476                 {
477                         mbedtls_ssl_setup(sess, clientctx.GetConf());
478                 }
479
480                 void SetupServerSession(mbedtls_ssl_context* sess)
481                 {
482                         mbedtls_ssl_setup(sess, serverctx.GetConf());
483                 }
484
485                 const std::string& GetName() const { return name; }
486                 X509Credentials& GetX509Credentials() { return x509cred; }
487                 unsigned int GetOutgoingRecordSize() const { return outrecsize; }
488                 const Hash& GetHash() const { return hash; }
489         };
490 }
491
492 class mbedTLSIOHook : public SSLIOHook
493 {
494         enum Status
495         {
496                 ISSL_NONE,
497                 ISSL_HANDSHAKING,
498                 ISSL_HANDSHAKEN
499         };
500
501         mbedtls_ssl_context sess;
502         Status status;
503         reference<mbedTLS::Profile> profile;
504
505         void CloseSession()
506         {
507                 if (status == ISSL_NONE)
508                         return;
509
510                 mbedtls_ssl_close_notify(&sess);
511                 mbedtls_ssl_free(&sess);
512                 certificate = NULL;
513                 status = ISSL_NONE;
514         }
515
516         // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed
517         int Handshake(StreamSocket* sock)
518         {
519                 int ret = mbedtls_ssl_handshake(&sess);
520                 if (ret == 0)
521                 {
522                         // Change the seesion state
523                         this->status = ISSL_HANDSHAKEN;
524
525                         VerifyCertificate();
526
527                         // Finish writing, if any left
528                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
529
530                         return 1;
531                 }
532
533                 this->status = ISSL_HANDSHAKING;
534                 if (ret == MBEDTLS_ERR_SSL_WANT_READ)
535                 {
536                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
537                         return 0;
538                 }
539                 else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
540                 {
541                         SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
542                         return 0;
543                 }
544
545                 sock->SetError("Handshake Failed - " + mbedTLS::ErrorToString(ret));
546                 CloseSession();
547                 return -1;
548         }
549
550         // Returns 1 if application I/O should proceed, 0 if it must wait for the underlying protocol to progress, -1 on fatal error
551         int PrepareIO(StreamSocket* sock)
552         {
553                 if (status == ISSL_HANDSHAKEN)
554                         return 1;
555                 else if (status == ISSL_HANDSHAKING)
556                 {
557                         // The handshake isn't finished, try to finish it
558                         return Handshake(sock);
559                 }
560
561                 CloseSession();
562                 sock->SetError("No SSL session");
563                 return -1;
564         }
565
566         void VerifyCertificate()
567         {
568                 this->certificate = new ssl_cert;
569                 const mbedtls_x509_crt* const cert = mbedtls_ssl_get_peer_cert(&sess);
570                 if (!cert)
571                 {
572                         certificate->error = "No client certificate sent";
573                         return;
574                 }
575
576                 // If there is a certificate we can always generate a fingerprint
577                 certificate->fingerprint = profile->GetHash().hash(cert->raw.p, cert->raw.len);
578
579                 // At this point mbedTLS verified the cert already, we just need to check the results
580                 const uint32_t flags = mbedtls_ssl_get_verify_result(&sess);
581                 if (flags == 0xFFFFFFFF)
582                 {
583                         certificate->error = "Internal error during verification";
584                         return;
585                 }
586
587                 if (flags == 0)
588                 {
589                         // Verification succeeded
590                         certificate->trusted = true;
591                 }
592                 else
593                 {
594                         // Verification failed
595                         certificate->trusted = false;
596                         if ((flags & MBEDTLS_X509_BADCERT_EXPIRED) || (flags & MBEDTLS_X509_BADCERT_FUTURE))
597                                 certificate->error = "Not activated, or expired certificate";
598                 }
599
600                 certificate->unknownsigner = (flags & MBEDTLS_X509_BADCERT_NOT_TRUSTED);
601                 certificate->revoked = (flags & MBEDTLS_X509_BADCERT_REVOKED);
602                 certificate->invalid = ((flags & MBEDTLS_X509_BADCERT_BAD_KEY) || (flags & MBEDTLS_X509_BADCERT_BAD_MD) || (flags & MBEDTLS_X509_BADCERT_BAD_PK));
603
604                 GetDNString(&cert->subject, certificate->dn);
605                 GetDNString(&cert->issuer, certificate->issuer);
606         }
607
608         static void GetDNString(const mbedtls_x509_name* x509name, std::string& out)
609         {
610                 char buf[512];
611                 const int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), x509name);
612                 if (ret <= 0)
613                         return;
614
615                 out.assign(buf, ret);
616         }
617
618         static int Pull(void* userptr, unsigned char* buffer, size_t size)
619         {
620                 StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
621                 if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
622                         return MBEDTLS_ERR_SSL_WANT_READ;
623
624                 const int ret = SocketEngine::Recv(sock, reinterpret_cast<char*>(buffer), size, 0);
625                 if (ret < (int)size)
626                 {
627                         SocketEngine::ChangeEventMask(sock, FD_READ_WILL_BLOCK);
628                         if ((ret == -1) && (SocketEngine::IgnoreError()))
629                                 return MBEDTLS_ERR_SSL_WANT_READ;
630                 }
631                 return ret;
632         }
633
634         static int Push(void* userptr, const unsigned char* buffer, size_t size)
635         {
636                 StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
637                 if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
638                         return MBEDTLS_ERR_SSL_WANT_WRITE;
639
640                 const int ret = SocketEngine::Send(sock, buffer, size, 0);
641                 if (ret < (int)size)
642                 {
643                         SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK);
644                         if ((ret == -1) && (SocketEngine::IgnoreError()))
645                                 return MBEDTLS_ERR_SSL_WANT_WRITE;
646                 }
647                 return ret;
648         }
649
650  public:
651         mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver, mbedTLS::Profile* sslprofile)
652                 : SSLIOHook(hookprov)
653                 , status(ISSL_NONE)
654                 , profile(sslprofile)
655         {
656                 mbedtls_ssl_init(&sess);
657                 if (isserver)
658                         profile->SetupServerSession(&sess);
659                 else
660                         profile->SetupClientSession(&sess);
661
662                 mbedtls_ssl_set_bio(&sess, reinterpret_cast<void*>(sock), Push, Pull, NULL);
663
664                 sock->AddIOHook(this);
665                 Handshake(sock);
666         }
667
668         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
669         {
670                 CloseSession();
671         }
672
673         int OnStreamSocketRead(StreamSocket* sock, std::string& recvq) CXX11_OVERRIDE
674         {
675                 // Finish handshake if needed
676                 int prepret = PrepareIO(sock);
677                 if (prepret <= 0)
678                         return prepret;
679
680                 // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN.
681                 char* const readbuf = ServerInstance->GetReadBuffer();
682                 const size_t readbufsize = ServerInstance->Config->NetBufferSize;
683                 int ret = mbedtls_ssl_read(&sess, reinterpret_cast<unsigned char*>(readbuf), readbufsize);
684                 if (ret > 0)
685                 {
686                         recvq.append(readbuf, ret);
687
688                         // Schedule a read if there is still data in the mbedTLS buffer
689                         if (mbedtls_ssl_get_bytes_avail(&sess) > 0)
690                                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_READ);
691                         return 1;
692                 }
693                 else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
694                 {
695                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
696                         return 0;
697                 }
698                 else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
699                 {
700                         SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
701                         return 0;
702                 }
703                 else if (ret == 0)
704                 {
705                         sock->SetError("Connection closed");
706                         CloseSession();
707                         return -1;
708                 }
709                 else // error or MBEDTLS_ERR_SSL_CLIENT_RECONNECT which we treat as an error
710                 {
711                         sock->SetError(mbedTLS::ErrorToString(ret));
712                         CloseSession();
713                         return -1;
714                 }
715         }
716
717         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
718         {
719                 // Finish handshake if needed
720                 int prepret = PrepareIO(sock);
721                 if (prepret <= 0)
722                         return prepret;
723
724                 // Session is ready for transferring application data
725                 while (!sendq.empty())
726                 {
727                         FlattenSendQueue(sendq, profile->GetOutgoingRecordSize());
728                         const StreamSocket::SendQueue::Element& buffer = sendq.front();
729                         int ret = mbedtls_ssl_write(&sess, reinterpret_cast<const unsigned char*>(buffer.data()), buffer.length());
730                         if (ret == (int)buffer.length())
731                         {
732                                 // Wrote entire record, continue sending
733                                 sendq.pop_front();
734                         }
735                         else if (ret > 0)
736                         {
737                                 sendq.erase_front(ret);
738                                 SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
739                                 return 0;
740                         }
741                         else if (ret == 0)
742                         {
743                                 sock->SetError("Connection closed");
744                                 CloseSession();
745                                 return -1;
746                         }
747                         else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
748                         {
749                                 SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
750                                 return 0;
751                         }
752                         else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
753                         {
754                                 SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
755                                 return 0;
756                         }
757                         else
758                         {
759                                 sock->SetError(mbedTLS::ErrorToString(ret));
760                                 CloseSession();
761                                 return -1;
762                         }
763                 }
764
765                 SocketEngine::ChangeEventMask(sock, FD_WANT_NO_WRITE);
766                 return 1;
767         }
768
769         void GetCiphersuite(std::string& out) const CXX11_OVERRIDE
770         {
771                 if (!IsHandshakeDone())
772                         return;
773                 out.append(mbedtls_ssl_get_version(&sess)).push_back('-');
774
775                 // All mbedTLS ciphersuite names currently begin with "TLS-" which provides no useful information so skip it, but be prepared if it changes
776                 const char* const ciphersuitestr = mbedtls_ssl_get_ciphersuite(&sess);
777                 const char prefix[] = "TLS-";
778                 unsigned int skip = sizeof(prefix)-1;
779                 if (strncmp(ciphersuitestr, prefix, sizeof(prefix)-1))
780                         skip = 0;
781                 out.append(ciphersuitestr + skip);
782         }
783
784         bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
785 };
786
787 class mbedTLSIOHookProvider : public refcountbase, public IOHookProvider
788 {
789         reference<mbedTLS::Profile> profile;
790
791  public:
792         mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile* prof)
793                 : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL)
794                 , profile(prof)
795         {
796                 ServerInstance->Modules->AddService(*this);
797         }
798
799         ~mbedTLSIOHookProvider()
800         {
801                 ServerInstance->Modules->DelService(*this);
802         }
803
804         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
805         {
806                 new mbedTLSIOHook(this, sock, true, profile);
807         }
808
809         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
810         {
811                 new mbedTLSIOHook(this, sock, false, profile);
812         }
813 };
814
815 class ModuleSSLmbedTLS : public Module
816 {
817         typedef std::vector<reference<mbedTLSIOHookProvider> > ProfileList;
818
819         mbedTLS::Entropy entropy;
820         mbedTLS::CTRDRBG ctr_drbg;
821         ProfileList profiles;
822
823         void ReadProfiles()
824         {
825                 // First, store all profiles in a new, temporary container. If no problems occur, swap the two
826                 // containers; this way if something goes wrong we can go back and continue using the current profiles,
827                 // avoiding unpleasant situations where no new SSL connections are possible.
828                 ProfileList newprofiles;
829
830                 ConfigTagList tags = ServerInstance->Config->ConfTags("sslprofile");
831                 if (tags.first == tags.second)
832                 {
833                         // No <sslprofile> tags found, create a profile named "mbedtls" from settings in the <mbedtls> block
834                         const std::string defname = "mbedtls";
835                         ConfigTag* tag = ServerInstance->Config->ConfValue(defname);
836                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No <sslprofile> tags found; using settings from the <mbedtls> tag");
837
838                         try
839                         {
840                                 reference<mbedTLS::Profile> profile(mbedTLS::Profile::Create(defname, tag, ctr_drbg));
841                                 newprofiles.push_back(new mbedTLSIOHookProvider(this, profile));
842                         }
843                         catch (CoreException& ex)
844                         {
845                                 throw ModuleException("Error while initializing the default SSL profile - " + ex.GetReason());
846                         }
847                 }
848
849                 for (ConfigIter i = tags.first; i != tags.second; ++i)
850                 {
851                         ConfigTag* tag = i->second;
852                         if (tag->getString("provider") != "mbedtls")
853                                 continue;
854
855                         std::string name = tag->getString("name");
856                         if (name.empty())
857                         {
858                                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring <sslprofile> tag without name at " + tag->getTagLocation());
859                                 continue;
860                         }
861
862                         reference<mbedTLS::Profile> profile;
863                         try
864                         {
865                                 profile = mbedTLS::Profile::Create(name, tag, ctr_drbg);
866                         }
867                         catch (CoreException& ex)
868                         {
869                                 throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
870                         }
871
872                         newprofiles.push_back(new mbedTLSIOHookProvider(this, profile));
873                 }
874
875                 // New profiles are ok, begin using them
876                 // Old profiles are deleted when their refcount drops to zero
877                 profiles.swap(newprofiles);
878         }
879
880  public:
881         void init() CXX11_OVERRIDE
882         {
883                 char verbuf[16]; // Should be at least 9 bytes in size
884                 mbedtls_version_get_string(verbuf);
885                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "mbedTLS lib version %s module was compiled for " MBEDTLS_VERSION_STRING, verbuf);
886
887                 if (!ctr_drbg.Seed(entropy))
888                         throw ModuleException("CTR DRBG seed failed");
889                 ReadProfiles();
890         }
891
892         void OnModuleRehash(User* user, const std::string &param) CXX11_OVERRIDE
893         {
894                 if (param != "ssl")
895                         return;
896
897                 try
898                 {
899                         ReadProfiles();
900                 }
901                 catch (ModuleException& ex)
902                 {
903                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, ex.GetReason() + " Not applying settings.");
904                 }
905         }
906
907         void OnCleanup(int target_type, void* item) CXX11_OVERRIDE
908         {
909                 if (target_type != TYPE_USER)
910                         return;
911
912                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
913                 if ((user) && (user->eh.GetModHook(this)))
914                 {
915                         // User is using SSL, they're a local user, and they're using our IOHook.
916                         // Potentially there could be multiple SSL modules loaded at once on different ports.
917                         ServerInstance->Users.QuitUser(user, "SSL module unloading");
918                 }
919         }
920
921         ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
922         {
923                 const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this));
924                 if ((iohook) && (!iohook->IsHandshakeDone()))
925                         return MOD_RES_DENY;
926                 return MOD_RES_PASSTHRU;
927         }
928
929         Version GetVersion() CXX11_OVERRIDE
930         {
931                 return Version("Provides SSL support via mbedTLS (PolarSSL)", VF_VENDOR);
932         }
933 };
934
935 MODULE_INIT(ModuleSSLmbedTLS)