X-Git-Url: https://git.netwichtig.de/gitweb/?a=blobdiff_plain;f=src%2Fmodules%2Fm_websocket.cpp;h=5ac661ccfaefb63927695bf0bbef450ea98424dd;hb=553877f7a9eff26166dfa4d953d6f69f9420de28;hp=399b0b017f3582cb2ff34760e881a08b54eee21f;hpb=c0aba5b728b0a921d95ec120aa638dab1520b42f;p=user%2Fhenk%2Fcode%2Finspircd.git diff --git a/src/modules/m_websocket.cpp b/src/modules/m_websocket.cpp index 399b0b017..5ac661ccf 100644 --- a/src/modules/m_websocket.cpp +++ b/src/modules/m_websocket.cpp @@ -21,6 +21,8 @@ #include "iohook.h" #include "modules/hash.h" +typedef std::vector OriginList; + static const char MagicGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; static const char whitespace[] = " \t\r\n"; static dynamic_reference_nocheck* sha1; @@ -28,6 +30,8 @@ static dynamic_reference_nocheck* sha1; class WebSocketHookProvider : public IOHookProvider { public: + OriginList allowedorigins; + WebSocketHookProvider(Module* mod) : IOHookProvider(mod, "websocket", IOHookProvider::IOH_UNKNOWN, true) { @@ -101,6 +105,7 @@ class WebSocketHook : public IOHookMiddle State state; time_t lastpingpong; + OriginList& allowedorigins; static size_t FillHeader(unsigned char* outbuf, size_t sendlength, OpCode opcode) { @@ -288,6 +293,27 @@ class WebSocketHook : public IOHookMiddle if (reqend == std::string::npos) return 0; + bool allowedorigin = false; + HTTPHeaderFinder originheader; + if (originheader.Find(recvq, "Origin:", 7, reqend)) + { + const std::string origin = originheader.ExtractValue(recvq); + for (OriginList::const_iterator iter = allowedorigins.begin(); iter != allowedorigins.end(); ++iter) + { + if (InspIRCd::Match(origin, *iter, ascii_case_insensitive_map)) + { + allowedorigin = true; + break; + } + } + } + + if (!allowedorigin) + { + FailHandshake(sock, "HTTP/1.1 403 Forbidden\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request from a non-whitelisted origin"); + return -1; + } + HTTPHeaderFinder keyheader; if (!keyheader.Find(recvq, "Sec-WebSocket-Key:", 18, reqend)) { @@ -318,10 +344,11 @@ class WebSocketHook : public IOHookMiddle } public: - WebSocketHook(IOHookProvider* Prov, StreamSocket* sock) + WebSocketHook(IOHookProvider* Prov, StreamSocket* sock, OriginList& AllowedOrigins) : IOHookMiddle(Prov) , state(STATE_HTTPREQ) , lastpingpong(0) + , allowedorigins(AllowedOrigins) { sock->AddIOHook(this); } @@ -370,25 +397,46 @@ class WebSocketHook : public IOHookMiddle void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) { - new WebSocketHook(this, sock); + new WebSocketHook(this, sock, allowedorigins); } class ModuleWebSocket : public Module { dynamic_reference_nocheck hash; - WebSocketHookProvider hookprov; + reference hookprov; public: ModuleWebSocket() : hash(this, "hash/sha1") - , hookprov(this) + , hookprov(new WebSocketHookProvider(this)) { sha1 = &hash; } - void OnCleanup(int target_type, void* item) CXX11_OVERRIDE + void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE + { + ConfigTagList tags = ServerInstance->Config->ConfTags("wsorigin"); + if (tags.first == tags.second) + throw ModuleException("You have loaded the websocket module but not configured any allowed origins!"); + + OriginList allowedorigins; + for (ConfigIter i = tags.first; i != tags.second; ++i) + { + ConfigTag* tag = i->second; + + // Ensure that we have the parameter. + const std::string allow = tag->getString("allow"); + if (allow.empty()) + throw ModuleException(" is a mandatory field, at " + tag->getTagLocation()); + + allowedorigins.push_back(allow); + } + hookprov->allowedorigins.swap(allowedorigins); + } + + void OnCleanup(ExtensionItem::ExtensibleType type, Extensible* item) CXX11_OVERRIDE { - if (target_type != TYPE_USER) + if (type != ExtensionItem::EXT_USER) return; LocalUser* user = IS_LOCAL(static_cast(item));