]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/m_websocket.cpp
8ec8968470c61cf99e160b53ff134c9ff2ed40a1
[user/henk/code/inspircd.git] / src / modules / m_websocket.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
5  *
6  * This file is part of InspIRCd.  InspIRCd is free software: you can
7  * redistribute it and/or modify it under the terms of the GNU General Public
8  * License as published by the Free Software Foundation, version 2.
9  *
10  * This program is distributed in the hope that it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
13  * details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19 /// $CompilerFlags: -Ivendor_directory("utfcpp")
20
21
22 #include "inspircd.h"
23 #include "iohook.h"
24 #include "modules/hash.h"
25
26 #include <utf8.h>
27
28 static const char MagicGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
29 static const char whitespace[] = " \t\r\n";
30 static dynamic_reference_nocheck<HashProvider>* sha1;
31
32 struct WebSocketConfig
33 {
34         typedef std::vector<std::string> OriginList;
35         typedef std::vector<std::string> ProxyRanges;
36
37         // The HTTP origins that can connect to the server.
38         OriginList allowedorigins;
39
40         // The IP ranges which send trustworthy X-Real-IP or X-Forwarded-For headers.
41         ProxyRanges proxyranges;
42
43         // Whether to send as UTF-8 text instead of binary data.
44         bool sendastext;
45 };
46
47 class WebSocketHookProvider : public IOHookProvider
48 {
49  public:
50         WebSocketConfig config;
51         WebSocketHookProvider(Module* mod)
52                 : IOHookProvider(mod, "websocket", IOHookProvider::IOH_UNKNOWN, true)
53         {
54         }
55
56         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE;
57
58         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
59         {
60         }
61 };
62
63 class WebSocketHook : public IOHookMiddle
64 {
65         class HTTPHeaderFinder
66         {
67                 std::string::size_type bpos;
68                 std::string::size_type len;
69
70          public:
71                 bool Find(const std::string& req, const char* header, std::string::size_type headerlen, std::string::size_type maxpos)
72                 {
73                         std::string::size_type keybegin = req.find(header);
74                         if ((keybegin == std::string::npos) || (keybegin > maxpos) || (keybegin == 0) || (req[keybegin-1] != '\n'))
75                                 return false;
76
77                         keybegin += headerlen;
78
79                         bpos = req.find_first_not_of(whitespace, keybegin, sizeof(whitespace)-1);
80                         if ((bpos == std::string::npos) || (bpos > maxpos))
81                                 return false;
82
83                         const std::string::size_type epos = req.find_first_of(whitespace, bpos, sizeof(whitespace)-1);
84                         len = epos - bpos;
85
86                         return true;
87                 }
88
89                 std::string ExtractValue(const std::string& req) const
90                 {
91                         return std::string(req, bpos, len);
92                 }
93         };
94
95         enum OpCode
96         {
97                 OP_CONTINUATION = 0x00,
98                 OP_TEXT = 0x01,
99                 OP_BINARY = 0x02,
100                 OP_CLOSE = 0x08,
101                 OP_PING = 0x09,
102                 OP_PONG = 0x0a
103         };
104
105         enum State
106         {
107                 STATE_HTTPREQ,
108                 STATE_ESTABLISHED
109         };
110
111         static const unsigned char WS_MASKBIT = (1 << 7);
112         static const unsigned char WS_FINBIT = (1 << 7);
113         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_LARGE = 126;
114         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_HUGE = 127;
115         static const size_t WS_MAX_PAYLOAD_LENGTH_SMALL = 125;
116         static const size_t WS_MAX_PAYLOAD_LENGTH_LARGE = 65535;
117         static const size_t MAXHEADERSIZE = sizeof(uint64_t) + 2;
118
119         // Clients sending ping or pong frames faster than this are killed
120         static const time_t MINPINGPONGDELAY = 10;
121
122         State state;
123         time_t lastpingpong;
124         WebSocketConfig& config;
125
126         static size_t FillHeader(unsigned char* outbuf, size_t sendlength, OpCode opcode)
127         {
128                 size_t pos = 0;
129                 outbuf[pos++] = WS_FINBIT | opcode;
130
131                 if (sendlength <= WS_MAX_PAYLOAD_LENGTH_SMALL)
132                 {
133                         outbuf[pos++] = sendlength;
134                 }
135                 else if (sendlength <= WS_MAX_PAYLOAD_LENGTH_LARGE)
136                 {
137                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_LARGE;
138                         outbuf[pos++] = (sendlength >> 8) & 0xff;
139                         outbuf[pos++] = sendlength & 0xff;
140                 }
141                 else
142                 {
143                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_HUGE;
144                         const uint64_t len = sendlength;
145                         for (int i = sizeof(uint64_t)-1; i >= 0; i--)
146                                 outbuf[pos++] = ((len >> i*8) & 0xff);
147                 }
148
149                 return pos;
150         }
151
152         static StreamSocket::SendQueue::Element PrepareSendQElem(size_t size, OpCode opcode)
153         {
154                 unsigned char header[MAXHEADERSIZE];
155                 const size_t n = FillHeader(header, size, opcode);
156
157                 return StreamSocket::SendQueue::Element(reinterpret_cast<const char*>(header), n);
158         }
159
160         int HandleAppData(StreamSocket* sock, std::string& appdataout, bool allowlarge)
161         {
162                 std::string& myrecvq = GetRecvQ();
163                 // Need 1 byte opcode, minimum 1 byte len, 4 bytes masking key
164                 if (myrecvq.length() < 6)
165                         return 0;
166
167                 const std::string& cmyrecvq = myrecvq;
168                 unsigned char len1 = (unsigned char)cmyrecvq[1];
169                 if (!(len1 & WS_MASKBIT))
170                 {
171                         sock->SetError("WebSocket protocol violation: unmasked client frame");
172                         return -1;
173                 }
174
175                 len1 &= ~WS_MASKBIT;
176
177                 // Assume the length is a single byte, if not, update values later
178                 unsigned int len = len1;
179                 unsigned int payloadstartoffset = 6;
180                 const unsigned char* maskkey = reinterpret_cast<const unsigned char*>(&cmyrecvq[2]);
181
182                 if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE)
183                 {
184                         // allowlarge is false for control frames according to the RFC meaning large pings, etc. are not allowed
185                         if (!allowlarge)
186                         {
187                                 sock->SetError("WebSocket protocol violation: large control frame");
188                                 return -1;
189                         }
190
191                         // Large frame, has 2 bytes len after the magic byte indicating the length
192                         // Need 1 byte opcode, 3 bytes len, 4 bytes masking key
193                         if (myrecvq.length() < 8)
194                                 return 0;
195
196                         unsigned char len2 = (unsigned char)cmyrecvq[2];
197                         unsigned char len3 = (unsigned char)cmyrecvq[3];
198                         len = (len2 << 8) | len3;
199
200                         if (len <= WS_MAX_PAYLOAD_LENGTH_SMALL)
201                         {
202                                 sock->SetError("WebSocket protocol violation: non-minimal length encoding used");
203                                 return -1;
204                         }
205
206                         maskkey += 2;
207                         payloadstartoffset += 2;
208                 }
209                 else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE)
210                 {
211                         sock->SetError("WebSocket: Huge frames are not supported");
212                         return -1;
213                 }
214
215                 if (myrecvq.length() < payloadstartoffset + len)
216                         return 0;
217
218                 unsigned int maskkeypos = 0;
219                 const std::string::iterator endit = myrecvq.begin() + payloadstartoffset + len;
220                 for (std::string::const_iterator i = myrecvq.begin() + payloadstartoffset; i != endit; ++i)
221                 {
222                         const unsigned char c = (unsigned char)*i;
223                         appdataout.push_back(c ^ maskkey[maskkeypos++]);
224                         maskkeypos %= 4;
225                 }
226
227                 myrecvq.erase(myrecvq.begin(), endit);
228                 return 1;
229         }
230
231         int HandlePingPongFrame(StreamSocket* sock, bool isping)
232         {
233                 if (lastpingpong + MINPINGPONGDELAY >= ServerInstance->Time())
234                 {
235                         sock->SetError("WebSocket: Ping/pong flood");
236                         return -1;
237                 }
238
239                 lastpingpong = ServerInstance->Time();
240
241                 std::string appdata;
242                 const int result = HandleAppData(sock, appdata, false);
243                 // If it's a pong stop here regardless of the result so we won't generate a reply
244                 if ((result <= 0) || (!isping))
245                         return result;
246
247                 StreamSocket::SendQueue::Element elem = PrepareSendQElem(appdata.length(), OP_PONG);
248                 elem.append(appdata);
249                 GetSendQ().push_back(elem);
250
251                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
252                 return 1;
253         }
254
255         int HandleWS(StreamSocket* sock, std::string& destrecvq)
256         {
257                 if (GetRecvQ().empty())
258                         return 0;
259
260                 unsigned char opcode = (unsigned char)GetRecvQ().c_str()[0];
261                 switch (opcode & ~WS_FINBIT)
262                 {
263                         case OP_CONTINUATION:
264                         case OP_TEXT:
265                         case OP_BINARY:
266                         {
267                                 std::string appdata;
268                                 const int result = HandleAppData(sock, appdata, true);
269                                 if (result != 1)
270                                         return result;
271
272                                 // Strip out any CR+LF which may have been erroneously sent.
273                                 for (std::string::const_iterator iter = appdata.begin(); iter != appdata.end(); ++iter)
274                                 {
275                                         if (*iter != '\r' && *iter != '\n')
276                                                 destrecvq.push_back(*iter);
277                                 }
278
279                                 // If we are on the final message of this block append a line terminator.
280                                 if (opcode & WS_FINBIT)
281                                         destrecvq.append("\r\n");
282
283                                 return 1;
284                         }
285
286                         case OP_PING:
287                         {
288                                 return HandlePingPongFrame(sock, true);
289                         }
290
291                         case OP_PONG:
292                         {
293                                 // A pong frame may be sent unsolicited, so we have to handle it.
294                                 // It may carry application data which we need to remove from the recvq as well.
295                                 return HandlePingPongFrame(sock, false);
296                         }
297
298                         case OP_CLOSE:
299                         {
300                                 sock->SetError("Connection closed");
301                                 return -1;
302                         }
303
304                         default:
305                         {
306                                 sock->SetError("WebSocket: Invalid opcode");
307                                 return -1;
308                         }
309                 }
310         }
311
312         void FailHandshake(StreamSocket* sock, const char* httpreply, const char* sockerror)
313         {
314                 GetSendQ().push_back(StreamSocket::SendQueue::Element(httpreply));
315                 sock->DoWrite();
316                 sock->SetError(sockerror);
317         }
318
319         int HandleHTTPReq(StreamSocket* sock)
320         {
321                 std::string& recvq = GetRecvQ();
322                 const std::string::size_type reqend = recvq.find("\r\n\r\n");
323                 if (reqend == std::string::npos)
324                         return 0;
325
326                 bool allowedorigin = false;
327                 HTTPHeaderFinder originheader;
328                 if (originheader.Find(recvq, "Origin:", 7, reqend))
329                 {
330                         const std::string origin = originheader.ExtractValue(recvq);
331                         for (WebSocketConfig::OriginList::const_iterator iter = config.allowedorigins.begin(); iter != config.allowedorigins.end(); ++iter)
332                         {
333                                 if (InspIRCd::Match(origin, *iter, ascii_case_insensitive_map))
334                                 {
335                                         allowedorigin = true;
336                                         break;
337                                 }
338                         }
339                 }
340
341                 if (!allowedorigin)
342                 {
343                         FailHandshake(sock, "HTTP/1.1 403 Forbidden\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request from a non-whitelisted origin");
344                         return -1;
345                 }
346
347                 if (!config.proxyranges.empty() && sock->type == StreamSocket::SS_USER)
348                 {
349                         LocalUser* luser = static_cast<UserIOHandler*>(sock)->user;
350                         irc::sockets::sockaddrs realsa(luser->client_sa);
351
352                         HTTPHeaderFinder proxyheader;
353                         if (proxyheader.Find(recvq, "X-Real-IP:", 10, reqend)
354                                 && irc::sockets::aptosa(proxyheader.ExtractValue(recvq), realsa.port(), realsa))
355                         {
356                                 // Nothing to do here.
357                         }
358                         else if (proxyheader.Find(recvq, "X-Forwarded-For:", 16, reqend)
359                                 && irc::sockets::aptosa(proxyheader.ExtractValue(recvq), realsa.port(), realsa))
360                         {
361                                 // Nothing to do here.
362                         }
363
364                         for (WebSocketConfig::ProxyRanges::const_iterator iter = config.proxyranges.begin(); iter != config.proxyranges.end(); ++iter)
365                         {
366                                 if (InspIRCd::MatchCIDR(luser->GetIPString(), *iter, ascii_case_insensitive_map))
367                                 {
368                                         // Give the user their real IP address.
369                                         if (realsa != luser->client_sa)
370                                                 luser->SetClientIP(realsa);
371                                         break;
372                                 }
373                         }
374                 }
375
376
377                 HTTPHeaderFinder keyheader;
378                 if (!keyheader.Find(recvq, "Sec-WebSocket-Key:", 18, reqend))
379                 {
380                         FailHandshake(sock, "HTTP/1.1 501 Not Implemented\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request which is not a websocket upgrade");
381                         return -1;
382                 }
383
384                 if (!*sha1)
385                 {
386                         FailHandshake(sock, "HTTP/1.1 503 Service Unavailable\r\nConnection: close\r\n\r\n", "WebSocket: SHA-1 provider missing");
387                         return -1;
388                 }
389
390                 state = STATE_ESTABLISHED;
391
392                 std::string key = keyheader.ExtractValue(recvq);
393                 key.append(MagicGUID);
394
395                 std::string reply = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
396                 reply.append(BinToBase64((*sha1)->GenerateRaw(key), NULL, '=')).append("\r\n\r\n");
397                 GetSendQ().push_back(StreamSocket::SendQueue::Element(reply));
398
399                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
400
401                 recvq.erase(0, reqend + 4);
402
403                 return 1;
404         }
405
406  public:
407         WebSocketHook(IOHookProvider* Prov, StreamSocket* sock, WebSocketConfig& cfg)
408                 : IOHookMiddle(Prov)
409                 , state(STATE_HTTPREQ)
410                 , lastpingpong(0)
411                 , config(cfg)
412         {
413                 sock->AddIOHook(this);
414         }
415
416         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& uppersendq) CXX11_OVERRIDE
417         {
418                 StreamSocket::SendQueue& mysendq = GetSendQ();
419
420                 // Return 1 to allow sending back an error HTTP response
421                 if (state != STATE_ESTABLISHED)
422                         return (mysendq.empty() ? 0 : 1);
423
424                 std::string message;
425                 for (StreamSocket::SendQueue::const_iterator elem = uppersendq.begin(); elem != uppersendq.end(); ++elem)
426                 {
427                         for (StreamSocket::SendQueue::Element::const_iterator chr = elem->begin(); chr != elem->end(); ++chr)
428                         {
429                                 if (*chr == '\n')
430                                 {
431                                         // We have found an entire message. Send it in its own frame.
432                                         if (config.sendastext)
433                                         {
434                                                 // If we send messages as text then we need to ensure they are valid UTF-8.
435                                                 std::string encoded;
436                                                 utf8::replace_invalid(message.begin(), message.end(), std::back_inserter(encoded));
437
438                                                 mysendq.push_back(PrepareSendQElem(encoded.length(), OP_TEXT));
439                                                 mysendq.push_back(encoded);
440                                         }
441                                         else
442                                         {
443                                                 // Otherwise, send the raw message as a binary frame.
444                                                 mysendq.push_back(PrepareSendQElem(message.length(), OP_BINARY));
445                                                 mysendq.push_back(message);
446                                         }
447                                         message.clear();
448                                 }
449                                 else if (*chr != '\r')
450                                 {
451                                         message.push_back(*chr);
452                                 }
453                         }
454                 }
455
456                 // Empty the upper send queue and push whatever is left back onto it.
457                 uppersendq.clear();
458                 if (!message.empty())
459                 {
460                         uppersendq.push_back(message);
461                         return 0;
462                 }
463
464                 return 1;
465         }
466
467         int OnStreamSocketRead(StreamSocket* sock, std::string& destrecvq) CXX11_OVERRIDE
468         {
469                 if (state == STATE_HTTPREQ)
470                 {
471                         int httpret = HandleHTTPReq(sock);
472                         if (httpret <= 0)
473                                 return httpret;
474                 }
475
476                 int wsret;
477                 do
478                 {
479                         wsret = HandleWS(sock, destrecvq);
480                 }
481                 while ((!GetRecvQ().empty()) && (wsret > 0));
482
483                 return wsret;
484         }
485
486         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
487         {
488         }
489 };
490
491 void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
492 {
493         new WebSocketHook(this, sock, config);
494 }
495
496 class ModuleWebSocket : public Module
497 {
498         dynamic_reference_nocheck<HashProvider> hash;
499         reference<WebSocketHookProvider> hookprov;
500
501  public:
502         ModuleWebSocket()
503                 : hash(this, "hash/sha1")
504                 , hookprov(new WebSocketHookProvider(this))
505         {
506                 sha1 = &hash;
507         }
508
509         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
510         {
511                 ConfigTagList tags = ServerInstance->Config->ConfTags("wsorigin");
512                 if (tags.first == tags.second)
513                         throw ModuleException("You have loaded the websocket module but not configured any allowed origins!");
514
515                 WebSocketConfig config;
516                 for (ConfigIter i = tags.first; i != tags.second; ++i)
517                 {
518                         ConfigTag* tag = i->second;
519
520                         // Ensure that we have the <wsorigin:allow> parameter.
521                         const std::string allow = tag->getString("allow");
522                         if (allow.empty())
523                                 throw ModuleException("<wsorigin:allow> is a mandatory field, at " + tag->getTagLocation());
524
525                         config.allowedorigins.push_back(allow);
526                 }
527
528                 ConfigTag* tag = ServerInstance->Config->ConfValue("websocket");
529                 config.sendastext = tag->getBool("sendastext", true);
530
531                 irc::spacesepstream proxyranges(tag->getString("proxyranges"));
532                 for (std::string proxyrange; proxyranges.GetToken(proxyrange); )
533                         config.proxyranges.push_back(proxyrange);
534
535                 // Everything is okay; apply the new config.
536                 hookprov->config = config;
537         }
538
539         void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
540         {
541                 if (type != ExtensionItem::EXT_USER)
542                         return;
543
544                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
545                 if ((user) && (user->eh.GetModHook(this)))
546                         ServerInstance->Users.QuitUser(user, "WebSocket module unloading");
547         }
548
549         Version GetVersion() CXX11_OVERRIDE
550         {
551                 return Version("Provides RFC 6455 WebSocket support", VF_VENDOR);
552         }
553 };
554
555 MODULE_INIT(ModuleWebSocket)