]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/m_websocket.cpp
12102d2151bd818f392ac82162dbb32dfa117c67
[user/henk/code/inspircd.git] / src / modules / m_websocket.cpp
1 /*
2  * InspIRCd -- Internet Relay Chat Daemon
3  *
4  *   Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
5  *
6  * This file is part of InspIRCd.  InspIRCd is free software: you can
7  * redistribute it and/or modify it under the terms of the GNU General Public
8  * License as published by the Free Software Foundation, version 2.
9  *
10  * This program is distributed in the hope that it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12  * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
13  * details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19
20 #include "inspircd.h"
21 #include "iohook.h"
22 #include "modules/hash.h"
23
24 static const char MagicGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
25 static const char whitespace[] = " \t\r\n";
26 static dynamic_reference_nocheck<HashProvider>* sha1;
27
28 class WebSocketHookProvider : public IOHookProvider
29 {
30  public:
31         WebSocketHookProvider(Module* mod)
32                 : IOHookProvider(mod, "websocket", IOHookProvider::IOH_UNKNOWN, true)
33         {
34         }
35
36         void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE;
37
38         void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
39         {
40         }
41 };
42
43 class WebSocketHook : public IOHookMiddle
44 {
45         class HTTPHeaderFinder
46         {
47                 std::string::size_type bpos;
48                 std::string::size_type len;
49
50          public:
51                 bool Find(const std::string& req, const char* header, std::string::size_type headerlen, std::string::size_type maxpos)
52                 {
53                         std::string::size_type keybegin = req.find(header);
54                         if ((keybegin == std::string::npos) || (keybegin > maxpos) || (keybegin == 0) || (req[keybegin-1] != '\n'))
55                                 return false;
56
57                         keybegin += headerlen;
58
59                         bpos = req.find_first_not_of(whitespace, keybegin, sizeof(whitespace)-1);
60                         if ((bpos == std::string::npos) || (bpos > maxpos))
61                                 return false;
62
63                         const std::string::size_type epos = req.find_first_of(whitespace, bpos, sizeof(whitespace)-1);
64                         len = epos - bpos;
65
66                         return true;
67                 }
68
69                 std::string ExtractValue(const std::string& req) const
70                 {
71                         return std::string(req, bpos, len);
72                 }
73         };
74
75         enum OpCode
76         {
77                 OP_CONTINUATION = 0x00,
78                 OP_TEXT = 0x01,
79                 OP_BINARY = 0x02,
80                 OP_CLOSE = 0x08,
81                 OP_PING = 0x09,
82                 OP_PONG = 0x0a
83         };
84
85         enum State
86         {
87                 STATE_HTTPREQ,
88                 STATE_ESTABLISHED
89         };
90
91         static const unsigned char WS_MASKBIT = (1 << 7);
92         static const unsigned char WS_FINBIT = (1 << 7);
93         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_LARGE = 126;
94         static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_HUGE = 127;
95         static const size_t WS_MAX_PAYLOAD_LENGTH_SMALL = 125;
96         static const size_t WS_MAX_PAYLOAD_LENGTH_LARGE = 65535;
97         static const size_t MAXHEADERSIZE = sizeof(uint64_t) + 2;
98
99         // Clients sending ping or pong frames faster than this are killed
100         static const time_t MINPINGPONGDELAY = 10;
101
102         State state;
103         time_t lastpingpong;
104
105         static size_t FillHeader(unsigned char* outbuf, size_t sendlength, OpCode opcode)
106         {
107                 size_t pos = 0;
108                 outbuf[pos++] = WS_FINBIT | opcode;
109
110                 if (sendlength <= WS_MAX_PAYLOAD_LENGTH_SMALL)
111                 {
112                         outbuf[pos++] = sendlength;
113                 }
114                 else if (sendlength <= WS_MAX_PAYLOAD_LENGTH_LARGE)
115                 {
116                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_LARGE;
117                         outbuf[pos++] = (sendlength >> 8) & 0xff;
118                         outbuf[pos++] = sendlength & 0xff;
119                 }
120                 else
121                 {
122                         outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_HUGE;
123                         const uint64_t len = sendlength;
124                         for (int i = sizeof(uint64_t)-1; i >= 0; i--)
125                                 outbuf[pos++] = ((len >> i*8) & 0xff);
126                 }
127
128                 return pos;
129         }
130
131         static StreamSocket::SendQueue::Element PrepareSendQElem(size_t size, OpCode opcode)
132         {
133                 unsigned char header[MAXHEADERSIZE];
134                 const size_t n = FillHeader(header, size, opcode);
135
136                 return StreamSocket::SendQueue::Element(reinterpret_cast<const char*>(header), n);
137         }
138
139         int HandleAppData(StreamSocket* sock, std::string& appdataout, bool allowlarge)
140         {
141                 std::string& myrecvq = GetRecvQ();
142                 // Need 1 byte opcode, minimum 1 byte len, 4 bytes masking key
143                 if (myrecvq.length() < 6)
144                         return 0;
145
146                 const std::string& cmyrecvq = myrecvq;
147                 unsigned char len1 = (unsigned char)cmyrecvq[1];
148                 if (!(len1 & WS_MASKBIT))
149                 {
150                         sock->SetError("WebSocket protocol violation: unmasked client frame");
151                         return -1;
152                 }
153
154                 len1 &= ~WS_MASKBIT;
155
156                 // Assume the length is a single byte, if not, update values later
157                 unsigned int len = len1;
158                 unsigned int payloadstartoffset = 6;
159                 const unsigned char* maskkey = reinterpret_cast<const unsigned char*>(&cmyrecvq[2]);
160
161                 if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE)
162                 {
163                         // allowlarge is false for control frames according to the RFC meaning large pings, etc. are not allowed
164                         if (!allowlarge)
165                         {
166                                 sock->SetError("WebSocket protocol violation: large control frame");
167                                 return -1;
168                         }
169
170                         // Large frame, has 2 bytes len after the magic byte indicating the length
171                         // Need 1 byte opcode, 3 bytes len, 4 bytes masking key
172                         if (myrecvq.length() < 8)
173                                 return 0;
174
175                         unsigned char len2 = (unsigned char)cmyrecvq[2];
176                         unsigned char len3 = (unsigned char)cmyrecvq[3];
177                         len = (len2 << 8) | len3;
178
179                         if (len <= WS_MAX_PAYLOAD_LENGTH_SMALL)
180                         {
181                                 sock->SetError("WebSocket protocol violation: non-minimal length encoding used");
182                                 return -1;
183                         }
184
185                         maskkey += 2;
186                         payloadstartoffset += 2;
187                 }
188                 else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE)
189                 {
190                         sock->SetError("WebSocket: Huge frames are not supported");
191                         return -1;
192                 }
193
194                 if (myrecvq.length() < payloadstartoffset + len)
195                         return 0;
196
197                 unsigned int maskkeypos = 0;
198                 const std::string::iterator endit = myrecvq.begin() + payloadstartoffset + len;
199                 for (std::string::const_iterator i = myrecvq.begin() + payloadstartoffset; i != endit; ++i)
200                 {
201                         const unsigned char c = (unsigned char)*i;
202                         appdataout.push_back(c ^ maskkey[maskkeypos++]);
203                         maskkeypos %= 4;
204                 }
205
206                 myrecvq.erase(myrecvq.begin(), endit);
207                 return 1;
208         }
209
210         int HandlePingPongFrame(StreamSocket* sock, bool isping)
211         {
212                 if (lastpingpong + MINPINGPONGDELAY >= ServerInstance->Time())
213                 {
214                         sock->SetError("WebSocket: Ping/pong flood");
215                         return -1;
216                 }
217
218                 lastpingpong = ServerInstance->Time();
219
220                 std::string appdata;
221                 const int result = HandleAppData(sock, appdata, false);
222                 // If it's a pong stop here regardless of the result so we won't generate a reply
223                 if ((result <= 0) || (!isping))
224                         return result;
225
226                 StreamSocket::SendQueue::Element elem = PrepareSendQElem(appdata.length(), OP_PONG);
227                 elem.append(appdata);
228                 GetSendQ().push_back(elem);
229
230                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
231                 return 1;
232         }
233
234         int HandleWS(StreamSocket* sock, std::string& destrecvq)
235         {
236                 if (GetRecvQ().empty())
237                         return 0;
238
239                 unsigned char opcode = (unsigned char)GetRecvQ().c_str()[0];
240                 opcode &= ~WS_FINBIT;
241
242                 switch (opcode)
243                 {
244                         case OP_CONTINUATION:
245                         case OP_TEXT:
246                         case OP_BINARY:
247                         {
248                                 return HandleAppData(sock, destrecvq, true);
249                         }
250
251                         case OP_PING:
252                         {
253                                 return HandlePingPongFrame(sock, true);
254                         }
255
256                         case OP_PONG:
257                         {
258                                 // A pong frame may be sent unsolicited, so we have to handle it.
259                                 // It may carry application data which we need to remove from the recvq as well.
260                                 return HandlePingPongFrame(sock, false);
261                         }
262
263                         case OP_CLOSE:
264                         {
265                                 sock->SetError("Connection closed");
266                                 return -1;
267                         }
268
269                         default:
270                         {
271                                 sock->SetError("WebSocket: Invalid opcode");
272                                 return -1;
273                         }
274                 }
275         }
276
277         void FailHandshake(StreamSocket* sock, const char* httpreply, const char* sockerror)
278         {
279                 GetSendQ().push_back(StreamSocket::SendQueue::Element(httpreply));
280                 sock->DoWrite();
281                 sock->SetError(sockerror);
282         }
283
284         int HandleHTTPReq(StreamSocket* sock)
285         {
286                 std::string& recvq = GetRecvQ();
287                 const std::string::size_type reqend = recvq.find("\r\n\r\n");
288                 if (reqend == std::string::npos)
289                         return 0;
290
291                 HTTPHeaderFinder keyheader;
292                 if (!keyheader.Find(recvq, "Sec-WebSocket-Key:", 18, reqend))
293                 {
294                         FailHandshake(sock, "HTTP/1.1 501 Not Implemented\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request which is not a websocket upgrade");
295                         return -1;
296                 }
297
298                 if (!*sha1)
299                 {
300                         FailHandshake(sock, "HTTP/1.1 503 Service Unavailable\r\nConnection: close\r\n\r\n", "WebSocket: SHA-1 provider missing");
301                         return -1;
302                 }
303
304                 state = STATE_ESTABLISHED;
305
306                 std::string key = keyheader.ExtractValue(recvq);
307                 key.append(MagicGUID);
308
309                 std::string reply = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
310                 reply.append(BinToBase64((*sha1)->GenerateRaw(key), NULL, '=')).append("\r\n\r\n");
311                 GetSendQ().push_back(StreamSocket::SendQueue::Element(reply));
312
313                 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
314
315                 recvq.erase(0, reqend + 4);
316
317                 return 1;
318         }
319
320  public:
321         WebSocketHook(IOHookProvider* Prov, StreamSocket* sock)
322                 : IOHookMiddle(Prov)
323                 , state(STATE_HTTPREQ)
324                 , lastpingpong(0)
325         {
326                 sock->AddIOHook(this);
327         }
328
329         int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& uppersendq) CXX11_OVERRIDE
330         {
331                 StreamSocket::SendQueue& mysendq = GetSendQ();
332
333                 // Return 1 to allow sending back an error HTTP response
334                 if (state != STATE_ESTABLISHED)
335                         return (mysendq.empty() ? 0 : 1);
336
337                 if (!uppersendq.empty())
338                 {
339                         StreamSocket::SendQueue::Element elem = PrepareSendQElem(uppersendq.bytes(), OP_BINARY);
340                         mysendq.push_back(elem);
341                         mysendq.moveall(uppersendq);
342                 }
343
344                 return 1;
345         }
346
347         int OnStreamSocketRead(StreamSocket* sock, std::string& destrecvq) CXX11_OVERRIDE
348         {
349                 if (state == STATE_HTTPREQ)
350                 {
351                         int httpret = HandleHTTPReq(sock);
352                         if (httpret <= 0)
353                                 return httpret;
354                 }
355
356                 int wsret;
357                 do
358                 {
359                         wsret = HandleWS(sock, destrecvq);
360                 }
361                 while ((!GetRecvQ().empty()) && (wsret > 0));
362
363                 return wsret;
364         }
365
366         void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
367         {
368         }
369 };
370
371 void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
372 {
373         new WebSocketHook(this, sock);
374 }
375
376 class ModuleWebSocket : public Module
377 {
378         dynamic_reference_nocheck<HashProvider> hash;
379         reference<WebSocketHookProvider> hookprov;
380
381  public:
382         ModuleWebSocket()
383                 : hash(this, "hash/sha1")
384                 , hookprov(new WebSocketHookProvider(this))
385         {
386                 sha1 = &hash;
387         }
388
389         void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
390         {
391                 if (type != ExtensionItem::EXT_USER)
392                         return;
393
394                 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
395                 if ((user) && (user->eh.GetModHook(this)))
396                         ServerInstance->Users.QuitUser(user, "WebSocket module unloading");
397         }
398
399         Version GetVersion() CXX11_OVERRIDE
400         {
401                 return Version("Provides RFC 6455 WebSocket support", VF_VENDOR);
402         }
403 };
404
405 MODULE_INIT(ModuleWebSocket)