##// END OF EJS Templates
Fix non-ascii spaces in comment....
Thomas Kluyver -
Show More
@@ -1,390 +1,390 b''
1 """Some generic utilities for dealing with classes, urls, and serialization."""
1 """Some generic utilities for dealing with classes, urls, and serialization."""
2
2
3 #Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 #Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import logging
6 import logging
7 import os
7 import os
8 import re
8 import re
9 import stat
9 import stat
10 import socket
10 import socket
11 import sys
11 import sys
12 import warnings
12 import warnings
13 from signal import signal, SIGINT, SIGABRT, SIGTERM
13 from signal import signal, SIGINT, SIGABRT, SIGTERM
14 try:
14 try:
15 from signal import SIGKILL
15 from signal import SIGKILL
16 except ImportError:
16 except ImportError:
17 SIGKILL=None
17 SIGKILL=None
18 from types import FunctionType
18 from types import FunctionType
19
19
20 try:
20 try:
21 import cPickle
21 import cPickle
22 pickle = cPickle
22 pickle = cPickle
23 except:
23 except:
24 cPickle = None
24 cPickle = None
25 import pickle
25 import pickle
26
26
27 import zmq
27 import zmq
28 from zmq.log import handlers
28 from zmq.log import handlers
29
29
30 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
31
31
32 from IPython.config.application import Application
32 from IPython.config.application import Application
33 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
33 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
34 from IPython.utils.py3compat import string_types, iteritems, itervalues
34 from IPython.utils.py3compat import string_types, iteritems, itervalues
35 from IPython.kernel.zmq.log import EnginePUBHandler
35 from IPython.kernel.zmq.log import EnginePUBHandler
36 from IPython.kernel.zmq.serialize import (
36 from IPython.kernel.zmq.serialize import (
37 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
37 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
38 )
38 )
39
39
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41 # Classes
41 # Classes
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43
43
44 class Namespace(dict):
44 class Namespace(dict):
45 """Subclass of dict for attribute access to keys."""
45 """Subclass of dict for attribute access to keys."""
46
46
47 def __getattr__(self, key):
47 def __getattr__(self, key):
48 """getattr aliased to getitem"""
48 """getattr aliased to getitem"""
49 if key in self:
49 if key in self:
50 return self[key]
50 return self[key]
51 else:
51 else:
52 raise NameError(key)
52 raise NameError(key)
53
53
54 def __setattr__(self, key, value):
54 def __setattr__(self, key, value):
55 """setattr aliased to setitem, with strict"""
55 """setattr aliased to setitem, with strict"""
56 if hasattr(dict, key):
56 if hasattr(dict, key):
57 raise KeyError("Cannot override dict keys %r"%key)
57 raise KeyError("Cannot override dict keys %r"%key)
58 self[key] = value
58 self[key] = value
59
59
60
60
61 class ReverseDict(dict):
61 class ReverseDict(dict):
62 """simple double-keyed subset of dict methods."""
62 """simple double-keyed subset of dict methods."""
63
63
64 def __init__(self, *args, **kwargs):
64 def __init__(self, *args, **kwargs):
65 dict.__init__(self, *args, **kwargs)
65 dict.__init__(self, *args, **kwargs)
66 self._reverse = dict()
66 self._reverse = dict()
67 for key, value in iteritems(self):
67 for key, value in iteritems(self):
68 self._reverse[value] = key
68 self._reverse[value] = key
69
69
70 def __getitem__(self, key):
70 def __getitem__(self, key):
71 try:
71 try:
72 return dict.__getitem__(self, key)
72 return dict.__getitem__(self, key)
73 except KeyError:
73 except KeyError:
74 return self._reverse[key]
74 return self._reverse[key]
75
75
76 def __setitem__(self, key, value):
76 def __setitem__(self, key, value):
77 if key in self._reverse:
77 if key in self._reverse:
78 raise KeyError("Can't have key %r on both sides!"%key)
78 raise KeyError("Can't have key %r on both sides!"%key)
79 dict.__setitem__(self, key, value)
79 dict.__setitem__(self, key, value)
80 self._reverse[value] = key
80 self._reverse[value] = key
81
81
82 def pop(self, key):
82 def pop(self, key):
83 value = dict.pop(self, key)
83 value = dict.pop(self, key)
84 self._reverse.pop(value)
84 self._reverse.pop(value)
85 return value
85 return value
86
86
87 def get(self, key, default=None):
87 def get(self, key, default=None):
88 try:
88 try:
89 return self[key]
89 return self[key]
90 except KeyError:
90 except KeyError:
91 return default
91 return default
92
92
93 #-----------------------------------------------------------------------------
93 #-----------------------------------------------------------------------------
94 # Functions
94 # Functions
95 #-----------------------------------------------------------------------------
95 #-----------------------------------------------------------------------------
96
96
97 @decorator
97 @decorator
98 def log_errors(f, self, *args, **kwargs):
98 def log_errors(f, self, *args, **kwargs):
99 """decorator to log unhandled exceptions raised in a method.
99 """decorator to log unhandled exceptions raised in a method.
100
100
101 For use wrapping on_recv callbacks, so that exceptions
101 For use wrapping on_recv callbacks, so that exceptions
102 do not cause the stream to be closed.
102 do not cause the stream to be closed.
103 """
103 """
104 try:
104 try:
105 return f(self, *args, **kwargs)
105 return f(self, *args, **kwargs)
106 except Exception:
106 except Exception:
107 self.log.error("Uncaught exception in %r" % f, exc_info=True)
107 self.log.error("Uncaught exception in %r" % f, exc_info=True)
108
108
109
109
110 def is_url(url):
110 def is_url(url):
111 """boolean check for whether a string is a zmq url"""
111 """boolean check for whether a string is a zmq url"""
112 if '://' not in url:
112 if '://' not in url:
113 return False
113 return False
114 proto, addr = url.split('://', 1)
114 proto, addr = url.split('://', 1)
115 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
115 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
116 return False
116 return False
117 return True
117 return True
118
118
119 def validate_url(url):
119 def validate_url(url):
120 """validate a url for zeromq"""
120 """validate a url for zeromq"""
121 if not isinstance(url, string_types):
121 if not isinstance(url, string_types):
122 raise TypeError("url must be a string, not %r"%type(url))
122 raise TypeError("url must be a string, not %r"%type(url))
123 url = url.lower()
123 url = url.lower()
124
124
125 proto_addr = url.split('://')
125 proto_addr = url.split('://')
126 assert len(proto_addr) == 2, 'Invalid url: %r'%url
126 assert len(proto_addr) == 2, 'Invalid url: %r'%url
127 proto, addr = proto_addr
127 proto, addr = proto_addr
128 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
128 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
129
129
130 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
130 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
131 # author: Remi Sabourin
131 # author: Remi Sabourin
132 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
132 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
133
133
134 if proto == 'tcp':
134 if proto == 'tcp':
135 lis = addr.split(':')
135 lis = addr.split(':')
136 assert len(lis) == 2, 'Invalid url: %r'%url
136 assert len(lis) == 2, 'Invalid url: %r'%url
137 addr,s_port = lis
137 addr,s_port = lis
138 try:
138 try:
139 port = int(s_port)
139 port = int(s_port)
140 except ValueError:
140 except ValueError:
141 raise AssertionError("Invalid port %r in url: %r"%(port, url))
141 raise AssertionError("Invalid port %r in url: %r"%(port, url))
142
142
143 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
143 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
144
144
145 else:
145 else:
146 # only validate tcp urls currently
146 # only validate tcp urls currently
147 pass
147 pass
148
148
149 return True
149 return True
150
150
151
151
152 def validate_url_container(container):
152 def validate_url_container(container):
153 """validate a potentially nested collection of urls."""
153 """validate a potentially nested collection of urls."""
154 if isinstance(container, string_types):
154 if isinstance(container, string_types):
155 url = container
155 url = container
156 return validate_url(url)
156 return validate_url(url)
157 elif isinstance(container, dict):
157 elif isinstance(container, dict):
158 container = itervalues(container)
158 container = itervalues(container)
159
159
160 for element in container:
160 for element in container:
161 validate_url_container(element)
161 validate_url_container(element)
162
162
163
163
164 def split_url(url):
164 def split_url(url):
165 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
165 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
166 proto_addr = url.split('://')
166 proto_addr = url.split('://')
167 assert len(proto_addr) == 2, 'Invalid url: %r'%url
167 assert len(proto_addr) == 2, 'Invalid url: %r'%url
168 proto, addr = proto_addr
168 proto, addr = proto_addr
169 lis = addr.split(':')
169 lis = addr.split(':')
170 assert len(lis) == 2, 'Invalid url: %r'%url
170 assert len(lis) == 2, 'Invalid url: %r'%url
171 addr,s_port = lis
171 addr,s_port = lis
172 return proto,addr,s_port
172 return proto,addr,s_port
173
173
174
174
175 def disambiguate_ip_address(ip, location=None):
175 def disambiguate_ip_address(ip, location=None):
176 """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address
176 """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address
177
177
178 Explicit IP addresses are returned unmodified.
178 Explicit IP addresses are returned unmodified.
179
179
180 Parameters
180 Parameters
181 ----------
181 ----------
182
182
183 ip : IP address
183 ip : IP address
184 An IP address, or the special values 0.0.0.0, or *
184 An IP address, or the special values 0.0.0.0, or *
185 location: IP address, optional
185 location: IP address, optional
186 A public IP of the target machine.
186 A public IP of the target machine.
187 If location is an IP of the current machine,
187 If location is an IP of the current machine,
188 localhost will be returned,
188 localhost will be returned,
189 otherwise location will be returned.
189 otherwise location will be returned.
190 """
190 """
191 if ip in {'0.0.0.0', '*'}:
191 if ip in {'0.0.0.0', '*'}:
192 if not location:
192 if not location:
193 # unspecified location, localhost is the only choice
193 # unspecified location, localhost is the only choice
194 ip = localhost()
194 ip = localhost()
195 elif is_public_ip(location):
195 elif is_public_ip(location):
196 # location is a public IP on this machine, use localhost
196 # location is a public IP on this machine, use localhost
197 ip = localhost()
197 ip = localhost()
198 elif not public_ips():
198 elif not public_ips():
199 # this machine's public IPs cannot be determined,
199 # this machine's public IPs cannot be determined,
200 # assume `location` is not this machine
200 # assume `location` is not this machine
201 warnings.warn("IPython could not determine public IPs", RuntimeWarning)
201 warnings.warn("IPython could not determine public IPs", RuntimeWarning)
202 ip = location
202 ip = location
203 else:
203 else:
204 # location is not this machine, do not use loopback
204 # location is not this machine, do not use loopback
205 ip = location
205 ip = location
206 return ip
206 return ip
207
207
208
208
209 def disambiguate_url(url, location=None):
209 def disambiguate_url(url, location=None):
210 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
210 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
211 ones, based on the location (default interpretation is localhost).
211 ones, based on the location (default interpretation is localhost).
212
212
213 This is for zeromq urls, such as ``tcp://*:10101``.
213 This is for zeromq urls, such as ``tcp://*:10101``.
214 """
214 """
215 try:
215 try:
216 proto,ip,port = split_url(url)
216 proto,ip,port = split_url(url)
217 except AssertionError:
217 except AssertionError:
218 # probably not tcp url; could be ipc, etc.
218 # probably not tcp url; could be ipc, etc.
219 return url
219 return url
220
220
221 ip = disambiguate_ip_address(ip,location)
221 ip = disambiguate_ip_address(ip,location)
222
222
223 return "%s://%s:%s"%(proto,ip,port)
223 return "%s://%s:%s"%(proto,ip,port)
224
224
225
225
226 #--------------------------------------------------------------------------
226 #--------------------------------------------------------------------------
227 # helpers for implementing old MEC API via view.apply
227 # helpers for implementing old MEC API via view.apply
228 #--------------------------------------------------------------------------
228 #--------------------------------------------------------------------------
229
229
230 def interactive(f):
230 def interactive(f):
231 """decorator for making functions appear as interactively defined.
231 """decorator for making functions appear as interactively defined.
232 This results in the function being linked to the user_ns as globals()
232 This results in the function being linked to the user_ns as globals()
233 instead of the module globals().
233 instead of the module globals().
234 """
234 """
235
235
236 # build new FunctionType, so it can have the right globals
236 # build new FunctionType, so it can have the right globals
237 # interactive functions never have closures, that's kind of the point
237 # interactive functions never have closures, that's kind of the point
238 if isinstance(f, FunctionType):
238 if isinstance(f, FunctionType):
239 mainmod = __import__('__main__')
239 mainmod = __import__('__main__')
240 f = FunctionType(f.__code__, mainmod.__dict__,
240 f = FunctionType(f.__code__, mainmod.__dict__,
241 f.__name__, f.__defaults__,
241 f.__name__, f.__defaults__,
242 )
242 )
243 # associate with __main__ for uncanning
243 # associate with __main__ for uncanning
244 f.__module__ = '__main__'
244 f.__module__ = '__main__'
245 return f
245 return f
246
246
247 @interactive
247 @interactive
248 def _push(**ns):
248 def _push(**ns):
249 """helper method for implementing `client.push` via `client.apply`"""
249 """helper method for implementing `client.push` via `client.apply`"""
250 user_ns = globals()
250 user_ns = globals()
251 tmp = '_IP_PUSH_TMP_'
251 tmp = '_IP_PUSH_TMP_'
252 while tmp in user_ns:
252 while tmp in user_ns:
253 tmp = tmp + '_'
253 tmp = tmp + '_'
254 try:
254 try:
255 for name, value in ns.items():
255 for name, value in ns.items():
256 user_ns[tmp] = value
256 user_ns[tmp] = value
257 exec("%s = %s" % (name, tmp), user_ns)
257 exec("%s = %s" % (name, tmp), user_ns)
258 finally:
258 finally:
259 user_ns.pop(tmp, None)
259 user_ns.pop(tmp, None)
260
260
261 @interactive
261 @interactive
262 def _pull(keys):
262 def _pull(keys):
263 """helper method for implementing `client.pull` via `client.apply`"""
263 """helper method for implementing `client.pull` via `client.apply`"""
264 if isinstance(keys, (list,tuple, set)):
264 if isinstance(keys, (list,tuple, set)):
265 return [eval(key, globals()) for key in keys]
265 return [eval(key, globals()) for key in keys]
266 else:
266 else:
267 return eval(keys, globals())
267 return eval(keys, globals())
268
268
269 @interactive
269 @interactive
270 def _execute(code):
270 def _execute(code):
271 """helper method for implementing `client.execute` via `client.apply`"""
271 """helper method for implementing `client.execute` via `client.apply`"""
272 exec(code, globals())
272 exec(code, globals())
273
273
274 #--------------------------------------------------------------------------
274 #--------------------------------------------------------------------------
275 # extra process management utilities
275 # extra process management utilities
276 #--------------------------------------------------------------------------
276 #--------------------------------------------------------------------------
277
277
278 _random_ports = set()
278 _random_ports = set()
279
279
280 def select_random_ports(n):
280 def select_random_ports(n):
281 """Selects and return n random ports that are available."""
281 """Selects and return n random ports that are available."""
282 ports = []
282 ports = []
283 for i in range(n):
283 for i in range(n):
284 sock = socket.socket()
284 sock = socket.socket()
285 sock.bind(('', 0))
285 sock.bind(('', 0))
286 while sock.getsockname()[1] in _random_ports:
286 while sock.getsockname()[1] in _random_ports:
287 sock.close()
287 sock.close()
288 sock = socket.socket()
288 sock = socket.socket()
289 sock.bind(('', 0))
289 sock.bind(('', 0))
290 ports.append(sock)
290 ports.append(sock)
291 for i, sock in enumerate(ports):
291 for i, sock in enumerate(ports):
292 port = sock.getsockname()[1]
292 port = sock.getsockname()[1]
293 sock.close()
293 sock.close()
294 ports[i] = port
294 ports[i] = port
295 _random_ports.add(port)
295 _random_ports.add(port)
296 return ports
296 return ports
297
297
298 def signal_children(children):
298 def signal_children(children):
299 """Relay interupt/term signals to children, for more solid process cleanup."""
299 """Relay interupt/term signals to children, for more solid process cleanup."""
300 def terminate_children(sig, frame):
300 def terminate_children(sig, frame):
301 log = Application.instance().log
301 log = Application.instance().log
302 log.critical("Got signal %i, terminating children..."%sig)
302 log.critical("Got signal %i, terminating children..."%sig)
303 for child in children:
303 for child in children:
304 child.terminate()
304 child.terminate()
305
305
306 sys.exit(sig != SIGINT)
306 sys.exit(sig != SIGINT)
307 # sys.exit(sig)
307 # sys.exit(sig)
308 for sig in (SIGINT, SIGABRT, SIGTERM):
308 for sig in (SIGINT, SIGABRT, SIGTERM):
309 signal(sig, terminate_children)
309 signal(sig, terminate_children)
310
310
311 def generate_exec_key(keyfile):
311 def generate_exec_key(keyfile):
312 import uuid
312 import uuid
313 newkey = str(uuid.uuid4())
313 newkey = str(uuid.uuid4())
314 with open(keyfile, 'w') as f:
314 with open(keyfile, 'w') as f:
315 # f.write('ipython-key ')
315 # f.write('ipython-key ')
316 f.write(newkey+'\n')
316 f.write(newkey+'\n')
317 # set user-only RW permissions (0600)
317 # set user-only RW permissions (0600)
318 # this will have no effect on Windows
318 # this will have no effect on Windows
319 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
319 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
320
320
321
321
322 def integer_loglevel(loglevel):
322 def integer_loglevel(loglevel):
323 try:
323 try:
324 loglevel = int(loglevel)
324 loglevel = int(loglevel)
325 except ValueError:
325 except ValueError:
326 if isinstance(loglevel, str):
326 if isinstance(loglevel, str):
327 loglevel = getattr(logging, loglevel)
327 loglevel = getattr(logging, loglevel)
328 return loglevel
328 return loglevel
329
329
330 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
330 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
331 logger = logging.getLogger(logname)
331 logger = logging.getLogger(logname)
332 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
332 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
333 # don't add a second PUBHandler
333 # don't add a second PUBHandler
334 return
334 return
335 loglevel = integer_loglevel(loglevel)
335 loglevel = integer_loglevel(loglevel)
336 lsock = context.socket(zmq.PUB)
336 lsock = context.socket(zmq.PUB)
337 lsock.connect(iface)
337 lsock.connect(iface)
338 handler = handlers.PUBHandler(lsock)
338 handler = handlers.PUBHandler(lsock)
339 handler.setLevel(loglevel)
339 handler.setLevel(loglevel)
340 handler.root_topic = root
340 handler.root_topic = root
341 logger.addHandler(handler)
341 logger.addHandler(handler)
342 logger.setLevel(loglevel)
342 logger.setLevel(loglevel)
343
343
344 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
344 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
345 logger = logging.getLogger()
345 logger = logging.getLogger()
346 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
346 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
347 # don't add a second PUBHandler
347 # don't add a second PUBHandler
348 return
348 return
349 loglevel = integer_loglevel(loglevel)
349 loglevel = integer_loglevel(loglevel)
350 lsock = context.socket(zmq.PUB)
350 lsock = context.socket(zmq.PUB)
351 lsock.connect(iface)
351 lsock.connect(iface)
352 handler = EnginePUBHandler(engine, lsock)
352 handler = EnginePUBHandler(engine, lsock)
353 handler.setLevel(loglevel)
353 handler.setLevel(loglevel)
354 logger.addHandler(handler)
354 logger.addHandler(handler)
355 logger.setLevel(loglevel)
355 logger.setLevel(loglevel)
356 return logger
356 return logger
357
357
358 def local_logger(logname, loglevel=logging.DEBUG):
358 def local_logger(logname, loglevel=logging.DEBUG):
359 loglevel = integer_loglevel(loglevel)
359 loglevel = integer_loglevel(loglevel)
360 logger = logging.getLogger(logname)
360 logger = logging.getLogger(logname)
361 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
361 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
362 # don't add a second StreamHandler
362 # don't add a second StreamHandler
363 return
363 return
364 handler = logging.StreamHandler()
364 handler = logging.StreamHandler()
365 handler.setLevel(loglevel)
365 handler.setLevel(loglevel)
366 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
366 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
367 datefmt="%Y-%m-%d %H:%M:%S")
367 datefmt="%Y-%m-%d %H:%M:%S")
368 handler.setFormatter(formatter)
368 handler.setFormatter(formatter)
369
369
370 logger.addHandler(handler)
370 logger.addHandler(handler)
371 logger.setLevel(loglevel)
371 logger.setLevel(loglevel)
372 return logger
372 return logger
373
373
374 def set_hwm(sock, hwm=0):
374 def set_hwm(sock, hwm=0):
375 """set zmq High Water Mark on a socket
375 """set zmq High Water Mark on a socket
376
376
377 in a way that always works for various pyzmq / libzmq versions.
377 in a way that always works for various pyzmq / libzmq versions.
378 """
378 """
379 import zmq
379 import zmq
380
380
381 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
381 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
382 opt = getattr(zmq, key, None)
382 opt = getattr(zmq, key, None)
383 if opt is None:
383 if opt is None:
384 continue
384 continue
385 try:
385 try:
386 sock.setsockopt(opt, hwm)
386 sock.setsockopt(opt, hwm)
387 except zmq.ZMQError:
387 except zmq.ZMQError:
388 pass
388 pass
389
389
390 No newline at end of file
390
General Comments 0
You need to be logged in to leave comments. Login now