]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sql.cpp
Updated copyrights in headers etc using perl inplace edit
[user/henk/code/inspircd.git] / src / modules / extra / m_sql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  InspIRCd is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17 using namespace std;
18
19 #include <stdio.h>
20 #include <string>
21 #include <mysql.h>
22 #include "users.h"
23 #include "channels.h"
24 #include "modules.h"
25 #include "helperfuncs.h"
26 #include "m_sql.h"
27
28 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
29 /* $CompileFlags: -I/usr/local/include/mysql -I/usr/include/mysql -I/usr/local/include -I/usr/include */
30 /* $LinkerFlags: -L/usr/local/lib/mysql -Wl,--rpath -Wl,/usr/local/lib/mysql -L/usr/lib/mysql -Wl,--rpath -Wl,/usr/lib/mysql -lmysqlclient */
31
32 /** SQLConnection represents one mysql session.
33  * Each session has its own persistent connection to the database.
34  */
35
36 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
37 #define mysql_field_count mysql_num_fields
38 #endif
39
40 class SQLConnection
41 {
42  protected:
43
44         MYSQL connection;
45         MYSQL_RES *res;
46         MYSQL_ROW row;
47         std::string host;
48         std::string user;
49         std::string pass;
50         std::string db;
51         std::map<std::string,std::string> thisrow;
52         bool Enabled;
53         long id;
54
55  public:
56
57         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
58         // MYSQL struct, but does not connect yet.
59         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
60         {
61                 this->Enabled = true;
62                 this->host = thishost;
63                 this->user = thisuser;
64                 this->pass = thispass;
65                 this->db = thisdb;
66                 this->id = myid;
67                 unsigned int timeout = 1;
68                 mysql_init(&connection);
69                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
70         }
71
72         // This method connects to the database using the credentials supplied to the constructor, and returns
73         // true upon success.
74         bool Connect()
75         {
76                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
77         }
78
79         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
80         // multiple rows.
81         bool QueryResult(std::string query)
82         {
83                 int r = mysql_query(&connection, query.c_str());
84                 if (!r)
85                 {
86                         res = mysql_use_result(&connection);
87                 }
88                 return (!r);
89         }
90
91         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
92         // the number of effected rows is returned in the return value.
93         unsigned long QueryCount(std::string query)
94         {
95                 int r = mysql_query(&connection, query.c_str());
96                 if (!r)
97                 {
98                         res = mysql_store_result(&connection);
99                         unsigned long rows = mysql_affected_rows(&connection);
100                         mysql_free_result(res);
101                         return rows;
102                 }
103                 return 0;
104         }
105
106         // This method fetches a row, if available from the database. You must issue a query
107         // using QueryResult() first! The row's values are returned as a map of std::string
108         // where each item is keyed by the column name.
109         std::map<std::string,std::string> GetRow()
110         {
111                 thisrow.clear();
112                 if (res)
113                 {
114                         row = mysql_fetch_row(res);
115                         if (row)
116                         {
117                                 unsigned int field_count = 0;
118                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
119                                 if(mysql_field_count(&connection) == 0)
120                                         return thisrow;
121                                 if (fields && mysql_field_count(&connection))
122                                 {
123                                         while (field_count < mysql_field_count(&connection))
124                                         {
125                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
126                                                 std::string b = (row[field_count] ? row[field_count] : "");
127                                                 thisrow[a] = b;
128                                                 field_count++;
129                                         }
130                                         return thisrow;
131                                 }
132                         }
133                 }
134                 return thisrow;
135         }
136
137         bool QueryDone()
138         {
139                 if (res)
140                 {
141                         mysql_free_result(res);
142                         return true;
143                 }
144                 else return false;
145         }
146
147         std::string GetError()
148         {
149                 return mysql_error(&connection);
150         }
151
152         long GetID()
153         {
154                 return id;
155         }
156
157         std::string GetHost()
158         {
159                 return host;
160         }
161
162         void Enable()
163         {
164                 Enabled = true;
165         }
166
167         void Disable()
168         {
169                 Enabled = false;
170         }
171
172         bool IsEnabled()
173         {
174                 return Enabled;
175         }
176
177 };
178
179 typedef std::vector<SQLConnection> ConnectionList;
180
181 class ModuleSQL : public Module
182 {
183         Server *Srv;
184         ConfigReader *Conf;
185         ConnectionList Connections;
186  
187  public:
188         void ConnectDatabases()
189         {
190                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
191                 {
192                         i->Enable();
193                         if (i->Connect())
194                         {
195                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
196                         }
197                         else
198                         {
199                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
200                                 i->Disable();
201                         }
202                 }
203         }
204
205         void LoadDatabases(ConfigReader* ThisConf)
206         {
207                 Srv->Log(DEFAULT,"SQL: Loading database settings");
208                 Connections.clear();
209                 Srv->Log(DEBUG,"Cleared connections");
210                 for (int j =0; j < ThisConf->Enumerate("database"); j++)
211                 {
212                         std::string db = ThisConf->ReadValue("database","name",j);
213                         std::string user = ThisConf->ReadValue("database","username",j);
214                         std::string pass = ThisConf->ReadValue("database","password",j);
215                         std::string host = ThisConf->ReadValue("database","hostname",j);
216                         std::string id = ThisConf->ReadValue("database","id",j);
217                         Srv->Log(DEBUG,"Read database settings");
218                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
219                         {
220                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
221                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
222                                 Connections.push_back(ThisSQL);
223                                 Srv->Log(DEBUG,"Pushed back connection");
224                         }
225                 }
226                 ConnectDatabases();
227         }
228
229         void ResultType(SQLRequest *r, SQLResult *res)
230         {
231                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
232                 {
233                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
234                         {
235                                 bool xr = i->QueryResult(r->GetQuery());
236                                 if (!xr)
237                                 {
238                                         res->SetType(SQL_ERROR);
239                                         res->SetError(i->GetError());
240                                         return;
241                                 }
242                                 res->SetType(SQL_OK);
243                                 return;
244                         }
245                 }
246         }
247
248         void CountType(SQLRequest *r, SQLResult* res)
249         {
250                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
251                 {
252                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
253                         {
254                                 res->SetType(SQL_COUNT);
255                                 res->SetCount(i->QueryCount(r->GetQuery()));
256                                 return;
257                         }
258                 }
259         }
260
261         void DoneType(SQLRequest *r, SQLResult* res)
262         {
263                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
264                 {
265                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
266                         {
267                                 res->SetType(SQL_DONE);
268                                 if (!i->QueryDone())
269                                         res->SetType(SQL_ERROR);
270                         }
271                 }
272         }
273
274         void RowType(SQLRequest *r, SQLResult* res)
275         {
276                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
277                 {
278                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
279                         {
280                                 log(DEBUG,"*** FOUND MATCHING ROW");
281                                 std::map<std::string,std::string> row = i->GetRow();
282                                 res->SetRow(row);
283                                 res->SetType(SQL_ROW);
284                                 if (!row.size())
285                                 {
286                                         log(DEBUG,"ROW SIZE IS 0");
287                                         res->SetType(SQL_END);
288                                 }
289                                 return;
290                         }
291                 }
292         }
293
294         void Implements(char* List)
295         {
296                 List[I_OnRehash] = List[I_OnRequest] = 1;
297         }
298
299         char* OnRequest(Request* request)
300         {
301                 if (request)
302                 {
303                         SQLResult* Result = new SQLResult();
304                         SQLRequest *r = (SQLRequest*)request->GetData();
305                         switch (r->GetQueryType())
306                         {
307                                 case SQL_RESULT:
308                                         ResultType(r,Result);
309                                 break;
310                                 case SQL_COUNT:
311                                         CountType(r,Result);
312                                 break;
313                                 case SQL_ROW:
314                                         RowType(r,Result);
315                                 break;
316                                 case SQL_DONE:
317                                         DoneType(r,Result);
318                                 break;
319                         }
320                         return (char*)Result;
321                 }
322                 return NULL;
323         }
324
325         ModuleSQL(Server* Me)
326                 : Module::Module(Me)
327         {
328                 Srv = Me;
329                 Conf = new ConfigReader();
330                 LoadDatabases(Conf);
331         }
332         
333         virtual ~ModuleSQL()
334         {
335                 Connections.clear();
336                 delete Conf;
337         }
338         
339         virtual void OnRehash(std::string parameter)
340         {
341                 delete Conf;
342                 Conf = new ConfigReader();
343                 LoadDatabases(Conf);
344         }
345         
346         virtual Version GetVersion()
347         {
348                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
349         }
350         
351 };
352
353 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
354
355 class ModuleSQLFactory : public ModuleFactory
356 {
357  public:
358         ModuleSQLFactory()
359         {
360         }
361         
362         ~ModuleSQLFactory()
363         {
364         }
365         
366         virtual Module * CreateModule(Server* Me)
367         {
368                 return new ModuleSQL(Me);
369         }
370         
371 };
372
373
374 extern "C" void * init_module( void )
375 {
376         return new ModuleSQLFactory;
377 }
378