]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sql.cpp
Major code tidyup (-W) - expect a few belches
[user/henk/code/inspircd.git] / src / modules / extra / m_sql.cpp
1 /*       +------------------------------------+
2  *       | Inspire Internet Relay Chat Daemon |
3  *       +------------------------------------+
4  *
5  *  Inspire is copyright (C) 2002-2004 ChatSpike-Dev.
6  *                       E-mail:
7  *                <brain@chatspike.net>
8  *                <Craig@chatspike.net>
9  *     
10  * Written by Craig Edwards, Craig McLure, and others.
11  * This program is free but copyrighted software; see
12  *            the file COPYING for details.
13  *
14  * ---------------------------------------------------
15  */
16
17
18 #include <stdio.h>
19 #include <string>
20 #include <mysql.h>
21 #include "users.h"
22 #include "channels.h"
23 #include "modules.h"
24 #include "m_sql.h"
25
26 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
27 /* $CompileFlags: -I/usr/local/include/mysql -I/usr/include/mysql -I/usr/local/include -I/usr/include -L/usr/local/lib/mysql -L/usr/lib/mysql -L/usr/local/lib -lmysqlclient */
28
29 /** SQLConnection represents one mysql session.
30  * Each session has its own persistent connection to the database.
31  */
32 class SQLConnection
33 {
34  protected:
35
36         MYSQL connection;
37         MYSQL_RES *res;
38         MYSQL_ROW row;
39         std::string host;
40         std::string user;
41         std::string pass;
42         std::string db;
43         std::map<std::string,std::string> thisrow;
44         bool Enabled;
45         long id;
46
47  public:
48
49         // This constructor creates an SQLConnection object with the given credentials, and creates the underlying
50         // MYSQL struct, but does not connect yet.
51         SQLConnection(std::string thishost, std::string thisuser, std::string thispass, std::string thisdb, long myid)
52         {
53                 this->Enabled = true;
54                 this->host = thishost;
55                 this->user = thisuser;
56                 this->pass = thispass;
57                 this->db = thisdb;
58                 this->id = myid;
59                 unsigned int timeout = 1;
60                 mysql_init(&connection);
61                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
62         }
63
64         // This method connects to the database using the credentials supplied to the constructor, and returns
65         // true upon success.
66         bool Connect()
67         {
68                 return mysql_real_connect(&connection, host.c_str(), user.c_str(), pass.c_str(), db.c_str(), 0, NULL, 0);
69         }
70
71         // This method issues a query that expects multiple rows of results. Use GetRow() and QueryDone() to retrieve
72         // multiple rows.
73         bool QueryResult(std::string query)
74         {
75                 int r = mysql_query(&connection, query.c_str());
76                 if (!r)
77                 {
78                         res = mysql_use_result(&connection);
79                 }
80                 return (!r);
81         }
82
83         // This method issues a query that just expects a number of 'effected' rows (e.g. UPDATE or DELETE FROM).
84         // the number of effected rows is returned in the return value.
85         unsigned long QueryCount(std::string query)
86         {
87                 int r = mysql_query(&connection, query.c_str());
88                 if (!r)
89                 {
90                         res = mysql_store_result(&connection);
91                         unsigned long rows = mysql_affected_rows(&connection);
92                         mysql_free_result(res);
93                         return rows;
94                 }
95                 return 0;
96         }
97
98         // This method fetches a row, if available from the database. You must issue a query
99         // using QueryResult() first! The row's values are returned as a map of std::string
100         // where each item is keyed by the column name.
101         std::map<std::string,std::string> GetRow()
102         {
103                 thisrow.clear();
104                 if (res)
105                 {
106                         row = mysql_fetch_row(res);
107                         if (row)
108                         {
109                                 unsigned int field_count = 0;
110                                 if(mysql_field_count(&connection) == 0)
111                                         return thisrow;
112                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
113                                 while (field_count < mysql_field_count(&connection))
114                                 {
115                                         thisrow[std::string(fields[field_count].name)] = std::string(row[field_count]);
116                                         field_count++;
117                                 }
118                                 return thisrow;
119                         }
120                 }
121                 return thisrow;
122         }
123
124         bool QueryDone()
125         {
126                 if (res)
127                 {
128                         mysql_free_result(res);
129                         return true;
130                 }
131                 else return false;
132         }
133
134         std::string GetError()
135         {
136                 return mysql_error(&connection);
137         }
138
139         long GetID()
140         {
141                 return id;
142         }
143
144         std::string GetHost()
145         {
146                 return host;
147         }
148
149         void Enable()
150         {
151                 Enabled = true;
152         }
153
154         void Disable()
155         {
156                 Enabled = false;
157         }
158
159         bool IsEnabled()
160         {
161                 return Enabled;
162         }
163
164 };
165
166 typedef std::vector<SQLConnection> ConnectionList;
167
168 class ModuleSQL : public Module
169 {
170         Server *Srv;
171         ConfigReader *Conf;
172         ConnectionList Connections;
173  
174  public:
175         void ConnectDatabases()
176         {
177                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
178                 {
179                         i->Enable();
180                         if (i->Connect())
181                         {
182                                 Srv->Log(DEFAULT,"SQL: Successfully connected database "+i->GetHost());
183                         }
184                         else
185                         {
186                                 Srv->Log(DEFAULT,"SQL: Failed to connect database "+i->GetHost()+": Error: "+i->GetError());
187                                 i->Disable();
188                         }
189                 }
190         }
191
192         void LoadDatabases(ConfigReader* ThisConf)
193         {
194                 Srv->Log(DEFAULT,"SQL: Loading database settings");
195                 Connections.clear();
196                 Srv->Log(DEBUG,"Cleared connections");
197                 for (int j =0; j < ThisConf->Enumerate("database"); j++)
198                 {
199                         std::string db = ThisConf->ReadValue("database","name",j);
200                         std::string user = ThisConf->ReadValue("database","username",j);
201                         std::string pass = ThisConf->ReadValue("database","password",j);
202                         std::string host = ThisConf->ReadValue("database","hostname",j);
203                         std::string id = ThisConf->ReadValue("database","id",j);
204                         Srv->Log(DEBUG,"Read database settings");
205                         if ((db != "") && (host != "") && (user != "") && (id != "") && (pass != ""))
206                         {
207                                 SQLConnection ThisSQL(host,user,pass,db,atoi(id.c_str()));
208                                 Srv->Log(DEFAULT,"Loaded database: "+ThisSQL.GetHost());
209                                 Connections.push_back(ThisSQL);
210                                 Srv->Log(DEBUG,"Pushed back connection");
211                         }
212                 }
213                 ConnectDatabases();
214         }
215
216         void ResultType(SQLRequest *r, SQLResult *res)
217         {
218                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
219                 {
220                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
221                         {
222                                 bool xr = i->QueryResult(r->GetQuery());
223                                 if (!xr)
224                                 {
225                                         res->SetType(SQL_ERROR);
226                                         res->SetError(i->GetError());
227                                         return;
228                                 }
229                                 res->SetType(SQL_OK);
230                                 return;
231                         }
232                 }
233         }
234
235         void CountType(SQLRequest *r, SQLResult* res)
236         {
237                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
238                 {
239                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
240                         {
241                                 res->SetType(SQL_COUNT);
242                                 res->SetCount(i->QueryCount(r->GetQuery()));
243                                 return;
244                         }
245                 }
246         }
247
248         void DoneType(SQLRequest *r, SQLResult* res)
249         {
250                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
251                 {
252                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
253                         {
254                                 res->SetType(SQL_DONE);
255                                 if (!i->QueryDone())
256                                         res->SetType(SQL_ERROR);
257                         }
258                 }
259         }
260
261         void RowType(SQLRequest *r, SQLResult* res)
262         {
263                 for (ConnectionList::iterator i = Connections.begin(); i != Connections.end(); i++)
264                 {
265                         if ((i->GetID() == r->GetConnID()) && (i->IsEnabled()))
266                         {
267                                 log(DEBUG,"*** FOUND MATCHING ROW");
268                                 std::map<std::string,std::string> row = i->GetRow();
269                                 res->SetRow(row);
270                                 res->SetType(SQL_ROW);
271                                 if (!row.size())
272                                 {
273                                         log(DEBUG,"ROW SIZE IS 0");
274                                         res->SetType(SQL_END);
275                                 }
276                                 return;
277                         }
278                 }
279         }
280
281         char* OnRequest(Request* request)
282         {
283                 if (request)
284                 {
285                         SQLResult* Result = new SQLResult();
286                         SQLRequest *r = (SQLRequest*)request->GetData();
287                         switch (r->GetQueryType())
288                         {
289                                 case SQL_RESULT:
290                                         ResultType(r,Result);
291                                 break;
292                                 case SQL_COUNT:
293                                         CountType(r,Result);
294                                 break;
295                                 case SQL_ROW:
296                                         RowType(r,Result);
297                                 break;
298                                 case SQL_DONE:
299                                         DoneType(r,Result);
300                                 break;
301                         }
302                         return (char*)Result;
303                 }
304                 return NULL;
305         }
306
307         ModuleSQL()
308         {
309                 Srv = new Server();
310                 Conf = new ConfigReader();
311                 LoadDatabases(Conf);
312         }
313         
314         virtual ~ModuleSQL()
315         {
316                 Connections.clear();
317                 delete Conf;
318                 delete Srv;
319         }
320         
321         virtual void OnRehash()
322         {
323                 delete Conf;
324                 Conf = new ConfigReader();
325                 LoadDatabases(Conf);
326         }
327         
328         virtual Version GetVersion()
329         {
330                 return Version(1,0,0,0,VF_VENDOR|VF_SERVICEPROVIDER);
331         }
332         
333 };
334
335 // stuff down here is the module-factory stuff. For basic modules you can ignore this.
336
337 class ModuleSQLFactory : public ModuleFactory
338 {
339  public:
340         ModuleSQLFactory()
341         {
342         }
343         
344         ~ModuleSQLFactory()
345         {
346         }
347         
348         virtual Module * CreateModule()
349         {
350                 return new ModuleSQL;
351         }
352         
353 };
354
355
356 extern "C" void * init_module( void )
357 {
358         return new ModuleSQLFactory;
359 }
360