]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_ssl_mbedtls.cpp
84c507cf8978ef65af8fe52329cc95bcf3ae7b08
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_mbedtls.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2019 Matt Schatz <genius3000@g3k.solutions>
5  *   Copyright (C) 2016-2019 Sadie Powell <sadie@witchery.services>
6  *   Copyright (C) 2016-2017 Attila Molnar <attilamolnar@hush.com>
7  *
8  * This file is part of InspIRCd.  InspIRCd is free software: you can
9  * redistribute it and/or modify it under the terms of the GNU General Public
10  * License as published by the Free Software Foundation, version 2.
11  *
12  * This program is distributed in the hope that it will be useful, but WITHOUT
13  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
14  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
15  * details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
19  */
20
21 /// $LinkerFlags: -lmbedtls
22
23 /// $PackageInfo: require_system("arch") mbedtls
24 /// $PackageInfo: require_system("darwin") mbedtls
25 /// $PackageInfo: require_system("debian" "9.0") libmbedtls-dev
26 /// $PackageInfo: require_system("ubuntu" "16.04") libmbedtls-dev
27
28
29 #include "inspircd.h"
30 #include "modules/ssl.h"
31
32 #include <mbedtls/ctr_drbg.h>
33 #include <mbedtls/dhm.h>
34 #include <mbedtls/ecp.h>
35 #include <mbedtls/entropy.h>
36 #include <mbedtls/error.h>
37 #include <mbedtls/md.h>
38 #include <mbedtls/pk.h>
39 #include <mbedtls/ssl.h>
40 #include <mbedtls/ssl_ciphersuites.h>
41 #include <mbedtls/version.h>
42 #include <mbedtls/x509.h>
43 #include <mbedtls/x509_crt.h>
44 #include <mbedtls/x509_crl.h>
45
46 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
47 #include <mbedtls/debug.h>
48 #endif
49
50 namespace mbedTLS
51 {
52         class Exception : public ModuleException
53         {
54          public:
55                 Exception(const std::string& reason)
56                         : ModuleException(reason) { }
57         };
58
59         std::string ErrorToString(int errcode)
60         {
61                 char buf[256];
62                 mbedtls_strerror(errcode, buf, sizeof(buf));
63                 return buf;
64         }
65
66         void ThrowOnError(int errcode, const char* msg)
67         {
68                 if (errcode != 0)
69                 {
70                         std::string reason = msg;
71                         reason.append(" :").append(ErrorToString(errcode));
72                         throw Exception(reason);
73                 }
74         }
75
76         template <typename T, void (*init)(T*), void (*deinit)(T*)>
77         class RAIIObj
78         {
79                 T obj;
80
81          public:
82                 RAIIObj()
83                 {
84                         init(&obj);
85                 }
86
87                 ~RAIIObj()
88                 {
89                         deinit(&obj);
90                 }
91
92                 T* get() { return &obj; }
93                 const T* get() const { return &obj; }
94         };
95
96         typedef RAIIObj<mbedtls_entropy_context, mbedtls_entropy_init, mbedtls_entropy_free> Entropy;
97
98         class CTRDRBG : private RAIIObj<mbedtls_ctr_drbg_context, mbedtls_ctr_drbg_init, mbedtls_ctr_drbg_free>
99         {
100          public:
101                 bool Seed(Entropy& entropy)
102                 {
103                         return (mbedtls_ctr_drbg_seed(get(), mbedtls_entropy_func, entropy.get(), NULL, 0) == 0);
104                 }
105
106                 void SetupConf(mbedtls_ssl_config* conf)
107                 {
108                         mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, get());
109                 }
110         };
111
112         class DHParams : public RAIIObj<mbedtls_dhm_context, mbedtls_dhm_init, mbedtls_dhm_free>
113         {
114          public:
115                 void set(const std::string& dhstr)
116                 {
117                         // Last parameter is buffer size, must include the terminating null
118                         int ret = mbedtls_dhm_parse_dhm(get(), reinterpret_cast<const unsigned char*>(dhstr.c_str()), dhstr.size()+1);
119                         ThrowOnError(ret, "Unable to import DH params");
120                 }
121         };
122
123         class X509Key : public RAIIObj<mbedtls_pk_context, mbedtls_pk_init, mbedtls_pk_free>
124         {
125          public:
126                 /** Import */
127                 X509Key(const std::string& keystr)
128                 {
129                         int ret = mbedtls_pk_parse_key(get(), reinterpret_cast<const unsigned char*>(keystr.c_str()), keystr.size()+1, NULL, 0);
130                         ThrowOnError(ret, "Unable to import private key");
131                 }
132         };
133
134         class Ciphersuites
135         {
136                 std::vector<int> list;
137
138          public:
139                 Ciphersuites(const std::string& str)
140                 {
141                         // mbedTLS uses the ciphersuite format "TLS-ECDHE-RSA-WITH-AES-128-GCM-SHA256" internally.
142                         // This is a bit verbose, so we make life a bit simpler for admins by not requiring them to supply the static parts.
143                         irc::sepstream ss(str, ':');
144                         for (std::string token; ss.GetToken(token); )
145                         {
146                                 // Prepend "TLS-" if not there
147                                 if (token.compare(0, 4, "TLS-", 4))
148                                         token.insert(0, "TLS-");
149
150                                 const int id = mbedtls_ssl_get_ciphersuite_id(token.c_str());
151                                 if (!id)
152                                         throw Exception("Unknown ciphersuite " + token);
153                                 list.push_back(id);
154                         }
155                         list.push_back(0);
156                 }
157
158                 const int* get() const { return &list.front(); }
159                 bool empty() const { return (list.size() <= 1); }
160         };
161
162         class Curves
163         {
164                 std::vector<mbedtls_ecp_group_id> list;
165
166          public:
167                 Curves(const std::string& str)
168                 {
169                         irc::sepstream ss(str, ':');
170                         for (std::string token; ss.GetToken(token); )
171                         {
172                                 const mbedtls_ecp_curve_info* curve = mbedtls_ecp_curve_info_from_name(token.c_str());
173                                 if (!curve)
174                                         throw Exception("Unknown curve " + token);
175                                 list.push_back(curve->grp_id);
176                         }
177                         list.push_back(MBEDTLS_ECP_DP_NONE);
178                 }
179
180                 const mbedtls_ecp_group_id* get() const { return &list.front(); }
181                 bool empty() const { return (list.size() <= 1); }
182         };
183
184         class X509CertList : public RAIIObj<mbedtls_x509_crt, mbedtls_x509_crt_init, mbedtls_x509_crt_free>
185         {
186          public:
187                 /** Import or create empty */
188                 X509CertList(const std::string& certstr, bool allowempty = false)
189                 {
190                         if ((allowempty) && (certstr.empty()))
191                                 return;
192                         int ret = mbedtls_x509_crt_parse(get(), reinterpret_cast<const unsigned char*>(certstr.c_str()), certstr.size()+1);
193                         ThrowOnError(ret, "Unable to load certificates");
194                 }
195
196                 bool empty() const { return (get()->raw.p != NULL); }
197         };
198
199         class X509CRL : public RAIIObj<mbedtls_x509_crl, mbedtls_x509_crl_init, mbedtls_x509_crl_free>
200         {
201          public:
202                 X509CRL(const std::string& crlstr)
203                 {
204                         if (crlstr.empty())
205                                 return;
206                         int ret = mbedtls_x509_crl_parse(get(), reinterpret_cast<const unsigned char*>(crlstr.c_str()), crlstr.size()+1);
207                         ThrowOnError(ret, "Unable to load CRL");
208                 }
209         };
210
211         class X509Credentials
212         {
213                 /** Private key
214                  */
215                 X509Key key;
216
217                 /** Certificate list, presented to the peer
218                  */
219                 X509CertList certs;
220
221          public:
222                 X509Credentials(const std::string& certstr, const std::string& keystr)
223                         : key(keystr)
224                         , certs(certstr)
225                 {
226                         // Verify that one of the certs match the private key
227                         bool found = false;
228                         for (mbedtls_x509_crt* cert = certs.get(); cert; cert = cert->next)
229                         {
230                                 if (mbedtls_pk_check_pair(&cert->pk, key.get()) == 0)
231                                 {
232                                         found = true;
233                                         break;
234                                 }
235                         }
236                         if (!found)
237                                 throw Exception("Public/private key pair does not match");
238                 }
239
240                 mbedtls_pk_context* getkey() { return key.get(); }
241                 mbedtls_x509_crt* getcerts() { return certs.get(); }
242         };
243
244         class Context
245         {
246                 mbedtls_ssl_config conf;
247
248 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
249                 static void DebugLogFunc(void* userptr, int level, const char* file, int line, const char* msg)
250                 {
251                         // Remove trailing \n
252                         size_t len = strlen(msg);
253                         if ((len > 0) && (msg[len-1] == '\n'))
254                                 len--;
255                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "%s:%d %.*s", file, line, len, msg);
256                 }
257 #endif
258
259          public:
260                 Context(CTRDRBG& ctrdrbg, unsigned int endpoint)
261                 {
262                         mbedtls_ssl_config_init(&conf);
263 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
264                         mbedtls_debug_set_threshold(INT_MAX);
265                         mbedtls_ssl_conf_dbg(&conf, DebugLogFunc, NULL);
266 #endif
267
268                         // TODO: check ret of mbedtls_ssl_config_defaults
269                         mbedtls_ssl_config_defaults(&conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
270                         ctrdrbg.SetupConf(&conf);
271                 }
272
273                 ~Context()
274                 {
275                         mbedtls_ssl_config_free(&conf);
276                 }
277
278                 void SetMinDHBits(unsigned int mindh)
279                 {
280                         mbedtls_ssl_conf_dhm_min_bitlen(&conf, mindh);
281                 }
282
283                 void SetDHParams(DHParams& dh)
284                 {
285                         mbedtls_ssl_conf_dh_param_ctx(&conf, dh.get());
286                 }
287
288                 void SetX509CertAndKey(X509Credentials& x509cred)
289                 {
290                         mbedtls_ssl_conf_own_cert(&conf, x509cred.getcerts(), x509cred.getkey());
291                 }
292
293                 void SetCiphersuites(const Ciphersuites& ciphersuites)
294                 {
295                         mbedtls_ssl_conf_ciphersuites(&conf, ciphersuites.get());
296                 }
297
298                 void SetCurves(const Curves& curves)
299                 {
300                         mbedtls_ssl_conf_curves(&conf, curves.get());
301                 }
302
303                 void SetVersion(int minver, int maxver)
304                 {
305                         // SSL v3 support cannot be enabled
306                         if (minver)
307                                 mbedtls_ssl_conf_min_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, minver);
308                         if (maxver)
309                                 mbedtls_ssl_conf_max_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, maxver);
310                 }
311
312                 void SetCA(X509CertList& certs, X509CRL& crl)
313                 {
314                         mbedtls_ssl_conf_ca_chain(&conf, certs.get(), crl.get());
315                 }
316
317                 void SetOptionalVerifyCert()
318                 {
319                         mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
320                 }
321
322                 const mbedtls_ssl_config* GetConf() const { return &conf; }
323         };
324
325         class Hash
326         {
327                 const mbedtls_md_info_t* md;
328
329                 /** Buffer where cert hashes are written temporarily
330                  */
331                 mutable std::vector<unsigned char> buf;
332
333          public:
334                 Hash(std::string hashstr)
335                 {
336                         std::transform(hashstr.begin(), hashstr.end(), hashstr.begin(), ::toupper);
337                         md = mbedtls_md_info_from_string(hashstr.c_str());
338                         if (!md)
339                                 throw Exception("Unknown hash: " + hashstr);
340
341                         buf.resize(mbedtls_md_get_size(md));
342                 }
343
344                 std::string hash(const unsigned char* input, size_t length) const
345                 {
346                         mbedtls_md(md, input, length, &buf.front());
347                         return BinToHex(&buf.front(), buf.size());
348                 }
349         };
350
351         class Profile
352         {
353                 /** Name of this profile
354                  */
355                 const std::string name;
356
357                 X509Credentials x509cred;
358
359                 /** Ciphersuites to use
360                  */
361                 Ciphersuites ciphersuites;
362
363                 /** Curves accepted for use in ECDHE and in the peer's end-entity certificate
364                  */
365                 Curves curves;
366
367                 Context serverctx;
368                 Context clientctx;
369
370                 DHParams dhparams;
371
372                 X509CertList cacerts;
373
374                 X509CRL crl;
375
376                 /** Hashing algorithm to use when generating certificate fingerprints
377                  */
378                 Hash hash;
379
380                 /** Rough max size of records to send
381                  */
382                 const unsigned int outrecsize;
383
384          public:
385                 struct Config
386                 {
387                         const std::string name;
388
389                         CTRDRBG& ctrdrbg;
390
391                         const std::string certstr;
392                         const std::string keystr;
393                         const std::string dhstr;
394
395                         const std::string ciphersuitestr;
396                         const std::string curvestr;
397                         const unsigned int mindh;
398                         const std::string hashstr;
399
400                         std::string crlstr;
401                         std::string castr;
402
403                         const int minver;
404                         const int maxver;
405                         const unsigned int outrecsize;
406                         const bool requestclientcert;
407
408                         Config(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
409                                 : name(profilename)
410                                 , ctrdrbg(ctr_drbg)
411                                 , certstr(ReadFile(tag->getString("certfile", "cert.pem")))
412                                 , keystr(ReadFile(tag->getString("keyfile", "key.pem")))
413                                 , dhstr(ReadFile(tag->getString("dhfile", "dhparams.pem")))
414                                 , ciphersuitestr(tag->getString("ciphersuites"))
415                                 , curvestr(tag->getString("curves"))
416                                 , mindh(tag->getUInt("mindhbits", 2048))
417                                 , hashstr(tag->getString("hash", "sha256"))
418                                 , castr(tag->getString("cafile"))
419                                 , minver(tag->getUInt("minver", 0))
420                                 , maxver(tag->getUInt("maxver", 0))
421                                 , outrecsize(tag->getUInt("outrecsize", 2048, 512, 16384))
422                                 , requestclientcert(tag->getBool("requestclientcert", true))
423                         {
424                                 if (!castr.empty())
425                                 {
426                                         castr = ReadFile(castr);
427                                         crlstr = tag->getString("crlfile");
428                                         if (!crlstr.empty())
429                                                 crlstr = ReadFile(crlstr);
430                                 }
431                         }
432                 };
433
434                 Profile(Config& config)
435                         : name(config.name)
436                         , x509cred(config.certstr, config.keystr)
437                         , ciphersuites(config.ciphersuitestr)
438                         , curves(config.curvestr)
439                         , serverctx(config.ctrdrbg, MBEDTLS_SSL_IS_SERVER)
440                         , clientctx(config.ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
441                         , cacerts(config.castr, true)
442                         , crl(config.crlstr)
443                         , hash(config.hashstr)
444                         , outrecsize(config.outrecsize)
445                 {
446                         serverctx.SetX509CertAndKey(x509cred);
447                         clientctx.SetX509CertAndKey(x509cred);
448                         clientctx.SetMinDHBits(config.mindh);
449
450                         if (!ciphersuites.empty())
451                         {
452                                 serverctx.SetCiphersuites(ciphersuites);
453                                 clientctx.SetCiphersuites(ciphersuites);
454                         }
455
456                         if (!curves.empty())
457                         {
458                                 serverctx.SetCurves(curves);
459                                 clientctx.SetCurves(curves);
460                         }
461
462                         serverctx.SetVersion(config.minver, config.maxver);
463                         clientctx.SetVersion(config.minver, config.maxver);
464
465                         if (!config.dhstr.empty())
466                         {
467                                 dhparams.set(config.dhstr);
468                                 serverctx.SetDHParams(dhparams);
469                         }
470
471                         clientctx.SetOptionalVerifyCert();
472                         clientctx.SetCA(cacerts, crl);
473                         // The default for servers is to not request a client certificate from the peer
474                         if (config.requestclientcert)
475                         {
476                                 serverctx.SetOptionalVerifyCert();
477                                 serverctx.SetCA(cacerts, crl);
478                         }
479                 }
480
481                 static std::string ReadFile(const std::string& filename)
482                 {
483                         FileReader reader(filename);
484                         std::string ret = reader.GetString();
485                         if (ret.empty())
486                                 throw Exception("Cannot read file " + filename);
487                         return ret;
488                 }
489
490                 /** Set up the given session with the settings in this profile
491                  */
492                 void SetupClientSession(mbedtls_ssl_context* sess)
493                 {
494                         mbedtls_ssl_setup(sess, clientctx.GetConf());
495                 }
496
497                 void SetupServerSession(mbedtls_ssl_context* sess)
498                 {
499                         mbedtls_ssl_setup(sess, serverctx.GetConf());
500                 }
501
502                 const std::string& GetName() const { return name; }
503                 X509Credentials& GetX509Credentials() { return x509cred; }
504                 unsigned int GetOutgoingRecordSize() const { return outrecsize; }
505                 const Hash& GetHash() const { return hash; }
506         };
507 }
508
509 class mbedTLSIOHook : public SSLIOHook
510 {
511         enum Status
512         {
513                 ISSL_NONE,
514                 ISSL_HANDSHAKING,
515                 ISSL_HANDSHAKEN
516         };
517
518         mbedtls_ssl_context sess;
519         Status status;
520
521         void CloseSession()
522         {
523                 if (status == ISSL_NONE)
524                         return;
525
526                 mbedtls_ssl_close_notify(&sess);
527                 mbedtls_ssl_free(&sess);
528                 certificate = NULL;
529                 status = ISSL_NONE;
530         }
531
532         // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed
533         int Handshake(StreamSocket* sock)
534         {
535                 int ret = mbedtls_ssl_handshake(&sess);
536                 if (ret == 0)
537                 {
538                         // Change the seesion state
539                         this->status = ISSL_HANDSHAKEN;
540
541                         VerifyCertificate();
542
543                         // Finish writing, if any left
544                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
545
546                         return 1;
547                 }
548
549                 this->status = ISSL_HANDSHAKING;
550                 if (ret == MBEDTLS_ERR_SSL_WANT_READ)
551                 {
552                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
553                         return 0;
554                 }
555                 else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
556                 {
557                         SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
558                         return 0;
559                 }
560
561                 sock->SetError("Handshake Failed - " + mbedTLS::ErrorToString(ret));
562                 CloseSession();
563                 return -1;
564         }
565
566         // Returns 1 if application I/O should proceed, 0 if it must wait for the underlying protocol to progress, -1 on fatal error
567         int PrepareIO(StreamSocket* sock)
568         {
569                 if (status == ISSL_HANDSHAKEN)
570                         return 1;
571                 else if (status == ISSL_HANDSHAKING)
572                 {
573                         // The handshake isn't finished, try to finish it
574                         return Handshake(sock);
575                 }
576
577                 CloseSession();
578                 sock->SetError("No SSL session");
579                 return -1;
580         }
581
582         void VerifyCertificate()
583         {
584                 this->certificate = new ssl_cert;
585                 const mbedtls_x509_crt* const cert = mbedtls_ssl_get_peer_cert(&sess);
586                 if (!cert)
587                 {
588                         certificate->error = "No client certificate sent";
589                         return;
590                 }
591
592                 // If there is a certificate we can always generate a fingerprint
593                 certificate->fingerprint = GetProfile().GetHash().hash(cert->raw.p, cert->raw.len);
594
595                 // At this point mbedTLS verified the cert already, we just need to check the results
596                 const uint32_t flags = mbedtls_ssl_get_verify_result(&sess);
597                 if (flags == 0xFFFFFFFF)
598                 {
599                         certificate->error = "Internal error during verification";
600                         return;
601                 }
602
603                 if (flags == 0)
604                 {
605                         // Verification succeeded
606                         certificate->trusted = true;
607                 }
608                 else
609                 {
610                         // Verification failed
611                         certificate->trusted = false;
612                         if ((flags & MBEDTLS_X509_BADCERT_EXPIRED) || (flags & MBEDTLS_X509_BADCERT_FUTURE))
613                                 certificate->error = "Not activated, or expired certificate";
614                 }
615
616                 certificate->unknownsigner = (flags & MBEDTLS_X509_BADCERT_NOT_TRUSTED);
617                 certificate->revoked = (flags & MBEDTLS_X509_BADCERT_REVOKED);
618                 certificate->invalid = ((flags & MBEDTLS_X509_BADCERT_BAD_KEY) || (flags & MBEDTLS_X509_BADCERT_BAD_MD) || (flags & MBEDTLS_X509_BADCERT_BAD_PK));
619
620                 GetDNString(&cert->subject, certificate->dn);
621                 GetDNString(&cert->issuer, certificate->issuer);
622         }
623
624         static void GetDNString(const mbedtls_x509_name* x509name, std::string& out)
625         {
626                 char buf[512];
627                 const int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), x509name);
628                 if (ret <= 0)
629                         return;
630
631                 out.assign(buf, ret);
632         }
633
634         static int Pull(void* userptr, unsigned char* buffer, size_t size)
635         {
636                 StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
637                 if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
638                         return MBEDTLS_ERR_SSL_WANT_READ;
639
640                 const int ret = SocketEngine::Recv(sock, reinterpret_cast<char*>(buffer), size, 0);
641                 if (ret < (int)size)
642                 {
643                         SocketEngine::ChangeEventMask(sock, FD_READ_WILL_BLOCK);
644                         if ((ret == -1) && (SocketEngine::IgnoreError()))
645                                 return MBEDTLS_ERR_SSL_WANT_READ;
646                 }
647                 return ret;
648         }
649
650         static int Push(void* userptr, const unsigned char* buffer, size_t size)
651         {
652                 StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
653                 if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
654                         return MBEDTLS_ERR_SSL_WANT_WRITE;
655
656                 const int ret = SocketEngine::Send(sock, buffer, size, 0);
657                 if (ret < (int)size)
658                 {
659                         SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK);
660                         if ((ret == -1) && (SocketEngine::IgnoreError()))
661                                 return MBEDTLS_ERR_SSL_WANT_WRITE;
662                 }
663                 return ret;
664         }
665
666  public:
667         mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver)
668                 : SSLIOHook(hookprov)
669                 , status(ISSL_NONE)
670         {
671                 mbedtls_ssl_init(&sess);
672                 if (isserver)
673                         GetProfile().SetupServerSession(&sess);
674                 else
675                         GetProfile().SetupClientSession(&sess);
676
677                 mbedtls_ssl_set_bio(&sess, reinterpret_cast<void*>(sock), Push, Pull, NULL);
678
679                 sock->AddIOHook(this);
680                 Handshake(sock);
681         }
682
683         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
684         {
685                 CloseSession();
686         }
687
688         int OnStreamSocketRead(StreamSocket* sock, std::string& recvq) CXX11_OVERRIDE
689         {
690                 // Finish handshake if needed
691                 int prepret = PrepareIO(sock);
692                 if (prepret <= 0)
693                         return prepret;
694
695                 // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN.
696                 char* const readbuf = ServerInstance->GetReadBuffer();
697                 const size_t readbufsize = ServerInstance->Config->NetBufferSize;
698                 int ret = mbedtls_ssl_read(&sess, reinterpret_cast<unsigned char*>(readbuf), readbufsize);
699                 if (ret > 0)
700                 {
701                         recvq.append(readbuf, ret);
702
703                         // Schedule a read if there is still data in the mbedTLS buffer
704                         if (mbedtls_ssl_get_bytes_avail(&sess) > 0)
705                                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_READ);
706                         return 1;
707                 }
708                 else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
709                 {
710                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
711                         return 0;
712                 }
713                 else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
714                 {
715                         SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
716                         return 0;
717                 }
718                 else if (ret == 0)
719                 {
720                         sock->SetError("Connection closed");
721                         CloseSession();
722                         return -1;
723                 }
724                 else // error or MBEDTLS_ERR_SSL_CLIENT_RECONNECT which we treat as an error
725                 {
726                         sock->SetError(mbedTLS::ErrorToString(ret));
727                         CloseSession();
728                         return -1;
729                 }
730         }
731
732         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
733         {
734                 // Finish handshake if needed
735                 int prepret = PrepareIO(sock);
736                 if (prepret <= 0)
737                         return prepret;
738
739                 // Session is ready for transferring application data
740                 while (!sendq.empty())
741                 {
742                         FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
743                         const StreamSocket::SendQueue::Element& buffer = sendq.front();
744                         int ret = mbedtls_ssl_write(&sess, reinterpret_cast<const unsigned char*>(buffer.data()), buffer.length());
745                         if (ret == (int)buffer.length())
746                         {
747                                 // Wrote entire record, continue sending
748                                 sendq.pop_front();
749                         }
750                         else if (ret > 0)
751                         {
752                                 sendq.erase_front(ret);
753                                 SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
754                                 return 0;
755                         }
756                         else if (ret == 0)
757                         {
758                                 sock->SetError("Connection closed");
759                                 CloseSession();
760                                 return -1;
761                         }
762                         else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
763                         {
764                                 SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
765                                 return 0;
766                         }
767                         else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
768                         {
769                                 SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
770                                 return 0;
771                         }
772                         else
773                         {
774                                 sock->SetError(mbedTLS::ErrorToString(ret));
775                                 CloseSession();
776                                 return -1;
777                         }
778                 }
779
780                 SocketEngine::ChangeEventMask(sock, FD_WANT_NO_WRITE);
781                 return 1;
782         }
783
784         void GetCiphersuite(std::string& out) const CXX11_OVERRIDE
785         {
786                 if (!IsHandshakeDone())
787                         return;
788                 out.append(mbedtls_ssl_get_version(&sess)).push_back('-');
789
790                 // All mbedTLS ciphersuite names currently begin with "TLS-" which provides no useful information so skip it, but be prepared if it changes
791                 const char* const ciphersuitestr = mbedtls_ssl_get_ciphersuite(&sess);
792                 const char prefix[] = "TLS-";
793                 unsigned int skip = sizeof(prefix)-1;
794                 if (strncmp(ciphersuitestr, prefix, sizeof(prefix)-1))
795                         skip = 0;
796                 out.append(ciphersuitestr + skip);
797         }
798
799         bool GetServerName(std::string& out) const CXX11_OVERRIDE
800         {
801                 // TODO: Implement SNI support.
802                 return false;
803         }
804
805         mbedTLS::Profile& GetProfile();
806         bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
807 };
808
809 class mbedTLSIOHookProvider : public IOHookProvider
810 {
811         mbedTLS::Profile profile;
812
813  public:
814         mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile::Config& config)
815                 : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL)
816                 , profile(config)
817         {
818                 ServerInstance->Modules->AddService(*this);
819         }
820
821         ~mbedTLSIOHookProvider()
822         {
823                 ServerInstance->Modules->DelService(*this);
824         }
825
826         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
827         {
828                 new mbedTLSIOHook(this, sock, true);
829         }
830
831         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
832         {
833                 new mbedTLSIOHook(this, sock, false);
834         }
835
836         mbedTLS::Profile& GetProfile() { return profile; }
837 };
838
839 mbedTLS::Profile& mbedTLSIOHook::GetProfile()
840 {
841         IOHookProvider* hookprov = prov;
842         return static_cast<mbedTLSIOHookProvider*>(hookprov)->GetProfile();
843 }
844
845 class ModuleSSLmbedTLS : public Module
846 {
847         typedef std::vector<reference<mbedTLSIOHookProvider> > ProfileList;
848
849         mbedTLS::Entropy entropy;
850         mbedTLS::CTRDRBG ctr_drbg;
851         ProfileList profiles;
852
853         void ReadProfiles()
854         {
855                 // First, store all profiles in a new, temporary container. If no problems occur, swap the two
856                 // containers; this way if something goes wrong we can go back and continue using the current profiles,
857                 // avoiding unpleasant situations where no new SSL connections are possible.
858                 ProfileList newprofiles;
859
860                 ConfigTagList tags = ServerInstance->Config->ConfTags("sslprofile");
861                 if (tags.first == tags.second)
862                 {
863                         // No <sslprofile> tags found, create a profile named "mbedtls" from settings in the <mbedtls> block
864                         const std::string defname = "mbedtls";
865                         ConfigTag* tag = ServerInstance->Config->ConfValue(defname);
866                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No <sslprofile> tags found; using settings from the <mbedtls> tag");
867
868                         try
869                         {
870                                 mbedTLS::Profile::Config profileconfig(defname, tag, ctr_drbg);
871                                 newprofiles.push_back(new mbedTLSIOHookProvider(this, profileconfig));
872                         }
873                         catch (CoreException& ex)
874                         {
875                                 throw ModuleException("Error while initializing the default SSL profile - " + ex.GetReason());
876                         }
877                 }
878
879                 for (ConfigIter i = tags.first; i != tags.second; ++i)
880                 {
881                         ConfigTag* tag = i->second;
882                         if (!stdalgo::string::equalsci(tag->getString("provider"), "mbedtls"))
883                                 continue;
884
885                         std::string name = tag->getString("name");
886                         if (name.empty())
887                         {
888                                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring <sslprofile> tag without name at " + tag->getTagLocation());
889                                 continue;
890                         }
891
892                         reference<mbedTLSIOHookProvider> prov;
893                         try
894                         {
895                                 mbedTLS::Profile::Config profileconfig(name, tag, ctr_drbg);
896                                 prov = new mbedTLSIOHookProvider(this, profileconfig);
897                         }
898                         catch (CoreException& ex)
899                         {
900                                 throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
901                         }
902
903                         newprofiles.push_back(prov);
904                 }
905
906                 // New profiles are ok, begin using them
907                 // Old profiles are deleted when their refcount drops to zero
908                 for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
909                 {
910                         mbedTLSIOHookProvider& prov = **i;
911                         ServerInstance->Modules.DelService(prov);
912                 }
913
914                 profiles.swap(newprofiles);
915         }
916
917  public:
918         void init() CXX11_OVERRIDE
919         {
920                 char verbuf[16]; // Should be at least 9 bytes in size
921                 mbedtls_version_get_string(verbuf);
922                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "mbedTLS lib version %s module was compiled for " MBEDTLS_VERSION_STRING, verbuf);
923
924                 if (!ctr_drbg.Seed(entropy))
925                         throw ModuleException("CTR DRBG seed failed");
926                 ReadProfiles();
927         }
928
929         void OnModuleRehash(User* user, const std::string &param) CXX11_OVERRIDE
930         {
931                 if (!irc::equals(param, "ssl"))
932                         return;
933
934                 try
935                 {
936                         ReadProfiles();
937                         ServerInstance->SNO->WriteToSnoMask('a', "SSL module %s rehashed.", MODNAME);
938                 }
939                 catch (ModuleException& ex)
940                 {
941                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, ex.GetReason() + " Not applying settings.");
942                 }
943         }
944
945         void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
946         {
947                 if (type != ExtensionItem::EXT_USER)
948                         return;
949
950                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
951                 if ((user) && (user->eh.GetModHook(this)))
952                 {
953                         // User is using SSL, they're a local user, and they're using our IOHook.
954                         // Potentially there could be multiple SSL modules loaded at once on different ports.
955                         ServerInstance->Users.QuitUser(user, "SSL module unloading");
956                 }
957         }
958
959         ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
960         {
961                 const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this));
962                 if ((iohook) && (!iohook->IsHandshakeDone()))
963                         return MOD_RES_DENY;
964                 return MOD_RES_PASSTHRU;
965         }
966
967         Version GetVersion() CXX11_OVERRIDE
968         {
969                 return Version("Provides SSL support via mbedTLS (PolarSSL)", VF_VENDOR);
970         }
971 };
972
973 MODULE_INIT(ModuleSSLmbedTLS)