]> git.netwichtig.de Git - user/henk/code/inspircd.git/blob - src/modules/extra/m_sqlite3.cpp
Add column names to SQLv3, allow sqloper to specify its own query string
[user/henk/code/inspircd.git] / src / modules / extra / m_sqlite3.cpp
1 /*               +------------------------------------+
2  *               | Inspire Internet Relay Chat Daemon |
3  *               +------------------------------------+
4  *
5  *      InspIRCd: (C) 2002-2010 InspIRCd Development Team
6  * See: http://wiki.inspircd.org/Credits
7  *
8  * This program is free but copyrighted software; see
9  *                        the file COPYING for details.
10  *
11  * ---------------------------------------------------
12  */
13
14 #include "inspircd.h"
15 #include <sqlite3.h>
16 #include "sql.h"
17
18 /* $ModDesc: sqlite3 provider */
19 /* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */
20 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
21 /* $NoPedantic */
22
23 class SQLConn;
24 typedef std::map<std::string, reference<SQLConn> > ConnMap;
25
26 class SQLite3Result : public SQLResult
27 {
28  public:
29         int currentrow;
30         int rows;
31         std::vector<std::string> columns;
32         std::vector<SQLEntries> fieldlists;
33
34         SQLite3Result() : currentrow(0), rows(0)
35         {
36         }
37
38         ~SQLite3Result()
39         {
40         }
41
42         virtual int Rows()
43         {
44                 return rows;
45         }
46
47         virtual bool GetRow(SQLEntries& result)
48         {
49                 if (currentrow < rows)
50                 {
51                         result.assign(fieldlists[currentrow].begin(), fieldlists[currentrow].end());
52                         currentrow++;
53                         return true;
54                 }
55                 else
56                 {
57                         result.clear();
58                         return false;
59                 }
60         }
61         virtual void GetCols(std::vector<std::string>& result)
62         {
63                 result.assign(columns.begin(), columns.end());
64         }
65 };
66
67 class SQLConn : public refcountbase
68 {
69  private:
70         sqlite3* conn;
71         reference<ConfigTag> config;
72
73  public:
74         SQLConn(ConfigTag* tag) : config(tag)
75         {
76                 std::string host = tag->getString("hostname");
77                 if (sqlite3_open_v2(host.c_str(), &conn, SQLITE_OPEN_READWRITE, 0) != SQLITE_OK)
78                 {
79                         ServerInstance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: Could not open DB with id: " + tag->getString("id"));
80                         conn = NULL;
81                 }
82         }
83
84         ~SQLConn()
85         {
86                 sqlite3_interrupt(conn);
87                 sqlite3_close(conn);
88         }
89
90         void Query(SQLQuery* query)
91         {
92                 SQLite3Result res;
93                 sqlite3_stmt *stmt;
94                 int err = sqlite3_prepare_v2(conn, query->query.c_str(), query->query.length(), &stmt, NULL);
95                 if (err != SQLITE_OK)
96                 {
97                         SQLerror error(SQL_QSEND_FAIL, sqlite3_errmsg(conn));
98                         query->OnError(error);
99                         return;
100                 }
101                 int cols = sqlite3_column_count(stmt);
102                 res.columns.resize(cols);
103                 for(int i=0; i < cols; i++)
104                 {
105                         res.columns[i] = sqlite3_column_name(stmt, i);
106                 }
107                 while (1)
108                 {
109                         err = sqlite3_step(stmt);
110                         if (err == SQLITE_ROW)
111                         {
112                                 // Add the row
113                                 res.fieldlists.resize(res.rows + 1);
114                                 res.fieldlists[res.rows].resize(cols);
115                                 for(int i=0; i < cols; i++)
116                                 {
117                                         const char* txt = (const char*)sqlite3_column_text(stmt, i);
118                                         if (txt)
119                                                 res.fieldlists[res.rows][i] = SQLEntry(txt);
120                                 }
121                                 res.rows++;
122                         }
123                         else if (err == SQLITE_DONE)
124                         {
125                                 query->OnResult(res);
126                                 break;
127                         }
128                         else
129                         {
130                                 SQLerror error(SQL_QREPLY_FAIL, sqlite3_errmsg(conn));
131                                 query->OnError(error);
132                                 break;
133                         }
134                 }
135                 sqlite3_finalize(stmt);
136         }
137 };
138
139 class SQLiteProvider : public SQLProvider
140 {
141  public:
142         ConnMap hosts;
143
144         SQLiteProvider(Module* Parent) : SQLProvider(Parent, "SQL/SQLite") {}
145
146         std::string FormatQuery(const std::string& q, const ParamL& p)
147         {
148                 std::string res;
149                 unsigned int param = 0;
150                 for(std::string::size_type i = 0; i < q.length(); i++)
151                 {
152                         if (q[i] != '?')
153                                 res.push_back(q[i]);
154                         else
155                         {
156                                 // TODO numbered parameter support ('?1')
157                                 if (param < p.size())
158                                 {
159                                         char* escaped = sqlite3_mprintf("%q", p[param++].c_str());
160                                         res.append(escaped);
161                                         sqlite3_free(escaped);
162                                 }
163                         }
164                 }
165                 return res;
166         }
167
168         std::string FormatQuery(const std::string& q, const ParamM& p)
169         {
170                 std::string res;
171                 for(std::string::size_type i = 0; i < q.length(); i++)
172                 {
173                         if (q[i] != '$')
174                                 res.push_back(q[i]);
175                         else
176                         {
177                                 std::string field;
178                                 i++;
179                                 while (i < q.length() && isalpha(q[i]))
180                                         field.push_back(q[i++]);
181                                 i--;
182
183                                 ParamM::const_iterator it = p.find(field);
184                                 if (it != p.end())
185                                 {
186                                         char* escaped = sqlite3_mprintf("%q", it->second.c_str());
187                                         res.append(escaped);
188                                         sqlite3_free(escaped);
189                                 }
190                         }
191                 }
192                 return res;
193         }
194         
195         void submit(SQLQuery* query)
196         {
197                 ConnMap::iterator iter = hosts.find(query->dbid);
198                 if (iter == hosts.end())
199                 {
200                         SQLerror err(SQL_BAD_DBID);
201                         query->OnError(err);
202                 }
203                 else
204                 {
205                         iter->second->Query(query);
206                 }
207                 delete query;
208         }
209 };
210
211 class ModuleSQLite3 : public Module
212 {
213  private:
214         SQLiteProvider sqlserv;
215
216  public:
217         ModuleSQLite3()
218         : sqlserv(this)
219         {
220         }
221
222         void init()
223         {
224                 ServerInstance->Modules->AddService(sqlserv);
225
226                 ReadConf();
227
228                 Implementation eventlist[] = { I_OnRehash };
229                 ServerInstance->Modules->Attach(eventlist, this, 1);
230         }
231
232         virtual ~ModuleSQLite3()
233         {
234         }
235
236         void ReadConf()
237         {
238                 sqlserv.hosts.clear();
239                 ConfigTagList tags = ServerInstance->Config->ConfTags("database");
240                 for(ConfigIter i = tags.first; i != tags.second; i++)
241                 {
242                         sqlserv.hosts.insert(std::make_pair(i->second->getString("id"), new SQLConn(i->second)));
243                 }
244         }
245
246         void OnRehash(User* user)
247         {
248                 ReadConf();
249         }
250
251         Version GetVersion()
252         {
253                 return Version("sqlite3 provider", VF_VENDOR);
254         }
255 };
256
257 MODULE_INIT(ModuleSQLite3)