##// END OF EJS Templates
still allow @interactive on Classes
MinRK -
Show More
@@ -1,379 +1,380 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30 from types import FunctionType
31 31
32 32 try:
33 33 import cPickle
34 34 pickle = cPickle
35 35 except:
36 36 cPickle = None
37 37 import pickle
38 38
39 39 # System library imports
40 40 import zmq
41 41 from zmq.log import handlers
42 42
43 43 from IPython.external.decorator import decorator
44 44
45 45 # IPython imports
46 46 from IPython.config.application import Application
47 47 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
48 48 from IPython.utils.py3compat import string_types, iteritems, itervalues
49 49 from IPython.kernel.zmq.log import EnginePUBHandler
50 50 from IPython.kernel.zmq.serialize import (
51 51 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
52 52 )
53 53
54 54 #-----------------------------------------------------------------------------
55 55 # Classes
56 56 #-----------------------------------------------------------------------------
57 57
58 58 class Namespace(dict):
59 59 """Subclass of dict for attribute access to keys."""
60 60
61 61 def __getattr__(self, key):
62 62 """getattr aliased to getitem"""
63 63 if key in self:
64 64 return self[key]
65 65 else:
66 66 raise NameError(key)
67 67
68 68 def __setattr__(self, key, value):
69 69 """setattr aliased to setitem, with strict"""
70 70 if hasattr(dict, key):
71 71 raise KeyError("Cannot override dict keys %r"%key)
72 72 self[key] = value
73 73
74 74
75 75 class ReverseDict(dict):
76 76 """simple double-keyed subset of dict methods."""
77 77
78 78 def __init__(self, *args, **kwargs):
79 79 dict.__init__(self, *args, **kwargs)
80 80 self._reverse = dict()
81 81 for key, value in iteritems(self):
82 82 self._reverse[value] = key
83 83
84 84 def __getitem__(self, key):
85 85 try:
86 86 return dict.__getitem__(self, key)
87 87 except KeyError:
88 88 return self._reverse[key]
89 89
90 90 def __setitem__(self, key, value):
91 91 if key in self._reverse:
92 92 raise KeyError("Can't have key %r on both sides!"%key)
93 93 dict.__setitem__(self, key, value)
94 94 self._reverse[value] = key
95 95
96 96 def pop(self, key):
97 97 value = dict.pop(self, key)
98 98 self._reverse.pop(value)
99 99 return value
100 100
101 101 def get(self, key, default=None):
102 102 try:
103 103 return self[key]
104 104 except KeyError:
105 105 return default
106 106
107 107 #-----------------------------------------------------------------------------
108 108 # Functions
109 109 #-----------------------------------------------------------------------------
110 110
111 111 @decorator
112 112 def log_errors(f, self, *args, **kwargs):
113 113 """decorator to log unhandled exceptions raised in a method.
114 114
115 115 For use wrapping on_recv callbacks, so that exceptions
116 116 do not cause the stream to be closed.
117 117 """
118 118 try:
119 119 return f(self, *args, **kwargs)
120 120 except Exception:
121 121 self.log.error("Uncaught exception in %r" % f, exc_info=True)
122 122
123 123
124 124 def is_url(url):
125 125 """boolean check for whether a string is a zmq url"""
126 126 if '://' not in url:
127 127 return False
128 128 proto, addr = url.split('://', 1)
129 129 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
130 130 return False
131 131 return True
132 132
133 133 def validate_url(url):
134 134 """validate a url for zeromq"""
135 135 if not isinstance(url, string_types):
136 136 raise TypeError("url must be a string, not %r"%type(url))
137 137 url = url.lower()
138 138
139 139 proto_addr = url.split('://')
140 140 assert len(proto_addr) == 2, 'Invalid url: %r'%url
141 141 proto, addr = proto_addr
142 142 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
143 143
144 144 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
145 145 # author: Remi Sabourin
146 146 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
147 147
148 148 if proto == 'tcp':
149 149 lis = addr.split(':')
150 150 assert len(lis) == 2, 'Invalid url: %r'%url
151 151 addr,s_port = lis
152 152 try:
153 153 port = int(s_port)
154 154 except ValueError:
155 155 raise AssertionError("Invalid port %r in url: %r"%(port, url))
156 156
157 157 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
158 158
159 159 else:
160 160 # only validate tcp urls currently
161 161 pass
162 162
163 163 return True
164 164
165 165
166 166 def validate_url_container(container):
167 167 """validate a potentially nested collection of urls."""
168 168 if isinstance(container, string_types):
169 169 url = container
170 170 return validate_url(url)
171 171 elif isinstance(container, dict):
172 172 container = itervalues(container)
173 173
174 174 for element in container:
175 175 validate_url_container(element)
176 176
177 177
178 178 def split_url(url):
179 179 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
180 180 proto_addr = url.split('://')
181 181 assert len(proto_addr) == 2, 'Invalid url: %r'%url
182 182 proto, addr = proto_addr
183 183 lis = addr.split(':')
184 184 assert len(lis) == 2, 'Invalid url: %r'%url
185 185 addr,s_port = lis
186 186 return proto,addr,s_port
187 187
188 188 def disambiguate_ip_address(ip, location=None):
189 189 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
190 190 ones, based on the location (default interpretation of location is localhost)."""
191 191 if ip in ('0.0.0.0', '*'):
192 192 if location is None or is_public_ip(location) or not public_ips():
193 193 # If location is unspecified or cannot be determined, assume local
194 194 ip = localhost()
195 195 elif location:
196 196 return location
197 197 return ip
198 198
199 199 def disambiguate_url(url, location=None):
200 200 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
201 201 ones, based on the location (default interpretation is localhost).
202 202
203 203 This is for zeromq urls, such as ``tcp://*:10101``.
204 204 """
205 205 try:
206 206 proto,ip,port = split_url(url)
207 207 except AssertionError:
208 208 # probably not tcp url; could be ipc, etc.
209 209 return url
210 210
211 211 ip = disambiguate_ip_address(ip,location)
212 212
213 213 return "%s://%s:%s"%(proto,ip,port)
214 214
215 215
216 216 #--------------------------------------------------------------------------
217 217 # helpers for implementing old MEC API via view.apply
218 218 #--------------------------------------------------------------------------
219 219
220 220 def interactive(f):
221 221 """decorator for making functions appear as interactively defined.
222 222 This results in the function being linked to the user_ns as globals()
223 223 instead of the module globals().
224 224 """
225 mainmod = __import__('__main__')
226 225
227 226 # build new FunctionType, so it can have the right globals
228 227 # interactive functions never have closures, that's kind of the point
229 f2 = FunctionType(f.__code__, mainmod.__dict__,
230 f.__name__, f.__defaults__,
231 )
228 if isinstance(f, FunctionType):
229 mainmod = __import__('__main__')
230 f = FunctionType(f.__code__, mainmod.__dict__,
231 f.__name__, f.__defaults__,
232 )
232 233 # associate with __main__ for uncanning
233 f2.__module__ = '__main__'
234 return f2
234 f.__module__ = '__main__'
235 return f
235 236
236 237 @interactive
237 238 def _push(**ns):
238 239 """helper method for implementing `client.push` via `client.apply`"""
239 240 user_ns = globals()
240 241 tmp = '_IP_PUSH_TMP_'
241 242 while tmp in user_ns:
242 243 tmp = tmp + '_'
243 244 try:
244 245 for name, value in ns.items():
245 246 user_ns[tmp] = value
246 247 exec("%s = %s" % (name, tmp), user_ns)
247 248 finally:
248 249 user_ns.pop(tmp, None)
249 250
250 251 @interactive
251 252 def _pull(keys):
252 253 """helper method for implementing `client.pull` via `client.apply`"""
253 254 if isinstance(keys, (list,tuple, set)):
254 255 return [eval(key, globals()) for key in keys]
255 256 else:
256 257 return eval(keys, globals())
257 258
258 259 @interactive
259 260 def _execute(code):
260 261 """helper method for implementing `client.execute` via `client.apply`"""
261 262 exec(code, globals())
262 263
263 264 #--------------------------------------------------------------------------
264 265 # extra process management utilities
265 266 #--------------------------------------------------------------------------
266 267
267 268 _random_ports = set()
268 269
269 270 def select_random_ports(n):
270 271 """Selects and return n random ports that are available."""
271 272 ports = []
272 273 for i in range(n):
273 274 sock = socket.socket()
274 275 sock.bind(('', 0))
275 276 while sock.getsockname()[1] in _random_ports:
276 277 sock.close()
277 278 sock = socket.socket()
278 279 sock.bind(('', 0))
279 280 ports.append(sock)
280 281 for i, sock in enumerate(ports):
281 282 port = sock.getsockname()[1]
282 283 sock.close()
283 284 ports[i] = port
284 285 _random_ports.add(port)
285 286 return ports
286 287
287 288 def signal_children(children):
288 289 """Relay interupt/term signals to children, for more solid process cleanup."""
289 290 def terminate_children(sig, frame):
290 291 log = Application.instance().log
291 292 log.critical("Got signal %i, terminating children..."%sig)
292 293 for child in children:
293 294 child.terminate()
294 295
295 296 sys.exit(sig != SIGINT)
296 297 # sys.exit(sig)
297 298 for sig in (SIGINT, SIGABRT, SIGTERM):
298 299 signal(sig, terminate_children)
299 300
300 301 def generate_exec_key(keyfile):
301 302 import uuid
302 303 newkey = str(uuid.uuid4())
303 304 with open(keyfile, 'w') as f:
304 305 # f.write('ipython-key ')
305 306 f.write(newkey+'\n')
306 307 # set user-only RW permissions (0600)
307 308 # this will have no effect on Windows
308 309 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
309 310
310 311
311 312 def integer_loglevel(loglevel):
312 313 try:
313 314 loglevel = int(loglevel)
314 315 except ValueError:
315 316 if isinstance(loglevel, str):
316 317 loglevel = getattr(logging, loglevel)
317 318 return loglevel
318 319
319 320 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
320 321 logger = logging.getLogger(logname)
321 322 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
322 323 # don't add a second PUBHandler
323 324 return
324 325 loglevel = integer_loglevel(loglevel)
325 326 lsock = context.socket(zmq.PUB)
326 327 lsock.connect(iface)
327 328 handler = handlers.PUBHandler(lsock)
328 329 handler.setLevel(loglevel)
329 330 handler.root_topic = root
330 331 logger.addHandler(handler)
331 332 logger.setLevel(loglevel)
332 333
333 334 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
334 335 logger = logging.getLogger()
335 336 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
336 337 # don't add a second PUBHandler
337 338 return
338 339 loglevel = integer_loglevel(loglevel)
339 340 lsock = context.socket(zmq.PUB)
340 341 lsock.connect(iface)
341 342 handler = EnginePUBHandler(engine, lsock)
342 343 handler.setLevel(loglevel)
343 344 logger.addHandler(handler)
344 345 logger.setLevel(loglevel)
345 346 return logger
346 347
347 348 def local_logger(logname, loglevel=logging.DEBUG):
348 349 loglevel = integer_loglevel(loglevel)
349 350 logger = logging.getLogger(logname)
350 351 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
351 352 # don't add a second StreamHandler
352 353 return
353 354 handler = logging.StreamHandler()
354 355 handler.setLevel(loglevel)
355 356 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
356 357 datefmt="%Y-%m-%d %H:%M:%S")
357 358 handler.setFormatter(formatter)
358 359
359 360 logger.addHandler(handler)
360 361 logger.setLevel(loglevel)
361 362 return logger
362 363
363 364 def set_hwm(sock, hwm=0):
364 365 """set zmq High Water Mark on a socket
365 366
366 367 in a way that always works for various pyzmq / libzmq versions.
367 368 """
368 369 import zmq
369 370
370 371 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
371 372 opt = getattr(zmq, key, None)
372 373 if opt is None:
373 374 continue
374 375 try:
375 376 sock.setsockopt(opt, hwm)
376 377 except zmq.ZMQError:
377 378 pass
378 379
379 380 No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now