/* * InspIRCd -- Internet Relay Chat Daemon * * Copyright (C) 2016 Attila Molnar * * This file is part of InspIRCd. InspIRCd is free software: you can * redistribute it and/or modify it under the terms of the GNU General Public * License as published by the Free Software Foundation, version 2. * * This program is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more * details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #include "inspircd.h" #include "iohook.h" #include "modules/hash.h" typedef std::vector OriginList; static const char MagicGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; static const char whitespace[] = " \t\r\n"; static dynamic_reference_nocheck* sha1; class WebSocketHookProvider : public IOHookProvider { public: OriginList allowedorigins; WebSocketHookProvider(Module* mod) : IOHookProvider(mod, "websocket", IOHookProvider::IOH_UNKNOWN, true) { } void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE; void OnConnect(StreamSocket* sock) CXX11_OVERRIDE { } }; class WebSocketHook : public IOHookMiddle { class HTTPHeaderFinder { std::string::size_type bpos; std::string::size_type len; public: bool Find(const std::string& req, const char* header, std::string::size_type headerlen, std::string::size_type maxpos) { std::string::size_type keybegin = req.find(header); if ((keybegin == std::string::npos) || (keybegin > maxpos) || (keybegin == 0) || (req[keybegin-1] != '\n')) return false; keybegin += headerlen; bpos = req.find_first_not_of(whitespace, keybegin, sizeof(whitespace)-1); if ((bpos == std::string::npos) || (bpos > maxpos)) return false; const std::string::size_type epos = req.find_first_of(whitespace, bpos, sizeof(whitespace)-1); len = epos - bpos; return true; } std::string ExtractValue(const std::string& req) const { return std::string(req, bpos, len); } }; enum OpCode { OP_CONTINUATION = 0x00, OP_TEXT = 0x01, OP_BINARY = 0x02, OP_CLOSE = 0x08, OP_PING = 0x09, OP_PONG = 0x0a }; enum State { STATE_HTTPREQ, STATE_ESTABLISHED }; static const unsigned char WS_MASKBIT = (1 << 7); static const unsigned char WS_FINBIT = (1 << 7); static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_LARGE = 126; static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_HUGE = 127; static const size_t WS_MAX_PAYLOAD_LENGTH_SMALL = 125; static const size_t WS_MAX_PAYLOAD_LENGTH_LARGE = 65535; static const size_t MAXHEADERSIZE = sizeof(uint64_t) + 2; // Clients sending ping or pong frames faster than this are killed static const time_t MINPINGPONGDELAY = 10; State state; time_t lastpingpong; OriginList& allowedorigins; static size_t FillHeader(unsigned char* outbuf, size_t sendlength, OpCode opcode) { size_t pos = 0; outbuf[pos++] = WS_FINBIT | opcode; if (sendlength <= WS_MAX_PAYLOAD_LENGTH_SMALL) { outbuf[pos++] = sendlength; } else if (sendlength <= WS_MAX_PAYLOAD_LENGTH_LARGE) { outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_LARGE; outbuf[pos++] = (sendlength >> 8) & 0xff; outbuf[pos++] = sendlength & 0xff; } else { outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_HUGE; const uint64_t len = sendlength; for (int i = sizeof(uint64_t)-1; i >= 0; i--) outbuf[pos++] = ((len >> i*8) & 0xff); } return pos; } static StreamSocket::SendQueue::Element PrepareSendQElem(size_t size, OpCode opcode) { unsigned char header[MAXHEADERSIZE]; const size_t n = FillHeader(header, size, opcode); return StreamSocket::SendQueue::Element(reinterpret_cast(header), n); } int HandleAppData(StreamSocket* sock, std::string& appdataout, bool allowlarge) { std::string& myrecvq = GetRecvQ(); // Need 1 byte opcode, minimum 1 byte len, 4 bytes masking key if (myrecvq.length() < 6) return 0; const std::string& cmyrecvq = myrecvq; unsigned char len1 = (unsigned char)cmyrecvq[1]; if (!(len1 & WS_MASKBIT)) { sock->SetError("WebSocket protocol violation: unmasked client frame"); return -1; } len1 &= ~WS_MASKBIT; // Assume the length is a single byte, if not, update values later unsigned int len = len1; unsigned int payloadstartoffset = 6; const unsigned char* maskkey = reinterpret_cast(&cmyrecvq[2]); if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE) { // allowlarge is false for control frames according to the RFC meaning large pings, etc. are not allowed if (!allowlarge) { sock->SetError("WebSocket protocol violation: large control frame"); return -1; } // Large frame, has 2 bytes len after the magic byte indicating the length // Need 1 byte opcode, 3 bytes len, 4 bytes masking key if (myrecvq.length() < 8) return 0; unsigned char len2 = (unsigned char)cmyrecvq[2]; unsigned char len3 = (unsigned char)cmyrecvq[3]; len = (len2 << 8) | len3; if (len <= WS_MAX_PAYLOAD_LENGTH_SMALL) { sock->SetError("WebSocket protocol violation: non-minimal length encoding used"); return -1; } maskkey += 2; payloadstartoffset += 2; } else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE) { sock->SetError("WebSocket: Huge frames are not supported"); return -1; } if (myrecvq.length() < payloadstartoffset + len) return 0; unsigned int maskkeypos = 0; const std::string::iterator endit = myrecvq.begin() + payloadstartoffset + len; for (std::string::const_iterator i = myrecvq.begin() + payloadstartoffset; i != endit; ++i) { const unsigned char c = (unsigned char)*i; appdataout.push_back(c ^ maskkey[maskkeypos++]); maskkeypos %= 4; } myrecvq.erase(myrecvq.begin(), endit); return 1; } int HandlePingPongFrame(StreamSocket* sock, bool isping) { if (lastpingpong + MINPINGPONGDELAY >= ServerInstance->Time()) { sock->SetError("WebSocket: Ping/pong flood"); return -1; } lastpingpong = ServerInstance->Time(); std::string appdata; const int result = HandleAppData(sock, appdata, false); // If it's a pong stop here regardless of the result so we won't generate a reply if ((result <= 0) || (!isping)) return result; StreamSocket::SendQueue::Element elem = PrepareSendQElem(appdata.length(), OP_PONG); elem.append(appdata); GetSendQ().push_back(elem); SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE); return 1; } int HandleWS(StreamSocket* sock, std::string& destrecvq) { if (GetRecvQ().empty()) return 0; unsigned char opcode = (unsigned char)GetRecvQ().c_str()[0]; switch (opcode & ~WS_FINBIT) { case OP_CONTINUATION: case OP_TEXT: case OP_BINARY: { std::string appdata; const int result = HandleAppData(sock, appdata, true); if (result != 1) return result; // Strip out any CR+LF which may have been erroneously sent. for (std::string::const_iterator iter = appdata.begin(); iter != appdata.end(); ++iter) { if (*iter != '\r' && *iter != '\n') destrecvq.push_back(*iter); } // If we are on the final message of this block append a line terminator. if (opcode & WS_FINBIT) destrecvq.append("\r\n"); return 1; } case OP_PING: { return HandlePingPongFrame(sock, true); } case OP_PONG: { // A pong frame may be sent unsolicited, so we have to handle it. // It may carry application data which we need to remove from the recvq as well. return HandlePingPongFrame(sock, false); } case OP_CLOSE: { sock->SetError("Connection closed"); return -1; } default: { sock->SetError("WebSocket: Invalid opcode"); return -1; } } } void FailHandshake(StreamSocket* sock, const char* httpreply, const char* sockerror) { GetSendQ().push_back(StreamSocket::SendQueue::Element(httpreply)); sock->DoWrite(); sock->SetError(sockerror); } int HandleHTTPReq(StreamSocket* sock) { std::string& recvq = GetRecvQ(); const std::string::size_type reqend = recvq.find("\r\n\r\n"); if (reqend == std::string::npos) return 0; bool allowedorigin = false; HTTPHeaderFinder originheader; if (originheader.Find(recvq, "Origin:", 7, reqend)) { const std::string origin = originheader.ExtractValue(recvq); for (OriginList::const_iterator iter = allowedorigins.begin(); iter != allowedorigins.end(); ++iter) { if (InspIRCd::Match(origin, *iter, ascii_case_insensitive_map)) { allowedorigin = true; break; } } } if (!allowedorigin) { FailHandshake(sock, "HTTP/1.1 403 Forbidden\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request from a non-whitelisted origin"); return -1; } HTTPHeaderFinder keyheader; if (!keyheader.Find(recvq, "Sec-WebSocket-Key:", 18, reqend)) { FailHandshake(sock, "HTTP/1.1 501 Not Implemented\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request which is not a websocket upgrade"); return -1; } if (!*sha1) { FailHandshake(sock, "HTTP/1.1 503 Service Unavailable\r\nConnection: close\r\n\r\n", "WebSocket: SHA-1 provider missing"); return -1; } state = STATE_ESTABLISHED; std::string key = keyheader.ExtractValue(recvq); key.append(MagicGUID); std::string reply = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "; reply.append(BinToBase64((*sha1)->GenerateRaw(key), NULL, '=')).append("\r\n\r\n"); GetSendQ().push_back(StreamSocket::SendQueue::Element(reply)); SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE); recvq.erase(0, reqend + 4); return 1; } public: WebSocketHook(IOHookProvider* Prov, StreamSocket* sock, OriginList& AllowedOrigins) : IOHookMiddle(Prov) , state(STATE_HTTPREQ) , lastpingpong(0) , allowedorigins(AllowedOrigins) { sock->AddIOHook(this); } int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& uppersendq) CXX11_OVERRIDE { StreamSocket::SendQueue& mysendq = GetSendQ(); // Return 1 to allow sending back an error HTTP response if (state != STATE_ESTABLISHED) return (mysendq.empty() ? 0 : 1); std::string message; for (StreamSocket::SendQueue::const_iterator elem = uppersendq.begin(); elem != uppersendq.end(); ++elem) { for (StreamSocket::SendQueue::Element::const_iterator chr = elem->begin(); chr != elem->end(); ++chr) { if (*chr == '\n') { // We have found an entire message. Send it in its own frame. mysendq.push_back(PrepareSendQElem(message.length(), OP_BINARY)); mysendq.push_back(message); message.clear(); } else if (*chr != '\r') { message.push_back(*chr); } } } // Empty the upper send queue and push whatever is left back onto it. uppersendq.clear(); if (!message.empty()) { uppersendq.push_back(message); return 0; } return 1; } int OnStreamSocketRead(StreamSocket* sock, std::string& destrecvq) CXX11_OVERRIDE { if (state == STATE_HTTPREQ) { int httpret = HandleHTTPReq(sock); if (httpret <= 0) return httpret; } int wsret; do { wsret = HandleWS(sock, destrecvq); } while ((!GetRecvQ().empty()) && (wsret > 0)); return wsret; } void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE { } }; void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) { new WebSocketHook(this, sock, allowedorigins); } class ModuleWebSocket : public Module { dynamic_reference_nocheck hash; reference hookprov; public: ModuleWebSocket() : hash(this, "hash/sha1") , hookprov(new WebSocketHookProvider(this)) { sha1 = &hash; } void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE { ConfigTagList tags = ServerInstance->Config->ConfTags("wsorigin"); if (tags.first == tags.second) throw ModuleException("You have loaded the websocket module but not configured any allowed origins!"); OriginList allowedorigins; for (ConfigIter i = tags.first; i != tags.second; ++i) { ConfigTag* tag = i->second; // Ensure that we have the parameter. const std::string allow = tag->getString("allow"); if (allow.empty()) throw ModuleException(" is a mandatory field, at " + tag->getTagLocation()); allowedorigins.push_back(allow); } hookprov->allowedorigins.swap(allowedorigins); } void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE { if (type != ExtensionItem::EXT_USER) return; LocalUser* user = IS_LOCAL(static_cast(item)); if ((user) && (user->eh.GetModHook(this))) ServerInstance->Users.QuitUser(user, "WebSocket module unloading"); } Version GetVersion() CXX11_OVERRIDE { return Version("Provides RFC 6455 WebSocket support", VF_VENDOR); } }; MODULE_INIT(ModuleWebSocket)