diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/modules/extra/m_mysql.cpp | 52 |
1 files changed, 29 insertions, 23 deletions
diff --git a/src/modules/extra/m_mysql.cpp b/src/modules/extra/m_mysql.cpp index 7b6e2906d..dcdbe0004 100644 --- a/src/modules/extra/m_mysql.cpp +++ b/src/modules/extra/m_mysql.cpp @@ -253,6 +253,31 @@ class MySQLresult : public SQL::Result */ class SQLConnection : public SQL::Provider { + private: + bool EscapeString(SQL::Query* query, const std::string& in, std::string& out) + { + // In the worst case each character may need to be encoded as using two bytes and one + // byte is the NUL terminator. + std::vector<char> buffer(in.length() * 2 + 1); + + // The return value of mysql_escape_string() is either an error or the length of the + // encoded string not including the NUL terminator. + // + // Unfortunately, someone genius decided that mysql_escape_string should return an + // unsigned type even though -1 is returned on error so checking whether an error + // happened is a bit cursed. + unsigned long escapedsize = mysql_escape_string(&buffer[0], in.c_str(), in.length()); + if (escapedsize == static_cast<unsigned long>(-1)) + { + SQL::Error err(SQL::QSEND_FAIL, InspIRCd::Format("%u: %s", mysql_errno(connection), mysql_error(connection))); + query->OnError(err); + return false; + } + + out.append(&buffer[0], escapedsize); + return true; + } + public: reference<ConfigTag> config; MYSQL *connection; @@ -356,21 +381,8 @@ class SQLConnection : public SQL::Provider { if (q[i] != '?') res.push_back(q[i]); - else - { - if (param < p.size()) - { - std::string parm = p[param++]; - // In the worst case, each character may need to be encoded as using two bytes, - // and one byte is the terminating null - std::vector<char> buffer(parm.length() * 2 + 1); - - // The return value of mysql_real_escape_string() is the length of the encoded string, - // not including the terminating null - unsigned long escapedsize = mysql_real_escape_string(connection, &buffer[0], parm.c_str(), parm.length()); - res.append(&buffer[0], escapedsize); - } - } + else if (param < p.size() && !EscapeString(call, p[param++], res)) + return; } Submit(call, res); } @@ -391,14 +403,8 @@ class SQLConnection : public SQL::Provider i--; SQL::ParamMap::const_iterator it = p.find(field); - if (it != p.end()) - { - std::string parm = it->second; - // NOTE: See above - std::vector<char> buffer(parm.length() * 2 + 1); - unsigned long escapedsize = mysql_escape_string(&buffer[0], parm.c_str(), parm.length()); - res.append(&buffer[0], escapedsize); - } + if (it != p.end() && !EscapeString(call, it->second, res)) + return; } } Submit(call, res); |