##// END OF EJS Templates
match log format in Scheduler to rest of parallel apps
MinRK -
Show More
@@ -1,472 +1,476 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 42 # IPython imports
43 43 from IPython.config.application import Application
44 44 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
45 45 from IPython.utils.newserialized import serialize, unserialize
46 46 from IPython.zmq.log import EnginePUBHandler
47 47
48 48 #-----------------------------------------------------------------------------
49 49 # Classes
50 50 #-----------------------------------------------------------------------------
51 51
52 52 class Namespace(dict):
53 53 """Subclass of dict for attribute access to keys."""
54 54
55 55 def __getattr__(self, key):
56 56 """getattr aliased to getitem"""
57 57 if key in self.iterkeys():
58 58 return self[key]
59 59 else:
60 60 raise NameError(key)
61 61
62 62 def __setattr__(self, key, value):
63 63 """setattr aliased to setitem, with strict"""
64 64 if hasattr(dict, key):
65 65 raise KeyError("Cannot override dict keys %r"%key)
66 66 self[key] = value
67 67
68 68
69 69 class ReverseDict(dict):
70 70 """simple double-keyed subset of dict methods."""
71 71
72 72 def __init__(self, *args, **kwargs):
73 73 dict.__init__(self, *args, **kwargs)
74 74 self._reverse = dict()
75 75 for key, value in self.iteritems():
76 76 self._reverse[value] = key
77 77
78 78 def __getitem__(self, key):
79 79 try:
80 80 return dict.__getitem__(self, key)
81 81 except KeyError:
82 82 return self._reverse[key]
83 83
84 84 def __setitem__(self, key, value):
85 85 if key in self._reverse:
86 86 raise KeyError("Can't have key %r on both sides!"%key)
87 87 dict.__setitem__(self, key, value)
88 88 self._reverse[value] = key
89 89
90 90 def pop(self, key):
91 91 value = dict.pop(self, key)
92 92 self._reverse.pop(value)
93 93 return value
94 94
95 95 def get(self, key, default=None):
96 96 try:
97 97 return self[key]
98 98 except KeyError:
99 99 return default
100 100
101 101 #-----------------------------------------------------------------------------
102 102 # Functions
103 103 #-----------------------------------------------------------------------------
104 104
105 105 def asbytes(s):
106 106 """ensure that an object is ascii bytes"""
107 107 if isinstance(s, unicode):
108 108 s = s.encode('ascii')
109 109 return s
110 110
111 111 def is_url(url):
112 112 """boolean check for whether a string is a zmq url"""
113 113 if '://' not in url:
114 114 return False
115 115 proto, addr = url.split('://', 1)
116 116 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
117 117 return False
118 118 return True
119 119
120 120 def validate_url(url):
121 121 """validate a url for zeromq"""
122 122 if not isinstance(url, basestring):
123 123 raise TypeError("url must be a string, not %r"%type(url))
124 124 url = url.lower()
125 125
126 126 proto_addr = url.split('://')
127 127 assert len(proto_addr) == 2, 'Invalid url: %r'%url
128 128 proto, addr = proto_addr
129 129 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
130 130
131 131 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
132 132 # author: Remi Sabourin
133 133 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
134 134
135 135 if proto == 'tcp':
136 136 lis = addr.split(':')
137 137 assert len(lis) == 2, 'Invalid url: %r'%url
138 138 addr,s_port = lis
139 139 try:
140 140 port = int(s_port)
141 141 except ValueError:
142 142 raise AssertionError("Invalid port %r in url: %r"%(port, url))
143 143
144 144 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
145 145
146 146 else:
147 147 # only validate tcp urls currently
148 148 pass
149 149
150 150 return True
151 151
152 152
153 153 def validate_url_container(container):
154 154 """validate a potentially nested collection of urls."""
155 155 if isinstance(container, basestring):
156 156 url = container
157 157 return validate_url(url)
158 158 elif isinstance(container, dict):
159 159 container = container.itervalues()
160 160
161 161 for element in container:
162 162 validate_url_container(element)
163 163
164 164
165 165 def split_url(url):
166 166 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
167 167 proto_addr = url.split('://')
168 168 assert len(proto_addr) == 2, 'Invalid url: %r'%url
169 169 proto, addr = proto_addr
170 170 lis = addr.split(':')
171 171 assert len(lis) == 2, 'Invalid url: %r'%url
172 172 addr,s_port = lis
173 173 return proto,addr,s_port
174 174
175 175 def disambiguate_ip_address(ip, location=None):
176 176 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
177 177 ones, based on the location (default interpretation of location is localhost)."""
178 178 if ip in ('0.0.0.0', '*'):
179 179 try:
180 180 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
181 181 except (socket.gaierror, IndexError):
182 182 # couldn't identify this machine, assume localhost
183 183 external_ips = []
184 184 if location is None or location in external_ips or not external_ips:
185 185 # If location is unspecified or cannot be determined, assume local
186 186 ip='127.0.0.1'
187 187 elif location:
188 188 return location
189 189 return ip
190 190
191 191 def disambiguate_url(url, location=None):
192 192 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
193 193 ones, based on the location (default interpretation is localhost).
194 194
195 195 This is for zeromq urls, such as tcp://*:10101."""
196 196 try:
197 197 proto,ip,port = split_url(url)
198 198 except AssertionError:
199 199 # probably not tcp url; could be ipc, etc.
200 200 return url
201 201
202 202 ip = disambiguate_ip_address(ip,location)
203 203
204 204 return "%s://%s:%s"%(proto,ip,port)
205 205
206 206 def serialize_object(obj, threshold=64e-6):
207 207 """Serialize an object into a list of sendable buffers.
208 208
209 209 Parameters
210 210 ----------
211 211
212 212 obj : object
213 213 The object to be serialized
214 214 threshold : float
215 215 The threshold for not double-pickling the content.
216 216
217 217
218 218 Returns
219 219 -------
220 220 ('pmd', [bufs]) :
221 221 where pmd is the pickled metadata wrapper,
222 222 bufs is a list of data buffers
223 223 """
224 224 databuffers = []
225 225 if isinstance(obj, (list, tuple)):
226 226 clist = canSequence(obj)
227 227 slist = map(serialize, clist)
228 228 for s in slist:
229 229 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
230 230 databuffers.append(s.getData())
231 231 s.data = None
232 232 return pickle.dumps(slist,-1), databuffers
233 233 elif isinstance(obj, dict):
234 234 sobj = {}
235 235 for k in sorted(obj.iterkeys()):
236 236 s = serialize(can(obj[k]))
237 237 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
238 238 databuffers.append(s.getData())
239 239 s.data = None
240 240 sobj[k] = s
241 241 return pickle.dumps(sobj,-1),databuffers
242 242 else:
243 243 s = serialize(can(obj))
244 244 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
245 245 databuffers.append(s.getData())
246 246 s.data = None
247 247 return pickle.dumps(s,-1),databuffers
248 248
249 249
250 250 def unserialize_object(bufs):
251 251 """reconstruct an object serialized by serialize_object from data buffers."""
252 252 bufs = list(bufs)
253 253 sobj = pickle.loads(bufs.pop(0))
254 254 if isinstance(sobj, (list, tuple)):
255 255 for s in sobj:
256 256 if s.data is None:
257 257 s.data = bufs.pop(0)
258 258 return uncanSequence(map(unserialize, sobj)), bufs
259 259 elif isinstance(sobj, dict):
260 260 newobj = {}
261 261 for k in sorted(sobj.iterkeys()):
262 262 s = sobj[k]
263 263 if s.data is None:
264 264 s.data = bufs.pop(0)
265 265 newobj[k] = uncan(unserialize(s))
266 266 return newobj, bufs
267 267 else:
268 268 if sobj.data is None:
269 269 sobj.data = bufs.pop(0)
270 270 return uncan(unserialize(sobj)), bufs
271 271
272 272 def pack_apply_message(f, args, kwargs, threshold=64e-6):
273 273 """pack up a function, args, and kwargs to be sent over the wire
274 274 as a series of buffers. Any object whose data is larger than `threshold`
275 275 will not have their data copied (currently only numpy arrays support zero-copy)"""
276 276 msg = [pickle.dumps(can(f),-1)]
277 277 databuffers = [] # for large objects
278 278 sargs, bufs = serialize_object(args,threshold)
279 279 msg.append(sargs)
280 280 databuffers.extend(bufs)
281 281 skwargs, bufs = serialize_object(kwargs,threshold)
282 282 msg.append(skwargs)
283 283 databuffers.extend(bufs)
284 284 msg.extend(databuffers)
285 285 return msg
286 286
287 287 def unpack_apply_message(bufs, g=None, copy=True):
288 288 """unpack f,args,kwargs from buffers packed by pack_apply_message()
289 289 Returns: original f,args,kwargs"""
290 290 bufs = list(bufs) # allow us to pop
291 291 assert len(bufs) >= 3, "not enough buffers!"
292 292 if not copy:
293 293 for i in range(3):
294 294 bufs[i] = bufs[i].bytes
295 295 cf = pickle.loads(bufs.pop(0))
296 296 sargs = list(pickle.loads(bufs.pop(0)))
297 297 skwargs = dict(pickle.loads(bufs.pop(0)))
298 298 # print sargs, skwargs
299 299 f = uncan(cf, g)
300 300 for sa in sargs:
301 301 if sa.data is None:
302 302 m = bufs.pop(0)
303 303 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
304 304 # always use a buffer, until memoryviews get sorted out
305 305 sa.data = buffer(m)
306 306 # disable memoryview support
307 307 # if copy:
308 308 # sa.data = buffer(m)
309 309 # else:
310 310 # sa.data = m.buffer
311 311 else:
312 312 if copy:
313 313 sa.data = m
314 314 else:
315 315 sa.data = m.bytes
316 316
317 317 args = uncanSequence(map(unserialize, sargs), g)
318 318 kwargs = {}
319 319 for k in sorted(skwargs.iterkeys()):
320 320 sa = skwargs[k]
321 321 if sa.data is None:
322 322 m = bufs.pop(0)
323 323 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
324 324 # always use a buffer, until memoryviews get sorted out
325 325 sa.data = buffer(m)
326 326 # disable memoryview support
327 327 # if copy:
328 328 # sa.data = buffer(m)
329 329 # else:
330 330 # sa.data = m.buffer
331 331 else:
332 332 if copy:
333 333 sa.data = m
334 334 else:
335 335 sa.data = m.bytes
336 336
337 337 kwargs[k] = uncan(unserialize(sa), g)
338 338
339 339 return f,args,kwargs
340 340
341 341 #--------------------------------------------------------------------------
342 342 # helpers for implementing old MEC API via view.apply
343 343 #--------------------------------------------------------------------------
344 344
345 345 def interactive(f):
346 346 """decorator for making functions appear as interactively defined.
347 347 This results in the function being linked to the user_ns as globals()
348 348 instead of the module globals().
349 349 """
350 350 f.__module__ = '__main__'
351 351 return f
352 352
353 353 @interactive
354 354 def _push(ns):
355 355 """helper method for implementing `client.push` via `client.apply`"""
356 356 globals().update(ns)
357 357
358 358 @interactive
359 359 def _pull(keys):
360 360 """helper method for implementing `client.pull` via `client.apply`"""
361 361 user_ns = globals()
362 362 if isinstance(keys, (list,tuple, set)):
363 363 for key in keys:
364 364 if not user_ns.has_key(key):
365 365 raise NameError("name '%s' is not defined"%key)
366 366 return map(user_ns.get, keys)
367 367 else:
368 368 if not user_ns.has_key(keys):
369 369 raise NameError("name '%s' is not defined"%keys)
370 370 return user_ns.get(keys)
371 371
372 372 @interactive
373 373 def _execute(code):
374 374 """helper method for implementing `client.execute` via `client.apply`"""
375 375 exec code in globals()
376 376
377 377 #--------------------------------------------------------------------------
378 378 # extra process management utilities
379 379 #--------------------------------------------------------------------------
380 380
381 381 _random_ports = set()
382 382
383 383 def select_random_ports(n):
384 384 """Selects and return n random ports that are available."""
385 385 ports = []
386 386 for i in xrange(n):
387 387 sock = socket.socket()
388 388 sock.bind(('', 0))
389 389 while sock.getsockname()[1] in _random_ports:
390 390 sock.close()
391 391 sock = socket.socket()
392 392 sock.bind(('', 0))
393 393 ports.append(sock)
394 394 for i, sock in enumerate(ports):
395 395 port = sock.getsockname()[1]
396 396 sock.close()
397 397 ports[i] = port
398 398 _random_ports.add(port)
399 399 return ports
400 400
401 401 def signal_children(children):
402 402 """Relay interupt/term signals to children, for more solid process cleanup."""
403 403 def terminate_children(sig, frame):
404 404 log = Application.instance().log
405 405 log.critical("Got signal %i, terminating children..."%sig)
406 406 for child in children:
407 407 child.terminate()
408 408
409 409 sys.exit(sig != SIGINT)
410 410 # sys.exit(sig)
411 411 for sig in (SIGINT, SIGABRT, SIGTERM):
412 412 signal(sig, terminate_children)
413 413
414 414 def generate_exec_key(keyfile):
415 415 import uuid
416 416 newkey = str(uuid.uuid4())
417 417 with open(keyfile, 'w') as f:
418 418 # f.write('ipython-key ')
419 419 f.write(newkey+'\n')
420 420 # set user-only RW permissions (0600)
421 421 # this will have no effect on Windows
422 422 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
423 423
424 424
425 425 def integer_loglevel(loglevel):
426 426 try:
427 427 loglevel = int(loglevel)
428 428 except ValueError:
429 429 if isinstance(loglevel, str):
430 430 loglevel = getattr(logging, loglevel)
431 431 return loglevel
432 432
433 433 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
434 434 logger = logging.getLogger(logname)
435 435 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
436 436 # don't add a second PUBHandler
437 437 return
438 438 loglevel = integer_loglevel(loglevel)
439 439 lsock = context.socket(zmq.PUB)
440 440 lsock.connect(iface)
441 441 handler = handlers.PUBHandler(lsock)
442 442 handler.setLevel(loglevel)
443 443 handler.root_topic = root
444 444 logger.addHandler(handler)
445 445 logger.setLevel(loglevel)
446 446
447 447 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
448 448 logger = logging.getLogger()
449 449 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
450 450 # don't add a second PUBHandler
451 451 return
452 452 loglevel = integer_loglevel(loglevel)
453 453 lsock = context.socket(zmq.PUB)
454 454 lsock.connect(iface)
455 455 handler = EnginePUBHandler(engine, lsock)
456 456 handler.setLevel(loglevel)
457 457 logger.addHandler(handler)
458 458 logger.setLevel(loglevel)
459 459 return logger
460 460
461 461 def local_logger(logname, loglevel=logging.DEBUG):
462 462 loglevel = integer_loglevel(loglevel)
463 463 logger = logging.getLogger(logname)
464 464 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
465 465 # don't add a second StreamHandler
466 466 return
467 467 handler = logging.StreamHandler()
468 468 handler.setLevel(loglevel)
469 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
470 datefmt="%Y-%m-%d %H:%M:%S")
471 handler.setFormatter(formatter)
472
469 473 logger.addHandler(handler)
470 474 logger.setLevel(loglevel)
471 475 return logger
472 476
General Comments 0
You need to be logged in to leave comments. Login now