#include "modules/account.h"
#include "modules/sasl.h"
#include "modules/ssl.h"
+#include "modules/spanningtree.h"
+
+static std::string sasl_target;
+
+class ServerTracker : public SpanningTreeEventListener
+{
+ bool online;
+
+ void Update(const Server* server, bool linked)
+ {
+ if (sasl_target == "*")
+ return;
+
+ if (InspIRCd::Match(server->GetName(), sasl_target))
+ {
+ ServerInstance->Logs->Log(MODNAME, LOG_VERBOSE, "SASL target server \"%s\" %s", sasl_target.c_str(), (linked ? "came online" : "went offline"));
+ online = linked;
+ }
+ }
+
+ void OnServerLink(const Server* server) CXX11_OVERRIDE
+ {
+ Update(server, true);
+ }
+
+ void OnServerSplit(const Server* server) CXX11_OVERRIDE
+ {
+ Update(server, false);
+ }
+
+ public:
+ ServerTracker(Module* mod)
+ : SpanningTreeEventListener(mod)
+ {
+ Reset();
+ }
+
+ void Reset()
+ {
+ if (sasl_target == "*")
+ {
+ online = true;
+ return;
+ }
+
+ online = false;
+
+ ProtocolInterface::ServerList servers;
+ ServerInstance->PI->GetServerList(servers);
+ for (ProtocolInterface::ServerList::const_iterator i = servers.begin(); i != servers.end(); ++i)
+ {
+ const ProtocolInterface::ServerInfo& server = *i;
+ if (InspIRCd::Match(server.servername, sasl_target))
+ {
+ online = true;
+ break;
+ }
+ }
+ }
+
+ bool IsOnline() const { return online; }
+};
class SASLCap : public Cap::Capability
{
std::string mechlist;
+ const ServerTracker& servertracker;
bool OnRequest(LocalUser* user, bool adding) CXX11_OVERRIDE
{
return (user->registered != REG_ALL);
}
+ bool OnList(LocalUser* user) CXX11_OVERRIDE
+ {
+ return servertracker.IsOnline();
+ }
+
const std::string* GetValue(LocalUser* user) const CXX11_OVERRIDE
{
return &mechlist;
}
public:
- SASLCap(Module* mod)
+ SASLCap(Module* mod, const ServerTracker& tracker)
: Cap::Capability(mod, "sasl")
+ , servertracker(tracker)
{
}
enum SaslState { SASL_INIT, SASL_COMM, SASL_DONE };
enum SaslResult { SASL_OK, SASL_FAIL, SASL_ABORT };
-static std::string sasl_target = "*";
static Events::ModuleEventProvider* saslevprov;
static void SendSASL(const parameterlist& params)
this->result = this->GetSaslResult(msg[3]);
}
else if (msg[2] == "M")
- this->user->WriteNumeric(908, "%s :are available SASL mechanisms", msg[3].c_str());
+ this->user->WriteNumeric(908, msg[3], "are available SASL mechanisms");
else
ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Services sent an unknown SASL message \"%s\" \"%s\"", msg[2].c_str(), msg[3].c_str());
switch (this->result)
{
case SASL_OK:
- this->user->WriteNumeric(903, ":SASL authentication successful");
+ this->user->WriteNumeric(903, "SASL authentication successful");
break;
case SASL_ABORT:
- this->user->WriteNumeric(906, ":SASL authentication aborted");
+ this->user->WriteNumeric(906, "SASL authentication aborted");
break;
case SASL_FAIL:
- this->user->WriteNumeric(904, ":SASL authentication failed");
+ this->user->WriteNumeric(904, "SASL authentication failed");
break;
default:
break;
CmdResult Handle (const std::vector<std::string>& parameters, User *user)
{
- /* Only allow AUTHENTICATE on unregistered clients */
- if (user->registered != REG_ALL)
{
if (!cap.get(user))
return CMD_FAILURE;
class ModuleSASL : public Module
{
SimpleExtItem<SaslAuthenticator> authExt;
+ ServerTracker servertracker;
SASLCap cap;
CommandAuthenticate auth;
CommandSASL sasl;
public:
ModuleSASL()
: authExt("sasl_auth", ExtensionItem::EXT_USER, this)
- , cap(this)
+ , servertracker(this)
+ , cap(this, servertracker)
, auth(this, authExt, cap)
, sasl(this, authExt)
, sasleventprov(this, "event/sasl")
void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE
{
sasl_target = ServerInstance->Config->ConfValue("sasl")->getString("target", "*");
+ servertracker.Reset();
}
ModResult OnUserRegister(LocalUser *user) CXX11_OVERRIDE