]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sql.cpp
Fixed crash-n-burn on rehash
[user/henk/code/inspircd.git] / src / modules / extra / m_sql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  Inspire is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17
18 #include <stdio.h>
19 #include <string>
20 #include <mysql/mysql.h>
21 #include "users.h"
22 #include "channels.h"
23 #include "modules.h"
24 #include "m_sql.h"
25
26 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
27 /* $CompileFlags: -I/usr/local/include -I/usr/include -L/usr/local/lib/mysql -L/usr/lib/mysql -lmysqlclient */
28
29 /** SQLConnection represents one mysql session.
30  * Each session has its own persistent connection to the database.
31  */
32 class SQLConnection
33 {
34  protected:
35
36         MYSQL connection;
37         MYSQL_RES *res;
38         MYSQL_ROW row;
39         std::string host;
40         std::string user;
41         std::string pass;
42         std::string db;
43         std::map<std::string,std::string> thisrow;
44         bool Enabled;
45         long id;
46
47  public:
48
49         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
50         // MYSQL struct, but does not connect yet.
51         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
52         {
53                 this->Enabled = true;
54                 this->host = thishost;
55                 this->user = thisuser;
56                 this->pass = thispass;
57                 this->db = thisdb;
58                 this->id = myid;
59                 mysql_init(&connection);
60         }
61
62         // This method connects to the database using the credentials supplied to the constructor, and returns
63         // true upon success.
64         bool Connect()
65         {
66                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
67         }
68
69         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
70         // multiple rows.
71         bool QueryResult(std::string query)
72         {
73                 int r = mysql_query(&connection, query.c_str());
74                 if (!r)
75                 {
76                         res = mysql_use_result(&connection);
77                 }
78                 return (!r);
79         }
80
81         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
82         // the number of effected rows is returned in the return value.
83         unsigned long QueryCount(std::string query)
84         {
85                 int r = mysql_query(&connection, query.c_str());
86                 if (!r)
87                 {
88                         res = mysql_store_result(&connection);
89                         unsigned long rows = mysql_affected_rows(&connection);
90                         mysql_free_result(res);
91                         return rows;
92                 }
93                 return 0;
94         }
95
96         // This method fetches a row, if available from the database. You must issue a query
97         // using QueryResult() first! The row's values are returned as a map of std::string
98         // where each item is keyed by the column name.
99         std::map<std::string,std::string> GetRow()
100         {
101                 thisrow.clear();
102                 if (res)
103                 {
104                         row = mysql_fetch_row(res);
105                         if (row)
106                         {
107                                 unsigned int field_count = 0;
108                                 if(mysql_field_count(&connection) == 0)
109                                         return thisrow;
110                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
111                                 while (field_count < mysql_field_count(&connection))
112                                 {
113                                         thisrow[std::string(fields[field_count].name)] = std::string(row[field_count]);
114                                         field_count++;
115                                 }
116                                 return thisrow;
117                         }
118                 }
119                 return thisrow;
120         }
121
122         bool QueryDone()
123         {
124                 if (res)
125                 {
126                         mysql_free_result(res);
127                         return true;
128                 }
129                 else return false;
130         }
131
132         std::string GetError()
133         {
134                 return mysql_error(&connection);
135         }
136
137         long GetID()
138         {
139                 return id;
140         }
141
142         std::string GetHost()
143         {
144                 return host;
145         }
146
147         void Enable()
148         {
149                 Enabled = true;
150         }
151
152         void Disable()
153         {
154                 Enabled = false;
155         }
156
157         bool IsEnabled()
158         {
159                 return Enabled;
160         }
161
162 };
163
164 typedef std::vector<SQLConnection> ConnectionList;
165
166 class ModuleSQL : public Module
167 {
168         Server *Srv;
169         ConfigReader *Conf;
170         ConnectionList Connections;
171  
172  public:
173         void ConnectDatabases()
174         {
175                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
176                 {
177                         i->Enable();
178                         if (i->Connect())
179                         {
180                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
181                         }
182                         else
183                         {
184                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
185                                 i->Disable();
186                         }
187                 }
188         }
189
190         void LoadDatabases(ConfigReader* Conf)
191         {
192                 Srv->Log(DEFAULT,"SQL: Loading database settings");
193                 Connections.clear();
194                 Srv->Log(DEBUG,"Cleared connections");
195                 for (int j =0; j < Conf->Enumerate("database"); j++)
196                 {
197                         std::string db = Conf->ReadValue("database","name",j);
198                         std::string user = Conf->ReadValue("database","username",j);
199                         std::string pass = Conf->ReadValue("database","password",j);
200                         std::string host = Conf->ReadValue("database","hostname",j);
201                         std::string id = Conf->ReadValue("database","id",j);
202                         Srv->Log(DEBUG,"Read database settings");
203                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
204                         {
205                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
206                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
207                                 Connections.push_back(ThisSQL);
208                                 Srv->Log(DEBUG,"Pushed back connection");
209                         }
210                 }
211                 ConnectDatabases();
212         }
213
214         void ResultType(SQLRequest *r, SQLResult *res)
215         {
216                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
217                 {
218                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
219                         {
220                                 bool xr = i->QueryResult(r->GetQuery());
221                                 if (!xr)
222                                 {
223                                         res->SetType(SQL_ERROR);
224                                         res->SetError(i->GetError());
225                                         return;
226                                 }
227                                 res->SetType(SQL_OK);
228                                 return;
229                         }
230                 }
231         }
232
233         void CountType(SQLRequest *r, SQLResult* res)
234         {
235                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
236                 {
237                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
238                         {
239                                 res->SetType(SQL_COUNT);
240                                 res->SetCount(i->QueryCount(r->GetQuery()));
241                                 return;
242                         }
243                 }
244         }
245
246         void DoneType(SQLRequest *r, SQLResult* res)
247         {
248                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
249                 {
250                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
251                         {
252                                 res->SetType(SQL_DONE);
253                                 if (!i->QueryDone())
254                                         res->SetType(SQL_ERROR);
255                         }
256                 }
257         }
258
259         void RowType(SQLRequest *r, SQLResult* res)
260         {
261                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
262                 {
263                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
264                         {
265                                 log(DEBUG,"*** FOUND MATCHING ROW");
266                                 std::map<std::string,std::string> row = i->GetRow();
267                                 res->SetRow(row);
268                                 res->SetType(SQL_ROW);
269                                 if (!row.size())
270                                 {
271                                         log(DEBUG,"ROW SIZE IS 0");
272                                         res->SetType(SQL_END);
273                                 }
274                                 return;
275                         }
276                 }
277         }
278
279         char* OnRequest(Request* request)
280         {
281                 if (request)
282                 {
283                         SQLResult* Result = new SQLResult();
284                         SQLRequest *r = (SQLRequest*)request->GetData();
285                         switch (r->GetQueryType())
286                         {
287                                 case SQL_RESULT:
288                                         ResultType(r,Result);
289                                 break;
290                                 case SQL_COUNT:
291                                         CountType(r,Result);
292                                 break;
293                                 case SQL_ROW:
294                                         RowType(r,Result);
295                                 break;
296                                 case SQL_DONE:
297                                         DoneType(r,Result);
298                                 break;
299                         }
300                         return (char*)Result;
301                 }
302                 return NULL;
303         }
304
305         ModuleSQL()
306         {
307                 Srv = new Server();
308                 Conf = new ConfigReader();
309                 LoadDatabases(Conf);
310         }
311         
312         virtual ~ModuleSQL()
313         {
314                 Connections.clear();
315                 delete Conf;
316                 delete Srv;
317         }
318         
319         virtual void OnRehash()
320         {
321                 delete Conf;
322                 Conf = new ConfigReader();
323                 LoadDatabases(Conf);
324         }
325         
326         virtual Version GetVersion()
327         {
328                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
329         }
330         
331 };
332
333 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
334
335 class ModuleSQLFactory : public ModuleFactory
336 {
337  public:
338         ModuleSQLFactory()
339         {
340         }
341         
342         ~ModuleSQLFactory()
343         {
344         }
345         
346         virtual Module * CreateModule()
347         {
348                 return new ModuleSQL;
349         }
350         
351 };
352
353
354 extern "C" void * init_module( void )
355 {
356         return new ModuleSQLFactory;
357 }
358