]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sql.cpp
Extra m_sql field checking
[user/henk/code/inspircd.git] / src / modules / extra / m_sql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  Inspire is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17 using namespace std;
18
19 #include <stdio.h>
20 #include <string>
21 #include <mysql.h>
22 #include "users.h"
23 #include "channels.h"
24 #include "modules.h"
25 #include "helperfuncs.h"
26 #include "m_sql.h"
27
28 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
29 /* $CompileFlags: -I/usr/local/include/mysql -I/usr/include/mysql -I/usr/local/include -I/usr/include */
30 /* $LinkerFlags: -L/usr/local/lib/mysql -Wl,--rpath -Wl,/usr/local/lib/mysql -L/usr/lib/mysql -Wl,--rpath -Wl,/usr/lib/mysql -lmysqlclient */
31
32 /** SQLConnection represents one mysql session.
33  * Each session has its own persistent connection to the database.
34  */
35
36 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
37 #define mysql_field_count mysql_num_fields
38 #endif
39
40 class SQLConnection
41 {
42  protected:
43
44         MYSQL connection;
45         MYSQL_RES *res;
46         MYSQL_ROW row;
47         std::string host;
48         std::string user;
49         std::string pass;
50         std::string db;
51         std::map<std::string,std::string> thisrow;
52         bool Enabled;
53         long id;
54
55  public:
56
57         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
58         // MYSQL struct, but does not connect yet.
59         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
60         {
61                 this->Enabled = true;
62                 this->host = thishost;
63                 this->user = thisuser;
64                 this->pass = thispass;
65                 this->db = thisdb;
66                 this->id = myid;
67                 unsigned int timeout = 1;
68                 mysql_init(&connection);
69                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
70         }
71
72         // This method connects to the database using the credentials supplied to the constructor, and returns
73         // true upon success.
74         bool Connect()
75         {
76                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
77         }
78
79         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
80         // multiple rows.
81         bool QueryResult(std::string query)
82         {
83                 int r = mysql_query(&connection, query.c_str());
84                 if (!r)
85                 {
86                         res = mysql_use_result(&connection);
87                 }
88                 return (!r);
89         }
90
91         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
92         // the number of effected rows is returned in the return value.
93         unsigned long QueryCount(std::string query)
94         {
95                 int r = mysql_query(&connection, query.c_str());
96                 if (!r)
97                 {
98                         res = mysql_store_result(&connection);
99                         unsigned long rows = mysql_affected_rows(&connection);
100                         mysql_free_result(res);
101                         return rows;
102                 }
103                 return 0;
104         }
105
106         // This method fetches a row, if available from the database. You must issue a query
107         // using QueryResult() first! The row's values are returned as a map of std::string
108         // where each item is keyed by the column name.
109         std::map<std::string,std::string> GetRow()
110         {
111                 thisrow.clear();
112                 if (res)
113                 {
114                         row = mysql_fetch_row(res);
115                         if (row)
116                         {
117                                 unsigned int field_count = 0;
118                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
119                                 if(mysql_field_count(&connection) == 0)
120                                         return thisrow;
121                                 if (fields && mysql_field_count(&connection))
122                                 {
123                                         while (field_count < mysql_field_count(&connection))
124                                         {
125                                                 if (fields[field_count] && fields[field_count].name && row[field_count])
126                                                 {
127                                                         std::string a = (fields[field_count].name ? fields[field_count].name : "");
128                                                         std::string b = (row[field_count] ? row[field_count] : "");
129                                                         thisrow[a] = b;
130                                                         field_count++;
131                                                 }
132                                         }
133                                         return thisrow;
134                                 }
135                         }
136                 }
137                 return thisrow;
138         }
139
140         bool QueryDone()
141         {
142                 if (res)
143                 {
144                         mysql_free_result(res);
145                         return true;
146                 }
147                 else return false;
148         }
149
150         std::string GetError()
151         {
152                 return mysql_error(&connection);
153         }
154
155         long GetID()
156         {
157                 return id;
158         }
159
160         std::string GetHost()
161         {
162                 return host;
163         }
164
165         void Enable()
166         {
167                 Enabled = true;
168         }
169
170         void Disable()
171         {
172                 Enabled = false;
173         }
174
175         bool IsEnabled()
176         {
177                 return Enabled;
178         }
179
180 };
181
182 typedef std::vector<SQLConnection> ConnectionList;
183
184 class ModuleSQL : public Module
185 {
186         Server *Srv;
187         ConfigReader *Conf;
188         ConnectionList Connections;
189  
190  public:
191         void ConnectDatabases()
192         {
193                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
194                 {
195                         i->Enable();
196                         if (i->Connect())
197                         {
198                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
199                         }
200                         else
201                         {
202                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
203                                 i->Disable();
204                         }
205                 }
206         }
207
208         void LoadDatabases(ConfigReader* ThisConf)
209         {
210                 Srv->Log(DEFAULT,"SQL: Loading database settings");
211                 Connections.clear();
212                 Srv->Log(DEBUG,"Cleared connections");
213                 for (int j =0; j < ThisConf->Enumerate("database"); j++)
214                 {
215                         std::string db = ThisConf->ReadValue("database","name",j);
216                         std::string user = ThisConf->ReadValue("database","username",j);
217                         std::string pass = ThisConf->ReadValue("database","password",j);
218                         std::string host = ThisConf->ReadValue("database","hostname",j);
219                         std::string id = ThisConf->ReadValue("database","id",j);
220                         Srv->Log(DEBUG,"Read database settings");
221                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
222                         {
223                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
224                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
225                                 Connections.push_back(ThisSQL);
226                                 Srv->Log(DEBUG,"Pushed back connection");
227                         }
228                 }
229                 ConnectDatabases();
230         }
231
232         void ResultType(SQLRequest *r, SQLResult *res)
233         {
234                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
235                 {
236                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
237                         {
238                                 bool xr = i->QueryResult(r->GetQuery());
239                                 if (!xr)
240                                 {
241                                         res->SetType(SQL_ERROR);
242                                         res->SetError(i->GetError());
243                                         return;
244                                 }
245                                 res->SetType(SQL_OK);
246                                 return;
247                         }
248                 }
249         }
250
251         void CountType(SQLRequest *r, SQLResult* res)
252         {
253                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
254                 {
255                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
256                         {
257                                 res->SetType(SQL_COUNT);
258                                 res->SetCount(i->QueryCount(r->GetQuery()));
259                                 return;
260                         }
261                 }
262         }
263
264         void DoneType(SQLRequest *r, SQLResult* res)
265         {
266                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
267                 {
268                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
269                         {
270                                 res->SetType(SQL_DONE);
271                                 if (!i->QueryDone())
272                                         res->SetType(SQL_ERROR);
273                         }
274                 }
275         }
276
277         void RowType(SQLRequest *r, SQLResult* res)
278         {
279                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
280                 {
281                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
282                         {
283                                 log(DEBUG,"*** FOUND MATCHING ROW");
284                                 std::map<std::string,std::string> row = i->GetRow();
285                                 res->SetRow(row);
286                                 res->SetType(SQL_ROW);
287                                 if (!row.size())
288                                 {
289                                         log(DEBUG,"ROW SIZE IS 0");
290                                         res->SetType(SQL_END);
291                                 }
292                                 return;
293                         }
294                 }
295         }
296
297         void Implements(char* List)
298         {
299                 List[I_OnRehash] = List[I_OnRequest] = 1;
300         }
301
302         char* OnRequest(Request* request)
303         {
304                 if (request)
305                 {
306                         SQLResult* Result = new SQLResult();
307                         SQLRequest *r = (SQLRequest*)request->GetData();
308                         switch (r->GetQueryType())
309                         {
310                                 case SQL_RESULT:
311                                         ResultType(r,Result);
312                                 break;
313                                 case SQL_COUNT:
314                                         CountType(r,Result);
315                                 break;
316                                 case SQL_ROW:
317                                         RowType(r,Result);
318                                 break;
319                                 case SQL_DONE:
320                                         DoneType(r,Result);
321                                 break;
322                         }
323                         return (char*)Result;
324                 }
325                 return NULL;
326         }
327
328         ModuleSQL(Server* Me)
329                 : Module::Module(Me)
330         {
331                 Srv = Me;
332                 Conf = new ConfigReader();
333                 LoadDatabases(Conf);
334         }
335         
336         virtual ~ModuleSQL()
337         {
338                 Connections.clear();
339                 delete Conf;
340         }
341         
342         virtual void OnRehash(std::string parameter)
343         {
344                 delete Conf;
345                 Conf = new ConfigReader();
346                 LoadDatabases(Conf);
347         }
348         
349         virtual Version GetVersion()
350         {
351                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
352         }
353         
354 };
355
356 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
357
358 class ModuleSQLFactory : public ModuleFactory
359 {
360  public:
361         ModuleSQLFactory()
362         {
363         }
364         
365         ~ModuleSQLFactory()
366         {
367         }
368         
369         virtual Module * CreateModule(Server* Me)
370         {
371                 return new ModuleSQL(Me);
372         }
373         
374 };
375
376
377 extern "C" void * init_module( void )
378 {
379         return new ModuleSQLFactory;
380 }
381