diff --git a/mercurial/util.py b/mercurial/util.py --- a/mercurial/util.py +++ b/mercurial/util.py @@ -689,6 +689,125 @@ class observedbufferedinputpipe(buffered return res +PROXIED_SOCKET_METHODS = { + r'makefile', + r'recv', + r'recvfrom', + r'recvfrom_into', + r'recv_into', + r'send', + r'sendall', + r'sendto', + r'setblocking', + r'settimeout', + r'gettimeout', + r'setsockopt', +} + +class socketproxy(object): + """A proxy around a socket that tells a watcher when events occur. + + This is like ``fileobjectproxy`` except for sockets. + + This type is intended to only be used for testing purposes. Think hard + before using it in important code. + """ + __slots__ = ( + r'_orig', + r'_observer', + ) + + def __init__(self, sock, observer): + object.__setattr__(self, r'_orig', sock) + object.__setattr__(self, r'_observer', observer) + + def __getattribute__(self, name): + if name in PROXIED_SOCKET_METHODS: + return object.__getattribute__(self, name) + + return getattr(object.__getattribute__(self, r'_orig'), name) + + def __delattr__(self, name): + return delattr(object.__getattribute__(self, r'_orig'), name) + + def __setattr__(self, name, value): + return setattr(object.__getattribute__(self, r'_orig'), name, value) + + def __nonzero__(self): + return bool(object.__getattribute__(self, r'_orig')) + + __bool__ = __nonzero__ + + def _observedcall(self, name, *args, **kwargs): + # Call the original object. + orig = object.__getattribute__(self, r'_orig') + res = getattr(orig, name)(*args, **kwargs) + + # Call a method on the observer of the same name with arguments + # so it can react, log, etc. + observer = object.__getattribute__(self, r'_observer') + fn = getattr(observer, name, None) + if fn: + fn(res, *args, **kwargs) + + return res + + def makefile(self, *args, **kwargs): + res = object.__getattribute__(self, r'_observedcall')( + r'makefile', *args, **kwargs) + + # The file object may be used for I/O. So we turn it into a + # proxy using our observer. + observer = object.__getattribute__(self, r'_observer') + return makeloggingfileobject(observer.fh, res, observer.name, + reads=observer.reads, + writes=observer.writes, + logdata=observer.logdata) + + def recv(self, *args, **kwargs): + return object.__getattribute__(self, r'_observedcall')( + r'recv', *args, **kwargs) + + def recvfrom(self, *args, **kwargs): + return object.__getattribute__(self, r'_observedcall')( + r'recvfrom', *args, **kwargs) + + def recvfrom_into(self, *args, **kwargs): + return object.__getattribute__(self, r'_observedcall')( + r'recvfrom_into', *args, **kwargs) + + def recv_into(self, *args, **kwargs): + return object.__getattribute__(self, r'_observedcall')( + r'recv_info', *args, **kwargs) + + def send(self, *args, **kwargs): + return object.__getattribute__(self, r'_observedcall')( + r'send', *args, **kwargs) + + def sendall(self, *args, **kwargs): + return object.__getattribute__(self, r'_observedcall')( + r'sendall', *args, **kwargs) + + def sendto(self, *args, **kwargs): + return object.__getattribute__(self, r'_observedcall')( + r'sendto', *args, **kwargs) + + def setblocking(self, *args, **kwargs): + return object.__getattribute__(self, r'_observedcall')( + r'setblocking', *args, **kwargs) + + def settimeout(self, *args, **kwargs): + return object.__getattribute__(self, r'_observedcall')( + r'settimeout', *args, **kwargs) + + def gettimeout(self, *args, **kwargs): + return object.__getattribute__(self, r'_observedcall')( + r'gettimeout', *args, **kwargs) + + def setsockopt(self, *args, **kwargs): + return object.__getattribute__(self, r'_observedcall')( + r'setsockopt', *args, **kwargs) + DATA_ESCAPE_MAP = {pycompat.bytechr(i): br'\x%02x' % i for i in range(256)} DATA_ESCAPE_MAP.update({ b'\\': b'\\\\', @@ -703,15 +822,7 @@ def escapedata(s): return DATA_ESCAPE_RE.sub(lambda m: DATA_ESCAPE_MAP[m.group(0)], s) -class fileobjectobserver(object): - """Logs file object activity.""" - def __init__(self, fh, name, reads=True, writes=True, logdata=False): - self.fh = fh - self.name = name - self.logdata = logdata - self.reads = reads - self.writes = writes - +class baseproxyobserver(object): def _writedata(self, data): if not self.logdata: self.fh.write('\n') @@ -731,6 +842,15 @@ class fileobjectobserver(object): self.fh.write('%s> %s\n' % (self.name, escapedata(line))) self.fh.flush() +class fileobjectobserver(baseproxyobserver): + """Logs file object activity.""" + def __init__(self, fh, name, reads=True, writes=True, logdata=False): + self.fh = fh + self.name = name + self.logdata = logdata + self.reads = reads + self.writes = writes + def read(self, res, size=-1): if not self.reads: return @@ -793,6 +913,119 @@ def makeloggingfileobject(logh, fh, name logdata=logdata) return fileobjectproxy(fh, observer) +class socketobserver(baseproxyobserver): + """Logs socket activity.""" + def __init__(self, fh, name, reads=True, writes=True, states=True, + logdata=False): + self.fh = fh + self.name = name + self.reads = reads + self.writes = writes + self.states = states + self.logdata = logdata + + def makefile(self, res, mode=None, bufsize=None): + if not self.states: + return + + self.fh.write('%s> makefile(%r, %r)\n' % ( + self.name, mode, bufsize)) + + def recv(self, res, size, flags=0): + if not self.reads: + return + + self.fh.write('%s> recv(%d, %d) -> %d' % ( + self.name, size, flags, len(res))) + self._writedata(res) + + def recvfrom(self, res, size, flags=0): + if not self.reads: + return + + self.fh.write('%s> recvfrom(%d, %d) -> %d' % ( + self.name, size, flags, len(res[0]))) + self._writedata(res[0]) + + def recvfrom_into(self, res, buf, size, flags=0): + if not self.reads: + return + + self.fh.write('%s> recvfrom_into(%d, %d) -> %d' % ( + self.name, size, flags, res[0])) + self._writedata(buf[0:res[0]]) + + def recv_into(self, res, buf, size=0, flags=0): + if not self.reads: + return + + self.fh.write('%s> recv_into(%d, %d) -> %d' % ( + self.name, size, flags, res)) + self._writedata(buf[0:res]) + + def send(self, res, data, flags=0): + if not self.writes: + return + + self.fh.write('%s> send(%d, %d) -> %d' % ( + self.name, len(data), flags, len(res))) + self._writedata(data) + + def sendall(self, res, data, flags=0): + if not self.writes: + return + + # Returns None on success. So don't bother reporting return value. + self.fh.write('%s> sendall(%d, %d)' % ( + self.name, len(data), flags)) + self._writedata(data) + + def sendto(self, res, data, flagsoraddress, address=None): + if not self.writes: + return + + if address: + flags = flagsoraddress + else: + flags = 0 + + self.fh.write('%s> sendto(%d, %d, %r) -> %d' % ( + self.name, len(data), flags, address, res)) + self._writedata(data) + + def setblocking(self, res, flag): + if not self.states: + return + + self.fh.write('%s> setblocking(%r)\n' % (self.name, flag)) + + def settimeout(self, res, value): + if not self.states: + return + + self.fh.write('%s> settimeout(%r)\n' % (self.name, value)) + + def gettimeout(self, res): + if not self.states: + return + + self.fh.write('%s> gettimeout() -> %f\n' % (self.name, res)) + + def setsockopt(self, level, optname, value): + if not self.states: + return + + self.fh.write('%s> setsockopt(%r, %r, %r) -> %r\n' % ( + self.name, level, optname, value)) + +def makeloggingsocket(logh, fh, name, reads=True, writes=True, states=True, + logdata=False): + """Turn a socket into a logging socket.""" + + observer = socketobserver(logh, name, reads=reads, writes=writes, + states=states, logdata=logdata) + return socketproxy(fh, observer) + def version(): """Return version information if available.""" try: