##// END OF EJS Templates
Update function attribute names...
Thomas Kluyver -
Show More
@@ -1,690 +1,690 b''
1 1 # encoding: utf-8
2 2 """Magic functions for InteractiveShell.
3 3 """
4 4 from __future__ import print_function
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2001 Janko Hauser <jhauser@zscout.de> and
8 8 # Copyright (C) 2001 Fernando Perez <fperez@colorado.edu>
9 9 # Copyright (C) 2008 The IPython Development Team
10 10
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18 # Stdlib
19 19 import os
20 20 import re
21 21 import sys
22 22 import types
23 23 from getopt import getopt, GetoptError
24 24
25 25 # Our own
26 26 from IPython.config.configurable import Configurable
27 27 from IPython.core import oinspect
28 28 from IPython.core.error import UsageError
29 29 from IPython.core.inputsplitter import ESC_MAGIC, ESC_MAGIC2
30 30 from IPython.external.decorator import decorator
31 31 from IPython.utils.ipstruct import Struct
32 32 from IPython.utils.process import arg_split
33 33 from IPython.utils.py3compat import string_types, iteritems
34 34 from IPython.utils.text import dedent
35 35 from IPython.utils.traitlets import Bool, Dict, Instance, MetaHasTraits
36 36 from IPython.utils.warn import error
37 37
38 38 #-----------------------------------------------------------------------------
39 39 # Globals
40 40 #-----------------------------------------------------------------------------
41 41
42 42 # A dict we'll use for each class that has magics, used as temporary storage to
43 43 # pass information between the @line/cell_magic method decorators and the
44 44 # @magics_class class decorator, because the method decorators have no
45 45 # access to the class when they run. See for more details:
46 46 # http://stackoverflow.com/questions/2366713/can-a-python-decorator-of-an-instance-method-access-the-class
47 47
48 48 magics = dict(line={}, cell={})
49 49
50 50 magic_kinds = ('line', 'cell')
51 51 magic_spec = ('line', 'cell', 'line_cell')
52 52 magic_escapes = dict(line=ESC_MAGIC, cell=ESC_MAGIC2)
53 53
54 54 #-----------------------------------------------------------------------------
55 55 # Utility classes and functions
56 56 #-----------------------------------------------------------------------------
57 57
58 58 class Bunch: pass
59 59
60 60
61 61 def on_off(tag):
62 62 """Return an ON/OFF string for a 1/0 input. Simple utility function."""
63 63 return ['OFF','ON'][tag]
64 64
65 65
66 66 def compress_dhist(dh):
67 67 """Compress a directory history into a new one with at most 20 entries.
68 68
69 69 Return a new list made from the first and last 10 elements of dhist after
70 70 removal of duplicates.
71 71 """
72 72 head, tail = dh[:-10], dh[-10:]
73 73
74 74 newhead = []
75 75 done = set()
76 76 for h in head:
77 77 if h in done:
78 78 continue
79 79 newhead.append(h)
80 80 done.add(h)
81 81
82 82 return newhead + tail
83 83
84 84
85 85 def needs_local_scope(func):
86 86 """Decorator to mark magic functions which need to local scope to run."""
87 87 func.needs_local_scope = True
88 88 return func
89 89
90 90 #-----------------------------------------------------------------------------
91 91 # Class and method decorators for registering magics
92 92 #-----------------------------------------------------------------------------
93 93
94 94 def magics_class(cls):
95 95 """Class decorator for all subclasses of the main Magics class.
96 96
97 97 Any class that subclasses Magics *must* also apply this decorator, to
98 98 ensure that all the methods that have been decorated as line/cell magics
99 99 get correctly registered in the class instance. This is necessary because
100 100 when method decorators run, the class does not exist yet, so they
101 101 temporarily store their information into a module global. Application of
102 102 this class decorator copies that global data to the class instance and
103 103 clears the global.
104 104
105 105 Obviously, this mechanism is not thread-safe, which means that the
106 106 *creation* of subclasses of Magic should only be done in a single-thread
107 107 context. Instantiation of the classes has no restrictions. Given that
108 108 these classes are typically created at IPython startup time and before user
109 109 application code becomes active, in practice this should not pose any
110 110 problems.
111 111 """
112 112 cls.registered = True
113 113 cls.magics = dict(line = magics['line'],
114 114 cell = magics['cell'])
115 115 magics['line'] = {}
116 116 magics['cell'] = {}
117 117 return cls
118 118
119 119
120 120 def record_magic(dct, magic_kind, magic_name, func):
121 121 """Utility function to store a function as a magic of a specific kind.
122 122
123 123 Parameters
124 124 ----------
125 125 dct : dict
126 126 A dictionary with 'line' and 'cell' subdicts.
127 127
128 128 magic_kind : str
129 129 Kind of magic to be stored.
130 130
131 131 magic_name : str
132 132 Key to store the magic as.
133 133
134 134 func : function
135 135 Callable object to store.
136 136 """
137 137 if magic_kind == 'line_cell':
138 138 dct['line'][magic_name] = dct['cell'][magic_name] = func
139 139 else:
140 140 dct[magic_kind][magic_name] = func
141 141
142 142
143 143 def validate_type(magic_kind):
144 144 """Ensure that the given magic_kind is valid.
145 145
146 146 Check that the given magic_kind is one of the accepted spec types (stored
147 147 in the global `magic_spec`), raise ValueError otherwise.
148 148 """
149 149 if magic_kind not in magic_spec:
150 150 raise ValueError('magic_kind must be one of %s, %s given' %
151 151 magic_kinds, magic_kind)
152 152
153 153
154 154 # The docstrings for the decorator below will be fairly similar for the two
155 155 # types (method and function), so we generate them here once and reuse the
156 156 # templates below.
157 157 _docstring_template = \
158 158 """Decorate the given {0} as {1} magic.
159 159
160 160 The decorator can be used with or without arguments, as follows.
161 161
162 162 i) without arguments: it will create a {1} magic named as the {0} being
163 163 decorated::
164 164
165 165 @deco
166 166 def foo(...)
167 167
168 168 will create a {1} magic named `foo`.
169 169
170 170 ii) with one string argument: which will be used as the actual name of the
171 171 resulting magic::
172 172
173 173 @deco('bar')
174 174 def foo(...)
175 175
176 176 will create a {1} magic named `bar`.
177 177 """
178 178
179 179 # These two are decorator factories. While they are conceptually very similar,
180 180 # there are enough differences in the details that it's simpler to have them
181 181 # written as completely standalone functions rather than trying to share code
182 182 # and make a single one with convoluted logic.
183 183
184 184 def _method_magic_marker(magic_kind):
185 185 """Decorator factory for methods in Magics subclasses.
186 186 """
187 187
188 188 validate_type(magic_kind)
189 189
190 190 # This is a closure to capture the magic_kind. We could also use a class,
191 191 # but it's overkill for just that one bit of state.
192 192 def magic_deco(arg):
193 193 call = lambda f, *a, **k: f(*a, **k)
194 194
195 195 if callable(arg):
196 196 # "Naked" decorator call (just @foo, no args)
197 197 func = arg
198 name = func.func_name
198 name = func.__name__
199 199 retval = decorator(call, func)
200 200 record_magic(magics, magic_kind, name, name)
201 201 elif isinstance(arg, string_types):
202 202 # Decorator called with arguments (@foo('bar'))
203 203 name = arg
204 204 def mark(func, *a, **kw):
205 record_magic(magics, magic_kind, name, func.func_name)
205 record_magic(magics, magic_kind, name, func.__name__)
206 206 return decorator(call, func)
207 207 retval = mark
208 208 else:
209 209 raise TypeError("Decorator can only be called with "
210 210 "string or function")
211 211 return retval
212 212
213 213 # Ensure the resulting decorator has a usable docstring
214 214 magic_deco.__doc__ = _docstring_template.format('method', magic_kind)
215 215 return magic_deco
216 216
217 217
218 218 def _function_magic_marker(magic_kind):
219 219 """Decorator factory for standalone functions.
220 220 """
221 221 validate_type(magic_kind)
222 222
223 223 # This is a closure to capture the magic_kind. We could also use a class,
224 224 # but it's overkill for just that one bit of state.
225 225 def magic_deco(arg):
226 226 call = lambda f, *a, **k: f(*a, **k)
227 227
228 228 # Find get_ipython() in the caller's namespace
229 229 caller = sys._getframe(1)
230 230 for ns in ['f_locals', 'f_globals', 'f_builtins']:
231 231 get_ipython = getattr(caller, ns).get('get_ipython')
232 232 if get_ipython is not None:
233 233 break
234 234 else:
235 235 raise NameError('Decorator can only run in context where '
236 236 '`get_ipython` exists')
237 237
238 238 ip = get_ipython()
239 239
240 240 if callable(arg):
241 241 # "Naked" decorator call (just @foo, no args)
242 242 func = arg
243 name = func.func_name
243 name = func.__name__
244 244 ip.register_magic_function(func, magic_kind, name)
245 245 retval = decorator(call, func)
246 246 elif isinstance(arg, string_types):
247 247 # Decorator called with arguments (@foo('bar'))
248 248 name = arg
249 249 def mark(func, *a, **kw):
250 250 ip.register_magic_function(func, magic_kind, name)
251 251 return decorator(call, func)
252 252 retval = mark
253 253 else:
254 254 raise TypeError("Decorator can only be called with "
255 255 "string or function")
256 256 return retval
257 257
258 258 # Ensure the resulting decorator has a usable docstring
259 259 ds = _docstring_template.format('function', magic_kind)
260 260
261 261 ds += dedent("""
262 262 Note: this decorator can only be used in a context where IPython is already
263 263 active, so that the `get_ipython()` call succeeds. You can therefore use
264 264 it in your startup files loaded after IPython initializes, but *not* in the
265 265 IPython configuration file itself, which is executed before IPython is
266 266 fully up and running. Any file located in the `startup` subdirectory of
267 267 your configuration profile will be OK in this sense.
268 268 """)
269 269
270 270 magic_deco.__doc__ = ds
271 271 return magic_deco
272 272
273 273
274 274 # Create the actual decorators for public use
275 275
276 276 # These three are used to decorate methods in class definitions
277 277 line_magic = _method_magic_marker('line')
278 278 cell_magic = _method_magic_marker('cell')
279 279 line_cell_magic = _method_magic_marker('line_cell')
280 280
281 281 # These three decorate standalone functions and perform the decoration
282 282 # immediately. They can only run where get_ipython() works
283 283 register_line_magic = _function_magic_marker('line')
284 284 register_cell_magic = _function_magic_marker('cell')
285 285 register_line_cell_magic = _function_magic_marker('line_cell')
286 286
287 287 #-----------------------------------------------------------------------------
288 288 # Core Magic classes
289 289 #-----------------------------------------------------------------------------
290 290
291 291 class MagicsManager(Configurable):
292 292 """Object that handles all magic-related functionality for IPython.
293 293 """
294 294 # Non-configurable class attributes
295 295
296 296 # A two-level dict, first keyed by magic type, then by magic function, and
297 297 # holding the actual callable object as value. This is the dict used for
298 298 # magic function dispatch
299 299 magics = Dict
300 300
301 301 # A registry of the original objects that we've been given holding magics.
302 302 registry = Dict
303 303
304 304 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
305 305
306 306 auto_magic = Bool(True, config=True, help=
307 307 "Automatically call line magics without requiring explicit % prefix")
308 308
309 309 def _auto_magic_changed(self, name, value):
310 310 self.shell.automagic = value
311 311
312 312 _auto_status = [
313 313 'Automagic is OFF, % prefix IS needed for line magics.',
314 314 'Automagic is ON, % prefix IS NOT needed for line magics.']
315 315
316 316 user_magics = Instance('IPython.core.magics.UserMagics')
317 317
318 318 def __init__(self, shell=None, config=None, user_magics=None, **traits):
319 319
320 320 super(MagicsManager, self).__init__(shell=shell, config=config,
321 321 user_magics=user_magics, **traits)
322 322 self.magics = dict(line={}, cell={})
323 323 # Let's add the user_magics to the registry for uniformity, so *all*
324 324 # registered magic containers can be found there.
325 325 self.registry[user_magics.__class__.__name__] = user_magics
326 326
327 327 def auto_status(self):
328 328 """Return descriptive string with automagic status."""
329 329 return self._auto_status[self.auto_magic]
330 330
331 331 def lsmagic(self):
332 332 """Return a dict of currently available magic functions.
333 333
334 334 The return dict has the keys 'line' and 'cell', corresponding to the
335 335 two types of magics we support. Each value is a list of names.
336 336 """
337 337 return self.magics
338 338
339 339 def lsmagic_docs(self, brief=False, missing=''):
340 340 """Return dict of documentation of magic functions.
341 341
342 342 The return dict has the keys 'line' and 'cell', corresponding to the
343 343 two types of magics we support. Each value is a dict keyed by magic
344 344 name whose value is the function docstring. If a docstring is
345 345 unavailable, the value of `missing` is used instead.
346 346
347 347 If brief is True, only the first line of each docstring will be returned.
348 348 """
349 349 docs = {}
350 350 for m_type in self.magics:
351 351 m_docs = {}
352 352 for m_name, m_func in iteritems(self.magics[m_type]):
353 353 if m_func.__doc__:
354 354 if brief:
355 355 m_docs[m_name] = m_func.__doc__.split('\n', 1)[0]
356 356 else:
357 357 m_docs[m_name] = m_func.__doc__.rstrip()
358 358 else:
359 359 m_docs[m_name] = missing
360 360 docs[m_type] = m_docs
361 361 return docs
362 362
363 363 def register(self, *magic_objects):
364 364 """Register one or more instances of Magics.
365 365
366 366 Take one or more classes or instances of classes that subclass the main
367 367 `core.Magic` class, and register them with IPython to use the magic
368 368 functions they provide. The registration process will then ensure that
369 369 any methods that have decorated to provide line and/or cell magics will
370 370 be recognized with the `%x`/`%%x` syntax as a line/cell magic
371 371 respectively.
372 372
373 373 If classes are given, they will be instantiated with the default
374 374 constructor. If your classes need a custom constructor, you should
375 375 instanitate them first and pass the instance.
376 376
377 377 The provided arguments can be an arbitrary mix of classes and instances.
378 378
379 379 Parameters
380 380 ----------
381 381 magic_objects : one or more classes or instances
382 382 """
383 383 # Start by validating them to ensure they have all had their magic
384 384 # methods registered at the instance level
385 385 for m in magic_objects:
386 386 if not m.registered:
387 387 raise ValueError("Class of magics %r was constructed without "
388 388 "the @register_magics class decorator")
389 389 if type(m) in (type, MetaHasTraits):
390 390 # If we're given an uninstantiated class
391 391 m = m(shell=self.shell)
392 392
393 393 # Now that we have an instance, we can register it and update the
394 394 # table of callables
395 395 self.registry[m.__class__.__name__] = m
396 396 for mtype in magic_kinds:
397 397 self.magics[mtype].update(m.magics[mtype])
398 398
399 399 def register_function(self, func, magic_kind='line', magic_name=None):
400 400 """Expose a standalone function as magic function for IPython.
401 401
402 402 This will create an IPython magic (line, cell or both) from a
403 403 standalone function. The functions should have the following
404 404 signatures:
405 405
406 406 * For line magics: `def f(line)`
407 407 * For cell magics: `def f(line, cell)`
408 408 * For a function that does both: `def f(line, cell=None)`
409 409
410 410 In the latter case, the function will be called with `cell==None` when
411 411 invoked as `%f`, and with cell as a string when invoked as `%%f`.
412 412
413 413 Parameters
414 414 ----------
415 415 func : callable
416 416 Function to be registered as a magic.
417 417
418 418 magic_kind : str
419 419 Kind of magic, one of 'line', 'cell' or 'line_cell'
420 420
421 421 magic_name : optional str
422 422 If given, the name the magic will have in the IPython namespace. By
423 423 default, the name of the function itself is used.
424 424 """
425 425
426 426 # Create the new method in the user_magics and register it in the
427 427 # global table
428 428 validate_type(magic_kind)
429 magic_name = func.func_name if magic_name is None else magic_name
429 magic_name = func.__name__ if magic_name is None else magic_name
430 430 setattr(self.user_magics, magic_name, func)
431 431 record_magic(self.magics, magic_kind, magic_name, func)
432 432
433 433 def define_magic(self, name, func):
434 434 """[Deprecated] Expose own function as magic function for IPython.
435 435
436 436 Example::
437 437
438 438 def foo_impl(self, parameter_s=''):
439 439 'My very own magic!. (Use docstrings, IPython reads them).'
440 440 print 'Magic function. Passed parameter is between < >:'
441 441 print '<%s>' % parameter_s
442 442 print 'The self object is:', self
443 443
444 444 ip.define_magic('foo',foo_impl)
445 445 """
446 446 meth = types.MethodType(func, self.user_magics)
447 447 setattr(self.user_magics, name, meth)
448 448 record_magic(self.magics, 'line', name, meth)
449 449
450 450 def register_alias(self, alias_name, magic_name, magic_kind='line'):
451 451 """Register an alias to a magic function.
452 452
453 453 The alias is an instance of :class:`MagicAlias`, which holds the
454 454 name and kind of the magic it should call. Binding is done at
455 455 call time, so if the underlying magic function is changed the alias
456 456 will call the new function.
457 457
458 458 Parameters
459 459 ----------
460 460 alias_name : str
461 461 The name of the magic to be registered.
462 462
463 463 magic_name : str
464 464 The name of an existing magic.
465 465
466 466 magic_kind : str
467 467 Kind of magic, one of 'line' or 'cell'
468 468 """
469 469
470 470 # `validate_type` is too permissive, as it allows 'line_cell'
471 471 # which we do not handle.
472 472 if magic_kind not in magic_kinds:
473 473 raise ValueError('magic_kind must be one of %s, %s given' %
474 474 magic_kinds, magic_kind)
475 475
476 476 alias = MagicAlias(self.shell, magic_name, magic_kind)
477 477 setattr(self.user_magics, alias_name, alias)
478 478 record_magic(self.magics, magic_kind, alias_name, alias)
479 479
480 480 # Key base class that provides the central functionality for magics.
481 481
482 482
483 483 class Magics(Configurable):
484 484 """Base class for implementing magic functions.
485 485
486 486 Shell functions which can be reached as %function_name. All magic
487 487 functions should accept a string, which they can parse for their own
488 488 needs. This can make some functions easier to type, eg `%cd ../`
489 489 vs. `%cd("../")`
490 490
491 491 Classes providing magic functions need to subclass this class, and they
492 492 MUST:
493 493
494 494 - Use the method decorators `@line_magic` and `@cell_magic` to decorate
495 495 individual methods as magic functions, AND
496 496
497 497 - Use the class decorator `@magics_class` to ensure that the magic
498 498 methods are properly registered at the instance level upon instance
499 499 initialization.
500 500
501 501 See :mod:`magic_functions` for examples of actual implementation classes.
502 502 """
503 503 # Dict holding all command-line options for each magic.
504 504 options_table = None
505 505 # Dict for the mapping of magic names to methods, set by class decorator
506 506 magics = None
507 507 # Flag to check that the class decorator was properly applied
508 508 registered = False
509 509 # Instance of IPython shell
510 510 shell = None
511 511
512 512 def __init__(self, shell=None, **kwargs):
513 513 if not(self.__class__.registered):
514 514 raise ValueError('Magics subclass without registration - '
515 515 'did you forget to apply @magics_class?')
516 516 if shell is not None:
517 517 if hasattr(shell, 'configurables'):
518 518 shell.configurables.append(self)
519 519 if hasattr(shell, 'config'):
520 520 kwargs.setdefault('parent', shell)
521 521 kwargs['shell'] = shell
522 522
523 523 self.shell = shell
524 524 self.options_table = {}
525 525 # The method decorators are run when the instance doesn't exist yet, so
526 526 # they can only record the names of the methods they are supposed to
527 527 # grab. Only now, that the instance exists, can we create the proper
528 528 # mapping to bound methods. So we read the info off the original names
529 529 # table and replace each method name by the actual bound method.
530 530 # But we mustn't clobber the *class* mapping, in case of multiple instances.
531 531 class_magics = self.magics
532 532 self.magics = {}
533 533 for mtype in magic_kinds:
534 534 tab = self.magics[mtype] = {}
535 535 cls_tab = class_magics[mtype]
536 536 for magic_name, meth_name in iteritems(cls_tab):
537 537 if isinstance(meth_name, string_types):
538 538 # it's a method name, grab it
539 539 tab[magic_name] = getattr(self, meth_name)
540 540 else:
541 541 # it's the real thing
542 542 tab[magic_name] = meth_name
543 543 # Configurable **needs** to be initiated at the end or the config
544 544 # magics get screwed up.
545 545 super(Magics, self).__init__(**kwargs)
546 546
547 547 def arg_err(self,func):
548 548 """Print docstring if incorrect arguments were passed"""
549 549 print('Error in arguments:')
550 550 print(oinspect.getdoc(func))
551 551
552 552 def format_latex(self, strng):
553 553 """Format a string for latex inclusion."""
554 554
555 555 # Characters that need to be escaped for latex:
556 556 escape_re = re.compile(r'(%|_|\$|#|&)',re.MULTILINE)
557 557 # Magic command names as headers:
558 558 cmd_name_re = re.compile(r'^(%s.*?):' % ESC_MAGIC,
559 559 re.MULTILINE)
560 560 # Magic commands
561 561 cmd_re = re.compile(r'(?P<cmd>%s.+?\b)(?!\}\}:)' % ESC_MAGIC,
562 562 re.MULTILINE)
563 563 # Paragraph continue
564 564 par_re = re.compile(r'\\$',re.MULTILINE)
565 565
566 566 # The "\n" symbol
567 567 newline_re = re.compile(r'\\n')
568 568
569 569 # Now build the string for output:
570 570 #strng = cmd_name_re.sub(r'\n\\texttt{\\textsl{\\large \1}}:',strng)
571 571 strng = cmd_name_re.sub(r'\n\\bigskip\n\\texttt{\\textbf{ \1}}:',
572 572 strng)
573 573 strng = cmd_re.sub(r'\\texttt{\g<cmd>}',strng)
574 574 strng = par_re.sub(r'\\\\',strng)
575 575 strng = escape_re.sub(r'\\\1',strng)
576 576 strng = newline_re.sub(r'\\textbackslash{}n',strng)
577 577 return strng
578 578
579 579 def parse_options(self, arg_str, opt_str, *long_opts, **kw):
580 580 """Parse options passed to an argument string.
581 581
582 582 The interface is similar to that of getopt(), but it returns back a
583 583 Struct with the options as keys and the stripped argument string still
584 584 as a string.
585 585
586 586 arg_str is quoted as a true sys.argv vector by using shlex.split.
587 587 This allows us to easily expand variables, glob files, quote
588 588 arguments, etc.
589 589
590 590 Options:
591 591 -mode: default 'string'. If given as 'list', the argument string is
592 592 returned as a list (split on whitespace) instead of a string.
593 593
594 594 -list_all: put all option values in lists. Normally only options
595 595 appearing more than once are put in a list.
596 596
597 597 -posix (True): whether to split the input line in POSIX mode or not,
598 598 as per the conventions outlined in the shlex module from the
599 599 standard library."""
600 600
601 601 # inject default options at the beginning of the input line
602 602 caller = sys._getframe(1).f_code.co_name
603 603 arg_str = '%s %s' % (self.options_table.get(caller,''),arg_str)
604 604
605 605 mode = kw.get('mode','string')
606 606 if mode not in ['string','list']:
607 607 raise ValueError('incorrect mode given: %s' % mode)
608 608 # Get options
609 609 list_all = kw.get('list_all',0)
610 610 posix = kw.get('posix', os.name == 'posix')
611 611 strict = kw.get('strict', True)
612 612
613 613 # Check if we have more than one argument to warrant extra processing:
614 614 odict = {} # Dictionary with options
615 615 args = arg_str.split()
616 616 if len(args) >= 1:
617 617 # If the list of inputs only has 0 or 1 thing in it, there's no
618 618 # need to look for options
619 619 argv = arg_split(arg_str, posix, strict)
620 620 # Do regular option processing
621 621 try:
622 622 opts,args = getopt(argv, opt_str, long_opts)
623 623 except GetoptError as e:
624 624 raise UsageError('%s ( allowed: "%s" %s)' % (e.msg,opt_str,
625 625 " ".join(long_opts)))
626 626 for o,a in opts:
627 627 if o.startswith('--'):
628 628 o = o[2:]
629 629 else:
630 630 o = o[1:]
631 631 try:
632 632 odict[o].append(a)
633 633 except AttributeError:
634 634 odict[o] = [odict[o],a]
635 635 except KeyError:
636 636 if list_all:
637 637 odict[o] = [a]
638 638 else:
639 639 odict[o] = a
640 640
641 641 # Prepare opts,args for return
642 642 opts = Struct(odict)
643 643 if mode == 'string':
644 644 args = ' '.join(args)
645 645
646 646 return opts,args
647 647
648 648 def default_option(self, fn, optstr):
649 649 """Make an entry in the options_table for fn, with value optstr"""
650 650
651 651 if fn not in self.lsmagic():
652 652 error("%s is not a magic function" % fn)
653 653 self.options_table[fn] = optstr
654 654
655 655
656 656 class MagicAlias(object):
657 657 """An alias to another magic function.
658 658
659 659 An alias is determined by its magic name and magic kind. Lookup
660 660 is done at call time, so if the underlying magic changes the alias
661 661 will call the new function.
662 662
663 663 Use the :meth:`MagicsManager.register_alias` method or the
664 664 `%alias_magic` magic function to create and register a new alias.
665 665 """
666 666 def __init__(self, shell, magic_name, magic_kind):
667 667 self.shell = shell
668 668 self.magic_name = magic_name
669 669 self.magic_kind = magic_kind
670 670
671 671 self.pretty_target = '%s%s' % (magic_escapes[self.magic_kind], self.magic_name)
672 672 self.__doc__ = "Alias for `%s`." % self.pretty_target
673 673
674 674 self._in_call = False
675 675
676 676 def __call__(self, *args, **kwargs):
677 677 """Call the magic alias."""
678 678 fn = self.shell.find_magic(self.magic_name, self.magic_kind)
679 679 if fn is None:
680 680 raise UsageError("Magic `%s` not found." % self.pretty_target)
681 681
682 682 # Protect against infinite recursion.
683 683 if self._in_call:
684 684 raise UsageError("Infinite recursion detected; "
685 685 "magic aliases cannot call themselves.")
686 686 self._in_call = True
687 687 try:
688 688 return fn(*args, **kwargs)
689 689 finally:
690 690 self._in_call = False
@@ -1,873 +1,873 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Tools for inspecting Python objects.
3 3
4 4 Uses syntax highlighting for presenting the various information elements.
5 5
6 6 Similar in spirit to the inspect module, but all calls take a name argument to
7 7 reference the name under which an object is being read.
8 8 """
9 9
10 10 #*****************************************************************************
11 11 # Copyright (C) 2001-2004 Fernando Perez <fperez@colorado.edu>
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #*****************************************************************************
16 16 from __future__ import print_function
17 17
18 18 __all__ = ['Inspector','InspectColors']
19 19
20 20 # stdlib modules
21 21 import inspect
22 22 import linecache
23 23 import os
24 24 import types
25 25 import io as stdlib_io
26 26
27 27 try:
28 28 from itertools import izip_longest
29 29 except ImportError:
30 30 from itertools import zip_longest as izip_longest
31 31
32 32 # IPython's own
33 33 from IPython.core import page
34 34 from IPython.testing.skipdoctest import skip_doctest_py3
35 35 from IPython.utils import PyColorize
36 36 from IPython.utils import io
37 37 from IPython.utils import openpy
38 38 from IPython.utils import py3compat
39 39 from IPython.utils.dir2 import safe_hasattr
40 40 from IPython.utils.text import indent
41 41 from IPython.utils.wildcard import list_namespace
42 42 from IPython.utils.coloransi import *
43 43 from IPython.utils.py3compat import cast_unicode, string_types
44 44
45 45 #****************************************************************************
46 46 # Builtin color schemes
47 47
48 48 Colors = TermColors # just a shorthand
49 49
50 50 # Build a few color schemes
51 51 NoColor = ColorScheme(
52 52 'NoColor',{
53 53 'header' : Colors.NoColor,
54 54 'normal' : Colors.NoColor # color off (usu. Colors.Normal)
55 55 } )
56 56
57 57 LinuxColors = ColorScheme(
58 58 'Linux',{
59 59 'header' : Colors.LightRed,
60 60 'normal' : Colors.Normal # color off (usu. Colors.Normal)
61 61 } )
62 62
63 63 LightBGColors = ColorScheme(
64 64 'LightBG',{
65 65 'header' : Colors.Red,
66 66 'normal' : Colors.Normal # color off (usu. Colors.Normal)
67 67 } )
68 68
69 69 # Build table of color schemes (needed by the parser)
70 70 InspectColors = ColorSchemeTable([NoColor,LinuxColors,LightBGColors],
71 71 'Linux')
72 72
73 73 #****************************************************************************
74 74 # Auxiliary functions and objects
75 75
76 76 # See the messaging spec for the definition of all these fields. This list
77 77 # effectively defines the order of display
78 78 info_fields = ['type_name', 'base_class', 'string_form', 'namespace',
79 79 'length', 'file', 'definition', 'docstring', 'source',
80 80 'init_definition', 'class_docstring', 'init_docstring',
81 81 'call_def', 'call_docstring',
82 82 # These won't be printed but will be used to determine how to
83 83 # format the object
84 84 'ismagic', 'isalias', 'isclass', 'argspec', 'found', 'name'
85 85 ]
86 86
87 87
88 88 def object_info(**kw):
89 89 """Make an object info dict with all fields present."""
90 90 infodict = dict(izip_longest(info_fields, [None]))
91 91 infodict.update(kw)
92 92 return infodict
93 93
94 94
95 95 def get_encoding(obj):
96 96 """Get encoding for python source file defining obj
97 97
98 98 Returns None if obj is not defined in a sourcefile.
99 99 """
100 100 ofile = find_file(obj)
101 101 # run contents of file through pager starting at line where the object
102 102 # is defined, as long as the file isn't binary and is actually on the
103 103 # filesystem.
104 104 if ofile is None:
105 105 return None
106 106 elif ofile.endswith(('.so', '.dll', '.pyd')):
107 107 return None
108 108 elif not os.path.isfile(ofile):
109 109 return None
110 110 else:
111 111 # Print only text files, not extension binaries. Note that
112 112 # getsourcelines returns lineno with 1-offset and page() uses
113 113 # 0-offset, so we must adjust.
114 114 buffer = stdlib_io.open(ofile, 'rb') # Tweaked to use io.open for Python 2
115 115 encoding, lines = openpy.detect_encoding(buffer.readline)
116 116 return encoding
117 117
118 118 def getdoc(obj):
119 119 """Stable wrapper around inspect.getdoc.
120 120
121 121 This can't crash because of attribute problems.
122 122
123 123 It also attempts to call a getdoc() method on the given object. This
124 124 allows objects which provide their docstrings via non-standard mechanisms
125 125 (like Pyro proxies) to still be inspected by ipython's ? system."""
126 126 # Allow objects to offer customized documentation via a getdoc method:
127 127 try:
128 128 ds = obj.getdoc()
129 129 except Exception:
130 130 pass
131 131 else:
132 132 # if we get extra info, we add it to the normal docstring.
133 133 if isinstance(ds, string_types):
134 134 return inspect.cleandoc(ds)
135 135
136 136 try:
137 137 docstr = inspect.getdoc(obj)
138 138 encoding = get_encoding(obj)
139 139 return py3compat.cast_unicode(docstr, encoding=encoding)
140 140 except Exception:
141 141 # Harden against an inspect failure, which can occur with
142 142 # SWIG-wrapped extensions.
143 143 raise
144 144 return None
145 145
146 146
147 147 def getsource(obj,is_binary=False):
148 148 """Wrapper around inspect.getsource.
149 149
150 150 This can be modified by other projects to provide customized source
151 151 extraction.
152 152
153 153 Inputs:
154 154
155 155 - obj: an object whose source code we will attempt to extract.
156 156
157 157 Optional inputs:
158 158
159 159 - is_binary: whether the object is known to come from a binary source.
160 160 This implementation will skip returning any output for binary objects, but
161 161 custom extractors may know how to meaningfully process them."""
162 162
163 163 if is_binary:
164 164 return None
165 165 else:
166 166 # get source if obj was decorated with @decorator
167 167 if hasattr(obj,"__wrapped__"):
168 168 obj = obj.__wrapped__
169 169 try:
170 170 src = inspect.getsource(obj)
171 171 except TypeError:
172 172 if hasattr(obj,'__class__'):
173 173 src = inspect.getsource(obj.__class__)
174 174 encoding = get_encoding(obj)
175 175 return cast_unicode(src, encoding=encoding)
176 176
177 177 def getargspec(obj):
178 178 """Get the names and default values of a function's arguments.
179 179
180 180 A tuple of four things is returned: (args, varargs, varkw, defaults).
181 181 'args' is a list of the argument names (it may contain nested lists).
182 182 'varargs' and 'varkw' are the names of the * and ** arguments or None.
183 183 'defaults' is an n-tuple of the default values of the last n arguments.
184 184
185 185 Modified version of inspect.getargspec from the Python Standard
186 186 Library."""
187 187
188 188 if inspect.isfunction(obj):
189 189 func_obj = obj
190 190 elif inspect.ismethod(obj):
191 191 func_obj = obj.im_func
192 192 elif hasattr(obj, '__call__'):
193 193 func_obj = obj.__call__
194 194 else:
195 195 raise TypeError('arg is not a Python function')
196 args, varargs, varkw = inspect.getargs(func_obj.func_code)
197 return args, varargs, varkw, func_obj.func_defaults
196 args, varargs, varkw = inspect.getargs(func_obj.__code__)
197 return args, varargs, varkw, func_obj.__defaults__
198 198
199 199
200 200 def format_argspec(argspec):
201 201 """Format argspect, convenience wrapper around inspect's.
202 202
203 203 This takes a dict instead of ordered arguments and calls
204 204 inspect.format_argspec with the arguments in the necessary order.
205 205 """
206 206 return inspect.formatargspec(argspec['args'], argspec['varargs'],
207 207 argspec['varkw'], argspec['defaults'])
208 208
209 209
210 210 def call_tip(oinfo, format_call=True):
211 211 """Extract call tip data from an oinfo dict.
212 212
213 213 Parameters
214 214 ----------
215 215 oinfo : dict
216 216
217 217 format_call : bool, optional
218 218 If True, the call line is formatted and returned as a string. If not, a
219 219 tuple of (name, argspec) is returned.
220 220
221 221 Returns
222 222 -------
223 223 call_info : None, str or (str, dict) tuple.
224 224 When format_call is True, the whole call information is formattted as a
225 225 single string. Otherwise, the object's name and its argspec dict are
226 226 returned. If no call information is available, None is returned.
227 227
228 228 docstring : str or None
229 229 The most relevant docstring for calling purposes is returned, if
230 230 available. The priority is: call docstring for callable instances, then
231 231 constructor docstring for classes, then main object's docstring otherwise
232 232 (regular functions).
233 233 """
234 234 # Get call definition
235 235 argspec = oinfo.get('argspec')
236 236 if argspec is None:
237 237 call_line = None
238 238 else:
239 239 # Callable objects will have 'self' as their first argument, prune
240 240 # it out if it's there for clarity (since users do *not* pass an
241 241 # extra first argument explicitly).
242 242 try:
243 243 has_self = argspec['args'][0] == 'self'
244 244 except (KeyError, IndexError):
245 245 pass
246 246 else:
247 247 if has_self:
248 248 argspec['args'] = argspec['args'][1:]
249 249
250 250 call_line = oinfo['name']+format_argspec(argspec)
251 251
252 252 # Now get docstring.
253 253 # The priority is: call docstring, constructor docstring, main one.
254 254 doc = oinfo.get('call_docstring')
255 255 if doc is None:
256 256 doc = oinfo.get('init_docstring')
257 257 if doc is None:
258 258 doc = oinfo.get('docstring','')
259 259
260 260 return call_line, doc
261 261
262 262
263 263 def find_file(obj):
264 264 """Find the absolute path to the file where an object was defined.
265 265
266 266 This is essentially a robust wrapper around `inspect.getabsfile`.
267 267
268 268 Returns None if no file can be found.
269 269
270 270 Parameters
271 271 ----------
272 272 obj : any Python object
273 273
274 274 Returns
275 275 -------
276 276 fname : str
277 277 The absolute path to the file where the object was defined.
278 278 """
279 279 # get source if obj was decorated with @decorator
280 280 if safe_hasattr(obj, '__wrapped__'):
281 281 obj = obj.__wrapped__
282 282
283 283 fname = None
284 284 try:
285 285 fname = inspect.getabsfile(obj)
286 286 except TypeError:
287 287 # For an instance, the file that matters is where its class was
288 288 # declared.
289 289 if hasattr(obj, '__class__'):
290 290 try:
291 291 fname = inspect.getabsfile(obj.__class__)
292 292 except TypeError:
293 293 # Can happen for builtins
294 294 pass
295 295 except:
296 296 pass
297 297 return cast_unicode(fname)
298 298
299 299
300 300 def find_source_lines(obj):
301 301 """Find the line number in a file where an object was defined.
302 302
303 303 This is essentially a robust wrapper around `inspect.getsourcelines`.
304 304
305 305 Returns None if no file can be found.
306 306
307 307 Parameters
308 308 ----------
309 309 obj : any Python object
310 310
311 311 Returns
312 312 -------
313 313 lineno : int
314 314 The line number where the object definition starts.
315 315 """
316 316 # get source if obj was decorated with @decorator
317 317 if safe_hasattr(obj, '__wrapped__'):
318 318 obj = obj.__wrapped__
319 319
320 320 try:
321 321 try:
322 322 lineno = inspect.getsourcelines(obj)[1]
323 323 except TypeError:
324 324 # For instances, try the class object like getsource() does
325 325 if hasattr(obj, '__class__'):
326 326 lineno = inspect.getsourcelines(obj.__class__)[1]
327 327 else:
328 328 lineno = None
329 329 except:
330 330 return None
331 331
332 332 return lineno
333 333
334 334
335 335 class Inspector:
336 336 def __init__(self, color_table=InspectColors,
337 337 code_color_table=PyColorize.ANSICodeColors,
338 338 scheme='NoColor',
339 339 str_detail_level=0):
340 340 self.color_table = color_table
341 341 self.parser = PyColorize.Parser(code_color_table,out='str')
342 342 self.format = self.parser.format
343 343 self.str_detail_level = str_detail_level
344 344 self.set_active_scheme(scheme)
345 345
346 346 def _getdef(self,obj,oname=''):
347 347 """Return the call signature for any callable object.
348 348
349 349 If any exception is generated, None is returned instead and the
350 350 exception is suppressed."""
351 351
352 352 try:
353 353 hdef = oname + inspect.formatargspec(*getargspec(obj))
354 354 return cast_unicode(hdef)
355 355 except:
356 356 return None
357 357
358 358 def __head(self,h):
359 359 """Return a header string with proper colors."""
360 360 return '%s%s%s' % (self.color_table.active_colors.header,h,
361 361 self.color_table.active_colors.normal)
362 362
363 363 def set_active_scheme(self, scheme):
364 364 self.color_table.set_active_scheme(scheme)
365 365 self.parser.color_table.set_active_scheme(scheme)
366 366
367 367 def noinfo(self, msg, oname):
368 368 """Generic message when no information is found."""
369 369 print('No %s found' % msg, end=' ')
370 370 if oname:
371 371 print('for %s' % oname)
372 372 else:
373 373 print()
374 374
375 375 def pdef(self, obj, oname=''):
376 376 """Print the call signature for any callable object.
377 377
378 378 If the object is a class, print the constructor information."""
379 379
380 380 if not callable(obj):
381 381 print('Object is not callable.')
382 382 return
383 383
384 384 header = ''
385 385
386 386 if inspect.isclass(obj):
387 387 header = self.__head('Class constructor information:\n')
388 388 obj = obj.__init__
389 389 elif (not py3compat.PY3) and type(obj) is types.InstanceType:
390 390 obj = obj.__call__
391 391
392 392 output = self._getdef(obj,oname)
393 393 if output is None:
394 394 self.noinfo('definition header',oname)
395 395 else:
396 396 print(header,self.format(output), end=' ', file=io.stdout)
397 397
398 398 # In Python 3, all classes are new-style, so they all have __init__.
399 399 @skip_doctest_py3
400 400 def pdoc(self,obj,oname='',formatter = None):
401 401 """Print the docstring for any object.
402 402
403 403 Optional:
404 404 -formatter: a function to run the docstring through for specially
405 405 formatted docstrings.
406 406
407 407 Examples
408 408 --------
409 409
410 410 In [1]: class NoInit:
411 411 ...: pass
412 412
413 413 In [2]: class NoDoc:
414 414 ...: def __init__(self):
415 415 ...: pass
416 416
417 417 In [3]: %pdoc NoDoc
418 418 No documentation found for NoDoc
419 419
420 420 In [4]: %pdoc NoInit
421 421 No documentation found for NoInit
422 422
423 423 In [5]: obj = NoInit()
424 424
425 425 In [6]: %pdoc obj
426 426 No documentation found for obj
427 427
428 428 In [5]: obj2 = NoDoc()
429 429
430 430 In [6]: %pdoc obj2
431 431 No documentation found for obj2
432 432 """
433 433
434 434 head = self.__head # For convenience
435 435 lines = []
436 436 ds = getdoc(obj)
437 437 if formatter:
438 438 ds = formatter(ds)
439 439 if ds:
440 440 lines.append(head("Class Docstring:"))
441 441 lines.append(indent(ds))
442 442 if inspect.isclass(obj) and hasattr(obj, '__init__'):
443 443 init_ds = getdoc(obj.__init__)
444 444 if init_ds is not None:
445 445 lines.append(head("Constructor Docstring:"))
446 446 lines.append(indent(init_ds))
447 447 elif hasattr(obj,'__call__'):
448 448 call_ds = getdoc(obj.__call__)
449 449 if call_ds:
450 450 lines.append(head("Calling Docstring:"))
451 451 lines.append(indent(call_ds))
452 452
453 453 if not lines:
454 454 self.noinfo('documentation',oname)
455 455 else:
456 456 page.page('\n'.join(lines))
457 457
458 458 def psource(self,obj,oname=''):
459 459 """Print the source code for an object."""
460 460
461 461 # Flush the source cache because inspect can return out-of-date source
462 462 linecache.checkcache()
463 463 try:
464 464 src = getsource(obj)
465 465 except:
466 466 self.noinfo('source',oname)
467 467 else:
468 468 page.page(self.format(src))
469 469
470 470 def pfile(self, obj, oname=''):
471 471 """Show the whole file where an object was defined."""
472 472
473 473 lineno = find_source_lines(obj)
474 474 if lineno is None:
475 475 self.noinfo('file', oname)
476 476 return
477 477
478 478 ofile = find_file(obj)
479 479 # run contents of file through pager starting at line where the object
480 480 # is defined, as long as the file isn't binary and is actually on the
481 481 # filesystem.
482 482 if ofile.endswith(('.so', '.dll', '.pyd')):
483 483 print('File %r is binary, not printing.' % ofile)
484 484 elif not os.path.isfile(ofile):
485 485 print('File %r does not exist, not printing.' % ofile)
486 486 else:
487 487 # Print only text files, not extension binaries. Note that
488 488 # getsourcelines returns lineno with 1-offset and page() uses
489 489 # 0-offset, so we must adjust.
490 490 page.page(self.format(openpy.read_py_file(ofile, skip_encoding_cookie=False)), lineno - 1)
491 491
492 492 def _format_fields(self, fields, title_width=12):
493 493 """Formats a list of fields for display.
494 494
495 495 Parameters
496 496 ----------
497 497 fields : list
498 498 A list of 2-tuples: (field_title, field_content)
499 499 title_width : int
500 500 How many characters to pad titles to. Default 12.
501 501 """
502 502 out = []
503 503 header = self.__head
504 504 for title, content in fields:
505 505 if len(content.splitlines()) > 1:
506 506 title = header(title + ":") + "\n"
507 507 else:
508 508 title = header((title+":").ljust(title_width))
509 509 out.append(cast_unicode(title) + cast_unicode(content))
510 510 return "\n".join(out)
511 511
512 512 # The fields to be displayed by pinfo: (fancy_name, key_in_info_dict)
513 513 pinfo_fields1 = [("Type", "type_name"),
514 514 ]
515 515
516 516 pinfo_fields2 = [("String Form", "string_form"),
517 517 ]
518 518
519 519 pinfo_fields3 = [("Length", "length"),
520 520 ("File", "file"),
521 521 ("Definition", "definition"),
522 522 ]
523 523
524 524 pinfo_fields_obj = [("Class Docstring", "class_docstring"),
525 525 ("Constructor Docstring","init_docstring"),
526 526 ("Call def", "call_def"),
527 527 ("Call docstring", "call_docstring")]
528 528
529 529 def pinfo(self,obj,oname='',formatter=None,info=None,detail_level=0):
530 530 """Show detailed information about an object.
531 531
532 532 Optional arguments:
533 533
534 534 - oname: name of the variable pointing to the object.
535 535
536 536 - formatter: special formatter for docstrings (see pdoc)
537 537
538 538 - info: a structure with some information fields which may have been
539 539 precomputed already.
540 540
541 541 - detail_level: if set to 1, more information is given.
542 542 """
543 543 info = self.info(obj, oname=oname, formatter=formatter,
544 544 info=info, detail_level=detail_level)
545 545 displayfields = []
546 546 def add_fields(fields):
547 547 for title, key in fields:
548 548 field = info[key]
549 549 if field is not None:
550 550 displayfields.append((title, field.rstrip()))
551 551
552 552 add_fields(self.pinfo_fields1)
553 553
554 554 # Base class for old-style instances
555 555 if (not py3compat.PY3) and isinstance(obj, types.InstanceType) and info['base_class']:
556 556 displayfields.append(("Base Class", info['base_class'].rstrip()))
557 557
558 558 add_fields(self.pinfo_fields2)
559 559
560 560 # Namespace
561 561 if info['namespace'] != 'Interactive':
562 562 displayfields.append(("Namespace", info['namespace'].rstrip()))
563 563
564 564 add_fields(self.pinfo_fields3)
565 565
566 566 # Source or docstring, depending on detail level and whether
567 567 # source found.
568 568 if detail_level > 0 and info['source'] is not None:
569 569 displayfields.append(("Source",
570 570 self.format(cast_unicode(info['source']))))
571 571 elif info['docstring'] is not None:
572 572 displayfields.append(("Docstring", info["docstring"]))
573 573
574 574 # Constructor info for classes
575 575 if info['isclass']:
576 576 if info['init_definition'] or info['init_docstring']:
577 577 displayfields.append(("Constructor information", ""))
578 578 if info['init_definition'] is not None:
579 579 displayfields.append((" Definition",
580 580 info['init_definition'].rstrip()))
581 581 if info['init_docstring'] is not None:
582 582 displayfields.append((" Docstring",
583 583 indent(info['init_docstring'])))
584 584
585 585 # Info for objects:
586 586 else:
587 587 add_fields(self.pinfo_fields_obj)
588 588
589 589 # Finally send to printer/pager:
590 590 if displayfields:
591 591 page.page(self._format_fields(displayfields))
592 592
593 593 def info(self, obj, oname='', formatter=None, info=None, detail_level=0):
594 594 """Compute a dict with detailed information about an object.
595 595
596 596 Optional arguments:
597 597
598 598 - oname: name of the variable pointing to the object.
599 599
600 600 - formatter: special formatter for docstrings (see pdoc)
601 601
602 602 - info: a structure with some information fields which may have been
603 603 precomputed already.
604 604
605 605 - detail_level: if set to 1, more information is given.
606 606 """
607 607
608 608 obj_type = type(obj)
609 609
610 610 header = self.__head
611 611 if info is None:
612 612 ismagic = 0
613 613 isalias = 0
614 614 ospace = ''
615 615 else:
616 616 ismagic = info.ismagic
617 617 isalias = info.isalias
618 618 ospace = info.namespace
619 619
620 620 # Get docstring, special-casing aliases:
621 621 if isalias:
622 622 if not callable(obj):
623 623 try:
624 624 ds = "Alias to the system command:\n %s" % obj[1]
625 625 except:
626 626 ds = "Alias: " + str(obj)
627 627 else:
628 628 ds = "Alias to " + str(obj)
629 629 if obj.__doc__:
630 630 ds += "\nDocstring:\n" + obj.__doc__
631 631 else:
632 632 ds = getdoc(obj)
633 633 if ds is None:
634 634 ds = '<no docstring>'
635 635 if formatter is not None:
636 636 ds = formatter(ds)
637 637
638 638 # store output in a dict, we initialize it here and fill it as we go
639 639 out = dict(name=oname, found=True, isalias=isalias, ismagic=ismagic)
640 640
641 641 string_max = 200 # max size of strings to show (snipped if longer)
642 642 shalf = int((string_max -5)/2)
643 643
644 644 if ismagic:
645 645 obj_type_name = 'Magic function'
646 646 elif isalias:
647 647 obj_type_name = 'System alias'
648 648 else:
649 649 obj_type_name = obj_type.__name__
650 650 out['type_name'] = obj_type_name
651 651
652 652 try:
653 653 bclass = obj.__class__
654 654 out['base_class'] = str(bclass)
655 655 except: pass
656 656
657 657 # String form, but snip if too long in ? form (full in ??)
658 658 if detail_level >= self.str_detail_level:
659 659 try:
660 660 ostr = str(obj)
661 661 str_head = 'string_form'
662 662 if not detail_level and len(ostr)>string_max:
663 663 ostr = ostr[:shalf] + ' <...> ' + ostr[-shalf:]
664 664 ostr = ("\n" + " " * len(str_head.expandtabs())).\
665 665 join(q.strip() for q in ostr.split("\n"))
666 666 out[str_head] = ostr
667 667 except:
668 668 pass
669 669
670 670 if ospace:
671 671 out['namespace'] = ospace
672 672
673 673 # Length (for strings and lists)
674 674 try:
675 675 out['length'] = str(len(obj))
676 676 except: pass
677 677
678 678 # Filename where object was defined
679 679 binary_file = False
680 680 fname = find_file(obj)
681 681 if fname is None:
682 682 # if anything goes wrong, we don't want to show source, so it's as
683 683 # if the file was binary
684 684 binary_file = True
685 685 else:
686 686 if fname.endswith(('.so', '.dll', '.pyd')):
687 687 binary_file = True
688 688 elif fname.endswith('<string>'):
689 689 fname = 'Dynamically generated function. No source code available.'
690 690 out['file'] = fname
691 691
692 692 # reconstruct the function definition and print it:
693 693 defln = self._getdef(obj, oname)
694 694 if defln:
695 695 out['definition'] = self.format(defln)
696 696
697 697 # Docstrings only in detail 0 mode, since source contains them (we
698 698 # avoid repetitions). If source fails, we add them back, see below.
699 699 if ds and detail_level == 0:
700 700 out['docstring'] = ds
701 701
702 702 # Original source code for any callable
703 703 if detail_level:
704 704 # Flush the source cache because inspect can return out-of-date
705 705 # source
706 706 linecache.checkcache()
707 707 source = None
708 708 try:
709 709 try:
710 710 source = getsource(obj, binary_file)
711 711 except TypeError:
712 712 if hasattr(obj, '__class__'):
713 713 source = getsource(obj.__class__, binary_file)
714 714 if source is not None:
715 715 out['source'] = source.rstrip()
716 716 except Exception:
717 717 pass
718 718
719 719 if ds and source is None:
720 720 out['docstring'] = ds
721 721
722 722
723 723 # Constructor docstring for classes
724 724 if inspect.isclass(obj):
725 725 out['isclass'] = True
726 726 # reconstruct the function definition and print it:
727 727 try:
728 728 obj_init = obj.__init__
729 729 except AttributeError:
730 730 init_def = init_ds = None
731 731 else:
732 732 init_def = self._getdef(obj_init,oname)
733 733 init_ds = getdoc(obj_init)
734 734 # Skip Python's auto-generated docstrings
735 735 if init_ds and \
736 736 init_ds.startswith('x.__init__(...) initializes'):
737 737 init_ds = None
738 738
739 739 if init_def or init_ds:
740 740 if init_def:
741 741 out['init_definition'] = self.format(init_def)
742 742 if init_ds:
743 743 out['init_docstring'] = init_ds
744 744
745 745 # and class docstring for instances:
746 746 else:
747 747 # First, check whether the instance docstring is identical to the
748 748 # class one, and print it separately if they don't coincide. In
749 749 # most cases they will, but it's nice to print all the info for
750 750 # objects which use instance-customized docstrings.
751 751 if ds:
752 752 try:
753 753 cls = getattr(obj,'__class__')
754 754 except:
755 755 class_ds = None
756 756 else:
757 757 class_ds = getdoc(cls)
758 758 # Skip Python's auto-generated docstrings
759 759 if class_ds and \
760 760 (class_ds.startswith('function(code, globals[,') or \
761 761 class_ds.startswith('instancemethod(function, instance,') or \
762 762 class_ds.startswith('module(name[,') ):
763 763 class_ds = None
764 764 if class_ds and ds != class_ds:
765 765 out['class_docstring'] = class_ds
766 766
767 767 # Next, try to show constructor docstrings
768 768 try:
769 769 init_ds = getdoc(obj.__init__)
770 770 # Skip Python's auto-generated docstrings
771 771 if init_ds and \
772 772 init_ds.startswith('x.__init__(...) initializes'):
773 773 init_ds = None
774 774 except AttributeError:
775 775 init_ds = None
776 776 if init_ds:
777 777 out['init_docstring'] = init_ds
778 778
779 779 # Call form docstring for callable instances
780 780 if safe_hasattr(obj, '__call__'):
781 781 call_def = self._getdef(obj.__call__, oname)
782 782 if call_def is not None:
783 783 out['call_def'] = self.format(call_def)
784 784 call_ds = getdoc(obj.__call__)
785 785 # Skip Python's auto-generated docstrings
786 786 if call_ds and call_ds.startswith('x.__call__(...) <==> x(...)'):
787 787 call_ds = None
788 788 if call_ds:
789 789 out['call_docstring'] = call_ds
790 790
791 791 # Compute the object's argspec as a callable. The key is to decide
792 792 # whether to pull it from the object itself, from its __init__ or
793 793 # from its __call__ method.
794 794
795 795 if inspect.isclass(obj):
796 796 # Old-style classes need not have an __init__
797 797 callable_obj = getattr(obj, "__init__", None)
798 798 elif callable(obj):
799 799 callable_obj = obj
800 800 else:
801 801 callable_obj = None
802 802
803 803 if callable_obj:
804 804 try:
805 805 args, varargs, varkw, defaults = getargspec(callable_obj)
806 806 except (TypeError, AttributeError):
807 807 # For extensions/builtins we can't retrieve the argspec
808 808 pass
809 809 else:
810 810 out['argspec'] = dict(args=args, varargs=varargs,
811 811 varkw=varkw, defaults=defaults)
812 812
813 813 return object_info(**out)
814 814
815 815
816 816 def psearch(self,pattern,ns_table,ns_search=[],
817 817 ignore_case=False,show_all=False):
818 818 """Search namespaces with wildcards for objects.
819 819
820 820 Arguments:
821 821
822 822 - pattern: string containing shell-like wildcards to use in namespace
823 823 searches and optionally a type specification to narrow the search to
824 824 objects of that type.
825 825
826 826 - ns_table: dict of name->namespaces for search.
827 827
828 828 Optional arguments:
829 829
830 830 - ns_search: list of namespace names to include in search.
831 831
832 832 - ignore_case(False): make the search case-insensitive.
833 833
834 834 - show_all(False): show all names, including those starting with
835 835 underscores.
836 836 """
837 837 #print 'ps pattern:<%r>' % pattern # dbg
838 838
839 839 # defaults
840 840 type_pattern = 'all'
841 841 filter = ''
842 842
843 843 cmds = pattern.split()
844 844 len_cmds = len(cmds)
845 845 if len_cmds == 1:
846 846 # Only filter pattern given
847 847 filter = cmds[0]
848 848 elif len_cmds == 2:
849 849 # Both filter and type specified
850 850 filter,type_pattern = cmds
851 851 else:
852 852 raise ValueError('invalid argument string for psearch: <%s>' %
853 853 pattern)
854 854
855 855 # filter search namespaces
856 856 for name in ns_search:
857 857 if name not in ns_table:
858 858 raise ValueError('invalid namespace <%s>. Valid names: %s' %
859 859 (name,ns_table.keys()))
860 860
861 861 #print 'type_pattern:',type_pattern # dbg
862 862 search_result, namespaces_seen = set(), set()
863 863 for ns_name in ns_search:
864 864 ns = ns_table[ns_name]
865 865 # Normally, locals and globals are the same, so we just check one.
866 866 if id(ns) in namespaces_seen:
867 867 continue
868 868 namespaces_seen.add(id(ns))
869 869 tmp_res = list_namespace(ns, type_pattern, filter,
870 870 ignore_case=ignore_case, show_all=show_all)
871 871 search_result.update(tmp_res)
872 872
873 873 page.page('\n'.join(sorted(search_result)))
@@ -1,1267 +1,1267 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 ultratb.py -- Spice up your tracebacks!
4 4
5 5 * ColorTB
6 6 I've always found it a bit hard to visually parse tracebacks in Python. The
7 7 ColorTB class is a solution to that problem. It colors the different parts of a
8 8 traceback in a manner similar to what you would expect from a syntax-highlighting
9 9 text editor.
10 10
11 11 Installation instructions for ColorTB::
12 12
13 13 import sys,ultratb
14 14 sys.excepthook = ultratb.ColorTB()
15 15
16 16 * VerboseTB
17 17 I've also included a port of Ka-Ping Yee's "cgitb.py" that produces all kinds
18 18 of useful info when a traceback occurs. Ping originally had it spit out HTML
19 19 and intended it for CGI programmers, but why should they have all the fun? I
20 20 altered it to spit out colored text to the terminal. It's a bit overwhelming,
21 21 but kind of neat, and maybe useful for long-running programs that you believe
22 22 are bug-free. If a crash *does* occur in that type of program you want details.
23 23 Give it a shot--you'll love it or you'll hate it.
24 24
25 25 .. note::
26 26
27 27 The Verbose mode prints the variables currently visible where the exception
28 28 happened (shortening their strings if too long). This can potentially be
29 29 very slow, if you happen to have a huge data structure whose string
30 30 representation is complex to compute. Your computer may appear to freeze for
31 31 a while with cpu usage at 100%. If this occurs, you can cancel the traceback
32 32 with Ctrl-C (maybe hitting it more than once).
33 33
34 34 If you encounter this kind of situation often, you may want to use the
35 35 Verbose_novars mode instead of the regular Verbose, which avoids formatting
36 36 variables (but otherwise includes the information and context given by
37 37 Verbose).
38 38
39 39
40 40 Installation instructions for ColorTB::
41 41
42 42 import sys,ultratb
43 43 sys.excepthook = ultratb.VerboseTB()
44 44
45 45 Note: Much of the code in this module was lifted verbatim from the standard
46 46 library module 'traceback.py' and Ka-Ping Yee's 'cgitb.py'.
47 47
48 48 Color schemes
49 49 -------------
50 50
51 51 The colors are defined in the class TBTools through the use of the
52 52 ColorSchemeTable class. Currently the following exist:
53 53
54 54 - NoColor: allows all of this module to be used in any terminal (the color
55 55 escapes are just dummy blank strings).
56 56
57 57 - Linux: is meant to look good in a terminal like the Linux console (black
58 58 or very dark background).
59 59
60 60 - LightBG: similar to Linux but swaps dark/light colors to be more readable
61 61 in light background terminals.
62 62
63 63 You can implement other color schemes easily, the syntax is fairly
64 64 self-explanatory. Please send back new schemes you develop to the author for
65 65 possible inclusion in future releases.
66 66
67 67 Inheritance diagram:
68 68
69 69 .. inheritance-diagram:: IPython.core.ultratb
70 70 :parts: 3
71 71 """
72 72
73 73 #*****************************************************************************
74 74 # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
75 75 # Copyright (C) 2001-2004 Fernando Perez <fperez@colorado.edu>
76 76 #
77 77 # Distributed under the terms of the BSD License. The full license is in
78 78 # the file COPYING, distributed as part of this software.
79 79 #*****************************************************************************
80 80
81 81 from __future__ import unicode_literals
82 82 from __future__ import print_function
83 83
84 84 import inspect
85 85 import keyword
86 86 import linecache
87 87 import os
88 88 import pydoc
89 89 import re
90 90 import sys
91 91 import time
92 92 import tokenize
93 93 import traceback
94 94 import types
95 95
96 96 try: # Python 2
97 97 generate_tokens = tokenize.generate_tokens
98 98 except AttributeError: # Python 3
99 99 generate_tokens = tokenize.tokenize
100 100
101 101 # For purposes of monkeypatching inspect to fix a bug in it.
102 102 from inspect import getsourcefile, getfile, getmodule,\
103 103 ismodule, isclass, ismethod, isfunction, istraceback, isframe, iscode
104 104
105 105 # IPython's own modules
106 106 # Modified pdb which doesn't damage IPython's readline handling
107 107 from IPython import get_ipython
108 108 from IPython.core import debugger
109 109 from IPython.core.display_trap import DisplayTrap
110 110 from IPython.core.excolors import exception_colors
111 111 from IPython.utils import PyColorize
112 112 from IPython.utils import io
113 113 from IPython.utils import openpy
114 114 from IPython.utils import path as util_path
115 115 from IPython.utils import py3compat
116 116 from IPython.utils import ulinecache
117 117 from IPython.utils.data import uniq_stable
118 118 from IPython.utils.warn import info, error
119 119
120 120 # Globals
121 121 # amount of space to put line numbers before verbose tracebacks
122 122 INDENT_SIZE = 8
123 123
124 124 # Default color scheme. This is used, for example, by the traceback
125 125 # formatter. When running in an actual IPython instance, the user's rc.colors
126 126 # value is used, but havinga module global makes this functionality available
127 127 # to users of ultratb who are NOT running inside ipython.
128 128 DEFAULT_SCHEME = 'NoColor'
129 129
130 130 #---------------------------------------------------------------------------
131 131 # Code begins
132 132
133 133 # Utility functions
134 134 def inspect_error():
135 135 """Print a message about internal inspect errors.
136 136
137 137 These are unfortunately quite common."""
138 138
139 139 error('Internal Python error in the inspect module.\n'
140 140 'Below is the traceback from this internal error.\n')
141 141
142 142 # This function is a monkeypatch we apply to the Python inspect module. We have
143 143 # now found when it's needed (see discussion on issue gh-1456), and we have a
144 144 # test case (IPython.core.tests.test_ultratb.ChangedPyFileTest) that fails if
145 145 # the monkeypatch is not applied. TK, Aug 2012.
146 146 def findsource(object):
147 147 """Return the entire source file and starting line number for an object.
148 148
149 149 The argument may be a module, class, method, function, traceback, frame,
150 150 or code object. The source code is returned as a list of all the lines
151 151 in the file and the line number indexes a line in that list. An IOError
152 152 is raised if the source code cannot be retrieved.
153 153
154 154 FIXED version with which we monkeypatch the stdlib to work around a bug."""
155 155
156 156 file = getsourcefile(object) or getfile(object)
157 157 # If the object is a frame, then trying to get the globals dict from its
158 158 # module won't work. Instead, the frame object itself has the globals
159 159 # dictionary.
160 160 globals_dict = None
161 161 if inspect.isframe(object):
162 162 # XXX: can this ever be false?
163 163 globals_dict = object.f_globals
164 164 else:
165 165 module = getmodule(object, file)
166 166 if module:
167 167 globals_dict = module.__dict__
168 168 lines = linecache.getlines(file, globals_dict)
169 169 if not lines:
170 170 raise IOError('could not get source code')
171 171
172 172 if ismodule(object):
173 173 return lines, 0
174 174
175 175 if isclass(object):
176 176 name = object.__name__
177 177 pat = re.compile(r'^(\s*)class\s*' + name + r'\b')
178 178 # make some effort to find the best matching class definition:
179 179 # use the one with the least indentation, which is the one
180 180 # that's most probably not inside a function definition.
181 181 candidates = []
182 182 for i in range(len(lines)):
183 183 match = pat.match(lines[i])
184 184 if match:
185 185 # if it's at toplevel, it's already the best one
186 186 if lines[i][0] == 'c':
187 187 return lines, i
188 188 # else add whitespace to candidate list
189 189 candidates.append((match.group(1), i))
190 190 if candidates:
191 191 # this will sort by whitespace, and by line number,
192 192 # less whitespace first
193 193 candidates.sort()
194 194 return lines, candidates[0][1]
195 195 else:
196 196 raise IOError('could not find class definition')
197 197
198 198 if ismethod(object):
199 199 object = object.im_func
200 200 if isfunction(object):
201 object = object.func_code
201 object = object.__code__
202 202 if istraceback(object):
203 203 object = object.tb_frame
204 204 if isframe(object):
205 205 object = object.f_code
206 206 if iscode(object):
207 207 if not hasattr(object, 'co_firstlineno'):
208 208 raise IOError('could not find function definition')
209 209 pat = re.compile(r'^(\s*def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)')
210 210 pmatch = pat.match
211 211 # fperez - fix: sometimes, co_firstlineno can give a number larger than
212 212 # the length of lines, which causes an error. Safeguard against that.
213 213 lnum = min(object.co_firstlineno,len(lines))-1
214 214 while lnum > 0:
215 215 if pmatch(lines[lnum]): break
216 216 lnum -= 1
217 217
218 218 return lines, lnum
219 219 raise IOError('could not find code object')
220 220
221 221 # Monkeypatch inspect to apply our bugfix. This code only works with Python >= 2.5
222 222 inspect.findsource = findsource
223 223
224 224 def fix_frame_records_filenames(records):
225 225 """Try to fix the filenames in each record from inspect.getinnerframes().
226 226
227 227 Particularly, modules loaded from within zip files have useless filenames
228 228 attached to their code object, and inspect.getinnerframes() just uses it.
229 229 """
230 230 fixed_records = []
231 231 for frame, filename, line_no, func_name, lines, index in records:
232 232 # Look inside the frame's globals dictionary for __file__, which should
233 233 # be better.
234 234 better_fn = frame.f_globals.get('__file__', None)
235 235 if isinstance(better_fn, str):
236 236 # Check the type just in case someone did something weird with
237 237 # __file__. It might also be None if the error occurred during
238 238 # import.
239 239 filename = better_fn
240 240 fixed_records.append((frame, filename, line_no, func_name, lines, index))
241 241 return fixed_records
242 242
243 243
244 244 def _fixed_getinnerframes(etb, context=1,tb_offset=0):
245 245 LNUM_POS, LINES_POS, INDEX_POS = 2, 4, 5
246 246
247 247 records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
248 248
249 249 # If the error is at the console, don't build any context, since it would
250 250 # otherwise produce 5 blank lines printed out (there is no file at the
251 251 # console)
252 252 rec_check = records[tb_offset:]
253 253 try:
254 254 rname = rec_check[0][1]
255 255 if rname == '<ipython console>' or rname.endswith('<string>'):
256 256 return rec_check
257 257 except IndexError:
258 258 pass
259 259
260 260 aux = traceback.extract_tb(etb)
261 261 assert len(records) == len(aux)
262 262 for i, (file, lnum, _, _) in zip(range(len(records)), aux):
263 263 maybeStart = lnum-1 - context//2
264 264 start = max(maybeStart, 0)
265 265 end = start + context
266 266 lines = ulinecache.getlines(file)[start:end]
267 267 buf = list(records[i])
268 268 buf[LNUM_POS] = lnum
269 269 buf[INDEX_POS] = lnum - 1 - start
270 270 buf[LINES_POS] = lines
271 271 records[i] = tuple(buf)
272 272 return records[tb_offset:]
273 273
274 274 # Helper function -- largely belongs to VerboseTB, but we need the same
275 275 # functionality to produce a pseudo verbose TB for SyntaxErrors, so that they
276 276 # can be recognized properly by ipython.el's py-traceback-line-re
277 277 # (SyntaxErrors have to be treated specially because they have no traceback)
278 278
279 279 _parser = PyColorize.Parser()
280 280
281 281 def _format_traceback_lines(lnum, index, lines, Colors, lvals=None,scheme=None):
282 282 numbers_width = INDENT_SIZE - 1
283 283 res = []
284 284 i = lnum - index
285 285
286 286 # This lets us get fully syntax-highlighted tracebacks.
287 287 if scheme is None:
288 288 ipinst = get_ipython()
289 289 if ipinst is not None:
290 290 scheme = ipinst.colors
291 291 else:
292 292 scheme = DEFAULT_SCHEME
293 293
294 294 _line_format = _parser.format2
295 295
296 296 for line in lines:
297 297 line = py3compat.cast_unicode(line)
298 298
299 299 new_line, err = _line_format(line, 'str', scheme)
300 300 if not err: line = new_line
301 301
302 302 if i == lnum:
303 303 # This is the line with the error
304 304 pad = numbers_width - len(str(i))
305 305 if pad >= 3:
306 306 marker = '-'*(pad-3) + '-> '
307 307 elif pad == 2:
308 308 marker = '> '
309 309 elif pad == 1:
310 310 marker = '>'
311 311 else:
312 312 marker = ''
313 313 num = marker + str(i)
314 314 line = '%s%s%s %s%s' %(Colors.linenoEm, num,
315 315 Colors.line, line, Colors.Normal)
316 316 else:
317 317 num = '%*s' % (numbers_width,i)
318 318 line = '%s%s%s %s' %(Colors.lineno, num,
319 319 Colors.Normal, line)
320 320
321 321 res.append(line)
322 322 if lvals and i == lnum:
323 323 res.append(lvals + '\n')
324 324 i = i + 1
325 325 return res
326 326
327 327
328 328 #---------------------------------------------------------------------------
329 329 # Module classes
330 330 class TBTools(object):
331 331 """Basic tools used by all traceback printer classes."""
332 332
333 333 # Number of frames to skip when reporting tracebacks
334 334 tb_offset = 0
335 335
336 336 def __init__(self, color_scheme='NoColor', call_pdb=False, ostream=None):
337 337 # Whether to call the interactive pdb debugger after printing
338 338 # tracebacks or not
339 339 self.call_pdb = call_pdb
340 340
341 341 # Output stream to write to. Note that we store the original value in
342 342 # a private attribute and then make the public ostream a property, so
343 343 # that we can delay accessing io.stdout until runtime. The way
344 344 # things are written now, the io.stdout object is dynamically managed
345 345 # so a reference to it should NEVER be stored statically. This
346 346 # property approach confines this detail to a single location, and all
347 347 # subclasses can simply access self.ostream for writing.
348 348 self._ostream = ostream
349 349
350 350 # Create color table
351 351 self.color_scheme_table = exception_colors()
352 352
353 353 self.set_colors(color_scheme)
354 354 self.old_scheme = color_scheme # save initial value for toggles
355 355
356 356 if call_pdb:
357 357 self.pdb = debugger.Pdb(self.color_scheme_table.active_scheme_name)
358 358 else:
359 359 self.pdb = None
360 360
361 361 def _get_ostream(self):
362 362 """Output stream that exceptions are written to.
363 363
364 364 Valid values are:
365 365
366 366 - None: the default, which means that IPython will dynamically resolve
367 367 to io.stdout. This ensures compatibility with most tools, including
368 368 Windows (where plain stdout doesn't recognize ANSI escapes).
369 369
370 370 - Any object with 'write' and 'flush' attributes.
371 371 """
372 372 return io.stdout if self._ostream is None else self._ostream
373 373
374 374 def _set_ostream(self, val):
375 375 assert val is None or (hasattr(val, 'write') and hasattr(val, 'flush'))
376 376 self._ostream = val
377 377
378 378 ostream = property(_get_ostream, _set_ostream)
379 379
380 380 def set_colors(self,*args,**kw):
381 381 """Shorthand access to the color table scheme selector method."""
382 382
383 383 # Set own color table
384 384 self.color_scheme_table.set_active_scheme(*args,**kw)
385 385 # for convenience, set Colors to the active scheme
386 386 self.Colors = self.color_scheme_table.active_colors
387 387 # Also set colors of debugger
388 388 if hasattr(self,'pdb') and self.pdb is not None:
389 389 self.pdb.set_colors(*args,**kw)
390 390
391 391 def color_toggle(self):
392 392 """Toggle between the currently active color scheme and NoColor."""
393 393
394 394 if self.color_scheme_table.active_scheme_name == 'NoColor':
395 395 self.color_scheme_table.set_active_scheme(self.old_scheme)
396 396 self.Colors = self.color_scheme_table.active_colors
397 397 else:
398 398 self.old_scheme = self.color_scheme_table.active_scheme_name
399 399 self.color_scheme_table.set_active_scheme('NoColor')
400 400 self.Colors = self.color_scheme_table.active_colors
401 401
402 402 def stb2text(self, stb):
403 403 """Convert a structured traceback (a list) to a string."""
404 404 return '\n'.join(stb)
405 405
406 406 def text(self, etype, value, tb, tb_offset=None, context=5):
407 407 """Return formatted traceback.
408 408
409 409 Subclasses may override this if they add extra arguments.
410 410 """
411 411 tb_list = self.structured_traceback(etype, value, tb,
412 412 tb_offset, context)
413 413 return self.stb2text(tb_list)
414 414
415 415 def structured_traceback(self, etype, evalue, tb, tb_offset=None,
416 416 context=5, mode=None):
417 417 """Return a list of traceback frames.
418 418
419 419 Must be implemented by each class.
420 420 """
421 421 raise NotImplementedError()
422 422
423 423
424 424 #---------------------------------------------------------------------------
425 425 class ListTB(TBTools):
426 426 """Print traceback information from a traceback list, with optional color.
427 427
428 428 Calling requires 3 arguments: (etype, evalue, elist)
429 429 as would be obtained by::
430 430
431 431 etype, evalue, tb = sys.exc_info()
432 432 if tb:
433 433 elist = traceback.extract_tb(tb)
434 434 else:
435 435 elist = None
436 436
437 437 It can thus be used by programs which need to process the traceback before
438 438 printing (such as console replacements based on the code module from the
439 439 standard library).
440 440
441 441 Because they are meant to be called without a full traceback (only a
442 442 list), instances of this class can't call the interactive pdb debugger."""
443 443
444 444 def __init__(self,color_scheme = 'NoColor', call_pdb=False, ostream=None):
445 445 TBTools.__init__(self, color_scheme=color_scheme, call_pdb=call_pdb,
446 446 ostream=ostream)
447 447
448 448 def __call__(self, etype, value, elist):
449 449 self.ostream.flush()
450 450 self.ostream.write(self.text(etype, value, elist))
451 451 self.ostream.write('\n')
452 452
453 453 def structured_traceback(self, etype, value, elist, tb_offset=None,
454 454 context=5):
455 455 """Return a color formatted string with the traceback info.
456 456
457 457 Parameters
458 458 ----------
459 459 etype : exception type
460 460 Type of the exception raised.
461 461
462 462 value : object
463 463 Data stored in the exception
464 464
465 465 elist : list
466 466 List of frames, see class docstring for details.
467 467
468 468 tb_offset : int, optional
469 469 Number of frames in the traceback to skip. If not given, the
470 470 instance value is used (set in constructor).
471 471
472 472 context : int, optional
473 473 Number of lines of context information to print.
474 474
475 475 Returns
476 476 -------
477 477 String with formatted exception.
478 478 """
479 479 tb_offset = self.tb_offset if tb_offset is None else tb_offset
480 480 Colors = self.Colors
481 481 out_list = []
482 482 if elist:
483 483
484 484 if tb_offset and len(elist) > tb_offset:
485 485 elist = elist[tb_offset:]
486 486
487 487 out_list.append('Traceback %s(most recent call last)%s:' %
488 488 (Colors.normalEm, Colors.Normal) + '\n')
489 489 out_list.extend(self._format_list(elist))
490 490 # The exception info should be a single entry in the list.
491 491 lines = ''.join(self._format_exception_only(etype, value))
492 492 out_list.append(lines)
493 493
494 494 # Note: this code originally read:
495 495
496 496 ## for line in lines[:-1]:
497 497 ## out_list.append(" "+line)
498 498 ## out_list.append(lines[-1])
499 499
500 500 # This means it was indenting everything but the last line by a little
501 501 # bit. I've disabled this for now, but if we see ugliness somewhre we
502 502 # can restore it.
503 503
504 504 return out_list
505 505
506 506 def _format_list(self, extracted_list):
507 507 """Format a list of traceback entry tuples for printing.
508 508
509 509 Given a list of tuples as returned by extract_tb() or
510 510 extract_stack(), return a list of strings ready for printing.
511 511 Each string in the resulting list corresponds to the item with the
512 512 same index in the argument list. Each string ends in a newline;
513 513 the strings may contain internal newlines as well, for those items
514 514 whose source text line is not None.
515 515
516 516 Lifted almost verbatim from traceback.py
517 517 """
518 518
519 519 Colors = self.Colors
520 520 list = []
521 521 for filename, lineno, name, line in extracted_list[:-1]:
522 522 item = ' File %s"%s"%s, line %s%d%s, in %s%s%s\n' % \
523 523 (Colors.filename, filename, Colors.Normal,
524 524 Colors.lineno, lineno, Colors.Normal,
525 525 Colors.name, name, Colors.Normal)
526 526 if line:
527 527 item += ' %s\n' % line.strip()
528 528 list.append(item)
529 529 # Emphasize the last entry
530 530 filename, lineno, name, line = extracted_list[-1]
531 531 item = '%s File %s"%s"%s, line %s%d%s, in %s%s%s%s\n' % \
532 532 (Colors.normalEm,
533 533 Colors.filenameEm, filename, Colors.normalEm,
534 534 Colors.linenoEm, lineno, Colors.normalEm,
535 535 Colors.nameEm, name, Colors.normalEm,
536 536 Colors.Normal)
537 537 if line:
538 538 item += '%s %s%s\n' % (Colors.line, line.strip(),
539 539 Colors.Normal)
540 540 list.append(item)
541 541 #from pprint import pformat; print 'LISTTB', pformat(list) # dbg
542 542 return list
543 543
544 544 def _format_exception_only(self, etype, value):
545 545 """Format the exception part of a traceback.
546 546
547 547 The arguments are the exception type and value such as given by
548 548 sys.exc_info()[:2]. The return value is a list of strings, each ending
549 549 in a newline. Normally, the list contains a single string; however,
550 550 for SyntaxError exceptions, it contains several lines that (when
551 551 printed) display detailed information about where the syntax error
552 552 occurred. The message indicating which exception occurred is the
553 553 always last string in the list.
554 554
555 555 Also lifted nearly verbatim from traceback.py
556 556 """
557 557 have_filedata = False
558 558 Colors = self.Colors
559 559 list = []
560 560 stype = Colors.excName + etype.__name__ + Colors.Normal
561 561 if value is None:
562 562 # Not sure if this can still happen in Python 2.6 and above
563 563 list.append( py3compat.cast_unicode(stype) + '\n')
564 564 else:
565 565 if issubclass(etype, SyntaxError):
566 566 have_filedata = True
567 567 #print 'filename is',filename # dbg
568 568 if not value.filename: value.filename = "<string>"
569 569 if value.lineno:
570 570 lineno = value.lineno
571 571 textline = ulinecache.getline(value.filename, value.lineno)
572 572 else:
573 573 lineno = 'unknown'
574 574 textline = ''
575 575 list.append('%s File %s"%s"%s, line %s%s%s\n' % \
576 576 (Colors.normalEm,
577 577 Colors.filenameEm, py3compat.cast_unicode(value.filename), Colors.normalEm,
578 578 Colors.linenoEm, lineno, Colors.Normal ))
579 579 if textline == '':
580 580 textline = py3compat.cast_unicode(value.text, "utf-8")
581 581
582 582 if textline is not None:
583 583 i = 0
584 584 while i < len(textline) and textline[i].isspace():
585 585 i += 1
586 586 list.append('%s %s%s\n' % (Colors.line,
587 587 textline.strip(),
588 588 Colors.Normal))
589 589 if value.offset is not None:
590 590 s = ' '
591 591 for c in textline[i:value.offset-1]:
592 592 if c.isspace():
593 593 s += c
594 594 else:
595 595 s += ' '
596 596 list.append('%s%s^%s\n' % (Colors.caret, s,
597 597 Colors.Normal) )
598 598
599 599 try:
600 600 s = value.msg
601 601 except Exception:
602 602 s = self._some_str(value)
603 603 if s:
604 604 list.append('%s%s:%s %s\n' % (str(stype), Colors.excName,
605 605 Colors.Normal, s))
606 606 else:
607 607 list.append('%s\n' % str(stype))
608 608
609 609 # sync with user hooks
610 610 if have_filedata:
611 611 ipinst = get_ipython()
612 612 if ipinst is not None:
613 613 ipinst.hooks.synchronize_with_editor(value.filename, value.lineno, 0)
614 614
615 615 return list
616 616
617 617 def get_exception_only(self, etype, value):
618 618 """Only print the exception type and message, without a traceback.
619 619
620 620 Parameters
621 621 ----------
622 622 etype : exception type
623 623 value : exception value
624 624 """
625 625 return ListTB.structured_traceback(self, etype, value, [])
626 626
627 627
628 628 def show_exception_only(self, etype, evalue):
629 629 """Only print the exception type and message, without a traceback.
630 630
631 631 Parameters
632 632 ----------
633 633 etype : exception type
634 634 value : exception value
635 635 """
636 636 # This method needs to use __call__ from *this* class, not the one from
637 637 # a subclass whose signature or behavior may be different
638 638 ostream = self.ostream
639 639 ostream.flush()
640 640 ostream.write('\n'.join(self.get_exception_only(etype, evalue)))
641 641 ostream.flush()
642 642
643 643 def _some_str(self, value):
644 644 # Lifted from traceback.py
645 645 try:
646 646 return str(value)
647 647 except:
648 648 return '<unprintable %s object>' % type(value).__name__
649 649
650 650 #----------------------------------------------------------------------------
651 651 class VerboseTB(TBTools):
652 652 """A port of Ka-Ping Yee's cgitb.py module that outputs color text instead
653 653 of HTML. Requires inspect and pydoc. Crazy, man.
654 654
655 655 Modified version which optionally strips the topmost entries from the
656 656 traceback, to be used with alternate interpreters (because their own code
657 657 would appear in the traceback)."""
658 658
659 659 def __init__(self,color_scheme = 'Linux', call_pdb=False, ostream=None,
660 660 tb_offset=0, long_header=False, include_vars=True,
661 661 check_cache=None):
662 662 """Specify traceback offset, headers and color scheme.
663 663
664 664 Define how many frames to drop from the tracebacks. Calling it with
665 665 tb_offset=1 allows use of this handler in interpreters which will have
666 666 their own code at the top of the traceback (VerboseTB will first
667 667 remove that frame before printing the traceback info)."""
668 668 TBTools.__init__(self, color_scheme=color_scheme, call_pdb=call_pdb,
669 669 ostream=ostream)
670 670 self.tb_offset = tb_offset
671 671 self.long_header = long_header
672 672 self.include_vars = include_vars
673 673 # By default we use linecache.checkcache, but the user can provide a
674 674 # different check_cache implementation. This is used by the IPython
675 675 # kernel to provide tracebacks for interactive code that is cached,
676 676 # by a compiler instance that flushes the linecache but preserves its
677 677 # own code cache.
678 678 if check_cache is None:
679 679 check_cache = linecache.checkcache
680 680 self.check_cache = check_cache
681 681
682 682 def structured_traceback(self, etype, evalue, etb, tb_offset=None,
683 683 context=5):
684 684 """Return a nice text document describing the traceback."""
685 685
686 686 tb_offset = self.tb_offset if tb_offset is None else tb_offset
687 687
688 688 # some locals
689 689 try:
690 690 etype = etype.__name__
691 691 except AttributeError:
692 692 pass
693 693 Colors = self.Colors # just a shorthand + quicker name lookup
694 694 ColorsNormal = Colors.Normal # used a lot
695 695 col_scheme = self.color_scheme_table.active_scheme_name
696 696 indent = ' '*INDENT_SIZE
697 697 em_normal = '%s\n%s%s' % (Colors.valEm, indent,ColorsNormal)
698 698 undefined = '%sundefined%s' % (Colors.em, ColorsNormal)
699 699 exc = '%s%s%s' % (Colors.excName,etype,ColorsNormal)
700 700
701 701 # some internal-use functions
702 702 def text_repr(value):
703 703 """Hopefully pretty robust repr equivalent."""
704 704 # this is pretty horrible but should always return *something*
705 705 try:
706 706 return pydoc.text.repr(value)
707 707 except KeyboardInterrupt:
708 708 raise
709 709 except:
710 710 try:
711 711 return repr(value)
712 712 except KeyboardInterrupt:
713 713 raise
714 714 except:
715 715 try:
716 716 # all still in an except block so we catch
717 717 # getattr raising
718 718 name = getattr(value, '__name__', None)
719 719 if name:
720 720 # ick, recursion
721 721 return text_repr(name)
722 722 klass = getattr(value, '__class__', None)
723 723 if klass:
724 724 return '%s instance' % text_repr(klass)
725 725 except KeyboardInterrupt:
726 726 raise
727 727 except:
728 728 return 'UNRECOVERABLE REPR FAILURE'
729 729 def eqrepr(value, repr=text_repr): return '=%s' % repr(value)
730 730 def nullrepr(value, repr=text_repr): return ''
731 731
732 732 # meat of the code begins
733 733 try:
734 734 etype = etype.__name__
735 735 except AttributeError:
736 736 pass
737 737
738 738 if self.long_header:
739 739 # Header with the exception type, python version, and date
740 740 pyver = 'Python ' + sys.version.split()[0] + ': ' + sys.executable
741 741 date = time.ctime(time.time())
742 742
743 743 head = '%s%s%s\n%s%s%s\n%s' % (Colors.topline, '-'*75, ColorsNormal,
744 744 exc, ' '*(75-len(str(etype))-len(pyver)),
745 745 pyver, date.rjust(75) )
746 746 head += "\nA problem occured executing Python code. Here is the sequence of function"\
747 747 "\ncalls leading up to the error, with the most recent (innermost) call last."
748 748 else:
749 749 # Simplified header
750 750 head = '%s%s%s\n%s%s' % (Colors.topline, '-'*75, ColorsNormal,exc,
751 751 'Traceback (most recent call last)'.\
752 752 rjust(75 - len(str(etype)) ) )
753 753 frames = []
754 754 # Flush cache before calling inspect. This helps alleviate some of the
755 755 # problems with python 2.3's inspect.py.
756 756 ##self.check_cache()
757 757 # Drop topmost frames if requested
758 758 try:
759 759 # Try the default getinnerframes and Alex's: Alex's fixes some
760 760 # problems, but it generates empty tracebacks for console errors
761 761 # (5 blanks lines) where none should be returned.
762 762 #records = inspect.getinnerframes(etb, context)[tb_offset:]
763 763 #print 'python records:', records # dbg
764 764 records = _fixed_getinnerframes(etb, context, tb_offset)
765 765 #print 'alex records:', records # dbg
766 766 except:
767 767
768 768 # FIXME: I've been getting many crash reports from python 2.3
769 769 # users, traceable to inspect.py. If I can find a small test-case
770 770 # to reproduce this, I should either write a better workaround or
771 771 # file a bug report against inspect (if that's the real problem).
772 772 # So far, I haven't been able to find an isolated example to
773 773 # reproduce the problem.
774 774 inspect_error()
775 775 traceback.print_exc(file=self.ostream)
776 776 info('\nUnfortunately, your original traceback can not be constructed.\n')
777 777 return ''
778 778
779 779 # build some color string templates outside these nested loops
780 780 tpl_link = '%s%%s%s' % (Colors.filenameEm,ColorsNormal)
781 781 tpl_call = 'in %s%%s%s%%s%s' % (Colors.vName, Colors.valEm,
782 782 ColorsNormal)
783 783 tpl_call_fail = 'in %s%%s%s(***failed resolving arguments***)%s' % \
784 784 (Colors.vName, Colors.valEm, ColorsNormal)
785 785 tpl_local_var = '%s%%s%s' % (Colors.vName, ColorsNormal)
786 786 tpl_global_var = '%sglobal%s %s%%s%s' % (Colors.em, ColorsNormal,
787 787 Colors.vName, ColorsNormal)
788 788 tpl_name_val = '%%s %s= %%s%s' % (Colors.valEm, ColorsNormal)
789 789 tpl_line = '%s%%s%s %%s' % (Colors.lineno, ColorsNormal)
790 790 tpl_line_em = '%s%%s%s %%s%s' % (Colors.linenoEm,Colors.line,
791 791 ColorsNormal)
792 792
793 793 # now, loop over all records printing context and info
794 794 abspath = os.path.abspath
795 795 for frame, file, lnum, func, lines, index in records:
796 796 #print '*** record:',file,lnum,func,lines,index # dbg
797 797 if not file:
798 798 file = '?'
799 799 elif not(file.startswith(str("<")) and file.endswith(str(">"))):
800 800 # Guess that filenames like <string> aren't real filenames, so
801 801 # don't call abspath on them.
802 802 try:
803 803 file = abspath(file)
804 804 except OSError:
805 805 # Not sure if this can still happen: abspath now works with
806 806 # file names like <string>
807 807 pass
808 808 file = py3compat.cast_unicode(file, util_path.fs_encoding)
809 809 link = tpl_link % file
810 810 args, varargs, varkw, locals = inspect.getargvalues(frame)
811 811
812 812 if func == '?':
813 813 call = ''
814 814 else:
815 815 # Decide whether to include variable details or not
816 816 var_repr = self.include_vars and eqrepr or nullrepr
817 817 try:
818 818 call = tpl_call % (func,inspect.formatargvalues(args,
819 819 varargs, varkw,
820 820 locals,formatvalue=var_repr))
821 821 except KeyError:
822 822 # This happens in situations like errors inside generator
823 823 # expressions, where local variables are listed in the
824 824 # line, but can't be extracted from the frame. I'm not
825 825 # 100% sure this isn't actually a bug in inspect itself,
826 826 # but since there's no info for us to compute with, the
827 827 # best we can do is report the failure and move on. Here
828 828 # we must *not* call any traceback construction again,
829 829 # because that would mess up use of %debug later on. So we
830 830 # simply report the failure and move on. The only
831 831 # limitation will be that this frame won't have locals
832 832 # listed in the call signature. Quite subtle problem...
833 833 # I can't think of a good way to validate this in a unit
834 834 # test, but running a script consisting of:
835 835 # dict( (k,v.strip()) for (k,v) in range(10) )
836 836 # will illustrate the error, if this exception catch is
837 837 # disabled.
838 838 call = tpl_call_fail % func
839 839
840 840 # Don't attempt to tokenize binary files.
841 841 if file.endswith(('.so', '.pyd', '.dll')):
842 842 frames.append('%s %s\n' % (link,call))
843 843 continue
844 844 elif file.endswith(('.pyc','.pyo')):
845 845 # Look up the corresponding source file.
846 846 file = openpy.source_from_cache(file)
847 847
848 848 def linereader(file=file, lnum=[lnum], getline=ulinecache.getline):
849 849 line = getline(file, lnum[0])
850 850 lnum[0] += 1
851 851 return line
852 852
853 853 # Build the list of names on this line of code where the exception
854 854 # occurred.
855 855 try:
856 856 names = []
857 857 name_cont = False
858 858
859 859 for token_type, token, start, end, line in generate_tokens(linereader):
860 860 # build composite names
861 861 if token_type == tokenize.NAME and token not in keyword.kwlist:
862 862 if name_cont:
863 863 # Continuation of a dotted name
864 864 try:
865 865 names[-1].append(token)
866 866 except IndexError:
867 867 names.append([token])
868 868 name_cont = False
869 869 else:
870 870 # Regular new names. We append everything, the caller
871 871 # will be responsible for pruning the list later. It's
872 872 # very tricky to try to prune as we go, b/c composite
873 873 # names can fool us. The pruning at the end is easy
874 874 # to do (or the caller can print a list with repeated
875 875 # names if so desired.
876 876 names.append([token])
877 877 elif token == '.':
878 878 name_cont = True
879 879 elif token_type == tokenize.NEWLINE:
880 880 break
881 881
882 882 except (IndexError, UnicodeDecodeError):
883 883 # signals exit of tokenizer
884 884 pass
885 885 except tokenize.TokenError as msg:
886 886 _m = ("An unexpected error occurred while tokenizing input\n"
887 887 "The following traceback may be corrupted or invalid\n"
888 888 "The error message is: %s\n" % msg)
889 889 error(_m)
890 890
891 891 # Join composite names (e.g. "dict.fromkeys")
892 892 names = ['.'.join(n) for n in names]
893 893 # prune names list of duplicates, but keep the right order
894 894 unique_names = uniq_stable(names)
895 895
896 896 # Start loop over vars
897 897 lvals = []
898 898 if self.include_vars:
899 899 for name_full in unique_names:
900 900 name_base = name_full.split('.',1)[0]
901 901 if name_base in frame.f_code.co_varnames:
902 902 if name_base in locals:
903 903 try:
904 904 value = repr(eval(name_full,locals))
905 905 except:
906 906 value = undefined
907 907 else:
908 908 value = undefined
909 909 name = tpl_local_var % name_full
910 910 else:
911 911 if name_base in frame.f_globals:
912 912 try:
913 913 value = repr(eval(name_full,frame.f_globals))
914 914 except:
915 915 value = undefined
916 916 else:
917 917 value = undefined
918 918 name = tpl_global_var % name_full
919 919 lvals.append(tpl_name_val % (name,value))
920 920 if lvals:
921 921 lvals = '%s%s' % (indent,em_normal.join(lvals))
922 922 else:
923 923 lvals = ''
924 924
925 925 level = '%s %s\n' % (link,call)
926 926
927 927 if index is None:
928 928 frames.append(level)
929 929 else:
930 930 frames.append('%s%s' % (level,''.join(
931 931 _format_traceback_lines(lnum,index,lines,Colors,lvals,
932 932 col_scheme))))
933 933
934 934 # Get (safely) a string form of the exception info
935 935 try:
936 936 etype_str,evalue_str = map(str,(etype,evalue))
937 937 except:
938 938 # User exception is improperly defined.
939 939 etype,evalue = str,sys.exc_info()[:2]
940 940 etype_str,evalue_str = map(str,(etype,evalue))
941 941 # ... and format it
942 942 exception = ['%s%s%s: %s' % (Colors.excName, etype_str,
943 943 ColorsNormal, py3compat.cast_unicode(evalue_str))]
944 944 if (not py3compat.PY3) and type(evalue) is types.InstanceType:
945 945 try:
946 946 names = [w for w in dir(evalue) if isinstance(w, py3compat.string_types)]
947 947 except:
948 948 # Every now and then, an object with funny inernals blows up
949 949 # when dir() is called on it. We do the best we can to report
950 950 # the problem and continue
951 951 _m = '%sException reporting error (object with broken dir())%s:'
952 952 exception.append(_m % (Colors.excName,ColorsNormal))
953 953 etype_str,evalue_str = map(str,sys.exc_info()[:2])
954 954 exception.append('%s%s%s: %s' % (Colors.excName,etype_str,
955 955 ColorsNormal, py3compat.cast_unicode(evalue_str)))
956 956 names = []
957 957 for name in names:
958 958 value = text_repr(getattr(evalue, name))
959 959 exception.append('\n%s%s = %s' % (indent, name, value))
960 960
961 961 # vds: >>
962 962 if records:
963 963 filepath, lnum = records[-1][1:3]
964 964 #print "file:", str(file), "linenb", str(lnum) # dbg
965 965 filepath = os.path.abspath(filepath)
966 966 ipinst = get_ipython()
967 967 if ipinst is not None:
968 968 ipinst.hooks.synchronize_with_editor(filepath, lnum, 0)
969 969 # vds: <<
970 970
971 971 # return all our info assembled as a single string
972 972 # return '%s\n\n%s\n%s' % (head,'\n'.join(frames),''.join(exception[0]) )
973 973 return [head] + frames + [''.join(exception[0])]
974 974
975 975 def debugger(self,force=False):
976 976 """Call up the pdb debugger if desired, always clean up the tb
977 977 reference.
978 978
979 979 Keywords:
980 980
981 981 - force(False): by default, this routine checks the instance call_pdb
982 982 flag and does not actually invoke the debugger if the flag is false.
983 983 The 'force' option forces the debugger to activate even if the flag
984 984 is false.
985 985
986 986 If the call_pdb flag is set, the pdb interactive debugger is
987 987 invoked. In all cases, the self.tb reference to the current traceback
988 988 is deleted to prevent lingering references which hamper memory
989 989 management.
990 990
991 991 Note that each call to pdb() does an 'import readline', so if your app
992 992 requires a special setup for the readline completers, you'll have to
993 993 fix that by hand after invoking the exception handler."""
994 994
995 995 if force or self.call_pdb:
996 996 if self.pdb is None:
997 997 self.pdb = debugger.Pdb(
998 998 self.color_scheme_table.active_scheme_name)
999 999 # the system displayhook may have changed, restore the original
1000 1000 # for pdb
1001 1001 display_trap = DisplayTrap(hook=sys.__displayhook__)
1002 1002 with display_trap:
1003 1003 self.pdb.reset()
1004 1004 # Find the right frame so we don't pop up inside ipython itself
1005 1005 if hasattr(self,'tb') and self.tb is not None:
1006 1006 etb = self.tb
1007 1007 else:
1008 1008 etb = self.tb = sys.last_traceback
1009 1009 while self.tb is not None and self.tb.tb_next is not None:
1010 1010 self.tb = self.tb.tb_next
1011 1011 if etb and etb.tb_next:
1012 1012 etb = etb.tb_next
1013 1013 self.pdb.botframe = etb.tb_frame
1014 1014 self.pdb.interaction(self.tb.tb_frame, self.tb)
1015 1015
1016 1016 if hasattr(self,'tb'):
1017 1017 del self.tb
1018 1018
1019 1019 def handler(self, info=None):
1020 1020 (etype, evalue, etb) = info or sys.exc_info()
1021 1021 self.tb = etb
1022 1022 ostream = self.ostream
1023 1023 ostream.flush()
1024 1024 ostream.write(self.text(etype, evalue, etb))
1025 1025 ostream.write('\n')
1026 1026 ostream.flush()
1027 1027
1028 1028 # Changed so an instance can just be called as VerboseTB_inst() and print
1029 1029 # out the right info on its own.
1030 1030 def __call__(self, etype=None, evalue=None, etb=None):
1031 1031 """This hook can replace sys.excepthook (for Python 2.1 or higher)."""
1032 1032 if etb is None:
1033 1033 self.handler()
1034 1034 else:
1035 1035 self.handler((etype, evalue, etb))
1036 1036 try:
1037 1037 self.debugger()
1038 1038 except KeyboardInterrupt:
1039 1039 print("\nKeyboardInterrupt")
1040 1040
1041 1041 #----------------------------------------------------------------------------
1042 1042 class FormattedTB(VerboseTB, ListTB):
1043 1043 """Subclass ListTB but allow calling with a traceback.
1044 1044
1045 1045 It can thus be used as a sys.excepthook for Python > 2.1.
1046 1046
1047 1047 Also adds 'Context' and 'Verbose' modes, not available in ListTB.
1048 1048
1049 1049 Allows a tb_offset to be specified. This is useful for situations where
1050 1050 one needs to remove a number of topmost frames from the traceback (such as
1051 1051 occurs with python programs that themselves execute other python code,
1052 1052 like Python shells). """
1053 1053
1054 1054 def __init__(self, mode='Plain', color_scheme='Linux', call_pdb=False,
1055 1055 ostream=None,
1056 1056 tb_offset=0, long_header=False, include_vars=False,
1057 1057 check_cache=None):
1058 1058
1059 1059 # NEVER change the order of this list. Put new modes at the end:
1060 1060 self.valid_modes = ['Plain','Context','Verbose']
1061 1061 self.verbose_modes = self.valid_modes[1:3]
1062 1062
1063 1063 VerboseTB.__init__(self, color_scheme=color_scheme, call_pdb=call_pdb,
1064 1064 ostream=ostream, tb_offset=tb_offset,
1065 1065 long_header=long_header, include_vars=include_vars,
1066 1066 check_cache=check_cache)
1067 1067
1068 1068 # Different types of tracebacks are joined with different separators to
1069 1069 # form a single string. They are taken from this dict
1070 1070 self._join_chars = dict(Plain='', Context='\n', Verbose='\n')
1071 1071 # set_mode also sets the tb_join_char attribute
1072 1072 self.set_mode(mode)
1073 1073
1074 1074 def _extract_tb(self,tb):
1075 1075 if tb:
1076 1076 return traceback.extract_tb(tb)
1077 1077 else:
1078 1078 return None
1079 1079
1080 1080 def structured_traceback(self, etype, value, tb, tb_offset=None, context=5):
1081 1081 tb_offset = self.tb_offset if tb_offset is None else tb_offset
1082 1082 mode = self.mode
1083 1083 if mode in self.verbose_modes:
1084 1084 # Verbose modes need a full traceback
1085 1085 return VerboseTB.structured_traceback(
1086 1086 self, etype, value, tb, tb_offset, context
1087 1087 )
1088 1088 else:
1089 1089 # We must check the source cache because otherwise we can print
1090 1090 # out-of-date source code.
1091 1091 self.check_cache()
1092 1092 # Now we can extract and format the exception
1093 1093 elist = self._extract_tb(tb)
1094 1094 return ListTB.structured_traceback(
1095 1095 self, etype, value, elist, tb_offset, context
1096 1096 )
1097 1097
1098 1098 def stb2text(self, stb):
1099 1099 """Convert a structured traceback (a list) to a string."""
1100 1100 return self.tb_join_char.join(stb)
1101 1101
1102 1102
1103 1103 def set_mode(self,mode=None):
1104 1104 """Switch to the desired mode.
1105 1105
1106 1106 If mode is not specified, cycles through the available modes."""
1107 1107
1108 1108 if not mode:
1109 1109 new_idx = ( self.valid_modes.index(self.mode) + 1 ) % \
1110 1110 len(self.valid_modes)
1111 1111 self.mode = self.valid_modes[new_idx]
1112 1112 elif mode not in self.valid_modes:
1113 1113 raise ValueError('Unrecognized mode in FormattedTB: <'+mode+'>\n'
1114 1114 'Valid modes: '+str(self.valid_modes))
1115 1115 else:
1116 1116 self.mode = mode
1117 1117 # include variable details only in 'Verbose' mode
1118 1118 self.include_vars = (self.mode == self.valid_modes[2])
1119 1119 # Set the join character for generating text tracebacks
1120 1120 self.tb_join_char = self._join_chars[self.mode]
1121 1121
1122 1122 # some convenient shorcuts
1123 1123 def plain(self):
1124 1124 self.set_mode(self.valid_modes[0])
1125 1125
1126 1126 def context(self):
1127 1127 self.set_mode(self.valid_modes[1])
1128 1128
1129 1129 def verbose(self):
1130 1130 self.set_mode(self.valid_modes[2])
1131 1131
1132 1132 #----------------------------------------------------------------------------
1133 1133 class AutoFormattedTB(FormattedTB):
1134 1134 """A traceback printer which can be called on the fly.
1135 1135
1136 1136 It will find out about exceptions by itself.
1137 1137
1138 1138 A brief example::
1139 1139
1140 1140 AutoTB = AutoFormattedTB(mode = 'Verbose',color_scheme='Linux')
1141 1141 try:
1142 1142 ...
1143 1143 except:
1144 1144 AutoTB() # or AutoTB(out=logfile) where logfile is an open file object
1145 1145 """
1146 1146
1147 1147 def __call__(self,etype=None,evalue=None,etb=None,
1148 1148 out=None,tb_offset=None):
1149 1149 """Print out a formatted exception traceback.
1150 1150
1151 1151 Optional arguments:
1152 1152 - out: an open file-like object to direct output to.
1153 1153
1154 1154 - tb_offset: the number of frames to skip over in the stack, on a
1155 1155 per-call basis (this overrides temporarily the instance's tb_offset
1156 1156 given at initialization time. """
1157 1157
1158 1158
1159 1159 if out is None:
1160 1160 out = self.ostream
1161 1161 out.flush()
1162 1162 out.write(self.text(etype, evalue, etb, tb_offset))
1163 1163 out.write('\n')
1164 1164 out.flush()
1165 1165 # FIXME: we should remove the auto pdb behavior from here and leave
1166 1166 # that to the clients.
1167 1167 try:
1168 1168 self.debugger()
1169 1169 except KeyboardInterrupt:
1170 1170 print("\nKeyboardInterrupt")
1171 1171
1172 1172 def structured_traceback(self, etype=None, value=None, tb=None,
1173 1173 tb_offset=None, context=5):
1174 1174 if etype is None:
1175 1175 etype,value,tb = sys.exc_info()
1176 1176 self.tb = tb
1177 1177 return FormattedTB.structured_traceback(
1178 1178 self, etype, value, tb, tb_offset, context)
1179 1179
1180 1180 #---------------------------------------------------------------------------
1181 1181
1182 1182 # A simple class to preserve Nathan's original functionality.
1183 1183 class ColorTB(FormattedTB):
1184 1184 """Shorthand to initialize a FormattedTB in Linux colors mode."""
1185 1185 def __init__(self,color_scheme='Linux',call_pdb=0):
1186 1186 FormattedTB.__init__(self,color_scheme=color_scheme,
1187 1187 call_pdb=call_pdb)
1188 1188
1189 1189
1190 1190 class SyntaxTB(ListTB):
1191 1191 """Extension which holds some state: the last exception value"""
1192 1192
1193 1193 def __init__(self,color_scheme = 'NoColor'):
1194 1194 ListTB.__init__(self,color_scheme)
1195 1195 self.last_syntax_error = None
1196 1196
1197 1197 def __call__(self, etype, value, elist):
1198 1198 self.last_syntax_error = value
1199 1199 ListTB.__call__(self,etype,value,elist)
1200 1200
1201 1201 def structured_traceback(self, etype, value, elist, tb_offset=None,
1202 1202 context=5):
1203 1203 # If the source file has been edited, the line in the syntax error can
1204 1204 # be wrong (retrieved from an outdated cache). This replaces it with
1205 1205 # the current value.
1206 1206 if isinstance(value, SyntaxError) \
1207 1207 and isinstance(value.filename, py3compat.string_types) \
1208 1208 and isinstance(value.lineno, int):
1209 1209 linecache.checkcache(value.filename)
1210 1210 newtext = ulinecache.getline(value.filename, value.lineno)
1211 1211 if newtext:
1212 1212 value.text = newtext
1213 1213 return super(SyntaxTB, self).structured_traceback(etype, value, elist,
1214 1214 tb_offset=tb_offset, context=context)
1215 1215
1216 1216 def clear_err_state(self):
1217 1217 """Return the current error state and clear it"""
1218 1218 e = self.last_syntax_error
1219 1219 self.last_syntax_error = None
1220 1220 return e
1221 1221
1222 1222 def stb2text(self, stb):
1223 1223 """Convert a structured traceback (a list) to a string."""
1224 1224 return ''.join(stb)
1225 1225
1226 1226
1227 1227 #----------------------------------------------------------------------------
1228 1228 # module testing (minimal)
1229 1229 if __name__ == "__main__":
1230 1230 def spam(c, d_e):
1231 1231 (d, e) = d_e
1232 1232 x = c + d
1233 1233 y = c * d
1234 1234 foo(x, y)
1235 1235
1236 1236 def foo(a, b, bar=1):
1237 1237 eggs(a, b + bar)
1238 1238
1239 1239 def eggs(f, g, z=globals()):
1240 1240 h = f + g
1241 1241 i = f - g
1242 1242 return h / i
1243 1243
1244 1244 print('')
1245 1245 print('*** Before ***')
1246 1246 try:
1247 1247 print(spam(1, (2, 3)))
1248 1248 except:
1249 1249 traceback.print_exc()
1250 1250 print('')
1251 1251
1252 1252 handler = ColorTB()
1253 1253 print('*** ColorTB ***')
1254 1254 try:
1255 1255 print(spam(1, (2, 3)))
1256 1256 except:
1257 1257 handler(*sys.exc_info())
1258 1258 print('')
1259 1259
1260 1260 handler = VerboseTB()
1261 1261 print('*** VerboseTB ***')
1262 1262 try:
1263 1263 print(spam(1, (2, 3)))
1264 1264 except:
1265 1265 handler(*sys.exc_info())
1266 1266 print('')
1267 1267
@@ -1,221 +1,221 b''
1 1 ########################## LICENCE ###############################
2 2
3 3 # Copyright (c) 2005-2012, Michele Simionato
4 4 # All rights reserved.
5 5
6 6 # Redistribution and use in source and binary forms, with or without
7 7 # modification, are permitted provided that the following conditions are
8 8 # met:
9 9
10 10 # Redistributions of source code must retain the above copyright
11 11 # notice, this list of conditions and the following disclaimer.
12 12 # Redistributions in bytecode form must reproduce the above copyright
13 13 # notice, this list of conditions and the following disclaimer in
14 14 # the documentation and/or other materials provided with the
15 15 # distribution.
16 16
17 17 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18 18 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19 19 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20 20 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
21 21 # HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
22 22 # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
23 23 # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
24 24 # OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
25 25 # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
26 26 # TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
27 27 # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
28 28 # DAMAGE.
29 29
30 30 """
31 31 Decorator module, see http://pypi.python.org/pypi/decorator
32 32 for the documentation.
33 33 """
34 34 from __future__ import print_function
35 35
36 36 __version__ = '3.3.3'
37 37
38 38 __all__ = ["decorator", "FunctionMaker", "partial"]
39 39
40 40 import sys, re, inspect
41 41
42 42 try:
43 43 from functools import partial
44 44 except ImportError: # for Python version < 2.5
45 45 class partial(object):
46 46 "A simple replacement of functools.partial"
47 47 def __init__(self, func, *args, **kw):
48 48 self.func = func
49 49 self.args = args
50 50 self.keywords = kw
51 51 def __call__(self, *otherargs, **otherkw):
52 52 kw = self.keywords.copy()
53 53 kw.update(otherkw)
54 54 return self.func(*(self.args + otherargs), **kw)
55 55
56 56 if sys.version >= '3':
57 57 from inspect import getfullargspec
58 58 else:
59 59 class getfullargspec(object):
60 60 "A quick and dirty replacement for getfullargspec for Python 2.X"
61 61 def __init__(self, f):
62 62 self.args, self.varargs, self.varkw, self.defaults = \
63 63 inspect.getargspec(f)
64 64 self.kwonlyargs = []
65 65 self.kwonlydefaults = None
66 66 def __iter__(self):
67 67 yield self.args
68 68 yield self.varargs
69 69 yield self.varkw
70 70 yield self.defaults
71 71
72 72 DEF = re.compile('\s*def\s*([_\w][_\w\d]*)\s*\(')
73 73
74 74 # basic functionality
75 75 class FunctionMaker(object):
76 76 """
77 77 An object with the ability to create functions with a given signature.
78 78 It has attributes name, doc, module, signature, defaults, dict and
79 79 methods update and make.
80 80 """
81 81 def __init__(self, func=None, name=None, signature=None,
82 82 defaults=None, doc=None, module=None, funcdict=None):
83 83 self.shortsignature = signature
84 84 if func:
85 85 # func can be a class or a callable, but not an instance method
86 86 self.name = func.__name__
87 87 if self.name == '<lambda>': # small hack for lambda functions
88 88 self.name = '_lambda_'
89 89 self.doc = func.__doc__
90 90 self.module = func.__module__
91 91 if inspect.isfunction(func):
92 92 argspec = getfullargspec(func)
93 93 self.annotations = getattr(func, '__annotations__', {})
94 94 for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs',
95 95 'kwonlydefaults'):
96 96 setattr(self, a, getattr(argspec, a))
97 97 for i, arg in enumerate(self.args):
98 98 setattr(self, 'arg%d' % i, arg)
99 99 if sys.version < '3': # easy way
100 100 self.shortsignature = self.signature = \
101 101 inspect.formatargspec(
102 102 formatvalue=lambda val: "", *argspec)[1:-1]
103 103 else: # Python 3 way
104 104 self.signature = self.shortsignature = ', '.join(self.args)
105 105 if self.varargs:
106 106 self.signature += ', *' + self.varargs
107 107 self.shortsignature += ', *' + self.varargs
108 108 if self.kwonlyargs:
109 109 for a in self.kwonlyargs:
110 110 self.signature += ', %s=None' % a
111 111 self.shortsignature += ', %s=%s' % (a, a)
112 112 if self.varkw:
113 113 self.signature += ', **' + self.varkw
114 114 self.shortsignature += ', **' + self.varkw
115 115 self.dict = func.__dict__.copy()
116 116 # func=None happens when decorating a caller
117 117 if name:
118 118 self.name = name
119 119 if signature is not None:
120 120 self.signature = signature
121 121 if defaults:
122 122 self.defaults = defaults
123 123 if doc:
124 124 self.doc = doc
125 125 if module:
126 126 self.module = module
127 127 if funcdict:
128 128 self.dict = funcdict
129 129 # check existence required attributes
130 130 assert hasattr(self, 'name')
131 131 if not hasattr(self, 'signature'):
132 132 raise TypeError('You are decorating a non function: %s' % func)
133 133
134 134 def update(self, func, **kw):
135 135 "Update the signature of func with the data in self"
136 136 func.__name__ = self.name
137 137 func.__doc__ = getattr(self, 'doc', None)
138 138 func.__dict__ = getattr(self, 'dict', {})
139 func.func_defaults = getattr(self, 'defaults', ())
139 func.__defaults__ = getattr(self, 'defaults', ())
140 140 func.__kwdefaults__ = getattr(self, 'kwonlydefaults', None)
141 141 func.__annotations__ = getattr(self, 'annotations', None)
142 142 callermodule = sys._getframe(3).f_globals.get('__name__', '?')
143 143 func.__module__ = getattr(self, 'module', callermodule)
144 144 func.__dict__.update(kw)
145 145
146 146 def make(self, src_templ, evaldict=None, addsource=False, **attrs):
147 147 "Make a new function from a given template and update the signature"
148 148 src = src_templ % vars(self) # expand name and signature
149 149 evaldict = evaldict or {}
150 150 mo = DEF.match(src)
151 151 if mo is None:
152 152 raise SyntaxError('not a valid function template\n%s' % src)
153 153 name = mo.group(1) # extract the function name
154 154 names = set([name] + [arg.strip(' *') for arg in
155 155 self.shortsignature.split(',')])
156 156 for n in names:
157 157 if n in ('_func_', '_call_'):
158 158 raise NameError('%s is overridden in\n%s' % (n, src))
159 159 if not src.endswith('\n'): # add a newline just for safety
160 160 src += '\n' # this is needed in old versions of Python
161 161 try:
162 162 code = compile(src, '<string>', 'single')
163 163 # print >> sys.stderr, 'Compiling %s' % src
164 164 exec(code, evaldict)
165 165 except:
166 166 print('Error in generated code:', file=sys.stderr)
167 167 print(src, file=sys.stderr)
168 168 raise
169 169 func = evaldict[name]
170 170 if addsource:
171 171 attrs['__source__'] = src
172 172 self.update(func, **attrs)
173 173 return func
174 174
175 175 @classmethod
176 176 def create(cls, obj, body, evaldict, defaults=None,
177 177 doc=None, module=None, addsource=True, **attrs):
178 178 """
179 179 Create a function from the strings name, signature and body.
180 180 evaldict is the evaluation dictionary. If addsource is true an attribute
181 181 __source__ is added to the result. The attributes attrs are added,
182 182 if any.
183 183 """
184 184 if isinstance(obj, str): # "name(signature)"
185 185 name, rest = obj.strip().split('(', 1)
186 186 signature = rest[:-1] #strip a right parens
187 187 func = None
188 188 else: # a function
189 189 name = None
190 190 signature = None
191 191 func = obj
192 192 self = cls(func, name, signature, defaults, doc, module)
193 193 ibody = '\n'.join(' ' + line for line in body.splitlines())
194 194 return self.make('def %(name)s(%(signature)s):\n' + ibody,
195 195 evaldict, addsource, **attrs)
196 196
197 197 def decorator(caller, func=None):
198 198 """
199 199 decorator(caller) converts a caller function into a decorator;
200 200 decorator(caller, func) decorates a function using a caller.
201 201 """
202 202 if func is not None: # returns a decorated function
203 evaldict = func.func_globals.copy()
203 evaldict = func.__globals__.copy()
204 204 evaldict['_call_'] = caller
205 205 evaldict['_func_'] = func
206 206 return FunctionMaker.create(
207 207 func, "return _call_(_func_, %(shortsignature)s)",
208 208 evaldict, undecorated=func, __wrapped__=func)
209 209 else: # returns a decorator
210 210 if isinstance(caller, partial):
211 211 return partial(decorator, caller)
212 212 # otherwise assume caller is a function
213 213 first = inspect.getargspec(caller)[0][0] # first arg
214 evaldict = caller.func_globals.copy()
214 evaldict = caller.__globals__.copy()
215 215 evaldict['_call_'] = caller
216 216 evaldict['decorator'] = decorator
217 217 return FunctionMaker.create(
218 218 '%s(%s)' % (caller.__name__, first),
219 219 'return decorator(_call_, %s)' % first,
220 220 evaldict, undecorated=caller, __wrapped__=caller,
221 221 doc=caller.__doc__, module=caller.__module__)
@@ -1,226 +1,226 b''
1 1 """Dependency utilities
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2013 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 from types import ModuleType
15 15
16 16 from IPython.parallel.client.asyncresult import AsyncResult
17 17 from IPython.parallel.error import UnmetDependency
18 18 from IPython.parallel.util import interactive
19 19 from IPython.utils import py3compat
20 20 from IPython.utils.py3compat import string_types
21 21 from IPython.utils.pickleutil import can, uncan
22 22
23 23 class depend(object):
24 24 """Dependency decorator, for use with tasks.
25 25
26 26 `@depend` lets you define a function for engine dependencies
27 27 just like you use `apply` for tasks.
28 28
29 29
30 30 Examples
31 31 --------
32 32 ::
33 33
34 34 @depend(df, a,b, c=5)
35 35 def f(m,n,p)
36 36
37 37 view.apply(f, 1,2,3)
38 38
39 39 will call df(a,b,c=5) on the engine, and if it returns False or
40 40 raises an UnmetDependency error, then the task will not be run
41 41 and another engine will be tried.
42 42 """
43 43 def __init__(self, f, *args, **kwargs):
44 44 self.f = f
45 45 self.args = args
46 46 self.kwargs = kwargs
47 47
48 48 def __call__(self, f):
49 49 return dependent(f, self.f, *self.args, **self.kwargs)
50 50
51 51 class dependent(object):
52 52 """A function that depends on another function.
53 53 This is an object to prevent the closure used
54 54 in traditional decorators, which are not picklable.
55 55 """
56 56
57 57 def __init__(self, f, df, *dargs, **dkwargs):
58 58 self.f = f
59 self.func_name = getattr(f, '__name__', 'f')
59 self.__name__ = getattr(f, '__name__', 'f')
60 60 self.df = df
61 61 self.dargs = dargs
62 62 self.dkwargs = dkwargs
63 63
64 64 def check_dependency(self):
65 65 if self.df(*self.dargs, **self.dkwargs) is False:
66 66 raise UnmetDependency()
67 67
68 68 def __call__(self, *args, **kwargs):
69 69 return self.f(*args, **kwargs)
70 70
71 71 if not py3compat.PY3:
72 72 @property
73 73 def __name__(self):
74 return self.func_name
74 return self.__name__
75 75
76 76 @interactive
77 77 def _require(*modules, **mapping):
78 78 """Helper for @require decorator."""
79 79 from IPython.parallel.error import UnmetDependency
80 80 from IPython.utils.pickleutil import uncan
81 81 user_ns = globals()
82 82 for name in modules:
83 83 try:
84 84 exec('import %s' % name, user_ns)
85 85 except ImportError:
86 86 raise UnmetDependency(name)
87 87
88 88 for name, cobj in mapping.items():
89 89 user_ns[name] = uncan(cobj, user_ns)
90 90 return True
91 91
92 92 def require(*objects, **mapping):
93 93 """Simple decorator for requiring local objects and modules to be available
94 94 when the decorated function is called on the engine.
95 95
96 96 Modules specified by name or passed directly will be imported
97 97 prior to calling the decorated function.
98 98
99 99 Objects other than modules will be pushed as a part of the task.
100 100 Functions can be passed positionally,
101 101 and will be pushed to the engine with their __name__.
102 102 Other objects can be passed by keyword arg.
103 103
104 104 Examples
105 105 --------
106 106
107 107 In [1]: @require('numpy')
108 108 ...: def norm(a):
109 109 ...: return numpy.linalg.norm(a,2)
110 110
111 111 In [2]: foo = lambda x: x*x
112 112 In [3]: @require(foo)
113 113 ...: def bar(a):
114 114 ...: return foo(1-a)
115 115 """
116 116 names = []
117 117 for obj in objects:
118 118 if isinstance(obj, ModuleType):
119 119 obj = obj.__name__
120 120
121 121 if isinstance(obj, string_types):
122 122 names.append(obj)
123 123 elif hasattr(obj, '__name__'):
124 124 mapping[obj.__name__] = obj
125 125 else:
126 126 raise TypeError("Objects other than modules and functions "
127 127 "must be passed by kwarg, but got: %s" % type(obj)
128 128 )
129 129
130 130 for name, obj in mapping.items():
131 131 mapping[name] = can(obj)
132 132 return depend(_require, *names, **mapping)
133 133
134 134 class Dependency(set):
135 135 """An object for representing a set of msg_id dependencies.
136 136
137 137 Subclassed from set().
138 138
139 139 Parameters
140 140 ----------
141 141 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
142 142 The msg_ids to depend on
143 143 all : bool [default True]
144 144 Whether the dependency should be considered met when *all* depending tasks have completed
145 145 or only when *any* have been completed.
146 146 success : bool [default True]
147 147 Whether to consider successes as fulfilling dependencies.
148 148 failure : bool [default False]
149 149 Whether to consider failures as fulfilling dependencies.
150 150
151 151 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
152 152 as soon as the first depended-upon task fails.
153 153 """
154 154
155 155 all=True
156 156 success=True
157 157 failure=True
158 158
159 159 def __init__(self, dependencies=[], all=True, success=True, failure=False):
160 160 if isinstance(dependencies, dict):
161 161 # load from dict
162 162 all = dependencies.get('all', True)
163 163 success = dependencies.get('success', success)
164 164 failure = dependencies.get('failure', failure)
165 165 dependencies = dependencies.get('dependencies', [])
166 166 ids = []
167 167
168 168 # extract ids from various sources:
169 169 if isinstance(dependencies, string_types + (AsyncResult,)):
170 170 dependencies = [dependencies]
171 171 for d in dependencies:
172 172 if isinstance(d, string_types):
173 173 ids.append(d)
174 174 elif isinstance(d, AsyncResult):
175 175 ids.extend(d.msg_ids)
176 176 else:
177 177 raise TypeError("invalid dependency type: %r"%type(d))
178 178
179 179 set.__init__(self, ids)
180 180 self.all = all
181 181 if not (success or failure):
182 182 raise ValueError("Must depend on at least one of successes or failures!")
183 183 self.success=success
184 184 self.failure = failure
185 185
186 186 def check(self, completed, failed=None):
187 187 """check whether our dependencies have been met."""
188 188 if len(self) == 0:
189 189 return True
190 190 against = set()
191 191 if self.success:
192 192 against = completed
193 193 if failed is not None and self.failure:
194 194 against = against.union(failed)
195 195 if self.all:
196 196 return self.issubset(against)
197 197 else:
198 198 return not self.isdisjoint(against)
199 199
200 200 def unreachable(self, completed, failed=None):
201 201 """return whether this dependency has become impossible."""
202 202 if len(self) == 0:
203 203 return False
204 204 against = set()
205 205 if not self.success:
206 206 against = completed
207 207 if failed is not None and not self.failure:
208 208 against = against.union(failed)
209 209 if self.all:
210 210 return not self.isdisjoint(against)
211 211 else:
212 212 return self.issubset(against)
213 213
214 214
215 215 def as_dict(self):
216 216 """Represent this dependency as a dict. For json compatibility."""
217 217 return dict(
218 218 dependencies=list(self),
219 219 all=self.all,
220 220 success=self.success,
221 221 failure=self.failure
222 222 )
223 223
224 224
225 225 __all__ = ['depend', 'require', 'dependent', 'Dependency']
226 226
@@ -1,860 +1,860 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 import logging
23 23 import sys
24 24 import time
25 25
26 26 from collections import deque
27 27 from datetime import datetime
28 28 from random import randint, random
29 29 from types import FunctionType
30 30
31 31 try:
32 32 import numpy
33 33 except ImportError:
34 34 numpy = None
35 35
36 36 import zmq
37 37 from zmq.eventloop import ioloop, zmqstream
38 38
39 39 # local imports
40 40 from IPython.external.decorator import decorator
41 41 from IPython.config.application import Application
42 42 from IPython.config.loader import Config
43 43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 44 from IPython.utils.py3compat import cast_bytes
45 45
46 46 from IPython.parallel import error, util
47 47 from IPython.parallel.factory import SessionFactory
48 48 from IPython.parallel.util import connect_logger, local_logger
49 49
50 50 from .dependency import Dependency
51 51
52 52 @decorator
53 53 def logged(f,self,*args,**kwargs):
54 54 # print ("#--------------------")
55 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
56 56 # print ("#--")
57 57 return f(self,*args, **kwargs)
58 58
59 59 #----------------------------------------------------------------------
60 60 # Chooser functions
61 61 #----------------------------------------------------------------------
62 62
63 63 def plainrandom(loads):
64 64 """Plain random pick."""
65 65 n = len(loads)
66 66 return randint(0,n-1)
67 67
68 68 def lru(loads):
69 69 """Always pick the front of the line.
70 70
71 71 The content of `loads` is ignored.
72 72
73 73 Assumes LRU ordering of loads, with oldest first.
74 74 """
75 75 return 0
76 76
77 77 def twobin(loads):
78 78 """Pick two at random, use the LRU of the two.
79 79
80 80 The content of loads is ignored.
81 81
82 82 Assumes LRU ordering of loads, with oldest first.
83 83 """
84 84 n = len(loads)
85 85 a = randint(0,n-1)
86 86 b = randint(0,n-1)
87 87 return min(a,b)
88 88
89 89 def weighted(loads):
90 90 """Pick two at random using inverse load as weight.
91 91
92 92 Return the less loaded of the two.
93 93 """
94 94 # weight 0 a million times more than 1:
95 95 weights = 1./(1e-6+numpy.array(loads))
96 96 sums = weights.cumsum()
97 97 t = sums[-1]
98 98 x = random()*t
99 99 y = random()*t
100 100 idx = 0
101 101 idy = 0
102 102 while sums[idx] < x:
103 103 idx += 1
104 104 while sums[idy] < y:
105 105 idy += 1
106 106 if weights[idy] > weights[idx]:
107 107 return idy
108 108 else:
109 109 return idx
110 110
111 111 def leastload(loads):
112 112 """Always choose the lowest load.
113 113
114 114 If the lowest load occurs more than once, the first
115 115 occurance will be used. If loads has LRU ordering, this means
116 116 the LRU of those with the lowest load is chosen.
117 117 """
118 118 return loads.index(min(loads))
119 119
120 120 #---------------------------------------------------------------------
121 121 # Classes
122 122 #---------------------------------------------------------------------
123 123
124 124
125 125 # store empty default dependency:
126 126 MET = Dependency([])
127 127
128 128
129 129 class Job(object):
130 130 """Simple container for a job"""
131 131 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
132 132 targets, after, follow, timeout):
133 133 self.msg_id = msg_id
134 134 self.raw_msg = raw_msg
135 135 self.idents = idents
136 136 self.msg = msg
137 137 self.header = header
138 138 self.metadata = metadata
139 139 self.targets = targets
140 140 self.after = after
141 141 self.follow = follow
142 142 self.timeout = timeout
143 143
144 144 self.removed = False # used for lazy-delete from sorted queue
145 145 self.timestamp = time.time()
146 146 self.timeout_id = 0
147 147 self.blacklist = set()
148 148
149 149 def __lt__(self, other):
150 150 return self.timestamp < other.timestamp
151 151
152 152 def __cmp__(self, other):
153 153 return cmp(self.timestamp, other.timestamp)
154 154
155 155 @property
156 156 def dependents(self):
157 157 return self.follow.union(self.after)
158 158
159 159
160 160 class TaskScheduler(SessionFactory):
161 161 """Python TaskScheduler object.
162 162
163 163 This is the simplest object that supports msg_id based
164 164 DAG dependencies. *Only* task msg_ids are checked, not
165 165 msg_ids of jobs submitted via the MUX queue.
166 166
167 167 """
168 168
169 169 hwm = Integer(1, config=True,
170 170 help="""specify the High Water Mark (HWM) for the downstream
171 171 socket in the Task scheduler. This is the maximum number
172 172 of allowed outstanding tasks on each engine.
173 173
174 174 The default (1) means that only one task can be outstanding on each
175 175 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
176 176 engines continue to be assigned tasks while they are working,
177 177 effectively hiding network latency behind computation, but can result
178 178 in an imbalance of work when submitting many heterogenous tasks all at
179 179 once. Any positive value greater than one is a compromise between the
180 180 two.
181 181
182 182 """
183 183 )
184 184 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
185 185 'leastload', config=True, allow_none=False,
186 186 help="""select the task scheduler scheme [default: Python LRU]
187 187 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
188 188 )
189 189 def _scheme_name_changed(self, old, new):
190 190 self.log.debug("Using scheme %r"%new)
191 191 self.scheme = globals()[new]
192 192
193 193 # input arguments:
194 194 scheme = Instance(FunctionType) # function for determining the destination
195 195 def _scheme_default(self):
196 196 return leastload
197 197 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
198 198 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
199 199 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
200 200 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
201 201 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
202 202
203 203 # internals:
204 204 queue = Instance(deque) # sorted list of Jobs
205 205 def _queue_default(self):
206 206 return deque()
207 207 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
208 208 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
209 209 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
210 210 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
211 211 pending = Dict() # dict by engine_uuid of submitted tasks
212 212 completed = Dict() # dict by engine_uuid of completed tasks
213 213 failed = Dict() # dict by engine_uuid of failed tasks
214 214 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
215 215 clients = Dict() # dict by msg_id for who submitted the task
216 216 targets = List() # list of target IDENTs
217 217 loads = List() # list of engine loads
218 218 # full = Set() # set of IDENTs that have HWM outstanding tasks
219 219 all_completed = Set() # set of all completed tasks
220 220 all_failed = Set() # set of all failed tasks
221 221 all_done = Set() # set of all finished tasks=union(completed,failed)
222 222 all_ids = Set() # set of all submitted task IDs
223 223
224 224 ident = CBytes() # ZMQ identity. This should just be self.session.session
225 225 # but ensure Bytes
226 226 def _ident_default(self):
227 227 return self.session.bsession
228 228
229 229 def start(self):
230 230 self.query_stream.on_recv(self.dispatch_query_reply)
231 231 self.session.send(self.query_stream, "connection_request", {})
232 232
233 233 self.engine_stream.on_recv(self.dispatch_result, copy=False)
234 234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
235 235
236 236 self._notification_handlers = dict(
237 237 registration_notification = self._register_engine,
238 238 unregistration_notification = self._unregister_engine
239 239 )
240 240 self.notifier_stream.on_recv(self.dispatch_notification)
241 241 self.log.info("Scheduler started [%s]" % self.scheme_name)
242 242
243 243 def resume_receiving(self):
244 244 """Resume accepting jobs."""
245 245 self.client_stream.on_recv(self.dispatch_submission, copy=False)
246 246
247 247 def stop_receiving(self):
248 248 """Stop accepting jobs while there are no engines.
249 249 Leave them in the ZMQ queue."""
250 250 self.client_stream.on_recv(None)
251 251
252 252 #-----------------------------------------------------------------------
253 253 # [Un]Registration Handling
254 254 #-----------------------------------------------------------------------
255 255
256 256
257 257 def dispatch_query_reply(self, msg):
258 258 """handle reply to our initial connection request"""
259 259 try:
260 260 idents,msg = self.session.feed_identities(msg)
261 261 except ValueError:
262 262 self.log.warn("task::Invalid Message: %r",msg)
263 263 return
264 264 try:
265 265 msg = self.session.unserialize(msg)
266 266 except ValueError:
267 267 self.log.warn("task::Unauthorized message from: %r"%idents)
268 268 return
269 269
270 270 content = msg['content']
271 271 for uuid in content.get('engines', {}).values():
272 272 self._register_engine(cast_bytes(uuid))
273 273
274 274
275 275 @util.log_errors
276 276 def dispatch_notification(self, msg):
277 277 """dispatch register/unregister events."""
278 278 try:
279 279 idents,msg = self.session.feed_identities(msg)
280 280 except ValueError:
281 281 self.log.warn("task::Invalid Message: %r",msg)
282 282 return
283 283 try:
284 284 msg = self.session.unserialize(msg)
285 285 except ValueError:
286 286 self.log.warn("task::Unauthorized message from: %r"%idents)
287 287 return
288 288
289 289 msg_type = msg['header']['msg_type']
290 290
291 291 handler = self._notification_handlers.get(msg_type, None)
292 292 if handler is None:
293 293 self.log.error("Unhandled message type: %r"%msg_type)
294 294 else:
295 295 try:
296 296 handler(cast_bytes(msg['content']['uuid']))
297 297 except Exception:
298 298 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
299 299
300 300 def _register_engine(self, uid):
301 301 """New engine with ident `uid` became available."""
302 302 # head of the line:
303 303 self.targets.insert(0,uid)
304 304 self.loads.insert(0,0)
305 305
306 306 # initialize sets
307 307 self.completed[uid] = set()
308 308 self.failed[uid] = set()
309 309 self.pending[uid] = {}
310 310
311 311 # rescan the graph:
312 312 self.update_graph(None)
313 313
314 314 def _unregister_engine(self, uid):
315 315 """Existing engine with ident `uid` became unavailable."""
316 316 if len(self.targets) == 1:
317 317 # this was our only engine
318 318 pass
319 319
320 320 # handle any potentially finished tasks:
321 321 self.engine_stream.flush()
322 322
323 323 # don't pop destinations, because they might be used later
324 324 # map(self.destinations.pop, self.completed.pop(uid))
325 325 # map(self.destinations.pop, self.failed.pop(uid))
326 326
327 327 # prevent this engine from receiving work
328 328 idx = self.targets.index(uid)
329 329 self.targets.pop(idx)
330 330 self.loads.pop(idx)
331 331
332 332 # wait 5 seconds before cleaning up pending jobs, since the results might
333 333 # still be incoming
334 334 if self.pending[uid]:
335 335 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
336 336 dc.start()
337 337 else:
338 338 self.completed.pop(uid)
339 339 self.failed.pop(uid)
340 340
341 341
342 342 def handle_stranded_tasks(self, engine):
343 343 """Deal with jobs resident in an engine that died."""
344 344 lost = self.pending[engine]
345 345 for msg_id in lost.keys():
346 346 if msg_id not in self.pending[engine]:
347 347 # prevent double-handling of messages
348 348 continue
349 349
350 350 raw_msg = lost[msg_id].raw_msg
351 351 idents,msg = self.session.feed_identities(raw_msg, copy=False)
352 352 parent = self.session.unpack(msg[1].bytes)
353 353 idents = [engine, idents[0]]
354 354
355 355 # build fake error reply
356 356 try:
357 357 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
358 358 except:
359 359 content = error.wrap_exception()
360 360 # build fake metadata
361 361 md = dict(
362 362 status=u'error',
363 363 engine=engine.decode('ascii'),
364 364 date=datetime.now(),
365 365 )
366 366 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
367 367 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
368 368 # and dispatch it
369 369 self.dispatch_result(raw_reply)
370 370
371 371 # finally scrub completed/failed lists
372 372 self.completed.pop(engine)
373 373 self.failed.pop(engine)
374 374
375 375
376 376 #-----------------------------------------------------------------------
377 377 # Job Submission
378 378 #-----------------------------------------------------------------------
379 379
380 380
381 381 @util.log_errors
382 382 def dispatch_submission(self, raw_msg):
383 383 """Dispatch job submission to appropriate handlers."""
384 384 # ensure targets up to date:
385 385 self.notifier_stream.flush()
386 386 try:
387 387 idents, msg = self.session.feed_identities(raw_msg, copy=False)
388 388 msg = self.session.unserialize(msg, content=False, copy=False)
389 389 except Exception:
390 390 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
391 391 return
392 392
393 393
394 394 # send to monitor
395 395 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
396 396
397 397 header = msg['header']
398 398 md = msg['metadata']
399 399 msg_id = header['msg_id']
400 400 self.all_ids.add(msg_id)
401 401
402 402 # get targets as a set of bytes objects
403 403 # from a list of unicode objects
404 404 targets = md.get('targets', [])
405 405 targets = map(cast_bytes, targets)
406 406 targets = set(targets)
407 407
408 408 retries = md.get('retries', 0)
409 409 self.retries[msg_id] = retries
410 410
411 411 # time dependencies
412 412 after = md.get('after', None)
413 413 if after:
414 414 after = Dependency(after)
415 415 if after.all:
416 416 if after.success:
417 417 after = Dependency(after.difference(self.all_completed),
418 418 success=after.success,
419 419 failure=after.failure,
420 420 all=after.all,
421 421 )
422 422 if after.failure:
423 423 after = Dependency(after.difference(self.all_failed),
424 424 success=after.success,
425 425 failure=after.failure,
426 426 all=after.all,
427 427 )
428 428 if after.check(self.all_completed, self.all_failed):
429 429 # recast as empty set, if `after` already met,
430 430 # to prevent unnecessary set comparisons
431 431 after = MET
432 432 else:
433 433 after = MET
434 434
435 435 # location dependencies
436 436 follow = Dependency(md.get('follow', []))
437 437
438 438 timeout = md.get('timeout', None)
439 439 if timeout:
440 440 timeout = float(timeout)
441 441
442 442 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
443 443 header=header, targets=targets, after=after, follow=follow,
444 444 timeout=timeout, metadata=md,
445 445 )
446 446 # validate and reduce dependencies:
447 447 for dep in after,follow:
448 448 if not dep: # empty dependency
449 449 continue
450 450 # check valid:
451 451 if msg_id in dep or dep.difference(self.all_ids):
452 452 self.queue_map[msg_id] = job
453 453 return self.fail_unreachable(msg_id, error.InvalidDependency)
454 454 # check if unreachable:
455 455 if dep.unreachable(self.all_completed, self.all_failed):
456 456 self.queue_map[msg_id] = job
457 457 return self.fail_unreachable(msg_id)
458 458
459 459 if after.check(self.all_completed, self.all_failed):
460 460 # time deps already met, try to run
461 461 if not self.maybe_run(job):
462 462 # can't run yet
463 463 if msg_id not in self.all_failed:
464 464 # could have failed as unreachable
465 465 self.save_unmet(job)
466 466 else:
467 467 self.save_unmet(job)
468 468
469 469 def job_timeout(self, job, timeout_id):
470 470 """callback for a job's timeout.
471 471
472 472 The job may or may not have been run at this point.
473 473 """
474 474 if job.timeout_id != timeout_id:
475 475 # not the most recent call
476 476 return
477 477 now = time.time()
478 478 if job.timeout >= (now + 1):
479 479 self.log.warn("task %s timeout fired prematurely: %s > %s",
480 480 job.msg_id, job.timeout, now
481 481 )
482 482 if job.msg_id in self.queue_map:
483 483 # still waiting, but ran out of time
484 484 self.log.info("task %r timed out", job.msg_id)
485 485 self.fail_unreachable(job.msg_id, error.TaskTimeout)
486 486
487 487 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
488 488 """a task has become unreachable, send a reply with an ImpossibleDependency
489 489 error."""
490 490 if msg_id not in self.queue_map:
491 491 self.log.error("task %r already failed!", msg_id)
492 492 return
493 493 job = self.queue_map.pop(msg_id)
494 494 # lazy-delete from the queue
495 495 job.removed = True
496 496 for mid in job.dependents:
497 497 if mid in self.graph:
498 498 self.graph[mid].remove(msg_id)
499 499
500 500 try:
501 501 raise why()
502 502 except:
503 503 content = error.wrap_exception()
504 504 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
505 505
506 506 self.all_done.add(msg_id)
507 507 self.all_failed.add(msg_id)
508 508
509 509 msg = self.session.send(self.client_stream, 'apply_reply', content,
510 510 parent=job.header, ident=job.idents)
511 511 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
512 512
513 513 self.update_graph(msg_id, success=False)
514 514
515 515 def available_engines(self):
516 516 """return a list of available engine indices based on HWM"""
517 517 if not self.hwm:
518 518 return range(len(self.targets))
519 519 available = []
520 520 for idx in range(len(self.targets)):
521 521 if self.loads[idx] < self.hwm:
522 522 available.append(idx)
523 523 return available
524 524
525 525 def maybe_run(self, job):
526 526 """check location dependencies, and run if they are met."""
527 527 msg_id = job.msg_id
528 528 self.log.debug("Attempting to assign task %s", msg_id)
529 529 available = self.available_engines()
530 530 if not available:
531 531 # no engines, definitely can't run
532 532 return False
533 533
534 534 if job.follow or job.targets or job.blacklist or self.hwm:
535 535 # we need a can_run filter
536 536 def can_run(idx):
537 537 # check hwm
538 538 if self.hwm and self.loads[idx] == self.hwm:
539 539 return False
540 540 target = self.targets[idx]
541 541 # check blacklist
542 542 if target in job.blacklist:
543 543 return False
544 544 # check targets
545 545 if job.targets and target not in job.targets:
546 546 return False
547 547 # check follow
548 548 return job.follow.check(self.completed[target], self.failed[target])
549 549
550 550 indices = filter(can_run, available)
551 551
552 552 if not indices:
553 553 # couldn't run
554 554 if job.follow.all:
555 555 # check follow for impossibility
556 556 dests = set()
557 557 relevant = set()
558 558 if job.follow.success:
559 559 relevant = self.all_completed
560 560 if job.follow.failure:
561 561 relevant = relevant.union(self.all_failed)
562 562 for m in job.follow.intersection(relevant):
563 563 dests.add(self.destinations[m])
564 564 if len(dests) > 1:
565 565 self.queue_map[msg_id] = job
566 566 self.fail_unreachable(msg_id)
567 567 return False
568 568 if job.targets:
569 569 # check blacklist+targets for impossibility
570 570 job.targets.difference_update(job.blacklist)
571 571 if not job.targets or not job.targets.intersection(self.targets):
572 572 self.queue_map[msg_id] = job
573 573 self.fail_unreachable(msg_id)
574 574 return False
575 575 return False
576 576 else:
577 577 indices = None
578 578
579 579 self.submit_task(job, indices)
580 580 return True
581 581
582 582 def save_unmet(self, job):
583 583 """Save a message for later submission when its dependencies are met."""
584 584 msg_id = job.msg_id
585 585 self.log.debug("Adding task %s to the queue", msg_id)
586 586 self.queue_map[msg_id] = job
587 587 self.queue.append(job)
588 588 # track the ids in follow or after, but not those already finished
589 589 for dep_id in job.after.union(job.follow).difference(self.all_done):
590 590 if dep_id not in self.graph:
591 591 self.graph[dep_id] = set()
592 592 self.graph[dep_id].add(msg_id)
593 593
594 594 # schedule timeout callback
595 595 if job.timeout:
596 596 timeout_id = job.timeout_id = job.timeout_id + 1
597 597 self.loop.add_timeout(time.time() + job.timeout,
598 598 lambda : self.job_timeout(job, timeout_id)
599 599 )
600 600
601 601
602 602 def submit_task(self, job, indices=None):
603 603 """Submit a task to any of a subset of our targets."""
604 604 if indices:
605 605 loads = [self.loads[i] for i in indices]
606 606 else:
607 607 loads = self.loads
608 608 idx = self.scheme(loads)
609 609 if indices:
610 610 idx = indices[idx]
611 611 target = self.targets[idx]
612 612 # print (target, map(str, msg[:3]))
613 613 # send job to the engine
614 614 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
615 615 self.engine_stream.send_multipart(job.raw_msg, copy=False)
616 616 # update load
617 617 self.add_job(idx)
618 618 self.pending[target][job.msg_id] = job
619 619 # notify Hub
620 620 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
621 621 self.session.send(self.mon_stream, 'task_destination', content=content,
622 622 ident=[b'tracktask',self.ident])
623 623
624 624
625 625 #-----------------------------------------------------------------------
626 626 # Result Handling
627 627 #-----------------------------------------------------------------------
628 628
629 629
630 630 @util.log_errors
631 631 def dispatch_result(self, raw_msg):
632 632 """dispatch method for result replies"""
633 633 try:
634 634 idents,msg = self.session.feed_identities(raw_msg, copy=False)
635 635 msg = self.session.unserialize(msg, content=False, copy=False)
636 636 engine = idents[0]
637 637 try:
638 638 idx = self.targets.index(engine)
639 639 except ValueError:
640 640 pass # skip load-update for dead engines
641 641 else:
642 642 self.finish_job(idx)
643 643 except Exception:
644 644 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
645 645 return
646 646
647 647 md = msg['metadata']
648 648 parent = msg['parent_header']
649 649 if md.get('dependencies_met', True):
650 650 success = (md['status'] == 'ok')
651 651 msg_id = parent['msg_id']
652 652 retries = self.retries[msg_id]
653 653 if not success and retries > 0:
654 654 # failed
655 655 self.retries[msg_id] = retries - 1
656 656 self.handle_unmet_dependency(idents, parent)
657 657 else:
658 658 del self.retries[msg_id]
659 659 # relay to client and update graph
660 660 self.handle_result(idents, parent, raw_msg, success)
661 661 # send to Hub monitor
662 662 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
663 663 else:
664 664 self.handle_unmet_dependency(idents, parent)
665 665
666 666 def handle_result(self, idents, parent, raw_msg, success=True):
667 667 """handle a real task result, either success or failure"""
668 668 # first, relay result to client
669 669 engine = idents[0]
670 670 client = idents[1]
671 671 # swap_ids for ROUTER-ROUTER mirror
672 672 raw_msg[:2] = [client,engine]
673 673 # print (map(str, raw_msg[:4]))
674 674 self.client_stream.send_multipart(raw_msg, copy=False)
675 675 # now, update our data structures
676 676 msg_id = parent['msg_id']
677 677 self.pending[engine].pop(msg_id)
678 678 if success:
679 679 self.completed[engine].add(msg_id)
680 680 self.all_completed.add(msg_id)
681 681 else:
682 682 self.failed[engine].add(msg_id)
683 683 self.all_failed.add(msg_id)
684 684 self.all_done.add(msg_id)
685 685 self.destinations[msg_id] = engine
686 686
687 687 self.update_graph(msg_id, success)
688 688
689 689 def handle_unmet_dependency(self, idents, parent):
690 690 """handle an unmet dependency"""
691 691 engine = idents[0]
692 692 msg_id = parent['msg_id']
693 693
694 694 job = self.pending[engine].pop(msg_id)
695 695 job.blacklist.add(engine)
696 696
697 697 if job.blacklist == job.targets:
698 698 self.queue_map[msg_id] = job
699 699 self.fail_unreachable(msg_id)
700 700 elif not self.maybe_run(job):
701 701 # resubmit failed
702 702 if msg_id not in self.all_failed:
703 703 # put it back in our dependency tree
704 704 self.save_unmet(job)
705 705
706 706 if self.hwm:
707 707 try:
708 708 idx = self.targets.index(engine)
709 709 except ValueError:
710 710 pass # skip load-update for dead engines
711 711 else:
712 712 if self.loads[idx] == self.hwm-1:
713 713 self.update_graph(None)
714 714
715 715 def update_graph(self, dep_id=None, success=True):
716 716 """dep_id just finished. Update our dependency
717 717 graph and submit any jobs that just became runnable.
718 718
719 719 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
720 720 """
721 721 # print ("\n\n***********")
722 722 # pprint (dep_id)
723 723 # pprint (self.graph)
724 724 # pprint (self.queue_map)
725 725 # pprint (self.all_completed)
726 726 # pprint (self.all_failed)
727 727 # print ("\n\n***********\n\n")
728 728 # update any jobs that depended on the dependency
729 729 msg_ids = self.graph.pop(dep_id, [])
730 730
731 731 # recheck *all* jobs if
732 732 # a) we have HWM and an engine just become no longer full
733 733 # or b) dep_id was given as None
734 734
735 735 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
736 736 jobs = self.queue
737 737 using_queue = True
738 738 else:
739 739 using_queue = False
740 740 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
741 741
742 742 to_restore = []
743 743 while jobs:
744 744 job = jobs.popleft()
745 745 if job.removed:
746 746 continue
747 747 msg_id = job.msg_id
748 748
749 749 put_it_back = True
750 750
751 751 if job.after.unreachable(self.all_completed, self.all_failed)\
752 752 or job.follow.unreachable(self.all_completed, self.all_failed):
753 753 self.fail_unreachable(msg_id)
754 754 put_it_back = False
755 755
756 756 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
757 757 if self.maybe_run(job):
758 758 put_it_back = False
759 759 self.queue_map.pop(msg_id)
760 760 for mid in job.dependents:
761 761 if mid in self.graph:
762 762 self.graph[mid].remove(msg_id)
763 763
764 764 # abort the loop if we just filled up all of our engines.
765 765 # avoids an O(N) operation in situation of full queue,
766 766 # where graph update is triggered as soon as an engine becomes
767 767 # non-full, and all tasks after the first are checked,
768 768 # even though they can't run.
769 769 if not self.available_engines():
770 770 break
771 771
772 772 if using_queue and put_it_back:
773 773 # popped a job from the queue but it neither ran nor failed,
774 774 # so we need to put it back when we are done
775 775 # make sure to_restore preserves the same ordering
776 776 to_restore.append(job)
777 777
778 778 # put back any tasks we popped but didn't run
779 779 if using_queue:
780 780 self.queue.extendleft(to_restore)
781 781
782 782 #----------------------------------------------------------------------
783 783 # methods to be overridden by subclasses
784 784 #----------------------------------------------------------------------
785 785
786 786 def add_job(self, idx):
787 787 """Called after self.targets[idx] just got the job with header.
788 788 Override with subclasses. The default ordering is simple LRU.
789 789 The default loads are the number of outstanding jobs."""
790 790 self.loads[idx] += 1
791 791 for lis in (self.targets, self.loads):
792 792 lis.append(lis.pop(idx))
793 793
794 794
795 795 def finish_job(self, idx):
796 796 """Called after self.targets[idx] just finished a job.
797 797 Override with subclasses."""
798 798 self.loads[idx] -= 1
799 799
800 800
801 801
802 802 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
803 803 logname='root', log_url=None, loglevel=logging.DEBUG,
804 804 identity=b'task', in_thread=False):
805 805
806 806 ZMQStream = zmqstream.ZMQStream
807 807
808 808 if config:
809 809 # unwrap dict back into Config
810 810 config = Config(config)
811 811
812 812 if in_thread:
813 813 # use instance() to get the same Context/Loop as our parent
814 814 ctx = zmq.Context.instance()
815 815 loop = ioloop.IOLoop.instance()
816 816 else:
817 817 # in a process, don't use instance()
818 818 # for safety with multiprocessing
819 819 ctx = zmq.Context()
820 820 loop = ioloop.IOLoop()
821 821 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
822 822 util.set_hwm(ins, 0)
823 823 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
824 824 ins.bind(in_addr)
825 825
826 826 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
827 827 util.set_hwm(outs, 0)
828 828 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
829 829 outs.bind(out_addr)
830 830 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
831 831 util.set_hwm(mons, 0)
832 832 mons.connect(mon_addr)
833 833 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
834 834 nots.setsockopt(zmq.SUBSCRIBE, b'')
835 835 nots.connect(not_addr)
836 836
837 837 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
838 838 querys.connect(reg_addr)
839 839
840 840 # setup logging.
841 841 if in_thread:
842 842 log = Application.instance().log
843 843 else:
844 844 if log_url:
845 845 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
846 846 else:
847 847 log = local_logger(logname, loglevel)
848 848
849 849 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
850 850 mon_stream=mons, notifier_stream=nots,
851 851 query_stream=querys,
852 852 loop=loop, log=log,
853 853 config=config)
854 854 scheduler.start()
855 855 if not in_thread:
856 856 try:
857 857 loop.start()
858 858 except KeyboardInterrupt:
859 859 scheduler.log.critical("Interrupted, exiting...")
860 860
@@ -1,760 +1,760 b''
1 1 """Nose Plugin that supports IPython doctests.
2 2
3 3 Limitations:
4 4
5 5 - When generating examples for use as doctests, make sure that you have
6 6 pretty-printing OFF. This can be done either by setting the
7 7 ``PlainTextFormatter.pprint`` option in your configuration file to False, or
8 8 by interactively disabling it with %Pprint. This is required so that IPython
9 9 output matches that of normal Python, which is used by doctest for internal
10 10 execution.
11 11
12 12 - Do not rely on specific prompt numbers for results (such as using
13 13 '_34==True', for example). For IPython tests run via an external process the
14 14 prompt numbers may be different, and IPython tests run as normal python code
15 15 won't even have these special _NN variables set at all.
16 16 """
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Module imports
20 20
21 21 # From the standard library
22 22 import doctest
23 23 import inspect
24 24 import logging
25 25 import os
26 26 import re
27 27 import sys
28 28 import traceback
29 29 import unittest
30 30
31 31 from inspect import getmodule
32 32 from io import StringIO
33 33
34 34 # We are overriding the default doctest runner, so we need to import a few
35 35 # things from doctest directly
36 36 from doctest import (REPORTING_FLAGS, REPORT_ONLY_FIRST_FAILURE,
37 37 _unittest_reportflags, DocTestRunner,
38 38 _extract_future_flags, pdb, _OutputRedirectingPdb,
39 39 _exception_traceback,
40 40 linecache)
41 41
42 42 # Third-party modules
43 43 import nose.core
44 44
45 45 from nose.plugins import doctests, Plugin
46 46 from nose.util import anyp, getpackage, test_address, resolve_name, tolist
47 47
48 48 # Our own imports
49 49 from IPython.utils.py3compat import builtin_mod
50 50
51 51 #-----------------------------------------------------------------------------
52 52 # Module globals and other constants
53 53 #-----------------------------------------------------------------------------
54 54
55 55 log = logging.getLogger(__name__)
56 56
57 57
58 58 #-----------------------------------------------------------------------------
59 59 # Classes and functions
60 60 #-----------------------------------------------------------------------------
61 61
62 62 def is_extension_module(filename):
63 63 """Return whether the given filename is an extension module.
64 64
65 65 This simply checks that the extension is either .so or .pyd.
66 66 """
67 67 return os.path.splitext(filename)[1].lower() in ('.so','.pyd')
68 68
69 69
70 70 class DocTestSkip(object):
71 71 """Object wrapper for doctests to be skipped."""
72 72
73 73 ds_skip = """Doctest to skip.
74 74 >>> 1 #doctest: +SKIP
75 75 """
76 76
77 77 def __init__(self,obj):
78 78 self.obj = obj
79 79
80 80 def __getattribute__(self,key):
81 81 if key == '__doc__':
82 82 return DocTestSkip.ds_skip
83 83 else:
84 84 return getattr(object.__getattribute__(self,'obj'),key)
85 85
86 86 # Modified version of the one in the stdlib, that fixes a python bug (doctests
87 87 # not found in extension modules, http://bugs.python.org/issue3158)
88 88 class DocTestFinder(doctest.DocTestFinder):
89 89
90 90 def _from_module(self, module, object):
91 91 """
92 92 Return true if the given object is defined in the given
93 93 module.
94 94 """
95 95 if module is None:
96 96 return True
97 97 elif inspect.isfunction(object):
98 return module.__dict__ is object.func_globals
98 return module.__dict__ is object.__globals__
99 99 elif inspect.isbuiltin(object):
100 100 return module.__name__ == object.__module__
101 101 elif inspect.isclass(object):
102 102 return module.__name__ == object.__module__
103 103 elif inspect.ismethod(object):
104 104 # This one may be a bug in cython that fails to correctly set the
105 105 # __module__ attribute of methods, but since the same error is easy
106 106 # to make by extension code writers, having this safety in place
107 107 # isn't such a bad idea
108 108 return module.__name__ == object.im_class.__module__
109 109 elif inspect.getmodule(object) is not None:
110 110 return module is inspect.getmodule(object)
111 111 elif hasattr(object, '__module__'):
112 112 return module.__name__ == object.__module__
113 113 elif isinstance(object, property):
114 114 return True # [XX] no way not be sure.
115 115 else:
116 116 raise ValueError("object must be a class or function")
117 117
118 118 def _find(self, tests, obj, name, module, source_lines, globs, seen):
119 119 """
120 120 Find tests for the given object and any contained objects, and
121 121 add them to `tests`.
122 122 """
123 123 #print '_find for:', obj, name, module # dbg
124 124 if hasattr(obj,"skip_doctest"):
125 125 #print 'SKIPPING DOCTEST FOR:',obj # dbg
126 126 obj = DocTestSkip(obj)
127 127
128 128 doctest.DocTestFinder._find(self,tests, obj, name, module,
129 129 source_lines, globs, seen)
130 130
131 131 # Below we re-run pieces of the above method with manual modifications,
132 132 # because the original code is buggy and fails to correctly identify
133 133 # doctests in extension modules.
134 134
135 135 # Local shorthands
136 136 from inspect import isroutine, isclass, ismodule
137 137
138 138 # Look for tests in a module's contained objects.
139 139 if inspect.ismodule(obj) and self._recurse:
140 140 for valname, val in obj.__dict__.items():
141 141 valname1 = '%s.%s' % (name, valname)
142 142 if ( (isroutine(val) or isclass(val))
143 143 and self._from_module(module, val) ):
144 144
145 145 self._find(tests, val, valname1, module, source_lines,
146 146 globs, seen)
147 147
148 148 # Look for tests in a class's contained objects.
149 149 if inspect.isclass(obj) and self._recurse:
150 150 #print 'RECURSE into class:',obj # dbg
151 151 for valname, val in obj.__dict__.items():
152 152 # Special handling for staticmethod/classmethod.
153 153 if isinstance(val, staticmethod):
154 154 val = getattr(obj, valname)
155 155 if isinstance(val, classmethod):
156 156 val = getattr(obj, valname).im_func
157 157
158 158 # Recurse to methods, properties, and nested classes.
159 159 if ((inspect.isfunction(val) or inspect.isclass(val) or
160 160 inspect.ismethod(val) or
161 161 isinstance(val, property)) and
162 162 self._from_module(module, val)):
163 163 valname = '%s.%s' % (name, valname)
164 164 self._find(tests, val, valname, module, source_lines,
165 165 globs, seen)
166 166
167 167
168 168 class IPDoctestOutputChecker(doctest.OutputChecker):
169 169 """Second-chance checker with support for random tests.
170 170
171 171 If the default comparison doesn't pass, this checker looks in the expected
172 172 output string for flags that tell us to ignore the output.
173 173 """
174 174
175 175 random_re = re.compile(r'#\s*random\s+')
176 176
177 177 def check_output(self, want, got, optionflags):
178 178 """Check output, accepting special markers embedded in the output.
179 179
180 180 If the output didn't pass the default validation but the special string
181 181 '#random' is included, we accept it."""
182 182
183 183 # Let the original tester verify first, in case people have valid tests
184 184 # that happen to have a comment saying '#random' embedded in.
185 185 ret = doctest.OutputChecker.check_output(self, want, got,
186 186 optionflags)
187 187 if not ret and self.random_re.search(want):
188 188 #print >> sys.stderr, 'RANDOM OK:',want # dbg
189 189 return True
190 190
191 191 return ret
192 192
193 193
194 194 class DocTestCase(doctests.DocTestCase):
195 195 """Proxy for DocTestCase: provides an address() method that
196 196 returns the correct address for the doctest case. Otherwise
197 197 acts as a proxy to the test case. To provide hints for address(),
198 198 an obj may also be passed -- this will be used as the test object
199 199 for purposes of determining the test address, if it is provided.
200 200 """
201 201
202 202 # Note: this method was taken from numpy's nosetester module.
203 203
204 204 # Subclass nose.plugins.doctests.DocTestCase to work around a bug in
205 205 # its constructor that blocks non-default arguments from being passed
206 206 # down into doctest.DocTestCase
207 207
208 208 def __init__(self, test, optionflags=0, setUp=None, tearDown=None,
209 209 checker=None, obj=None, result_var='_'):
210 210 self._result_var = result_var
211 211 doctests.DocTestCase.__init__(self, test,
212 212 optionflags=optionflags,
213 213 setUp=setUp, tearDown=tearDown,
214 214 checker=checker)
215 215 # Now we must actually copy the original constructor from the stdlib
216 216 # doctest class, because we can't call it directly and a bug in nose
217 217 # means it never gets passed the right arguments.
218 218
219 219 self._dt_optionflags = optionflags
220 220 self._dt_checker = checker
221 221 self._dt_test = test
222 222 self._dt_test_globs_ori = test.globs
223 223 self._dt_setUp = setUp
224 224 self._dt_tearDown = tearDown
225 225
226 226 # XXX - store this runner once in the object!
227 227 runner = IPDocTestRunner(optionflags=optionflags,
228 228 checker=checker, verbose=False)
229 229 self._dt_runner = runner
230 230
231 231
232 232 # Each doctest should remember the directory it was loaded from, so
233 233 # things like %run work without too many contortions
234 234 self._ori_dir = os.path.dirname(test.filename)
235 235
236 236 # Modified runTest from the default stdlib
237 237 def runTest(self):
238 238 test = self._dt_test
239 239 runner = self._dt_runner
240 240
241 241 old = sys.stdout
242 242 new = StringIO()
243 243 optionflags = self._dt_optionflags
244 244
245 245 if not (optionflags & REPORTING_FLAGS):
246 246 # The option flags don't include any reporting flags,
247 247 # so add the default reporting flags
248 248 optionflags |= _unittest_reportflags
249 249
250 250 try:
251 251 # Save our current directory and switch out to the one where the
252 252 # test was originally created, in case another doctest did a
253 253 # directory change. We'll restore this in the finally clause.
254 254 curdir = os.getcwdu()
255 255 #print 'runTest in dir:', self._ori_dir # dbg
256 256 os.chdir(self._ori_dir)
257 257
258 258 runner.DIVIDER = "-"*70
259 259 failures, tries = runner.run(test,out=new.write,
260 260 clear_globs=False)
261 261 finally:
262 262 sys.stdout = old
263 263 os.chdir(curdir)
264 264
265 265 if failures:
266 266 raise self.failureException(self.format_failure(new.getvalue()))
267 267
268 268 def setUp(self):
269 269 """Modified test setup that syncs with ipython namespace"""
270 270 #print "setUp test", self._dt_test.examples # dbg
271 271 if isinstance(self._dt_test.examples[0], IPExample):
272 272 # for IPython examples *only*, we swap the globals with the ipython
273 273 # namespace, after updating it with the globals (which doctest
274 274 # fills with the necessary info from the module being tested).
275 275 self.user_ns_orig = {}
276 276 self.user_ns_orig.update(_ip.user_ns)
277 277 _ip.user_ns.update(self._dt_test.globs)
278 278 # We must remove the _ key in the namespace, so that Python's
279 279 # doctest code sets it naturally
280 280 _ip.user_ns.pop('_', None)
281 281 _ip.user_ns['__builtins__'] = builtin_mod
282 282 self._dt_test.globs = _ip.user_ns
283 283
284 284 super(DocTestCase, self).setUp()
285 285
286 286 def tearDown(self):
287 287
288 288 # Undo the test.globs reassignment we made, so that the parent class
289 289 # teardown doesn't destroy the ipython namespace
290 290 if isinstance(self._dt_test.examples[0], IPExample):
291 291 self._dt_test.globs = self._dt_test_globs_ori
292 292 _ip.user_ns.clear()
293 293 _ip.user_ns.update(self.user_ns_orig)
294 294
295 295 # XXX - fperez: I am not sure if this is truly a bug in nose 0.11, but
296 296 # it does look like one to me: its tearDown method tries to run
297 297 #
298 298 # delattr(builtin_mod, self._result_var)
299 299 #
300 300 # without checking that the attribute really is there; it implicitly
301 301 # assumes it should have been set via displayhook. But if the
302 302 # displayhook was never called, this doesn't necessarily happen. I
303 303 # haven't been able to find a little self-contained example outside of
304 304 # ipython that would show the problem so I can report it to the nose
305 305 # team, but it does happen a lot in our code.
306 306 #
307 307 # So here, we just protect as narrowly as possible by trapping an
308 308 # attribute error whose message would be the name of self._result_var,
309 309 # and letting any other error propagate.
310 310 try:
311 311 super(DocTestCase, self).tearDown()
312 312 except AttributeError as exc:
313 313 if exc.args[0] != self._result_var:
314 314 raise
315 315
316 316
317 317 # A simple subclassing of the original with a different class name, so we can
318 318 # distinguish and treat differently IPython examples from pure python ones.
319 319 class IPExample(doctest.Example): pass
320 320
321 321
322 322 class IPExternalExample(doctest.Example):
323 323 """Doctest examples to be run in an external process."""
324 324
325 325 def __init__(self, source, want, exc_msg=None, lineno=0, indent=0,
326 326 options=None):
327 327 # Parent constructor
328 328 doctest.Example.__init__(self,source,want,exc_msg,lineno,indent,options)
329 329
330 330 # An EXTRA newline is needed to prevent pexpect hangs
331 331 self.source += '\n'
332 332
333 333
334 334 class IPDocTestParser(doctest.DocTestParser):
335 335 """
336 336 A class used to parse strings containing doctest examples.
337 337
338 338 Note: This is a version modified to properly recognize IPython input and
339 339 convert any IPython examples into valid Python ones.
340 340 """
341 341 # This regular expression is used to find doctest examples in a
342 342 # string. It defines three groups: `source` is the source code
343 343 # (including leading indentation and prompts); `indent` is the
344 344 # indentation of the first (PS1) line of the source code; and
345 345 # `want` is the expected output (including leading indentation).
346 346
347 347 # Classic Python prompts or default IPython ones
348 348 _PS1_PY = r'>>>'
349 349 _PS2_PY = r'\.\.\.'
350 350
351 351 _PS1_IP = r'In\ \[\d+\]:'
352 352 _PS2_IP = r'\ \ \ \.\.\.+:'
353 353
354 354 _RE_TPL = r'''
355 355 # Source consists of a PS1 line followed by zero or more PS2 lines.
356 356 (?P<source>
357 357 (?:^(?P<indent> [ ]*) (?P<ps1> %s) .*) # PS1 line
358 358 (?:\n [ ]* (?P<ps2> %s) .*)*) # PS2 lines
359 359 \n? # a newline
360 360 # Want consists of any non-blank lines that do not start with PS1.
361 361 (?P<want> (?:(?![ ]*$) # Not a blank line
362 362 (?![ ]*%s) # Not a line starting with PS1
363 363 (?![ ]*%s) # Not a line starting with PS2
364 364 .*$\n? # But any other line
365 365 )*)
366 366 '''
367 367
368 368 _EXAMPLE_RE_PY = re.compile( _RE_TPL % (_PS1_PY,_PS2_PY,_PS1_PY,_PS2_PY),
369 369 re.MULTILINE | re.VERBOSE)
370 370
371 371 _EXAMPLE_RE_IP = re.compile( _RE_TPL % (_PS1_IP,_PS2_IP,_PS1_IP,_PS2_IP),
372 372 re.MULTILINE | re.VERBOSE)
373 373
374 374 # Mark a test as being fully random. In this case, we simply append the
375 375 # random marker ('#random') to each individual example's output. This way
376 376 # we don't need to modify any other code.
377 377 _RANDOM_TEST = re.compile(r'#\s*all-random\s+')
378 378
379 379 # Mark tests to be executed in an external process - currently unsupported.
380 380 _EXTERNAL_IP = re.compile(r'#\s*ipdoctest:\s*EXTERNAL')
381 381
382 382 def ip2py(self,source):
383 383 """Convert input IPython source into valid Python."""
384 384 block = _ip.input_transformer_manager.transform_cell(source)
385 385 if len(block.splitlines()) == 1:
386 386 return _ip.prefilter(block)
387 387 else:
388 388 return block
389 389
390 390 def parse(self, string, name='<string>'):
391 391 """
392 392 Divide the given string into examples and intervening text,
393 393 and return them as a list of alternating Examples and strings.
394 394 Line numbers for the Examples are 0-based. The optional
395 395 argument `name` is a name identifying this string, and is only
396 396 used for error messages.
397 397 """
398 398
399 399 #print 'Parse string:\n',string # dbg
400 400
401 401 string = string.expandtabs()
402 402 # If all lines begin with the same indentation, then strip it.
403 403 min_indent = self._min_indent(string)
404 404 if min_indent > 0:
405 405 string = '\n'.join([l[min_indent:] for l in string.split('\n')])
406 406
407 407 output = []
408 408 charno, lineno = 0, 0
409 409
410 410 # We make 'all random' tests by adding the '# random' mark to every
411 411 # block of output in the test.
412 412 if self._RANDOM_TEST.search(string):
413 413 random_marker = '\n# random'
414 414 else:
415 415 random_marker = ''
416 416
417 417 # Whether to convert the input from ipython to python syntax
418 418 ip2py = False
419 419 # Find all doctest examples in the string. First, try them as Python
420 420 # examples, then as IPython ones
421 421 terms = list(self._EXAMPLE_RE_PY.finditer(string))
422 422 if terms:
423 423 # Normal Python example
424 424 #print '-'*70 # dbg
425 425 #print 'PyExample, Source:\n',string # dbg
426 426 #print '-'*70 # dbg
427 427 Example = doctest.Example
428 428 else:
429 429 # It's an ipython example. Note that IPExamples are run
430 430 # in-process, so their syntax must be turned into valid python.
431 431 # IPExternalExamples are run out-of-process (via pexpect) so they
432 432 # don't need any filtering (a real ipython will be executing them).
433 433 terms = list(self._EXAMPLE_RE_IP.finditer(string))
434 434 if self._EXTERNAL_IP.search(string):
435 435 #print '-'*70 # dbg
436 436 #print 'IPExternalExample, Source:\n',string # dbg
437 437 #print '-'*70 # dbg
438 438 Example = IPExternalExample
439 439 else:
440 440 #print '-'*70 # dbg
441 441 #print 'IPExample, Source:\n',string # dbg
442 442 #print '-'*70 # dbg
443 443 Example = IPExample
444 444 ip2py = True
445 445
446 446 for m in terms:
447 447 # Add the pre-example text to `output`.
448 448 output.append(string[charno:m.start()])
449 449 # Update lineno (lines before this example)
450 450 lineno += string.count('\n', charno, m.start())
451 451 # Extract info from the regexp match.
452 452 (source, options, want, exc_msg) = \
453 453 self._parse_example(m, name, lineno,ip2py)
454 454
455 455 # Append the random-output marker (it defaults to empty in most
456 456 # cases, it's only non-empty for 'all-random' tests):
457 457 want += random_marker
458 458
459 459 if Example is IPExternalExample:
460 460 options[doctest.NORMALIZE_WHITESPACE] = True
461 461 want += '\n'
462 462
463 463 # Create an Example, and add it to the list.
464 464 if not self._IS_BLANK_OR_COMMENT(source):
465 465 output.append(Example(source, want, exc_msg,
466 466 lineno=lineno,
467 467 indent=min_indent+len(m.group('indent')),
468 468 options=options))
469 469 # Update lineno (lines inside this example)
470 470 lineno += string.count('\n', m.start(), m.end())
471 471 # Update charno.
472 472 charno = m.end()
473 473 # Add any remaining post-example text to `output`.
474 474 output.append(string[charno:])
475 475 return output
476 476
477 477 def _parse_example(self, m, name, lineno,ip2py=False):
478 478 """
479 479 Given a regular expression match from `_EXAMPLE_RE` (`m`),
480 480 return a pair `(source, want)`, where `source` is the matched
481 481 example's source code (with prompts and indentation stripped);
482 482 and `want` is the example's expected output (with indentation
483 483 stripped).
484 484
485 485 `name` is the string's name, and `lineno` is the line number
486 486 where the example starts; both are used for error messages.
487 487
488 488 Optional:
489 489 `ip2py`: if true, filter the input via IPython to convert the syntax
490 490 into valid python.
491 491 """
492 492
493 493 # Get the example's indentation level.
494 494 indent = len(m.group('indent'))
495 495
496 496 # Divide source into lines; check that they're properly
497 497 # indented; and then strip their indentation & prompts.
498 498 source_lines = m.group('source').split('\n')
499 499
500 500 # We're using variable-length input prompts
501 501 ps1 = m.group('ps1')
502 502 ps2 = m.group('ps2')
503 503 ps1_len = len(ps1)
504 504
505 505 self._check_prompt_blank(source_lines, indent, name, lineno,ps1_len)
506 506 if ps2:
507 507 self._check_prefix(source_lines[1:], ' '*indent + ps2, name, lineno)
508 508
509 509 source = '\n'.join([sl[indent+ps1_len+1:] for sl in source_lines])
510 510
511 511 if ip2py:
512 512 # Convert source input from IPython into valid Python syntax
513 513 source = self.ip2py(source)
514 514
515 515 # Divide want into lines; check that it's properly indented; and
516 516 # then strip the indentation. Spaces before the last newline should
517 517 # be preserved, so plain rstrip() isn't good enough.
518 518 want = m.group('want')
519 519 want_lines = want.split('\n')
520 520 if len(want_lines) > 1 and re.match(r' *$', want_lines[-1]):
521 521 del want_lines[-1] # forget final newline & spaces after it
522 522 self._check_prefix(want_lines, ' '*indent, name,
523 523 lineno + len(source_lines))
524 524
525 525 # Remove ipython output prompt that might be present in the first line
526 526 want_lines[0] = re.sub(r'Out\[\d+\]: \s*?\n?','',want_lines[0])
527 527
528 528 want = '\n'.join([wl[indent:] for wl in want_lines])
529 529
530 530 # If `want` contains a traceback message, then extract it.
531 531 m = self._EXCEPTION_RE.match(want)
532 532 if m:
533 533 exc_msg = m.group('msg')
534 534 else:
535 535 exc_msg = None
536 536
537 537 # Extract options from the source.
538 538 options = self._find_options(source, name, lineno)
539 539
540 540 return source, options, want, exc_msg
541 541
542 542 def _check_prompt_blank(self, lines, indent, name, lineno, ps1_len):
543 543 """
544 544 Given the lines of a source string (including prompts and
545 545 leading indentation), check to make sure that every prompt is
546 546 followed by a space character. If any line is not followed by
547 547 a space character, then raise ValueError.
548 548
549 549 Note: IPython-modified version which takes the input prompt length as a
550 550 parameter, so that prompts of variable length can be dealt with.
551 551 """
552 552 space_idx = indent+ps1_len
553 553 min_len = space_idx+1
554 554 for i, line in enumerate(lines):
555 555 if len(line) >= min_len and line[space_idx] != ' ':
556 556 raise ValueError('line %r of the docstring for %s '
557 557 'lacks blank after %s: %r' %
558 558 (lineno+i+1, name,
559 559 line[indent:space_idx], line))
560 560
561 561
562 562 SKIP = doctest.register_optionflag('SKIP')
563 563
564 564
565 565 class IPDocTestRunner(doctest.DocTestRunner,object):
566 566 """Test runner that synchronizes the IPython namespace with test globals.
567 567 """
568 568
569 569 def run(self, test, compileflags=None, out=None, clear_globs=True):
570 570
571 571 # Hack: ipython needs access to the execution context of the example,
572 572 # so that it can propagate user variables loaded by %run into
573 573 # test.globs. We put them here into our modified %run as a function
574 574 # attribute. Our new %run will then only make the namespace update
575 575 # when called (rather than unconconditionally updating test.globs here
576 576 # for all examples, most of which won't be calling %run anyway).
577 577 #_ip._ipdoctest_test_globs = test.globs
578 578 #_ip._ipdoctest_test_filename = test.filename
579 579
580 580 test.globs.update(_ip.user_ns)
581 581
582 582 return super(IPDocTestRunner,self).run(test,
583 583 compileflags,out,clear_globs)
584 584
585 585
586 586 class DocFileCase(doctest.DocFileCase):
587 587 """Overrides to provide filename
588 588 """
589 589 def address(self):
590 590 return (self._dt_test.filename, None, None)
591 591
592 592
593 593 class ExtensionDoctest(doctests.Doctest):
594 594 """Nose Plugin that supports doctests in extension modules.
595 595 """
596 596 name = 'extdoctest' # call nosetests with --with-extdoctest
597 597 enabled = True
598 598
599 599 def options(self, parser, env=os.environ):
600 600 Plugin.options(self, parser, env)
601 601 parser.add_option('--doctest-tests', action='store_true',
602 602 dest='doctest_tests',
603 603 default=env.get('NOSE_DOCTEST_TESTS',True),
604 604 help="Also look for doctests in test modules. "
605 605 "Note that classes, methods and functions should "
606 606 "have either doctests or non-doctest tests, "
607 607 "not both. [NOSE_DOCTEST_TESTS]")
608 608 parser.add_option('--doctest-extension', action="append",
609 609 dest="doctestExtension",
610 610 help="Also look for doctests in files with "
611 611 "this extension [NOSE_DOCTEST_EXTENSION]")
612 612 # Set the default as a list, if given in env; otherwise
613 613 # an additional value set on the command line will cause
614 614 # an error.
615 615 env_setting = env.get('NOSE_DOCTEST_EXTENSION')
616 616 if env_setting is not None:
617 617 parser.set_defaults(doctestExtension=tolist(env_setting))
618 618
619 619
620 620 def configure(self, options, config):
621 621 Plugin.configure(self, options, config)
622 622 # Pull standard doctest plugin out of config; we will do doctesting
623 623 config.plugins.plugins = [p for p in config.plugins.plugins
624 624 if p.name != 'doctest']
625 625 self.doctest_tests = options.doctest_tests
626 626 self.extension = tolist(options.doctestExtension)
627 627
628 628 self.parser = doctest.DocTestParser()
629 629 self.finder = DocTestFinder()
630 630 self.checker = IPDoctestOutputChecker()
631 631 self.globs = None
632 632 self.extraglobs = None
633 633
634 634
635 635 def loadTestsFromExtensionModule(self,filename):
636 636 bpath,mod = os.path.split(filename)
637 637 modname = os.path.splitext(mod)[0]
638 638 try:
639 639 sys.path.append(bpath)
640 640 module = __import__(modname)
641 641 tests = list(self.loadTestsFromModule(module))
642 642 finally:
643 643 sys.path.pop()
644 644 return tests
645 645
646 646 # NOTE: the method below is almost a copy of the original one in nose, with
647 647 # a few modifications to control output checking.
648 648
649 649 def loadTestsFromModule(self, module):
650 650 #print '*** ipdoctest - lTM',module # dbg
651 651
652 652 if not self.matches(module.__name__):
653 653 log.debug("Doctest doesn't want module %s", module)
654 654 return
655 655
656 656 tests = self.finder.find(module,globs=self.globs,
657 657 extraglobs=self.extraglobs)
658 658 if not tests:
659 659 return
660 660
661 661 # always use whitespace and ellipsis options
662 662 optionflags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS
663 663
664 664 tests.sort()
665 665 module_file = module.__file__
666 666 if module_file[-4:] in ('.pyc', '.pyo'):
667 667 module_file = module_file[:-1]
668 668 for test in tests:
669 669 if not test.examples:
670 670 continue
671 671 if not test.filename:
672 672 test.filename = module_file
673 673
674 674 yield DocTestCase(test,
675 675 optionflags=optionflags,
676 676 checker=self.checker)
677 677
678 678
679 679 def loadTestsFromFile(self, filename):
680 680 #print "ipdoctest - from file", filename # dbg
681 681 if is_extension_module(filename):
682 682 for t in self.loadTestsFromExtensionModule(filename):
683 683 yield t
684 684 else:
685 685 if self.extension and anyp(filename.endswith, self.extension):
686 686 name = os.path.basename(filename)
687 687 dh = open(filename)
688 688 try:
689 689 doc = dh.read()
690 690 finally:
691 691 dh.close()
692 692 test = self.parser.get_doctest(
693 693 doc, globs={'__file__': filename}, name=name,
694 694 filename=filename, lineno=0)
695 695 if test.examples:
696 696 #print 'FileCase:',test.examples # dbg
697 697 yield DocFileCase(test)
698 698 else:
699 699 yield False # no tests to load
700 700
701 701
702 702 class IPythonDoctest(ExtensionDoctest):
703 703 """Nose Plugin that supports doctests in extension modules.
704 704 """
705 705 name = 'ipdoctest' # call nosetests with --with-ipdoctest
706 706 enabled = True
707 707
708 708 def makeTest(self, obj, parent):
709 709 """Look for doctests in the given object, which will be a
710 710 function, method or class.
711 711 """
712 712 #print 'Plugin analyzing:', obj, parent # dbg
713 713 # always use whitespace and ellipsis options
714 714 optionflags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS
715 715
716 716 doctests = self.finder.find(obj, module=getmodule(parent))
717 717 if doctests:
718 718 for test in doctests:
719 719 if len(test.examples) == 0:
720 720 continue
721 721
722 722 yield DocTestCase(test, obj=obj,
723 723 optionflags=optionflags,
724 724 checker=self.checker)
725 725
726 726 def options(self, parser, env=os.environ):
727 727 #print "Options for nose plugin:", self.name # dbg
728 728 Plugin.options(self, parser, env)
729 729 parser.add_option('--ipdoctest-tests', action='store_true',
730 730 dest='ipdoctest_tests',
731 731 default=env.get('NOSE_IPDOCTEST_TESTS',True),
732 732 help="Also look for doctests in test modules. "
733 733 "Note that classes, methods and functions should "
734 734 "have either doctests or non-doctest tests, "
735 735 "not both. [NOSE_IPDOCTEST_TESTS]")
736 736 parser.add_option('--ipdoctest-extension', action="append",
737 737 dest="ipdoctest_extension",
738 738 help="Also look for doctests in files with "
739 739 "this extension [NOSE_IPDOCTEST_EXTENSION]")
740 740 # Set the default as a list, if given in env; otherwise
741 741 # an additional value set on the command line will cause
742 742 # an error.
743 743 env_setting = env.get('NOSE_IPDOCTEST_EXTENSION')
744 744 if env_setting is not None:
745 745 parser.set_defaults(ipdoctest_extension=tolist(env_setting))
746 746
747 747 def configure(self, options, config):
748 748 #print "Configuring nose plugin:", self.name # dbg
749 749 Plugin.configure(self, options, config)
750 750 # Pull standard doctest plugin out of config; we will do doctesting
751 751 config.plugins.plugins = [p for p in config.plugins.plugins
752 752 if p.name != 'doctest']
753 753 self.doctest_tests = options.ipdoctest_tests
754 754 self.extension = tolist(options.ipdoctest_extension)
755 755
756 756 self.parser = IPDocTestParser()
757 757 self.finder = DocTestFinder(parser=self.parser)
758 758 self.checker = IPDoctestOutputChecker()
759 759 self.globs = None
760 760 self.extraglobs = None
@@ -1,169 +1,169 b''
1 1 """Tests for the decorators we've created for IPython.
2 2 """
3 3 from __future__ import print_function
4 4
5 5 # Module imports
6 6 # Std lib
7 7 import inspect
8 8 import sys
9 9
10 10 # Third party
11 11 import nose.tools as nt
12 12
13 13 # Our own
14 14 from IPython.testing import decorators as dec
15 15 from IPython.testing.skipdoctest import skip_doctest
16 16
17 17 #-----------------------------------------------------------------------------
18 18 # Utilities
19 19
20 20 # Note: copied from OInspect, kept here so the testing stuff doesn't create
21 21 # circular dependencies and is easier to reuse.
22 22 def getargspec(obj):
23 23 """Get the names and default values of a function's arguments.
24 24
25 25 A tuple of four things is returned: (args, varargs, varkw, defaults).
26 26 'args' is a list of the argument names (it may contain nested lists).
27 27 'varargs' and 'varkw' are the names of the * and ** arguments or None.
28 28 'defaults' is an n-tuple of the default values of the last n arguments.
29 29
30 30 Modified version of inspect.getargspec from the Python Standard
31 31 Library."""
32 32
33 33 if inspect.isfunction(obj):
34 34 func_obj = obj
35 35 elif inspect.ismethod(obj):
36 36 func_obj = obj.im_func
37 37 else:
38 38 raise TypeError('arg is not a Python function')
39 args, varargs, varkw = inspect.getargs(func_obj.func_code)
40 return args, varargs, varkw, func_obj.func_defaults
39 args, varargs, varkw = inspect.getargs(func_obj.__code__)
40 return args, varargs, varkw, func_obj.__defaults__
41 41
42 42 #-----------------------------------------------------------------------------
43 43 # Testing functions
44 44
45 45 @dec.as_unittest
46 46 def trivial():
47 47 """A trivial test"""
48 48 pass
49 49
50 50
51 51 @dec.skip
52 52 def test_deliberately_broken():
53 53 """A deliberately broken test - we want to skip this one."""
54 54 1/0
55 55
56 56 @dec.skip('Testing the skip decorator')
57 57 def test_deliberately_broken2():
58 58 """Another deliberately broken test - we want to skip this one."""
59 59 1/0
60 60
61 61
62 62 # Verify that we can correctly skip the doctest for a function at will, but
63 63 # that the docstring itself is NOT destroyed by the decorator.
64 64 @skip_doctest
65 65 def doctest_bad(x,y=1,**k):
66 66 """A function whose doctest we need to skip.
67 67
68 68 >>> 1+1
69 69 3
70 70 """
71 71 print('x:',x)
72 72 print('y:',y)
73 73 print('k:',k)
74 74
75 75
76 76 def call_doctest_bad():
77 77 """Check that we can still call the decorated functions.
78 78
79 79 >>> doctest_bad(3,y=4)
80 80 x: 3
81 81 y: 4
82 82 k: {}
83 83 """
84 84 pass
85 85
86 86
87 87 def test_skip_dt_decorator():
88 88 """Doctest-skipping decorator should preserve the docstring.
89 89 """
90 90 # Careful: 'check' must be a *verbatim* copy of the doctest_bad docstring!
91 91 check = """A function whose doctest we need to skip.
92 92
93 93 >>> 1+1
94 94 3
95 95 """
96 96 # Fetch the docstring from doctest_bad after decoration.
97 97 val = doctest_bad.__doc__
98 98
99 99 nt.assert_equal(check,val,"doctest_bad docstrings don't match")
100 100
101 101
102 102 # Doctest skipping should work for class methods too
103 103 class FooClass(object):
104 104 """FooClass
105 105
106 106 Example:
107 107
108 108 >>> 1+1
109 109 2
110 110 """
111 111
112 112 @skip_doctest
113 113 def __init__(self,x):
114 114 """Make a FooClass.
115 115
116 116 Example:
117 117
118 118 >>> f = FooClass(3)
119 119 junk
120 120 """
121 121 print('Making a FooClass.')
122 122 self.x = x
123 123
124 124 @skip_doctest
125 125 def bar(self,y):
126 126 """Example:
127 127
128 128 >>> ff = FooClass(3)
129 129 >>> ff.bar(0)
130 130 boom!
131 131 >>> 1/0
132 132 bam!
133 133 """
134 134 return 1/y
135 135
136 136 def baz(self,y):
137 137 """Example:
138 138
139 139 >>> ff2 = FooClass(3)
140 140 Making a FooClass.
141 141 >>> ff2.baz(3)
142 142 True
143 143 """
144 144 return self.x==y
145 145
146 146
147 147 def test_skip_dt_decorator2():
148 148 """Doctest-skipping decorator should preserve function signature.
149 149 """
150 150 # Hardcoded correct answer
151 151 dtargs = (['x', 'y'], None, 'k', (1,))
152 152 # Introspect out the value
153 153 dtargsr = getargspec(doctest_bad)
154 154 assert dtargsr==dtargs, \
155 155 "Incorrectly reconstructed args for doctest_bad: %s" % (dtargsr,)
156 156
157 157
158 158 @dec.skip_linux
159 159 def test_linux():
160 160 nt.assert_false(sys.platform.startswith('linux'),"This test can't run under linux")
161 161
162 162 @dec.skip_win32
163 163 def test_win32():
164 164 nt.assert_not_equal(sys.platform,'win32',"This test can't run under windows")
165 165
166 166 @dec.skip_osx
167 167 def test_osx():
168 168 nt.assert_not_equal(sys.platform,'darwin',"This test can't run under osx")
169 169
@@ -1,352 +1,352 b''
1 1 # encoding: utf-8
2 2
3 3 """Pickle related utilities. Perhaps this should be called 'can'."""
4 4
5 5 __docformat__ = "restructuredtext en"
6 6
7 7 #-------------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-------------------------------------------------------------------------------
13 13
14 14 #-------------------------------------------------------------------------------
15 15 # Imports
16 16 #-------------------------------------------------------------------------------
17 17
18 18 import copy
19 19 import logging
20 20 import sys
21 21 from types import FunctionType
22 22
23 23 try:
24 24 import cPickle as pickle
25 25 except ImportError:
26 26 import pickle
27 27
28 28 from . import codeutil # This registers a hook when it's imported
29 29 from . import py3compat
30 30 from .importstring import import_item
31 31 from .py3compat import string_types, iteritems
32 32
33 33 from IPython.config import Application
34 34
35 35 if py3compat.PY3:
36 36 buffer = memoryview
37 37 class_type = type
38 38 else:
39 39 from types import ClassType
40 40 class_type = (type, ClassType)
41 41
42 42 #-------------------------------------------------------------------------------
43 43 # Classes
44 44 #-------------------------------------------------------------------------------
45 45
46 46
47 47 class CannedObject(object):
48 48 def __init__(self, obj, keys=[], hook=None):
49 49 """can an object for safe pickling
50 50
51 51 Parameters
52 52 ==========
53 53
54 54 obj:
55 55 The object to be canned
56 56 keys: list (optional)
57 57 list of attribute names that will be explicitly canned / uncanned
58 58 hook: callable (optional)
59 59 An optional extra callable,
60 60 which can do additional processing of the uncanned object.
61 61
62 62 large data may be offloaded into the buffers list,
63 63 used for zero-copy transfers.
64 64 """
65 65 self.keys = keys
66 66 self.obj = copy.copy(obj)
67 67 self.hook = can(hook)
68 68 for key in keys:
69 69 setattr(self.obj, key, can(getattr(obj, key)))
70 70
71 71 self.buffers = []
72 72
73 73 def get_object(self, g=None):
74 74 if g is None:
75 75 g = {}
76 76 obj = self.obj
77 77 for key in self.keys:
78 78 setattr(obj, key, uncan(getattr(obj, key), g))
79 79
80 80 if self.hook:
81 81 self.hook = uncan(self.hook, g)
82 82 self.hook(obj, g)
83 83 return self.obj
84 84
85 85
86 86 class Reference(CannedObject):
87 87 """object for wrapping a remote reference by name."""
88 88 def __init__(self, name):
89 89 if not isinstance(name, string_types):
90 90 raise TypeError("illegal name: %r"%name)
91 91 self.name = name
92 92 self.buffers = []
93 93
94 94 def __repr__(self):
95 95 return "<Reference: %r>"%self.name
96 96
97 97 def get_object(self, g=None):
98 98 if g is None:
99 99 g = {}
100 100
101 101 return eval(self.name, g)
102 102
103 103
104 104 class CannedFunction(CannedObject):
105 105
106 106 def __init__(self, f):
107 107 self._check_type(f)
108 self.code = f.func_code
109 if f.func_defaults:
110 self.defaults = [ can(fd) for fd in f.func_defaults ]
108 self.code = f.__code__
109 if f.__defaults__:
110 self.defaults = [ can(fd) for fd in f.__defaults__ ]
111 111 else:
112 112 self.defaults = None
113 113 self.module = f.__module__ or '__main__'
114 114 self.__name__ = f.__name__
115 115 self.buffers = []
116 116
117 117 def _check_type(self, obj):
118 118 assert isinstance(obj, FunctionType), "Not a function type"
119 119
120 120 def get_object(self, g=None):
121 121 # try to load function back into its module:
122 122 if not self.module.startswith('__'):
123 123 __import__(self.module)
124 124 g = sys.modules[self.module].__dict__
125 125
126 126 if g is None:
127 127 g = {}
128 128 if self.defaults:
129 129 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
130 130 else:
131 131 defaults = None
132 132 newFunc = FunctionType(self.code, g, self.__name__, defaults)
133 133 return newFunc
134 134
135 135 class CannedClass(CannedObject):
136 136
137 137 def __init__(self, cls):
138 138 self._check_type(cls)
139 139 self.name = cls.__name__
140 140 self.old_style = not isinstance(cls, type)
141 141 self._canned_dict = {}
142 142 for k,v in cls.__dict__.items():
143 143 if k not in ('__weakref__', '__dict__'):
144 144 self._canned_dict[k] = can(v)
145 145 if self.old_style:
146 146 mro = []
147 147 else:
148 148 mro = cls.mro()
149 149
150 150 self.parents = [ can(c) for c in mro[1:] ]
151 151 self.buffers = []
152 152
153 153 def _check_type(self, obj):
154 154 assert isinstance(obj, class_type), "Not a class type"
155 155
156 156 def get_object(self, g=None):
157 157 parents = tuple(uncan(p, g) for p in self.parents)
158 158 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
159 159
160 160 class CannedArray(CannedObject):
161 161 def __init__(self, obj):
162 162 from numpy import ascontiguousarray
163 163 self.shape = obj.shape
164 164 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
165 165 if sum(obj.shape) == 0:
166 166 # just pickle it
167 167 self.buffers = [pickle.dumps(obj, -1)]
168 168 else:
169 169 # ensure contiguous
170 170 obj = ascontiguousarray(obj, dtype=None)
171 171 self.buffers = [buffer(obj)]
172 172
173 173 def get_object(self, g=None):
174 174 from numpy import frombuffer
175 175 data = self.buffers[0]
176 176 if sum(self.shape) == 0:
177 177 # no shape, we just pickled it
178 178 return pickle.loads(data)
179 179 else:
180 180 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
181 181
182 182
183 183 class CannedBytes(CannedObject):
184 184 wrap = bytes
185 185 def __init__(self, obj):
186 186 self.buffers = [obj]
187 187
188 188 def get_object(self, g=None):
189 189 data = self.buffers[0]
190 190 return self.wrap(data)
191 191
192 192 def CannedBuffer(CannedBytes):
193 193 wrap = buffer
194 194
195 195 #-------------------------------------------------------------------------------
196 196 # Functions
197 197 #-------------------------------------------------------------------------------
198 198
199 199 def _logger():
200 200 """get the logger for the current Application
201 201
202 202 the root logger will be used if no Application is running
203 203 """
204 204 if Application.initialized():
205 205 logger = Application.instance().log
206 206 else:
207 207 logger = logging.getLogger()
208 208 if not logger.handlers:
209 209 logging.basicConfig()
210 210
211 211 return logger
212 212
213 213 def _import_mapping(mapping, original=None):
214 214 """import any string-keys in a type mapping
215 215
216 216 """
217 217 log = _logger()
218 218 log.debug("Importing canning map")
219 219 for key,value in mapping.items():
220 220 if isinstance(key, string_types):
221 221 try:
222 222 cls = import_item(key)
223 223 except Exception:
224 224 if original and key not in original:
225 225 # only message on user-added classes
226 226 log.error("canning class not importable: %r", key, exc_info=True)
227 227 mapping.pop(key)
228 228 else:
229 229 mapping[cls] = mapping.pop(key)
230 230
231 231 def istype(obj, check):
232 232 """like isinstance(obj, check), but strict
233 233
234 234 This won't catch subclasses.
235 235 """
236 236 if isinstance(check, tuple):
237 237 for cls in check:
238 238 if type(obj) is cls:
239 239 return True
240 240 return False
241 241 else:
242 242 return type(obj) is check
243 243
244 244 def can(obj):
245 245 """prepare an object for pickling"""
246 246
247 247 import_needed = False
248 248
249 249 for cls,canner in iteritems(can_map):
250 250 if isinstance(cls, string_types):
251 251 import_needed = True
252 252 break
253 253 elif istype(obj, cls):
254 254 return canner(obj)
255 255
256 256 if import_needed:
257 257 # perform can_map imports, then try again
258 258 # this will usually only happen once
259 259 _import_mapping(can_map, _original_can_map)
260 260 return can(obj)
261 261
262 262 return obj
263 263
264 264 def can_class(obj):
265 265 if isinstance(obj, class_type) and obj.__module__ == '__main__':
266 266 return CannedClass(obj)
267 267 else:
268 268 return obj
269 269
270 270 def can_dict(obj):
271 271 """can the *values* of a dict"""
272 272 if istype(obj, dict):
273 273 newobj = {}
274 274 for k, v in iteritems(obj):
275 275 newobj[k] = can(v)
276 276 return newobj
277 277 else:
278 278 return obj
279 279
280 280 sequence_types = (list, tuple, set)
281 281
282 282 def can_sequence(obj):
283 283 """can the elements of a sequence"""
284 284 if istype(obj, sequence_types):
285 285 t = type(obj)
286 286 return t([can(i) for i in obj])
287 287 else:
288 288 return obj
289 289
290 290 def uncan(obj, g=None):
291 291 """invert canning"""
292 292
293 293 import_needed = False
294 294 for cls,uncanner in iteritems(uncan_map):
295 295 if isinstance(cls, string_types):
296 296 import_needed = True
297 297 break
298 298 elif isinstance(obj, cls):
299 299 return uncanner(obj, g)
300 300
301 301 if import_needed:
302 302 # perform uncan_map imports, then try again
303 303 # this will usually only happen once
304 304 _import_mapping(uncan_map, _original_uncan_map)
305 305 return uncan(obj, g)
306 306
307 307 return obj
308 308
309 309 def uncan_dict(obj, g=None):
310 310 if istype(obj, dict):
311 311 newobj = {}
312 312 for k, v in iteritems(obj):
313 313 newobj[k] = uncan(v,g)
314 314 return newobj
315 315 else:
316 316 return obj
317 317
318 318 def uncan_sequence(obj, g=None):
319 319 if istype(obj, sequence_types):
320 320 t = type(obj)
321 321 return t([uncan(i,g) for i in obj])
322 322 else:
323 323 return obj
324 324
325 325 def _uncan_dependent_hook(dep, g=None):
326 326 dep.check_dependency()
327 327
328 328 def can_dependent(obj):
329 329 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
330 330
331 331 #-------------------------------------------------------------------------------
332 332 # API dictionaries
333 333 #-------------------------------------------------------------------------------
334 334
335 335 # These dicts can be extended for custom serialization of new objects
336 336
337 337 can_map = {
338 338 'IPython.parallel.dependent' : can_dependent,
339 339 'numpy.ndarray' : CannedArray,
340 340 FunctionType : CannedFunction,
341 341 bytes : CannedBytes,
342 342 buffer : CannedBuffer,
343 343 class_type : can_class,
344 344 }
345 345
346 346 uncan_map = {
347 347 CannedObject : lambda obj, g: obj.get_object(g),
348 348 }
349 349
350 350 # for use in _import_mapping:
351 351 _original_can_map = can_map.copy()
352 352 _original_uncan_map = uncan_map.copy()
General Comments 0
You need to be logged in to leave comments. Login now