]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
4a0538bc8e3d4ad8d8f92836706c9cc6567547e9
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*               +------------------------------------+
2  *               | Inspire Internet Relay Chat Daemon |
3  *               +------------------------------------+
4  *
5  *      InspIRCd: (C) 2002-2010 InspIRCd Development Team
6  * See: http://wiki.inspircd.org/Credits
7  *
8  * This program is free but copyrighted software; see
9  *                        the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include "inspircd.h"
15 #include <sqlite3.h>
16 #include "sql.h"
17
18 /* $ModDesc: sqlite3 provider */
19 /* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */
20 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
21 /* $NoPedantic */
22
23 class SQLConn;
24 typedef std::map<std::string, reference<SQLConn> > ConnMap;
25
26 class SQLite3Result : public SQLResult
27 {
28  public:
29         int currentrow;
30         int rows;
31         std::vector<std::vector<std::string> > fieldlists;
32
33         SQLite3Result() : currentrow(0), rows(0)
34         {
35         }
36
37         ~SQLite3Result()
38         {
39         }
40
41         virtual int Rows()
42         {
43                 return rows;
44         }
45
46         virtual bool GetRow(std::vector<std::string>& result)
47         {
48                 if (currentrow < rows)
49                 {
50                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
51                         currentrow++;
52                         return true;
53                 }
54                 else
55                 {
56                         result.clear();
57                         return false;
58                 }
59         }
60 };
61
62 class SQLConn : public refcountbase
63 {
64  private:
65         sqlite3* conn;
66         reference<ConfigTag> config;
67
68  public:
69         SQLConn(ConfigTag* tag) : config(tag)
70         {
71                 std::string host = tag->getString("hostname");
72                 if (sqlite3_open_v2(host.c_str(), &conn, SQLITE_OPEN_READWRITE, 0) != SQLITE_OK)
73                 {
74                         ServerInstance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: Could not open DB with id: " + tag->getString("id"));
75                         conn = NULL;
76                 }
77         }
78
79         ~SQLConn()
80         {
81                 sqlite3_interrupt(conn);
82                 sqlite3_close(conn);
83         }
84
85         void Query(SQLQuery* query)
86         {
87                 SQLite3Result res;
88                 sqlite3_stmt *stmt;
89                 int err = sqlite3_prepare_v2(conn, query->query.c_str(), query->query.length(), &stmt, NULL);
90                 if (err != SQLITE_OK)
91                 {
92                         SQLerror error(SQL_QSEND_FAIL, sqlite3_errmsg(conn));
93                         query->OnError(error);
94                         return;
95                 }
96                 int cols = sqlite3_column_count(stmt);
97                 while (1)
98                 {
99                         err = sqlite3_step(stmt);
100                         if (err == SQLITE_ROW)
101                         {
102                                 // Add the row
103                                 res.fieldlists.resize(res.rows + 1);
104                                 res.fieldlists[res.rows].resize(cols);
105                                 for(int i=0; i < cols; i++)
106                                 {
107                                         const char* txt = (const char*)sqlite3_column_text(stmt, i);
108                                         res.fieldlists[res.rows][i] = txt ? txt : "";
109                                 }
110                                 res.rows++;
111                         }
112                         else if (err == SQLITE_DONE)
113                         {
114                                 query->OnResult(res);
115                                 break;
116                         }
117                         else
118                         {
119                                 SQLerror error(SQL_QREPLY_FAIL, sqlite3_errmsg(conn));
120                                 query->OnError(error);
121                                 break;
122                         }
123                 }
124                 sqlite3_finalize(stmt);
125         }
126 };
127
128 class SQLiteProvider : public SQLProvider
129 {
130  public:
131         ConnMap hosts;
132
133         SQLiteProvider(Module* Parent) : SQLProvider(Parent, "SQL/SQLite") {}
134
135         std::string FormatQuery(std::string q, ParamL p)
136         {
137                 std::string res;
138                 unsigned int param = 0;
139                 for(std::string::size_type i = 0; i < q.length(); i++)
140                 {
141                         if (q[i] != '?')
142                                 res.push_back(q[i]);
143                         else
144                         {
145                                 // TODO numbered parameter support ('?1')
146                                 if (param < p.size())
147                                 {
148                                         char* escaped = sqlite3_mprintf("%q", p[param++].c_str());
149                                         res.append(escaped);
150                                         sqlite3_free(escaped);
151                                 }
152                         }
153                 }
154                 return res;
155         }
156
157         std::string FormatQuery(std::string q, ParamM p)
158         {
159                 std::string res;
160                 for(std::string::size_type i = 0; i < q.length(); i++)
161                 {
162                         if (q[i] != '$')
163                                 res.push_back(q[i]);
164                         else
165                         {
166                                 std::string field;
167                                 i++;
168                                 while (i < q.length() && isalpha(q[i]))
169                                         field.push_back(q[i++]);
170                                 i--;
171
172                                 char* escaped = sqlite3_mprintf("%q", p[field].c_str());
173                                 res.append(escaped);
174                                 sqlite3_free(escaped);
175                         }
176                 }
177                 return res;
178         }
179         
180         void submit(SQLQuery* query)
181         {
182                 ConnMap::iterator iter = hosts.find(query->dbid);
183                 if (iter == hosts.end())
184                 {
185                         SQLerror err(SQL_BAD_DBID);
186                         query->OnError(err);
187                 }
188                 else
189                 {
190                         iter->second->Query(query);
191                 }
192                 delete query;
193         }
194 };
195
196 class ModuleSQLite3 : public Module
197 {
198  private:
199         SQLiteProvider sqlserv;
200
201  public:
202         ModuleSQLite3()
203         : sqlserv(this)
204         {
205         }
206
207         void init()
208         {
209                 ServerInstance->Modules->AddService(sqlserv);
210
211                 ReadConf();
212
213                 Implementation eventlist[] = { I_OnRehash };
214                 ServerInstance->Modules->Attach(eventlist, this, 1);
215         }
216
217         virtual ~ModuleSQLite3()
218         {
219         }
220
221         void ReadConf()
222         {
223                 sqlserv.hosts.clear();
224                 ConfigTagList tags = ServerInstance->Config->ConfTags("database");
225                 for(ConfigIter i = tags.first; i != tags.second; i++)
226                 {
227                         sqlserv.hosts.insert(std::make_pair(i->second->getString("id"), new SQLConn(i->second)));
228                 }
229         }
230
231         void OnRehash(User* user)
232         {
233                 ReadConf();
234         }
235
236         Version GetVersion()
237         {
238                 return Version("sqlite3 provider", VF_VENDOR);
239         }
240 };
241
242 MODULE_INIT(ModuleSQLite3)