diff options
-rw-r--r-- | include/modules/ssl.h | 40 | ||||
-rw-r--r-- | src/modules/m_sslinfo.cpp | 31 | ||||
-rw-r--r-- | src/modules/m_sslmodes.cpp | 32 |
3 files changed, 72 insertions, 31 deletions
diff --git a/include/modules/ssl.h b/include/modules/ssl.h index 9830b1ca6..25076215a 100644 --- a/include/modules/ssl.h +++ b/include/modules/ssl.h @@ -20,7 +20,6 @@ #pragma once -#include <map> #include <string> #include "iohook.h" @@ -199,22 +198,43 @@ class SSLClientCert } }; -/** Get certificate from a user (requires m_sslinfo) */ -struct UserCertificateRequest : public Request +class UserCertificateAPIBase : public DataProvider { - User* const user; - ssl_cert* cert; - - UserCertificateRequest(User* u, Module* Me, Module* info = ServerInstance->Modules->Find("m_sslinfo.so")) - : Request(Me, info, "GET_USER_CERT"), user(u), cert(NULL) + public: + UserCertificateAPIBase(Module* parent) + : DataProvider(parent, "m_sslinfo_api") { - Send(); } - std::string GetFingerprint() + /** Get the SSL certificate of a user + * @param user The user whose certificate to get, user may be remote + * @return The SSL certificate of the user or NULL if the user is not using SSL + */ + virtual ssl_cert* GetCertificate(User* user) = 0; + + /** Get the key fingerprint from a user's certificate + * @param user The user whose key fingerprint to get, user may be remote + * @return The key fingerprint from the user's SSL certificate or an empty string + * if the user is not using SSL or did not provide a client certificate + */ + std::string GetFingerprint(User* user) { + ssl_cert* cert = GetCertificate(user); if (cert) return cert->GetFingerprint(); return ""; } }; + +/** API implemented by m_sslinfo that allows modules to retrive the SSL certificate + * information of local and remote users. It can also be used to find out whether a + * user is using SSL or not. + */ +class UserCertificateAPI : public dynamic_reference<UserCertificateAPIBase> +{ + public: + UserCertificateAPI(Module* parent) + : dynamic_reference<UserCertificateAPIBase>(parent, "m_sslinfo_api") + { + } +}; diff --git a/src/modules/m_sslinfo.cpp b/src/modules/m_sslinfo.cpp index 5516af7ef..edf3918e2 100644 --- a/src/modules/m_sslinfo.cpp +++ b/src/modules/m_sslinfo.cpp @@ -121,19 +121,37 @@ class CommandSSLInfo : public Command } }; +class UserCertificateAPIImpl : public UserCertificateAPIBase +{ + SSLCertExt& ext; + + public: + UserCertificateAPIImpl(Module* mod, SSLCertExt& certext) + : UserCertificateAPIBase(mod), ext(certext) + { + } + + ssl_cert* GetCertificate(User* user) CXX11_OVERRIDE + { + return ext.get(user); + } +}; + class ModuleSSLInfo : public Module { CommandSSLInfo cmd; + UserCertificateAPIImpl APIImpl; public: - ModuleSSLInfo() : cmd(this) + ModuleSSLInfo() + : cmd(this), APIImpl(this, cmd.CertExt) { } void init() CXX11_OVERRIDE { + ServerInstance->Modules->AddService(APIImpl); ServerInstance->Modules->AddService(cmd); - ServerInstance->Modules->AddService(cmd.CertExt); Implementation eventlist[] = { I_OnWhois, I_OnPreCommand, I_OnSetConnectClass, I_OnUserConnect, I_OnPostConnect }; @@ -228,15 +246,6 @@ class ModuleSSLInfo : public Module return MOD_RES_DENY; return MOD_RES_PASSTHRU; } - - void OnRequest(Request& request) CXX11_OVERRIDE - { - if (strcmp("GET_USER_CERT", request.id) == 0) - { - UserCertificateRequest& req = static_cast<UserCertificateRequest&>(request); - req.cert = cmd.CertExt.get(req.user); - } - } }; MODULE_INIT(ModuleSSLInfo) diff --git a/src/modules/m_sslmodes.cpp b/src/modules/m_sslmodes.cpp index 360f63bc9..65933cc14 100644 --- a/src/modules/m_sslmodes.cpp +++ b/src/modules/m_sslmodes.cpp @@ -31,7 +31,13 @@ class SSLMode : public ModeHandler { public: - SSLMode(Module* Creator) : ModeHandler(Creator, "sslonly", 'z', PARAM_NONE, MODETYPE_CHANNEL) { } + UserCertificateAPI API; + + SSLMode(Module* Creator) + : ModeHandler(Creator, "sslonly", 'z', PARAM_NONE, MODETYPE_CHANNEL) + , API(Creator) + { + } ModeAction OnModeChange(User* source, User* dest, Channel* channel, std::string ¶meter, bool adding) { @@ -41,12 +47,14 @@ class SSLMode : public ModeHandler { if (IS_LOCAL(source)) { + if (!API) + return MODEACTION_DENY; + const UserMembList* userlist = channel->GetUsers(); for(UserMembCIter i = userlist->begin(); i != userlist->end(); i++) { - UserCertificateRequest req(i->first, creator); - req.Send(); - if(!req.cert && !ServerInstance->ULine(i->first->server)) + ssl_cert* cert = API->GetCertificate(i->first); + if (!cert && !ServerInstance->ULine(i->first->server)) { source->WriteNumeric(ERR_ALLMUSTSSL, "%s %s :all members of the channel must be connected via SSL", source->nick.c_str(), channel->name.c_str()); return MODEACTION_DENY; @@ -96,9 +104,11 @@ class ModuleSSLModes : public Module { if(chan && chan->IsModeSet('z')) { - UserCertificateRequest req(user, this); - req.Send(); - if (req.cert) + if (!sslm.API) + return MOD_RES_DENY; + + ssl_cert* cert = sslm.API->GetCertificate(user); + if (cert) { // Let them in return MOD_RES_PASSTHRU; @@ -118,9 +128,11 @@ class ModuleSSLModes : public Module { if ((mask.length() > 2) && (mask[0] == 'z') && (mask[1] == ':')) { - UserCertificateRequest req(user, this); - req.Send(); - if (req.cert && InspIRCd::Match(req.cert->GetFingerprint(), mask.substr(2))) + if (!sslm.API) + return MOD_RES_DENY; + + ssl_cert* cert = sslm.API->GetCertificate(user); + if (cert && InspIRCd::Match(cert->GetFingerprint(), mask.substr(2))) return MOD_RES_DENY; } return MOD_RES_PASSTHRU; |