]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
Change includes, use --libs_r rather than mysql_config --libs, we want re-enterant...
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17 using namespace std;
18
19 #include <stdio.h>
20 #include <string>
21 #include <mysql.h>
22 #include <pthread.h>
23 #include "users.h"
24 #include "channels.h"
25 #include "modules.h"
26 #include "helperfuncs.h"
27 #include "m_sqlv2.h"
28
29 /* VERSION 2 API: With nonblocking (threaded) requests */
30
31 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
32 /* $CompileFlags: -pthread `mysql_config --include` */
33 /* $LinkerFlags: -pthread `mysql_config --libs_r` `perl ../mysql_rpath.pl` */
34
35 /** SQLConnection represents one mysql session.
36  * Each session has its own persistent connection to the database.
37  */
38
39 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
40 #define mysql_field_count mysql_num_fields
41 #endif
42
43 class SQLConnection : public classbase
44 {
45  protected:
46
47         MYSQL connection;
48         MYSQL_RES *res;
49         MYSQL_ROW row;
50         std::string host;
51         std::string user;
52         std::string pass;
53         std::string db;
54         std::map<std::string,std::string> thisrow;
55         bool Enabled;
56         long id;
57
58  public:
59
60         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
61         // MYSQL struct, but does not connect yet.
62         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
63         {
64                 this->Enabled = true;
65                 this->host = thishost;
66                 this->user = thisuser;
67                 this->pass = thispass;
68                 this->db = thisdb;
69                 this->id = myid;
70         }
71
72         // This method connects to the database using the credentials supplied to the constructor, and returns
73         // true upon success.
74         bool Connect()
75         {
76                 unsigned int timeout = 1;
77                 mysql_init(&connection);
78                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
79                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
80         }
81
82         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
83         // multiple rows.
84         bool QueryResult(std::string query)
85         {
86                 if (!CheckConnection()) return false;
87                 
88                 int r = mysql_query(&connection, query.c_str());
89                 if (!r)
90                 {
91                         res = mysql_use_result(&connection);
92                 }
93                 return (!r);
94         }
95
96         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
97         // the number of effected rows is returned in the return value.
98         long QueryCount(std::string query)
99         {
100                 /* If the connection is down, we return a negative value - New to 1.1 */
101                 if (!CheckConnection()) return -1;
102
103                 int r = mysql_query(&connection, query.c_str());
104                 if (!r)
105                 {
106                         res = mysql_store_result(&connection);
107                         unsigned long rows = mysql_affected_rows(&connection);
108                         mysql_free_result(res);
109                         return rows;
110                 }
111                 return 0;
112         }
113
114         // This method fetches a row, if available from the database. You must issue a query
115         // using QueryResult() first! The row's values are returned as a map of std::string
116         // where each item is keyed by the column name.
117         std::map<std::string,std::string> GetRow()
118         {
119                 thisrow.clear();
120                 if (res)
121                 {
122                         row = mysql_fetch_row(res);
123                         if (row)
124                         {
125                                 unsigned int field_count = 0;
126                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
127                                 if(mysql_field_count(&connection) == 0)
128                                         return thisrow;
129                                 if (fields && mysql_field_count(&connection))
130                                 {
131                                         while (field_count < mysql_field_count(&connection))
132                                         {
133                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
134                                                 std::string b = (row[field_count] ? row[field_count] : "");
135                                                 thisrow[a] = b;
136                                                 field_count++;
137                                         }
138                                         return thisrow;
139                                 }
140                         }
141                 }
142                 return thisrow;
143         }
144
145         bool QueryDone()
146         {
147                 if (res)
148                 {
149                         mysql_free_result(res);
150                         res = NULL;
151                         return true;
152                 }
153                 else return false;
154         }
155
156         bool ConnectionLost()
157         {
158                 if (&connection) {
159                         return (mysql_ping(&connection) != 0);
160                 }
161                 else return false;
162         }
163
164         bool CheckConnection()
165         {
166                 if (ConnectionLost()) {
167                         return Connect();
168                 }
169                 else return true;
170         }
171
172         std::string GetError()
173         {
174                 return mysql_error(&connection);
175         }
176
177         long GetID()
178         {
179                 return id;
180         }
181
182         std::string GetHost()
183         {
184                 return host;
185         }
186
187         void Enable()
188         {
189                 Enabled = true;
190         }
191
192         void Disable()
193         {
194                 Enabled = false;
195         }
196
197         bool IsEnabled()
198         {
199                 return Enabled;
200         }
201
202 };
203
204 typedef std::vector<SQLConnection> ConnectionList;
205
206 class ModuleSQL : public Module
207 {
208         Server *Srv;
209         ConfigReader *Conf;
210         ConnectionList Connections;
211  
212  public:
213         void ConnectDatabases()
214         {
215                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
216                 {
217                         i->Enable();
218                         if (i->Connect())
219                         {
220                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
221                         }
222                         else
223                         {
224                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
225                                 i->Disable();
226                         }
227                 }
228         }
229
230         void LoadDatabases(ConfigReader* ThisConf)
231         {
232                 Srv->Log(DEFAULT,"SQL: Loading database settings");
233                 Connections.clear();
234                 Srv->Log(DEBUG,"Cleared connections");
235                 for (int j =0; j < ThisConf->Enumerate("database"); j++)
236                 {
237                         std::string db = ThisConf->ReadValue("database","name",j);
238                         std::string user = ThisConf->ReadValue("database","username",j);
239                         std::string pass = ThisConf->ReadValue("database","password",j);
240                         std::string host = ThisConf->ReadValue("database","hostname",j);
241                         std::string id = ThisConf->ReadValue("database","id",j);
242                         Srv->Log(DEBUG,"Read database settings");
243                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
244                         {
245                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
246                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
247                                 Connections.push_back(ThisSQL);
248                                 Srv->Log(DEBUG,"Pushed back connection");
249                         }
250                 }
251                 ConnectDatabases();
252         }
253
254         void ResultType(SQLRequest *r, SQLResult *res)
255         {
256                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
257                 {
258                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
259                         {
260                                 bool xr = i->QueryResult(r->GetQuery());
261                                 if (!xr)
262                                 {
263                                         res->SetType(SQL_ERROR);
264                                         res->SetError(i->GetError());
265                                         return;
266                                 }
267                                 res->SetType(SQL_OK);
268                                 return;
269                         }
270                 }
271         }
272
273         void CountType(SQLRequest *r, SQLResult* res)
274         {
275                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
276                 {
277                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
278                         {
279                                 res->SetType(SQL_COUNT);
280                                 res->SetCount(i->QueryCount(r->GetQuery()));
281                                 return;
282                         }
283                 }
284         }
285
286         void DoneType(SQLRequest *r, SQLResult* res)
287         {
288                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
289                 {
290                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
291                         {
292                                 res->SetType(SQL_DONE);
293                                 if (!i->QueryDone())
294                                         res->SetType(SQL_ERROR);
295                         }
296                 }
297         }
298
299         void RowType(SQLRequest *r, SQLResult* res)
300         {
301                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
302                 {
303                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
304                         {
305                                 log(DEBUG,"*** FOUND MATCHING ROW");
306                                 std::map<std::string,std::string> row = i->GetRow();
307                                 res->SetRow(row);
308                                 res->SetType(SQL_ROW);
309                                 if (!row.size())
310                                 {
311                                         log(DEBUG,"ROW SIZE IS 0");
312                                         res->SetType(SQL_END);
313                                 }
314                                 return;
315                         }
316                 }
317         }
318
319         void Implements(char* List)
320         {
321                 List[I_OnRehash] = List[I_OnRequest] = 1;
322         }
323
324         char* OnRequest(Request* request)
325         {
326                 if (request)
327                 {
328                         SQLResult* Result = new SQLResult();
329                         SQLRequest *r = (SQLRequest*)request->GetData();
330                         switch (r->GetQueryType())
331                         {
332                                 case SQL_RESULT:
333                                         ResultType(r,Result);
334                                 break;
335                                 case SQL_COUNT:
336                                         CountType(r,Result);
337                                 break;
338                                 case SQL_ROW:
339                                         RowType(r,Result);
340                                 break;
341                                 case SQL_DONE:
342                                         DoneType(r,Result);
343                                 break;
344                         }
345                         return (char*)Result;
346                 }
347                 return NULL;
348         }
349
350         ModuleSQL(Server* Me)
351                 : Module::Module(Me)
352         {
353                 Srv = Me;
354                 Conf = new ConfigReader();
355                 LoadDatabases(Conf);
356         }
357         
358         virtual ~ModuleSQL()
359         {
360                 Connections.clear();
361                 DELETE(Conf);
362         }
363         
364         virtual void OnRehash(const std::string &parameter)
365         {
366                 DELETE(Conf);
367                 Conf = new ConfigReader();
368                 LoadDatabases(Conf);
369         }
370         
371         virtual Version GetVersion()
372         {
373                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
374         }
375         
376 };
377
378 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
379
380 class ModuleSQLFactory : public ModuleFactory
381 {
382  public:
383         ModuleSQLFactory()
384         {
385         }
386         
387         ~ModuleSQLFactory()
388         {
389         }
390         
391         virtual Module * CreateModule(Server* Me)
392         {
393                 return new ModuleSQL(Me);
394         }
395         
396 };
397
398
399 extern "C" void * init_module( void )
400 {
401         return new ModuleSQLFactory;
402 }
403