]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
6741d77459f8db24479cb3cbb330f4532a328344
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*       +------------------------------------+\r *       | Inspire Internet Relay Chat Daemon |\r *       +------------------------------------+\r *\r *  InspIRCd: (C) 2002-2007 InspIRCd Development Team\r * See: http://www.inspircd.org/wiki/index.php/Credits\r *\r * This program is free but copyrighted software; see\r *            the file COPYING for details.\r *\r * ---------------------------------------------------\r */\r\r#include "inspircd.h"\r#include <sqlite3.h>\r#include "users.h"\r#include "channels.h"\r#include "modules.h"\r\r#include "m_sqlv2.h"\r\r/* $ModDesc: sqlite3 provider */\r/* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */\r/* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */\r/* $ModDep: m_sqlv2.h */\r\r\rclass SQLConn;\rclass SQLite3Result;\rclass ResultNotifier;\r\rtypedef std::map<std::string, SQLConn*> ConnMap;\rtypedef std::deque<classbase*> paramlist;\rtypedef std::deque<SQLite3Result*> ResultQueue;\r\rResultNotifier* resultnotify = NULL;\r\r\rclass ResultNotifier : public InspSocket\r{\r       Module* mod;\r   insp_sockaddr sock_us;\r socklen_t uslen;\r\r public:\r     /* Create a socket on a random port. Let the tcp stack allocate us an available port */\r#ifdef IPV6\r    ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "::1", 0, true, 3000), mod(m)\r#else\r   ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "127.0.0.1", 0, true, 3000), mod(m)\r#endif\r    {\r              uslen = sizeof(sock_us);\r               if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))\r          {\r                      throw ModuleException("Could not create random listening port on localhost");\r          }\r      }\r\r     ResultNotifier(InspIRCd* SI, Module* m, int newfd, char* ip) : InspSocket(SI, newfd, ip), mod(m)\r       {\r      }\r\r     /* Using getsockname and ntohs, we can determine which port number we were allocated */\r        int GetPort()\r  {\r#ifdef IPV6\r          return ntohs(sock_us.sin6_port);\r#else\r         return ntohs(sock_us.sin_port);\r#endif\r }\r\r     virtual int OnIncomingConnection(int newsock, char* ip)\r        {\r              Dispatch();\r            return false;\r  }\r\r     void Dispatch();\r};\r\r\rclass SQLite3Result : public SQLresult\r{\r  private:\r      int currentrow;\r        int rows;\r      int cols;\r\r     std::vector<std::string> colnames;\r     std::vector<SQLfieldList> fieldlists;\r  SQLfieldList emptyfieldlist;\r\r  SQLfieldList* fieldlist;\r       SQLfieldMap* fieldmap;\r\r  public:\r      SQLite3Result(Module* self, Module* to, unsigned int id)\r       : SQLresult(self, to, id), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)\r    {\r      }\r\r     ~SQLite3Result()\r       {\r      }\r\r     void AddRow(int colsnum, char **data, char **colname)\r  {\r              colnames.clear();\r              cols = colsnum;\r                for (int i = 0; i < colsnum; i++)\r              {\r                      fieldlists.resize(fieldlists.size()+1);\r                        colnames.push_back(colname[i]);\r                        SQLfield sf(data[i] ? data[i] : "", data[i] ? false : true);\r                   fieldlists[rows].push_back(sf);\r                }\r              rows++;\r        }\r\r     void UpdateAffectedCount()\r     {\r              rows++;\r        }\r\r     virtual int Rows()\r     {\r              return rows;\r   }\r\r     virtual int Cols()\r     {\r              return cols;\r   }\r\r     virtual std::string ColName(int column)\r        {\r              if (column < (int)colnames.size())\r             {\r                      return colnames[column];\r               }\r              else\r           {\r                      throw SQLbadColName();\r         }\r              return "";\r     }\r\r     virtual int ColNum(const std::string &column)\r  {\r              for (unsigned int i = 0; i < colnames.size(); i++)\r             {\r                      if (column == colnames[i])\r                             return i;\r              }\r              throw SQLbadColName();\r         return 0;\r      }\r\r     virtual SQLfield GetValue(int row, int column)\r {\r              if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))\r          {\r                      return fieldlists[row][column];\r                }\r\r             throw SQLbadColName();\r\r                /* XXX: We never actually get here because of the throw */\r             return SQLfield("",true);\r      }\r\r     virtual SQLfieldList& GetRow()\r {\r              if (currentrow < rows)\r                 return fieldlists[currentrow];\r         else\r                   return emptyfieldlist;\r }\r\r     virtual SQLfieldMap& GetRowMap()\r       {\r              /* In an effort to reduce overhead we don't actually allocate the map\r           * until the first time it's needed...so...\r             */\r            if(fieldmap)\r           {\r                      fieldmap->clear();\r             }\r              else\r           {\r                      fieldmap = new SQLfieldMap;\r            }\r\r             if (currentrow < rows)\r         {\r                      for (int i = 0; i < Cols(); i++)\r                       {\r                              fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));\r                 }\r                      currentrow++;\r          }\r\r             return *fieldmap;\r      }\r\r     virtual SQLfieldList* GetRowPtr()\r      {\r              fieldlist = new SQLfieldList();\r\r               if (currentrow < rows)\r         {\r                      for (int i = 0; i < Rows(); i++)\r                       {\r                              fieldlist->push_back(fieldlists[currentrow][i]);\r                       }\r                      currentrow++;\r          }\r              return fieldlist;\r      }\r\r     virtual SQLfieldMap* GetRowMapPtr()\r    {\r              fieldmap = new SQLfieldMap();\r\r         if (currentrow < rows)\r         {\r                      for (int i = 0; i < Cols(); i++)\r                       {\r                              fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));\r                 }\r                      currentrow++;\r          }\r\r             return fieldmap;\r       }\r\r     virtual void Free(SQLfieldMap* fm)\r     {\r              delete fm;\r     }\r\r     virtual void Free(SQLfieldList* fl)\r    {\r              delete fl;\r     }\r\r\r};\r\rclass SQLConn : public classbase\r{\r  private:\r  ResultQueue results;\r   InspIRCd* Instance;\r    Module* mod;\r   SQLhost host;\r  sqlite3* conn;\r\r  public:\r      SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)\r    : Instance(SI), mod(m), host(hi)\r       {\r              if (OpenDB() != SQLITE_OK)\r             {\r                      Instance->Log(DEFAULT, "WARNING: Could not open DB with id: " + host.id);\r                      CloseDB();\r             }\r      }\r\r     ~SQLConn()\r     {\r              CloseDB();\r     }\r\r     SQLerror Query(SQLrequest &req)\r        {\r              /* Pointer to the buffer we screw around with substitution in */\r               char* query;\r\r          /* Pointer to the current end of query, where we append new stuff */\r           char* queryend;\r\r               /* Total length of the unescaped parameters */\r         unsigned long paramlen;\r\r               /* Total length of query, used for binary-safety in mysql_real_query */\r                unsigned long querylength = 0;\r\r                paramlen = 0;\r          for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)\r             {\r                      paramlen += i->size();\r         }\r\r             /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.\r           * sizeofquery + (totalparamlength*2) + 1\r               *\r              * The +1 is for null-terminating the string for mysql_real_escape_string\r               */\r            query = new char[req.query.q.length() + (paramlen*2) + 1];\r             queryend = query;\r\r             for(unsigned long i = 0; i < req.query.q.length(); i++)\r                {\r                      if(req.query.q[i] == '?')\r                      {\r                              if(req.query.p.size())\r                         {\r                                      char* escaped;\r                                 escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());\r                                  for (char* n = escaped; *n; n++)\r                                       {\r                                              *queryend = *n;\r                                                queryend++;\r                                    }\r                                      sqlite3_free(escaped);\r                                 req.query.p.pop_front();\r                               }\r                              else\r                                   break;\r                 }\r                      else\r                   {\r                              *queryend = req.query.q[i];\r                            queryend++;\r                    }\r                      querylength++;\r         }\r              *queryend = 0;\r         req.query.q = query;\r\r          SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);\r          res->dbid = host.id;\r           res->query = req.query.q;\r              paramlist params;\r              params.push_back(this);\r                params.push_back(res);\r\r                char *errmsg = 0;\r              sqlite3_update_hook(conn, QueryUpdateHook, &params);\r           if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)\r                {\r                      std::string error(errmsg);\r                     sqlite3_free(errmsg);\r                  delete[] query;\r                        delete res;\r                    return SQLerror(QSEND_FAIL, error);\r            }\r              delete[] query;\r\r               results.push_back(res);\r                SendNotify();\r          return SQLerror();\r     }\r\r     static int QueryResult(void *params, int argc, char **argv, char **azColName)\r  {\r              paramlist* p = (paramlist*)params;\r             ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);\r            return 0;\r      }\r\r     static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid)\r      {\r              paramlist* p = (paramlist*)params;\r             ((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1]));\r }\r\r     void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)\r   {\r              res->AddRow(cols, data, colnames);\r     }\r\r     void AffectedReady(SQLite3Result *res)\r {\r              res->UpdateAffectedCount();\r    }\r\r     int OpenDB()\r   {\r              return sqlite3_open(host.host.c_str(), &conn);\r }\r\r     void CloseDB()\r {\r              sqlite3_interrupt(conn);\r               sqlite3_close(conn);\r   }\r\r     SQLhost GetConfHost()\r  {\r              return host;\r   }\r\r     void SendResults()\r     {\r              while (results.size())\r         {\r                      SQLite3Result* res = results[0];\r                       if (res->GetDest())\r                    {\r                              res->Send();\r                   }\r                      else\r                   {\r                              /* If the client module is unloaded partway through a query then the provider will set\r                          * the pointer to NULL. We cannot just cancel the query as the result will still come\r                           * through at some point...and it could get messy if we play with invalid pointers...\r                           */\r                            delete res;\r                    }\r                      results.pop_front();\r           }\r      }\r\r     void ClearResults()\r    {\r              while (results.size())\r         {\r                      SQLite3Result* res = results[0];\r                       delete res;\r                    results.pop_front();\r           }\r      }\r\r     void SendNotify()\r      {\r              int QueueFD;\r           if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)\r               {\r                      /* crap, we're out of sockets... */\r                    return;\r                }\r\r             insp_sockaddr addr;\r\r#ifdef IPV6\r               insp_aton("::1", &addr.sin6_addr);\r             addr.sin6_family = AF_FAMILY;\r          addr.sin6_port = htons(resultnotify->GetPort());\r#else\r         insp_inaddr ia;\r                insp_aton("127.0.0.1", &ia);\r           addr.sin_family = AF_FAMILY;\r           addr.sin_addr = ia;\r            addr.sin_port = htons(resultnotify->GetPort());\r#endif\r\r                if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)\r             {\r                      /* wtf, we cant connect to it, but we just created it! */\r                      return;\r                }\r      }\r\r};\r\r\rclass ModuleSQLite3 : public Module\r{\r  private:\r       ConnMap connections;\r   unsigned long currid;\r\r  public:\r       ModuleSQLite3(InspIRCd* Me)\r    : Module::Module(Me), currid(0)\r        {\r              ServerInstance->UseInterface("SQLutils");\r\r             if (!ServerInstance->PublishFeature("SQL", this))\r              {\r                      throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");\r           }\r\r             resultnotify = new ResultNotifier(ServerInstance, this);\r\r              ReadConf();\r\r           ServerInstance->PublishInterface("SQL", this);\r }\r\r     virtual ~ModuleSQLite3()\r       {\r              ClearQueue();\r          ClearAllConnections();\r         resultnotify->SetFd(-1);\r               resultnotify->state = I_ERROR;\r         resultnotify->OnError(I_ERR_SOCKET);\r           resultnotify->ClosePending = true;\r             delete resultnotify;\r           ServerInstance->UnpublishInterface("SQL", this);\r               ServerInstance->UnpublishFeature("SQL");\r               ServerInstance->DoneWithInterface("SQLutils");\r }\r\r     void Implements(char* List)\r    {\r              List[I_OnRequest] = List[I_OnRehash] = 1;\r      }\r\r     void SendQueue()\r       {\r              for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)\r          {\r                      iter->second->SendResults();\r           }\r      }\r\r     void ClearQueue()\r      {\r              for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)\r          {\r                      iter->second->ClearResults();\r          }\r      }\r\r     bool HasHost(const SQLhost &host)\r      {\r              for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)\r          {\r                      if (host == iter->second->GetConfHost())\r                               return true;\r           }\r              return false;\r  }\r\r     bool HostInConf(const SQLhost &h)\r      {\r              ConfigReader conf(ServerInstance);\r             for(int i = 0; i < conf.Enumerate("database"); i++)\r            {\r                      SQLhost host;\r                  host.id         = conf.ReadValue("database", "id", i);\r                 host.host       = conf.ReadValue("database", "hostname", i);\r                   host.port       = conf.ReadInteger("database", "port", i, true);\r                       host.name       = conf.ReadValue("database", "name", i);\r                       host.user       = conf.ReadValue("database", "username", i);\r                   host.pass       = conf.ReadValue("database", "password", i);\r                   host.ssl        = conf.ReadFlag("database", "ssl", "0", i);\r                    if (h == host)\r                         return true;\r           }\r              return false;\r  }\r\r     void ReadConf()\r        {\r              ClearOldConnections();\r\r                ConfigReader conf(ServerInstance);\r             for(int i = 0; i < conf.Enumerate("database"); i++)\r            {\r                      SQLhost host;\r\r                 host.id         = conf.ReadValue("database", "id", i);\r                 host.host       = conf.ReadValue("database", "hostname", i);\r                   host.port       = conf.ReadInteger("database", "port", i, true);\r                       host.name       = conf.ReadValue("database", "name", i);\r                       host.user       = conf.ReadValue("database", "username", i);\r                   host.pass       = conf.ReadValue("database", "password", i);\r                   host.ssl        = conf.ReadFlag("database", "ssl", "0", i);\r\r                   if (HasHost(host))\r                             continue;\r\r                     this->AddConn(host);\r           }\r      }\r\r     void AddConn(const SQLhost& hi)\r        {\r              if (HasHost(hi))\r               {\r                      ServerInstance->Log(DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());\r                       return;\r                }\r\r             SQLConn* newconn;\r\r             newconn = new SQLConn(ServerInstance, this, hi);\r\r              connections.insert(std::make_pair(hi.id, newconn));\r    }\r\r     void ClearOldConnections()\r     {\r              ConnMap::iterator iter,safei;\r          for (iter = connections.begin(); iter != connections.end(); iter++)\r            {\r                      if (!HostInConf(iter->second->GetConfHost()))\r                  {\r                              DELETE(iter->second);\r                          safei = iter;\r                          --iter;\r                                connections.erase(safei);\r                      }\r              }\r      }\r\r     void ClearAllConnections()\r     {\r              ConnMap::iterator i;\r           while ((i = connections.begin()) != connections.end())\r         {\r                      connections.erase(i);\r                  DELETE(i->second);\r             }\r      }\r\r     virtual void OnRehash(userrec* user, const std::string &parameter)\r     {\r              ReadConf();\r    }\r\r     virtual char* OnRequest(Request* request)\r      {\r              if(strcmp(SQLREQID, request->GetId()) == 0)\r            {\r                      SQLrequest* req = (SQLrequest*)request;\r                        ConnMap::iterator iter;\r                        if((iter = connections.find(req->dbid)) != connections.end())\r                  {\r                              req->id = NewID();\r                             req->error = iter->second->Query(*req);\r                                return SQLSUCCESS;\r                     }\r                      else\r                   {\r                              req->error.Id(BAD_DBID);\r                               return NULL;\r                   }\r              }\r              return NULL;\r   }\r\r     unsigned long NewID()\r  {\r              if (currid+1 == 0)\r                     currid++;\r\r             return ++currid;\r       }\r\r     virtual Version GetVersion()\r   {\r              return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);\r      }\r\r};\r\rvoid ResultNotifier::Dispatch()\r{\r       ((ModuleSQLite3*)mod)->SendQueue();\r}\r\rMODULE_INIT(ModuleSQLite3);\r\r