]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_mysql.cpp
654d841918284aebb3f70b6ba8c8496d6df2bb1f
[user/henk/code/inspircd.git] / src / modules / extra / m_mysql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17 using namespace std;
18
19 #include <stdio.h>
20 #include <string>
21 #include <mysql.h>
22 #include "users.h"
23 #include "channels.h"
24 #include "modules.h"
25 #include "helperfuncs.h"
26 #include "m_sql.h"
27
28 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
29 /* $CompileFlags: `mysql_config --include` */
30 /* $LinkerFlags: `mysql_config --libs` `perl ../mysql_rpath.pl` */
31
32 /** SQLConnection represents one mysql session.
33  * Each session has its own persistent connection to the database.
34  */
35
36 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
37 #define mysql_field_count mysql_num_fields
38 #endif
39
40 class SQLConnection : public classbase
41 {
42  protected:
43
44         MYSQL connection;
45         MYSQL_RES *res;
46         MYSQL_ROW row;
47         std::string host;
48         std::string user;
49         std::string pass;
50         std::string db;
51         std::map<std::string,std::string> thisrow;
52         bool Enabled;
53         long id;
54
55  public:
56
57         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
58         // MYSQL struct, but does not connect yet.
59         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
60         {
61                 this->Enabled = true;
62                 this->host = thishost;
63                 this->user = thisuser;
64                 this->pass = thispass;
65                 this->db = thisdb;
66                 this->id = myid;
67         }
68
69         // This method connects to the database using the credentials supplied to the constructor, and returns
70         // true upon success.
71         bool Connect()
72         {
73                 unsigned int timeout = 1;
74                 mysql_init(&connection);
75                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
76                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
77         }
78
79         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
80         // multiple rows.
81         bool QueryResult(std::string query)
82         {
83                 if (!CheckConnection()) return false;
84                 
85                 int r = mysql_query(&connection, query.c_str());
86                 if (!r)
87                 {
88                         res = mysql_use_result(&connection);
89                 }
90                 return (!r);
91         }
92
93         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
94         // the number of effected rows is returned in the return value.
95         long QueryCount(std::string query)
96         {
97                 /* If the connection is down, we return a negative value - New to 1.1 */
98                 if (!CheckConnection()) return -1;
99
100                 int r = mysql_query(&connection, query.c_str());
101                 if (!r)
102                 {
103                         res = mysql_store_result(&connection);
104                         unsigned long rows = mysql_affected_rows(&connection);
105                         mysql_free_result(res);
106                         return rows;
107                 }
108                 return 0;
109         }
110
111         // This method fetches a row, if available from the database. You must issue a query
112         // using QueryResult() first! The row's values are returned as a map of std::string
113         // where each item is keyed by the column name.
114         std::map<std::string,std::string> GetRow()
115         {
116                 thisrow.clear();
117                 if (res)
118                 {
119                         row = mysql_fetch_row(res);
120                         if (row)
121                         {
122                                 unsigned int field_count = 0;
123                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
124                                 if(mysql_field_count(&connection) == 0)
125                                         return thisrow;
126                                 if (fields && mysql_field_count(&connection))
127                                 {
128                                         while (field_count < mysql_field_count(&connection))
129                                         {
130                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
131                                                 std::string b = (row[field_count] ? row[field_count] : "");
132                                                 thisrow[a] = b;
133                                                 field_count++;
134                                         }
135                                         return thisrow;
136                                 }
137                         }
138                 }
139                 return thisrow;
140         }
141
142         bool QueryDone()
143         {
144                 if (res)
145                 {
146                         mysql_free_result(res);
147                         res = NULL;
148                         return true;
149                 }
150                 else return false;
151         }
152
153         bool ConnectionLost()
154         {
155                 if (&connection) {
156                         return (mysql_ping(&connection) != 0);
157                 }
158                 else return false;
159         }
160
161         bool CheckConnection()
162         {
163                 if (ConnectionLost()) {
164                         return Connect();
165                 }
166                 else return true;
167         }
168
169         std::string GetError()
170         {
171                 return mysql_error(&connection);
172         }
173
174         long GetID()
175         {
176                 return id;
177         }
178
179         std::string GetHost()
180         {
181                 return host;
182         }
183
184         void Enable()
185         {
186                 Enabled = true;
187         }
188
189         void Disable()
190         {
191                 Enabled = false;
192         }
193
194         bool IsEnabled()
195         {
196                 return Enabled;
197         }
198
199 };
200
201 typedef std::vector<SQLConnection> ConnectionList;
202
203 class ModuleSQL : public Module
204 {
205         Server *Srv;
206         ConfigReader *Conf;
207         ConnectionList Connections;
208  
209  public:
210         void ConnectDatabases()
211         {
212                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
213                 {
214                         i->Enable();
215                         if (i->Connect())
216                         {
217                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
218                         }
219                         else
220                         {
221                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
222                                 i->Disable();
223                         }
224                 }
225         }
226
227         void LoadDatabases(ConfigReader* ThisConf)
228         {
229                 Srv->Log(DEFAULT,"SQL: Loading database settings");
230                 Connections.clear();
231                 Srv->Log(DEBUG,"Cleared connections");
232                 for (int j =0; j < ThisConf->Enumerate("database"); j++)
233                 {
234                         std::string db = ThisConf->ReadValue("database","name",j);
235                         std::string user = ThisConf->ReadValue("database","username",j);
236                         std::string pass = ThisConf->ReadValue("database","password",j);
237                         std::string host = ThisConf->ReadValue("database","hostname",j);
238                         std::string id = ThisConf->ReadValue("database","id",j);
239                         Srv->Log(DEBUG,"Read database settings");
240                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
241                         {
242                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
243                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
244                                 Connections.push_back(ThisSQL);
245                                 Srv->Log(DEBUG,"Pushed back connection");
246                         }
247                 }
248                 ConnectDatabases();
249         }
250
251         void ResultType(SQLRequest *r, SQLResult *res)
252         {
253                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
254                 {
255                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
256                         {
257                                 bool xr = i->QueryResult(r->GetQuery());
258                                 if (!xr)
259                                 {
260                                         res->SetType(SQL_ERROR);
261                                         res->SetError(i->GetError());
262                                         return;
263                                 }
264                                 res->SetType(SQL_OK);
265                                 return;
266                         }
267                 }
268         }
269
270         void CountType(SQLRequest *r, SQLResult* res)
271         {
272                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
273                 {
274                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
275                         {
276                                 res->SetType(SQL_COUNT);
277                                 res->SetCount(i->QueryCount(r->GetQuery()));
278                                 return;
279                         }
280                 }
281         }
282
283         void DoneType(SQLRequest *r, SQLResult* res)
284         {
285                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
286                 {
287                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
288                         {
289                                 res->SetType(SQL_DONE);
290                                 if (!i->QueryDone())
291                                         res->SetType(SQL_ERROR);
292                         }
293                 }
294         }
295
296         void RowType(SQLRequest *r, SQLResult* res)
297         {
298                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
299                 {
300                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
301                         {
302                                 log(DEBUG,"*** FOUND MATCHING ROW");
303                                 std::map<std::string,std::string> row = i->GetRow();
304                                 res->SetRow(row);
305                                 res->SetType(SQL_ROW);
306                                 if (!row.size())
307                                 {
308                                         log(DEBUG,"ROW SIZE IS 0");
309                                         res->SetType(SQL_END);
310                                 }
311                                 return;
312                         }
313                 }
314         }
315
316         void Implements(char* List)
317         {
318                 List[I_OnRehash] = List[I_OnRequest] = 1;
319         }
320
321         char* OnRequest(Request* request)
322         {
323                 if (request)
324                 {
325                         SQLResult* Result = new SQLResult();
326                         SQLRequest *r = (SQLRequest*)request->GetData();
327                         switch (r->GetQueryType())
328                         {
329                                 case SQL_RESULT:
330                                         ResultType(r,Result);
331                                 break;
332                                 case SQL_COUNT:
333                                         CountType(r,Result);
334                                 break;
335                                 case SQL_ROW:
336                                         RowType(r,Result);
337                                 break;
338                                 case SQL_DONE:
339                                         DoneType(r,Result);
340                                 break;
341                         }
342                         return (char*)Result;
343                 }
344                 return NULL;
345         }
346
347         ModuleSQL(Server* Me)
348                 : Module::Module(Me)
349         {
350                 Srv = Me;
351                 Conf = new ConfigReader();
352                 LoadDatabases(Conf);
353         }
354         
355         virtual ~ModuleSQL()
356         {
357                 Connections.clear();
358                 DELETE(Conf);
359         }
360         
361         virtual void OnRehash(const std::string &parameter)
362         {
363                 DELETE(Conf);
364                 Conf = new ConfigReader();
365                 LoadDatabases(Conf);
366         }
367         
368         virtual Version GetVersion()
369         {
370                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
371         }
372         
373 };
374
375 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
376
377 class ModuleSQLFactory : public ModuleFactory
378 {
379  public:
380         ModuleSQLFactory()
381         {
382         }
383         
384         ~ModuleSQLFactory()
385         {
386         }
387         
388         virtual Module * CreateModule(Server* Me)
389         {
390                 return new ModuleSQL(Me);
391         }
392         
393 };
394
395
396 extern "C" void * init_module( void )
397 {
398         return new ModuleSQLFactory;
399 }
400