]> git.netwichtig.de Git - user/henk/code/inspircd.git/commitdiff
Change SQLv3 to format queries during submission, not before
authordanieldg <danieldg@e03df62e-2008-0410-955e-edbf42e46eb7>
Sat, 13 Mar 2010 16:35:16 +0000 (16:35 +0000)
committerdanieldg <danieldg@e03df62e-2008-0410-955e-edbf42e46eb7>
Sat, 13 Mar 2010 16:35:16 +0000 (16:35 +0000)
git-svn-id: http://svn.inspircd.org/repository/trunk/inspircd@12633 e03df62e-2008-0410-955e-edbf42e46eb7

src/modules/extra/m_mysql.cpp
src/modules/extra/m_pgsql.cpp
src/modules/extra/m_sqlite3.cpp
src/modules/m_sqlauth.cpp
src/modules/m_sqloper.cpp
src/modules/sql.h

index ae219df70755b68561d90135d4474af4f01a9bdb..88c98a5e7d6f41e8505a291438d81fa7e0a0dd10 100644 (file)
@@ -70,8 +70,9 @@ class DispatcherThread;
 struct QQueueItem
 {
        SQLQuery* q;
+       std::string query;
        SQLConnection* c;
-       QQueueItem(SQLQuery* Q, SQLConnection* C) : q(Q), c(C) {}
+       QQueueItem(SQLQuery* Q, const std::string& S, SQLConnection* C) : q(Q), query(S), c(C) {}
 };
 
 struct RQueueItem
@@ -260,68 +261,16 @@ class SQLConnection : public SQLProvider
                return true;
        }
 
-       virtual std::string FormatQuery(const std::string& q, const ParamL& p)
-       {
-               std::string res;
-               unsigned int param = 0;
-               for(std::string::size_type i = 0; i < q.length(); i++)
-               {
-                       if (q[i] != '?')
-                               res.push_back(q[i]);
-                       else
-                       {
-                               // TODO numbered parameter support ('?1')
-                               if (param < p.size())
-                               {
-                                       std::string parm = p[param++];
-                                       char buffer[MAXBUF];
-                                       mysql_escape_string(buffer, parm.c_str(), parm.length());
-//                                     mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
-                                       res.append(buffer);
-                               }
-                       }
-               }
-               return res;
-       }
-
-       std::string FormatQuery(const std::string& q, const ParamM& p)
-       {
-               std::string res;
-               for(std::string::size_type i = 0; i < q.length(); i++)
-               {
-                       if (q[i] != '$')
-                               res.push_back(q[i]);
-                       else
-                       {
-                               std::string field;
-                               i++;
-                               while (i < q.length() && isalpha(q[i]))
-                                       field.push_back(q[i++]);
-                               i--;
-
-                               ParamM::const_iterator it = p.find(field);
-                               if (it != p.end())
-                               {
-                                       std::string parm = it->second;
-                                       char buffer[MAXBUF];
-                                       mysql_escape_string(buffer, parm.c_str(), parm.length());
-                                       res.append(buffer);
-                               }
-                       }
-               }
-               return res;
-       }
-
        ModuleSQL* Parent()
        {
                return (ModuleSQL*)(Module*)creator;
        }
 
-       MySQLresult* DoBlockingQuery(SQLQuery* req)
+       MySQLresult* DoBlockingQuery(const std::string& query)
        {
 
                /* Parse the command string and dispatch it to mysql */
-               if (CheckConnection() && !mysql_real_query(connection, req->query.data(), req->query.length()))
+               if (CheckConnection() && !mysql_real_query(connection, query.data(), query.length()))
                {
                        /* Successfull query */
                        MYSQL_RES* res = mysql_use_result(connection);
@@ -356,12 +305,63 @@ class SQLConnection : public SQLProvider
                mysql_close(connection);
        }
 
-       void submit(SQLQuery* q)
+       void submit(SQLQuery* q, const std::string& qs)
        {
                Parent()->Dispatcher->LockQueue();
-               Parent()->qq.push_back(QQueueItem(q, this));
+               Parent()->qq.push_back(QQueueItem(q, qs, this));
                Parent()->Dispatcher->UnlockQueueWakeup();
        }
+
+       void submit(SQLQuery* call, const std::string& q, const ParamL& p)
+       {
+               std::string res;
+               unsigned int param = 0;
+               for(std::string::size_type i = 0; i < q.length(); i++)
+               {
+                       if (q[i] != '?')
+                               res.push_back(q[i]);
+                       else
+                       {
+                               if (param < p.size())
+                               {
+                                       std::string parm = p[param++];
+                                       char buffer[MAXBUF];
+                                       mysql_escape_string(buffer, parm.c_str(), parm.length());
+//                                     mysql_real_escape_string(connection, queryend, paramscopy[paramnum].c_str(), paramscopy[paramnum].length());
+                                       res.append(buffer);
+                               }
+                       }
+               }
+               submit(call, res);
+       }
+
+       void submit(SQLQuery* call, const std::string& q, const ParamM& p)
+       {
+               std::string res;
+               for(std::string::size_type i = 0; i < q.length(); i++)
+               {
+                       if (q[i] != '$')
+                               res.push_back(q[i]);
+                       else
+                       {
+                               std::string field;
+                               i++;
+                               while (i < q.length() && isalpha(q[i]))
+                                       field.push_back(q[i++]);
+                               i--;
+
+                               ParamM::const_iterator it = p.find(field);
+                               if (it != p.end())
+                               {
+                                       std::string parm = it->second;
+                                       char buffer[MAXBUF];
+                                       mysql_escape_string(buffer, parm.c_str(), parm.length());
+                                       res.append(buffer);
+                               }
+                       }
+               }
+               submit(call, res);
+       }
 };
 
 ModuleSQL::ModuleSQL()
@@ -481,7 +481,7 @@ void DispatcherThread::Run()
                        QQueueItem i = Parent->qq.front();
                        i.c->lock.Lock();
                        this->UnlockQueue();
-                       MySQLresult* res = i.c->DoBlockingQuery(i.q);
+                       MySQLresult* res = i.c->DoBlockingQuery(i.query);
                        i.c->lock.Unlock();
 
                        /*
index 735ca2f5aedced6f1d0b5d5535f2964c4381504a..655007ea32c81f5f7a4100c26ccce162b8db0d28 100644 (file)
@@ -55,6 +55,12 @@ class ReconnectTimer : public Timer
        virtual void Tick(time_t TIME);
 };
 
+struct QueueItem
+{
+       SQLQuery* c;
+       std::string q;
+       QueueItem(SQLQuery* C, const std::string& Q) : c(C), q(Q) {}
+};
 
 /** PgSQLresult is a subclass of the mostly-pure-virtual class SQLresult.
  * All SQL providers must create their own subclass and define it's methods using that
@@ -126,13 +132,13 @@ class SQLConn : public SQLProvider, public EventHandler
 {
  public:
        reference<ConfigTag> conf;      /* The <database> entry */
-       std::deque<SQLQuery*> queue;
+       std::deque<QueueItem> queue;
        PGconn*                 sql;            /* PgSQL database connection handle */
        SQLstatus               status;         /* PgSQL database connection status */
-       SQLQuery*               qinprog;        /* If there is currently a query in progress */
+       QueueItem               qinprog;        /* If there is currently a query in progress */
 
        SQLConn(Module* Creator, ConfigTag* tag)
-       : SQLProvider(Creator, "SQL/" + tag->getString("id")), conf(tag), sql(NULL), status(CWRITE), qinprog(NULL)
+       : SQLProvider(Creator, "SQL/" + tag->getString("id")), conf(tag), sql(NULL), status(CWRITE), qinprog(NULL, "")
        {
                if (!DoConnect())
                {
@@ -151,14 +157,14 @@ class SQLConn : public SQLProvider, public EventHandler
        ~SQLConn()
        {
                SQLerror err(SQL_BAD_DBID);
-               if (qinprog)
+               if (qinprog.c)
                {
-                       qinprog->OnError(err);
-                       delete qinprog;
+                       qinprog.c->OnError(err);
+                       delete qinprog.c;
                }
-               for(std::deque<SQLQuery*>::iterator i = queue.begin(); i != queue.end(); i++)
+               for(std::deque<QueueItem>::iterator i = queue.begin(); i != queue.end(); i++)
                {
-                       SQLQuery* q = *i;
+                       SQLQuery* q = i->c;
                        q->OnError(err);
                        delete q;
                }
@@ -262,7 +268,7 @@ class SQLConn : public SQLProvider, public EventHandler
        void DoConnectedPoll()
        {
 restart:
-               while (!qinprog && !queue.empty())
+               while (qinprog.q.empty() && !queue.empty())
                {
                        /* There's no query currently in progress, and there's queries in the queue. */
                        DoQuery(queue.front());
@@ -275,7 +281,7 @@ restart:
                        {
                                /* Nothing happens here */
                        }
-                       else if (qinprog)
+                       else if (qinprog.c)
                        {
                                /* Fetch the result.. */
                                PGresult* result = PQgetResult(sql);
@@ -301,18 +307,22 @@ restart:
                                        case PGRES_FATAL_ERROR:
                                        {
                                                SQLerror err(SQL_QREPLY_FAIL, PQresultErrorMessage(result));
-                                               qinprog->OnError(err);
+                                               qinprog.c->OnError(err);
                                                break;
                                        }
                                        default:
                                                /* Other values are not errors */
-                                               qinprog->OnResult(reply);
+                                               qinprog.c->OnResult(reply);
                                }
 
-                               delete qinprog;
-                               qinprog = NULL;
+                               delete qinprog.c;
+                               qinprog = QueueItem(NULL, "");
                                goto restart;
                        }
+                       else
+                       {
+                               qinprog.q = "";
+                       }
                }
                else
                {
@@ -366,7 +376,20 @@ restart:
                }
        }
 
-       virtual std::string FormatQuery(const std::string& q, const ParamL& p)
+       void submit(SQLQuery *req, const std::string& q)
+       {
+               if (qinprog.q.empty())
+               {
+                       DoQuery(QueueItem(req,q));
+               }
+               else
+               {
+                       // wait your turn.
+                       queue.push_back(QueueItem(req,q));
+               }
+       }
+
+       void submit(SQLQuery *req, const std::string& q, const ParamL& p)
        {
                std::string res;
                unsigned int param = 0;
@@ -376,7 +399,6 @@ restart:
                                res.push_back(q[i]);
                        else
                        {
-                               // TODO numbered parameter support ('?1')
                                if (param < p.size())
                                {
                                        std::string parm = p[param++];
@@ -393,10 +415,10 @@ restart:
                                }
                        }
                }
-               return res;
+               submit(req, res);
        }
 
-       std::string FormatQuery(const std::string& q, const ParamM& p)
+       void submit(SQLQuery *req, const std::string& q, const ParamM& p)
        {
                std::string res;
                for(std::string::size_type i = 0; i < q.length(); i++)
@@ -428,42 +450,29 @@ restart:
                                }
                        }
                }
-               return res;
+               submit(req, res);
        }
 
-       virtual void submit(SQLQuery *req)
-       {
-               if (qinprog)
-               {
-                       // wait your turn.
-                       queue.push_back(req);
-               }
-               else
-               {
-                       DoQuery(req);
-               }
-       }
-
-       void DoQuery(SQLQuery* req)
+       void DoQuery(const QueueItem& req)
        {
                if (status != WREAD && status != WWRITE)
                {
                        // whoops, not connected...
                        SQLerror err(SQL_BAD_CONN);
-                       req->OnError(err);
-                       delete req;
+                       req.c->OnError(err);
+                       delete req.c;
                        return;
                }
 
-               if(PQsendQuery(sql, req->query.c_str()))
+               if(PQsendQuery(sql, req.q.c_str()))
                {
                        qinprog = req;
                }
                else
                {
                        SQLerror err(SQL_QSEND_FAIL, PQerrorMessage(sql));
-                       req->OnError(err);
-                       delete req;
+                       req.c->OnError(err);
+                       delete req.c;
                }
        }
 
@@ -479,13 +488,6 @@ restart:
        }
 };
 
-class DummyQuery : public SQLQuery
-{
- public:
-       DummyQuery(Module* me) : SQLQuery(me, "") {}
-       void OnResult(SQLResult& result) {}
-};
-
 class ModulePgSQL : public Module
 {
  public:
@@ -558,16 +560,16 @@ class ModulePgSQL : public Module
                for(ConnMap::iterator i = connections.begin(); i != connections.end(); i++)
                {
                        SQLConn* conn = i->second;
-                       if (conn->qinprog && conn->qinprog->creator == mod)
+                       if (conn->qinprog.c && conn->qinprog.c->creator == mod)
                        {
-                               conn->qinprog->OnError(err);
-                               delete conn->qinprog;
-                               conn->qinprog = new DummyQuery(this);
+                               conn->qinprog.c->OnError(err);
+                               delete conn->qinprog.c;
+                               conn->qinprog.c = NULL;
                        }
-                       std::deque<SQLQuery*>::iterator j = conn->queue.begin();
+                       std::deque<QueueItem>::iterator j = conn->queue.begin();
                        while (j != conn->queue.end())
                        {
-                               SQLQuery* q = *j;
+                               SQLQuery* q = j->c;
                                if (q->creator == mod)
                                {
                                        q->OnError(err);
index b13d42bcabb0d517302c54d6f7122fe77471da63..b3bb5a51b1476aaf1afe5e7dad3f7df2a47f1c45 100644 (file)
@@ -88,11 +88,11 @@ class SQLConn : public SQLProvider
                sqlite3_close(conn);
        }
 
-       void Query(SQLQuery* query)
+       void Query(SQLQuery* query, const std::string& q)
        {
                SQLite3Result res;
                sqlite3_stmt *stmt;
-               int err = sqlite3_prepare_v2(conn, query->query.c_str(), query->query.length(), &stmt, NULL);
+               int err = sqlite3_prepare_v2(conn, q.c_str(), q.length(), &stmt, NULL);
                if (err != SQLITE_OK)
                {
                        SQLerror error(SQL_QSEND_FAIL, sqlite3_errmsg(conn));
@@ -136,7 +136,13 @@ class SQLConn : public SQLProvider
                sqlite3_finalize(stmt);
        }
 
-       std::string FormatQuery(const std::string& q, const ParamL& p)
+       virtual void submit(SQLQuery* query, const std::string& q)
+       {
+               Query(query, q);
+               delete query;
+       }
+
+       virtual void submit(SQLQuery* query, const std::string& q, const ParamL& p)
        {
                std::string res;
                unsigned int param = 0;
@@ -146,7 +152,6 @@ class SQLConn : public SQLProvider
                                res.push_back(q[i]);
                        else
                        {
-                               // TODO numbered parameter support ('?1')
                                if (param < p.size())
                                {
                                        char* escaped = sqlite3_mprintf("%q", p[param++].c_str());
@@ -155,10 +160,10 @@ class SQLConn : public SQLProvider
                                }
                        }
                }
-               return res;
+               submit(query, res);
        }
 
-       std::string FormatQuery(const std::string& q, const ParamM& p)
+       virtual void submit(SQLQuery* query, const std::string& q, const ParamM& p)
        {
                std::string res;
                for(std::string::size_type i = 0; i < q.length(); i++)
@@ -182,13 +187,7 @@ class SQLConn : public SQLProvider
                                }
                        }
                }
-               return res;
-       }
-       
-       virtual void submit(SQLQuery* query)
-       {
-               Query(query);
-               delete query;
+               submit(query, res);
        }
 };
 
index 52f0d5d4fea8af52ef70a1711461cded1ecde69e..29554b03176533756c3a062c041c7b6379a1700e 100644 (file)
@@ -29,10 +29,9 @@ class AuthQuery : public SQLQuery
        const std::string uid;
        LocalIntExt& pendingExt;
        bool verbose;
-       AuthQuery(Module* me, const std::string& q, const std::string& u, LocalIntExt& e, bool v)
-               : SQLQuery(me, q), uid(u), pendingExt(e), verbose(v)
+       AuthQuery(Module* me, const std::string& u, LocalIntExt& e, bool v)
+               : SQLQuery(me), uid(u), pendingExt(e), verbose(v)
        {
-               ServerInstance->Logs->Log("m_sqlauth",DEBUG, "SQLAUTH: query=\"%s\"", q.c_str());
        }
        
        void OnResult(SQLResult& res)
@@ -125,7 +124,7 @@ class ModuleSQLAuth : public Module
                if (sha256)
                        userinfo["sha256pass"] = sha256->hexsum(user->password);
 
-               SQL->submit(new AuthQuery(this, SQL->FormatQuery(freeformquery, userinfo), user->uuid, pendingExt, verbose));
+               SQL->submit(new AuthQuery(this, user->uuid, pendingExt, verbose), freeformquery, userinfo);
 
                return MOD_RES_PASSTHRU;
        }
index 307f72f3c034e70194ff67edac07b129c8699ce9..92e402811dd9a3e0ee7d8d376a10fb2c4fab4c52 100644 (file)
@@ -35,10 +35,9 @@ class OpMeQuery : public SQLQuery
 {
  public:
        const std::string uid, username, password;
-       OpMeQuery(Module* me, const std::string& q, const std::string& u, const std::string& un, const std::string& pw)
-               : SQLQuery(me, q), uid(u), username(un), password(pw)
+       OpMeQuery(Module* me, const std::string& u, const std::string& un, const std::string& pw)
+               : SQLQuery(me), uid(u), username(un), password(pw)
        {
-               ServerInstance->Logs->Log("m_sqloper",DEBUG, "SQLOPER: query=\"%s\"", q.c_str());
        }
 
        void OnResult(SQLResult& res)
@@ -173,7 +172,7 @@ public:
                userinfo["username"] = username;
                userinfo["password"] = hash ? hash->hexsum(password) : password;
 
-               SQL->submit(new OpMeQuery(this, SQL->FormatQuery(query, userinfo), user->uuid, username, password));
+               SQL->submit(new OpMeQuery(this, user->uuid, username, password), query, userinfo);
        }
 
        Version GetVersion()
index e558d9f6bdf687d415634d9b7fc5cf0fd1adfce8..0c174e987a5bb5eba41b665ca9987bd27f2a86a2 100644 (file)
@@ -127,10 +127,8 @@ class SQLQuery : public classbase
 {
  public:
        ModuleRef creator;
-       const std::string query;
 
-       SQLQuery(Module* Creator, const std::string& q)
-               : creator(Creator), query(q) {}
+       SQLQuery(Module* Creator) : creator(Creator) {}
        virtual ~SQLQuery() {}
 
        virtual void OnResult(SQLResult& result) = 0;
@@ -148,25 +146,26 @@ class SQLProvider : public DataProvider
  public:
        SQLProvider(Module* Creator, const std::string& Name) : DataProvider(Creator, Name) {}
        /** Submit an asynchronous SQL request
-        * @param dbid The database ID to apply the request to
-        * @param query The query string
-        * @param callback The callback that the result is sent to
+        * @param callback The result reporting point
+        * @param query The hardcoded query string. If you have parameters to substitute, see below.
         */
-       virtual void submit(SQLQuery* query) = 0;
+       virtual void submit(SQLQuery* callback, const std::string& query) = 0;
 
-       /** Format a parameterized query string using proper SQL escaping.
-        * @param q The query string, with '?' parameters
-        * @param p The parameters to fill in in the '?' slots
+       /** Submit an asynchronous SQL request
+        * @param callback The result reporting point
+        * @param format The simple parameterized query string ('?' parameters)
+        * @param p Parameters to fill in for the '?' entries
         */
-       virtual std::string FormatQuery(const std::string& q, const ParamL& p) = 0;
+       virtual void submit(SQLQuery* callback, const std::string& format, const ParamL& p) = 0;
 
-       /** Format a parameterized query string using proper SQL escaping.
-        * @param q The query string, with '$foo' parameters
-        * @param p The map to look up parameters in
+       /** Submit an asynchronous SQL request.
+        * @param callback The result reporting point
+        * @param format The parameterized query string ('$name' parameters)
+        * @param p Parameters to fill in for the '$name' entries
         */
-       virtual std::string FormatQuery(const std::string& q, const ParamM& p) = 0;
+       virtual void submit(SQLQuery* callback, const std::string& format, const ParamM& p) = 0;
 
-       /** Convenience function */
+       /** Convenience function to prepare a map from a User* */
        void PopulateUserInfo(User* user, ParamM& userinfo)
        {
                userinfo["nick"] = user->nick;