##// END OF EJS Templates
make @interactive decorator friendlier with dill...
MinRK -
Show More
@@ -1,370 +1,379 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 from types import FunctionType
30 31
31 32 try:
32 33 import cPickle
33 34 pickle = cPickle
34 35 except:
35 36 cPickle = None
36 37 import pickle
37 38
38 39 # System library imports
39 40 import zmq
40 41 from zmq.log import handlers
41 42
42 43 from IPython.external.decorator import decorator
43 44
44 45 # IPython imports
45 46 from IPython.config.application import Application
46 47 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
47 48 from IPython.utils.py3compat import string_types, iteritems, itervalues
48 49 from IPython.kernel.zmq.log import EnginePUBHandler
49 50 from IPython.kernel.zmq.serialize import (
50 51 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
51 52 )
52 53
53 54 #-----------------------------------------------------------------------------
54 55 # Classes
55 56 #-----------------------------------------------------------------------------
56 57
57 58 class Namespace(dict):
58 59 """Subclass of dict for attribute access to keys."""
59 60
60 61 def __getattr__(self, key):
61 62 """getattr aliased to getitem"""
62 63 if key in self:
63 64 return self[key]
64 65 else:
65 66 raise NameError(key)
66 67
67 68 def __setattr__(self, key, value):
68 69 """setattr aliased to setitem, with strict"""
69 70 if hasattr(dict, key):
70 71 raise KeyError("Cannot override dict keys %r"%key)
71 72 self[key] = value
72 73
73 74
74 75 class ReverseDict(dict):
75 76 """simple double-keyed subset of dict methods."""
76 77
77 78 def __init__(self, *args, **kwargs):
78 79 dict.__init__(self, *args, **kwargs)
79 80 self._reverse = dict()
80 81 for key, value in iteritems(self):
81 82 self._reverse[value] = key
82 83
83 84 def __getitem__(self, key):
84 85 try:
85 86 return dict.__getitem__(self, key)
86 87 except KeyError:
87 88 return self._reverse[key]
88 89
89 90 def __setitem__(self, key, value):
90 91 if key in self._reverse:
91 92 raise KeyError("Can't have key %r on both sides!"%key)
92 93 dict.__setitem__(self, key, value)
93 94 self._reverse[value] = key
94 95
95 96 def pop(self, key):
96 97 value = dict.pop(self, key)
97 98 self._reverse.pop(value)
98 99 return value
99 100
100 101 def get(self, key, default=None):
101 102 try:
102 103 return self[key]
103 104 except KeyError:
104 105 return default
105 106
106 107 #-----------------------------------------------------------------------------
107 108 # Functions
108 109 #-----------------------------------------------------------------------------
109 110
110 111 @decorator
111 112 def log_errors(f, self, *args, **kwargs):
112 113 """decorator to log unhandled exceptions raised in a method.
113 114
114 115 For use wrapping on_recv callbacks, so that exceptions
115 116 do not cause the stream to be closed.
116 117 """
117 118 try:
118 119 return f(self, *args, **kwargs)
119 120 except Exception:
120 121 self.log.error("Uncaught exception in %r" % f, exc_info=True)
121 122
122 123
123 124 def is_url(url):
124 125 """boolean check for whether a string is a zmq url"""
125 126 if '://' not in url:
126 127 return False
127 128 proto, addr = url.split('://', 1)
128 129 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
129 130 return False
130 131 return True
131 132
132 133 def validate_url(url):
133 134 """validate a url for zeromq"""
134 135 if not isinstance(url, string_types):
135 136 raise TypeError("url must be a string, not %r"%type(url))
136 137 url = url.lower()
137 138
138 139 proto_addr = url.split('://')
139 140 assert len(proto_addr) == 2, 'Invalid url: %r'%url
140 141 proto, addr = proto_addr
141 142 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
142 143
143 144 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
144 145 # author: Remi Sabourin
145 146 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
146 147
147 148 if proto == 'tcp':
148 149 lis = addr.split(':')
149 150 assert len(lis) == 2, 'Invalid url: %r'%url
150 151 addr,s_port = lis
151 152 try:
152 153 port = int(s_port)
153 154 except ValueError:
154 155 raise AssertionError("Invalid port %r in url: %r"%(port, url))
155 156
156 157 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
157 158
158 159 else:
159 160 # only validate tcp urls currently
160 161 pass
161 162
162 163 return True
163 164
164 165
165 166 def validate_url_container(container):
166 167 """validate a potentially nested collection of urls."""
167 168 if isinstance(container, string_types):
168 169 url = container
169 170 return validate_url(url)
170 171 elif isinstance(container, dict):
171 172 container = itervalues(container)
172 173
173 174 for element in container:
174 175 validate_url_container(element)
175 176
176 177
177 178 def split_url(url):
178 179 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
179 180 proto_addr = url.split('://')
180 181 assert len(proto_addr) == 2, 'Invalid url: %r'%url
181 182 proto, addr = proto_addr
182 183 lis = addr.split(':')
183 184 assert len(lis) == 2, 'Invalid url: %r'%url
184 185 addr,s_port = lis
185 186 return proto,addr,s_port
186 187
187 188 def disambiguate_ip_address(ip, location=None):
188 189 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
189 190 ones, based on the location (default interpretation of location is localhost)."""
190 191 if ip in ('0.0.0.0', '*'):
191 192 if location is None or is_public_ip(location) or not public_ips():
192 193 # If location is unspecified or cannot be determined, assume local
193 194 ip = localhost()
194 195 elif location:
195 196 return location
196 197 return ip
197 198
198 199 def disambiguate_url(url, location=None):
199 200 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
200 201 ones, based on the location (default interpretation is localhost).
201 202
202 203 This is for zeromq urls, such as ``tcp://*:10101``.
203 204 """
204 205 try:
205 206 proto,ip,port = split_url(url)
206 207 except AssertionError:
207 208 # probably not tcp url; could be ipc, etc.
208 209 return url
209 210
210 211 ip = disambiguate_ip_address(ip,location)
211 212
212 213 return "%s://%s:%s"%(proto,ip,port)
213 214
214 215
215 216 #--------------------------------------------------------------------------
216 217 # helpers for implementing old MEC API via view.apply
217 218 #--------------------------------------------------------------------------
218 219
219 220 def interactive(f):
220 221 """decorator for making functions appear as interactively defined.
221 222 This results in the function being linked to the user_ns as globals()
222 223 instead of the module globals().
223 224 """
224 f.__module__ = '__main__'
225 return f
225 mainmod = __import__('__main__')
226
227 # build new FunctionType, so it can have the right globals
228 # interactive functions never have closures, that's kind of the point
229 f2 = FunctionType(f.__code__, mainmod.__dict__,
230 f.__name__, f.__defaults__,
231 )
232 # associate with __main__ for uncanning
233 f2.__module__ = '__main__'
234 return f2
226 235
227 236 @interactive
228 237 def _push(**ns):
229 238 """helper method for implementing `client.push` via `client.apply`"""
230 239 user_ns = globals()
231 240 tmp = '_IP_PUSH_TMP_'
232 241 while tmp in user_ns:
233 242 tmp = tmp + '_'
234 243 try:
235 244 for name, value in ns.items():
236 245 user_ns[tmp] = value
237 246 exec("%s = %s" % (name, tmp), user_ns)
238 247 finally:
239 248 user_ns.pop(tmp, None)
240 249
241 250 @interactive
242 251 def _pull(keys):
243 252 """helper method for implementing `client.pull` via `client.apply`"""
244 253 if isinstance(keys, (list,tuple, set)):
245 254 return [eval(key, globals()) for key in keys]
246 255 else:
247 256 return eval(keys, globals())
248 257
249 258 @interactive
250 259 def _execute(code):
251 260 """helper method for implementing `client.execute` via `client.apply`"""
252 261 exec(code, globals())
253 262
254 263 #--------------------------------------------------------------------------
255 264 # extra process management utilities
256 265 #--------------------------------------------------------------------------
257 266
258 267 _random_ports = set()
259 268
260 269 def select_random_ports(n):
261 270 """Selects and return n random ports that are available."""
262 271 ports = []
263 272 for i in range(n):
264 273 sock = socket.socket()
265 274 sock.bind(('', 0))
266 275 while sock.getsockname()[1] in _random_ports:
267 276 sock.close()
268 277 sock = socket.socket()
269 278 sock.bind(('', 0))
270 279 ports.append(sock)
271 280 for i, sock in enumerate(ports):
272 281 port = sock.getsockname()[1]
273 282 sock.close()
274 283 ports[i] = port
275 284 _random_ports.add(port)
276 285 return ports
277 286
278 287 def signal_children(children):
279 288 """Relay interupt/term signals to children, for more solid process cleanup."""
280 289 def terminate_children(sig, frame):
281 290 log = Application.instance().log
282 291 log.critical("Got signal %i, terminating children..."%sig)
283 292 for child in children:
284 293 child.terminate()
285 294
286 295 sys.exit(sig != SIGINT)
287 296 # sys.exit(sig)
288 297 for sig in (SIGINT, SIGABRT, SIGTERM):
289 298 signal(sig, terminate_children)
290 299
291 300 def generate_exec_key(keyfile):
292 301 import uuid
293 302 newkey = str(uuid.uuid4())
294 303 with open(keyfile, 'w') as f:
295 304 # f.write('ipython-key ')
296 305 f.write(newkey+'\n')
297 306 # set user-only RW permissions (0600)
298 307 # this will have no effect on Windows
299 308 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
300 309
301 310
302 311 def integer_loglevel(loglevel):
303 312 try:
304 313 loglevel = int(loglevel)
305 314 except ValueError:
306 315 if isinstance(loglevel, str):
307 316 loglevel = getattr(logging, loglevel)
308 317 return loglevel
309 318
310 319 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
311 320 logger = logging.getLogger(logname)
312 321 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
313 322 # don't add a second PUBHandler
314 323 return
315 324 loglevel = integer_loglevel(loglevel)
316 325 lsock = context.socket(zmq.PUB)
317 326 lsock.connect(iface)
318 327 handler = handlers.PUBHandler(lsock)
319 328 handler.setLevel(loglevel)
320 329 handler.root_topic = root
321 330 logger.addHandler(handler)
322 331 logger.setLevel(loglevel)
323 332
324 333 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
325 334 logger = logging.getLogger()
326 335 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
327 336 # don't add a second PUBHandler
328 337 return
329 338 loglevel = integer_loglevel(loglevel)
330 339 lsock = context.socket(zmq.PUB)
331 340 lsock.connect(iface)
332 341 handler = EnginePUBHandler(engine, lsock)
333 342 handler.setLevel(loglevel)
334 343 logger.addHandler(handler)
335 344 logger.setLevel(loglevel)
336 345 return logger
337 346
338 347 def local_logger(logname, loglevel=logging.DEBUG):
339 348 loglevel = integer_loglevel(loglevel)
340 349 logger = logging.getLogger(logname)
341 350 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
342 351 # don't add a second StreamHandler
343 352 return
344 353 handler = logging.StreamHandler()
345 354 handler.setLevel(loglevel)
346 355 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
347 356 datefmt="%Y-%m-%d %H:%M:%S")
348 357 handler.setFormatter(formatter)
349 358
350 359 logger.addHandler(handler)
351 360 logger.setLevel(loglevel)
352 361 return logger
353 362
354 363 def set_hwm(sock, hwm=0):
355 364 """set zmq High Water Mark on a socket
356 365
357 366 in a way that always works for various pyzmq / libzmq versions.
358 367 """
359 368 import zmq
360 369
361 370 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
362 371 opt = getattr(zmq, key, None)
363 372 if opt is None:
364 373 continue
365 374 try:
366 375 sock.setsockopt(opt, hwm)
367 376 except zmq.ZMQError:
368 377 pass
369 378
370 379 No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now