]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/m_websocket.cpp
7458e56b699a113ca40ca0b47c2a239c8e9198ef
[user/henk/code/inspircd.git] / src / modules / m_websocket.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
5  *
6  * This file is part of InspIRCd.  InspIRCd is free software: you can
7  * redistribute it and/or modify it under the terms of the GNU General Public
8  * License as published by the Free Software Foundation, version 2.
9  *
10  * This program is distributed in the hope that it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
13  * details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19
20 #include "inspircd.h"
21 #include "iohook.h"
22 #include "modules/hash.h"
23
24 typedef std::vector<std::string> OriginList;
25
26 static const char MagicGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
27 static const char whitespace[] = " \t\r\n";
28 static dynamic_reference_nocheck<HashProvider>* sha1;
29
30 class WebSocketHookProvider : public IOHookProvider
31 {
32  public:
33         OriginList allowedorigins;
34
35         WebSocketHookProvider(Module* mod)
36                 : IOHookProvider(mod, "websocket", IOHookProvider::IOH_UNKNOWN, true)
37         {
38         }
39
40         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE;
41
42         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
43         {
44         }
45 };
46
47 class WebSocketHook : public IOHookMiddle
48 {
49         class HTTPHeaderFinder
50         {
51                 std::string::size_type bpos;
52                 std::string::size_type len;
53
54          public:
55                 bool Find(const std::string& req, const char* header, std::string::size_type headerlen, std::string::size_type maxpos)
56                 {
57                         std::string::size_type keybegin = req.find(header);
58                         if ((keybegin == std::string::npos) || (keybegin > maxpos) || (keybegin == 0) || (req[keybegin-1] != '\n'))
59                                 return false;
60
61                         keybegin += headerlen;
62
63                         bpos = req.find_first_not_of(whitespace, keybegin, sizeof(whitespace)-1);
64                         if ((bpos == std::string::npos) || (bpos > maxpos))
65                                 return false;
66
67                         const std::string::size_type epos = req.find_first_of(whitespace, bpos, sizeof(whitespace)-1);
68                         len = epos - bpos;
69
70                         return true;
71                 }
72
73                 std::string ExtractValue(const std::string& req) const
74                 {
75                         return std::string(req, bpos, len);
76                 }
77         };
78
79         enum OpCode
80         {
81                 OP_CONTINUATION = 0x00,
82                 OP_TEXT = 0x01,
83                 OP_BINARY = 0x02,
84                 OP_CLOSE = 0x08,
85                 OP_PING = 0x09,
86                 OP_PONG = 0x0a
87         };
88
89         enum State
90         {
91                 STATE_HTTPREQ,
92                 STATE_ESTABLISHED
93         };
94
95         static const unsigned char WS_MASKBIT = (1 << 7);
96         static const unsigned char WS_FINBIT = (1 << 7);
97         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_LARGE = 126;
98         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_HUGE = 127;
99         static const size_t WS_MAX_PAYLOAD_LENGTH_SMALL = 125;
100         static const size_t WS_MAX_PAYLOAD_LENGTH_LARGE = 65535;
101         static const size_t MAXHEADERSIZE = sizeof(uint64_t) + 2;
102
103         // Clients sending ping or pong frames faster than this are killed
104         static const time_t MINPINGPONGDELAY = 10;
105
106         State state;
107         time_t lastpingpong;
108         OriginList& allowedorigins;
109
110         static size_t FillHeader(unsigned char* outbuf, size_t sendlength, OpCode opcode)
111         {
112                 size_t pos = 0;
113                 outbuf[pos++] = WS_FINBIT | opcode;
114
115                 if (sendlength <= WS_MAX_PAYLOAD_LENGTH_SMALL)
116                 {
117                         outbuf[pos++] = sendlength;
118                 }
119                 else if (sendlength <= WS_MAX_PAYLOAD_LENGTH_LARGE)
120                 {
121                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_LARGE;
122                         outbuf[pos++] = (sendlength >> 8) & 0xff;
123                         outbuf[pos++] = sendlength & 0xff;
124                 }
125                 else
126                 {
127                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_HUGE;
128                         const uint64_t len = sendlength;
129                         for (int i = sizeof(uint64_t)-1; i >= 0; i--)
130                                 outbuf[pos++] = ((len >> i*8) & 0xff);
131                 }
132
133                 return pos;
134         }
135
136         static StreamSocket::SendQueue::Element PrepareSendQElem(size_t size, OpCode opcode)
137         {
138                 unsigned char header[MAXHEADERSIZE];
139                 const size_t n = FillHeader(header, size, opcode);
140
141                 return StreamSocket::SendQueue::Element(reinterpret_cast<const char*>(header), n);
142         }
143
144         int HandleAppData(StreamSocket* sock, std::string& appdataout, bool allowlarge)
145         {
146                 std::string& myrecvq = GetRecvQ();
147                 // Need 1 byte opcode, minimum 1 byte len, 4 bytes masking key
148                 if (myrecvq.length() < 6)
149                         return 0;
150
151                 const std::string& cmyrecvq = myrecvq;
152                 unsigned char len1 = (unsigned char)cmyrecvq[1];
153                 if (!(len1 & WS_MASKBIT))
154                 {
155                         sock->SetError("WebSocket protocol violation: unmasked client frame");
156                         return -1;
157                 }
158
159                 len1 &= ~WS_MASKBIT;
160
161                 // Assume the length is a single byte, if not, update values later
162                 unsigned int len = len1;
163                 unsigned int payloadstartoffset = 6;
164                 const unsigned char* maskkey = reinterpret_cast<const unsigned char*>(&cmyrecvq[2]);
165
166                 if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE)
167                 {
168                         // allowlarge is false for control frames according to the RFC meaning large pings, etc. are not allowed
169                         if (!allowlarge)
170                         {
171                                 sock->SetError("WebSocket protocol violation: large control frame");
172                                 return -1;
173                         }
174
175                         // Large frame, has 2 bytes len after the magic byte indicating the length
176                         // Need 1 byte opcode, 3 bytes len, 4 bytes masking key
177                         if (myrecvq.length() < 8)
178                                 return 0;
179
180                         unsigned char len2 = (unsigned char)cmyrecvq[2];
181                         unsigned char len3 = (unsigned char)cmyrecvq[3];
182                         len = (len2 << 8) | len3;
183
184                         if (len <= WS_MAX_PAYLOAD_LENGTH_SMALL)
185                         {
186                                 sock->SetError("WebSocket protocol violation: non-minimal length encoding used");
187                                 return -1;
188                         }
189
190                         maskkey += 2;
191                         payloadstartoffset += 2;
192                 }
193                 else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE)
194                 {
195                         sock->SetError("WebSocket: Huge frames are not supported");
196                         return -1;
197                 }
198
199                 if (myrecvq.length() < payloadstartoffset + len)
200                         return 0;
201
202                 unsigned int maskkeypos = 0;
203                 const std::string::iterator endit = myrecvq.begin() + payloadstartoffset + len;
204                 for (std::string::const_iterator i = myrecvq.begin() + payloadstartoffset; i != endit; ++i)
205                 {
206                         const unsigned char c = (unsigned char)*i;
207                         appdataout.push_back(c ^ maskkey[maskkeypos++]);
208                         maskkeypos %= 4;
209                 }
210
211                 myrecvq.erase(myrecvq.begin(), endit);
212                 return 1;
213         }
214
215         int HandlePingPongFrame(StreamSocket* sock, bool isping)
216         {
217                 if (lastpingpong + MINPINGPONGDELAY >= ServerInstance->Time())
218                 {
219                         sock->SetError("WebSocket: Ping/pong flood");
220                         return -1;
221                 }
222
223                 lastpingpong = ServerInstance->Time();
224
225                 std::string appdata;
226                 const int result = HandleAppData(sock, appdata, false);
227                 // If it's a pong stop here regardless of the result so we won't generate a reply
228                 if ((result <= 0) || (!isping))
229                         return result;
230
231                 StreamSocket::SendQueue::Element elem = PrepareSendQElem(appdata.length(), OP_PONG);
232                 elem.append(appdata);
233                 GetSendQ().push_back(elem);
234
235                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
236                 return 1;
237         }
238
239         int HandleWS(StreamSocket* sock, std::string& destrecvq)
240         {
241                 if (GetRecvQ().empty())
242                         return 0;
243
244                 unsigned char opcode = (unsigned char)GetRecvQ().c_str()[0];
245                 switch (opcode & ~WS_FINBIT)
246                 {
247                         case OP_CONTINUATION:
248                         case OP_TEXT:
249                         case OP_BINARY:
250                         {
251                                 std::string appdata;
252                                 const int result = HandleAppData(sock, appdata, true);
253                                 if (result != 1)
254                                         return result;
255
256                                 // Strip out any CR+LF which may have been erroneously sent.
257                                 for (std::string::const_iterator iter = appdata.begin(); iter != appdata.end(); ++iter)
258                                 {
259                                         if (*iter != '\r' && *iter != '\n')
260                                                 destrecvq.push_back(*iter);
261                                 }
262
263                                 // If we are on the final message of this block append a line terminator.
264                                 if (opcode & WS_FINBIT)
265                                         destrecvq.append("\r\n");
266
267                                 return 1;
268                         }
269
270                         case OP_PING:
271                         {
272                                 return HandlePingPongFrame(sock, true);
273                         }
274
275                         case OP_PONG:
276                         {
277                                 // A pong frame may be sent unsolicited, so we have to handle it.
278                                 // It may carry application data which we need to remove from the recvq as well.
279                                 return HandlePingPongFrame(sock, false);
280                         }
281
282                         case OP_CLOSE:
283                         {
284                                 sock->SetError("Connection closed");
285                                 return -1;
286                         }
287
288                         default:
289                         {
290                                 sock->SetError("WebSocket: Invalid opcode");
291                                 return -1;
292                         }
293                 }
294         }
295
296         void FailHandshake(StreamSocket* sock, const char* httpreply, const char* sockerror)
297         {
298                 GetSendQ().push_back(StreamSocket::SendQueue::Element(httpreply));
299                 sock->DoWrite();
300                 sock->SetError(sockerror);
301         }
302
303         int HandleHTTPReq(StreamSocket* sock)
304         {
305                 std::string& recvq = GetRecvQ();
306                 const std::string::size_type reqend = recvq.find("\r\n\r\n");
307                 if (reqend == std::string::npos)
308                         return 0;
309
310                 bool allowedorigin = false;
311                 HTTPHeaderFinder originheader;
312                 if (originheader.Find(recvq, "Origin:", 7, reqend))
313                 {
314                         const std::string origin = originheader.ExtractValue(recvq);
315                         for (OriginList::const_iterator iter = allowedorigins.begin(); iter != allowedorigins.end(); ++iter)
316                         {
317                                 if (InspIRCd::Match(origin, *iter, ascii_case_insensitive_map))
318                                 {
319                                         allowedorigin = true;
320                                         break;
321                                 }
322                         }
323                 }
324
325                 if (!allowedorigin)
326                 {
327                         FailHandshake(sock, "HTTP/1.1 403 Forbidden\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request from a non-whitelisted origin");
328                         return -1;
329                 }
330
331                 HTTPHeaderFinder keyheader;
332                 if (!keyheader.Find(recvq, "Sec-WebSocket-Key:", 18, reqend))
333                 {
334                         FailHandshake(sock, "HTTP/1.1 501 Not Implemented\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request which is not a websocket upgrade");
335                         return -1;
336                 }
337
338                 if (!*sha1)
339                 {
340                         FailHandshake(sock, "HTTP/1.1 503 Service Unavailable\r\nConnection: close\r\n\r\n", "WebSocket: SHA-1 provider missing");
341                         return -1;
342                 }
343
344                 state = STATE_ESTABLISHED;
345
346                 std::string key = keyheader.ExtractValue(recvq);
347                 key.append(MagicGUID);
348
349                 std::string reply = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
350                 reply.append(BinToBase64((*sha1)->GenerateRaw(key), NULL, '=')).append("\r\n\r\n");
351                 GetSendQ().push_back(StreamSocket::SendQueue::Element(reply));
352
353                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
354
355                 recvq.erase(0, reqend + 4);
356
357                 return 1;
358         }
359
360  public:
361         WebSocketHook(IOHookProvider* Prov, StreamSocket* sock, OriginList& AllowedOrigins)
362                 : IOHookMiddle(Prov)
363                 , state(STATE_HTTPREQ)
364                 , lastpingpong(0)
365                 , allowedorigins(AllowedOrigins)
366         {
367                 sock->AddIOHook(this);
368         }
369
370         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& uppersendq) CXX11_OVERRIDE
371         {
372                 StreamSocket::SendQueue& mysendq = GetSendQ();
373
374                 // Return 1 to allow sending back an error HTTP response
375                 if (state != STATE_ESTABLISHED)
376                         return (mysendq.empty() ? 0 : 1);
377
378                 std::string message;
379                 for (StreamSocket::SendQueue::const_iterator elem = uppersendq.begin(); elem != uppersendq.end(); ++elem)
380                 {
381                         for (StreamSocket::SendQueue::Element::const_iterator chr = elem->begin(); chr != elem->end(); ++chr)
382                         {
383                                 if (*chr == '\n')
384                                 {
385                                         // We have found an entire message. Send it in its own frame.
386                                         mysendq.push_back(PrepareSendQElem(message.length(), OP_BINARY));
387                                         mysendq.push_back(message);
388                                         message.clear();
389                                 }
390                                 else if (*chr != '\r')
391                                 {
392                                         message.push_back(*chr);
393                                 }
394                         }
395                 }
396
397                 // Empty the upper send queue and push whatever is left back onto it.
398                 uppersendq.clear();
399                 if (!message.empty())
400                 {
401                         uppersendq.push_back(message);
402                         return 0;
403                 }
404
405                 return 1;
406         }
407
408         int OnStreamSocketRead(StreamSocket* sock, std::string& destrecvq) CXX11_OVERRIDE
409         {
410                 if (state == STATE_HTTPREQ)
411                 {
412                         int httpret = HandleHTTPReq(sock);
413                         if (httpret <= 0)
414                                 return httpret;
415                 }
416
417                 int wsret;
418                 do
419                 {
420                         wsret = HandleWS(sock, destrecvq);
421                 }
422                 while ((!GetRecvQ().empty()) && (wsret > 0));
423
424                 return wsret;
425         }
426
427         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
428         {
429         }
430 };
431
432 void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
433 {
434         new WebSocketHook(this, sock, allowedorigins);
435 }
436
437 class ModuleWebSocket : public Module
438 {
439         dynamic_reference_nocheck<HashProvider> hash;
440         reference<WebSocketHookProvider> hookprov;
441
442  public:
443         ModuleWebSocket()
444                 : hash(this, "hash/sha1")
445                 , hookprov(new WebSocketHookProvider(this))
446         {
447                 sha1 = &hash;
448         }
449
450         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
451         {
452                 ConfigTagList tags = ServerInstance->Config->ConfTags("wsorigin");
453                 if (tags.first == tags.second)
454                         throw ModuleException("You have loaded the websocket module but not configured any allowed origins!");
455
456                 OriginList allowedorigins;
457                 for (ConfigIter i = tags.first; i != tags.second; ++i)
458                 {
459                         ConfigTag* tag = i->second;
460
461                         // Ensure that we have the <wsorigin:allow> parameter.
462                         const std::string allow = tag->getString("allow");
463                         if (allow.empty())
464                                 throw ModuleException("<wsorigin:allow> is a mandatory field, at " + tag->getTagLocation());
465
466                         allowedorigins.push_back(allow);
467                 }
468                 hookprov->allowedorigins.swap(allowedorigins);
469         }
470
471         void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
472         {
473                 if (type != ExtensionItem::EXT_USER)
474                         return;
475
476                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
477                 if ((user) && (user->eh.GetModHook(this)))
478                         ServerInstance->Users.QuitUser(user, "WebSocket module unloading");
479         }
480
481         Version GetVersion() CXX11_OVERRIDE
482         {
483                 return Version("Provides RFC 6455 WebSocket support", VF_VENDOR);
484         }
485 };
486
487 MODULE_INIT(ModuleWebSocket)