]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sql.cpp
Changed description comment
[user/henk/code/inspircd.git] / src / modules / extra / m_sql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  Inspire is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17
18 #include <stdio.h>
19 #include <string>
20 #include <mysql/mysql.h>
21 #include "users.h"
22 #include "channels.h"
23 #include "modules.h"
24 #include "m_sql.h"
25
26 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
27 /* $CompileFlags: -I/usr/local/include -I/usr/include -L/usr/local/lib/mysql -L/usr/lib/mysql -lmysqlclient */
28
29 /** SQLConnection represents one mysql session.
30  * Each session has its own persistent connection to the database.
31  */
32 class SQLConnection
33 {
34  protected:
35
36         MYSQL connection;
37         MYSQL_RES *res;
38         MYSQL_ROW row;
39         std::string host;
40         std::string user;
41         std::string pass;
42         std::string db;
43         std::map<std::string,std::string> thisrow;
44         bool Enabled;
45         long id;
46
47  public:
48
49         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
50         // MYSQL struct, but does not connect yet.
51         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
52         {
53                 this->Enabled = true;
54                 this->host = thishost;
55                 this->user = thisuser;
56                 this->pass = thispass;
57                 this->db = thisdb;
58                 this->id = myid;
59                 mysql_init(&connection);
60         }
61
62         // This method connects to the database using the credentials supplied to the constructor, and returns
63         // true upon success.
64         bool Connect()
65         {
66                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
67         }
68
69         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
70         // multiple rows.
71         bool QueryResult(std::string query)
72         {
73                 int r = mysql_query(&connection, query.c_str());
74                 if (!r)
75                 {
76                         res = mysql_use_result(&connection);
77                 }
78                 return (!r);
79         }
80
81         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
82         // the number of effected rows is returned in the return value.
83         unsigned long QueryCount(std::string query)
84         {
85                 int r = mysql_query(&connection, query.c_str());
86                 if (!r)
87                 {
88                         res = mysql_store_result(&connection);
89                         unsigned long rows = mysql_affected_rows(&connection);
90                         mysql_free_result(res);
91                         return rows;
92                 }
93                 return 0;
94         }
95
96         // This method fetches a row, if available from the database. You must issue a query
97         // using QueryResult() first! The row's values are returned as a map of std::string
98         // where each item is keyed by the column name.
99         std::map<std::string,std::string> GetRow()
100         {
101                 thisrow.clear();
102                 if (res)
103                 {
104                         row = mysql_fetch_row(res);
105                         if (row)
106                         {
107                                 unsigned int field_count = 0;
108                                 if(mysql_field_count(&connection) == 0)
109                                         return thisrow;
110                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
111                                 while (field_count < mysql_field_count(&connection))
112                                 {
113                                         thisrow[std::string(fields[field_count].name)] = std::string(row[field_count]);
114                                         field_count++;
115                                 }
116                                 return thisrow;
117                         }
118                 }
119                 return thisrow;
120         }
121
122         bool QueryDone()
123         {
124                 if (res)
125                 {
126                         mysql_free_result(res);
127                         return true;
128                 }
129                 else return false;
130         }
131
132         std::string GetError()
133         {
134                 return mysql_error(&connection);
135         }
136
137         long GetID()
138         {
139                 return id;
140         }
141
142         std::string GetHost()
143         {
144                 return host;
145         }
146
147         void Enable()
148         {
149                 Enabled = true;
150         }
151
152         void Disable()
153         {
154                 Enabled = false;
155         }
156
157         bool IsEnabled()
158         {
159                 return Enabled;
160         }
161
162         ~SQLConnection()
163         {
164                 mysql_close(&connection);
165         }
166 };
167
168 typedef std::vector<SQLConnection> ConnectionList;
169
170 class ModuleSQL : public Module
171 {
172         Server *Srv;
173         ConfigReader *Conf;
174         ConnectionList Connections;
175  
176  public:
177         void ConnectDatabases()
178         {
179                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
180                 {
181                         i->Enable();
182                         if (i->Connect())
183                         {
184                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
185                         }
186                         else
187                         {
188                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
189                                 i->Disable();
190                         }
191                 }
192         }
193
194         void LoadDatabases(ConfigReader* Conf)
195         {
196                 Srv->Log(DEFAULT,"SQL: Loading database settings");
197                 Connections.clear();
198                 for (int j =0; j < Conf->Enumerate("database"); j++)
199                 {
200                         std::string db = Conf->ReadValue("database","name",j);
201                         std::string user = Conf->ReadValue("database","username",j);
202                         std::string pass = Conf->ReadValue("database","password",j);
203                         std::string host = Conf->ReadValue("database","hostname",j);
204                         std::string id = Conf->ReadValue("database","id",j);
205                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
206                         {
207                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
208                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
209                                 Connections.push_back(ThisSQL);
210                         }
211                 }
212                 ConnectDatabases();
213         }
214
215         void ResultType(SQLRequest *r, SQLResult *res)
216         {
217                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
218                 {
219                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
220                         {
221                                 bool xr = i->QueryResult(r->GetQuery());
222                                 if (!xr)
223                                 {
224                                         res->SetType(SQL_ERROR);
225                                         res->SetError(i->GetError());
226                                         return;
227                                 }
228                                 res->SetType(SQL_OK);
229                                 return;
230                         }
231                 }
232         }
233
234         void CountType(SQLRequest *r, SQLResult* res)
235         {
236                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
237                 {
238                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
239                         {
240                                 res->SetType(SQL_COUNT);
241                                 res->SetCount(i->QueryCount(r->GetQuery()));
242                                 return;
243                         }
244                 }
245         }
246
247         void DoneType(SQLRequest *r, SQLResult* res)
248         {
249                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
250                 {
251                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
252                         {
253                                 res->SetType(SQL_DONE);
254                                 if (!i->QueryDone())
255                                         res->SetType(SQL_ERROR);
256                         }
257                 }
258         }
259
260         void RowType(SQLRequest *r, SQLResult* res)
261         {
262                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
263                 {
264                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
265                         {
266                                 log(DEBUG,"*** FOUND MATCHING ROW");
267                                 std::map<std::string,std::string> row = i->GetRow();
268                                 res->SetRow(row);
269                                 res->SetType(SQL_ROW);
270                                 if (!row.size())
271                                 {
272                                         log(DEBUG,"ROW SIZE IS 0");
273                                         res->SetType(SQL_END);
274                                 }
275                                 return;
276                         }
277                 }
278         }
279
280         char* OnRequest(Request* request)
281         {
282                 if (request)
283                 {
284                         SQLResult* Result = new SQLResult();
285                         SQLRequest *r = (SQLRequest*)request->GetData();
286                         switch (r->GetQueryType())
287                         {
288                                 case SQL_RESULT:
289                                         ResultType(r,Result);
290                                 break;
291                                 case SQL_COUNT:
292                                         CountType(r,Result);
293                                 break;
294                                 case SQL_ROW:
295                                         RowType(r,Result);
296                                 break;
297                                 case SQL_DONE:
298                                         DoneType(r,Result);
299                                 break;
300                         }
301                         return (char*)Result;
302                 }
303                 return NULL;
304         }
305
306         ModuleSQL()
307         {
308                 Srv = new Server();
309                 Conf = new ConfigReader();
310                 LoadDatabases(Conf);
311         }
312         
313         virtual ~ModuleSQL()
314         {
315                 Connections.clear();
316                 delete Conf;
317                 delete Srv;
318         }
319         
320         virtual void OnRehash()
321         {
322                 delete Conf;
323                 Conf = new ConfigReader();
324                 LoadDatabases(Conf);
325         }
326         
327         virtual Version GetVersion()
328         {
329                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
330         }
331         
332 };
333
334 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
335
336 class ModuleSQLFactory : public ModuleFactory
337 {
338  public:
339         ModuleSQLFactory()
340         {
341         }
342         
343         ~ModuleSQLFactory()
344         {
345         }
346         
347         virtual Module * CreateModule()
348         {
349                 return new ModuleSQL;
350         }
351         
352 };
353
354
355 extern "C" void * init_module( void )
356 {
357         return new ModuleSQLFactory;
358 }
359