]> git.netwichtig.de Git - user/henk/code/ruby/rbot.git/blob - lib/rbot/ircsocket.rb
allow to verify ssl connections against a CA.
[user/henk/code/ruby/rbot.git] / lib / rbot / ircsocket.rb
1 #-- vim:sw=2:et
2 #++
3 #
4 # :title: IRC Socket
5 #
6 # This module implements the IRC socket interface, including IRC message
7 # penalty computation and the message queue system
8
9 require 'monitor'
10
11 class ::String
12   # Calculate the penalty which will be assigned to this message
13   # by the IRCd
14   def irc_send_penalty
15     # According to eggdrop, the initial penalty is
16     penalty = 1 + self.size/100
17     # on everything but UnderNET where it's
18     # penalty = 2 + self.size/120
19
20     cmd, pars = self.split($;,2)
21     debug "cmd: #{cmd}, pars: #{pars.inspect}"
22     case cmd.to_sym
23     when :KICK
24       chan, nick, msg = pars.split
25       chan = chan.split(',')
26       nick = nick.split(',')
27       penalty += nick.size
28       penalty *= chan.size
29     when :MODE
30       chan, modes, argument = pars.split
31       extra = 0
32       if modes
33         extra = 1
34         if argument
35           extra += modes.split(/\+|-/).size
36         else
37           extra += 3 * modes.split(/\+|-/).size
38         end
39       end
40       if argument
41         extra += 2 * argument.split.size
42       end
43       penalty += extra * chan.split.size
44     when :TOPIC
45       penalty += 1
46       penalty += 2 unless pars.split.size < 2
47     when :PRIVMSG, :NOTICE
48       dests = pars.split($;,2).first
49       penalty += dests.split(',').size
50     when :WHO
51       args = pars.split
52       if args.length > 0
53         penalty += args.inject(0){ |sum,x| sum += ((x.length > 4) ? 3 : 5) }
54       else
55         penalty += 10
56       end
57     when :PART
58       penalty += 4
59     when :AWAY, :JOIN, :VERSION, :TIME, :TRACE, :WHOIS, :DNS
60       penalty += 2
61     when :INVITE, :NICK
62       penalty += 3
63     when :ISON
64       penalty += 1
65     else # Unknown messages
66       penalty += 1
67     end
68     if penalty > 99
69       debug "Wow, more than 99 secs of penalty!"
70       penalty = 99
71     end
72     if penalty < 2
73       debug "Wow, less than 2 secs of penalty!"
74       penalty = 2
75     end
76     debug "penalty: #{penalty}"
77     return penalty
78   end
79 end
80
81 module Irc
82
83   require 'socket'
84   require 'thread'
85
86   class QueueRing
87     # A QueueRing is implemented as an array with elements in the form
88     # [chan, [message1, message2, ...]
89     # Note that the channel +chan+ has no actual bearing with the channels
90     # to which messages will be sent
91
92     def initialize
93       @storage = Array.new
94       @last_idx = -1
95     end
96
97     def clear
98       @storage.clear
99       @last_idx = -1
100     end
101
102     def length
103       len = 0
104       @storage.each {|c|
105         len += c[1].size
106       }
107       return len
108     end
109     alias :size :length
110
111     def empty?
112       @storage.empty?
113     end
114
115     def push(mess, chan)
116       cmess = @storage.assoc(chan)
117       if cmess
118         idx = @storage.index(cmess)
119         cmess[1] << mess
120         @storage[idx] = cmess
121       else
122         @storage << [chan, [mess]]
123       end
124     end
125
126     def next
127       if empty?
128         warning "trying to access empty ring"
129         return nil
130       end
131       save_idx = @last_idx
132       @last_idx = (@last_idx + 1) % @storage.size
133       mess = @storage[@last_idx][1].first
134       @last_idx = save_idx
135       return mess
136     end
137
138     def shift
139       if empty?
140         warning "trying to access empty ring"
141         return nil
142       end
143       @last_idx = (@last_idx + 1) % @storage.size
144       mess = @storage[@last_idx][1].shift
145       @storage.delete(@storage[@last_idx]) if @storage[@last_idx][1] == []
146       return mess
147     end
148
149   end
150
151   class MessageQueue
152
153     def initialize
154       # a MessageQueue is an array of QueueRings
155       # rings have decreasing priority, so messages in ring 0
156       # are more important than messages in ring 1, and so on
157       @rings = Array.new(3) { |i|
158         if i > 0
159           QueueRing.new
160         else
161           # ring 0 is special in that if it's not empty, it will
162           # be popped. IOW, ring 0 can starve the other rings
163           # ring 0 is strictly FIFO and is therefore implemented
164           # as an array
165           Array.new
166         end
167       }
168       # the other rings are satisfied round-robin
169       @last_ring = 0
170       self.extend(MonitorMixin)
171       @non_empty = self.new_cond
172     end
173
174     def clear
175       self.synchronize do
176         @rings.each { |r| r.clear }
177         @last_ring = 0
178       end
179     end
180
181     def push(mess, chan=nil, cring=0)
182       ring = cring
183       self.synchronize do
184         if ring == 0
185           warning "message #{mess} at ring 0 has channel #{chan}: channel will be ignored" if !chan.nil?
186           @rings[0] << mess
187         else
188           error "message #{mess} at ring #{ring} must have a channel" if chan.nil?
189           @rings[ring].push mess, chan
190         end
191         @non_empty.signal
192       end
193     end
194
195     def shift(tmout = nil)
196       self.synchronize do
197         @non_empty.wait(tmout) if self.empty?
198         return unsafe_shift
199       end
200     end
201
202     protected
203
204     def empty?
205       !@rings.find { |r| !r.empty? }
206     end
207
208     def length
209       @rings.inject(0) { |s, r| s + r.size }
210     end
211     alias :size :length
212
213     def unsafe_shift
214       if !@rings[0].empty?
215         return @rings[0].shift
216       end
217       (@rings.size - 1).times do
218         @last_ring = (@last_ring % (@rings.size - 1)) + 1
219         return @rings[@last_ring].shift unless @rings[@last_ring].empty?
220       end
221       warning "trying to access an empty message queue"
222       return nil
223     end
224
225   end
226
227   # wrapped TCPSocket for communication with the server.
228   # emulates a subset of TCPSocket functionality
229   class Socket
230
231     MAX_IRC_SEND_PENALTY = 10
232
233     # total number of lines sent to the irc server
234     attr_reader :lines_sent
235
236     # total number of lines received from the irc server
237     attr_reader :lines_received
238
239     # total number of bytes sent to the irc server
240     attr_reader :bytes_sent
241
242     # total number of bytes received from the irc server
243     attr_reader :bytes_received
244
245     # accumulator for the throttle
246     attr_reader :throttle_bytes
247
248     # an optional filter object. we call @filter.in(data) for
249     # all incoming data and @filter.out(data) for all outgoing data
250     attr_reader :filter
251
252     # normalized uri of the current server
253     attr_reader :server_uri
254
255     # penalty multiplier (percent)
256     attr_accessor :penalty_pct
257
258     # default trivial filter class
259     class IdentityFilter
260         def in(x)
261             x
262         end
263
264         def out(x)
265             x
266         end
267     end
268
269     # set filter to identity, not to nil
270     def filter=(f)
271         @filter = f || IdentityFilter.new
272     end
273
274     # server_list:: list of servers to connect to
275     # host::   optional local host to bind to (ruby 1.7+ required)
276     # create a new Irc::Socket
277     def initialize(server_list, host, opts={})
278       @server_list = server_list.dup
279       @server_uri = nil
280       @conn_count = 0
281       @host = host
282       @sock = nil
283       @filter = IdentityFilter.new
284       @spooler = false
285       @lines_sent = 0
286       @lines_received = 0
287       @ssl = opts[:ssl]
288       @ssl_verify = opts[:ssl_verify]
289       @ssl_ca_file = opts[:ssl_ca_file]
290       @ssl_ca_path = opts[:ssl_ca_path]
291       @penalty_pct = opts[:penalty_pct] || 100
292     end
293
294     def connected?
295       !@sock.nil?
296     end
297
298     # open a TCP connection to the server
299     def connect
300       if connected?
301         warning "reconnecting while connected"
302         return
303       end
304       srv_uri = @server_list[@conn_count % @server_list.size].dup
305       srv_uri = 'irc://' + srv_uri if !(srv_uri =~ /:\/\//)
306       @conn_count += 1
307       @server_uri = URI.parse(srv_uri)
308       @server_uri.port = 6667 if !@server_uri.port
309
310       debug "connection attempt \##{@conn_count} (#{@server_uri.host}:#{@server_uri.port})"
311
312       # if the host is a bracketed (IPv6) address, strip the brackets
313       # since Ruby doesn't like them in the Socket host parameter
314       # FIXME it would be safer to have it check for a valid
315       # IPv6 bracketed address rather than just stripping the brackets
316       srv_host = @server_uri.host
317       if srv_host.match(/\A\[(.*)\]\z/)
318         srv_host = $1
319       end
320
321       if(@host)
322         begin
323           sock=TCPSocket.new(srv_host, @server_uri.port, @host)
324         rescue ArgumentError => e
325           error "Your version of ruby does not support binding to a "
326           error "specific local address, please upgrade if you wish "
327           error "to use HOST = foo"
328           error "(this option has been disabled in order to continue)"
329           sock=TCPSocket.new(srv_host, @server_uri.port)
330         end
331       else
332         sock=TCPSocket.new(srv_host, @server_uri.port)
333       end
334       if(@ssl)
335         require 'openssl'
336         ssl_context = OpenSSL::SSL::SSLContext.new()
337         if @ssl_verify
338           ssl_context.ca_file = @ssl_ca_file if @ssl_ca_file and not @ssl_ca_file.empty?
339           ssl_context.ca_path = @ssl_ca_path if @ssl_ca_path and not @ssl_ca_path.empty?
340           ssl_context.verify_mode = OpenSSL::SSL::VERIFY_PEER 
341         else
342           ssl_context.verify_mode = OpenSSL::SSL::VERIFY_NONE
343         end
344         sock = OpenSSL::SSL::SSLSocket.new(sock, ssl_context)
345         sock.sync_close = true
346         sock.connect
347       end
348       @sock = sock
349       @last_send = Time.new
350       @flood_send = Time.new
351       @burst = 0
352       @sock.extend(MonitorMixin)
353       @sendq = MessageQueue.new
354       @qthread = Thread.new { writer_loop }
355     end
356
357     # used to send lines to the remote IRCd by skipping the queue
358     # message: IRC message to send
359     # it should only be used for stuff that *must not* be queued,
360     # i.e. the initial PASS, NICK and USER command
361     # or the final QUIT message
362     def emergency_puts(message, penalty = false)
363       @sock.synchronize do
364         # debug "In puts - got @sock"
365         puts_critical(message, penalty)
366       end
367     end
368
369     def handle_socket_error(string, e)
370       error "#{string} failed: #{e.pretty_inspect}"
371       # We assume that an error means that there are connection
372       # problems and that we should reconnect, so we
373       shutdown
374       raise SocketError.new(e.inspect)
375     end
376
377     # get the next line from the server (blocks)
378     def gets
379       if @sock.nil?
380         warning "socket get attempted while closed"
381         return nil
382       end
383       begin
384         reply = @filter.in(@sock.gets)
385         @lines_received += 1
386         reply.strip! if reply
387         debug "RECV: #{reply.inspect}"
388         return reply
389       rescue Exception => e
390         handle_socket_error(:RECV, e)
391       end
392     end
393
394     def queue(msg, chan=nil, ring=0)
395       @sendq.push msg, chan, ring
396     end
397
398     def clearq
399       @sendq.clear
400     end
401
402     # flush the TCPSocket
403     def flush
404       @sock.flush
405     end
406
407     # Wraps Kernel.select on the socket
408     def select(timeout=nil)
409       Kernel.select([@sock], nil, nil, timeout)
410     end
411
412     # shutdown the connection to the server
413     def shutdown(how=2)
414       return unless connected?
415       @qthread.kill
416       @qthread = nil
417       begin
418         @sock.close
419       rescue Exception => e
420         error "error while shutting down: #{e.pretty_inspect}"
421       end
422       @sock = nil
423       @server_uri = nil
424       @sendq.clear
425     end
426
427     private
428
429     def writer_loop
430       loop do
431         begin
432           now = Time.now
433           flood_delay = @flood_send - MAX_IRC_SEND_PENALTY - now
434           delay = [flood_delay, 0].max
435           if delay > 0
436             debug "sleep(#{delay}) # (f: #{flood_delay})"
437             sleep(delay)
438           end
439           msg = @sendq.shift
440           debug "got #{msg.inspect} from queue, sending"
441           emergency_puts(msg, true)
442         rescue Exception => e
443           error "Spooling failed: #{e.pretty_inspect}"
444           debug e.backtrace.join("\n")
445           raise e
446         end
447       end
448     end
449
450     # same as puts, but expects to be called with a lock held on @sock
451     def puts_critical(message, penalty=false)
452       # debug "in puts_critical"
453       begin
454         debug "SEND: #{message.inspect}"
455         if @sock.nil?
456           error "SEND attempted on closed socket"
457         else
458           # we use Socket#syswrite() instead of Socket#puts() because
459           # the latter is racy and can cause double message output in
460           # some circumstances
461           actual = @filter.out(message) + "\n"
462           now = Time.new
463           @sock.syswrite actual
464           @last_send = now
465           @flood_send = now if @flood_send < now
466           @flood_send += message.irc_send_penalty*@penalty_pct/100.0 if penalty
467           @lines_sent += 1
468         end
469       rescue Exception => e
470         handle_socket_error(:SEND, e)
471       end
472     end
473
474   end
475
476 end