]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_ssl_mbedtls.cpp
Add package names for ArchLinux.
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_mbedtls.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
5  *
6  * This file is part of InspIRCd.  InspIRCd is free software: you can
7  * redistribute it and/or modify it under the terms of the GNU General Public
8  * License as published by the Free Software Foundation, version 2.
9  *
10  * This program is distributed in the hope that it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
13  * details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19 /// $LinkerFlags: -lmbedtls
20
21 /// $PackageInfo: require_system("arch") mbedtls
22 /// $PackageInfo: require_system("darwin") mbedtls
23 /// $PackageInfo: require_system("debian" "9.0") libmbedtls-dev
24 /// $PackageInfo: require_system("ubuntu" "16.04") libmbedtls-dev
25
26
27 #include "inspircd.h"
28 #include "modules/ssl.h"
29
30 #include <mbedtls/ctr_drbg.h>
31 #include <mbedtls/dhm.h>
32 #include <mbedtls/ecp.h>
33 #include <mbedtls/entropy.h>
34 #include <mbedtls/error.h>
35 #include <mbedtls/md.h>
36 #include <mbedtls/pk.h>
37 #include <mbedtls/ssl.h>
38 #include <mbedtls/ssl_ciphersuites.h>
39 #include <mbedtls/version.h>
40 #include <mbedtls/x509.h>
41 #include <mbedtls/x509_crt.h>
42 #include <mbedtls/x509_crl.h>
43
44 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
45 #include <mbedtls/debug.h>
46 #endif
47
48 namespace mbedTLS
49 {
50         class Exception : public ModuleException
51         {
52          public:
53                 Exception(const std::string& reason)
54                         : ModuleException(reason) { }
55         };
56
57         std::string ErrorToString(int errcode)
58         {
59                 char buf[256];
60                 mbedtls_strerror(errcode, buf, sizeof(buf));
61                 return buf;
62         }
63
64         void ThrowOnError(int errcode, const char* msg)
65         {
66                 if (errcode != 0)
67                 {
68                         std::string reason = msg;
69                         reason.append(" :").append(ErrorToString(errcode));
70                         throw Exception(reason);
71                 }
72         }
73
74         template <typename T, void (*init)(T*), void (*deinit)(T*)>
75         class RAIIObj
76         {
77                 T obj;
78
79          public:
80                 RAIIObj()
81                 {
82                         init(&obj);
83                 }
84
85                 ~RAIIObj()
86                 {
87                         deinit(&obj);
88                 }
89
90                 T* get() { return &obj; }
91                 const T* get() const { return &obj; }
92         };
93
94         typedef RAIIObj<mbedtls_entropy_context, mbedtls_entropy_init, mbedtls_entropy_free> Entropy;
95
96         class CTRDRBG : private RAIIObj<mbedtls_ctr_drbg_context, mbedtls_ctr_drbg_init, mbedtls_ctr_drbg_free>
97         {
98          public:
99                 bool Seed(Entropy& entropy)
100                 {
101                         return (mbedtls_ctr_drbg_seed(get(), mbedtls_entropy_func, entropy.get(), NULL, 0) == 0);
102                 }
103
104                 void SetupConf(mbedtls_ssl_config* conf)
105                 {
106                         mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, get());
107                 }
108         };
109
110         class DHParams : public RAIIObj<mbedtls_dhm_context, mbedtls_dhm_init, mbedtls_dhm_free>
111         {
112          public:
113                 void set(const std::string& dhstr)
114                 {
115                         // Last parameter is buffer size, must include the terminating null
116                         int ret = mbedtls_dhm_parse_dhm(get(), reinterpret_cast<const unsigned char*>(dhstr.c_str()), dhstr.size()+1);
117                         ThrowOnError(ret, "Unable to import DH params");
118                 }
119         };
120
121         class X509Key : public RAIIObj<mbedtls_pk_context, mbedtls_pk_init, mbedtls_pk_free>
122         {
123          public:
124                 /** Import */
125                 X509Key(const std::string& keystr)
126                 {
127                         int ret = mbedtls_pk_parse_key(get(), reinterpret_cast<const unsigned char*>(keystr.c_str()), keystr.size()+1, NULL, 0);
128                         ThrowOnError(ret, "Unable to import private key");
129                 }
130         };
131
132         class Ciphersuites
133         {
134                 std::vector<int> list;
135
136          public:
137                 Ciphersuites(const std::string& str)
138                 {
139                         // mbedTLS uses the ciphersuite format "TLS-ECDHE-RSA-WITH-AES-128-GCM-SHA256" internally.
140                         // This is a bit verbose, so we make life a bit simpler for admins by not requiring them to supply the static parts.
141                         irc::sepstream ss(str, ':');
142                         for (std::string token; ss.GetToken(token); )
143                         {
144                                 // Prepend "TLS-" if not there
145                                 if (token.compare(0, 4, "TLS-", 4))
146                                         token.insert(0, "TLS-");
147
148                                 const int id = mbedtls_ssl_get_ciphersuite_id(token.c_str());
149                                 if (!id)
150                                         throw Exception("Unknown ciphersuite " + token);
151                                 list.push_back(id);
152                         }
153                         list.push_back(0);
154                 }
155
156                 const int* get() const { return &list.front(); }
157                 bool empty() const { return (list.size() <= 1); }
158         };
159
160         class Curves
161         {
162                 std::vector<mbedtls_ecp_group_id> list;
163
164          public:
165                 Curves(const std::string& str)
166                 {
167                         irc::sepstream ss(str, ':');
168                         for (std::string token; ss.GetToken(token); )
169                         {
170                                 const mbedtls_ecp_curve_info* curve = mbedtls_ecp_curve_info_from_name(token.c_str());
171                                 if (!curve)
172                                         throw Exception("Unknown curve " + token);
173                                 list.push_back(curve->grp_id);
174                         }
175                         list.push_back(MBEDTLS_ECP_DP_NONE);
176                 }
177
178                 const mbedtls_ecp_group_id* get() const { return &list.front(); }
179                 bool empty() const { return (list.size() <= 1); }
180         };
181
182         class X509CertList : public RAIIObj<mbedtls_x509_crt, mbedtls_x509_crt_init, mbedtls_x509_crt_free>
183         {
184          public:
185                 /** Import or create empty */
186                 X509CertList(const std::string& certstr, bool allowempty = false)
187                 {
188                         if ((allowempty) && (certstr.empty()))
189                                 return;
190                         int ret = mbedtls_x509_crt_parse(get(), reinterpret_cast<const unsigned char*>(certstr.c_str()), certstr.size()+1);
191                         ThrowOnError(ret, "Unable to load certificates");
192                 }
193
194                 bool empty() const { return (get()->raw.p != NULL); }
195         };
196
197         class X509CRL : public RAIIObj<mbedtls_x509_crl, mbedtls_x509_crl_init, mbedtls_x509_crl_free>
198         {
199          public:
200                 X509CRL(const std::string& crlstr)
201                 {
202                         if (crlstr.empty())
203                                 return;
204                         int ret = mbedtls_x509_crl_parse(get(), reinterpret_cast<const unsigned char*>(crlstr.c_str()), crlstr.size()+1);
205                         ThrowOnError(ret, "Unable to load CRL");
206                 }
207         };
208
209         class X509Credentials
210         {
211                 /** Private key
212                  */
213                 X509Key key;
214
215                 /** Certificate list, presented to the peer
216                  */
217                 X509CertList certs;
218
219          public:
220                 X509Credentials(const std::string& certstr, const std::string& keystr)
221                         : key(keystr)
222                         , certs(certstr)
223                 {
224                         // Verify that one of the certs match the private key
225                         bool found = false;
226                         for (mbedtls_x509_crt* cert = certs.get(); cert; cert = cert->next)
227                         {
228                                 if (mbedtls_pk_check_pair(&cert->pk, key.get()) == 0)
229                                 {
230                                         found = true;
231                                         break;
232                                 }
233                         }
234                         if (!found)
235                                 throw Exception("Public/private key pair does not match");
236                 }
237
238                 mbedtls_pk_context* getkey() { return key.get(); }
239                 mbedtls_x509_crt* getcerts() { return certs.get(); }
240         };
241
242         class Context
243         {
244                 mbedtls_ssl_config conf;
245
246 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
247                 static void DebugLogFunc(void* userptr, int level, const char* file, int line, const char* msg)
248                 {
249                         // Remove trailing \n
250                         size_t len = strlen(msg);
251                         if ((len > 0) && (msg[len-1] == '\n'))
252                                 len--;
253                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "%s:%d %.*s", file, line, len, msg);
254                 }
255 #endif
256
257          public:
258                 Context(CTRDRBG& ctrdrbg, unsigned int endpoint)
259                 {
260                         mbedtls_ssl_config_init(&conf);
261 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
262                         mbedtls_debug_set_threshold(INT_MAX);
263                         mbedtls_ssl_conf_dbg(&conf, DebugLogFunc, NULL);
264 #endif
265
266                         // TODO: check ret of mbedtls_ssl_config_defaults
267                         mbedtls_ssl_config_defaults(&conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
268                         ctrdrbg.SetupConf(&conf);
269                 }
270
271                 ~Context()
272                 {
273                         mbedtls_ssl_config_free(&conf);
274                 }
275
276                 void SetMinDHBits(unsigned int mindh)
277                 {
278                         mbedtls_ssl_conf_dhm_min_bitlen(&conf, mindh);
279                 }
280
281                 void SetDHParams(DHParams& dh)
282                 {
283                         mbedtls_ssl_conf_dh_param_ctx(&conf, dh.get());
284                 }
285
286                 void SetX509CertAndKey(X509Credentials& x509cred)
287                 {
288                         mbedtls_ssl_conf_own_cert(&conf, x509cred.getcerts(), x509cred.getkey());
289                 }
290
291                 void SetCiphersuites(const Ciphersuites& ciphersuites)
292                 {
293                         mbedtls_ssl_conf_ciphersuites(&conf, ciphersuites.get());
294                 }
295
296                 void SetCurves(const Curves& curves)
297                 {
298                         mbedtls_ssl_conf_curves(&conf, curves.get());
299                 }
300
301                 void SetVersion(int minver, int maxver)
302                 {
303                         // SSL v3 support cannot be enabled
304                         if (minver)
305                                 mbedtls_ssl_conf_min_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, minver);
306                         if (maxver)
307                                 mbedtls_ssl_conf_max_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, maxver);
308                 }
309
310                 void SetCA(X509CertList& certs, X509CRL& crl)
311                 {
312                         mbedtls_ssl_conf_ca_chain(&conf, certs.get(), crl.get());
313                 }
314
315                 void SetOptionalVerifyCert()
316                 {
317                         mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
318                 }
319
320                 const mbedtls_ssl_config* GetConf() const { return &conf; }
321         };
322
323         class Hash
324         {
325                 const mbedtls_md_info_t* md;
326
327                 /** Buffer where cert hashes are written temporarily
328                  */
329                 mutable std::vector<unsigned char> buf;
330
331          public:
332                 Hash(std::string hashstr)
333                 {
334                         std::transform(hashstr.begin(), hashstr.end(), hashstr.begin(), ::toupper);
335                         md = mbedtls_md_info_from_string(hashstr.c_str());
336                         if (!md)
337                                 throw Exception("Unknown hash: " + hashstr);
338
339                         buf.resize(mbedtls_md_get_size(md));
340                 }
341
342                 std::string hash(const unsigned char* input, size_t length) const
343                 {
344                         mbedtls_md(md, input, length, &buf.front());
345                         return BinToHex(&buf.front(), buf.size());
346                 }
347         };
348
349         class Profile
350         {
351                 /** Name of this profile
352                  */
353                 const std::string name;
354
355                 X509Credentials x509cred;
356
357                 /** Ciphersuites to use
358                  */
359                 Ciphersuites ciphersuites;
360
361                 /** Curves accepted for use in ECDHE and in the peer's end-entity certificate
362                  */
363                 Curves curves;
364
365                 Context serverctx;
366                 Context clientctx;
367
368                 DHParams dhparams;
369
370                 X509CertList cacerts;
371
372                 X509CRL crl;
373
374                 /** Hashing algorithm to use when generating certificate fingerprints
375                  */
376                 Hash hash;
377
378                 /** Rough max size of records to send
379                  */
380                 const unsigned int outrecsize;
381
382          public:
383                 struct Config
384                 {
385                         const std::string name;
386
387                         CTRDRBG& ctrdrbg;
388
389                         const std::string certstr;
390                         const std::string keystr;
391                         const std::string dhstr;
392
393                         const std::string ciphersuitestr;
394                         const std::string curvestr;
395                         const unsigned int mindh;
396                         const std::string hashstr;
397
398                         std::string crlstr;
399                         std::string castr;
400
401                         const int minver;
402                         const int maxver;
403                         const unsigned int outrecsize;
404                         const bool requestclientcert;
405
406                         Config(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
407                                 : name(profilename)
408                                 , ctrdrbg(ctr_drbg)
409                                 , certstr(ReadFile(tag->getString("certfile", "cert.pem")))
410                                 , keystr(ReadFile(tag->getString("keyfile", "key.pem")))
411                                 , dhstr(ReadFile(tag->getString("dhfile", "dhparams.pem")))
412                                 , ciphersuitestr(tag->getString("ciphersuites"))
413                                 , curvestr(tag->getString("curves"))
414                                 , mindh(tag->getUInt("mindhbits", 2048))
415                                 , hashstr(tag->getString("hash", "sha256"))
416                                 , castr(tag->getString("cafile"))
417                                 , minver(tag->getUInt("minver", 0))
418                                 , maxver(tag->getUInt("maxver", 0))
419                                 , outrecsize(tag->getUInt("outrecsize", 2048, 512, 16384))
420                                 , requestclientcert(tag->getBool("requestclientcert", true))
421                         {
422                                 if (!castr.empty())
423                                 {
424                                         castr = ReadFile(castr);
425                                         crlstr = tag->getString("crlfile");
426                                         if (!crlstr.empty())
427                                                 crlstr = ReadFile(crlstr);
428                                 }
429                         }
430                 };
431
432                 Profile(Config& config)
433                         : name(config.name)
434                         , x509cred(config.certstr, config.keystr)
435                         , ciphersuites(config.ciphersuitestr)
436                         , curves(config.curvestr)
437                         , serverctx(config.ctrdrbg, MBEDTLS_SSL_IS_SERVER)
438                         , clientctx(config.ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
439                         , cacerts(config.castr, true)
440                         , crl(config.crlstr)
441                         , hash(config.hashstr)
442                         , outrecsize(config.outrecsize)
443                 {
444                         serverctx.SetX509CertAndKey(x509cred);
445                         clientctx.SetX509CertAndKey(x509cred);
446                         clientctx.SetMinDHBits(config.mindh);
447
448                         if (!ciphersuites.empty())
449                         {
450                                 serverctx.SetCiphersuites(ciphersuites);
451                                 clientctx.SetCiphersuites(ciphersuites);
452                         }
453
454                         if (!curves.empty())
455                         {
456                                 serverctx.SetCurves(curves);
457                                 clientctx.SetCurves(curves);
458                         }
459
460                         serverctx.SetVersion(config.minver, config.maxver);
461                         clientctx.SetVersion(config.minver, config.maxver);
462
463                         if (!config.dhstr.empty())
464                         {
465                                 dhparams.set(config.dhstr);
466                                 serverctx.SetDHParams(dhparams);
467                         }
468
469                         clientctx.SetOptionalVerifyCert();
470                         clientctx.SetCA(cacerts, crl);
471                         // The default for servers is to not request a client certificate from the peer
472                         if (config.requestclientcert)
473                         {
474                                 serverctx.SetOptionalVerifyCert();
475                                 serverctx.SetCA(cacerts, crl);
476                         }
477                 }
478
479                 static std::string ReadFile(const std::string& filename)
480                 {
481                         FileReader reader(filename);
482                         std::string ret = reader.GetString();
483                         if (ret.empty())
484                                 throw Exception("Cannot read file " + filename);
485                         return ret;
486                 }
487
488                 /** Set up the given session with the settings in this profile
489                  */
490                 void SetupClientSession(mbedtls_ssl_context* sess)
491                 {
492                         mbedtls_ssl_setup(sess, clientctx.GetConf());
493                 }
494
495                 void SetupServerSession(mbedtls_ssl_context* sess)
496                 {
497                         mbedtls_ssl_setup(sess, serverctx.GetConf());
498                 }
499
500                 const std::string& GetName() const { return name; }
501                 X509Credentials& GetX509Credentials() { return x509cred; }
502                 unsigned int GetOutgoingRecordSize() const { return outrecsize; }
503                 const Hash& GetHash() const { return hash; }
504         };
505 }
506
507 class mbedTLSIOHook : public SSLIOHook
508 {
509         enum Status
510         {
511                 ISSL_NONE,
512                 ISSL_HANDSHAKING,
513                 ISSL_HANDSHAKEN
514         };
515
516         mbedtls_ssl_context sess;
517         Status status;
518
519         void CloseSession()
520         {
521                 if (status == ISSL_NONE)
522                         return;
523
524                 mbedtls_ssl_close_notify(&sess);
525                 mbedtls_ssl_free(&sess);
526                 certificate = NULL;
527                 status = ISSL_NONE;
528         }
529
530         // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed
531         int Handshake(StreamSocket* sock)
532         {
533                 int ret = mbedtls_ssl_handshake(&sess);
534                 if (ret == 0)
535                 {
536                         // Change the seesion state
537                         this->status = ISSL_HANDSHAKEN;
538
539                         VerifyCertificate();
540
541                         // Finish writing, if any left
542                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
543
544                         return 1;
545                 }
546
547                 this->status = ISSL_HANDSHAKING;
548                 if (ret == MBEDTLS_ERR_SSL_WANT_READ)
549                 {
550                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
551                         return 0;
552                 }
553                 else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
554                 {
555                         SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
556                         return 0;
557                 }
558
559                 sock->SetError("Handshake Failed - " + mbedTLS::ErrorToString(ret));
560                 CloseSession();
561                 return -1;
562         }
563
564         // Returns 1 if application I/O should proceed, 0 if it must wait for the underlying protocol to progress, -1 on fatal error
565         int PrepareIO(StreamSocket* sock)
566         {
567                 if (status == ISSL_HANDSHAKEN)
568                         return 1;
569                 else if (status == ISSL_HANDSHAKING)
570                 {
571                         // The handshake isn't finished, try to finish it
572                         return Handshake(sock);
573                 }
574
575                 CloseSession();
576                 sock->SetError("No SSL session");
577                 return -1;
578         }
579
580         void VerifyCertificate()
581         {
582                 this->certificate = new ssl_cert;
583                 const mbedtls_x509_crt* const cert = mbedtls_ssl_get_peer_cert(&sess);
584                 if (!cert)
585                 {
586                         certificate->error = "No client certificate sent";
587                         return;
588                 }
589
590                 // If there is a certificate we can always generate a fingerprint
591                 certificate->fingerprint = GetProfile().GetHash().hash(cert->raw.p, cert->raw.len);
592
593                 // At this point mbedTLS verified the cert already, we just need to check the results
594                 const uint32_t flags = mbedtls_ssl_get_verify_result(&sess);
595                 if (flags == 0xFFFFFFFF)
596                 {
597                         certificate->error = "Internal error during verification";
598                         return;
599                 }
600
601                 if (flags == 0)
602                 {
603                         // Verification succeeded
604                         certificate->trusted = true;
605                 }
606                 else
607                 {
608                         // Verification failed
609                         certificate->trusted = false;
610                         if ((flags & MBEDTLS_X509_BADCERT_EXPIRED) || (flags & MBEDTLS_X509_BADCERT_FUTURE))
611                                 certificate->error = "Not activated, or expired certificate";
612                 }
613
614                 certificate->unknownsigner = (flags & MBEDTLS_X509_BADCERT_NOT_TRUSTED);
615                 certificate->revoked = (flags & MBEDTLS_X509_BADCERT_REVOKED);
616                 certificate->invalid = ((flags & MBEDTLS_X509_BADCERT_BAD_KEY) || (flags & MBEDTLS_X509_BADCERT_BAD_MD) || (flags & MBEDTLS_X509_BADCERT_BAD_PK));
617
618                 GetDNString(&cert->subject, certificate->dn);
619                 GetDNString(&cert->issuer, certificate->issuer);
620         }
621
622         static void GetDNString(const mbedtls_x509_name* x509name, std::string& out)
623         {
624                 char buf[512];
625                 const int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), x509name);
626                 if (ret <= 0)
627                         return;
628
629                 out.assign(buf, ret);
630         }
631
632         static int Pull(void* userptr, unsigned char* buffer, size_t size)
633         {
634                 StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
635                 if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
636                         return MBEDTLS_ERR_SSL_WANT_READ;
637
638                 const int ret = SocketEngine::Recv(sock, reinterpret_cast<char*>(buffer), size, 0);
639                 if (ret < (int)size)
640                 {
641                         SocketEngine::ChangeEventMask(sock, FD_READ_WILL_BLOCK);
642                         if ((ret == -1) && (SocketEngine::IgnoreError()))
643                                 return MBEDTLS_ERR_SSL_WANT_READ;
644                 }
645                 return ret;
646         }
647
648         static int Push(void* userptr, const unsigned char* buffer, size_t size)
649         {
650                 StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
651                 if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
652                         return MBEDTLS_ERR_SSL_WANT_WRITE;
653
654                 const int ret = SocketEngine::Send(sock, buffer, size, 0);
655                 if (ret < (int)size)
656                 {
657                         SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK);
658                         if ((ret == -1) && (SocketEngine::IgnoreError()))
659                                 return MBEDTLS_ERR_SSL_WANT_WRITE;
660                 }
661                 return ret;
662         }
663
664  public:
665         mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver)
666                 : SSLIOHook(hookprov)
667                 , status(ISSL_NONE)
668         {
669                 mbedtls_ssl_init(&sess);
670                 if (isserver)
671                         GetProfile().SetupServerSession(&sess);
672                 else
673                         GetProfile().SetupClientSession(&sess);
674
675                 mbedtls_ssl_set_bio(&sess, reinterpret_cast<void*>(sock), Push, Pull, NULL);
676
677                 sock->AddIOHook(this);
678                 Handshake(sock);
679         }
680
681         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
682         {
683                 CloseSession();
684         }
685
686         int OnStreamSocketRead(StreamSocket* sock, std::string& recvq) CXX11_OVERRIDE
687         {
688                 // Finish handshake if needed
689                 int prepret = PrepareIO(sock);
690                 if (prepret <= 0)
691                         return prepret;
692
693                 // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN.
694                 char* const readbuf = ServerInstance->GetReadBuffer();
695                 const size_t readbufsize = ServerInstance->Config->NetBufferSize;
696                 int ret = mbedtls_ssl_read(&sess, reinterpret_cast<unsigned char*>(readbuf), readbufsize);
697                 if (ret > 0)
698                 {
699                         recvq.append(readbuf, ret);
700
701                         // Schedule a read if there is still data in the mbedTLS buffer
702                         if (mbedtls_ssl_get_bytes_avail(&sess) > 0)
703                                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_READ);
704                         return 1;
705                 }
706                 else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
707                 {
708                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
709                         return 0;
710                 }
711                 else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
712                 {
713                         SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
714                         return 0;
715                 }
716                 else if (ret == 0)
717                 {
718                         sock->SetError("Connection closed");
719                         CloseSession();
720                         return -1;
721                 }
722                 else // error or MBEDTLS_ERR_SSL_CLIENT_RECONNECT which we treat as an error
723                 {
724                         sock->SetError(mbedTLS::ErrorToString(ret));
725                         CloseSession();
726                         return -1;
727                 }
728         }
729
730         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
731         {
732                 // Finish handshake if needed
733                 int prepret = PrepareIO(sock);
734                 if (prepret <= 0)
735                         return prepret;
736
737                 // Session is ready for transferring application data
738                 while (!sendq.empty())
739                 {
740                         FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
741                         const StreamSocket::SendQueue::Element& buffer = sendq.front();
742                         int ret = mbedtls_ssl_write(&sess, reinterpret_cast<const unsigned char*>(buffer.data()), buffer.length());
743                         if (ret == (int)buffer.length())
744                         {
745                                 // Wrote entire record, continue sending
746                                 sendq.pop_front();
747                         }
748                         else if (ret > 0)
749                         {
750                                 sendq.erase_front(ret);
751                                 SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
752                                 return 0;
753                         }
754                         else if (ret == 0)
755                         {
756                                 sock->SetError("Connection closed");
757                                 CloseSession();
758                                 return -1;
759                         }
760                         else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
761                         {
762                                 SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
763                                 return 0;
764                         }
765                         else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
766                         {
767                                 SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
768                                 return 0;
769                         }
770                         else
771                         {
772                                 sock->SetError(mbedTLS::ErrorToString(ret));
773                                 CloseSession();
774                                 return -1;
775                         }
776                 }
777
778                 SocketEngine::ChangeEventMask(sock, FD_WANT_NO_WRITE);
779                 return 1;
780         }
781
782         void GetCiphersuite(std::string& out) const CXX11_OVERRIDE
783         {
784                 if (!IsHandshakeDone())
785                         return;
786                 out.append(mbedtls_ssl_get_version(&sess)).push_back('-');
787
788                 // All mbedTLS ciphersuite names currently begin with "TLS-" which provides no useful information so skip it, but be prepared if it changes
789                 const char* const ciphersuitestr = mbedtls_ssl_get_ciphersuite(&sess);
790                 const char prefix[] = "TLS-";
791                 unsigned int skip = sizeof(prefix)-1;
792                 if (strncmp(ciphersuitestr, prefix, sizeof(prefix)-1))
793                         skip = 0;
794                 out.append(ciphersuitestr + skip);
795         }
796
797         bool GetServerName(std::string& out) const CXX11_OVERRIDE
798         {
799                 // TODO: Implement SNI support.
800                 return false;
801         }
802
803         mbedTLS::Profile& GetProfile();
804         bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
805 };
806
807 class mbedTLSIOHookProvider : public IOHookProvider
808 {
809         mbedTLS::Profile profile;
810
811  public:
812         mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile::Config& config)
813                 : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL)
814                 , profile(config)
815         {
816                 ServerInstance->Modules->AddService(*this);
817         }
818
819         ~mbedTLSIOHookProvider()
820         {
821                 ServerInstance->Modules->DelService(*this);
822         }
823
824         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
825         {
826                 new mbedTLSIOHook(this, sock, true);
827         }
828
829         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
830         {
831                 new mbedTLSIOHook(this, sock, false);
832         }
833
834         mbedTLS::Profile& GetProfile() { return profile; }
835 };
836
837 mbedTLS::Profile& mbedTLSIOHook::GetProfile()
838 {
839         IOHookProvider* hookprov = prov;
840         return static_cast<mbedTLSIOHookProvider*>(hookprov)->GetProfile();
841 }
842
843 class ModuleSSLmbedTLS : public Module
844 {
845         typedef std::vector<reference<mbedTLSIOHookProvider> > ProfileList;
846
847         mbedTLS::Entropy entropy;
848         mbedTLS::CTRDRBG ctr_drbg;
849         ProfileList profiles;
850
851         void ReadProfiles()
852         {
853                 // First, store all profiles in a new, temporary container. If no problems occur, swap the two
854                 // containers; this way if something goes wrong we can go back and continue using the current profiles,
855                 // avoiding unpleasant situations where no new SSL connections are possible.
856                 ProfileList newprofiles;
857
858                 ConfigTagList tags = ServerInstance->Config->ConfTags("sslprofile");
859                 if (tags.first == tags.second)
860                 {
861                         // No <sslprofile> tags found, create a profile named "mbedtls" from settings in the <mbedtls> block
862                         const std::string defname = "mbedtls";
863                         ConfigTag* tag = ServerInstance->Config->ConfValue(defname);
864                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No <sslprofile> tags found; using settings from the <mbedtls> tag");
865
866                         try
867                         {
868                                 mbedTLS::Profile::Config profileconfig(defname, tag, ctr_drbg);
869                                 newprofiles.push_back(new mbedTLSIOHookProvider(this, profileconfig));
870                         }
871                         catch (CoreException& ex)
872                         {
873                                 throw ModuleException("Error while initializing the default SSL profile - " + ex.GetReason());
874                         }
875                 }
876
877                 for (ConfigIter i = tags.first; i != tags.second; ++i)
878                 {
879                         ConfigTag* tag = i->second;
880                         if (!stdalgo::string::equalsci(tag->getString("provider"), "mbedtls"))
881                                 continue;
882
883                         std::string name = tag->getString("name");
884                         if (name.empty())
885                         {
886                                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring <sslprofile> tag without name at " + tag->getTagLocation());
887                                 continue;
888                         }
889
890                         reference<mbedTLSIOHookProvider> prov;
891                         try
892                         {
893                                 mbedTLS::Profile::Config profileconfig(name, tag, ctr_drbg);
894                                 prov = new mbedTLSIOHookProvider(this, profileconfig);
895                         }
896                         catch (CoreException& ex)
897                         {
898                                 throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
899                         }
900
901                         newprofiles.push_back(prov);
902                 }
903
904                 // New profiles are ok, begin using them
905                 // Old profiles are deleted when their refcount drops to zero
906                 for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
907                 {
908                         mbedTLSIOHookProvider& prov = **i;
909                         ServerInstance->Modules.DelService(prov);
910                 }
911
912                 profiles.swap(newprofiles);
913         }
914
915  public:
916         void init() CXX11_OVERRIDE
917         {
918                 char verbuf[16]; // Should be at least 9 bytes in size
919                 mbedtls_version_get_string(verbuf);
920                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "mbedTLS lib version %s module was compiled for " MBEDTLS_VERSION_STRING, verbuf);
921
922                 if (!ctr_drbg.Seed(entropy))
923                         throw ModuleException("CTR DRBG seed failed");
924                 ReadProfiles();
925         }
926
927         void OnModuleRehash(User* user, const std::string &param) CXX11_OVERRIDE
928         {
929                 if (param != "ssl")
930                         return;
931
932                 try
933                 {
934                         ReadProfiles();
935                 }
936                 catch (ModuleException& ex)
937                 {
938                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, ex.GetReason() + " Not applying settings.");
939                 }
940         }
941
942         void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
943         {
944                 if (type != ExtensionItem::EXT_USER)
945                         return;
946
947                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
948                 if ((user) && (user->eh.GetModHook(this)))
949                 {
950                         // User is using SSL, they're a local user, and they're using our IOHook.
951                         // Potentially there could be multiple SSL modules loaded at once on different ports.
952                         ServerInstance->Users.QuitUser(user, "SSL module unloading");
953                 }
954         }
955
956         ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
957         {
958                 const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this));
959                 if ((iohook) && (!iohook->IsHandshakeDone()))
960                         return MOD_RES_DENY;
961                 return MOD_RES_PASSTHRU;
962         }
963
964         Version GetVersion() CXX11_OVERRIDE
965         {
966                 return Version("Provides SSL support via mbedTLS (PolarSSL)", VF_VENDOR);
967         }
968 };
969
970 MODULE_INIT(ModuleSSLmbedTLS)