]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sql.cpp
Added header for m_sql with inherited Request class
[user/henk/code/inspircd.git] / src / modules / extra / m_sql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  Inspire is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17
18 #include <stdio.h>
19 #include <string>
20 #include <mysql/mysql.h>
21 #include "users.h"
22 #include "channels.h"
23 #include "modules.h"
24 #include "m_sql.h"
25
26 /* $ModDesc: m_filter with regexps */
27 /* $CompileFlags: -I/usr/local/include -I/usr/include -L/usr/local/lib/mysql -L/usr/lib/mysql -lmysqlclient */
28
29 /** SQLConnection represents one mysql session.
30  * Each session has its own persistent connection to the database.
31  */
32 class SQLConnection
33 {
34  protected:
35
36         MYSQL connection;
37         MYSQL_RES *res;
38         MYSQL_ROW row;
39         std::string host;
40         std::string user;
41         std::string pass;
42         std::string db;
43         std::map<std::string,std::string> thisrow;
44         bool Enabled;
45         long id;
46
47  public:
48
49         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
50         // MYSQL struct, but does not connect yet.
51         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
52         {
53                 this->Enabled = true;
54                 this->host = thishost;
55                 this->user = thisuser;
56                 this->pass = thispass;
57                 this->db = thisdb;
58                 this->id = myid;
59                 mysql_init(&connection);
60         }
61
62         // This method connects to the database using the credentials supplied to the constructor, and returns
63         // true upon success.
64         bool Connect()
65         {
66                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
67         }
68
69         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
70         // multiple rows.
71         bool QueryResult(std::string query)
72         {
73                 char escaped_query[query.length()+1];
74                 mysql_real_escape_string(&connection, escaped_query, query.c_str(), query.length());
75                 int r = mysql_query(&connection, escaped_query);
76                 if (!r)
77                 {
78                         res = mysql_use_result(&connection);
79                 }
80                 return (!r);
81         }
82
83         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
84         // the number of effected rows is returned in the return value.
85         unsigned long QueryCount(std::string query)
86         {
87                 char escaped_query[query.length()+1];
88                 mysql_real_escape_string(&connection, escaped_query, query.c_str(), query.length());
89                 int r = mysql_query(&connection, escaped_query);
90                 if (!r)
91                 {
92                         res = mysql_store_result(&connection);
93                         unsigned long rows = mysql_affected_rows(&connection);
94                         mysql_free_result(res);
95                         return rows;
96                 }
97                 return 0;
98         }
99
100         // This method fetches a row, if available from the database. You must issue a query
101         // using QueryResult() first! The row's values are returned as a map of std::string
102         // where each item is keyed by the column name.
103         std::map<std::string,std::string> GetRow()
104         {
105                 thisrow.clear();
106                 if (res)
107                 {
108                         row = mysql_fetch_row(res);
109                         if (row)
110                         {
111                                 unsigned int field_count = 0;
112                                 if(mysql_field_count(&connection) == 0)
113                                         return thisrow;
114                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
115                                 while (field_count < mysql_field_count(&connection))
116                                 {
117                                         thisrow[std::string(fields[field_count].name)] = std::string(row[field_count]);
118                                         field_count++;
119                                 }
120                                 return thisrow;
121                         }
122                 }
123                 return thisrow;
124         }
125
126         bool QueryDone()
127         {
128                 if (res)
129                 {
130                         mysql_free_result(res);
131                         return true;
132                 }
133                 else return false;
134         }
135
136         std::string GetError()
137         {
138                 return mysql_error(&connection);
139         }
140
141         long GetID()
142         {
143                 return id;
144         }
145
146         std::string GetHost()
147         {
148                 return host;
149         }
150
151         void Enable()
152         {
153                 Enabled = true;
154         }
155
156         void Disable()
157         {
158                 Enabled = false;
159         }
160
161         bool IsEnabled()
162         {
163                 return Enabled;
164         }
165
166         ~SQLConnection()
167         {
168                 mysql_close(&connection);
169         }
170 };
171
172 typedef std::vector<SQLConnection> ConnectionList;
173
174 class ModuleSQL : public Module
175 {
176         Server *Srv;
177         ConfigReader *Conf;
178         ConnectionList Connections;
179  
180  public:
181         void ConnectDatabases()
182         {
183                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
184                 {
185                         i->Enable();
186                         if (i->Connect())
187                         {
188                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
189                         }
190                         else
191                         {
192                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
193                                 i->Disable();
194                         }
195                 }
196         }
197
198         void LoadDatabases(ConfigReader* Conf)
199         {
200                 Srv->Log(DEFAULT,"SQL: Loading database settings");
201                 Connections.clear();
202                 for (int j =0; j < Conf->Enumerate("database"); j++)
203                 {
204                         std::string db = Conf->ReadValue("database","name",j);
205                         std::string user = Conf->ReadValue("database","username",j);
206                         std::string pass = Conf->ReadValue("database","password",j);
207                         std::string host = Conf->ReadValue("database","hostname",j);
208                         std::string id = Conf->ReadValue("database","id",j);
209                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
210                         {
211                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
212                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
213                                 Connections.push_back(ThisSQL);
214                         }
215                 }
216                 ConnectDatabases();
217         }
218
219         void ResultType(SQLRequest *r, SQLResult *res)
220         {
221                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
222                 {
223                         if ((i->GetID() == r->GetConnID()) && (i->Enabled()))
224                         {
225                                 bool xr = i->QueryResult(r->GetQuery());
226                                 if (!xr)
227                                 {
228                                         res->SetType(SQL_ERROR);
229                                         res->SetError(r->GetError());
230                                         return;
231                                 }
232                         }
233                 }
234         }
235
236         void CountType(SQLRequest *r, SQLResult* res)
237         {
238                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
239                 {
240                         if ((i->GetID() == r->GetConnID()) && (i->Enabled()))
241                         {
242                                 res->SetType(SQL_COUNT);
243                                 res->SetCount(i->QueryCount(r->GetQuery()));
244                                 return;
245                         }
246                 }
247         }
248
249         void RowType(SQLRequest *r, SQLResult* res)
250         {
251                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
252                 {
253                         if ((i->GetID() == r->GetConnID()) && (i->Enabled()))
254                         {
255                                 std::map<std::string,std::string> row = i->GetRow();
256                                 res->SetRow(row);
257                                 res->SetType(SQL_ROW);
258                                 if (!row.size())
259                                         res->SetType(SQL_END);
260                                 return;
261                         }
262                 }
263         }
264
265         char* OnRequest(Request* request)
266         {
267                 SQLResult Result = new SQLResult();
268                 SQLRequest *r = (SQLRequest*)request;
269                 switch (r->GetRequest())
270                 {
271                         case SQL_RESULT:
272                                 ResultType(r,Result);
273                         break;
274                         case SQL_COUNT:
275                                 CountType(r,Result);
276                         break;
277                         case SQL_ROW:
278                                 RowType(r,Result);
279                         break;
280                 }
281                 return Result;
282         }
283
284         ModuleSQL()
285         {
286                 Srv = new Server();
287                 Conf = new ConfigReader();
288                 LoadDatabases(Conf);
289         }
290         
291         virtual ~ModuleSQL()
292         {
293                 Connections.clear();
294                 delete Conf;
295                 delete Srv;
296         }
297         
298         virtual void OnRehash()
299         {
300                 delete Conf;
301                 Conf = new ConfigReader();
302                 LoadDatabases(Conf);
303         }
304         
305         virtual Version GetVersion()
306         {
307                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
308         }
309         
310 };
311
312 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
313
314 class ModuleSQLFactory : public ModuleFactory
315 {
316  public:
317         ModuleSQLFactory()
318         {
319         }
320         
321         ~ModuleSQLFactory()
322         {
323         }
324         
325         virtual Module * CreateModule()
326         {
327                 return new ModuleSQL;
328         }
329         
330 };
331
332
333 extern "C" void * init_module( void )
334 {
335         return new ModuleSQLFactory;
336 }
337