]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sql.cpp
Test fix for 3.23 and below
[user/henk/code/inspircd.git] / src / modules / extra / m_sql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  Inspire is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17
18 #include <stdio.h>
19 #include <string>
20 #include <mysql.h>
21 #include "users.h"
22 #include "channels.h"
23 #include "modules.h"
24 #include "m_sql.h"
25
26 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
27 /* $CompileFlags: -I/usr/local/include/mysql -I/usr/include/mysql -I/usr/local/include -I/usr/include -L/usr/local/lib/mysql -L/usr/lib/mysql -L/usr/local/lib -lmysqlclient */
28
29 /** SQLConnection represents one mysql session.
30  * Each session has its own persistent connection to the database.
31  */
32
33 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
34 #define mysql_field_count mysql_num_fields
35 #endif
36
37 class SQLConnection
38 {
39  protected:
40
41         MYSQL connection;
42         MYSQL_RES *res;
43         MYSQL_ROW row;
44         std::string host;
45         std::string user;
46         std::string pass;
47         std::string db;
48         std::map<std::string,std::string> thisrow;
49         bool Enabled;
50         long id;
51
52  public:
53
54         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
55         // MYSQL struct, but does not connect yet.
56         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
57         {
58                 this->Enabled = true;
59                 this->host = thishost;
60                 this->user = thisuser;
61                 this->pass = thispass;
62                 this->db = thisdb;
63                 this->id = myid;
64                 unsigned int timeout = 1;
65                 mysql_init(&connection);
66                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
67         }
68
69         // This method connects to the database using the credentials supplied to the constructor, and returns
70         // true upon success.
71         bool Connect()
72         {
73                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
74         }
75
76         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
77         // multiple rows.
78         bool QueryResult(std::string query)
79         {
80                 int r = mysql_query(&connection, query.c_str());
81                 if (!r)
82                 {
83                         res = mysql_use_result(&connection);
84                 }
85                 return (!r);
86         }
87
88         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
89         // the number of effected rows is returned in the return value.
90         unsigned long QueryCount(std::string query)
91         {
92                 int r = mysql_query(&connection, query.c_str());
93                 if (!r)
94                 {
95                         res = mysql_store_result(&connection);
96                         unsigned long rows = mysql_affected_rows(&connection);
97                         mysql_free_result(res);
98                         return rows;
99                 }
100                 return 0;
101         }
102
103         // This method fetches a row, if available from the database. You must issue a query
104         // using QueryResult() first! The row's values are returned as a map of std::string
105         // where each item is keyed by the column name.
106         std::map<std::string,std::string> GetRow()
107         {
108                 thisrow.clear();
109                 if (res)
110                 {
111                         row = mysql_fetch_row(res);
112                         if (row)
113                         {
114                                 unsigned int field_count = 0;
115                                 if(mysql_field_count(&connection) == 0)
116                                         return thisrow;
117                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
118                                 while (field_count < mysql_field_count(&connection))
119                                 {
120                                         thisrow[std::string(fields[field_count].name)] = std::string(row[field_count]);
121                                         field_count++;
122                                 }
123                                 return thisrow;
124                         }
125                 }
126                 return thisrow;
127         }
128
129         bool QueryDone()
130         {
131                 if (res)
132                 {
133                         mysql_free_result(res);
134                         return true;
135                 }
136                 else return false;
137         }
138
139         std::string GetError()
140         {
141                 return mysql_error(&connection);
142         }
143
144         long GetID()
145         {
146                 return id;
147         }
148
149         std::string GetHost()
150         {
151                 return host;
152         }
153
154         void Enable()
155         {
156                 Enabled = true;
157         }
158
159         void Disable()
160         {
161                 Enabled = false;
162         }
163
164         bool IsEnabled()
165         {
166                 return Enabled;
167         }
168
169 };
170
171 typedef std::vector<SQLConnection> ConnectionList;
172
173 class ModuleSQL : public Module
174 {
175         Server *Srv;
176         ConfigReader *Conf;
177         ConnectionList Connections;
178  
179  public:
180         void ConnectDatabases()
181         {
182                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
183                 {
184                         i->Enable();
185                         if (i->Connect())
186                         {
187                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
188                         }
189                         else
190                         {
191                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
192                                 i->Disable();
193                         }
194                 }
195         }
196
197         void LoadDatabases(ConfigReader* ThisConf)
198         {
199                 Srv->Log(DEFAULT,"SQL: Loading database settings");
200                 Connections.clear();
201                 Srv->Log(DEBUG,"Cleared connections");
202                 for (int j =0; j < ThisConf->Enumerate("database"); j++)
203                 {
204                         std::string db = ThisConf->ReadValue("database","name",j);
205                         std::string user = ThisConf->ReadValue("database","username",j);
206                         std::string pass = ThisConf->ReadValue("database","password",j);
207                         std::string host = ThisConf->ReadValue("database","hostname",j);
208                         std::string id = ThisConf->ReadValue("database","id",j);
209                         Srv->Log(DEBUG,"Read database settings");
210                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
211                         {
212                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
213                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
214                                 Connections.push_back(ThisSQL);
215                                 Srv->Log(DEBUG,"Pushed back connection");
216                         }
217                 }
218                 ConnectDatabases();
219         }
220
221         void ResultType(SQLRequest *r, SQLResult *res)
222         {
223                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
224                 {
225                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
226                         {
227                                 bool xr = i->QueryResult(r->GetQuery());
228                                 if (!xr)
229                                 {
230                                         res->SetType(SQL_ERROR);
231                                         res->SetError(i->GetError());
232                                         return;
233                                 }
234                                 res->SetType(SQL_OK);
235                                 return;
236                         }
237                 }
238         }
239
240         void CountType(SQLRequest *r, SQLResult* res)
241         {
242                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
243                 {
244                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
245                         {
246                                 res->SetType(SQL_COUNT);
247                                 res->SetCount(i->QueryCount(r->GetQuery()));
248                                 return;
249                         }
250                 }
251         }
252
253         void DoneType(SQLRequest *r, SQLResult* res)
254         {
255                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
256                 {
257                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
258                         {
259                                 res->SetType(SQL_DONE);
260                                 if (!i->QueryDone())
261                                         res->SetType(SQL_ERROR);
262                         }
263                 }
264         }
265
266         void RowType(SQLRequest *r, SQLResult* res)
267         {
268                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
269                 {
270                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
271                         {
272                                 log(DEBUG,"*** FOUND MATCHING ROW");
273                                 std::map<std::string,std::string> row = i->GetRow();
274                                 res->SetRow(row);
275                                 res->SetType(SQL_ROW);
276                                 if (!row.size())
277                                 {
278                                         log(DEBUG,"ROW SIZE IS 0");
279                                         res->SetType(SQL_END);
280                                 }
281                                 return;
282                         }
283                 }
284         }
285
286         char* OnRequest(Request* request)
287         {
288                 if (request)
289                 {
290                         SQLResult* Result = new SQLResult();
291                         SQLRequest *r = (SQLRequest*)request->GetData();
292                         switch (r->GetQueryType())
293                         {
294                                 case SQL_RESULT:
295                                         ResultType(r,Result);
296                                 break;
297                                 case SQL_COUNT:
298                                         CountType(r,Result);
299                                 break;
300                                 case SQL_ROW:
301                                         RowType(r,Result);
302                                 break;
303                                 case SQL_DONE:
304                                         DoneType(r,Result);
305                                 break;
306                         }
307                         return (char*)Result;
308                 }
309                 return NULL;
310         }
311
312         ModuleSQL()
313         {
314                 Srv = new Server();
315                 Conf = new ConfigReader();
316                 LoadDatabases(Conf);
317         }
318         
319         virtual ~ModuleSQL()
320         {
321                 Connections.clear();
322                 delete Conf;
323                 delete Srv;
324         }
325         
326         virtual void OnRehash()
327         {
328                 delete Conf;
329                 Conf = new ConfigReader();
330                 LoadDatabases(Conf);
331         }
332         
333         virtual Version GetVersion()
334         {
335                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
336         }
337         
338 };
339
340 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
341
342 class ModuleSQLFactory : public ModuleFactory
343 {
344  public:
345         ModuleSQLFactory()
346         {
347         }
348         
349         ~ModuleSQLFactory()
350         {
351         }
352         
353         virtual Module * CreateModule()
354         {
355                 return new ModuleSQL;
356         }
357         
358 };
359
360
361 extern "C" void * init_module( void )
362 {
363         return new ModuleSQLFactory;
364 }
365