]> git.netwichtig.de Git - user/henk/code/ruby/rbot.git/blobdiff - lib/rbot/ircsocket.rb
Solve Socket vs URI IPv6 handling in Ruby
[user/henk/code/ruby/rbot.git] / lib / rbot / ircsocket.rb
index 4ee3be236e59637d5de04e399a1eb487b7d57c4e..029d1ca54026353518eab0b402e8a72a3b2670a7 100644 (file)
+#-- vim:sw=2:et
+#++
+#
+# :title: IRC Socket
+#
+# This module implements the IRC socket interface, including IRC message
+# penalty computation and the message queue system
+
+require 'monitor'
+
+class ::String
+  # Calculate the penalty which will be assigned to this message
+  # by the IRCd
+  def irc_send_penalty
+    # According to eggdrop, the initial penalty is
+    penalty = 1 + self.size/100
+    # on everything but UnderNET where it's
+    # penalty = 2 + self.size/120
+
+    cmd, pars = self.split($;,2)
+    debug "cmd: #{cmd}, pars: #{pars.inspect}"
+    case cmd.to_sym
+    when :KICK
+      chan, nick, msg = pars.split
+      chan = chan.split(',')
+      nick = nick.split(',')
+      penalty += nick.size
+      penalty *= chan.size
+    when :MODE
+      chan, modes, argument = pars.split
+      extra = 0
+      if modes
+        extra = 1
+        if argument
+          extra += modes.split(/\+|-/).size
+        else
+          extra += 3 * modes.split(/\+|-/).size
+        end
+      end
+      if argument
+        extra += 2 * argument.split.size
+      end
+      penalty += extra * chan.split.size
+    when :TOPIC
+      penalty += 1
+      penalty += 2 unless pars.split.size < 2
+    when :PRIVMSG, :NOTICE
+      dests = pars.split($;,2).first
+      penalty += dests.split(',').size
+    when :WHO
+      args = pars.split
+      if args.length > 0
+        penalty += args.inject(0){ |sum,x| sum += ((x.length > 4) ? 3 : 5) }
+      else
+        penalty += 10
+      end
+    when :PART
+      penalty += 4
+    when :AWAY, :JOIN, :VERSION, :TIME, :TRACE, :WHOIS, :DNS
+      penalty += 2
+    when :INVITE, :NICK
+      penalty += 3
+    when :ISON
+      penalty += 1
+    else # Unknown messages
+      penalty += 1
+    end
+    if penalty > 99
+      debug "Wow, more than 99 secs of penalty!"
+      penalty = 99
+    end
+    if penalty < 2
+      debug "Wow, less than 2 secs of penalty!"
+      penalty = 2
+    end
+    debug "penalty: #{penalty}"
+    return penalty
+  end
+end
+
 module Irc
 
   require 'socket'
   require 'thread'
-  require 'rbot/timer'
+
+  class QueueRing
+    # A QueueRing is implemented as an array with elements in the form
+    # [chan, [message1, message2, ...]
+    # Note that the channel +chan+ has no actual bearing with the channels
+    # to which messages will be sent
+
+    def initialize
+      @storage = Array.new
+      @last_idx = -1
+    end
+
+    def clear
+      @storage.clear
+      @last_idx = -1
+    end
+
+    def length
+      len = 0
+      @storage.each {|c|
+        len += c[1].size
+      }
+      return len
+    end
+    alias :size :length
+
+    def empty?
+      @storage.empty?
+    end
+
+    def push(mess, chan)
+      cmess = @storage.assoc(chan)
+      if cmess
+        idx = @storage.index(cmess)
+        cmess[1] << mess
+        @storage[idx] = cmess
+      else
+        @storage << [chan, [mess]]
+      end
+    end
+
+    def next
+      if empty?
+        warning "trying to access empty ring"
+        return nil
+      end
+      save_idx = @last_idx
+      @last_idx = (@last_idx + 1) % @storage.size
+      mess = @storage[@last_idx][1].first
+      @last_idx = save_idx
+      return mess
+    end
+
+    def shift
+      if empty?
+        warning "trying to access empty ring"
+        return nil
+      end
+      @last_idx = (@last_idx + 1) % @storage.size
+      mess = @storage[@last_idx][1].shift
+      @storage.delete(@storage[@last_idx]) if @storage[@last_idx][1] == []
+      return mess
+    end
+
+  end
+
+  class MessageQueue
+
+    def initialize
+      # a MessageQueue is an array of QueueRings
+      # rings have decreasing priority, so messages in ring 0
+      # are more important than messages in ring 1, and so on
+      @rings = Array.new(3) { |i|
+        if i > 0
+          QueueRing.new
+        else
+          # ring 0 is special in that if it's not empty, it will
+          # be popped. IOW, ring 0 can starve the other rings
+          # ring 0 is strictly FIFO and is therefore implemented
+          # as an array
+          Array.new
+        end
+      }
+      # the other rings are satisfied round-robin
+      @last_ring = 0
+      self.extend(MonitorMixin)
+      @non_empty = self.new_cond
+    end
+
+    def clear
+      self.synchronize do
+        @rings.each { |r| r.clear }
+        @last_ring = 0
+      end
+    end
+
+    def push(mess, chan=nil, cring=0)
+      ring = cring
+      self.synchronize do
+        if ring == 0
+          warning "message #{mess} at ring 0 has channel #{chan}: channel will be ignored" if !chan.nil?
+          @rings[0] << mess
+        else
+          error "message #{mess} at ring #{ring} must have a channel" if chan.nil?
+          @rings[ring].push mess, chan
+        end
+        @non_empty.signal
+      end
+    end
+
+    def shift(tmout = nil)
+      self.synchronize do
+        @non_empty.wait(tmout) if self.empty?
+        return unsafe_shift
+      end
+    end
+
+    protected
+
+    def empty?
+      !@rings.find { |r| !r.empty? }
+    end
+
+    def length
+      @rings.inject(0) { |s, r| s + r.size }
+    end
+    alias :size :length
+
+    def unsafe_shift
+      if !@rings[0].empty?
+        return @rings[0].shift
+      end
+      (@rings.size - 1).times do
+        @last_ring = (@last_ring % (@rings.size - 1)) + 1
+        return @rings[@last_ring].shift unless @rings[@last_ring].empty?
+      end
+      warning "trying to access an empty message queue"
+      return nil
+    end
+
+  end
 
   # wrapped TCPSocket for communication with the server.
   # emulates a subset of TCPSocket functionality
-  class IrcSocket
+  class Socket
+
+    MAX_IRC_SEND_PENALTY = 10
+
     # total number of lines sent to the irc server
     attr_reader :lines_sent
 
@@ -22,62 +245,47 @@ module Irc
     # accumulator for the throttle
     attr_reader :throttle_bytes
 
-    # byterate components
-    attr_reader :bytes_per
-    attr_reader :seconds_per
+    # an optional filter object. we call @filter.in(data) for
+    # all incoming data and @filter.out(data) for all outgoing data
+    attr_reader :filter
+
+    # normalized uri of the current server
+    attr_reader :server_uri
 
-    # delay between lines sent
-    attr_reader :sendq_delay
+    # penalty multiplier (percent)
+    attr_accessor :penalty_pct
 
-    # max lines to burst
-    attr_reader :sendq_burst
+    # default trivial filter class
+    class IdentityFilter
+        def in(x)
+            x
+        end
 
-    # server:: server to connect to
-    # port::   IRCd port
+        def out(x)
+            x
+        end
+    end
+
+    # set filter to identity, not to nil
+    def filter=(f)
+        @filter = f || IdentityFilter.new
+    end
+
+    # server_list:: list of servers to connect to
     # host::   optional local host to bind to (ruby 1.7+ required)
-    # create a new IrcSocket
-    def initialize(server, port, host, sendq_delay=2, sendq_burst=4, brt="400/2")
-      @timer = Timer::Timer.new
-      @timer.add(0.2) do
-        spool
-      end
-      @server = server.dup
-      @port = port.to_i
+    # create a new Irc::Socket
+    def initialize(server_list, host, opts={})
+      @server_list = server_list.dup
+      @server_uri = nil
+      @conn_count = 0
       @host = host
       @sock = nil
+      @filter = IdentityFilter.new
       @spooler = false
       @lines_sent = 0
       @lines_received = 0
-      if sendq_delay
-        @sendq_delay = sendq_delay.to_f
-      else
-        @sendq_delay = 2
-      end
-      @last_send = Time.new - @sendq_delay
-      @last_throttle = Time.new
-      @burst = 0
-      if sendq_burst
-        @sendq_burst = sendq_burst.to_i
-      else
-        @sendq_burst = 4
-      end
-      @bytes_per = 400
-      @seconds_per = 2
-      @throttle_bytes = 0
-      @throttle_div = 1
-      setbyterate(brt)
-    end
-
-    def setbyterate(brt)
-      if brt.match(/(\d+)\/(\d)/)
-        @bytes_per = $1.to_i
-        @seconds_per = $2.to_i
-        debug "Byterate now #{byterate}"
-        return true
-      else
-        debug "Couldn't set byterate #{brt}"
-        return false
-      end
+      @ssl = opts[:ssl]
+      @penalty_pct = opts[:penalty_pct] || 100
     end
 
     def connected?
@@ -87,174 +295,99 @@ module Irc
     # open a TCP connection to the server
     def connect
       if connected?
-        debug "reconnecting socket while connected"
-        shutdown
+        warning "reconnecting while connected"
+        return
+      end
+      srv_uri = @server_list[@conn_count % @server_list.size].dup
+      srv_uri = 'irc://' + srv_uri if !(srv_uri =~ /:\/\//)
+      @conn_count += 1
+      @server_uri = URI.parse(srv_uri)
+      @server_uri.port = 6667 if !@server_uri.port
+
+      debug "connection attempt \##{@conn_count} (#{@server_uri.host}:#{@server_uri.port})"
+
+      # if the host is a bracketed (IPv6) address, strip the brackets
+      # since Ruby doesn't like them in the Socket host parameter
+      # FIXME it would be safer to have it check for a valid
+      # IPv6 bracketed address rather than just stripping the brackets
+      srv_host = @server_uri.host
+      if srv_host.match(/\A\[(.*)\]\z/)
+        srv_host = $1
       end
+
       if(@host)
         begin
-          @sock=TCPSocket.new(@server, @port, @host)
+          sock=TCPSocket.new(srv_host, @server_uri.port, @host)
         rescue ArgumentError => e
           error "Your version of ruby does not support binding to a "
           error "specific local address, please upgrade if you wish "
           error "to use HOST = foo"
           error "(this option has been disabled in order to continue)"
-          @sock=TCPSocket.new(@server, @port)
+          sock=TCPSocket.new(srv_host, @server_uri.port)
         end
       else
-        @sock=TCPSocket.new(@server, @port)
+        sock=TCPSocket.new(srv_host, @server_uri.port)
       end
-      @qthread = false
-      @qmutex = Mutex.new
-      @sendq = Array.new
-    end
-
-    def sendq_delay=(newfreq)
-      debug "changing sendq frequency to #{newfreq}"
-      @qmutex.synchronize do
-        @sendq_delay = newfreq
-        if newfreq == 0
-          clearq
-          @timer.stop
-        else
-          @timer.start
-        end
-      end
-    end
-
-    def sendq_burst=(newburst)
-      @qmutex.synchronize do
-        @sendq_burst = newburst
-      end
-    end
-
-    def byterate
-      return "#{@bytes_per}/#{@seconds_per}"
-    end
-
-    def byterate=(newrate)
-      @qmutex.synchronize do
-        setbyterate(newrate)
+      if(@ssl)
+        require 'openssl'
+        ssl_context = OpenSSL::SSL::SSLContext.new()
+        ssl_context.verify_mode = OpenSSL::SSL::VERIFY_NONE
+        sock = OpenSSL::SSL::SSLSocket.new(sock, ssl_context)
+        sock.sync_close = true
+        sock.connect
       end
+      @sock = sock
+      @last_send = Time.new
+      @flood_send = Time.new
+      @burst = 0
+      @sock.extend(MonitorMixin)
+      @sendq = MessageQueue.new
+      @qthread = Thread.new { writer_loop }
     end
 
-    def run_throttle(more=0)
-      now = Time.new
-      if @throttle_bytes > 0
-        # If we ever reach the limit, we halve the actual allowed byterate
-        # until we manage to reset the throttle.
-        # I don't know if this is the best way, though, because the real
-        # problem is probably non-queued messages like PINGs and PONGs.
-        # A better solution would probably be to have two queues,
-        # one for priority messages and another one for normal messages.
-        # Even better, we should have:
-        # * one queue for server stuff
-        # * one for each channel
-        # * one for each private communication
-        # The server queue would have priority, everything else would be served
-        # round-robin, so that someone making the bot flood one channel wouldn't
-        # prevent the bot from working on other channels (or private communications)
-        if @throttle_bytes >= @bytes_per
-          @throttle_div = 0.5
-        end
-        delta = ((now - @last_throttle)*@throttle_div*@bytes_per/@seconds_per).floor
-        if delta > 0
-          @throttle_bytes -= delta
-          @throttle_bytes = 0 if @throttle_bytes < 0
-          @last_throttle = now
-        end
-      else
-        @throttle_div = 1
+    # used to send lines to the remote IRCd by skipping the queue
+    # message: IRC message to send
+    # it should only be used for stuff that *must not* be queued,
+    # i.e. the initial PASS, NICK and USER command
+    # or the final QUIT message
+    def emergency_puts(message, penalty = false)
+      @sock.synchronize do
+        # debug "In puts - got @sock"
+        puts_critical(message, penalty)
       end
-      @throttle_bytes += more
     end
 
-    # used to send lines to the remote IRCd
-    # message: IRC message to send
-    def puts(message)
-      @qmutex.synchronize do
-        # debug "In puts - got mutex"
-        puts_critical(message)
-      end
+    def handle_socket_error(string, e)
+      error "#{string} failed: #{e.pretty_inspect}"
+      # We assume that an error means that there are connection
+      # problems and that we should reconnect, so we
+      shutdown
+      raise SocketError.new(e.inspect)
     end
 
     # get the next line from the server (blocks)
     def gets
       if @sock.nil?
-        debug "socket get attempted while closed"
+        warning "socket get attempted while closed"
         return nil
       end
       begin
-        reply = @sock.gets
+        reply = @filter.in(@sock.gets)
         @lines_received += 1
         reply.strip! if reply
         debug "RECV: #{reply.inspect}"
         return reply
-      rescue => e
-        debug "socket get failed: #{e.inspect}"
-        return nil
-      end
-    end
-
-    def queue(msg)
-      if @sendq_delay > 0
-        @qmutex.synchronize do
-          @sendq.push msg
-        end
-        @timer.start
-      else
-        # just send it if queueing is disabled
-        self.puts(msg)
+      rescue Exception => e
+        handle_socket_error(:RECV, e)
       end
     end
 
-    # pop a message off the queue, send it
-    def spool
-      if @sendq.empty?
-        @timer.stop
-        return
-      end
-      now = Time.new
-      if (now >= (@last_send + @sendq_delay))
-        # reset burst counter after @sendq_delay has passed
-        @burst = 0
-        debug "in spool, resetting @burst"
-      elsif (@burst >= @sendq_burst)
-        # nope. can't send anything, come back to us next tick...
-        @timer.start
-        return
-      end
-      @qmutex.synchronize do
-        debug "(can send #{@sendq_burst - @burst} lines, there are #{@sendq.length} to send)"
-        (@sendq_burst - @burst).times do
-          break if @sendq.empty?
-          mess = @sendq[0]
-          if @throttle_bytes == 0 or mess.length+@throttle_bytes < @bytes_per
-            debug "(flood protection: sending message of length #{mess.length})"
-           debug "(byterate: #{byterate}, throttle bytes: #{@throttle_bytes})"
-            puts_critical(@sendq.shift)
-          else
-            debug "(flood protection: throttling message of length #{mess.length})"
-           debug "(byterate: #{byterate}, throttle bytes: #{@throttle_bytes})"
-            run_throttle
-            break
-          end
-        end
-      end
-      if @sendq.empty?
-        @timer.stop
-      end
+    def queue(msg, chan=nil, ring=0)
+      @sendq.push msg, chan, ring
     end
 
     def clearq
-      if @sock
-        unless @sendq.empty?
-          @qmutex.synchronize do
-            @sendq.clear
-          end
-        end
-      else
-        debug "Clearing socket while disconnected"
-      end
+      @sendq.clear
     end
 
     # flush the TCPSocket
@@ -269,21 +402,64 @@ module Irc
 
     # shutdown the connection to the server
     def shutdown(how=2)
-      @sock.shutdown(how) unless @sock.nil?
+      return unless connected?
+      @qthread.kill
+      @qthread = nil
+      begin
+        @sock.close
+      rescue Exception => e
+        error "error while shutting down: #{e.pretty_inspect}"
+      end
       @sock = nil
+      @server_uri = nil
+      @sendq.clear
     end
 
     private
 
-    # same as puts, but expects to be called with a mutex held on @qmutex
-    def puts_critical(message)
+    def writer_loop
+      loop do
+        begin
+          now = Time.now
+          flood_delay = @flood_send - MAX_IRC_SEND_PENALTY - now
+          delay = [flood_delay, 0].max
+          if delay > 0
+            debug "sleep(#{delay}) # (f: #{flood_delay})"
+            sleep(delay)
+          end
+          msg = @sendq.shift
+          debug "got #{msg.inspect} from queue, sending"
+          emergency_puts(msg, true)
+        rescue Exception => e
+          error "Spooling failed: #{e.pretty_inspect}"
+          debug e.backtrace.join("\n")
+          raise e
+        end
+      end
+    end
+
+    # same as puts, but expects to be called with a lock held on @sock
+    def puts_critical(message, penalty=false)
       # debug "in puts_critical"
-      debug "SEND: #{message.inspect}"
-      @sock.send(message + "\n",0)
-      @last_send = Time.new
-      @lines_sent += 1
-      @burst += 1
-      run_throttle(message.length + 1)
+      begin
+        debug "SEND: #{message.inspect}"
+        if @sock.nil?
+          error "SEND attempted on closed socket"
+        else
+          # we use Socket#syswrite() instead of Socket#puts() because
+          # the latter is racy and can cause double message output in
+          # some circumstances
+          actual = @filter.out(message) + "\n"
+          now = Time.new
+          @sock.syswrite actual
+          @last_send = now
+          @flood_send = now if @flood_send < now
+          @flood_send += message.irc_send_penalty*@penalty_pct/100.0 if penalty
+          @lines_sent += 1
+        end
+      rescue Exception => e
+        handle_socket_error(:SEND, e)
+      end
     end
 
   end