diff options
-rw-r--r-- | src/modules/extra/m_sqlite3.cpp | 31 | ||||
-rw-r--r-- | src/modules/m_sqlauth.cpp | 9 | ||||
-rw-r--r-- | src/modules/m_sqloper.cpp | 37 | ||||
-rw-r--r-- | src/modules/sql.h | 34 |
4 files changed, 81 insertions, 30 deletions
diff --git a/src/modules/extra/m_sqlite3.cpp b/src/modules/extra/m_sqlite3.cpp index 4a0538bc8..cc54b90bf 100644 --- a/src/modules/extra/m_sqlite3.cpp +++ b/src/modules/extra/m_sqlite3.cpp @@ -28,7 +28,8 @@ class SQLite3Result : public SQLResult public: int currentrow; int rows; - std::vector<std::vector<std::string> > fieldlists; + std::vector<std::string> columns; + std::vector<SQLEntries> fieldlists; SQLite3Result() : currentrow(0), rows(0) { @@ -43,7 +44,7 @@ class SQLite3Result : public SQLResult return rows; } - virtual bool GetRow(std::vector<std::string>& result) + virtual bool GetRow(SQLEntries& result) { if (currentrow < rows) { @@ -57,6 +58,10 @@ class SQLite3Result : public SQLResult return false; } } + virtual void GetCols(std::vector<std::string>& result) + { + result.assign(columns.begin(), columns.end()); + } }; class SQLConn : public refcountbase @@ -94,6 +99,11 @@ class SQLConn : public refcountbase return; } int cols = sqlite3_column_count(stmt); + res.columns.resize(cols); + for(int i=0; i < cols; i++) + { + res.columns[i] = sqlite3_column_name(stmt, i); + } while (1) { err = sqlite3_step(stmt); @@ -105,7 +115,8 @@ class SQLConn : public refcountbase for(int i=0; i < cols; i++) { const char* txt = (const char*)sqlite3_column_text(stmt, i); - res.fieldlists[res.rows][i] = txt ? txt : ""; + if (txt) + res.fieldlists[res.rows][i] = SQLEntry(txt); } res.rows++; } @@ -132,7 +143,7 @@ class SQLiteProvider : public SQLProvider SQLiteProvider(Module* Parent) : SQLProvider(Parent, "SQL/SQLite") {} - std::string FormatQuery(std::string q, ParamL p) + std::string FormatQuery(const std::string& q, const ParamL& p) { std::string res; unsigned int param = 0; @@ -154,7 +165,7 @@ class SQLiteProvider : public SQLProvider return res; } - std::string FormatQuery(std::string q, ParamM p) + std::string FormatQuery(const std::string& q, const ParamM& p) { std::string res; for(std::string::size_type i = 0; i < q.length(); i++) @@ -169,9 +180,13 @@ class SQLiteProvider : public SQLProvider field.push_back(q[i++]); i--; - char* escaped = sqlite3_mprintf("%q", p[field].c_str()); - res.append(escaped); - sqlite3_free(escaped); + ParamM::const_iterator it = p.find(field); + if (it != p.end()) + { + char* escaped = sqlite3_mprintf("%q", it->second.c_str()); + res.append(escaped); + sqlite3_free(escaped); + } } } return res; diff --git a/src/modules/m_sqlauth.cpp b/src/modules/m_sqlauth.cpp index 1494a5634..c7c6c61a6 100644 --- a/src/modules/m_sqlauth.cpp +++ b/src/modules/m_sqlauth.cpp @@ -113,16 +113,9 @@ class ModuleSQLAuth : public Module pendingExt.set(user, AUTH_STATE_BUSY); - std::string thisquery = freeformquery; ParamM userinfo; - userinfo["nick"] = user->nick; + SQL->PopulateUserInfo(user, userinfo); userinfo["pass"] = user->password; - userinfo["host"] = user->host; - userinfo["ip"] = user->GetIPString(); - userinfo["gecos"] = user->fullname; - userinfo["ident"] = user->ident; - userinfo["server"] = user->server; - userinfo["uuid"] = user->uuid; HashProvider* md5 = ServerInstance->Modules->FindDataService<HashProvider>("hash/md5"); if (md5) diff --git a/src/modules/m_sqloper.cpp b/src/modules/m_sqloper.cpp index 3b2f67196..66fb0550e 100644 --- a/src/modules/m_sqloper.cpp +++ b/src/modules/m_sqloper.cpp @@ -48,12 +48,25 @@ class OpMeQuery : public SQLQuery if (!user) return; - // multiple rows may exist for multiple hosts - parameterlist row; + // multiple rows may exist + SQLEntries row; while (res.GetRow(row)) { +#if 0 + parameterlist cols; + res.GetCols(cols); + + std::vector<KeyVal>* items; + reference<ConfigTag> tag = ConfigTag::create("oper", "<m_sqloper>", 0, items); + for(unsigned int i=0; i < cols.size(); i++) + { + if (!row[i].nul) + items->insert(std::make_pair(cols[i], row[i])); + } +#else if (OperUser(user, row[0], row[1])) return; +#endif } ServerInstance->Logs->Log("m_sqloper",DEBUG, "SQLOPER: no matches for %s (checked %d rows)", uid.c_str(), res.Rows()); // nobody succeeded... fall back to OPER @@ -62,6 +75,7 @@ class OpMeQuery : public SQLQuery void OnError(SQLerror& error) { + ServerInstance->Logs->Log("m_sqloper",DEFAULT, "SQLOPER: query failed (%s)", error.Str()); fallback(); } @@ -115,6 +129,7 @@ class OpMeQuery : public SQLQuery class ModuleSQLOper : public Module { std::string databaseid; + std::string query; std::string hashtype; dynamic_reference<SQLProvider> SQL; @@ -131,10 +146,11 @@ public: void OnRehash(User* user) { - ConfigReader Conf; + ConfigTag* tag = ServerInstance->Config->ConfValue("sqloper"); - databaseid = Conf.ReadValue("sqloper", "dbid", 0); /* Database ID of a database configured for the service provider module */ - hashtype = Conf.ReadValue("sqloper", "hash", 0); + databaseid = tag->getString("dbid"); + hashtype = tag->getString("hash"); + query = tag->getString("query", "SELECT hostname as host, type FROM ircd_opers WHERE username='$username' AND password='$password'"); } ModResult OnPreCommand(std::string &command, std::vector<std::string> ¶meters, LocalUser *user, bool validated, const std::string &original_line) @@ -152,13 +168,12 @@ public: { HashProvider* hash = ServerInstance->Modules->FindDataService<HashProvider>("hash/" + hashtype); - parameterlist params; - params.push_back(username); - params.push_back(hash ? hash->hexsum(password) : password); + ParamM userinfo; + SQL->PopulateUserInfo(user, userinfo); + userinfo["username"] = username; + userinfo["password"] = hash ? hash->hexsum(password) : password; - SQL->submit(new OpMeQuery(this, databaseid, SQL->FormatQuery( - "SELECT hostname, type FROM ircd_opers WHERE username = '?' AND password='?'", params - ), user->uuid, username, password)); + SQL->submit(new OpMeQuery(this, databaseid, SQL->FormatQuery(query, userinfo), user->uuid, username, password)); } Version GetVersion() diff --git a/src/modules/sql.h b/src/modules/sql.h index 9114bea88..ffe95a9cb 100644 --- a/src/modules/sql.h +++ b/src/modules/sql.h @@ -24,6 +24,18 @@ typedef std::vector<std::string> ParamL; typedef std::map<std::string, std::string> ParamM; +class SQLEntry +{ + public: + std::string value; + bool nul; + SQLEntry() : nul(true) {} + SQLEntry(const std::string& v) : value(v), nul(false) {} + inline operator std::string&() { return value; } +}; + +typedef std::vector<SQLEntry> SQLEntries; + /** * Result of an SQL query. Only valid inside OnResult */ @@ -49,7 +61,11 @@ class SQLResult : public interfacebase * @returns true if there was a row, false if no row exists (end of * iteration) */ - virtual bool GetRow(std::vector<std::string>& result) = 0; + virtual bool GetRow(SQLEntries& result) = 0; + + /** Returns column names for the items in this row + */ + virtual void GetCols(std::vector<std::string>& result) = 0; }; /** SQLerror holds the error state of a request. @@ -143,13 +159,25 @@ class SQLProvider : public DataProvider * @param q The query string, with '?' parameters * @param p The parameters to fill in in the '?' slots */ - virtual std::string FormatQuery(std::string q, ParamL p) = 0; + virtual std::string FormatQuery(const std::string& q, const ParamL& p) = 0; /** Format a parameterized query string using proper SQL escaping. * @param q The query string, with '$foo' parameters * @param p The map to look up parameters in */ - virtual std::string FormatQuery(std::string q, ParamM p) = 0; + virtual std::string FormatQuery(const std::string& q, const ParamM& p) = 0; + + /** Convenience function */ + void PopulateUserInfo(User* user, ParamM& userinfo) + { + userinfo["nick"] = user->nick; + userinfo["host"] = user->host; + userinfo["ip"] = user->GetIPString(); + userinfo["gecos"] = user->fullname; + userinfo["ident"] = user->ident; + userinfo["server"] = user->server; + userinfo["uuid"] = user->uuid; + } }; #endif |