]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/m_websocket.cpp
c7b7f6d4fa537695c5d55848b2b97bbff80e533c
[user/henk/code/inspircd.git] / src / modules / m_websocket.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2019 iwalkalone <iwalkalone69@gmail.com>
5  *   Copyright (C) 2017-2020 Sadie Powell <sadie@witchery.services>
6  *   Copyright (C) 2016-2017 Attila Molnar <attilamolnar@hush.com>
7  *
8  * This file is part of InspIRCd.  InspIRCd is free software: you can
9  * redistribute it and/or modify it under the terms of the GNU General Public
10  * License as published by the Free Software Foundation, version 2.
11  *
12  * This program is distributed in the hope that it will be useful, but WITHOUT
13  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
14  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
15  * details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
19  */
20
21 /// $CompilerFlags: -Ivendor_directory("utfcpp")
22
23
24 #include "inspircd.h"
25 #include "iohook.h"
26 #include "modules/hash.h"
27
28 #define UTF_CPP_CPLUSPLUS 199711L
29 #include <unchecked.h>
30
31 static const char MagicGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
32 static const char whitespace[] = " \t\r\n";
33 static dynamic_reference_nocheck<HashProvider>* sha1;
34
35 struct WebSocketConfig
36 {
37         typedef std::vector<std::string> OriginList;
38         typedef std::vector<std::string> ProxyRanges;
39
40         // The HTTP origins that can connect to the server.
41         OriginList allowedorigins;
42
43         // The IP ranges which send trustworthy X-Real-IP or X-Forwarded-For headers.
44         ProxyRanges proxyranges;
45
46         // Whether to send as UTF-8 text instead of binary data.
47         bool sendastext;
48 };
49
50 class WebSocketHookProvider : public IOHookProvider
51 {
52  public:
53         WebSocketConfig config;
54         WebSocketHookProvider(Module* mod)
55                 : IOHookProvider(mod, "websocket", IOHookProvider::IOH_UNKNOWN, true)
56         {
57         }
58
59         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE;
60
61         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
62         {
63         }
64 };
65
66 class WebSocketHook : public IOHookMiddle
67 {
68         class HTTPHeaderFinder
69         {
70                 std::string::size_type bpos;
71                 std::string::size_type len;
72
73          public:
74                 bool Find(const std::string& req, const char* header, std::string::size_type headerlen, std::string::size_type maxpos)
75                 {
76                         std::string::size_type keybegin = req.find(header);
77                         if ((keybegin == std::string::npos) || (keybegin > maxpos) || (keybegin == 0) || (req[keybegin-1] != '\n'))
78                                 return false;
79
80                         keybegin += headerlen;
81
82                         bpos = req.find_first_not_of(whitespace, keybegin, sizeof(whitespace)-1);
83                         if ((bpos == std::string::npos) || (bpos > maxpos))
84                                 return false;
85
86                         const std::string::size_type epos = req.find_first_of(whitespace, bpos, sizeof(whitespace)-1);
87                         len = epos - bpos;
88
89                         return true;
90                 }
91
92                 std::string ExtractValue(const std::string& req) const
93                 {
94                         return std::string(req, bpos, len);
95                 }
96         };
97
98         enum OpCode
99         {
100                 OP_CONTINUATION = 0x00,
101                 OP_TEXT = 0x01,
102                 OP_BINARY = 0x02,
103                 OP_CLOSE = 0x08,
104                 OP_PING = 0x09,
105                 OP_PONG = 0x0a
106         };
107
108         enum State
109         {
110                 STATE_HTTPREQ,
111                 STATE_ESTABLISHED
112         };
113
114         static const unsigned char WS_MASKBIT = (1 << 7);
115         static const unsigned char WS_FINBIT = (1 << 7);
116         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_LARGE = 126;
117         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_HUGE = 127;
118         static const size_t WS_MAX_PAYLOAD_LENGTH_SMALL = 125;
119         static const size_t WS_MAX_PAYLOAD_LENGTH_LARGE = 65535;
120         static const size_t MAXHEADERSIZE = sizeof(uint64_t) + 2;
121
122         // Clients sending ping or pong frames faster than this are killed
123         static const time_t MINPINGPONGDELAY = 10;
124
125         State state;
126         time_t lastpingpong;
127         WebSocketConfig& config;
128
129         static size_t FillHeader(unsigned char* outbuf, size_t sendlength, OpCode opcode)
130         {
131                 size_t pos = 0;
132                 outbuf[pos++] = WS_FINBIT | opcode;
133
134                 if (sendlength <= WS_MAX_PAYLOAD_LENGTH_SMALL)
135                 {
136                         outbuf[pos++] = sendlength;
137                 }
138                 else if (sendlength <= WS_MAX_PAYLOAD_LENGTH_LARGE)
139                 {
140                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_LARGE;
141                         outbuf[pos++] = (sendlength >> 8) & 0xff;
142                         outbuf[pos++] = sendlength & 0xff;
143                 }
144                 else
145                 {
146                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_HUGE;
147                         const uint64_t len = sendlength;
148                         for (int i = sizeof(uint64_t)-1; i >= 0; i--)
149                                 outbuf[pos++] = ((len >> i*8) & 0xff);
150                 }
151
152                 return pos;
153         }
154
155         static StreamSocket::SendQueue::Element PrepareSendQElem(size_t size, OpCode opcode)
156         {
157                 unsigned char header[MAXHEADERSIZE];
158                 const size_t n = FillHeader(header, size, opcode);
159
160                 return StreamSocket::SendQueue::Element(reinterpret_cast<const char*>(header), n);
161         }
162
163         int HandleAppData(StreamSocket* sock, std::string& appdataout, bool allowlarge)
164         {
165                 std::string& myrecvq = GetRecvQ();
166                 // Need 1 byte opcode, minimum 1 byte len, 4 bytes masking key
167                 if (myrecvq.length() < 6)
168                         return 0;
169
170                 const std::string& cmyrecvq = myrecvq;
171                 unsigned char len1 = (unsigned char)cmyrecvq[1];
172                 if (!(len1 & WS_MASKBIT))
173                 {
174                         sock->SetError("WebSocket protocol violation: unmasked client frame");
175                         return -1;
176                 }
177
178                 len1 &= ~WS_MASKBIT;
179
180                 // Assume the length is a single byte, if not, update values later
181                 unsigned int len = len1;
182                 unsigned int payloadstartoffset = 6;
183                 const unsigned char* maskkey = reinterpret_cast<const unsigned char*>(&cmyrecvq[2]);
184
185                 if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE)
186                 {
187                         // allowlarge is false for control frames according to the RFC meaning large pings, etc. are not allowed
188                         if (!allowlarge)
189                         {
190                                 sock->SetError("WebSocket protocol violation: large control frame");
191                                 return -1;
192                         }
193
194                         // Large frame, has 2 bytes len after the magic byte indicating the length
195                         // Need 1 byte opcode, 3 bytes len, 4 bytes masking key
196                         if (myrecvq.length() < 8)
197                                 return 0;
198
199                         unsigned char len2 = (unsigned char)cmyrecvq[2];
200                         unsigned char len3 = (unsigned char)cmyrecvq[3];
201                         len = (len2 << 8) | len3;
202
203                         if (len <= WS_MAX_PAYLOAD_LENGTH_SMALL)
204                         {
205                                 sock->SetError("WebSocket protocol violation: non-minimal length encoding used");
206                                 return -1;
207                         }
208
209                         maskkey += 2;
210                         payloadstartoffset += 2;
211                 }
212                 else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE)
213                 {
214                         sock->SetError("WebSocket: Huge frames are not supported");
215                         return -1;
216                 }
217
218                 if (myrecvq.length() < payloadstartoffset + len)
219                         return 0;
220
221                 unsigned int maskkeypos = 0;
222                 const std::string::iterator endit = myrecvq.begin() + payloadstartoffset + len;
223                 for (std::string::const_iterator i = myrecvq.begin() + payloadstartoffset; i != endit; ++i)
224                 {
225                         const unsigned char c = (unsigned char)*i;
226                         appdataout.push_back(c ^ maskkey[maskkeypos++]);
227                         maskkeypos %= 4;
228                 }
229
230                 myrecvq.erase(myrecvq.begin(), endit);
231                 return 1;
232         }
233
234         int HandlePingPongFrame(StreamSocket* sock, bool isping)
235         {
236                 if (lastpingpong + MINPINGPONGDELAY >= ServerInstance->Time())
237                 {
238                         sock->SetError("WebSocket: Ping/pong flood");
239                         return -1;
240                 }
241
242                 lastpingpong = ServerInstance->Time();
243
244                 std::string appdata;
245                 const int result = HandleAppData(sock, appdata, false);
246                 // If it's a pong stop here regardless of the result so we won't generate a reply
247                 if ((result <= 0) || (!isping))
248                         return result;
249
250                 StreamSocket::SendQueue::Element elem = PrepareSendQElem(appdata.length(), OP_PONG);
251                 elem.append(appdata);
252                 GetSendQ().push_back(elem);
253
254                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
255                 return 1;
256         }
257
258         int HandleWS(StreamSocket* sock, std::string& destrecvq)
259         {
260                 if (GetRecvQ().empty())
261                         return 0;
262
263                 unsigned char opcode = (unsigned char)GetRecvQ().c_str()[0];
264                 switch (opcode & ~WS_FINBIT)
265                 {
266                         case OP_CONTINUATION:
267                         case OP_TEXT:
268                         case OP_BINARY:
269                         {
270                                 std::string appdata;
271                                 const int result = HandleAppData(sock, appdata, true);
272                                 if (result != 1)
273                                         return result;
274
275                                 // Strip out any CR+LF which may have been erroneously sent.
276                                 for (std::string::const_iterator iter = appdata.begin(); iter != appdata.end(); ++iter)
277                                 {
278                                         if (*iter != '\r' && *iter != '\n')
279                                                 destrecvq.push_back(*iter);
280                                 }
281
282                                 // If we are on the final message of this block append a line terminator.
283                                 if (opcode & WS_FINBIT)
284                                         destrecvq.append("\r\n");
285
286                                 return 1;
287                         }
288
289                         case OP_PING:
290                         {
291                                 return HandlePingPongFrame(sock, true);
292                         }
293
294                         case OP_PONG:
295                         {
296                                 // A pong frame may be sent unsolicited, so we have to handle it.
297                                 // It may carry application data which we need to remove from the recvq as well.
298                                 return HandlePingPongFrame(sock, false);
299                         }
300
301                         case OP_CLOSE:
302                         {
303                                 sock->SetError("Connection closed");
304                                 return -1;
305                         }
306
307                         default:
308                         {
309                                 sock->SetError("WebSocket: Invalid opcode");
310                                 return -1;
311                         }
312                 }
313         }
314
315         void FailHandshake(StreamSocket* sock, const char* httpreply, const char* sockerror)
316         {
317                 GetSendQ().push_back(StreamSocket::SendQueue::Element(httpreply));
318                 sock->DoWrite();
319                 sock->SetError(sockerror);
320         }
321
322         int HandleHTTPReq(StreamSocket* sock)
323         {
324                 std::string& recvq = GetRecvQ();
325                 const std::string::size_type reqend = recvq.find("\r\n\r\n");
326                 if (reqend == std::string::npos)
327                         return 0;
328
329                 bool allowedorigin = false;
330                 HTTPHeaderFinder originheader;
331                 if (originheader.Find(recvq, "Origin:", 7, reqend))
332                 {
333                         const std::string origin = originheader.ExtractValue(recvq);
334                         for (WebSocketConfig::OriginList::const_iterator iter = config.allowedorigins.begin(); iter != config.allowedorigins.end(); ++iter)
335                         {
336                                 if (InspIRCd::Match(origin, *iter, ascii_case_insensitive_map))
337                                 {
338                                         allowedorigin = true;
339                                         break;
340                                 }
341                         }
342                 }
343                 else
344                 {
345                         FailHandshake(sock, "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request that did not send the Origin header");
346                         return -1;
347                 }
348
349                 if (!allowedorigin)
350                 {
351                         FailHandshake(sock, "HTTP/1.1 403 Forbidden\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request from a non-whitelisted origin");
352                         return -1;
353                 }
354
355                 if (!config.proxyranges.empty() && sock->type == StreamSocket::SS_USER)
356                 {
357                         LocalUser* luser = static_cast<UserIOHandler*>(sock)->user;
358                         irc::sockets::sockaddrs realsa(luser->client_sa);
359
360                         HTTPHeaderFinder proxyheader;
361                         if (proxyheader.Find(recvq, "X-Real-IP:", 10, reqend)
362                                 && irc::sockets::aptosa(proxyheader.ExtractValue(recvq), realsa.port(), realsa))
363                         {
364                                 // Nothing to do here.
365                         }
366                         else if (proxyheader.Find(recvq, "X-Forwarded-For:", 16, reqend)
367                                 && irc::sockets::aptosa(proxyheader.ExtractValue(recvq), realsa.port(), realsa))
368                         {
369                                 // Nothing to do here.
370                         }
371
372                         for (WebSocketConfig::ProxyRanges::const_iterator iter = config.proxyranges.begin(); iter != config.proxyranges.end(); ++iter)
373                         {
374                                 if (InspIRCd::MatchCIDR(luser->GetIPString(), *iter, ascii_case_insensitive_map))
375                                 {
376                                         // Give the user their real IP address.
377                                         if (realsa != luser->client_sa)
378                                                 luser->SetClientIP(realsa);
379
380                                         // Error if changing their IP gets them banned.
381                                         if (luser->quitting)
382                                                 return -1;
383                                         break;
384                                 }
385                         }
386                 }
387
388
389                 HTTPHeaderFinder keyheader;
390                 if (!keyheader.Find(recvq, "Sec-WebSocket-Key:", 18, reqend))
391                 {
392                         FailHandshake(sock, "HTTP/1.1 501 Not Implemented\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request which is not a websocket upgrade");
393                         return -1;
394                 }
395
396                 if (!*sha1)
397                 {
398                         FailHandshake(sock, "HTTP/1.1 503 Service Unavailable\r\nConnection: close\r\n\r\n", "WebSocket: SHA-1 provider missing");
399                         return -1;
400                 }
401
402                 state = STATE_ESTABLISHED;
403
404                 std::string key = keyheader.ExtractValue(recvq);
405                 key.append(MagicGUID);
406
407                 std::string reply = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
408                 reply.append(BinToBase64((*sha1)->GenerateRaw(key), NULL, '=')).append("\r\n\r\n");
409                 GetSendQ().push_back(StreamSocket::SendQueue::Element(reply));
410
411                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
412
413                 recvq.erase(0, reqend + 4);
414
415                 return 1;
416         }
417
418  public:
419         WebSocketHook(IOHookProvider* Prov, StreamSocket* sock, WebSocketConfig& cfg)
420                 : IOHookMiddle(Prov)
421                 , state(STATE_HTTPREQ)
422                 , lastpingpong(0)
423                 , config(cfg)
424         {
425                 sock->AddIOHook(this);
426         }
427
428         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& uppersendq) CXX11_OVERRIDE
429         {
430                 StreamSocket::SendQueue& mysendq = GetSendQ();
431
432                 // Return 1 to allow sending back an error HTTP response
433                 if (state != STATE_ESTABLISHED)
434                         return (mysendq.empty() ? 0 : 1);
435
436                 std::string message;
437                 for (StreamSocket::SendQueue::const_iterator elem = uppersendq.begin(); elem != uppersendq.end(); ++elem)
438                 {
439                         for (StreamSocket::SendQueue::Element::const_iterator chr = elem->begin(); chr != elem->end(); ++chr)
440                         {
441                                 if (*chr == '\n')
442                                 {
443                                         // We have found an entire message. Send it in its own frame.
444                                         if (config.sendastext)
445                                         {
446                                                 // If we send messages as text then we need to ensure they are valid UTF-8.
447                                                 std::string encoded;
448                                                 utf8::unchecked::replace_invalid(message.begin(), message.end(), std::back_inserter(encoded));
449
450                                                 mysendq.push_back(PrepareSendQElem(encoded.length(), OP_TEXT));
451                                                 mysendq.push_back(encoded);
452                                         }
453                                         else
454                                         {
455                                                 // Otherwise, send the raw message as a binary frame.
456                                                 mysendq.push_back(PrepareSendQElem(message.length(), OP_BINARY));
457                                                 mysendq.push_back(message);
458                                         }
459                                         message.clear();
460                                 }
461                                 else if (*chr != '\r')
462                                 {
463                                         message.push_back(*chr);
464                                 }
465                         }
466                 }
467
468                 // Empty the upper send queue and push whatever is left back onto it.
469                 uppersendq.clear();
470                 if (!message.empty())
471                 {
472                         uppersendq.push_back(message);
473                         return 0;
474                 }
475
476                 return 1;
477         }
478
479         int OnStreamSocketRead(StreamSocket* sock, std::string& destrecvq) CXX11_OVERRIDE
480         {
481                 if (state == STATE_HTTPREQ)
482                 {
483                         int httpret = HandleHTTPReq(sock);
484                         if (httpret <= 0)
485                                 return httpret;
486                 }
487
488                 int wsret;
489                 do
490                 {
491                         wsret = HandleWS(sock, destrecvq);
492                 }
493                 while ((!GetRecvQ().empty()) && (wsret > 0));
494
495                 return wsret;
496         }
497
498         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
499         {
500         }
501 };
502
503 void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
504 {
505         new WebSocketHook(this, sock, config);
506 }
507
508 class ModuleWebSocket : public Module
509 {
510         dynamic_reference_nocheck<HashProvider> hash;
511         reference<WebSocketHookProvider> hookprov;
512
513  public:
514         ModuleWebSocket()
515                 : hash(this, "hash/sha1")
516                 , hookprov(new WebSocketHookProvider(this))
517         {
518                 sha1 = &hash;
519         }
520
521         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
522         {
523                 ConfigTagList tags = ServerInstance->Config->ConfTags("wsorigin");
524                 if (tags.first == tags.second)
525                         throw ModuleException("You have loaded the websocket module but not configured any allowed origins!");
526
527                 WebSocketConfig config;
528                 for (ConfigIter i = tags.first; i != tags.second; ++i)
529                 {
530                         ConfigTag* tag = i->second;
531
532                         // Ensure that we have the <wsorigin:allow> parameter.
533                         const std::string allow = tag->getString("allow");
534                         if (allow.empty())
535                                 throw ModuleException("<wsorigin:allow> is a mandatory field, at " + tag->getTagLocation());
536
537                         config.allowedorigins.push_back(allow);
538                 }
539
540                 ConfigTag* tag = ServerInstance->Config->ConfValue("websocket");
541                 config.sendastext = tag->getBool("sendastext", true);
542
543                 irc::spacesepstream proxyranges(tag->getString("proxyranges"));
544                 for (std::string proxyrange; proxyranges.GetToken(proxyrange); )
545                         config.proxyranges.push_back(proxyrange);
546
547                 // Everything is okay; apply the new config.
548                 hookprov->config = config;
549         }
550
551         void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
552         {
553                 if (type != ExtensionItem::EXT_USER)
554                         return;
555
556                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
557                 if ((user) && (user->eh.GetModHook(this)))
558                         ServerInstance->Users.QuitUser(user, "WebSocket module unloading");
559         }
560
561         Version GetVersion() CXX11_OVERRIDE
562         {
563                 return Version("Allows WebSocket clients to connect to the IRC server.", VF_VENDOR);
564         }
565 };
566
567 MODULE_INIT(ModuleWebSocket)