]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_ssl_mbedtls.cpp
9c2a535ae0c5da095ba18279ce1ea1161cd969d4
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_mbedtls.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2020 Matt Schatz <genius3000@g3k.solutions>
5  *   Copyright (C) 2016-2021 Sadie Powell <sadie@witchery.services>
6  *   Copyright (C) 2016-2017 Attila Molnar <attilamolnar@hush.com>
7  *
8  * This file is part of InspIRCd.  InspIRCd is free software: you can
9  * redistribute it and/or modify it under the terms of the GNU General Public
10  * License as published by the Free Software Foundation, version 2.
11  *
12  * This program is distributed in the hope that it will be useful, but WITHOUT
13  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
14  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
15  * details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
19  */
20
21 /// $LinkerFlags: -lmbedtls
22
23 /// $PackageInfo: require_system("arch") mbedtls
24 /// $PackageInfo: require_system("darwin") mbedtls
25 /// $PackageInfo: require_system("debian" "9.0") libmbedtls-dev
26 /// $PackageInfo: require_system("ubuntu" "16.04") libmbedtls-dev
27
28
29 #include "inspircd.h"
30 #include "modules/ssl.h"
31
32 // Fix warnings about the use of commas at end of enumerator lists on C++03.
33 #if defined __clang__
34 # pragma clang diagnostic ignored "-Wc++11-extensions"
35 #elif defined __GNUC__
36 # if (__GNUC__ > 4) || ((__GNUC__ == 4) && (__GNUC_MINOR__ >= 8))
37 #  pragma GCC diagnostic ignored "-Wpedantic"
38 # else
39 #  pragma GCC diagnostic ignored "-pedantic"
40 # endif
41 #endif
42
43 #include <mbedtls/ctr_drbg.h>
44 #include <mbedtls/dhm.h>
45 #include <mbedtls/ecp.h>
46 #include <mbedtls/entropy.h>
47 #include <mbedtls/error.h>
48 #include <mbedtls/md.h>
49 #include <mbedtls/pk.h>
50 #include <mbedtls/ssl.h>
51 #include <mbedtls/ssl_ciphersuites.h>
52 #include <mbedtls/version.h>
53 #include <mbedtls/x509.h>
54 #include <mbedtls/x509_crt.h>
55 #include <mbedtls/x509_crl.h>
56
57 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
58 #include <mbedtls/debug.h>
59 #endif
60
61 namespace mbedTLS
62 {
63         class Exception : public ModuleException
64         {
65          public:
66                 Exception(const std::string& reason)
67                         : ModuleException(reason) { }
68         };
69
70         std::string ErrorToString(int errcode)
71         {
72                 char buf[256];
73                 mbedtls_strerror(errcode, buf, sizeof(buf));
74                 return buf;
75         }
76
77         void ThrowOnError(int errcode, const char* msg)
78         {
79                 if (errcode != 0)
80                 {
81                         std::string reason = msg;
82                         reason.append(" :").append(ErrorToString(errcode));
83                         throw Exception(reason);
84                 }
85         }
86
87         template <typename T, void (*init)(T*), void (*deinit)(T*)>
88         class RAIIObj
89         {
90                 T obj;
91
92          public:
93                 RAIIObj()
94                 {
95                         init(&obj);
96                 }
97
98                 ~RAIIObj()
99                 {
100                         deinit(&obj);
101                 }
102
103                 T* get() { return &obj; }
104                 const T* get() const { return &obj; }
105         };
106
107         typedef RAIIObj<mbedtls_entropy_context, mbedtls_entropy_init, mbedtls_entropy_free> Entropy;
108
109         class CTRDRBG : private RAIIObj<mbedtls_ctr_drbg_context, mbedtls_ctr_drbg_init, mbedtls_ctr_drbg_free>
110         {
111          public:
112                 bool Seed(Entropy& entropy)
113                 {
114                         return (mbedtls_ctr_drbg_seed(get(), mbedtls_entropy_func, entropy.get(), NULL, 0) == 0);
115                 }
116
117                 void SetupConf(mbedtls_ssl_config* conf)
118                 {
119                         mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, get());
120                 }
121         };
122
123         class DHParams : public RAIIObj<mbedtls_dhm_context, mbedtls_dhm_init, mbedtls_dhm_free>
124         {
125          public:
126                 void set(const std::string& dhstr)
127                 {
128                         // Last parameter is buffer size, must include the terminating null
129                         int ret = mbedtls_dhm_parse_dhm(get(), reinterpret_cast<const unsigned char*>(dhstr.c_str()), dhstr.size()+1);
130                         ThrowOnError(ret, "Unable to import DH params");
131                 }
132         };
133
134         class X509Key : public RAIIObj<mbedtls_pk_context, mbedtls_pk_init, mbedtls_pk_free>
135         {
136          public:
137                 /** Import */
138                 X509Key(const std::string& keystr)
139                 {
140                         int ret = mbedtls_pk_parse_key(get(), reinterpret_cast<const unsigned char*>(keystr.c_str()), keystr.size()+1, NULL, 0);
141                         ThrowOnError(ret, "Unable to import private key");
142                 }
143         };
144
145         class Ciphersuites
146         {
147                 std::vector<int> list;
148
149          public:
150                 Ciphersuites(const std::string& str)
151                 {
152                         // mbedTLS uses the ciphersuite format "TLS-ECDHE-RSA-WITH-AES-128-GCM-SHA256" internally.
153                         // This is a bit verbose, so we make life a bit simpler for admins by not requiring them to supply the static parts.
154                         irc::sepstream ss(str, ':');
155                         for (std::string token; ss.GetToken(token); )
156                         {
157                                 // Prepend "TLS-" if not there
158                                 if (token.compare(0, 4, "TLS-", 4))
159                                         token.insert(0, "TLS-");
160
161                                 const int id = mbedtls_ssl_get_ciphersuite_id(token.c_str());
162                                 if (!id)
163                                         throw Exception("Unknown ciphersuite " + token);
164                                 list.push_back(id);
165                         }
166                         list.push_back(0);
167                 }
168
169                 const int* get() const { return &list.front(); }
170                 bool empty() const { return (list.size() <= 1); }
171         };
172
173         class Curves
174         {
175                 std::vector<mbedtls_ecp_group_id> list;
176
177          public:
178                 Curves(const std::string& str)
179                 {
180                         irc::sepstream ss(str, ':');
181                         for (std::string token; ss.GetToken(token); )
182                         {
183                                 const mbedtls_ecp_curve_info* curve = mbedtls_ecp_curve_info_from_name(token.c_str());
184                                 if (!curve)
185                                         throw Exception("Unknown curve " + token);
186                                 list.push_back(curve->grp_id);
187                         }
188                         list.push_back(MBEDTLS_ECP_DP_NONE);
189                 }
190
191                 const mbedtls_ecp_group_id* get() const { return &list.front(); }
192                 bool empty() const { return (list.size() <= 1); }
193         };
194
195         class X509CertList : public RAIIObj<mbedtls_x509_crt, mbedtls_x509_crt_init, mbedtls_x509_crt_free>
196         {
197          public:
198                 /** Import or create empty */
199                 X509CertList(const std::string& certstr, bool allowempty = false)
200                 {
201                         if ((allowempty) && (certstr.empty()))
202                                 return;
203                         int ret = mbedtls_x509_crt_parse(get(), reinterpret_cast<const unsigned char*>(certstr.c_str()), certstr.size()+1);
204                         ThrowOnError(ret, "Unable to load certificates");
205                 }
206
207                 bool empty() const { return (get()->raw.p != NULL); }
208         };
209
210         class X509CRL : public RAIIObj<mbedtls_x509_crl, mbedtls_x509_crl_init, mbedtls_x509_crl_free>
211         {
212          public:
213                 X509CRL(const std::string& crlstr)
214                 {
215                         if (crlstr.empty())
216                                 return;
217                         int ret = mbedtls_x509_crl_parse(get(), reinterpret_cast<const unsigned char*>(crlstr.c_str()), crlstr.size()+1);
218                         ThrowOnError(ret, "Unable to load CRL");
219                 }
220         };
221
222         class X509Credentials
223         {
224                 /** Private key
225                  */
226                 X509Key key;
227
228                 /** Certificate list, presented to the peer
229                  */
230                 X509CertList certs;
231
232          public:
233                 X509Credentials(const std::string& certstr, const std::string& keystr)
234                         : key(keystr)
235                         , certs(certstr)
236                 {
237                         // Verify that one of the certs match the private key
238                         bool found = false;
239                         for (mbedtls_x509_crt* cert = certs.get(); cert; cert = cert->next)
240                         {
241                                 if (mbedtls_pk_check_pair(&cert->pk, key.get()) == 0)
242                                 {
243                                         found = true;
244                                         break;
245                                 }
246                         }
247                         if (!found)
248                                 throw Exception("Public/private key pair does not match");
249                 }
250
251                 mbedtls_pk_context* getkey() { return key.get(); }
252                 mbedtls_x509_crt* getcerts() { return certs.get(); }
253         };
254
255         class Context
256         {
257                 mbedtls_ssl_config conf;
258
259 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
260                 static void DebugLogFunc(void* userptr, int level, const char* file, int line, const char* msg)
261                 {
262                         // Remove trailing \n
263                         size_t len = strlen(msg);
264                         if ((len > 0) && (msg[len-1] == '\n'))
265                                 len--;
266                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "%s:%d %.*s", file, line, len, msg);
267                 }
268 #endif
269
270          public:
271                 Context(CTRDRBG& ctrdrbg, unsigned int endpoint)
272                 {
273                         mbedtls_ssl_config_init(&conf);
274 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
275                         mbedtls_debug_set_threshold(INT_MAX);
276                         mbedtls_ssl_conf_dbg(&conf, DebugLogFunc, NULL);
277 #endif
278
279                         // TODO: check ret of mbedtls_ssl_config_defaults
280                         mbedtls_ssl_config_defaults(&conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
281                         ctrdrbg.SetupConf(&conf);
282                 }
283
284                 ~Context()
285                 {
286                         mbedtls_ssl_config_free(&conf);
287                 }
288
289                 void SetMinDHBits(unsigned int mindh)
290                 {
291                         mbedtls_ssl_conf_dhm_min_bitlen(&conf, mindh);
292                 }
293
294                 void SetDHParams(DHParams& dh)
295                 {
296                         mbedtls_ssl_conf_dh_param_ctx(&conf, dh.get());
297                 }
298
299                 void SetX509CertAndKey(X509Credentials& x509cred)
300                 {
301                         mbedtls_ssl_conf_own_cert(&conf, x509cred.getcerts(), x509cred.getkey());
302                 }
303
304                 void SetCiphersuites(const Ciphersuites& ciphersuites)
305                 {
306                         mbedtls_ssl_conf_ciphersuites(&conf, ciphersuites.get());
307                 }
308
309                 void SetCurves(const Curves& curves)
310                 {
311                         mbedtls_ssl_conf_curves(&conf, curves.get());
312                 }
313
314                 void SetVersion(int minver, int maxver)
315                 {
316                         // SSL v3 support cannot be enabled
317                         if (minver)
318                                 mbedtls_ssl_conf_min_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, minver);
319                         if (maxver)
320                                 mbedtls_ssl_conf_max_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, maxver);
321                 }
322
323                 void SetCA(X509CertList& certs, X509CRL& crl)
324                 {
325                         mbedtls_ssl_conf_ca_chain(&conf, certs.get(), crl.get());
326                 }
327
328                 void SetOptionalVerifyCert()
329                 {
330                         mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
331                 }
332
333                 const mbedtls_ssl_config* GetConf() const { return &conf; }
334         };
335
336         class Hash
337         {
338                 const mbedtls_md_info_t* md;
339
340                 /** Buffer where cert hashes are written temporarily
341                  */
342                 mutable std::vector<unsigned char> buf;
343
344          public:
345                 Hash(std::string hashstr)
346                 {
347                         std::transform(hashstr.begin(), hashstr.end(), hashstr.begin(), ::toupper);
348                         md = mbedtls_md_info_from_string(hashstr.c_str());
349                         if (!md)
350                                 throw Exception("Unknown hash: " + hashstr);
351
352                         buf.resize(mbedtls_md_get_size(md));
353                 }
354
355                 std::string hash(const unsigned char* input, size_t length) const
356                 {
357                         mbedtls_md(md, input, length, &buf.front());
358                         return BinToHex(&buf.front(), buf.size());
359                 }
360         };
361
362         class Profile
363         {
364                 /** Name of this profile
365                  */
366                 const std::string name;
367
368                 X509Credentials x509cred;
369
370                 /** Ciphersuites to use
371                  */
372                 Ciphersuites ciphersuites;
373
374                 /** Curves accepted for use in ECDHE and in the peer's end-entity certificate
375                  */
376                 Curves curves;
377
378                 Context serverctx;
379                 Context clientctx;
380
381                 DHParams dhparams;
382
383                 X509CertList cacerts;
384
385                 X509CRL crl;
386
387                 /** Hashing algorithm to use when generating certificate fingerprints
388                  */
389                 Hash hash;
390
391                 /** Rough max size of records to send
392                  */
393                 const unsigned int outrecsize;
394
395          public:
396                 struct Config
397                 {
398                         const std::string name;
399
400                         CTRDRBG& ctrdrbg;
401
402                         const std::string certstr;
403                         const std::string keystr;
404                         const std::string dhstr;
405
406                         const std::string ciphersuitestr;
407                         const std::string curvestr;
408                         const unsigned int mindh;
409                         const std::string hashstr;
410
411                         std::string crlstr;
412                         std::string castr;
413
414                         const int minver;
415                         const int maxver;
416                         const unsigned int outrecsize;
417                         const bool requestclientcert;
418
419                         Config(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
420                                 : name(profilename)
421                                 , ctrdrbg(ctr_drbg)
422                                 , certstr(ReadFile(tag->getString("certfile", "cert.pem", 1)))
423                                 , keystr(ReadFile(tag->getString("keyfile", "key.pem", 1)))
424                                 , dhstr(ReadFile(tag->getString("dhfile", "dhparams.pem", 1)))
425                                 , ciphersuitestr(tag->getString("ciphersuites"))
426                                 , curvestr(tag->getString("curves"))
427                                 , mindh(tag->getUInt("mindhbits", 2048))
428                                 , hashstr(tag->getString("hash", "sha256", 1))
429                                 , castr(tag->getString("cafile"))
430                                 , minver(tag->getUInt("minver", 0))
431                                 , maxver(tag->getUInt("maxver", 0))
432                                 , outrecsize(tag->getUInt("outrecsize", 2048, 512, 16384))
433                                 , requestclientcert(tag->getBool("requestclientcert", true))
434                         {
435                                 if (!castr.empty())
436                                 {
437                                         castr = ReadFile(castr);
438                                         crlstr = tag->getString("crlfile");
439                                         if (!crlstr.empty())
440                                                 crlstr = ReadFile(crlstr);
441                                 }
442                         }
443                 };
444
445                 Profile(Config& config)
446                         : name(config.name)
447                         , x509cred(config.certstr, config.keystr)
448                         , ciphersuites(config.ciphersuitestr)
449                         , curves(config.curvestr)
450                         , serverctx(config.ctrdrbg, MBEDTLS_SSL_IS_SERVER)
451                         , clientctx(config.ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
452                         , cacerts(config.castr, true)
453                         , crl(config.crlstr)
454                         , hash(config.hashstr)
455                         , outrecsize(config.outrecsize)
456                 {
457                         serverctx.SetX509CertAndKey(x509cred);
458                         clientctx.SetX509CertAndKey(x509cred);
459                         clientctx.SetMinDHBits(config.mindh);
460
461                         if (!ciphersuites.empty())
462                         {
463                                 serverctx.SetCiphersuites(ciphersuites);
464                                 clientctx.SetCiphersuites(ciphersuites);
465                         }
466
467                         if (!curves.empty())
468                         {
469                                 serverctx.SetCurves(curves);
470                                 clientctx.SetCurves(curves);
471                         }
472
473                         serverctx.SetVersion(config.minver, config.maxver);
474                         clientctx.SetVersion(config.minver, config.maxver);
475
476                         if (!config.dhstr.empty())
477                         {
478                                 dhparams.set(config.dhstr);
479                                 serverctx.SetDHParams(dhparams);
480                         }
481
482                         clientctx.SetOptionalVerifyCert();
483                         clientctx.SetCA(cacerts, crl);
484                         // The default for servers is to not request a client certificate from the peer
485                         if (config.requestclientcert)
486                         {
487                                 serverctx.SetOptionalVerifyCert();
488                                 serverctx.SetCA(cacerts, crl);
489                         }
490                 }
491
492                 static std::string ReadFile(const std::string& filename)
493                 {
494                         FileReader reader(filename);
495                         std::string ret = reader.GetString();
496                         if (ret.empty())
497                                 throw Exception("Cannot read file " + filename);
498                         return ret;
499                 }
500
501                 /** Set up the given session with the settings in this profile
502                  */
503                 void SetupClientSession(mbedtls_ssl_context* sess)
504                 {
505                         mbedtls_ssl_setup(sess, clientctx.GetConf());
506                 }
507
508                 void SetupServerSession(mbedtls_ssl_context* sess)
509                 {
510                         mbedtls_ssl_setup(sess, serverctx.GetConf());
511                 }
512
513                 const std::string& GetName() const { return name; }
514                 X509Credentials& GetX509Credentials() { return x509cred; }
515                 unsigned int GetOutgoingRecordSize() const { return outrecsize; }
516                 const Hash& GetHash() const { return hash; }
517         };
518 }
519
520 class mbedTLSIOHook : public SSLIOHook
521 {
522         enum Status
523         {
524                 ISSL_NONE,
525                 ISSL_HANDSHAKING,
526                 ISSL_HANDSHAKEN
527         };
528
529         mbedtls_ssl_context sess;
530         Status status;
531
532         void CloseSession()
533         {
534                 if (status == ISSL_NONE)
535                         return;
536
537                 mbedtls_ssl_close_notify(&sess);
538                 mbedtls_ssl_free(&sess);
539                 certificate = NULL;
540                 status = ISSL_NONE;
541         }
542
543         // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed
544         int Handshake(StreamSocket* sock)
545         {
546                 int ret = mbedtls_ssl_handshake(&sess);
547                 if (ret == 0)
548                 {
549                         // Change the seesion state
550                         this->status = ISSL_HANDSHAKEN;
551
552                         VerifyCertificate();
553
554                         // Finish writing, if any left
555                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
556
557                         return 1;
558                 }
559
560                 this->status = ISSL_HANDSHAKING;
561                 if (ret == MBEDTLS_ERR_SSL_WANT_READ)
562                 {
563                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
564                         return 0;
565                 }
566                 else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
567                 {
568                         SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
569                         return 0;
570                 }
571
572                 sock->SetError("Handshake Failed - " + mbedTLS::ErrorToString(ret));
573                 CloseSession();
574                 return -1;
575         }
576
577         // Returns 1 if application I/O should proceed, 0 if it must wait for the underlying protocol to progress, -1 on fatal error
578         int PrepareIO(StreamSocket* sock)
579         {
580                 if (status == ISSL_HANDSHAKEN)
581                         return 1;
582                 else if (status == ISSL_HANDSHAKING)
583                 {
584                         // The handshake isn't finished, try to finish it
585                         return Handshake(sock);
586                 }
587
588                 CloseSession();
589                 sock->SetError("No TLS (SSL) session");
590                 return -1;
591         }
592
593         void VerifyCertificate()
594         {
595                 this->certificate = new ssl_cert;
596                 const mbedtls_x509_crt* const cert = mbedtls_ssl_get_peer_cert(&sess);
597                 if (!cert)
598                 {
599                         certificate->error = "No client certificate sent";
600                         return;
601                 }
602
603                 // If there is a certificate we can always generate a fingerprint
604                 certificate->fingerprint = GetProfile().GetHash().hash(cert->raw.p, cert->raw.len);
605
606                 // At this point mbedTLS verified the cert already, we just need to check the results
607                 const uint32_t flags = mbedtls_ssl_get_verify_result(&sess);
608                 if (flags == 0xFFFFFFFF)
609                 {
610                         certificate->error = "Internal error during verification";
611                         return;
612                 }
613
614                 if (flags == 0)
615                 {
616                         // Verification succeeded
617                         certificate->trusted = true;
618                 }
619                 else
620                 {
621                         // Verification failed
622                         certificate->trusted = false;
623                         if ((flags & MBEDTLS_X509_BADCERT_EXPIRED) || (flags & MBEDTLS_X509_BADCERT_FUTURE))
624                                 certificate->error = "Not activated, or expired certificate";
625                 }
626
627                 certificate->unknownsigner = (flags & MBEDTLS_X509_BADCERT_NOT_TRUSTED);
628                 certificate->revoked = (flags & MBEDTLS_X509_BADCERT_REVOKED);
629                 certificate->invalid = ((flags & MBEDTLS_X509_BADCERT_BAD_KEY) || (flags & MBEDTLS_X509_BADCERT_BAD_MD) || (flags & MBEDTLS_X509_BADCERT_BAD_PK));
630
631                 GetDNString(&cert->subject, certificate->dn);
632                 GetDNString(&cert->issuer, certificate->issuer);
633         }
634
635         static void GetDNString(const mbedtls_x509_name* x509name, std::string& out)
636         {
637                 char buf[512];
638                 const int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), x509name);
639                 if (ret <= 0)
640                         return;
641
642                 out.assign(buf, ret);
643         }
644
645         static int Pull(void* userptr, unsigned char* buffer, size_t size)
646         {
647                 StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
648                 if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
649                         return MBEDTLS_ERR_SSL_WANT_READ;
650
651                 const int ret = SocketEngine::Recv(sock, reinterpret_cast<char*>(buffer), size, 0);
652                 if (ret < (int)size)
653                 {
654                         SocketEngine::ChangeEventMask(sock, FD_READ_WILL_BLOCK);
655                         if ((ret == -1) && (SocketEngine::IgnoreError()))
656                                 return MBEDTLS_ERR_SSL_WANT_READ;
657                 }
658                 return ret;
659         }
660
661         static int Push(void* userptr, const unsigned char* buffer, size_t size)
662         {
663                 StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
664                 if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
665                         return MBEDTLS_ERR_SSL_WANT_WRITE;
666
667                 const int ret = SocketEngine::Send(sock, buffer, size, 0);
668                 if (ret < (int)size)
669                 {
670                         SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK);
671                         if ((ret == -1) && (SocketEngine::IgnoreError()))
672                                 return MBEDTLS_ERR_SSL_WANT_WRITE;
673                 }
674                 return ret;
675         }
676
677  public:
678         mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver)
679                 : SSLIOHook(hookprov)
680                 , status(ISSL_NONE)
681         {
682                 mbedtls_ssl_init(&sess);
683                 if (isserver)
684                         GetProfile().SetupServerSession(&sess);
685                 else
686                         GetProfile().SetupClientSession(&sess);
687
688                 mbedtls_ssl_set_bio(&sess, reinterpret_cast<void*>(sock), Push, Pull, NULL);
689
690                 sock->AddIOHook(this);
691                 Handshake(sock);
692         }
693
694         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
695         {
696                 CloseSession();
697         }
698
699         int OnStreamSocketRead(StreamSocket* sock, std::string& recvq) CXX11_OVERRIDE
700         {
701                 // Finish handshake if needed
702                 int prepret = PrepareIO(sock);
703                 if (prepret <= 0)
704                         return prepret;
705
706                 // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN.
707                 char* const readbuf = ServerInstance->GetReadBuffer();
708                 const size_t readbufsize = ServerInstance->Config->NetBufferSize;
709                 int ret = mbedtls_ssl_read(&sess, reinterpret_cast<unsigned char*>(readbuf), readbufsize);
710                 if (ret > 0)
711                 {
712                         recvq.append(readbuf, ret);
713
714                         // Schedule a read if there is still data in the mbedTLS buffer
715                         if (mbedtls_ssl_get_bytes_avail(&sess) > 0)
716                                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_READ);
717                         return 1;
718                 }
719                 else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
720                 {
721                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
722                         return 0;
723                 }
724                 else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
725                 {
726                         SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
727                         return 0;
728                 }
729                 else if (ret == 0)
730                 {
731                         sock->SetError("Connection closed");
732                         CloseSession();
733                         return -1;
734                 }
735                 else // error or MBEDTLS_ERR_SSL_CLIENT_RECONNECT which we treat as an error
736                 {
737                         sock->SetError(mbedTLS::ErrorToString(ret));
738                         CloseSession();
739                         return -1;
740                 }
741         }
742
743         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
744         {
745                 // Finish handshake if needed
746                 int prepret = PrepareIO(sock);
747                 if (prepret <= 0)
748                         return prepret;
749
750                 // Session is ready for transferring application data
751                 while (!sendq.empty())
752                 {
753                         FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
754                         const StreamSocket::SendQueue::Element& buffer = sendq.front();
755                         int ret = mbedtls_ssl_write(&sess, reinterpret_cast<const unsigned char*>(buffer.data()), buffer.length());
756                         if (ret == (int)buffer.length())
757                         {
758                                 // Wrote entire record, continue sending
759                                 sendq.pop_front();
760                         }
761                         else if (ret > 0)
762                         {
763                                 sendq.erase_front(ret);
764                                 SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
765                                 return 0;
766                         }
767                         else if (ret == 0)
768                         {
769                                 sock->SetError("Connection closed");
770                                 CloseSession();
771                                 return -1;
772                         }
773                         else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
774                         {
775                                 SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
776                                 return 0;
777                         }
778                         else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
779                         {
780                                 SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
781                                 return 0;
782                         }
783                         else
784                         {
785                                 sock->SetError(mbedTLS::ErrorToString(ret));
786                                 CloseSession();
787                                 return -1;
788                         }
789                 }
790
791                 SocketEngine::ChangeEventMask(sock, FD_WANT_NO_WRITE);
792                 return 1;
793         }
794
795         void GetCiphersuite(std::string& out) const CXX11_OVERRIDE
796         {
797                 if (!IsHandshakeDone())
798                         return;
799                 out.append(mbedtls_ssl_get_version(&sess)).push_back('-');
800
801                 // All mbedTLS ciphersuite names currently begin with "TLS-" which provides no useful information so skip it, but be prepared if it changes
802                 const char* const ciphersuitestr = mbedtls_ssl_get_ciphersuite(&sess);
803                 const char prefix[] = "TLS-";
804                 unsigned int skip = sizeof(prefix)-1;
805                 if (strncmp(ciphersuitestr, prefix, sizeof(prefix)-1))
806                         skip = 0;
807                 out.append(ciphersuitestr + skip);
808         }
809
810         bool GetServerName(std::string& out) const CXX11_OVERRIDE
811         {
812                 // TODO: Implement SNI support.
813                 return false;
814         }
815
816         mbedTLS::Profile& GetProfile();
817         bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
818 };
819
820 class mbedTLSIOHookProvider : public IOHookProvider
821 {
822         mbedTLS::Profile profile;
823
824  public:
825         mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile::Config& config)
826                 : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL)
827                 , profile(config)
828         {
829                 ServerInstance->Modules->AddService(*this);
830         }
831
832         ~mbedTLSIOHookProvider()
833         {
834                 ServerInstance->Modules->DelService(*this);
835         }
836
837         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
838         {
839                 new mbedTLSIOHook(this, sock, true);
840         }
841
842         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
843         {
844                 new mbedTLSIOHook(this, sock, false);
845         }
846
847         mbedTLS::Profile& GetProfile() { return profile; }
848 };
849
850 mbedTLS::Profile& mbedTLSIOHook::GetProfile()
851 {
852         IOHookProvider* hookprov = prov;
853         return static_cast<mbedTLSIOHookProvider*>(hookprov)->GetProfile();
854 }
855
856 class ModuleSSLmbedTLS : public Module
857 {
858         typedef std::vector<reference<mbedTLSIOHookProvider> > ProfileList;
859
860         mbedTLS::Entropy entropy;
861         mbedTLS::CTRDRBG ctr_drbg;
862         ProfileList profiles;
863
864         void ReadProfiles()
865         {
866                 // First, store all profiles in a new, temporary container. If no problems occur, swap the two
867                 // containers; this way if something goes wrong we can go back and continue using the current profiles,
868                 // avoiding unpleasant situations where no new TLS (SSL) connections are possible.
869                 ProfileList newprofiles;
870
871                 ConfigTagList tags = ServerInstance->Config->ConfTags("sslprofile");
872                 if (tags.first == tags.second)
873                 {
874                         // No <sslprofile> tags found, create a profile named "mbedtls" from settings in the <mbedtls> block
875                         const std::string defname = "mbedtls";
876                         ConfigTag* tag = ServerInstance->Config->ConfValue(defname);
877                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No <sslprofile> tags found; using settings from the deprecated <mbedtls> tag");
878
879                         try
880                         {
881                                 mbedTLS::Profile::Config profileconfig(defname, tag, ctr_drbg);
882                                 newprofiles.push_back(new mbedTLSIOHookProvider(this, profileconfig));
883                         }
884                         catch (CoreException& ex)
885                         {
886                                 throw ModuleException("Error while initializing the default TLS (SSL) profile - " + ex.GetReason());
887                         }
888                 }
889                 else
890                 {
891                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "You have defined an <sslprofile> tag; you should use this in place of \"mbedtls\" when configuring TLS (SSL) connections in <bind:ssl> or <link:ssl>");
892                         for (ConfigIter i = tags.first; i != tags.second; ++i)
893                         {
894                                 ConfigTag* tag = i->second;
895                                 if (!stdalgo::string::equalsci(tag->getString("provider"), "mbedtls"))
896                                 {
897                                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring non-mbedTLS <sslprofile> tag at " + tag->getTagLocation());
898                                         continue;
899                                 }
900
901                                 std::string name = tag->getString("name");
902                                 if (name.empty())
903                                 {
904                                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring <sslprofile> tag without name at " + tag->getTagLocation());
905                                         continue;
906                                 }
907
908                                 reference<mbedTLSIOHookProvider> prov;
909                                 try
910                                 {
911                                         mbedTLS::Profile::Config profileconfig(name, tag, ctr_drbg);
912                                         prov = new mbedTLSIOHookProvider(this, profileconfig);
913                                 }
914                                 catch (CoreException& ex)
915                                 {
916                                         throw ModuleException("Error while initializing TLS (SSL) profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
917                                 }
918
919                                 newprofiles.push_back(prov);
920                         }
921                 }
922
923                 // New profiles are ok, begin using them
924                 // Old profiles are deleted when their refcount drops to zero
925                 for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
926                 {
927                         mbedTLSIOHookProvider& prov = **i;
928                         ServerInstance->Modules.DelService(prov);
929                 }
930
931                 profiles.swap(newprofiles);
932         }
933
934  public:
935         void init() CXX11_OVERRIDE
936         {
937                 char verbuf[16]; // Should be at least 9 bytes in size
938                 mbedtls_version_get_string(verbuf);
939                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "mbedTLS lib version %s module was compiled for " MBEDTLS_VERSION_STRING, verbuf);
940
941                 if (!ctr_drbg.Seed(entropy))
942                         throw ModuleException("CTR DRBG seed failed");
943                 ReadProfiles();
944         }
945
946         void OnModuleRehash(User* user, const std::string &param) CXX11_OVERRIDE
947         {
948                 if (!irc::equals(param, "tls") && !irc::equals(param, "ssl"))
949                         return;
950
951                 try
952                 {
953                         ReadProfiles();
954                         ServerInstance->SNO->WriteToSnoMask('a', "mbedTLS TLS (SSL) profiles have been reloaded.");
955                 }
956                 catch (ModuleException& ex)
957                 {
958                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, ex.GetReason() + " Not applying settings.");
959                 }
960         }
961
962         void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
963         {
964                 if (type != ExtensionItem::EXT_USER)
965                         return;
966
967                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
968                 if ((user) && (user->eh.GetModHook(this)))
969                 {
970                         // User is using TLS (SSL), they're a local user, and they're using our IOHook.
971                         // Potentially there could be multiple TLS (SSL) modules loaded at once on different ports.
972                         ServerInstance->Users.QuitUser(user, "mbedTLS module unloading");
973                 }
974         }
975
976         ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
977         {
978                 const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this));
979                 if ((iohook) && (!iohook->IsHandshakeDone()))
980                         return MOD_RES_DENY;
981                 return MOD_RES_PASSTHRU;
982         }
983
984         Version GetVersion() CXX11_OVERRIDE
985         {
986                 return Version("Allows TLS (SSL) encrypted connections using the mbedTLS library.", VF_VENDOR);
987         }
988 };
989
990 MODULE_INIT(ModuleSSLmbedTLS)