2 * InspIRCd -- Internet Relay Chat Daemon
4 * Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
6 * This file is part of InspIRCd. InspIRCd is free software: you can
7 * redistribute it and/or modify it under the terms of the GNU General Public
8 * License as published by the Free Software Foundation, version 2.
10 * This program is distributed in the hope that it will be useful, but WITHOUT
11 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
15 * You should have received a copy of the GNU General Public License
16 * along with this program. If not, see <http://www.gnu.org/licenses/>.
22 #include "modules/hash.h"
24 static const char MagicGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
25 static const char whitespace[] = " \t\r\n";
26 static dynamic_reference_nocheck<HashProvider>* sha1;
28 class WebSocketHookProvider : public IOHookProvider
31 WebSocketHookProvider(Module* mod)
32 : IOHookProvider(mod, "websocket", IOHookProvider::IOH_UNKNOWN, true)
36 void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE;
38 void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
43 class WebSocketHook : public IOHookMiddle
45 class HTTPHeaderFinder
47 std::string::size_type bpos;
48 std::string::size_type len;
51 bool Find(const std::string& req, const char* header, std::string::size_type headerlen, std::string::size_type maxpos)
53 std::string::size_type keybegin = req.find(header);
54 if ((keybegin == std::string::npos) || (keybegin > maxpos) || (keybegin == 0) || (req[keybegin-1] != '\n'))
57 keybegin += headerlen;
59 bpos = req.find_first_not_of(whitespace, keybegin, sizeof(whitespace)-1);
60 if ((bpos == std::string::npos) || (bpos > maxpos))
63 const std::string::size_type epos = req.find_first_of(whitespace, bpos, sizeof(whitespace)-1);
69 std::string ExtractValue(const std::string& req) const
71 return std::string(req, bpos, len);
77 OP_CONTINUATION = 0x00,
91 static const unsigned char WS_MASKBIT = (1 << 7);
92 static const unsigned char WS_FINBIT = (1 << 7);
93 static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_LARGE = 126;
94 static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_HUGE = 127;
95 static const size_t WS_MAX_PAYLOAD_LENGTH_SMALL = 125;
96 static const size_t WS_MAX_PAYLOAD_LENGTH_LARGE = 65535;
97 static const size_t MAXHEADERSIZE = sizeof(uint64_t) + 2;
99 // Clients sending ping or pong frames faster than this are killed
100 static const time_t MINPINGPONGDELAY = 10;
105 static size_t FillHeader(unsigned char* outbuf, size_t sendlength, OpCode opcode)
108 outbuf[pos++] = WS_FINBIT | opcode;
110 if (sendlength <= WS_MAX_PAYLOAD_LENGTH_SMALL)
112 outbuf[pos++] = sendlength;
114 else if (sendlength <= WS_MAX_PAYLOAD_LENGTH_LARGE)
116 outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_LARGE;
117 outbuf[pos++] = (sendlength >> 8) & 0xff;
118 outbuf[pos++] = sendlength & 0xff;
122 outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_HUGE;
123 const uint64_t len = sendlength;
124 for (int i = sizeof(uint64_t)-1; i >= 0; i--)
125 outbuf[pos++] = ((len >> i*8) & 0xff);
131 static StreamSocket::SendQueue::Element PrepareSendQElem(size_t size, OpCode opcode)
133 unsigned char header[MAXHEADERSIZE];
134 const size_t n = FillHeader(header, size, opcode);
136 return StreamSocket::SendQueue::Element(reinterpret_cast<const char*>(header), n);
139 int HandleAppData(StreamSocket* sock, std::string& appdataout, bool allowlarge)
141 std::string& myrecvq = GetRecvQ();
142 // Need 1 byte opcode, minimum 1 byte len, 4 bytes masking key
143 if (myrecvq.length() < 6)
146 const std::string& cmyrecvq = myrecvq;
147 unsigned char len1 = (unsigned char)cmyrecvq[1];
148 if (!(len1 & WS_MASKBIT))
150 sock->SetError("WebSocket protocol violation: unmasked client frame");
156 // Assume the length is a single byte, if not, update values later
157 unsigned int len = len1;
158 unsigned int payloadstartoffset = 6;
159 const unsigned char* maskkey = reinterpret_cast<const unsigned char*>(&cmyrecvq[2]);
161 if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE)
163 // allowlarge is false for control frames according to the RFC meaning large pings, etc. are not allowed
166 sock->SetError("WebSocket protocol violation: large control frame");
170 // Large frame, has 2 bytes len after the magic byte indicating the length
171 // Need 1 byte opcode, 3 bytes len, 4 bytes masking key
172 if (myrecvq.length() < 8)
175 unsigned char len2 = (unsigned char)cmyrecvq[2];
176 unsigned char len3 = (unsigned char)cmyrecvq[3];
177 len = (len2 << 8) | len3;
179 if (len <= WS_MAX_PAYLOAD_LENGTH_SMALL)
181 sock->SetError("WebSocket protocol violation: non-minimal length encoding used");
186 payloadstartoffset += 2;
188 else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE)
190 sock->SetError("WebSocket: Huge frames are not supported");
194 if (myrecvq.length() < payloadstartoffset + len)
197 unsigned int maskkeypos = 0;
198 const std::string::iterator endit = myrecvq.begin() + payloadstartoffset + len;
199 for (std::string::const_iterator i = myrecvq.begin() + payloadstartoffset; i != endit; ++i)
201 const unsigned char c = (unsigned char)*i;
202 appdataout.push_back(c ^ maskkey[maskkeypos++]);
206 myrecvq.erase(myrecvq.begin(), endit);
210 int HandlePingPongFrame(StreamSocket* sock, bool isping)
212 if (lastpingpong + MINPINGPONGDELAY >= ServerInstance->Time())
214 sock->SetError("WebSocket: Ping/pong flood");
218 lastpingpong = ServerInstance->Time();
221 const int result = HandleAppData(sock, appdata, false);
222 // If it's a pong stop here regardless of the result so we won't generate a reply
223 if ((result <= 0) || (!isping))
226 StreamSocket::SendQueue::Element elem = PrepareSendQElem(appdata.length(), OP_PONG);
227 elem.append(appdata);
228 GetSendQ().push_back(elem);
230 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
234 int HandleWS(StreamSocket* sock, std::string& destrecvq)
236 if (GetRecvQ().empty())
239 unsigned char opcode = (unsigned char)GetRecvQ().c_str()[0];
240 opcode &= ~WS_FINBIT;
244 case OP_CONTINUATION:
248 return HandleAppData(sock, destrecvq, true);
253 return HandlePingPongFrame(sock, true);
258 // A pong frame may be sent unsolicited, so we have to handle it.
259 // It may carry application data which we need to remove from the recvq as well.
260 return HandlePingPongFrame(sock, false);
265 sock->SetError("Connection closed");
271 sock->SetError("WebSocket: Invalid opcode");
277 void FailHandshake(StreamSocket* sock, const char* httpreply, const char* sockerror)
279 GetSendQ().push_back(StreamSocket::SendQueue::Element(httpreply));
281 sock->SetError(sockerror);
284 int HandleHTTPReq(StreamSocket* sock)
286 std::string& recvq = GetRecvQ();
287 const std::string::size_type reqend = recvq.find("\r\n\r\n");
288 if (reqend == std::string::npos)
291 HTTPHeaderFinder keyheader;
292 if (!keyheader.Find(recvq, "Sec-WebSocket-Key:", 18, reqend))
294 FailHandshake(sock, "HTTP/1.1 501 Not Implemented\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request which is not a websocket upgrade");
300 FailHandshake(sock, "HTTP/1.1 503 Service Unavailable\r\nConnection: close\r\n\r\n", "WebSocket: SHA-1 provider missing");
304 state = STATE_ESTABLISHED;
306 std::string key = keyheader.ExtractValue(recvq);
307 key.append(MagicGUID);
309 std::string reply = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
310 reply.append(BinToBase64((*sha1)->GenerateRaw(key), NULL, '=')).append("\r\n\r\n");
311 GetSendQ().push_back(StreamSocket::SendQueue::Element(reply));
313 SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
315 recvq.erase(0, reqend + 4);
321 WebSocketHook(IOHookProvider* Prov, StreamSocket* sock)
323 , state(STATE_HTTPREQ)
326 sock->AddIOHook(this);
329 int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& uppersendq) CXX11_OVERRIDE
331 StreamSocket::SendQueue& mysendq = GetSendQ();
333 // Return 1 to allow sending back an error HTTP response
334 if (state != STATE_ESTABLISHED)
335 return (mysendq.empty() ? 0 : 1);
337 if (!uppersendq.empty())
339 StreamSocket::SendQueue::Element elem = PrepareSendQElem(uppersendq.bytes(), OP_BINARY);
340 mysendq.push_back(elem);
341 mysendq.moveall(uppersendq);
347 int OnStreamSocketRead(StreamSocket* sock, std::string& destrecvq) CXX11_OVERRIDE
349 if (state == STATE_HTTPREQ)
351 int httpret = HandleHTTPReq(sock);
359 wsret = HandleWS(sock, destrecvq);
361 while ((!GetRecvQ().empty()) && (wsret > 0));
366 void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
371 void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
373 new WebSocketHook(this, sock);
376 class ModuleWebSocket : public Module
378 dynamic_reference_nocheck<HashProvider> hash;
379 reference<WebSocketHookProvider> hookprov;
383 : hash(this, "hash/sha1")
384 , hookprov(new WebSocketHookProvider(this))
389 void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE
391 if (type != ExtensionItem::EXT_USER)
394 LocalUser* user = IS_LOCAL(static_cast<User*>(item));
395 if ((user) && (user->eh.GetModHook(this)))
396 ServerInstance->Users.QuitUser(user, "WebSocket module unloading");
399 Version GetVersion() CXX11_OVERRIDE
401 return Version("Provides RFC 6455 WebSocket support", VF_VENDOR);
405 MODULE_INIT(ModuleWebSocket)