]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_ssl_mbedtls.cpp
50bf382666f5a66109ba3a468d059bf03ce0003c
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_mbedtls.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
5  *
6  * This file is part of InspIRCd.  InspIRCd is free software: you can
7  * redistribute it and/or modify it under the terms of the GNU General Public
8  * License as published by the Free Software Foundation, version 2.
9  *
10  * This program is distributed in the hope that it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
13  * details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19
20 /* $LinkerFlags: -lmbedtls */
21
22 #include "inspircd.h"
23 #include "modules/ssl.h"
24
25 #include <mbedtls/ctr_drbg.h>
26 #include <mbedtls/dhm.h>
27 #include <mbedtls/ecp.h>
28 #include <mbedtls/entropy.h>
29 #include <mbedtls/error.h>
30 #include <mbedtls/md.h>
31 #include <mbedtls/pk.h>
32 #include <mbedtls/ssl.h>
33 #include <mbedtls/ssl_ciphersuites.h>
34 #include <mbedtls/version.h>
35 #include <mbedtls/x509.h>
36 #include <mbedtls/x509_crt.h>
37 #include <mbedtls/x509_crl.h>
38
39 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
40 #include <mbedtls/debug.h>
41 #endif
42
43 namespace mbedTLS
44 {
45         class Exception : public ModuleException
46         {
47          public:
48                 Exception(const std::string& reason)
49                         : ModuleException(reason) { }
50         };
51
52         std::string ErrorToString(int errcode)
53         {
54                 char buf[256];
55                 mbedtls_strerror(errcode, buf, sizeof(buf));
56                 return buf;
57         }
58
59         void ThrowOnError(int errcode, const char* msg)
60         {
61                 if (errcode != 0)
62                 {
63                         std::string reason = msg;
64                         reason.append(" :").append(ErrorToString(errcode));
65                         throw Exception(reason);
66                 }
67         }
68
69         template <typename T, void (*init)(T*), void (*deinit)(T*)>
70         class RAIIObj
71         {
72                 T obj;
73
74          public:
75                 RAIIObj()
76                 {
77                         init(&obj);
78                 }
79
80                 ~RAIIObj()
81                 {
82                         deinit(&obj);
83                 }
84
85                 T* get() { return &obj; }
86                 const T* get() const { return &obj; }
87         };
88
89         typedef RAIIObj<mbedtls_entropy_context, mbedtls_entropy_init, mbedtls_entropy_free> Entropy;
90
91         class CTRDRBG : private RAIIObj<mbedtls_ctr_drbg_context, mbedtls_ctr_drbg_init, mbedtls_ctr_drbg_free>
92         {
93          public:
94                 bool Seed(Entropy& entropy)
95                 {
96                         return (mbedtls_ctr_drbg_seed(get(), mbedtls_entropy_func, entropy.get(), NULL, 0) == 0);
97                 }
98
99                 void SetupConf(mbedtls_ssl_config* conf)
100                 {
101                         mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, get());
102                 }
103         };
104
105         class DHParams : public RAIIObj<mbedtls_dhm_context, mbedtls_dhm_init, mbedtls_dhm_free>
106         {
107          public:
108                 void set(const std::string& dhstr)
109                 {
110                         // Last parameter is buffer size, must include the terminating null
111                         int ret = mbedtls_dhm_parse_dhm(get(), reinterpret_cast<const unsigned char*>(dhstr.c_str()), dhstr.size()+1);
112                         ThrowOnError(ret, "Unable to import DH params");
113                 }
114         };
115
116         class X509Key : public RAIIObj<mbedtls_pk_context, mbedtls_pk_init, mbedtls_pk_free>
117         {
118          public:
119                 /** Import */
120                 X509Key(const std::string& keystr)
121                 {
122                         int ret = mbedtls_pk_parse_key(get(), reinterpret_cast<const unsigned char*>(keystr.c_str()), keystr.size()+1, NULL, 0);
123                         ThrowOnError(ret, "Unable to import private key");
124                 }
125         };
126
127         class Ciphersuites
128         {
129                 std::vector<int> list;
130
131          public:
132                 Ciphersuites(const std::string& str)
133                 {
134                         // mbedTLS uses the ciphersuite format "TLS-ECDHE-RSA-WITH-AES-128-GCM-SHA256" internally.
135                         // This is a bit verbose, so we make life a bit simpler for admins by not requiring them to supply the static parts.
136                         irc::sepstream ss(str, ':');
137                         for (std::string token; ss.GetToken(token); )
138                         {
139                                 // Prepend "TLS-" if not there
140                                 if (token.compare(0, 4, "TLS-", 4))
141                                         token.insert(0, "TLS-");
142
143                                 const int id = mbedtls_ssl_get_ciphersuite_id(token.c_str());
144                                 if (!id)
145                                         throw Exception("Unknown ciphersuite " + token);
146                                 list.push_back(id);
147                         }
148                         list.push_back(0);
149                 }
150
151                 const int* get() const { return &list.front(); }
152                 bool empty() const { return (list.size() <= 1); }
153         };
154
155         class Curves
156         {
157                 std::vector<mbedtls_ecp_group_id> list;
158
159          public:
160                 Curves(const std::string& str)
161                 {
162                         irc::sepstream ss(str, ':');
163                         for (std::string token; ss.GetToken(token); )
164                         {
165                                 const mbedtls_ecp_curve_info* curve = mbedtls_ecp_curve_info_from_name(token.c_str());
166                                 if (!curve)
167                                         throw Exception("Unknown curve " + token);
168                                 list.push_back(curve->grp_id);
169                         }
170                         list.push_back(MBEDTLS_ECP_DP_NONE);
171                 }
172
173                 const mbedtls_ecp_group_id* get() const { return &list.front(); }
174                 bool empty() const { return (list.size() <= 1); }
175         };
176
177         class X509CertList : public RAIIObj<mbedtls_x509_crt, mbedtls_x509_crt_init, mbedtls_x509_crt_free>
178         {
179          public:
180                 /** Import or create empty */
181                 X509CertList(const std::string& certstr, bool allowempty = false)
182                 {
183                         if ((allowempty) && (certstr.empty()))
184                                 return;
185                         int ret = mbedtls_x509_crt_parse(get(), reinterpret_cast<const unsigned char*>(certstr.c_str()), certstr.size()+1);
186                         ThrowOnError(ret, "Unable to load certificates");
187                 }
188
189                 bool empty() const { return (get()->raw.p != NULL); }
190         };
191
192         class X509CRL : public RAIIObj<mbedtls_x509_crl, mbedtls_x509_crl_init, mbedtls_x509_crl_free>
193         {
194          public:
195                 X509CRL(const std::string& crlstr)
196                 {
197                         if (crlstr.empty())
198                                 return;
199                         int ret = mbedtls_x509_crl_parse(get(), reinterpret_cast<const unsigned char*>(crlstr.c_str()), crlstr.size()+1);
200                         ThrowOnError(ret, "Unable to load CRL");
201                 }
202         };
203
204         class X509Credentials
205         {
206                 /** Private key
207                  */
208                 X509Key key;
209
210                 /** Certificate list, presented to the peer
211                  */
212                 X509CertList certs;
213
214          public:
215                 X509Credentials(const std::string& certstr, const std::string& keystr)
216                         : key(keystr)
217                         , certs(certstr)
218                 {
219                         // Verify that one of the certs match the private key
220                         bool found = false;
221                         for (mbedtls_x509_crt* cert = certs.get(); cert; cert = cert->next)
222                         {
223                                 if (mbedtls_pk_check_pair(&cert->pk, key.get()) == 0)
224                                 {
225                                         found = true;
226                                         break;
227                                 }
228                         }
229                         if (!found)
230                                 throw Exception("Public/private key pair does not match");
231                 }
232
233                 mbedtls_pk_context* getkey() { return key.get(); }
234                 mbedtls_x509_crt* getcerts() { return certs.get(); }
235         };
236
237         class Context
238         {
239                 mbedtls_ssl_config conf;
240
241 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
242                 static void DebugLogFunc(void* userptr, int level, const char* file, int line, const char* msg)
243                 {
244                         // Remove trailing \n
245                         size_t len = strlen(msg);
246                         if ((len > 0) && (msg[len-1] == '\n'))
247                                 len--;
248                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "%s:%d %.*s", file, line, len, msg);
249                 }
250 #endif
251
252          public:
253                 Context(CTRDRBG& ctrdrbg, unsigned int endpoint)
254                 {
255                         mbedtls_ssl_config_init(&conf);
256 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
257                         mbedtls_debug_set_threshold(INT_MAX);
258                         mbedtls_ssl_conf_dbg(&conf, DebugLogFunc, NULL);
259 #endif
260
261                         // TODO: check ret of mbedtls_ssl_config_defaults
262                         mbedtls_ssl_config_defaults(&conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
263                         ctrdrbg.SetupConf(&conf);
264                 }
265
266                 ~Context()
267                 {
268                         mbedtls_ssl_config_free(&conf);
269                 }
270
271                 void SetMinDHBits(unsigned int mindh)
272                 {
273                         mbedtls_ssl_conf_dhm_min_bitlen(&conf, mindh);
274                 }
275
276                 void SetDHParams(DHParams& dh)
277                 {
278                         mbedtls_ssl_conf_dh_param_ctx(&conf, dh.get());
279                 }
280
281                 void SetX509CertAndKey(X509Credentials& x509cred)
282                 {
283                         mbedtls_ssl_conf_own_cert(&conf, x509cred.getcerts(), x509cred.getkey());
284                 }
285
286                 void SetCiphersuites(const Ciphersuites& ciphersuites)
287                 {
288                         mbedtls_ssl_conf_ciphersuites(&conf, ciphersuites.get());
289                 }
290
291                 void SetCurves(const Curves& curves)
292                 {
293                         mbedtls_ssl_conf_curves(&conf, curves.get());
294                 }
295
296                 void SetVersion(int minver, int maxver)
297                 {
298                         // SSL v3 support cannot be enabled
299                         if (minver)
300                                 mbedtls_ssl_conf_min_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, minver);
301                         if (maxver)
302                                 mbedtls_ssl_conf_max_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, maxver);
303                 }
304
305                 void SetCA(X509CertList& certs, X509CRL& crl)
306                 {
307                         mbedtls_ssl_conf_ca_chain(&conf, certs.get(), crl.get());
308                 }
309
310                 void SetOptionalVerifyCert()
311                 {
312                         mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
313                 }
314
315                 const mbedtls_ssl_config* GetConf() const { return &conf; }
316         };
317
318         class Hash
319         {
320                 const mbedtls_md_info_t* md;
321
322                 /** Buffer where cert hashes are written temporarily
323                  */
324                 mutable std::vector<unsigned char> buf;
325
326          public:
327                 Hash(std::string hashstr)
328                 {
329                         std::transform(hashstr.begin(), hashstr.end(), hashstr.begin(), ::toupper);
330                         md = mbedtls_md_info_from_string(hashstr.c_str());
331                         if (!md)
332                                 throw Exception("Unknown hash: " + hashstr);
333
334                         buf.resize(mbedtls_md_get_size(md));
335                 }
336
337                 std::string hash(const unsigned char* input, size_t length) const
338                 {
339                         mbedtls_md(md, input, length, &buf.front());
340                         return BinToHex(&buf.front(), buf.size());
341                 }
342         };
343
344         class Profile : public refcountbase
345         {
346                 /** Name of this profile
347                  */
348                 const std::string name;
349
350                 X509Credentials x509cred;
351
352                 /** Ciphersuites to use
353                  */
354                 Ciphersuites ciphersuites;
355
356                 /** Curves accepted for use in ECDHE and in the peer's end-entity certificate
357                  */
358                 Curves curves;
359
360                 Context serverctx;
361                 Context clientctx;
362
363                 DHParams dhparams;
364
365                 X509CertList cacerts;
366
367                 X509CRL crl;
368
369                 /** Hashing algorithm to use when generating certificate fingerprints
370                  */
371                 Hash hash;
372
373                 /** Rough max size of records to send
374                  */
375                 const unsigned int outrecsize;
376
377                 Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr,
378                                 const std::string& dhstr, unsigned int mindh, const std::string& hashstr,
379                                 const std::string& ciphersuitestr, const std::string& curvestr,
380                                 const std::string& castr, const std::string& crlstr,
381                                 unsigned int recsize,
382                                 CTRDRBG& ctrdrbg,
383                                 int minver, int maxver,
384                                 bool requestclientcert
385                                 )
386                         : name(profilename)
387                         , x509cred(certstr, keystr)
388                         , ciphersuites(ciphersuitestr)
389                         , curves(curvestr)
390                         , serverctx(ctrdrbg, MBEDTLS_SSL_IS_SERVER)
391                         , clientctx(ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
392                         , cacerts(castr, true)
393                         , crl(crlstr)
394                         , hash(hashstr)
395                         , outrecsize(recsize)
396                 {
397                         serverctx.SetX509CertAndKey(x509cred);
398                         clientctx.SetX509CertAndKey(x509cred);
399                         clientctx.SetMinDHBits(mindh);
400
401                         if (!ciphersuites.empty())
402                         {
403                                 serverctx.SetCiphersuites(ciphersuites);
404                                 clientctx.SetCiphersuites(ciphersuites);
405                         }
406
407                         if (!curves.empty())
408                         {
409                                 serverctx.SetCurves(curves);
410                                 clientctx.SetCurves(curves);
411                         }
412
413                         serverctx.SetVersion(minver, maxver);
414                         clientctx.SetVersion(minver, maxver);
415
416                         if (!dhstr.empty())
417                         {
418                                 dhparams.set(dhstr);
419                                 serverctx.SetDHParams(dhparams);
420                         }
421
422                         clientctx.SetOptionalVerifyCert();
423                         clientctx.SetCA(cacerts, crl);
424                         // The default for servers is to not request a client certificate from the peer
425                         if (requestclientcert)
426                         {
427                                 serverctx.SetOptionalVerifyCert();
428                                 serverctx.SetCA(cacerts, crl);
429                         }
430                 }
431
432                 static std::string ReadFile(const std::string& filename)
433                 {
434                         FileReader reader(filename);
435                         std::string ret = reader.GetString();
436                         if (ret.empty())
437                                 throw Exception("Cannot read file " + filename);
438                         return ret;
439                 }
440
441          public:
442                 static reference<Profile> Create(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
443                 {
444                         const std::string certstr = ReadFile(tag->getString("certfile", "cert.pem"));
445                         const std::string keystr = ReadFile(tag->getString("keyfile", "key.pem"));
446                         const std::string dhstr = ReadFile(tag->getString("dhfile", "dhparams.pem"));
447
448                         const std::string ciphersuitestr = tag->getString("ciphersuites");
449                         const std::string curvestr = tag->getString("curves");
450                         unsigned int mindh = tag->getInt("mindhbits", 2048);
451                         std::string hashstr = tag->getString("hash", "sha256");
452
453                         std::string crlstr;
454                         std::string castr = tag->getString("cafile");
455                         if (!castr.empty())
456                         {
457                                 castr = ReadFile(castr);
458                                 crlstr = tag->getString("crlfile");
459                                 if (!crlstr.empty())
460                                         crlstr = ReadFile(crlstr);
461                         }
462
463                         int minver = tag->getInt("minver");
464                         int maxver = tag->getInt("maxver");
465                         unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
466                         const bool requestclientcert = tag->getBool("requestclientcert", true);
467                         return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver, requestclientcert);
468                 }
469
470                 /** Set up the given session with the settings in this profile
471                  */
472                 void SetupClientSession(mbedtls_ssl_context* sess)
473                 {
474                         mbedtls_ssl_setup(sess, clientctx.GetConf());
475                 }
476
477                 void SetupServerSession(mbedtls_ssl_context* sess)
478                 {
479                         mbedtls_ssl_setup(sess, serverctx.GetConf());
480                 }
481
482                 const std::string& GetName() const { return name; }
483                 X509Credentials& GetX509Credentials() { return x509cred; }
484                 unsigned int GetOutgoingRecordSize() const { return outrecsize; }
485                 const Hash& GetHash() const { return hash; }
486         };
487 }
488
489 class mbedTLSIOHook : public SSLIOHook
490 {
491         enum Status
492         {
493                 ISSL_NONE,
494                 ISSL_HANDSHAKING,
495                 ISSL_HANDSHAKEN
496         };
497
498         mbedtls_ssl_context sess;
499         Status status;
500         reference<mbedTLS::Profile> profile;
501
502         void CloseSession()
503         {
504                 if (status == ISSL_NONE)
505                         return;
506
507                 mbedtls_ssl_close_notify(&sess);
508                 mbedtls_ssl_free(&sess);
509                 certificate = NULL;
510                 status = ISSL_NONE;
511         }
512
513         // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed
514         int Handshake(StreamSocket* sock)
515         {
516                 int ret = mbedtls_ssl_handshake(&sess);
517                 if (ret == 0)
518                 {
519                         // Change the seesion state
520                         this->status = ISSL_HANDSHAKEN;
521
522                         VerifyCertificate();
523
524                         // Finish writing, if any left
525                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
526
527                         return 1;
528                 }
529
530                 this->status = ISSL_HANDSHAKING;
531                 if (ret == MBEDTLS_ERR_SSL_WANT_READ)
532                 {
533                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
534                         return 0;
535                 }
536                 else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
537                 {
538                         SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
539                         return 0;
540                 }
541
542                 sock->SetError("Handshake Failed - " + mbedTLS::ErrorToString(ret));
543                 CloseSession();
544                 return -1;
545         }
546
547         // Returns 1 if application I/O should proceed, 0 if it must wait for the underlying protocol to progress, -1 on fatal error
548         int PrepareIO(StreamSocket* sock)
549         {
550                 if (status == ISSL_HANDSHAKEN)
551                         return 1;
552                 else if (status == ISSL_HANDSHAKING)
553                 {
554                         // The handshake isn't finished, try to finish it
555                         return Handshake(sock);
556                 }
557
558                 CloseSession();
559                 sock->SetError("No SSL session");
560                 return -1;
561         }
562
563         void VerifyCertificate()
564         {
565                 this->certificate = new ssl_cert;
566                 const mbedtls_x509_crt* const cert = mbedtls_ssl_get_peer_cert(&sess);
567                 if (!cert)
568                 {
569                         certificate->error = "No client certificate sent";
570                         return;
571                 }
572
573                 // If there is a certificate we can always generate a fingerprint
574                 certificate->fingerprint = profile->GetHash().hash(cert->raw.p, cert->raw.len);
575
576                 // At this point mbedTLS verified the cert already, we just need to check the results
577                 const uint32_t flags = mbedtls_ssl_get_verify_result(&sess);
578                 if (flags == 0xFFFFFFFF)
579                 {
580                         certificate->error = "Internal error during verification";
581                         return;
582                 }
583
584                 if (flags == 0)
585                 {
586                         // Verification succeeded
587                         certificate->trusted = true;
588                 }
589                 else
590                 {
591                         // Verification failed
592                         certificate->trusted = false;
593                         if ((flags & MBEDTLS_X509_BADCERT_EXPIRED) || (flags & MBEDTLS_X509_BADCERT_FUTURE))
594                                 certificate->error = "Not activated, or expired certificate";
595                 }
596
597                 certificate->unknownsigner = (flags & MBEDTLS_X509_BADCERT_NOT_TRUSTED);
598                 certificate->revoked = (flags & MBEDTLS_X509_BADCERT_REVOKED);
599                 certificate->invalid = ((flags & MBEDTLS_X509_BADCERT_BAD_KEY) || (flags & MBEDTLS_X509_BADCERT_BAD_MD) || (flags & MBEDTLS_X509_BADCERT_BAD_PK));
600
601                 GetDNString(&cert->subject, certificate->dn);
602                 GetDNString(&cert->issuer, certificate->issuer);
603         }
604
605         static void GetDNString(const mbedtls_x509_name* x509name, std::string& out)
606         {
607                 char buf[512];
608                 const int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), x509name);
609                 if (ret <= 0)
610                         return;
611
612                 out.assign(buf, ret);
613         }
614
615         static int Pull(void* userptr, unsigned char* buffer, size_t size)
616         {
617                 StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
618                 if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
619                         return MBEDTLS_ERR_SSL_WANT_READ;
620
621                 const int ret = SocketEngine::Recv(sock, reinterpret_cast<char*>(buffer), size, 0);
622                 if (ret < (int)size)
623                 {
624                         SocketEngine::ChangeEventMask(sock, FD_READ_WILL_BLOCK);
625                         if ((ret == -1) && (SocketEngine::IgnoreError()))
626                                 return MBEDTLS_ERR_SSL_WANT_READ;
627                 }
628                 return ret;
629         }
630
631         static int Push(void* userptr, const unsigned char* buffer, size_t size)
632         {
633                 StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
634                 if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
635                         return MBEDTLS_ERR_SSL_WANT_WRITE;
636
637                 const int ret = SocketEngine::Send(sock, buffer, size, 0);
638                 if (ret < (int)size)
639                 {
640                         SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK);
641                         if ((ret == -1) && (SocketEngine::IgnoreError()))
642                                 return MBEDTLS_ERR_SSL_WANT_WRITE;
643                 }
644                 return ret;
645         }
646
647  public:
648         mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver, mbedTLS::Profile* sslprofile)
649                 : SSLIOHook(hookprov)
650                 , status(ISSL_NONE)
651                 , profile(sslprofile)
652         {
653                 mbedtls_ssl_init(&sess);
654                 if (isserver)
655                         profile->SetupServerSession(&sess);
656                 else
657                         profile->SetupClientSession(&sess);
658
659                 mbedtls_ssl_set_bio(&sess, reinterpret_cast<void*>(sock), Push, Pull, NULL);
660
661                 sock->AddIOHook(this);
662                 Handshake(sock);
663         }
664
665         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
666         {
667                 CloseSession();
668         }
669
670         int OnStreamSocketRead(StreamSocket* sock, std::string& recvq) CXX11_OVERRIDE
671         {
672                 // Finish handshake if needed
673                 int prepret = PrepareIO(sock);
674                 if (prepret <= 0)
675                         return prepret;
676
677                 // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN.
678                 char* const readbuf = ServerInstance->GetReadBuffer();
679                 const size_t readbufsize = ServerInstance->Config->NetBufferSize;
680                 int ret = mbedtls_ssl_read(&sess, reinterpret_cast<unsigned char*>(readbuf), readbufsize);
681                 if (ret > 0)
682                 {
683                         recvq.append(readbuf, ret);
684
685                         // Schedule a read if there is still data in the mbedTLS buffer
686                         if (mbedtls_ssl_get_bytes_avail(&sess) > 0)
687                                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_READ);
688                         return 1;
689                 }
690                 else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
691                 {
692                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
693                         return 0;
694                 }
695                 else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
696                 {
697                         SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
698                         return 0;
699                 }
700                 else if (ret == 0)
701                 {
702                         sock->SetError("Connection closed");
703                         CloseSession();
704                         return -1;
705                 }
706                 else // error or MBEDTLS_ERR_SSL_CLIENT_RECONNECT which we treat as an error
707                 {
708                         sock->SetError(mbedTLS::ErrorToString(ret));
709                         CloseSession();
710                         return -1;
711                 }
712         }
713
714         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
715         {
716                 // Finish handshake if needed
717                 int prepret = PrepareIO(sock);
718                 if (prepret <= 0)
719                         return prepret;
720
721                 // Session is ready for transferring application data
722                 while (!sendq.empty())
723                 {
724                         FlattenSendQueue(sendq, profile->GetOutgoingRecordSize());
725                         const StreamSocket::SendQueue::Element& buffer = sendq.front();
726                         int ret = mbedtls_ssl_write(&sess, reinterpret_cast<const unsigned char*>(buffer.data()), buffer.length());
727                         if (ret == (int)buffer.length())
728                         {
729                                 // Wrote entire record, continue sending
730                                 sendq.pop_front();
731                         }
732                         else if (ret > 0)
733                         {
734                                 sendq.erase_front(ret);
735                                 SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
736                                 return 0;
737                         }
738                         else if (ret == 0)
739                         {
740                                 sock->SetError("Connection closed");
741                                 CloseSession();
742                                 return -1;
743                         }
744                         else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
745                         {
746                                 SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
747                                 return 0;
748                         }
749                         else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
750                         {
751                                 SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
752                                 return 0;
753                         }
754                         else
755                         {
756                                 sock->SetError(mbedTLS::ErrorToString(ret));
757                                 CloseSession();
758                                 return -1;
759                         }
760                 }
761
762                 SocketEngine::ChangeEventMask(sock, FD_WANT_NO_WRITE);
763                 return 1;
764         }
765
766         void GetCiphersuite(std::string& out) const CXX11_OVERRIDE
767         {
768                 if (!IsHandshakeDone())
769                         return;
770                 out.append(mbedtls_ssl_get_version(&sess)).push_back('-');
771
772                 // All mbedTLS ciphersuite names currently begin with "TLS-" which provides no useful information so skip it, but be prepared if it changes
773                 const char* const ciphersuitestr = mbedtls_ssl_get_ciphersuite(&sess);
774                 const char prefix[] = "TLS-";
775                 unsigned int skip = sizeof(prefix)-1;
776                 if (strncmp(ciphersuitestr, prefix, sizeof(prefix)-1))
777                         skip = 0;
778                 out.append(ciphersuitestr + skip);
779         }
780
781         bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
782 };
783
784 class mbedTLSIOHookProvider : public refcountbase, public IOHookProvider
785 {
786         reference<mbedTLS::Profile> profile;
787
788  public:
789         mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile* prof)
790                 : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL)
791                 , profile(prof)
792         {
793                 ServerInstance->Modules->AddService(*this);
794         }
795
796         ~mbedTLSIOHookProvider()
797         {
798                 ServerInstance->Modules->DelService(*this);
799         }
800
801         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
802         {
803                 new mbedTLSIOHook(this, sock, true, profile);
804         }
805
806         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
807         {
808                 new mbedTLSIOHook(this, sock, false, profile);
809         }
810 };
811
812 class ModuleSSLmbedTLS : public Module
813 {
814         typedef std::vector<reference<mbedTLSIOHookProvider> > ProfileList;
815
816         mbedTLS::Entropy entropy;
817         mbedTLS::CTRDRBG ctr_drbg;
818         ProfileList profiles;
819
820         void ReadProfiles()
821         {
822                 // First, store all profiles in a new, temporary container. If no problems occur, swap the two
823                 // containers; this way if something goes wrong we can go back and continue using the current profiles,
824                 // avoiding unpleasant situations where no new SSL connections are possible.
825                 ProfileList newprofiles;
826
827                 ConfigTagList tags = ServerInstance->Config->ConfTags("sslprofile");
828                 if (tags.first == tags.second)
829                 {
830                         // No <sslprofile> tags found, create a profile named "mbedtls" from settings in the <mbedtls> block
831                         const std::string defname = "mbedtls";
832                         ConfigTag* tag = ServerInstance->Config->ConfValue(defname);
833                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No <sslprofile> tags found; using settings from the <mbedtls> tag");
834
835                         try
836                         {
837                                 reference<mbedTLS::Profile> profile(mbedTLS::Profile::Create(defname, tag, ctr_drbg));
838                                 newprofiles.push_back(new mbedTLSIOHookProvider(this, profile));
839                         }
840                         catch (CoreException& ex)
841                         {
842                                 throw ModuleException("Error while initializing the default SSL profile - " + ex.GetReason());
843                         }
844                 }
845
846                 for (ConfigIter i = tags.first; i != tags.second; ++i)
847                 {
848                         ConfigTag* tag = i->second;
849                         if (tag->getString("provider") != "mbedtls")
850                                 continue;
851
852                         std::string name = tag->getString("name");
853                         if (name.empty())
854                         {
855                                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring <sslprofile> tag without name at " + tag->getTagLocation());
856                                 continue;
857                         }
858
859                         reference<mbedTLS::Profile> profile;
860                         try
861                         {
862                                 profile = mbedTLS::Profile::Create(name, tag, ctr_drbg);
863                         }
864                         catch (CoreException& ex)
865                         {
866                                 throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
867                         }
868
869                         newprofiles.push_back(new mbedTLSIOHookProvider(this, profile));
870                 }
871
872                 // New profiles are ok, begin using them
873                 // Old profiles are deleted when their refcount drops to zero
874                 profiles.swap(newprofiles);
875         }
876
877  public:
878         void init() CXX11_OVERRIDE
879         {
880                 char verbuf[16]; // Should be at least 9 bytes in size
881                 mbedtls_version_get_string(verbuf);
882                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "mbedTLS lib version %s module was compiled for " MBEDTLS_VERSION_STRING, verbuf);
883
884                 if (!ctr_drbg.Seed(entropy))
885                         throw ModuleException("CTR DRBG seed failed");
886                 ReadProfiles();
887         }
888
889         void OnModuleRehash(User* user, const std::string &param) CXX11_OVERRIDE
890         {
891                 if (param != "ssl")
892                         return;
893
894                 try
895                 {
896                         ReadProfiles();
897                 }
898                 catch (ModuleException& ex)
899                 {
900                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, ex.GetReason() + " Not applying settings.");
901                 }
902         }
903
904         void OnCleanup(int target_type, void* item) CXX11_OVERRIDE
905         {
906                 if (target_type != TYPE_USER)
907                         return;
908
909                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
910                 if ((user) && (user->eh.GetModHook(this)))
911                 {
912                         // User is using SSL, they're a local user, and they're using our IOHook.
913                         // Potentially there could be multiple SSL modules loaded at once on different ports.
914                         ServerInstance->Users.QuitUser(user, "SSL module unloading");
915                 }
916         }
917
918         ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
919         {
920                 const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this));
921                 if ((iohook) && (!iohook->IsHandshakeDone()))
922                         return MOD_RES_DENY;
923                 return MOD_RES_PASSTHRU;
924         }
925
926         Version GetVersion() CXX11_OVERRIDE
927         {
928                 return Version("Provides SSL support via mbedTLS (PolarSSL)", VF_VENDOR);
929         }
930 };
931
932 MODULE_INIT(ModuleSSLmbedTLS)