]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/m_websocket.cpp
Implement support for websocket connections via a proxy like nginx.
[user/henk/code/inspircd.git] / src / modules / m_websocket.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
5  *
6  * This file is part of InspIRCd.  InspIRCd is free software: you can
7  * redistribute it and/or modify it under the terms of the GNU General Public
8  * License as published by the Free Software Foundation, version 2.
9  *
10  * This program is distributed in the hope that it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
13  * details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19 /// $CompilerFlags: -Ivendor_directory("utfcpp")
20
21
22 #include "inspircd.h"
23 #include "iohook.h"
24 #include "modules/hash.h"
25
26 #include <utf8.h>
27
28 typedef std::vector<std::string> OriginList;
29
30 static const char MagicGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
31 static const char whitespace[] = " \t\r\n";
32 static dynamic_reference_nocheck<HashProvider>* sha1;
33
34 struct WebSocketConfig
35 {
36         // The HTTP origins that can connect to the server.
37         OriginList allowedorigins;
38
39         // Whether to trust the X-Real-IP or X-Forwarded-For headers.
40         bool behindproxy;
41
42         // Whether to send as UTF-8 text instead of binary data.
43         bool sendastext;
44 };
45
46 class WebSocketHookProvider : public IOHookProvider
47 {
48  public:
49         WebSocketConfig config;
50         WebSocketHookProvider(Module* mod)
51                 : IOHookProvider(mod, "websocket", IOHookProvider::IOH_UNKNOWN, true)
52         {
53         }
54
55         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE;
56
57         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
58         {
59         }
60 };
61
62 class WebSocketHook : public IOHookMiddle
63 {
64         class HTTPHeaderFinder
65         {
66                 std::string::size_type bpos;
67                 std::string::size_type len;
68
69          public:
70                 bool Find(const std::string& req, const char* header, std::string::size_type headerlen, std::string::size_type maxpos)
71                 {
72                         std::string::size_type keybegin = req.find(header);
73                         if ((keybegin == std::string::npos) || (keybegin > maxpos) || (keybegin == 0) || (req[keybegin-1] != '\n'))
74                                 return false;
75
76                         keybegin += headerlen;
77
78                         bpos = req.find_first_not_of(whitespace, keybegin, sizeof(whitespace)-1);
79                         if ((bpos == std::string::npos) || (bpos > maxpos))
80                                 return false;
81
82                         const std::string::size_type epos = req.find_first_of(whitespace, bpos, sizeof(whitespace)-1);
83                         len = epos - bpos;
84
85                         return true;
86                 }
87
88                 std::string ExtractValue(const std::string& req) const
89                 {
90                         return std::string(req, bpos, len);
91                 }
92         };
93
94         enum OpCode
95         {
96                 OP_CONTINUATION = 0x00,
97                 OP_TEXT = 0x01,
98                 OP_BINARY = 0x02,
99                 OP_CLOSE = 0x08,
100                 OP_PING = 0x09,
101                 OP_PONG = 0x0a
102         };
103
104         enum State
105         {
106                 STATE_HTTPREQ,
107                 STATE_ESTABLISHED
108         };
109
110         static const unsigned char WS_MASKBIT = (1 << 7);
111         static const unsigned char WS_FINBIT = (1 << 7);
112         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_LARGE = 126;
113         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_HUGE = 127;
114         static const size_t WS_MAX_PAYLOAD_LENGTH_SMALL = 125;
115         static const size_t WS_MAX_PAYLOAD_LENGTH_LARGE = 65535;
116         static const size_t MAXHEADERSIZE = sizeof(uint64_t) + 2;
117
118         // Clients sending ping or pong frames faster than this are killed
119         static const time_t MINPINGPONGDELAY = 10;
120
121         State state;
122         time_t lastpingpong;
123         WebSocketConfig& config;
124
125         static size_t FillHeader(unsigned char* outbuf, size_t sendlength, OpCode opcode)
126         {
127                 size_t pos = 0;
128                 outbuf[pos++] = WS_FINBIT | opcode;
129
130                 if (sendlength <= WS_MAX_PAYLOAD_LENGTH_SMALL)
131                 {
132                         outbuf[pos++] = sendlength;
133                 }
134                 else if (sendlength <= WS_MAX_PAYLOAD_LENGTH_LARGE)
135                 {
136                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_LARGE;
137                         outbuf[pos++] = (sendlength >> 8) & 0xff;
138                         outbuf[pos++] = sendlength & 0xff;
139                 }
140                 else
141                 {
142                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_HUGE;
143                         const uint64_t len = sendlength;
144                         for (int i = sizeof(uint64_t)-1; i >= 0; i--)
145                                 outbuf[pos++] = ((len >> i*8) & 0xff);
146                 }
147
148                 return pos;
149         }
150
151         static StreamSocket::SendQueue::Element PrepareSendQElem(size_t size, OpCode opcode)
152         {
153                 unsigned char header[MAXHEADERSIZE];
154                 const size_t n = FillHeader(header, size, opcode);
155
156                 return StreamSocket::SendQueue::Element(reinterpret_cast<const char*>(header), n);
157         }
158
159         int HandleAppData(StreamSocket* sock, std::string& appdataout, bool allowlarge)
160         {
161                 std::string& myrecvq = GetRecvQ();
162                 // Need 1 byte opcode, minimum 1 byte len, 4 bytes masking key
163                 if (myrecvq.length() < 6)
164                         return 0;
165
166                 const std::string& cmyrecvq = myrecvq;
167                 unsigned char len1 = (unsigned char)cmyrecvq[1];
168                 if (!(len1 & WS_MASKBIT))
169                 {
170                         sock->SetError("WebSocket protocol violation: unmasked client frame");
171                         return -1;
172                 }
173
174                 len1 &= ~WS_MASKBIT;
175
176                 // Assume the length is a single byte, if not, update values later
177                 unsigned int len = len1;
178                 unsigned int payloadstartoffset = 6;
179                 const unsigned char* maskkey = reinterpret_cast<const unsigned char*>(&cmyrecvq[2]);
180
181                 if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE)
182                 {
183                         // allowlarge is false for control frames according to the RFC meaning large pings, etc. are not allowed
184                         if (!allowlarge)
185                         {
186                                 sock->SetError("WebSocket protocol violation: large control frame");
187                                 return -1;
188                         }
189
190                         // Large frame, has 2 bytes len after the magic byte indicating the length
191                         // Need 1 byte opcode, 3 bytes len, 4 bytes masking key
192                         if (myrecvq.length() < 8)
193                                 return 0;
194
195                         unsigned char len2 = (unsigned char)cmyrecvq[2];
196                         unsigned char len3 = (unsigned char)cmyrecvq[3];
197                         len = (len2 << 8) | len3;
198
199                         if (len <= WS_MAX_PAYLOAD_LENGTH_SMALL)
200                         {
201                                 sock->SetError("WebSocket protocol violation: non-minimal length encoding used");
202                                 return -1;
203                         }
204
205                         maskkey += 2;
206                         payloadstartoffset += 2;
207                 }
208                 else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE)
209                 {
210                         sock->SetError("WebSocket: Huge frames are not supported");
211                         return -1;
212                 }
213
214                 if (myrecvq.length() < payloadstartoffset + len)
215                         return 0;
216
217                 unsigned int maskkeypos = 0;
218                 const std::string::iterator endit = myrecvq.begin() + payloadstartoffset + len;
219                 for (std::string::const_iterator i = myrecvq.begin() + payloadstartoffset; i != endit; ++i)
220                 {
221                         const unsigned char c = (unsigned char)*i;
222                         appdataout.push_back(c ^ maskkey[maskkeypos++]);
223                         maskkeypos %= 4;
224                 }
225
226                 myrecvq.erase(myrecvq.begin(), endit);
227                 return 1;
228         }
229
230         int HandlePingPongFrame(StreamSocket* sock, bool isping)
231         {
232                 if (lastpingpong + MINPINGPONGDELAY >= ServerInstance->Time())
233                 {
234                         sock->SetError("WebSocket: Ping/pong flood");
235                         return -1;
236                 }
237
238                 lastpingpong = ServerInstance->Time();
239
240                 std::string appdata;
241                 const int result = HandleAppData(sock, appdata, false);
242                 // If it's a pong stop here regardless of the result so we won't generate a reply
243                 if ((result <= 0) || (!isping))
244                         return result;
245
246                 StreamSocket::SendQueue::Element elem = PrepareSendQElem(appdata.length(), OP_PONG);
247                 elem.append(appdata);
248                 GetSendQ().push_back(elem);
249
250                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
251                 return 1;
252         }
253
254         int HandleWS(StreamSocket* sock, std::string& destrecvq)
255         {
256                 if (GetRecvQ().empty())
257                         return 0;
258
259                 unsigned char opcode = (unsigned char)GetRecvQ().c_str()[0];
260                 switch (opcode & ~WS_FINBIT)
261                 {
262                         case OP_CONTINUATION:
263                         case OP_TEXT:
264                         case OP_BINARY:
265                         {
266                                 std::string appdata;
267                                 const int result = HandleAppData(sock, appdata, true);
268                                 if (result != 1)
269                                         return result;
270
271                                 // Strip out any CR+LF which may have been erroneously sent.
272                                 for (std::string::const_iterator iter = appdata.begin(); iter != appdata.end(); ++iter)
273                                 {
274                                         if (*iter != '\r' && *iter != '\n')
275                                                 destrecvq.push_back(*iter);
276                                 }
277
278                                 // If we are on the final message of this block append a line terminator.
279                                 if (opcode & WS_FINBIT)
280                                         destrecvq.append("\r\n");
281
282                                 return 1;
283                         }
284
285                         case OP_PING:
286                         {
287                                 return HandlePingPongFrame(sock, true);
288                         }
289
290                         case OP_PONG:
291                         {
292                                 // A pong frame may be sent unsolicited, so we have to handle it.
293                                 // It may carry application data which we need to remove from the recvq as well.
294                                 return HandlePingPongFrame(sock, false);
295                         }
296
297                         case OP_CLOSE:
298                         {
299                                 sock->SetError("Connection closed");
300                                 return -1;
301                         }
302
303                         default:
304                         {
305                                 sock->SetError("WebSocket: Invalid opcode");
306                                 return -1;
307                         }
308                 }
309         }
310
311         void FailHandshake(StreamSocket* sock, const char* httpreply, const char* sockerror)
312         {
313                 GetSendQ().push_back(StreamSocket::SendQueue::Element(httpreply));
314                 sock->DoWrite();
315                 sock->SetError(sockerror);
316         }
317
318         int HandleHTTPReq(StreamSocket* sock)
319         {
320                 std::string& recvq = GetRecvQ();
321                 const std::string::size_type reqend = recvq.find("\r\n\r\n");
322                 if (reqend == std::string::npos)
323                         return 0;
324
325                 bool allowedorigin = false;
326                 HTTPHeaderFinder originheader;
327                 if (originheader.Find(recvq, "Origin:", 7, reqend))
328                 {
329                         const std::string origin = originheader.ExtractValue(recvq);
330                         for (OriginList::const_iterator iter = config.allowedorigins.begin(); iter != config.allowedorigins.end(); ++iter)
331                         {
332                                 if (InspIRCd::Match(origin, *iter, ascii_case_insensitive_map))
333                                 {
334                                         allowedorigin = true;
335                                         break;
336                                 }
337                         }
338                 }
339
340                 if (!allowedorigin)
341                 {
342                         FailHandshake(sock, "HTTP/1.1 403 Forbidden\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request from a non-whitelisted origin");
343                         return -1;
344                 }
345
346                 if (config.behindproxy && sock->type == StreamSocket::SS_USER)
347                 {
348                         LocalUser* luser = static_cast<UserIOHandler*>(sock)->user;
349                         irc::sockets::sockaddrs realsa(luser->client_sa);
350
351                         HTTPHeaderFinder proxyheader;
352                         if (proxyheader.Find(recvq, "X-Real-IP:", 10, reqend)
353                                 && irc::sockets::aptosa(proxyheader.ExtractValue(recvq), realsa.port(), realsa))
354                         {
355                                 // Nothing to do here.
356                         }
357                         else if (proxyheader.Find(recvq, "X-Forwarded-For:", 16, reqend)
358                                 && irc::sockets::aptosa(proxyheader.ExtractValue(recvq), realsa.port(), realsa))
359                         {
360                                 // Nothing to do here.
361                         }
362
363                         // Give the user their real IP address.
364                         if (realsa != luser->client_sa)
365                                 luser->SetClientIP(realsa);
366                 }
367
368
369                 HTTPHeaderFinder keyheader;
370                 if (!keyheader.Find(recvq, "Sec-WebSocket-Key:", 18, reqend))
371                 {
372                         FailHandshake(sock, "HTTP/1.1 501 Not Implemented\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request which is not a websocket upgrade");
373                         return -1;
374                 }
375
376                 if (!*sha1)
377                 {
378                         FailHandshake(sock, "HTTP/1.1 503 Service Unavailable\r\nConnection: close\r\n\r\n", "WebSocket: SHA-1 provider missing");
379                         return -1;
380                 }
381
382                 state = STATE_ESTABLISHED;
383
384                 std::string key = keyheader.ExtractValue(recvq);
385                 key.append(MagicGUID);
386
387                 std::string reply = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
388                 reply.append(BinToBase64((*sha1)->GenerateRaw(key), NULL, '=')).append("\r\n\r\n");
389                 GetSendQ().push_back(StreamSocket::SendQueue::Element(reply));
390
391                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
392
393                 recvq.erase(0, reqend + 4);
394
395                 return 1;
396         }
397
398  public:
399         WebSocketHook(IOHookProvider* Prov, StreamSocket* sock, WebSocketConfig& cfg)
400                 : IOHookMiddle(Prov)
401                 , state(STATE_HTTPREQ)
402                 , lastpingpong(0)
403                 , config(cfg)
404         {
405                 sock->AddIOHook(this);
406         }
407
408         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& uppersendq) CXX11_OVERRIDE
409         {
410                 StreamSocket::SendQueue& mysendq = GetSendQ();
411
412                 // Return 1 to allow sending back an error HTTP response
413                 if (state != STATE_ESTABLISHED)
414                         return (mysendq.empty() ? 0 : 1);
415
416                 std::string message;
417                 for (StreamSocket::SendQueue::const_iterator elem = uppersendq.begin(); elem != uppersendq.end(); ++elem)
418                 {
419                         for (StreamSocket::SendQueue::Element::const_iterator chr = elem->begin(); chr != elem->end(); ++chr)
420                         {
421                                 if (*chr == '\n')
422                                 {
423                                         // We have found an entire message. Send it in its own frame.
424                                         if (config.sendastext)
425                                         {
426                                                 // If we send messages as text then we need to ensure they are valid UTF-8.
427                                                 std::string encoded;
428                                                 utf8::replace_invalid(message.begin(), message.end(), std::back_inserter(encoded));
429
430                                                 mysendq.push_back(PrepareSendQElem(encoded.length(), OP_TEXT));
431                                                 mysendq.push_back(encoded);
432                                         }
433                                         else
434                                         {
435                                                 // Otherwise, send the raw message as a binary frame.
436                                                 mysendq.push_back(PrepareSendQElem(message.length(), OP_BINARY));
437                                                 mysendq.push_back(message);
438                                         }
439                                         message.clear();
440                                 }
441                                 else if (*chr != '\r')
442                                 {
443                                         message.push_back(*chr);
444                                 }
445                         }
446                 }
447
448                 // Empty the upper send queue and push whatever is left back onto it.
449                 uppersendq.clear();
450                 if (!message.empty())
451                 {
452                         uppersendq.push_back(message);
453                         return 0;
454                 }
455
456                 return 1;
457         }
458
459         int OnStreamSocketRead(StreamSocket* sock, std::string& destrecvq) CXX11_OVERRIDE
460         {
461                 if (state == STATE_HTTPREQ)
462                 {
463                         int httpret = HandleHTTPReq(sock);
464                         if (httpret <= 0)
465                                 return httpret;
466                 }
467
468                 int wsret;
469                 do
470                 {
471                         wsret = HandleWS(sock, destrecvq);
472                 }
473                 while ((!GetRecvQ().empty()) && (wsret > 0));
474
475                 return wsret;
476         }
477
478         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
479         {
480         }
481 };
482
483 void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
484 {
485         new WebSocketHook(this, sock, config);
486 }
487
488 class ModuleWebSocket : public Module
489 {
490         dynamic_reference_nocheck<HashProvider> hash;
491         reference<WebSocketHookProvider> hookprov;
492
493  public:
494         ModuleWebSocket()
495                 : hash(this, "hash/sha1")
496                 , hookprov(new WebSocketHookProvider(this))
497         {
498                 sha1 = &hash;
499         }
500
501         void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
502         {
503                 ConfigTagList tags = ServerInstance->Config->ConfTags("wsorigin");
504                 if (tags.first == tags.second)
505                         throw ModuleException("You have loaded the websocket module but not configured any allowed origins!");
506
507                 WebSocketConfig config;
508                 for (ConfigIter i = tags.first; i != tags.second; ++i)
509                 {
510                         ConfigTag* tag = i->second;
511
512                         // Ensure that we have the <wsorigin:allow> parameter.
513                         const std::string allow = tag->getString("allow");
514                         if (allow.empty())
515                                 throw ModuleException("<wsorigin:allow> is a mandatory field, at " + tag->getTagLocation());
516
517                         config.allowedorigins.push_back(allow);
518                 }
519
520                 ConfigTag* tag = ServerInstance->Config->ConfValue("websocket");
521                 config.behindproxy = tag->getBool("behindproxy");
522                 config.sendastext = tag->getBool("sendastext", true);
523
524                 // Everything is okay; apply the new config.
525                 hookprov->config = config;
526         }
527
528         void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
529         {
530                 if (type != ExtensionItem::EXT_USER)
531                         return;
532
533                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
534                 if ((user) && (user->eh.GetModHook(this)))
535                         ServerInstance->Users.QuitUser(user, "WebSocket module unloading");
536         }
537
538         Version GetVersion() CXX11_OVERRIDE
539         {
540                 return Version("Provides RFC 6455 WebSocket support", VF_VENDOR);
541         }
542 };
543
544 MODULE_INIT(ModuleWebSocket)