##// END OF EJS Templates
make @interactive decorator friendlier with dill...
MinRK -
Show More
@@ -1,370 +1,379 b''
1 """some generic utilities for dealing with classes, urls, and serialization
1 """some generic utilities for dealing with classes, urls, and serialization
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import socket
23 import socket
24 import sys
24 import sys
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 try:
26 try:
27 from signal import SIGKILL
27 from signal import SIGKILL
28 except ImportError:
28 except ImportError:
29 SIGKILL=None
29 SIGKILL=None
30 from types import FunctionType
30
31
31 try:
32 try:
32 import cPickle
33 import cPickle
33 pickle = cPickle
34 pickle = cPickle
34 except:
35 except:
35 cPickle = None
36 cPickle = None
36 import pickle
37 import pickle
37
38
38 # System library imports
39 # System library imports
39 import zmq
40 import zmq
40 from zmq.log import handlers
41 from zmq.log import handlers
41
42
42 from IPython.external.decorator import decorator
43 from IPython.external.decorator import decorator
43
44
44 # IPython imports
45 # IPython imports
45 from IPython.config.application import Application
46 from IPython.config.application import Application
46 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
47 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
47 from IPython.utils.py3compat import string_types, iteritems, itervalues
48 from IPython.utils.py3compat import string_types, iteritems, itervalues
48 from IPython.kernel.zmq.log import EnginePUBHandler
49 from IPython.kernel.zmq.log import EnginePUBHandler
49 from IPython.kernel.zmq.serialize import (
50 from IPython.kernel.zmq.serialize import (
50 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
51 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
51 )
52 )
52
53
53 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
54 # Classes
55 # Classes
55 #-----------------------------------------------------------------------------
56 #-----------------------------------------------------------------------------
56
57
57 class Namespace(dict):
58 class Namespace(dict):
58 """Subclass of dict for attribute access to keys."""
59 """Subclass of dict for attribute access to keys."""
59
60
60 def __getattr__(self, key):
61 def __getattr__(self, key):
61 """getattr aliased to getitem"""
62 """getattr aliased to getitem"""
62 if key in self:
63 if key in self:
63 return self[key]
64 return self[key]
64 else:
65 else:
65 raise NameError(key)
66 raise NameError(key)
66
67
67 def __setattr__(self, key, value):
68 def __setattr__(self, key, value):
68 """setattr aliased to setitem, with strict"""
69 """setattr aliased to setitem, with strict"""
69 if hasattr(dict, key):
70 if hasattr(dict, key):
70 raise KeyError("Cannot override dict keys %r"%key)
71 raise KeyError("Cannot override dict keys %r"%key)
71 self[key] = value
72 self[key] = value
72
73
73
74
74 class ReverseDict(dict):
75 class ReverseDict(dict):
75 """simple double-keyed subset of dict methods."""
76 """simple double-keyed subset of dict methods."""
76
77
77 def __init__(self, *args, **kwargs):
78 def __init__(self, *args, **kwargs):
78 dict.__init__(self, *args, **kwargs)
79 dict.__init__(self, *args, **kwargs)
79 self._reverse = dict()
80 self._reverse = dict()
80 for key, value in iteritems(self):
81 for key, value in iteritems(self):
81 self._reverse[value] = key
82 self._reverse[value] = key
82
83
83 def __getitem__(self, key):
84 def __getitem__(self, key):
84 try:
85 try:
85 return dict.__getitem__(self, key)
86 return dict.__getitem__(self, key)
86 except KeyError:
87 except KeyError:
87 return self._reverse[key]
88 return self._reverse[key]
88
89
89 def __setitem__(self, key, value):
90 def __setitem__(self, key, value):
90 if key in self._reverse:
91 if key in self._reverse:
91 raise KeyError("Can't have key %r on both sides!"%key)
92 raise KeyError("Can't have key %r on both sides!"%key)
92 dict.__setitem__(self, key, value)
93 dict.__setitem__(self, key, value)
93 self._reverse[value] = key
94 self._reverse[value] = key
94
95
95 def pop(self, key):
96 def pop(self, key):
96 value = dict.pop(self, key)
97 value = dict.pop(self, key)
97 self._reverse.pop(value)
98 self._reverse.pop(value)
98 return value
99 return value
99
100
100 def get(self, key, default=None):
101 def get(self, key, default=None):
101 try:
102 try:
102 return self[key]
103 return self[key]
103 except KeyError:
104 except KeyError:
104 return default
105 return default
105
106
106 #-----------------------------------------------------------------------------
107 #-----------------------------------------------------------------------------
107 # Functions
108 # Functions
108 #-----------------------------------------------------------------------------
109 #-----------------------------------------------------------------------------
109
110
110 @decorator
111 @decorator
111 def log_errors(f, self, *args, **kwargs):
112 def log_errors(f, self, *args, **kwargs):
112 """decorator to log unhandled exceptions raised in a method.
113 """decorator to log unhandled exceptions raised in a method.
113
114
114 For use wrapping on_recv callbacks, so that exceptions
115 For use wrapping on_recv callbacks, so that exceptions
115 do not cause the stream to be closed.
116 do not cause the stream to be closed.
116 """
117 """
117 try:
118 try:
118 return f(self, *args, **kwargs)
119 return f(self, *args, **kwargs)
119 except Exception:
120 except Exception:
120 self.log.error("Uncaught exception in %r" % f, exc_info=True)
121 self.log.error("Uncaught exception in %r" % f, exc_info=True)
121
122
122
123
123 def is_url(url):
124 def is_url(url):
124 """boolean check for whether a string is a zmq url"""
125 """boolean check for whether a string is a zmq url"""
125 if '://' not in url:
126 if '://' not in url:
126 return False
127 return False
127 proto, addr = url.split('://', 1)
128 proto, addr = url.split('://', 1)
128 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
129 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
129 return False
130 return False
130 return True
131 return True
131
132
132 def validate_url(url):
133 def validate_url(url):
133 """validate a url for zeromq"""
134 """validate a url for zeromq"""
134 if not isinstance(url, string_types):
135 if not isinstance(url, string_types):
135 raise TypeError("url must be a string, not %r"%type(url))
136 raise TypeError("url must be a string, not %r"%type(url))
136 url = url.lower()
137 url = url.lower()
137
138
138 proto_addr = url.split('://')
139 proto_addr = url.split('://')
139 assert len(proto_addr) == 2, 'Invalid url: %r'%url
140 assert len(proto_addr) == 2, 'Invalid url: %r'%url
140 proto, addr = proto_addr
141 proto, addr = proto_addr
141 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
142 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
142
143
143 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
144 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
144 # author: Remi Sabourin
145 # author: Remi Sabourin
145 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
146 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
146
147
147 if proto == 'tcp':
148 if proto == 'tcp':
148 lis = addr.split(':')
149 lis = addr.split(':')
149 assert len(lis) == 2, 'Invalid url: %r'%url
150 assert len(lis) == 2, 'Invalid url: %r'%url
150 addr,s_port = lis
151 addr,s_port = lis
151 try:
152 try:
152 port = int(s_port)
153 port = int(s_port)
153 except ValueError:
154 except ValueError:
154 raise AssertionError("Invalid port %r in url: %r"%(port, url))
155 raise AssertionError("Invalid port %r in url: %r"%(port, url))
155
156
156 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
157 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
157
158
158 else:
159 else:
159 # only validate tcp urls currently
160 # only validate tcp urls currently
160 pass
161 pass
161
162
162 return True
163 return True
163
164
164
165
165 def validate_url_container(container):
166 def validate_url_container(container):
166 """validate a potentially nested collection of urls."""
167 """validate a potentially nested collection of urls."""
167 if isinstance(container, string_types):
168 if isinstance(container, string_types):
168 url = container
169 url = container
169 return validate_url(url)
170 return validate_url(url)
170 elif isinstance(container, dict):
171 elif isinstance(container, dict):
171 container = itervalues(container)
172 container = itervalues(container)
172
173
173 for element in container:
174 for element in container:
174 validate_url_container(element)
175 validate_url_container(element)
175
176
176
177
177 def split_url(url):
178 def split_url(url):
178 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
179 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
179 proto_addr = url.split('://')
180 proto_addr = url.split('://')
180 assert len(proto_addr) == 2, 'Invalid url: %r'%url
181 assert len(proto_addr) == 2, 'Invalid url: %r'%url
181 proto, addr = proto_addr
182 proto, addr = proto_addr
182 lis = addr.split(':')
183 lis = addr.split(':')
183 assert len(lis) == 2, 'Invalid url: %r'%url
184 assert len(lis) == 2, 'Invalid url: %r'%url
184 addr,s_port = lis
185 addr,s_port = lis
185 return proto,addr,s_port
186 return proto,addr,s_port
186
187
187 def disambiguate_ip_address(ip, location=None):
188 def disambiguate_ip_address(ip, location=None):
188 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
189 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
189 ones, based on the location (default interpretation of location is localhost)."""
190 ones, based on the location (default interpretation of location is localhost)."""
190 if ip in ('0.0.0.0', '*'):
191 if ip in ('0.0.0.0', '*'):
191 if location is None or is_public_ip(location) or not public_ips():
192 if location is None or is_public_ip(location) or not public_ips():
192 # If location is unspecified or cannot be determined, assume local
193 # If location is unspecified or cannot be determined, assume local
193 ip = localhost()
194 ip = localhost()
194 elif location:
195 elif location:
195 return location
196 return location
196 return ip
197 return ip
197
198
198 def disambiguate_url(url, location=None):
199 def disambiguate_url(url, location=None):
199 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
200 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
200 ones, based on the location (default interpretation is localhost).
201 ones, based on the location (default interpretation is localhost).
201
202
202 This is for zeromq urls, such as ``tcp://*:10101``.
203 This is for zeromq urls, such as ``tcp://*:10101``.
203 """
204 """
204 try:
205 try:
205 proto,ip,port = split_url(url)
206 proto,ip,port = split_url(url)
206 except AssertionError:
207 except AssertionError:
207 # probably not tcp url; could be ipc, etc.
208 # probably not tcp url; could be ipc, etc.
208 return url
209 return url
209
210
210 ip = disambiguate_ip_address(ip,location)
211 ip = disambiguate_ip_address(ip,location)
211
212
212 return "%s://%s:%s"%(proto,ip,port)
213 return "%s://%s:%s"%(proto,ip,port)
213
214
214
215
215 #--------------------------------------------------------------------------
216 #--------------------------------------------------------------------------
216 # helpers for implementing old MEC API via view.apply
217 # helpers for implementing old MEC API via view.apply
217 #--------------------------------------------------------------------------
218 #--------------------------------------------------------------------------
218
219
219 def interactive(f):
220 def interactive(f):
220 """decorator for making functions appear as interactively defined.
221 """decorator for making functions appear as interactively defined.
221 This results in the function being linked to the user_ns as globals()
222 This results in the function being linked to the user_ns as globals()
222 instead of the module globals().
223 instead of the module globals().
223 """
224 """
224 f.__module__ = '__main__'
225 mainmod = __import__('__main__')
225 return f
226
227 # build new FunctionType, so it can have the right globals
228 # interactive functions never have closures, that's kind of the point
229 f2 = FunctionType(f.__code__, mainmod.__dict__,
230 f.__name__, f.__defaults__,
231 )
232 # associate with __main__ for uncanning
233 f2.__module__ = '__main__'
234 return f2
226
235
227 @interactive
236 @interactive
228 def _push(**ns):
237 def _push(**ns):
229 """helper method for implementing `client.push` via `client.apply`"""
238 """helper method for implementing `client.push` via `client.apply`"""
230 user_ns = globals()
239 user_ns = globals()
231 tmp = '_IP_PUSH_TMP_'
240 tmp = '_IP_PUSH_TMP_'
232 while tmp in user_ns:
241 while tmp in user_ns:
233 tmp = tmp + '_'
242 tmp = tmp + '_'
234 try:
243 try:
235 for name, value in ns.items():
244 for name, value in ns.items():
236 user_ns[tmp] = value
245 user_ns[tmp] = value
237 exec("%s = %s" % (name, tmp), user_ns)
246 exec("%s = %s" % (name, tmp), user_ns)
238 finally:
247 finally:
239 user_ns.pop(tmp, None)
248 user_ns.pop(tmp, None)
240
249
241 @interactive
250 @interactive
242 def _pull(keys):
251 def _pull(keys):
243 """helper method for implementing `client.pull` via `client.apply`"""
252 """helper method for implementing `client.pull` via `client.apply`"""
244 if isinstance(keys, (list,tuple, set)):
253 if isinstance(keys, (list,tuple, set)):
245 return [eval(key, globals()) for key in keys]
254 return [eval(key, globals()) for key in keys]
246 else:
255 else:
247 return eval(keys, globals())
256 return eval(keys, globals())
248
257
249 @interactive
258 @interactive
250 def _execute(code):
259 def _execute(code):
251 """helper method for implementing `client.execute` via `client.apply`"""
260 """helper method for implementing `client.execute` via `client.apply`"""
252 exec(code, globals())
261 exec(code, globals())
253
262
254 #--------------------------------------------------------------------------
263 #--------------------------------------------------------------------------
255 # extra process management utilities
264 # extra process management utilities
256 #--------------------------------------------------------------------------
265 #--------------------------------------------------------------------------
257
266
258 _random_ports = set()
267 _random_ports = set()
259
268
260 def select_random_ports(n):
269 def select_random_ports(n):
261 """Selects and return n random ports that are available."""
270 """Selects and return n random ports that are available."""
262 ports = []
271 ports = []
263 for i in range(n):
272 for i in range(n):
264 sock = socket.socket()
273 sock = socket.socket()
265 sock.bind(('', 0))
274 sock.bind(('', 0))
266 while sock.getsockname()[1] in _random_ports:
275 while sock.getsockname()[1] in _random_ports:
267 sock.close()
276 sock.close()
268 sock = socket.socket()
277 sock = socket.socket()
269 sock.bind(('', 0))
278 sock.bind(('', 0))
270 ports.append(sock)
279 ports.append(sock)
271 for i, sock in enumerate(ports):
280 for i, sock in enumerate(ports):
272 port = sock.getsockname()[1]
281 port = sock.getsockname()[1]
273 sock.close()
282 sock.close()
274 ports[i] = port
283 ports[i] = port
275 _random_ports.add(port)
284 _random_ports.add(port)
276 return ports
285 return ports
277
286
278 def signal_children(children):
287 def signal_children(children):
279 """Relay interupt/term signals to children, for more solid process cleanup."""
288 """Relay interupt/term signals to children, for more solid process cleanup."""
280 def terminate_children(sig, frame):
289 def terminate_children(sig, frame):
281 log = Application.instance().log
290 log = Application.instance().log
282 log.critical("Got signal %i, terminating children..."%sig)
291 log.critical("Got signal %i, terminating children..."%sig)
283 for child in children:
292 for child in children:
284 child.terminate()
293 child.terminate()
285
294
286 sys.exit(sig != SIGINT)
295 sys.exit(sig != SIGINT)
287 # sys.exit(sig)
296 # sys.exit(sig)
288 for sig in (SIGINT, SIGABRT, SIGTERM):
297 for sig in (SIGINT, SIGABRT, SIGTERM):
289 signal(sig, terminate_children)
298 signal(sig, terminate_children)
290
299
291 def generate_exec_key(keyfile):
300 def generate_exec_key(keyfile):
292 import uuid
301 import uuid
293 newkey = str(uuid.uuid4())
302 newkey = str(uuid.uuid4())
294 with open(keyfile, 'w') as f:
303 with open(keyfile, 'w') as f:
295 # f.write('ipython-key ')
304 # f.write('ipython-key ')
296 f.write(newkey+'\n')
305 f.write(newkey+'\n')
297 # set user-only RW permissions (0600)
306 # set user-only RW permissions (0600)
298 # this will have no effect on Windows
307 # this will have no effect on Windows
299 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
308 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
300
309
301
310
302 def integer_loglevel(loglevel):
311 def integer_loglevel(loglevel):
303 try:
312 try:
304 loglevel = int(loglevel)
313 loglevel = int(loglevel)
305 except ValueError:
314 except ValueError:
306 if isinstance(loglevel, str):
315 if isinstance(loglevel, str):
307 loglevel = getattr(logging, loglevel)
316 loglevel = getattr(logging, loglevel)
308 return loglevel
317 return loglevel
309
318
310 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
319 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
311 logger = logging.getLogger(logname)
320 logger = logging.getLogger(logname)
312 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
321 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
313 # don't add a second PUBHandler
322 # don't add a second PUBHandler
314 return
323 return
315 loglevel = integer_loglevel(loglevel)
324 loglevel = integer_loglevel(loglevel)
316 lsock = context.socket(zmq.PUB)
325 lsock = context.socket(zmq.PUB)
317 lsock.connect(iface)
326 lsock.connect(iface)
318 handler = handlers.PUBHandler(lsock)
327 handler = handlers.PUBHandler(lsock)
319 handler.setLevel(loglevel)
328 handler.setLevel(loglevel)
320 handler.root_topic = root
329 handler.root_topic = root
321 logger.addHandler(handler)
330 logger.addHandler(handler)
322 logger.setLevel(loglevel)
331 logger.setLevel(loglevel)
323
332
324 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
333 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
325 logger = logging.getLogger()
334 logger = logging.getLogger()
326 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
335 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
327 # don't add a second PUBHandler
336 # don't add a second PUBHandler
328 return
337 return
329 loglevel = integer_loglevel(loglevel)
338 loglevel = integer_loglevel(loglevel)
330 lsock = context.socket(zmq.PUB)
339 lsock = context.socket(zmq.PUB)
331 lsock.connect(iface)
340 lsock.connect(iface)
332 handler = EnginePUBHandler(engine, lsock)
341 handler = EnginePUBHandler(engine, lsock)
333 handler.setLevel(loglevel)
342 handler.setLevel(loglevel)
334 logger.addHandler(handler)
343 logger.addHandler(handler)
335 logger.setLevel(loglevel)
344 logger.setLevel(loglevel)
336 return logger
345 return logger
337
346
338 def local_logger(logname, loglevel=logging.DEBUG):
347 def local_logger(logname, loglevel=logging.DEBUG):
339 loglevel = integer_loglevel(loglevel)
348 loglevel = integer_loglevel(loglevel)
340 logger = logging.getLogger(logname)
349 logger = logging.getLogger(logname)
341 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
350 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
342 # don't add a second StreamHandler
351 # don't add a second StreamHandler
343 return
352 return
344 handler = logging.StreamHandler()
353 handler = logging.StreamHandler()
345 handler.setLevel(loglevel)
354 handler.setLevel(loglevel)
346 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
355 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
347 datefmt="%Y-%m-%d %H:%M:%S")
356 datefmt="%Y-%m-%d %H:%M:%S")
348 handler.setFormatter(formatter)
357 handler.setFormatter(formatter)
349
358
350 logger.addHandler(handler)
359 logger.addHandler(handler)
351 logger.setLevel(loglevel)
360 logger.setLevel(loglevel)
352 return logger
361 return logger
353
362
354 def set_hwm(sock, hwm=0):
363 def set_hwm(sock, hwm=0):
355 """set zmq High Water Mark on a socket
364 """set zmq High Water Mark on a socket
356
365
357 in a way that always works for various pyzmq / libzmq versions.
366 in a way that always works for various pyzmq / libzmq versions.
358 """
367 """
359 import zmq
368 import zmq
360
369
361 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
370 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
362 opt = getattr(zmq, key, None)
371 opt = getattr(zmq, key, None)
363 if opt is None:
372 if opt is None:
364 continue
373 continue
365 try:
374 try:
366 sock.setsockopt(opt, hwm)
375 sock.setsockopt(opt, hwm)
367 except zmq.ZMQError:
376 except zmq.ZMQError:
368 pass
377 pass
369
378
370 No newline at end of file
379
General Comments 0
You need to be logged in to leave comments. Login now