]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sql.cpp
Review and optimize
[user/henk/code/inspircd.git] / src / modules / extra / m_sql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  Inspire is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17 using namespace std;
18
19 #include <stdio.h>
20 #include <string>
21 #include <mysql.h>
22 #include "users.h"
23 #include "channels.h"
24 #include "modules.h"
25 #include "helperfuncs.h"
26 #include "m_sql.h"
27
28 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
29 /* $CompileFlags: -I/usr/local/include/mysql -I/usr/include/mysql -I/usr/local/include -I/usr/include */
30 /* $LinkerFlags: -L/usr/local/lib/mysql -Wl,--rpath -Wl,/usr/local/lib/mysql -L/usr/lib/mysql -Wl,--rpath -Wl,/usr/lib/mysql -lmysqlclient */
31
32 /** SQLConnection represents one mysql session.
33  * Each session has its own persistent connection to the database.
34  */
35
36 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
37 #define mysql_field_count mysql_num_fields
38 #endif
39
40 class SQLConnection
41 {
42  protected:
43
44         MYSQL connection;
45         MYSQL_RES *res;
46         MYSQL_ROW row;
47         std::string host;
48         std::string user;
49         std::string pass;
50         std::string db;
51         std::map<std::string,std::string> thisrow;
52         bool Enabled;
53         long id;
54
55  public:
56
57         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
58         // MYSQL struct, but does not connect yet.
59         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
60         {
61                 this->Enabled = true;
62                 this->host = thishost;
63                 this->user = thisuser;
64                 this->pass = thispass;
65                 this->db = thisdb;
66                 this->id = myid;
67                 unsigned int timeout = 1;
68                 mysql_init(&connection);
69                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
70         }
71
72         // This method connects to the database using the credentials supplied to the constructor, and returns
73         // true upon success.
74         bool Connect()
75         {
76                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
77         }
78
79         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
80         // multiple rows.
81         bool QueryResult(std::string query)
82         {
83                 int r = mysql_query(&connection, query.c_str());
84                 if (!r)
85                 {
86                         res = mysql_use_result(&connection);
87                 }
88                 return (!r);
89         }
90
91         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
92         // the number of effected rows is returned in the return value.
93         unsigned long QueryCount(std::string query)
94         {
95                 int r = mysql_query(&connection, query.c_str());
96                 if (!r)
97                 {
98                         res = mysql_store_result(&connection);
99                         unsigned long rows = mysql_affected_rows(&connection);
100                         mysql_free_result(res);
101                         return rows;
102                 }
103                 return 0;
104         }
105
106         // This method fetches a row, if available from the database. You must issue a query
107         // using QueryResult() first! The row's values are returned as a map of std::string
108         // where each item is keyed by the column name.
109         std::map<std::string,std::string> GetRow()
110         {
111                 thisrow.clear();
112                 if (res)
113                 {
114                         row = mysql_fetch_row(res);
115                         if (row)
116                         {
117                                 unsigned int field_count = 0;
118                                 if(mysql_field_count(&connection) == 0)
119                                         return thisrow;
120                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
121                                 while (field_count < mysql_field_count(&connection))
122                                 {
123                                         thisrow[std::string(fields[field_count].name)] = std::string(row[field_count]);
124                                         field_count++;
125                                 }
126                                 return thisrow;
127                         }
128                 }
129                 return thisrow;
130         }
131
132         bool QueryDone()
133         {
134                 if (res)
135                 {
136                         mysql_free_result(res);
137                         return true;
138                 }
139                 else return false;
140         }
141
142         std::string GetError()
143         {
144                 return mysql_error(&connection);
145         }
146
147         long GetID()
148         {
149                 return id;
150         }
151
152         std::string GetHost()
153         {
154                 return host;
155         }
156
157         void Enable()
158         {
159                 Enabled = true;
160         }
161
162         void Disable()
163         {
164                 Enabled = false;
165         }
166
167         bool IsEnabled()
168         {
169                 return Enabled;
170         }
171
172 };
173
174 typedef std::vector<SQLConnection> ConnectionList;
175
176 class ModuleSQL : public Module
177 {
178         Server *Srv;
179         ConfigReader *Conf;
180         ConnectionList Connections;
181  
182  public:
183         void ConnectDatabases()
184         {
185                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
186                 {
187                         i->Enable();
188                         if (i->Connect())
189                         {
190                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
191                         }
192                         else
193                         {
194                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
195                                 i->Disable();
196                         }
197                 }
198         }
199
200         void LoadDatabases(ConfigReader* ThisConf)
201         {
202                 Srv->Log(DEFAULT,"SQL: Loading database settings");
203                 Connections.clear();
204                 Srv->Log(DEBUG,"Cleared connections");
205                 for (int j =0; j < ThisConf->Enumerate("database"); j++)
206                 {
207                         std::string db = ThisConf->ReadValue("database","name",j);
208                         std::string user = ThisConf->ReadValue("database","username",j);
209                         std::string pass = ThisConf->ReadValue("database","password",j);
210                         std::string host = ThisConf->ReadValue("database","hostname",j);
211                         std::string id = ThisConf->ReadValue("database","id",j);
212                         Srv->Log(DEBUG,"Read database settings");
213                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
214                         {
215                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
216                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
217                                 Connections.push_back(ThisSQL);
218                                 Srv->Log(DEBUG,"Pushed back connection");
219                         }
220                 }
221                 ConnectDatabases();
222         }
223
224         void ResultType(SQLRequest *r, SQLResult *res)
225         {
226                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
227                 {
228                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
229                         {
230                                 bool xr = i->QueryResult(r->GetQuery());
231                                 if (!xr)
232                                 {
233                                         res->SetType(SQL_ERROR);
234                                         res->SetError(i->GetError());
235                                         return;
236                                 }
237                                 res->SetType(SQL_OK);
238                                 return;
239                         }
240                 }
241         }
242
243         void CountType(SQLRequest *r, SQLResult* res)
244         {
245                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
246                 {
247                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
248                         {
249                                 res->SetType(SQL_COUNT);
250                                 res->SetCount(i->QueryCount(r->GetQuery()));
251                                 return;
252                         }
253                 }
254         }
255
256         void DoneType(SQLRequest *r, SQLResult* res)
257         {
258                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
259                 {
260                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
261                         {
262                                 res->SetType(SQL_DONE);
263                                 if (!i->QueryDone())
264                                         res->SetType(SQL_ERROR);
265                         }
266                 }
267         }
268
269         void RowType(SQLRequest *r, SQLResult* res)
270         {
271                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
272                 {
273                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
274                         {
275                                 log(DEBUG,"*** FOUND MATCHING ROW");
276                                 std::map<std::string,std::string> row = i->GetRow();
277                                 res->SetRow(row);
278                                 res->SetType(SQL_ROW);
279                                 if (!row.size())
280                                 {
281                                         log(DEBUG,"ROW SIZE IS 0");
282                                         res->SetType(SQL_END);
283                                 }
284                                 return;
285                         }
286                 }
287         }
288
289         char* OnRequest(Request* request)
290         {
291                 if (request)
292                 {
293                         SQLResult* Result = new SQLResult();
294                         SQLRequest *r = (SQLRequest*)request->GetData();
295                         switch (r->GetQueryType())
296                         {
297                                 case SQL_RESULT:
298                                         ResultType(r,Result);
299                                 break;
300                                 case SQL_COUNT:
301                                         CountType(r,Result);
302                                 break;
303                                 case SQL_ROW:
304                                         RowType(r,Result);
305                                 break;
306                                 case SQL_DONE:
307                                         DoneType(r,Result);
308                                 break;
309                         }
310                         return (char*)Result;
311                 }
312                 return NULL;
313         }
314
315         ModuleSQL(Server* Me)
316                 : Module::Module(Me)
317         {
318                 Srv = Me;
319                 Conf = new ConfigReader();
320                 LoadDatabases(Conf);
321         }
322         
323         virtual ~ModuleSQL()
324         {
325                 Connections.clear();
326                 delete Conf;
327         }
328         
329         virtual void OnRehash(std::string parameter)
330         {
331                 delete Conf;
332                 Conf = new ConfigReader();
333                 LoadDatabases(Conf);
334         }
335         
336         virtual Version GetVersion()
337         {
338                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
339         }
340         
341 };
342
343 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
344
345 class ModuleSQLFactory : public ModuleFactory
346 {
347  public:
348         ModuleSQLFactory()
349         {
350         }
351         
352         ~ModuleSQLFactory()
353         {
354         }
355         
356         virtual Module * CreateModule(Server* Me)
357         {
358                 return new ModuleSQL(Me);
359         }
360         
361 };
362
363
364 extern "C" void * init_module( void )
365 {
366         return new ModuleSQLFactory;
367 }
368