]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sql.cpp
1b7862281f74c764483b08a9056a6b040886c2d2
[user/henk/code/inspircd.git] / src / modules / extra / m_sql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  Inspire is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17
18 #include <stdio.h>
19 #include <string>
20 #include <mysql/mysql.h>
21 #include "users.h"
22 #include "channels.h"
23 #include "modules.h"
24 #include "m_sql.h"
25
26 /* $ModDesc: m_filter with regexps */
27 /* $CompileFlags: -I/usr/local/include -I/usr/include -L/usr/local/lib/mysql -L/usr/lib/mysql -lmysqlclient */
28
29 /** SQLConnection represents one mysql session.
30  * Each session has its own persistent connection to the database.
31  */
32 class SQLConnection
33 {
34  protected:
35
36         MYSQL connection;
37         MYSQL_RES *res;
38         MYSQL_ROW row;
39         std::string host;
40         std::string user;
41         std::string pass;
42         std::string db;
43         std::map<std::string,std::string> thisrow;
44         bool Enabled;
45         long id;
46
47  public:
48
49         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
50         // MYSQL struct, but does not connect yet.
51         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
52         {
53                 this->Enabled = true;
54                 this->host = thishost;
55                 this->user = thisuser;
56                 this->pass = thispass;
57                 this->db = thisdb;
58                 this->id = myid;
59                 mysql_init(&connection);
60         }
61
62         // This method connects to the database using the credentials supplied to the constructor, and returns
63         // true upon success.
64         bool Connect()
65         {
66                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
67         }
68
69         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
70         // multiple rows.
71         bool QueryResult(std::string query)
72         {
73                 char escaped_query[query.length()+1];
74                 mysql_real_escape_string(&connection, escaped_query, query.c_str(), query.length());
75                 int r = mysql_query(&connection, escaped_query);
76                 if (!r)
77                 {
78                         res = mysql_use_result(&connection);
79                 }
80                 return (!r);
81         }
82
83         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
84         // the number of effected rows is returned in the return value.
85         unsigned long QueryCount(std::string query)
86         {
87                 char escaped_query[query.length()+1];
88                 mysql_real_escape_string(&connection, escaped_query, query.c_str(), query.length());
89                 int r = mysql_query(&connection, escaped_query);
90                 if (!r)
91                 {
92                         res = mysql_store_result(&connection);
93                         unsigned long rows = mysql_affected_rows(&connection);
94                         mysql_free_result(res);
95                         return rows;
96                 }
97                 return 0;
98         }
99
100         // This method fetches a row, if available from the database. You must issue a query
101         // using QueryResult() first! The row's values are returned as a map of std::string
102         // where each item is keyed by the column name.
103         std::map<std::string,std::string> GetRow()
104         {
105                 thisrow.clear();
106                 if (res)
107                 {
108                         row = mysql_fetch_row(res);
109                         if (row)
110                         {
111                                 unsigned int field_count = 0;
112                                 if(mysql_field_count(&connection) == 0)
113                                         return thisrow;
114                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
115                                 while (field_count < mysql_field_count(&connection))
116                                 {
117                                         thisrow[std::string(fields[field_count].name)] = std::string(row[field_count]);
118                                         field_count++;
119                                 }
120                                 return thisrow;
121                         }
122                 }
123                 return thisrow;
124         }
125
126         bool QueryDone()
127         {
128                 if (res)
129                 {
130                         mysql_free_result(res);
131                         return true;
132                 }
133                 else return false;
134         }
135
136         std::string GetError()
137         {
138                 return mysql_error(&connection);
139         }
140
141         long GetID()
142         {
143                 return id;
144         }
145
146         std::string GetHost()
147         {
148                 return host;
149         }
150
151         void Enable()
152         {
153                 Enabled = true;
154         }
155
156         void Disable()
157         {
158                 Enabled = false;
159         }
160
161         bool IsEnabled()
162         {
163                 return Enabled;
164         }
165
166         ~SQLConnection()
167         {
168                 mysql_close(&connection);
169         }
170 };
171
172 typedef std::vector<SQLConnection> ConnectionList;
173
174 class ModuleSQL : public Module
175 {
176         Server *Srv;
177         ConfigReader *Conf;
178         ConnectionList Connections;
179  
180  public:
181         void ConnectDatabases()
182         {
183                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
184                 {
185                         i->Enable();
186                         if (i->Connect())
187                         {
188                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
189                         }
190                         else
191                         {
192                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
193                                 i->Disable();
194                         }
195                 }
196         }
197
198         void LoadDatabases(ConfigReader* Conf)
199         {
200                 Srv->Log(DEFAULT,"SQL: Loading database settings");
201                 Connections.clear();
202                 for (int j =0; j < Conf->Enumerate("database"); j++)
203                 {
204                         std::string db = Conf->ReadValue("database","name",j);
205                         std::string user = Conf->ReadValue("database","username",j);
206                         std::string pass = Conf->ReadValue("database","password",j);
207                         std::string host = Conf->ReadValue("database","hostname",j);
208                         std::string id = Conf->ReadValue("database","id",j);
209                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
210                         {
211                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
212                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
213                                 Connections.push_back(ThisSQL);
214                         }
215                 }
216                 ConnectDatabases();
217         }
218
219         void ResultType(SQLRequest *r, SQLResult *res)
220         {
221                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
222                 {
223                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
224                         {
225                                 bool xr = i->QueryResult(r->GetQuery());
226                                 if (!xr)
227                                 {
228                                         res->SetType(SQL_ERROR);
229                                         res->SetError(i->GetError());
230                                         return;
231                                 }
232                                 res->SetType(SQL_OK);
233                                 return;
234                         }
235                 }
236         }
237
238         void CountType(SQLRequest *r, SQLResult* res)
239         {
240                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
241                 {
242                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
243                         {
244                                 res->SetType(SQL_COUNT);
245                                 res->SetCount(i->QueryCount(r->GetQuery()));
246                                 return;
247                         }
248                 }
249         }
250
251         void DoneType(SQLRequest *r, SQLResult* res)
252         {
253                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
254                 {
255                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
256                         {
257                                 res->SetType(SQL_DONE);
258                                 if (!i->QueryDone())
259                                         res->SetType(SQL_ERROR);
260                         }
261                 }
262         }
263
264         void RowType(SQLRequest *r, SQLResult* res)
265         {
266                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
267                 {
268                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
269                         {
270                                 log(DEBUG,"*** FOUND MATCHING ROW");
271                                 std::map<std::string,std::string> row = i->GetRow();
272                                 res->SetRow(row);
273                                 res->SetType(SQL_ROW);
274                                 if (!row.size())
275                                 {
276                                         log(DEBUG,"ROW SIZE IS 0");
277                                         res->SetType(SQL_END);
278                                 }
279                                 return;
280                         }
281                 }
282         }
283
284         char* OnRequest(Request* request)
285         {
286                 if (request)
287                 {
288                         SQLResult* Result = new SQLResult();
289                         SQLRequest *r = (SQLRequest*)request->GetData();
290                         switch (r->GetQueryType())
291                         {
292                                 case SQL_RESULT:
293                                         ResultType(r,Result);
294                                 break;
295                                 case SQL_COUNT:
296                                         CountType(r,Result);
297                                 break;
298                                 case SQL_ROW:
299                                         RowType(r,Result);
300                                 break;
301                                 case SQL_DONE:
302                                         DoneType(r,Result);
303                                 break;
304                         }
305                         return (char*)Result;
306                 }
307                 return NULL;
308         }
309
310         ModuleSQL()
311         {
312                 Srv = new Server();
313                 Conf = new ConfigReader();
314                 LoadDatabases(Conf);
315         }
316         
317         virtual ~ModuleSQL()
318         {
319                 Connections.clear();
320                 delete Conf;
321                 delete Srv;
322         }
323         
324         virtual void OnRehash()
325         {
326                 delete Conf;
327                 Conf = new ConfigReader();
328                 LoadDatabases(Conf);
329         }
330         
331         virtual Version GetVersion()
332         {
333                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
334         }
335         
336 };
337
338 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
339
340 class ModuleSQLFactory : public ModuleFactory
341 {
342  public:
343         ModuleSQLFactory()
344         {
345         }
346         
347         ~ModuleSQLFactory()
348         {
349         }
350         
351         virtual Module * CreateModule()
352         {
353                 return new ModuleSQL;
354         }
355         
356 };
357
358
359 extern "C" void * init_module( void )
360 {
361         return new ModuleSQLFactory;
362 }
363