diff options
Diffstat (limited to 'python/src/tcf/channel/AbstractChannel.py')
-rw-r--r-- | python/src/tcf/channel/AbstractChannel.py | 753 |
1 files changed, 753 insertions, 0 deletions
diff --git a/python/src/tcf/channel/AbstractChannel.py b/python/src/tcf/channel/AbstractChannel.py new file mode 100644 index 000000000..f343d58ef --- /dev/null +++ b/python/src/tcf/channel/AbstractChannel.py @@ -0,0 +1,753 @@ +# ******************************************************************************* +# * Copyright (c) 2011 Wind River Systems, Inc. and others. +# * All rights reserved. This program and the accompanying materials +# * are made available under the terms of the Eclipse Public License v1.0 +# * which accompanies this distribution, and is available at +# * http://www.eclipse.org/legal/epl-v10.html +# * +# * Contributors: +# * Wind River Systems - initial API and implementation +# ******************************************************************************* + +import sys, threading, time, types +from exceptions import Exception, IOError +from tcf import protocol, transport, services, peer, errors +from tcf.services import locator +from tcf.channel import STATE_CLOSED, STATE_OPEN, STATE_OPENING +from tcf.channel import Token, fromJSONSequence, toJSONSequence + +EOS = -1 # End Of Stream +EOM = -2 # End Of Message + + +class Message(object): + def __init__(self, typeCode): + if type(typeCode) is types.IntType: + typeCode = chr(typeCode) + self.type = typeCode + self.service = None + self.name = None + self.data = None + self.is_canceled = None + self.is_sent = None + self.token = None + self.trace = () + def __str__(self): + return "%s %s %s" % (self.type, self.service, self.name) + +class ReaderThread(threading.Thread): + def __init__(self, channel, handleInput): + super(ReaderThread, self).__init__(name="TCF Reader Thread") + self.channel = channel + self.handleInput = handleInput + self.buf = bytearray() + self.eos_err_report = None + self.daemon = True + + def error(self): + raise IOError("Protocol syntax error") + + def readBytes(self, end, buf=None): + if buf is None: + buf = bytearray() + while True: + n = self.channel.read() + if n <= 0: + if n == end: break + if n == EOM: raise IOError("Unexpected end of message") + if n < 0: raise IOError("Communication channel is closed by remote peer") + buf.append(n) + return buf + + def readString(self): + del self.buf[:] + bytes = self.readBytes(0, self.buf) + return bytes.decode("UTF8") + + def run(self): + try: + while True: + n = self.channel.read() + if n == EOM: continue + if n == EOS: + try: + self.eos_err_report = self.readBytes(EOM) + reportLen = len(self.eos_err_report) + if reportLen == 0 or reportLen == 1 and self.eos_err_report[0] == 0: + self.eos_err_report = None + except: + pass + break + msg = Message(n) + if self.channel.read() != 0: self.error() + typeCode = msg.type + if typeCode == 'C': + msg.token = Token(self.readBytes(0)) + msg.service = self.readString() + msg.name = self.readString() + msg.data = self.readBytes(EOM) + elif typeCode in 'PRN': + msg.token = Token(self.readBytes(0)) + msg.data = self.readBytes(EOM) + elif typeCode == 'E': + msg.service = self.readString() + msg.name = self.readString() + msg.data = self.readBytes(EOM) + elif typeCode == 'F': + msg.data = self.readBytes(EOM) + else: + self.error() + protocol.invokeLater(self.handleInput, msg) + delay = self.channel.local_congestion_level + if delay > 0: time.sleep(delay / 1000.0) + protocol.invokeLater(self.handleEOS) + except Exception as x: + try: + x.tb = sys.exc_info()[2] + protocol.invokeLater(self.channel.terminate, x) + except: + # TCF event dispatcher has shut down + pass + + def handleEOS(self): + if not self.channel.out_tokens and not self.eos_err_report: + self.channel.close() + else: + x = IOError("Communication channel is closed by remote peer") + if self.eos_err_report: + try: + args = fromJSONSequence(self.eos_err_report) + if len(args) > 0 and args[0] is not None: + x.caused_by = Exception(errors.toErrorString(args[0])) + except IOError: + pass + self.channel.terminate(x) + + +class AbstractChannel(object): + """ + AbstractChannel implements communication link connecting two end points (peers). + The channel asynchronously transmits messages: commands, results and events. + + Clients can subclass AbstractChannel to support particular transport (wire) protocol. + Also, see StreamChannel for stream oriented transport protocols. + """ + + def __init__(self, remote_peer, local_peer=None): + self.remote_peer = remote_peer + self.local_peer = local_peer # TODO + self.inp_thread = ReaderThread(self, self.__handleInput) + self.out_thread = threading.Thread(target=self.__write_output,name="TCF Channel Transmitter") + self.out_thread.daemon = True + self.out_tokens = {} + self.out_queue = [] + self.out_lock = threading.Condition() + self.pending_command_limit = 32 + self.remote_service_by_class = {} + self.local_service_by_name = {} + self.remote_service_by_name = {} + self.channel_listeners = [] + self.event_listeners = {} + self.command_servers = {} + self.redirect_queue = [] + self.redirect_command = None + self.notifying_channel_opened = False + self.registered_with_trasport = False + self.state = STATE_OPENING + self.proxy = None + self.zero_copy = False + + self.local_congestion_level = -100 + self.remote_congestion_level = -100 + self.local_congestion_cnt = 0 + self.local_congestion_time = 0 + self.local_service_by_class = {} + self.trace_listeners = [] + + def __write_output(self): + try: + while True: + msg = None + last = False + with self.out_lock: + while len(self.out_queue) == 0: + self.out_lock.wait() + msg = self.out_queue.pop(0) + if not msg: break + last = len(self.out_queue) == 0 + if msg.is_canceled: + if last: self.flush() + continue + msg.is_sent = True + if msg.trace: + protocol.invokeLater(self.__traceMessageSent, msg) + self.write(msg.type) + self.write(0) + if msg.token: + self.write(msg.token.id) + self.write(0) + if msg.service: + self.write(msg.service.encode("UTF8")) + self.write(0) + if msg.name: + self.write(msg.name.encode("UTF8")) + self.write(0) + if msg.data: + self.write(msg.data) + self.write(EOM) + delay = 0 + level = self.remote_congestion_level + if level > 0: delay = level * 10 + if last or delay > 0: self.flush() + if delay > 0: time.sleep(delay / 1000.0) + #else yield() + self.write(EOS) + self.write(EOM) + self.flush() + except Exception as x: + try: + protocol.invokeLater(self.terminate, x) + except: + # TCF event dispatcher has shut down + pass + + def __traceMessageSent(self, m): + for l in m.trace: + try: + id = None + if m.token is not None: + id = m.token.getID() + l.onMessageSent(m.type, id, m.service, m.name, m.data) + except Exception as x: + protocol.log("Exception in channel listener", x) + + def start(self): + assert protocol.isDispatchThread() + protocol.invokeLater(self.__initServices) + self.inp_thread.start() + self.out_thread.start() + + def __initServices(self): + try: + if self.proxy: return + if self.state == STATE_CLOSED: return + services.onChannelCreated(self, self.local_service_by_name) + self.__makeServiceByClassMap(self.local_service_by_name, self.local_service_by_class) + args = self.local_service_by_name.keys() + self.sendEvent(protocol.getLocator(), "Hello", toJSONSequence((args,))) + except IOError as x: + self.terminate(x) + + def redirect_id(self, peer_id): + """ + Redirect this channel to given peer using this channel remote peer locator service as a proxy. + @param peer_id - peer that will become new remote communication endpoint of this channel + """ + map = {} + map[peer.ATTR_ID] = peer_id + self.redirect(map) + + def redirect(self, peer_attrs): + """ + Redirect this channel to given peer using this channel remote peer locator service as a proxy. + @param peer_attrs - peer that will become new remote communication endpoint of this channel + """ + channel = self + assert protocol.isDispatchThread() + if self.state == STATE_OPENING: + self.redirect_queue.append(peer_attrs) + else: + assert self.state == STATE_OPEN + assert self.redirect_command is None + try: + l = self.remote_service_by_class.get(locator.LocatorService) + if not l: raise IOError("Cannot redirect channel: peer " + + self.remote_peer.getID() + " has no locator service") + peer_id = peer_attrs.get(peer.ATTR_ID) + if peer_id and len(peer_attrs) == 1: + peer = l.getPeers().get(peer_id) + if not peer: + # Peer not found, must wait for a while until peer is discovered or time out + class Callback(object): + found = None + def __call__(self): + if self.found: return + self.channel.terminate(Exception("Peer " + peer_id + " not found")) + cb = Callback() + protocol.invokeLaterWithDelay(locator.DATA_RETENTION_PERIOD / 3, cb) + class Listener(locator.LocatorListener): + def peerAdded(self, peer): + if peer.getID() == peer_id: + cb.found = True + channel.state = STATE_OPEN + l.removeListener(self) + channel.redirect_id(peer_id) + l.addListener(Listener()) + else: + class DoneRedirect(locator.DoneRedirect): + def doneRedirect(self, token, exc): + assert channel.redirect_command is token + channel.redirect_command = None + if channel.state != STATE_OPENING: return + if exc: channel.terminate(exc) + channel.remote_peer = peer + channel.remote_service_by_class.clear() + channel.remote_service_by_name.clear() + channel.event_listeners.clear() + self.redirect_command = l.redirect(peer_id, DoneRedirect()) + else: + class TransientPeer(peer.TransientPeer): + def __init__(self, peer_attrs, parent): + super(TransientPeer, self).__init__(peer_attrs) + self.parent = parent + def openChannel(self): + c = self.parent.openChannel() + c.redirect(peer_attrs) + class DoneRedirect(locator.DoneRedirect): + def doneRedirect(self, token, exc): + assert channel.redirect_command is token + channel.redirect_command = None + if channel.state != STATE_OPENING: return + if exc: channel.terminate(exc) + parent = channel.remote_peer + channel.remote_peer = TransientPeer(peer_attrs, parent) + channel.remote_service_by_class.clear() + channel.remote_service_by_name.clear() + channel.event_listeners.clear() + self.redirect_command = l.redirect(peer_attrs, DoneRedirect()) + self.state = STATE_OPENING + except Exception as x: + self.terminate(x) + + def __makeServiceByClassMap(self, by_name, by_class): + for service in by_name.values(): + for clazz in service.__class__.__bases__: + if clazz == services.Service: continue + # TODO + # if (!IService.class.isAssignableFrom(fs)) continue + by_class[clazz] = service + + def getState(self): + return self.state + + def addChannelListener(self, listener): + assert protocol.isDispatchThread() + assert listener + self.channel_listeners.append(listener) + + def removeChannelListener(self, listener): + assert protocol.isDispatchThread() + self.channel_listeners.remove(listener) + + def addTraceListener(self, listener): + if self.trace_listeners is None: + self.trace_listeners = [] + else: + self.trace_listeners = self.trace_listeners[:] + self.trace_listeners.append(listener) + + def removeTraceListener(self, listener): + self.trace_listeners = self.trace_listeners[:] + self.trace_listeners.remove(listener) + if len(self.trace_listeners) == 0: self.trace_listeners = None + + def addEventListener(self, service, listener): + assert protocol.isDispatchThread() + svc_name = str(service) + listener.svc_name = svc_name + list = self.event_listeners.get(svc_name) or [] + list.append(listener) + self.event_listeners[svc_name] = list + + def removeEventListener(self, service, listener): + assert protocol.isDispatchThread() + svc_name = str(service) + list = self.event_listeners.get(svc_name) + if not list: return + for i in range(len(list)): + if list[i] is listener: + if len(list) == 1: + del self.event_listeners[svc_name] + else: + del list[i] + return + + def addCommandServer(self, service, listener): + assert protocol.isDispatchThread() + svc_name = str(service) + if self.command_servers.get(svc_name): + raise Exception("Only one command server per service is allowed") + self.command_servers[svc_name] = listener + + def removeCommandServer(self, service, listener): + assert protocol.isDispatchThread() + svc_name = str(service) + if self.command_servers.get(svc_name) is not listener: + raise Exception("Invalid command server") + del self.command_servers[svc_name] + + def close(self): + assert protocol.isDispatchThread() + if self.state == STATE_CLOSED: return + try: + self.__sendEndOfStream(10000) + self._close(None) + except Exception as x: + self._close(x) + + def terminate(self, error): + assert protocol.isDispatchThread() + if self.state == STATE_CLOSED: return + try: + self.__sendEndOfStream(500) + except Exception as x: + if not error: error = x + self._close(error) + + def __sendEndOfStream(self, timeout): + with self.out_lock: + del self.out_queue[:] + self.out_queue.append(None) + self.out_lock.notify() + self.out_thread.join(timeout) + + def _close(self, error): + assert self.state != STATE_CLOSED + self.state = STATE_CLOSED + # Closing channel underlying streams can block for a long time, + # so it needs to be done by a background thread. + thread = threading.Thread(target=self.stop, name="TCF Channel Cleanup") + thread.daemon = True + thread.start() + if error and isinstance(self.remote_peer, peer.AbstractPeer): + self.remote_peer.onChannelTerminated() + if self.registered_with_trasport: + self.registered_with_trasport = False + transport.channelClosed(self, error) + if self.proxy: + try: + self.proxy.onChannelClosed(error) + except Exception as x: + protocol.log("Exception in channel listener", x) + channel = self + class Runnable(object): + def __call__(self): + if channel.out_tokens: + x = None + if isinstance(error, Exception): x = error + elif error: x = Exception(error) + else: x = IOError("Channel is closed") + for msg in channel.out_tokens.values(): + try: + s = str(msg) + if len(s) > 72: s = s[:72] + "...]" + y = IOError("Command " + s + " aborted") +# y.initCause(x) + msg.token.getListener().terminated(msg.token, y) + except Exception as e: + protocol.log("Exception in command listener", e) + channel.out_tokens.clear() + if channel.channel_listeners: + for l in channel.channel_listeners: + if not l: break + try: + l.onChannelClosed(error) + except Exception as x: + protocol.log("Exception in channel listener", x) + elif error: + protocol.log("TCF channel terminated", error) + if channel.trace_listeners: + for l in channel.trace_listeners: + try: + l.onChannelClosed(error) + except Exception as x: + protocol.log("Exception in channel listener", x) + protocol.invokeLater(Runnable()) + + def getCongestion(self): + assert protocol.isDispatchThread() + level = len(self.out_tokens) * 100 / self.pending_command_limit - 100 + if self.remote_congestion_level > level: level = self.remote_congestion_level + if level > 100: level = 100 + return level + + def getLocalPeer(self): + assert protocol.isDispatchThread() + return self.local_peer + + def getRemotePeer(self): + assert protocol.isDispatchThread() + return self.remote_peer + + def getLocalServices(self): + assert protocol.isDispatchThread() + assert self.state != STATE_OPENING + return self.local_service_by_name.keys() + + def getRemoteServices(self): + assert protocol.isDispatchThread() + assert self.state != STATE_OPENING + return self.remote_service_by_name.keys() + + def getLocalService(self, cls_or_name): + assert protocol.isDispatchThread() + assert self.state != STATE_OPENING + if type(cls_or_name) == types.StringType: + return self.local_service_by_name.get(cls_or_name) + else: + return self.local_service_by_class.get(cls_or_name) + + def getRemoteService(self, cls_or_name): + assert protocol.isDispatchThread() + assert self.state != STATE_OPENING + if type(cls_or_name) == types.StringType: + return self.remote_service_by_name.get(cls_or_name) + else: + return self.remote_service_by_class.get(cls_or_name) + + def setServiceProxy(self, service_interface, service_proxy): + if not self.notifying_channel_opened: raise Exception("setServiceProxe() can be called only from channel open call-back") + if not isinstance(self.remote_service_by_name.get(service_proxy.getName()), services.GenericProxy): raise Exception("Proxy already set") + if self.remote_service_by_class.get(service_interface): raise Exception("Proxy already set") + self.remote_service_by_class[service_interface] = service_proxy + self.remote_service_by_name[service_proxy.getName()] = service_proxy + + def setProxy(self, proxy, services): + self.proxy = proxy + self.sendEvent(protocol.getLocator(), "Hello", toJSONSequence((services,))) + self.local_service_by_class.clear() + self.local_service_by_name.clear() + + def addToOutQueue(self, msg): + msg.trace = self.trace_listeners + with self.out_lock: + self.out_queue.append(msg) + self.out_lock.notify() + + def sendCommand(self, service, name, args, listener): + assert protocol.isDispatchThread() + if self.state == STATE_OPENING: raise Exception("Channel is waiting for Hello message") + if self.state == STATE_CLOSED: raise Exception("Channel is closed") + msg = Message('C') + msg.service = str(service) + msg.name = name + msg.data = args + channel = self + class CancelableToken(Token): + def __init__(self, listener): + super(CancelableToken, self).__init__(listener=listener) + def cancel(self): + assert protocol.isDispatchThread() + if channel.state != STATE_OPEN: return False + with channel.out_lock: + if msg.is_sent: return False + msg.is_canceled = True + del channel.out_tokens[msg.token.getID()] + return True + token = CancelableToken(listener) + msg.token = token + self.out_tokens[token.getID()] = msg + self.addToOutQueue(msg) + return token + + def sendProgress(self, token, results): + assert protocol.isDispatchThread() + if self.state != STATE_OPEN: raise Exception("Channel is closed") + msg = Message('P') + msg.data = results + msg.token = token + self.addToOutQueue(msg) + + def sendResult(self, token, results): + assert protocol.isDispatchThread() + if self.state != STATE_OPEN: raise Exception("Channel is closed") + msg = Message('R') + msg.data = results + msg.token = token + self.addToOutQueue(msg) + + def rejectCommand(self, token): + assert protocol.isDispatchThread() + if self.state != STATE_OPEN: raise Exception("Channel is closed") + msg = Message('N') + msg.token = token + self.addToOutQueue(msg) + + def sendEvent(self, service, name, args): + assert protocol.isDispatchThread() + if not (self.state == STATE_OPEN or self.state == STATE_OPENING and isinstance(service, locator.LocatorService)): + raise Exception("Channel is closed") + msg = Message('E') + msg.service = str(service) + msg.name = name + msg.data = args + self.addToOutQueue(msg) + + def isZeroCopySupported(self): + return self.zero_copy + + def __traceMessageReceived(self, m): + for l in self.trace_listeners: + try: + id = None + if m.token is not None: + id = m.token.getID() + l.onMessageReceived(m.type, id, m.service, m.name, m.data) + except Exception as x: + protocol.log("Exception in channel listener", x) + + def __handleInput(self, msg): + assert protocol.isDispatchThread() + if self.state == STATE_CLOSED: return + if self.trace_listeners: + self.__traceMessageReceived(msg) + try: + token = None + typeCode = msg.type + if typeCode in 'PRN': + token_id = msg.token.getID() + cmd = self.out_tokens.get(token_id) + if cmd is None: raise Exception("Invalid token received: " + token_id) + if typeCode != 'P': + del self.out_tokens[token_id] + token = cmd.token + if typeCode == 'C': + if self.state == STATE_OPENING: + raise IOError("Received command " + msg.service + "." + msg.name + " before Hello message") + if self.proxy: + self.proxy.onCommand(msg.token, msg.service, msg.name, msg.data) + else: + token = msg.token + cmds = self.command_servers.get(msg.service) + if cmds: + cmds.command(token, msg.name, msg.data) + else: + self.rejectCommand(token) + elif typeCode == 'P': + token.getListener().progress(token, msg.data) + self.__sendCongestionLevel() + elif typeCode == 'R': + token.getListener().result(token, msg.data) + self.__sendCongestionLevel() + elif typeCode == 'N': + token.getListener().terminated(token, errors.ErrorReport( + "Command is not recognized", errors.TCF_ERROR_INV_COMMAND)) + elif typeCode == 'E': + hello = msg.service == locator.NAME and msg.name == "Hello" + if hello: + self.remote_service_by_name.clear() + self.remote_service_by_class.clear() + data = fromJSONSequence(msg.data)[0] + services.onChannelOpened(self, data, self.remote_service_by_name) + self.__makeServiceByClassMap(self.remote_service_by_name, self.remote_service_by_class) + self.zero_copy = self.remote_service_by_name.has_key("ZeroCopy") + if self.proxy and self.state == STATE_OPEN: + self.proxy.onEvent(msg.service, msg.name, msg.data) + elif hello: + assert self.state == STATE_OPENING + self.state = STATE_OPEN + assert self.redirect_command is None + if self.redirect_queue: + self.redirect(self.redirect_queue.pop(0)) + else: + self.notifying_channel_opened = True + if not self.registered_with_trasport: + transport.channelOpened(self) + self.registered_with_trasport = True + for l in self.channel_listeners: + if not l: break + try: + l.onChannelOpened() + except Exception as x: + protocol.log("Exception in channel listener", x) + self.notifying_channel_opened = False + else: + list = self.event_listeners.get(msg.service) + if list: + for l in list: + l.event(msg.name, msg.data) + self.__sendCongestionLevel() + elif typeCode == 'F': + len = len(msg.data) + if len > 0 and msg.data[len - 1] == '\0': len -= 1 + self.remote_congestion_level = int(msg.data) + else: + assert False + except Exception as x: + x.tb = sys.exc_info()[2] + self.terminate(x) + + def __sendCongestionLevel(self): + self.local_congestion_cnt += 1 + if self.local_congestion_cnt < 8: return + self.local_congestion_cnt = 0 + if self.state != STATE_OPEN: return + timeVal = int(time.time()) + if timeVal - self.local_congestion_time < 500: return + assert protocol.isDispatchThread() + level = protocol.getCongestionLevel() + if level == self.local_congestion_level: return + i = (level - self.local_congestion_level) / 8 + if i != 0: level = self.local_congestion_level + i + self.local_congestion_time = timeVal + with self.out_lock: + msg = None + if self.out_queue: + msg = self.out_queue[0] + if msg is None or msg.type != 'F': + msg = Message('F') + self.out_queue.insert(0, msg) + self.out_lock.notify() + data = "%i\0" % self.local_congestion_level + msg.data = data + msg.trace = self.trace_listeners + self.local_congestion_level = level + + def read(self): + """ + Read one byte from the channel input stream. + @return next data byte or EOS (-1) if end of stream is reached, + or EOM (-2) if end of message is reached. + @raises IOError + """ + raise NotImplementedError("Abstract method") + + def writeByte(self, n): + """ + Write one byte into the channel output stream. + The method argument can be one of two special values: + EOS (-1) end of stream marker + EOM (-2) end of message marker. + The stream can put the byte into a buffer instead of transmitting it right away. + @param n - the data byte. + @raises IOError + """ + raise NotImplementedError("Abstract method") + + def flush(self): + """ + Flush the channel output stream. + All buffered data should be transmitted immediately. + @raises IOError + """ + raise NotImplementedError("Abstract method") + + def stop(self): + """ + Stop (close) channel underlying streams. + If a thread is blocked by read() or write(), it should be + resumed (or interrupted). + @raises IOError + """ + raise NotImplementedError("Abstract method") + + def write(self, buf): + """ + Write array of bytes into the channel output stream. + The stream can put bytes into a buffer instead of transmitting it right away. + @param buf + @raises IOError + """ + assert threading.currentThread() == self.out_thread + for i in buf: + self.writeByte(ord(buf[i]) & 0xff) |