##// END OF EJS Templates
Fix references to xrange
Thomas Kluyver -
Show More
@@ -1,649 +1,649 b''
1 1 """Base classes to manage a Client's interaction with a running kernel
2 2 """
3 3
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2013 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14
15 15 from __future__ import absolute_import
16 16
17 17 # Standard library imports
18 18 import atexit
19 19 import errno
20 20 from threading import Thread
21 21 import time
22 22
23 23 import zmq
24 24 # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
25 25 # during garbage collection of threads at exit:
26 26 from zmq import ZMQError
27 27 from zmq.eventloop import ioloop, zmqstream
28 28
29 29 # Local imports
30 30 from .channelsabc import (
31 31 ShellChannelABC, IOPubChannelABC,
32 32 HBChannelABC, StdInChannelABC,
33 33 )
34 34 from IPython.utils.py3compat import string_types
35 35
36 36 #-----------------------------------------------------------------------------
37 37 # Constants and exceptions
38 38 #-----------------------------------------------------------------------------
39 39
40 40 class InvalidPortNumber(Exception):
41 41 pass
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Utility functions
45 45 #-----------------------------------------------------------------------------
46 46
47 47 # some utilities to validate message structure, these might get moved elsewhere
48 48 # if they prove to have more generic utility
49 49
50 50 def validate_string_list(lst):
51 51 """Validate that the input is a list of strings.
52 52
53 53 Raises ValueError if not."""
54 54 if not isinstance(lst, list):
55 55 raise ValueError('input %r must be a list' % lst)
56 56 for x in lst:
57 57 if not isinstance(x, string_types):
58 58 raise ValueError('element %r in list must be a string' % x)
59 59
60 60
61 61 def validate_string_dict(dct):
62 62 """Validate that the input is a dict with string keys and values.
63 63
64 64 Raises ValueError if not."""
65 65 for k,v in dct.iteritems():
66 66 if not isinstance(k, string_types):
67 67 raise ValueError('key %r in dict must be a string' % k)
68 68 if not isinstance(v, string_types):
69 69 raise ValueError('value %r in dict must be a string' % v)
70 70
71 71
72 72 #-----------------------------------------------------------------------------
73 73 # ZMQ Socket Channel classes
74 74 #-----------------------------------------------------------------------------
75 75
76 76 class ZMQSocketChannel(Thread):
77 77 """The base class for the channels that use ZMQ sockets."""
78 78 context = None
79 79 session = None
80 80 socket = None
81 81 ioloop = None
82 82 stream = None
83 83 _address = None
84 84 _exiting = False
85 85 proxy_methods = []
86 86
87 87 def __init__(self, context, session, address):
88 88 """Create a channel.
89 89
90 90 Parameters
91 91 ----------
92 92 context : :class:`zmq.Context`
93 93 The ZMQ context to use.
94 94 session : :class:`session.Session`
95 95 The session to use.
96 96 address : zmq url
97 97 Standard (ip, port) tuple that the kernel is listening on.
98 98 """
99 99 super(ZMQSocketChannel, self).__init__()
100 100 self.daemon = True
101 101
102 102 self.context = context
103 103 self.session = session
104 104 if isinstance(address, tuple):
105 105 if address[1] == 0:
106 106 message = 'The port number for a channel cannot be 0.'
107 107 raise InvalidPortNumber(message)
108 108 address = "tcp://%s:%i" % address
109 109 self._address = address
110 110 atexit.register(self._notice_exit)
111 111
112 112 def _notice_exit(self):
113 113 self._exiting = True
114 114
115 115 def _run_loop(self):
116 116 """Run my loop, ignoring EINTR events in the poller"""
117 117 while True:
118 118 try:
119 119 self.ioloop.start()
120 120 except ZMQError as e:
121 121 if e.errno == errno.EINTR:
122 122 continue
123 123 else:
124 124 raise
125 125 except Exception:
126 126 if self._exiting:
127 127 break
128 128 else:
129 129 raise
130 130 else:
131 131 break
132 132
133 133 def stop(self):
134 134 """Stop the channel's event loop and join its thread.
135 135
136 136 This calls :method:`Thread.join` and returns when the thread
137 137 terminates. :class:`RuntimeError` will be raised if
138 138 :method:`self.start` is called again.
139 139 """
140 140 self.join()
141 141
142 142 @property
143 143 def address(self):
144 144 """Get the channel's address as a zmq url string.
145 145
146 146 These URLS have the form: 'tcp://127.0.0.1:5555'.
147 147 """
148 148 return self._address
149 149
150 150 def _queue_send(self, msg):
151 151 """Queue a message to be sent from the IOLoop's thread.
152 152
153 153 Parameters
154 154 ----------
155 155 msg : message to send
156 156
157 157 This is threadsafe, as it uses IOLoop.add_callback to give the loop's
158 158 thread control of the action.
159 159 """
160 160 def thread_send():
161 161 self.session.send(self.stream, msg)
162 162 self.ioloop.add_callback(thread_send)
163 163
164 164 def _handle_recv(self, msg):
165 165 """Callback for stream.on_recv.
166 166
167 167 Unpacks message, and calls handlers with it.
168 168 """
169 169 ident,smsg = self.session.feed_identities(msg)
170 170 self.call_handlers(self.session.unserialize(smsg))
171 171
172 172
173 173
174 174 class ShellChannel(ZMQSocketChannel):
175 175 """The shell channel for issuing request/replies to the kernel."""
176 176
177 177 command_queue = None
178 178 # flag for whether execute requests should be allowed to call raw_input:
179 179 allow_stdin = True
180 180 proxy_methods = [
181 181 'execute',
182 182 'complete',
183 183 'object_info',
184 184 'history',
185 185 'kernel_info',
186 186 'shutdown',
187 187 ]
188 188
189 189 def __init__(self, context, session, address):
190 190 super(ShellChannel, self).__init__(context, session, address)
191 191 self.ioloop = ioloop.IOLoop()
192 192
193 193 def run(self):
194 194 """The thread's main activity. Call start() instead."""
195 195 self.socket = self.context.socket(zmq.DEALER)
196 196 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
197 197 self.socket.connect(self.address)
198 198 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
199 199 self.stream.on_recv(self._handle_recv)
200 200 self._run_loop()
201 201 try:
202 202 self.socket.close()
203 203 except:
204 204 pass
205 205
206 206 def stop(self):
207 207 """Stop the channel's event loop and join its thread."""
208 208 self.ioloop.stop()
209 209 super(ShellChannel, self).stop()
210 210
211 211 def call_handlers(self, msg):
212 212 """This method is called in the ioloop thread when a message arrives.
213 213
214 214 Subclasses should override this method to handle incoming messages.
215 215 It is important to remember that this method is called in the thread
216 216 so that some logic must be done to ensure that the application level
217 217 handlers are called in the application thread.
218 218 """
219 219 raise NotImplementedError('call_handlers must be defined in a subclass.')
220 220
221 221 def execute(self, code, silent=False, store_history=True,
222 222 user_variables=None, user_expressions=None, allow_stdin=None):
223 223 """Execute code in the kernel.
224 224
225 225 Parameters
226 226 ----------
227 227 code : str
228 228 A string of Python code.
229 229
230 230 silent : bool, optional (default False)
231 231 If set, the kernel will execute the code as quietly possible, and
232 232 will force store_history to be False.
233 233
234 234 store_history : bool, optional (default True)
235 235 If set, the kernel will store command history. This is forced
236 236 to be False if silent is True.
237 237
238 238 user_variables : list, optional
239 239 A list of variable names to pull from the user's namespace. They
240 240 will come back as a dict with these names as keys and their
241 241 :func:`repr` as values.
242 242
243 243 user_expressions : dict, optional
244 244 A dict mapping names to expressions to be evaluated in the user's
245 245 dict. The expression values are returned as strings formatted using
246 246 :func:`repr`.
247 247
248 248 allow_stdin : bool, optional (default self.allow_stdin)
249 249 Flag for whether the kernel can send stdin requests to frontends.
250 250
251 251 Some frontends (e.g. the Notebook) do not support stdin requests.
252 252 If raw_input is called from code executed from such a frontend, a
253 253 StdinNotImplementedError will be raised.
254 254
255 255 Returns
256 256 -------
257 257 The msg_id of the message sent.
258 258 """
259 259 if user_variables is None:
260 260 user_variables = []
261 261 if user_expressions is None:
262 262 user_expressions = {}
263 263 if allow_stdin is None:
264 264 allow_stdin = self.allow_stdin
265 265
266 266
267 267 # Don't waste network traffic if inputs are invalid
268 268 if not isinstance(code, string_types):
269 269 raise ValueError('code %r must be a string' % code)
270 270 validate_string_list(user_variables)
271 271 validate_string_dict(user_expressions)
272 272
273 273 # Create class for content/msg creation. Related to, but possibly
274 274 # not in Session.
275 275 content = dict(code=code, silent=silent, store_history=store_history,
276 276 user_variables=user_variables,
277 277 user_expressions=user_expressions,
278 278 allow_stdin=allow_stdin,
279 279 )
280 280 msg = self.session.msg('execute_request', content)
281 281 self._queue_send(msg)
282 282 return msg['header']['msg_id']
283 283
284 284 def complete(self, text, line, cursor_pos, block=None):
285 285 """Tab complete text in the kernel's namespace.
286 286
287 287 Parameters
288 288 ----------
289 289 text : str
290 290 The text to complete.
291 291 line : str
292 292 The full line of text that is the surrounding context for the
293 293 text to complete.
294 294 cursor_pos : int
295 295 The position of the cursor in the line where the completion was
296 296 requested.
297 297 block : str, optional
298 298 The full block of code in which the completion is being requested.
299 299
300 300 Returns
301 301 -------
302 302 The msg_id of the message sent.
303 303 """
304 304 content = dict(text=text, line=line, block=block, cursor_pos=cursor_pos)
305 305 msg = self.session.msg('complete_request', content)
306 306 self._queue_send(msg)
307 307 return msg['header']['msg_id']
308 308
309 309 def object_info(self, oname, detail_level=0):
310 310 """Get metadata information about an object in the kernel's namespace.
311 311
312 312 Parameters
313 313 ----------
314 314 oname : str
315 315 A string specifying the object name.
316 316 detail_level : int, optional
317 317 The level of detail for the introspection (0-2)
318 318
319 319 Returns
320 320 -------
321 321 The msg_id of the message sent.
322 322 """
323 323 content = dict(oname=oname, detail_level=detail_level)
324 324 msg = self.session.msg('object_info_request', content)
325 325 self._queue_send(msg)
326 326 return msg['header']['msg_id']
327 327
328 328 def history(self, raw=True, output=False, hist_access_type='range', **kwargs):
329 329 """Get entries from the kernel's history list.
330 330
331 331 Parameters
332 332 ----------
333 333 raw : bool
334 334 If True, return the raw input.
335 335 output : bool
336 336 If True, then return the output as well.
337 337 hist_access_type : str
338 338 'range' (fill in session, start and stop params), 'tail' (fill in n)
339 339 or 'search' (fill in pattern param).
340 340
341 341 session : int
342 342 For a range request, the session from which to get lines. Session
343 343 numbers are positive integers; negative ones count back from the
344 344 current session.
345 345 start : int
346 346 The first line number of a history range.
347 347 stop : int
348 348 The final (excluded) line number of a history range.
349 349
350 350 n : int
351 351 The number of lines of history to get for a tail request.
352 352
353 353 pattern : str
354 354 The glob-syntax pattern for a search request.
355 355
356 356 Returns
357 357 -------
358 358 The msg_id of the message sent.
359 359 """
360 360 content = dict(raw=raw, output=output, hist_access_type=hist_access_type,
361 361 **kwargs)
362 362 msg = self.session.msg('history_request', content)
363 363 self._queue_send(msg)
364 364 return msg['header']['msg_id']
365 365
366 366 def kernel_info(self):
367 367 """Request kernel info."""
368 368 msg = self.session.msg('kernel_info_request')
369 369 self._queue_send(msg)
370 370 return msg['header']['msg_id']
371 371
372 372 def shutdown(self, restart=False):
373 373 """Request an immediate kernel shutdown.
374 374
375 375 Upon receipt of the (empty) reply, client code can safely assume that
376 376 the kernel has shut down and it's safe to forcefully terminate it if
377 377 it's still alive.
378 378
379 379 The kernel will send the reply via a function registered with Python's
380 380 atexit module, ensuring it's truly done as the kernel is done with all
381 381 normal operation.
382 382 """
383 383 # Send quit message to kernel. Once we implement kernel-side setattr,
384 384 # this should probably be done that way, but for now this will do.
385 385 msg = self.session.msg('shutdown_request', {'restart':restart})
386 386 self._queue_send(msg)
387 387 return msg['header']['msg_id']
388 388
389 389
390 390
391 391 class IOPubChannel(ZMQSocketChannel):
392 392 """The iopub channel which listens for messages that the kernel publishes.
393 393
394 394 This channel is where all output is published to frontends.
395 395 """
396 396
397 397 def __init__(self, context, session, address):
398 398 super(IOPubChannel, self).__init__(context, session, address)
399 399 self.ioloop = ioloop.IOLoop()
400 400
401 401 def run(self):
402 402 """The thread's main activity. Call start() instead."""
403 403 self.socket = self.context.socket(zmq.SUB)
404 404 self.socket.setsockopt(zmq.SUBSCRIBE,b'')
405 405 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
406 406 self.socket.connect(self.address)
407 407 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
408 408 self.stream.on_recv(self._handle_recv)
409 409 self._run_loop()
410 410 try:
411 411 self.socket.close()
412 412 except:
413 413 pass
414 414
415 415 def stop(self):
416 416 """Stop the channel's event loop and join its thread."""
417 417 self.ioloop.stop()
418 418 super(IOPubChannel, self).stop()
419 419
420 420 def call_handlers(self, msg):
421 421 """This method is called in the ioloop thread when a message arrives.
422 422
423 423 Subclasses should override this method to handle incoming messages.
424 424 It is important to remember that this method is called in the thread
425 425 so that some logic must be done to ensure that the application leve
426 426 handlers are called in the application thread.
427 427 """
428 428 raise NotImplementedError('call_handlers must be defined in a subclass.')
429 429
430 430 def flush(self, timeout=1.0):
431 431 """Immediately processes all pending messages on the iopub channel.
432 432
433 433 Callers should use this method to ensure that :method:`call_handlers`
434 434 has been called for all messages that have been received on the
435 435 0MQ SUB socket of this channel.
436 436
437 437 This method is thread safe.
438 438
439 439 Parameters
440 440 ----------
441 441 timeout : float, optional
442 442 The maximum amount of time to spend flushing, in seconds. The
443 443 default is one second.
444 444 """
445 445 # We do the IOLoop callback process twice to ensure that the IOLoop
446 446 # gets to perform at least one full poll.
447 447 stop_time = time.time() + timeout
448 for i in xrange(2):
448 for i in range(2):
449 449 self._flushed = False
450 450 self.ioloop.add_callback(self._flush)
451 451 while not self._flushed and time.time() < stop_time:
452 452 time.sleep(0.01)
453 453
454 454 def _flush(self):
455 455 """Callback for :method:`self.flush`."""
456 456 self.stream.flush()
457 457 self._flushed = True
458 458
459 459
460 460 class StdInChannel(ZMQSocketChannel):
461 461 """The stdin channel to handle raw_input requests that the kernel makes."""
462 462
463 463 msg_queue = None
464 464 proxy_methods = ['input']
465 465
466 466 def __init__(self, context, session, address):
467 467 super(StdInChannel, self).__init__(context, session, address)
468 468 self.ioloop = ioloop.IOLoop()
469 469
470 470 def run(self):
471 471 """The thread's main activity. Call start() instead."""
472 472 self.socket = self.context.socket(zmq.DEALER)
473 473 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
474 474 self.socket.connect(self.address)
475 475 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
476 476 self.stream.on_recv(self._handle_recv)
477 477 self._run_loop()
478 478 try:
479 479 self.socket.close()
480 480 except:
481 481 pass
482 482
483 483 def stop(self):
484 484 """Stop the channel's event loop and join its thread."""
485 485 self.ioloop.stop()
486 486 super(StdInChannel, self).stop()
487 487
488 488 def call_handlers(self, msg):
489 489 """This method is called in the ioloop thread when a message arrives.
490 490
491 491 Subclasses should override this method to handle incoming messages.
492 492 It is important to remember that this method is called in the thread
493 493 so that some logic must be done to ensure that the application leve
494 494 handlers are called in the application thread.
495 495 """
496 496 raise NotImplementedError('call_handlers must be defined in a subclass.')
497 497
498 498 def input(self, string):
499 499 """Send a string of raw input to the kernel."""
500 500 content = dict(value=string)
501 501 msg = self.session.msg('input_reply', content)
502 502 self._queue_send(msg)
503 503
504 504
505 505 class HBChannel(ZMQSocketChannel):
506 506 """The heartbeat channel which monitors the kernel heartbeat.
507 507
508 508 Note that the heartbeat channel is paused by default. As long as you start
509 509 this channel, the kernel manager will ensure that it is paused and un-paused
510 510 as appropriate.
511 511 """
512 512
513 513 time_to_dead = 3.0
514 514 socket = None
515 515 poller = None
516 516 _running = None
517 517 _pause = None
518 518 _beating = None
519 519
520 520 def __init__(self, context, session, address):
521 521 super(HBChannel, self).__init__(context, session, address)
522 522 self._running = False
523 523 self._pause =True
524 524 self.poller = zmq.Poller()
525 525
526 526 def _create_socket(self):
527 527 if self.socket is not None:
528 528 # close previous socket, before opening a new one
529 529 self.poller.unregister(self.socket)
530 530 self.socket.close()
531 531 self.socket = self.context.socket(zmq.REQ)
532 532 self.socket.setsockopt(zmq.LINGER, 0)
533 533 self.socket.connect(self.address)
534 534
535 535 self.poller.register(self.socket, zmq.POLLIN)
536 536
537 537 def _poll(self, start_time):
538 538 """poll for heartbeat replies until we reach self.time_to_dead.
539 539
540 540 Ignores interrupts, and returns the result of poll(), which
541 541 will be an empty list if no messages arrived before the timeout,
542 542 or the event tuple if there is a message to receive.
543 543 """
544 544
545 545 until_dead = self.time_to_dead - (time.time() - start_time)
546 546 # ensure poll at least once
547 547 until_dead = max(until_dead, 1e-3)
548 548 events = []
549 549 while True:
550 550 try:
551 551 events = self.poller.poll(1000 * until_dead)
552 552 except ZMQError as e:
553 553 if e.errno == errno.EINTR:
554 554 # ignore interrupts during heartbeat
555 555 # this may never actually happen
556 556 until_dead = self.time_to_dead - (time.time() - start_time)
557 557 until_dead = max(until_dead, 1e-3)
558 558 pass
559 559 else:
560 560 raise
561 561 except Exception:
562 562 if self._exiting:
563 563 break
564 564 else:
565 565 raise
566 566 else:
567 567 break
568 568 return events
569 569
570 570 def run(self):
571 571 """The thread's main activity. Call start() instead."""
572 572 self._create_socket()
573 573 self._running = True
574 574 self._beating = True
575 575
576 576 while self._running:
577 577 if self._pause:
578 578 # just sleep, and skip the rest of the loop
579 579 time.sleep(self.time_to_dead)
580 580 continue
581 581
582 582 since_last_heartbeat = 0.0
583 583 # io.rprint('Ping from HB channel') # dbg
584 584 # no need to catch EFSM here, because the previous event was
585 585 # either a recv or connect, which cannot be followed by EFSM
586 586 self.socket.send(b'ping')
587 587 request_time = time.time()
588 588 ready = self._poll(request_time)
589 589 if ready:
590 590 self._beating = True
591 591 # the poll above guarantees we have something to recv
592 592 self.socket.recv()
593 593 # sleep the remainder of the cycle
594 594 remainder = self.time_to_dead - (time.time() - request_time)
595 595 if remainder > 0:
596 596 time.sleep(remainder)
597 597 continue
598 598 else:
599 599 # nothing was received within the time limit, signal heart failure
600 600 self._beating = False
601 601 since_last_heartbeat = time.time() - request_time
602 602 self.call_handlers(since_last_heartbeat)
603 603 # and close/reopen the socket, because the REQ/REP cycle has been broken
604 604 self._create_socket()
605 605 continue
606 606 try:
607 607 self.socket.close()
608 608 except:
609 609 pass
610 610
611 611 def pause(self):
612 612 """Pause the heartbeat."""
613 613 self._pause = True
614 614
615 615 def unpause(self):
616 616 """Unpause the heartbeat."""
617 617 self._pause = False
618 618
619 619 def is_beating(self):
620 620 """Is the heartbeat running and responsive (and not paused)."""
621 621 if self.is_alive() and not self._pause and self._beating:
622 622 return True
623 623 else:
624 624 return False
625 625
626 626 def stop(self):
627 627 """Stop the channel's event loop and join its thread."""
628 628 self._running = False
629 629 super(HBChannel, self).stop()
630 630
631 631 def call_handlers(self, since_last_heartbeat):
632 632 """This method is called in the ioloop thread when a message arrives.
633 633
634 634 Subclasses should override this method to handle incoming messages.
635 635 It is important to remember that this method is called in the thread
636 636 so that some logic must be done to ensure that the application level
637 637 handlers are called in the application thread.
638 638 """
639 639 raise NotImplementedError('call_handlers must be defined in a subclass.')
640 640
641 641
642 642 #---------------------------------------------------------------------#-----------------------------------------------------------------------------
643 643 # ABC Registration
644 644 #-----------------------------------------------------------------------------
645 645
646 646 ShellChannelABC.register(ShellChannel)
647 647 IOPubChannelABC.register(IOPubChannel)
648 648 HBChannelABC.register(HBChannel)
649 649 StdInChannelABC.register(StdInChannel)
@@ -1,339 +1,339 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 A module to change reload() so that it acts recursively.
4 4 To enable it type::
5 5
6 6 import __builtin__, deepreload
7 7 __builtin__.reload = deepreload.reload
8 8
9 9 You can then disable it with::
10 10
11 11 __builtin__.reload = deepreload.original_reload
12 12
13 13 Alternatively, you can add a dreload builtin alongside normal reload with::
14 14
15 15 __builtin__.dreload = deepreload.reload
16 16
17 17 This code is almost entirely based on knee.py, which is a Python
18 18 re-implementation of hierarchical module import.
19 19 """
20 20 from __future__ import print_function
21 21 #*****************************************************************************
22 22 # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
23 23 #
24 24 # Distributed under the terms of the BSD License. The full license is in
25 25 # the file COPYING, distributed as part of this software.
26 26 #*****************************************************************************
27 27
28 28 from contextlib import contextmanager
29 29 import imp
30 30 import sys
31 31
32 32 from types import ModuleType
33 33 from warnings import warn
34 34
35 35 from IPython.utils.py3compat import builtin_mod, builtin_mod_name
36 36
37 37 original_import = builtin_mod.__import__
38 38
39 39 @contextmanager
40 40 def replace_import_hook(new_import):
41 41 saved_import = builtin_mod.__import__
42 42 builtin_mod.__import__ = new_import
43 43 try:
44 44 yield
45 45 finally:
46 46 builtin_mod.__import__ = saved_import
47 47
48 48 def get_parent(globals, level):
49 49 """
50 50 parent, name = get_parent(globals, level)
51 51
52 52 Return the package that an import is being performed in. If globals comes
53 53 from the module foo.bar.bat (not itself a package), this returns the
54 54 sys.modules entry for foo.bar. If globals is from a package's __init__.py,
55 55 the package's entry in sys.modules is returned.
56 56
57 57 If globals doesn't come from a package or a module in a package, or a
58 58 corresponding entry is not found in sys.modules, None is returned.
59 59 """
60 60 orig_level = level
61 61
62 62 if not level or not isinstance(globals, dict):
63 63 return None, ''
64 64
65 65 pkgname = globals.get('__package__', None)
66 66
67 67 if pkgname is not None:
68 68 # __package__ is set, so use it
69 69 if not hasattr(pkgname, 'rindex'):
70 70 raise ValueError('__package__ set to non-string')
71 71 if len(pkgname) == 0:
72 72 if level > 0:
73 73 raise ValueError('Attempted relative import in non-package')
74 74 return None, ''
75 75 name = pkgname
76 76 else:
77 77 # __package__ not set, so figure it out and set it
78 78 if '__name__' not in globals:
79 79 return None, ''
80 80 modname = globals['__name__']
81 81
82 82 if '__path__' in globals:
83 83 # __path__ is set, so modname is already the package name
84 84 globals['__package__'] = name = modname
85 85 else:
86 86 # Normal module, so work out the package name if any
87 87 lastdot = modname.rfind('.')
88 88 if lastdot < 0 and level > 0:
89 89 raise ValueError("Attempted relative import in non-package")
90 90 if lastdot < 0:
91 91 globals['__package__'] = None
92 92 return None, ''
93 93 globals['__package__'] = name = modname[:lastdot]
94 94
95 95 dot = len(name)
96 for x in xrange(level, 1, -1):
96 for x in range(level, 1, -1):
97 97 try:
98 98 dot = name.rindex('.', 0, dot)
99 99 except ValueError:
100 100 raise ValueError("attempted relative import beyond top-level "
101 101 "package")
102 102 name = name[:dot]
103 103
104 104 try:
105 105 parent = sys.modules[name]
106 106 except:
107 107 if orig_level < 1:
108 108 warn("Parent module '%.200s' not found while handling absolute "
109 109 "import" % name)
110 110 parent = None
111 111 else:
112 112 raise SystemError("Parent module '%.200s' not loaded, cannot "
113 113 "perform relative import" % name)
114 114
115 115 # We expect, but can't guarantee, if parent != None, that:
116 116 # - parent.__name__ == name
117 117 # - parent.__dict__ is globals
118 118 # If this is violated... Who cares?
119 119 return parent, name
120 120
121 121 def load_next(mod, altmod, name, buf):
122 122 """
123 123 mod, name, buf = load_next(mod, altmod, name, buf)
124 124
125 125 altmod is either None or same as mod
126 126 """
127 127
128 128 if len(name) == 0:
129 129 # completely empty module name should only happen in
130 130 # 'from . import' (or '__import__("")')
131 131 return mod, None, buf
132 132
133 133 dot = name.find('.')
134 134 if dot == 0:
135 135 raise ValueError('Empty module name')
136 136
137 137 if dot < 0:
138 138 subname = name
139 139 next = None
140 140 else:
141 141 subname = name[:dot]
142 142 next = name[dot+1:]
143 143
144 144 if buf != '':
145 145 buf += '.'
146 146 buf += subname
147 147
148 148 result = import_submodule(mod, subname, buf)
149 149 if result is None and mod != altmod:
150 150 result = import_submodule(altmod, subname, subname)
151 151 if result is not None:
152 152 buf = subname
153 153
154 154 if result is None:
155 155 raise ImportError("No module named %.200s" % name)
156 156
157 157 return result, next, buf
158 158
159 159 # Need to keep track of what we've already reloaded to prevent cyclic evil
160 160 found_now = {}
161 161
162 162 def import_submodule(mod, subname, fullname):
163 163 """m = import_submodule(mod, subname, fullname)"""
164 164 # Require:
165 165 # if mod == None: subname == fullname
166 166 # else: mod.__name__ + "." + subname == fullname
167 167
168 168 global found_now
169 169 if fullname in found_now and fullname in sys.modules:
170 170 m = sys.modules[fullname]
171 171 else:
172 172 print('Reloading', fullname)
173 173 found_now[fullname] = 1
174 174 oldm = sys.modules.get(fullname, None)
175 175
176 176 if mod is None:
177 177 path = None
178 178 elif hasattr(mod, '__path__'):
179 179 path = mod.__path__
180 180 else:
181 181 return None
182 182
183 183 try:
184 184 # This appears to be necessary on Python 3, because imp.find_module()
185 185 # tries to import standard libraries (like io) itself, and we don't
186 186 # want them to be processed by our deep_import_hook.
187 187 with replace_import_hook(original_import):
188 188 fp, filename, stuff = imp.find_module(subname, path)
189 189 except ImportError:
190 190 return None
191 191
192 192 try:
193 193 m = imp.load_module(fullname, fp, filename, stuff)
194 194 except:
195 195 # load_module probably removed name from modules because of
196 196 # the error. Put back the original module object.
197 197 if oldm:
198 198 sys.modules[fullname] = oldm
199 199 raise
200 200 finally:
201 201 if fp: fp.close()
202 202
203 203 add_submodule(mod, m, fullname, subname)
204 204
205 205 return m
206 206
207 207 def add_submodule(mod, submod, fullname, subname):
208 208 """mod.{subname} = submod"""
209 209 if mod is None:
210 210 return #Nothing to do here.
211 211
212 212 if submod is None:
213 213 submod = sys.modules[fullname]
214 214
215 215 setattr(mod, subname, submod)
216 216
217 217 return
218 218
219 219 def ensure_fromlist(mod, fromlist, buf, recursive):
220 220 """Handle 'from module import a, b, c' imports."""
221 221 if not hasattr(mod, '__path__'):
222 222 return
223 223 for item in fromlist:
224 224 if not hasattr(item, 'rindex'):
225 225 raise TypeError("Item in ``from list'' not a string")
226 226 if item == '*':
227 227 if recursive:
228 228 continue # avoid endless recursion
229 229 try:
230 230 all = mod.__all__
231 231 except AttributeError:
232 232 pass
233 233 else:
234 234 ret = ensure_fromlist(mod, all, buf, 1)
235 235 if not ret:
236 236 return 0
237 237 elif not hasattr(mod, item):
238 238 import_submodule(mod, item, buf + '.' + item)
239 239
240 240 def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
241 241 """Replacement for __import__()"""
242 242 parent, buf = get_parent(globals, level)
243 243
244 244 head, name, buf = load_next(parent, None if level < 0 else parent, name, buf)
245 245
246 246 tail = head
247 247 while name:
248 248 tail, name, buf = load_next(tail, tail, name, buf)
249 249
250 250 # If tail is None, both get_parent and load_next found
251 251 # an empty module name: someone called __import__("") or
252 252 # doctored faulty bytecode
253 253 if tail is None:
254 254 raise ValueError('Empty module name')
255 255
256 256 if not fromlist:
257 257 return head
258 258
259 259 ensure_fromlist(tail, fromlist, buf, 0)
260 260 return tail
261 261
262 262 modules_reloading = {}
263 263
264 264 def deep_reload_hook(m):
265 265 """Replacement for reload()."""
266 266 if not isinstance(m, ModuleType):
267 267 raise TypeError("reload() argument must be module")
268 268
269 269 name = m.__name__
270 270
271 271 if name not in sys.modules:
272 272 raise ImportError("reload(): module %.200s not in sys.modules" % name)
273 273
274 274 global modules_reloading
275 275 try:
276 276 return modules_reloading[name]
277 277 except:
278 278 modules_reloading[name] = m
279 279
280 280 dot = name.rfind('.')
281 281 if dot < 0:
282 282 subname = name
283 283 path = None
284 284 else:
285 285 try:
286 286 parent = sys.modules[name[:dot]]
287 287 except KeyError:
288 288 modules_reloading.clear()
289 289 raise ImportError("reload(): parent %.200s not in sys.modules" % name[:dot])
290 290 subname = name[dot+1:]
291 291 path = getattr(parent, "__path__", None)
292 292
293 293 try:
294 294 # This appears to be necessary on Python 3, because imp.find_module()
295 295 # tries to import standard libraries (like io) itself, and we don't
296 296 # want them to be processed by our deep_import_hook.
297 297 with replace_import_hook(original_import):
298 298 fp, filename, stuff = imp.find_module(subname, path)
299 299 finally:
300 300 modules_reloading.clear()
301 301
302 302 try:
303 303 newm = imp.load_module(name, fp, filename, stuff)
304 304 except:
305 305 # load_module probably removed name from modules because of
306 306 # the error. Put back the original module object.
307 307 sys.modules[name] = m
308 308 raise
309 309 finally:
310 310 if fp: fp.close()
311 311
312 312 modules_reloading.clear()
313 313 return newm
314 314
315 315 # Save the original hooks
316 316 try:
317 317 original_reload = builtin_mod.reload
318 318 except AttributeError:
319 319 original_reload = imp.reload # Python 3
320 320
321 321 # Replacement for reload()
322 322 def reload(module, exclude=['sys', 'os.path', builtin_mod_name, '__main__']):
323 323 """Recursively reload all modules used in the given module. Optionally
324 324 takes a list of modules to exclude from reloading. The default exclude
325 325 list contains sys, __main__, and __builtin__, to prevent, e.g., resetting
326 326 display, exception, and io hooks.
327 327 """
328 328 global found_now
329 329 for i in exclude:
330 330 found_now[i] = 1
331 331 try:
332 332 with replace_import_hook(deep_import_hook):
333 333 return deep_reload_hook(module)
334 334 finally:
335 335 found_now = {}
336 336
337 337 # Uncomment the following to automatically activate deep reloading whenever
338 338 # this module is imported
339 339 #builtin_mod.reload = reload
@@ -1,790 +1,790 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 Python advanced pretty printer. This pretty printer is intended to
4 4 replace the old `pprint` python module which does not allow developers
5 5 to provide their own pretty print callbacks.
6 6
7 7 This module is based on ruby's `prettyprint.rb` library by `Tanaka Akira`.
8 8
9 9
10 10 Example Usage
11 11 -------------
12 12
13 13 To directly print the representation of an object use `pprint`::
14 14
15 15 from pretty import pprint
16 16 pprint(complex_object)
17 17
18 18 To get a string of the output use `pretty`::
19 19
20 20 from pretty import pretty
21 21 string = pretty(complex_object)
22 22
23 23
24 24 Extending
25 25 ---------
26 26
27 27 The pretty library allows developers to add pretty printing rules for their
28 28 own objects. This process is straightforward. All you have to do is to
29 29 add a `_repr_pretty_` method to your object and call the methods on the
30 30 pretty printer passed::
31 31
32 32 class MyObject(object):
33 33
34 34 def _repr_pretty_(self, p, cycle):
35 35 ...
36 36
37 37 Depending on the python version you want to support you have two
38 38 possibilities. The following list shows the python 2.5 version and the
39 39 compatibility one.
40 40
41 41
42 42 Here the example implementation of a `_repr_pretty_` method for a list
43 43 subclass for python 2.5 and higher (python 2.5 requires the with statement
44 44 __future__ import)::
45 45
46 46 class MyList(list):
47 47
48 48 def _repr_pretty_(self, p, cycle):
49 49 if cycle:
50 50 p.text('MyList(...)')
51 51 else:
52 52 with p.group(8, 'MyList([', '])'):
53 53 for idx, item in enumerate(self):
54 54 if idx:
55 55 p.text(',')
56 56 p.breakable()
57 57 p.pretty(item)
58 58
59 59 The `cycle` parameter is `True` if pretty detected a cycle. You *have* to
60 60 react to that or the result is an infinite loop. `p.text()` just adds
61 61 non breaking text to the output, `p.breakable()` either adds a whitespace
62 62 or breaks here. If you pass it an argument it's used instead of the
63 63 default space. `p.pretty` prettyprints another object using the pretty print
64 64 method.
65 65
66 66 The first parameter to the `group` function specifies the extra indentation
67 67 of the next line. In this example the next item will either be not
68 68 breaked (if the items are short enough) or aligned with the right edge of
69 69 the opening bracked of `MyList`.
70 70
71 71 If you want to support python 2.4 and lower you can use this code::
72 72
73 73 class MyList(list):
74 74
75 75 def _repr_pretty_(self, p, cycle):
76 76 if cycle:
77 77 p.text('MyList(...)')
78 78 else:
79 79 p.begin_group(8, 'MyList([')
80 80 for idx, item in enumerate(self):
81 81 if idx:
82 82 p.text(',')
83 83 p.breakable()
84 84 p.pretty(item)
85 85 p.end_group(8, '])')
86 86
87 87 If you just want to indent something you can use the group function
88 88 without open / close parameters. Under python 2.5 you can also use this
89 89 code::
90 90
91 91 with p.indent(2):
92 92 ...
93 93
94 94 Or under python2.4 you might want to modify ``p.indentation`` by hand but
95 95 this is rather ugly.
96 96
97 97 Inheritance diagram:
98 98
99 99 .. inheritance-diagram:: IPython.lib.pretty
100 100 :parts: 3
101 101
102 102 :copyright: 2007 by Armin Ronacher.
103 103 Portions (c) 2009 by Robert Kern.
104 104 :license: BSD License.
105 105 """
106 106 from __future__ import print_function
107 107 from contextlib import contextmanager
108 108 import sys
109 109 import types
110 110 import re
111 111 import datetime
112 112 from io import StringIO
113 113 from collections import deque
114 114
115 115
116 116 __all__ = ['pretty', 'pprint', 'PrettyPrinter', 'RepresentationPrinter',
117 117 'for_type', 'for_type_by_name']
118 118
119 119
120 120 _re_pattern_type = type(re.compile(''))
121 121
122 122
123 123 def pretty(obj, verbose=False, max_width=79, newline='\n'):
124 124 """
125 125 Pretty print the object's representation.
126 126 """
127 127 stream = StringIO()
128 128 printer = RepresentationPrinter(stream, verbose, max_width, newline)
129 129 printer.pretty(obj)
130 130 printer.flush()
131 131 return stream.getvalue()
132 132
133 133
134 134 def pprint(obj, verbose=False, max_width=79, newline='\n'):
135 135 """
136 136 Like `pretty` but print to stdout.
137 137 """
138 138 printer = RepresentationPrinter(sys.stdout, verbose, max_width, newline)
139 139 printer.pretty(obj)
140 140 printer.flush()
141 141 sys.stdout.write(newline)
142 142 sys.stdout.flush()
143 143
144 144 class _PrettyPrinterBase(object):
145 145
146 146 @contextmanager
147 147 def indent(self, indent):
148 148 """with statement support for indenting/dedenting."""
149 149 self.indentation += indent
150 150 try:
151 151 yield
152 152 finally:
153 153 self.indentation -= indent
154 154
155 155 @contextmanager
156 156 def group(self, indent=0, open='', close=''):
157 157 """like begin_group / end_group but for the with statement."""
158 158 self.begin_group(indent, open)
159 159 try:
160 160 yield
161 161 finally:
162 162 self.end_group(indent, close)
163 163
164 164 class PrettyPrinter(_PrettyPrinterBase):
165 165 """
166 166 Baseclass for the `RepresentationPrinter` prettyprinter that is used to
167 167 generate pretty reprs of objects. Contrary to the `RepresentationPrinter`
168 168 this printer knows nothing about the default pprinters or the `_repr_pretty_`
169 169 callback method.
170 170 """
171 171
172 172 def __init__(self, output, max_width=79, newline='\n'):
173 173 self.output = output
174 174 self.max_width = max_width
175 175 self.newline = newline
176 176 self.output_width = 0
177 177 self.buffer_width = 0
178 178 self.buffer = deque()
179 179
180 180 root_group = Group(0)
181 181 self.group_stack = [root_group]
182 182 self.group_queue = GroupQueue(root_group)
183 183 self.indentation = 0
184 184
185 185 def _break_outer_groups(self):
186 186 while self.max_width < self.output_width + self.buffer_width:
187 187 group = self.group_queue.deq()
188 188 if not group:
189 189 return
190 190 while group.breakables:
191 191 x = self.buffer.popleft()
192 192 self.output_width = x.output(self.output, self.output_width)
193 193 self.buffer_width -= x.width
194 194 while self.buffer and isinstance(self.buffer[0], Text):
195 195 x = self.buffer.popleft()
196 196 self.output_width = x.output(self.output, self.output_width)
197 197 self.buffer_width -= x.width
198 198
199 199 def text(self, obj):
200 200 """Add literal text to the output."""
201 201 width = len(obj)
202 202 if self.buffer:
203 203 text = self.buffer[-1]
204 204 if not isinstance(text, Text):
205 205 text = Text()
206 206 self.buffer.append(text)
207 207 text.add(obj, width)
208 208 self.buffer_width += width
209 209 self._break_outer_groups()
210 210 else:
211 211 self.output.write(obj)
212 212 self.output_width += width
213 213
214 214 def breakable(self, sep=' '):
215 215 """
216 216 Add a breakable separator to the output. This does not mean that it
217 217 will automatically break here. If no breaking on this position takes
218 218 place the `sep` is inserted which default to one space.
219 219 """
220 220 width = len(sep)
221 221 group = self.group_stack[-1]
222 222 if group.want_break:
223 223 self.flush()
224 224 self.output.write(self.newline)
225 225 self.output.write(' ' * self.indentation)
226 226 self.output_width = self.indentation
227 227 self.buffer_width = 0
228 228 else:
229 229 self.buffer.append(Breakable(sep, width, self))
230 230 self.buffer_width += width
231 231 self._break_outer_groups()
232 232
233 233 def break_(self):
234 234 """
235 235 Explicitly insert a newline into the output, maintaining correct indentation.
236 236 """
237 237 self.flush()
238 238 self.output.write(self.newline)
239 239 self.output.write(' ' * self.indentation)
240 240 self.output_width = self.indentation
241 241 self.buffer_width = 0
242 242
243 243
244 244 def begin_group(self, indent=0, open=''):
245 245 """
246 246 Begin a group. If you want support for python < 2.5 which doesn't has
247 247 the with statement this is the preferred way:
248 248
249 249 p.begin_group(1, '{')
250 250 ...
251 251 p.end_group(1, '}')
252 252
253 253 The python 2.5 expression would be this:
254 254
255 255 with p.group(1, '{', '}'):
256 256 ...
257 257
258 258 The first parameter specifies the indentation for the next line (usually
259 259 the width of the opening text), the second the opening text. All
260 260 parameters are optional.
261 261 """
262 262 if open:
263 263 self.text(open)
264 264 group = Group(self.group_stack[-1].depth + 1)
265 265 self.group_stack.append(group)
266 266 self.group_queue.enq(group)
267 267 self.indentation += indent
268 268
269 269 def end_group(self, dedent=0, close=''):
270 270 """End a group. See `begin_group` for more details."""
271 271 self.indentation -= dedent
272 272 group = self.group_stack.pop()
273 273 if not group.breakables:
274 274 self.group_queue.remove(group)
275 275 if close:
276 276 self.text(close)
277 277
278 278 def flush(self):
279 279 """Flush data that is left in the buffer."""
280 280 for data in self.buffer:
281 281 self.output_width += data.output(self.output, self.output_width)
282 282 self.buffer.clear()
283 283 self.buffer_width = 0
284 284
285 285
286 286 def _get_mro(obj_class):
287 287 """ Get a reasonable method resolution order of a class and its superclasses
288 288 for both old-style and new-style classes.
289 289 """
290 290 if not hasattr(obj_class, '__mro__'):
291 291 # Old-style class. Mix in object to make a fake new-style class.
292 292 try:
293 293 obj_class = type(obj_class.__name__, (obj_class, object), {})
294 294 except TypeError:
295 295 # Old-style extension type that does not descend from object.
296 296 # FIXME: try to construct a more thorough MRO.
297 297 mro = [obj_class]
298 298 else:
299 299 mro = obj_class.__mro__[1:-1]
300 300 else:
301 301 mro = obj_class.__mro__
302 302 return mro
303 303
304 304
305 305 class RepresentationPrinter(PrettyPrinter):
306 306 """
307 307 Special pretty printer that has a `pretty` method that calls the pretty
308 308 printer for a python object.
309 309
310 310 This class stores processing data on `self` so you must *never* use
311 311 this class in a threaded environment. Always lock it or reinstanciate
312 312 it.
313 313
314 314 Instances also have a verbose flag callbacks can access to control their
315 315 output. For example the default instance repr prints all attributes and
316 316 methods that are not prefixed by an underscore if the printer is in
317 317 verbose mode.
318 318 """
319 319
320 320 def __init__(self, output, verbose=False, max_width=79, newline='\n',
321 321 singleton_pprinters=None, type_pprinters=None, deferred_pprinters=None):
322 322
323 323 PrettyPrinter.__init__(self, output, max_width, newline)
324 324 self.verbose = verbose
325 325 self.stack = []
326 326 if singleton_pprinters is None:
327 327 singleton_pprinters = _singleton_pprinters.copy()
328 328 self.singleton_pprinters = singleton_pprinters
329 329 if type_pprinters is None:
330 330 type_pprinters = _type_pprinters.copy()
331 331 self.type_pprinters = type_pprinters
332 332 if deferred_pprinters is None:
333 333 deferred_pprinters = _deferred_type_pprinters.copy()
334 334 self.deferred_pprinters = deferred_pprinters
335 335
336 336 def pretty(self, obj):
337 337 """Pretty print the given object."""
338 338 obj_id = id(obj)
339 339 cycle = obj_id in self.stack
340 340 self.stack.append(obj_id)
341 341 self.begin_group()
342 342 try:
343 343 obj_class = getattr(obj, '__class__', None) or type(obj)
344 344 # First try to find registered singleton printers for the type.
345 345 try:
346 346 printer = self.singleton_pprinters[obj_id]
347 347 except (TypeError, KeyError):
348 348 pass
349 349 else:
350 350 return printer(obj, self, cycle)
351 351 # Next walk the mro and check for either:
352 352 # 1) a registered printer
353 353 # 2) a _repr_pretty_ method
354 354 for cls in _get_mro(obj_class):
355 355 if cls in self.type_pprinters:
356 356 # printer registered in self.type_pprinters
357 357 return self.type_pprinters[cls](obj, self, cycle)
358 358 else:
359 359 # deferred printer
360 360 printer = self._in_deferred_types(cls)
361 361 if printer is not None:
362 362 return printer(obj, self, cycle)
363 363 else:
364 364 # Finally look for special method names.
365 365 # Some objects automatically create any requested
366 366 # attribute. Try to ignore most of them by checking for
367 367 # callability.
368 368 if '_repr_pretty_' in cls.__dict__:
369 369 meth = cls._repr_pretty_
370 370 if callable(meth):
371 371 return meth(obj, self, cycle)
372 372 return _default_pprint(obj, self, cycle)
373 373 finally:
374 374 self.end_group()
375 375 self.stack.pop()
376 376
377 377 def _in_deferred_types(self, cls):
378 378 """
379 379 Check if the given class is specified in the deferred type registry.
380 380
381 381 Returns the printer from the registry if it exists, and None if the
382 382 class is not in the registry. Successful matches will be moved to the
383 383 regular type registry for future use.
384 384 """
385 385 mod = getattr(cls, '__module__', None)
386 386 name = getattr(cls, '__name__', None)
387 387 key = (mod, name)
388 388 printer = None
389 389 if key in self.deferred_pprinters:
390 390 # Move the printer over to the regular registry.
391 391 printer = self.deferred_pprinters.pop(key)
392 392 self.type_pprinters[cls] = printer
393 393 return printer
394 394
395 395
396 396 class Printable(object):
397 397
398 398 def output(self, stream, output_width):
399 399 return output_width
400 400
401 401
402 402 class Text(Printable):
403 403
404 404 def __init__(self):
405 405 self.objs = []
406 406 self.width = 0
407 407
408 408 def output(self, stream, output_width):
409 409 for obj in self.objs:
410 410 stream.write(obj)
411 411 return output_width + self.width
412 412
413 413 def add(self, obj, width):
414 414 self.objs.append(obj)
415 415 self.width += width
416 416
417 417
418 418 class Breakable(Printable):
419 419
420 420 def __init__(self, seq, width, pretty):
421 421 self.obj = seq
422 422 self.width = width
423 423 self.pretty = pretty
424 424 self.indentation = pretty.indentation
425 425 self.group = pretty.group_stack[-1]
426 426 self.group.breakables.append(self)
427 427
428 428 def output(self, stream, output_width):
429 429 self.group.breakables.popleft()
430 430 if self.group.want_break:
431 431 stream.write(self.pretty.newline)
432 432 stream.write(' ' * self.indentation)
433 433 return self.indentation
434 434 if not self.group.breakables:
435 435 self.pretty.group_queue.remove(self.group)
436 436 stream.write(self.obj)
437 437 return output_width + self.width
438 438
439 439
440 440 class Group(Printable):
441 441
442 442 def __init__(self, depth):
443 443 self.depth = depth
444 444 self.breakables = deque()
445 445 self.want_break = False
446 446
447 447
448 448 class GroupQueue(object):
449 449
450 450 def __init__(self, *groups):
451 451 self.queue = []
452 452 for group in groups:
453 453 self.enq(group)
454 454
455 455 def enq(self, group):
456 456 depth = group.depth
457 457 while depth > len(self.queue) - 1:
458 458 self.queue.append([])
459 459 self.queue[depth].append(group)
460 460
461 461 def deq(self):
462 462 for stack in self.queue:
463 463 for idx, group in enumerate(reversed(stack)):
464 464 if group.breakables:
465 465 del stack[idx]
466 466 group.want_break = True
467 467 return group
468 468 for group in stack:
469 469 group.want_break = True
470 470 del stack[:]
471 471
472 472 def remove(self, group):
473 473 try:
474 474 self.queue[group.depth].remove(group)
475 475 except ValueError:
476 476 pass
477 477
478 478 try:
479 479 _baseclass_reprs = (object.__repr__, types.InstanceType.__repr__)
480 480 except AttributeError: # Python 3
481 481 _baseclass_reprs = (object.__repr__,)
482 482
483 483
484 484 def _default_pprint(obj, p, cycle):
485 485 """
486 486 The default print function. Used if an object does not provide one and
487 487 it's none of the builtin objects.
488 488 """
489 489 klass = getattr(obj, '__class__', None) or type(obj)
490 490 if getattr(klass, '__repr__', None) not in _baseclass_reprs:
491 491 # A user-provided repr. Find newlines and replace them with p.break_()
492 492 output = repr(obj)
493 493 for idx,output_line in enumerate(output.splitlines()):
494 494 if idx:
495 495 p.break_()
496 496 p.text(output_line)
497 497 return
498 498 p.begin_group(1, '<')
499 499 p.pretty(klass)
500 500 p.text(' at 0x%x' % id(obj))
501 501 if cycle:
502 502 p.text(' ...')
503 503 elif p.verbose:
504 504 first = True
505 505 for key in dir(obj):
506 506 if not key.startswith('_'):
507 507 try:
508 508 value = getattr(obj, key)
509 509 except AttributeError:
510 510 continue
511 511 if isinstance(value, types.MethodType):
512 512 continue
513 513 if not first:
514 514 p.text(',')
515 515 p.breakable()
516 516 p.text(key)
517 517 p.text('=')
518 518 step = len(key) + 1
519 519 p.indentation += step
520 520 p.pretty(value)
521 521 p.indentation -= step
522 522 first = False
523 523 p.end_group(1, '>')
524 524
525 525
526 526 def _seq_pprinter_factory(start, end, basetype):
527 527 """
528 528 Factory that returns a pprint function useful for sequences. Used by
529 529 the default pprint for tuples, dicts, and lists.
530 530 """
531 531 def inner(obj, p, cycle):
532 532 typ = type(obj)
533 533 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
534 534 # If the subclass provides its own repr, use it instead.
535 535 return p.text(typ.__repr__(obj))
536 536
537 537 if cycle:
538 538 return p.text(start + '...' + end)
539 539 step = len(start)
540 540 p.begin_group(step, start)
541 541 for idx, x in enumerate(obj):
542 542 if idx:
543 543 p.text(',')
544 544 p.breakable()
545 545 p.pretty(x)
546 546 if len(obj) == 1 and type(obj) is tuple:
547 547 # Special case for 1-item tuples.
548 548 p.text(',')
549 549 p.end_group(step, end)
550 550 return inner
551 551
552 552
553 553 def _set_pprinter_factory(start, end, basetype):
554 554 """
555 555 Factory that returns a pprint function useful for sets and frozensets.
556 556 """
557 557 def inner(obj, p, cycle):
558 558 typ = type(obj)
559 559 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
560 560 # If the subclass provides its own repr, use it instead.
561 561 return p.text(typ.__repr__(obj))
562 562
563 563 if cycle:
564 564 return p.text(start + '...' + end)
565 565 if len(obj) == 0:
566 566 # Special case.
567 567 p.text(basetype.__name__ + '()')
568 568 else:
569 569 step = len(start)
570 570 p.begin_group(step, start)
571 571 # Like dictionary keys, we will try to sort the items.
572 572 items = list(obj)
573 573 try:
574 574 items.sort()
575 575 except Exception:
576 576 # Sometimes the items don't sort.
577 577 pass
578 578 for idx, x in enumerate(items):
579 579 if idx:
580 580 p.text(',')
581 581 p.breakable()
582 582 p.pretty(x)
583 583 p.end_group(step, end)
584 584 return inner
585 585
586 586
587 587 def _dict_pprinter_factory(start, end, basetype=None):
588 588 """
589 589 Factory that returns a pprint function used by the default pprint of
590 590 dicts and dict proxies.
591 591 """
592 592 def inner(obj, p, cycle):
593 593 typ = type(obj)
594 594 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
595 595 # If the subclass provides its own repr, use it instead.
596 596 return p.text(typ.__repr__(obj))
597 597
598 598 if cycle:
599 599 return p.text('{...}')
600 600 p.begin_group(1, start)
601 601 keys = obj.keys()
602 602 try:
603 603 keys.sort()
604 604 except Exception as e:
605 605 # Sometimes the keys don't sort.
606 606 pass
607 607 for idx, key in enumerate(keys):
608 608 if idx:
609 609 p.text(',')
610 610 p.breakable()
611 611 p.pretty(key)
612 612 p.text(': ')
613 613 p.pretty(obj[key])
614 614 p.end_group(1, end)
615 615 return inner
616 616
617 617
618 618 def _super_pprint(obj, p, cycle):
619 619 """The pprint for the super type."""
620 620 p.begin_group(8, '<super: ')
621 621 p.pretty(obj.__self_class__)
622 622 p.text(',')
623 623 p.breakable()
624 624 p.pretty(obj.__self__)
625 625 p.end_group(8, '>')
626 626
627 627
628 628 def _re_pattern_pprint(obj, p, cycle):
629 629 """The pprint function for regular expression patterns."""
630 630 p.text('re.compile(')
631 631 pattern = repr(obj.pattern)
632 632 if pattern[:1] in 'uU':
633 633 pattern = pattern[1:]
634 634 prefix = 'ur'
635 635 else:
636 636 prefix = 'r'
637 637 pattern = prefix + pattern.replace('\\\\', '\\')
638 638 p.text(pattern)
639 639 if obj.flags:
640 640 p.text(',')
641 641 p.breakable()
642 642 done_one = False
643 643 for flag in ('TEMPLATE', 'IGNORECASE', 'LOCALE', 'MULTILINE', 'DOTALL',
644 644 'UNICODE', 'VERBOSE', 'DEBUG'):
645 645 if obj.flags & getattr(re, flag):
646 646 if done_one:
647 647 p.text('|')
648 648 p.text('re.' + flag)
649 649 done_one = True
650 650 p.text(')')
651 651
652 652
653 653 def _type_pprint(obj, p, cycle):
654 654 """The pprint for classes and types."""
655 655 mod = getattr(obj, '__module__', None)
656 656 if mod is None:
657 657 # Heap allocated types might not have the module attribute,
658 658 # and others may set it to None.
659 659 return p.text(obj.__name__)
660 660
661 661 if mod in ('__builtin__', 'builtins', 'exceptions'):
662 662 name = obj.__name__
663 663 else:
664 664 name = mod + '.' + obj.__name__
665 665 p.text(name)
666 666
667 667
668 668 def _repr_pprint(obj, p, cycle):
669 669 """A pprint that just redirects to the normal repr function."""
670 670 p.text(repr(obj))
671 671
672 672
673 673 def _function_pprint(obj, p, cycle):
674 674 """Base pprint for all functions and builtin functions."""
675 675 if obj.__module__ in ('__builtin__', 'builtins', 'exceptions') or not obj.__module__:
676 676 name = obj.__name__
677 677 else:
678 678 name = obj.__module__ + '.' + obj.__name__
679 679 p.text('<function %s>' % name)
680 680
681 681
682 682 def _exception_pprint(obj, p, cycle):
683 683 """Base pprint for all exceptions."""
684 684 if obj.__class__.__module__ in ('exceptions', 'builtins'):
685 685 name = obj.__class__.__name__
686 686 else:
687 687 name = '%s.%s' % (
688 688 obj.__class__.__module__,
689 689 obj.__class__.__name__
690 690 )
691 691 step = len(name) + 1
692 692 p.begin_group(step, name + '(')
693 693 for idx, arg in enumerate(getattr(obj, 'args', ())):
694 694 if idx:
695 695 p.text(',')
696 696 p.breakable()
697 697 p.pretty(arg)
698 698 p.end_group(step, ')')
699 699
700 700
701 701 #: the exception base
702 702 try:
703 703 _exception_base = BaseException
704 704 except NameError:
705 705 _exception_base = Exception
706 706
707 707
708 708 #: printers for builtin types
709 709 _type_pprinters = {
710 710 int: _repr_pprint,
711 long: _repr_pprint,
712 711 float: _repr_pprint,
713 712 str: _repr_pprint,
714 unicode: _repr_pprint,
715 713 tuple: _seq_pprinter_factory('(', ')', tuple),
716 714 list: _seq_pprinter_factory('[', ']', list),
717 715 dict: _dict_pprinter_factory('{', '}', dict),
718 716
719 717 set: _set_pprinter_factory('{', '}', set),
720 718 frozenset: _set_pprinter_factory('frozenset({', '})', frozenset),
721 719 super: _super_pprint,
722 720 _re_pattern_type: _re_pattern_pprint,
723 721 type: _type_pprint,
724 722 types.FunctionType: _function_pprint,
725 723 types.BuiltinFunctionType: _function_pprint,
726 724 types.SliceType: _repr_pprint,
727 725 types.MethodType: _repr_pprint,
728 726
729 727 datetime.datetime: _repr_pprint,
730 728 datetime.timedelta: _repr_pprint,
731 729 _exception_base: _exception_pprint
732 730 }
733 731
734 732 try:
735 733 _type_pprinters[types.DictProxyType] = _dict_pprinter_factory('<dictproxy {', '}>')
736 734 _type_pprinters[types.ClassType] = _type_pprint
737 except AttributeError: # Python 3
738 pass
735 _type_pprinters[long] = _repr_pprint
736 _type_pprinters[unicode] = _repr_pprint
737 except (AttributeError, NameError): # Python 3
738 _type_pprinters[bytes] = _repr_pprint
739 739
740 740 try:
741 741 _type_pprinters[xrange] = _repr_pprint
742 742 except NameError:
743 743 _type_pprinters[range] = _repr_pprint
744 744
745 745 #: printers for types specified by name
746 746 _deferred_type_pprinters = {
747 747 }
748 748
749 749 def for_type(typ, func):
750 750 """
751 751 Add a pretty printer for a given type.
752 752 """
753 753 oldfunc = _type_pprinters.get(typ, None)
754 754 if func is not None:
755 755 # To support easy restoration of old pprinters, we need to ignore Nones.
756 756 _type_pprinters[typ] = func
757 757 return oldfunc
758 758
759 759 def for_type_by_name(type_module, type_name, func):
760 760 """
761 761 Add a pretty printer for a type specified by the module and name of a type
762 762 rather than the type object itself.
763 763 """
764 764 key = (type_module, type_name)
765 765 oldfunc = _deferred_type_pprinters.get(key, None)
766 766 if func is not None:
767 767 # To support easy restoration of old pprinters, we need to ignore Nones.
768 768 _deferred_type_pprinters[key] = func
769 769 return oldfunc
770 770
771 771
772 772 #: printers for the default singletons
773 773 _singleton_pprinters = dict.fromkeys(map(id, [None, True, False, Ellipsis,
774 774 NotImplemented]), _repr_pprint)
775 775
776 776
777 777 if __name__ == '__main__':
778 778 from random import randrange
779 779 class Foo(object):
780 780 def __init__(self):
781 781 self.foo = 1
782 782 self.bar = re.compile(r'\s+')
783 783 self.blub = dict.fromkeys(range(30), randrange(1, 40))
784 784 self.hehe = 23424.234234
785 785 self.list = ["blub", "blah", self]
786 786
787 787 def get_foo(self):
788 788 print("foo")
789 789
790 790 pprint(Foo(), verbose=True)
@@ -1,1859 +1,1854 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 from __future__ import print_function
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 import os
20 20 import json
21 21 import sys
22 22 from threading import Thread, Event
23 23 import time
24 24 import warnings
25 25 from datetime import datetime
26 26 from getpass import getpass
27 27 from pprint import pprint
28 28
29 29 pjoin = os.path.join
30 30
31 31 import zmq
32 32 # from zmq.eventloop import ioloop, zmqstream
33 33
34 34 from IPython.config.configurable import MultipleInstanceError
35 35 from IPython.core.application import BaseIPythonApplication
36 36 from IPython.core.profiledir import ProfileDir, ProfileDirError
37 37
38 38 from IPython.utils.capture import RichOutput
39 39 from IPython.utils.coloransi import TermColors
40 40 from IPython.utils.jsonutil import rekey
41 41 from IPython.utils.localinterfaces import localhost, is_local_ip
42 42 from IPython.utils.path import get_ipython_dir
43 from IPython.utils.py3compat import cast_bytes, string_types
43 from IPython.utils.py3compat import cast_bytes, string_types, xrange
44 44 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
45 45 Dict, List, Bool, Set, Any)
46 46 from IPython.external.decorator import decorator
47 47 from IPython.external.ssh import tunnel
48 48
49 49 from IPython.parallel import Reference
50 50 from IPython.parallel import error
51 51 from IPython.parallel import util
52 52
53 53 from IPython.kernel.zmq.session import Session, Message
54 54 from IPython.kernel.zmq import serialize
55 55
56 56 from .asyncresult import AsyncResult, AsyncHubResult
57 57 from .view import DirectView, LoadBalancedView
58 58
59 if sys.version_info[0] >= 3:
60 # xrange is used in a couple 'isinstance' tests in py2
61 # should be just 'range' in 3k
62 xrange = range
63
64 59 #--------------------------------------------------------------------------
65 60 # Decorators for Client methods
66 61 #--------------------------------------------------------------------------
67 62
68 63 @decorator
69 64 def spin_first(f, self, *args, **kwargs):
70 65 """Call spin() to sync state prior to calling the method."""
71 66 self.spin()
72 67 return f(self, *args, **kwargs)
73 68
74 69
75 70 #--------------------------------------------------------------------------
76 71 # Classes
77 72 #--------------------------------------------------------------------------
78 73
79 74
80 75 class ExecuteReply(RichOutput):
81 76 """wrapper for finished Execute results"""
82 77 def __init__(self, msg_id, content, metadata):
83 78 self.msg_id = msg_id
84 79 self._content = content
85 80 self.execution_count = content['execution_count']
86 81 self.metadata = metadata
87 82
88 83 # RichOutput overrides
89 84
90 85 @property
91 86 def source(self):
92 87 pyout = self.metadata['pyout']
93 88 if pyout:
94 89 return pyout.get('source', '')
95 90
96 91 @property
97 92 def data(self):
98 93 pyout = self.metadata['pyout']
99 94 if pyout:
100 95 return pyout.get('data', {})
101 96
102 97 @property
103 98 def _metadata(self):
104 99 pyout = self.metadata['pyout']
105 100 if pyout:
106 101 return pyout.get('metadata', {})
107 102
108 103 def display(self):
109 104 from IPython.display import publish_display_data
110 105 publish_display_data(self.source, self.data, self.metadata)
111 106
112 107 def _repr_mime_(self, mime):
113 108 if mime not in self.data:
114 109 return
115 110 data = self.data[mime]
116 111 if mime in self._metadata:
117 112 return data, self._metadata[mime]
118 113 else:
119 114 return data
120 115
121 116 def __getitem__(self, key):
122 117 return self.metadata[key]
123 118
124 119 def __getattr__(self, key):
125 120 if key not in self.metadata:
126 121 raise AttributeError(key)
127 122 return self.metadata[key]
128 123
129 124 def __repr__(self):
130 125 pyout = self.metadata['pyout'] or {'data':{}}
131 126 text_out = pyout['data'].get('text/plain', '')
132 127 if len(text_out) > 32:
133 128 text_out = text_out[:29] + '...'
134 129
135 130 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
136 131
137 132 def _repr_pretty_(self, p, cycle):
138 133 pyout = self.metadata['pyout'] or {'data':{}}
139 134 text_out = pyout['data'].get('text/plain', '')
140 135
141 136 if not text_out:
142 137 return
143 138
144 139 try:
145 140 ip = get_ipython()
146 141 except NameError:
147 142 colors = "NoColor"
148 143 else:
149 144 colors = ip.colors
150 145
151 146 if colors == "NoColor":
152 147 out = normal = ""
153 148 else:
154 149 out = TermColors.Red
155 150 normal = TermColors.Normal
156 151
157 152 if '\n' in text_out and not text_out.startswith('\n'):
158 153 # add newline for multiline reprs
159 154 text_out = '\n' + text_out
160 155
161 156 p.text(
162 157 out + u'Out[%i:%i]: ' % (
163 158 self.metadata['engine_id'], self.execution_count
164 159 ) + normal + text_out
165 160 )
166 161
167 162
168 163 class Metadata(dict):
169 164 """Subclass of dict for initializing metadata values.
170 165
171 166 Attribute access works on keys.
172 167
173 168 These objects have a strict set of keys - errors will raise if you try
174 169 to add new keys.
175 170 """
176 171 def __init__(self, *args, **kwargs):
177 172 dict.__init__(self)
178 173 md = {'msg_id' : None,
179 174 'submitted' : None,
180 175 'started' : None,
181 176 'completed' : None,
182 177 'received' : None,
183 178 'engine_uuid' : None,
184 179 'engine_id' : None,
185 180 'follow' : None,
186 181 'after' : None,
187 182 'status' : None,
188 183
189 184 'pyin' : None,
190 185 'pyout' : None,
191 186 'pyerr' : None,
192 187 'stdout' : '',
193 188 'stderr' : '',
194 189 'outputs' : [],
195 190 'data': {},
196 191 'outputs_ready' : False,
197 192 }
198 193 self.update(md)
199 194 self.update(dict(*args, **kwargs))
200 195
201 196 def __getattr__(self, key):
202 197 """getattr aliased to getitem"""
203 198 if key in self.iterkeys():
204 199 return self[key]
205 200 else:
206 201 raise AttributeError(key)
207 202
208 203 def __setattr__(self, key, value):
209 204 """setattr aliased to setitem, with strict"""
210 205 if key in self.iterkeys():
211 206 self[key] = value
212 207 else:
213 208 raise AttributeError(key)
214 209
215 210 def __setitem__(self, key, value):
216 211 """strict static key enforcement"""
217 212 if key in self.iterkeys():
218 213 dict.__setitem__(self, key, value)
219 214 else:
220 215 raise KeyError(key)
221 216
222 217
223 218 class Client(HasTraits):
224 219 """A semi-synchronous client to the IPython ZMQ cluster
225 220
226 221 Parameters
227 222 ----------
228 223
229 224 url_file : str/unicode; path to ipcontroller-client.json
230 225 This JSON file should contain all the information needed to connect to a cluster,
231 226 and is likely the only argument needed.
232 227 Connection information for the Hub's registration. If a json connector
233 228 file is given, then likely no further configuration is necessary.
234 229 [Default: use profile]
235 230 profile : bytes
236 231 The name of the Cluster profile to be used to find connector information.
237 232 If run from an IPython application, the default profile will be the same
238 233 as the running application, otherwise it will be 'default'.
239 234 cluster_id : str
240 235 String id to added to runtime files, to prevent name collisions when using
241 236 multiple clusters with a single profile simultaneously.
242 237 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
243 238 Since this is text inserted into filenames, typical recommendations apply:
244 239 Simple character strings are ideal, and spaces are not recommended (but
245 240 should generally work)
246 241 context : zmq.Context
247 242 Pass an existing zmq.Context instance, otherwise the client will create its own.
248 243 debug : bool
249 244 flag for lots of message printing for debug purposes
250 245 timeout : int/float
251 246 time (in seconds) to wait for connection replies from the Hub
252 247 [Default: 10]
253 248
254 249 #-------------- session related args ----------------
255 250
256 251 config : Config object
257 252 If specified, this will be relayed to the Session for configuration
258 253 username : str
259 254 set username for the session object
260 255
261 256 #-------------- ssh related args ----------------
262 257 # These are args for configuring the ssh tunnel to be used
263 258 # credentials are used to forward connections over ssh to the Controller
264 259 # Note that the ip given in `addr` needs to be relative to sshserver
265 260 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
266 261 # and set sshserver as the same machine the Controller is on. However,
267 262 # the only requirement is that sshserver is able to see the Controller
268 263 # (i.e. is within the same trusted network).
269 264
270 265 sshserver : str
271 266 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
272 267 If keyfile or password is specified, and this is not, it will default to
273 268 the ip given in addr.
274 269 sshkey : str; path to ssh private key file
275 270 This specifies a key to be used in ssh login, default None.
276 271 Regular default ssh keys will be used without specifying this argument.
277 272 password : str
278 273 Your ssh password to sshserver. Note that if this is left None,
279 274 you will be prompted for it if passwordless key based login is unavailable.
280 275 paramiko : bool
281 276 flag for whether to use paramiko instead of shell ssh for tunneling.
282 277 [default: True on win32, False else]
283 278
284 279
285 280 Attributes
286 281 ----------
287 282
288 283 ids : list of int engine IDs
289 284 requesting the ids attribute always synchronizes
290 285 the registration state. To request ids without synchronization,
291 286 use semi-private _ids attributes.
292 287
293 288 history : list of msg_ids
294 289 a list of msg_ids, keeping track of all the execution
295 290 messages you have submitted in order.
296 291
297 292 outstanding : set of msg_ids
298 293 a set of msg_ids that have been submitted, but whose
299 294 results have not yet been received.
300 295
301 296 results : dict
302 297 a dict of all our results, keyed by msg_id
303 298
304 299 block : bool
305 300 determines default behavior when block not specified
306 301 in execution methods
307 302
308 303 Methods
309 304 -------
310 305
311 306 spin
312 307 flushes incoming results and registration state changes
313 308 control methods spin, and requesting `ids` also ensures up to date
314 309
315 310 wait
316 311 wait on one or more msg_ids
317 312
318 313 execution methods
319 314 apply
320 315 legacy: execute, run
321 316
322 317 data movement
323 318 push, pull, scatter, gather
324 319
325 320 query methods
326 321 queue_status, get_result, purge, result_status
327 322
328 323 control methods
329 324 abort, shutdown
330 325
331 326 """
332 327
333 328
334 329 block = Bool(False)
335 330 outstanding = Set()
336 331 results = Instance('collections.defaultdict', (dict,))
337 332 metadata = Instance('collections.defaultdict', (Metadata,))
338 333 history = List()
339 334 debug = Bool(False)
340 335 _spin_thread = Any()
341 336 _stop_spinning = Any()
342 337
343 338 profile=Unicode()
344 339 def _profile_default(self):
345 340 if BaseIPythonApplication.initialized():
346 341 # an IPython app *might* be running, try to get its profile
347 342 try:
348 343 return BaseIPythonApplication.instance().profile
349 344 except (AttributeError, MultipleInstanceError):
350 345 # could be a *different* subclass of config.Application,
351 346 # which would raise one of these two errors.
352 347 return u'default'
353 348 else:
354 349 return u'default'
355 350
356 351
357 352 _outstanding_dict = Instance('collections.defaultdict', (set,))
358 353 _ids = List()
359 354 _connected=Bool(False)
360 355 _ssh=Bool(False)
361 356 _context = Instance('zmq.Context')
362 357 _config = Dict()
363 358 _engines=Instance(util.ReverseDict, (), {})
364 359 # _hub_socket=Instance('zmq.Socket')
365 360 _query_socket=Instance('zmq.Socket')
366 361 _control_socket=Instance('zmq.Socket')
367 362 _iopub_socket=Instance('zmq.Socket')
368 363 _notification_socket=Instance('zmq.Socket')
369 364 _mux_socket=Instance('zmq.Socket')
370 365 _task_socket=Instance('zmq.Socket')
371 366 _task_scheme=Unicode()
372 367 _closed = False
373 368 _ignored_control_replies=Integer(0)
374 369 _ignored_hub_replies=Integer(0)
375 370
376 371 def __new__(self, *args, **kw):
377 372 # don't raise on positional args
378 373 return HasTraits.__new__(self, **kw)
379 374
380 375 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
381 376 context=None, debug=False,
382 377 sshserver=None, sshkey=None, password=None, paramiko=None,
383 378 timeout=10, cluster_id=None, **extra_args
384 379 ):
385 380 if profile:
386 381 super(Client, self).__init__(debug=debug, profile=profile)
387 382 else:
388 383 super(Client, self).__init__(debug=debug)
389 384 if context is None:
390 385 context = zmq.Context.instance()
391 386 self._context = context
392 387 self._stop_spinning = Event()
393 388
394 389 if 'url_or_file' in extra_args:
395 390 url_file = extra_args['url_or_file']
396 391 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
397 392
398 393 if url_file and util.is_url(url_file):
399 394 raise ValueError("single urls cannot be specified, url-files must be used.")
400 395
401 396 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
402 397
403 398 if self._cd is not None:
404 399 if url_file is None:
405 400 if not cluster_id:
406 401 client_json = 'ipcontroller-client.json'
407 402 else:
408 403 client_json = 'ipcontroller-%s-client.json' % cluster_id
409 404 url_file = pjoin(self._cd.security_dir, client_json)
410 405 if url_file is None:
411 406 raise ValueError(
412 407 "I can't find enough information to connect to a hub!"
413 408 " Please specify at least one of url_file or profile."
414 409 )
415 410
416 411 with open(url_file) as f:
417 412 cfg = json.load(f)
418 413
419 414 self._task_scheme = cfg['task_scheme']
420 415
421 416 # sync defaults from args, json:
422 417 if sshserver:
423 418 cfg['ssh'] = sshserver
424 419
425 420 location = cfg.setdefault('location', None)
426 421
427 422 proto,addr = cfg['interface'].split('://')
428 423 addr = util.disambiguate_ip_address(addr, location)
429 424 cfg['interface'] = "%s://%s" % (proto, addr)
430 425
431 426 # turn interface,port into full urls:
432 427 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
433 428 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
434 429
435 430 url = cfg['registration']
436 431
437 432 if location is not None and addr == localhost():
438 433 # location specified, and connection is expected to be local
439 434 if not is_local_ip(location) and not sshserver:
440 435 # load ssh from JSON *only* if the controller is not on
441 436 # this machine
442 437 sshserver=cfg['ssh']
443 438 if not is_local_ip(location) and not sshserver:
444 439 # warn if no ssh specified, but SSH is probably needed
445 440 # This is only a warning, because the most likely cause
446 441 # is a local Controller on a laptop whose IP is dynamic
447 442 warnings.warn("""
448 443 Controller appears to be listening on localhost, but not on this machine.
449 444 If this is true, you should specify Client(...,sshserver='you@%s')
450 445 or instruct your controller to listen on an external IP."""%location,
451 446 RuntimeWarning)
452 447 elif not sshserver:
453 448 # otherwise sync with cfg
454 449 sshserver = cfg['ssh']
455 450
456 451 self._config = cfg
457 452
458 453 self._ssh = bool(sshserver or sshkey or password)
459 454 if self._ssh and sshserver is None:
460 455 # default to ssh via localhost
461 456 sshserver = addr
462 457 if self._ssh and password is None:
463 458 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
464 459 password=False
465 460 else:
466 461 password = getpass("SSH Password for %s: "%sshserver)
467 462 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
468 463
469 464 # configure and construct the session
470 465 try:
471 466 extra_args['packer'] = cfg['pack']
472 467 extra_args['unpacker'] = cfg['unpack']
473 468 extra_args['key'] = cast_bytes(cfg['key'])
474 469 extra_args['signature_scheme'] = cfg['signature_scheme']
475 470 except KeyError as exc:
476 471 msg = '\n'.join([
477 472 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
478 473 "If you are reusing connection files, remove them and start ipcontroller again."
479 474 ])
480 475 raise ValueError(msg.format(exc.message))
481 476
482 477 self.session = Session(**extra_args)
483 478
484 479 self._query_socket = self._context.socket(zmq.DEALER)
485 480
486 481 if self._ssh:
487 482 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
488 483 else:
489 484 self._query_socket.connect(cfg['registration'])
490 485
491 486 self.session.debug = self.debug
492 487
493 488 self._notification_handlers = {'registration_notification' : self._register_engine,
494 489 'unregistration_notification' : self._unregister_engine,
495 490 'shutdown_notification' : lambda msg: self.close(),
496 491 }
497 492 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
498 493 'apply_reply' : self._handle_apply_reply}
499 494
500 495 try:
501 496 self._connect(sshserver, ssh_kwargs, timeout)
502 497 except:
503 498 self.close(linger=0)
504 499 raise
505 500
506 501 # last step: setup magics, if we are in IPython:
507 502
508 503 try:
509 504 ip = get_ipython()
510 505 except NameError:
511 506 return
512 507 else:
513 508 if 'px' not in ip.magics_manager.magics:
514 509 # in IPython but we are the first Client.
515 510 # activate a default view for parallel magics.
516 511 self.activate()
517 512
518 513 def __del__(self):
519 514 """cleanup sockets, but _not_ context."""
520 515 self.close()
521 516
522 517 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
523 518 if ipython_dir is None:
524 519 ipython_dir = get_ipython_dir()
525 520 if profile_dir is not None:
526 521 try:
527 522 self._cd = ProfileDir.find_profile_dir(profile_dir)
528 523 return
529 524 except ProfileDirError:
530 525 pass
531 526 elif profile is not None:
532 527 try:
533 528 self._cd = ProfileDir.find_profile_dir_by_name(
534 529 ipython_dir, profile)
535 530 return
536 531 except ProfileDirError:
537 532 pass
538 533 self._cd = None
539 534
540 535 def _update_engines(self, engines):
541 536 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
542 537 for k,v in engines.iteritems():
543 538 eid = int(k)
544 539 if eid not in self._engines:
545 540 self._ids.append(eid)
546 541 self._engines[eid] = v
547 542 self._ids = sorted(self._ids)
548 543 if sorted(self._engines.keys()) != range(len(self._engines)) and \
549 544 self._task_scheme == 'pure' and self._task_socket:
550 545 self._stop_scheduling_tasks()
551 546
552 547 def _stop_scheduling_tasks(self):
553 548 """Stop scheduling tasks because an engine has been unregistered
554 549 from a pure ZMQ scheduler.
555 550 """
556 551 self._task_socket.close()
557 552 self._task_socket = None
558 553 msg = "An engine has been unregistered, and we are using pure " +\
559 554 "ZMQ task scheduling. Task farming will be disabled."
560 555 if self.outstanding:
561 556 msg += " If you were running tasks when this happened, " +\
562 557 "some `outstanding` msg_ids may never resolve."
563 558 warnings.warn(msg, RuntimeWarning)
564 559
565 560 def _build_targets(self, targets):
566 561 """Turn valid target IDs or 'all' into two lists:
567 562 (int_ids, uuids).
568 563 """
569 564 if not self._ids:
570 565 # flush notification socket if no engines yet, just in case
571 566 if not self.ids:
572 567 raise error.NoEnginesRegistered("Can't build targets without any engines")
573 568
574 569 if targets is None:
575 570 targets = self._ids
576 571 elif isinstance(targets, string_types):
577 572 if targets.lower() == 'all':
578 573 targets = self._ids
579 574 else:
580 575 raise TypeError("%r not valid str target, must be 'all'"%(targets))
581 576 elif isinstance(targets, int):
582 577 if targets < 0:
583 578 targets = self.ids[targets]
584 579 if targets not in self._ids:
585 580 raise IndexError("No such engine: %i"%targets)
586 581 targets = [targets]
587 582
588 583 if isinstance(targets, slice):
589 584 indices = range(len(self._ids))[targets]
590 585 ids = self.ids
591 586 targets = [ ids[i] for i in indices ]
592 587
593 588 if not isinstance(targets, (tuple, list, xrange)):
594 589 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
595 590
596 591 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
597 592
598 593 def _connect(self, sshserver, ssh_kwargs, timeout):
599 594 """setup all our socket connections to the cluster. This is called from
600 595 __init__."""
601 596
602 597 # Maybe allow reconnecting?
603 598 if self._connected:
604 599 return
605 600 self._connected=True
606 601
607 602 def connect_socket(s, url):
608 603 if self._ssh:
609 604 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
610 605 else:
611 606 return s.connect(url)
612 607
613 608 self.session.send(self._query_socket, 'connection_request')
614 609 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
615 610 poller = zmq.Poller()
616 611 poller.register(self._query_socket, zmq.POLLIN)
617 612 # poll expects milliseconds, timeout is seconds
618 613 evts = poller.poll(timeout*1000)
619 614 if not evts:
620 615 raise error.TimeoutError("Hub connection request timed out")
621 616 idents,msg = self.session.recv(self._query_socket,mode=0)
622 617 if self.debug:
623 618 pprint(msg)
624 619 content = msg['content']
625 620 # self._config['registration'] = dict(content)
626 621 cfg = self._config
627 622 if content['status'] == 'ok':
628 623 self._mux_socket = self._context.socket(zmq.DEALER)
629 624 connect_socket(self._mux_socket, cfg['mux'])
630 625
631 626 self._task_socket = self._context.socket(zmq.DEALER)
632 627 connect_socket(self._task_socket, cfg['task'])
633 628
634 629 self._notification_socket = self._context.socket(zmq.SUB)
635 630 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
636 631 connect_socket(self._notification_socket, cfg['notification'])
637 632
638 633 self._control_socket = self._context.socket(zmq.DEALER)
639 634 connect_socket(self._control_socket, cfg['control'])
640 635
641 636 self._iopub_socket = self._context.socket(zmq.SUB)
642 637 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
643 638 connect_socket(self._iopub_socket, cfg['iopub'])
644 639
645 640 self._update_engines(dict(content['engines']))
646 641 else:
647 642 self._connected = False
648 643 raise Exception("Failed to connect!")
649 644
650 645 #--------------------------------------------------------------------------
651 646 # handlers and callbacks for incoming messages
652 647 #--------------------------------------------------------------------------
653 648
654 649 def _unwrap_exception(self, content):
655 650 """unwrap exception, and remap engine_id to int."""
656 651 e = error.unwrap_exception(content)
657 652 # print e.traceback
658 653 if e.engine_info:
659 654 e_uuid = e.engine_info['engine_uuid']
660 655 eid = self._engines[e_uuid]
661 656 e.engine_info['engine_id'] = eid
662 657 return e
663 658
664 659 def _extract_metadata(self, msg):
665 660 header = msg['header']
666 661 parent = msg['parent_header']
667 662 msg_meta = msg['metadata']
668 663 content = msg['content']
669 664 md = {'msg_id' : parent['msg_id'],
670 665 'received' : datetime.now(),
671 666 'engine_uuid' : msg_meta.get('engine', None),
672 667 'follow' : msg_meta.get('follow', []),
673 668 'after' : msg_meta.get('after', []),
674 669 'status' : content['status'],
675 670 }
676 671
677 672 if md['engine_uuid'] is not None:
678 673 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
679 674
680 675 if 'date' in parent:
681 676 md['submitted'] = parent['date']
682 677 if 'started' in msg_meta:
683 678 md['started'] = msg_meta['started']
684 679 if 'date' in header:
685 680 md['completed'] = header['date']
686 681 return md
687 682
688 683 def _register_engine(self, msg):
689 684 """Register a new engine, and update our connection info."""
690 685 content = msg['content']
691 686 eid = content['id']
692 687 d = {eid : content['uuid']}
693 688 self._update_engines(d)
694 689
695 690 def _unregister_engine(self, msg):
696 691 """Unregister an engine that has died."""
697 692 content = msg['content']
698 693 eid = int(content['id'])
699 694 if eid in self._ids:
700 695 self._ids.remove(eid)
701 696 uuid = self._engines.pop(eid)
702 697
703 698 self._handle_stranded_msgs(eid, uuid)
704 699
705 700 if self._task_socket and self._task_scheme == 'pure':
706 701 self._stop_scheduling_tasks()
707 702
708 703 def _handle_stranded_msgs(self, eid, uuid):
709 704 """Handle messages known to be on an engine when the engine unregisters.
710 705
711 706 It is possible that this will fire prematurely - that is, an engine will
712 707 go down after completing a result, and the client will be notified
713 708 of the unregistration and later receive the successful result.
714 709 """
715 710
716 711 outstanding = self._outstanding_dict[uuid]
717 712
718 713 for msg_id in list(outstanding):
719 714 if msg_id in self.results:
720 715 # we already
721 716 continue
722 717 try:
723 718 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
724 719 except:
725 720 content = error.wrap_exception()
726 721 # build a fake message:
727 722 msg = self.session.msg('apply_reply', content=content)
728 723 msg['parent_header']['msg_id'] = msg_id
729 724 msg['metadata']['engine'] = uuid
730 725 self._handle_apply_reply(msg)
731 726
732 727 def _handle_execute_reply(self, msg):
733 728 """Save the reply to an execute_request into our results.
734 729
735 730 execute messages are never actually used. apply is used instead.
736 731 """
737 732
738 733 parent = msg['parent_header']
739 734 msg_id = parent['msg_id']
740 735 if msg_id not in self.outstanding:
741 736 if msg_id in self.history:
742 737 print(("got stale result: %s"%msg_id))
743 738 else:
744 739 print(("got unknown result: %s"%msg_id))
745 740 else:
746 741 self.outstanding.remove(msg_id)
747 742
748 743 content = msg['content']
749 744 header = msg['header']
750 745
751 746 # construct metadata:
752 747 md = self.metadata[msg_id]
753 748 md.update(self._extract_metadata(msg))
754 749 # is this redundant?
755 750 self.metadata[msg_id] = md
756 751
757 752 e_outstanding = self._outstanding_dict[md['engine_uuid']]
758 753 if msg_id in e_outstanding:
759 754 e_outstanding.remove(msg_id)
760 755
761 756 # construct result:
762 757 if content['status'] == 'ok':
763 758 self.results[msg_id] = ExecuteReply(msg_id, content, md)
764 759 elif content['status'] == 'aborted':
765 760 self.results[msg_id] = error.TaskAborted(msg_id)
766 761 elif content['status'] == 'resubmitted':
767 762 # TODO: handle resubmission
768 763 pass
769 764 else:
770 765 self.results[msg_id] = self._unwrap_exception(content)
771 766
772 767 def _handle_apply_reply(self, msg):
773 768 """Save the reply to an apply_request into our results."""
774 769 parent = msg['parent_header']
775 770 msg_id = parent['msg_id']
776 771 if msg_id not in self.outstanding:
777 772 if msg_id in self.history:
778 773 print(("got stale result: %s"%msg_id))
779 774 print(self.results[msg_id])
780 775 print(msg)
781 776 else:
782 777 print(("got unknown result: %s"%msg_id))
783 778 else:
784 779 self.outstanding.remove(msg_id)
785 780 content = msg['content']
786 781 header = msg['header']
787 782
788 783 # construct metadata:
789 784 md = self.metadata[msg_id]
790 785 md.update(self._extract_metadata(msg))
791 786 # is this redundant?
792 787 self.metadata[msg_id] = md
793 788
794 789 e_outstanding = self._outstanding_dict[md['engine_uuid']]
795 790 if msg_id in e_outstanding:
796 791 e_outstanding.remove(msg_id)
797 792
798 793 # construct result:
799 794 if content['status'] == 'ok':
800 795 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
801 796 elif content['status'] == 'aborted':
802 797 self.results[msg_id] = error.TaskAborted(msg_id)
803 798 elif content['status'] == 'resubmitted':
804 799 # TODO: handle resubmission
805 800 pass
806 801 else:
807 802 self.results[msg_id] = self._unwrap_exception(content)
808 803
809 804 def _flush_notifications(self):
810 805 """Flush notifications of engine registrations waiting
811 806 in ZMQ queue."""
812 807 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
813 808 while msg is not None:
814 809 if self.debug:
815 810 pprint(msg)
816 811 msg_type = msg['header']['msg_type']
817 812 handler = self._notification_handlers.get(msg_type, None)
818 813 if handler is None:
819 814 raise Exception("Unhandled message type: %s" % msg_type)
820 815 else:
821 816 handler(msg)
822 817 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
823 818
824 819 def _flush_results(self, sock):
825 820 """Flush task or queue results waiting in ZMQ queue."""
826 821 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
827 822 while msg is not None:
828 823 if self.debug:
829 824 pprint(msg)
830 825 msg_type = msg['header']['msg_type']
831 826 handler = self._queue_handlers.get(msg_type, None)
832 827 if handler is None:
833 828 raise Exception("Unhandled message type: %s" % msg_type)
834 829 else:
835 830 handler(msg)
836 831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
837 832
838 833 def _flush_control(self, sock):
839 834 """Flush replies from the control channel waiting
840 835 in the ZMQ queue.
841 836
842 837 Currently: ignore them."""
843 838 if self._ignored_control_replies <= 0:
844 839 return
845 840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
846 841 while msg is not None:
847 842 self._ignored_control_replies -= 1
848 843 if self.debug:
849 844 pprint(msg)
850 845 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
851 846
852 847 def _flush_ignored_control(self):
853 848 """flush ignored control replies"""
854 849 while self._ignored_control_replies > 0:
855 850 self.session.recv(self._control_socket)
856 851 self._ignored_control_replies -= 1
857 852
858 853 def _flush_ignored_hub_replies(self):
859 854 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
860 855 while msg is not None:
861 856 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
862 857
863 858 def _flush_iopub(self, sock):
864 859 """Flush replies from the iopub channel waiting
865 860 in the ZMQ queue.
866 861 """
867 862 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
868 863 while msg is not None:
869 864 if self.debug:
870 865 pprint(msg)
871 866 parent = msg['parent_header']
872 867 # ignore IOPub messages with no parent.
873 868 # Caused by print statements or warnings from before the first execution.
874 869 if not parent:
875 870 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
876 871 continue
877 872 msg_id = parent['msg_id']
878 873 content = msg['content']
879 874 header = msg['header']
880 875 msg_type = msg['header']['msg_type']
881 876
882 877 # init metadata:
883 878 md = self.metadata[msg_id]
884 879
885 880 if msg_type == 'stream':
886 881 name = content['name']
887 882 s = md[name] or ''
888 883 md[name] = s + content['data']
889 884 elif msg_type == 'pyerr':
890 885 md.update({'pyerr' : self._unwrap_exception(content)})
891 886 elif msg_type == 'pyin':
892 887 md.update({'pyin' : content['code']})
893 888 elif msg_type == 'display_data':
894 889 md['outputs'].append(content)
895 890 elif msg_type == 'pyout':
896 891 md['pyout'] = content
897 892 elif msg_type == 'data_message':
898 893 data, remainder = serialize.unserialize_object(msg['buffers'])
899 894 md['data'].update(data)
900 895 elif msg_type == 'status':
901 896 # idle message comes after all outputs
902 897 if content['execution_state'] == 'idle':
903 898 md['outputs_ready'] = True
904 899 else:
905 900 # unhandled msg_type (status, etc.)
906 901 pass
907 902
908 903 # reduntant?
909 904 self.metadata[msg_id] = md
910 905
911 906 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
912 907
913 908 #--------------------------------------------------------------------------
914 909 # len, getitem
915 910 #--------------------------------------------------------------------------
916 911
917 912 def __len__(self):
918 913 """len(client) returns # of engines."""
919 914 return len(self.ids)
920 915
921 916 def __getitem__(self, key):
922 917 """index access returns DirectView multiplexer objects
923 918
924 919 Must be int, slice, or list/tuple/xrange of ints"""
925 920 if not isinstance(key, (int, slice, tuple, list, xrange)):
926 921 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
927 922 else:
928 923 return self.direct_view(key)
929 924
930 925 #--------------------------------------------------------------------------
931 926 # Begin public methods
932 927 #--------------------------------------------------------------------------
933 928
934 929 @property
935 930 def ids(self):
936 931 """Always up-to-date ids property."""
937 932 self._flush_notifications()
938 933 # always copy:
939 934 return list(self._ids)
940 935
941 936 def activate(self, targets='all', suffix=''):
942 937 """Create a DirectView and register it with IPython magics
943 938
944 939 Defines the magics `%px, %autopx, %pxresult, %%px`
945 940
946 941 Parameters
947 942 ----------
948 943
949 944 targets: int, list of ints, or 'all'
950 945 The engines on which the view's magics will run
951 946 suffix: str [default: '']
952 947 The suffix, if any, for the magics. This allows you to have
953 948 multiple views associated with parallel magics at the same time.
954 949
955 950 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
956 951 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
957 952 on engine 0.
958 953 """
959 954 view = self.direct_view(targets)
960 955 view.block = True
961 956 view.activate(suffix)
962 957 return view
963 958
964 959 def close(self, linger=None):
965 960 """Close my zmq Sockets
966 961
967 962 If `linger`, set the zmq LINGER socket option,
968 963 which allows discarding of messages.
969 964 """
970 965 if self._closed:
971 966 return
972 967 self.stop_spin_thread()
973 968 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
974 969 for name in snames:
975 970 socket = getattr(self, name)
976 971 if socket is not None and not socket.closed:
977 972 if linger is not None:
978 973 socket.close(linger=linger)
979 974 else:
980 975 socket.close()
981 976 self._closed = True
982 977
983 978 def _spin_every(self, interval=1):
984 979 """target func for use in spin_thread"""
985 980 while True:
986 981 if self._stop_spinning.is_set():
987 982 return
988 983 time.sleep(interval)
989 984 self.spin()
990 985
991 986 def spin_thread(self, interval=1):
992 987 """call Client.spin() in a background thread on some regular interval
993 988
994 989 This helps ensure that messages don't pile up too much in the zmq queue
995 990 while you are working on other things, or just leaving an idle terminal.
996 991
997 992 It also helps limit potential padding of the `received` timestamp
998 993 on AsyncResult objects, used for timings.
999 994
1000 995 Parameters
1001 996 ----------
1002 997
1003 998 interval : float, optional
1004 999 The interval on which to spin the client in the background thread
1005 1000 (simply passed to time.sleep).
1006 1001
1007 1002 Notes
1008 1003 -----
1009 1004
1010 1005 For precision timing, you may want to use this method to put a bound
1011 1006 on the jitter (in seconds) in `received` timestamps used
1012 1007 in AsyncResult.wall_time.
1013 1008
1014 1009 """
1015 1010 if self._spin_thread is not None:
1016 1011 self.stop_spin_thread()
1017 1012 self._stop_spinning.clear()
1018 1013 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1019 1014 self._spin_thread.daemon = True
1020 1015 self._spin_thread.start()
1021 1016
1022 1017 def stop_spin_thread(self):
1023 1018 """stop background spin_thread, if any"""
1024 1019 if self._spin_thread is not None:
1025 1020 self._stop_spinning.set()
1026 1021 self._spin_thread.join()
1027 1022 self._spin_thread = None
1028 1023
1029 1024 def spin(self):
1030 1025 """Flush any registration notifications and execution results
1031 1026 waiting in the ZMQ queue.
1032 1027 """
1033 1028 if self._notification_socket:
1034 1029 self._flush_notifications()
1035 1030 if self._iopub_socket:
1036 1031 self._flush_iopub(self._iopub_socket)
1037 1032 if self._mux_socket:
1038 1033 self._flush_results(self._mux_socket)
1039 1034 if self._task_socket:
1040 1035 self._flush_results(self._task_socket)
1041 1036 if self._control_socket:
1042 1037 self._flush_control(self._control_socket)
1043 1038 if self._query_socket:
1044 1039 self._flush_ignored_hub_replies()
1045 1040
1046 1041 def wait(self, jobs=None, timeout=-1):
1047 1042 """waits on one or more `jobs`, for up to `timeout` seconds.
1048 1043
1049 1044 Parameters
1050 1045 ----------
1051 1046
1052 1047 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1053 1048 ints are indices to self.history
1054 1049 strs are msg_ids
1055 1050 default: wait on all outstanding messages
1056 1051 timeout : float
1057 1052 a time in seconds, after which to give up.
1058 1053 default is -1, which means no timeout
1059 1054
1060 1055 Returns
1061 1056 -------
1062 1057
1063 1058 True : when all msg_ids are done
1064 1059 False : timeout reached, some msg_ids still outstanding
1065 1060 """
1066 1061 tic = time.time()
1067 1062 if jobs is None:
1068 1063 theids = self.outstanding
1069 1064 else:
1070 1065 if isinstance(jobs, string_types + (int, AsyncResult)):
1071 1066 jobs = [jobs]
1072 1067 theids = set()
1073 1068 for job in jobs:
1074 1069 if isinstance(job, int):
1075 1070 # index access
1076 1071 job = self.history[job]
1077 1072 elif isinstance(job, AsyncResult):
1078 1073 map(theids.add, job.msg_ids)
1079 1074 continue
1080 1075 theids.add(job)
1081 1076 if not theids.intersection(self.outstanding):
1082 1077 return True
1083 1078 self.spin()
1084 1079 while theids.intersection(self.outstanding):
1085 1080 if timeout >= 0 and ( time.time()-tic ) > timeout:
1086 1081 break
1087 1082 time.sleep(1e-3)
1088 1083 self.spin()
1089 1084 return len(theids.intersection(self.outstanding)) == 0
1090 1085
1091 1086 #--------------------------------------------------------------------------
1092 1087 # Control methods
1093 1088 #--------------------------------------------------------------------------
1094 1089
1095 1090 @spin_first
1096 1091 def clear(self, targets=None, block=None):
1097 1092 """Clear the namespace in target(s)."""
1098 1093 block = self.block if block is None else block
1099 1094 targets = self._build_targets(targets)[0]
1100 1095 for t in targets:
1101 1096 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1102 1097 error = False
1103 1098 if block:
1104 1099 self._flush_ignored_control()
1105 1100 for i in range(len(targets)):
1106 1101 idents,msg = self.session.recv(self._control_socket,0)
1107 1102 if self.debug:
1108 1103 pprint(msg)
1109 1104 if msg['content']['status'] != 'ok':
1110 1105 error = self._unwrap_exception(msg['content'])
1111 1106 else:
1112 1107 self._ignored_control_replies += len(targets)
1113 1108 if error:
1114 1109 raise error
1115 1110
1116 1111
1117 1112 @spin_first
1118 1113 def abort(self, jobs=None, targets=None, block=None):
1119 1114 """Abort specific jobs from the execution queues of target(s).
1120 1115
1121 1116 This is a mechanism to prevent jobs that have already been submitted
1122 1117 from executing.
1123 1118
1124 1119 Parameters
1125 1120 ----------
1126 1121
1127 1122 jobs : msg_id, list of msg_ids, or AsyncResult
1128 1123 The jobs to be aborted
1129 1124
1130 1125 If unspecified/None: abort all outstanding jobs.
1131 1126
1132 1127 """
1133 1128 block = self.block if block is None else block
1134 1129 jobs = jobs if jobs is not None else list(self.outstanding)
1135 1130 targets = self._build_targets(targets)[0]
1136 1131
1137 1132 msg_ids = []
1138 1133 if isinstance(jobs, string_types + (AsyncResult,)):
1139 1134 jobs = [jobs]
1140 1135 bad_ids = filter(lambda obj: not isinstance(obj, string_types + (AsyncResult,)), jobs)
1141 1136 if bad_ids:
1142 1137 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1143 1138 for j in jobs:
1144 1139 if isinstance(j, AsyncResult):
1145 1140 msg_ids.extend(j.msg_ids)
1146 1141 else:
1147 1142 msg_ids.append(j)
1148 1143 content = dict(msg_ids=msg_ids)
1149 1144 for t in targets:
1150 1145 self.session.send(self._control_socket, 'abort_request',
1151 1146 content=content, ident=t)
1152 1147 error = False
1153 1148 if block:
1154 1149 self._flush_ignored_control()
1155 1150 for i in range(len(targets)):
1156 1151 idents,msg = self.session.recv(self._control_socket,0)
1157 1152 if self.debug:
1158 1153 pprint(msg)
1159 1154 if msg['content']['status'] != 'ok':
1160 1155 error = self._unwrap_exception(msg['content'])
1161 1156 else:
1162 1157 self._ignored_control_replies += len(targets)
1163 1158 if error:
1164 1159 raise error
1165 1160
1166 1161 @spin_first
1167 1162 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1168 1163 """Terminates one or more engine processes, optionally including the hub.
1169 1164
1170 1165 Parameters
1171 1166 ----------
1172 1167
1173 1168 targets: list of ints or 'all' [default: all]
1174 1169 Which engines to shutdown.
1175 1170 hub: bool [default: False]
1176 1171 Whether to include the Hub. hub=True implies targets='all'.
1177 1172 block: bool [default: self.block]
1178 1173 Whether to wait for clean shutdown replies or not.
1179 1174 restart: bool [default: False]
1180 1175 NOT IMPLEMENTED
1181 1176 whether to restart engines after shutting them down.
1182 1177 """
1183 1178 from IPython.parallel.error import NoEnginesRegistered
1184 1179 if restart:
1185 1180 raise NotImplementedError("Engine restart is not yet implemented")
1186 1181
1187 1182 block = self.block if block is None else block
1188 1183 if hub:
1189 1184 targets = 'all'
1190 1185 try:
1191 1186 targets = self._build_targets(targets)[0]
1192 1187 except NoEnginesRegistered:
1193 1188 targets = []
1194 1189 for t in targets:
1195 1190 self.session.send(self._control_socket, 'shutdown_request',
1196 1191 content={'restart':restart},ident=t)
1197 1192 error = False
1198 1193 if block or hub:
1199 1194 self._flush_ignored_control()
1200 1195 for i in range(len(targets)):
1201 1196 idents,msg = self.session.recv(self._control_socket, 0)
1202 1197 if self.debug:
1203 1198 pprint(msg)
1204 1199 if msg['content']['status'] != 'ok':
1205 1200 error = self._unwrap_exception(msg['content'])
1206 1201 else:
1207 1202 self._ignored_control_replies += len(targets)
1208 1203
1209 1204 if hub:
1210 1205 time.sleep(0.25)
1211 1206 self.session.send(self._query_socket, 'shutdown_request')
1212 1207 idents,msg = self.session.recv(self._query_socket, 0)
1213 1208 if self.debug:
1214 1209 pprint(msg)
1215 1210 if msg['content']['status'] != 'ok':
1216 1211 error = self._unwrap_exception(msg['content'])
1217 1212
1218 1213 if error:
1219 1214 raise error
1220 1215
1221 1216 #--------------------------------------------------------------------------
1222 1217 # Execution related methods
1223 1218 #--------------------------------------------------------------------------
1224 1219
1225 1220 def _maybe_raise(self, result):
1226 1221 """wrapper for maybe raising an exception if apply failed."""
1227 1222 if isinstance(result, error.RemoteError):
1228 1223 raise result
1229 1224
1230 1225 return result
1231 1226
1232 1227 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1233 1228 ident=None):
1234 1229 """construct and send an apply message via a socket.
1235 1230
1236 1231 This is the principal method with which all engine execution is performed by views.
1237 1232 """
1238 1233
1239 1234 if self._closed:
1240 1235 raise RuntimeError("Client cannot be used after its sockets have been closed")
1241 1236
1242 1237 # defaults:
1243 1238 args = args if args is not None else []
1244 1239 kwargs = kwargs if kwargs is not None else {}
1245 1240 metadata = metadata if metadata is not None else {}
1246 1241
1247 1242 # validate arguments
1248 1243 if not callable(f) and not isinstance(f, Reference):
1249 1244 raise TypeError("f must be callable, not %s"%type(f))
1250 1245 if not isinstance(args, (tuple, list)):
1251 1246 raise TypeError("args must be tuple or list, not %s"%type(args))
1252 1247 if not isinstance(kwargs, dict):
1253 1248 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1254 1249 if not isinstance(metadata, dict):
1255 1250 raise TypeError("metadata must be dict, not %s"%type(metadata))
1256 1251
1257 1252 bufs = serialize.pack_apply_message(f, args, kwargs,
1258 1253 buffer_threshold=self.session.buffer_threshold,
1259 1254 item_threshold=self.session.item_threshold,
1260 1255 )
1261 1256
1262 1257 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1263 1258 metadata=metadata, track=track)
1264 1259
1265 1260 msg_id = msg['header']['msg_id']
1266 1261 self.outstanding.add(msg_id)
1267 1262 if ident:
1268 1263 # possibly routed to a specific engine
1269 1264 if isinstance(ident, list):
1270 1265 ident = ident[-1]
1271 1266 if ident in self._engines.values():
1272 1267 # save for later, in case of engine death
1273 1268 self._outstanding_dict[ident].add(msg_id)
1274 1269 self.history.append(msg_id)
1275 1270 self.metadata[msg_id]['submitted'] = datetime.now()
1276 1271
1277 1272 return msg
1278 1273
1279 1274 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1280 1275 """construct and send an execute request via a socket.
1281 1276
1282 1277 """
1283 1278
1284 1279 if self._closed:
1285 1280 raise RuntimeError("Client cannot be used after its sockets have been closed")
1286 1281
1287 1282 # defaults:
1288 1283 metadata = metadata if metadata is not None else {}
1289 1284
1290 1285 # validate arguments
1291 1286 if not isinstance(code, string_types):
1292 1287 raise TypeError("code must be text, not %s" % type(code))
1293 1288 if not isinstance(metadata, dict):
1294 1289 raise TypeError("metadata must be dict, not %s" % type(metadata))
1295 1290
1296 1291 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1297 1292
1298 1293
1299 1294 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1300 1295 metadata=metadata)
1301 1296
1302 1297 msg_id = msg['header']['msg_id']
1303 1298 self.outstanding.add(msg_id)
1304 1299 if ident:
1305 1300 # possibly routed to a specific engine
1306 1301 if isinstance(ident, list):
1307 1302 ident = ident[-1]
1308 1303 if ident in self._engines.values():
1309 1304 # save for later, in case of engine death
1310 1305 self._outstanding_dict[ident].add(msg_id)
1311 1306 self.history.append(msg_id)
1312 1307 self.metadata[msg_id]['submitted'] = datetime.now()
1313 1308
1314 1309 return msg
1315 1310
1316 1311 #--------------------------------------------------------------------------
1317 1312 # construct a View object
1318 1313 #--------------------------------------------------------------------------
1319 1314
1320 1315 def load_balanced_view(self, targets=None):
1321 1316 """construct a DirectView object.
1322 1317
1323 1318 If no arguments are specified, create a LoadBalancedView
1324 1319 using all engines.
1325 1320
1326 1321 Parameters
1327 1322 ----------
1328 1323
1329 1324 targets: list,slice,int,etc. [default: use all engines]
1330 1325 The subset of engines across which to load-balance
1331 1326 """
1332 1327 if targets == 'all':
1333 1328 targets = None
1334 1329 if targets is not None:
1335 1330 targets = self._build_targets(targets)[1]
1336 1331 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1337 1332
1338 1333 def direct_view(self, targets='all'):
1339 1334 """construct a DirectView object.
1340 1335
1341 1336 If no targets are specified, create a DirectView using all engines.
1342 1337
1343 1338 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1344 1339 evaluate the target engines at each execution, whereas rc[:] will connect to
1345 1340 all *current* engines, and that list will not change.
1346 1341
1347 1342 That is, 'all' will always use all engines, whereas rc[:] will not use
1348 1343 engines added after the DirectView is constructed.
1349 1344
1350 1345 Parameters
1351 1346 ----------
1352 1347
1353 1348 targets: list,slice,int,etc. [default: use all engines]
1354 1349 The engines to use for the View
1355 1350 """
1356 1351 single = isinstance(targets, int)
1357 1352 # allow 'all' to be lazily evaluated at each execution
1358 1353 if targets != 'all':
1359 1354 targets = self._build_targets(targets)[1]
1360 1355 if single:
1361 1356 targets = targets[0]
1362 1357 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1363 1358
1364 1359 #--------------------------------------------------------------------------
1365 1360 # Query methods
1366 1361 #--------------------------------------------------------------------------
1367 1362
1368 1363 @spin_first
1369 1364 def get_result(self, indices_or_msg_ids=None, block=None):
1370 1365 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1371 1366
1372 1367 If the client already has the results, no request to the Hub will be made.
1373 1368
1374 1369 This is a convenient way to construct AsyncResult objects, which are wrappers
1375 1370 that include metadata about execution, and allow for awaiting results that
1376 1371 were not submitted by this Client.
1377 1372
1378 1373 It can also be a convenient way to retrieve the metadata associated with
1379 1374 blocking execution, since it always retrieves
1380 1375
1381 1376 Examples
1382 1377 --------
1383 1378 ::
1384 1379
1385 1380 In [10]: r = client.apply()
1386 1381
1387 1382 Parameters
1388 1383 ----------
1389 1384
1390 1385 indices_or_msg_ids : integer history index, str msg_id, or list of either
1391 1386 The indices or msg_ids of indices to be retrieved
1392 1387
1393 1388 block : bool
1394 1389 Whether to wait for the result to be done
1395 1390
1396 1391 Returns
1397 1392 -------
1398 1393
1399 1394 AsyncResult
1400 1395 A single AsyncResult object will always be returned.
1401 1396
1402 1397 AsyncHubResult
1403 1398 A subclass of AsyncResult that retrieves results from the Hub
1404 1399
1405 1400 """
1406 1401 block = self.block if block is None else block
1407 1402 if indices_or_msg_ids is None:
1408 1403 indices_or_msg_ids = -1
1409 1404
1410 1405 single_result = False
1411 1406 if not isinstance(indices_or_msg_ids, (list,tuple)):
1412 1407 indices_or_msg_ids = [indices_or_msg_ids]
1413 1408 single_result = True
1414 1409
1415 1410 theids = []
1416 1411 for id in indices_or_msg_ids:
1417 1412 if isinstance(id, int):
1418 1413 id = self.history[id]
1419 1414 if not isinstance(id, string_types):
1420 1415 raise TypeError("indices must be str or int, not %r"%id)
1421 1416 theids.append(id)
1422 1417
1423 1418 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1424 1419 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1425 1420
1426 1421 # given single msg_id initially, get_result shot get the result itself,
1427 1422 # not a length-one list
1428 1423 if single_result:
1429 1424 theids = theids[0]
1430 1425
1431 1426 if remote_ids:
1432 1427 ar = AsyncHubResult(self, msg_ids=theids)
1433 1428 else:
1434 1429 ar = AsyncResult(self, msg_ids=theids)
1435 1430
1436 1431 if block:
1437 1432 ar.wait()
1438 1433
1439 1434 return ar
1440 1435
1441 1436 @spin_first
1442 1437 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1443 1438 """Resubmit one or more tasks.
1444 1439
1445 1440 in-flight tasks may not be resubmitted.
1446 1441
1447 1442 Parameters
1448 1443 ----------
1449 1444
1450 1445 indices_or_msg_ids : integer history index, str msg_id, or list of either
1451 1446 The indices or msg_ids of indices to be retrieved
1452 1447
1453 1448 block : bool
1454 1449 Whether to wait for the result to be done
1455 1450
1456 1451 Returns
1457 1452 -------
1458 1453
1459 1454 AsyncHubResult
1460 1455 A subclass of AsyncResult that retrieves results from the Hub
1461 1456
1462 1457 """
1463 1458 block = self.block if block is None else block
1464 1459 if indices_or_msg_ids is None:
1465 1460 indices_or_msg_ids = -1
1466 1461
1467 1462 if not isinstance(indices_or_msg_ids, (list,tuple)):
1468 1463 indices_or_msg_ids = [indices_or_msg_ids]
1469 1464
1470 1465 theids = []
1471 1466 for id in indices_or_msg_ids:
1472 1467 if isinstance(id, int):
1473 1468 id = self.history[id]
1474 1469 if not isinstance(id, string_types):
1475 1470 raise TypeError("indices must be str or int, not %r"%id)
1476 1471 theids.append(id)
1477 1472
1478 1473 content = dict(msg_ids = theids)
1479 1474
1480 1475 self.session.send(self._query_socket, 'resubmit_request', content)
1481 1476
1482 1477 zmq.select([self._query_socket], [], [])
1483 1478 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1484 1479 if self.debug:
1485 1480 pprint(msg)
1486 1481 content = msg['content']
1487 1482 if content['status'] != 'ok':
1488 1483 raise self._unwrap_exception(content)
1489 1484 mapping = content['resubmitted']
1490 1485 new_ids = [ mapping[msg_id] for msg_id in theids ]
1491 1486
1492 1487 ar = AsyncHubResult(self, msg_ids=new_ids)
1493 1488
1494 1489 if block:
1495 1490 ar.wait()
1496 1491
1497 1492 return ar
1498 1493
1499 1494 @spin_first
1500 1495 def result_status(self, msg_ids, status_only=True):
1501 1496 """Check on the status of the result(s) of the apply request with `msg_ids`.
1502 1497
1503 1498 If status_only is False, then the actual results will be retrieved, else
1504 1499 only the status of the results will be checked.
1505 1500
1506 1501 Parameters
1507 1502 ----------
1508 1503
1509 1504 msg_ids : list of msg_ids
1510 1505 if int:
1511 1506 Passed as index to self.history for convenience.
1512 1507 status_only : bool (default: True)
1513 1508 if False:
1514 1509 Retrieve the actual results of completed tasks.
1515 1510
1516 1511 Returns
1517 1512 -------
1518 1513
1519 1514 results : dict
1520 1515 There will always be the keys 'pending' and 'completed', which will
1521 1516 be lists of msg_ids that are incomplete or complete. If `status_only`
1522 1517 is False, then completed results will be keyed by their `msg_id`.
1523 1518 """
1524 1519 if not isinstance(msg_ids, (list,tuple)):
1525 1520 msg_ids = [msg_ids]
1526 1521
1527 1522 theids = []
1528 1523 for msg_id in msg_ids:
1529 1524 if isinstance(msg_id, int):
1530 1525 msg_id = self.history[msg_id]
1531 1526 if not isinstance(msg_id, string_types):
1532 1527 raise TypeError("msg_ids must be str, not %r"%msg_id)
1533 1528 theids.append(msg_id)
1534 1529
1535 1530 completed = []
1536 1531 local_results = {}
1537 1532
1538 1533 # comment this block out to temporarily disable local shortcut:
1539 1534 for msg_id in theids:
1540 1535 if msg_id in self.results:
1541 1536 completed.append(msg_id)
1542 1537 local_results[msg_id] = self.results[msg_id]
1543 1538 theids.remove(msg_id)
1544 1539
1545 1540 if theids: # some not locally cached
1546 1541 content = dict(msg_ids=theids, status_only=status_only)
1547 1542 msg = self.session.send(self._query_socket, "result_request", content=content)
1548 1543 zmq.select([self._query_socket], [], [])
1549 1544 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1550 1545 if self.debug:
1551 1546 pprint(msg)
1552 1547 content = msg['content']
1553 1548 if content['status'] != 'ok':
1554 1549 raise self._unwrap_exception(content)
1555 1550 buffers = msg['buffers']
1556 1551 else:
1557 1552 content = dict(completed=[],pending=[])
1558 1553
1559 1554 content['completed'].extend(completed)
1560 1555
1561 1556 if status_only:
1562 1557 return content
1563 1558
1564 1559 failures = []
1565 1560 # load cached results into result:
1566 1561 content.update(local_results)
1567 1562
1568 1563 # update cache with results:
1569 1564 for msg_id in sorted(theids):
1570 1565 if msg_id in content['completed']:
1571 1566 rec = content[msg_id]
1572 1567 parent = rec['header']
1573 1568 header = rec['result_header']
1574 1569 rcontent = rec['result_content']
1575 1570 iodict = rec['io']
1576 1571 if isinstance(rcontent, str):
1577 1572 rcontent = self.session.unpack(rcontent)
1578 1573
1579 1574 md = self.metadata[msg_id]
1580 1575 md_msg = dict(
1581 1576 content=rcontent,
1582 1577 parent_header=parent,
1583 1578 header=header,
1584 1579 metadata=rec['result_metadata'],
1585 1580 )
1586 1581 md.update(self._extract_metadata(md_msg))
1587 1582 if rec.get('received'):
1588 1583 md['received'] = rec['received']
1589 1584 md.update(iodict)
1590 1585
1591 1586 if rcontent['status'] == 'ok':
1592 1587 if header['msg_type'] == 'apply_reply':
1593 1588 res,buffers = serialize.unserialize_object(buffers)
1594 1589 elif header['msg_type'] == 'execute_reply':
1595 1590 res = ExecuteReply(msg_id, rcontent, md)
1596 1591 else:
1597 1592 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1598 1593 else:
1599 1594 res = self._unwrap_exception(rcontent)
1600 1595 failures.append(res)
1601 1596
1602 1597 self.results[msg_id] = res
1603 1598 content[msg_id] = res
1604 1599
1605 1600 if len(theids) == 1 and failures:
1606 1601 raise failures[0]
1607 1602
1608 1603 error.collect_exceptions(failures, "result_status")
1609 1604 return content
1610 1605
1611 1606 @spin_first
1612 1607 def queue_status(self, targets='all', verbose=False):
1613 1608 """Fetch the status of engine queues.
1614 1609
1615 1610 Parameters
1616 1611 ----------
1617 1612
1618 1613 targets : int/str/list of ints/strs
1619 1614 the engines whose states are to be queried.
1620 1615 default : all
1621 1616 verbose : bool
1622 1617 Whether to return lengths only, or lists of ids for each element
1623 1618 """
1624 1619 if targets == 'all':
1625 1620 # allow 'all' to be evaluated on the engine
1626 1621 engine_ids = None
1627 1622 else:
1628 1623 engine_ids = self._build_targets(targets)[1]
1629 1624 content = dict(targets=engine_ids, verbose=verbose)
1630 1625 self.session.send(self._query_socket, "queue_request", content=content)
1631 1626 idents,msg = self.session.recv(self._query_socket, 0)
1632 1627 if self.debug:
1633 1628 pprint(msg)
1634 1629 content = msg['content']
1635 1630 status = content.pop('status')
1636 1631 if status != 'ok':
1637 1632 raise self._unwrap_exception(content)
1638 1633 content = rekey(content)
1639 1634 if isinstance(targets, int):
1640 1635 return content[targets]
1641 1636 else:
1642 1637 return content
1643 1638
1644 1639 def _build_msgids_from_target(self, targets=None):
1645 1640 """Build a list of msg_ids from the list of engine targets"""
1646 1641 if not targets: # needed as _build_targets otherwise uses all engines
1647 1642 return []
1648 1643 target_ids = self._build_targets(targets)[0]
1649 1644 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1650 1645
1651 1646 def _build_msgids_from_jobs(self, jobs=None):
1652 1647 """Build a list of msg_ids from "jobs" """
1653 1648 if not jobs:
1654 1649 return []
1655 1650 msg_ids = []
1656 1651 if isinstance(jobs, string_types + (AsyncResult,)):
1657 1652 jobs = [jobs]
1658 1653 bad_ids = filter(lambda obj: not isinstance(obj, string_types + (AsyncResult)), jobs)
1659 1654 if bad_ids:
1660 1655 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1661 1656 for j in jobs:
1662 1657 if isinstance(j, AsyncResult):
1663 1658 msg_ids.extend(j.msg_ids)
1664 1659 else:
1665 1660 msg_ids.append(j)
1666 1661 return msg_ids
1667 1662
1668 1663 def purge_local_results(self, jobs=[], targets=[]):
1669 1664 """Clears the client caches of results and frees such memory.
1670 1665
1671 1666 Individual results can be purged by msg_id, or the entire
1672 1667 history of specific targets can be purged.
1673 1668
1674 1669 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1675 1670
1676 1671 The client must have no outstanding tasks before purging the caches.
1677 1672 Raises `AssertionError` if there are still outstanding tasks.
1678 1673
1679 1674 After this call all `AsyncResults` are invalid and should be discarded.
1680 1675
1681 1676 If you must "reget" the results, you can still do so by using
1682 1677 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1683 1678 redownload the results from the hub if they are still available
1684 1679 (i.e `client.purge_hub_results(...)` has not been called.
1685 1680
1686 1681 Parameters
1687 1682 ----------
1688 1683
1689 1684 jobs : str or list of str or AsyncResult objects
1690 1685 the msg_ids whose results should be purged.
1691 1686 targets : int/str/list of ints/strs
1692 1687 The targets, by int_id, whose entire results are to be purged.
1693 1688
1694 1689 default : None
1695 1690 """
1696 1691 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1697 1692
1698 1693 if not targets and not jobs:
1699 1694 raise ValueError("Must specify at least one of `targets` and `jobs`")
1700 1695
1701 1696 if jobs == 'all':
1702 1697 self.results.clear()
1703 1698 self.metadata.clear()
1704 1699 return
1705 1700 else:
1706 1701 msg_ids = []
1707 1702 msg_ids.extend(self._build_msgids_from_target(targets))
1708 1703 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1709 1704 map(self.results.pop, msg_ids)
1710 1705 map(self.metadata.pop, msg_ids)
1711 1706
1712 1707
1713 1708 @spin_first
1714 1709 def purge_hub_results(self, jobs=[], targets=[]):
1715 1710 """Tell the Hub to forget results.
1716 1711
1717 1712 Individual results can be purged by msg_id, or the entire
1718 1713 history of specific targets can be purged.
1719 1714
1720 1715 Use `purge_results('all')` to scrub everything from the Hub's db.
1721 1716
1722 1717 Parameters
1723 1718 ----------
1724 1719
1725 1720 jobs : str or list of str or AsyncResult objects
1726 1721 the msg_ids whose results should be forgotten.
1727 1722 targets : int/str/list of ints/strs
1728 1723 The targets, by int_id, whose entire history is to be purged.
1729 1724
1730 1725 default : None
1731 1726 """
1732 1727 if not targets and not jobs:
1733 1728 raise ValueError("Must specify at least one of `targets` and `jobs`")
1734 1729 if targets:
1735 1730 targets = self._build_targets(targets)[1]
1736 1731
1737 1732 # construct msg_ids from jobs
1738 1733 if jobs == 'all':
1739 1734 msg_ids = jobs
1740 1735 else:
1741 1736 msg_ids = self._build_msgids_from_jobs(jobs)
1742 1737
1743 1738 content = dict(engine_ids=targets, msg_ids=msg_ids)
1744 1739 self.session.send(self._query_socket, "purge_request", content=content)
1745 1740 idents, msg = self.session.recv(self._query_socket, 0)
1746 1741 if self.debug:
1747 1742 pprint(msg)
1748 1743 content = msg['content']
1749 1744 if content['status'] != 'ok':
1750 1745 raise self._unwrap_exception(content)
1751 1746
1752 1747 def purge_results(self, jobs=[], targets=[]):
1753 1748 """Clears the cached results from both the hub and the local client
1754 1749
1755 1750 Individual results can be purged by msg_id, or the entire
1756 1751 history of specific targets can be purged.
1757 1752
1758 1753 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1759 1754 the Client's db.
1760 1755
1761 1756 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1762 1757 the same arguments.
1763 1758
1764 1759 Parameters
1765 1760 ----------
1766 1761
1767 1762 jobs : str or list of str or AsyncResult objects
1768 1763 the msg_ids whose results should be forgotten.
1769 1764 targets : int/str/list of ints/strs
1770 1765 The targets, by int_id, whose entire history is to be purged.
1771 1766
1772 1767 default : None
1773 1768 """
1774 1769 self.purge_local_results(jobs=jobs, targets=targets)
1775 1770 self.purge_hub_results(jobs=jobs, targets=targets)
1776 1771
1777 1772 def purge_everything(self):
1778 1773 """Clears all content from previous Tasks from both the hub and the local client
1779 1774
1780 1775 In addition to calling `purge_results("all")` it also deletes the history and
1781 1776 other bookkeeping lists.
1782 1777 """
1783 1778 self.purge_results("all")
1784 1779 self.history = []
1785 1780 self.session.digest_history.clear()
1786 1781
1787 1782 @spin_first
1788 1783 def hub_history(self):
1789 1784 """Get the Hub's history
1790 1785
1791 1786 Just like the Client, the Hub has a history, which is a list of msg_ids.
1792 1787 This will contain the history of all clients, and, depending on configuration,
1793 1788 may contain history across multiple cluster sessions.
1794 1789
1795 1790 Any msg_id returned here is a valid argument to `get_result`.
1796 1791
1797 1792 Returns
1798 1793 -------
1799 1794
1800 1795 msg_ids : list of strs
1801 1796 list of all msg_ids, ordered by task submission time.
1802 1797 """
1803 1798
1804 1799 self.session.send(self._query_socket, "history_request", content={})
1805 1800 idents, msg = self.session.recv(self._query_socket, 0)
1806 1801
1807 1802 if self.debug:
1808 1803 pprint(msg)
1809 1804 content = msg['content']
1810 1805 if content['status'] != 'ok':
1811 1806 raise self._unwrap_exception(content)
1812 1807 else:
1813 1808 return content['history']
1814 1809
1815 1810 @spin_first
1816 1811 def db_query(self, query, keys=None):
1817 1812 """Query the Hub's TaskRecord database
1818 1813
1819 1814 This will return a list of task record dicts that match `query`
1820 1815
1821 1816 Parameters
1822 1817 ----------
1823 1818
1824 1819 query : mongodb query dict
1825 1820 The search dict. See mongodb query docs for details.
1826 1821 keys : list of strs [optional]
1827 1822 The subset of keys to be returned. The default is to fetch everything but buffers.
1828 1823 'msg_id' will *always* be included.
1829 1824 """
1830 1825 if isinstance(keys, string_types):
1831 1826 keys = [keys]
1832 1827 content = dict(query=query, keys=keys)
1833 1828 self.session.send(self._query_socket, "db_request", content=content)
1834 1829 idents, msg = self.session.recv(self._query_socket, 0)
1835 1830 if self.debug:
1836 1831 pprint(msg)
1837 1832 content = msg['content']
1838 1833 if content['status'] != 'ok':
1839 1834 raise self._unwrap_exception(content)
1840 1835
1841 1836 records = content['records']
1842 1837
1843 1838 buffer_lens = content['buffer_lens']
1844 1839 result_buffer_lens = content['result_buffer_lens']
1845 1840 buffers = msg['buffers']
1846 1841 has_bufs = buffer_lens is not None
1847 1842 has_rbufs = result_buffer_lens is not None
1848 1843 for i,rec in enumerate(records):
1849 1844 # relink buffers
1850 1845 if has_bufs:
1851 1846 blen = buffer_lens[i]
1852 1847 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1853 1848 if has_rbufs:
1854 1849 blen = result_buffer_lens[i]
1855 1850 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1856 1851
1857 1852 return records
1858 1853
1859 1854 __all__ = [ 'Client' ]
@@ -1,369 +1,369 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 42 from IPython.external.decorator import decorator
43 43
44 44 # IPython imports
45 45 from IPython.config.application import Application
46 46 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
47 47 from IPython.utils.py3compat import string_types
48 48 from IPython.kernel.zmq.log import EnginePUBHandler
49 49 from IPython.kernel.zmq.serialize import (
50 50 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
51 51 )
52 52
53 53 #-----------------------------------------------------------------------------
54 54 # Classes
55 55 #-----------------------------------------------------------------------------
56 56
57 57 class Namespace(dict):
58 58 """Subclass of dict for attribute access to keys."""
59 59
60 60 def __getattr__(self, key):
61 61 """getattr aliased to getitem"""
62 62 if key in self.iterkeys():
63 63 return self[key]
64 64 else:
65 65 raise NameError(key)
66 66
67 67 def __setattr__(self, key, value):
68 68 """setattr aliased to setitem, with strict"""
69 69 if hasattr(dict, key):
70 70 raise KeyError("Cannot override dict keys %r"%key)
71 71 self[key] = value
72 72
73 73
74 74 class ReverseDict(dict):
75 75 """simple double-keyed subset of dict methods."""
76 76
77 77 def __init__(self, *args, **kwargs):
78 78 dict.__init__(self, *args, **kwargs)
79 79 self._reverse = dict()
80 80 for key, value in self.iteritems():
81 81 self._reverse[value] = key
82 82
83 83 def __getitem__(self, key):
84 84 try:
85 85 return dict.__getitem__(self, key)
86 86 except KeyError:
87 87 return self._reverse[key]
88 88
89 89 def __setitem__(self, key, value):
90 90 if key in self._reverse:
91 91 raise KeyError("Can't have key %r on both sides!"%key)
92 92 dict.__setitem__(self, key, value)
93 93 self._reverse[value] = key
94 94
95 95 def pop(self, key):
96 96 value = dict.pop(self, key)
97 97 self._reverse.pop(value)
98 98 return value
99 99
100 100 def get(self, key, default=None):
101 101 try:
102 102 return self[key]
103 103 except KeyError:
104 104 return default
105 105
106 106 #-----------------------------------------------------------------------------
107 107 # Functions
108 108 #-----------------------------------------------------------------------------
109 109
110 110 @decorator
111 111 def log_errors(f, self, *args, **kwargs):
112 112 """decorator to log unhandled exceptions raised in a method.
113 113
114 114 For use wrapping on_recv callbacks, so that exceptions
115 115 do not cause the stream to be closed.
116 116 """
117 117 try:
118 118 return f(self, *args, **kwargs)
119 119 except Exception:
120 120 self.log.error("Uncaught exception in %r" % f, exc_info=True)
121 121
122 122
123 123 def is_url(url):
124 124 """boolean check for whether a string is a zmq url"""
125 125 if '://' not in url:
126 126 return False
127 127 proto, addr = url.split('://', 1)
128 128 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
129 129 return False
130 130 return True
131 131
132 132 def validate_url(url):
133 133 """validate a url for zeromq"""
134 134 if not isinstance(url, string_types):
135 135 raise TypeError("url must be a string, not %r"%type(url))
136 136 url = url.lower()
137 137
138 138 proto_addr = url.split('://')
139 139 assert len(proto_addr) == 2, 'Invalid url: %r'%url
140 140 proto, addr = proto_addr
141 141 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
142 142
143 143 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
144 144 # author: Remi Sabourin
145 145 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
146 146
147 147 if proto == 'tcp':
148 148 lis = addr.split(':')
149 149 assert len(lis) == 2, 'Invalid url: %r'%url
150 150 addr,s_port = lis
151 151 try:
152 152 port = int(s_port)
153 153 except ValueError:
154 154 raise AssertionError("Invalid port %r in url: %r"%(port, url))
155 155
156 156 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
157 157
158 158 else:
159 159 # only validate tcp urls currently
160 160 pass
161 161
162 162 return True
163 163
164 164
165 165 def validate_url_container(container):
166 166 """validate a potentially nested collection of urls."""
167 167 if isinstance(container, string_types):
168 168 url = container
169 169 return validate_url(url)
170 170 elif isinstance(container, dict):
171 171 container = container.itervalues()
172 172
173 173 for element in container:
174 174 validate_url_container(element)
175 175
176 176
177 177 def split_url(url):
178 178 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
179 179 proto_addr = url.split('://')
180 180 assert len(proto_addr) == 2, 'Invalid url: %r'%url
181 181 proto, addr = proto_addr
182 182 lis = addr.split(':')
183 183 assert len(lis) == 2, 'Invalid url: %r'%url
184 184 addr,s_port = lis
185 185 return proto,addr,s_port
186 186
187 187 def disambiguate_ip_address(ip, location=None):
188 188 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
189 189 ones, based on the location (default interpretation of location is localhost)."""
190 190 if ip in ('0.0.0.0', '*'):
191 191 if location is None or is_public_ip(location) or not public_ips():
192 192 # If location is unspecified or cannot be determined, assume local
193 193 ip = localhost()
194 194 elif location:
195 195 return location
196 196 return ip
197 197
198 198 def disambiguate_url(url, location=None):
199 199 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
200 200 ones, based on the location (default interpretation is localhost).
201 201
202 202 This is for zeromq urls, such as tcp://*:10101."""
203 203 try:
204 204 proto,ip,port = split_url(url)
205 205 except AssertionError:
206 206 # probably not tcp url; could be ipc, etc.
207 207 return url
208 208
209 209 ip = disambiguate_ip_address(ip,location)
210 210
211 211 return "%s://%s:%s"%(proto,ip,port)
212 212
213 213
214 214 #--------------------------------------------------------------------------
215 215 # helpers for implementing old MEC API via view.apply
216 216 #--------------------------------------------------------------------------
217 217
218 218 def interactive(f):
219 219 """decorator for making functions appear as interactively defined.
220 220 This results in the function being linked to the user_ns as globals()
221 221 instead of the module globals().
222 222 """
223 223 f.__module__ = '__main__'
224 224 return f
225 225
226 226 @interactive
227 227 def _push(**ns):
228 228 """helper method for implementing `client.push` via `client.apply`"""
229 229 user_ns = globals()
230 230 tmp = '_IP_PUSH_TMP_'
231 231 while tmp in user_ns:
232 232 tmp = tmp + '_'
233 233 try:
234 234 for name, value in ns.iteritems():
235 235 user_ns[tmp] = value
236 236 exec("%s = %s" % (name, tmp), user_ns)
237 237 finally:
238 238 user_ns.pop(tmp, None)
239 239
240 240 @interactive
241 241 def _pull(keys):
242 242 """helper method for implementing `client.pull` via `client.apply`"""
243 243 if isinstance(keys, (list,tuple, set)):
244 244 return map(lambda key: eval(key, globals()), keys)
245 245 else:
246 246 return eval(keys, globals())
247 247
248 248 @interactive
249 249 def _execute(code):
250 250 """helper method for implementing `client.execute` via `client.apply`"""
251 251 exec(code, globals())
252 252
253 253 #--------------------------------------------------------------------------
254 254 # extra process management utilities
255 255 #--------------------------------------------------------------------------
256 256
257 257 _random_ports = set()
258 258
259 259 def select_random_ports(n):
260 260 """Selects and return n random ports that are available."""
261 261 ports = []
262 for i in xrange(n):
262 for i in range(n):
263 263 sock = socket.socket()
264 264 sock.bind(('', 0))
265 265 while sock.getsockname()[1] in _random_ports:
266 266 sock.close()
267 267 sock = socket.socket()
268 268 sock.bind(('', 0))
269 269 ports.append(sock)
270 270 for i, sock in enumerate(ports):
271 271 port = sock.getsockname()[1]
272 272 sock.close()
273 273 ports[i] = port
274 274 _random_ports.add(port)
275 275 return ports
276 276
277 277 def signal_children(children):
278 278 """Relay interupt/term signals to children, for more solid process cleanup."""
279 279 def terminate_children(sig, frame):
280 280 log = Application.instance().log
281 281 log.critical("Got signal %i, terminating children..."%sig)
282 282 for child in children:
283 283 child.terminate()
284 284
285 285 sys.exit(sig != SIGINT)
286 286 # sys.exit(sig)
287 287 for sig in (SIGINT, SIGABRT, SIGTERM):
288 288 signal(sig, terminate_children)
289 289
290 290 def generate_exec_key(keyfile):
291 291 import uuid
292 292 newkey = str(uuid.uuid4())
293 293 with open(keyfile, 'w') as f:
294 294 # f.write('ipython-key ')
295 295 f.write(newkey+'\n')
296 296 # set user-only RW permissions (0600)
297 297 # this will have no effect on Windows
298 298 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
299 299
300 300
301 301 def integer_loglevel(loglevel):
302 302 try:
303 303 loglevel = int(loglevel)
304 304 except ValueError:
305 305 if isinstance(loglevel, str):
306 306 loglevel = getattr(logging, loglevel)
307 307 return loglevel
308 308
309 309 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
310 310 logger = logging.getLogger(logname)
311 311 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
312 312 # don't add a second PUBHandler
313 313 return
314 314 loglevel = integer_loglevel(loglevel)
315 315 lsock = context.socket(zmq.PUB)
316 316 lsock.connect(iface)
317 317 handler = handlers.PUBHandler(lsock)
318 318 handler.setLevel(loglevel)
319 319 handler.root_topic = root
320 320 logger.addHandler(handler)
321 321 logger.setLevel(loglevel)
322 322
323 323 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
324 324 logger = logging.getLogger()
325 325 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
326 326 # don't add a second PUBHandler
327 327 return
328 328 loglevel = integer_loglevel(loglevel)
329 329 lsock = context.socket(zmq.PUB)
330 330 lsock.connect(iface)
331 331 handler = EnginePUBHandler(engine, lsock)
332 332 handler.setLevel(loglevel)
333 333 logger.addHandler(handler)
334 334 logger.setLevel(loglevel)
335 335 return logger
336 336
337 337 def local_logger(logname, loglevel=logging.DEBUG):
338 338 loglevel = integer_loglevel(loglevel)
339 339 logger = logging.getLogger(logname)
340 340 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
341 341 # don't add a second StreamHandler
342 342 return
343 343 handler = logging.StreamHandler()
344 344 handler.setLevel(loglevel)
345 345 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
346 346 datefmt="%Y-%m-%d %H:%M:%S")
347 347 handler.setFormatter(formatter)
348 348
349 349 logger.addHandler(handler)
350 350 logger.setLevel(loglevel)
351 351 return logger
352 352
353 353 def set_hwm(sock, hwm=0):
354 354 """set zmq High Water Mark on a socket
355 355
356 356 in a way that always works for various pyzmq / libzmq versions.
357 357 """
358 358 import zmq
359 359
360 360 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
361 361 opt = getattr(zmq, key, None)
362 362 if opt is None:
363 363 continue
364 364 try:
365 365 sock.setsockopt(opt, hwm)
366 366 except zmq.ZMQError:
367 367 pass
368 368
369 369 No newline at end of file
@@ -1,378 +1,378 b''
1 1 """ Utilities for processing ANSI escape codes and special ASCII characters.
2 2 """
3 3 #-----------------------------------------------------------------------------
4 4 # Imports
5 5 #-----------------------------------------------------------------------------
6 6
7 7 # Standard library imports
8 8 from collections import namedtuple
9 9 import re
10 10
11 11 # System library imports
12 12 from IPython.external.qt import QtGui
13 13
14 14 # Local imports
15 15 from IPython.utils.py3compat import string_types
16 16
17 17 #-----------------------------------------------------------------------------
18 18 # Constants and datatypes
19 19 #-----------------------------------------------------------------------------
20 20
21 21 # An action for erase requests (ED and EL commands).
22 22 EraseAction = namedtuple('EraseAction', ['action', 'area', 'erase_to'])
23 23
24 24 # An action for cursor move requests (CUU, CUD, CUF, CUB, CNL, CPL, CHA, CUP,
25 25 # and HVP commands).
26 26 # FIXME: Not implemented in AnsiCodeProcessor.
27 27 MoveAction = namedtuple('MoveAction', ['action', 'dir', 'unit', 'count'])
28 28
29 29 # An action for scroll requests (SU and ST) and form feeds.
30 30 ScrollAction = namedtuple('ScrollAction', ['action', 'dir', 'unit', 'count'])
31 31
32 32 # An action for the carriage return character
33 33 CarriageReturnAction = namedtuple('CarriageReturnAction', ['action'])
34 34
35 35 # An action for the \n character
36 36 NewLineAction = namedtuple('NewLineAction', ['action'])
37 37
38 38 # An action for the beep character
39 39 BeepAction = namedtuple('BeepAction', ['action'])
40 40
41 41 # An action for backspace
42 42 BackSpaceAction = namedtuple('BackSpaceAction', ['action'])
43 43
44 44 # Regular expressions.
45 45 CSI_COMMANDS = 'ABCDEFGHJKSTfmnsu'
46 46 CSI_SUBPATTERN = '\[(.*?)([%s])' % CSI_COMMANDS
47 47 OSC_SUBPATTERN = '\](.*?)[\x07\x1b]'
48 48 ANSI_PATTERN = ('\x01?\x1b(%s|%s)\x02?' % \
49 49 (CSI_SUBPATTERN, OSC_SUBPATTERN))
50 50 ANSI_OR_SPECIAL_PATTERN = re.compile('(\a|\b|\r(?!\n)|\r?\n)|(?:%s)' % ANSI_PATTERN)
51 51 SPECIAL_PATTERN = re.compile('([\f])')
52 52
53 53 #-----------------------------------------------------------------------------
54 54 # Classes
55 55 #-----------------------------------------------------------------------------
56 56
57 57 class AnsiCodeProcessor(object):
58 58 """ Translates special ASCII characters and ANSI escape codes into readable
59 59 attributes. It also supports a few non-standard, xterm-specific codes.
60 60 """
61 61
62 62 # Whether to increase intensity or set boldness for SGR code 1.
63 63 # (Different terminals handle this in different ways.)
64 64 bold_text_enabled = False
65 65
66 66 # We provide an empty default color map because subclasses will likely want
67 67 # to use a custom color format.
68 68 default_color_map = {}
69 69
70 70 #---------------------------------------------------------------------------
71 71 # AnsiCodeProcessor interface
72 72 #---------------------------------------------------------------------------
73 73
74 74 def __init__(self):
75 75 self.actions = []
76 76 self.color_map = self.default_color_map.copy()
77 77 self.reset_sgr()
78 78
79 79 def reset_sgr(self):
80 80 """ Reset graphics attributs to their default values.
81 81 """
82 82 self.intensity = 0
83 83 self.italic = False
84 84 self.bold = False
85 85 self.underline = False
86 86 self.foreground_color = None
87 87 self.background_color = None
88 88
89 89 def split_string(self, string):
90 90 """ Yields substrings for which the same escape code applies.
91 91 """
92 92 self.actions = []
93 93 start = 0
94 94
95 95 # strings ending with \r are assumed to be ending in \r\n since
96 96 # \n is appended to output strings automatically. Accounting
97 97 # for that, here.
98 98 last_char = '\n' if len(string) > 0 and string[-1] == '\n' else None
99 99 string = string[:-1] if last_char is not None else string
100 100
101 101 for match in ANSI_OR_SPECIAL_PATTERN.finditer(string):
102 102 raw = string[start:match.start()]
103 103 substring = SPECIAL_PATTERN.sub(self._replace_special, raw)
104 104 if substring or self.actions:
105 105 yield substring
106 106 self.actions = []
107 107 start = match.end()
108 108
109 109 groups = filter(lambda x: x is not None, match.groups())
110 110 g0 = groups[0]
111 111 if g0 == '\a':
112 112 self.actions.append(BeepAction('beep'))
113 113 yield None
114 114 self.actions = []
115 115 elif g0 == '\r':
116 116 self.actions.append(CarriageReturnAction('carriage-return'))
117 117 yield None
118 118 self.actions = []
119 119 elif g0 == '\b':
120 120 self.actions.append(BackSpaceAction('backspace'))
121 121 yield None
122 122 self.actions = []
123 123 elif g0 == '\n' or g0 == '\r\n':
124 124 self.actions.append(NewLineAction('newline'))
125 125 yield g0
126 126 self.actions = []
127 127 else:
128 128 params = [ param for param in groups[1].split(';') if param ]
129 129 if g0.startswith('['):
130 130 # Case 1: CSI code.
131 131 try:
132 132 params = map(int, params)
133 133 except ValueError:
134 134 # Silently discard badly formed codes.
135 135 pass
136 136 else:
137 137 self.set_csi_code(groups[2], params)
138 138
139 139 elif g0.startswith(']'):
140 140 # Case 2: OSC code.
141 141 self.set_osc_code(params)
142 142
143 143 raw = string[start:]
144 144 substring = SPECIAL_PATTERN.sub(self._replace_special, raw)
145 145 if substring or self.actions:
146 146 yield substring
147 147
148 148 if last_char is not None:
149 149 self.actions.append(NewLineAction('newline'))
150 150 yield last_char
151 151
152 152 def set_csi_code(self, command, params=[]):
153 153 """ Set attributes based on CSI (Control Sequence Introducer) code.
154 154
155 155 Parameters
156 156 ----------
157 157 command : str
158 158 The code identifier, i.e. the final character in the sequence.
159 159
160 160 params : sequence of integers, optional
161 161 The parameter codes for the command.
162 162 """
163 163 if command == 'm': # SGR - Select Graphic Rendition
164 164 if params:
165 165 self.set_sgr_code(params)
166 166 else:
167 167 self.set_sgr_code([0])
168 168
169 169 elif (command == 'J' or # ED - Erase Data
170 170 command == 'K'): # EL - Erase in Line
171 171 code = params[0] if params else 0
172 172 if 0 <= code <= 2:
173 173 area = 'screen' if command == 'J' else 'line'
174 174 if code == 0:
175 175 erase_to = 'end'
176 176 elif code == 1:
177 177 erase_to = 'start'
178 178 elif code == 2:
179 179 erase_to = 'all'
180 180 self.actions.append(EraseAction('erase', area, erase_to))
181 181
182 182 elif (command == 'S' or # SU - Scroll Up
183 183 command == 'T'): # SD - Scroll Down
184 184 dir = 'up' if command == 'S' else 'down'
185 185 count = params[0] if params else 1
186 186 self.actions.append(ScrollAction('scroll', dir, 'line', count))
187 187
188 188 def set_osc_code(self, params):
189 189 """ Set attributes based on OSC (Operating System Command) parameters.
190 190
191 191 Parameters
192 192 ----------
193 193 params : sequence of str
194 194 The parameters for the command.
195 195 """
196 196 try:
197 197 command = int(params.pop(0))
198 198 except (IndexError, ValueError):
199 199 return
200 200
201 201 if command == 4:
202 202 # xterm-specific: set color number to color spec.
203 203 try:
204 204 color = int(params.pop(0))
205 205 spec = params.pop(0)
206 206 self.color_map[color] = self._parse_xterm_color_spec(spec)
207 207 except (IndexError, ValueError):
208 208 pass
209 209
210 210 def set_sgr_code(self, params):
211 211 """ Set attributes based on SGR (Select Graphic Rendition) codes.
212 212
213 213 Parameters
214 214 ----------
215 215 params : sequence of ints
216 216 A list of SGR codes for one or more SGR commands. Usually this
217 217 sequence will have one element per command, although certain
218 218 xterm-specific commands requires multiple elements.
219 219 """
220 220 # Always consume the first parameter.
221 221 if not params:
222 222 return
223 223 code = params.pop(0)
224 224
225 225 if code == 0:
226 226 self.reset_sgr()
227 227 elif code == 1:
228 228 if self.bold_text_enabled:
229 229 self.bold = True
230 230 else:
231 231 self.intensity = 1
232 232 elif code == 2:
233 233 self.intensity = 0
234 234 elif code == 3:
235 235 self.italic = True
236 236 elif code == 4:
237 237 self.underline = True
238 238 elif code == 22:
239 239 self.intensity = 0
240 240 self.bold = False
241 241 elif code == 23:
242 242 self.italic = False
243 243 elif code == 24:
244 244 self.underline = False
245 245 elif code >= 30 and code <= 37:
246 246 self.foreground_color = code - 30
247 247 elif code == 38 and params and params.pop(0) == 5:
248 248 # xterm-specific: 256 color support.
249 249 if params:
250 250 self.foreground_color = params.pop(0)
251 251 elif code == 39:
252 252 self.foreground_color = None
253 253 elif code >= 40 and code <= 47:
254 254 self.background_color = code - 40
255 255 elif code == 48 and params and params.pop(0) == 5:
256 256 # xterm-specific: 256 color support.
257 257 if params:
258 258 self.background_color = params.pop(0)
259 259 elif code == 49:
260 260 self.background_color = None
261 261
262 262 # Recurse with unconsumed parameters.
263 263 self.set_sgr_code(params)
264 264
265 265 #---------------------------------------------------------------------------
266 266 # Protected interface
267 267 #---------------------------------------------------------------------------
268 268
269 269 def _parse_xterm_color_spec(self, spec):
270 270 if spec.startswith('rgb:'):
271 271 return tuple(map(lambda x: int(x, 16), spec[4:].split('/')))
272 272 elif spec.startswith('rgbi:'):
273 273 return tuple(map(lambda x: int(float(x) * 255),
274 274 spec[5:].split('/')))
275 275 elif spec == '?':
276 276 raise ValueError('Unsupported xterm color spec')
277 277 return spec
278 278
279 279 def _replace_special(self, match):
280 280 special = match.group(1)
281 281 if special == '\f':
282 282 self.actions.append(ScrollAction('scroll', 'down', 'page', 1))
283 283 return ''
284 284
285 285
286 286 class QtAnsiCodeProcessor(AnsiCodeProcessor):
287 287 """ Translates ANSI escape codes into QTextCharFormats.
288 288 """
289 289
290 290 # A map from ANSI color codes to SVG color names or RGB(A) tuples.
291 291 darkbg_color_map = {
292 292 0 : 'black', # black
293 293 1 : 'darkred', # red
294 294 2 : 'darkgreen', # green
295 295 3 : 'brown', # yellow
296 296 4 : 'darkblue', # blue
297 297 5 : 'darkviolet', # magenta
298 298 6 : 'steelblue', # cyan
299 299 7 : 'grey', # white
300 300 8 : 'grey', # black (bright)
301 301 9 : 'red', # red (bright)
302 302 10 : 'lime', # green (bright)
303 303 11 : 'yellow', # yellow (bright)
304 304 12 : 'deepskyblue', # blue (bright)
305 305 13 : 'magenta', # magenta (bright)
306 306 14 : 'cyan', # cyan (bright)
307 307 15 : 'white' } # white (bright)
308 308
309 309 # Set the default color map for super class.
310 310 default_color_map = darkbg_color_map.copy()
311 311
312 312 def get_color(self, color, intensity=0):
313 313 """ Returns a QColor for a given color code, or None if one cannot be
314 314 constructed.
315 315 """
316 316 if color is None:
317 317 return None
318 318
319 319 # Adjust for intensity, if possible.
320 320 if color < 8 and intensity > 0:
321 321 color += 8
322 322
323 323 constructor = self.color_map.get(color, None)
324 324 if isinstance(constructor, string_types):
325 325 # If this is an X11 color name, we just hope there is a close SVG
326 326 # color name. We could use QColor's static method
327 327 # 'setAllowX11ColorNames()', but this is global and only available
328 328 # on X11. It seems cleaner to aim for uniformity of behavior.
329 329 return QtGui.QColor(constructor)
330 330
331 331 elif isinstance(constructor, (tuple, list)):
332 332 return QtGui.QColor(*constructor)
333 333
334 334 return None
335 335
336 336 def get_format(self):
337 337 """ Returns a QTextCharFormat that encodes the current style attributes.
338 338 """
339 339 format = QtGui.QTextCharFormat()
340 340
341 341 # Set foreground color
342 342 qcolor = self.get_color(self.foreground_color, self.intensity)
343 343 if qcolor is not None:
344 344 format.setForeground(qcolor)
345 345
346 346 # Set background color
347 347 qcolor = self.get_color(self.background_color, self.intensity)
348 348 if qcolor is not None:
349 349 format.setBackground(qcolor)
350 350
351 351 # Set font weight/style options
352 352 if self.bold:
353 353 format.setFontWeight(QtGui.QFont.Bold)
354 354 else:
355 355 format.setFontWeight(QtGui.QFont.Normal)
356 356 format.setFontItalic(self.italic)
357 357 format.setFontUnderline(self.underline)
358 358
359 359 return format
360 360
361 361 def set_background_color(self, color):
362 362 """ Given a background color (a QColor), attempt to set a color map
363 363 that will be aesthetically pleasing.
364 364 """
365 365 # Set a new default color map.
366 366 self.default_color_map = self.darkbg_color_map.copy()
367 367
368 368 if color.value() >= 127:
369 369 # Colors appropriate for a terminal with a light background. For
370 370 # now, only use non-bright colors...
371 for i in xrange(8):
371 for i in range(8):
372 372 self.default_color_map[i + 8] = self.default_color_map[i]
373 373
374 374 # ...and replace white with black.
375 375 self.default_color_map[7] = self.default_color_map[15] = 'black'
376 376
377 377 # Update the current color map with the new defaults.
378 378 self.color_map.update(self.default_color_map)
@@ -1,35 +1,37 b''
1 1 # encoding: utf-8
2 2 """Utilities for working with data structures like lists, dicts and tuples.
3 3 """
4 4
5 5 #-----------------------------------------------------------------------------
6 6 # Copyright (C) 2008-2011 The IPython Development Team
7 7 #
8 8 # Distributed under the terms of the BSD License. The full license is in
9 9 # the file COPYING, distributed as part of this software.
10 10 #-----------------------------------------------------------------------------
11 11
12 from .py3compat import xrange
13
12 14 def uniq_stable(elems):
13 15 """uniq_stable(elems) -> list
14 16
15 17 Return from an iterable, a list of all the unique elements in the input,
16 18 but maintaining the order in which they first appear.
17 19
18 20 Note: All elements in the input must be hashable for this routine
19 21 to work, as it internally uses a set for efficiency reasons.
20 22 """
21 23 seen = set()
22 24 return [x for x in elems if x not in seen and not seen.add(x)]
23 25
24 26
25 27 def flatten(seq):
26 28 """Flatten a list of lists (NOT recursive, only works for 2d lists)."""
27 29
28 30 return [x for subseq in seq for x in subseq]
29 31
30 32
31 33 def chop(seq, size):
32 34 """Chop a sequence into chunks of the given size."""
33 35 return [seq[i:i+size] for i in xrange(0,len(seq),size)]
34 36
35 37
@@ -1,207 +1,210 b''
1 1 # coding: utf-8
2 2 """Compatibility tricks for Python 3. Mainly to do with unicode."""
3 3 import functools
4 4 import sys
5 5 import re
6 6 import types
7 7
8 8 from .encoding import DEFAULT_ENCODING
9 9
10 10 orig_open = open
11 11
12 12 def no_code(x, encoding=None):
13 13 return x
14 14
15 15 def decode(s, encoding=None):
16 16 encoding = encoding or DEFAULT_ENCODING
17 17 return s.decode(encoding, "replace")
18 18
19 19 def encode(u, encoding=None):
20 20 encoding = encoding or DEFAULT_ENCODING
21 21 return u.encode(encoding, "replace")
22 22
23 23
24 24 def cast_unicode(s, encoding=None):
25 25 if isinstance(s, bytes):
26 26 return decode(s, encoding)
27 27 return s
28 28
29 29 def cast_bytes(s, encoding=None):
30 30 if not isinstance(s, bytes):
31 31 return encode(s, encoding)
32 32 return s
33 33
34 34 def _modify_str_or_docstring(str_change_func):
35 35 @functools.wraps(str_change_func)
36 36 def wrapper(func_or_str):
37 37 if isinstance(func_or_str, string_types):
38 38 func = None
39 39 doc = func_or_str
40 40 else:
41 41 func = func_or_str
42 42 doc = func.__doc__
43 43
44 44 doc = str_change_func(doc)
45 45
46 46 if func:
47 47 func.__doc__ = doc
48 48 return func
49 49 return doc
50 50 return wrapper
51 51
52 52 def safe_unicode(e):
53 53 """unicode(e) with various fallbacks. Used for exceptions, which may not be
54 54 safe to call unicode() on.
55 55 """
56 56 try:
57 57 return unicode_type(e)
58 58 except UnicodeError:
59 59 pass
60 60
61 61 try:
62 62 return str_to_unicode(str(e))
63 63 except UnicodeError:
64 64 pass
65 65
66 66 try:
67 67 return str_to_unicode(repr(e))
68 68 except UnicodeError:
69 69 pass
70 70
71 71 return u'Unrecoverably corrupt evalue'
72 72
73 73 if sys.version_info[0] >= 3:
74 74 PY3 = True
75 75
76 76 input = input
77 77 builtin_mod_name = "builtins"
78 78 import builtins as builtin_mod
79 79
80 80 str_to_unicode = no_code
81 81 unicode_to_str = no_code
82 82 str_to_bytes = encode
83 83 bytes_to_str = decode
84 84 cast_bytes_py2 = no_code
85 85
86 86 string_types = (str,)
87 87 unicode_type = str
88 88
89 89 def isidentifier(s, dotted=False):
90 90 if dotted:
91 91 return all(isidentifier(a) for a in s.split("."))
92 92 return s.isidentifier()
93 93
94 94 open = orig_open
95 xrange = range
95 96
96 97 MethodType = types.MethodType
97 98
98 99 def execfile(fname, glob, loc=None):
99 100 loc = loc if (loc is not None) else glob
100 101 with open(fname, 'rb') as f:
101 102 exec(compile(f.read(), fname, 'exec'), glob, loc)
102 103
103 104 # Refactor print statements in doctests.
104 105 _print_statement_re = re.compile(r"\bprint (?P<expr>.*)$", re.MULTILINE)
105 106 def _print_statement_sub(match):
106 107 expr = match.groups('expr')
107 108 return "print(%s)" % expr
108 109
109 110 @_modify_str_or_docstring
110 111 def doctest_refactor_print(doc):
111 112 """Refactor 'print x' statements in a doctest to print(x) style. 2to3
112 113 unfortunately doesn't pick up on our doctests.
113 114
114 115 Can accept a string or a function, so it can be used as a decorator."""
115 116 return _print_statement_re.sub(_print_statement_sub, doc)
116 117
117 118 # Abstract u'abc' syntax:
118 119 @_modify_str_or_docstring
119 120 def u_format(s):
120 121 """"{u}'abc'" --> "'abc'" (Python 3)
121 122
122 123 Accepts a string or a function, so it can be used as a decorator."""
123 124 return s.format(u='')
124 125
125 126 else:
126 127 PY3 = False
127 128
128 129 input = raw_input
129 130 builtin_mod_name = "__builtin__"
130 131 import __builtin__ as builtin_mod
131 132
132 133 str_to_unicode = decode
133 134 unicode_to_str = encode
134 135 str_to_bytes = no_code
135 136 bytes_to_str = no_code
136 137 cast_bytes_py2 = cast_bytes
137 138
138 139 string_types = (str, unicode)
139 140 unicode_type = unicode
140 141
141 142 import re
142 143 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
143 144 def isidentifier(s, dotted=False):
144 145 if dotted:
145 146 return all(isidentifier(a) for a in s.split("."))
146 147 return bool(_name_re.match(s))
147 148
148 149 class open(object):
149 150 """Wrapper providing key part of Python 3 open() interface."""
150 151 def __init__(self, fname, mode="r", encoding="utf-8"):
151 152 self.f = orig_open(fname, mode)
152 153 self.enc = encoding
153 154
154 155 def write(self, s):
155 156 return self.f.write(s.encode(self.enc))
156 157
157 158 def read(self, size=-1):
158 159 return self.f.read(size).decode(self.enc)
159 160
160 161 def close(self):
161 162 return self.f.close()
162 163
163 164 def __enter__(self):
164 165 return self
165 166
166 167 def __exit__(self, etype, value, traceback):
167 168 self.f.close()
168 169
170 xrange = xrange
171
169 172 def MethodType(func, instance):
170 173 return types.MethodType(func, instance, type(instance))
171 174
172 175 # don't override system execfile on 2.x:
173 176 execfile = execfile
174 177
175 178 def doctest_refactor_print(func_or_str):
176 179 return func_or_str
177 180
178 181
179 182 # Abstract u'abc' syntax:
180 183 @_modify_str_or_docstring
181 184 def u_format(s):
182 185 """"{u}'abc'" --> "u'abc'" (Python 2)
183 186
184 187 Accepts a string or a function, so it can be used as a decorator."""
185 188 return s.format(u='u')
186 189
187 190 if sys.platform == 'win32':
188 191 def execfile(fname, glob=None, loc=None):
189 192 loc = loc if (loc is not None) else glob
190 193 # The rstrip() is necessary b/c trailing whitespace in files will
191 194 # cause an IndentationError in Python 2.6 (this was fixed in 2.7,
192 195 # but we still support 2.6). See issue 1027.
193 196 scripttext = builtin_mod.open(fname).read().rstrip() + '\n'
194 197 # compile converts unicode filename to str assuming
195 198 # ascii. Let's do the conversion before calling compile
196 199 if isinstance(fname, unicode):
197 200 filename = unicode_to_str(fname)
198 201 else:
199 202 filename = fname
200 203 exec(compile(scripttext, filename, 'exec'), glob, loc)
201 204 else:
202 205 def execfile(fname, *where):
203 206 if isinstance(fname, unicode):
204 207 filename = fname.encode(sys.getfilesystemencoding())
205 208 else:
206 209 filename = fname
207 210 builtin_mod.execfile(filename, *where)
@@ -1,759 +1,758 b''
1 1 # encoding: utf-8
2 2 """
3 3 Utilities for working with strings and text.
4 4
5 5 Inheritance diagram:
6 6
7 7 .. inheritance-diagram:: IPython.utils.text
8 8 :parts: 3
9 9 """
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2008-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Imports
20 20 #-----------------------------------------------------------------------------
21 21
22 22 import os
23 23 import re
24 24 import sys
25 25 import textwrap
26 26 from string import Formatter
27 27
28 28 from IPython.external.path import path
29 29 from IPython.testing.skipdoctest import skip_doctest_py3, skip_doctest
30 30 from IPython.utils import py3compat
31 31
32
33 32 #-----------------------------------------------------------------------------
34 33 # Declarations
35 34 #-----------------------------------------------------------------------------
36 35
37 36 # datetime.strftime date format for ipython
38 37 if sys.platform == 'win32':
39 38 date_format = "%B %d, %Y"
40 39 else:
41 40 date_format = "%B %-d, %Y"
42 41
43 42
44 43 #-----------------------------------------------------------------------------
45 44 # Code
46 45 #-----------------------------------------------------------------------------
47 46
48 47 class LSString(str):
49 48 """String derivative with a special access attributes.
50 49
51 50 These are normal strings, but with the special attributes:
52 51
53 52 .l (or .list) : value as list (split on newlines).
54 53 .n (or .nlstr): original value (the string itself).
55 54 .s (or .spstr): value as whitespace-separated string.
56 55 .p (or .paths): list of path objects
57 56
58 57 Any values which require transformations are computed only once and
59 58 cached.
60 59
61 60 Such strings are very useful to efficiently interact with the shell, which
62 61 typically only understands whitespace-separated options for commands."""
63 62
64 63 def get_list(self):
65 64 try:
66 65 return self.__list
67 66 except AttributeError:
68 67 self.__list = self.split('\n')
69 68 return self.__list
70 69
71 70 l = list = property(get_list)
72 71
73 72 def get_spstr(self):
74 73 try:
75 74 return self.__spstr
76 75 except AttributeError:
77 76 self.__spstr = self.replace('\n',' ')
78 77 return self.__spstr
79 78
80 79 s = spstr = property(get_spstr)
81 80
82 81 def get_nlstr(self):
83 82 return self
84 83
85 84 n = nlstr = property(get_nlstr)
86 85
87 86 def get_paths(self):
88 87 try:
89 88 return self.__paths
90 89 except AttributeError:
91 90 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
92 91 return self.__paths
93 92
94 93 p = paths = property(get_paths)
95 94
96 95 # FIXME: We need to reimplement type specific displayhook and then add this
97 96 # back as a custom printer. This should also be moved outside utils into the
98 97 # core.
99 98
100 99 # def print_lsstring(arg):
101 100 # """ Prettier (non-repr-like) and more informative printer for LSString """
102 101 # print "LSString (.p, .n, .l, .s available). Value:"
103 102 # print arg
104 103 #
105 104 #
106 105 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
107 106
108 107
109 108 class SList(list):
110 109 """List derivative with a special access attributes.
111 110
112 111 These are normal lists, but with the special attributes:
113 112
114 113 .l (or .list) : value as list (the list itself).
115 114 .n (or .nlstr): value as a string, joined on newlines.
116 115 .s (or .spstr): value as a string, joined on spaces.
117 116 .p (or .paths): list of path objects
118 117
119 118 Any values which require transformations are computed only once and
120 119 cached."""
121 120
122 121 def get_list(self):
123 122 return self
124 123
125 124 l = list = property(get_list)
126 125
127 126 def get_spstr(self):
128 127 try:
129 128 return self.__spstr
130 129 except AttributeError:
131 130 self.__spstr = ' '.join(self)
132 131 return self.__spstr
133 132
134 133 s = spstr = property(get_spstr)
135 134
136 135 def get_nlstr(self):
137 136 try:
138 137 return self.__nlstr
139 138 except AttributeError:
140 139 self.__nlstr = '\n'.join(self)
141 140 return self.__nlstr
142 141
143 142 n = nlstr = property(get_nlstr)
144 143
145 144 def get_paths(self):
146 145 try:
147 146 return self.__paths
148 147 except AttributeError:
149 148 self.__paths = [path(p) for p in self if os.path.exists(p)]
150 149 return self.__paths
151 150
152 151 p = paths = property(get_paths)
153 152
154 153 def grep(self, pattern, prune = False, field = None):
155 154 """ Return all strings matching 'pattern' (a regex or callable)
156 155
157 156 This is case-insensitive. If prune is true, return all items
158 157 NOT matching the pattern.
159 158
160 159 If field is specified, the match must occur in the specified
161 160 whitespace-separated field.
162 161
163 162 Examples::
164 163
165 164 a.grep( lambda x: x.startswith('C') )
166 165 a.grep('Cha.*log', prune=1)
167 166 a.grep('chm', field=-1)
168 167 """
169 168
170 169 def match_target(s):
171 170 if field is None:
172 171 return s
173 172 parts = s.split()
174 173 try:
175 174 tgt = parts[field]
176 175 return tgt
177 176 except IndexError:
178 177 return ""
179 178
180 179 if isinstance(pattern, py3compat.string_types):
181 180 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
182 181 else:
183 182 pred = pattern
184 183 if not prune:
185 184 return SList([el for el in self if pred(match_target(el))])
186 185 else:
187 186 return SList([el for el in self if not pred(match_target(el))])
188 187
189 188 def fields(self, *fields):
190 189 """ Collect whitespace-separated fields from string list
191 190
192 191 Allows quick awk-like usage of string lists.
193 192
194 193 Example data (in var a, created by 'a = !ls -l')::
195 194 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
196 195 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
197 196
198 197 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
199 198 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
200 199 (note the joining by space).
201 200 a.fields(-1) is ['ChangeLog', 'IPython']
202 201
203 202 IndexErrors are ignored.
204 203
205 204 Without args, fields() just split()'s the strings.
206 205 """
207 206 if len(fields) == 0:
208 207 return [el.split() for el in self]
209 208
210 209 res = SList()
211 210 for el in [f.split() for f in self]:
212 211 lineparts = []
213 212
214 213 for fd in fields:
215 214 try:
216 215 lineparts.append(el[fd])
217 216 except IndexError:
218 217 pass
219 218 if lineparts:
220 219 res.append(" ".join(lineparts))
221 220
222 221 return res
223 222
224 223 def sort(self,field= None, nums = False):
225 224 """ sort by specified fields (see fields())
226 225
227 226 Example::
228 227 a.sort(1, nums = True)
229 228
230 229 Sorts a by second field, in numerical order (so that 21 > 3)
231 230
232 231 """
233 232
234 233 #decorate, sort, undecorate
235 234 if field is not None:
236 235 dsu = [[SList([line]).fields(field), line] for line in self]
237 236 else:
238 237 dsu = [[line, line] for line in self]
239 238 if nums:
240 239 for i in range(len(dsu)):
241 240 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
242 241 try:
243 242 n = int(numstr)
244 243 except ValueError:
245 244 n = 0;
246 245 dsu[i][0] = n
247 246
248 247
249 248 dsu.sort()
250 249 return SList([t[1] for t in dsu])
251 250
252 251
253 252 # FIXME: We need to reimplement type specific displayhook and then add this
254 253 # back as a custom printer. This should also be moved outside utils into the
255 254 # core.
256 255
257 256 # def print_slist(arg):
258 257 # """ Prettier (non-repr-like) and more informative printer for SList """
259 258 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
260 259 # if hasattr(arg, 'hideonce') and arg.hideonce:
261 260 # arg.hideonce = False
262 261 # return
263 262 #
264 263 # nlprint(arg) # This was a nested list printer, now removed.
265 264 #
266 265 # print_slist = result_display.when_type(SList)(print_slist)
267 266
268 267
269 268 def indent(instr,nspaces=4, ntabs=0, flatten=False):
270 269 """Indent a string a given number of spaces or tabstops.
271 270
272 271 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
273 272
274 273 Parameters
275 274 ----------
276 275
277 276 instr : basestring
278 277 The string to be indented.
279 278 nspaces : int (default: 4)
280 279 The number of spaces to be indented.
281 280 ntabs : int (default: 0)
282 281 The number of tabs to be indented.
283 282 flatten : bool (default: False)
284 283 Whether to scrub existing indentation. If True, all lines will be
285 284 aligned to the same indentation. If False, existing indentation will
286 285 be strictly increased.
287 286
288 287 Returns
289 288 -------
290 289
291 290 str|unicode : string indented by ntabs and nspaces.
292 291
293 292 """
294 293 if instr is None:
295 294 return
296 295 ind = '\t'*ntabs+' '*nspaces
297 296 if flatten:
298 297 pat = re.compile(r'^\s*', re.MULTILINE)
299 298 else:
300 299 pat = re.compile(r'^', re.MULTILINE)
301 300 outstr = re.sub(pat, ind, instr)
302 301 if outstr.endswith(os.linesep+ind):
303 302 return outstr[:-len(ind)]
304 303 else:
305 304 return outstr
306 305
307 306
308 307 def list_strings(arg):
309 308 """Always return a list of strings, given a string or list of strings
310 309 as input.
311 310
312 311 :Examples:
313 312
314 313 In [7]: list_strings('A single string')
315 314 Out[7]: ['A single string']
316 315
317 316 In [8]: list_strings(['A single string in a list'])
318 317 Out[8]: ['A single string in a list']
319 318
320 319 In [9]: list_strings(['A','list','of','strings'])
321 320 Out[9]: ['A', 'list', 'of', 'strings']
322 321 """
323 322
324 323 if isinstance(arg, py3compat.string_types): return [arg]
325 324 else: return arg
326 325
327 326
328 327 def marquee(txt='',width=78,mark='*'):
329 328 """Return the input string centered in a 'marquee'.
330 329
331 330 :Examples:
332 331
333 332 In [16]: marquee('A test',40)
334 333 Out[16]: '**************** A test ****************'
335 334
336 335 In [17]: marquee('A test',40,'-')
337 336 Out[17]: '---------------- A test ----------------'
338 337
339 338 In [18]: marquee('A test',40,' ')
340 339 Out[18]: ' A test '
341 340
342 341 """
343 342 if not txt:
344 343 return (mark*width)[:width]
345 344 nmark = (width-len(txt)-2)//len(mark)//2
346 345 if nmark < 0: nmark =0
347 346 marks = mark*nmark
348 347 return '%s %s %s' % (marks,txt,marks)
349 348
350 349
351 350 ini_spaces_re = re.compile(r'^(\s+)')
352 351
353 352 def num_ini_spaces(strng):
354 353 """Return the number of initial spaces in a string"""
355 354
356 355 ini_spaces = ini_spaces_re.match(strng)
357 356 if ini_spaces:
358 357 return ini_spaces.end()
359 358 else:
360 359 return 0
361 360
362 361
363 362 def format_screen(strng):
364 363 """Format a string for screen printing.
365 364
366 365 This removes some latex-type format codes."""
367 366 # Paragraph continue
368 367 par_re = re.compile(r'\\$',re.MULTILINE)
369 368 strng = par_re.sub('',strng)
370 369 return strng
371 370
372 371
373 372 def dedent(text):
374 373 """Equivalent of textwrap.dedent that ignores unindented first line.
375 374
376 375 This means it will still dedent strings like:
377 376 '''foo
378 377 is a bar
379 378 '''
380 379
381 380 For use in wrap_paragraphs.
382 381 """
383 382
384 383 if text.startswith('\n'):
385 384 # text starts with blank line, don't ignore the first line
386 385 return textwrap.dedent(text)
387 386
388 387 # split first line
389 388 splits = text.split('\n',1)
390 389 if len(splits) == 1:
391 390 # only one line
392 391 return textwrap.dedent(text)
393 392
394 393 first, rest = splits
395 394 # dedent everything but the first line
396 395 rest = textwrap.dedent(rest)
397 396 return '\n'.join([first, rest])
398 397
399 398
400 399 def wrap_paragraphs(text, ncols=80):
401 400 """Wrap multiple paragraphs to fit a specified width.
402 401
403 402 This is equivalent to textwrap.wrap, but with support for multiple
404 403 paragraphs, as separated by empty lines.
405 404
406 405 Returns
407 406 -------
408 407
409 408 list of complete paragraphs, wrapped to fill `ncols` columns.
410 409 """
411 410 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
412 411 text = dedent(text).strip()
413 412 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
414 413 out_ps = []
415 414 indent_re = re.compile(r'\n\s+', re.MULTILINE)
416 415 for p in paragraphs:
417 416 # presume indentation that survives dedent is meaningful formatting,
418 417 # so don't fill unless text is flush.
419 418 if indent_re.search(p) is None:
420 419 # wrap paragraph
421 420 p = textwrap.fill(p, ncols)
422 421 out_ps.append(p)
423 422 return out_ps
424 423
425 424
426 425 def long_substr(data):
427 426 """Return the longest common substring in a list of strings.
428 427
429 428 Credit: http://stackoverflow.com/questions/2892931/longest-common-substring-from-more-than-two-strings-python
430 429 """
431 430 substr = ''
432 431 if len(data) > 1 and len(data[0]) > 0:
433 432 for i in range(len(data[0])):
434 433 for j in range(len(data[0])-i+1):
435 434 if j > len(substr) and all(data[0][i:i+j] in x for x in data):
436 435 substr = data[0][i:i+j]
437 436 elif len(data) == 1:
438 437 substr = data[0]
439 438 return substr
440 439
441 440
442 441 def strip_email_quotes(text):
443 442 """Strip leading email quotation characters ('>').
444 443
445 444 Removes any combination of leading '>' interspersed with whitespace that
446 445 appears *identically* in all lines of the input text.
447 446
448 447 Parameters
449 448 ----------
450 449 text : str
451 450
452 451 Examples
453 452 --------
454 453
455 454 Simple uses::
456 455
457 456 In [2]: strip_email_quotes('> > text')
458 457 Out[2]: 'text'
459 458
460 459 In [3]: strip_email_quotes('> > text\\n> > more')
461 460 Out[3]: 'text\\nmore'
462 461
463 462 Note how only the common prefix that appears in all lines is stripped::
464 463
465 464 In [4]: strip_email_quotes('> > text\\n> > more\\n> more...')
466 465 Out[4]: '> text\\n> more\\nmore...'
467 466
468 467 So if any line has no quote marks ('>') , then none are stripped from any
469 468 of them ::
470 469
471 470 In [5]: strip_email_quotes('> > text\\n> > more\\nlast different')
472 471 Out[5]: '> > text\\n> > more\\nlast different'
473 472 """
474 473 lines = text.splitlines()
475 474 matches = set()
476 475 for line in lines:
477 476 prefix = re.match(r'^(\s*>[ >]*)', line)
478 477 if prefix:
479 478 matches.add(prefix.group(1))
480 479 else:
481 480 break
482 481 else:
483 482 prefix = long_substr(list(matches))
484 483 if prefix:
485 484 strip = len(prefix)
486 485 text = '\n'.join([ ln[strip:] for ln in lines])
487 486 return text
488 487
489 488
490 489 class EvalFormatter(Formatter):
491 490 """A String Formatter that allows evaluation of simple expressions.
492 491
493 492 Note that this version interprets a : as specifying a format string (as per
494 493 standard string formatting), so if slicing is required, you must explicitly
495 494 create a slice.
496 495
497 496 This is to be used in templating cases, such as the parallel batch
498 497 script templates, where simple arithmetic on arguments is useful.
499 498
500 499 Examples
501 500 --------
502 501
503 502 In [1]: f = EvalFormatter()
504 503 In [2]: f.format('{n//4}', n=8)
505 504 Out [2]: '2'
506 505
507 506 In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
508 507 Out [3]: 'll'
509 508 """
510 509 def get_field(self, name, args, kwargs):
511 510 v = eval(name, kwargs)
512 511 return v, name
513 512
514 513
515 514 @skip_doctest_py3
516 515 class FullEvalFormatter(Formatter):
517 516 """A String Formatter that allows evaluation of simple expressions.
518 517
519 518 Any time a format key is not found in the kwargs,
520 519 it will be tried as an expression in the kwargs namespace.
521 520
522 521 Note that this version allows slicing using [1:2], so you cannot specify
523 522 a format string. Use :class:`EvalFormatter` to permit format strings.
524 523
525 524 Examples
526 525 --------
527 526
528 527 In [1]: f = FullEvalFormatter()
529 528 In [2]: f.format('{n//4}', n=8)
530 529 Out[2]: u'2'
531 530
532 531 In [3]: f.format('{list(range(5))[2:4]}')
533 532 Out[3]: u'[2, 3]'
534 533
535 534 In [4]: f.format('{3*2}')
536 535 Out[4]: u'6'
537 536 """
538 537 # copied from Formatter._vformat with minor changes to allow eval
539 538 # and replace the format_spec code with slicing
540 539 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
541 540 if recursion_depth < 0:
542 541 raise ValueError('Max string recursion exceeded')
543 542 result = []
544 543 for literal_text, field_name, format_spec, conversion in \
545 544 self.parse(format_string):
546 545
547 546 # output the literal text
548 547 if literal_text:
549 548 result.append(literal_text)
550 549
551 550 # if there's a field, output it
552 551 if field_name is not None:
553 552 # this is some markup, find the object and do
554 553 # the formatting
555 554
556 555 if format_spec:
557 556 # override format spec, to allow slicing:
558 557 field_name = ':'.join([field_name, format_spec])
559 558
560 559 # eval the contents of the field for the object
561 560 # to be formatted
562 561 obj = eval(field_name, kwargs)
563 562
564 563 # do any conversion on the resulting object
565 564 obj = self.convert_field(obj, conversion)
566 565
567 566 # format the object and append to the result
568 567 result.append(self.format_field(obj, ''))
569 568
570 569 return u''.join(py3compat.cast_unicode(s) for s in result)
571 570
572 571
573 572 @skip_doctest_py3
574 573 class DollarFormatter(FullEvalFormatter):
575 574 """Formatter allowing Itpl style $foo replacement, for names and attribute
576 575 access only. Standard {foo} replacement also works, and allows full
577 576 evaluation of its arguments.
578 577
579 578 Examples
580 579 --------
581 580 In [1]: f = DollarFormatter()
582 581 In [2]: f.format('{n//4}', n=8)
583 582 Out[2]: u'2'
584 583
585 584 In [3]: f.format('23 * 76 is $result', result=23*76)
586 585 Out[3]: u'23 * 76 is 1748'
587 586
588 587 In [4]: f.format('$a or {b}', a=1, b=2)
589 588 Out[4]: u'1 or 2'
590 589 """
591 590 _dollar_pattern = re.compile("(.*?)\$(\$?[\w\.]+)")
592 591 def parse(self, fmt_string):
593 592 for literal_txt, field_name, format_spec, conversion \
594 593 in Formatter.parse(self, fmt_string):
595 594
596 595 # Find $foo patterns in the literal text.
597 596 continue_from = 0
598 597 txt = ""
599 598 for m in self._dollar_pattern.finditer(literal_txt):
600 599 new_txt, new_field = m.group(1,2)
601 600 # $$foo --> $foo
602 601 if new_field.startswith("$"):
603 602 txt += new_txt + new_field
604 603 else:
605 604 yield (txt + new_txt, new_field, "", None)
606 605 txt = ""
607 606 continue_from = m.end()
608 607
609 608 # Re-yield the {foo} style pattern
610 609 yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
611 610
612 611 #-----------------------------------------------------------------------------
613 612 # Utils to columnize a list of string
614 613 #-----------------------------------------------------------------------------
615 614
616 615 def _chunks(l, n):
617 616 """Yield successive n-sized chunks from l."""
618 for i in xrange(0, len(l), n):
617 for i in py3compat.xrange(0, len(l), n):
619 618 yield l[i:i+n]
620 619
621 620
622 621 def _find_optimal(rlist , separator_size=2 , displaywidth=80):
623 622 """Calculate optimal info to columnize a list of string"""
624 623 for nrow in range(1, len(rlist)+1) :
625 624 chk = map(max,_chunks(rlist, nrow))
626 625 sumlength = sum(chk)
627 626 ncols = len(chk)
628 627 if sumlength+separator_size*(ncols-1) <= displaywidth :
629 628 break;
630 629 return {'columns_numbers' : ncols,
631 630 'optimal_separator_width':(displaywidth - sumlength)/(ncols-1) if (ncols -1) else 0,
632 631 'rows_numbers' : nrow,
633 632 'columns_width' : chk
634 633 }
635 634
636 635
637 636 def _get_or_default(mylist, i, default=None):
638 637 """return list item number, or default if don't exist"""
639 638 if i >= len(mylist):
640 639 return default
641 640 else :
642 641 return mylist[i]
643 642
644 643
645 644 @skip_doctest
646 645 def compute_item_matrix(items, empty=None, *args, **kwargs) :
647 646 """Returns a nested list, and info to columnize items
648 647
649 648 Parameters
650 649 ----------
651 650
652 651 items :
653 652 list of strings to columize
654 653 empty : (default None)
655 654 default value to fill list if needed
656 655 separator_size : int (default=2)
657 656 How much caracters will be used as a separation between each columns.
658 657 displaywidth : int (default=80)
659 658 The width of the area onto wich the columns should enter
660 659
661 660 Returns
662 661 -------
663 662
664 663 Returns a tuple of (strings_matrix, dict_info)
665 664
666 665 strings_matrix :
667 666
668 667 nested list of string, the outer most list contains as many list as
669 668 rows, the innermost lists have each as many element as colums. If the
670 669 total number of elements in `items` does not equal the product of
671 670 rows*columns, the last element of some lists are filled with `None`.
672 671
673 672 dict_info :
674 673 some info to make columnize easier:
675 674
676 675 columns_numbers : number of columns
677 676 rows_numbers : number of rows
678 677 columns_width : list of with of each columns
679 678 optimal_separator_width : best separator width between columns
680 679
681 680 Examples
682 681 --------
683 682
684 683 In [1]: l = ['aaa','b','cc','d','eeeee','f','g','h','i','j','k','l']
685 684 ...: compute_item_matrix(l,displaywidth=12)
686 685 Out[1]:
687 686 ([['aaa', 'f', 'k'],
688 687 ['b', 'g', 'l'],
689 688 ['cc', 'h', None],
690 689 ['d', 'i', None],
691 690 ['eeeee', 'j', None]],
692 691 {'columns_numbers': 3,
693 692 'columns_width': [5, 1, 1],
694 693 'optimal_separator_width': 2,
695 694 'rows_numbers': 5})
696 695
697 696 """
698 697 info = _find_optimal(map(len, items), *args, **kwargs)
699 698 nrow, ncol = info['rows_numbers'], info['columns_numbers']
700 699 return ([[ _get_or_default(items, c*nrow+i, default=empty) for c in range(ncol) ] for i in range(nrow) ], info)
701 700
702 701
703 702 def columnize(items, separator=' ', displaywidth=80):
704 703 """ Transform a list of strings into a single string with columns.
705 704
706 705 Parameters
707 706 ----------
708 707 items : sequence of strings
709 708 The strings to process.
710 709
711 710 separator : str, optional [default is two spaces]
712 711 The string that separates columns.
713 712
714 713 displaywidth : int, optional [default is 80]
715 714 Width of the display in number of characters.
716 715
717 716 Returns
718 717 -------
719 718 The formatted string.
720 719 """
721 720 if not items :
722 721 return '\n'
723 722 matrix, info = compute_item_matrix(items, separator_size=len(separator), displaywidth=displaywidth)
724 723 fmatrix = [filter(None, x) for x in matrix]
725 724 sjoin = lambda x : separator.join([ y.ljust(w, ' ') for y, w in zip(x, info['columns_width'])])
726 725 return '\n'.join(map(sjoin, fmatrix))+'\n'
727 726
728 727
729 728 def get_text_list(list_, last_sep=' and ', sep=", ", wrap_item_with=""):
730 729 """
731 730 Return a string with a natural enumeration of items
732 731
733 732 >>> get_text_list(['a', 'b', 'c', 'd'])
734 733 'a, b, c and d'
735 734 >>> get_text_list(['a', 'b', 'c'], ' or ')
736 735 'a, b or c'
737 736 >>> get_text_list(['a', 'b', 'c'], ', ')
738 737 'a, b, c'
739 738 >>> get_text_list(['a', 'b'], ' or ')
740 739 'a or b'
741 740 >>> get_text_list(['a'])
742 741 'a'
743 742 >>> get_text_list([])
744 743 ''
745 744 >>> get_text_list(['a', 'b'], wrap_item_with="`")
746 745 '`a` and `b`'
747 746 >>> get_text_list(['a', 'b', 'c', 'd'], " = ", sep=" + ")
748 747 'a + b + c = d'
749 748 """
750 749 if len(list_) == 0:
751 750 return ''
752 751 if wrap_item_with:
753 752 list_ = ['%s%s%s' % (wrap_item_with, item, wrap_item_with) for
754 753 item in list_]
755 754 if len(list_) == 1:
756 755 return list_[0]
757 756 return '%s%s%s' % (
758 757 sep.join(i for i in list_[:-1]),
759 758 last_sep, list_[-1]) No newline at end of file
@@ -1,116 +1,118 b''
1 1 # encoding: utf-8
2 2 """
3 3 Utilities for timing code execution.
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2008-2011 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16
17 17 import time
18 18
19 from .py3compat import xrange
20
19 21 #-----------------------------------------------------------------------------
20 22 # Code
21 23 #-----------------------------------------------------------------------------
22 24
23 25 # If possible (Unix), use the resource module instead of time.clock()
24 26 try:
25 27 import resource
26 28 def clocku():
27 29 """clocku() -> floating point number
28 30
29 31 Return the *USER* CPU time in seconds since the start of the process.
30 32 This is done via a call to resource.getrusage, so it avoids the
31 33 wraparound problems in time.clock()."""
32 34
33 35 return resource.getrusage(resource.RUSAGE_SELF)[0]
34 36
35 37 def clocks():
36 38 """clocks() -> floating point number
37 39
38 40 Return the *SYSTEM* CPU time in seconds since the start of the process.
39 41 This is done via a call to resource.getrusage, so it avoids the
40 42 wraparound problems in time.clock()."""
41 43
42 44 return resource.getrusage(resource.RUSAGE_SELF)[1]
43 45
44 46 def clock():
45 47 """clock() -> floating point number
46 48
47 49 Return the *TOTAL USER+SYSTEM* CPU time in seconds since the start of
48 50 the process. This is done via a call to resource.getrusage, so it
49 51 avoids the wraparound problems in time.clock()."""
50 52
51 53 u,s = resource.getrusage(resource.RUSAGE_SELF)[:2]
52 54 return u+s
53 55
54 56 def clock2():
55 57 """clock2() -> (t_user,t_system)
56 58
57 59 Similar to clock(), but return a tuple of user/system times."""
58 60 return resource.getrusage(resource.RUSAGE_SELF)[:2]
59 61 except ImportError:
60 62 # There is no distinction of user/system time under windows, so we just use
61 63 # time.clock() for everything...
62 64 clocku = clocks = clock = time.clock
63 65 def clock2():
64 66 """Under windows, system CPU time can't be measured.
65 67
66 68 This just returns clock() and zero."""
67 69 return time.clock(),0.0
68 70
69 71
70 72 def timings_out(reps,func,*args,**kw):
71 73 """timings_out(reps,func,*args,**kw) -> (t_total,t_per_call,output)
72 74
73 75 Execute a function reps times, return a tuple with the elapsed total
74 76 CPU time in seconds, the time per call and the function's output.
75 77
76 78 Under Unix, the return value is the sum of user+system time consumed by
77 79 the process, computed via the resource module. This prevents problems
78 80 related to the wraparound effect which the time.clock() function has.
79 81
80 82 Under Windows the return value is in wall clock seconds. See the
81 83 documentation for the time module for more details."""
82 84
83 85 reps = int(reps)
84 86 assert reps >=1, 'reps must be >= 1'
85 87 if reps==1:
86 88 start = clock()
87 89 out = func(*args,**kw)
88 90 tot_time = clock()-start
89 91 else:
90 92 rng = xrange(reps-1) # the last time is executed separately to store output
91 93 start = clock()
92 94 for dummy in rng: func(*args,**kw)
93 95 out = func(*args,**kw) # one last time
94 96 tot_time = clock()-start
95 97 av_time = tot_time / reps
96 98 return tot_time,av_time,out
97 99
98 100
99 101 def timings(reps,func,*args,**kw):
100 102 """timings(reps,func,*args,**kw) -> (t_total,t_per_call)
101 103
102 104 Execute a function reps times, return a tuple with the elapsed total CPU
103 105 time in seconds and the time per call. These are just the first two values
104 106 in timings_out()."""
105 107
106 108 return timings_out(reps,func,*args,**kw)[0:2]
107 109
108 110
109 111 def timing(func,*args,**kw):
110 112 """timing(func,*args,**kw) -> t_total
111 113
112 114 Execute a function once, return the elapsed total CPU time in
113 115 seconds. This is just the first value in timings_out()."""
114 116
115 117 return timings_out(1,func,*args,**kw)[0]
116 118
General Comments 0
You need to be logged in to leave comments. Login now