X-Git-Url: https://git.netwichtig.de/gitweb/?a=blobdiff_plain;f=src%2Fmodules%2Fextra%2Fm_pgsql.cpp;h=51217bb715d5a95ef59e644d42c249721c8a7b44;hb=48262da087538c38b91bf3a1a51ffaa5e61e502f;hp=c7feb2a2217d1aeb8aea84bf547196900bb6f578;hpb=95ac8e2fd305798bdaa6d0e1720fb36e3f954b18;p=user%2Fhenk%2Fcode%2Finspircd.git diff --git a/src/modules/extra/m_pgsql.cpp b/src/modules/extra/m_pgsql.cpp index c7feb2a22..51217bb71 100644 --- a/src/modules/extra/m_pgsql.cpp +++ b/src/modules/extra/m_pgsql.cpp @@ -15,6 +15,7 @@ * --------------------------------------------------- */ +#include #include #include #include @@ -43,7 +44,7 @@ * I can access the socket engine :\ */ extern InspIRCd* ServerInstance; -InspSocket* socket_ref[MAX_DESCRIPTORS]; +extern time_t TIME; /* Forward declare, so we can have the typedef neatly at the top */ class SQLConn; @@ -55,9 +56,61 @@ typedef std::map ConnMap; /* CREAD, Connecting and wants read event * CWRITE, Connecting and wants write event * WREAD, Connected/Working and wants read event - * WWRITE, Connected/Working and wants write event + * WWRITE, Connected/Working and wants write event + * RREAD, Resetting and wants read event + * RWRITE, Resetting and wants write event */ -enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE }; +enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE, RREAD, RWRITE }; + +/** SQLhost, simple structure to store information about a SQL-connection-to-be + * We use this struct simply to make it neater passing around host information + * when we're creating connections and resolving hosts. + * Rather than giving SQLresolver a parameter for every field here so it can in + * turn call SQLConn's constructor with them all, both can simply use a SQLhost. + */ +class SQLhost +{ + public: + std::string id; /* Database handle id */ + std::string host; /* Database server hostname */ + unsigned int port; /* Database server port */ + std::string name; /* Database name */ + std::string user; /* Database username */ + std::string pass; /* Database password */ + bool ssl; /* If we should require SSL */ + + SQLhost() + { + } + + SQLhost(const std::string& i, const std::string& h, unsigned int p, const std::string& n, const std::string& u, const std::string& pa, bool s) + : id(i), host(h), port(p), name(n), user(u), pass(pa), ssl(s) + { + } +}; + +class SQLresolver : public Resolver +{ + private: + SQLhost host; + ModulePgSQL* mod; + public: + SQLresolver(ModulePgSQL* m, Server* srv, const SQLhost& hi) + : Resolver(ServerInstance, hi.host, DNS_QUERY_FORWARD), host(hi), mod(m) + { + } + + virtual void OnLookupComplete(const std::string &result); + + virtual void OnError(ResolverError e, const std::string &errormessage) + { + log(DEBUG, "DNS lookup failed (%s), dying horribly", errormessage.c_str()); + } + + virtual ~SQLresolver() + { + } +}; /** QueryQueue, a queue of queries waiting to be executed. * This maintains two queues internally, one for 'priority' @@ -205,27 +258,43 @@ class PgSQLresult : public SQLresult { PGresult* res; int currentrow; + int rows; + int cols; SQLfieldList* fieldlist; SQLfieldMap* fieldmap; public: - PgSQLresult(Module* self, Module* to, PGresult* result) - : SQLresult(self, to), res(result), currentrow(0), fieldlist(NULL), fieldmap(NULL) + PgSQLresult(Module* self, Module* to, unsigned long id, PGresult* result) + : SQLresult(self, to, id), res(result), currentrow(0), fieldlist(NULL), fieldmap(NULL) { - int rows = PQntuples(res); - int cols = PQnfields(res); + rows = PQntuples(res); + cols = PQnfields(res); - log(DEBUG, "Created new PgSQL result; %d rows, %d columns", rows, cols); + log(DEBUG, "Created new PgSQL result; %d rows, %d columns, %s affected", rows, cols, PQcmdTuples(res)); } ~PgSQLresult() { + /* If we allocated these, free them... */ + if(fieldlist) + DELETE(fieldlist); + + if(fieldmap) + DELETE(fieldmap); + PQclear(res); } virtual int Rows() { - return PQntuples(res); + if(!cols && !rows) + { + return atoi(PQcmdTuples(res)); + } + else + { + return rows; + } } virtual int Cols() @@ -398,17 +467,16 @@ private: SQLstatus status; /* PgSQL database connection status */ bool qinprog;/* If there is currently a query in progress */ QueryQueue queue; /* Queue of queries waiting to be executed on this connection */ + time_t idle; /* Time we last heard from the database */ public: /* This class should only ever be created inside this module, using this constructor, so we don't have to worry about the default ones */ - SQLConn(ModulePgSQL* self, Server* srv, const std::string &h, unsigned int p, const std::string &d, const std::string &u, const std::string &pwd, bool s); + SQLConn(InspIRCd* SI, ModulePgSQL* self, Server* srv, const SQLhost& hostinfo); ~SQLConn(); - bool DoResolve(); - bool DoConnect(); virtual void Close(); @@ -416,6 +484,8 @@ public: bool DoPoll(); bool DoConnectedPoll(); + + bool DoResetPoll(); void ShowStatus(); @@ -427,11 +497,13 @@ public: bool DoEvent(); + bool Reconnect(); + std::string MkInfoStr(); const char* StatusStr(); - SQLerror DoQuery(const SQLrequest &req); + SQLerror DoQuery(SQLrequest &req); SQLerror Query(const SQLrequest &req); @@ -451,7 +523,6 @@ public: : Module::Module(Me), Srv(Me), currid(0) { log(DEBUG, "%s 'SQL' feature", Srv->PublishFeature("SQL", this) ? "Published" : "Couldn't publish"); - log(DEBUG, "%s 'PgSQL' feature", Srv->PublishFeature("PgSQL", this) ? "Published" : "Couldn't publish"); sqlsuccess = new char[strlen(SQLSUCCESS)+1]; @@ -483,25 +554,55 @@ public: for(int i = 0; i < conf.Enumerate("database"); i++) { - std::string id; - SQLConn* newconn; + SQLhost host; + int ipvalid; + insp_inaddr blargle; + + host.id = conf.ReadValue("database", "id", i); + host.host = conf.ReadValue("database", "hostname", i); + host.port = conf.ReadInteger("database", "port", i, true); + host.name = conf.ReadValue("database", "name", i); + host.user = conf.ReadValue("database", "username", i); + host.pass = conf.ReadValue("database", "password", i); + host.ssl = conf.ReadFlag("database", "ssl", i); - id = conf.ReadValue("database", "id", i); - newconn = new SQLConn(this, Srv, - conf.ReadValue("database", "hostname", i), - conf.ReadInteger("database", "port", i, true), - conf.ReadValue("database", "name", i), - conf.ReadValue("database", "username", i), - conf.ReadValue("database", "password", i), - conf.ReadFlag("database", "ssl", i)); + ipvalid = insp_aton(host.host.c_str(), &blargle); - connections.insert(std::make_pair(id, newconn)); + if(ipvalid > 0) + { + /* The conversion succeeded, we were given an IP and we can give it straight to SQLConn */ + this->AddConn(host); + } + else if(ipvalid == 0) + { + /* Conversion failed, assume it's a host */ + SQLresolver* resolver; + + resolver = new SQLresolver(this, Srv, host); + + Srv->AddResolver(resolver); + } + else + { + /* Invalid address family, die horribly. */ + log(DEBUG, "insp_aton failed returning -1, oh noes."); + } } } + void AddConn(const SQLhost& hi) + { + SQLConn* newconn; + + /* The conversion succeeded, we were given an IP and we can give it straight to SQLConn */ + newconn = new SQLConn(ServerInstance, this, Srv, hi); + + connections.insert(std::make_pair(hi.id, newconn)); + } + virtual char* OnRequest(Request* request) { - if(strcmp(SQLREQID, request->GetData()) == 0) + if(strcmp(SQLREQID, request->GetId()) == 0) { SQLrequest* req = (SQLrequest*)request; ConnMap::iterator iter; @@ -511,8 +612,8 @@ public: if((iter = connections.find(req->dbid)) != connections.end()) { /* Execute query */ - req->error = iter->second->Query(*req); req->id = NewID(); + req->error = iter->second->Query(*req); return (req->error.Id() == NO_ERROR) ? sqlsuccess : NULL; } @@ -523,7 +624,7 @@ public: } } - log(DEBUG, "Got unsupported API version string: %s", request->GetData()); + log(DEBUG, "Got unsupported API version string: %s", request->GetId()); return NULL; } @@ -538,7 +639,7 @@ public: */ for(ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) { - + iter->second->OnUnloadModule(mod); } } @@ -561,8 +662,8 @@ public: } }; -SQLConn::SQLConn(ModulePgSQL* self, Server* srv, const std::string &h, unsigned int p, const std::string &d, const std::string &u, const std::string &pwd, bool s) -: InspSocket::InspSocket(), us(self), Srv(srv), dbhost(h), dbport(p), dbname(d), dbuser(u), dbpass(pwd), ssl(s), sql(NULL), status(CWRITE), qinprog(false) +SQLConn::SQLConn(InspIRCd* SI, ModulePgSQL* self, Server* srv, const SQLhost& hi) +: InspSocket::InspSocket(SI), us(self), Srv(srv), dbhost(hi.host), dbport(hi.port), dbname(hi.name), dbuser(hi.user), dbpass(hi.pass), ssl(hi.ssl), sql(NULL), status(CWRITE), qinprog(false) { log(DEBUG, "Creating new PgSQL connection to database %s on %s:%u (%s/%s)", dbname.c_str(), dbhost.c_str(), dbport, dbuser.c_str(), dbpass.c_str()); @@ -570,36 +671,18 @@ SQLConn::SQLConn(ModulePgSQL* self, Server* srv, const std::string &h, unsigned * just copied this over from the InspSocket constructor. */ strlcpy(this->host, dbhost.c_str(), MAXBUF); + strlcpy(this->IP, dbhost.c_str(), MAXBUF); this->port = dbport; + idle = TIME; this->ClosePending = false; + + log(DEBUG,"No need to resolve %s", this->host); - if(!inet_aton(this->host, &this->addy)) - { - /* Its not an ip, spawn the resolver. - * PgSQL doesn't do nonblocking DNS - * lookups, so we do it for it. - */ - - log(DEBUG,"Attempting to resolve %s", this->host); - this->dns.SetNS(Srv->GetConfig()->DNSServer); - this->dns.ForwardLookupWithFD(this->host, fd); - - this->state = I_RESOLVING; - socket_ref[this->fd] = this; - - return; - } - else + if(!this->DoConnect()) { - log(DEBUG,"No need to resolve %s", this->host); - strlcpy(this->IP, this->host, MAXBUF); - - if(!this->DoConnect()) - { - throw ModuleException("Connect failed"); - } + throw ModuleException("Connect failed"); } } @@ -608,39 +691,6 @@ SQLConn::~SQLConn() Close(); } -bool SQLConn::DoResolve() -{ - log(DEBUG, "Checking for DNS lookup result"); - - if(this->dns.HasResult()) - { - std::string res_ip = dns.GetResultIP(); - - if(res_ip.length()) - { - log(DEBUG, "Got result: %s", res_ip.c_str()); - - strlcpy(this->IP, res_ip.c_str(), MAXBUF); - dbhost = res_ip; - - socket_ref[this->fd] = NULL; - - return this->DoConnect(); - } - else - { - log(DEBUG, "DNS lookup failed, dying horribly"); - Close(); - return false; - } - } - else - { - log(DEBUG, "No result for lookup yet!"); - return true; - } -} - bool SQLConn::DoConnect() { log(DEBUG, "SQLConn::DoConnect()"); @@ -684,8 +734,13 @@ bool SQLConn::DoConnect() } this->state = I_CONNECTING; - ServerInstance->SE->AddFd(this->fd,false,X_ESTAB_MODULE); - socket_ref[this->fd] = this; + if (!ServerInstance->SE->AddFd(this->fd,false,X_ESTAB_MODULE)) + { + log(DEBUG, "A PQsocket cant be added to the socket engine!"); + Close(); + return false; + } + Instance->socket_ref[this->fd] = this; /* Socket all hooked into the engine, now to tell PgSQL to start connecting */ @@ -697,7 +752,7 @@ void SQLConn::Close() log(DEBUG,"SQLConn::Close"); if(this->fd > 01) - socket_ref[this->fd] = NULL; + Instance->socket_ref[this->fd] = NULL; this->fd = -1; this->state = I_ERROR; this->OnError(I_ERR_SOCKET); @@ -724,7 +779,7 @@ bool SQLConn::DoPoll() case PGRES_POLLING_READING: log(DEBUG, "PGconnectPoll: PGRES_POLLING_READING"); status = CREAD; - break; + return true; case PGRES_POLLING_FAILED: log(DEBUG, "PGconnectPoll: PGRES_POLLING_FAILED: %s", PQerrorMessage(sql)); return false; @@ -734,10 +789,8 @@ bool SQLConn::DoPoll() return DoConnectedPoll(); default: log(DEBUG, "PGconnectPoll: wtf?"); - break; + return true; } - - return true; } bool SQLConn::DoConnectedPoll() @@ -752,6 +805,11 @@ bool SQLConn::DoConnectedPoll() if(PQconsumeInput(sql)) { log(DEBUG, "PQconsumeInput succeeded"); + + /* We just read stuff from the server, that counts as it being alive + * so update the idle-since time :p + */ + idle = TIME; if(PQisBusy(sql)) { @@ -764,6 +822,8 @@ bool SQLConn::DoConnectedPoll() /* Grab the request we're processing */ SQLrequest& query = queue.front(); + log(DEBUG, "ID is %lu", query.id); + /* Get a pointer to the module we're about to return the result to */ Module* to = query.GetSource(); @@ -785,9 +845,20 @@ bool SQLConn::DoConnectedPoll() if(to) { /* ..and the result */ - log(DEBUG, "Got result, status code: %s; error message: %s", PQresStatus(PQresultStatus(result)), PQresultErrorMessage(result)); - - PgSQLresult reply(us, to, result); + PgSQLresult reply(us, to, query.id, result); + + log(DEBUG, "Got result, status code: %s; error message: %s", PQresStatus(PQresultStatus(result)), PQresultErrorMessage(result)); + + switch(PQresultStatus(result)) + { + case PGRES_EMPTY_QUERY: + case PGRES_BAD_RESPONSE: + case PGRES_FATAL_ERROR: + reply.error.Id(QREPLY_FAIL); + reply.error.Str(PQresultErrorMessage(result)); + default:; + /* No action, other values are not errors */ + } reply.Send(); @@ -807,12 +878,50 @@ bool SQLConn::DoConnectedPoll() queue.pop(); DoConnectedPoll(); } + else + { + log(DEBUG, "Eh!? We just got a read event, and connection isn't busy..but no result :("); + } return true; } - - log(DEBUG, "PQconsumeInput failed: %s", PQerrorMessage(sql)); - return false; + else + { + /* I think we'll assume this means the server died...it might not, + * but I think that any error serious enough we actually get here + * deserves to reconnect [/excuse] + * Returning true so the core doesn't try and close the connection. + */ + log(DEBUG, "PQconsumeInput failed: %s", PQerrorMessage(sql)); + Reconnect(); + return true; + } +} + +bool SQLConn::DoResetPoll() +{ + switch(PQresetPoll(sql)) + { + case PGRES_POLLING_WRITING: + log(DEBUG, "PGresetPoll: PGRES_POLLING_WRITING"); + WantWrite(); + status = CWRITE; + return DoPoll(); + case PGRES_POLLING_READING: + log(DEBUG, "PGresetPoll: PGRES_POLLING_READING"); + status = CREAD; + return true; + case PGRES_POLLING_FAILED: + log(DEBUG, "PGresetPoll: PGRES_POLLING_FAILED: %s", PQerrorMessage(sql)); + return false; + case PGRES_POLLING_OK: + log(DEBUG, "PGresetPoll: PGRES_POLLING_OK"); + status = WWRITE; + return DoConnectedPoll(); + default: + log(DEBUG, "PGresetPoll: wtf?"); + return true; + } } void SQLConn::ShowStatus() @@ -871,6 +980,26 @@ bool SQLConn::OnConnected() return DoEvent(); } +bool SQLConn::Reconnect() +{ + log(DEBUG, "Initiating reconnect"); + + if(PQresetStart(sql)) + { + /* Successfully initiatied database reconnect, + * set flags so PQresetPoll() will be called appropriately + */ + status = RWRITE; + qinprog = false; + return true; + } + else + { + log(DEBUG, "Failed to initiate reconnect...fun"); + return false; + } +} + bool SQLConn::DoEvent() { bool ret; @@ -879,6 +1008,10 @@ bool SQLConn::DoEvent() { ret = DoPoll(); } + else if((status == RREAD) || (status == RWRITE)) + { + ret = DoResetPoll(); + } else { ret = DoConnectedPoll(); @@ -935,7 +1068,7 @@ const char* SQLConn::StatusStr() return "Err...what, erm..BUG!"; } -SQLerror SQLConn::DoQuery(const SQLrequest &req) +SQLerror SQLConn::DoQuery(SQLrequest &req) { if((status == WREAD) || (status == WWRITE)) { @@ -943,45 +1076,101 @@ SQLerror SQLConn::DoQuery(const SQLrequest &req) { /* Parse the command string and dispatch it */ - /* A list of offsets into the original string of the '?' characters we're substituting */ - std::vector insertlocs; + /* Pointer to the buffer we screw around with substitution in */ + char* query; + /* Pointer to the current end of query, where we append new stuff */ + char* queryend; + /* Total length of the unescaped parameters */ + unsigned int paramlen; + + paramlen = 0; + + for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++) + { + paramlen += i->size(); + } + + /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be. + * sizeofquery + (totalparamlength*2) + 1 + * + * The +1 is for null-terminating the string for PQsendQuery() + */ + + query = new char[req.query.q.length() + (paramlen*2)]; + queryend = query; + + /* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting + * the parameters into it... + */ for(unsigned int i = 0; i < req.query.q.length(); i++) { if(req.query.q[i] == '?') { - insertlocs.push_back(i); - } - } - - char* query = new char[(req.query.q.length()*2)+1]; - int error = 0; + /* We found a place to substitute..what fun. + * Use the PgSQL calls to escape and write the + * escaped string onto the end of our query buffer, + * then we "just" need to make sure queryend is + * pointing at the right place. + */ + + if(req.query.p.size()) + { + int error = 0; + size_t len = 0; #ifdef PGSQL_HAS_ESCAPECONN - PQescapeStringConn(sql, query, req.query.q.c_str(), req.query.q.length(), &error); + len = PQescapeStringConn(sql, queryend, req.query.p.front().c_str(), req.query.p.front().length(), &error); #else - PQescapeString(query, req.query.q.c_str(), req.query.q.length()); - error = 0; + len = PQescapeStringConn(queryend, req.query.p.front().c_str(), req.query.p.front().length()); + error = 0; #endif - - if(error == 0) - { - if(PQsendQuery(sql, query)) - { - log(DEBUG, "Dispatched query: %s", query); - qinprog = true; - return SQLerror(); + + if(error) + { + log(DEBUG, "Apparently PQescapeStringConn() failed somehow...don't know how or what to do..."); + } + + log(DEBUG, "Appended %d bytes of escaped string onto the query", len); + + /* Incremenet queryend to the end of the newly escaped parameter */ + queryend += len; + + /* Remove the parameter we just substituted in */ + req.query.p.pop_front(); + } + else + { + log(DEBUG, "Found a substitution location but no parameter to substitute :|"); + break; + } } else { - log(DEBUG, "Failed to dispatch query: %s", PQerrorMessage(sql)); - return SQLerror(QSEND_FAIL, PQerrorMessage(sql)); + *queryend = req.query.q[i]; + queryend++; } } + + /* Null-terminate the query */ + *queryend = 0; + + log(DEBUG, "Attempting to dispatch query: %s", query); + + req.query.q = query; + + if(PQsendQuery(sql, query)) + { + log(DEBUG, "Dispatched query successfully"); + qinprog = true; + DELETE(query); + return SQLerror(); + } else { - log(DEBUG, "Failed to escape query string"); - return SQLerror(QSEND_FAIL, "Couldn't escape query string"); + log(DEBUG, "Failed to dispatch query: %s", PQerrorMessage(sql)); + DELETE(query); + return SQLerror(QSEND_FAIL, PQerrorMessage(sql)); } } } @@ -1011,6 +1200,12 @@ void SQLConn::OnUnloadModule(Module* mod) queue.PurgeModule(mod); } +void SQLresolver::OnLookupComplete(const std::string &result) +{ + host.host = result; + mod->AddConn(host); +} + class ModulePgSQLFactory : public ModuleFactory { public: