##// END OF EJS Templates
Fix non-ascii spaces in comment....
Thomas Kluyver -
Show More
@@ -1,390 +1,390 b''
1 1 """Some generic utilities for dealing with classes, urls, and serialization."""
2 2
3 #Copyright (c) IPython Development Team.
4 #Distributed under the terms of the Modified BSD License.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import logging
7 7 import os
8 8 import re
9 9 import stat
10 10 import socket
11 11 import sys
12 12 import warnings
13 13 from signal import signal, SIGINT, SIGABRT, SIGTERM
14 14 try:
15 15 from signal import SIGKILL
16 16 except ImportError:
17 17 SIGKILL=None
18 18 from types import FunctionType
19 19
20 20 try:
21 21 import cPickle
22 22 pickle = cPickle
23 23 except:
24 24 cPickle = None
25 25 import pickle
26 26
27 27 import zmq
28 28 from zmq.log import handlers
29 29
30 30 from IPython.external.decorator import decorator
31 31
32 32 from IPython.config.application import Application
33 33 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
34 34 from IPython.utils.py3compat import string_types, iteritems, itervalues
35 35 from IPython.kernel.zmq.log import EnginePUBHandler
36 36 from IPython.kernel.zmq.serialize import (
37 37 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
38 38 )
39 39
40 40 #-----------------------------------------------------------------------------
41 41 # Classes
42 42 #-----------------------------------------------------------------------------
43 43
44 44 class Namespace(dict):
45 45 """Subclass of dict for attribute access to keys."""
46 46
47 47 def __getattr__(self, key):
48 48 """getattr aliased to getitem"""
49 49 if key in self:
50 50 return self[key]
51 51 else:
52 52 raise NameError(key)
53 53
54 54 def __setattr__(self, key, value):
55 55 """setattr aliased to setitem, with strict"""
56 56 if hasattr(dict, key):
57 57 raise KeyError("Cannot override dict keys %r"%key)
58 58 self[key] = value
59 59
60 60
61 61 class ReverseDict(dict):
62 62 """simple double-keyed subset of dict methods."""
63 63
64 64 def __init__(self, *args, **kwargs):
65 65 dict.__init__(self, *args, **kwargs)
66 66 self._reverse = dict()
67 67 for key, value in iteritems(self):
68 68 self._reverse[value] = key
69 69
70 70 def __getitem__(self, key):
71 71 try:
72 72 return dict.__getitem__(self, key)
73 73 except KeyError:
74 74 return self._reverse[key]
75 75
76 76 def __setitem__(self, key, value):
77 77 if key in self._reverse:
78 78 raise KeyError("Can't have key %r on both sides!"%key)
79 79 dict.__setitem__(self, key, value)
80 80 self._reverse[value] = key
81 81
82 82 def pop(self, key):
83 83 value = dict.pop(self, key)
84 84 self._reverse.pop(value)
85 85 return value
86 86
87 87 def get(self, key, default=None):
88 88 try:
89 89 return self[key]
90 90 except KeyError:
91 91 return default
92 92
93 93 #-----------------------------------------------------------------------------
94 94 # Functions
95 95 #-----------------------------------------------------------------------------
96 96
97 97 @decorator
98 98 def log_errors(f, self, *args, **kwargs):
99 99 """decorator to log unhandled exceptions raised in a method.
100 100
101 101 For use wrapping on_recv callbacks, so that exceptions
102 102 do not cause the stream to be closed.
103 103 """
104 104 try:
105 105 return f(self, *args, **kwargs)
106 106 except Exception:
107 107 self.log.error("Uncaught exception in %r" % f, exc_info=True)
108 108
109 109
110 110 def is_url(url):
111 111 """boolean check for whether a string is a zmq url"""
112 112 if '://' not in url:
113 113 return False
114 114 proto, addr = url.split('://', 1)
115 115 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
116 116 return False
117 117 return True
118 118
119 119 def validate_url(url):
120 120 """validate a url for zeromq"""
121 121 if not isinstance(url, string_types):
122 122 raise TypeError("url must be a string, not %r"%type(url))
123 123 url = url.lower()
124 124
125 125 proto_addr = url.split('://')
126 126 assert len(proto_addr) == 2, 'Invalid url: %r'%url
127 127 proto, addr = proto_addr
128 128 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
129 129
130 130 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
131 131 # author: Remi Sabourin
132 132 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
133 133
134 134 if proto == 'tcp':
135 135 lis = addr.split(':')
136 136 assert len(lis) == 2, 'Invalid url: %r'%url
137 137 addr,s_port = lis
138 138 try:
139 139 port = int(s_port)
140 140 except ValueError:
141 141 raise AssertionError("Invalid port %r in url: %r"%(port, url))
142 142
143 143 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
144 144
145 145 else:
146 146 # only validate tcp urls currently
147 147 pass
148 148
149 149 return True
150 150
151 151
152 152 def validate_url_container(container):
153 153 """validate a potentially nested collection of urls."""
154 154 if isinstance(container, string_types):
155 155 url = container
156 156 return validate_url(url)
157 157 elif isinstance(container, dict):
158 158 container = itervalues(container)
159 159
160 160 for element in container:
161 161 validate_url_container(element)
162 162
163 163
164 164 def split_url(url):
165 165 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
166 166 proto_addr = url.split('://')
167 167 assert len(proto_addr) == 2, 'Invalid url: %r'%url
168 168 proto, addr = proto_addr
169 169 lis = addr.split(':')
170 170 assert len(lis) == 2, 'Invalid url: %r'%url
171 171 addr,s_port = lis
172 172 return proto,addr,s_port
173 173
174 174
175 175 def disambiguate_ip_address(ip, location=None):
176 176 """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address
177 177
178 178 Explicit IP addresses are returned unmodified.
179 179
180 180 Parameters
181 181 ----------
182 182
183 183 ip : IP address
184 184 An IP address, or the special values 0.0.0.0, or *
185 185 location: IP address, optional
186 186 A public IP of the target machine.
187 187 If location is an IP of the current machine,
188 188 localhost will be returned,
189 189 otherwise location will be returned.
190 190 """
191 191 if ip in {'0.0.0.0', '*'}:
192 192 if not location:
193 193 # unspecified location, localhost is the only choice
194 194 ip = localhost()
195 195 elif is_public_ip(location):
196 196 # location is a public IP on this machine, use localhost
197 197 ip = localhost()
198 198 elif not public_ips():
199 199 # this machine's public IPs cannot be determined,
200 200 # assume `location` is not this machine
201 201 warnings.warn("IPython could not determine public IPs", RuntimeWarning)
202 202 ip = location
203 203 else:
204 204 # location is not this machine, do not use loopback
205 205 ip = location
206 206 return ip
207 207
208 208
209 209 def disambiguate_url(url, location=None):
210 210 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
211 211 ones, based on the location (default interpretation is localhost).
212 212
213 213 This is for zeromq urls, such as ``tcp://*:10101``.
214 214 """
215 215 try:
216 216 proto,ip,port = split_url(url)
217 217 except AssertionError:
218 218 # probably not tcp url; could be ipc, etc.
219 219 return url
220 220
221 221 ip = disambiguate_ip_address(ip,location)
222 222
223 223 return "%s://%s:%s"%(proto,ip,port)
224 224
225 225
226 226 #--------------------------------------------------------------------------
227 227 # helpers for implementing old MEC API via view.apply
228 228 #--------------------------------------------------------------------------
229 229
230 230 def interactive(f):
231 231 """decorator for making functions appear as interactively defined.
232 232 This results in the function being linked to the user_ns as globals()
233 233 instead of the module globals().
234 234 """
235 235
236 236 # build new FunctionType, so it can have the right globals
237 237 # interactive functions never have closures, that's kind of the point
238 238 if isinstance(f, FunctionType):
239 239 mainmod = __import__('__main__')
240 240 f = FunctionType(f.__code__, mainmod.__dict__,
241 241 f.__name__, f.__defaults__,
242 242 )
243 243 # associate with __main__ for uncanning
244 244 f.__module__ = '__main__'
245 245 return f
246 246
247 247 @interactive
248 248 def _push(**ns):
249 249 """helper method for implementing `client.push` via `client.apply`"""
250 250 user_ns = globals()
251 251 tmp = '_IP_PUSH_TMP_'
252 252 while tmp in user_ns:
253 253 tmp = tmp + '_'
254 254 try:
255 255 for name, value in ns.items():
256 256 user_ns[tmp] = value
257 257 exec("%s = %s" % (name, tmp), user_ns)
258 258 finally:
259 259 user_ns.pop(tmp, None)
260 260
261 261 @interactive
262 262 def _pull(keys):
263 263 """helper method for implementing `client.pull` via `client.apply`"""
264 264 if isinstance(keys, (list,tuple, set)):
265 265 return [eval(key, globals()) for key in keys]
266 266 else:
267 267 return eval(keys, globals())
268 268
269 269 @interactive
270 270 def _execute(code):
271 271 """helper method for implementing `client.execute` via `client.apply`"""
272 272 exec(code, globals())
273 273
274 274 #--------------------------------------------------------------------------
275 275 # extra process management utilities
276 276 #--------------------------------------------------------------------------
277 277
278 278 _random_ports = set()
279 279
280 280 def select_random_ports(n):
281 281 """Selects and return n random ports that are available."""
282 282 ports = []
283 283 for i in range(n):
284 284 sock = socket.socket()
285 285 sock.bind(('', 0))
286 286 while sock.getsockname()[1] in _random_ports:
287 287 sock.close()
288 288 sock = socket.socket()
289 289 sock.bind(('', 0))
290 290 ports.append(sock)
291 291 for i, sock in enumerate(ports):
292 292 port = sock.getsockname()[1]
293 293 sock.close()
294 294 ports[i] = port
295 295 _random_ports.add(port)
296 296 return ports
297 297
298 298 def signal_children(children):
299 299 """Relay interupt/term signals to children, for more solid process cleanup."""
300 300 def terminate_children(sig, frame):
301 301 log = Application.instance().log
302 302 log.critical("Got signal %i, terminating children..."%sig)
303 303 for child in children:
304 304 child.terminate()
305 305
306 306 sys.exit(sig != SIGINT)
307 307 # sys.exit(sig)
308 308 for sig in (SIGINT, SIGABRT, SIGTERM):
309 309 signal(sig, terminate_children)
310 310
311 311 def generate_exec_key(keyfile):
312 312 import uuid
313 313 newkey = str(uuid.uuid4())
314 314 with open(keyfile, 'w') as f:
315 315 # f.write('ipython-key ')
316 316 f.write(newkey+'\n')
317 317 # set user-only RW permissions (0600)
318 318 # this will have no effect on Windows
319 319 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
320 320
321 321
322 322 def integer_loglevel(loglevel):
323 323 try:
324 324 loglevel = int(loglevel)
325 325 except ValueError:
326 326 if isinstance(loglevel, str):
327 327 loglevel = getattr(logging, loglevel)
328 328 return loglevel
329 329
330 330 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
331 331 logger = logging.getLogger(logname)
332 332 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
333 333 # don't add a second PUBHandler
334 334 return
335 335 loglevel = integer_loglevel(loglevel)
336 336 lsock = context.socket(zmq.PUB)
337 337 lsock.connect(iface)
338 338 handler = handlers.PUBHandler(lsock)
339 339 handler.setLevel(loglevel)
340 340 handler.root_topic = root
341 341 logger.addHandler(handler)
342 342 logger.setLevel(loglevel)
343 343
344 344 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
345 345 logger = logging.getLogger()
346 346 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
347 347 # don't add a second PUBHandler
348 348 return
349 349 loglevel = integer_loglevel(loglevel)
350 350 lsock = context.socket(zmq.PUB)
351 351 lsock.connect(iface)
352 352 handler = EnginePUBHandler(engine, lsock)
353 353 handler.setLevel(loglevel)
354 354 logger.addHandler(handler)
355 355 logger.setLevel(loglevel)
356 356 return logger
357 357
358 358 def local_logger(logname, loglevel=logging.DEBUG):
359 359 loglevel = integer_loglevel(loglevel)
360 360 logger = logging.getLogger(logname)
361 361 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
362 362 # don't add a second StreamHandler
363 363 return
364 364 handler = logging.StreamHandler()
365 365 handler.setLevel(loglevel)
366 366 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
367 367 datefmt="%Y-%m-%d %H:%M:%S")
368 368 handler.setFormatter(formatter)
369 369
370 370 logger.addHandler(handler)
371 371 logger.setLevel(loglevel)
372 372 return logger
373 373
374 374 def set_hwm(sock, hwm=0):
375 375 """set zmq High Water Mark on a socket
376 376
377 377 in a way that always works for various pyzmq / libzmq versions.
378 378 """
379 379 import zmq
380 380
381 381 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
382 382 opt = getattr(zmq, key, None)
383 383 if opt is None:
384 384 continue
385 385 try:
386 386 sock.setsockopt(opt, hwm)
387 387 except zmq.ZMQError:
388 388 pass
389 389
390 390 No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now