]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/m_websocket.cpp
ee1c00e97fe7b0bfd7473efb9687a686eded5b64
[user/henk/code/inspircd.git] / src / modules / m_websocket.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
5  *
6  * This file is part of InspIRCd.  InspIRCd is free software: you can
7  * redistribute it and/or modify it under the terms of the GNU General Public
8  * License as published by the Free Software Foundation, version 2.
9  *
10  * This program is distributed in the hope that it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
13  * details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19 /// $CompilerFlags: -Ivendor_directory("utfcpp")
20
21
22 #include "inspircd.h"
23 #include "iohook.h"
24 #include "modules/hash.h"
25
26 #include <utf8.h>
27
28 typedef std::vector<std::string> OriginList;
29
30 static const char MagicGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
31 static const char whitespace[] = " \t\r\n";
32 static dynamic_reference_nocheck<HashProvider>* sha1;
33
34 struct WebSocketConfig
35 {
36         typedef std::vector<std::string> ProxyRanges;
37
38         // The HTTP origins that can connect to the server.
39         OriginList allowedorigins;
40
41         // The IP ranges which send trustworthy X-Real-IP or X-Forwarded-For headers.
42         ProxyRanges proxyranges;
43
44         // Whether to send as UTF-8 text instead of binary data.
45         bool sendastext;
46 };
47
48 class WebSocketHookProvider : public IOHookProvider
49 {
50  public:
51         WebSocketConfig config;
52         WebSocketHookProvider(Module* mod)
53                 : IOHookProvider(mod, "websocket", IOHookProvider::IOH_UNKNOWN, true)
54         {
55         }
56
57         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE;
58
59         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
60         {
61         }
62 };
63
64 class WebSocketHook : public IOHookMiddle
65 {
66         class HTTPHeaderFinder
67         {
68                 std::string::size_type bpos;
69                 std::string::size_type len;
70
71          public:
72                 bool Find(const std::string& req, const char* header, std::string::size_type headerlen, std::string::size_type maxpos)
73                 {
74                         std::string::size_type keybegin = req.find(header);
75                         if ((keybegin == std::string::npos) || (keybegin > maxpos) || (keybegin == 0) || (req[keybegin-1] != '\n'))
76                                 return false;
77
78                         keybegin += headerlen;
79
80                         bpos = req.find_first_not_of(whitespace, keybegin, sizeof(whitespace)-1);
81                         if ((bpos == std::string::npos) || (bpos > maxpos))
82                                 return false;
83
84                         const std::string::size_type epos = req.find_first_of(whitespace, bpos, sizeof(whitespace)-1);
85                         len = epos - bpos;
86
87                         return true;
88                 }
89
90                 std::string ExtractValue(const std::string& req) const
91                 {
92                         return std::string(req, bpos, len);
93                 }
94         };
95
96         enum OpCode
97         {
98                 OP_CONTINUATION = 0x00,
99                 OP_TEXT = 0x01,
100                 OP_BINARY = 0x02,
101                 OP_CLOSE = 0x08,
102                 OP_PING = 0x09,
103                 OP_PONG = 0x0a
104         };
105
106         enum State
107         {
108                 STATE_HTTPREQ,
109                 STATE_ESTABLISHED
110         };
111
112         static const unsigned char WS_MASKBIT = (1 << 7);
113         static const unsigned char WS_FINBIT = (1 << 7);
114         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_LARGE = 126;
115         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_HUGE = 127;
116         static const size_t WS_MAX_PAYLOAD_LENGTH_SMALL = 125;
117         static const size_t WS_MAX_PAYLOAD_LENGTH_LARGE = 65535;
118         static const size_t MAXHEADERSIZE = sizeof(uint64_t) + 2;
119
120         // Clients sending ping or pong frames faster than this are killed
121         static const time_t MINPINGPONGDELAY = 10;
122
123         State state;
124         time_t lastpingpong;
125         WebSocketConfig& config;
126
127         static size_t FillHeader(unsigned char* outbuf, size_t sendlength, OpCode opcode)
128         {
129                 size_t pos = 0;
130                 outbuf[pos++] = WS_FINBIT | opcode;
131
132                 if (sendlength <= WS_MAX_PAYLOAD_LENGTH_SMALL)
133                 {
134                         outbuf[pos++] = sendlength;
135                 }
136                 else if (sendlength <= WS_MAX_PAYLOAD_LENGTH_LARGE)
137                 {
138                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_LARGE;
139                         outbuf[pos++] = (sendlength >> 8) & 0xff;
140                         outbuf[pos++] = sendlength & 0xff;
141                 }
142                 else
143                 {
144                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_HUGE;
145                         const uint64_t len = sendlength;
146                         for (int i = sizeof(uint64_t)-1; i >= 0; i--)
147                                 outbuf[pos++] = ((len >> i*8) & 0xff);
148                 }
149
150                 return pos;
151         }
152
153         static StreamSocket::SendQueue::Element PrepareSendQElem(size_t size, OpCode opcode)
154         {
155                 unsigned char header[MAXHEADERSIZE];
156                 const size_t n = FillHeader(header, size, opcode);
157
158                 return StreamSocket::SendQueue::Element(reinterpret_cast<const char*>(header), n);
159         }
160
161         int HandleAppData(StreamSocket* sock, std::string& appdataout, bool allowlarge)
162         {
163                 std::string& myrecvq = GetRecvQ();
164                 // Need 1 byte opcode, minimum 1 byte len, 4 bytes masking key
165                 if (myrecvq.length() < 6)
166                         return 0;
167
168                 const std::string& cmyrecvq = myrecvq;
169                 unsigned char len1 = (unsigned char)cmyrecvq[1];
170                 if (!(len1 & WS_MASKBIT))
171                 {
172                         sock->SetError("WebSocket protocol violation: unmasked client frame");
173                         return -1;
174                 }
175
176                 len1 &= ~WS_MASKBIT;
177
178                 // Assume the length is a single byte, if not, update values later
179                 unsigned int len = len1;
180                 unsigned int payloadstartoffset = 6;
181                 const unsigned char* maskkey = reinterpret_cast<const unsigned char*>(&cmyrecvq[2]);
182
183                 if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE)
184                 {
185                         // allowlarge is false for control frames according to the RFC meaning large pings, etc. are not allowed
186                         if (!allowlarge)
187                         {
188                                 sock->SetError("WebSocket protocol violation: large control frame");
189                                 return -1;
190                         }
191
192                         // Large frame, has 2 bytes len after the magic byte indicating the length
193                         // Need 1 byte opcode, 3 bytes len, 4 bytes masking key
194                         if (myrecvq.length() < 8)
195                                 return 0;
196
197                         unsigned char len2 = (unsigned char)cmyrecvq[2];
198                         unsigned char len3 = (unsigned char)cmyrecvq[3];
199                         len = (len2 << 8) | len3;
200
201                         if (len <= WS_MAX_PAYLOAD_LENGTH_SMALL)
202                         {
203                                 sock->SetError("WebSocket protocol violation: non-minimal length encoding used");
204                                 return -1;
205                         }
206
207                         maskkey += 2;
208                         payloadstartoffset += 2;
209                 }
210                 else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE)
211                 {
212                         sock->SetError("WebSocket: Huge frames are not supported");
213                         return -1;
214                 }
215
216                 if (myrecvq.length() < payloadstartoffset + len)
217                         return 0;
218
219                 unsigned int maskkeypos = 0;
220                 const std::string::iterator endit = myrecvq.begin() + payloadstartoffset + len;
221                 for (std::string::const_iterator i = myrecvq.begin() + payloadstartoffset; i != endit; ++i)
222                 {
223                         const unsigned char c = (unsigned char)*i;
224                         appdataout.push_back(c ^ maskkey[maskkeypos++]);
225                         maskkeypos %= 4;
226                 }
227
228                 myrecvq.erase(myrecvq.begin(), endit);
229                 return 1;
230         }
231
232         int HandlePingPongFrame(StreamSocket* sock, bool isping)
233         {
234                 if (lastpingpong + MINPINGPONGDELAY >= ServerInstance->Time())
235                 {
236                         sock->SetError("WebSocket: Ping/pong flood");
237                         return -1;
238                 }
239
240                 lastpingpong = ServerInstance->Time();
241
242                 std::string appdata;
243                 const int result = HandleAppData(sock, appdata, false);
244                 // If it's a pong stop here regardless of the result so we won't generate a reply
245                 if ((result <= 0) || (!isping))
246                         return result;
247
248                 StreamSocket::SendQueue::Element elem = PrepareSendQElem(appdata.length(), OP_PONG);
249                 elem.append(appdata);
250                 GetSendQ().push_back(elem);
251
252                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
253                 return 1;
254         }
255
256         int HandleWS(StreamSocket* sock, std::string& destrecvq)
257         {
258                 if (GetRecvQ().empty())
259                         return 0;
260
261                 unsigned char opcode = (unsigned char)GetRecvQ().c_str()[0];
262                 switch (opcode & ~WS_FINBIT)
263                 {
264                         case OP_CONTINUATION:
265                         case OP_TEXT:
266                         case OP_BINARY:
267                         {
268                                 std::string appdata;
269                                 const int result = HandleAppData(sock, appdata, true);
270                                 if (result != 1)
271                                         return result;
272
273                                 // Strip out any CR+LF which may have been erroneously sent.
274                                 for (std::string::const_iterator iter = appdata.begin(); iter != appdata.end(); ++iter)
275                                 {
276                                         if (*iter != '\r' && *iter != '\n')
277                                                 destrecvq.push_back(*iter);
278                                 }
279
280                                 // If we are on the final message of this block append a line terminator.
281                                 if (opcode & WS_FINBIT)
282                                         destrecvq.append("\r\n");
283
284                                 return 1;
285                         }
286
287                         case OP_PING:
288                         {
289                                 return HandlePingPongFrame(sock, true);
290                         }
291
292                         case OP_PONG:
293                         {
294                                 // A pong frame may be sent unsolicited, so we have to handle it.
295                                 // It may carry application data which we need to remove from the recvq as well.
296                                 return HandlePingPongFrame(sock, false);
297                         }
298
299                         case OP_CLOSE:
300                         {
301                                 sock->SetError("Connection closed");
302                                 return -1;
303                         }
304
305                         default:
306                         {
307                                 sock->SetError("WebSocket: Invalid opcode");
308                                 return -1;
309                         }
310                 }
311         }
312
313         void FailHandshake(StreamSocket* sock, const char* httpreply, const char* sockerror)
314         {
315                 GetSendQ().push_back(StreamSocket::SendQueue::Element(httpreply));
316                 sock->DoWrite();
317                 sock->SetError(sockerror);
318         }
319
320         int HandleHTTPReq(StreamSocket* sock)
321         {
322                 std::string& recvq = GetRecvQ();
323                 const std::string::size_type reqend = recvq.find("\r\n\r\n");
324                 if (reqend == std::string::npos)
325                         return 0;
326
327                 bool allowedorigin = false;
328                 HTTPHeaderFinder originheader;
329                 if (originheader.Find(recvq, "Origin:", 7, reqend))
330                 {
331                         const std::string origin = originheader.ExtractValue(recvq);
332                         for (OriginList::const_iterator iter = config.allowedorigins.begin(); iter != config.allowedorigins.end(); ++iter)
333                         {
334                                 if (InspIRCd::Match(origin, *iter, ascii_case_insensitive_map))
335                                 {
336                                         allowedorigin = true;
337                                         break;
338                                 }
339                         }
340                 }
341
342                 if (!allowedorigin)
343                 {
344                         FailHandshake(sock, "HTTP/1.1 403 Forbidden\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request from a non-whitelisted origin");
345                         return -1;
346                 }
347
348                 if (!config.proxyranges.empty() && sock->type == StreamSocket::SS_USER)
349                 {
350                         LocalUser* luser = static_cast<UserIOHandler*>(sock)->user;
351                         irc::sockets::sockaddrs realsa(luser->client_sa);
352
353                         HTTPHeaderFinder proxyheader;
354                         if (proxyheader.Find(recvq, "X-Real-IP:", 10, reqend)
355                                 && irc::sockets::aptosa(proxyheader.ExtractValue(recvq), realsa.port(), realsa))
356                         {
357                                 // Nothing to do here.
358                         }
359                         else if (proxyheader.Find(recvq, "X-Forwarded-For:", 16, reqend)
360                                 && irc::sockets::aptosa(proxyheader.ExtractValue(recvq), realsa.port(), realsa))
361                         {
362                                 // Nothing to do here.
363                         }
364
365                         for (WebSocketConfig::ProxyRanges::const_iterator iter = config.proxyranges.begin(); iter != config.proxyranges.end(); ++iter)
366                         {
367                                 if (InspIRCd::MatchCIDR(*iter, luser->GetIPString(), ascii_case_insensitive_map))
368                                 {
369                                         // Give the user their real IP address.
370                                         if (realsa == luser->client_sa)
371                                                 luser->SetClientIP(realsa);
372                                         break;
373                                 }
374                         }
375                 }
376
377
378                 HTTPHeaderFinder keyheader;
379                 if (!keyheader.Find(recvq, "Sec-WebSocket-Key:", 18, reqend))
380                 {
381                         FailHandshake(sock, "HTTP/1.1 501 Not Implemented\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request which is not a websocket upgrade");
382                         return -1;
383                 }
384
385                 if (!*sha1)
386                 {
387                         FailHandshake(sock, "HTTP/1.1 503 Service Unavailable\r\nConnection: close\r\n\r\n", "WebSocket: SHA-1 provider missing");
388                         return -1;
389                 }
390
391                 state = STATE_ESTABLISHED;
392
393                 std::string key = keyheader.ExtractValue(recvq);
394                 key.append(MagicGUID);
395
396                 std::string reply = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
397                 reply.append(BinToBase64((*sha1)->GenerateRaw(key), NULL, '=')).append("\r\n\r\n");
398                 GetSendQ().push_back(StreamSocket::SendQueue::Element(reply));
399
400                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
401
402                 recvq.erase(0, reqend + 4);
403
404                 return 1;
405         }
406
407  public:
408         WebSocketHook(IOHookProvider* Prov, StreamSocket* sock, WebSocketConfig& cfg)
409                 : IOHookMiddle(Prov)
410                 , state(STATE_HTTPREQ)
411                 , lastpingpong(0)
412                 , config(cfg)
413         {
414                 sock->AddIOHook(this);
415         }
416
417         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& uppersendq) CXX11_OVERRIDE
418         {
419                 StreamSocket::SendQueue& mysendq = GetSendQ();
420
421                 // Return 1 to allow sending back an error HTTP response
422                 if (state != STATE_ESTABLISHED)
423                         return (mysendq.empty() ? 0 : 1);
424
425                 std::string message;
426                 for (StreamSocket::SendQueue::const_iterator elem = uppersendq.begin(); elem != uppersendq.end(); ++elem)
427                 {
428                         for (StreamSocket::SendQueue::Element::const_iterator chr = elem->begin(); chr != elem->end(); ++chr)
429                         {
430                                 if (*chr == '\n')
431                                 {
432                                         // We have found an entire message. Send it in its own frame.
433                                         if (config.sendastext)
434                                         {
435                                                 // If we send messages as text then we need to ensure they are valid UTF-8.
436                                                 std::string encoded;
437                                                 utf8::replace_invalid(message.begin(), message.end(), std::back_inserter(encoded));
438
439                                                 mysendq.push_back(PrepareSendQElem(encoded.length(), OP_TEXT));
440                                                 mysendq.push_back(encoded);
441                                         }
442                                         else
443                                         {
444                                                 // Otherwise, send the raw message as a binary frame.
445                                                 mysendq.push_back(PrepareSendQElem(message.length(), OP_BINARY));
446                                                 mysendq.push_back(message);
447                                         }
448                                         message.clear();
449                                 }
450                                 else if (*chr != '\r')
451                                 {
452                                         message.push_back(*chr);
453                                 }
454                         }
455                 }
456
457                 // Empty the upper send queue and push whatever is left back onto it.
458                 uppersendq.clear();
459                 if (!message.empty())
460                 {
461                         uppersendq.push_back(message);
462                         return 0;
463                 }
464
465                 return 1;
466         }
467
468         int OnStreamSocketRead(StreamSocket* sock, std::string& destrecvq) CXX11_OVERRIDE
469         {
470                 if (state == STATE_HTTPREQ)
471                 {
472                         int httpret = HandleHTTPReq(sock);
473                         if (httpret <= 0)
474                                 return httpret;
475                 }
476
477                 int wsret;
478                 do
479                 {
480                         wsret = HandleWS(sock, destrecvq);
481                 }
482                 while ((!GetRecvQ().empty()) && (wsret > 0));
483
484                 return wsret;
485         }
486
487         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
488         {
489         }
490 };
491
492 void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
493 {
494         new WebSocketHook(this, sock, config);
495 }
496
497 class ModuleWebSocket : public Module
498 {
499         dynamic_reference_nocheck<HashProvider> hash;
500         reference<WebSocketHookProvider> hookprov;
501
502  public:
503         ModuleWebSocket()
504                 : hash(this, "hash/sha1")
505                 , hookprov(new WebSocketHookProvider(this))
506         {
507                 sha1 = &hash;
508         }
509
510         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
511         {
512                 ConfigTagList tags = ServerInstance->Config->ConfTags("wsorigin");
513                 if (tags.first == tags.second)
514                         throw ModuleException("You have loaded the websocket module but not configured any allowed origins!");
515
516                 WebSocketConfig config;
517                 for (ConfigIter i = tags.first; i != tags.second; ++i)
518                 {
519                         ConfigTag* tag = i->second;
520
521                         // Ensure that we have the <wsorigin:allow> parameter.
522                         const std::string allow = tag->getString("allow");
523                         if (allow.empty())
524                                 throw ModuleException("<wsorigin:allow> is a mandatory field, at " + tag->getTagLocation());
525
526                         config.allowedorigins.push_back(allow);
527                 }
528
529                 ConfigTag* tag = ServerInstance->Config->ConfValue("websocket");
530                 config.sendastext = tag->getBool("sendastext", true);
531
532                 irc::spacesepstream proxyranges(tag->getString("proxyranges"));
533                 for (std::string proxyrange; proxyranges.GetToken(proxyrange); )
534                         config.proxyranges.push_back(proxyrange);
535
536                 // Everything is okay; apply the new config.
537                 hookprov->config = config;
538         }
539
540         void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
541         {
542                 if (type != ExtensionItem::EXT_USER)
543                         return;
544
545                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
546                 if ((user) && (user->eh.GetModHook(this)))
547                         ServerInstance->Users.QuitUser(user, "WebSocket module unloading");
548         }
549
550         Version GetVersion() CXX11_OVERRIDE
551         {
552                 return Version("Provides RFC 6455 WebSocket support", VF_VENDOR);
553         }
554 };
555
556 MODULE_INIT(ModuleWebSocket)