]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/m_websocket.cpp
Update the cloaks of connected users when their IP address changes.
[user/henk/code/inspircd.git] / src / modules / m_websocket.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
5  *
6  * This file is part of InspIRCd.  InspIRCd is free software: you can
7  * redistribute it and/or modify it under the terms of the GNU General Public
8  * License as published by the Free Software Foundation, version 2.
9  *
10  * This program is distributed in the hope that it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
13  * details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19 /// $CompilerFlags: -Ivendor_directory("utfcpp")
20
21
22 #include "inspircd.h"
23 #include "iohook.h"
24 #include "modules/hash.h"
25
26 #include <utf8.h>
27
28 typedef std::vector<std::string> OriginList;
29
30 static const char MagicGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
31 static const char whitespace[] = " \t\r\n";
32 static dynamic_reference_nocheck<HashProvider>* sha1;
33
34 class WebSocketHookProvider : public IOHookProvider
35 {
36  public:
37         OriginList allowedorigins;
38         bool sendastext;
39
40         WebSocketHookProvider(Module* mod)
41                 : IOHookProvider(mod, "websocket", IOHookProvider::IOH_UNKNOWN, true)
42         {
43         }
44
45         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE;
46
47         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
48         {
49         }
50 };
51
52 class WebSocketHook : public IOHookMiddle
53 {
54         class HTTPHeaderFinder
55         {
56                 std::string::size_type bpos;
57                 std::string::size_type len;
58
59          public:
60                 bool Find(const std::string& req, const char* header, std::string::size_type headerlen, std::string::size_type maxpos)
61                 {
62                         std::string::size_type keybegin = req.find(header);
63                         if ((keybegin == std::string::npos) || (keybegin > maxpos) || (keybegin == 0) || (req[keybegin-1] != '\n'))
64                                 return false;
65
66                         keybegin += headerlen;
67
68                         bpos = req.find_first_not_of(whitespace, keybegin, sizeof(whitespace)-1);
69                         if ((bpos == std::string::npos) || (bpos > maxpos))
70                                 return false;
71
72                         const std::string::size_type epos = req.find_first_of(whitespace, bpos, sizeof(whitespace)-1);
73                         len = epos - bpos;
74
75                         return true;
76                 }
77
78                 std::string ExtractValue(const std::string& req) const
79                 {
80                         return std::string(req, bpos, len);
81                 }
82         };
83
84         enum OpCode
85         {
86                 OP_CONTINUATION = 0x00,
87                 OP_TEXT = 0x01,
88                 OP_BINARY = 0x02,
89                 OP_CLOSE = 0x08,
90                 OP_PING = 0x09,
91                 OP_PONG = 0x0a
92         };
93
94         enum State
95         {
96                 STATE_HTTPREQ,
97                 STATE_ESTABLISHED
98         };
99
100         static const unsigned char WS_MASKBIT = (1 << 7);
101         static const unsigned char WS_FINBIT = (1 << 7);
102         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_LARGE = 126;
103         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_HUGE = 127;
104         static const size_t WS_MAX_PAYLOAD_LENGTH_SMALL = 125;
105         static const size_t WS_MAX_PAYLOAD_LENGTH_LARGE = 65535;
106         static const size_t MAXHEADERSIZE = sizeof(uint64_t) + 2;
107
108         // Clients sending ping or pong frames faster than this are killed
109         static const time_t MINPINGPONGDELAY = 10;
110
111         State state;
112         time_t lastpingpong;
113         OriginList& allowedorigins;
114         bool& sendastext;
115
116         static size_t FillHeader(unsigned char* outbuf, size_t sendlength, OpCode opcode)
117         {
118                 size_t pos = 0;
119                 outbuf[pos++] = WS_FINBIT | opcode;
120
121                 if (sendlength <= WS_MAX_PAYLOAD_LENGTH_SMALL)
122                 {
123                         outbuf[pos++] = sendlength;
124                 }
125                 else if (sendlength <= WS_MAX_PAYLOAD_LENGTH_LARGE)
126                 {
127                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_LARGE;
128                         outbuf[pos++] = (sendlength >> 8) & 0xff;
129                         outbuf[pos++] = sendlength & 0xff;
130                 }
131                 else
132                 {
133                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_HUGE;
134                         const uint64_t len = sendlength;
135                         for (int i = sizeof(uint64_t)-1; i >= 0; i--)
136                                 outbuf[pos++] = ((len >> i*8) & 0xff);
137                 }
138
139                 return pos;
140         }
141
142         static StreamSocket::SendQueue::Element PrepareSendQElem(size_t size, OpCode opcode)
143         {
144                 unsigned char header[MAXHEADERSIZE];
145                 const size_t n = FillHeader(header, size, opcode);
146
147                 return StreamSocket::SendQueue::Element(reinterpret_cast<const char*>(header), n);
148         }
149
150         int HandleAppData(StreamSocket* sock, std::string& appdataout, bool allowlarge)
151         {
152                 std::string& myrecvq = GetRecvQ();
153                 // Need 1 byte opcode, minimum 1 byte len, 4 bytes masking key
154                 if (myrecvq.length() < 6)
155                         return 0;
156
157                 const std::string& cmyrecvq = myrecvq;
158                 unsigned char len1 = (unsigned char)cmyrecvq[1];
159                 if (!(len1 & WS_MASKBIT))
160                 {
161                         sock->SetError("WebSocket protocol violation: unmasked client frame");
162                         return -1;
163                 }
164
165                 len1 &= ~WS_MASKBIT;
166
167                 // Assume the length is a single byte, if not, update values later
168                 unsigned int len = len1;
169                 unsigned int payloadstartoffset = 6;
170                 const unsigned char* maskkey = reinterpret_cast<const unsigned char*>(&cmyrecvq[2]);
171
172                 if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE)
173                 {
174                         // allowlarge is false for control frames according to the RFC meaning large pings, etc. are not allowed
175                         if (!allowlarge)
176                         {
177                                 sock->SetError("WebSocket protocol violation: large control frame");
178                                 return -1;
179                         }
180
181                         // Large frame, has 2 bytes len after the magic byte indicating the length
182                         // Need 1 byte opcode, 3 bytes len, 4 bytes masking key
183                         if (myrecvq.length() < 8)
184                                 return 0;
185
186                         unsigned char len2 = (unsigned char)cmyrecvq[2];
187                         unsigned char len3 = (unsigned char)cmyrecvq[3];
188                         len = (len2 << 8) | len3;
189
190                         if (len <= WS_MAX_PAYLOAD_LENGTH_SMALL)
191                         {
192                                 sock->SetError("WebSocket protocol violation: non-minimal length encoding used");
193                                 return -1;
194                         }
195
196                         maskkey += 2;
197                         payloadstartoffset += 2;
198                 }
199                 else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE)
200                 {
201                         sock->SetError("WebSocket: Huge frames are not supported");
202                         return -1;
203                 }
204
205                 if (myrecvq.length() < payloadstartoffset + len)
206                         return 0;
207
208                 unsigned int maskkeypos = 0;
209                 const std::string::iterator endit = myrecvq.begin() + payloadstartoffset + len;
210                 for (std::string::const_iterator i = myrecvq.begin() + payloadstartoffset; i != endit; ++i)
211                 {
212                         const unsigned char c = (unsigned char)*i;
213                         appdataout.push_back(c ^ maskkey[maskkeypos++]);
214                         maskkeypos %= 4;
215                 }
216
217                 myrecvq.erase(myrecvq.begin(), endit);
218                 return 1;
219         }
220
221         int HandlePingPongFrame(StreamSocket* sock, bool isping)
222         {
223                 if (lastpingpong + MINPINGPONGDELAY >= ServerInstance->Time())
224                 {
225                         sock->SetError("WebSocket: Ping/pong flood");
226                         return -1;
227                 }
228
229                 lastpingpong = ServerInstance->Time();
230
231                 std::string appdata;
232                 const int result = HandleAppData(sock, appdata, false);
233                 // If it's a pong stop here regardless of the result so we won't generate a reply
234                 if ((result <= 0) || (!isping))
235                         return result;
236
237                 StreamSocket::SendQueue::Element elem = PrepareSendQElem(appdata.length(), OP_PONG);
238                 elem.append(appdata);
239                 GetSendQ().push_back(elem);
240
241                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
242                 return 1;
243         }
244
245         int HandleWS(StreamSocket* sock, std::string& destrecvq)
246         {
247                 if (GetRecvQ().empty())
248                         return 0;
249
250                 unsigned char opcode = (unsigned char)GetRecvQ().c_str()[0];
251                 switch (opcode & ~WS_FINBIT)
252                 {
253                         case OP_CONTINUATION:
254                         case OP_TEXT:
255                         case OP_BINARY:
256                         {
257                                 std::string appdata;
258                                 const int result = HandleAppData(sock, appdata, true);
259                                 if (result != 1)
260                                         return result;
261
262                                 // Strip out any CR+LF which may have been erroneously sent.
263                                 for (std::string::const_iterator iter = appdata.begin(); iter != appdata.end(); ++iter)
264                                 {
265                                         if (*iter != '\r' && *iter != '\n')
266                                                 destrecvq.push_back(*iter);
267                                 }
268
269                                 // If we are on the final message of this block append a line terminator.
270                                 if (opcode & WS_FINBIT)
271                                         destrecvq.append("\r\n");
272
273                                 return 1;
274                         }
275
276                         case OP_PING:
277                         {
278                                 return HandlePingPongFrame(sock, true);
279                         }
280
281                         case OP_PONG:
282                         {
283                                 // A pong frame may be sent unsolicited, so we have to handle it.
284                                 // It may carry application data which we need to remove from the recvq as well.
285                                 return HandlePingPongFrame(sock, false);
286                         }
287
288                         case OP_CLOSE:
289                         {
290                                 sock->SetError("Connection closed");
291                                 return -1;
292                         }
293
294                         default:
295                         {
296                                 sock->SetError("WebSocket: Invalid opcode");
297                                 return -1;
298                         }
299                 }
300         }
301
302         void FailHandshake(StreamSocket* sock, const char* httpreply, const char* sockerror)
303         {
304                 GetSendQ().push_back(StreamSocket::SendQueue::Element(httpreply));
305                 sock->DoWrite();
306                 sock->SetError(sockerror);
307         }
308
309         int HandleHTTPReq(StreamSocket* sock)
310         {
311                 std::string& recvq = GetRecvQ();
312                 const std::string::size_type reqend = recvq.find("\r\n\r\n");
313                 if (reqend == std::string::npos)
314                         return 0;
315
316                 bool allowedorigin = false;
317                 HTTPHeaderFinder originheader;
318                 if (originheader.Find(recvq, "Origin:", 7, reqend))
319                 {
320                         const std::string origin = originheader.ExtractValue(recvq);
321                         for (OriginList::const_iterator iter = allowedorigins.begin(); iter != allowedorigins.end(); ++iter)
322                         {
323                                 if (InspIRCd::Match(origin, *iter, ascii_case_insensitive_map))
324                                 {
325                                         allowedorigin = true;
326                                         break;
327                                 }
328                         }
329                 }
330
331                 if (!allowedorigin)
332                 {
333                         FailHandshake(sock, "HTTP/1.1 403 Forbidden\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request from a non-whitelisted origin");
334                         return -1;
335                 }
336
337                 HTTPHeaderFinder keyheader;
338                 if (!keyheader.Find(recvq, "Sec-WebSocket-Key:", 18, reqend))
339                 {
340                         FailHandshake(sock, "HTTP/1.1 501 Not Implemented\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request which is not a websocket upgrade");
341                         return -1;
342                 }
343
344                 if (!*sha1)
345                 {
346                         FailHandshake(sock, "HTTP/1.1 503 Service Unavailable\r\nConnection: close\r\n\r\n", "WebSocket: SHA-1 provider missing");
347                         return -1;
348                 }
349
350                 state = STATE_ESTABLISHED;
351
352                 std::string key = keyheader.ExtractValue(recvq);
353                 key.append(MagicGUID);
354
355                 std::string reply = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
356                 reply.append(BinToBase64((*sha1)->GenerateRaw(key), NULL, '=')).append("\r\n\r\n");
357                 GetSendQ().push_back(StreamSocket::SendQueue::Element(reply));
358
359                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
360
361                 recvq.erase(0, reqend + 4);
362
363                 return 1;
364         }
365
366  public:
367         WebSocketHook(IOHookProvider* Prov, StreamSocket* sock, OriginList& AllowedOrigins, bool& SendAsText)
368                 : IOHookMiddle(Prov)
369                 , state(STATE_HTTPREQ)
370                 , lastpingpong(0)
371                 , allowedorigins(AllowedOrigins)
372                 , sendastext(SendAsText)
373         {
374                 sock->AddIOHook(this);
375         }
376
377         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& uppersendq) CXX11_OVERRIDE
378         {
379                 StreamSocket::SendQueue& mysendq = GetSendQ();
380
381                 // Return 1 to allow sending back an error HTTP response
382                 if (state != STATE_ESTABLISHED)
383                         return (mysendq.empty() ? 0 : 1);
384
385                 std::string message;
386                 for (StreamSocket::SendQueue::const_iterator elem = uppersendq.begin(); elem != uppersendq.end(); ++elem)
387                 {
388                         for (StreamSocket::SendQueue::Element::const_iterator chr = elem->begin(); chr != elem->end(); ++chr)
389                         {
390                                 if (*chr == '\n')
391                                 {
392                                         // We have found an entire message. Send it in its own frame.
393                                         if (sendastext)
394                                         {
395                                                 // If we send messages as text then we need to ensure they are valid UTF-8.
396                                                 std::string encoded;
397                                                 utf8::replace_invalid(message.begin(), message.end(), std::back_inserter(encoded));
398
399                                                 mysendq.push_back(PrepareSendQElem(encoded.length(), OP_TEXT));
400                                                 mysendq.push_back(encoded);
401                                         }
402                                         else
403                                         {
404                                                 // Otherwise, send the raw message as a binary frame.
405                                                 mysendq.push_back(PrepareSendQElem(message.length(), OP_BINARY));
406                                                 mysendq.push_back(message);
407                                         }
408                                         message.clear();
409                                 }
410                                 else if (*chr != '\r')
411                                 {
412                                         message.push_back(*chr);
413                                 }
414                         }
415                 }
416
417                 // Empty the upper send queue and push whatever is left back onto it.
418                 uppersendq.clear();
419                 if (!message.empty())
420                 {
421                         uppersendq.push_back(message);
422                         return 0;
423                 }
424
425                 return 1;
426         }
427
428         int OnStreamSocketRead(StreamSocket* sock, std::string& destrecvq) CXX11_OVERRIDE
429         {
430                 if (state == STATE_HTTPREQ)
431                 {
432                         int httpret = HandleHTTPReq(sock);
433                         if (httpret <= 0)
434                                 return httpret;
435                 }
436
437                 int wsret;
438                 do
439                 {
440                         wsret = HandleWS(sock, destrecvq);
441                 }
442                 while ((!GetRecvQ().empty()) && (wsret > 0));
443
444                 return wsret;
445         }
446
447         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
448         {
449         }
450 };
451
452 void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
453 {
454         new WebSocketHook(this, sock, allowedorigins, sendastext);
455 }
456
457 class ModuleWebSocket : public Module
458 {
459         dynamic_reference_nocheck<HashProvider> hash;
460         reference<WebSocketHookProvider> hookprov;
461
462  public:
463         ModuleWebSocket()
464                 : hash(this, "hash/sha1")
465                 , hookprov(new WebSocketHookProvider(this))
466         {
467                 sha1 = &hash;
468         }
469
470         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
471         {
472                 ConfigTagList tags = ServerInstance->Config->ConfTags("wsorigin");
473                 if (tags.first == tags.second)
474                         throw ModuleException("You have loaded the websocket module but not configured any allowed origins!");
475
476                 OriginList allowedorigins;
477                 for (ConfigIter i = tags.first; i != tags.second; ++i)
478                 {
479                         ConfigTag* tag = i->second;
480
481                         // Ensure that we have the <wsorigin:allow> parameter.
482                         const std::string allow = tag->getString("allow");
483                         if (allow.empty())
484                                 throw ModuleException("<wsorigin:allow> is a mandatory field, at " + tag->getTagLocation());
485
486                         allowedorigins.push_back(allow);
487                 }
488
489                 ConfigTag* tag = ServerInstance->Config->ConfValue("websocket");
490                 hookprov->sendastext = tag->getBool("sendastext", true);
491                 hookprov->allowedorigins.swap(allowedorigins);
492         }
493
494         void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
495         {
496                 if (type != ExtensionItem::EXT_USER)
497                         return;
498
499                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
500                 if ((user) && (user->eh.GetModHook(this)))
501                         ServerInstance->Users.QuitUser(user, "WebSocket module unloading");
502         }
503
504         Version GetVersion() CXX11_OVERRIDE
505         {
506                 return Version("Provides RFC 6455 WebSocket support", VF_VENDOR);
507         }
508 };
509
510 MODULE_INIT(ModuleWebSocket)