]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sql.cpp
Moved a ton of functions into helperfuncs.h to speed up recompiles
[user/henk/code/inspircd.git] / src / modules / extra / m_sql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  Inspire is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17
18 #include <stdio.h>
19 #include <string>
20 #include <mysql.h>
21 #include "users.h"
22 #include "channels.h"
23 #include "modules.h"
24 #include "helperfuncs.h"
25 #include "m_sql.h"
26
27 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
28 /* $CompileFlags: -I/usr/local/include/mysql -I/usr/include/mysql -I/usr/local/include -I/usr/include -L/usr/local/lib/mysql -L/usr/lib/mysql -L/usr/local/lib -lmysqlclient */
29
30 /** SQLConnection represents one mysql session.
31  * Each session has its own persistent connection to the database.
32  */
33
34 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
35 #define mysql_field_count mysql_num_fields
36 #endif
37
38 class SQLConnection
39 {
40  protected:
41
42         MYSQL connection;
43         MYSQL_RES *res;
44         MYSQL_ROW row;
45         std::string host;
46         std::string user;
47         std::string pass;
48         std::string db;
49         std::map<std::string,std::string> thisrow;
50         bool Enabled;
51         long id;
52
53  public:
54
55         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
56         // MYSQL struct, but does not connect yet.
57         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
58         {
59                 this->Enabled = true;
60                 this->host = thishost;
61                 this->user = thisuser;
62                 this->pass = thispass;
63                 this->db = thisdb;
64                 this->id = myid;
65                 unsigned int timeout = 1;
66                 mysql_init(&connection);
67                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
68         }
69
70         // This method connects to the database using the credentials supplied to the constructor, and returns
71         // true upon success.
72         bool Connect()
73         {
74                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
75         }
76
77         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
78         // multiple rows.
79         bool QueryResult(std::string query)
80         {
81                 int r = mysql_query(&connection, query.c_str());
82                 if (!r)
83                 {
84                         res = mysql_use_result(&connection);
85                 }
86                 return (!r);
87         }
88
89         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
90         // the number of effected rows is returned in the return value.
91         unsigned long QueryCount(std::string query)
92         {
93                 int r = mysql_query(&connection, query.c_str());
94                 if (!r)
95                 {
96                         res = mysql_store_result(&connection);
97                         unsigned long rows = mysql_affected_rows(&connection);
98                         mysql_free_result(res);
99                         return rows;
100                 }
101                 return 0;
102         }
103
104         // This method fetches a row, if available from the database. You must issue a query
105         // using QueryResult() first! The row's values are returned as a map of std::string
106         // where each item is keyed by the column name.
107         std::map<std::string,std::string> GetRow()
108         {
109                 thisrow.clear();
110                 if (res)
111                 {
112                         row = mysql_fetch_row(res);
113                         if (row)
114                         {
115                                 unsigned int field_count = 0;
116                                 if(mysql_field_count(&connection) == 0)
117                                         return thisrow;
118                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
119                                 while (field_count < mysql_field_count(&connection))
120                                 {
121                                         thisrow[std::string(fields[field_count].name)] = std::string(row[field_count]);
122                                         field_count++;
123                                 }
124                                 return thisrow;
125                         }
126                 }
127                 return thisrow;
128         }
129
130         bool QueryDone()
131         {
132                 if (res)
133                 {
134                         mysql_free_result(res);
135                         return true;
136                 }
137                 else return false;
138         }
139
140         std::string GetError()
141         {
142                 return mysql_error(&connection);
143         }
144
145         long GetID()
146         {
147                 return id;
148         }
149
150         std::string GetHost()
151         {
152                 return host;
153         }
154
155         void Enable()
156         {
157                 Enabled = true;
158         }
159
160         void Disable()
161         {
162                 Enabled = false;
163         }
164
165         bool IsEnabled()
166         {
167                 return Enabled;
168         }
169
170 };
171
172 typedef std::vector<SQLConnection> ConnectionList;
173
174 class ModuleSQL : public Module
175 {
176         Server *Srv;
177         ConfigReader *Conf;
178         ConnectionList Connections;
179  
180  public:
181         void ConnectDatabases()
182         {
183                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
184                 {
185                         i->Enable();
186                         if (i->Connect())
187                         {
188                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
189                         }
190                         else
191                         {
192                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
193                                 i->Disable();
194                         }
195                 }
196         }
197
198         void LoadDatabases(ConfigReader* ThisConf)
199         {
200                 Srv->Log(DEFAULT,"SQL: Loading database settings");
201                 Connections.clear();
202                 Srv->Log(DEBUG,"Cleared connections");
203                 for (int j =0; j < ThisConf->Enumerate("database"); j++)
204                 {
205                         std::string db = ThisConf->ReadValue("database","name",j);
206                         std::string user = ThisConf->ReadValue("database","username",j);
207                         std::string pass = ThisConf->ReadValue("database","password",j);
208                         std::string host = ThisConf->ReadValue("database","hostname",j);
209                         std::string id = ThisConf->ReadValue("database","id",j);
210                         Srv->Log(DEBUG,"Read database settings");
211                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
212                         {
213                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
214                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
215                                 Connections.push_back(ThisSQL);
216                                 Srv->Log(DEBUG,"Pushed back connection");
217                         }
218                 }
219                 ConnectDatabases();
220         }
221
222         void ResultType(SQLRequest *r, SQLResult *res)
223         {
224                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
225                 {
226                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
227                         {
228                                 bool xr = i->QueryResult(r->GetQuery());
229                                 if (!xr)
230                                 {
231                                         res->SetType(SQL_ERROR);
232                                         res->SetError(i->GetError());
233                                         return;
234                                 }
235                                 res->SetType(SQL_OK);
236                                 return;
237                         }
238                 }
239         }
240
241         void CountType(SQLRequest *r, SQLResult* res)
242         {
243                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
244                 {
245                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
246                         {
247                                 res->SetType(SQL_COUNT);
248                                 res->SetCount(i->QueryCount(r->GetQuery()));
249                                 return;
250                         }
251                 }
252         }
253
254         void DoneType(SQLRequest *r, SQLResult* res)
255         {
256                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
257                 {
258                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
259                         {
260                                 res->SetType(SQL_DONE);
261                                 if (!i->QueryDone())
262                                         res->SetType(SQL_ERROR);
263                         }
264                 }
265         }
266
267         void RowType(SQLRequest *r, SQLResult* res)
268         {
269                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
270                 {
271                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
272                         {
273                                 log(DEBUG,"*** FOUND MATCHING ROW");
274                                 std::map<std::string,std::string> row = i->GetRow();
275                                 res->SetRow(row);
276                                 res->SetType(SQL_ROW);
277                                 if (!row.size())
278                                 {
279                                         log(DEBUG,"ROW SIZE IS 0");
280                                         res->SetType(SQL_END);
281                                 }
282                                 return;
283                         }
284                 }
285         }
286
287         char* OnRequest(Request* request)
288         {
289                 if (request)
290                 {
291                         SQLResult* Result = new SQLResult();
292                         SQLRequest *r = (SQLRequest*)request->GetData();
293                         switch (r->GetQueryType())
294                         {
295                                 case SQL_RESULT:
296                                         ResultType(r,Result);
297                                 break;
298                                 case SQL_COUNT:
299                                         CountType(r,Result);
300                                 break;
301                                 case SQL_ROW:
302                                         RowType(r,Result);
303                                 break;
304                                 case SQL_DONE:
305                                         DoneType(r,Result);
306                                 break;
307                         }
308                         return (char*)Result;
309                 }
310                 return NULL;
311         }
312
313         ModuleSQL()
314         {
315                 Srv = new Server();
316                 Conf = new ConfigReader();
317                 LoadDatabases(Conf);
318         }
319         
320         virtual ~ModuleSQL()
321         {
322                 Connections.clear();
323                 delete Conf;
324                 delete Srv;
325         }
326         
327         virtual void OnRehash()
328         {
329                 delete Conf;
330                 Conf = new ConfigReader();
331                 LoadDatabases(Conf);
332         }
333         
334         virtual Version GetVersion()
335         {
336                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
337         }
338         
339 };
340
341 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
342
343 class ModuleSQLFactory : public ModuleFactory
344 {
345  public:
346         ModuleSQLFactory()
347         {
348         }
349         
350         ~ModuleSQLFactory()
351         {
352         }
353         
354         virtual Module * CreateModule()
355         {
356                 return new ModuleSQL;
357         }
358         
359 };
360
361
362 extern "C" void * init_module( void )
363 {
364         return new ModuleSQLFactory;
365 }
366