diff options
Diffstat (limited to 'src/modules/extra')
-rw-r--r-- | src/modules/extra/README | 8 | ||||
-rw-r--r-- | src/modules/extra/m_filter_pcre.cpp | 183 | ||||
-rw-r--r-- | src/modules/extra/m_httpclienttest.cpp | 82 | ||||
-rw-r--r-- | src/modules/extra/m_mysql.cpp | 890 | ||||
-rw-r--r-- | src/modules/extra/m_pgsql.cpp | 985 | ||||
-rw-r--r-- | src/modules/extra/m_sqlauth.cpp | 195 | ||||
-rw-r--r-- | src/modules/extra/m_sqlite3.cpp | 661 | ||||
-rw-r--r-- | src/modules/extra/m_sqllog.cpp | 311 | ||||
-rw-r--r-- | src/modules/extra/m_sqloper.cpp | 284 | ||||
-rw-r--r-- | src/modules/extra/m_sqlutils.cpp | 239 | ||||
-rw-r--r-- | src/modules/extra/m_sqlutils.h | 144 | ||||
-rw-r--r-- | src/modules/extra/m_sqlv2.h | 606 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_gnutls.cpp | 844 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_openssl.cpp | 902 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_oper_cert.cpp | 181 | ||||
-rw-r--r-- | src/modules/extra/m_sslinfo.cpp | 95 | ||||
-rw-r--r-- | src/modules/extra/m_testclient.cpp | 111 | ||||
-rw-r--r-- | src/modules/extra/m_ziplink.cpp | 453 |
18 files changed, 7156 insertions, 18 deletions
diff --git a/src/modules/extra/README b/src/modules/extra/README index 4c4beef9d..7e3096b34 100644 --- a/src/modules/extra/README +++ b/src/modules/extra/README @@ -1 +1,7 @@ -This directory stores modules which require external libraries to compile.
For example, m_filter_pcre requires the PCRE libraries.
To compile any of these modules first ensure you have the required dependencies
(read the online documentation at http://www.inspircd.org/wiki/) and then cp
the .cpp file from this directory into the parent directory (src/modules/) and
re-configure your inspircd with ./configure -update to detect the new module.
\ No newline at end of file +This directory stores modules which require external libraries to compile. +For example, m_filter_pcre requires the PCRE libraries. + +To compile any of these modules first ensure you have the required dependencies +(read the online documentation at http://www.inspircd.org/wiki/) and then cp +the .cpp file from this directory into the parent directory (src/modules/) and +re-configure your inspircd with ./configure -update to detect the new module. diff --git a/src/modules/extra/m_filter_pcre.cpp b/src/modules/extra/m_filter_pcre.cpp index 0c6c05c8c..6fe79a981 100644 --- a/src/modules/extra/m_filter_pcre.cpp +++ b/src/modules/extra/m_filter_pcre.cpp @@ -1 +1,182 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include <pcre.h>
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "m_filter.h"
/* $ModDesc: m_filter with regexps */
/* $CompileFlags: exec("pcre-config --cflags") */
/* $LinkerFlags: exec("pcre-config --libs") rpath("pcre-config --libs") -lpcre */
/* $ModDep: m_filter.h */
#ifdef WINDOWS
#pragma comment(lib, "pcre.lib")
#endif
class PCREFilter : public FilterResult
{
public:
pcre* regexp;
PCREFilter(pcre* r, const std::string &rea, const std::string &act, long gline_time, const std::string &pat, const std::string &flags)
: FilterResult(pat, rea, act, gline_time, flags), regexp(r)
{
}
PCREFilter()
{
}
};
class ModuleFilterPCRE : public FilterBase
{
std::vector<PCREFilter> filters;
pcre *re;
const char *error;
int erroffset;
PCREFilter fr;
public:
ModuleFilterPCRE(InspIRCd* Me)
: FilterBase(Me, "m_filter_pcre.so")
{
OnRehash(NULL,"");
}
virtual ~ModuleFilterPCRE()
{
}
virtual FilterResult* FilterMatch(userrec* user, const std::string &text, int flags)
{
for (std::vector<PCREFilter>::iterator index = filters.begin(); index != filters.end(); index++)
{
/* Skip ones that dont apply to us */
if (!FilterBase::AppliesToMe(user, dynamic_cast<FilterResult*>(&(*index)), flags))
continue;
if (pcre_exec(index->regexp, NULL, text.c_str(), text.length(), 0, 0, NULL, 0) > -1)
{
fr = *index;
if (index != filters.begin())
{
filters.erase(index);
filters.insert(filters.begin(), fr);
}
return &fr;
}
}
return NULL;
}
virtual bool DeleteFilter(const std::string &freeform)
{
for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++)
{
if (i->freeform == freeform)
{
pcre_free((*i).regexp);
filters.erase(i);
return true;
}
}
return false;
}
virtual void SyncFilters(Module* proto, void* opaque)
{
for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++)
{
this->SendFilter(proto, opaque, &(*i));
}
}
virtual std::pair<bool, std::string> AddFilter(const std::string &freeform, const std::string &type, const std::string &reason, long duration, const std::string &flags)
{
for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++)
{
if (i->freeform == freeform)
{
return std::make_pair(false, "Filter already exists");
}
}
re = pcre_compile(freeform.c_str(),0,&error,&erroffset,NULL);
if (!re)
{
ServerInstance->Log(DEFAULT,"Error in regular expression: %s at offset %d: %s\n", freeform.c_str(), erroffset, error);
ServerInstance->Log(DEFAULT,"Regular expression %s not loaded.", freeform.c_str());
return std::make_pair(false, "Error in regular expression at offset " + ConvToStr(erroffset) + ": "+error);
}
else
{
filters.push_back(PCREFilter(re, reason, type, duration, freeform, flags));
return std::make_pair(true, "");
}
}
virtual void OnRehash(userrec* user, const std::string ¶meter)
{
ConfigReader MyConf(ServerInstance);
for (int index = 0; index < MyConf.Enumerate("keyword"); index++)
{
this->DeleteFilter(MyConf.ReadValue("keyword", "pattern", index));
std::string pattern = MyConf.ReadValue("keyword", "pattern", index);
std::string reason = MyConf.ReadValue("keyword", "reason", index);
std::string action = MyConf.ReadValue("keyword", "action", index);
std::string flags = MyConf.ReadValue("keyword", "flags", index);
long gline_time = ServerInstance->Duration(MyConf.ReadValue("keyword", "duration", index));
if (action.empty())
action = "none";
if (flags.empty())
flags = "*";
re = pcre_compile(pattern.c_str(),0,&error,&erroffset,NULL);
if (!re)
{
ServerInstance->Log(DEFAULT,"Error in regular expression: %s at offset %d: %s\n", pattern.c_str(), erroffset, error);
ServerInstance->Log(DEFAULT,"Regular expression %s not loaded.", pattern.c_str());
}
else
{
filters.push_back(PCREFilter(re, reason, action, gline_time, pattern, flags));
ServerInstance->Log(DEFAULT,"Regular expression %s loaded.", pattern.c_str());
}
}
}
virtual int OnStats(char symbol, userrec* user, string_list &results)
{
if (symbol == 's')
{
std::string sn = ServerInstance->Config->ServerName;
for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++)
{
results.push_back(sn+" 223 "+user->nick+" :REGEXP:"+i->freeform+" "+i->flags+" "+i->action+" "+ConvToStr(i->gline_time)+" :"+i->reason);
}
}
return 0;
}
};
MODULE_INIT(ModuleFilterPCRE);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" +#include <pcre.h> +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "m_filter.h" + +/* $ModDesc: m_filter with regexps */ +/* $CompileFlags: exec("pcre-config --cflags") */ +/* $LinkerFlags: exec("pcre-config --libs") rpath("pcre-config --libs") -lpcre */ +/* $ModDep: m_filter.h */ + +#ifdef WINDOWS +#pragma comment(lib, "pcre.lib") +#endif + +class PCREFilter : public FilterResult +{ + public: + pcre* regexp; + + PCREFilter(pcre* r, const std::string &rea, const std::string &act, long gline_time, const std::string &pat, const std::string &flags) + : FilterResult(pat, rea, act, gline_time, flags), regexp(r) + { + } + + PCREFilter() + { + } +}; + +class ModuleFilterPCRE : public FilterBase +{ + std::vector<PCREFilter> filters; + pcre *re; + const char *error; + int erroffset; + PCREFilter fr; + + public: + ModuleFilterPCRE(InspIRCd* Me) + : FilterBase(Me, "m_filter_pcre.so") + { + OnRehash(NULL,""); + } + + virtual ~ModuleFilterPCRE() + { + } + + virtual FilterResult* FilterMatch(userrec* user, const std::string &text, int flags) + { + for (std::vector<PCREFilter>::iterator index = filters.begin(); index != filters.end(); index++) + { + /* Skip ones that dont apply to us */ + + if (!FilterBase::AppliesToMe(user, dynamic_cast<FilterResult*>(&(*index)), flags)) + continue; + + if (pcre_exec(index->regexp, NULL, text.c_str(), text.length(), 0, 0, NULL, 0) > -1) + { + fr = *index; + if (index != filters.begin()) + { + filters.erase(index); + filters.insert(filters.begin(), fr); + } + return &fr; + } + } + return NULL; + } + + virtual bool DeleteFilter(const std::string &freeform) + { + for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++) + { + if (i->freeform == freeform) + { + pcre_free((*i).regexp); + filters.erase(i); + return true; + } + } + return false; + } + + virtual void SyncFilters(Module* proto, void* opaque) + { + for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++) + { + this->SendFilter(proto, opaque, &(*i)); + } + } + + virtual std::pair<bool, std::string> AddFilter(const std::string &freeform, const std::string &type, const std::string &reason, long duration, const std::string &flags) + { + for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++) + { + if (i->freeform == freeform) + { + return std::make_pair(false, "Filter already exists"); + } + } + + re = pcre_compile(freeform.c_str(),0,&error,&erroffset,NULL); + + if (!re) + { + ServerInstance->Log(DEFAULT,"Error in regular expression: %s at offset %d: %s\n", freeform.c_str(), erroffset, error); + ServerInstance->Log(DEFAULT,"Regular expression %s not loaded.", freeform.c_str()); + return std::make_pair(false, "Error in regular expression at offset " + ConvToStr(erroffset) + ": "+error); + } + else + { + filters.push_back(PCREFilter(re, reason, type, duration, freeform, flags)); + return std::make_pair(true, ""); + } + } + + virtual void OnRehash(userrec* user, const std::string ¶meter) + { + ConfigReader MyConf(ServerInstance); + + for (int index = 0; index < MyConf.Enumerate("keyword"); index++) + { + this->DeleteFilter(MyConf.ReadValue("keyword", "pattern", index)); + + std::string pattern = MyConf.ReadValue("keyword", "pattern", index); + std::string reason = MyConf.ReadValue("keyword", "reason", index); + std::string action = MyConf.ReadValue("keyword", "action", index); + std::string flags = MyConf.ReadValue("keyword", "flags", index); + long gline_time = ServerInstance->Duration(MyConf.ReadValue("keyword", "duration", index)); + if (action.empty()) + action = "none"; + if (flags.empty()) + flags = "*"; + + re = pcre_compile(pattern.c_str(),0,&error,&erroffset,NULL); + + if (!re) + { + ServerInstance->Log(DEFAULT,"Error in regular expression: %s at offset %d: %s\n", pattern.c_str(), erroffset, error); + ServerInstance->Log(DEFAULT,"Regular expression %s not loaded.", pattern.c_str()); + } + else + { + filters.push_back(PCREFilter(re, reason, action, gline_time, pattern, flags)); + ServerInstance->Log(DEFAULT,"Regular expression %s loaded.", pattern.c_str()); + } + } + } + + virtual int OnStats(char symbol, userrec* user, string_list &results) + { + if (symbol == 's') + { + std::string sn = ServerInstance->Config->ServerName; + for (std::vector<PCREFilter>::iterator i = filters.begin(); i != filters.end(); i++) + { + results.push_back(sn+" 223 "+user->nick+" :REGEXP:"+i->freeform+" "+i->flags+" "+i->action+" "+ConvToStr(i->gline_time)+" :"+i->reason); + } + } + return 0; + } +}; + +MODULE_INIT(ModuleFilterPCRE); + diff --git a/src/modules/extra/m_httpclienttest.cpp b/src/modules/extra/m_httpclienttest.cpp index 3f74b549b..90e7a5159 100644 --- a/src/modules/extra/m_httpclienttest.cpp +++ b/src/modules/extra/m_httpclienttest.cpp @@ -1 +1,81 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "httpclient.h"
/* $ModDep: httpclient.h */
class MyModule : public Module
{
public:
MyModule(InspIRCd* Me)
: Module::Module(Me)
{
}
virtual ~MyModule()
{
}
virtual void Implements(char* List)
{
List[I_OnRequest] = List[I_OnUserJoin] = List[I_OnUserPart] = 1;
}
virtual Version GetVersion()
{
return Version(1,0,0,1,VF_VENDOR,API_VERSION);
}
virtual void OnUserJoin(userrec* user, chanrec* channel, bool &silent)
{
// method called when a user joins a channel
std::string chan = channel->name;
std::string nick = user->nick;
ServerInstance->Log(DEBUG,"User " + nick + " joined " + chan);
Module* target = ServerInstance->FindModule("m_http_client.so");
if(target)
{
HTTPClientRequest req(ServerInstance, this, target, "http://znc.in/~psychon");
req.Send();
}
else
ServerInstance->Log(DEBUG,"module not found, load it!!");
}
char* OnRequest(Request* req)
{
HTTPClientResponse* resp = (HTTPClientResponse*)req;
if(!strcmp(resp->GetId(), HTTP_CLIENT_RESPONSE))
{
ServerInstance->Log(DEBUG, resp->GetData());
}
return NULL;
}
virtual void OnUserPart(userrec* user, chanrec* channel, const std::string &partmessage, bool &silent)
{
}
};
MODULE_INIT(MyModule);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "httpclient.h" + +/* $ModDep: httpclient.h */ + +class MyModule : public Module +{ + +public: + + MyModule(InspIRCd* Me) + : Module::Module(Me) + { + } + + virtual ~MyModule() + { + } + + virtual void Implements(char* List) + { + List[I_OnRequest] = List[I_OnUserJoin] = List[I_OnUserPart] = 1; + } + + virtual Version GetVersion() + { + return Version(1,0,0,1,VF_VENDOR,API_VERSION); + } + + virtual void OnUserJoin(userrec* user, chanrec* channel, bool &silent) + { + // method called when a user joins a channel + + std::string chan = channel->name; + std::string nick = user->nick; + ServerInstance->Log(DEBUG,"User " + nick + " joined " + chan); + + Module* target = ServerInstance->FindModule("m_http_client.so"); + if(target) + { + HTTPClientRequest req(ServerInstance, this, target, "http://znc.in/~psychon"); + req.Send(); + } + else + ServerInstance->Log(DEBUG,"module not found, load it!!"); + } + + char* OnRequest(Request* req) + { + HTTPClientResponse* resp = (HTTPClientResponse*)req; + if(!strcmp(resp->GetId(), HTTP_CLIENT_RESPONSE)) + { + ServerInstance->Log(DEBUG, resp->GetData()); + } + return NULL; + } + + virtual void OnUserPart(userrec* user, chanrec* channel, const std::string &partmessage, bool &silent) + { + } + +}; + +MODULE_INIT(MyModule); + diff --git a/src/modules/extra/m_mysql.cpp b/src/modules/extra/m_mysql.cpp index eeabe5d48..6605bed3c 100644 --- a/src/modules/extra/m_mysql.cpp +++ b/src/modules/extra/m_mysql.cpp @@ -1 +1,889 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include <mysql.h>
#include <pthread.h>
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "m_sqlv2.h"
/* VERSION 2 API: With nonblocking (threaded) requests */
/* $ModDesc: SQL Service Provider module for all other m_sql* modules */
/* $CompileFlags: exec("mysql_config --include") */
/* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */
/* $ModDep: m_sqlv2.h */
/* THE NONBLOCKING MYSQL API!
*
* MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
* that instead, you should thread your program. This is what i've done here to allow for
* asyncronous SQL requests via mysql. The way this works is as follows:
*
* The module spawns a thread via pthreads, and performs its mysql queries in this thread,
* using a queue with priorities. There is a mutex on either end which prevents two threads
* adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
* worker thread wakes up, and checks if there is a request at the head of its queue.
* If there is, it processes this request, blocking the worker thread but leaving the ircd
* thread to go about its business as usual. During this period, the ircd thread is able
* to insert futher pending requests into the queue.
*
* Once the processing of a request is complete, it is removed from the incoming queue to
* an outgoing queue, and initialized as a 'response'. The worker thread then signals the
* ircd thread (via a loopback socket) of the fact a result is available, by sending the
* connection ID through the connection.
*
* The ircd thread then mutexes the queue once more, reads the outbound response off the head
* of the queue, and sends it on its way to the original calling module.
*
* XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
* The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
* threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
* however, if we were to call a module's OnRequest even from within a thread which was not the
* one the module was originally instantiated upon, there is a chance of all hell breaking loose
* if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
* corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
* gauranteed threadsafe!)
*
* For a diagram of this system please see http://www.inspircd.org/wiki/Mysql2
*/
class SQLConnection;
class Notifier;
typedef std::map<std::string, SQLConnection*> ConnMap;
bool giveup = false;
static Module* SQLModule = NULL;
static Notifier* MessagePipe = NULL;
int QueueFD = -1;
#if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
#define mysql_field_count mysql_num_fields
#endif
typedef std::deque<SQLresult*> ResultQueue;
/* A mutex to wrap around queue accesses */
pthread_mutex_t queue_mutex = PTHREAD_MUTEX_INITIALIZER;
pthread_mutex_t results_mutex = PTHREAD_MUTEX_INITIALIZER;
pthread_mutex_t logging_mutex = PTHREAD_MUTEX_INITIALIZER;
/** Represents a mysql result set
*/
class MySQLresult : public SQLresult
{
int currentrow;
std::vector<std::string> colnames;
std::vector<SQLfieldList> fieldlists;
SQLfieldMap* fieldmap;
SQLfieldMap fieldmap2;
SQLfieldList emptyfieldlist;
int rows;
public:
MySQLresult(Module* self, Module* to, MYSQL_RES* res, int affected_rows, unsigned int id) : SQLresult(self, to, id), currentrow(0), fieldmap(NULL)
{
/* A number of affected rows from from mysql_affected_rows.
*/
fieldlists.clear();
rows = 0;
if (affected_rows >= 1)
{
rows = affected_rows;
fieldlists.resize(rows);
}
unsigned int field_count = 0;
if (res)
{
MYSQL_ROW row;
int n = 0;
while ((row = mysql_fetch_row(res)))
{
if (fieldlists.size() < (unsigned int)rows+1)
{
fieldlists.resize(fieldlists.size()+1);
}
field_count = 0;
MYSQL_FIELD *fields = mysql_fetch_fields(res);
if(mysql_num_fields(res) == 0)
break;
if (fields && mysql_num_fields(res))
{
colnames.clear();
while (field_count < mysql_num_fields(res))
{
std::string a = (fields[field_count].name ? fields[field_count].name : "");
std::string b = (row[field_count] ? row[field_count] : "");
SQLfield sqlf(b, !row[field_count]);
colnames.push_back(a);
fieldlists[n].push_back(sqlf);
field_count++;
}
n++;
}
rows++;
}
mysql_free_result(res);
}
}
MySQLresult(Module* self, Module* to, SQLerror e, unsigned int id) : SQLresult(self, to, id), currentrow(0)
{
rows = 0;
error = e;
}
~MySQLresult()
{
}
virtual int Rows()
{
return rows;
}
virtual int Cols()
{
return colnames.size();
}
virtual std::string ColName(int column)
{
if (column < (int)colnames.size())
{
return colnames[column];
}
else
{
throw SQLbadColName();
}
return "";
}
virtual int ColNum(const std::string &column)
{
for (unsigned int i = 0; i < colnames.size(); i++)
{
if (column == colnames[i])
return i;
}
throw SQLbadColName();
return 0;
}
virtual SQLfield GetValue(int row, int column)
{
if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
{
return fieldlists[row][column];
}
throw SQLbadColName();
/* XXX: We never actually get here because of the throw */
return SQLfield("",true);
}
virtual SQLfieldList& GetRow()
{
if (currentrow < rows)
return fieldlists[currentrow];
else
return emptyfieldlist;
}
virtual SQLfieldMap& GetRowMap()
{
fieldmap2.clear();
if (currentrow < rows)
{
for (int i = 0; i < Cols(); i++)
{
fieldmap2.insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
}
currentrow++;
}
return fieldmap2;
}
virtual SQLfieldList* GetRowPtr()
{
SQLfieldList* fieldlist = new SQLfieldList();
if (currentrow < rows)
{
for (int i = 0; i < Rows(); i++)
{
fieldlist->push_back(fieldlists[currentrow][i]);
}
currentrow++;
}
return fieldlist;
}
virtual SQLfieldMap* GetRowMapPtr()
{
fieldmap = new SQLfieldMap();
if (currentrow < rows)
{
for (int i = 0; i < Cols(); i++)
{
fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
}
currentrow++;
}
return fieldmap;
}
virtual void Free(SQLfieldMap* fm)
{
delete fm;
}
virtual void Free(SQLfieldList* fl)
{
delete fl;
}
};
class SQLConnection;
void NotifyMainThread(SQLConnection* connection_with_new_result);
/** Represents a connection to a mysql database
*/
class SQLConnection : public classbase
{
protected:
MYSQL connection;
MYSQL_RES *res;
MYSQL_ROW row;
SQLhost host;
std::map<std::string,std::string> thisrow;
bool Enabled;
public:
QueryQueue queue;
ResultQueue rq;
// This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
SQLConnection(const SQLhost &hi) : host(hi), Enabled(false)
{
}
~SQLConnection()
{
Close();
}
// This method connects to the database using the credentials supplied to the constructor, and returns
// true upon success.
bool Connect()
{
unsigned int timeout = 1;
mysql_init(&connection);
mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
return mysql_real_connect(&connection, host.host.c_str(), host.user.c_str(), host.pass.c_str(), host.name.c_str(), host.port, NULL, 0);
}
void DoLeadingQuery()
{
if (!CheckConnection())
return;
/* Parse the command string and dispatch it to mysql */
SQLrequest& req = queue.front();
/* Pointer to the buffer we screw around with substitution in */
char* query;
/* Pointer to the current end of query, where we append new stuff */
char* queryend;
/* Total length of the unescaped parameters */
unsigned long paramlen;
/* Total length of query, used for binary-safety in mysql_real_query */
unsigned long querylength = 0;
paramlen = 0;
for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
{
paramlen += i->size();
}
/* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
* sizeofquery + (totalparamlength*2) + 1
*
* The +1 is for null-terminating the string for mysql_real_escape_string
*/
query = new char[req.query.q.length() + (paramlen*2) + 1];
queryend = query;
/* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting
* the parameters into it...
*/
for(unsigned long i = 0; i < req.query.q.length(); i++)
{
if(req.query.q[i] == '?')
{
/* We found a place to substitute..what fun.
* use mysql calls to escape and write the
* escaped string onto the end of our query buffer,
* then we "just" need to make sure queryend is
* pointing at the right place.
*/
if(req.query.p.size())
{
unsigned long len = mysql_real_escape_string(&connection, queryend, req.query.p.front().c_str(), req.query.p.front().length());
queryend += len;
req.query.p.pop_front();
}
else
break;
}
else
{
*queryend = req.query.q[i];
queryend++;
}
querylength++;
}
*queryend = 0;
pthread_mutex_lock(&queue_mutex);
req.query.q = query;
pthread_mutex_unlock(&queue_mutex);
if (!mysql_real_query(&connection, req.query.q.data(), req.query.q.length()))
{
/* Successfull query */
res = mysql_use_result(&connection);
unsigned long rows = mysql_affected_rows(&connection);
MySQLresult* r = new MySQLresult(SQLModule, req.GetSource(), res, rows, req.id);
r->dbid = this->GetID();
r->query = req.query.q;
/* Put this new result onto the results queue.
* XXX: Remember to mutex the queue!
*/
pthread_mutex_lock(&results_mutex);
rq.push_back(r);
pthread_mutex_unlock(&results_mutex);
}
else
{
/* XXX: See /usr/include/mysql/mysqld_error.h for a list of
* possible error numbers and error messages */
SQLerror e(QREPLY_FAIL, ConvToStr(mysql_errno(&connection)) + std::string(": ") + mysql_error(&connection));
MySQLresult* r = new MySQLresult(SQLModule, req.GetSource(), e, req.id);
r->dbid = this->GetID();
r->query = req.query.q;
pthread_mutex_lock(&results_mutex);
rq.push_back(r);
pthread_mutex_unlock(&results_mutex);
}
/* Now signal the main thread that we've got a result to process.
* Pass them this connection id as what to examine
*/
delete[] query;
NotifyMainThread(this);
}
bool ConnectionLost()
{
if (&connection) {
return (mysql_ping(&connection) != 0);
}
else return false;
}
bool CheckConnection()
{
if (ConnectionLost()) {
return Connect();
}
else return true;
}
std::string GetError()
{
return mysql_error(&connection);
}
const std::string& GetID()
{
return host.id;
}
std::string GetHost()
{
return host.host;
}
void SetEnable(bool Enable)
{
Enabled = Enable;
}
bool IsEnabled()
{
return Enabled;
}
void Close()
{
mysql_close(&connection);
}
const SQLhost& GetConfHost()
{
return host;
}
};
ConnMap Connections;
bool HasHost(const SQLhost &host)
{
for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); iter++)
{
if (host == iter->second->GetConfHost())
return true;
}
return false;
}
bool HostInConf(ConfigReader* conf, const SQLhost &h)
{
for(int i = 0; i < conf->Enumerate("database"); i++)
{
SQLhost host;
host.id = conf->ReadValue("database", "id", i);
host.host = conf->ReadValue("database", "hostname", i);
host.port = conf->ReadInteger("database", "port", i, true);
host.name = conf->ReadValue("database", "name", i);
host.user = conf->ReadValue("database", "username", i);
host.pass = conf->ReadValue("database", "password", i);
host.ssl = conf->ReadFlag("database", "ssl", i);
if (h == host)
return true;
}
return false;
}
void ClearOldConnections(ConfigReader* conf)
{
ConnMap::iterator i,safei;
for (i = Connections.begin(); i != Connections.end(); i++)
{
if (!HostInConf(conf, i->second->GetConfHost()))
{
DELETE(i->second);
safei = i;
--i;
Connections.erase(safei);
}
}
}
void ClearAllConnections()
{
ConnMap::iterator i;
while ((i = Connections.begin()) != Connections.end())
{
Connections.erase(i);
DELETE(i->second);
}
}
void ConnectDatabases(InspIRCd* ServerInstance)
{
for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
{
if (i->second->IsEnabled())
continue;
i->second->SetEnable(true);
if (!i->second->Connect())
{
/* XXX: MUTEX */
pthread_mutex_lock(&logging_mutex);
ServerInstance->Log(DEFAULT,"SQL: Failed to connect database "+i->second->GetHost()+": Error: "+i->second->GetError());
i->second->SetEnable(false);
pthread_mutex_unlock(&logging_mutex);
}
}
}
void LoadDatabases(ConfigReader* conf, InspIRCd* ServerInstance)
{
ClearOldConnections(conf);
for (int j =0; j < conf->Enumerate("database"); j++)
{
SQLhost host;
host.id = conf->ReadValue("database", "id", j);
host.host = conf->ReadValue("database", "hostname", j);
host.port = conf->ReadInteger("database", "port", j, true);
host.name = conf->ReadValue("database", "name", j);
host.user = conf->ReadValue("database", "username", j);
host.pass = conf->ReadValue("database", "password", j);
host.ssl = conf->ReadFlag("database", "ssl", j);
if (HasHost(host))
continue;
if (!host.id.empty() && !host.host.empty() && !host.name.empty() && !host.user.empty() && !host.pass.empty())
{
SQLConnection* ThisSQL = new SQLConnection(host);
Connections[host.id] = ThisSQL;
}
}
ConnectDatabases(ServerInstance);
}
char FindCharId(const std::string &id)
{
char i = 1;
for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); ++iter, ++i)
{
if (iter->first == id)
{
return i;
}
}
return 0;
}
ConnMap::iterator GetCharId(char id)
{
char i = 1;
for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); ++iter, ++i)
{
if (i == id)
return iter;
}
return Connections.end();
}
void NotifyMainThread(SQLConnection* connection_with_new_result)
{
/* Here we write() to the socket the main thread has open
* and we connect()ed back to before our thread became active.
* The main thread is using a nonblocking socket tied into
* the socket engine, so they wont block and they'll receive
* nearly instant notification. Because we're in a seperate
* thread, we can just use standard connect(), and we can
* block if we like. We just send the connection id of the
* connection back.
*
* NOTE: We only send a single char down the connection, this
* way we know it wont get a partial read at the other end if
* the system is especially congested (see bug #263).
* The function FindCharId translates a connection name into a
* one character id, and GetCharId translates a character id
* back into an iterator.
*/
char id = FindCharId(connection_with_new_result->GetID());
send(QueueFD, &id, 1, 0);
}
void* DispatcherThread(void* arg);
/** Used by m_mysql to notify one thread when the other has a result
*/
class Notifier : public InspSocket
{
insp_sockaddr sock_us;
socklen_t uslen;
public:
/* Create a socket on a random port. Let the tcp stack allocate us an available port */
#ifdef IPV6
Notifier(InspIRCd* SI) : InspSocket(SI, "::1", 0, true, 3000)
#else
Notifier(InspIRCd* SI) : InspSocket(SI, "127.0.0.1", 0, true, 3000)
#endif
{
uslen = sizeof(sock_us);
if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
{
throw ModuleException("Could not create random listening port on localhost");
}
}
Notifier(InspIRCd* SI, int newfd, char* ip) : InspSocket(SI, newfd, ip)
{
}
/* Using getsockname and ntohs, we can determine which port number we were allocated */
int GetPort()
{
#ifdef IPV6
return ntohs(sock_us.sin6_port);
#else
return ntohs(sock_us.sin_port);
#endif
}
virtual int OnIncomingConnection(int newsock, char* ip)
{
Notifier* n = new Notifier(this->Instance, newsock, ip);
n = n; /* Stop bitching at me, GCC */
return true;
}
virtual bool OnDataReady()
{
char data = 0;
/* NOTE: Only a single character is read so we know we
* cant get a partial read. (We've been told that theres
* data waiting, so we wont ever get EAGAIN)
* The function GetCharId translates a single character
* back into an iterator.
*/
if (read(this->GetFd(), &data, 1) > 0)
{
ConnMap::iterator iter = GetCharId(data);
if (iter != Connections.end())
{
/* Lock the mutex, send back the data */
pthread_mutex_lock(&results_mutex);
ResultQueue::iterator n = iter->second->rq.begin();
(*n)->Send();
iter->second->rq.pop_front();
pthread_mutex_unlock(&results_mutex);
return true;
}
/* No error, but unknown id */
return true;
}
/* Erk, error on descriptor! */
return false;
}
};
/** MySQL module
*/
class ModuleSQL : public Module
{
public:
ConfigReader *Conf;
InspIRCd* PublicServerInstance;
pthread_t Dispatcher;
int currid;
bool rehashing;
ModuleSQL(InspIRCd* Me)
: Module::Module(Me), rehashing(false)
{
ServerInstance->UseInterface("SQLutils");
Conf = new ConfigReader(ServerInstance);
PublicServerInstance = ServerInstance;
currid = 0;
SQLModule = this;
MessagePipe = new Notifier(ServerInstance);
pthread_attr_t attribs;
pthread_attr_init(&attribs);
pthread_attr_setdetachstate(&attribs, PTHREAD_CREATE_DETACHED);
if (pthread_create(&this->Dispatcher, &attribs, DispatcherThread, (void *)this) != 0)
{
throw ModuleException("m_mysql: Failed to create dispatcher thread: " + std::string(strerror(errno)));
}
if (!ServerInstance->PublishFeature("SQL", this))
{
/* Tell worker thread to exit NOW */
giveup = true;
throw ModuleException("m_mysql: Unable to publish feature 'SQL'");
}
ServerInstance->PublishInterface("SQL", this);
}
virtual ~ModuleSQL()
{
giveup = true;
ClearAllConnections();
DELETE(Conf);
ServerInstance->UnpublishInterface("SQL", this);
ServerInstance->UnpublishFeature("SQL");
ServerInstance->DoneWithInterface("SQLutils");
}
void Implements(char* List)
{
List[I_OnRehash] = List[I_OnRequest] = 1;
}
unsigned long NewID()
{
if (currid+1 == 0)
currid++;
return ++currid;
}
char* OnRequest(Request* request)
{
if(strcmp(SQLREQID, request->GetId()) == 0)
{
SQLrequest* req = (SQLrequest*)request;
/* XXX: Lock */
pthread_mutex_lock(&queue_mutex);
ConnMap::iterator iter;
char* returnval = NULL;
if((iter = Connections.find(req->dbid)) != Connections.end())
{
req->id = NewID();
iter->second->queue.push(*req);
returnval = SQLSUCCESS;
}
else
{
req->error.Id(BAD_DBID);
}
pthread_mutex_unlock(&queue_mutex);
/* XXX: Unlock */
return returnval;
}
return NULL;
}
virtual void OnRehash(userrec* user, const std::string ¶meter)
{
rehashing = true;
}
virtual Version GetVersion()
{
return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);
}
};
void* DispatcherThread(void* arg)
{
ModuleSQL* thismodule = (ModuleSQL*)arg;
LoadDatabases(thismodule->Conf, thismodule->PublicServerInstance);
/* Connect back to the Notifier */
if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
{
/* crap, we're out of sockets... */
return NULL;
}
insp_sockaddr addr;
#ifdef IPV6
insp_aton("::1", &addr.sin6_addr);
addr.sin6_family = AF_FAMILY;
addr.sin6_port = htons(MessagePipe->GetPort());
#else
insp_inaddr ia;
insp_aton("127.0.0.1", &ia);
addr.sin_family = AF_FAMILY;
addr.sin_addr = ia;
addr.sin_port = htons(MessagePipe->GetPort());
#endif
if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
{
/* wtf, we cant connect to it, but we just created it! */
return NULL;
}
while (!giveup)
{
if (thismodule->rehashing)
{
/* XXX: Lock */
pthread_mutex_lock(&queue_mutex);
thismodule->rehashing = false;
LoadDatabases(thismodule->Conf, thismodule->PublicServerInstance);
pthread_mutex_unlock(&queue_mutex);
/* XXX: Unlock */
}
SQLConnection* conn = NULL;
/* XXX: Lock here for safety */
pthread_mutex_lock(&queue_mutex);
for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
{
if (i->second->queue.totalsize())
{
conn = i->second;
break;
}
}
pthread_mutex_unlock(&queue_mutex);
/* XXX: Unlock */
/* Theres an item! */
if (conn)
{
conn->DoLeadingQuery();
/* XXX: Lock */
pthread_mutex_lock(&queue_mutex);
conn->queue.pop();
pthread_mutex_unlock(&queue_mutex);
/* XXX: Unlock */
}
usleep(50);
}
return NULL;
}
MODULE_INIT(ModuleSQL);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" +#include <mysql.h> +#include <pthread.h> +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "m_sqlv2.h" + +/* VERSION 2 API: With nonblocking (threaded) requests */ + +/* $ModDesc: SQL Service Provider module for all other m_sql* modules */ +/* $CompileFlags: exec("mysql_config --include") */ +/* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */ +/* $ModDep: m_sqlv2.h */ + +/* THE NONBLOCKING MYSQL API! + * + * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend + * that instead, you should thread your program. This is what i've done here to allow for + * asyncronous SQL requests via mysql. The way this works is as follows: + * + * The module spawns a thread via pthreads, and performs its mysql queries in this thread, + * using a queue with priorities. There is a mutex on either end which prevents two threads + * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the + * worker thread wakes up, and checks if there is a request at the head of its queue. + * If there is, it processes this request, blocking the worker thread but leaving the ircd + * thread to go about its business as usual. During this period, the ircd thread is able + * to insert futher pending requests into the queue. + * + * Once the processing of a request is complete, it is removed from the incoming queue to + * an outgoing queue, and initialized as a 'response'. The worker thread then signals the + * ircd thread (via a loopback socket) of the fact a result is available, by sending the + * connection ID through the connection. + * + * The ircd thread then mutexes the queue once more, reads the outbound response off the head + * of the queue, and sends it on its way to the original calling module. + * + * XXX: You might be asking "why doesnt he just send the response from within the worker thread?" + * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not + * threadsafe. This module is designed to be threadsafe and is careful with its use of threads, + * however, if we were to call a module's OnRequest even from within a thread which was not the + * one the module was originally instantiated upon, there is a chance of all hell breaking loose + * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data + * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100% + * gauranteed threadsafe!) + * + * For a diagram of this system please see http://www.inspircd.org/wiki/Mysql2 + */ + + +class SQLConnection; +class Notifier; + + +typedef std::map<std::string, SQLConnection*> ConnMap; +bool giveup = false; +static Module* SQLModule = NULL; +static Notifier* MessagePipe = NULL; +int QueueFD = -1; + + +#if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224 +#define mysql_field_count mysql_num_fields +#endif + +typedef std::deque<SQLresult*> ResultQueue; + +/* A mutex to wrap around queue accesses */ +pthread_mutex_t queue_mutex = PTHREAD_MUTEX_INITIALIZER; + +pthread_mutex_t results_mutex = PTHREAD_MUTEX_INITIALIZER; + +pthread_mutex_t logging_mutex = PTHREAD_MUTEX_INITIALIZER; + +/** Represents a mysql result set + */ +class MySQLresult : public SQLresult +{ + int currentrow; + std::vector<std::string> colnames; + std::vector<SQLfieldList> fieldlists; + SQLfieldMap* fieldmap; + SQLfieldMap fieldmap2; + SQLfieldList emptyfieldlist; + int rows; + public: + + MySQLresult(Module* self, Module* to, MYSQL_RES* res, int affected_rows, unsigned int id) : SQLresult(self, to, id), currentrow(0), fieldmap(NULL) + { + /* A number of affected rows from from mysql_affected_rows. + */ + fieldlists.clear(); + rows = 0; + if (affected_rows >= 1) + { + rows = affected_rows; + fieldlists.resize(rows); + } + unsigned int field_count = 0; + if (res) + { + MYSQL_ROW row; + int n = 0; + while ((row = mysql_fetch_row(res))) + { + if (fieldlists.size() < (unsigned int)rows+1) + { + fieldlists.resize(fieldlists.size()+1); + } + field_count = 0; + MYSQL_FIELD *fields = mysql_fetch_fields(res); + if(mysql_num_fields(res) == 0) + break; + if (fields && mysql_num_fields(res)) + { + colnames.clear(); + while (field_count < mysql_num_fields(res)) + { + std::string a = (fields[field_count].name ? fields[field_count].name : ""); + std::string b = (row[field_count] ? row[field_count] : ""); + SQLfield sqlf(b, !row[field_count]); + colnames.push_back(a); + fieldlists[n].push_back(sqlf); + field_count++; + } + n++; + } + rows++; + } + mysql_free_result(res); + } + } + + MySQLresult(Module* self, Module* to, SQLerror e, unsigned int id) : SQLresult(self, to, id), currentrow(0) + { + rows = 0; + error = e; + } + + ~MySQLresult() + { + } + + virtual int Rows() + { + return rows; + } + + virtual int Cols() + { + return colnames.size(); + } + + virtual std::string ColName(int column) + { + if (column < (int)colnames.size()) + { + return colnames[column]; + } + else + { + throw SQLbadColName(); + } + return ""; + } + + virtual int ColNum(const std::string &column) + { + for (unsigned int i = 0; i < colnames.size(); i++) + { + if (column == colnames[i]) + return i; + } + throw SQLbadColName(); + return 0; + } + + virtual SQLfield GetValue(int row, int column) + { + if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols())) + { + return fieldlists[row][column]; + } + + throw SQLbadColName(); + + /* XXX: We never actually get here because of the throw */ + return SQLfield("",true); + } + + virtual SQLfieldList& GetRow() + { + if (currentrow < rows) + return fieldlists[currentrow]; + else + return emptyfieldlist; + } + + virtual SQLfieldMap& GetRowMap() + { + fieldmap2.clear(); + + if (currentrow < rows) + { + for (int i = 0; i < Cols(); i++) + { + fieldmap2.insert(std::make_pair(colnames[i],GetValue(currentrow, i))); + } + currentrow++; + } + + return fieldmap2; + } + + virtual SQLfieldList* GetRowPtr() + { + SQLfieldList* fieldlist = new SQLfieldList(); + + if (currentrow < rows) + { + for (int i = 0; i < Rows(); i++) + { + fieldlist->push_back(fieldlists[currentrow][i]); + } + currentrow++; + } + return fieldlist; + } + + virtual SQLfieldMap* GetRowMapPtr() + { + fieldmap = new SQLfieldMap(); + + if (currentrow < rows) + { + for (int i = 0; i < Cols(); i++) + { + fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i))); + } + currentrow++; + } + + return fieldmap; + } + + virtual void Free(SQLfieldMap* fm) + { + delete fm; + } + + virtual void Free(SQLfieldList* fl) + { + delete fl; + } +}; + +class SQLConnection; + +void NotifyMainThread(SQLConnection* connection_with_new_result); + +/** Represents a connection to a mysql database + */ +class SQLConnection : public classbase +{ + protected: + + MYSQL connection; + MYSQL_RES *res; + MYSQL_ROW row; + SQLhost host; + std::map<std::string,std::string> thisrow; + bool Enabled; + + public: + + QueryQueue queue; + ResultQueue rq; + + // This constructor creates an SQLConnection object with the given credentials, but does not connect yet. + SQLConnection(const SQLhost &hi) : host(hi), Enabled(false) + { + } + + ~SQLConnection() + { + Close(); + } + + // This method connects to the database using the credentials supplied to the constructor, and returns + // true upon success. + bool Connect() + { + unsigned int timeout = 1; + mysql_init(&connection); + mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout); + return mysql_real_connect(&connection, host.host.c_str(), host.user.c_str(), host.pass.c_str(), host.name.c_str(), host.port, NULL, 0); + } + + void DoLeadingQuery() + { + if (!CheckConnection()) + return; + + /* Parse the command string and dispatch it to mysql */ + SQLrequest& req = queue.front(); + + /* Pointer to the buffer we screw around with substitution in */ + char* query; + + /* Pointer to the current end of query, where we append new stuff */ + char* queryend; + + /* Total length of the unescaped parameters */ + unsigned long paramlen; + + /* Total length of query, used for binary-safety in mysql_real_query */ + unsigned long querylength = 0; + + paramlen = 0; + + for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++) + { + paramlen += i->size(); + } + + /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be. + * sizeofquery + (totalparamlength*2) + 1 + * + * The +1 is for null-terminating the string for mysql_real_escape_string + */ + + query = new char[req.query.q.length() + (paramlen*2) + 1]; + queryend = query; + + /* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting + * the parameters into it... + */ + + for(unsigned long i = 0; i < req.query.q.length(); i++) + { + if(req.query.q[i] == '?') + { + /* We found a place to substitute..what fun. + * use mysql calls to escape and write the + * escaped string onto the end of our query buffer, + * then we "just" need to make sure queryend is + * pointing at the right place. + */ + if(req.query.p.size()) + { + unsigned long len = mysql_real_escape_string(&connection, queryend, req.query.p.front().c_str(), req.query.p.front().length()); + + queryend += len; + req.query.p.pop_front(); + } + else + break; + } + else + { + *queryend = req.query.q[i]; + queryend++; + } + querylength++; + } + + *queryend = 0; + + pthread_mutex_lock(&queue_mutex); + req.query.q = query; + pthread_mutex_unlock(&queue_mutex); + + if (!mysql_real_query(&connection, req.query.q.data(), req.query.q.length())) + { + /* Successfull query */ + res = mysql_use_result(&connection); + unsigned long rows = mysql_affected_rows(&connection); + MySQLresult* r = new MySQLresult(SQLModule, req.GetSource(), res, rows, req.id); + r->dbid = this->GetID(); + r->query = req.query.q; + /* Put this new result onto the results queue. + * XXX: Remember to mutex the queue! + */ + pthread_mutex_lock(&results_mutex); + rq.push_back(r); + pthread_mutex_unlock(&results_mutex); + } + else + { + /* XXX: See /usr/include/mysql/mysqld_error.h for a list of + * possible error numbers and error messages */ + SQLerror e(QREPLY_FAIL, ConvToStr(mysql_errno(&connection)) + std::string(": ") + mysql_error(&connection)); + MySQLresult* r = new MySQLresult(SQLModule, req.GetSource(), e, req.id); + r->dbid = this->GetID(); + r->query = req.query.q; + + pthread_mutex_lock(&results_mutex); + rq.push_back(r); + pthread_mutex_unlock(&results_mutex); + } + + /* Now signal the main thread that we've got a result to process. + * Pass them this connection id as what to examine + */ + + delete[] query; + + NotifyMainThread(this); + } + + bool ConnectionLost() + { + if (&connection) { + return (mysql_ping(&connection) != 0); + } + else return false; + } + + bool CheckConnection() + { + if (ConnectionLost()) { + return Connect(); + } + else return true; + } + + std::string GetError() + { + return mysql_error(&connection); + } + + const std::string& GetID() + { + return host.id; + } + + std::string GetHost() + { + return host.host; + } + + void SetEnable(bool Enable) + { + Enabled = Enable; + } + + bool IsEnabled() + { + return Enabled; + } + + void Close() + { + mysql_close(&connection); + } + + const SQLhost& GetConfHost() + { + return host; + } + +}; + +ConnMap Connections; + +bool HasHost(const SQLhost &host) +{ + for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); iter++) + { + if (host == iter->second->GetConfHost()) + return true; + } + return false; +} + +bool HostInConf(ConfigReader* conf, const SQLhost &h) +{ + for(int i = 0; i < conf->Enumerate("database"); i++) + { + SQLhost host; + host.id = conf->ReadValue("database", "id", i); + host.host = conf->ReadValue("database", "hostname", i); + host.port = conf->ReadInteger("database", "port", i, true); + host.name = conf->ReadValue("database", "name", i); + host.user = conf->ReadValue("database", "username", i); + host.pass = conf->ReadValue("database", "password", i); + host.ssl = conf->ReadFlag("database", "ssl", i); + if (h == host) + return true; + } + return false; +} + +void ClearOldConnections(ConfigReader* conf) +{ + ConnMap::iterator i,safei; + for (i = Connections.begin(); i != Connections.end(); i++) + { + if (!HostInConf(conf, i->second->GetConfHost())) + { + DELETE(i->second); + safei = i; + --i; + Connections.erase(safei); + } + } +} + +void ClearAllConnections() +{ + ConnMap::iterator i; + while ((i = Connections.begin()) != Connections.end()) + { + Connections.erase(i); + DELETE(i->second); + } +} + +void ConnectDatabases(InspIRCd* ServerInstance) +{ + for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++) + { + if (i->second->IsEnabled()) + continue; + + i->second->SetEnable(true); + if (!i->second->Connect()) + { + /* XXX: MUTEX */ + pthread_mutex_lock(&logging_mutex); + ServerInstance->Log(DEFAULT,"SQL: Failed to connect database "+i->second->GetHost()+": Error: "+i->second->GetError()); + i->second->SetEnable(false); + pthread_mutex_unlock(&logging_mutex); + } + } +} + +void LoadDatabases(ConfigReader* conf, InspIRCd* ServerInstance) +{ + ClearOldConnections(conf); + for (int j =0; j < conf->Enumerate("database"); j++) + { + SQLhost host; + host.id = conf->ReadValue("database", "id", j); + host.host = conf->ReadValue("database", "hostname", j); + host.port = conf->ReadInteger("database", "port", j, true); + host.name = conf->ReadValue("database", "name", j); + host.user = conf->ReadValue("database", "username", j); + host.pass = conf->ReadValue("database", "password", j); + host.ssl = conf->ReadFlag("database", "ssl", j); + + if (HasHost(host)) + continue; + + if (!host.id.empty() && !host.host.empty() && !host.name.empty() && !host.user.empty() && !host.pass.empty()) + { + SQLConnection* ThisSQL = new SQLConnection(host); + Connections[host.id] = ThisSQL; + } + } + ConnectDatabases(ServerInstance); +} + +char FindCharId(const std::string &id) +{ + char i = 1; + for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); ++iter, ++i) + { + if (iter->first == id) + { + return i; + } + } + return 0; +} + +ConnMap::iterator GetCharId(char id) +{ + char i = 1; + for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); ++iter, ++i) + { + if (i == id) + return iter; + } + return Connections.end(); +} + +void NotifyMainThread(SQLConnection* connection_with_new_result) +{ + /* Here we write() to the socket the main thread has open + * and we connect()ed back to before our thread became active. + * The main thread is using a nonblocking socket tied into + * the socket engine, so they wont block and they'll receive + * nearly instant notification. Because we're in a seperate + * thread, we can just use standard connect(), and we can + * block if we like. We just send the connection id of the + * connection back. + * + * NOTE: We only send a single char down the connection, this + * way we know it wont get a partial read at the other end if + * the system is especially congested (see bug #263). + * The function FindCharId translates a connection name into a + * one character id, and GetCharId translates a character id + * back into an iterator. + */ + char id = FindCharId(connection_with_new_result->GetID()); + send(QueueFD, &id, 1, 0); +} + +void* DispatcherThread(void* arg); + +/** Used by m_mysql to notify one thread when the other has a result + */ +class Notifier : public InspSocket +{ + insp_sockaddr sock_us; + socklen_t uslen; + + + public: + + /* Create a socket on a random port. Let the tcp stack allocate us an available port */ +#ifdef IPV6 + Notifier(InspIRCd* SI) : InspSocket(SI, "::1", 0, true, 3000) +#else + Notifier(InspIRCd* SI) : InspSocket(SI, "127.0.0.1", 0, true, 3000) +#endif + { + uslen = sizeof(sock_us); + if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen)) + { + throw ModuleException("Could not create random listening port on localhost"); + } + } + + Notifier(InspIRCd* SI, int newfd, char* ip) : InspSocket(SI, newfd, ip) + { + } + + /* Using getsockname and ntohs, we can determine which port number we were allocated */ + int GetPort() + { +#ifdef IPV6 + return ntohs(sock_us.sin6_port); +#else + return ntohs(sock_us.sin_port); +#endif + } + + virtual int OnIncomingConnection(int newsock, char* ip) + { + Notifier* n = new Notifier(this->Instance, newsock, ip); + n = n; /* Stop bitching at me, GCC */ + return true; + } + + virtual bool OnDataReady() + { + char data = 0; + /* NOTE: Only a single character is read so we know we + * cant get a partial read. (We've been told that theres + * data waiting, so we wont ever get EAGAIN) + * The function GetCharId translates a single character + * back into an iterator. + */ + if (read(this->GetFd(), &data, 1) > 0) + { + ConnMap::iterator iter = GetCharId(data); + if (iter != Connections.end()) + { + /* Lock the mutex, send back the data */ + pthread_mutex_lock(&results_mutex); + ResultQueue::iterator n = iter->second->rq.begin(); + (*n)->Send(); + iter->second->rq.pop_front(); + pthread_mutex_unlock(&results_mutex); + return true; + } + /* No error, but unknown id */ + return true; + } + + /* Erk, error on descriptor! */ + return false; + } +}; + +/** MySQL module + */ +class ModuleSQL : public Module +{ + public: + + ConfigReader *Conf; + InspIRCd* PublicServerInstance; + pthread_t Dispatcher; + int currid; + bool rehashing; + + ModuleSQL(InspIRCd* Me) + : Module::Module(Me), rehashing(false) + { + ServerInstance->UseInterface("SQLutils"); + + Conf = new ConfigReader(ServerInstance); + PublicServerInstance = ServerInstance; + currid = 0; + SQLModule = this; + + MessagePipe = new Notifier(ServerInstance); + + pthread_attr_t attribs; + pthread_attr_init(&attribs); + pthread_attr_setdetachstate(&attribs, PTHREAD_CREATE_DETACHED); + if (pthread_create(&this->Dispatcher, &attribs, DispatcherThread, (void *)this) != 0) + { + throw ModuleException("m_mysql: Failed to create dispatcher thread: " + std::string(strerror(errno))); + } + + if (!ServerInstance->PublishFeature("SQL", this)) + { + /* Tell worker thread to exit NOW */ + giveup = true; + throw ModuleException("m_mysql: Unable to publish feature 'SQL'"); + } + + ServerInstance->PublishInterface("SQL", this); + } + + virtual ~ModuleSQL() + { + giveup = true; + ClearAllConnections(); + DELETE(Conf); + ServerInstance->UnpublishInterface("SQL", this); + ServerInstance->UnpublishFeature("SQL"); + ServerInstance->DoneWithInterface("SQLutils"); + } + + + void Implements(char* List) + { + List[I_OnRehash] = List[I_OnRequest] = 1; + } + + unsigned long NewID() + { + if (currid+1 == 0) + currid++; + return ++currid; + } + + char* OnRequest(Request* request) + { + if(strcmp(SQLREQID, request->GetId()) == 0) + { + SQLrequest* req = (SQLrequest*)request; + + /* XXX: Lock */ + pthread_mutex_lock(&queue_mutex); + + ConnMap::iterator iter; + + char* returnval = NULL; + + if((iter = Connections.find(req->dbid)) != Connections.end()) + { + req->id = NewID(); + iter->second->queue.push(*req); + returnval = SQLSUCCESS; + } + else + { + req->error.Id(BAD_DBID); + } + + pthread_mutex_unlock(&queue_mutex); + /* XXX: Unlock */ + + return returnval; + } + + return NULL; + } + + virtual void OnRehash(userrec* user, const std::string ¶meter) + { + rehashing = true; + } + + virtual Version GetVersion() + { + return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION); + } + +}; + +void* DispatcherThread(void* arg) +{ + ModuleSQL* thismodule = (ModuleSQL*)arg; + LoadDatabases(thismodule->Conf, thismodule->PublicServerInstance); + + /* Connect back to the Notifier */ + + if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1) + { + /* crap, we're out of sockets... */ + return NULL; + } + + insp_sockaddr addr; + +#ifdef IPV6 + insp_aton("::1", &addr.sin6_addr); + addr.sin6_family = AF_FAMILY; + addr.sin6_port = htons(MessagePipe->GetPort()); +#else + insp_inaddr ia; + insp_aton("127.0.0.1", &ia); + addr.sin_family = AF_FAMILY; + addr.sin_addr = ia; + addr.sin_port = htons(MessagePipe->GetPort()); +#endif + + if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1) + { + /* wtf, we cant connect to it, but we just created it! */ + return NULL; + } + + while (!giveup) + { + if (thismodule->rehashing) + { + /* XXX: Lock */ + pthread_mutex_lock(&queue_mutex); + thismodule->rehashing = false; + LoadDatabases(thismodule->Conf, thismodule->PublicServerInstance); + pthread_mutex_unlock(&queue_mutex); + /* XXX: Unlock */ + } + + SQLConnection* conn = NULL; + /* XXX: Lock here for safety */ + pthread_mutex_lock(&queue_mutex); + for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++) + { + if (i->second->queue.totalsize()) + { + conn = i->second; + break; + } + } + pthread_mutex_unlock(&queue_mutex); + /* XXX: Unlock */ + + /* Theres an item! */ + if (conn) + { + conn->DoLeadingQuery(); + + /* XXX: Lock */ + pthread_mutex_lock(&queue_mutex); + conn->queue.pop(); + pthread_mutex_unlock(&queue_mutex); + /* XXX: Unlock */ + } + + usleep(50); + } + + return NULL; +} + +MODULE_INIT(ModuleSQL); + diff --git a/src/modules/extra/m_pgsql.cpp b/src/modules/extra/m_pgsql.cpp index 9e85a40de..5d267fc1a 100644 --- a/src/modules/extra/m_pgsql.cpp +++ b/src/modules/extra/m_pgsql.cpp @@ -1 +1,984 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include <cstdlib>
#include <sstream>
#include <libpq-fe.h>
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "configreader.h"
#include "m_sqlv2.h"
/* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */
/* $CompileFlags: -Iexec("pg_config --includedir") eval("my $s = `pg_config --version`;$s =~ /^.*?(\d+)\.(\d+)\.(\d+).*?$/;my $v = hex(sprintf("0x%02x%02x%02x", $1, $2, $3));print "-DPGSQL_HAS_ESCAPECONN" if(($v >= 0x080104) || ($v >= 0x07030F && $v < 0x070400) || ($v >= 0x07040D && $v < 0x080000) || ($v >= 0x080008 && $v < 0x080100));") */
/* $LinkerFlags: -Lexec("pg_config --libdir") -lpq */
/* $ModDep: m_sqlv2.h */
/* SQLConn rewritten by peavey to
* use EventHandler instead of
* InspSocket. This is much neater
* and gives total control of destroy
* and delete of resources.
*/
/* Forward declare, so we can have the typedef neatly at the top */
class SQLConn;
typedef std::map<std::string, SQLConn*> ConnMap;
/* CREAD, Connecting and wants read event
* CWRITE, Connecting and wants write event
* WREAD, Connected/Working and wants read event
* WWRITE, Connected/Working and wants write event
* RREAD, Resetting and wants read event
* RWRITE, Resetting and wants write event
*/
enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE, RREAD, RWRITE };
/** SQLhost::GetDSN() - Overload to return correct DSN for PostgreSQL
*/
std::string SQLhost::GetDSN()
{
std::ostringstream conninfo("connect_timeout = '2'");
if (ip.length())
conninfo << " hostaddr = '" << ip << "'";
if (port)
conninfo << " port = '" << port << "'";
if (name.length())
conninfo << " dbname = '" << name << "'";
if (user.length())
conninfo << " user = '" << user << "'";
if (pass.length())
conninfo << " password = '" << pass << "'";
if (ssl)
{
conninfo << " sslmode = 'require'";
}
else
{
conninfo << " sslmode = 'disable'";
}
return conninfo.str();
}
class ReconnectTimer : public InspTimer
{
private:
Module* mod;
public:
ReconnectTimer(InspIRCd* SI, Module* m)
: InspTimer(5, SI->Time(), false), mod(m)
{
}
virtual void Tick(time_t TIME);
};
/** Used to resolve sql server hostnames
*/
class SQLresolver : public Resolver
{
private:
SQLhost host;
Module* mod;
public:
SQLresolver(Module* m, InspIRCd* Instance, const SQLhost& hi, bool &cached)
: Resolver(Instance, hi.host, DNS_QUERY_FORWARD, cached, (Module*)m), host(hi), mod(m)
{
}
virtual void OnLookupComplete(const std::string &result, unsigned int ttl, bool cached);
virtual void OnError(ResolverError e, const std::string &errormessage)
{
ServerInstance->Log(DEBUG, "PgSQL: DNS lookup failed (%s), dying horribly", errormessage.c_str());
}
};
/** PgSQLresult is a subclass of the mostly-pure-virtual class SQLresult.
* All SQL providers must create their own subclass and define it's methods using that
* database library's data retriveal functions. The aim is to avoid a slow and inefficient process
* of converting all data to a common format before it reaches the result structure. This way
* data is passes to the module nearly as directly as if it was using the API directly itself.
*/
class PgSQLresult : public SQLresult
{
PGresult* res;
int currentrow;
int rows;
int cols;
SQLfieldList* fieldlist;
SQLfieldMap* fieldmap;
public:
PgSQLresult(Module* self, Module* to, unsigned long id, PGresult* result)
: SQLresult(self, to, id), res(result), currentrow(0), fieldlist(NULL), fieldmap(NULL)
{
rows = PQntuples(res);
cols = PQnfields(res);
}
~PgSQLresult()
{
/* If we allocated these, free them... */
if(fieldlist)
DELETE(fieldlist);
if(fieldmap)
DELETE(fieldmap);
PQclear(res);
}
virtual int Rows()
{
if(!cols && !rows)
{
return atoi(PQcmdTuples(res));
}
else
{
return rows;
}
}
virtual int Cols()
{
return PQnfields(res);
}
virtual std::string ColName(int column)
{
char* name = PQfname(res, column);
return (name) ? name : "";
}
virtual int ColNum(const std::string &column)
{
int n = PQfnumber(res, column.c_str());
if(n == -1)
{
throw SQLbadColName();
}
else
{
return n;
}
}
virtual SQLfield GetValue(int row, int column)
{
char* v = PQgetvalue(res, row, column);
if(v)
{
return SQLfield(std::string(v, PQgetlength(res, row, column)), PQgetisnull(res, row, column));
}
else
{
throw SQLbadColName();
}
}
virtual SQLfieldList& GetRow()
{
/* In an effort to reduce overhead we don't actually allocate the list
* until the first time it's needed...so...
*/
if(fieldlist)
{
fieldlist->clear();
}
else
{
fieldlist = new SQLfieldList;
}
if(currentrow < PQntuples(res))
{
int cols = PQnfields(res);
for(int i = 0; i < cols; i++)
{
fieldlist->push_back(GetValue(currentrow, i));
}
currentrow++;
}
return *fieldlist;
}
virtual SQLfieldMap& GetRowMap()
{
/* In an effort to reduce overhead we don't actually allocate the map
* until the first time it's needed...so...
*/
if(fieldmap)
{
fieldmap->clear();
}
else
{
fieldmap = new SQLfieldMap;
}
if(currentrow < PQntuples(res))
{
int cols = PQnfields(res);
for(int i = 0; i < cols; i++)
{
fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
}
currentrow++;
}
return *fieldmap;
}
virtual SQLfieldList* GetRowPtr()
{
SQLfieldList* fl = new SQLfieldList;
if(currentrow < PQntuples(res))
{
int cols = PQnfields(res);
for(int i = 0; i < cols; i++)
{
fl->push_back(GetValue(currentrow, i));
}
currentrow++;
}
return fl;
}
virtual SQLfieldMap* GetRowMapPtr()
{
SQLfieldMap* fm = new SQLfieldMap;
if(currentrow < PQntuples(res))
{
int cols = PQnfields(res);
for(int i = 0; i < cols; i++)
{
fm->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
}
currentrow++;
}
return fm;
}
virtual void Free(SQLfieldMap* fm)
{
DELETE(fm);
}
virtual void Free(SQLfieldList* fl)
{
DELETE(fl);
}
};
/** SQLConn represents one SQL session.
*/
class SQLConn : public EventHandler
{
private:
InspIRCd* Instance;
SQLhost confhost; /* The <database> entry */
Module* us; /* Pointer to the SQL provider itself */
PGconn* sql; /* PgSQL database connection handle */
SQLstatus status; /* PgSQL database connection status */
bool qinprog; /* If there is currently a query in progress */
QueryQueue queue; /* Queue of queries waiting to be executed on this connection */
time_t idle; /* Time we last heard from the database */
public:
SQLConn(InspIRCd* SI, Module* self, const SQLhost& hi)
: EventHandler(), Instance(SI), confhost(hi), us(self), sql(NULL), status(CWRITE), qinprog(false)
{
idle = this->Instance->Time();
if(!DoConnect())
{
Instance->Log(DEFAULT, "WARNING: Could not connect to database with id: " + ConvToStr(hi.id));
DelayReconnect();
}
}
~SQLConn()
{
Close();
}
virtual void HandleEvent(EventType et, int errornum)
{
switch (et)
{
case EVENT_READ:
OnDataReady();
break;
case EVENT_WRITE:
OnWriteReady();
break;
case EVENT_ERROR:
DelayReconnect();
break;
default:
break;
}
}
bool DoConnect()
{
if(!(sql = PQconnectStart(confhost.GetDSN().c_str())))
return false;
if(PQstatus(sql) == CONNECTION_BAD)
return false;
if(PQsetnonblocking(sql, 1) == -1)
return false;
/* OK, we've initalised the connection, now to get it hooked into the socket engine
* and then start polling it.
*/
this->fd = PQsocket(sql);
if(this->fd <= -1)
return false;
if (!this->Instance->SE->AddFd(this))
{
Instance->Log(DEBUG, "BUG: Couldn't add pgsql socket to socket engine");
return false;
}
/* Socket all hooked into the engine, now to tell PgSQL to start connecting */
return DoPoll();
}
bool DoPoll()
{
switch(PQconnectPoll(sql))
{
case PGRES_POLLING_WRITING:
Instance->SE->WantWrite(this);
status = CWRITE;
return true;
case PGRES_POLLING_READING:
status = CREAD;
return true;
case PGRES_POLLING_FAILED:
return false;
case PGRES_POLLING_OK:
status = WWRITE;
return DoConnectedPoll();
default:
return true;
}
}
bool DoConnectedPoll()
{
if(!qinprog && queue.totalsize())
{
/* There's no query currently in progress, and there's queries in the queue. */
SQLrequest& query = queue.front();
DoQuery(query);
}
if(PQconsumeInput(sql))
{
/* We just read stuff from the server, that counts as it being alive
* so update the idle-since time :p
*/
idle = this->Instance->Time();
if (PQisBusy(sql))
{
/* Nothing happens here */
}
else if (qinprog)
{
/* Grab the request we're processing */
SQLrequest& query = queue.front();
/* Get a pointer to the module we're about to return the result to */
Module* to = query.GetSource();
/* Fetch the result.. */
PGresult* result = PQgetResult(sql);
/* PgSQL would allow a query string to be sent which has multiple
* queries in it, this isn't portable across database backends and
* we don't want modules doing it. But just in case we make sure we
* drain any results there are and just use the last one.
* If the module devs are behaving there will only be one result.
*/
while (PGresult* temp = PQgetResult(sql))
{
PQclear(result);
result = temp;
}
if(to)
{
/* ..and the result */
PgSQLresult reply(us, to, query.id, result);
/* Fix by brain, make sure the original query gets sent back in the reply */
reply.query = query.query.q;
switch(PQresultStatus(result))
{
case PGRES_EMPTY_QUERY:
case PGRES_BAD_RESPONSE:
case PGRES_FATAL_ERROR:
reply.error.Id(QREPLY_FAIL);
reply.error.Str(PQresultErrorMessage(result));
default:;
/* No action, other values are not errors */
}
reply.Send();
/* PgSQLresult's destructor will free the PGresult */
}
else
{
/* If the client module is unloaded partway through a query then the provider will set
* the pointer to NULL. We cannot just cancel the query as the result will still come
* through at some point...and it could get messy if we play with invalid pointers...
*/
PQclear(result);
}
qinprog = false;
queue.pop();
DoConnectedPoll();
}
return true;
}
else
{
/* I think we'll assume this means the server died...it might not,
* but I think that any error serious enough we actually get here
* deserves to reconnect [/excuse]
* Returning true so the core doesn't try and close the connection.
*/
DelayReconnect();
return true;
}
}
bool DoResetPoll()
{
switch(PQresetPoll(sql))
{
case PGRES_POLLING_WRITING:
Instance->SE->WantWrite(this);
status = CWRITE;
return DoPoll();
case PGRES_POLLING_READING:
status = CREAD;
return true;
case PGRES_POLLING_FAILED:
return false;
case PGRES_POLLING_OK:
status = WWRITE;
return DoConnectedPoll();
default:
return true;
}
}
bool OnDataReady()
{
/* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */
return DoEvent();
}
bool OnWriteReady()
{
/* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */
return DoEvent();
}
bool OnConnected()
{
return DoEvent();
}
void DelayReconnect();
bool DoEvent()
{
bool ret;
if((status == CREAD) || (status == CWRITE))
{
ret = DoPoll();
}
else if((status == RREAD) || (status == RWRITE))
{
ret = DoResetPoll();
}
else
{
ret = DoConnectedPoll();
}
return ret;
}
SQLerror DoQuery(SQLrequest &req)
{
if((status == WREAD) || (status == WWRITE))
{
if(!qinprog)
{
/* Parse the command string and dispatch it */
/* Pointer to the buffer we screw around with substitution in */
char* query;
/* Pointer to the current end of query, where we append new stuff */
char* queryend;
/* Total length of the unescaped parameters */
unsigned int paramlen;
paramlen = 0;
for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
{
paramlen += i->size();
}
/* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
* sizeofquery + (totalparamlength*2) + 1
*
* The +1 is for null-terminating the string for PQsendQuery()
*/
query = new char[req.query.q.length() + (paramlen*2) + 1];
queryend = query;
/* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting
* the parameters into it...
*/
for(unsigned int i = 0; i < req.query.q.length(); i++)
{
if(req.query.q[i] == '?')
{
/* We found a place to substitute..what fun.
* Use the PgSQL calls to escape and write the
* escaped string onto the end of our query buffer,
* then we "just" need to make sure queryend is
* pointing at the right place.
*/
if(req.query.p.size())
{
int error = 0;
size_t len = 0;
#ifdef PGSQL_HAS_ESCAPECONN
len = PQescapeStringConn(sql, queryend, req.query.p.front().c_str(), req.query.p.front().length(), &error);
#else
len = PQescapeString (queryend, req.query.p.front().c_str(), req.query.p.front().length());
#endif
if(error)
{
Instance->Log(DEBUG, "BUG: Apparently PQescapeStringConn() failed somehow...don't know how or what to do...");
}
/* Incremenet queryend to the end of the newly escaped parameter */
queryend += len;
/* Remove the parameter we just substituted in */
req.query.p.pop_front();
}
else
{
Instance->Log(DEBUG, "BUG: Found a substitution location but no parameter to substitute :|");
break;
}
}
else
{
*queryend = req.query.q[i];
queryend++;
}
}
/* Null-terminate the query */
*queryend = 0;
req.query.q = query;
if(PQsendQuery(sql, query))
{
qinprog = true;
delete[] query;
return SQLerror();
}
else
{
delete[] query;
return SQLerror(QSEND_FAIL, PQerrorMessage(sql));
}
}
}
return SQLerror(BAD_CONN, "Can't query until connection is complete");
}
SQLerror Query(const SQLrequest &req)
{
queue.push(req);
if(!qinprog && queue.totalsize())
{
/* There's no query currently in progress, and there's queries in the queue. */
SQLrequest& query = queue.front();
return DoQuery(query);
}
else
{
return SQLerror();
}
}
void OnUnloadModule(Module* mod)
{
queue.PurgeModule(mod);
}
const SQLhost GetConfHost()
{
return confhost;
}
void Close() {
if (!this->Instance->SE->DelFd(this))
{
if (sql && PQstatus(sql) == CONNECTION_BAD)
{
this->Instance->SE->DelFd(this, true);
}
else
{
Instance->Log(DEBUG, "BUG: PQsocket cant be removed from socket engine!");
}
}
if(sql)
{
PQfinish(sql);
sql = NULL;
}
}
};
class ModulePgSQL : public Module
{
private:
ConnMap connections;
unsigned long currid;
char* sqlsuccess;
ReconnectTimer* retimer;
public:
ModulePgSQL(InspIRCd* Me)
: Module::Module(Me), currid(0)
{
ServerInstance->UseInterface("SQLutils");
sqlsuccess = new char[strlen(SQLSUCCESS)+1];
strlcpy(sqlsuccess, SQLSUCCESS, strlen(SQLSUCCESS));
if (!ServerInstance->PublishFeature("SQL", this))
{
throw ModuleException("BUG: PgSQL Unable to publish feature 'SQL'");
}
ReadConf();
ServerInstance->PublishInterface("SQL", this);
}
virtual ~ModulePgSQL()
{
if (retimer)
ServerInstance->Timers->DelTimer(retimer);
ClearAllConnections();
delete[] sqlsuccess;
ServerInstance->UnpublishInterface("SQL", this);
ServerInstance->UnpublishFeature("SQL");
ServerInstance->DoneWithInterface("SQLutils");
}
void Implements(char* List)
{
List[I_OnUnloadModule] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = List[I_OnCheckReady] = List[I_OnUserDisconnect] = 1;
}
virtual void OnRehash(userrec* user, const std::string ¶meter)
{
ReadConf();
}
bool HasHost(const SQLhost &host)
{
for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
{
if (host == iter->second->GetConfHost())
return true;
}
return false;
}
bool HostInConf(const SQLhost &h)
{
ConfigReader conf(ServerInstance);
for(int i = 0; i < conf.Enumerate("database"); i++)
{
SQLhost host;
host.id = conf.ReadValue("database", "id", i);
host.host = conf.ReadValue("database", "hostname", i);
host.port = conf.ReadInteger("database", "port", i, true);
host.name = conf.ReadValue("database", "name", i);
host.user = conf.ReadValue("database", "username", i);
host.pass = conf.ReadValue("database", "password", i);
host.ssl = conf.ReadFlag("database", "ssl", "0", i);
if (h == host)
return true;
}
return false;
}
void ReadConf()
{
ClearOldConnections();
ConfigReader conf(ServerInstance);
for(int i = 0; i < conf.Enumerate("database"); i++)
{
SQLhost host;
int ipvalid;
host.id = conf.ReadValue("database", "id", i);
host.host = conf.ReadValue("database", "hostname", i);
host.port = conf.ReadInteger("database", "port", i, true);
host.name = conf.ReadValue("database", "name", i);
host.user = conf.ReadValue("database", "username", i);
host.pass = conf.ReadValue("database", "password", i);
host.ssl = conf.ReadFlag("database", "ssl", "0", i);
if (HasHost(host))
continue;
#ifdef IPV6
if (strchr(host.host.c_str(),':'))
{
in6_addr blargle;
ipvalid = inet_pton(AF_INET6, host.host.c_str(), &blargle);
}
else
#endif
{
in_addr blargle;
ipvalid = inet_aton(host.host.c_str(), &blargle);
}
if(ipvalid > 0)
{
/* The conversion succeeded, we were given an IP and we can give it straight to SQLConn */
host.ip = host.host;
this->AddConn(host);
}
else if(ipvalid == 0)
{
/* Conversion failed, assume it's a host */
SQLresolver* resolver;
try
{
bool cached;
resolver = new SQLresolver(this, ServerInstance, host, cached);
ServerInstance->AddResolver(resolver, cached);
}
catch(...)
{
/* THE WORLD IS COMING TO AN END! */
}
}
else
{
/* Invalid address family, die horribly. */
ServerInstance->Log(DEBUG, "BUG: insp_aton failed returning -1, oh noes.");
}
}
}
void ClearOldConnections()
{
ConnMap::iterator iter,safei;
for (iter = connections.begin(); iter != connections.end(); iter++)
{
if (!HostInConf(iter->second->GetConfHost()))
{
DELETE(iter->second);
safei = iter;
--iter;
connections.erase(safei);
}
}
}
void ClearAllConnections()
{
ConnMap::iterator i;
while ((i = connections.begin()) != connections.end())
{
connections.erase(i);
DELETE(i->second);
}
}
void AddConn(const SQLhost& hi)
{
if (HasHost(hi))
{
ServerInstance->Log(DEFAULT, "WARNING: A pgsql connection with id: %s already exists, possibly due to DNS delay. Aborting connection attempt.", hi.id.c_str());
return;
}
SQLConn* newconn;
/* The conversion succeeded, we were given an IP and we can give it straight to SQLConn */
newconn = new SQLConn(ServerInstance, this, hi);
connections.insert(std::make_pair(hi.id, newconn));
}
void ReconnectConn(SQLConn* conn)
{
for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
{
if (conn == iter->second)
{
DELETE(iter->second);
connections.erase(iter);
break;
}
}
retimer = new ReconnectTimer(ServerInstance, this);
ServerInstance->Timers->AddTimer(retimer);
}
virtual char* OnRequest(Request* request)
{
if(strcmp(SQLREQID, request->GetId()) == 0)
{
SQLrequest* req = (SQLrequest*)request;
ConnMap::iterator iter;
if((iter = connections.find(req->dbid)) != connections.end())
{
/* Execute query */
req->id = NewID();
req->error = iter->second->Query(*req);
return (req->error.Id() == NO_ERROR) ? sqlsuccess : NULL;
}
else
{
req->error.Id(BAD_DBID);
return NULL;
}
}
return NULL;
}
virtual void OnUnloadModule(Module* mod, const std::string& name)
{
/* When a module unloads we have to check all the pending queries for all our connections
* and set the Module* specifying where the query came from to NULL. If the query has already
* been dispatched then when it is processed it will be dropped if the pointer is NULL.
*
* If the queries we find are not already being executed then we can simply remove them immediately.
*/
for(ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
{
iter->second->OnUnloadModule(mod);
}
}
unsigned long NewID()
{
if (currid+1 == 0)
currid++;
return ++currid;
}
virtual Version GetVersion()
{
return Version(1, 1, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER, API_VERSION);
}
};
/* move this here to use AddConn, rather that than having the whole
* module above SQLConn, since this is buggin me right now :/
*/
void SQLresolver::OnLookupComplete(const std::string &result, unsigned int ttl, bool cached)
{
host.ip = result;
((ModulePgSQL*)mod)->AddConn(host);
((ModulePgSQL*)mod)->ClearOldConnections();
}
void ReconnectTimer::Tick(time_t time)
{
((ModulePgSQL*)mod)->ReadConf();
}
void SQLConn::DelayReconnect()
{
((ModulePgSQL*)us)->ReconnectConn(this);
}
MODULE_INIT(ModulePgSQL);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" +#include <cstdlib> +#include <sstream> +#include <libpq-fe.h> +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "configreader.h" +#include "m_sqlv2.h" + +/* $ModDesc: PostgreSQL Service Provider module for all other m_sql* modules, uses v2 of the SQL API */ +/* $CompileFlags: -Iexec("pg_config --includedir") eval("my $s = `pg_config --version`;$s =~ /^.*?(\d+)\.(\d+)\.(\d+).*?$/;my $v = hex(sprintf("0x%02x%02x%02x", $1, $2, $3));print "-DPGSQL_HAS_ESCAPECONN" if(($v >= 0x080104) || ($v >= 0x07030F && $v < 0x070400) || ($v >= 0x07040D && $v < 0x080000) || ($v >= 0x080008 && $v < 0x080100));") */ +/* $LinkerFlags: -Lexec("pg_config --libdir") -lpq */ +/* $ModDep: m_sqlv2.h */ + + +/* SQLConn rewritten by peavey to + * use EventHandler instead of + * InspSocket. This is much neater + * and gives total control of destroy + * and delete of resources. + */ + +/* Forward declare, so we can have the typedef neatly at the top */ +class SQLConn; + +typedef std::map<std::string, SQLConn*> ConnMap; + +/* CREAD, Connecting and wants read event + * CWRITE, Connecting and wants write event + * WREAD, Connected/Working and wants read event + * WWRITE, Connected/Working and wants write event + * RREAD, Resetting and wants read event + * RWRITE, Resetting and wants write event + */ +enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE, RREAD, RWRITE }; + +/** SQLhost::GetDSN() - Overload to return correct DSN for PostgreSQL + */ +std::string SQLhost::GetDSN() +{ + std::ostringstream conninfo("connect_timeout = '2'"); + + if (ip.length()) + conninfo << " hostaddr = '" << ip << "'"; + + if (port) + conninfo << " port = '" << port << "'"; + + if (name.length()) + conninfo << " dbname = '" << name << "'"; + + if (user.length()) + conninfo << " user = '" << user << "'"; + + if (pass.length()) + conninfo << " password = '" << pass << "'"; + + if (ssl) + { + conninfo << " sslmode = 'require'"; + } + else + { + conninfo << " sslmode = 'disable'"; + } + + return conninfo.str(); +} + +class ReconnectTimer : public InspTimer +{ + private: + Module* mod; + public: + ReconnectTimer(InspIRCd* SI, Module* m) + : InspTimer(5, SI->Time(), false), mod(m) + { + } + virtual void Tick(time_t TIME); +}; + + +/** Used to resolve sql server hostnames + */ +class SQLresolver : public Resolver +{ + private: + SQLhost host; + Module* mod; + public: + SQLresolver(Module* m, InspIRCd* Instance, const SQLhost& hi, bool &cached) + : Resolver(Instance, hi.host, DNS_QUERY_FORWARD, cached, (Module*)m), host(hi), mod(m) + { + } + + virtual void OnLookupComplete(const std::string &result, unsigned int ttl, bool cached); + + virtual void OnError(ResolverError e, const std::string &errormessage) + { + ServerInstance->Log(DEBUG, "PgSQL: DNS lookup failed (%s), dying horribly", errormessage.c_str()); + } +}; + +/** PgSQLresult is a subclass of the mostly-pure-virtual class SQLresult. + * All SQL providers must create their own subclass and define it's methods using that + * database library's data retriveal functions. The aim is to avoid a slow and inefficient process + * of converting all data to a common format before it reaches the result structure. This way + * data is passes to the module nearly as directly as if it was using the API directly itself. + */ + +class PgSQLresult : public SQLresult +{ + PGresult* res; + int currentrow; + int rows; + int cols; + + SQLfieldList* fieldlist; + SQLfieldMap* fieldmap; +public: + PgSQLresult(Module* self, Module* to, unsigned long id, PGresult* result) + : SQLresult(self, to, id), res(result), currentrow(0), fieldlist(NULL), fieldmap(NULL) + { + rows = PQntuples(res); + cols = PQnfields(res); + } + + ~PgSQLresult() + { + /* If we allocated these, free them... */ + if(fieldlist) + DELETE(fieldlist); + + if(fieldmap) + DELETE(fieldmap); + + PQclear(res); + } + + virtual int Rows() + { + if(!cols && !rows) + { + return atoi(PQcmdTuples(res)); + } + else + { + return rows; + } + } + + virtual int Cols() + { + return PQnfields(res); + } + + virtual std::string ColName(int column) + { + char* name = PQfname(res, column); + + return (name) ? name : ""; + } + + virtual int ColNum(const std::string &column) + { + int n = PQfnumber(res, column.c_str()); + + if(n == -1) + { + throw SQLbadColName(); + } + else + { + return n; + } + } + + virtual SQLfield GetValue(int row, int column) + { + char* v = PQgetvalue(res, row, column); + + if(v) + { + return SQLfield(std::string(v, PQgetlength(res, row, column)), PQgetisnull(res, row, column)); + } + else + { + throw SQLbadColName(); + } + } + + virtual SQLfieldList& GetRow() + { + /* In an effort to reduce overhead we don't actually allocate the list + * until the first time it's needed...so... + */ + if(fieldlist) + { + fieldlist->clear(); + } + else + { + fieldlist = new SQLfieldList; + } + + if(currentrow < PQntuples(res)) + { + int cols = PQnfields(res); + + for(int i = 0; i < cols; i++) + { + fieldlist->push_back(GetValue(currentrow, i)); + } + + currentrow++; + } + + return *fieldlist; + } + + virtual SQLfieldMap& GetRowMap() + { + /* In an effort to reduce overhead we don't actually allocate the map + * until the first time it's needed...so... + */ + if(fieldmap) + { + fieldmap->clear(); + } + else + { + fieldmap = new SQLfieldMap; + } + + if(currentrow < PQntuples(res)) + { + int cols = PQnfields(res); + + for(int i = 0; i < cols; i++) + { + fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i))); + } + + currentrow++; + } + + return *fieldmap; + } + + virtual SQLfieldList* GetRowPtr() + { + SQLfieldList* fl = new SQLfieldList; + + if(currentrow < PQntuples(res)) + { + int cols = PQnfields(res); + + for(int i = 0; i < cols; i++) + { + fl->push_back(GetValue(currentrow, i)); + } + + currentrow++; + } + + return fl; + } + + virtual SQLfieldMap* GetRowMapPtr() + { + SQLfieldMap* fm = new SQLfieldMap; + + if(currentrow < PQntuples(res)) + { + int cols = PQnfields(res); + + for(int i = 0; i < cols; i++) + { + fm->insert(std::make_pair(ColName(i), GetValue(currentrow, i))); + } + + currentrow++; + } + + return fm; + } + + virtual void Free(SQLfieldMap* fm) + { + DELETE(fm); + } + + virtual void Free(SQLfieldList* fl) + { + DELETE(fl); + } +}; + +/** SQLConn represents one SQL session. + */ +class SQLConn : public EventHandler +{ + private: + InspIRCd* Instance; + SQLhost confhost; /* The <database> entry */ + Module* us; /* Pointer to the SQL provider itself */ + PGconn* sql; /* PgSQL database connection handle */ + SQLstatus status; /* PgSQL database connection status */ + bool qinprog; /* If there is currently a query in progress */ + QueryQueue queue; /* Queue of queries waiting to be executed on this connection */ + time_t idle; /* Time we last heard from the database */ + + public: + SQLConn(InspIRCd* SI, Module* self, const SQLhost& hi) + : EventHandler(), Instance(SI), confhost(hi), us(self), sql(NULL), status(CWRITE), qinprog(false) + { + idle = this->Instance->Time(); + if(!DoConnect()) + { + Instance->Log(DEFAULT, "WARNING: Could not connect to database with id: " + ConvToStr(hi.id)); + DelayReconnect(); + } + } + + ~SQLConn() + { + Close(); + } + + virtual void HandleEvent(EventType et, int errornum) + { + switch (et) + { + case EVENT_READ: + OnDataReady(); + break; + + case EVENT_WRITE: + OnWriteReady(); + break; + + case EVENT_ERROR: + DelayReconnect(); + break; + + default: + break; + } + } + + bool DoConnect() + { + if(!(sql = PQconnectStart(confhost.GetDSN().c_str()))) + return false; + + if(PQstatus(sql) == CONNECTION_BAD) + return false; + + if(PQsetnonblocking(sql, 1) == -1) + return false; + + /* OK, we've initalised the connection, now to get it hooked into the socket engine + * and then start polling it. + */ + this->fd = PQsocket(sql); + + if(this->fd <= -1) + return false; + + if (!this->Instance->SE->AddFd(this)) + { + Instance->Log(DEBUG, "BUG: Couldn't add pgsql socket to socket engine"); + return false; + } + + /* Socket all hooked into the engine, now to tell PgSQL to start connecting */ + return DoPoll(); + } + + bool DoPoll() + { + switch(PQconnectPoll(sql)) + { + case PGRES_POLLING_WRITING: + Instance->SE->WantWrite(this); + status = CWRITE; + return true; + case PGRES_POLLING_READING: + status = CREAD; + return true; + case PGRES_POLLING_FAILED: + return false; + case PGRES_POLLING_OK: + status = WWRITE; + return DoConnectedPoll(); + default: + return true; + } + } + + bool DoConnectedPoll() + { + if(!qinprog && queue.totalsize()) + { + /* There's no query currently in progress, and there's queries in the queue. */ + SQLrequest& query = queue.front(); + DoQuery(query); + } + + if(PQconsumeInput(sql)) + { + /* We just read stuff from the server, that counts as it being alive + * so update the idle-since time :p + */ + idle = this->Instance->Time(); + + if (PQisBusy(sql)) + { + /* Nothing happens here */ + } + else if (qinprog) + { + /* Grab the request we're processing */ + SQLrequest& query = queue.front(); + + /* Get a pointer to the module we're about to return the result to */ + Module* to = query.GetSource(); + + /* Fetch the result.. */ + PGresult* result = PQgetResult(sql); + + /* PgSQL would allow a query string to be sent which has multiple + * queries in it, this isn't portable across database backends and + * we don't want modules doing it. But just in case we make sure we + * drain any results there are and just use the last one. + * If the module devs are behaving there will only be one result. + */ + while (PGresult* temp = PQgetResult(sql)) + { + PQclear(result); + result = temp; + } + + if(to) + { + /* ..and the result */ + PgSQLresult reply(us, to, query.id, result); + + /* Fix by brain, make sure the original query gets sent back in the reply */ + reply.query = query.query.q; + + switch(PQresultStatus(result)) + { + case PGRES_EMPTY_QUERY: + case PGRES_BAD_RESPONSE: + case PGRES_FATAL_ERROR: + reply.error.Id(QREPLY_FAIL); + reply.error.Str(PQresultErrorMessage(result)); + default:; + /* No action, other values are not errors */ + } + + reply.Send(); + + /* PgSQLresult's destructor will free the PGresult */ + } + else + { + /* If the client module is unloaded partway through a query then the provider will set + * the pointer to NULL. We cannot just cancel the query as the result will still come + * through at some point...and it could get messy if we play with invalid pointers... + */ + PQclear(result); + } + qinprog = false; + queue.pop(); + DoConnectedPoll(); + } + return true; + } + else + { + /* I think we'll assume this means the server died...it might not, + * but I think that any error serious enough we actually get here + * deserves to reconnect [/excuse] + * Returning true so the core doesn't try and close the connection. + */ + DelayReconnect(); + return true; + } + } + + bool DoResetPoll() + { + switch(PQresetPoll(sql)) + { + case PGRES_POLLING_WRITING: + Instance->SE->WantWrite(this); + status = CWRITE; + return DoPoll(); + case PGRES_POLLING_READING: + status = CREAD; + return true; + case PGRES_POLLING_FAILED: + return false; + case PGRES_POLLING_OK: + status = WWRITE; + return DoConnectedPoll(); + default: + return true; + } + } + + bool OnDataReady() + { + /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */ + return DoEvent(); + } + + bool OnWriteReady() + { + /* Always return true here, false would close the socket - we need to do that ourselves with the pgsql API */ + return DoEvent(); + } + + bool OnConnected() + { + return DoEvent(); + } + + void DelayReconnect(); + + bool DoEvent() + { + bool ret; + + if((status == CREAD) || (status == CWRITE)) + { + ret = DoPoll(); + } + else if((status == RREAD) || (status == RWRITE)) + { + ret = DoResetPoll(); + } + else + { + ret = DoConnectedPoll(); + } + return ret; + } + + SQLerror DoQuery(SQLrequest &req) + { + if((status == WREAD) || (status == WWRITE)) + { + if(!qinprog) + { + /* Parse the command string and dispatch it */ + + /* Pointer to the buffer we screw around with substitution in */ + char* query; + /* Pointer to the current end of query, where we append new stuff */ + char* queryend; + /* Total length of the unescaped parameters */ + unsigned int paramlen; + + paramlen = 0; + + for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++) + { + paramlen += i->size(); + } + + /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be. + * sizeofquery + (totalparamlength*2) + 1 + * + * The +1 is for null-terminating the string for PQsendQuery() + */ + + query = new char[req.query.q.length() + (paramlen*2) + 1]; + queryend = query; + + /* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting + * the parameters into it... + */ + + for(unsigned int i = 0; i < req.query.q.length(); i++) + { + if(req.query.q[i] == '?') + { + /* We found a place to substitute..what fun. + * Use the PgSQL calls to escape and write the + * escaped string onto the end of our query buffer, + * then we "just" need to make sure queryend is + * pointing at the right place. + */ + + if(req.query.p.size()) + { + int error = 0; + size_t len = 0; + +#ifdef PGSQL_HAS_ESCAPECONN + len = PQescapeStringConn(sql, queryend, req.query.p.front().c_str(), req.query.p.front().length(), &error); +#else + len = PQescapeString (queryend, req.query.p.front().c_str(), req.query.p.front().length()); +#endif + if(error) + { + Instance->Log(DEBUG, "BUG: Apparently PQescapeStringConn() failed somehow...don't know how or what to do..."); + } + + /* Incremenet queryend to the end of the newly escaped parameter */ + queryend += len; + + /* Remove the parameter we just substituted in */ + req.query.p.pop_front(); + } + else + { + Instance->Log(DEBUG, "BUG: Found a substitution location but no parameter to substitute :|"); + break; + } + } + else + { + *queryend = req.query.q[i]; + queryend++; + } + } + + /* Null-terminate the query */ + *queryend = 0; + req.query.q = query; + + if(PQsendQuery(sql, query)) + { + qinprog = true; + delete[] query; + return SQLerror(); + } + else + { + delete[] query; + return SQLerror(QSEND_FAIL, PQerrorMessage(sql)); + } + } + } + return SQLerror(BAD_CONN, "Can't query until connection is complete"); + } + + SQLerror Query(const SQLrequest &req) + { + queue.push(req); + + if(!qinprog && queue.totalsize()) + { + /* There's no query currently in progress, and there's queries in the queue. */ + SQLrequest& query = queue.front(); + return DoQuery(query); + } + else + { + return SQLerror(); + } + } + + void OnUnloadModule(Module* mod) + { + queue.PurgeModule(mod); + } + + const SQLhost GetConfHost() + { + return confhost; + } + + void Close() { + if (!this->Instance->SE->DelFd(this)) + { + if (sql && PQstatus(sql) == CONNECTION_BAD) + { + this->Instance->SE->DelFd(this, true); + } + else + { + Instance->Log(DEBUG, "BUG: PQsocket cant be removed from socket engine!"); + } + } + + if(sql) + { + PQfinish(sql); + sql = NULL; + } + } + +}; + +class ModulePgSQL : public Module +{ + private: + ConnMap connections; + unsigned long currid; + char* sqlsuccess; + ReconnectTimer* retimer; + + public: + ModulePgSQL(InspIRCd* Me) + : Module::Module(Me), currid(0) + { + ServerInstance->UseInterface("SQLutils"); + + sqlsuccess = new char[strlen(SQLSUCCESS)+1]; + + strlcpy(sqlsuccess, SQLSUCCESS, strlen(SQLSUCCESS)); + + if (!ServerInstance->PublishFeature("SQL", this)) + { + throw ModuleException("BUG: PgSQL Unable to publish feature 'SQL'"); + } + + ReadConf(); + + ServerInstance->PublishInterface("SQL", this); + } + + virtual ~ModulePgSQL() + { + if (retimer) + ServerInstance->Timers->DelTimer(retimer); + ClearAllConnections(); + delete[] sqlsuccess; + ServerInstance->UnpublishInterface("SQL", this); + ServerInstance->UnpublishFeature("SQL"); + ServerInstance->DoneWithInterface("SQLutils"); + } + + void Implements(char* List) + { + List[I_OnUnloadModule] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = List[I_OnCheckReady] = List[I_OnUserDisconnect] = 1; + } + + virtual void OnRehash(userrec* user, const std::string ¶meter) + { + ReadConf(); + } + + bool HasHost(const SQLhost &host) + { + for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) + { + if (host == iter->second->GetConfHost()) + return true; + } + return false; + } + + bool HostInConf(const SQLhost &h) + { + ConfigReader conf(ServerInstance); + for(int i = 0; i < conf.Enumerate("database"); i++) + { + SQLhost host; + host.id = conf.ReadValue("database", "id", i); + host.host = conf.ReadValue("database", "hostname", i); + host.port = conf.ReadInteger("database", "port", i, true); + host.name = conf.ReadValue("database", "name", i); + host.user = conf.ReadValue("database", "username", i); + host.pass = conf.ReadValue("database", "password", i); + host.ssl = conf.ReadFlag("database", "ssl", "0", i); + if (h == host) + return true; + } + return false; + } + + void ReadConf() + { + ClearOldConnections(); + + ConfigReader conf(ServerInstance); + for(int i = 0; i < conf.Enumerate("database"); i++) + { + SQLhost host; + int ipvalid; + + host.id = conf.ReadValue("database", "id", i); + host.host = conf.ReadValue("database", "hostname", i); + host.port = conf.ReadInteger("database", "port", i, true); + host.name = conf.ReadValue("database", "name", i); + host.user = conf.ReadValue("database", "username", i); + host.pass = conf.ReadValue("database", "password", i); + host.ssl = conf.ReadFlag("database", "ssl", "0", i); + + if (HasHost(host)) + continue; + +#ifdef IPV6 + if (strchr(host.host.c_str(),':')) + { + in6_addr blargle; + ipvalid = inet_pton(AF_INET6, host.host.c_str(), &blargle); + } + else +#endif + { + in_addr blargle; + ipvalid = inet_aton(host.host.c_str(), &blargle); + } + + if(ipvalid > 0) + { + /* The conversion succeeded, we were given an IP and we can give it straight to SQLConn */ + host.ip = host.host; + this->AddConn(host); + } + else if(ipvalid == 0) + { + /* Conversion failed, assume it's a host */ + SQLresolver* resolver; + + try + { + bool cached; + resolver = new SQLresolver(this, ServerInstance, host, cached); + ServerInstance->AddResolver(resolver, cached); + } + catch(...) + { + /* THE WORLD IS COMING TO AN END! */ + } + } + else + { + /* Invalid address family, die horribly. */ + ServerInstance->Log(DEBUG, "BUG: insp_aton failed returning -1, oh noes."); + } + } + } + + void ClearOldConnections() + { + ConnMap::iterator iter,safei; + for (iter = connections.begin(); iter != connections.end(); iter++) + { + if (!HostInConf(iter->second->GetConfHost())) + { + DELETE(iter->second); + safei = iter; + --iter; + connections.erase(safei); + } + } + } + + void ClearAllConnections() + { + ConnMap::iterator i; + while ((i = connections.begin()) != connections.end()) + { + connections.erase(i); + DELETE(i->second); + } + } + + void AddConn(const SQLhost& hi) + { + if (HasHost(hi)) + { + ServerInstance->Log(DEFAULT, "WARNING: A pgsql connection with id: %s already exists, possibly due to DNS delay. Aborting connection attempt.", hi.id.c_str()); + return; + } + + SQLConn* newconn; + + /* The conversion succeeded, we were given an IP and we can give it straight to SQLConn */ + newconn = new SQLConn(ServerInstance, this, hi); + + connections.insert(std::make_pair(hi.id, newconn)); + } + + void ReconnectConn(SQLConn* conn) + { + for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) + { + if (conn == iter->second) + { + DELETE(iter->second); + connections.erase(iter); + break; + } + } + retimer = new ReconnectTimer(ServerInstance, this); + ServerInstance->Timers->AddTimer(retimer); + } + + virtual char* OnRequest(Request* request) + { + if(strcmp(SQLREQID, request->GetId()) == 0) + { + SQLrequest* req = (SQLrequest*)request; + ConnMap::iterator iter; + if((iter = connections.find(req->dbid)) != connections.end()) + { + /* Execute query */ + req->id = NewID(); + req->error = iter->second->Query(*req); + + return (req->error.Id() == NO_ERROR) ? sqlsuccess : NULL; + } + else + { + req->error.Id(BAD_DBID); + return NULL; + } + } + return NULL; + } + + virtual void OnUnloadModule(Module* mod, const std::string& name) + { + /* When a module unloads we have to check all the pending queries for all our connections + * and set the Module* specifying where the query came from to NULL. If the query has already + * been dispatched then when it is processed it will be dropped if the pointer is NULL. + * + * If the queries we find are not already being executed then we can simply remove them immediately. + */ + for(ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) + { + iter->second->OnUnloadModule(mod); + } + } + + unsigned long NewID() + { + if (currid+1 == 0) + currid++; + + return ++currid; + } + + virtual Version GetVersion() + { + return Version(1, 1, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER, API_VERSION); + } +}; + +/* move this here to use AddConn, rather that than having the whole + * module above SQLConn, since this is buggin me right now :/ + */ +void SQLresolver::OnLookupComplete(const std::string &result, unsigned int ttl, bool cached) +{ + host.ip = result; + ((ModulePgSQL*)mod)->AddConn(host); + ((ModulePgSQL*)mod)->ClearOldConnections(); +} + +void ReconnectTimer::Tick(time_t time) +{ + ((ModulePgSQL*)mod)->ReadConf(); +} + +void SQLConn::DelayReconnect() +{ + ((ModulePgSQL*)us)->ReconnectConn(this); +} + +MODULE_INIT(ModulePgSQL); + diff --git a/src/modules/extra/m_sqlauth.cpp b/src/modules/extra/m_sqlauth.cpp index 862929919..6b05ee521 100644 --- a/src/modules/extra/m_sqlauth.cpp +++ b/src/modules/extra/m_sqlauth.cpp @@ -1 +1,194 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "m_sqlv2.h"
#include "m_sqlutils.h"
/* $ModDesc: Allow/Deny connections based upon an arbitary SQL table */
/* $ModDep: m_sqlv2.h m_sqlutils.h */
class ModuleSQLAuth : public Module
{
Module* SQLutils;
Module* SQLprovider;
std::string usertable;
std::string userfield;
std::string passfield;
std::string encryption;
std::string killreason;
std::string allowpattern;
std::string databaseid;
bool verbose;
public:
ModuleSQLAuth(InspIRCd* Me)
: Module::Module(Me)
{
ServerInstance->UseInterface("SQLutils");
ServerInstance->UseInterface("SQL");
SQLutils = ServerInstance->FindModule("m_sqlutils.so");
if (!SQLutils)
throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so.");
SQLprovider = ServerInstance->FindFeature("SQL");
if (!SQLprovider)
throw ModuleException("Can't find an SQL provider module. Please load one before attempting to load m_sqlauth.");
OnRehash(NULL,"");
}
virtual ~ModuleSQLAuth()
{
ServerInstance->DoneWithInterface("SQL");
ServerInstance->DoneWithInterface("SQLutils");
}
void Implements(char* List)
{
List[I_OnUserDisconnect] = List[I_OnCheckReady] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = 1;
}
virtual void OnRehash(userrec* user, const std::string ¶meter)
{
ConfigReader Conf(ServerInstance);
usertable = Conf.ReadValue("sqlauth", "usertable", 0); /* User table name */
databaseid = Conf.ReadValue("sqlauth", "dbid", 0); /* Database ID, given to the SQL service provider */
userfield = Conf.ReadValue("sqlauth", "userfield", 0); /* Field name where username can be found */
passfield = Conf.ReadValue("sqlauth", "passfield", 0); /* Field name where password can be found */
killreason = Conf.ReadValue("sqlauth", "killreason", 0); /* Reason to give when access is denied to a user (put your reg details here) */
allowpattern= Conf.ReadValue("sqlauth", "allowpattern",0 ); /* Allow nicks matching this pattern without requiring auth */
encryption = Conf.ReadValue("sqlauth", "encryption", 0); /* Name of sql function used to encrypt password, e.g. "md5" or "passwd".
* define, but leave blank if no encryption is to be used.
*/
verbose = Conf.ReadFlag("sqlauth", "verbose", 0); /* Set to true if failed connects should be reported to operators */
if (encryption.find("(") == std::string::npos)
{
encryption.append("(");
}
}
virtual int OnUserRegister(userrec* user)
{
if ((!allowpattern.empty()) && (ServerInstance->MatchText(user->nick,allowpattern)))
{
user->Extend("sqlauthed");
return 0;
}
if (!CheckCredentials(user))
{
userrec::QuitUser(ServerInstance,user,killreason);
return 1;
}
return 0;
}
bool CheckCredentials(userrec* user)
{
SQLrequest req = SQLreq(this, SQLprovider, databaseid, "SELECT ? FROM ? WHERE ? = '?' AND ? = ?'?')", userfield, usertable, userfield, user->nick, passfield, encryption, user->password);
if(req.Send())
{
/* When we get the query response from the service provider we will be given an ID to play with,
* just an ID number which is unique to this query. We need a way of associating that ID with a userrec
* so we insert it into a map mapping the IDs to users.
* Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the
* association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling
* us to discard the query.
*/
AssociateUser(this, SQLutils, req.id, user).Send();
return true;
}
else
{
if (verbose)
ServerInstance->WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, req.error.Str());
return false;
}
}
virtual char* OnRequest(Request* request)
{
if(strcmp(SQLRESID, request->GetId()) == 0)
{
SQLresult* res = static_cast<SQLresult*>(request);
userrec* user = GetAssocUser(this, SQLutils, res->id).S().user;
UnAssociate(this, SQLutils, res->id).S();
if(user)
{
if(res->error.Id() == NO_ERROR)
{
if(res->Rows())
{
/* We got a row in the result, this is enough really */
user->Extend("sqlauthed");
}
else if (verbose)
{
/* No rows in result, this means there was no record matching the user */
ServerInstance->WriteOpers("Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick, user->ident, user->host);
user->Extend("sqlauth_failed");
}
}
else if (verbose)
{
ServerInstance->WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, res->error.Str());
user->Extend("sqlauth_failed");
}
}
else
{
return NULL;
}
if (!user->GetExt("sqlauthed"))
{
userrec::QuitUser(ServerInstance,user,killreason);
}
return SQLSUCCESS;
}
return NULL;
}
virtual void OnUserDisconnect(userrec* user)
{
user->Shrink("sqlauthed");
user->Shrink("sqlauth_failed");
}
virtual bool OnCheckReady(userrec* user)
{
return user->GetExt("sqlauthed");
}
virtual Version GetVersion()
{
return Version(1,1,1,0,VF_VENDOR,API_VERSION);
}
};
MODULE_INIT(ModuleSQLAuth);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "m_sqlv2.h" +#include "m_sqlutils.h" + +/* $ModDesc: Allow/Deny connections based upon an arbitary SQL table */ +/* $ModDep: m_sqlv2.h m_sqlutils.h */ + +class ModuleSQLAuth : public Module +{ + Module* SQLutils; + Module* SQLprovider; + + std::string usertable; + std::string userfield; + std::string passfield; + std::string encryption; + std::string killreason; + std::string allowpattern; + std::string databaseid; + + bool verbose; + +public: + ModuleSQLAuth(InspIRCd* Me) + : Module::Module(Me) + { + ServerInstance->UseInterface("SQLutils"); + ServerInstance->UseInterface("SQL"); + + SQLutils = ServerInstance->FindModule("m_sqlutils.so"); + if (!SQLutils) + throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so."); + + SQLprovider = ServerInstance->FindFeature("SQL"); + if (!SQLprovider) + throw ModuleException("Can't find an SQL provider module. Please load one before attempting to load m_sqlauth."); + + OnRehash(NULL,""); + } + + virtual ~ModuleSQLAuth() + { + ServerInstance->DoneWithInterface("SQL"); + ServerInstance->DoneWithInterface("SQLutils"); + } + + void Implements(char* List) + { + List[I_OnUserDisconnect] = List[I_OnCheckReady] = List[I_OnRequest] = List[I_OnRehash] = List[I_OnUserRegister] = 1; + } + + virtual void OnRehash(userrec* user, const std::string ¶meter) + { + ConfigReader Conf(ServerInstance); + + usertable = Conf.ReadValue("sqlauth", "usertable", 0); /* User table name */ + databaseid = Conf.ReadValue("sqlauth", "dbid", 0); /* Database ID, given to the SQL service provider */ + userfield = Conf.ReadValue("sqlauth", "userfield", 0); /* Field name where username can be found */ + passfield = Conf.ReadValue("sqlauth", "passfield", 0); /* Field name where password can be found */ + killreason = Conf.ReadValue("sqlauth", "killreason", 0); /* Reason to give when access is denied to a user (put your reg details here) */ + allowpattern= Conf.ReadValue("sqlauth", "allowpattern",0 ); /* Allow nicks matching this pattern without requiring auth */ + encryption = Conf.ReadValue("sqlauth", "encryption", 0); /* Name of sql function used to encrypt password, e.g. "md5" or "passwd". + * define, but leave blank if no encryption is to be used. + */ + verbose = Conf.ReadFlag("sqlauth", "verbose", 0); /* Set to true if failed connects should be reported to operators */ + + if (encryption.find("(") == std::string::npos) + { + encryption.append("("); + } + } + + virtual int OnUserRegister(userrec* user) + { + if ((!allowpattern.empty()) && (ServerInstance->MatchText(user->nick,allowpattern))) + { + user->Extend("sqlauthed"); + return 0; + } + + if (!CheckCredentials(user)) + { + userrec::QuitUser(ServerInstance,user,killreason); + return 1; + } + return 0; + } + + bool CheckCredentials(userrec* user) + { + SQLrequest req = SQLreq(this, SQLprovider, databaseid, "SELECT ? FROM ? WHERE ? = '?' AND ? = ?'?')", userfield, usertable, userfield, user->nick, passfield, encryption, user->password); + + if(req.Send()) + { + /* When we get the query response from the service provider we will be given an ID to play with, + * just an ID number which is unique to this query. We need a way of associating that ID with a userrec + * so we insert it into a map mapping the IDs to users. + * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the + * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling + * us to discard the query. + */ + AssociateUser(this, SQLutils, req.id, user).Send(); + + return true; + } + else + { + if (verbose) + ServerInstance->WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, req.error.Str()); + return false; + } + } + + virtual char* OnRequest(Request* request) + { + if(strcmp(SQLRESID, request->GetId()) == 0) + { + SQLresult* res = static_cast<SQLresult*>(request); + + userrec* user = GetAssocUser(this, SQLutils, res->id).S().user; + UnAssociate(this, SQLutils, res->id).S(); + + if(user) + { + if(res->error.Id() == NO_ERROR) + { + if(res->Rows()) + { + /* We got a row in the result, this is enough really */ + user->Extend("sqlauthed"); + } + else if (verbose) + { + /* No rows in result, this means there was no record matching the user */ + ServerInstance->WriteOpers("Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick, user->ident, user->host); + user->Extend("sqlauth_failed"); + } + } + else if (verbose) + { + ServerInstance->WriteOpers("Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick, user->ident, user->host, res->error.Str()); + user->Extend("sqlauth_failed"); + } + } + else + { + return NULL; + } + + if (!user->GetExt("sqlauthed")) + { + userrec::QuitUser(ServerInstance,user,killreason); + } + return SQLSUCCESS; + } + return NULL; + } + + virtual void OnUserDisconnect(userrec* user) + { + user->Shrink("sqlauthed"); + user->Shrink("sqlauth_failed"); + } + + virtual bool OnCheckReady(userrec* user) + { + return user->GetExt("sqlauthed"); + } + + virtual Version GetVersion() + { + return Version(1,1,1,0,VF_VENDOR,API_VERSION); + } + +}; + +MODULE_INIT(ModuleSQLAuth); + diff --git a/src/modules/extra/m_sqlite3.cpp b/src/modules/extra/m_sqlite3.cpp index 6741d7745..66955de07 100644 --- a/src/modules/extra/m_sqlite3.cpp +++ b/src/modules/extra/m_sqlite3.cpp @@ -1 +1,660 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include <sqlite3.h>
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "m_sqlv2.h"
/* $ModDesc: sqlite3 provider */
/* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */
/* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
/* $ModDep: m_sqlv2.h */
class SQLConn;
class SQLite3Result;
class ResultNotifier;
typedef std::map<std::string, SQLConn*> ConnMap;
typedef std::deque<classbase*> paramlist;
typedef std::deque<SQLite3Result*> ResultQueue;
ResultNotifier* resultnotify = NULL;
class ResultNotifier : public InspSocket
{
Module* mod;
insp_sockaddr sock_us;
socklen_t uslen;
public:
/* Create a socket on a random port. Let the tcp stack allocate us an available port */
#ifdef IPV6
ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "::1", 0, true, 3000), mod(m)
#else
ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "127.0.0.1", 0, true, 3000), mod(m)
#endif
{
uslen = sizeof(sock_us);
if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
{
throw ModuleException("Could not create random listening port on localhost");
}
}
ResultNotifier(InspIRCd* SI, Module* m, int newfd, char* ip) : InspSocket(SI, newfd, ip), mod(m)
{
}
/* Using getsockname and ntohs, we can determine which port number we were allocated */
int GetPort()
{
#ifdef IPV6
return ntohs(sock_us.sin6_port);
#else
return ntohs(sock_us.sin_port);
#endif
}
virtual int OnIncomingConnection(int newsock, char* ip)
{
Dispatch();
return false;
}
void Dispatch();
};
class SQLite3Result : public SQLresult
{
private:
int currentrow;
int rows;
int cols;
std::vector<std::string> colnames;
std::vector<SQLfieldList> fieldlists;
SQLfieldList emptyfieldlist;
SQLfieldList* fieldlist;
SQLfieldMap* fieldmap;
public:
SQLite3Result(Module* self, Module* to, unsigned int id)
: SQLresult(self, to, id), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
{
}
~SQLite3Result()
{
}
void AddRow(int colsnum, char **data, char **colname)
{
colnames.clear();
cols = colsnum;
for (int i = 0; i < colsnum; i++)
{
fieldlists.resize(fieldlists.size()+1);
colnames.push_back(colname[i]);
SQLfield sf(data[i] ? data[i] : "", data[i] ? false : true);
fieldlists[rows].push_back(sf);
}
rows++;
}
void UpdateAffectedCount()
{
rows++;
}
virtual int Rows()
{
return rows;
}
virtual int Cols()
{
return cols;
}
virtual std::string ColName(int column)
{
if (column < (int)colnames.size())
{
return colnames[column];
}
else
{
throw SQLbadColName();
}
return "";
}
virtual int ColNum(const std::string &column)
{
for (unsigned int i = 0; i < colnames.size(); i++)
{
if (column == colnames[i])
return i;
}
throw SQLbadColName();
return 0;
}
virtual SQLfield GetValue(int row, int column)
{
if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
{
return fieldlists[row][column];
}
throw SQLbadColName();
/* XXX: We never actually get here because of the throw */
return SQLfield("",true);
}
virtual SQLfieldList& GetRow()
{
if (currentrow < rows)
return fieldlists[currentrow];
else
return emptyfieldlist;
}
virtual SQLfieldMap& GetRowMap()
{
/* In an effort to reduce overhead we don't actually allocate the map
* until the first time it's needed...so...
*/
if(fieldmap)
{
fieldmap->clear();
}
else
{
fieldmap = new SQLfieldMap;
}
if (currentrow < rows)
{
for (int i = 0; i < Cols(); i++)
{
fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
}
currentrow++;
}
return *fieldmap;
}
virtual SQLfieldList* GetRowPtr()
{
fieldlist = new SQLfieldList();
if (currentrow < rows)
{
for (int i = 0; i < Rows(); i++)
{
fieldlist->push_back(fieldlists[currentrow][i]);
}
currentrow++;
}
return fieldlist;
}
virtual SQLfieldMap* GetRowMapPtr()
{
fieldmap = new SQLfieldMap();
if (currentrow < rows)
{
for (int i = 0; i < Cols(); i++)
{
fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
}
currentrow++;
}
return fieldmap;
}
virtual void Free(SQLfieldMap* fm)
{
delete fm;
}
virtual void Free(SQLfieldList* fl)
{
delete fl;
}
};
class SQLConn : public classbase
{
private:
ResultQueue results;
InspIRCd* Instance;
Module* mod;
SQLhost host;
sqlite3* conn;
public:
SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
: Instance(SI), mod(m), host(hi)
{
if (OpenDB() != SQLITE_OK)
{
Instance->Log(DEFAULT, "WARNING: Could not open DB with id: " + host.id);
CloseDB();
}
}
~SQLConn()
{
CloseDB();
}
SQLerror Query(SQLrequest &req)
{
/* Pointer to the buffer we screw around with substitution in */
char* query;
/* Pointer to the current end of query, where we append new stuff */
char* queryend;
/* Total length of the unescaped parameters */
unsigned long paramlen;
/* Total length of query, used for binary-safety in mysql_real_query */
unsigned long querylength = 0;
paramlen = 0;
for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
{
paramlen += i->size();
}
/* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
* sizeofquery + (totalparamlength*2) + 1
*
* The +1 is for null-terminating the string for mysql_real_escape_string
*/
query = new char[req.query.q.length() + (paramlen*2) + 1];
queryend = query;
for(unsigned long i = 0; i < req.query.q.length(); i++)
{
if(req.query.q[i] == '?')
{
if(req.query.p.size())
{
char* escaped;
escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
for (char* n = escaped; *n; n++)
{
*queryend = *n;
queryend++;
}
sqlite3_free(escaped);
req.query.p.pop_front();
}
else
break;
}
else
{
*queryend = req.query.q[i];
queryend++;
}
querylength++;
}
*queryend = 0;
req.query.q = query;
SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
res->dbid = host.id;
res->query = req.query.q;
paramlist params;
params.push_back(this);
params.push_back(res);
char *errmsg = 0;
sqlite3_update_hook(conn, QueryUpdateHook, ¶ms);
if (sqlite3_exec(conn, req.query.q.data(), QueryResult, ¶ms, &errmsg) != SQLITE_OK)
{
std::string error(errmsg);
sqlite3_free(errmsg);
delete[] query;
delete res;
return SQLerror(QSEND_FAIL, error);
}
delete[] query;
results.push_back(res);
SendNotify();
return SQLerror();
}
static int QueryResult(void *params, int argc, char **argv, char **azColName)
{
paramlist* p = (paramlist*)params;
((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
return 0;
}
static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid)
{
paramlist* p = (paramlist*)params;
((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1]));
}
void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
{
res->AddRow(cols, data, colnames);
}
void AffectedReady(SQLite3Result *res)
{
res->UpdateAffectedCount();
}
int OpenDB()
{
return sqlite3_open(host.host.c_str(), &conn);
}
void CloseDB()
{
sqlite3_interrupt(conn);
sqlite3_close(conn);
}
SQLhost GetConfHost()
{
return host;
}
void SendResults()
{
while (results.size())
{
SQLite3Result* res = results[0];
if (res->GetDest())
{
res->Send();
}
else
{
/* If the client module is unloaded partway through a query then the provider will set
* the pointer to NULL. We cannot just cancel the query as the result will still come
* through at some point...and it could get messy if we play with invalid pointers...
*/
delete res;
}
results.pop_front();
}
}
void ClearResults()
{
while (results.size())
{
SQLite3Result* res = results[0];
delete res;
results.pop_front();
}
}
void SendNotify()
{
int QueueFD;
if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
{
/* crap, we're out of sockets... */
return;
}
insp_sockaddr addr;
#ifdef IPV6
insp_aton("::1", &addr.sin6_addr);
addr.sin6_family = AF_FAMILY;
addr.sin6_port = htons(resultnotify->GetPort());
#else
insp_inaddr ia;
insp_aton("127.0.0.1", &ia);
addr.sin_family = AF_FAMILY;
addr.sin_addr = ia;
addr.sin_port = htons(resultnotify->GetPort());
#endif
if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
{
/* wtf, we cant connect to it, but we just created it! */
return;
}
}
};
class ModuleSQLite3 : public Module
{
private:
ConnMap connections;
unsigned long currid;
public:
ModuleSQLite3(InspIRCd* Me)
: Module::Module(Me), currid(0)
{
ServerInstance->UseInterface("SQLutils");
if (!ServerInstance->PublishFeature("SQL", this))
{
throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");
}
resultnotify = new ResultNotifier(ServerInstance, this);
ReadConf();
ServerInstance->PublishInterface("SQL", this);
}
virtual ~ModuleSQLite3()
{
ClearQueue();
ClearAllConnections();
resultnotify->SetFd(-1);
resultnotify->state = I_ERROR;
resultnotify->OnError(I_ERR_SOCKET);
resultnotify->ClosePending = true;
delete resultnotify;
ServerInstance->UnpublishInterface("SQL", this);
ServerInstance->UnpublishFeature("SQL");
ServerInstance->DoneWithInterface("SQLutils");
}
void Implements(char* List)
{
List[I_OnRequest] = List[I_OnRehash] = 1;
}
void SendQueue()
{
for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
{
iter->second->SendResults();
}
}
void ClearQueue()
{
for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
{
iter->second->ClearResults();
}
}
bool HasHost(const SQLhost &host)
{
for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
{
if (host == iter->second->GetConfHost())
return true;
}
return false;
}
bool HostInConf(const SQLhost &h)
{
ConfigReader conf(ServerInstance);
for(int i = 0; i < conf.Enumerate("database"); i++)
{
SQLhost host;
host.id = conf.ReadValue("database", "id", i);
host.host = conf.ReadValue("database", "hostname", i);
host.port = conf.ReadInteger("database", "port", i, true);
host.name = conf.ReadValue("database", "name", i);
host.user = conf.ReadValue("database", "username", i);
host.pass = conf.ReadValue("database", "password", i);
host.ssl = conf.ReadFlag("database", "ssl", "0", i);
if (h == host)
return true;
}
return false;
}
void ReadConf()
{
ClearOldConnections();
ConfigReader conf(ServerInstance);
for(int i = 0; i < conf.Enumerate("database"); i++)
{
SQLhost host;
host.id = conf.ReadValue("database", "id", i);
host.host = conf.ReadValue("database", "hostname", i);
host.port = conf.ReadInteger("database", "port", i, true);
host.name = conf.ReadValue("database", "name", i);
host.user = conf.ReadValue("database", "username", i);
host.pass = conf.ReadValue("database", "password", i);
host.ssl = conf.ReadFlag("database", "ssl", "0", i);
if (HasHost(host))
continue;
this->AddConn(host);
}
}
void AddConn(const SQLhost& hi)
{
if (HasHost(hi))
{
ServerInstance->Log(DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
return;
}
SQLConn* newconn;
newconn = new SQLConn(ServerInstance, this, hi);
connections.insert(std::make_pair(hi.id, newconn));
}
void ClearOldConnections()
{
ConnMap::iterator iter,safei;
for (iter = connections.begin(); iter != connections.end(); iter++)
{
if (!HostInConf(iter->second->GetConfHost()))
{
DELETE(iter->second);
safei = iter;
--iter;
connections.erase(safei);
}
}
}
void ClearAllConnections()
{
ConnMap::iterator i;
while ((i = connections.begin()) != connections.end())
{
connections.erase(i);
DELETE(i->second);
}
}
virtual void OnRehash(userrec* user, const std::string ¶meter)
{
ReadConf();
}
virtual char* OnRequest(Request* request)
{
if(strcmp(SQLREQID, request->GetId()) == 0)
{
SQLrequest* req = (SQLrequest*)request;
ConnMap::iterator iter;
if((iter = connections.find(req->dbid)) != connections.end())
{
req->id = NewID();
req->error = iter->second->Query(*req);
return SQLSUCCESS;
}
else
{
req->error.Id(BAD_DBID);
return NULL;
}
}
return NULL;
}
unsigned long NewID()
{
if (currid+1 == 0)
currid++;
return ++currid;
}
virtual Version GetVersion()
{
return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION);
}
};
void ResultNotifier::Dispatch()
{
((ModuleSQLite3*)mod)->SendQueue();
}
MODULE_INIT(ModuleSQLite3);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" +#include <sqlite3.h> +#include "users.h" +#include "channels.h" +#include "modules.h" + +#include "m_sqlv2.h" + +/* $ModDesc: sqlite3 provider */ +/* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */ +/* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */ +/* $ModDep: m_sqlv2.h */ + + +class SQLConn; +class SQLite3Result; +class ResultNotifier; + +typedef std::map<std::string, SQLConn*> ConnMap; +typedef std::deque<classbase*> paramlist; +typedef std::deque<SQLite3Result*> ResultQueue; + +ResultNotifier* resultnotify = NULL; + + +class ResultNotifier : public InspSocket +{ + Module* mod; + insp_sockaddr sock_us; + socklen_t uslen; + + public: + /* Create a socket on a random port. Let the tcp stack allocate us an available port */ +#ifdef IPV6 + ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "::1", 0, true, 3000), mod(m) +#else + ResultNotifier(InspIRCd* SI, Module* m) : InspSocket(SI, "127.0.0.1", 0, true, 3000), mod(m) +#endif + { + uslen = sizeof(sock_us); + if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen)) + { + throw ModuleException("Could not create random listening port on localhost"); + } + } + + ResultNotifier(InspIRCd* SI, Module* m, int newfd, char* ip) : InspSocket(SI, newfd, ip), mod(m) + { + } + + /* Using getsockname and ntohs, we can determine which port number we were allocated */ + int GetPort() + { +#ifdef IPV6 + return ntohs(sock_us.sin6_port); +#else + return ntohs(sock_us.sin_port); +#endif + } + + virtual int OnIncomingConnection(int newsock, char* ip) + { + Dispatch(); + return false; + } + + void Dispatch(); +}; + + +class SQLite3Result : public SQLresult +{ + private: + int currentrow; + int rows; + int cols; + + std::vector<std::string> colnames; + std::vector<SQLfieldList> fieldlists; + SQLfieldList emptyfieldlist; + + SQLfieldList* fieldlist; + SQLfieldMap* fieldmap; + + public: + SQLite3Result(Module* self, Module* to, unsigned int id) + : SQLresult(self, to, id), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL) + { + } + + ~SQLite3Result() + { + } + + void AddRow(int colsnum, char **data, char **colname) + { + colnames.clear(); + cols = colsnum; + for (int i = 0; i < colsnum; i++) + { + fieldlists.resize(fieldlists.size()+1); + colnames.push_back(colname[i]); + SQLfield sf(data[i] ? data[i] : "", data[i] ? false : true); + fieldlists[rows].push_back(sf); + } + rows++; + } + + void UpdateAffectedCount() + { + rows++; + } + + virtual int Rows() + { + return rows; + } + + virtual int Cols() + { + return cols; + } + + virtual std::string ColName(int column) + { + if (column < (int)colnames.size()) + { + return colnames[column]; + } + else + { + throw SQLbadColName(); + } + return ""; + } + + virtual int ColNum(const std::string &column) + { + for (unsigned int i = 0; i < colnames.size(); i++) + { + if (column == colnames[i]) + return i; + } + throw SQLbadColName(); + return 0; + } + + virtual SQLfield GetValue(int row, int column) + { + if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols())) + { + return fieldlists[row][column]; + } + + throw SQLbadColName(); + + /* XXX: We never actually get here because of the throw */ + return SQLfield("",true); + } + + virtual SQLfieldList& GetRow() + { + if (currentrow < rows) + return fieldlists[currentrow]; + else + return emptyfieldlist; + } + + virtual SQLfieldMap& GetRowMap() + { + /* In an effort to reduce overhead we don't actually allocate the map + * until the first time it's needed...so... + */ + if(fieldmap) + { + fieldmap->clear(); + } + else + { + fieldmap = new SQLfieldMap; + } + + if (currentrow < rows) + { + for (int i = 0; i < Cols(); i++) + { + fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i))); + } + currentrow++; + } + + return *fieldmap; + } + + virtual SQLfieldList* GetRowPtr() + { + fieldlist = new SQLfieldList(); + + if (currentrow < rows) + { + for (int i = 0; i < Rows(); i++) + { + fieldlist->push_back(fieldlists[currentrow][i]); + } + currentrow++; + } + return fieldlist; + } + + virtual SQLfieldMap* GetRowMapPtr() + { + fieldmap = new SQLfieldMap(); + + if (currentrow < rows) + { + for (int i = 0; i < Cols(); i++) + { + fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i))); + } + currentrow++; + } + + return fieldmap; + } + + virtual void Free(SQLfieldMap* fm) + { + delete fm; + } + + virtual void Free(SQLfieldList* fl) + { + delete fl; + } + + +}; + +class SQLConn : public classbase +{ + private: + ResultQueue results; + InspIRCd* Instance; + Module* mod; + SQLhost host; + sqlite3* conn; + + public: + SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi) + : Instance(SI), mod(m), host(hi) + { + if (OpenDB() != SQLITE_OK) + { + Instance->Log(DEFAULT, "WARNING: Could not open DB with id: " + host.id); + CloseDB(); + } + } + + ~SQLConn() + { + CloseDB(); + } + + SQLerror Query(SQLrequest &req) + { + /* Pointer to the buffer we screw around with substitution in */ + char* query; + + /* Pointer to the current end of query, where we append new stuff */ + char* queryend; + + /* Total length of the unescaped parameters */ + unsigned long paramlen; + + /* Total length of query, used for binary-safety in mysql_real_query */ + unsigned long querylength = 0; + + paramlen = 0; + for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++) + { + paramlen += i->size(); + } + + /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be. + * sizeofquery + (totalparamlength*2) + 1 + * + * The +1 is for null-terminating the string for mysql_real_escape_string + */ + query = new char[req.query.q.length() + (paramlen*2) + 1]; + queryend = query; + + for(unsigned long i = 0; i < req.query.q.length(); i++) + { + if(req.query.q[i] == '?') + { + if(req.query.p.size()) + { + char* escaped; + escaped = sqlite3_mprintf("%q", req.query.p.front().c_str()); + for (char* n = escaped; *n; n++) + { + *queryend = *n; + queryend++; + } + sqlite3_free(escaped); + req.query.p.pop_front(); + } + else + break; + } + else + { + *queryend = req.query.q[i]; + queryend++; + } + querylength++; + } + *queryend = 0; + req.query.q = query; + + SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id); + res->dbid = host.id; + res->query = req.query.q; + paramlist params; + params.push_back(this); + params.push_back(res); + + char *errmsg = 0; + sqlite3_update_hook(conn, QueryUpdateHook, ¶ms); + if (sqlite3_exec(conn, req.query.q.data(), QueryResult, ¶ms, &errmsg) != SQLITE_OK) + { + std::string error(errmsg); + sqlite3_free(errmsg); + delete[] query; + delete res; + return SQLerror(QSEND_FAIL, error); + } + delete[] query; + + results.push_back(res); + SendNotify(); + return SQLerror(); + } + + static int QueryResult(void *params, int argc, char **argv, char **azColName) + { + paramlist* p = (paramlist*)params; + ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName); + return 0; + } + + static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid) + { + paramlist* p = (paramlist*)params; + ((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1])); + } + + void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames) + { + res->AddRow(cols, data, colnames); + } + + void AffectedReady(SQLite3Result *res) + { + res->UpdateAffectedCount(); + } + + int OpenDB() + { + return sqlite3_open(host.host.c_str(), &conn); + } + + void CloseDB() + { + sqlite3_interrupt(conn); + sqlite3_close(conn); + } + + SQLhost GetConfHost() + { + return host; + } + + void SendResults() + { + while (results.size()) + { + SQLite3Result* res = results[0]; + if (res->GetDest()) + { + res->Send(); + } + else + { + /* If the client module is unloaded partway through a query then the provider will set + * the pointer to NULL. We cannot just cancel the query as the result will still come + * through at some point...and it could get messy if we play with invalid pointers... + */ + delete res; + } + results.pop_front(); + } + } + + void ClearResults() + { + while (results.size()) + { + SQLite3Result* res = results[0]; + delete res; + results.pop_front(); + } + } + + void SendNotify() + { + int QueueFD; + if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1) + { + /* crap, we're out of sockets... */ + return; + } + + insp_sockaddr addr; + +#ifdef IPV6 + insp_aton("::1", &addr.sin6_addr); + addr.sin6_family = AF_FAMILY; + addr.sin6_port = htons(resultnotify->GetPort()); +#else + insp_inaddr ia; + insp_aton("127.0.0.1", &ia); + addr.sin_family = AF_FAMILY; + addr.sin_addr = ia; + addr.sin_port = htons(resultnotify->GetPort()); +#endif + + if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1) + { + /* wtf, we cant connect to it, but we just created it! */ + return; + } + } + +}; + + +class ModuleSQLite3 : public Module +{ + private: + ConnMap connections; + unsigned long currid; + + public: + ModuleSQLite3(InspIRCd* Me) + : Module::Module(Me), currid(0) + { + ServerInstance->UseInterface("SQLutils"); + + if (!ServerInstance->PublishFeature("SQL", this)) + { + throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'"); + } + + resultnotify = new ResultNotifier(ServerInstance, this); + + ReadConf(); + + ServerInstance->PublishInterface("SQL", this); + } + + virtual ~ModuleSQLite3() + { + ClearQueue(); + ClearAllConnections(); + resultnotify->SetFd(-1); + resultnotify->state = I_ERROR; + resultnotify->OnError(I_ERR_SOCKET); + resultnotify->ClosePending = true; + delete resultnotify; + ServerInstance->UnpublishInterface("SQL", this); + ServerInstance->UnpublishFeature("SQL"); + ServerInstance->DoneWithInterface("SQLutils"); + } + + void Implements(char* List) + { + List[I_OnRequest] = List[I_OnRehash] = 1; + } + + void SendQueue() + { + for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) + { + iter->second->SendResults(); + } + } + + void ClearQueue() + { + for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) + { + iter->second->ClearResults(); + } + } + + bool HasHost(const SQLhost &host) + { + for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++) + { + if (host == iter->second->GetConfHost()) + return true; + } + return false; + } + + bool HostInConf(const SQLhost &h) + { + ConfigReader conf(ServerInstance); + for(int i = 0; i < conf.Enumerate("database"); i++) + { + SQLhost host; + host.id = conf.ReadValue("database", "id", i); + host.host = conf.ReadValue("database", "hostname", i); + host.port = conf.ReadInteger("database", "port", i, true); + host.name = conf.ReadValue("database", "name", i); + host.user = conf.ReadValue("database", "username", i); + host.pass = conf.ReadValue("database", "password", i); + host.ssl = conf.ReadFlag("database", "ssl", "0", i); + if (h == host) + return true; + } + return false; + } + + void ReadConf() + { + ClearOldConnections(); + + ConfigReader conf(ServerInstance); + for(int i = 0; i < conf.Enumerate("database"); i++) + { + SQLhost host; + + host.id = conf.ReadValue("database", "id", i); + host.host = conf.ReadValue("database", "hostname", i); + host.port = conf.ReadInteger("database", "port", i, true); + host.name = conf.ReadValue("database", "name", i); + host.user = conf.ReadValue("database", "username", i); + host.pass = conf.ReadValue("database", "password", i); + host.ssl = conf.ReadFlag("database", "ssl", "0", i); + + if (HasHost(host)) + continue; + + this->AddConn(host); + } + } + + void AddConn(const SQLhost& hi) + { + if (HasHost(hi)) + { + ServerInstance->Log(DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str()); + return; + } + + SQLConn* newconn; + + newconn = new SQLConn(ServerInstance, this, hi); + + connections.insert(std::make_pair(hi.id, newconn)); + } + + void ClearOldConnections() + { + ConnMap::iterator iter,safei; + for (iter = connections.begin(); iter != connections.end(); iter++) + { + if (!HostInConf(iter->second->GetConfHost())) + { + DELETE(iter->second); + safei = iter; + --iter; + connections.erase(safei); + } + } + } + + void ClearAllConnections() + { + ConnMap::iterator i; + while ((i = connections.begin()) != connections.end()) + { + connections.erase(i); + DELETE(i->second); + } + } + + virtual void OnRehash(userrec* user, const std::string ¶meter) + { + ReadConf(); + } + + virtual char* OnRequest(Request* request) + { + if(strcmp(SQLREQID, request->GetId()) == 0) + { + SQLrequest* req = (SQLrequest*)request; + ConnMap::iterator iter; + if((iter = connections.find(req->dbid)) != connections.end()) + { + req->id = NewID(); + req->error = iter->second->Query(*req); + return SQLSUCCESS; + } + else + { + req->error.Id(BAD_DBID); + return NULL; + } + } + return NULL; + } + + unsigned long NewID() + { + if (currid+1 == 0) + currid++; + + return ++currid; + } + + virtual Version GetVersion() + { + return Version(1,1,0,0,VF_VENDOR|VF_SERVICEPROVIDER,API_VERSION); + } + +}; + +void ResultNotifier::Dispatch() +{ + ((ModuleSQLite3*)mod)->SendQueue(); +} + +MODULE_INIT(ModuleSQLite3); + diff --git a/src/modules/extra/m_sqllog.cpp b/src/modules/extra/m_sqllog.cpp index 04eb1fef1..391e4bbba 100644 --- a/src/modules/extra/m_sqllog.cpp +++ b/src/modules/extra/m_sqllog.cpp @@ -1 +1,310 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "configreader.h"
#include "m_sqlv2.h"
static Module* SQLModule;
static Module* MyMod;
static std::string dbid;
enum LogTypes { LT_OPER = 1, LT_KILL, LT_SERVLINK, LT_XLINE, LT_CONNECT, LT_DISCONNECT, LT_FLOOD, LT_LOADMODULE };
enum QueryState { FIND_SOURCE, FIND_NICK, FIND_HOST, DONE};
class QueryInfo;
std::map<unsigned long,QueryInfo*> active_queries;
class QueryInfo
{
public:
QueryState qs;
unsigned long id;
std::string nick;
std::string source;
std::string hostname;
int sourceid;
int nickid;
int hostid;
int category;
time_t date;
bool insert;
QueryInfo(const std::string &n, const std::string &s, const std::string &h, unsigned long i, int cat)
{
qs = FIND_SOURCE;
nick = n;
source = s;
hostname = h;
id = i;
category = cat;
sourceid = nickid = hostid = -1;
date = time(NULL);
insert = false;
}
void Go(SQLresult* res)
{
SQLrequest req = SQLreq(MyMod, SQLModule, dbid, "", "");
switch (qs)
{
case FIND_SOURCE:
if (res->Rows() && sourceid == -1 && !insert)
{
sourceid = atoi(res->GetValue(0,0).d.c_str());
req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", nick);
if(req.Send())
{
insert = false;
qs = FIND_NICK;
active_queries[req.id] = this;
}
}
else if (res->Rows() && sourceid == -1 && insert)
{
req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", source);
if(req.Send())
{
insert = false;
qs = FIND_SOURCE;
active_queries[req.id] = this;
}
}
else
{
req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_actors (actor) VALUES('?')", source);
if(req.Send())
{
insert = true;
qs = FIND_SOURCE;
active_queries[req.id] = this;
}
}
break;
case FIND_NICK:
if (res->Rows() && nickid == -1 && !insert)
{
nickid = atoi(res->GetValue(0,0).d.c_str());
req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,hostname FROM ircd_log_hosts WHERE hostname='?'", hostname);
if(req.Send())
{
insert = false;
qs = FIND_HOST;
active_queries[req.id] = this;
}
}
else if (res->Rows() && nickid == -1 && insert)
{
req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", nick);
if(req.Send())
{
insert = false;
qs = FIND_NICK;
active_queries[req.id] = this;
}
}
else
{
req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_actors (actor) VALUES('?')",nick);
if(req.Send())
{
insert = true;
qs = FIND_NICK;
active_queries[req.id] = this;
}
}
break;
case FIND_HOST:
if (res->Rows() && hostid == -1 && !insert)
{
hostid = atoi(res->GetValue(0,0).d.c_str());
req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log (category_id,nick,host,source,dtime) VALUES("+ConvToStr(category)+","+ConvToStr(nickid)+","+ConvToStr(hostid)+","+ConvToStr(sourceid)+","+ConvToStr(date)+")");
if(req.Send())
{
insert = true;
qs = DONE;
active_queries[req.id] = this;
}
}
else if (res->Rows() && hostid == -1 && insert)
{
req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,hostname FROM ircd_log_hosts WHERE hostname='?'", hostname);
if(req.Send())
{
insert = false;
qs = FIND_HOST;
active_queries[req.id] = this;
}
}
else
{
req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_hosts (hostname) VALUES('?')", hostname);
if(req.Send())
{
insert = true;
qs = FIND_HOST;
active_queries[req.id] = this;
}
}
break;
case DONE:
delete active_queries[req.id];
active_queries[req.id] = NULL;
break;
}
}
};
/* $ModDesc: Logs network-wide data to an SQL database */
class ModuleSQLLog : public Module
{
ConfigReader* Conf;
public:
ModuleSQLLog(InspIRCd* Me)
: Module::Module(Me)
{
ServerInstance->UseInterface("SQLutils");
ServerInstance->UseInterface("SQL");
Module* SQLutils = ServerInstance->FindModule("m_sqlutils.so");
if (!SQLutils)
throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so.");
SQLModule = ServerInstance->FindFeature("SQL");
OnRehash(NULL,"");
MyMod = this;
active_queries.clear();
}
virtual ~ModuleSQLLog()
{
ServerInstance->DoneWithInterface("SQL");
ServerInstance->DoneWithInterface("SQLutils");
}
void Implements(char* List)
{
List[I_OnRehash] = List[I_OnOper] = List[I_OnGlobalOper] = List[I_OnKill] = 1;
List[I_OnPreCommand] = List[I_OnUserConnect] = 1;
List[I_OnUserQuit] = List[I_OnLoadModule] = List[I_OnRequest] = 1;
}
void ReadConfig()
{
ConfigReader Conf(ServerInstance);
dbid = Conf.ReadValue("sqllog","dbid",0); // database id of a database configured in sql module
}
virtual void OnRehash(userrec* user, const std::string ¶meter)
{
ReadConfig();
}
virtual char* OnRequest(Request* request)
{
if(strcmp(SQLRESID, request->GetId()) == 0)
{
SQLresult* res;
std::map<unsigned long, QueryInfo*>::iterator n;
res = static_cast<SQLresult*>(request);
n = active_queries.find(res->id);
if (n != active_queries.end())
{
n->second->Go(res);
std::map<unsigned long, QueryInfo*>::iterator n = active_queries.find(res->id);
active_queries.erase(n);
}
return SQLSUCCESS;
}
return NULL;
}
void AddLogEntry(int category, const std::string &nick, const std::string &host, const std::string &source)
{
// is the sql module loaded? If not, we don't attempt to do anything.
if (!SQLModule)
return;
SQLrequest req = SQLreq(this, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", source);
if(req.Send())
{
QueryInfo* i = new QueryInfo(nick, source, host, req.id, category);
i->qs = FIND_SOURCE;
active_queries[req.id] = i;
}
}
virtual void OnOper(userrec* user, const std::string &opertype)
{
AddLogEntry(LT_OPER,user->nick,user->host,user->server);
}
virtual void OnGlobalOper(userrec* user)
{
AddLogEntry(LT_OPER,user->nick,user->host,user->server);
}
virtual int OnKill(userrec* source, userrec* dest, const std::string &reason)
{
AddLogEntry(LT_KILL,dest->nick,dest->host,source->nick);
return 0;
}
virtual int OnPreCommand(const std::string &command, const char** parameters, int pcnt, userrec *user, bool validated, const std::string &original_line)
{
if ((command == "GLINE" || command == "KLINE" || command == "ELINE" || command == "ZLINE") && validated)
{
AddLogEntry(LT_XLINE,user->nick,command[0]+std::string(":")+std::string(parameters[0]),user->server);
}
return 0;
}
virtual void OnUserConnect(userrec* user)
{
AddLogEntry(LT_CONNECT,user->nick,user->host,user->server);
}
virtual void OnUserQuit(userrec* user, const std::string &reason, const std::string &oper_message)
{
AddLogEntry(LT_DISCONNECT,user->nick,user->host,user->server);
}
virtual void OnLoadModule(Module* mod, const std::string &name)
{
AddLogEntry(LT_LOADMODULE,name,ServerInstance->Config->ServerName, ServerInstance->Config->ServerName);
}
virtual Version GetVersion()
{
return Version(1,1,0,1,VF_VENDOR,API_VERSION);
}
};
MODULE_INIT(ModuleSQLLog);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "configreader.h" +#include "m_sqlv2.h" + +static Module* SQLModule; +static Module* MyMod; +static std::string dbid; + +enum LogTypes { LT_OPER = 1, LT_KILL, LT_SERVLINK, LT_XLINE, LT_CONNECT, LT_DISCONNECT, LT_FLOOD, LT_LOADMODULE }; + +enum QueryState { FIND_SOURCE, FIND_NICK, FIND_HOST, DONE}; + +class QueryInfo; + +std::map<unsigned long,QueryInfo*> active_queries; + +class QueryInfo +{ +public: + QueryState qs; + unsigned long id; + std::string nick; + std::string source; + std::string hostname; + int sourceid; + int nickid; + int hostid; + int category; + time_t date; + bool insert; + + QueryInfo(const std::string &n, const std::string &s, const std::string &h, unsigned long i, int cat) + { + qs = FIND_SOURCE; + nick = n; + source = s; + hostname = h; + id = i; + category = cat; + sourceid = nickid = hostid = -1; + date = time(NULL); + insert = false; + } + + void Go(SQLresult* res) + { + SQLrequest req = SQLreq(MyMod, SQLModule, dbid, "", ""); + switch (qs) + { + case FIND_SOURCE: + if (res->Rows() && sourceid == -1 && !insert) + { + sourceid = atoi(res->GetValue(0,0).d.c_str()); + req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", nick); + if(req.Send()) + { + insert = false; + qs = FIND_NICK; + active_queries[req.id] = this; + } + } + else if (res->Rows() && sourceid == -1 && insert) + { + req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", source); + if(req.Send()) + { + insert = false; + qs = FIND_SOURCE; + active_queries[req.id] = this; + } + } + else + { + req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_actors (actor) VALUES('?')", source); + if(req.Send()) + { + insert = true; + qs = FIND_SOURCE; + active_queries[req.id] = this; + } + } + break; + + case FIND_NICK: + if (res->Rows() && nickid == -1 && !insert) + { + nickid = atoi(res->GetValue(0,0).d.c_str()); + req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,hostname FROM ircd_log_hosts WHERE hostname='?'", hostname); + if(req.Send()) + { + insert = false; + qs = FIND_HOST; + active_queries[req.id] = this; + } + } + else if (res->Rows() && nickid == -1 && insert) + { + req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", nick); + if(req.Send()) + { + insert = false; + qs = FIND_NICK; + active_queries[req.id] = this; + } + } + else + { + req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_actors (actor) VALUES('?')",nick); + if(req.Send()) + { + insert = true; + qs = FIND_NICK; + active_queries[req.id] = this; + } + } + break; + + case FIND_HOST: + if (res->Rows() && hostid == -1 && !insert) + { + hostid = atoi(res->GetValue(0,0).d.c_str()); + req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log (category_id,nick,host,source,dtime) VALUES("+ConvToStr(category)+","+ConvToStr(nickid)+","+ConvToStr(hostid)+","+ConvToStr(sourceid)+","+ConvToStr(date)+")"); + if(req.Send()) + { + insert = true; + qs = DONE; + active_queries[req.id] = this; + } + } + else if (res->Rows() && hostid == -1 && insert) + { + req = SQLreq(MyMod, SQLModule, dbid, "SELECT id,hostname FROM ircd_log_hosts WHERE hostname='?'", hostname); + if(req.Send()) + { + insert = false; + qs = FIND_HOST; + active_queries[req.id] = this; + } + } + else + { + req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_hosts (hostname) VALUES('?')", hostname); + if(req.Send()) + { + insert = true; + qs = FIND_HOST; + active_queries[req.id] = this; + } + } + break; + + case DONE: + delete active_queries[req.id]; + active_queries[req.id] = NULL; + break; + } + } +}; + +/* $ModDesc: Logs network-wide data to an SQL database */ + +class ModuleSQLLog : public Module +{ + ConfigReader* Conf; + + public: + ModuleSQLLog(InspIRCd* Me) + : Module::Module(Me) + { + ServerInstance->UseInterface("SQLutils"); + ServerInstance->UseInterface("SQL"); + + Module* SQLutils = ServerInstance->FindModule("m_sqlutils.so"); + if (!SQLutils) + throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so."); + + SQLModule = ServerInstance->FindFeature("SQL"); + + OnRehash(NULL,""); + MyMod = this; + active_queries.clear(); + } + + virtual ~ModuleSQLLog() + { + ServerInstance->DoneWithInterface("SQL"); + ServerInstance->DoneWithInterface("SQLutils"); + } + + void Implements(char* List) + { + List[I_OnRehash] = List[I_OnOper] = List[I_OnGlobalOper] = List[I_OnKill] = 1; + List[I_OnPreCommand] = List[I_OnUserConnect] = 1; + List[I_OnUserQuit] = List[I_OnLoadModule] = List[I_OnRequest] = 1; + } + + void ReadConfig() + { + ConfigReader Conf(ServerInstance); + dbid = Conf.ReadValue("sqllog","dbid",0); // database id of a database configured in sql module + } + + virtual void OnRehash(userrec* user, const std::string ¶meter) + { + ReadConfig(); + } + + virtual char* OnRequest(Request* request) + { + if(strcmp(SQLRESID, request->GetId()) == 0) + { + SQLresult* res; + std::map<unsigned long, QueryInfo*>::iterator n; + + res = static_cast<SQLresult*>(request); + n = active_queries.find(res->id); + + if (n != active_queries.end()) + { + n->second->Go(res); + std::map<unsigned long, QueryInfo*>::iterator n = active_queries.find(res->id); + active_queries.erase(n); + } + + return SQLSUCCESS; + } + + return NULL; + } + + void AddLogEntry(int category, const std::string &nick, const std::string &host, const std::string &source) + { + // is the sql module loaded? If not, we don't attempt to do anything. + if (!SQLModule) + return; + + SQLrequest req = SQLreq(this, SQLModule, dbid, "SELECT id,actor FROM ircd_log_actors WHERE actor='?'", source); + if(req.Send()) + { + QueryInfo* i = new QueryInfo(nick, source, host, req.id, category); + i->qs = FIND_SOURCE; + active_queries[req.id] = i; + } + } + + virtual void OnOper(userrec* user, const std::string &opertype) + { + AddLogEntry(LT_OPER,user->nick,user->host,user->server); + } + + virtual void OnGlobalOper(userrec* user) + { + AddLogEntry(LT_OPER,user->nick,user->host,user->server); + } + + virtual int OnKill(userrec* source, userrec* dest, const std::string &reason) + { + AddLogEntry(LT_KILL,dest->nick,dest->host,source->nick); + return 0; + } + + virtual int OnPreCommand(const std::string &command, const char** parameters, int pcnt, userrec *user, bool validated, const std::string &original_line) + { + if ((command == "GLINE" || command == "KLINE" || command == "ELINE" || command == "ZLINE") && validated) + { + AddLogEntry(LT_XLINE,user->nick,command[0]+std::string(":")+std::string(parameters[0]),user->server); + } + return 0; + } + + virtual void OnUserConnect(userrec* user) + { + AddLogEntry(LT_CONNECT,user->nick,user->host,user->server); + } + + virtual void OnUserQuit(userrec* user, const std::string &reason, const std::string &oper_message) + { + AddLogEntry(LT_DISCONNECT,user->nick,user->host,user->server); + } + + virtual void OnLoadModule(Module* mod, const std::string &name) + { + AddLogEntry(LT_LOADMODULE,name,ServerInstance->Config->ServerName, ServerInstance->Config->ServerName); + } + + virtual Version GetVersion() + { + return Version(1,1,0,1,VF_VENDOR,API_VERSION); + } + +}; + +MODULE_INIT(ModuleSQLLog); + diff --git a/src/modules/extra/m_sqloper.cpp b/src/modules/extra/m_sqloper.cpp index 4b09ac26e..520869e21 100644 --- a/src/modules/extra/m_sqloper.cpp +++ b/src/modules/extra/m_sqloper.cpp @@ -1 +1,283 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "configreader.h"
#include "m_sqlv2.h"
#include "m_sqlutils.h"
#include "m_hash.h"
#include "commands/cmd_oper.h"
/* $ModDesc: Allows storage of oper credentials in an SQL table */
/* $ModDep: m_sqlv2.h m_sqlutils.h */
class ModuleSQLOper : public Module
{
Module* SQLutils;
Module* HashModule;
std::string databaseid;
public:
ModuleSQLOper(InspIRCd* Me)
: Module::Module(Me)
{
ServerInstance->UseInterface("SQLutils");
ServerInstance->UseInterface("SQL");
ServerInstance->UseInterface("HashRequest");
/* Attempt to locate the md5 service provider, bail if we can't find it */
HashModule = ServerInstance->FindModule("m_md5.so");
if (!HashModule)
throw ModuleException("Can't find m_md5.so. Please load m_md5.so before m_sqloper.so.");
SQLutils = ServerInstance->FindModule("m_sqlutils.so");
if (!SQLutils)
throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqloper.so.");
OnRehash(NULL,"");
}
virtual ~ModuleSQLOper()
{
ServerInstance->DoneWithInterface("SQL");
ServerInstance->DoneWithInterface("SQLutils");
ServerInstance->DoneWithInterface("HashRequest");
}
void Implements(char* List)
{
List[I_OnRequest] = List[I_OnRehash] = List[I_OnPreCommand] = 1;
}
virtual void OnRehash(userrec* user, const std::string ¶meter)
{
ConfigReader Conf(ServerInstance);
databaseid = Conf.ReadValue("sqloper", "dbid", 0); /* Database ID of a database configured for the service provider module */
}
virtual int OnPreCommand(const std::string &command, const char** parameters, int pcnt, userrec *user, bool validated, const std::string &original_line)
{
if ((validated) && (command == "OPER"))
{
if (LookupOper(user, parameters[0], parameters[1]))
{
/* Returning true here just means the query is in progress, or on it's way to being
* in progress. Nothing about the /oper actually being successful..
* If the oper lookup fails later, we pass the command to the original handler
* for /oper by calling its Handle method directly.
*/
return 1;
}
}
return 0;
}
bool LookupOper(userrec* user, const std::string &username, const std::string &password)
{
Module* target;
target = ServerInstance->FindFeature("SQL");
if (target)
{
/* Reset hash module first back to MD5 standard state */
HashResetRequest(this, HashModule).Send();
/* Make an MD5 hash of the password for using in the query */
std::string md5_pass_hash = HashSumRequest(this, HashModule, password.c_str()).Send();
/* We generate our own MD5 sum here because some database providers (e.g. SQLite) dont have a builtin md5 function,
* also hashing it in the module and only passing a remote query containing a hash is more secure.
*/
SQLrequest req = SQLreq(this, target, databaseid, "SELECT username, password, hostname, type FROM ircd_opers WHERE username = '?' AND password='?'", username, md5_pass_hash);
if (req.Send())
{
/* When we get the query response from the service provider we will be given an ID to play with,
* just an ID number which is unique to this query. We need a way of associating that ID with a userrec
* so we insert it into a map mapping the IDs to users.
* Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the
* association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling
* us to discard the query.
*/
AssociateUser(this, SQLutils, req.id, user).Send();
user->Extend("oper_user", strdup(username.c_str()));
user->Extend("oper_pass", strdup(password.c_str()));
return true;
}
else
{
return false;
}
}
else
{
ServerInstance->Log(SPARSE, "WARNING: Couldn't find SQL provider module. NOBODY will be able to oper up unless their o:line is statically configured");
return false;
}
}
virtual char* OnRequest(Request* request)
{
if (strcmp(SQLRESID, request->GetId()) == 0)
{
SQLresult* res = static_cast<SQLresult*>(request);
userrec* user = GetAssocUser(this, SQLutils, res->id).S().user;
UnAssociate(this, SQLutils, res->id).S();
char* tried_user = NULL;
char* tried_pass = NULL;
user->GetExt("oper_user", tried_user);
user->GetExt("oper_pass", tried_pass);
if (user)
{
if (res->error.Id() == NO_ERROR)
{
if (res->Rows())
{
/* We got a row in the result, this means there was a record for the oper..
* now we just need to check if their host matches, and if it does then
* oper them up.
*
* We now (previous versions of the module didn't) support multiple SQL
* rows per-oper in the same way the config file does, all rows will be tried
* until one is found which matches. This is useful to define several different
* hosts for a single oper.
*
* The for() loop works as SQLresult::GetRowMap() returns an empty map when there
* are no more rows to return.
*/
for (SQLfieldMap& row = res->GetRowMap(); row.size(); row = res->GetRowMap())
{
if (OperUser(user, row["username"].d, row["password"].d, row["hostname"].d, row["type"].d))
{
/* If/when one of the rows matches, stop checking and return */
return SQLSUCCESS;
}
if (tried_user && tried_pass)
{
LoginFail(user, tried_user, tried_pass);
free(tried_user);
free(tried_pass);
user->Shrink("oper_user");
user->Shrink("oper_pass");
}
}
}
else
{
/* No rows in result, this means there was no oper line for the user,
* we should have already checked the o:lines so now we need an
* "insufficient awesomeness" (invalid credentials) error
*/
if (tried_user && tried_pass)
{
LoginFail(user, tried_user, tried_pass);
free(tried_user);
free(tried_pass);
user->Shrink("oper_user");
user->Shrink("oper_pass");
}
}
}
else
{
/* This one shouldn't happen, the query failed for some reason.
* We have to fail the /oper request and give them the same error
* as above.
*/
if (tried_user && tried_pass)
{
LoginFail(user, tried_user, tried_pass);
free(tried_user);
free(tried_pass);
user->Shrink("oper_user");
user->Shrink("oper_pass");
}
}
}
return SQLSUCCESS;
}
return NULL;
}
void LoginFail(userrec* user, const std::string &username, const std::string &pass)
{
command_t* oper_command = ServerInstance->Parser->GetHandler("OPER");
if (oper_command)
{
const char* params[] = { username.c_str(), pass.c_str() };
oper_command->Handle(params, 2, user);
}
else
{
ServerInstance->Log(DEBUG, "BUG: WHAT?! Why do we have no OPER command?!");
}
}
bool OperUser(userrec* user, const std::string &username, const std::string &password, const std::string &pattern, const std::string &type)
{
ConfigReader Conf(ServerInstance);
for (int j = 0; j < Conf.Enumerate("type"); j++)
{
std::string tname = Conf.ReadValue("type","name",j);
std::string hostname(user->ident);
hostname.append("@").append(user->host);
if ((tname == type) && OneOfMatches(hostname.c_str(), user->GetIPString(), pattern.c_str()))
{
/* Opertype and host match, looks like this is it. */
std::string operhost = Conf.ReadValue("type", "host", j);
if (operhost.size())
user->ChangeDisplayedHost(operhost.c_str());
ServerInstance->SNO->WriteToSnoMask('o',"%s (%s@%s) is now an IRC operator of type %s", user->nick, user->ident, user->host, type.c_str());
user->WriteServ("381 %s :You are now an IRC operator of type %s", user->nick, type.c_str());
if (!user->modes[UM_OPERATOR])
user->Oper(type);
return true;
}
}
return false;
}
virtual Version GetVersion()
{
return Version(1,1,1,0,VF_VENDOR,API_VERSION);
}
};
MODULE_INIT(ModuleSQLOper);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "configreader.h" + +#include "m_sqlv2.h" +#include "m_sqlutils.h" +#include "m_hash.h" +#include "commands/cmd_oper.h" + +/* $ModDesc: Allows storage of oper credentials in an SQL table */ +/* $ModDep: m_sqlv2.h m_sqlutils.h */ + +class ModuleSQLOper : public Module +{ + Module* SQLutils; + Module* HashModule; + std::string databaseid; + +public: + ModuleSQLOper(InspIRCd* Me) + : Module::Module(Me) + { + ServerInstance->UseInterface("SQLutils"); + ServerInstance->UseInterface("SQL"); + ServerInstance->UseInterface("HashRequest"); + + /* Attempt to locate the md5 service provider, bail if we can't find it */ + HashModule = ServerInstance->FindModule("m_md5.so"); + if (!HashModule) + throw ModuleException("Can't find m_md5.so. Please load m_md5.so before m_sqloper.so."); + + SQLutils = ServerInstance->FindModule("m_sqlutils.so"); + if (!SQLutils) + throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqloper.so."); + + OnRehash(NULL,""); + } + + virtual ~ModuleSQLOper() + { + ServerInstance->DoneWithInterface("SQL"); + ServerInstance->DoneWithInterface("SQLutils"); + ServerInstance->DoneWithInterface("HashRequest"); + } + + void Implements(char* List) + { + List[I_OnRequest] = List[I_OnRehash] = List[I_OnPreCommand] = 1; + } + + virtual void OnRehash(userrec* user, const std::string ¶meter) + { + ConfigReader Conf(ServerInstance); + + databaseid = Conf.ReadValue("sqloper", "dbid", 0); /* Database ID of a database configured for the service provider module */ + } + + virtual int OnPreCommand(const std::string &command, const char** parameters, int pcnt, userrec *user, bool validated, const std::string &original_line) + { + if ((validated) && (command == "OPER")) + { + if (LookupOper(user, parameters[0], parameters[1])) + { + /* Returning true here just means the query is in progress, or on it's way to being + * in progress. Nothing about the /oper actually being successful.. + * If the oper lookup fails later, we pass the command to the original handler + * for /oper by calling its Handle method directly. + */ + return 1; + } + } + return 0; + } + + bool LookupOper(userrec* user, const std::string &username, const std::string &password) + { + Module* target; + + target = ServerInstance->FindFeature("SQL"); + + if (target) + { + /* Reset hash module first back to MD5 standard state */ + HashResetRequest(this, HashModule).Send(); + /* Make an MD5 hash of the password for using in the query */ + std::string md5_pass_hash = HashSumRequest(this, HashModule, password.c_str()).Send(); + + /* We generate our own MD5 sum here because some database providers (e.g. SQLite) dont have a builtin md5 function, + * also hashing it in the module and only passing a remote query containing a hash is more secure. + */ + + SQLrequest req = SQLreq(this, target, databaseid, "SELECT username, password, hostname, type FROM ircd_opers WHERE username = '?' AND password='?'", username, md5_pass_hash); + + if (req.Send()) + { + /* When we get the query response from the service provider we will be given an ID to play with, + * just an ID number which is unique to this query. We need a way of associating that ID with a userrec + * so we insert it into a map mapping the IDs to users. + * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the + * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling + * us to discard the query. + */ + AssociateUser(this, SQLutils, req.id, user).Send(); + + user->Extend("oper_user", strdup(username.c_str())); + user->Extend("oper_pass", strdup(password.c_str())); + + return true; + } + else + { + return false; + } + } + else + { + ServerInstance->Log(SPARSE, "WARNING: Couldn't find SQL provider module. NOBODY will be able to oper up unless their o:line is statically configured"); + return false; + } + } + + virtual char* OnRequest(Request* request) + { + if (strcmp(SQLRESID, request->GetId()) == 0) + { + SQLresult* res = static_cast<SQLresult*>(request); + + userrec* user = GetAssocUser(this, SQLutils, res->id).S().user; + UnAssociate(this, SQLutils, res->id).S(); + + char* tried_user = NULL; + char* tried_pass = NULL; + + user->GetExt("oper_user", tried_user); + user->GetExt("oper_pass", tried_pass); + + if (user) + { + if (res->error.Id() == NO_ERROR) + { + if (res->Rows()) + { + /* We got a row in the result, this means there was a record for the oper.. + * now we just need to check if their host matches, and if it does then + * oper them up. + * + * We now (previous versions of the module didn't) support multiple SQL + * rows per-oper in the same way the config file does, all rows will be tried + * until one is found which matches. This is useful to define several different + * hosts for a single oper. + * + * The for() loop works as SQLresult::GetRowMap() returns an empty map when there + * are no more rows to return. + */ + + for (SQLfieldMap& row = res->GetRowMap(); row.size(); row = res->GetRowMap()) + { + if (OperUser(user, row["username"].d, row["password"].d, row["hostname"].d, row["type"].d)) + { + /* If/when one of the rows matches, stop checking and return */ + return SQLSUCCESS; + } + if (tried_user && tried_pass) + { + LoginFail(user, tried_user, tried_pass); + free(tried_user); + free(tried_pass); + user->Shrink("oper_user"); + user->Shrink("oper_pass"); + } + } + } + else + { + /* No rows in result, this means there was no oper line for the user, + * we should have already checked the o:lines so now we need an + * "insufficient awesomeness" (invalid credentials) error + */ + if (tried_user && tried_pass) + { + LoginFail(user, tried_user, tried_pass); + free(tried_user); + free(tried_pass); + user->Shrink("oper_user"); + user->Shrink("oper_pass"); + } + } + } + else + { + /* This one shouldn't happen, the query failed for some reason. + * We have to fail the /oper request and give them the same error + * as above. + */ + if (tried_user && tried_pass) + { + LoginFail(user, tried_user, tried_pass); + free(tried_user); + free(tried_pass); + user->Shrink("oper_user"); + user->Shrink("oper_pass"); + } + + } + } + + return SQLSUCCESS; + } + + return NULL; + } + + void LoginFail(userrec* user, const std::string &username, const std::string &pass) + { + command_t* oper_command = ServerInstance->Parser->GetHandler("OPER"); + + if (oper_command) + { + const char* params[] = { username.c_str(), pass.c_str() }; + oper_command->Handle(params, 2, user); + } + else + { + ServerInstance->Log(DEBUG, "BUG: WHAT?! Why do we have no OPER command?!"); + } + } + + bool OperUser(userrec* user, const std::string &username, const std::string &password, const std::string &pattern, const std::string &type) + { + ConfigReader Conf(ServerInstance); + + for (int j = 0; j < Conf.Enumerate("type"); j++) + { + std::string tname = Conf.ReadValue("type","name",j); + std::string hostname(user->ident); + + hostname.append("@").append(user->host); + + if ((tname == type) && OneOfMatches(hostname.c_str(), user->GetIPString(), pattern.c_str())) + { + /* Opertype and host match, looks like this is it. */ + std::string operhost = Conf.ReadValue("type", "host", j); + + if (operhost.size()) + user->ChangeDisplayedHost(operhost.c_str()); + + ServerInstance->SNO->WriteToSnoMask('o',"%s (%s@%s) is now an IRC operator of type %s", user->nick, user->ident, user->host, type.c_str()); + user->WriteServ("381 %s :You are now an IRC operator of type %s", user->nick, type.c_str()); + + if (!user->modes[UM_OPERATOR]) + user->Oper(type); + + return true; + } + } + + return false; + } + + virtual Version GetVersion() + { + return Version(1,1,1,0,VF_VENDOR,API_VERSION); + } + +}; + +MODULE_INIT(ModuleSQLOper); + diff --git a/src/modules/extra/m_sqlutils.cpp b/src/modules/extra/m_sqlutils.cpp index 6cd09252b..b470f99af 100644 --- a/src/modules/extra/m_sqlutils.cpp +++ b/src/modules/extra/m_sqlutils.cpp @@ -1 +1,238 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include <sstream>
#include <list>
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "configreader.h"
#include "m_sqlutils.h"
/* $ModDesc: Provides some utilities to SQL client modules, such as mapping queries to users and channels */
/* $ModDep: m_sqlutils.h */
typedef std::map<unsigned long, userrec*> IdUserMap;
typedef std::map<unsigned long, chanrec*> IdChanMap;
typedef std::list<unsigned long> AssocIdList;
class ModuleSQLutils : public Module
{
private:
IdUserMap iduser;
IdChanMap idchan;
public:
ModuleSQLutils(InspIRCd* Me)
: Module::Module(Me)
{
ServerInstance->PublishInterface("SQLutils", this);
}
virtual ~ModuleSQLutils()
{
ServerInstance->UnpublishInterface("SQLutils", this);
}
void Implements(char* List)
{
List[I_OnChannelDelete] = List[I_OnUnloadModule] = List[I_OnRequest] = List[I_OnUserDisconnect] = 1;
}
virtual char* OnRequest(Request* request)
{
if(strcmp(SQLUTILAU, request->GetId()) == 0)
{
AssociateUser* req = (AssociateUser*)request;
iduser.insert(std::make_pair(req->id, req->user));
AttachList(req->user, req->id);
}
else if(strcmp(SQLUTILAC, request->GetId()) == 0)
{
AssociateChan* req = (AssociateChan*)request;
idchan.insert(std::make_pair(req->id, req->chan));
AttachList(req->chan, req->id);
}
else if(strcmp(SQLUTILUA, request->GetId()) == 0)
{
UnAssociate* req = (UnAssociate*)request;
/* Unassociate a given query ID with all users and channels
* it is associated with.
*/
DoUnAssociate(iduser, req->id);
DoUnAssociate(idchan, req->id);
}
else if(strcmp(SQLUTILGU, request->GetId()) == 0)
{
GetAssocUser* req = (GetAssocUser*)request;
IdUserMap::iterator iter = iduser.find(req->id);
if(iter != iduser.end())
{
req->user = iter->second;
}
}
else if(strcmp(SQLUTILGC, request->GetId()) == 0)
{
GetAssocChan* req = (GetAssocChan*)request;
IdChanMap::iterator iter = idchan.find(req->id);
if(iter != idchan.end())
{
req->chan = iter->second;
}
}
return SQLUTILSUCCESS;
}
virtual void OnUserDisconnect(userrec* user)
{
/* A user is disconnecting, first we need to check if they have a list of queries associated with them.
* Then, if they do, we need to erase each of them from our IdUserMap (iduser) so when the module that
* associated them asks to look them up then it gets a NULL result and knows to discard the query.
*/
AssocIdList* il;
if(user->GetExt("sqlutils_queryids", il))
{
for(AssocIdList::iterator listiter = il->begin(); listiter != il->end(); listiter++)
{
IdUserMap::iterator iter;
iter = iduser.find(*listiter);
if(iter != iduser.end())
{
if(iter->second != user)
{
ServerInstance->Log(DEBUG, "BUG: ID associated with user %s doesn't have the same userrec* associated with it in the map (erasing anyway)", user->nick);
}
iduser.erase(iter);
}
else
{
ServerInstance->Log(DEBUG, "BUG: user %s was extended with sqlutils_queryids but there was nothing matching in the map", user->nick);
}
}
user->Shrink("sqlutils_queryids");
delete il;
}
}
void AttachList(Extensible* obj, unsigned long id)
{
AssocIdList* il;
if(!obj->GetExt("sqlutils_queryids", il))
{
/* Doesn't already exist, create a new list and attach it. */
il = new AssocIdList;
obj->Extend("sqlutils_queryids", il);
}
/* Now either way we have a valid list in il, attached. */
il->push_back(id);
}
void RemoveFromList(Extensible* obj, unsigned long id)
{
AssocIdList* il;
if(obj->GetExt("sqlutils_queryids", il))
{
/* Only do anything if the list exists... (which it ought to) */
il->remove(id);
if(il->empty())
{
/* If we just emptied it.. */
delete il;
obj->Shrink("sqlutils_queryids");
}
}
}
template <class T> void DoUnAssociate(T &map, unsigned long id)
{
/* For each occurence of 'id' (well, only one..it's not a multimap) in 'map'
* remove it from the map, take an Extensible* value from the map and remove
* 'id' from the list of query IDs attached to it.
*/
typename T::iterator iter = map.find(id);
if(iter != map.end())
{
/* Found a value indexed by 'id', call RemoveFromList()
* on it with 'id' to remove 'id' from the list attached
* to the value.
*/
RemoveFromList(iter->second, id);
}
}
virtual void OnChannelDelete(chanrec* chan)
{
/* A channel is being destroyed, first we need to check if it has a list of queries associated with it.
* Then, if it does, we need to erase each of them from our IdChanMap (idchan) so when the module that
* associated them asks to look them up then it gets a NULL result and knows to discard the query.
*/
AssocIdList* il;
if(chan->GetExt("sqlutils_queryids", il))
{
for(AssocIdList::iterator listiter = il->begin(); listiter != il->end(); listiter++)
{
IdChanMap::iterator iter;
iter = idchan.find(*listiter);
if(iter != idchan.end())
{
if(iter->second != chan)
{
ServerInstance->Log(DEBUG, "BUG: ID associated with channel %s doesn't have the same chanrec* associated with it in the map (erasing anyway)", chan->name);
}
idchan.erase(iter);
}
else
{
ServerInstance->Log(DEBUG, "BUG: channel %s was extended with sqlutils_queryids but there was nothing matching in the map", chan->name);
}
}
chan->Shrink("sqlutils_queryids");
delete il;
}
}
virtual Version GetVersion()
{
return Version(1, 1, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER, API_VERSION);
}
};
MODULE_INIT(ModuleSQLutils);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" +#include <sstream> +#include <list> +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "configreader.h" +#include "m_sqlutils.h" + +/* $ModDesc: Provides some utilities to SQL client modules, such as mapping queries to users and channels */ +/* $ModDep: m_sqlutils.h */ + +typedef std::map<unsigned long, userrec*> IdUserMap; +typedef std::map<unsigned long, chanrec*> IdChanMap; +typedef std::list<unsigned long> AssocIdList; + +class ModuleSQLutils : public Module +{ +private: + IdUserMap iduser; + IdChanMap idchan; + +public: + ModuleSQLutils(InspIRCd* Me) + : Module::Module(Me) + { + ServerInstance->PublishInterface("SQLutils", this); + } + + virtual ~ModuleSQLutils() + { + ServerInstance->UnpublishInterface("SQLutils", this); + } + + void Implements(char* List) + { + List[I_OnChannelDelete] = List[I_OnUnloadModule] = List[I_OnRequest] = List[I_OnUserDisconnect] = 1; + } + + virtual char* OnRequest(Request* request) + { + if(strcmp(SQLUTILAU, request->GetId()) == 0) + { + AssociateUser* req = (AssociateUser*)request; + + iduser.insert(std::make_pair(req->id, req->user)); + + AttachList(req->user, req->id); + } + else if(strcmp(SQLUTILAC, request->GetId()) == 0) + { + AssociateChan* req = (AssociateChan*)request; + + idchan.insert(std::make_pair(req->id, req->chan)); + + AttachList(req->chan, req->id); + } + else if(strcmp(SQLUTILUA, request->GetId()) == 0) + { + UnAssociate* req = (UnAssociate*)request; + + /* Unassociate a given query ID with all users and channels + * it is associated with. + */ + + DoUnAssociate(iduser, req->id); + DoUnAssociate(idchan, req->id); + } + else if(strcmp(SQLUTILGU, request->GetId()) == 0) + { + GetAssocUser* req = (GetAssocUser*)request; + + IdUserMap::iterator iter = iduser.find(req->id); + + if(iter != iduser.end()) + { + req->user = iter->second; + } + } + else if(strcmp(SQLUTILGC, request->GetId()) == 0) + { + GetAssocChan* req = (GetAssocChan*)request; + + IdChanMap::iterator iter = idchan.find(req->id); + + if(iter != idchan.end()) + { + req->chan = iter->second; + } + } + + return SQLUTILSUCCESS; + } + + virtual void OnUserDisconnect(userrec* user) + { + /* A user is disconnecting, first we need to check if they have a list of queries associated with them. + * Then, if they do, we need to erase each of them from our IdUserMap (iduser) so when the module that + * associated them asks to look them up then it gets a NULL result and knows to discard the query. + */ + AssocIdList* il; + + if(user->GetExt("sqlutils_queryids", il)) + { + for(AssocIdList::iterator listiter = il->begin(); listiter != il->end(); listiter++) + { + IdUserMap::iterator iter; + + iter = iduser.find(*listiter); + + if(iter != iduser.end()) + { + if(iter->second != user) + { + ServerInstance->Log(DEBUG, "BUG: ID associated with user %s doesn't have the same userrec* associated with it in the map (erasing anyway)", user->nick); + } + + iduser.erase(iter); + } + else + { + ServerInstance->Log(DEBUG, "BUG: user %s was extended with sqlutils_queryids but there was nothing matching in the map", user->nick); + } + } + + user->Shrink("sqlutils_queryids"); + delete il; + } + } + + void AttachList(Extensible* obj, unsigned long id) + { + AssocIdList* il; + + if(!obj->GetExt("sqlutils_queryids", il)) + { + /* Doesn't already exist, create a new list and attach it. */ + il = new AssocIdList; + obj->Extend("sqlutils_queryids", il); + } + + /* Now either way we have a valid list in il, attached. */ + il->push_back(id); + } + + void RemoveFromList(Extensible* obj, unsigned long id) + { + AssocIdList* il; + + if(obj->GetExt("sqlutils_queryids", il)) + { + /* Only do anything if the list exists... (which it ought to) */ + il->remove(id); + + if(il->empty()) + { + /* If we just emptied it.. */ + delete il; + obj->Shrink("sqlutils_queryids"); + } + } + } + + template <class T> void DoUnAssociate(T &map, unsigned long id) + { + /* For each occurence of 'id' (well, only one..it's not a multimap) in 'map' + * remove it from the map, take an Extensible* value from the map and remove + * 'id' from the list of query IDs attached to it. + */ + typename T::iterator iter = map.find(id); + + if(iter != map.end()) + { + /* Found a value indexed by 'id', call RemoveFromList() + * on it with 'id' to remove 'id' from the list attached + * to the value. + */ + RemoveFromList(iter->second, id); + } + } + + virtual void OnChannelDelete(chanrec* chan) + { + /* A channel is being destroyed, first we need to check if it has a list of queries associated with it. + * Then, if it does, we need to erase each of them from our IdChanMap (idchan) so when the module that + * associated them asks to look them up then it gets a NULL result and knows to discard the query. + */ + AssocIdList* il; + + if(chan->GetExt("sqlutils_queryids", il)) + { + for(AssocIdList::iterator listiter = il->begin(); listiter != il->end(); listiter++) + { + IdChanMap::iterator iter; + + iter = idchan.find(*listiter); + + if(iter != idchan.end()) + { + if(iter->second != chan) + { + ServerInstance->Log(DEBUG, "BUG: ID associated with channel %s doesn't have the same chanrec* associated with it in the map (erasing anyway)", chan->name); + } + idchan.erase(iter); + } + else + { + ServerInstance->Log(DEBUG, "BUG: channel %s was extended with sqlutils_queryids but there was nothing matching in the map", chan->name); + } + } + + chan->Shrink("sqlutils_queryids"); + delete il; + } + } + + virtual Version GetVersion() + { + return Version(1, 1, 0, 0, VF_VENDOR|VF_SERVICEPROVIDER, API_VERSION); + } + +}; + +MODULE_INIT(ModuleSQLutils); + diff --git a/src/modules/extra/m_sqlutils.h b/src/modules/extra/m_sqlutils.h index cdde51f67..92fbdf5c7 100644 --- a/src/modules/extra/m_sqlutils.h +++ b/src/modules/extra/m_sqlutils.h @@ -1 +1,143 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#ifndef INSPIRCD_SQLUTILS
#define INSPIRCD_SQLUTILS
#include "modules.h"
#define SQLUTILAU "SQLutil AssociateUser"
#define SQLUTILAC "SQLutil AssociateChan"
#define SQLUTILUA "SQLutil UnAssociate"
#define SQLUTILGU "SQLutil GetAssocUser"
#define SQLUTILGC "SQLutil GetAssocChan"
#define SQLUTILSUCCESS "You shouldn't be reading this (success)"
/** Used to associate an SQL query with a user
*/
class AssociateUser : public Request
{
public:
/** Query ID
*/
unsigned long id;
/** User
*/
userrec* user;
AssociateUser(Module* s, Module* d, unsigned long i, userrec* u)
: Request(s, d, SQLUTILAU), id(i), user(u)
{
}
AssociateUser& S()
{
Send();
return *this;
}
};
/** Used to associate an SQL query with a channel
*/
class AssociateChan : public Request
{
public:
/** Query ID
*/
unsigned long id;
/** Channel
*/
chanrec* chan;
AssociateChan(Module* s, Module* d, unsigned long i, chanrec* u)
: Request(s, d, SQLUTILAC), id(i), chan(u)
{
}
AssociateChan& S()
{
Send();
return *this;
}
};
/** Unassociate a user or class from an SQL query
*/
class UnAssociate : public Request
{
public:
/** The query ID
*/
unsigned long id;
UnAssociate(Module* s, Module* d, unsigned long i)
: Request(s, d, SQLUTILUA), id(i)
{
}
UnAssociate& S()
{
Send();
return *this;
}
};
/** Get the user associated with an SQL query ID
*/
class GetAssocUser : public Request
{
public:
/** The query id
*/
unsigned long id;
/** The user
*/
userrec* user;
GetAssocUser(Module* s, Module* d, unsigned long i)
: Request(s, d, SQLUTILGU), id(i), user(NULL)
{
}
GetAssocUser& S()
{
Send();
return *this;
}
};
/** Get the channel associated with an SQL query ID
*/
class GetAssocChan : public Request
{
public:
/** The query id
*/
unsigned long id;
/** The channel
*/
chanrec* chan;
GetAssocChan(Module* s, Module* d, unsigned long i)
: Request(s, d, SQLUTILGC), id(i), chan(NULL)
{
}
GetAssocChan& S()
{
Send();
return *this;
}
};
#endif
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#ifndef INSPIRCD_SQLUTILS +#define INSPIRCD_SQLUTILS + +#include "modules.h" + +#define SQLUTILAU "SQLutil AssociateUser" +#define SQLUTILAC "SQLutil AssociateChan" +#define SQLUTILUA "SQLutil UnAssociate" +#define SQLUTILGU "SQLutil GetAssocUser" +#define SQLUTILGC "SQLutil GetAssocChan" +#define SQLUTILSUCCESS "You shouldn't be reading this (success)" + +/** Used to associate an SQL query with a user + */ +class AssociateUser : public Request +{ +public: + /** Query ID + */ + unsigned long id; + /** User + */ + userrec* user; + + AssociateUser(Module* s, Module* d, unsigned long i, userrec* u) + : Request(s, d, SQLUTILAU), id(i), user(u) + { + } + + AssociateUser& S() + { + Send(); + return *this; + } +}; + +/** Used to associate an SQL query with a channel + */ +class AssociateChan : public Request +{ +public: + /** Query ID + */ + unsigned long id; + /** Channel + */ + chanrec* chan; + + AssociateChan(Module* s, Module* d, unsigned long i, chanrec* u) + : Request(s, d, SQLUTILAC), id(i), chan(u) + { + } + + AssociateChan& S() + { + Send(); + return *this; + } +}; + +/** Unassociate a user or class from an SQL query + */ +class UnAssociate : public Request +{ +public: + /** The query ID + */ + unsigned long id; + + UnAssociate(Module* s, Module* d, unsigned long i) + : Request(s, d, SQLUTILUA), id(i) + { + } + + UnAssociate& S() + { + Send(); + return *this; + } +}; + +/** Get the user associated with an SQL query ID + */ +class GetAssocUser : public Request +{ +public: + /** The query id + */ + unsigned long id; + /** The user + */ + userrec* user; + + GetAssocUser(Module* s, Module* d, unsigned long i) + : Request(s, d, SQLUTILGU), id(i), user(NULL) + { + } + + GetAssocUser& S() + { + Send(); + return *this; + } +}; + +/** Get the channel associated with an SQL query ID + */ +class GetAssocChan : public Request +{ +public: + /** The query id + */ + unsigned long id; + /** The channel + */ + chanrec* chan; + + GetAssocChan(Module* s, Module* d, unsigned long i) + : Request(s, d, SQLUTILGC), id(i), chan(NULL) + { + } + + GetAssocChan& S() + { + Send(); + return *this; + } +}; + +#endif diff --git a/src/modules/extra/m_sqlv2.h b/src/modules/extra/m_sqlv2.h index decac4b57..c7f6edbb9 100644 --- a/src/modules/extra/m_sqlv2.h +++ b/src/modules/extra/m_sqlv2.h @@ -1 +1,605 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#ifndef INSPIRCD_SQLAPI_2
#define INSPIRCD_SQLAPI_2
#include <string>
#include <deque>
#include <map>
#include "modules.h"
/** SQLreq define.
* This is the voodoo magic which lets us pass multiple
* parameters to the SQLrequest constructor... voodoo...
*/
#define SQLreq(a, b, c, d, e...) SQLrequest(a, b, c, (SQLquery(d), ##e))
/** Identifiers used to identify Request types
*/
#define SQLREQID "SQLv2 Request"
#define SQLRESID "SQLv2 Result"
#define SQLSUCCESS "You shouldn't be reading this (success)"
/** Defines the error types which SQLerror may be set to
*/
enum SQLerrorNum { NO_ERROR, BAD_DBID, BAD_CONN, QSEND_FAIL, QREPLY_FAIL };
/** A list of format parameters for an SQLquery object.
*/
typedef std::deque<std::string> ParamL;
/** The base class of SQL exceptions
*/
class SQLexception : public ModuleException
{
public:
SQLexception(const std::string &reason) : ModuleException(reason)
{
}
SQLexception() : ModuleException("SQLv2: Undefined exception")
{
}
};
/** An exception thrown when a bad column or row name or id is requested
*/
class SQLbadColName : public SQLexception
{
public:
SQLbadColName() : SQLexception("SQLv2: Bad column name")
{
}
};
/** SQLerror holds the error state of any SQLrequest or SQLresult.
* The error string varies from database software to database software
* and should be used to display informational error messages to users.
*/
class SQLerror : public classbase
{
/** The error id
*/
SQLerrorNum id;
/** The error string
*/
std::string str;
public:
/** Initialize an SQLerror
* @param i The error ID to set
* @param s The (optional) error string to set
*/
SQLerror(SQLerrorNum i = NO_ERROR, const std::string &s = "")
: id(i), str(s)
{
}
/** Return the ID of the error
*/
SQLerrorNum Id()
{
return id;
}
/** Set the ID of an error
* @param i The new error ID to set
* @return the ID which was set
*/
SQLerrorNum Id(SQLerrorNum i)
{
id = i;
return id;
}
/** Set the error string for an error
* @param s The new error string to set
*/
void Str(const std::string &s)
{
str = s;
}
/** Return the error string for an error
*/
const char* Str()
{
if(str.length())
return str.c_str();
switch(id)
{
case NO_ERROR:
return "No error";
case BAD_DBID:
return "Invalid database ID";
case BAD_CONN:
return "Invalid connection";
case QSEND_FAIL:
return "Sending query failed";
case QREPLY_FAIL:
return "Getting query result failed";
default:
return "Unknown error";
}
}
};
/** SQLquery provides a way to represent a query string, and its parameters in a type-safe way.
* C++ has no native type-safe way of having a variable number of arguments to a function,
* the workaround for this isn't easy to describe simply, but in a nutshell what's really
* happening when - from the above example - you do this:
*
* SQLrequest foo = SQLreq(this, target, "databaseid", "SELECT (foo, bar) FROM rawr WHERE foo = '?' AND bar = ?", "Hello", "42");
*
* what's actually happening is functionally this:
*
* SQLrequest foo = SQLreq(this, target, "databaseid", query("SELECT (foo, bar) FROM rawr WHERE foo = '?' AND bar = ?").addparam("Hello").addparam("42"));
*
* with 'query()' returning a reference to an object with a 'addparam()' member function which
* in turn returns a reference to that object. There are actually four ways you can create a
* SQLrequest..all have their disadvantages and advantages. In the real implementations the
* 'query()' function is replaced by the constructor of another class 'SQLquery' which holds
* the query string and a ParamL (std::deque<std::string>) of query parameters.
* This is essentially the same as the above example except 'addparam()' is replaced by operator,(). The full syntax for this method is:
*
* SQLrequest foo = SQLrequest(this, target, "databaseid", (SQLquery("SELECT.. ?"), parameter, parameter));
*/
class SQLquery : public classbase
{
public:
/** The query 'format string'
*/
std::string q;
/** The query parameter list
* There should be one parameter for every ? character
* within the format string shown above.
*/
ParamL p;
/** Initialize an SQLquery with a given format string only
*/
SQLquery(const std::string &query)
: q(query)
{
}
/** Initialize an SQLquery with a format string and parameters.
* If you provide parameters, you must initialize the list yourself
* if you choose to do it via this method, using std::deque::push_back().
*/
SQLquery(const std::string &query, const ParamL ¶ms)
: q(query), p(params)
{
}
/** An overloaded operator for pushing parameters onto the parameter list
*/
template<typename T> SQLquery& operator,(const T &foo)
{
p.push_back(ConvToStr(foo));
return *this;
}
/** An overloaded operator for pushing parameters onto the parameter list.
* This has higher precedence than 'operator,' and can save on parenthesis.
*/
template<typename T> SQLquery& operator%(const T &foo)
{
p.push_back(ConvToStr(foo));
return *this;
}
};
/** SQLrequest is sent to the SQL API to command it to run a query and return the result.
* You must instantiate this object with a valid SQLquery object and its parameters, then
* send it using its Send() method to the module providing the 'SQL' feature. To find this
* module, use Server::FindFeature().
*/
class SQLrequest : public Request
{
public:
/** The fully parsed and expanded query string
* This is initialized from the SQLquery parameter given in the constructor.
*/
SQLquery query;
/** The database ID to apply the request to
*/
std::string dbid;
/** True if this is a priority query.
* Priority queries may 'queue jump' in the request queue.
*/
bool pri;
/** The query ID, assigned by the SQL api.
* After your request is processed, this will
* be initialized for you by the API to a valid request ID,
* except in the case of an error.
*/
unsigned long id;
/** If an error occured, error.id will be any other value than NO_ERROR.
*/
SQLerror error;
/** Initialize an SQLrequest.
* For example:
*
* SQLrequest req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_actors VALUES('','?')", nick);
*
* @param s A pointer to the sending module, where the result should be routed
* @param d A pointer to the receiving module, identified as implementing the 'SQL' feature
* @param databaseid The database ID to perform the query on. This must match a valid
* database ID from the configuration of the SQL module.
* @param q A properly initialized SQLquery object.
*/
SQLrequest(Module* s, Module* d, const std::string &databaseid, const SQLquery &q)
: Request(s, d, SQLREQID), query(q), dbid(databaseid), pri(false), id(0)
{
}
/** Set the priority of a request.
*/
void Priority(bool p = true)
{
pri = p;
}
/** Set the source of a request. You should not need to use this method.
*/
void SetSource(Module* mod)
{
source = mod;
}
};
/**
* This class contains a field's data plus a way to determine if the field
* is NULL or not without having to mess around with NULL pointers.
*/
class SQLfield
{
public:
/**
* The data itself
*/
std::string d;
/**
* If the field was null
*/
bool null;
/** Initialize an SQLfield
*/
SQLfield(const std::string &data = "", bool n = false)
: d(data), null(n)
{
}
};
/** A list of items which make up a row of a result or table (tuple)
* This does not include field names.
*/
typedef std::vector<SQLfield> SQLfieldList;
/** A list of items which make up a row of a result or table (tuple)
* This also includes the field names.
*/
typedef std::map<std::string, SQLfield> SQLfieldMap;
/** SQLresult is a reply to a previous query.
* If you send a query to the SQL api, the response will arrive at your
* OnRequest method of your module at some later time, depending on the
* congestion of the SQL server and complexity of the query. The ID of
* this result will match the ID assigned to your original request.
* SQLresult contains its own internal cursor (row counter) which is
* incremented with each method call which retrieves a single row.
*/
class SQLresult : public Request
{
public:
/** The original query string passed initially to the SQL API
*/
std::string query;
/** The database ID the query was executed on
*/
std::string dbid;
/**
* The error (if any) which occured.
* If an error occured the value of error.id will be any
* other value than NO_ERROR.
*/
SQLerror error;
/**
* This will match query ID you were given when sending
* the request at an earlier time.
*/
unsigned long id;
/** Used by the SQL API to instantiate an SQLrequest
*/
SQLresult(Module* s, Module* d, unsigned long i)
: Request(s, d, SQLRESID), id(i)
{
}
/**
* Return the number of rows in the result
* Note that if you have perfomed an INSERT
* or UPDATE query or other query which will
* not return rows, this will return the
* number of affected rows, and SQLresult::Cols()
* will contain 0. In this case you SHOULD NEVER
* access any of the result set rows, as there arent any!
* @returns Number of rows in the result set.
*/
virtual int Rows() = 0;
/**
* Return the number of columns in the result.
* If you performed an UPDATE or INSERT which
* does not return a dataset, this value will
* be 0.
* @returns Number of columns in the result set.
*/
virtual int Cols() = 0;
/**
* Get a string name of the column by an index number
* @param column The id number of a column
* @returns The column name associated with the given ID
*/
virtual std::string ColName(int column) = 0;
/**
* Get an index number for a column from a string name.
* An exception of type SQLbadColName will be thrown if
* the name given is invalid.
* @param column The column name to get the ID of
* @returns The ID number of the column provided
*/
virtual int ColNum(const std::string &column) = 0;
/**
* Get a string value in a given row and column
* This does not effect the internal cursor.
* @returns The value stored at [row,column] in the table
*/
virtual SQLfield GetValue(int row, int column) = 0;
/**
* Return a list of values in a row, this should
* increment an internal counter so you can repeatedly
* call it until it returns an empty vector.
* This returns a reference to an internal object,
* the same object is used for all calls to this function
* and therefore the return value is only valid until
* you call this function again. It is also invalid if
* the SQLresult object is destroyed.
* The internal cursor (row counter) is incremented by one.
* @returns A reference to the current row's SQLfieldList
*/
virtual SQLfieldList& GetRow() = 0;
/**
* As above, but return a map indexed by key name.
* The internal cursor (row counter) is incremented by one.
* @returns A reference to the current row's SQLfieldMap
*/
virtual SQLfieldMap& GetRowMap() = 0;
/**
* Like GetRow(), but returns a pointer to a dynamically
* allocated object which must be explicitly freed. For
* portability reasons this must be freed with SQLresult::Free()
* The internal cursor (row counter) is incremented by one.
* @returns A newly-allocated SQLfieldList
*/
virtual SQLfieldList* GetRowPtr() = 0;
/**
* As above, but return a map indexed by key name
* The internal cursor (row counter) is incremented by one.
* @returns A newly-allocated SQLfieldMap
*/
virtual SQLfieldMap* GetRowMapPtr() = 0;
/**
* Overloaded function for freeing the lists and maps
* returned by GetRowPtr or GetRowMapPtr.
* @param fm The SQLfieldMap to free
*/
virtual void Free(SQLfieldMap* fm) = 0;
/**
* Overloaded function for freeing the lists and maps
* returned by GetRowPtr or GetRowMapPtr.
* @param fl The SQLfieldList to free
*/
virtual void Free(SQLfieldList* fl) = 0;
};
/** SQLHost represents a <database> config line and is useful
* for storing in a map and iterating on rehash to see which
* <database> tags was added/removed/unchanged.
*/
class SQLhost
{
public:
std::string id; /* Database handle id */
std::string host; /* Database server hostname */
std::string ip; /* resolved IP, needed for at least pgsql.so */
unsigned int port; /* Database server port */
std::string name; /* Database name */
std::string user; /* Database username */
std::string pass; /* Database password */
bool ssl; /* If we should require SSL */
SQLhost()
{
}
SQLhost(const std::string& i, const std::string& h, unsigned int p, const std::string& n, const std::string& u, const std::string& pa, bool s)
: id(i), host(h), port(p), name(n), user(u), pass(pa), ssl(s)
{
}
/** Overload this to return a correct Data source Name (DSN) for
* the current SQL module.
*/
std::string GetDSN();
};
/** Overload operator== for two SQLhost objects for easy comparison.
*/
bool operator== (const SQLhost& l, const SQLhost& r)
{
return (l.id == r.id && l.host == r.host && l.port == r.port && l.name == r.name && l.user == l.user && l.pass == r.pass && l.ssl == r.ssl);
}
/** QueryQueue, a queue of queries waiting to be executed.
* This maintains two queues internally, one for 'priority'
* queries and one for less important ones. Each queue has
* new queries appended to it and ones to execute are popped
* off the front. This keeps them flowing round nicely and no
* query should ever get 'stuck' for too long. If there are
* queries in the priority queue they will be executed first,
* 'unimportant' queries will only be executed when the
* priority queue is empty.
*
* We store lists of SQLrequest's here, by value as we want to avoid storing
* any data allocated inside the client module (in case that module is unloaded
* while the query is in progress).
*
* Because we want to work on the current SQLrequest in-situ, we need a way
* of accessing the request we are currently processing, QueryQueue::front(),
* but that call needs to always return the same request until that request
* is removed from the queue, this is what the 'which' variable is. New queries are
* always added to the back of one of the two queues, but if when front()
* is first called then the priority queue is empty then front() will return
* a query from the normal queue, but if a query is then added to the priority
* queue then front() must continue to return the front of the *normal* queue
* until pop() is called.
*/
class QueryQueue : public classbase
{
private:
typedef std::deque<SQLrequest> ReqDeque;
ReqDeque priority; /* The priority queue */
ReqDeque normal; /* The 'normal' queue */
enum { PRI, NOR, NON } which; /* Which queue the currently active element is at the front of */
public:
QueryQueue()
: which(NON)
{
}
void push(const SQLrequest &q)
{
if(q.pri)
priority.push_back(q);
else
normal.push_back(q);
}
void pop()
{
if((which == PRI) && priority.size())
{
priority.pop_front();
}
else if((which == NOR) && normal.size())
{
normal.pop_front();
}
/* Reset this */
which = NON;
/* Silently do nothing if there was no element to pop() */
}
SQLrequest& front()
{
switch(which)
{
case PRI:
return priority.front();
case NOR:
return normal.front();
default:
if(priority.size())
{
which = PRI;
return priority.front();
}
if(normal.size())
{
which = NOR;
return normal.front();
}
/* This will probably result in a segfault,
* but the caller should have checked totalsize()
* first so..meh - moron :p
*/
return priority.front();
}
}
std::pair<int, int> size()
{
return std::make_pair(priority.size(), normal.size());
}
int totalsize()
{
return priority.size() + normal.size();
}
void PurgeModule(Module* mod)
{
DoPurgeModule(mod, priority);
DoPurgeModule(mod, normal);
}
private:
void DoPurgeModule(Module* mod, ReqDeque& q)
{
for(ReqDeque::iterator iter = q.begin(); iter != q.end(); iter++)
{
if(iter->GetSource() == mod)
{
if(iter->id == front().id)
{
/* It's the currently active query.. :x */
iter->SetSource(NULL);
}
else
{
/* It hasn't been executed yet..just remove it */
iter = q.erase(iter);
}
}
}
}
};
#endif
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#ifndef INSPIRCD_SQLAPI_2 +#define INSPIRCD_SQLAPI_2 + +#include <string> +#include <deque> +#include <map> +#include "modules.h" + +/** SQLreq define. + * This is the voodoo magic which lets us pass multiple + * parameters to the SQLrequest constructor... voodoo... + */ +#define SQLreq(a, b, c, d, e...) SQLrequest(a, b, c, (SQLquery(d), ##e)) + +/** Identifiers used to identify Request types + */ +#define SQLREQID "SQLv2 Request" +#define SQLRESID "SQLv2 Result" +#define SQLSUCCESS "You shouldn't be reading this (success)" + +/** Defines the error types which SQLerror may be set to + */ +enum SQLerrorNum { NO_ERROR, BAD_DBID, BAD_CONN, QSEND_FAIL, QREPLY_FAIL }; + +/** A list of format parameters for an SQLquery object. + */ +typedef std::deque<std::string> ParamL; + +/** The base class of SQL exceptions + */ +class SQLexception : public ModuleException +{ + public: + SQLexception(const std::string &reason) : ModuleException(reason) + { + } + + SQLexception() : ModuleException("SQLv2: Undefined exception") + { + } +}; + +/** An exception thrown when a bad column or row name or id is requested + */ +class SQLbadColName : public SQLexception +{ +public: + SQLbadColName() : SQLexception("SQLv2: Bad column name") + { + } +}; + +/** SQLerror holds the error state of any SQLrequest or SQLresult. + * The error string varies from database software to database software + * and should be used to display informational error messages to users. + */ +class SQLerror : public classbase +{ + /** The error id + */ + SQLerrorNum id; + /** The error string + */ + std::string str; +public: + /** Initialize an SQLerror + * @param i The error ID to set + * @param s The (optional) error string to set + */ + SQLerror(SQLerrorNum i = NO_ERROR, const std::string &s = "") + : id(i), str(s) + { + } + + /** Return the ID of the error + */ + SQLerrorNum Id() + { + return id; + } + + /** Set the ID of an error + * @param i The new error ID to set + * @return the ID which was set + */ + SQLerrorNum Id(SQLerrorNum i) + { + id = i; + return id; + } + + /** Set the error string for an error + * @param s The new error string to set + */ + void Str(const std::string &s) + { + str = s; + } + + /** Return the error string for an error + */ + const char* Str() + { + if(str.length()) + return str.c_str(); + + switch(id) + { + case NO_ERROR: + return "No error"; + case BAD_DBID: + return "Invalid database ID"; + case BAD_CONN: + return "Invalid connection"; + case QSEND_FAIL: + return "Sending query failed"; + case QREPLY_FAIL: + return "Getting query result failed"; + default: + return "Unknown error"; + } + } +}; + +/** SQLquery provides a way to represent a query string, and its parameters in a type-safe way. + * C++ has no native type-safe way of having a variable number of arguments to a function, + * the workaround for this isn't easy to describe simply, but in a nutshell what's really + * happening when - from the above example - you do this: + * + * SQLrequest foo = SQLreq(this, target, "databaseid", "SELECT (foo, bar) FROM rawr WHERE foo = '?' AND bar = ?", "Hello", "42"); + * + * what's actually happening is functionally this: + * + * SQLrequest foo = SQLreq(this, target, "databaseid", query("SELECT (foo, bar) FROM rawr WHERE foo = '?' AND bar = ?").addparam("Hello").addparam("42")); + * + * with 'query()' returning a reference to an object with a 'addparam()' member function which + * in turn returns a reference to that object. There are actually four ways you can create a + * SQLrequest..all have their disadvantages and advantages. In the real implementations the + * 'query()' function is replaced by the constructor of another class 'SQLquery' which holds + * the query string and a ParamL (std::deque<std::string>) of query parameters. + * This is essentially the same as the above example except 'addparam()' is replaced by operator,(). The full syntax for this method is: + * + * SQLrequest foo = SQLrequest(this, target, "databaseid", (SQLquery("SELECT.. ?"), parameter, parameter)); + */ +class SQLquery : public classbase +{ +public: + /** The query 'format string' + */ + std::string q; + /** The query parameter list + * There should be one parameter for every ? character + * within the format string shown above. + */ + ParamL p; + + /** Initialize an SQLquery with a given format string only + */ + SQLquery(const std::string &query) + : q(query) + { + } + + /** Initialize an SQLquery with a format string and parameters. + * If you provide parameters, you must initialize the list yourself + * if you choose to do it via this method, using std::deque::push_back(). + */ + SQLquery(const std::string &query, const ParamL ¶ms) + : q(query), p(params) + { + } + + /** An overloaded operator for pushing parameters onto the parameter list + */ + template<typename T> SQLquery& operator,(const T &foo) + { + p.push_back(ConvToStr(foo)); + return *this; + } + + /** An overloaded operator for pushing parameters onto the parameter list. + * This has higher precedence than 'operator,' and can save on parenthesis. + */ + template<typename T> SQLquery& operator%(const T &foo) + { + p.push_back(ConvToStr(foo)); + return *this; + } +}; + +/** SQLrequest is sent to the SQL API to command it to run a query and return the result. + * You must instantiate this object with a valid SQLquery object and its parameters, then + * send it using its Send() method to the module providing the 'SQL' feature. To find this + * module, use Server::FindFeature(). + */ +class SQLrequest : public Request +{ +public: + /** The fully parsed and expanded query string + * This is initialized from the SQLquery parameter given in the constructor. + */ + SQLquery query; + /** The database ID to apply the request to + */ + std::string dbid; + /** True if this is a priority query. + * Priority queries may 'queue jump' in the request queue. + */ + bool pri; + /** The query ID, assigned by the SQL api. + * After your request is processed, this will + * be initialized for you by the API to a valid request ID, + * except in the case of an error. + */ + unsigned long id; + /** If an error occured, error.id will be any other value than NO_ERROR. + */ + SQLerror error; + + /** Initialize an SQLrequest. + * For example: + * + * SQLrequest req = SQLreq(MyMod, SQLModule, dbid, "INSERT INTO ircd_log_actors VALUES('','?')", nick); + * + * @param s A pointer to the sending module, where the result should be routed + * @param d A pointer to the receiving module, identified as implementing the 'SQL' feature + * @param databaseid The database ID to perform the query on. This must match a valid + * database ID from the configuration of the SQL module. + * @param q A properly initialized SQLquery object. + */ + SQLrequest(Module* s, Module* d, const std::string &databaseid, const SQLquery &q) + : Request(s, d, SQLREQID), query(q), dbid(databaseid), pri(false), id(0) + { + } + + /** Set the priority of a request. + */ + void Priority(bool p = true) + { + pri = p; + } + + /** Set the source of a request. You should not need to use this method. + */ + void SetSource(Module* mod) + { + source = mod; + } +}; + +/** + * This class contains a field's data plus a way to determine if the field + * is NULL or not without having to mess around with NULL pointers. + */ +class SQLfield +{ +public: + /** + * The data itself + */ + std::string d; + + /** + * If the field was null + */ + bool null; + + /** Initialize an SQLfield + */ + SQLfield(const std::string &data = "", bool n = false) + : d(data), null(n) + { + + } +}; + +/** A list of items which make up a row of a result or table (tuple) + * This does not include field names. + */ +typedef std::vector<SQLfield> SQLfieldList; +/** A list of items which make up a row of a result or table (tuple) + * This also includes the field names. + */ +typedef std::map<std::string, SQLfield> SQLfieldMap; + +/** SQLresult is a reply to a previous query. + * If you send a query to the SQL api, the response will arrive at your + * OnRequest method of your module at some later time, depending on the + * congestion of the SQL server and complexity of the query. The ID of + * this result will match the ID assigned to your original request. + * SQLresult contains its own internal cursor (row counter) which is + * incremented with each method call which retrieves a single row. + */ +class SQLresult : public Request +{ +public: + /** The original query string passed initially to the SQL API + */ + std::string query; + /** The database ID the query was executed on + */ + std::string dbid; + /** + * The error (if any) which occured. + * If an error occured the value of error.id will be any + * other value than NO_ERROR. + */ + SQLerror error; + /** + * This will match query ID you were given when sending + * the request at an earlier time. + */ + unsigned long id; + + /** Used by the SQL API to instantiate an SQLrequest + */ + SQLresult(Module* s, Module* d, unsigned long i) + : Request(s, d, SQLRESID), id(i) + { + } + + /** + * Return the number of rows in the result + * Note that if you have perfomed an INSERT + * or UPDATE query or other query which will + * not return rows, this will return the + * number of affected rows, and SQLresult::Cols() + * will contain 0. In this case you SHOULD NEVER + * access any of the result set rows, as there arent any! + * @returns Number of rows in the result set. + */ + virtual int Rows() = 0; + + /** + * Return the number of columns in the result. + * If you performed an UPDATE or INSERT which + * does not return a dataset, this value will + * be 0. + * @returns Number of columns in the result set. + */ + virtual int Cols() = 0; + + /** + * Get a string name of the column by an index number + * @param column The id number of a column + * @returns The column name associated with the given ID + */ + virtual std::string ColName(int column) = 0; + + /** + * Get an index number for a column from a string name. + * An exception of type SQLbadColName will be thrown if + * the name given is invalid. + * @param column The column name to get the ID of + * @returns The ID number of the column provided + */ + virtual int ColNum(const std::string &column) = 0; + + /** + * Get a string value in a given row and column + * This does not effect the internal cursor. + * @returns The value stored at [row,column] in the table + */ + virtual SQLfield GetValue(int row, int column) = 0; + + /** + * Return a list of values in a row, this should + * increment an internal counter so you can repeatedly + * call it until it returns an empty vector. + * This returns a reference to an internal object, + * the same object is used for all calls to this function + * and therefore the return value is only valid until + * you call this function again. It is also invalid if + * the SQLresult object is destroyed. + * The internal cursor (row counter) is incremented by one. + * @returns A reference to the current row's SQLfieldList + */ + virtual SQLfieldList& GetRow() = 0; + + /** + * As above, but return a map indexed by key name. + * The internal cursor (row counter) is incremented by one. + * @returns A reference to the current row's SQLfieldMap + */ + virtual SQLfieldMap& GetRowMap() = 0; + + /** + * Like GetRow(), but returns a pointer to a dynamically + * allocated object which must be explicitly freed. For + * portability reasons this must be freed with SQLresult::Free() + * The internal cursor (row counter) is incremented by one. + * @returns A newly-allocated SQLfieldList + */ + virtual SQLfieldList* GetRowPtr() = 0; + + /** + * As above, but return a map indexed by key name + * The internal cursor (row counter) is incremented by one. + * @returns A newly-allocated SQLfieldMap + */ + virtual SQLfieldMap* GetRowMapPtr() = 0; + + /** + * Overloaded function for freeing the lists and maps + * returned by GetRowPtr or GetRowMapPtr. + * @param fm The SQLfieldMap to free + */ + virtual void Free(SQLfieldMap* fm) = 0; + + /** + * Overloaded function for freeing the lists and maps + * returned by GetRowPtr or GetRowMapPtr. + * @param fl The SQLfieldList to free + */ + virtual void Free(SQLfieldList* fl) = 0; +}; + + +/** SQLHost represents a <database> config line and is useful + * for storing in a map and iterating on rehash to see which + * <database> tags was added/removed/unchanged. + */ +class SQLhost +{ + public: + std::string id; /* Database handle id */ + std::string host; /* Database server hostname */ + std::string ip; /* resolved IP, needed for at least pgsql.so */ + unsigned int port; /* Database server port */ + std::string name; /* Database name */ + std::string user; /* Database username */ + std::string pass; /* Database password */ + bool ssl; /* If we should require SSL */ + + SQLhost() + { + } + + SQLhost(const std::string& i, const std::string& h, unsigned int p, const std::string& n, const std::string& u, const std::string& pa, bool s) + : id(i), host(h), port(p), name(n), user(u), pass(pa), ssl(s) + { + } + + /** Overload this to return a correct Data source Name (DSN) for + * the current SQL module. + */ + std::string GetDSN(); +}; + +/** Overload operator== for two SQLhost objects for easy comparison. + */ +bool operator== (const SQLhost& l, const SQLhost& r) +{ + return (l.id == r.id && l.host == r.host && l.port == r.port && l.name == r.name && l.user == l.user && l.pass == r.pass && l.ssl == r.ssl); +} + + +/** QueryQueue, a queue of queries waiting to be executed. + * This maintains two queues internally, one for 'priority' + * queries and one for less important ones. Each queue has + * new queries appended to it and ones to execute are popped + * off the front. This keeps them flowing round nicely and no + * query should ever get 'stuck' for too long. If there are + * queries in the priority queue they will be executed first, + * 'unimportant' queries will only be executed when the + * priority queue is empty. + * + * We store lists of SQLrequest's here, by value as we want to avoid storing + * any data allocated inside the client module (in case that module is unloaded + * while the query is in progress). + * + * Because we want to work on the current SQLrequest in-situ, we need a way + * of accessing the request we are currently processing, QueryQueue::front(), + * but that call needs to always return the same request until that request + * is removed from the queue, this is what the 'which' variable is. New queries are + * always added to the back of one of the two queues, but if when front() + * is first called then the priority queue is empty then front() will return + * a query from the normal queue, but if a query is then added to the priority + * queue then front() must continue to return the front of the *normal* queue + * until pop() is called. + */ + +class QueryQueue : public classbase +{ +private: + typedef std::deque<SQLrequest> ReqDeque; + + ReqDeque priority; /* The priority queue */ + ReqDeque normal; /* The 'normal' queue */ + enum { PRI, NOR, NON } which; /* Which queue the currently active element is at the front of */ + +public: + QueryQueue() + : which(NON) + { + } + + void push(const SQLrequest &q) + { + if(q.pri) + priority.push_back(q); + else + normal.push_back(q); + } + + void pop() + { + if((which == PRI) && priority.size()) + { + priority.pop_front(); + } + else if((which == NOR) && normal.size()) + { + normal.pop_front(); + } + + /* Reset this */ + which = NON; + + /* Silently do nothing if there was no element to pop() */ + } + + SQLrequest& front() + { + switch(which) + { + case PRI: + return priority.front(); + case NOR: + return normal.front(); + default: + if(priority.size()) + { + which = PRI; + return priority.front(); + } + + if(normal.size()) + { + which = NOR; + return normal.front(); + } + + /* This will probably result in a segfault, + * but the caller should have checked totalsize() + * first so..meh - moron :p + */ + + return priority.front(); + } + } + + std::pair<int, int> size() + { + return std::make_pair(priority.size(), normal.size()); + } + + int totalsize() + { + return priority.size() + normal.size(); + } + + void PurgeModule(Module* mod) + { + DoPurgeModule(mod, priority); + DoPurgeModule(mod, normal); + } + +private: + void DoPurgeModule(Module* mod, ReqDeque& q) + { + for(ReqDeque::iterator iter = q.begin(); iter != q.end(); iter++) + { + if(iter->GetSource() == mod) + { + if(iter->id == front().id) + { + /* It's the currently active query.. :x */ + iter->SetSource(NULL); + } + else + { + /* It hasn't been executed yet..just remove it */ + iter = q.erase(iter); + } + } + } + } +}; + + +#endif diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index 037d2cf72..fd8b12d32 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -1 +1,843 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#include "inspircd_config.h"
#include "configreader.h"
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "socket.h"
#include "hashcomp.h"
#include "transport.h"
#ifdef WINDOWS
#pragma comment(lib, "libgnutls-13.lib")
#undef MAX_DESCRIPTORS
#define MAX_DESCRIPTORS 10000
#endif
/* $ModDesc: Provides SSL support for clients */
/* $CompileFlags: exec("libgnutls-config --cflags") */
/* $LinkerFlags: rpath("libgnutls-config --libs") exec("libgnutls-config --libs") */
/* $ModDep: transport.h */
enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED };
bool isin(int port, const std::vector<int> &portlist)
{
for(unsigned int i = 0; i < portlist.size(); i++)
if(portlist[i] == port)
return true;
return false;
}
/** Represents an SSL user's extra data
*/
class issl_session : public classbase
{
public:
gnutls_session_t sess;
issl_status status;
std::string outbuf;
int inbufoffset;
char* inbuf;
int fd;
};
class ModuleSSLGnuTLS : public Module
{
ConfigReader* Conf;
char* dummy;
std::vector<int> listenports;
int inbufsize;
issl_session sessions[MAX_DESCRIPTORS];
gnutls_certificate_credentials x509_cred;
gnutls_dh_params dh_params;
std::string keyfile;
std::string certfile;
std::string cafile;
std::string crlfile;
std::string sslports;
int dh_bits;
int clientactive;
public:
ModuleSSLGnuTLS(InspIRCd* Me)
: Module(Me)
{
ServerInstance->PublishInterface("InspSocketHook", this);
// Not rehashable...because I cba to reduce all the sizes of existing buffers.
inbufsize = ServerInstance->Config->NetBufferSize;
gnutls_global_init(); // This must be called once in the program
if(gnutls_certificate_allocate_credentials(&x509_cred) != 0)
ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to allocate certificate credentials");
// Guessing return meaning
if(gnutls_dh_params_init(&dh_params) < 0)
ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to initialise DH parameters");
// Needs the flag as it ignores a plain /rehash
OnRehash(NULL,"ssl");
// Void return, guess we assume success
gnutls_certificate_set_dh_params(x509_cred, dh_params);
}
virtual void OnRehash(userrec* user, const std::string ¶m)
{
if(param != "ssl")
return;
Conf = new ConfigReader(ServerInstance);
for(unsigned int i = 0; i < listenports.size(); i++)
{
ServerInstance->Config->DelIOHook(listenports[i]);
}
listenports.clear();
clientactive = 0;
sslports.clear();
for(int i = 0; i < Conf->Enumerate("bind"); i++)
{
// For each <bind> tag
std::string x = Conf->ReadValue("bind", "type", i);
if(((x.empty()) || (x == "clients")) && (Conf->ReadValue("bind", "ssl", i) == "gnutls"))
{
// Get the port we're meant to be listening on with SSL
std::string port = Conf->ReadValue("bind", "port", i);
irc::portparser portrange(port, false);
long portno = -1;
while ((portno = portrange.GetToken()))
{
clientactive++;
try
{
if (ServerInstance->Config->AddIOHook(portno, this))
{
listenports.push_back(portno);
for (size_t i = 0; i < ServerInstance->Config->ports.size(); i++)
if (ServerInstance->Config->ports[i]->GetPort() == portno)
ServerInstance->Config->ports[i]->SetDescription("ssl");
ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %d", portno);
sslports.append("*:").append(ConvToStr(portno)).append(";");
}
else
{
ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: FAILED to enable SSL on port %d, maybe you have another ssl or similar module loaded?", portno);
}
}
catch (ModuleException &e)
{
ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: FAILED to enable SSL on port %d: %s. Maybe it's already hooked by the same port on a different IP, or you have an other SSL or similar module loaded?", portno, e.GetReason());
}
}
}
}
std::string confdir(ServerInstance->ConfigFileName);
// +1 so we the path ends with a /
confdir = confdir.substr(0, confdir.find_last_of('/') + 1);
cafile = Conf->ReadValue("gnutls", "cafile", 0);
crlfile = Conf->ReadValue("gnutls", "crlfile", 0);
certfile = Conf->ReadValue("gnutls", "certfile", 0);
keyfile = Conf->ReadValue("gnutls", "keyfile", 0);
dh_bits = Conf->ReadInteger("gnutls", "dhbits", 0, false);
// Set all the default values needed.
if (cafile.empty())
cafile = "ca.pem";
if (crlfile.empty())
crlfile = "crl.pem";
if (certfile.empty())
certfile = "cert.pem";
if (keyfile.empty())
keyfile = "key.pem";
if((dh_bits != 768) && (dh_bits != 1024) && (dh_bits != 2048) && (dh_bits != 3072) && (dh_bits != 4096))
dh_bits = 1024;
// Prepend relative paths with the path to the config directory.
if(cafile[0] != '/')
cafile = confdir + cafile;
if(crlfile[0] != '/')
crlfile = confdir + crlfile;
if(certfile[0] != '/')
certfile = confdir + certfile;
if(keyfile[0] != '/')
keyfile = confdir + keyfile;
int ret;
if((ret =gnutls_certificate_set_x509_trust_file(x509_cred, cafile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret));
if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret));
if((ret = gnutls_certificate_set_x509_key_file (x509_cred, certfile.c_str(), keyfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0)
{
// If this fails, no SSL port will work. At all. So, do the smart thing - throw a ModuleException
throw ModuleException("Unable to load GnuTLS server certificate: " + std::string(gnutls_strerror(ret)));
}
// This may be on a large (once a day or week) timer eventually.
GenerateDHParams();
DELETE(Conf);
}
void GenerateDHParams()
{
// Generate Diffie Hellman parameters - for use with DHE
// kx algorithms. These should be discarded and regenerated
// once a day, once a week or once a month. Depending on the
// security requirements.
int ret;
if((ret = gnutls_dh_params_generate2(dh_params, dh_bits)) < 0)
ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to generate DH parameters (%d bits): %s", dh_bits, gnutls_strerror(ret));
}
virtual ~ModuleSSLGnuTLS()
{
gnutls_dh_params_deinit(dh_params);
gnutls_certificate_free_credentials(x509_cred);
gnutls_global_deinit();
}
virtual void OnCleanup(int target_type, void* item)
{
if(target_type == TYPE_USER)
{
userrec* user = (userrec*)item;
if(user->GetExt("ssl", dummy) && isin(user->GetPort(), listenports))
{
// User is using SSL, they're a local user, and they're using one of *our* SSL ports.
// Potentially there could be multiple SSL modules loaded at once on different ports.
ServerInstance->GlobalCulls.AddItem(user, "SSL module unloading");
}
if (user->GetExt("ssl_cert", dummy) && isin(user->GetPort(), listenports))
{
ssl_cert* tofree;
user->GetExt("ssl_cert", tofree);
delete tofree;
user->Shrink("ssl_cert");
}
}
}
virtual void OnUnloadModule(Module* mod, const std::string &name)
{
if(mod == this)
{
for(unsigned int i = 0; i < listenports.size(); i++)
{
ServerInstance->Config->DelIOHook(listenports[i]);
for (size_t j = 0; j < ServerInstance->Config->ports.size(); j++)
if (ServerInstance->Config->ports[j]->GetPort() == listenports[i])
ServerInstance->Config->ports[j]->SetDescription("plaintext");
}
}
}
virtual Version GetVersion()
{
return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION);
}
void Implements(char* List)
{
List[I_On005Numeric] = List[I_OnRawSocketConnect] = List[I_OnRawSocketAccept] = List[I_OnRawSocketClose] = List[I_OnRawSocketRead] = List[I_OnRawSocketWrite] = List[I_OnCleanup] = 1;
List[I_OnRequest] = List[I_OnSyncUserMetaData] = List[I_OnDecodeMetaData] = List[I_OnUnloadModule] = List[I_OnRehash] = List[I_OnWhois] = List[I_OnPostConnect] = 1;
}
virtual void On005Numeric(std::string &output)
{
output.append(" SSL=" + sslports);
}
virtual char* OnRequest(Request* request)
{
ISHRequest* ISR = (ISHRequest*)request;
if (strcmp("IS_NAME", request->GetId()) == 0)
{
return "gnutls";
}
else if (strcmp("IS_HOOK", request->GetId()) == 0)
{
char* ret = "OK";
try
{
ret = ServerInstance->Config->AddIOHook((Module*)this, (InspSocket*)ISR->Sock) ? (char*)"OK" : NULL;
}
catch (ModuleException &e)
{
return NULL;
}
return ret;
}
else if (strcmp("IS_UNHOOK", request->GetId()) == 0)
{
return ServerInstance->Config->DelIOHook((InspSocket*)ISR->Sock) ? (char*)"OK" : NULL;
}
else if (strcmp("IS_HSDONE", request->GetId()) == 0)
{
if (ISR->Sock->GetFd() < 0)
return (char*)"OK";
issl_session* session = &sessions[ISR->Sock->GetFd()];
return (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) ? NULL : (char*)"OK";
}
else if (strcmp("IS_ATTACH", request->GetId()) == 0)
{
if (ISR->Sock->GetFd() > -1)
{
issl_session* session = &sessions[ISR->Sock->GetFd()];
if (session->sess)
{
if ((Extensible*)ServerInstance->FindDescriptor(ISR->Sock->GetFd()) == (Extensible*)(ISR->Sock))
{
VerifyCertificate(session, (InspSocket*)ISR->Sock);
return "OK";
}
}
}
}
return NULL;
}
virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport)
{
issl_session* session = &sessions[fd];
session->fd = fd;
session->inbuf = new char[inbufsize];
session->inbufoffset = 0;
gnutls_init(&session->sess, GNUTLS_SERVER);
gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate.
gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
gnutls_dh_set_prime_bits(session->sess, dh_bits);
/* This is an experimental change to avoid a warning on 64bit systems about casting between integer and pointer of different sizes
* This needs testing, but it's easy enough to rollback if need be
* Old: gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
* New: gnutls_transport_set_ptr(session->sess, &fd); // Give gnutls the fd for the socket.
*
* With testing this seems to...not work :/
*/
gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any.
Handshake(session);
}
virtual void OnRawSocketConnect(int fd)
{
issl_session* session = &sessions[fd];
session->fd = fd;
session->inbuf = new char[inbufsize];
session->inbufoffset = 0;
gnutls_init(&session->sess, GNUTLS_CLIENT);
gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate.
gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
gnutls_dh_set_prime_bits(session->sess, dh_bits);
gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket.
Handshake(session);
}
virtual void OnRawSocketClose(int fd)
{
CloseSession(&sessions[fd]);
EventHandler* user = ServerInstance->SE->GetRef(fd);
if ((user) && (user->GetExt("ssl_cert", dummy)))
{
ssl_cert* tofree;
user->GetExt("ssl_cert", tofree);
delete tofree;
user->Shrink("ssl_cert");
}
}
virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult)
{
issl_session* session = &sessions[fd];
if (!session->sess)
{
readresult = 0;
CloseSession(session);
return 1;
}
if (session->status == ISSL_HANDSHAKING_READ)
{
// The handshake isn't finished, try to finish it.
if(!Handshake(session))
{
// Couldn't resume handshake.
return -1;
}
}
else if (session->status == ISSL_HANDSHAKING_WRITE)
{
errno = EAGAIN;
return -1;
}
// If we resumed the handshake then session->status will be ISSL_HANDSHAKEN.
if (session->status == ISSL_HANDSHAKEN)
{
// Is this right? Not sure if the unencrypted data is garaunteed to be the same length.
// Read into the inbuffer, offset from the beginning by the amount of data we have that insp hasn't taken yet.
int ret = gnutls_record_recv(session->sess, session->inbuf + session->inbufoffset, inbufsize - session->inbufoffset);
if (ret == 0)
{
// Client closed connection.
readresult = 0;
CloseSession(session);
return 1;
}
else if (ret < 0)
{
if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
{
errno = EAGAIN;
return -1;
}
else
{
readresult = 0;
CloseSession(session);
}
}
else
{
// Read successfully 'ret' bytes into inbuf + inbufoffset
// There are 'ret' + 'inbufoffset' bytes of data in 'inbuf'
// 'buffer' is 'count' long
unsigned int length = ret + session->inbufoffset;
if(count <= length)
{
memcpy(buffer, session->inbuf, count);
// Move the stuff left in inbuf to the beginning of it
memcpy(session->inbuf, session->inbuf + count, (length - count));
// Now we need to set session->inbufoffset to the amount of data still waiting to be handed to insp.
session->inbufoffset = length - count;
// Insp uses readresult as the count of how much data there is in buffer, so:
readresult = count;
}
else
{
// There's not as much in the inbuf as there is space in the buffer, so just copy the whole thing.
memcpy(buffer, session->inbuf, length);
// Zero the offset, as there's nothing there..
session->inbufoffset = 0;
// As above
readresult = length;
}
}
}
else if(session->status == ISSL_CLOSING)
readresult = 0;
return 1;
}
virtual int OnRawSocketWrite(int fd, const char* buffer, int count)
{
if (!count)
return 0;
issl_session* session = &sessions[fd];
const char* sendbuffer = buffer;
if (!session->sess)
{
ServerInstance->Log(DEBUG,"No session");
CloseSession(session);
return 1;
}
session->outbuf.append(sendbuffer, count);
sendbuffer = session->outbuf.c_str();
count = session->outbuf.size();
if (session->status == ISSL_HANDSHAKING_WRITE)
{
// The handshake isn't finished, try to finish it.
ServerInstance->Log(DEBUG,"Finishing handshake");
Handshake(session);
errno = EAGAIN;
return -1;
}
int ret = 0;
if (session->status == ISSL_HANDSHAKEN)
{
ServerInstance->Log(DEBUG,"Send record");
ret = gnutls_record_send(session->sess, sendbuffer, count);
ServerInstance->Log(DEBUG,"Return: %d", ret);
if (ret == 0)
{
CloseSession(session);
}
else if (ret < 0)
{
if(ret != GNUTLS_E_AGAIN && ret != GNUTLS_E_INTERRUPTED)
{
ServerInstance->Log(DEBUG,"Not egain or interrupt, close session");
CloseSession(session);
}
else
{
ServerInstance->Log(DEBUG,"Again please");
errno = EAGAIN;
return -1;
}
}
else
{
ServerInstance->Log(DEBUG,"Trim buffer");
session->outbuf = session->outbuf.substr(ret);
}
}
/* Who's smart idea was it to return 1 when we havent written anything?
* This fucks the buffer up in InspSocket :p
*/
return ret < 1 ? 0 : ret;
}
// :kenny.chatspike.net 320 Om Epy|AFK :is a Secure Connection
virtual void OnWhois(userrec* source, userrec* dest)
{
if (!clientactive)
return;
// Bugfix, only send this numeric for *our* SSL users
if(dest->GetExt("ssl", dummy) || (IS_LOCAL(dest) && isin(dest->GetPort(), listenports)))
{
ServerInstance->SendWhoisLine(source, dest, 320, "%s %s :is using a secure connection", source->nick, dest->nick);
}
}
virtual void OnSyncUserMetaData(userrec* user, Module* proto, void* opaque, const std::string &extname, bool displayable)
{
// check if the linking module wants to know about OUR metadata
if(extname == "ssl")
{
// check if this user has an swhois field to send
if(user->GetExt(extname, dummy))
{
// call this function in the linking module, let it format the data how it
// sees fit, and send it on its way. We dont need or want to know how.
proto->ProtoSendMetaData(opaque, TYPE_USER, user, extname, displayable ? "Enabled" : "ON");
}
}
}
virtual void OnDecodeMetaData(int target_type, void* target, const std::string &extname, const std::string &extdata)
{
// check if its our metadata key, and its associated with a user
if ((target_type == TYPE_USER) && (extname == "ssl"))
{
userrec* dest = (userrec*)target;
// if they dont already have an ssl flag, accept the remote server's
if (!dest->GetExt(extname, dummy))
{
dest->Extend(extname, "ON");
}
}
}
bool Handshake(issl_session* session)
{
int ret = gnutls_handshake(session->sess);
if (ret < 0)
{
if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
{
// Handshake needs resuming later, read() or write() would have blocked.
if(gnutls_record_get_direction(session->sess) == 0)
{
// gnutls_handshake() wants to read() again.
session->status = ISSL_HANDSHAKING_READ;
}
else
{
// gnutls_handshake() wants to write() again.
session->status = ISSL_HANDSHAKING_WRITE;
MakePollWrite(session);
}
}
else
{
// Handshake failed.
CloseSession(session);
session->status = ISSL_CLOSING;
}
return false;
}
else
{
// Handshake complete.
// This will do for setting the ssl flag...it could be done earlier if it's needed. But this seems neater.
userrec* extendme = ServerInstance->FindDescriptor(session->fd);
if (extendme)
{
if (!extendme->GetExt("ssl", dummy))
extendme->Extend("ssl", "ON");
}
// Change the seesion state
session->status = ISSL_HANDSHAKEN;
// Finish writing, if any left
MakePollWrite(session);
return true;
}
}
virtual void OnPostConnect(userrec* user)
{
// This occurs AFTER OnUserConnect so we can be sure the
// protocol module has propogated the NICK message.
if ((user->GetExt("ssl", dummy)) && (IS_LOCAL(user)))
{
// Tell whatever protocol module we're using that we need to inform other servers of this metadata NOW.
std::deque<std::string>* metadata = new std::deque<std::string>;
metadata->push_back(user->nick);
metadata->push_back("ssl"); // The metadata id
metadata->push_back("ON"); // The value to send
Event* event = new Event((char*)metadata,(Module*)this,"send_metadata");
event->Send(ServerInstance); // Trigger the event. We don't care what module picks it up.
DELETE(event);
DELETE(metadata);
VerifyCertificate(&sessions[user->GetFd()],user);
if (sessions[user->GetFd()].sess)
{
std::string cipher = gnutls_kx_get_name(gnutls_kx_get(sessions[user->GetFd()].sess));
cipher.append("-").append(gnutls_cipher_get_name(gnutls_cipher_get(sessions[user->GetFd()].sess))).append("-");
cipher.append(gnutls_mac_get_name(gnutls_mac_get(sessions[user->GetFd()].sess)));
user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick, cipher.c_str());
}
}
}
void MakePollWrite(issl_session* session)
{
OnRawSocketWrite(session->fd, NULL, 0);
}
void CloseSession(issl_session* session)
{
if(session->sess)
{
gnutls_bye(session->sess, GNUTLS_SHUT_WR);
gnutls_deinit(session->sess);
}
if(session->inbuf)
{
delete[] session->inbuf;
}
session->outbuf.clear();
session->inbuf = NULL;
session->sess = NULL;
session->status = ISSL_NONE;
}
void VerifyCertificate(issl_session* session, Extensible* user)
{
if (!session->sess || !user)
return;
unsigned int status;
const gnutls_datum_t* cert_list;
int ret;
unsigned int cert_list_size;
gnutls_x509_crt_t cert;
char name[MAXBUF];
unsigned char digest[MAXBUF];
size_t digest_size = sizeof(digest);
size_t name_size = sizeof(name);
ssl_cert* certinfo = new ssl_cert;
user->Extend("ssl_cert",certinfo);
/* This verification function uses the trusted CAs in the credentials
* structure. So you must have installed one or more CA certificates.
*/
ret = gnutls_certificate_verify_peers2(session->sess, &status);
if (ret < 0)
{
certinfo->data.insert(std::make_pair("error",std::string(gnutls_strerror(ret))));
return;
}
if (status & GNUTLS_CERT_INVALID)
{
certinfo->data.insert(std::make_pair("invalid",ConvToStr(1)));
}
else
{
certinfo->data.insert(std::make_pair("invalid",ConvToStr(0)));
}
if (status & GNUTLS_CERT_SIGNER_NOT_FOUND)
{
certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(1)));
}
else
{
certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(0)));
}
if (status & GNUTLS_CERT_REVOKED)
{
certinfo->data.insert(std::make_pair("revoked",ConvToStr(1)));
}
else
{
certinfo->data.insert(std::make_pair("revoked",ConvToStr(0)));
}
if (status & GNUTLS_CERT_SIGNER_NOT_CA)
{
certinfo->data.insert(std::make_pair("trusted",ConvToStr(0)));
}
else
{
certinfo->data.insert(std::make_pair("trusted",ConvToStr(1)));
}
/* Up to here the process is the same for X.509 certificates and
* OpenPGP keys. From now on X.509 certificates are assumed. This can
* be easily extended to work with openpgp keys as well.
*/
if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509)
{
certinfo->data.insert(std::make_pair("error","No X509 keys sent"));
return;
}
ret = gnutls_x509_crt_init(&cert);
if (ret < 0)
{
certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
return;
}
cert_list_size = 0;
cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size);
if (cert_list == NULL)
{
certinfo->data.insert(std::make_pair("error","No certificate was found"));
return;
}
/* This is not a real world example, since we only check the first
* certificate in the given chain.
*/
ret = gnutls_x509_crt_import(cert, &cert_list[0], GNUTLS_X509_FMT_DER);
if (ret < 0)
{
certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
return;
}
gnutls_x509_crt_get_dn(cert, name, &name_size);
certinfo->data.insert(std::make_pair("dn",name));
gnutls_x509_crt_get_issuer_dn(cert, name, &name_size);
certinfo->data.insert(std::make_pair("issuer",name));
if ((ret = gnutls_x509_crt_get_fingerprint(cert, GNUTLS_DIG_MD5, digest, &digest_size)) < 0)
{
certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret)));
}
else
{
certinfo->data.insert(std::make_pair("fingerprint",irc::hex(digest, digest_size)));
}
/* Beware here we do not check for errors.
*/
if ((gnutls_x509_crt_get_expiration_time(cert) < time(0)) || (gnutls_x509_crt_get_activation_time(cert) > time(0)))
{
certinfo->data.insert(std::make_pair("error","Not activated, or expired certificate"));
}
gnutls_x509_crt_deinit(cert);
return;
}
};
MODULE_INIT(ModuleSSLGnuTLS);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" + +#include <gnutls/gnutls.h> +#include <gnutls/x509.h> + +#include "inspircd_config.h" +#include "configreader.h" +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "socket.h" +#include "hashcomp.h" +#include "transport.h" + +#ifdef WINDOWS +#pragma comment(lib, "libgnutls-13.lib") +#undef MAX_DESCRIPTORS +#define MAX_DESCRIPTORS 10000 +#endif + +/* $ModDesc: Provides SSL support for clients */ +/* $CompileFlags: exec("libgnutls-config --cflags") */ +/* $LinkerFlags: rpath("libgnutls-config --libs") exec("libgnutls-config --libs") */ +/* $ModDep: transport.h */ + + +enum issl_status { ISSL_NONE, ISSL_HANDSHAKING_READ, ISSL_HANDSHAKING_WRITE, ISSL_HANDSHAKEN, ISSL_CLOSING, ISSL_CLOSED }; + +bool isin(int port, const std::vector<int> &portlist) +{ + for(unsigned int i = 0; i < portlist.size(); i++) + if(portlist[i] == port) + return true; + + return false; +} + +/** Represents an SSL user's extra data + */ +class issl_session : public classbase +{ +public: + gnutls_session_t sess; + issl_status status; + std::string outbuf; + int inbufoffset; + char* inbuf; + int fd; +}; + +class ModuleSSLGnuTLS : public Module +{ + + ConfigReader* Conf; + + char* dummy; + + std::vector<int> listenports; + + int inbufsize; + issl_session sessions[MAX_DESCRIPTORS]; + + gnutls_certificate_credentials x509_cred; + gnutls_dh_params dh_params; + + std::string keyfile; + std::string certfile; + std::string cafile; + std::string crlfile; + std::string sslports; + int dh_bits; + + int clientactive; + + public: + + ModuleSSLGnuTLS(InspIRCd* Me) + : Module(Me) + { + ServerInstance->PublishInterface("InspSocketHook", this); + + // Not rehashable...because I cba to reduce all the sizes of existing buffers. + inbufsize = ServerInstance->Config->NetBufferSize; + + gnutls_global_init(); // This must be called once in the program + + if(gnutls_certificate_allocate_credentials(&x509_cred) != 0) + ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to allocate certificate credentials"); + + // Guessing return meaning + if(gnutls_dh_params_init(&dh_params) < 0) + ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to initialise DH parameters"); + + // Needs the flag as it ignores a plain /rehash + OnRehash(NULL,"ssl"); + + // Void return, guess we assume success + gnutls_certificate_set_dh_params(x509_cred, dh_params); + } + + virtual void OnRehash(userrec* user, const std::string ¶m) + { + if(param != "ssl") + return; + + Conf = new ConfigReader(ServerInstance); + + for(unsigned int i = 0; i < listenports.size(); i++) + { + ServerInstance->Config->DelIOHook(listenports[i]); + } + + listenports.clear(); + clientactive = 0; + sslports.clear(); + + for(int i = 0; i < Conf->Enumerate("bind"); i++) + { + // For each <bind> tag + std::string x = Conf->ReadValue("bind", "type", i); + if(((x.empty()) || (x == "clients")) && (Conf->ReadValue("bind", "ssl", i) == "gnutls")) + { + // Get the port we're meant to be listening on with SSL + std::string port = Conf->ReadValue("bind", "port", i); + irc::portparser portrange(port, false); + long portno = -1; + while ((portno = portrange.GetToken())) + { + clientactive++; + try + { + if (ServerInstance->Config->AddIOHook(portno, this)) + { + listenports.push_back(portno); + for (size_t i = 0; i < ServerInstance->Config->ports.size(); i++) + if (ServerInstance->Config->ports[i]->GetPort() == portno) + ServerInstance->Config->ports[i]->SetDescription("ssl"); + ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Enabling SSL for port %d", portno); + sslports.append("*:").append(ConvToStr(portno)).append(";"); + } + else + { + ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: FAILED to enable SSL on port %d, maybe you have another ssl or similar module loaded?", portno); + } + } + catch (ModuleException &e) + { + ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: FAILED to enable SSL on port %d: %s. Maybe it's already hooked by the same port on a different IP, or you have an other SSL or similar module loaded?", portno, e.GetReason()); + } + } + } + } + + std::string confdir(ServerInstance->ConfigFileName); + // +1 so we the path ends with a / + confdir = confdir.substr(0, confdir.find_last_of('/') + 1); + + cafile = Conf->ReadValue("gnutls", "cafile", 0); + crlfile = Conf->ReadValue("gnutls", "crlfile", 0); + certfile = Conf->ReadValue("gnutls", "certfile", 0); + keyfile = Conf->ReadValue("gnutls", "keyfile", 0); + dh_bits = Conf->ReadInteger("gnutls", "dhbits", 0, false); + + // Set all the default values needed. + if (cafile.empty()) + cafile = "ca.pem"; + + if (crlfile.empty()) + crlfile = "crl.pem"; + + if (certfile.empty()) + certfile = "cert.pem"; + + if (keyfile.empty()) + keyfile = "key.pem"; + + if((dh_bits != 768) && (dh_bits != 1024) && (dh_bits != 2048) && (dh_bits != 3072) && (dh_bits != 4096)) + dh_bits = 1024; + + // Prepend relative paths with the path to the config directory. + if(cafile[0] != '/') + cafile = confdir + cafile; + + if(crlfile[0] != '/') + crlfile = confdir + crlfile; + + if(certfile[0] != '/') + certfile = confdir + certfile; + + if(keyfile[0] != '/') + keyfile = confdir + keyfile; + + int ret; + + if((ret =gnutls_certificate_set_x509_trust_file(x509_cred, cafile.c_str(), GNUTLS_X509_FMT_PEM)) < 0) + ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 trust file '%s': %s", cafile.c_str(), gnutls_strerror(ret)); + + if((ret = gnutls_certificate_set_x509_crl_file (x509_cred, crlfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0) + ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to set X.509 CRL file '%s': %s", crlfile.c_str(), gnutls_strerror(ret)); + + if((ret = gnutls_certificate_set_x509_key_file (x509_cred, certfile.c_str(), keyfile.c_str(), GNUTLS_X509_FMT_PEM)) < 0) + { + // If this fails, no SSL port will work. At all. So, do the smart thing - throw a ModuleException + throw ModuleException("Unable to load GnuTLS server certificate: " + std::string(gnutls_strerror(ret))); + } + + // This may be on a large (once a day or week) timer eventually. + GenerateDHParams(); + + DELETE(Conf); + } + + void GenerateDHParams() + { + // Generate Diffie Hellman parameters - for use with DHE + // kx algorithms. These should be discarded and regenerated + // once a day, once a week or once a month. Depending on the + // security requirements. + + int ret; + + if((ret = gnutls_dh_params_generate2(dh_params, dh_bits)) < 0) + ServerInstance->Log(DEFAULT, "m_ssl_gnutls.so: Failed to generate DH parameters (%d bits): %s", dh_bits, gnutls_strerror(ret)); + } + + virtual ~ModuleSSLGnuTLS() + { + gnutls_dh_params_deinit(dh_params); + gnutls_certificate_free_credentials(x509_cred); + gnutls_global_deinit(); + } + + virtual void OnCleanup(int target_type, void* item) + { + if(target_type == TYPE_USER) + { + userrec* user = (userrec*)item; + + if(user->GetExt("ssl", dummy) && isin(user->GetPort(), listenports)) + { + // User is using SSL, they're a local user, and they're using one of *our* SSL ports. + // Potentially there could be multiple SSL modules loaded at once on different ports. + ServerInstance->GlobalCulls.AddItem(user, "SSL module unloading"); + } + if (user->GetExt("ssl_cert", dummy) && isin(user->GetPort(), listenports)) + { + ssl_cert* tofree; + user->GetExt("ssl_cert", tofree); + delete tofree; + user->Shrink("ssl_cert"); + } + } + } + + virtual void OnUnloadModule(Module* mod, const std::string &name) + { + if(mod == this) + { + for(unsigned int i = 0; i < listenports.size(); i++) + { + ServerInstance->Config->DelIOHook(listenports[i]); + for (size_t j = 0; j < ServerInstance->Config->ports.size(); j++) + if (ServerInstance->Config->ports[j]->GetPort() == listenports[i]) + ServerInstance->Config->ports[j]->SetDescription("plaintext"); + } + } + } + + virtual Version GetVersion() + { + return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION); + } + + void Implements(char* List) + { + List[I_On005Numeric] = List[I_OnRawSocketConnect] = List[I_OnRawSocketAccept] = List[I_OnRawSocketClose] = List[I_OnRawSocketRead] = List[I_OnRawSocketWrite] = List[I_OnCleanup] = 1; + List[I_OnRequest] = List[I_OnSyncUserMetaData] = List[I_OnDecodeMetaData] = List[I_OnUnloadModule] = List[I_OnRehash] = List[I_OnWhois] = List[I_OnPostConnect] = 1; + } + + virtual void On005Numeric(std::string &output) + { + output.append(" SSL=" + sslports); + } + + virtual char* OnRequest(Request* request) + { + ISHRequest* ISR = (ISHRequest*)request; + if (strcmp("IS_NAME", request->GetId()) == 0) + { + return "gnutls"; + } + else if (strcmp("IS_HOOK", request->GetId()) == 0) + { + char* ret = "OK"; + try + { + ret = ServerInstance->Config->AddIOHook((Module*)this, (InspSocket*)ISR->Sock) ? (char*)"OK" : NULL; + } + catch (ModuleException &e) + { + return NULL; + } + return ret; + } + else if (strcmp("IS_UNHOOK", request->GetId()) == 0) + { + return ServerInstance->Config->DelIOHook((InspSocket*)ISR->Sock) ? (char*)"OK" : NULL; + } + else if (strcmp("IS_HSDONE", request->GetId()) == 0) + { + if (ISR->Sock->GetFd() < 0) + return (char*)"OK"; + + issl_session* session = &sessions[ISR->Sock->GetFd()]; + return (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) ? NULL : (char*)"OK"; + } + else if (strcmp("IS_ATTACH", request->GetId()) == 0) + { + if (ISR->Sock->GetFd() > -1) + { + issl_session* session = &sessions[ISR->Sock->GetFd()]; + if (session->sess) + { + if ((Extensible*)ServerInstance->FindDescriptor(ISR->Sock->GetFd()) == (Extensible*)(ISR->Sock)) + { + VerifyCertificate(session, (InspSocket*)ISR->Sock); + return "OK"; + } + } + } + } + return NULL; + } + + + virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport) + { + issl_session* session = &sessions[fd]; + + session->fd = fd; + session->inbuf = new char[inbufsize]; + session->inbufoffset = 0; + + gnutls_init(&session->sess, GNUTLS_SERVER); + + gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate. + gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred); + gnutls_dh_set_prime_bits(session->sess, dh_bits); + + /* This is an experimental change to avoid a warning on 64bit systems about casting between integer and pointer of different sizes + * This needs testing, but it's easy enough to rollback if need be + * Old: gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket. + * New: gnutls_transport_set_ptr(session->sess, &fd); // Give gnutls the fd for the socket. + * + * With testing this seems to...not work :/ + */ + + gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket. + + gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any. + + Handshake(session); + } + + virtual void OnRawSocketConnect(int fd) + { + issl_session* session = &sessions[fd]; + + session->fd = fd; + session->inbuf = new char[inbufsize]; + session->inbufoffset = 0; + + gnutls_init(&session->sess, GNUTLS_CLIENT); + + gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate. + gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred); + gnutls_dh_set_prime_bits(session->sess, dh_bits); + gnutls_transport_set_ptr(session->sess, (gnutls_transport_ptr_t) fd); // Give gnutls the fd for the socket. + + Handshake(session); + } + + virtual void OnRawSocketClose(int fd) + { + CloseSession(&sessions[fd]); + + EventHandler* user = ServerInstance->SE->GetRef(fd); + + if ((user) && (user->GetExt("ssl_cert", dummy))) + { + ssl_cert* tofree; + user->GetExt("ssl_cert", tofree); + delete tofree; + user->Shrink("ssl_cert"); + } + } + + virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult) + { + issl_session* session = &sessions[fd]; + + if (!session->sess) + { + readresult = 0; + CloseSession(session); + return 1; + } + + if (session->status == ISSL_HANDSHAKING_READ) + { + // The handshake isn't finished, try to finish it. + + if(!Handshake(session)) + { + // Couldn't resume handshake. + return -1; + } + } + else if (session->status == ISSL_HANDSHAKING_WRITE) + { + errno = EAGAIN; + return -1; + } + + // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN. + + if (session->status == ISSL_HANDSHAKEN) + { + // Is this right? Not sure if the unencrypted data is garaunteed to be the same length. + // Read into the inbuffer, offset from the beginning by the amount of data we have that insp hasn't taken yet. + int ret = gnutls_record_recv(session->sess, session->inbuf + session->inbufoffset, inbufsize - session->inbufoffset); + + if (ret == 0) + { + // Client closed connection. + readresult = 0; + CloseSession(session); + return 1; + } + else if (ret < 0) + { + if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) + { + errno = EAGAIN; + return -1; + } + else + { + readresult = 0; + CloseSession(session); + } + } + else + { + // Read successfully 'ret' bytes into inbuf + inbufoffset + // There are 'ret' + 'inbufoffset' bytes of data in 'inbuf' + // 'buffer' is 'count' long + + unsigned int length = ret + session->inbufoffset; + + if(count <= length) + { + memcpy(buffer, session->inbuf, count); + // Move the stuff left in inbuf to the beginning of it + memcpy(session->inbuf, session->inbuf + count, (length - count)); + // Now we need to set session->inbufoffset to the amount of data still waiting to be handed to insp. + session->inbufoffset = length - count; + // Insp uses readresult as the count of how much data there is in buffer, so: + readresult = count; + } + else + { + // There's not as much in the inbuf as there is space in the buffer, so just copy the whole thing. + memcpy(buffer, session->inbuf, length); + // Zero the offset, as there's nothing there.. + session->inbufoffset = 0; + // As above + readresult = length; + } + } + } + else if(session->status == ISSL_CLOSING) + readresult = 0; + + return 1; + } + + virtual int OnRawSocketWrite(int fd, const char* buffer, int count) + { + if (!count) + return 0; + + issl_session* session = &sessions[fd]; + const char* sendbuffer = buffer; + + if (!session->sess) + { + ServerInstance->Log(DEBUG,"No session"); + CloseSession(session); + return 1; + } + + session->outbuf.append(sendbuffer, count); + sendbuffer = session->outbuf.c_str(); + count = session->outbuf.size(); + + if (session->status == ISSL_HANDSHAKING_WRITE) + { + // The handshake isn't finished, try to finish it. + ServerInstance->Log(DEBUG,"Finishing handshake"); + Handshake(session); + errno = EAGAIN; + return -1; + } + + int ret = 0; + + if (session->status == ISSL_HANDSHAKEN) + { + ServerInstance->Log(DEBUG,"Send record"); + ret = gnutls_record_send(session->sess, sendbuffer, count); + ServerInstance->Log(DEBUG,"Return: %d", ret); + + if (ret == 0) + { + CloseSession(session); + } + else if (ret < 0) + { + if(ret != GNUTLS_E_AGAIN && ret != GNUTLS_E_INTERRUPTED) + { + ServerInstance->Log(DEBUG,"Not egain or interrupt, close session"); + CloseSession(session); + } + else + { + ServerInstance->Log(DEBUG,"Again please"); + errno = EAGAIN; + return -1; + } + } + else + { + ServerInstance->Log(DEBUG,"Trim buffer"); + session->outbuf = session->outbuf.substr(ret); + } + } + + /* Who's smart idea was it to return 1 when we havent written anything? + * This fucks the buffer up in InspSocket :p + */ + return ret < 1 ? 0 : ret; + } + + // :kenny.chatspike.net 320 Om Epy|AFK :is a Secure Connection + virtual void OnWhois(userrec* source, userrec* dest) + { + if (!clientactive) + return; + + // Bugfix, only send this numeric for *our* SSL users + if(dest->GetExt("ssl", dummy) || (IS_LOCAL(dest) && isin(dest->GetPort(), listenports))) + { + ServerInstance->SendWhoisLine(source, dest, 320, "%s %s :is using a secure connection", source->nick, dest->nick); + } + } + + virtual void OnSyncUserMetaData(userrec* user, Module* proto, void* opaque, const std::string &extname, bool displayable) + { + // check if the linking module wants to know about OUR metadata + if(extname == "ssl") + { + // check if this user has an swhois field to send + if(user->GetExt(extname, dummy)) + { + // call this function in the linking module, let it format the data how it + // sees fit, and send it on its way. We dont need or want to know how. + proto->ProtoSendMetaData(opaque, TYPE_USER, user, extname, displayable ? "Enabled" : "ON"); + } + } + } + + virtual void OnDecodeMetaData(int target_type, void* target, const std::string &extname, const std::string &extdata) + { + // check if its our metadata key, and its associated with a user + if ((target_type == TYPE_USER) && (extname == "ssl")) + { + userrec* dest = (userrec*)target; + // if they dont already have an ssl flag, accept the remote server's + if (!dest->GetExt(extname, dummy)) + { + dest->Extend(extname, "ON"); + } + } + } + + bool Handshake(issl_session* session) + { + int ret = gnutls_handshake(session->sess); + + if (ret < 0) + { + if(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) + { + // Handshake needs resuming later, read() or write() would have blocked. + + if(gnutls_record_get_direction(session->sess) == 0) + { + // gnutls_handshake() wants to read() again. + session->status = ISSL_HANDSHAKING_READ; + } + else + { + // gnutls_handshake() wants to write() again. + session->status = ISSL_HANDSHAKING_WRITE; + MakePollWrite(session); + } + } + else + { + // Handshake failed. + CloseSession(session); + session->status = ISSL_CLOSING; + } + + return false; + } + else + { + // Handshake complete. + // This will do for setting the ssl flag...it could be done earlier if it's needed. But this seems neater. + userrec* extendme = ServerInstance->FindDescriptor(session->fd); + if (extendme) + { + if (!extendme->GetExt("ssl", dummy)) + extendme->Extend("ssl", "ON"); + } + + // Change the seesion state + session->status = ISSL_HANDSHAKEN; + + // Finish writing, if any left + MakePollWrite(session); + + return true; + } + } + + virtual void OnPostConnect(userrec* user) + { + // This occurs AFTER OnUserConnect so we can be sure the + // protocol module has propogated the NICK message. + if ((user->GetExt("ssl", dummy)) && (IS_LOCAL(user))) + { + // Tell whatever protocol module we're using that we need to inform other servers of this metadata NOW. + std::deque<std::string>* metadata = new std::deque<std::string>; + metadata->push_back(user->nick); + metadata->push_back("ssl"); // The metadata id + metadata->push_back("ON"); // The value to send + Event* event = new Event((char*)metadata,(Module*)this,"send_metadata"); + event->Send(ServerInstance); // Trigger the event. We don't care what module picks it up. + DELETE(event); + DELETE(metadata); + + VerifyCertificate(&sessions[user->GetFd()],user); + if (sessions[user->GetFd()].sess) + { + std::string cipher = gnutls_kx_get_name(gnutls_kx_get(sessions[user->GetFd()].sess)); + cipher.append("-").append(gnutls_cipher_get_name(gnutls_cipher_get(sessions[user->GetFd()].sess))).append("-"); + cipher.append(gnutls_mac_get_name(gnutls_mac_get(sessions[user->GetFd()].sess))); + user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick, cipher.c_str()); + } + } + } + + void MakePollWrite(issl_session* session) + { + OnRawSocketWrite(session->fd, NULL, 0); + } + + void CloseSession(issl_session* session) + { + if(session->sess) + { + gnutls_bye(session->sess, GNUTLS_SHUT_WR); + gnutls_deinit(session->sess); + } + + if(session->inbuf) + { + delete[] session->inbuf; + } + + session->outbuf.clear(); + session->inbuf = NULL; + session->sess = NULL; + session->status = ISSL_NONE; + } + + void VerifyCertificate(issl_session* session, Extensible* user) + { + if (!session->sess || !user) + return; + + unsigned int status; + const gnutls_datum_t* cert_list; + int ret; + unsigned int cert_list_size; + gnutls_x509_crt_t cert; + char name[MAXBUF]; + unsigned char digest[MAXBUF]; + size_t digest_size = sizeof(digest); + size_t name_size = sizeof(name); + ssl_cert* certinfo = new ssl_cert; + + user->Extend("ssl_cert",certinfo); + + /* This verification function uses the trusted CAs in the credentials + * structure. So you must have installed one or more CA certificates. + */ + ret = gnutls_certificate_verify_peers2(session->sess, &status); + + if (ret < 0) + { + certinfo->data.insert(std::make_pair("error",std::string(gnutls_strerror(ret)))); + return; + } + + if (status & GNUTLS_CERT_INVALID) + { + certinfo->data.insert(std::make_pair("invalid",ConvToStr(1))); + } + else + { + certinfo->data.insert(std::make_pair("invalid",ConvToStr(0))); + } + if (status & GNUTLS_CERT_SIGNER_NOT_FOUND) + { + certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(1))); + } + else + { + certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(0))); + } + if (status & GNUTLS_CERT_REVOKED) + { + certinfo->data.insert(std::make_pair("revoked",ConvToStr(1))); + } + else + { + certinfo->data.insert(std::make_pair("revoked",ConvToStr(0))); + } + if (status & GNUTLS_CERT_SIGNER_NOT_CA) + { + certinfo->data.insert(std::make_pair("trusted",ConvToStr(0))); + } + else + { + certinfo->data.insert(std::make_pair("trusted",ConvToStr(1))); + } + + /* Up to here the process is the same for X.509 certificates and + * OpenPGP keys. From now on X.509 certificates are assumed. This can + * be easily extended to work with openpgp keys as well. + */ + if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509) + { + certinfo->data.insert(std::make_pair("error","No X509 keys sent")); + return; + } + + ret = gnutls_x509_crt_init(&cert); + if (ret < 0) + { + certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret))); + return; + } + + cert_list_size = 0; + cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size); + if (cert_list == NULL) + { + certinfo->data.insert(std::make_pair("error","No certificate was found")); + return; + } + + /* This is not a real world example, since we only check the first + * certificate in the given chain. + */ + + ret = gnutls_x509_crt_import(cert, &cert_list[0], GNUTLS_X509_FMT_DER); + if (ret < 0) + { + certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret))); + return; + } + + gnutls_x509_crt_get_dn(cert, name, &name_size); + + certinfo->data.insert(std::make_pair("dn",name)); + + gnutls_x509_crt_get_issuer_dn(cert, name, &name_size); + + certinfo->data.insert(std::make_pair("issuer",name)); + + if ((ret = gnutls_x509_crt_get_fingerprint(cert, GNUTLS_DIG_MD5, digest, &digest_size)) < 0) + { + certinfo->data.insert(std::make_pair("error",gnutls_strerror(ret))); + } + else + { + certinfo->data.insert(std::make_pair("fingerprint",irc::hex(digest, digest_size))); + } + + /* Beware here we do not check for errors. + */ + if ((gnutls_x509_crt_get_expiration_time(cert) < time(0)) || (gnutls_x509_crt_get_activation_time(cert) > time(0))) + { + certinfo->data.insert(std::make_pair("error","Not activated, or expired certificate")); + } + + gnutls_x509_crt_deinit(cert); + + return; + } + +}; + +MODULE_INIT(ModuleSSLGnuTLS); + diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp index 43dc43aea..ffd9d4032 100644 --- a/src/modules/extra/m_ssl_openssl.cpp +++ b/src/modules/extra/m_ssl_openssl.cpp @@ -1 +1,901 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include <openssl/ssl.h>
#include <openssl/err.h>
#ifdef WINDOWS
#include <openssl/applink.c>
#endif
#include "configreader.h"
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "socket.h"
#include "hashcomp.h"
#include "transport.h"
#ifdef WINDOWS
#pragma comment(lib, "libeay32MTd")
#pragma comment(lib, "ssleay32MTd")
#undef MAX_DESCRIPTORS
#define MAX_DESCRIPTORS 10000
#endif
/* $ModDesc: Provides SSL support for clients */
/* $CompileFlags: pkgconfversion("openssl","0.9.7") pkgconfincludes("openssl","/openssl/ssl.h","") */
/* $LinkerFlags: rpath("pkg-config --libs openssl") pkgconflibs("openssl","/libssl.so","-lssl -lcrypto -ldl") */
/* $ModDep: transport.h */
enum issl_status { ISSL_NONE, ISSL_HANDSHAKING, ISSL_OPEN };
enum issl_io_status { ISSL_WRITE, ISSL_READ };
static bool SelfSigned = false;
bool isin(int port, const std::vector<int> &portlist)
{
for(unsigned int i = 0; i < portlist.size(); i++)
if(portlist[i] == port)
return true;
return false;
}
char* get_error()
{
return ERR_error_string(ERR_get_error(), NULL);
}
static int error_callback(const char *str, size_t len, void *u);
/** Represents an SSL user's extra data
*/
class issl_session : public classbase
{
public:
SSL* sess;
issl_status status;
issl_io_status rstat;
issl_io_status wstat;
unsigned int inbufoffset;
char* inbuf; // Buffer OpenSSL reads into.
std::string outbuf; // Buffer for outgoing data that OpenSSL will not take.
int fd;
bool outbound;
issl_session()
{
outbound = false;
rstat = ISSL_READ;
wstat = ISSL_WRITE;
}
};
static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx)
{
/* XXX: This will allow self signed certificates.
* In the future if we want an option to not allow this,
* we can just return preverify_ok here, and openssl
* will boot off self-signed and invalid peer certs.
*/
int ve = X509_STORE_CTX_get_error(ctx);
SelfSigned = (ve == X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT);
return 1;
}
class ModuleSSLOpenSSL : public Module
{
ConfigReader* Conf;
std::vector<int> listenports;
int inbufsize;
issl_session sessions[MAX_DESCRIPTORS];
SSL_CTX* ctx;
SSL_CTX* clictx;
char* dummy;
char cipher[MAXBUF];
std::string keyfile;
std::string certfile;
std::string cafile;
// std::string crlfile;
std::string dhfile;
std::string sslports;
int clientactive;
public:
InspIRCd* PublicInstance;
ModuleSSLOpenSSL(InspIRCd* Me)
: Module(Me), PublicInstance(Me)
{
ServerInstance->PublishInterface("InspSocketHook", this);
// Not rehashable...because I cba to reduce all the sizes of existing buffers.
inbufsize = ServerInstance->Config->NetBufferSize;
/* Global SSL library initialization*/
SSL_library_init();
SSL_load_error_strings();
/* Build our SSL contexts:
* NOTE: OpenSSL makes us have two contexts, one for servers and one for clients. ICK.
*/
ctx = SSL_CTX_new( SSLv23_server_method() );
clictx = SSL_CTX_new( SSLv23_client_method() );
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify);
SSL_CTX_set_verify(clictx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify);
// Needs the flag as it ignores a plain /rehash
OnRehash(NULL,"ssl");
}
virtual void OnRehash(userrec* user, const std::string ¶m)
{
if (param != "ssl")
return;
Conf = new ConfigReader(ServerInstance);
for (unsigned int i = 0; i < listenports.size(); i++)
{
ServerInstance->Config->DelIOHook(listenports[i]);
}
listenports.clear();
clientactive = 0;
sslports.clear();
for (int i = 0; i < Conf->Enumerate("bind"); i++)
{
// For each <bind> tag
std::string x = Conf->ReadValue("bind", "type", i);
if (((x.empty()) || (x == "clients")) && (Conf->ReadValue("bind", "ssl", i) == "openssl"))
{
// Get the port we're meant to be listening on with SSL
std::string port = Conf->ReadValue("bind", "port", i);
irc::portparser portrange(port, false);
long portno = -1;
while ((portno = portrange.GetToken()))
{
clientactive++;
try
{
if (ServerInstance->Config->AddIOHook(portno, this))
{
listenports.push_back(portno);
for (size_t i = 0; i < ServerInstance->Config->ports.size(); i++)
if (ServerInstance->Config->ports[i]->GetPort() == portno)
ServerInstance->Config->ports[i]->SetDescription("ssl");
ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Enabling SSL for port %d", portno);
sslports.append("*:").append(ConvToStr(portno)).append(";");
}
else
{
ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: FAILED to enable SSL on port %d, maybe you have another ssl or similar module loaded?", portno);
}
}
catch (ModuleException &e)
{
ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: FAILED to enable SSL on port %d: %s. Maybe it's already hooked by the same port on a different IP, or you have another SSL or similar module loaded?", portno, e.GetReason());
}
}
}
}
if (!sslports.empty())
sslports.erase(sslports.end() - 1);
std::string confdir(ServerInstance->ConfigFileName);
// +1 so we the path ends with a /
confdir = confdir.substr(0, confdir.find_last_of('/') + 1);
cafile = Conf->ReadValue("openssl", "cafile", 0);
certfile = Conf->ReadValue("openssl", "certfile", 0);
keyfile = Conf->ReadValue("openssl", "keyfile", 0);
dhfile = Conf->ReadValue("openssl", "dhfile", 0);
// Set all the default values needed.
if (cafile.empty())
cafile = "ca.pem";
if (certfile.empty())
certfile = "cert.pem";
if (keyfile.empty())
keyfile = "key.pem";
if (dhfile.empty())
dhfile = "dhparams.pem";
// Prepend relative paths with the path to the config directory.
if (cafile[0] != '/')
cafile = confdir + cafile;
if (certfile[0] != '/')
certfile = confdir + certfile;
if (keyfile[0] != '/')
keyfile = confdir + keyfile;
if (dhfile[0] != '/')
dhfile = confdir + dhfile;
/* Load our keys and certificates
* NOTE: OpenSSL's error logging API sucks, don't blame us for this clusterfuck.
*/
if ((!SSL_CTX_use_certificate_chain_file(ctx, certfile.c_str())) || (!SSL_CTX_use_certificate_chain_file(clictx, certfile.c_str())))
{
ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Can't read certificate file %s. %s", certfile.c_str(), strerror(errno));
ERR_print_errors_cb(error_callback, this);
}
if (((!SSL_CTX_use_PrivateKey_file(ctx, keyfile.c_str(), SSL_FILETYPE_PEM))) || (!SSL_CTX_use_PrivateKey_file(clictx, keyfile.c_str(), SSL_FILETYPE_PEM)))
{
ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Can't read key file %s. %s", keyfile.c_str(), strerror(errno));
ERR_print_errors_cb(error_callback, this);
}
/* Load the CAs we trust*/
if (((!SSL_CTX_load_verify_locations(ctx, cafile.c_str(), 0))) || (!SSL_CTX_load_verify_locations(clictx, cafile.c_str(), 0)))
{
ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Can't read CA list from %s. %s", cafile.c_str(), strerror(errno));
ERR_print_errors_cb(error_callback, this);
}
FILE* dhpfile = fopen(dhfile.c_str(), "r");
DH* ret;
if (dhpfile == NULL)
{
ServerInstance->Log(DEFAULT, "m_ssl_openssl.so Couldn't open DH file %s: %s", dhfile.c_str(), strerror(errno));
throw ModuleException("Couldn't open DH file " + dhfile + ": " + strerror(errno));
}
else
{
ret = PEM_read_DHparams(dhpfile, NULL, NULL, NULL);
if ((SSL_CTX_set_tmp_dh(ctx, ret) < 0) || (SSL_CTX_set_tmp_dh(clictx, ret) < 0))
{
ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Couldn't set DH parameters %s. SSL errors follow:", dhfile.c_str());
ERR_print_errors_cb(error_callback, this);
}
}
fclose(dhpfile);
DELETE(Conf);
}
virtual void On005Numeric(std::string &output)
{
output.append(" SSL=" + sslports);
}
virtual ~ModuleSSLOpenSSL()
{
SSL_CTX_free(ctx);
SSL_CTX_free(clictx);
}
virtual void OnCleanup(int target_type, void* item)
{
if (target_type == TYPE_USER)
{
userrec* user = (userrec*)item;
if (user->GetExt("ssl", dummy) && IS_LOCAL(user) && isin(user->GetPort(), listenports))
{
// User is using SSL, they're a local user, and they're using one of *our* SSL ports.
// Potentially there could be multiple SSL modules loaded at once on different ports.
ServerInstance->GlobalCulls.AddItem(user, "SSL module unloading");
}
if (user->GetExt("ssl_cert", dummy) && isin(user->GetPort(), listenports))
{
ssl_cert* tofree;
user->GetExt("ssl_cert", tofree);
delete tofree;
user->Shrink("ssl_cert");
}
}
}
virtual void OnUnloadModule(Module* mod, const std::string &name)
{
if (mod == this)
{
for(unsigned int i = 0; i < listenports.size(); i++)
{
ServerInstance->Config->DelIOHook(listenports[i]);
for (size_t j = 0; j < ServerInstance->Config->ports.size(); j++)
if (ServerInstance->Config->ports[j]->GetPort() == listenports[i])
ServerInstance->Config->ports[j]->SetDescription("plaintext");
}
}
}
virtual Version GetVersion()
{
return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION);
}
void Implements(char* List)
{
List[I_OnRawSocketConnect] = List[I_OnRawSocketAccept] = List[I_OnRawSocketClose] = List[I_OnRawSocketRead] = List[I_OnRawSocketWrite] = List[I_OnCleanup] = List[I_On005Numeric] = 1;
List[I_OnRequest] = List[I_OnSyncUserMetaData] = List[I_OnDecodeMetaData] = List[I_OnUnloadModule] = List[I_OnRehash] = List[I_OnWhois] = List[I_OnPostConnect] = 1;
}
virtual char* OnRequest(Request* request)
{
ISHRequest* ISR = (ISHRequest*)request;
if (strcmp("IS_NAME", request->GetId()) == 0)
{
return "openssl";
}
else if (strcmp("IS_HOOK", request->GetId()) == 0)
{
char* ret = "OK";
try
{
ret = ServerInstance->Config->AddIOHook((Module*)this, (InspSocket*)ISR->Sock) ? (char*)"OK" : NULL;
}
catch (ModuleException &e)
{
return NULL;
}
return ret;
}
else if (strcmp("IS_UNHOOK", request->GetId()) == 0)
{
return ServerInstance->Config->DelIOHook((InspSocket*)ISR->Sock) ? (char*)"OK" : NULL;
}
else if (strcmp("IS_HSDONE", request->GetId()) == 0)
{
ServerInstance->Log(DEBUG,"Module checking if handshake is done");
if (ISR->Sock->GetFd() < 0)
return (char*)"OK";
issl_session* session = &sessions[ISR->Sock->GetFd()];
return (session->status == ISSL_HANDSHAKING) ? NULL : (char*)"OK";
}
else if (strcmp("IS_ATTACH", request->GetId()) == 0)
{
issl_session* session = &sessions[ISR->Sock->GetFd()];
if (session->sess)
{
VerifyCertificate(session, (InspSocket*)ISR->Sock);
return "OK";
}
}
return NULL;
}
virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport)
{
issl_session* session = &sessions[fd];
session->fd = fd;
session->inbuf = new char[inbufsize];
session->inbufoffset = 0;
session->sess = SSL_new(ctx);
session->status = ISSL_NONE;
session->outbound = false;
if (session->sess == NULL)
return;
if (SSL_set_fd(session->sess, fd) == 0)
{
ServerInstance->Log(DEBUG,"BUG: Can't set fd with SSL_set_fd: %d", fd);
return;
}
Handshake(session);
}
virtual void OnRawSocketConnect(int fd)
{
ServerInstance->Log(DEBUG,"OnRawSocketConnect connecting");
issl_session* session = &sessions[fd];
session->fd = fd;
session->inbuf = new char[inbufsize];
session->inbufoffset = 0;
session->sess = SSL_new(clictx);
session->status = ISSL_NONE;
session->outbound = true;
if (session->sess == NULL)
return;
if (SSL_set_fd(session->sess, fd) == 0)
{
ServerInstance->Log(DEBUG,"BUG: Can't set fd with SSL_set_fd: %d", fd);
return;
}
Handshake(session);
ServerInstance->Log(DEBUG,"Exiting OnRawSocketConnect");
}
virtual void OnRawSocketClose(int fd)
{
CloseSession(&sessions[fd]);
EventHandler* user = ServerInstance->SE->GetRef(fd);
if ((user) && (user->GetExt("ssl_cert", dummy)))
{
ssl_cert* tofree;
user->GetExt("ssl_cert", tofree);
delete tofree;
user->Shrink("ssl_cert");
}
}
virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult)
{
issl_session* session = &sessions[fd];
ServerInstance->Log(DEBUG,"OnRawSocketRead");
if (!session->sess)
{
ServerInstance->Log(DEBUG,"OnRawSocketRead has no session");
readresult = 0;
CloseSession(session);
return 1;
}
if (session->status == ISSL_HANDSHAKING)
{
if (session->rstat == ISSL_READ || session->wstat == ISSL_READ)
{
ServerInstance->Log(DEBUG,"Resume handshake in read");
// The handshake isn't finished and it wants to read, try to finish it.
if (!Handshake(session))
{
ServerInstance->Log(DEBUG,"Cant resume handshake in read");
// Couldn't resume handshake.
return -1;
}
}
else
{
errno = EAGAIN;
return -1;
}
}
// If we resumed the handshake then session->status will be ISSL_OPEN
if (session->status == ISSL_OPEN)
{
if (session->wstat == ISSL_READ)
{
if(DoWrite(session) == 0)
return 0;
}
if (session->rstat == ISSL_READ)
{
int ret = DoRead(session);
if (ret > 0)
{
if (count <= session->inbufoffset)
{
memcpy(buffer, session->inbuf, count);
// Move the stuff left in inbuf to the beginning of it
memcpy(session->inbuf, session->inbuf + count, (session->inbufoffset - count));
// Now we need to set session->inbufoffset to the amount of data still waiting to be handed to insp.
session->inbufoffset -= count;
// Insp uses readresult as the count of how much data there is in buffer, so:
readresult = count;
}
else
{
// There's not as much in the inbuf as there is space in the buffer, so just copy the whole thing.
memcpy(buffer, session->inbuf, session->inbufoffset);
readresult = session->inbufoffset;
// Zero the offset, as there's nothing there..
session->inbufoffset = 0;
}
return 1;
}
else
{
return ret;
}
}
}
return -1;
}
virtual int OnRawSocketWrite(int fd, const char* buffer, int count)
{
issl_session* session = &sessions[fd];
if (!session->sess)
{
ServerInstance->Log(DEBUG,"Close session missing sess");
CloseSession(session);
return -1;
}
session->outbuf.append(buffer, count);
if (session->status == ISSL_HANDSHAKING)
{
// The handshake isn't finished, try to finish it.
if (session->rstat == ISSL_WRITE || session->wstat == ISSL_WRITE)
{
ServerInstance->Log(DEBUG,"Handshake resume");
Handshake(session);
}
}
if (session->status == ISSL_OPEN)
{
if (session->rstat == ISSL_WRITE)
{
ServerInstance->Log(DEBUG,"DoRead");
DoRead(session);
}
if (session->wstat == ISSL_WRITE)
{
ServerInstance->Log(DEBUG,"DoWrite");
return DoWrite(session);
}
}
return 1;
}
int DoWrite(issl_session* session)
{
if (!session->outbuf.size())
return -1;
int ret = SSL_write(session->sess, session->outbuf.data(), session->outbuf.size());
if (ret == 0)
{
ServerInstance->Log(DEBUG,"Oops, got 0 from SSL_write");
CloseSession(session);
return 0;
}
else if (ret < 0)
{
int err = SSL_get_error(session->sess, ret);
if (err == SSL_ERROR_WANT_WRITE)
{
session->wstat = ISSL_WRITE;
return -1;
}
else if (err == SSL_ERROR_WANT_READ)
{
session->wstat = ISSL_READ;
return -1;
}
else
{
ServerInstance->Log(DEBUG,"Close due to returned -1 in SSL_Write");
CloseSession(session);
return 0;
}
}
else
{
session->outbuf = session->outbuf.substr(ret);
return ret;
}
}
int DoRead(issl_session* session)
{
// Is this right? Not sure if the unencrypted data is garaunteed to be the same length.
// Read into the inbuffer, offset from the beginning by the amount of data we have that insp hasn't taken yet.
ServerInstance->Log(DEBUG,"DoRead");
int ret = SSL_read(session->sess, session->inbuf + session->inbufoffset, inbufsize - session->inbufoffset);
if (ret == 0)
{
// Client closed connection.
ServerInstance->Log(DEBUG,"Oops, got 0 from SSL_read");
CloseSession(session);
return 0;
}
else if (ret < 0)
{
int err = SSL_get_error(session->sess, ret);
if (err == SSL_ERROR_WANT_READ)
{
session->rstat = ISSL_READ;
ServerInstance->Log(DEBUG,"Setting want_read");
return -1;
}
else if (err == SSL_ERROR_WANT_WRITE)
{
session->rstat = ISSL_WRITE;
ServerInstance->Log(DEBUG,"Setting want_write");
return -1;
}
else
{
ServerInstance->Log(DEBUG,"Closed due to returned -1 in SSL_Read");
CloseSession(session);
return 0;
}
}
else
{
// Read successfully 'ret' bytes into inbuf + inbufoffset
// There are 'ret' + 'inbufoffset' bytes of data in 'inbuf'
// 'buffer' is 'count' long
session->inbufoffset += ret;
return ret;
}
}
// :kenny.chatspike.net 320 Om Epy|AFK :is a Secure Connection
virtual void OnWhois(userrec* source, userrec* dest)
{
if (!clientactive)
return;
// Bugfix, only send this numeric for *our* SSL users
if (dest->GetExt("ssl", dummy) || (IS_LOCAL(dest) && isin(dest->GetPort(), listenports)))
{
ServerInstance->SendWhoisLine(source, dest, 320, "%s %s :is using a secure connection", source->nick, dest->nick);
}
}
virtual void OnSyncUserMetaData(userrec* user, Module* proto, void* opaque, const std::string &extname, bool displayable)
{
// check if the linking module wants to know about OUR metadata
if (extname == "ssl")
{
// check if this user has an swhois field to send
if(user->GetExt(extname, dummy))
{
// call this function in the linking module, let it format the data how it
// sees fit, and send it on its way. We dont need or want to know how.
proto->ProtoSendMetaData(opaque, TYPE_USER, user, extname, displayable ? "Enabled" : "ON");
}
}
}
virtual void OnDecodeMetaData(int target_type, void* target, const std::string &extname, const std::string &extdata)
{
// check if its our metadata key, and its associated with a user
if ((target_type == TYPE_USER) && (extname == "ssl"))
{
userrec* dest = (userrec*)target;
// if they dont already have an ssl flag, accept the remote server's
if (!dest->GetExt(extname, dummy))
{
dest->Extend(extname, "ON");
}
}
}
bool Handshake(issl_session* session)
{
ServerInstance->Log(DEBUG,"Handshake");
int ret;
if (session->outbound)
{
ServerInstance->Log(DEBUG,"SSL_connect");
ret = SSL_connect(session->sess);
}
else
ret = SSL_accept(session->sess);
if (ret < 0)
{
int err = SSL_get_error(session->sess, ret);
if (err == SSL_ERROR_WANT_READ)
{
ServerInstance->Log(DEBUG,"Want read, handshaking");
session->rstat = ISSL_READ;
session->status = ISSL_HANDSHAKING;
return true;
}
else if (err == SSL_ERROR_WANT_WRITE)
{
ServerInstance->Log(DEBUG,"Want write, handshaking");
session->wstat = ISSL_WRITE;
session->status = ISSL_HANDSHAKING;
MakePollWrite(session);
return true;
}
else
{
ServerInstance->Log(DEBUG,"Handshake failed");
CloseSession(session);
}
return false;
}
else if (ret > 0)
{
// Handshake complete.
// This will do for setting the ssl flag...it could be done earlier if it's needed. But this seems neater.
userrec* u = ServerInstance->FindDescriptor(session->fd);
if (u)
{
if (!u->GetExt("ssl", dummy))
u->Extend("ssl", "ON");
}
session->status = ISSL_OPEN;
MakePollWrite(session);
return true;
}
else if (ret == 0)
{
int ssl_err = SSL_get_error(session->sess, ret);
char buf[1024];
ERR_print_errors_fp(stderr);
ServerInstance->Log(DEBUG,"Handshake fail 2: %d: %s", ssl_err, ERR_error_string(ssl_err,buf));
CloseSession(session);
return true;
}
return true;
}
virtual void OnPostConnect(userrec* user)
{
// This occurs AFTER OnUserConnect so we can be sure the
// protocol module has propogated the NICK message.
if ((user->GetExt("ssl", dummy)) && (IS_LOCAL(user)))
{
// Tell whatever protocol module we're using that we need to inform other servers of this metadata NOW.
std::deque<std::string>* metadata = new std::deque<std::string>;
metadata->push_back(user->nick);
metadata->push_back("ssl"); // The metadata id
metadata->push_back("ON"); // The value to send
Event* event = new Event((char*)metadata,(Module*)this,"send_metadata");
event->Send(ServerInstance); // Trigger the event. We don't care what module picks it up.
DELETE(event);
DELETE(metadata);
VerifyCertificate(&sessions[user->GetFd()], user);
if (sessions[user->GetFd()].sess)
user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick, SSL_get_cipher(sessions[user->GetFd()].sess));
}
}
void MakePollWrite(issl_session* session)
{
OnRawSocketWrite(session->fd, NULL, 0);
//EventHandler* eh = ServerInstance->FindDescriptor(session->fd);
//if (eh)
// ServerInstance->SE->WantWrite(eh);
}
void CloseSession(issl_session* session)
{
if (session->sess)
{
SSL_shutdown(session->sess);
SSL_free(session->sess);
}
if (session->inbuf)
{
delete[] session->inbuf;
}
session->outbuf.clear();
session->inbuf = NULL;
session->sess = NULL;
session->status = ISSL_NONE;
}
void VerifyCertificate(issl_session* session, Extensible* user)
{
if (!session->sess || !user)
return;
X509* cert;
ssl_cert* certinfo = new ssl_cert;
unsigned int n;
unsigned char md[EVP_MAX_MD_SIZE];
const EVP_MD *digest = EVP_md5();
user->Extend("ssl_cert",certinfo);
cert = SSL_get_peer_certificate((SSL*)session->sess);
if (!cert)
{
certinfo->data.insert(std::make_pair("error","Could not get peer certificate: "+std::string(get_error())));
return;
}
certinfo->data.insert(std::make_pair("invalid", SSL_get_verify_result(session->sess) != X509_V_OK ? ConvToStr(1) : ConvToStr(0)));
if (SelfSigned)
{
certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(0)));
certinfo->data.insert(std::make_pair("trusted",ConvToStr(1)));
}
else
{
certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(1)));
certinfo->data.insert(std::make_pair("trusted",ConvToStr(0)));
}
certinfo->data.insert(std::make_pair("dn",std::string(X509_NAME_oneline(X509_get_subject_name(cert),0,0))));
certinfo->data.insert(std::make_pair("issuer",std::string(X509_NAME_oneline(X509_get_issuer_name(cert),0,0))));
if (!X509_digest(cert, digest, md, &n))
{
certinfo->data.insert(std::make_pair("error","Out of memory generating fingerprint"));
}
else
{
certinfo->data.insert(std::make_pair("fingerprint",irc::hex(md, n)));
}
if ((ASN1_UTCTIME_cmp_time_t(X509_get_notAfter(cert), time(NULL)) == -1) || (ASN1_UTCTIME_cmp_time_t(X509_get_notBefore(cert), time(NULL)) == 0))
{
certinfo->data.insert(std::make_pair("error","Not activated, or expired certificate"));
}
X509_free(cert);
}
};
static int error_callback(const char *str, size_t len, void *u)
{
ModuleSSLOpenSSL* mssl = (ModuleSSLOpenSSL*)u;
mssl->PublicInstance->Log(DEFAULT, "SSL error: " + std::string(str, len - 1));
return 0;
}
MODULE_INIT(ModuleSSLOpenSSL);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" + +#include <openssl/ssl.h> +#include <openssl/err.h> + +#ifdef WINDOWS +#include <openssl/applink.c> +#endif + +#include "configreader.h" +#include "users.h" +#include "channels.h" +#include "modules.h" + +#include "socket.h" +#include "hashcomp.h" + +#include "transport.h" + +#ifdef WINDOWS +#pragma comment(lib, "libeay32MTd") +#pragma comment(lib, "ssleay32MTd") +#undef MAX_DESCRIPTORS +#define MAX_DESCRIPTORS 10000 +#endif + +/* $ModDesc: Provides SSL support for clients */ +/* $CompileFlags: pkgconfversion("openssl","0.9.7") pkgconfincludes("openssl","/openssl/ssl.h","") */ +/* $LinkerFlags: rpath("pkg-config --libs openssl") pkgconflibs("openssl","/libssl.so","-lssl -lcrypto -ldl") */ +/* $ModDep: transport.h */ + +enum issl_status { ISSL_NONE, ISSL_HANDSHAKING, ISSL_OPEN }; +enum issl_io_status { ISSL_WRITE, ISSL_READ }; + +static bool SelfSigned = false; + +bool isin(int port, const std::vector<int> &portlist) +{ + for(unsigned int i = 0; i < portlist.size(); i++) + if(portlist[i] == port) + return true; + + return false; +} + +char* get_error() +{ + return ERR_error_string(ERR_get_error(), NULL); +} + +static int error_callback(const char *str, size_t len, void *u); + +/** Represents an SSL user's extra data + */ +class issl_session : public classbase +{ +public: + SSL* sess; + issl_status status; + issl_io_status rstat; + issl_io_status wstat; + + unsigned int inbufoffset; + char* inbuf; // Buffer OpenSSL reads into. + std::string outbuf; // Buffer for outgoing data that OpenSSL will not take. + int fd; + bool outbound; + + issl_session() + { + outbound = false; + rstat = ISSL_READ; + wstat = ISSL_WRITE; + } +}; + +static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx) +{ + /* XXX: This will allow self signed certificates. + * In the future if we want an option to not allow this, + * we can just return preverify_ok here, and openssl + * will boot off self-signed and invalid peer certs. + */ + int ve = X509_STORE_CTX_get_error(ctx); + + SelfSigned = (ve == X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT); + + return 1; +} + +class ModuleSSLOpenSSL : public Module +{ + + ConfigReader* Conf; + + std::vector<int> listenports; + + int inbufsize; + issl_session sessions[MAX_DESCRIPTORS]; + + SSL_CTX* ctx; + SSL_CTX* clictx; + + char* dummy; + char cipher[MAXBUF]; + + std::string keyfile; + std::string certfile; + std::string cafile; + // std::string crlfile; + std::string dhfile; + std::string sslports; + + int clientactive; + + public: + + InspIRCd* PublicInstance; + + ModuleSSLOpenSSL(InspIRCd* Me) + : Module(Me), PublicInstance(Me) + { + ServerInstance->PublishInterface("InspSocketHook", this); + + // Not rehashable...because I cba to reduce all the sizes of existing buffers. + inbufsize = ServerInstance->Config->NetBufferSize; + + /* Global SSL library initialization*/ + SSL_library_init(); + SSL_load_error_strings(); + + /* Build our SSL contexts: + * NOTE: OpenSSL makes us have two contexts, one for servers and one for clients. ICK. + */ + ctx = SSL_CTX_new( SSLv23_server_method() ); + clictx = SSL_CTX_new( SSLv23_client_method() ); + + SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify); + SSL_CTX_set_verify(clictx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify); + + // Needs the flag as it ignores a plain /rehash + OnRehash(NULL,"ssl"); + } + + virtual void OnRehash(userrec* user, const std::string ¶m) + { + if (param != "ssl") + return; + + Conf = new ConfigReader(ServerInstance); + + for (unsigned int i = 0; i < listenports.size(); i++) + { + ServerInstance->Config->DelIOHook(listenports[i]); + } + + listenports.clear(); + clientactive = 0; + sslports.clear(); + + for (int i = 0; i < Conf->Enumerate("bind"); i++) + { + // For each <bind> tag + std::string x = Conf->ReadValue("bind", "type", i); + if (((x.empty()) || (x == "clients")) && (Conf->ReadValue("bind", "ssl", i) == "openssl")) + { + // Get the port we're meant to be listening on with SSL + std::string port = Conf->ReadValue("bind", "port", i); + irc::portparser portrange(port, false); + long portno = -1; + while ((portno = portrange.GetToken())) + { + clientactive++; + try + { + if (ServerInstance->Config->AddIOHook(portno, this)) + { + listenports.push_back(portno); + for (size_t i = 0; i < ServerInstance->Config->ports.size(); i++) + if (ServerInstance->Config->ports[i]->GetPort() == portno) + ServerInstance->Config->ports[i]->SetDescription("ssl"); + ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Enabling SSL for port %d", portno); + sslports.append("*:").append(ConvToStr(portno)).append(";"); + } + else + { + ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: FAILED to enable SSL on port %d, maybe you have another ssl or similar module loaded?", portno); + } + } + catch (ModuleException &e) + { + ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: FAILED to enable SSL on port %d: %s. Maybe it's already hooked by the same port on a different IP, or you have another SSL or similar module loaded?", portno, e.GetReason()); + } + } + } + } + + if (!sslports.empty()) + sslports.erase(sslports.end() - 1); + + std::string confdir(ServerInstance->ConfigFileName); + // +1 so we the path ends with a / + confdir = confdir.substr(0, confdir.find_last_of('/') + 1); + + cafile = Conf->ReadValue("openssl", "cafile", 0); + certfile = Conf->ReadValue("openssl", "certfile", 0); + keyfile = Conf->ReadValue("openssl", "keyfile", 0); + dhfile = Conf->ReadValue("openssl", "dhfile", 0); + + // Set all the default values needed. + if (cafile.empty()) + cafile = "ca.pem"; + + if (certfile.empty()) + certfile = "cert.pem"; + + if (keyfile.empty()) + keyfile = "key.pem"; + + if (dhfile.empty()) + dhfile = "dhparams.pem"; + + // Prepend relative paths with the path to the config directory. + if (cafile[0] != '/') + cafile = confdir + cafile; + + if (certfile[0] != '/') + certfile = confdir + certfile; + + if (keyfile[0] != '/') + keyfile = confdir + keyfile; + + if (dhfile[0] != '/') + dhfile = confdir + dhfile; + + /* Load our keys and certificates + * NOTE: OpenSSL's error logging API sucks, don't blame us for this clusterfuck. + */ + if ((!SSL_CTX_use_certificate_chain_file(ctx, certfile.c_str())) || (!SSL_CTX_use_certificate_chain_file(clictx, certfile.c_str()))) + { + ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Can't read certificate file %s. %s", certfile.c_str(), strerror(errno)); + ERR_print_errors_cb(error_callback, this); + } + + if (((!SSL_CTX_use_PrivateKey_file(ctx, keyfile.c_str(), SSL_FILETYPE_PEM))) || (!SSL_CTX_use_PrivateKey_file(clictx, keyfile.c_str(), SSL_FILETYPE_PEM))) + { + ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Can't read key file %s. %s", keyfile.c_str(), strerror(errno)); + ERR_print_errors_cb(error_callback, this); + } + + /* Load the CAs we trust*/ + if (((!SSL_CTX_load_verify_locations(ctx, cafile.c_str(), 0))) || (!SSL_CTX_load_verify_locations(clictx, cafile.c_str(), 0))) + { + ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Can't read CA list from %s. %s", cafile.c_str(), strerror(errno)); + ERR_print_errors_cb(error_callback, this); + } + + FILE* dhpfile = fopen(dhfile.c_str(), "r"); + DH* ret; + + if (dhpfile == NULL) + { + ServerInstance->Log(DEFAULT, "m_ssl_openssl.so Couldn't open DH file %s: %s", dhfile.c_str(), strerror(errno)); + throw ModuleException("Couldn't open DH file " + dhfile + ": " + strerror(errno)); + } + else + { + ret = PEM_read_DHparams(dhpfile, NULL, NULL, NULL); + if ((SSL_CTX_set_tmp_dh(ctx, ret) < 0) || (SSL_CTX_set_tmp_dh(clictx, ret) < 0)) + { + ServerInstance->Log(DEFAULT, "m_ssl_openssl.so: Couldn't set DH parameters %s. SSL errors follow:", dhfile.c_str()); + ERR_print_errors_cb(error_callback, this); + } + } + + fclose(dhpfile); + + DELETE(Conf); + } + + virtual void On005Numeric(std::string &output) + { + output.append(" SSL=" + sslports); + } + + virtual ~ModuleSSLOpenSSL() + { + SSL_CTX_free(ctx); + SSL_CTX_free(clictx); + } + + virtual void OnCleanup(int target_type, void* item) + { + if (target_type == TYPE_USER) + { + userrec* user = (userrec*)item; + + if (user->GetExt("ssl", dummy) && IS_LOCAL(user) && isin(user->GetPort(), listenports)) + { + // User is using SSL, they're a local user, and they're using one of *our* SSL ports. + // Potentially there could be multiple SSL modules loaded at once on different ports. + ServerInstance->GlobalCulls.AddItem(user, "SSL module unloading"); + } + if (user->GetExt("ssl_cert", dummy) && isin(user->GetPort(), listenports)) + { + ssl_cert* tofree; + user->GetExt("ssl_cert", tofree); + delete tofree; + user->Shrink("ssl_cert"); + } + } + } + + virtual void OnUnloadModule(Module* mod, const std::string &name) + { + if (mod == this) + { + for(unsigned int i = 0; i < listenports.size(); i++) + { + ServerInstance->Config->DelIOHook(listenports[i]); + for (size_t j = 0; j < ServerInstance->Config->ports.size(); j++) + if (ServerInstance->Config->ports[j]->GetPort() == listenports[i]) + ServerInstance->Config->ports[j]->SetDescription("plaintext"); + } + } + } + + virtual Version GetVersion() + { + return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION); + } + + void Implements(char* List) + { + List[I_OnRawSocketConnect] = List[I_OnRawSocketAccept] = List[I_OnRawSocketClose] = List[I_OnRawSocketRead] = List[I_OnRawSocketWrite] = List[I_OnCleanup] = List[I_On005Numeric] = 1; + List[I_OnRequest] = List[I_OnSyncUserMetaData] = List[I_OnDecodeMetaData] = List[I_OnUnloadModule] = List[I_OnRehash] = List[I_OnWhois] = List[I_OnPostConnect] = 1; + } + + virtual char* OnRequest(Request* request) + { + ISHRequest* ISR = (ISHRequest*)request; + if (strcmp("IS_NAME", request->GetId()) == 0) + { + return "openssl"; + } + else if (strcmp("IS_HOOK", request->GetId()) == 0) + { + char* ret = "OK"; + try + { + ret = ServerInstance->Config->AddIOHook((Module*)this, (InspSocket*)ISR->Sock) ? (char*)"OK" : NULL; + } + catch (ModuleException &e) + { + return NULL; + } + + return ret; + } + else if (strcmp("IS_UNHOOK", request->GetId()) == 0) + { + return ServerInstance->Config->DelIOHook((InspSocket*)ISR->Sock) ? (char*)"OK" : NULL; + } + else if (strcmp("IS_HSDONE", request->GetId()) == 0) + { + ServerInstance->Log(DEBUG,"Module checking if handshake is done"); + if (ISR->Sock->GetFd() < 0) + return (char*)"OK"; + + issl_session* session = &sessions[ISR->Sock->GetFd()]; + return (session->status == ISSL_HANDSHAKING) ? NULL : (char*)"OK"; + } + else if (strcmp("IS_ATTACH", request->GetId()) == 0) + { + issl_session* session = &sessions[ISR->Sock->GetFd()]; + if (session->sess) + { + VerifyCertificate(session, (InspSocket*)ISR->Sock); + return "OK"; + } + } + return NULL; + } + + + virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport) + { + issl_session* session = &sessions[fd]; + + session->fd = fd; + session->inbuf = new char[inbufsize]; + session->inbufoffset = 0; + session->sess = SSL_new(ctx); + session->status = ISSL_NONE; + session->outbound = false; + + if (session->sess == NULL) + return; + + if (SSL_set_fd(session->sess, fd) == 0) + { + ServerInstance->Log(DEBUG,"BUG: Can't set fd with SSL_set_fd: %d", fd); + return; + } + + Handshake(session); + } + + virtual void OnRawSocketConnect(int fd) + { + ServerInstance->Log(DEBUG,"OnRawSocketConnect connecting"); + issl_session* session = &sessions[fd]; + + session->fd = fd; + session->inbuf = new char[inbufsize]; + session->inbufoffset = 0; + session->sess = SSL_new(clictx); + session->status = ISSL_NONE; + session->outbound = true; + + if (session->sess == NULL) + return; + + if (SSL_set_fd(session->sess, fd) == 0) + { + ServerInstance->Log(DEBUG,"BUG: Can't set fd with SSL_set_fd: %d", fd); + return; + } + + Handshake(session); + ServerInstance->Log(DEBUG,"Exiting OnRawSocketConnect"); + } + + virtual void OnRawSocketClose(int fd) + { + CloseSession(&sessions[fd]); + + EventHandler* user = ServerInstance->SE->GetRef(fd); + + if ((user) && (user->GetExt("ssl_cert", dummy))) + { + ssl_cert* tofree; + user->GetExt("ssl_cert", tofree); + delete tofree; + user->Shrink("ssl_cert"); + } + } + + virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult) + { + issl_session* session = &sessions[fd]; + + ServerInstance->Log(DEBUG,"OnRawSocketRead"); + + if (!session->sess) + { + ServerInstance->Log(DEBUG,"OnRawSocketRead has no session"); + readresult = 0; + CloseSession(session); + return 1; + } + + if (session->status == ISSL_HANDSHAKING) + { + if (session->rstat == ISSL_READ || session->wstat == ISSL_READ) + { + ServerInstance->Log(DEBUG,"Resume handshake in read"); + // The handshake isn't finished and it wants to read, try to finish it. + if (!Handshake(session)) + { + ServerInstance->Log(DEBUG,"Cant resume handshake in read"); + // Couldn't resume handshake. + return -1; + } + } + else + { + errno = EAGAIN; + return -1; + } + } + + // If we resumed the handshake then session->status will be ISSL_OPEN + + if (session->status == ISSL_OPEN) + { + if (session->wstat == ISSL_READ) + { + if(DoWrite(session) == 0) + return 0; + } + + if (session->rstat == ISSL_READ) + { + int ret = DoRead(session); + + if (ret > 0) + { + if (count <= session->inbufoffset) + { + memcpy(buffer, session->inbuf, count); + // Move the stuff left in inbuf to the beginning of it + memcpy(session->inbuf, session->inbuf + count, (session->inbufoffset - count)); + // Now we need to set session->inbufoffset to the amount of data still waiting to be handed to insp. + session->inbufoffset -= count; + // Insp uses readresult as the count of how much data there is in buffer, so: + readresult = count; + } + else + { + // There's not as much in the inbuf as there is space in the buffer, so just copy the whole thing. + memcpy(buffer, session->inbuf, session->inbufoffset); + + readresult = session->inbufoffset; + // Zero the offset, as there's nothing there.. + session->inbufoffset = 0; + } + + return 1; + } + else + { + return ret; + } + } + } + + return -1; + } + + virtual int OnRawSocketWrite(int fd, const char* buffer, int count) + { + issl_session* session = &sessions[fd]; + + if (!session->sess) + { + ServerInstance->Log(DEBUG,"Close session missing sess"); + CloseSession(session); + return -1; + } + + session->outbuf.append(buffer, count); + + if (session->status == ISSL_HANDSHAKING) + { + // The handshake isn't finished, try to finish it. + if (session->rstat == ISSL_WRITE || session->wstat == ISSL_WRITE) + { + ServerInstance->Log(DEBUG,"Handshake resume"); + Handshake(session); + } + } + + if (session->status == ISSL_OPEN) + { + if (session->rstat == ISSL_WRITE) + { + ServerInstance->Log(DEBUG,"DoRead"); + DoRead(session); + } + + if (session->wstat == ISSL_WRITE) + { + ServerInstance->Log(DEBUG,"DoWrite"); + return DoWrite(session); + } + } + + return 1; + } + + int DoWrite(issl_session* session) + { + if (!session->outbuf.size()) + return -1; + + int ret = SSL_write(session->sess, session->outbuf.data(), session->outbuf.size()); + + if (ret == 0) + { + ServerInstance->Log(DEBUG,"Oops, got 0 from SSL_write"); + CloseSession(session); + return 0; + } + else if (ret < 0) + { + int err = SSL_get_error(session->sess, ret); + + if (err == SSL_ERROR_WANT_WRITE) + { + session->wstat = ISSL_WRITE; + return -1; + } + else if (err == SSL_ERROR_WANT_READ) + { + session->wstat = ISSL_READ; + return -1; + } + else + { + ServerInstance->Log(DEBUG,"Close due to returned -1 in SSL_Write"); + CloseSession(session); + return 0; + } + } + else + { + session->outbuf = session->outbuf.substr(ret); + return ret; + } + } + + int DoRead(issl_session* session) + { + // Is this right? Not sure if the unencrypted data is garaunteed to be the same length. + // Read into the inbuffer, offset from the beginning by the amount of data we have that insp hasn't taken yet. + + ServerInstance->Log(DEBUG,"DoRead"); + + int ret = SSL_read(session->sess, session->inbuf + session->inbufoffset, inbufsize - session->inbufoffset); + + if (ret == 0) + { + // Client closed connection. + ServerInstance->Log(DEBUG,"Oops, got 0 from SSL_read"); + CloseSession(session); + return 0; + } + else if (ret < 0) + { + int err = SSL_get_error(session->sess, ret); + + if (err == SSL_ERROR_WANT_READ) + { + session->rstat = ISSL_READ; + ServerInstance->Log(DEBUG,"Setting want_read"); + return -1; + } + else if (err == SSL_ERROR_WANT_WRITE) + { + session->rstat = ISSL_WRITE; + ServerInstance->Log(DEBUG,"Setting want_write"); + return -1; + } + else + { + ServerInstance->Log(DEBUG,"Closed due to returned -1 in SSL_Read"); + CloseSession(session); + return 0; + } + } + else + { + // Read successfully 'ret' bytes into inbuf + inbufoffset + // There are 'ret' + 'inbufoffset' bytes of data in 'inbuf' + // 'buffer' is 'count' long + + session->inbufoffset += ret; + + return ret; + } + } + + // :kenny.chatspike.net 320 Om Epy|AFK :is a Secure Connection + virtual void OnWhois(userrec* source, userrec* dest) + { + if (!clientactive) + return; + + // Bugfix, only send this numeric for *our* SSL users + if (dest->GetExt("ssl", dummy) || (IS_LOCAL(dest) && isin(dest->GetPort(), listenports))) + { + ServerInstance->SendWhoisLine(source, dest, 320, "%s %s :is using a secure connection", source->nick, dest->nick); + } + } + + virtual void OnSyncUserMetaData(userrec* user, Module* proto, void* opaque, const std::string &extname, bool displayable) + { + // check if the linking module wants to know about OUR metadata + if (extname == "ssl") + { + // check if this user has an swhois field to send + if(user->GetExt(extname, dummy)) + { + // call this function in the linking module, let it format the data how it + // sees fit, and send it on its way. We dont need or want to know how. + proto->ProtoSendMetaData(opaque, TYPE_USER, user, extname, displayable ? "Enabled" : "ON"); + } + } + } + + virtual void OnDecodeMetaData(int target_type, void* target, const std::string &extname, const std::string &extdata) + { + // check if its our metadata key, and its associated with a user + if ((target_type == TYPE_USER) && (extname == "ssl")) + { + userrec* dest = (userrec*)target; + // if they dont already have an ssl flag, accept the remote server's + if (!dest->GetExt(extname, dummy)) + { + dest->Extend(extname, "ON"); + } + } + } + + bool Handshake(issl_session* session) + { + ServerInstance->Log(DEBUG,"Handshake"); + int ret; + + if (session->outbound) + { + ServerInstance->Log(DEBUG,"SSL_connect"); + ret = SSL_connect(session->sess); + } + else + ret = SSL_accept(session->sess); + + if (ret < 0) + { + int err = SSL_get_error(session->sess, ret); + + if (err == SSL_ERROR_WANT_READ) + { + ServerInstance->Log(DEBUG,"Want read, handshaking"); + session->rstat = ISSL_READ; + session->status = ISSL_HANDSHAKING; + return true; + } + else if (err == SSL_ERROR_WANT_WRITE) + { + ServerInstance->Log(DEBUG,"Want write, handshaking"); + session->wstat = ISSL_WRITE; + session->status = ISSL_HANDSHAKING; + MakePollWrite(session); + return true; + } + else + { + ServerInstance->Log(DEBUG,"Handshake failed"); + CloseSession(session); + } + + return false; + } + else if (ret > 0) + { + // Handshake complete. + // This will do for setting the ssl flag...it could be done earlier if it's needed. But this seems neater. + userrec* u = ServerInstance->FindDescriptor(session->fd); + if (u) + { + if (!u->GetExt("ssl", dummy)) + u->Extend("ssl", "ON"); + } + + session->status = ISSL_OPEN; + + MakePollWrite(session); + + return true; + } + else if (ret == 0) + { + int ssl_err = SSL_get_error(session->sess, ret); + char buf[1024]; + ERR_print_errors_fp(stderr); + ServerInstance->Log(DEBUG,"Handshake fail 2: %d: %s", ssl_err, ERR_error_string(ssl_err,buf)); + CloseSession(session); + return true; + } + + return true; + } + + virtual void OnPostConnect(userrec* user) + { + // This occurs AFTER OnUserConnect so we can be sure the + // protocol module has propogated the NICK message. + if ((user->GetExt("ssl", dummy)) && (IS_LOCAL(user))) + { + // Tell whatever protocol module we're using that we need to inform other servers of this metadata NOW. + std::deque<std::string>* metadata = new std::deque<std::string>; + metadata->push_back(user->nick); + metadata->push_back("ssl"); // The metadata id + metadata->push_back("ON"); // The value to send + Event* event = new Event((char*)metadata,(Module*)this,"send_metadata"); + event->Send(ServerInstance); // Trigger the event. We don't care what module picks it up. + DELETE(event); + DELETE(metadata); + + VerifyCertificate(&sessions[user->GetFd()], user); + if (sessions[user->GetFd()].sess) + user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick, SSL_get_cipher(sessions[user->GetFd()].sess)); + } + } + + void MakePollWrite(issl_session* session) + { + OnRawSocketWrite(session->fd, NULL, 0); + //EventHandler* eh = ServerInstance->FindDescriptor(session->fd); + //if (eh) + // ServerInstance->SE->WantWrite(eh); + } + + void CloseSession(issl_session* session) + { + if (session->sess) + { + SSL_shutdown(session->sess); + SSL_free(session->sess); + } + + if (session->inbuf) + { + delete[] session->inbuf; + } + + session->outbuf.clear(); + session->inbuf = NULL; + session->sess = NULL; + session->status = ISSL_NONE; + } + + void VerifyCertificate(issl_session* session, Extensible* user) + { + if (!session->sess || !user) + return; + + X509* cert; + ssl_cert* certinfo = new ssl_cert; + unsigned int n; + unsigned char md[EVP_MAX_MD_SIZE]; + const EVP_MD *digest = EVP_md5(); + + user->Extend("ssl_cert",certinfo); + + cert = SSL_get_peer_certificate((SSL*)session->sess); + + if (!cert) + { + certinfo->data.insert(std::make_pair("error","Could not get peer certificate: "+std::string(get_error()))); + return; + } + + certinfo->data.insert(std::make_pair("invalid", SSL_get_verify_result(session->sess) != X509_V_OK ? ConvToStr(1) : ConvToStr(0))); + + if (SelfSigned) + { + certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(0))); + certinfo->data.insert(std::make_pair("trusted",ConvToStr(1))); + } + else + { + certinfo->data.insert(std::make_pair("unknownsigner",ConvToStr(1))); + certinfo->data.insert(std::make_pair("trusted",ConvToStr(0))); + } + + certinfo->data.insert(std::make_pair("dn",std::string(X509_NAME_oneline(X509_get_subject_name(cert),0,0)))); + certinfo->data.insert(std::make_pair("issuer",std::string(X509_NAME_oneline(X509_get_issuer_name(cert),0,0)))); + + if (!X509_digest(cert, digest, md, &n)) + { + certinfo->data.insert(std::make_pair("error","Out of memory generating fingerprint")); + } + else + { + certinfo->data.insert(std::make_pair("fingerprint",irc::hex(md, n))); + } + + if ((ASN1_UTCTIME_cmp_time_t(X509_get_notAfter(cert), time(NULL)) == -1) || (ASN1_UTCTIME_cmp_time_t(X509_get_notBefore(cert), time(NULL)) == 0)) + { + certinfo->data.insert(std::make_pair("error","Not activated, or expired certificate")); + } + + X509_free(cert); + } +}; + +static int error_callback(const char *str, size_t len, void *u) +{ + ModuleSSLOpenSSL* mssl = (ModuleSSLOpenSSL*)u; + mssl->PublicInstance->Log(DEFAULT, "SSL error: " + std::string(str, len - 1)); + return 0; +} + +MODULE_INIT(ModuleSSLOpenSSL); + diff --git a/src/modules/extra/m_ssl_oper_cert.cpp b/src/modules/extra/m_ssl_oper_cert.cpp index 7b1c90868..c67b50c8c 100644 --- a/src/modules/extra/m_ssl_oper_cert.cpp +++ b/src/modules/extra/m_ssl_oper_cert.cpp @@ -1 +1,180 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
/* $ModDesc: Allows for MD5 encrypted oper passwords */
/* $ModDep: transport.h */
#include "inspircd.h"
#include "inspircd_config.h"
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "transport.h"
#include "wildcard.h"
/** Handle /FINGERPRINT
*/
class cmd_fingerprint : public command_t
{
public:
cmd_fingerprint (InspIRCd* Instance) : command_t(Instance,"FINGERPRINT", 0, 1)
{
this->source = "m_ssl_oper_cert.so";
syntax = "<nickname>";
}
CmdResult Handle (const char** parameters, int pcnt, userrec *user)
{
userrec* target = ServerInstance->FindNick(parameters[0]);
if (target)
{
ssl_cert* cert;
if (target->GetExt("ssl_cert",cert))
{
if (cert->GetFingerprint().length())
{
user->WriteServ("NOTICE %s :Certificate fingerprint for %s is %s",user->nick,target->nick,cert->GetFingerprint().c_str());
return CMD_SUCCESS;
}
else
{
user->WriteServ("NOTICE %s :Certificate fingerprint for %s does not exist!", user->nick,target->nick);
return CMD_FAILURE;
}
}
else
{
user->WriteServ("NOTICE %s :Certificate fingerprint for %s does not exist!", user->nick, target->nick);
return CMD_FAILURE;
}
}
else
{
user->WriteServ("401 %s %s :No such nickname", user->nick, parameters[0]);
return CMD_FAILURE;
}
}
};
class ModuleOperSSLCert : public Module
{
ssl_cert* cert;
bool HasCert;
cmd_fingerprint* mycommand;
ConfigReader* cf;
public:
ModuleOperSSLCert(InspIRCd* Me)
: Module(Me)
{
mycommand = new cmd_fingerprint(ServerInstance);
ServerInstance->AddCommand(mycommand);
cf = new ConfigReader(ServerInstance);
}
virtual ~ModuleOperSSLCert()
{
delete cf;
}
void Implements(char* List)
{
List[I_OnPreCommand] = List[I_OnRehash] = 1;
}
virtual void OnRehash(userrec* user, const std::string ¶meter)
{
delete cf;
cf = new ConfigReader(ServerInstance);
}
bool OneOfMatches(const char* host, const char* ip, const char* hostlist)
{
std::stringstream hl(hostlist);
std::string xhost;
while (hl >> xhost)
{
if (match(host,xhost.c_str()) || match(ip,xhost.c_str(),true))
{
return true;
}
}
return false;
}
virtual int OnPreCommand(const std::string &command, const char** parameters, int pcnt, userrec *user, bool validated, const std::string &original_line)
{
irc::string cmd = command.c_str();
if ((cmd == "OPER") && (validated))
{
char TheHost[MAXBUF];
char TheIP[MAXBUF];
std::string LoginName;
std::string Password;
std::string OperType;
std::string HostName;
std::string FingerPrint;
bool SSLOnly;
char* dummy;
snprintf(TheHost,MAXBUF,"%s@%s",user->ident,user->host);
snprintf(TheIP, MAXBUF,"%s@%s",user->ident,user->GetIPString());
HasCert = user->GetExt("ssl_cert",cert);
for (int i = 0; i < cf->Enumerate("oper"); i++)
{
LoginName = cf->ReadValue("oper", "name", i);
Password = cf->ReadValue("oper", "password", i);
OperType = cf->ReadValue("oper", "type", i);
HostName = cf->ReadValue("oper", "host", i);
FingerPrint = cf->ReadValue("oper", "fingerprint", i);
SSLOnly = cf->ReadFlag("oper", "sslonly", i);
if (SSLOnly || !FingerPrint.empty())
{
if ((!strcmp(LoginName.c_str(),parameters[0])) && (!ServerInstance->OperPassCompare(Password.c_str(),parameters[1],i)) && (OneOfMatches(TheHost,TheIP,HostName.c_str())))
{
if (SSLOnly && !user->GetExt("ssl", dummy))
{
user->WriteServ("491 %s :This oper login name requires an SSL connection.", user->nick);
return 1;
}
/* This oper would match */
if ((!cert) || (cert->GetFingerprint() != FingerPrint))
{
user->WriteServ("491 %s :This oper login name requires a matching key fingerprint.",user->nick);
ServerInstance->SNO->WriteToSnoMask('o',"'%s' cannot oper, does not match fingerprint", user->nick);
ServerInstance->Log(DEFAULT,"OPER: Failed oper attempt by %s!%s@%s: credentials valid, but wrong fingerprint.",user->nick,user->ident,user->host);
return 1;
}
}
}
}
}
return 0;
}
virtual Version GetVersion()
{
return Version(1,1,0,0,VF_VENDOR,API_VERSION);
}
};
MODULE_INIT(ModuleOperSSLCert);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +/* $ModDesc: Allows for MD5 encrypted oper passwords */ +/* $ModDep: transport.h */ + +#include "inspircd.h" +#include "inspircd_config.h" +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "transport.h" +#include "wildcard.h" + +/** Handle /FINGERPRINT + */ +class cmd_fingerprint : public command_t +{ + public: + cmd_fingerprint (InspIRCd* Instance) : command_t(Instance,"FINGERPRINT", 0, 1) + { + this->source = "m_ssl_oper_cert.so"; + syntax = "<nickname>"; + } + + CmdResult Handle (const char** parameters, int pcnt, userrec *user) + { + userrec* target = ServerInstance->FindNick(parameters[0]); + if (target) + { + ssl_cert* cert; + if (target->GetExt("ssl_cert",cert)) + { + if (cert->GetFingerprint().length()) + { + user->WriteServ("NOTICE %s :Certificate fingerprint for %s is %s",user->nick,target->nick,cert->GetFingerprint().c_str()); + return CMD_SUCCESS; + } + else + { + user->WriteServ("NOTICE %s :Certificate fingerprint for %s does not exist!", user->nick,target->nick); + return CMD_FAILURE; + } + } + else + { + user->WriteServ("NOTICE %s :Certificate fingerprint for %s does not exist!", user->nick, target->nick); + return CMD_FAILURE; + } + } + else + { + user->WriteServ("401 %s %s :No such nickname", user->nick, parameters[0]); + return CMD_FAILURE; + } + } +}; + + + +class ModuleOperSSLCert : public Module +{ + ssl_cert* cert; + bool HasCert; + cmd_fingerprint* mycommand; + ConfigReader* cf; + public: + + ModuleOperSSLCert(InspIRCd* Me) + : Module(Me) + { + mycommand = new cmd_fingerprint(ServerInstance); + ServerInstance->AddCommand(mycommand); + cf = new ConfigReader(ServerInstance); + } + + virtual ~ModuleOperSSLCert() + { + delete cf; + } + + void Implements(char* List) + { + List[I_OnPreCommand] = List[I_OnRehash] = 1; + } + + virtual void OnRehash(userrec* user, const std::string ¶meter) + { + delete cf; + cf = new ConfigReader(ServerInstance); + } + + bool OneOfMatches(const char* host, const char* ip, const char* hostlist) + { + std::stringstream hl(hostlist); + std::string xhost; + while (hl >> xhost) + { + if (match(host,xhost.c_str()) || match(ip,xhost.c_str(),true)) + { + return true; + } + } + return false; + } + + + virtual int OnPreCommand(const std::string &command, const char** parameters, int pcnt, userrec *user, bool validated, const std::string &original_line) + { + irc::string cmd = command.c_str(); + + if ((cmd == "OPER") && (validated)) + { + char TheHost[MAXBUF]; + char TheIP[MAXBUF]; + std::string LoginName; + std::string Password; + std::string OperType; + std::string HostName; + std::string FingerPrint; + bool SSLOnly; + char* dummy; + + snprintf(TheHost,MAXBUF,"%s@%s",user->ident,user->host); + snprintf(TheIP, MAXBUF,"%s@%s",user->ident,user->GetIPString()); + + HasCert = user->GetExt("ssl_cert",cert); + + for (int i = 0; i < cf->Enumerate("oper"); i++) + { + LoginName = cf->ReadValue("oper", "name", i); + Password = cf->ReadValue("oper", "password", i); + OperType = cf->ReadValue("oper", "type", i); + HostName = cf->ReadValue("oper", "host", i); + FingerPrint = cf->ReadValue("oper", "fingerprint", i); + SSLOnly = cf->ReadFlag("oper", "sslonly", i); + + if (SSLOnly || !FingerPrint.empty()) + { + if ((!strcmp(LoginName.c_str(),parameters[0])) && (!ServerInstance->OperPassCompare(Password.c_str(),parameters[1],i)) && (OneOfMatches(TheHost,TheIP,HostName.c_str()))) + { + if (SSLOnly && !user->GetExt("ssl", dummy)) + { + user->WriteServ("491 %s :This oper login name requires an SSL connection.", user->nick); + return 1; + } + + /* This oper would match */ + if ((!cert) || (cert->GetFingerprint() != FingerPrint)) + { + user->WriteServ("491 %s :This oper login name requires a matching key fingerprint.",user->nick); + ServerInstance->SNO->WriteToSnoMask('o',"'%s' cannot oper, does not match fingerprint", user->nick); + ServerInstance->Log(DEFAULT,"OPER: Failed oper attempt by %s!%s@%s: credentials valid, but wrong fingerprint.",user->nick,user->ident,user->host); + return 1; + } + } + } + } + } + return 0; + } + + virtual Version GetVersion() + { + return Version(1,1,0,0,VF_VENDOR,API_VERSION); + } +}; + +MODULE_INIT(ModuleOperSSLCert); + diff --git a/src/modules/extra/m_sslinfo.cpp b/src/modules/extra/m_sslinfo.cpp index 83de798c8..dc9274f1e 100644 --- a/src/modules/extra/m_sslinfo.cpp +++ b/src/modules/extra/m_sslinfo.cpp @@ -1 +1,94 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "transport.h"
#include "wildcard.h"
#include "dns.h"
/* $ModDesc: Provides /sslinfo command used to test who a mask matches */
/* $ModDep: transport.h */
/** Handle /SSLINFO
*/
class cmd_sslinfo : public command_t
{
public:
cmd_sslinfo (InspIRCd* Instance) : command_t(Instance,"SSLINFO", 0, 1)
{
this->source = "m_sslinfo.so";
this->syntax = "<nick>";
}
CmdResult Handle (const char** parameters, int pcnt, userrec *user)
{
userrec* target = ServerInstance->FindNick(parameters[0]);
ssl_cert* cert;
if (target)
{
if (target->GetExt("ssl_cert", cert))
{
if (cert->GetError().length())
{
user->WriteServ("NOTICE %s :*** Error: %s", user->nick, cert->GetError().c_str());
}
user->WriteServ("NOTICE %s :*** Distinguised Name: %s", user->nick, cert->GetDN().c_str());
user->WriteServ("NOTICE %s :*** Issuer: %s", user->nick, cert->GetIssuer().c_str());
user->WriteServ("NOTICE %s :*** Key Fingerprint: %s", user->nick, cert->GetFingerprint().c_str());
return CMD_SUCCESS;
}
else
{
user->WriteServ("NOTICE %s :*** No SSL certificate information for this user.", user->nick);
return CMD_FAILURE;
}
}
else
user->WriteServ("401 %s %s :No such nickname", user->nick, parameters[0]);
return CMD_FAILURE;
}
};
class ModuleSSLInfo : public Module
{
cmd_sslinfo* newcommand;
public:
ModuleSSLInfo(InspIRCd* Me)
: Module(Me)
{
newcommand = new cmd_sslinfo(ServerInstance);
ServerInstance->AddCommand(newcommand);
}
void Implements(char* List)
{
}
virtual ~ModuleSSLInfo()
{
}
virtual Version GetVersion()
{
return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION);
}
};
MODULE_INIT(ModuleSSLInfo);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "transport.h" +#include "wildcard.h" +#include "dns.h" + +/* $ModDesc: Provides /sslinfo command used to test who a mask matches */ +/* $ModDep: transport.h */ + +/** Handle /SSLINFO + */ +class cmd_sslinfo : public command_t +{ + public: + cmd_sslinfo (InspIRCd* Instance) : command_t(Instance,"SSLINFO", 0, 1) + { + this->source = "m_sslinfo.so"; + this->syntax = "<nick>"; + } + + CmdResult Handle (const char** parameters, int pcnt, userrec *user) + { + userrec* target = ServerInstance->FindNick(parameters[0]); + ssl_cert* cert; + + if (target) + { + if (target->GetExt("ssl_cert", cert)) + { + if (cert->GetError().length()) + { + user->WriteServ("NOTICE %s :*** Error: %s", user->nick, cert->GetError().c_str()); + } + user->WriteServ("NOTICE %s :*** Distinguised Name: %s", user->nick, cert->GetDN().c_str()); + user->WriteServ("NOTICE %s :*** Issuer: %s", user->nick, cert->GetIssuer().c_str()); + user->WriteServ("NOTICE %s :*** Key Fingerprint: %s", user->nick, cert->GetFingerprint().c_str()); + return CMD_SUCCESS; + } + else + { + user->WriteServ("NOTICE %s :*** No SSL certificate information for this user.", user->nick); + return CMD_FAILURE; + } + } + else + user->WriteServ("401 %s %s :No such nickname", user->nick, parameters[0]); + + return CMD_FAILURE; + } +}; + +class ModuleSSLInfo : public Module +{ + cmd_sslinfo* newcommand; + public: + ModuleSSLInfo(InspIRCd* Me) + : Module(Me) + { + + newcommand = new cmd_sslinfo(ServerInstance); + ServerInstance->AddCommand(newcommand); + } + + void Implements(char* List) + { + } + + virtual ~ModuleSSLInfo() + { + } + + virtual Version GetVersion() + { + return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION); + } +}; + +MODULE_INIT(ModuleSSLInfo); + diff --git a/src/modules/extra/m_testclient.cpp b/src/modules/extra/m_testclient.cpp index a867dad20..f4e58b7b5 100644 --- a/src/modules/extra/m_testclient.cpp +++ b/src/modules/extra/m_testclient.cpp @@ -1 +1,110 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "configreader.h"
#include "m_sqlv2.h"
class ModuleTestClient : public Module
{
private:
public:
ModuleTestClient(InspIRCd* Me)
: Module::Module(Me)
{
}
void Implements(char* List)
{
List[I_OnRequest] = List[I_OnBackgroundTimer] = 1;
}
virtual Version GetVersion()
{
return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION);
}
virtual void OnBackgroundTimer(time_t foo)
{
Module* target = ServerInstance->FindFeature("SQL");
if(target)
{
SQLrequest foo = SQLreq(this, target, "foo", "UPDATE rawr SET foo = '?' WHERE bar = 42", ConvToStr(time(NULL)));
if(foo.Send())
{
ServerInstance->Log(DEBUG, "Sent query, got given ID %lu", foo.id);
}
else
{
ServerInstance->Log(DEBUG, "SQLrequest failed: %s", foo.error.Str());
}
}
}
virtual char* OnRequest(Request* request)
{
if(strcmp(SQLRESID, request->GetId()) == 0)
{
ServerInstance->Log(DEBUG, "Got SQL result (%s)", request->GetId());
SQLresult* res = (SQLresult*)request;
if (res->error.Id() == NO_ERROR)
{
if(res->Cols())
{
ServerInstance->Log(DEBUG, "Got result with %d rows and %d columns", res->Rows(), res->Cols());
for (int r = 0; r < res->Rows(); r++)
{
ServerInstance->Log(DEBUG, "Row %d:", r);
for(int i = 0; i < res->Cols(); i++)
{
ServerInstance->Log(DEBUG, "\t[%s]: %s", res->ColName(i).c_str(), res->GetValue(r, i).d.c_str());
}
}
}
else
{
ServerInstance->Log(DEBUG, "%d rows affected in query", res->Rows());
}
}
else
{
ServerInstance->Log(DEBUG, "SQLrequest failed: %s", res->error.Str());
}
return SQLSUCCESS;
}
ServerInstance->Log(DEBUG, "Got unsupported API version string: %s", request->GetId());
return NULL;
}
virtual ~ModuleTestClient()
{
}
};
MODULE_INIT(ModuleTestClient);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "configreader.h" +#include "m_sqlv2.h" + +class ModuleTestClient : public Module +{ +private: + + +public: + ModuleTestClient(InspIRCd* Me) + : Module::Module(Me) + { + } + + void Implements(char* List) + { + List[I_OnRequest] = List[I_OnBackgroundTimer] = 1; + } + + virtual Version GetVersion() + { + return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION); + } + + virtual void OnBackgroundTimer(time_t foo) + { + Module* target = ServerInstance->FindFeature("SQL"); + + if(target) + { + SQLrequest foo = SQLreq(this, target, "foo", "UPDATE rawr SET foo = '?' WHERE bar = 42", ConvToStr(time(NULL))); + + if(foo.Send()) + { + ServerInstance->Log(DEBUG, "Sent query, got given ID %lu", foo.id); + } + else + { + ServerInstance->Log(DEBUG, "SQLrequest failed: %s", foo.error.Str()); + } + } + } + + virtual char* OnRequest(Request* request) + { + if(strcmp(SQLRESID, request->GetId()) == 0) + { + ServerInstance->Log(DEBUG, "Got SQL result (%s)", request->GetId()); + + SQLresult* res = (SQLresult*)request; + + if (res->error.Id() == NO_ERROR) + { + if(res->Cols()) + { + ServerInstance->Log(DEBUG, "Got result with %d rows and %d columns", res->Rows(), res->Cols()); + + for (int r = 0; r < res->Rows(); r++) + { + ServerInstance->Log(DEBUG, "Row %d:", r); + + for(int i = 0; i < res->Cols(); i++) + { + ServerInstance->Log(DEBUG, "\t[%s]: %s", res->ColName(i).c_str(), res->GetValue(r, i).d.c_str()); + } + } + } + else + { + ServerInstance->Log(DEBUG, "%d rows affected in query", res->Rows()); + } + } + else + { + ServerInstance->Log(DEBUG, "SQLrequest failed: %s", res->error.Str()); + + } + + return SQLSUCCESS; + } + + ServerInstance->Log(DEBUG, "Got unsupported API version string: %s", request->GetId()); + + return NULL; + } + + virtual ~ModuleTestClient() + { + } +}; + +MODULE_INIT(ModuleTestClient); + diff --git a/src/modules/extra/m_ziplink.cpp b/src/modules/extra/m_ziplink.cpp index 2a127258d..e815d1042 100644 --- a/src/modules/extra/m_ziplink.cpp +++ b/src/modules/extra/m_ziplink.cpp @@ -1 +1,452 @@ -/* +------------------------------------+
* | Inspire Internet Relay Chat Daemon |
* +------------------------------------+
*
* InspIRCd: (C) 2002-2007 InspIRCd Development Team
* See: http://www.inspircd.org/wiki/index.php/Credits
*
* This program is free but copyrighted software; see
* the file COPYING for details.
*
* ---------------------------------------------------
*/
#include "inspircd.h"
#include <zlib.h>
#include "users.h"
#include "channels.h"
#include "modules.h"
#include "socket.h"
#include "hashcomp.h"
#include "transport.h"
/* $ModDesc: Provides zlib link support for servers */
/* $LinkerFlags: -lz */
/* $ModDep: transport.h */
/*
* Compressed data is transmitted across the link in the following format:
*
* 0 1 2 3 4 ... n
* +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
* | n | Z0 -> Zn |
* +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
*
* Where: n is the size of a frame, in network byte order, 4 bytes.
* Z0 through Zn are Zlib compressed data, n bytes in length.
*
* If the module fails to read the entire frame, then it will buffer
* the portion of the last frame it received, then attempt to read
* the next part of the frame next time a write notification arrives.
*
* ZLIB_BEST_COMPRESSION (9) is used for all sending of data with
* a flush after each frame. A frame may contain multiple lines
* and should be treated as raw binary data.
*
*/
/* Status of a connection */
enum izip_status { IZIP_OPEN, IZIP_CLOSED };
/* Maximum transfer size per read operation */
const unsigned int CHUNK = 128 * 1024;
/* This class manages a compressed chunk of data preceeded by
* a length count.
*
* It can handle having multiple chunks of data in the buffer
* at any time.
*/
class CountedBuffer : public classbase
{
std::string buffer; /* Current buffer contents */
unsigned int amount_expected; /* Amount of data expected */
public:
CountedBuffer()
{
amount_expected = 0;
}
/** Adds arbitrary compressed data to the buffer.
* - Binsry safe, of course.
*/
void AddData(unsigned char* data, int data_length)
{
buffer.append((const char*)data, data_length);
this->NextFrameSize();
}
/** Works out the size of the next compressed frame
*/
void NextFrameSize()
{
if ((!amount_expected) && (buffer.length() >= 4))
{
/* We have enough to read an int -
* Yes, this is safe, but its ugly. Give me
* a nicer way to read 4 bytes from a binary
* stream, and push them into a 32 bit int,
* and i'll consider replacing this.
*/
amount_expected = ntohl((buffer[3] << 24) | (buffer[2] << 16) | (buffer[1] << 8) | buffer[0]);
buffer = buffer.substr(4);
}
}
/** Gets the next frame and returns its size, or returns
* zero if there isnt one available yet.
* A frame can contain multiple plaintext lines.
* - Binary safe.
*/
int GetFrame(unsigned char* frame, int maxsize)
{
if (amount_expected)
{
/* We know how much we're expecting...
* Do we have enough yet?
*/
if (buffer.length() >= amount_expected)
{
int j = 0;
for (unsigned int i = 0; i < amount_expected; i++, j++)
frame[i] = buffer[i];
buffer = buffer.substr(j);
amount_expected = 0;
NextFrameSize();
return j;
}
}
/* Not enough for a frame yet, COME AGAIN! */
return 0;
}
};
/** Represents an zipped connections extra data
*/
class izip_session : public classbase
{
public:
z_stream c_stream; /* compression stream */
z_stream d_stream; /* decompress stream */
izip_status status; /* Connection status */
int fd; /* File descriptor */
CountedBuffer* inbuf; /* Holds input buffer */
std::string outbuf; /* Holds output buffer */
};
class ModuleZLib : public Module
{
izip_session sessions[MAX_DESCRIPTORS];
/* Used for stats z extensions */
float total_out_compressed;
float total_in_compressed;
float total_out_uncompressed;
float total_in_uncompressed;
public:
ModuleZLib(InspIRCd* Me)
: Module::Module(Me)
{
ServerInstance->PublishInterface("InspSocketHook", this);
total_out_compressed = total_in_compressed = 0;
total_out_uncompressed = total_out_uncompressed = 0;
}
virtual ~ModuleZLib()
{
ServerInstance->UnpublishInterface("InspSocketHook", this);
}
virtual Version GetVersion()
{
return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION);
}
void Implements(char* List)
{
List[I_OnRawSocketConnect] = List[I_OnRawSocketAccept] = List[I_OnRawSocketClose] = List[I_OnRawSocketRead] = List[I_OnRawSocketWrite] = 1;
List[I_OnStats] = List[I_OnRequest] = 1;
}
/* Handle InspSocketHook API requests */
virtual char* OnRequest(Request* request)
{
ISHRequest* ISR = (ISHRequest*)request;
if (strcmp("IS_NAME", request->GetId()) == 0)
{
/* Return name */
return "zip";
}
else if (strcmp("IS_HOOK", request->GetId()) == 0)
{
/* Attach to an inspsocket */
char* ret = "OK";
try
{
ret = ServerInstance->Config->AddIOHook((Module*)this, (InspSocket*)ISR->Sock) ? (char*)"OK" : NULL;
}
catch (ModuleException& e)
{
return NULL;
}
return ret;
}
else if (strcmp("IS_UNHOOK", request->GetId()) == 0)
{
/* Detatch from an inspsocket */
return ServerInstance->Config->DelIOHook((InspSocket*)ISR->Sock) ? (char*)"OK" : NULL;
}
else if (strcmp("IS_HSDONE", request->GetId()) == 0)
{
/* Check for completion of handshake
* (actually, this module doesnt handshake)
*/
return "OK";
}
else if (strcmp("IS_ATTACH", request->GetId()) == 0)
{
/* Attach certificate data to the inspsocket
* (this module doesnt do that, either)
*/
return NULL;
}
return NULL;
}
/* Handle stats z (misc stats) */
virtual int OnStats(char symbol, userrec* user, string_list &results)
{
if (symbol == 'z')
{
std::string sn = ServerInstance->Config->ServerName;
/* Yeah yeah, i know, floats are ew.
* We used them here because we'd be casting to float anyway to do this maths,
* and also only floating point numbers can deal with the pretty large numbers
* involved in the total throughput of a server over a large period of time.
* (we dont count 64 bit ints because not all systems have 64 bit ints, and floats
* can still hold more.
*/
float outbound_r = 100 - ((total_out_compressed / (total_out_uncompressed + 0.001)) * 100);
float inbound_r = 100 - ((total_in_compressed / (total_in_uncompressed + 0.001)) * 100);
float total_compressed = total_in_compressed + total_out_compressed;
float total_uncompressed = total_in_uncompressed + total_out_uncompressed;
float total_r = 100 - ((total_compressed / (total_uncompressed + 0.001)) * 100);
char outbound_ratio[MAXBUF], inbound_ratio[MAXBUF], combined_ratio[MAXBUF];
sprintf(outbound_ratio, "%3.2f%%", outbound_r);
sprintf(inbound_ratio, "%3.2f%%", inbound_r);
sprintf(combined_ratio, "%3.2f%%", total_r);
results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS outbound_compressed = "+ConvToStr(total_out_compressed));
results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS inbound_compressed = "+ConvToStr(total_in_compressed));
results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS outbound_uncompressed = "+ConvToStr(total_out_uncompressed));
results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS inbound_uncompressed = "+ConvToStr(total_in_uncompressed));
results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS outbound_ratio = "+outbound_ratio);
results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS inbound_ratio = "+inbound_ratio);
results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS combined_ratio = "+combined_ratio);
return 0;
}
return 0;
}
virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport)
{
izip_session* session = &sessions[fd];
/* allocate state and buffers */
session->fd = fd;
session->status = IZIP_OPEN;
session->inbuf = new CountedBuffer();
session->c_stream.zalloc = (alloc_func)0;
session->c_stream.zfree = (free_func)0;
session->c_stream.opaque = (voidpf)0;
session->d_stream.zalloc = (alloc_func)0;
session->d_stream.zfree = (free_func)0;
session->d_stream.opaque = (voidpf)0;
}
virtual void OnRawSocketConnect(int fd)
{
/* Nothing special needs doing here compared to accept() */
OnRawSocketAccept(fd, "", 0);
}
virtual void OnRawSocketClose(int fd)
{
CloseSession(&sessions[fd]);
}
virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult)
{
/* Find the sockets session */
izip_session* session = &sessions[fd];
if (session->status == IZIP_CLOSED)
return 0;
unsigned char compr[CHUNK + 4];
unsigned int offset = 0;
unsigned int total_size = 0;
/* Read CHUNK bytes at a time to the buffer (usually 128k) */
readresult = read(fd, compr, CHUNK);
/* Did we get anything? */
if (readresult > 0)
{
/* Add it to the frame queue */
session->inbuf->AddData(compr, readresult);
total_in_compressed += readresult;
/* Parse all completed frames */
int size = 0;
while ((size = session->inbuf->GetFrame(compr, CHUNK)) != 0)
{
session->d_stream.next_in = (Bytef*)compr;
session->d_stream.avail_in = 0;
session->d_stream.next_out = (Bytef*)(buffer + offset);
/* If we cant call this, well, we're boned. */
if (inflateInit(&session->d_stream) != Z_OK)
return 0;
while ((session->d_stream.total_out < count) && (session->d_stream.total_in < (unsigned int)size))
{
session->d_stream.avail_in = session->d_stream.avail_out = 1;
if (inflate(&session->d_stream, Z_NO_FLUSH) == Z_STREAM_END)
break;
}
/* Stick a fork in me, i'm done */
inflateEnd(&session->d_stream);
/* Update counters and offsets */
total_size += session->d_stream.total_out;
total_in_uncompressed += session->d_stream.total_out;
offset += session->d_stream.total_out;
}
/* Null-terminate the buffer -- this doesnt harm binary data */
buffer[total_size] = 0;
/* Set the read size to the correct total size */
readresult = total_size;
}
return (readresult > 0);
}
virtual int OnRawSocketWrite(int fd, const char* buffer, int count)
{
izip_session* session = &sessions[fd];
int ocount = count;
if (!count) /* Nothing to do! */
return 0;
if(session->status != IZIP_OPEN)
{
/* Seriously, wtf? */
CloseSession(session);
return 0;
}
unsigned char compr[CHUNK + 4];
/* Gentlemen, start your engines! */
if (deflateInit(&session->c_stream, Z_BEST_COMPRESSION) != Z_OK)
{
CloseSession(session);
return 0;
}
/* Set buffer sizes (we reserve 4 bytes at the start of the
* buffer for the length counters)
*/
session->c_stream.next_in = (Bytef*)buffer;
session->c_stream.next_out = compr + 4;
/* Compress the text */
while ((session->c_stream.total_in < (unsigned int)count) && (session->c_stream.total_out < CHUNK))
{
session->c_stream.avail_in = session->c_stream.avail_out = 1;
if (deflate(&session->c_stream, Z_NO_FLUSH) != Z_OK)
{
CloseSession(session);
return 0;
}
}
/* Finish the stream */
for (session->c_stream.avail_out = 1; deflate(&session->c_stream, Z_FINISH) != Z_STREAM_END; session->c_stream.avail_out = 1);
deflateEnd(&session->c_stream);
total_out_uncompressed += ocount;
total_out_compressed += session->c_stream.total_out;
/** Assemble the frame length onto the frame, in network byte order */
compr[0] = (session->c_stream.total_out >> 24);
compr[1] = (session->c_stream.total_out >> 16);
compr[2] = (session->c_stream.total_out >> 8);
compr[3] = (session->c_stream.total_out & 0xFF);
/* Add compressed data plus leading length to the output buffer -
* Note, we may have incomplete half-sent frames in here.
*/
session->outbuf.append((const char*)compr, session->c_stream.total_out + 4);
/* Lets see how much we can send out */
int ret = write(fd, session->outbuf.data(), session->outbuf.length());
/* Check for errors, and advance the buffer if any was sent */
if (ret > 0)
session->outbuf = session->outbuf.substr(ret);
else if (ret < 1)
{
if (ret == -1)
{
if (errno == EAGAIN)
return 0;
else
{
session->outbuf.clear();
return 0;
}
}
else
{
session->outbuf.clear();
return 0;
}
}
/* ALL LIES the lot of it, we havent really written
* this amount, but the layer above doesnt need to know.
*/
return ocount;
}
void CloseSession(izip_session* session)
{
if (session->status == IZIP_OPEN)
{
session->status = IZIP_CLOSED;
session->outbuf.clear();
delete session->inbuf;
}
}
};
MODULE_INIT(ModuleZLib);
\ No newline at end of file +/* +------------------------------------+ + * | Inspire Internet Relay Chat Daemon | + * +------------------------------------+ + * + * InspIRCd: (C) 2002-2007 InspIRCd Development Team + * See: http://www.inspircd.org/wiki/index.php/Credits + * + * This program is free but copyrighted software; see + * the file COPYING for details. + * + * --------------------------------------------------- + */ + +#include "inspircd.h" +#include <zlib.h> +#include "users.h" +#include "channels.h" +#include "modules.h" +#include "socket.h" +#include "hashcomp.h" +#include "transport.h" + +/* $ModDesc: Provides zlib link support for servers */ +/* $LinkerFlags: -lz */ +/* $ModDep: transport.h */ + +/* + * Compressed data is transmitted across the link in the following format: + * + * 0 1 2 3 4 ... n + * +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + * | n | Z0 -> Zn | + * +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + * + * Where: n is the size of a frame, in network byte order, 4 bytes. + * Z0 through Zn are Zlib compressed data, n bytes in length. + * + * If the module fails to read the entire frame, then it will buffer + * the portion of the last frame it received, then attempt to read + * the next part of the frame next time a write notification arrives. + * + * ZLIB_BEST_COMPRESSION (9) is used for all sending of data with + * a flush after each frame. A frame may contain multiple lines + * and should be treated as raw binary data. + * + */ + +/* Status of a connection */ +enum izip_status { IZIP_OPEN, IZIP_CLOSED }; + +/* Maximum transfer size per read operation */ +const unsigned int CHUNK = 128 * 1024; + +/* This class manages a compressed chunk of data preceeded by + * a length count. + * + * It can handle having multiple chunks of data in the buffer + * at any time. + */ +class CountedBuffer : public classbase +{ + std::string buffer; /* Current buffer contents */ + unsigned int amount_expected; /* Amount of data expected */ + public: + CountedBuffer() + { + amount_expected = 0; + } + + /** Adds arbitrary compressed data to the buffer. + * - Binsry safe, of course. + */ + void AddData(unsigned char* data, int data_length) + { + buffer.append((const char*)data, data_length); + this->NextFrameSize(); + } + + /** Works out the size of the next compressed frame + */ + void NextFrameSize() + { + if ((!amount_expected) && (buffer.length() >= 4)) + { + /* We have enough to read an int - + * Yes, this is safe, but its ugly. Give me + * a nicer way to read 4 bytes from a binary + * stream, and push them into a 32 bit int, + * and i'll consider replacing this. + */ + amount_expected = ntohl((buffer[3] << 24) | (buffer[2] << 16) | (buffer[1] << 8) | buffer[0]); + buffer = buffer.substr(4); + } + } + + /** Gets the next frame and returns its size, or returns + * zero if there isnt one available yet. + * A frame can contain multiple plaintext lines. + * - Binary safe. + */ + int GetFrame(unsigned char* frame, int maxsize) + { + if (amount_expected) + { + /* We know how much we're expecting... + * Do we have enough yet? + */ + if (buffer.length() >= amount_expected) + { + int j = 0; + for (unsigned int i = 0; i < amount_expected; i++, j++) + frame[i] = buffer[i]; + + buffer = buffer.substr(j); + amount_expected = 0; + NextFrameSize(); + return j; + } + } + /* Not enough for a frame yet, COME AGAIN! */ + return 0; + } +}; + +/** Represents an zipped connections extra data + */ +class izip_session : public classbase +{ + public: + z_stream c_stream; /* compression stream */ + z_stream d_stream; /* decompress stream */ + izip_status status; /* Connection status */ + int fd; /* File descriptor */ + CountedBuffer* inbuf; /* Holds input buffer */ + std::string outbuf; /* Holds output buffer */ +}; + +class ModuleZLib : public Module +{ + izip_session sessions[MAX_DESCRIPTORS]; + + /* Used for stats z extensions */ + float total_out_compressed; + float total_in_compressed; + float total_out_uncompressed; + float total_in_uncompressed; + + public: + + ModuleZLib(InspIRCd* Me) + : Module::Module(Me) + { + ServerInstance->PublishInterface("InspSocketHook", this); + + total_out_compressed = total_in_compressed = 0; + total_out_uncompressed = total_out_uncompressed = 0; + } + + virtual ~ModuleZLib() + { + ServerInstance->UnpublishInterface("InspSocketHook", this); + } + + virtual Version GetVersion() + { + return Version(1, 1, 0, 0, VF_VENDOR, API_VERSION); + } + + void Implements(char* List) + { + List[I_OnRawSocketConnect] = List[I_OnRawSocketAccept] = List[I_OnRawSocketClose] = List[I_OnRawSocketRead] = List[I_OnRawSocketWrite] = 1; + List[I_OnStats] = List[I_OnRequest] = 1; + } + + /* Handle InspSocketHook API requests */ + virtual char* OnRequest(Request* request) + { + ISHRequest* ISR = (ISHRequest*)request; + if (strcmp("IS_NAME", request->GetId()) == 0) + { + /* Return name */ + return "zip"; + } + else if (strcmp("IS_HOOK", request->GetId()) == 0) + { + /* Attach to an inspsocket */ + char* ret = "OK"; + try + { + ret = ServerInstance->Config->AddIOHook((Module*)this, (InspSocket*)ISR->Sock) ? (char*)"OK" : NULL; + } + catch (ModuleException& e) + { + return NULL; + } + return ret; + } + else if (strcmp("IS_UNHOOK", request->GetId()) == 0) + { + /* Detatch from an inspsocket */ + return ServerInstance->Config->DelIOHook((InspSocket*)ISR->Sock) ? (char*)"OK" : NULL; + } + else if (strcmp("IS_HSDONE", request->GetId()) == 0) + { + /* Check for completion of handshake + * (actually, this module doesnt handshake) + */ + return "OK"; + } + else if (strcmp("IS_ATTACH", request->GetId()) == 0) + { + /* Attach certificate data to the inspsocket + * (this module doesnt do that, either) + */ + return NULL; + } + return NULL; + } + + /* Handle stats z (misc stats) */ + virtual int OnStats(char symbol, userrec* user, string_list &results) + { + if (symbol == 'z') + { + std::string sn = ServerInstance->Config->ServerName; + + /* Yeah yeah, i know, floats are ew. + * We used them here because we'd be casting to float anyway to do this maths, + * and also only floating point numbers can deal with the pretty large numbers + * involved in the total throughput of a server over a large period of time. + * (we dont count 64 bit ints because not all systems have 64 bit ints, and floats + * can still hold more. + */ + float outbound_r = 100 - ((total_out_compressed / (total_out_uncompressed + 0.001)) * 100); + float inbound_r = 100 - ((total_in_compressed / (total_in_uncompressed + 0.001)) * 100); + + float total_compressed = total_in_compressed + total_out_compressed; + float total_uncompressed = total_in_uncompressed + total_out_uncompressed; + + float total_r = 100 - ((total_compressed / (total_uncompressed + 0.001)) * 100); + + char outbound_ratio[MAXBUF], inbound_ratio[MAXBUF], combined_ratio[MAXBUF]; + + sprintf(outbound_ratio, "%3.2f%%", outbound_r); + sprintf(inbound_ratio, "%3.2f%%", inbound_r); + sprintf(combined_ratio, "%3.2f%%", total_r); + + results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS outbound_compressed = "+ConvToStr(total_out_compressed)); + results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS inbound_compressed = "+ConvToStr(total_in_compressed)); + results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS outbound_uncompressed = "+ConvToStr(total_out_uncompressed)); + results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS inbound_uncompressed = "+ConvToStr(total_in_uncompressed)); + results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS outbound_ratio = "+outbound_ratio); + results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS inbound_ratio = "+inbound_ratio); + results.push_back(sn+" 304 "+user->nick+" :ZIPSTATS combined_ratio = "+combined_ratio); + return 0; + } + + return 0; + } + + virtual void OnRawSocketAccept(int fd, const std::string &ip, int localport) + { + izip_session* session = &sessions[fd]; + + /* allocate state and buffers */ + session->fd = fd; + session->status = IZIP_OPEN; + session->inbuf = new CountedBuffer(); + + session->c_stream.zalloc = (alloc_func)0; + session->c_stream.zfree = (free_func)0; + session->c_stream.opaque = (voidpf)0; + + session->d_stream.zalloc = (alloc_func)0; + session->d_stream.zfree = (free_func)0; + session->d_stream.opaque = (voidpf)0; + } + + virtual void OnRawSocketConnect(int fd) + { + /* Nothing special needs doing here compared to accept() */ + OnRawSocketAccept(fd, "", 0); + } + + virtual void OnRawSocketClose(int fd) + { + CloseSession(&sessions[fd]); + } + + virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult) + { + /* Find the sockets session */ + izip_session* session = &sessions[fd]; + + if (session->status == IZIP_CLOSED) + return 0; + + unsigned char compr[CHUNK + 4]; + unsigned int offset = 0; + unsigned int total_size = 0; + + /* Read CHUNK bytes at a time to the buffer (usually 128k) */ + readresult = read(fd, compr, CHUNK); + + /* Did we get anything? */ + if (readresult > 0) + { + /* Add it to the frame queue */ + session->inbuf->AddData(compr, readresult); + total_in_compressed += readresult; + + /* Parse all completed frames */ + int size = 0; + while ((size = session->inbuf->GetFrame(compr, CHUNK)) != 0) + { + session->d_stream.next_in = (Bytef*)compr; + session->d_stream.avail_in = 0; + session->d_stream.next_out = (Bytef*)(buffer + offset); + + /* If we cant call this, well, we're boned. */ + if (inflateInit(&session->d_stream) != Z_OK) + return 0; + + while ((session->d_stream.total_out < count) && (session->d_stream.total_in < (unsigned int)size)) + { + session->d_stream.avail_in = session->d_stream.avail_out = 1; + if (inflate(&session->d_stream, Z_NO_FLUSH) == Z_STREAM_END) + break; + } + + /* Stick a fork in me, i'm done */ + inflateEnd(&session->d_stream); + + /* Update counters and offsets */ + total_size += session->d_stream.total_out; + total_in_uncompressed += session->d_stream.total_out; + offset += session->d_stream.total_out; + } + + /* Null-terminate the buffer -- this doesnt harm binary data */ + buffer[total_size] = 0; + + /* Set the read size to the correct total size */ + readresult = total_size; + + } + return (readresult > 0); + } + + virtual int OnRawSocketWrite(int fd, const char* buffer, int count) + { + izip_session* session = &sessions[fd]; + int ocount = count; + + if (!count) /* Nothing to do! */ + return 0; + + if(session->status != IZIP_OPEN) + { + /* Seriously, wtf? */ + CloseSession(session); + return 0; + } + + unsigned char compr[CHUNK + 4]; + + /* Gentlemen, start your engines! */ + if (deflateInit(&session->c_stream, Z_BEST_COMPRESSION) != Z_OK) + { + CloseSession(session); + return 0; + } + + /* Set buffer sizes (we reserve 4 bytes at the start of the + * buffer for the length counters) + */ + session->c_stream.next_in = (Bytef*)buffer; + session->c_stream.next_out = compr + 4; + + /* Compress the text */ + while ((session->c_stream.total_in < (unsigned int)count) && (session->c_stream.total_out < CHUNK)) + { + session->c_stream.avail_in = session->c_stream.avail_out = 1; + if (deflate(&session->c_stream, Z_NO_FLUSH) != Z_OK) + { + CloseSession(session); + return 0; + } + } + /* Finish the stream */ + for (session->c_stream.avail_out = 1; deflate(&session->c_stream, Z_FINISH) != Z_STREAM_END; session->c_stream.avail_out = 1); + deflateEnd(&session->c_stream); + + total_out_uncompressed += ocount; + total_out_compressed += session->c_stream.total_out; + + /** Assemble the frame length onto the frame, in network byte order */ + compr[0] = (session->c_stream.total_out >> 24); + compr[1] = (session->c_stream.total_out >> 16); + compr[2] = (session->c_stream.total_out >> 8); + compr[3] = (session->c_stream.total_out & 0xFF); + + /* Add compressed data plus leading length to the output buffer - + * Note, we may have incomplete half-sent frames in here. + */ + session->outbuf.append((const char*)compr, session->c_stream.total_out + 4); + + /* Lets see how much we can send out */ + int ret = write(fd, session->outbuf.data(), session->outbuf.length()); + + /* Check for errors, and advance the buffer if any was sent */ + if (ret > 0) + session->outbuf = session->outbuf.substr(ret); + else if (ret < 1) + { + if (ret == -1) + { + if (errno == EAGAIN) + return 0; + else + { + session->outbuf.clear(); + return 0; + } + } + else + { + session->outbuf.clear(); + return 0; + } + } + + /* ALL LIES the lot of it, we havent really written + * this amount, but the layer above doesnt need to know. + */ + return ocount; + } + + void CloseSession(izip_session* session) + { + if (session->status == IZIP_OPEN) + { + session->status = IZIP_CLOSED; + session->outbuf.clear(); + delete session->inbuf; + } + } + +}; + +MODULE_INIT(ModuleZLib); + |