]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sql.cpp
Fixed removal of mesh linking from core
[user/henk/code/inspircd.git] / src / modules / extra / m_sql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  Inspire is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17 using namespace std;
18
19 #include <stdio.h>
20 #include <string>
21 #include <mysql.h>
22 #include "users.h"
23 #include "channels.h"
24 #include "modules.h"
25 #include "helperfuncs.h"
26 #include "m_sql.h"
27
28 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
29 /* $CompileFlags: -I/usr/local/include/mysql -I/usr/include/mysql -I/usr/local/include -I/usr/include -L/usr/local/lib/mysql -L/usr/lib/mysql -L/usr/local/lib -lmysqlclient */
30
31 /** SQLConnection represents one mysql session.
32  * Each session has its own persistent connection to the database.
33  */
34
35 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
36 #define mysql_field_count mysql_num_fields
37 #endif
38
39 class SQLConnection
40 {
41  protected:
42
43         MYSQL connection;
44         MYSQL_RES *res;
45         MYSQL_ROW row;
46         std::string host;
47         std::string user;
48         std::string pass;
49         std::string db;
50         std::map<std::string,std::string> thisrow;
51         bool Enabled;
52         long id;
53
54  public:
55
56         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
57         // MYSQL struct, but does not connect yet.
58         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
59         {
60                 this->Enabled = true;
61                 this->host = thishost;
62                 this->user = thisuser;
63                 this->pass = thispass;
64                 this->db = thisdb;
65                 this->id = myid;
66                 unsigned int timeout = 1;
67                 mysql_init(&connection);
68                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
69         }
70
71         // This method connects to the database using the credentials supplied to the constructor, and returns
72         // true upon success.
73         bool Connect()
74         {
75                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
76         }
77
78         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
79         // multiple rows.
80         bool QueryResult(std::string query)
81         {
82                 int r = mysql_query(&connection, query.c_str());
83                 if (!r)
84                 {
85                         res = mysql_use_result(&connection);
86                 }
87                 return (!r);
88         }
89
90         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
91         // the number of effected rows is returned in the return value.
92         unsigned long QueryCount(std::string query)
93         {
94                 int r = mysql_query(&connection, query.c_str());
95                 if (!r)
96                 {
97                         res = mysql_store_result(&connection);
98                         unsigned long rows = mysql_affected_rows(&connection);
99                         mysql_free_result(res);
100                         return rows;
101                 }
102                 return 0;
103         }
104
105         // This method fetches a row, if available from the database. You must issue a query
106         // using QueryResult() first! The row's values are returned as a map of std::string
107         // where each item is keyed by the column name.
108         std::map<std::string,std::string> GetRow()
109         {
110                 thisrow.clear();
111                 if (res)
112                 {
113                         row = mysql_fetch_row(res);
114                         if (row)
115                         {
116                                 unsigned int field_count = 0;
117                                 if(mysql_field_count(&connection) == 0)
118                                         return thisrow;
119                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
120                                 while (field_count < mysql_field_count(&connection))
121                                 {
122                                         thisrow[std::string(fields[field_count].name)] = std::string(row[field_count]);
123                                         field_count++;
124                                 }
125                                 return thisrow;
126                         }
127                 }
128                 return thisrow;
129         }
130
131         bool QueryDone()
132         {
133                 if (res)
134                 {
135                         mysql_free_result(res);
136                         return true;
137                 }
138                 else return false;
139         }
140
141         std::string GetError()
142         {
143                 return mysql_error(&connection);
144         }
145
146         long GetID()
147         {
148                 return id;
149         }
150
151         std::string GetHost()
152         {
153                 return host;
154         }
155
156         void Enable()
157         {
158                 Enabled = true;
159         }
160
161         void Disable()
162         {
163                 Enabled = false;
164         }
165
166         bool IsEnabled()
167         {
168                 return Enabled;
169         }
170
171 };
172
173 typedef std::vector<SQLConnection> ConnectionList;
174
175 class ModuleSQL : public Module
176 {
177         Server *Srv;
178         ConfigReader *Conf;
179         ConnectionList Connections;
180  
181  public:
182         void ConnectDatabases()
183         {
184                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
185                 {
186                         i->Enable();
187                         if (i->Connect())
188                         {
189                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
190                         }
191                         else
192                         {
193                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
194                                 i->Disable();
195                         }
196                 }
197         }
198
199         void LoadDatabases(ConfigReader* ThisConf)
200         {
201                 Srv->Log(DEFAULT,"SQL: Loading database settings");
202                 Connections.clear();
203                 Srv->Log(DEBUG,"Cleared connections");
204                 for (int j =0; j < ThisConf->Enumerate("database"); j++)
205                 {
206                         std::string db = ThisConf->ReadValue("database","name",j);
207                         std::string user = ThisConf->ReadValue("database","username",j);
208                         std::string pass = ThisConf->ReadValue("database","password",j);
209                         std::string host = ThisConf->ReadValue("database","hostname",j);
210                         std::string id = ThisConf->ReadValue("database","id",j);
211                         Srv->Log(DEBUG,"Read database settings");
212                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
213                         {
214                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
215                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
216                                 Connections.push_back(ThisSQL);
217                                 Srv->Log(DEBUG,"Pushed back connection");
218                         }
219                 }
220                 ConnectDatabases();
221         }
222
223         void ResultType(SQLRequest *r, SQLResult *res)
224         {
225                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
226                 {
227                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
228                         {
229                                 bool xr = i->QueryResult(r->GetQuery());
230                                 if (!xr)
231                                 {
232                                         res->SetType(SQL_ERROR);
233                                         res->SetError(i->GetError());
234                                         return;
235                                 }
236                                 res->SetType(SQL_OK);
237                                 return;
238                         }
239                 }
240         }
241
242         void CountType(SQLRequest *r, SQLResult* res)
243         {
244                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
245                 {
246                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
247                         {
248                                 res->SetType(SQL_COUNT);
249                                 res->SetCount(i->QueryCount(r->GetQuery()));
250                                 return;
251                         }
252                 }
253         }
254
255         void DoneType(SQLRequest *r, SQLResult* res)
256         {
257                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
258                 {
259                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
260                         {
261                                 res->SetType(SQL_DONE);
262                                 if (!i->QueryDone())
263                                         res->SetType(SQL_ERROR);
264                         }
265                 }
266         }
267
268         void RowType(SQLRequest *r, SQLResult* res)
269         {
270                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
271                 {
272                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
273                         {
274                                 log(DEBUG,"*** FOUND MATCHING ROW");
275                                 std::map<std::string,std::string> row = i->GetRow();
276                                 res->SetRow(row);
277                                 res->SetType(SQL_ROW);
278                                 if (!row.size())
279                                 {
280                                         log(DEBUG,"ROW SIZE IS 0");
281                                         res->SetType(SQL_END);
282                                 }
283                                 return;
284                         }
285                 }
286         }
287
288         char* OnRequest(Request* request)
289         {
290                 if (request)
291                 {
292                         SQLResult* Result = new SQLResult();
293                         SQLRequest *r = (SQLRequest*)request->GetData();
294                         switch (r->GetQueryType())
295                         {
296                                 case SQL_RESULT:
297                                         ResultType(r,Result);
298                                 break;
299                                 case SQL_COUNT:
300                                         CountType(r,Result);
301                                 break;
302                                 case SQL_ROW:
303                                         RowType(r,Result);
304                                 break;
305                                 case SQL_DONE:
306                                         DoneType(r,Result);
307                                 break;
308                         }
309                         return (char*)Result;
310                 }
311                 return NULL;
312         }
313
314         ModuleSQL()
315         {
316                 Srv = new Server();
317                 Conf = new ConfigReader();
318                 LoadDatabases(Conf);
319         }
320         
321         virtual ~ModuleSQL()
322         {
323                 Connections.clear();
324                 delete Conf;
325                 delete Srv;
326         }
327         
328         virtual void OnRehash()
329         {
330                 delete Conf;
331                 Conf = new ConfigReader();
332                 LoadDatabases(Conf);
333         }
334         
335         virtual Version GetVersion()
336         {
337                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
338         }
339         
340 };
341
342 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
343
344 class ModuleSQLFactory : public ModuleFactory
345 {
346  public:
347         ModuleSQLFactory()
348         {
349         }
350         
351         ~ModuleSQLFactory()
352         {
353         }
354         
355         virtual Module * CreateModule()
356         {
357                 return new ModuleSQL;
358         }
359         
360 };
361
362
363 extern "C" void * init_module( void )
364 {
365         return new ModuleSQLFactory;
366 }
367