#!/usr/bin/python

"""Multiplex various sockets to a single command channel."""
# (C) Copyright IBM Corp. 2008-2009
# Licensed under the GPLv2.
import select
import socket
import sys
import thread
import struct
import traceback
import os
import cPickle as pickle

def res_unavailable_error(err):
	"""Determine if the error is due to lack of resource availability."""
	if isinstance(err, socket.error) and err[0] == 11:
		return True
	return False

READ_BUFFER_SIZE = 1
class sockmux:
	"""Multiplexer that sends objects from an input channel to sockets
and from sockets to the controller."""
	def __init__(self, controller, listen_port):
		"""Create a socket mux with a controller and a given listening port."""
		self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
		self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
		print "Listening on port %d" % listen_port
		self.server_socket.bind(('0.0.0.0', listen_port))
		self.server_socket.listen(5)
		self.controller = controller
		self.should_close = False
		self.queues = {}
		self.in_buffers = {}
		self.pipe = os.pipe()
		self.pipe_read = os.fdopen(self.pipe[0], "r")
		self.pipe_write = os.fdopen(self.pipe[1], "w")

		# This MUST come last!
		thread.start_new_thread(self.run, ())

	def run(self):
		"""Run the mux."""
		self.should_close = False
		while not self.should_close:
			try:
				self.run_once()
			except Exception, e:
				print e
		self.server_socket.shutdown(socket.SHUT_RDWR)
		self.server_socket.close()

		for cs in self.queues.keys():
			cs.shutdown(socket.SHUT_RDWR)
			cs.close()

		self.pipe_read.close()
		self.pipe_write.close()

	def shut_down(self):
		"""Deactivate the mux."""
		self.should_close = True

	def write(self, object):
		"""Write an object to all outputs."""
		for cs in self.queues.keys():
			queue = self.queues[cs]
			queue.append(object)

		# Wake up select()
		self.pipe_write.write("0")
		self.pipe_write.flush()

	def kill_socket(self, cs):
		"""Terminate a socket."""
		del self.queues[cs]
		del self.in_buffers[cs]
		cs.shutdown(socket.SHUT_RDWR)
		cs.close()

	def run_once(self):
		"""Shovel objects between controller and sockets."""
		def write_socket(cs, queue):
			"""Write something to a socket."""
			try:
				assert len(queue) > 0
				obj = queue[0]	
				objlen = len(obj)
				sent = cs.send(obj)
				while sent != 0:
					if objlen == sent:
						del queue[0]
					else:
						queue[0] = obj[sent:]
					if len(queue) == 0:
						return
					obj = queue[0]
					objlen = len(obj)
					sent = cs.send(obj)
			except Exception, e:
				# Don't fault on -EAGAIN
				if res_unavailable_error(e):
					return
				print e
				traceback.print_exc()
				self.kill_socket(cs)

		def read_socket(cs):
			"""Read something from a socket."""
			global READ_BUFFER_SIZE

			try:
				buffer = self.in_buffers[cs]
				str = cs.recv(1)
				# Readable but no data == EOF
				if len(str) == 0:
					raise EOFError()
				while len(str) > 0:
					buffer = buffer + str
					try:
						obj = pickle.loads(buffer)
						self.in_buffers[cs] = ""
						self.controller.command(obj)
					except:
						self.in_buffers[cs] = buffer
						pass
					str = cs.recv(1)
			except Exception, e:
				# Don't kill socket if -EAGAIN
				if res_unavailable_error(e):
					return
				print e
				traceback.print_exc(file=sys.stdout)
				self.kill_socket(cs)
		readers = [self.pipe_read, self.server_socket]
		writers = []
		exceptions = []

		# Nominate all sockets with pending writes for select, and
		# all sockets for reads
		for cs in self.queues.keys():
			queue = self.queues[cs]
			readers.append(cs)
			exceptions.append(cs)
			if len(queue) > 0:
				writers.append(cs)

		# Find sockets that aren't blocked
		#print ("before: ", readers, writers, exceptions)
		(r, w, x) = select.select(readers, writers, exceptions)
		#print ("after: ", r, w, x)

		assert len(x) == 0

		# If someone connects, tell the controller and add the
		# socket to our list.
		if self.server_socket in r:
			try:
				(cs, addr) = self.server_socket.accept()
				print (cs, addr)
				cs.setblocking(0)
				queue = []
				res = self.controller.connect(cs, queue)
				if not res:
					cs.shutdown(socket.SHUT_RDWR)
					cs.close()
				else:
					self.queues[cs] = queue
					self.in_buffers[cs] = ""
			except socket.error:
				pass

		# For all writers that can be written, write queued data
		for cs in w:
			queue = self.queues[cs]
			write_socket(cs, queue)

		# For all reader sockets...
		for cs in r:
			if cs == self.pipe_read:
				x = self.pipe_read.read(1)
				continue
			if cs == self.server_socket:
				continue
			read_socket(cs)
