]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_ssl_mbedtls.cpp
Merge branch 'insp20' into master.
[user/henk/code/inspircd.git] / src / modules / extra / m_ssl_mbedtls.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
5  *
6  * This file is part of InspIRCd.  InspIRCd is free software: you can
7  * redistribute it and/or modify it under the terms of the GNU General Public
8  * License as published by the Free Software Foundation, version 2.
9  *
10  * This program is distributed in the hope that it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
13  * details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19 /// $LinkerFlags: -lmbedtls
20
21 /// $PackageInfo: require_system("darwin") mbedtls
22 /// $PackageInfo: require_system("debian" "9.0") libmbedtls-dev
23 /// $PackageInfo: require_system("ubuntu" "16.04") libmbedtls-dev
24
25
26 #include "inspircd.h"
27 #include "modules/ssl.h"
28
29 #include <mbedtls/ctr_drbg.h>
30 #include <mbedtls/dhm.h>
31 #include <mbedtls/ecp.h>
32 #include <mbedtls/entropy.h>
33 #include <mbedtls/error.h>
34 #include <mbedtls/md.h>
35 #include <mbedtls/pk.h>
36 #include <mbedtls/ssl.h>
37 #include <mbedtls/ssl_ciphersuites.h>
38 #include <mbedtls/version.h>
39 #include <mbedtls/x509.h>
40 #include <mbedtls/x509_crt.h>
41 #include <mbedtls/x509_crl.h>
42
43 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
44 #include <mbedtls/debug.h>
45 #endif
46
47 namespace mbedTLS
48 {
49         class Exception : public ModuleException
50         {
51          public:
52                 Exception(const std::string& reason)
53                         : ModuleException(reason) { }
54         };
55
56         std::string ErrorToString(int errcode)
57         {
58                 char buf[256];
59                 mbedtls_strerror(errcode, buf, sizeof(buf));
60                 return buf;
61         }
62
63         void ThrowOnError(int errcode, const char* msg)
64         {
65                 if (errcode != 0)
66                 {
67                         std::string reason = msg;
68                         reason.append(" :").append(ErrorToString(errcode));
69                         throw Exception(reason);
70                 }
71         }
72
73         template <typename T, void (*init)(T*), void (*deinit)(T*)>
74         class RAIIObj
75         {
76                 T obj;
77
78          public:
79                 RAIIObj()
80                 {
81                         init(&obj);
82                 }
83
84                 ~RAIIObj()
85                 {
86                         deinit(&obj);
87                 }
88
89                 T* get() { return &obj; }
90                 const T* get() const { return &obj; }
91         };
92
93         typedef RAIIObj<mbedtls_entropy_context, mbedtls_entropy_init, mbedtls_entropy_free> Entropy;
94
95         class CTRDRBG : private RAIIObj<mbedtls_ctr_drbg_context, mbedtls_ctr_drbg_init, mbedtls_ctr_drbg_free>
96         {
97          public:
98                 bool Seed(Entropy& entropy)
99                 {
100                         return (mbedtls_ctr_drbg_seed(get(), mbedtls_entropy_func, entropy.get(), NULL, 0) == 0);
101                 }
102
103                 void SetupConf(mbedtls_ssl_config* conf)
104                 {
105                         mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, get());
106                 }
107         };
108
109         class DHParams : public RAIIObj<mbedtls_dhm_context, mbedtls_dhm_init, mbedtls_dhm_free>
110         {
111          public:
112                 void set(const std::string& dhstr)
113                 {
114                         // Last parameter is buffer size, must include the terminating null
115                         int ret = mbedtls_dhm_parse_dhm(get(), reinterpret_cast<const unsigned char*>(dhstr.c_str()), dhstr.size()+1);
116                         ThrowOnError(ret, "Unable to import DH params");
117                 }
118         };
119
120         class X509Key : public RAIIObj<mbedtls_pk_context, mbedtls_pk_init, mbedtls_pk_free>
121         {
122          public:
123                 /** Import */
124                 X509Key(const std::string& keystr)
125                 {
126                         int ret = mbedtls_pk_parse_key(get(), reinterpret_cast<const unsigned char*>(keystr.c_str()), keystr.size()+1, NULL, 0);
127                         ThrowOnError(ret, "Unable to import private key");
128                 }
129         };
130
131         class Ciphersuites
132         {
133                 std::vector<int> list;
134
135          public:
136                 Ciphersuites(const std::string& str)
137                 {
138                         // mbedTLS uses the ciphersuite format "TLS-ECDHE-RSA-WITH-AES-128-GCM-SHA256" internally.
139                         // This is a bit verbose, so we make life a bit simpler for admins by not requiring them to supply the static parts.
140                         irc::sepstream ss(str, ':');
141                         for (std::string token; ss.GetToken(token); )
142                         {
143                                 // Prepend "TLS-" if not there
144                                 if (token.compare(0, 4, "TLS-", 4))
145                                         token.insert(0, "TLS-");
146
147                                 const int id = mbedtls_ssl_get_ciphersuite_id(token.c_str());
148                                 if (!id)
149                                         throw Exception("Unknown ciphersuite " + token);
150                                 list.push_back(id);
151                         }
152                         list.push_back(0);
153                 }
154
155                 const int* get() const { return &list.front(); }
156                 bool empty() const { return (list.size() <= 1); }
157         };
158
159         class Curves
160         {
161                 std::vector<mbedtls_ecp_group_id> list;
162
163          public:
164                 Curves(const std::string& str)
165                 {
166                         irc::sepstream ss(str, ':');
167                         for (std::string token; ss.GetToken(token); )
168                         {
169                                 const mbedtls_ecp_curve_info* curve = mbedtls_ecp_curve_info_from_name(token.c_str());
170                                 if (!curve)
171                                         throw Exception("Unknown curve " + token);
172                                 list.push_back(curve->grp_id);
173                         }
174                         list.push_back(MBEDTLS_ECP_DP_NONE);
175                 }
176
177                 const mbedtls_ecp_group_id* get() const { return &list.front(); }
178                 bool empty() const { return (list.size() <= 1); }
179         };
180
181         class X509CertList : public RAIIObj<mbedtls_x509_crt, mbedtls_x509_crt_init, mbedtls_x509_crt_free>
182         {
183          public:
184                 /** Import or create empty */
185                 X509CertList(const std::string& certstr, bool allowempty = false)
186                 {
187                         if ((allowempty) && (certstr.empty()))
188                                 return;
189                         int ret = mbedtls_x509_crt_parse(get(), reinterpret_cast<const unsigned char*>(certstr.c_str()), certstr.size()+1);
190                         ThrowOnError(ret, "Unable to load certificates");
191                 }
192
193                 bool empty() const { return (get()->raw.p != NULL); }
194         };
195
196         class X509CRL : public RAIIObj<mbedtls_x509_crl, mbedtls_x509_crl_init, mbedtls_x509_crl_free>
197         {
198          public:
199                 X509CRL(const std::string& crlstr)
200                 {
201                         if (crlstr.empty())
202                                 return;
203                         int ret = mbedtls_x509_crl_parse(get(), reinterpret_cast<const unsigned char*>(crlstr.c_str()), crlstr.size()+1);
204                         ThrowOnError(ret, "Unable to load CRL");
205                 }
206         };
207
208         class X509Credentials
209         {
210                 /** Private key
211                  */
212                 X509Key key;
213
214                 /** Certificate list, presented to the peer
215                  */
216                 X509CertList certs;
217
218          public:
219                 X509Credentials(const std::string& certstr, const std::string& keystr)
220                         : key(keystr)
221                         , certs(certstr)
222                 {
223                         // Verify that one of the certs match the private key
224                         bool found = false;
225                         for (mbedtls_x509_crt* cert = certs.get(); cert; cert = cert->next)
226                         {
227                                 if (mbedtls_pk_check_pair(&cert->pk, key.get()) == 0)
228                                 {
229                                         found = true;
230                                         break;
231                                 }
232                         }
233                         if (!found)
234                                 throw Exception("Public/private key pair does not match");
235                 }
236
237                 mbedtls_pk_context* getkey() { return key.get(); }
238                 mbedtls_x509_crt* getcerts() { return certs.get(); }
239         };
240
241         class Context
242         {
243                 mbedtls_ssl_config conf;
244
245 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
246                 static void DebugLogFunc(void* userptr, int level, const char* file, int line, const char* msg)
247                 {
248                         // Remove trailing \n
249                         size_t len = strlen(msg);
250                         if ((len > 0) && (msg[len-1] == '\n'))
251                                 len--;
252                         ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "%s:%d %.*s", file, line, len, msg);
253                 }
254 #endif
255
256          public:
257                 Context(CTRDRBG& ctrdrbg, unsigned int endpoint)
258                 {
259                         mbedtls_ssl_config_init(&conf);
260 #ifdef INSPIRCD_MBEDTLS_LIBRARY_DEBUG
261                         mbedtls_debug_set_threshold(INT_MAX);
262                         mbedtls_ssl_conf_dbg(&conf, DebugLogFunc, NULL);
263 #endif
264
265                         // TODO: check ret of mbedtls_ssl_config_defaults
266                         mbedtls_ssl_config_defaults(&conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
267                         ctrdrbg.SetupConf(&conf);
268                 }
269
270                 ~Context()
271                 {
272                         mbedtls_ssl_config_free(&conf);
273                 }
274
275                 void SetMinDHBits(unsigned int mindh)
276                 {
277                         mbedtls_ssl_conf_dhm_min_bitlen(&conf, mindh);
278                 }
279
280                 void SetDHParams(DHParams& dh)
281                 {
282                         mbedtls_ssl_conf_dh_param_ctx(&conf, dh.get());
283                 }
284
285                 void SetX509CertAndKey(X509Credentials& x509cred)
286                 {
287                         mbedtls_ssl_conf_own_cert(&conf, x509cred.getcerts(), x509cred.getkey());
288                 }
289
290                 void SetCiphersuites(const Ciphersuites& ciphersuites)
291                 {
292                         mbedtls_ssl_conf_ciphersuites(&conf, ciphersuites.get());
293                 }
294
295                 void SetCurves(const Curves& curves)
296                 {
297                         mbedtls_ssl_conf_curves(&conf, curves.get());
298                 }
299
300                 void SetVersion(int minver, int maxver)
301                 {
302                         // SSL v3 support cannot be enabled
303                         if (minver)
304                                 mbedtls_ssl_conf_min_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, minver);
305                         if (maxver)
306                                 mbedtls_ssl_conf_max_version(&conf, MBEDTLS_SSL_MAJOR_VERSION_3, maxver);
307                 }
308
309                 void SetCA(X509CertList& certs, X509CRL& crl)
310                 {
311                         mbedtls_ssl_conf_ca_chain(&conf, certs.get(), crl.get());
312                 }
313
314                 void SetOptionalVerifyCert()
315                 {
316                         mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
317                 }
318
319                 const mbedtls_ssl_config* GetConf() const { return &conf; }
320         };
321
322         class Hash
323         {
324                 const mbedtls_md_info_t* md;
325
326                 /** Buffer where cert hashes are written temporarily
327                  */
328                 mutable std::vector<unsigned char> buf;
329
330          public:
331                 Hash(std::string hashstr)
332                 {
333                         std::transform(hashstr.begin(), hashstr.end(), hashstr.begin(), ::toupper);
334                         md = mbedtls_md_info_from_string(hashstr.c_str());
335                         if (!md)
336                                 throw Exception("Unknown hash: " + hashstr);
337
338                         buf.resize(mbedtls_md_get_size(md));
339                 }
340
341                 std::string hash(const unsigned char* input, size_t length) const
342                 {
343                         mbedtls_md(md, input, length, &buf.front());
344                         return BinToHex(&buf.front(), buf.size());
345                 }
346         };
347
348         class Profile
349         {
350                 /** Name of this profile
351                  */
352                 const std::string name;
353
354                 X509Credentials x509cred;
355
356                 /** Ciphersuites to use
357                  */
358                 Ciphersuites ciphersuites;
359
360                 /** Curves accepted for use in ECDHE and in the peer's end-entity certificate
361                  */
362                 Curves curves;
363
364                 Context serverctx;
365                 Context clientctx;
366
367                 DHParams dhparams;
368
369                 X509CertList cacerts;
370
371                 X509CRL crl;
372
373                 /** Hashing algorithm to use when generating certificate fingerprints
374                  */
375                 Hash hash;
376
377                 /** Rough max size of records to send
378                  */
379                 const unsigned int outrecsize;
380
381          public:
382                 struct Config
383                 {
384                         const std::string name;
385
386                         CTRDRBG& ctrdrbg;
387
388                         const std::string certstr;
389                         const std::string keystr;
390                         const std::string dhstr;
391
392                         const std::string ciphersuitestr;
393                         const std::string curvestr;
394                         const unsigned int mindh;
395                         const std::string hashstr;
396
397                         std::string crlstr;
398                         std::string castr;
399
400                         const int minver;
401                         const int maxver;
402                         const unsigned int outrecsize;
403                         const bool requestclientcert;
404
405                         Config(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
406                                 : name(profilename)
407                                 , ctrdrbg(ctr_drbg)
408                                 , certstr(ReadFile(tag->getString("certfile", "cert.pem")))
409                                 , keystr(ReadFile(tag->getString("keyfile", "key.pem")))
410                                 , dhstr(ReadFile(tag->getString("dhfile", "dhparams.pem")))
411                                 , ciphersuitestr(tag->getString("ciphersuites"))
412                                 , curvestr(tag->getString("curves"))
413                                 , mindh(tag->getUInt("mindhbits", 2048))
414                                 , hashstr(tag->getString("hash", "sha256"))
415                                 , castr(tag->getString("cafile"))
416                                 , minver(tag->getUInt("minver", 0))
417                                 , maxver(tag->getUInt("maxver", 0))
418                                 , outrecsize(tag->getUInt("outrecsize", 2048, 512, 16384))
419                                 , requestclientcert(tag->getBool("requestclientcert", true))
420                         {
421                                 if (!castr.empty())
422                                 {
423                                         castr = ReadFile(castr);
424                                         crlstr = tag->getString("crlfile");
425                                         if (!crlstr.empty())
426                                                 crlstr = ReadFile(crlstr);
427                                 }
428                         }
429                 };
430
431                 Profile(Config& config)
432                         : name(config.name)
433                         , x509cred(config.certstr, config.keystr)
434                         , ciphersuites(config.ciphersuitestr)
435                         , curves(config.curvestr)
436                         , serverctx(config.ctrdrbg, MBEDTLS_SSL_IS_SERVER)
437                         , clientctx(config.ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
438                         , cacerts(config.castr, true)
439                         , crl(config.crlstr)
440                         , hash(config.hashstr)
441                         , outrecsize(config.outrecsize)
442                 {
443                         serverctx.SetX509CertAndKey(x509cred);
444                         clientctx.SetX509CertAndKey(x509cred);
445                         clientctx.SetMinDHBits(config.mindh);
446
447                         if (!ciphersuites.empty())
448                         {
449                                 serverctx.SetCiphersuites(ciphersuites);
450                                 clientctx.SetCiphersuites(ciphersuites);
451                         }
452
453                         if (!curves.empty())
454                         {
455                                 serverctx.SetCurves(curves);
456                                 clientctx.SetCurves(curves);
457                         }
458
459                         serverctx.SetVersion(config.minver, config.maxver);
460                         clientctx.SetVersion(config.minver, config.maxver);
461
462                         if (!config.dhstr.empty())
463                         {
464                                 dhparams.set(config.dhstr);
465                                 serverctx.SetDHParams(dhparams);
466                         }
467
468                         clientctx.SetOptionalVerifyCert();
469                         clientctx.SetCA(cacerts, crl);
470                         // The default for servers is to not request a client certificate from the peer
471                         if (config.requestclientcert)
472                         {
473                                 serverctx.SetOptionalVerifyCert();
474                                 serverctx.SetCA(cacerts, crl);
475                         }
476                 }
477
478                 static std::string ReadFile(const std::string& filename)
479                 {
480                         FileReader reader(filename);
481                         std::string ret = reader.GetString();
482                         if (ret.empty())
483                                 throw Exception("Cannot read file " + filename);
484                         return ret;
485                 }
486
487                 /** Set up the given session with the settings in this profile
488                  */
489                 void SetupClientSession(mbedtls_ssl_context* sess)
490                 {
491                         mbedtls_ssl_setup(sess, clientctx.GetConf());
492                 }
493
494                 void SetupServerSession(mbedtls_ssl_context* sess)
495                 {
496                         mbedtls_ssl_setup(sess, serverctx.GetConf());
497                 }
498
499                 const std::string& GetName() const { return name; }
500                 X509Credentials& GetX509Credentials() { return x509cred; }
501                 unsigned int GetOutgoingRecordSize() const { return outrecsize; }
502                 const Hash& GetHash() const { return hash; }
503         };
504 }
505
506 class mbedTLSIOHook : public SSLIOHook
507 {
508         enum Status
509         {
510                 ISSL_NONE,
511                 ISSL_HANDSHAKING,
512                 ISSL_HANDSHAKEN
513         };
514
515         mbedtls_ssl_context sess;
516         Status status;
517
518         void CloseSession()
519         {
520                 if (status == ISSL_NONE)
521                         return;
522
523                 mbedtls_ssl_close_notify(&sess);
524                 mbedtls_ssl_free(&sess);
525                 certificate = NULL;
526                 status = ISSL_NONE;
527         }
528
529         // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed
530         int Handshake(StreamSocket* sock)
531         {
532                 int ret = mbedtls_ssl_handshake(&sess);
533                 if (ret == 0)
534                 {
535                         // Change the seesion state
536                         this->status = ISSL_HANDSHAKEN;
537
538                         VerifyCertificate();
539
540                         // Finish writing, if any left
541                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE);
542
543                         return 1;
544                 }
545
546                 this->status = ISSL_HANDSHAKING;
547                 if (ret == MBEDTLS_ERR_SSL_WANT_READ)
548                 {
549                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ | FD_WANT_NO_WRITE);
550                         return 0;
551                 }
552                 else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
553                 {
554                         SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
555                         return 0;
556                 }
557
558                 sock->SetError("Handshake Failed - " + mbedTLS::ErrorToString(ret));
559                 CloseSession();
560                 return -1;
561         }
562
563         // Returns 1 if application I/O should proceed, 0 if it must wait for the underlying protocol to progress, -1 on fatal error
564         int PrepareIO(StreamSocket* sock)
565         {
566                 if (status == ISSL_HANDSHAKEN)
567                         return 1;
568                 else if (status == ISSL_HANDSHAKING)
569                 {
570                         // The handshake isn't finished, try to finish it
571                         return Handshake(sock);
572                 }
573
574                 CloseSession();
575                 sock->SetError("No SSL session");
576                 return -1;
577         }
578
579         void VerifyCertificate()
580         {
581                 this->certificate = new ssl_cert;
582                 const mbedtls_x509_crt* const cert = mbedtls_ssl_get_peer_cert(&sess);
583                 if (!cert)
584                 {
585                         certificate->error = "No client certificate sent";
586                         return;
587                 }
588
589                 // If there is a certificate we can always generate a fingerprint
590                 certificate->fingerprint = GetProfile().GetHash().hash(cert->raw.p, cert->raw.len);
591
592                 // At this point mbedTLS verified the cert already, we just need to check the results
593                 const uint32_t flags = mbedtls_ssl_get_verify_result(&sess);
594                 if (flags == 0xFFFFFFFF)
595                 {
596                         certificate->error = "Internal error during verification";
597                         return;
598                 }
599
600                 if (flags == 0)
601                 {
602                         // Verification succeeded
603                         certificate->trusted = true;
604                 }
605                 else
606                 {
607                         // Verification failed
608                         certificate->trusted = false;
609                         if ((flags & MBEDTLS_X509_BADCERT_EXPIRED) || (flags & MBEDTLS_X509_BADCERT_FUTURE))
610                                 certificate->error = "Not activated, or expired certificate";
611                 }
612
613                 certificate->unknownsigner = (flags & MBEDTLS_X509_BADCERT_NOT_TRUSTED);
614                 certificate->revoked = (flags & MBEDTLS_X509_BADCERT_REVOKED);
615                 certificate->invalid = ((flags & MBEDTLS_X509_BADCERT_BAD_KEY) || (flags & MBEDTLS_X509_BADCERT_BAD_MD) || (flags & MBEDTLS_X509_BADCERT_BAD_PK));
616
617                 GetDNString(&cert->subject, certificate->dn);
618                 GetDNString(&cert->issuer, certificate->issuer);
619         }
620
621         static void GetDNString(const mbedtls_x509_name* x509name, std::string& out)
622         {
623                 char buf[512];
624                 const int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), x509name);
625                 if (ret <= 0)
626                         return;
627
628                 out.assign(buf, ret);
629         }
630
631         static int Pull(void* userptr, unsigned char* buffer, size_t size)
632         {
633                 StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
634                 if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
635                         return MBEDTLS_ERR_SSL_WANT_READ;
636
637                 const int ret = SocketEngine::Recv(sock, reinterpret_cast<char*>(buffer), size, 0);
638                 if (ret < (int)size)
639                 {
640                         SocketEngine::ChangeEventMask(sock, FD_READ_WILL_BLOCK);
641                         if ((ret == -1) && (SocketEngine::IgnoreError()))
642                                 return MBEDTLS_ERR_SSL_WANT_READ;
643                 }
644                 return ret;
645         }
646
647         static int Push(void* userptr, const unsigned char* buffer, size_t size)
648         {
649                 StreamSocket* const sock = reinterpret_cast<StreamSocket*>(userptr);
650                 if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
651                         return MBEDTLS_ERR_SSL_WANT_WRITE;
652
653                 const int ret = SocketEngine::Send(sock, buffer, size, 0);
654                 if (ret < (int)size)
655                 {
656                         SocketEngine::ChangeEventMask(sock, FD_WRITE_WILL_BLOCK);
657                         if ((ret == -1) && (SocketEngine::IgnoreError()))
658                                 return MBEDTLS_ERR_SSL_WANT_WRITE;
659                 }
660                 return ret;
661         }
662
663  public:
664         mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver)
665                 : SSLIOHook(hookprov)
666                 , status(ISSL_NONE)
667         {
668                 mbedtls_ssl_init(&sess);
669                 if (isserver)
670                         GetProfile().SetupServerSession(&sess);
671                 else
672                         GetProfile().SetupClientSession(&sess);
673
674                 mbedtls_ssl_set_bio(&sess, reinterpret_cast<void*>(sock), Push, Pull, NULL);
675
676                 sock->AddIOHook(this);
677                 Handshake(sock);
678         }
679
680         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
681         {
682                 CloseSession();
683         }
684
685         int OnStreamSocketRead(StreamSocket* sock, std::string& recvq) CXX11_OVERRIDE
686         {
687                 // Finish handshake if needed
688                 int prepret = PrepareIO(sock);
689                 if (prepret <= 0)
690                         return prepret;
691
692                 // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN.
693                 char* const readbuf = ServerInstance->GetReadBuffer();
694                 const size_t readbufsize = ServerInstance->Config->NetBufferSize;
695                 int ret = mbedtls_ssl_read(&sess, reinterpret_cast<unsigned char*>(readbuf), readbufsize);
696                 if (ret > 0)
697                 {
698                         recvq.append(readbuf, ret);
699
700                         // Schedule a read if there is still data in the mbedTLS buffer
701                         if (mbedtls_ssl_get_bytes_avail(&sess) > 0)
702                                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_READ);
703                         return 1;
704                 }
705                 else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
706                 {
707                         SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
708                         return 0;
709                 }
710                 else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
711                 {
712                         SocketEngine::ChangeEventMask(sock, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE);
713                         return 0;
714                 }
715                 else if (ret == 0)
716                 {
717                         sock->SetError("Connection closed");
718                         CloseSession();
719                         return -1;
720                 }
721                 else // error or MBEDTLS_ERR_SSL_CLIENT_RECONNECT which we treat as an error
722                 {
723                         sock->SetError(mbedTLS::ErrorToString(ret));
724                         CloseSession();
725                         return -1;
726                 }
727         }
728
729         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
730         {
731                 // Finish handshake if needed
732                 int prepret = PrepareIO(sock);
733                 if (prepret <= 0)
734                         return prepret;
735
736                 // Session is ready for transferring application data
737                 while (!sendq.empty())
738                 {
739                         FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
740                         const StreamSocket::SendQueue::Element& buffer = sendq.front();
741                         int ret = mbedtls_ssl_write(&sess, reinterpret_cast<const unsigned char*>(buffer.data()), buffer.length());
742                         if (ret == (int)buffer.length())
743                         {
744                                 // Wrote entire record, continue sending
745                                 sendq.pop_front();
746                         }
747                         else if (ret > 0)
748                         {
749                                 sendq.erase_front(ret);
750                                 SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
751                                 return 0;
752                         }
753                         else if (ret == 0)
754                         {
755                                 sock->SetError("Connection closed");
756                                 CloseSession();
757                                 return -1;
758                         }
759                         else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE)
760                         {
761                                 SocketEngine::ChangeEventMask(sock, FD_WANT_SINGLE_WRITE);
762                                 return 0;
763                         }
764                         else if (ret == MBEDTLS_ERR_SSL_WANT_READ)
765                         {
766                                 SocketEngine::ChangeEventMask(sock, FD_WANT_POLL_READ);
767                                 return 0;
768                         }
769                         else
770                         {
771                                 sock->SetError(mbedTLS::ErrorToString(ret));
772                                 CloseSession();
773                                 return -1;
774                         }
775                 }
776
777                 SocketEngine::ChangeEventMask(sock, FD_WANT_NO_WRITE);
778                 return 1;
779         }
780
781         void GetCiphersuite(std::string& out) const CXX11_OVERRIDE
782         {
783                 if (!IsHandshakeDone())
784                         return;
785                 out.append(mbedtls_ssl_get_version(&sess)).push_back('-');
786
787                 // All mbedTLS ciphersuite names currently begin with "TLS-" which provides no useful information so skip it, but be prepared if it changes
788                 const char* const ciphersuitestr = mbedtls_ssl_get_ciphersuite(&sess);
789                 const char prefix[] = "TLS-";
790                 unsigned int skip = sizeof(prefix)-1;
791                 if (strncmp(ciphersuitestr, prefix, sizeof(prefix)-1))
792                         skip = 0;
793                 out.append(ciphersuitestr + skip);
794         }
795
796         bool GetServerName(std::string& out) const CXX11_OVERRIDE
797         {
798                 // TODO: Implement SNI support.
799                 return false;
800         }
801
802         mbedTLS::Profile& GetProfile();
803         bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
804 };
805
806 class mbedTLSIOHookProvider : public IOHookProvider
807 {
808         mbedTLS::Profile profile;
809
810  public:
811         mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile::Config& config)
812                 : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL)
813                 , profile(config)
814         {
815                 ServerInstance->Modules->AddService(*this);
816         }
817
818         ~mbedTLSIOHookProvider()
819         {
820                 ServerInstance->Modules->DelService(*this);
821         }
822
823         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
824         {
825                 new mbedTLSIOHook(this, sock, true);
826         }
827
828         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
829         {
830                 new mbedTLSIOHook(this, sock, false);
831         }
832
833         mbedTLS::Profile& GetProfile() { return profile; }
834 };
835
836 mbedTLS::Profile& mbedTLSIOHook::GetProfile()
837 {
838         IOHookProvider* hookprov = prov;
839         return static_cast<mbedTLSIOHookProvider*>(hookprov)->GetProfile();
840 }
841
842 class ModuleSSLmbedTLS : public Module
843 {
844         typedef std::vector<reference<mbedTLSIOHookProvider> > ProfileList;
845
846         mbedTLS::Entropy entropy;
847         mbedTLS::CTRDRBG ctr_drbg;
848         ProfileList profiles;
849
850         void ReadProfiles()
851         {
852                 // First, store all profiles in a new, temporary container. If no problems occur, swap the two
853                 // containers; this way if something goes wrong we can go back and continue using the current profiles,
854                 // avoiding unpleasant situations where no new SSL connections are possible.
855                 ProfileList newprofiles;
856
857                 ConfigTagList tags = ServerInstance->Config->ConfTags("sslprofile");
858                 if (tags.first == tags.second)
859                 {
860                         // No <sslprofile> tags found, create a profile named "mbedtls" from settings in the <mbedtls> block
861                         const std::string defname = "mbedtls";
862                         ConfigTag* tag = ServerInstance->Config->ConfValue(defname);
863                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "No <sslprofile> tags found; using settings from the <mbedtls> tag");
864
865                         try
866                         {
867                                 mbedTLS::Profile::Config profileconfig(defname, tag, ctr_drbg);
868                                 newprofiles.push_back(new mbedTLSIOHookProvider(this, profileconfig));
869                         }
870                         catch (CoreException& ex)
871                         {
872                                 throw ModuleException("Error while initializing the default SSL profile - " + ex.GetReason());
873                         }
874                 }
875
876                 for (ConfigIter i = tags.first; i != tags.second; ++i)
877                 {
878                         ConfigTag* tag = i->second;
879                         if (tag->getString("provider") != "mbedtls")
880                                 continue;
881
882                         std::string name = tag->getString("name");
883                         if (name.empty())
884                         {
885                                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Ignoring <sslprofile> tag without name at " + tag->getTagLocation());
886                                 continue;
887                         }
888
889                         reference<mbedTLSIOHookProvider> prov;
890                         try
891                         {
892                                 mbedTLS::Profile::Config profileconfig(name, tag, ctr_drbg);
893                                 prov = new mbedTLSIOHookProvider(this, profileconfig);
894                         }
895                         catch (CoreException& ex)
896                         {
897                                 throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
898                         }
899
900                         newprofiles.push_back(prov);
901                 }
902
903                 // New profiles are ok, begin using them
904                 // Old profiles are deleted when their refcount drops to zero
905                 for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
906                 {
907                         mbedTLSIOHookProvider& prov = **i;
908                         ServerInstance->Modules.DelService(prov);
909                 }
910
911                 profiles.swap(newprofiles);
912         }
913
914  public:
915         void init() CXX11_OVERRIDE
916         {
917                 char verbuf[16]; // Should be at least 9 bytes in size
918                 mbedtls_version_get_string(verbuf);
919                 ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "mbedTLS lib version %s module was compiled for " MBEDTLS_VERSION_STRING, verbuf);
920
921                 if (!ctr_drbg.Seed(entropy))
922                         throw ModuleException("CTR DRBG seed failed");
923                 ReadProfiles();
924         }
925
926         void OnModuleRehash(User* user, const std::string &param) CXX11_OVERRIDE
927         {
928                 if (param != "ssl")
929                         return;
930
931                 try
932                 {
933                         ReadProfiles();
934                 }
935                 catch (ModuleException& ex)
936                 {
937                         ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, ex.GetReason() + " Not applying settings.");
938                 }
939         }
940
941         void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
942         {
943                 if (type != ExtensionItem::EXT_USER)
944                         return;
945
946                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
947                 if ((user) && (user->eh.GetModHook(this)))
948                 {
949                         // User is using SSL, they're a local user, and they're using our IOHook.
950                         // Potentially there could be multiple SSL modules loaded at once on different ports.
951                         ServerInstance->Users.QuitUser(user, "SSL module unloading");
952                 }
953         }
954
955         ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
956         {
957                 const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this));
958                 if ((iohook) && (!iohook->IsHandshakeDone()))
959                         return MOD_RES_DENY;
960                 return MOD_RES_PASSTHRU;
961         }
962
963         Version GetVersion() CXX11_OVERRIDE
964         {
965                 return Version("Provides SSL support via mbedTLS (PolarSSL)", VF_VENDOR);
966         }
967 };
968
969 MODULE_INIT(ModuleSSLmbedTLS)