##// END OF EJS Templates
remove decorator from external
MinRK -
Show More
@@ -1,621 +1,621 b''
1 1 # encoding: utf-8
2 2 """A base class for a configurable application."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 7 from __future__ import print_function
8 8
9 9 import json
10 10 import logging
11 11 import os
12 12 import re
13 13 import sys
14 14 from copy import deepcopy
15 15 from collections import defaultdict
16 16
17 from IPython.external.decorator import decorator
17 from decorator import decorator
18 18
19 19 from IPython.config.configurable import SingletonConfigurable
20 20 from IPython.config.loader import (
21 21 KVArgParseConfigLoader, PyFileConfigLoader, Config, ArgumentError, ConfigFileNotFound, JSONFileConfigLoader
22 22 )
23 23
24 24 from IPython.utils.traitlets import (
25 25 Unicode, List, Enum, Dict, Instance, TraitError
26 26 )
27 27 from IPython.utils.importstring import import_item
28 28 from IPython.utils.text import indent, wrap_paragraphs, dedent
29 29 from IPython.utils import py3compat
30 30 from IPython.utils.py3compat import string_types, iteritems
31 31
32 32 #-----------------------------------------------------------------------------
33 33 # Descriptions for the various sections
34 34 #-----------------------------------------------------------------------------
35 35
36 36 # merge flags&aliases into options
37 37 option_description = """
38 38 Arguments that take values are actually convenience aliases to full
39 39 Configurables, whose aliases are listed on the help line. For more information
40 40 on full configurables, see '--help-all'.
41 41 """.strip() # trim newlines of front and back
42 42
43 43 keyvalue_description = """
44 44 Parameters are set from command-line arguments of the form:
45 45 `--Class.trait=value`.
46 46 This line is evaluated in Python, so simple expressions are allowed, e.g.::
47 47 `--C.a='range(3)'` For setting C.a=[0,1,2].
48 48 """.strip() # trim newlines of front and back
49 49
50 50 # sys.argv can be missing, for example when python is embedded. See the docs
51 51 # for details: http://docs.python.org/2/c-api/intro.html#embedding-python
52 52 if not hasattr(sys, "argv"):
53 53 sys.argv = [""]
54 54
55 55 subcommand_description = """
56 56 Subcommands are launched as `{app} cmd [args]`. For information on using
57 57 subcommand 'cmd', do: `{app} cmd -h`.
58 58 """
59 59 # get running program name
60 60
61 61 #-----------------------------------------------------------------------------
62 62 # Application class
63 63 #-----------------------------------------------------------------------------
64 64
65 65 @decorator
66 66 def catch_config_error(method, app, *args, **kwargs):
67 67 """Method decorator for catching invalid config (Trait/ArgumentErrors) during init.
68 68
69 69 On a TraitError (generally caused by bad config), this will print the trait's
70 70 message, and exit the app.
71 71
72 72 For use on init methods, to prevent invoking excepthook on invalid input.
73 73 """
74 74 try:
75 75 return method(app, *args, **kwargs)
76 76 except (TraitError, ArgumentError) as e:
77 77 app.print_help()
78 78 app.log.fatal("Bad config encountered during initialization:")
79 79 app.log.fatal(str(e))
80 80 app.log.debug("Config at the time: %s", app.config)
81 81 app.exit(1)
82 82
83 83
84 84 class ApplicationError(Exception):
85 85 pass
86 86
87 87 class LevelFormatter(logging.Formatter):
88 88 """Formatter with additional `highlevel` record
89 89
90 90 This field is empty if log level is less than highlevel_limit,
91 91 otherwise it is formatted with self.highlevel_format.
92 92
93 93 Useful for adding 'WARNING' to warning messages,
94 94 without adding 'INFO' to info, etc.
95 95 """
96 96 highlevel_limit = logging.WARN
97 97 highlevel_format = " %(levelname)s |"
98 98
99 99 def format(self, record):
100 100 if record.levelno >= self.highlevel_limit:
101 101 record.highlevel = self.highlevel_format % record.__dict__
102 102 else:
103 103 record.highlevel = ""
104 104 return super(LevelFormatter, self).format(record)
105 105
106 106
107 107 class Application(SingletonConfigurable):
108 108 """A singleton application with full configuration support."""
109 109
110 110 # The name of the application, will usually match the name of the command
111 111 # line application
112 112 name = Unicode(u'application')
113 113
114 114 # The description of the application that is printed at the beginning
115 115 # of the help.
116 116 description = Unicode(u'This is an application.')
117 117 # default section descriptions
118 118 option_description = Unicode(option_description)
119 119 keyvalue_description = Unicode(keyvalue_description)
120 120 subcommand_description = Unicode(subcommand_description)
121 121
122 122 # The usage and example string that goes at the end of the help string.
123 123 examples = Unicode()
124 124
125 125 # A sequence of Configurable subclasses whose config=True attributes will
126 126 # be exposed at the command line.
127 127 classes = []
128 128 @property
129 129 def _help_classes(self):
130 130 """Define `App.help_classes` if CLI classes should differ from config file classes"""
131 131 return getattr(self, 'help_classes', self.classes)
132 132
133 133 @property
134 134 def _config_classes(self):
135 135 """Define `App.config_classes` if config file classes should differ from CLI classes."""
136 136 return getattr(self, 'config_classes', self.classes)
137 137
138 138 # The version string of this application.
139 139 version = Unicode(u'0.0')
140 140
141 141 # the argv used to initialize the application
142 142 argv = List()
143 143
144 144 # The log level for the application
145 145 log_level = Enum((0,10,20,30,40,50,'DEBUG','INFO','WARN','ERROR','CRITICAL'),
146 146 default_value=logging.WARN,
147 147 config=True,
148 148 help="Set the log level by value or name.")
149 149 def _log_level_changed(self, name, old, new):
150 150 """Adjust the log level when log_level is set."""
151 151 if isinstance(new, string_types):
152 152 new = getattr(logging, new)
153 153 self.log_level = new
154 154 self.log.setLevel(new)
155 155
156 156 _log_formatter_cls = LevelFormatter
157 157
158 158 log_datefmt = Unicode("%Y-%m-%d %H:%M:%S", config=True,
159 159 help="The date format used by logging formatters for %(asctime)s"
160 160 )
161 161 def _log_datefmt_changed(self, name, old, new):
162 162 self._log_format_changed('log_format', self.log_format, self.log_format)
163 163
164 164 log_format = Unicode("[%(name)s]%(highlevel)s %(message)s", config=True,
165 165 help="The Logging format template",
166 166 )
167 167 def _log_format_changed(self, name, old, new):
168 168 """Change the log formatter when log_format is set."""
169 169 _log_handler = self.log.handlers[0]
170 170 _log_formatter = self._log_formatter_cls(fmt=new, datefmt=self.log_datefmt)
171 171 _log_handler.setFormatter(_log_formatter)
172 172
173 173
174 174 log = Instance(logging.Logger)
175 175 def _log_default(self):
176 176 """Start logging for this application.
177 177
178 178 The default is to log to stderr using a StreamHandler, if no default
179 179 handler already exists. The log level starts at logging.WARN, but this
180 180 can be adjusted by setting the ``log_level`` attribute.
181 181 """
182 182 log = logging.getLogger(self.__class__.__name__)
183 183 log.setLevel(self.log_level)
184 184 log.propagate = False
185 185 _log = log # copied from Logger.hasHandlers() (new in Python 3.2)
186 186 while _log:
187 187 if _log.handlers:
188 188 return log
189 189 if not _log.propagate:
190 190 break
191 191 else:
192 192 _log = _log.parent
193 193 if sys.executable.endswith('pythonw.exe'):
194 194 # this should really go to a file, but file-logging is only
195 195 # hooked up in parallel applications
196 196 _log_handler = logging.StreamHandler(open(os.devnull, 'w'))
197 197 else:
198 198 _log_handler = logging.StreamHandler()
199 199 _log_formatter = self._log_formatter_cls(fmt=self.log_format, datefmt=self.log_datefmt)
200 200 _log_handler.setFormatter(_log_formatter)
201 201 log.addHandler(_log_handler)
202 202 return log
203 203
204 204 # the alias map for configurables
205 205 aliases = Dict({'log-level' : 'Application.log_level'})
206 206
207 207 # flags for loading Configurables or store_const style flags
208 208 # flags are loaded from this dict by '--key' flags
209 209 # this must be a dict of two-tuples, the first element being the Config/dict
210 210 # and the second being the help string for the flag
211 211 flags = Dict()
212 212 def _flags_changed(self, name, old, new):
213 213 """ensure flags dict is valid"""
214 214 for key,value in iteritems(new):
215 215 assert len(value) == 2, "Bad flag: %r:%s"%(key,value)
216 216 assert isinstance(value[0], (dict, Config)), "Bad flag: %r:%s"%(key,value)
217 217 assert isinstance(value[1], string_types), "Bad flag: %r:%s"%(key,value)
218 218
219 219
220 220 # subcommands for launching other applications
221 221 # if this is not empty, this will be a parent Application
222 222 # this must be a dict of two-tuples,
223 223 # the first element being the application class/import string
224 224 # and the second being the help string for the subcommand
225 225 subcommands = Dict()
226 226 # parse_command_line will initialize a subapp, if requested
227 227 subapp = Instance('IPython.config.application.Application', allow_none=True)
228 228
229 229 # extra command-line arguments that don't set config values
230 230 extra_args = List(Unicode)
231 231
232 232
233 233 def __init__(self, **kwargs):
234 234 SingletonConfigurable.__init__(self, **kwargs)
235 235 # Ensure my class is in self.classes, so my attributes appear in command line
236 236 # options and config files.
237 237 if self.__class__ not in self.classes:
238 238 self.classes.insert(0, self.__class__)
239 239
240 240 def _config_changed(self, name, old, new):
241 241 SingletonConfigurable._config_changed(self, name, old, new)
242 242 self.log.debug('Config changed:')
243 243 self.log.debug(repr(new))
244 244
245 245 @catch_config_error
246 246 def initialize(self, argv=None):
247 247 """Do the basic steps to configure me.
248 248
249 249 Override in subclasses.
250 250 """
251 251 self.parse_command_line(argv)
252 252
253 253
254 254 def start(self):
255 255 """Start the app mainloop.
256 256
257 257 Override in subclasses.
258 258 """
259 259 if self.subapp is not None:
260 260 return self.subapp.start()
261 261
262 262 def print_alias_help(self):
263 263 """Print the alias part of the help."""
264 264 if not self.aliases:
265 265 return
266 266
267 267 lines = []
268 268 classdict = {}
269 269 for cls in self._help_classes:
270 270 # include all parents (up to, but excluding Configurable) in available names
271 271 for c in cls.mro()[:-3]:
272 272 classdict[c.__name__] = c
273 273
274 274 for alias, longname in iteritems(self.aliases):
275 275 classname, traitname = longname.split('.',1)
276 276 cls = classdict[classname]
277 277
278 278 trait = cls.class_traits(config=True)[traitname]
279 279 help = cls.class_get_trait_help(trait).splitlines()
280 280 # reformat first line
281 281 help[0] = help[0].replace(longname, alias) + ' (%s)'%longname
282 282 if len(alias) == 1:
283 283 help[0] = help[0].replace('--%s='%alias, '-%s '%alias)
284 284 lines.extend(help)
285 285 # lines.append('')
286 286 print(os.linesep.join(lines))
287 287
288 288 def print_flag_help(self):
289 289 """Print the flag part of the help."""
290 290 if not self.flags:
291 291 return
292 292
293 293 lines = []
294 294 for m, (cfg,help) in iteritems(self.flags):
295 295 prefix = '--' if len(m) > 1 else '-'
296 296 lines.append(prefix+m)
297 297 lines.append(indent(dedent(help.strip())))
298 298 # lines.append('')
299 299 print(os.linesep.join(lines))
300 300
301 301 def print_options(self):
302 302 if not self.flags and not self.aliases:
303 303 return
304 304 lines = ['Options']
305 305 lines.append('-'*len(lines[0]))
306 306 lines.append('')
307 307 for p in wrap_paragraphs(self.option_description):
308 308 lines.append(p)
309 309 lines.append('')
310 310 print(os.linesep.join(lines))
311 311 self.print_flag_help()
312 312 self.print_alias_help()
313 313 print()
314 314
315 315 def print_subcommands(self):
316 316 """Print the subcommand part of the help."""
317 317 if not self.subcommands:
318 318 return
319 319
320 320 lines = ["Subcommands"]
321 321 lines.append('-'*len(lines[0]))
322 322 lines.append('')
323 323 for p in wrap_paragraphs(self.subcommand_description.format(
324 324 app=self.name)):
325 325 lines.append(p)
326 326 lines.append('')
327 327 for subc, (cls, help) in iteritems(self.subcommands):
328 328 lines.append(subc)
329 329 if help:
330 330 lines.append(indent(dedent(help.strip())))
331 331 lines.append('')
332 332 print(os.linesep.join(lines))
333 333
334 334 def print_help(self, classes=False):
335 335 """Print the help for each Configurable class in self.classes.
336 336
337 337 If classes=False (the default), only flags and aliases are printed.
338 338 """
339 339 self.print_description()
340 340 self.print_subcommands()
341 341 self.print_options()
342 342
343 343 if classes:
344 344 help_classes = self._help_classes
345 345 if help_classes:
346 346 print("Class parameters")
347 347 print("----------------")
348 348 print()
349 349 for p in wrap_paragraphs(self.keyvalue_description):
350 350 print(p)
351 351 print()
352 352
353 353 for cls in help_classes:
354 354 cls.class_print_help()
355 355 print()
356 356 else:
357 357 print("To see all available configurables, use `--help-all`")
358 358 print()
359 359
360 360 self.print_examples()
361 361
362 362
363 363 def print_description(self):
364 364 """Print the application description."""
365 365 for p in wrap_paragraphs(self.description):
366 366 print(p)
367 367 print()
368 368
369 369 def print_examples(self):
370 370 """Print usage and examples.
371 371
372 372 This usage string goes at the end of the command line help string
373 373 and should contain examples of the application's usage.
374 374 """
375 375 if self.examples:
376 376 print("Examples")
377 377 print("--------")
378 378 print()
379 379 print(indent(dedent(self.examples.strip())))
380 380 print()
381 381
382 382 def print_version(self):
383 383 """Print the version string."""
384 384 print(self.version)
385 385
386 386 def update_config(self, config):
387 387 """Fire the traits events when the config is updated."""
388 388 # Save a copy of the current config.
389 389 newconfig = deepcopy(self.config)
390 390 # Merge the new config into the current one.
391 391 newconfig.merge(config)
392 392 # Save the combined config as self.config, which triggers the traits
393 393 # events.
394 394 self.config = newconfig
395 395
396 396 @catch_config_error
397 397 def initialize_subcommand(self, subc, argv=None):
398 398 """Initialize a subcommand with argv."""
399 399 subapp,help = self.subcommands.get(subc)
400 400
401 401 if isinstance(subapp, string_types):
402 402 subapp = import_item(subapp)
403 403
404 404 # clear existing instances
405 405 self.__class__.clear_instance()
406 406 # instantiate
407 407 self.subapp = subapp.instance(config=self.config)
408 408 # and initialize subapp
409 409 self.subapp.initialize(argv)
410 410
411 411 def flatten_flags(self):
412 412 """flatten flags and aliases, so cl-args override as expected.
413 413
414 414 This prevents issues such as an alias pointing to InteractiveShell,
415 415 but a config file setting the same trait in TerminalInteraciveShell
416 416 getting inappropriate priority over the command-line arg.
417 417
418 418 Only aliases with exactly one descendent in the class list
419 419 will be promoted.
420 420
421 421 """
422 422 # build a tree of classes in our list that inherit from a particular
423 423 # it will be a dict by parent classname of classes in our list
424 424 # that are descendents
425 425 mro_tree = defaultdict(list)
426 426 for cls in self._help_classes:
427 427 clsname = cls.__name__
428 428 for parent in cls.mro()[1:-3]:
429 429 # exclude cls itself and Configurable,HasTraits,object
430 430 mro_tree[parent.__name__].append(clsname)
431 431 # flatten aliases, which have the form:
432 432 # { 'alias' : 'Class.trait' }
433 433 aliases = {}
434 434 for alias, cls_trait in iteritems(self.aliases):
435 435 cls,trait = cls_trait.split('.',1)
436 436 children = mro_tree[cls]
437 437 if len(children) == 1:
438 438 # exactly one descendent, promote alias
439 439 cls = children[0]
440 440 aliases[alias] = '.'.join([cls,trait])
441 441
442 442 # flatten flags, which are of the form:
443 443 # { 'key' : ({'Cls' : {'trait' : value}}, 'help')}
444 444 flags = {}
445 445 for key, (flagdict, help) in iteritems(self.flags):
446 446 newflag = {}
447 447 for cls, subdict in iteritems(flagdict):
448 448 children = mro_tree[cls]
449 449 # exactly one descendent, promote flag section
450 450 if len(children) == 1:
451 451 cls = children[0]
452 452 newflag[cls] = subdict
453 453 flags[key] = (newflag, help)
454 454 return flags, aliases
455 455
456 456 @catch_config_error
457 457 def parse_command_line(self, argv=None):
458 458 """Parse the command line arguments."""
459 459 argv = sys.argv[1:] if argv is None else argv
460 460 self.argv = [ py3compat.cast_unicode(arg) for arg in argv ]
461 461
462 462 if argv and argv[0] == 'help':
463 463 # turn `ipython help notebook` into `ipython notebook -h`
464 464 argv = argv[1:] + ['-h']
465 465
466 466 if self.subcommands and len(argv) > 0:
467 467 # we have subcommands, and one may have been specified
468 468 subc, subargv = argv[0], argv[1:]
469 469 if re.match(r'^\w(\-?\w)*$', subc) and subc in self.subcommands:
470 470 # it's a subcommand, and *not* a flag or class parameter
471 471 return self.initialize_subcommand(subc, subargv)
472 472
473 473 # Arguments after a '--' argument are for the script IPython may be
474 474 # about to run, not IPython iteslf. For arguments parsed here (help and
475 475 # version), we want to only search the arguments up to the first
476 476 # occurrence of '--', which we're calling interpreted_argv.
477 477 try:
478 478 interpreted_argv = argv[:argv.index('--')]
479 479 except ValueError:
480 480 interpreted_argv = argv
481 481
482 482 if any(x in interpreted_argv for x in ('-h', '--help-all', '--help')):
483 483 self.print_help('--help-all' in interpreted_argv)
484 484 self.exit(0)
485 485
486 486 if '--version' in interpreted_argv or '-V' in interpreted_argv:
487 487 self.print_version()
488 488 self.exit(0)
489 489
490 490 # flatten flags&aliases, so cl-args get appropriate priority:
491 491 flags,aliases = self.flatten_flags()
492 492 loader = KVArgParseConfigLoader(argv=argv, aliases=aliases,
493 493 flags=flags, log=self.log)
494 494 config = loader.load_config()
495 495 self.update_config(config)
496 496 # store unparsed args in extra_args
497 497 self.extra_args = loader.extra_args
498 498
499 499 @classmethod
500 500 def _load_config_files(cls, basefilename, path=None, log=None):
501 501 """Load config files (py,json) by filename and path.
502 502
503 503 yield each config object in turn.
504 504 """
505 505
506 506 if not isinstance(path, list):
507 507 path = [path]
508 508 for path in path[::-1]:
509 509 # path list is in descending priority order, so load files backwards:
510 510 pyloader = PyFileConfigLoader(basefilename+'.py', path=path, log=log)
511 511 jsonloader = JSONFileConfigLoader(basefilename+'.json', path=path, log=log)
512 512 config = None
513 513 for loader in [pyloader, jsonloader]:
514 514 try:
515 515 config = loader.load_config()
516 516 except ConfigFileNotFound:
517 517 pass
518 518 except Exception:
519 519 # try to get the full filename, but it will be empty in the
520 520 # unlikely event that the error raised before filefind finished
521 521 filename = loader.full_filename or basefilename
522 522 # problem while running the file
523 523 if log:
524 524 log.error("Exception while loading config file %s",
525 525 filename, exc_info=True)
526 526 else:
527 527 if log:
528 528 log.debug("Loaded config file: %s", loader.full_filename)
529 529 if config:
530 530 yield config
531 531
532 532 raise StopIteration
533 533
534 534
535 535 @catch_config_error
536 536 def load_config_file(self, filename, path=None):
537 537 """Load config files by filename and path."""
538 538 filename, ext = os.path.splitext(filename)
539 539 loaded = []
540 540 for config in self._load_config_files(filename, path=path, log=self.log):
541 541 loaded.append(config)
542 542 self.update_config(config)
543 543 if len(loaded) > 1:
544 544 collisions = loaded[0].collisions(loaded[1])
545 545 if collisions:
546 546 self.log.warn("Collisions detected in {0}.py and {0}.json config files."
547 547 " {0}.json has higher priority: {1}".format(
548 548 filename, json.dumps(collisions, indent=2),
549 549 ))
550 550
551 551
552 552 def generate_config_file(self):
553 553 """generate default config file from Configurables"""
554 554 lines = ["# Configuration file for %s."%self.name]
555 555 lines.append('')
556 556 lines.append('c = get_config()')
557 557 lines.append('')
558 558 for cls in self._config_classes:
559 559 lines.append(cls.class_config_section())
560 560 return '\n'.join(lines)
561 561
562 562 def exit(self, exit_status=0):
563 563 self.log.debug("Exiting application: %s" % self.name)
564 564 sys.exit(exit_status)
565 565
566 566 @classmethod
567 567 def launch_instance(cls, argv=None, **kwargs):
568 568 """Launch a global instance of this Application
569 569
570 570 If a global instance already exists, this reinitializes and starts it
571 571 """
572 572 app = cls.instance(**kwargs)
573 573 app.initialize(argv)
574 574 app.start()
575 575
576 576 #-----------------------------------------------------------------------------
577 577 # utility functions, for convenience
578 578 #-----------------------------------------------------------------------------
579 579
580 580 def boolean_flag(name, configurable, set_help='', unset_help=''):
581 581 """Helper for building basic --trait, --no-trait flags.
582 582
583 583 Parameters
584 584 ----------
585 585
586 586 name : str
587 587 The name of the flag.
588 588 configurable : str
589 589 The 'Class.trait' string of the trait to be set/unset with the flag
590 590 set_help : unicode
591 591 help string for --name flag
592 592 unset_help : unicode
593 593 help string for --no-name flag
594 594
595 595 Returns
596 596 -------
597 597
598 598 cfg : dict
599 599 A dict with two keys: 'name', and 'no-name', for setting and unsetting
600 600 the trait, respectively.
601 601 """
602 602 # default helpstrings
603 603 set_help = set_help or "set %s=True"%configurable
604 604 unset_help = unset_help or "set %s=False"%configurable
605 605
606 606 cls,trait = configurable.split('.')
607 607
608 608 setter = {cls : {trait : True}}
609 609 unsetter = {cls : {trait : False}}
610 610 return {name : (setter, set_help), 'no-'+name : (unsetter, unset_help)}
611 611
612 612
613 613 def get_config():
614 614 """Get the config object for the global Application instance, if there is one
615 615
616 616 otherwise return an empty config object
617 617 """
618 618 if Application.initialized():
619 619 return Application.instance().config
620 620 else:
621 621 return Config()
@@ -1,965 +1,965 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Display formatters.
3 3
4 4 Inheritance diagram:
5 5
6 6 .. inheritance-diagram:: IPython.core.formatters
7 7 :parts: 3
8 8 """
9 9
10 10 # Copyright (c) IPython Development Team.
11 11 # Distributed under the terms of the Modified BSD License.
12 12
13 13 import abc
14 14 import inspect
15 15 import json
16 16 import sys
17 17 import traceback
18 18 import warnings
19 19
20 from IPython.external.decorator import decorator
20 from decorator import decorator
21 21
22 22 from IPython.config.configurable import Configurable
23 23 from IPython.core.getipython import get_ipython
24 24 from IPython.lib import pretty
25 25 from IPython.utils.traitlets import (
26 26 Bool, Dict, Integer, Unicode, CUnicode, ObjectName, List,
27 27 ForwardDeclaredInstance,
28 28 )
29 29 from IPython.utils.py3compat import (
30 30 with_metaclass, string_types, unicode_type,
31 31 )
32 32
33 33
34 34 #-----------------------------------------------------------------------------
35 35 # The main DisplayFormatter class
36 36 #-----------------------------------------------------------------------------
37 37
38 38
39 39 def _safe_get_formatter_method(obj, name):
40 40 """Safely get a formatter method
41 41
42 42 - Classes cannot have formatter methods, only instance
43 43 - protect against proxy objects that claim to have everything
44 44 """
45 45 if inspect.isclass(obj):
46 46 # repr methods only make sense on instances, not classes
47 47 return None
48 48 method = pretty._safe_getattr(obj, name, None)
49 49 if callable(method):
50 50 # obj claims to have repr method...
51 51 if callable(pretty._safe_getattr(obj, '_ipython_canary_method_should_not_exist_', None)):
52 52 # ...but don't trust proxy objects that claim to have everything
53 53 return None
54 54 return method
55 55
56 56
57 57 class DisplayFormatter(Configurable):
58 58
59 59 # When set to true only the default plain text formatter will be used.
60 60 plain_text_only = Bool(False, config=True)
61 61 def _plain_text_only_changed(self, name, old, new):
62 62 warnings.warn("""DisplayFormatter.plain_text_only is deprecated.
63 63
64 64 Use DisplayFormatter.active_types = ['text/plain']
65 65 for the same effect.
66 66 """, DeprecationWarning)
67 67 if new:
68 68 self.active_types = ['text/plain']
69 69 else:
70 70 self.active_types = self.format_types
71 71
72 72 active_types = List(Unicode, config=True,
73 73 help="""List of currently active mime-types to display.
74 74 You can use this to set a white-list for formats to display.
75 75
76 76 Most users will not need to change this value.
77 77 """)
78 78 def _active_types_default(self):
79 79 return self.format_types
80 80
81 81 def _active_types_changed(self, name, old, new):
82 82 for key, formatter in self.formatters.items():
83 83 if key in new:
84 84 formatter.enabled = True
85 85 else:
86 86 formatter.enabled = False
87 87
88 88 ipython_display_formatter = ForwardDeclaredInstance('FormatterABC')
89 89 def _ipython_display_formatter_default(self):
90 90 return IPythonDisplayFormatter(parent=self)
91 91
92 92 # A dict of formatter whose keys are format types (MIME types) and whose
93 93 # values are subclasses of BaseFormatter.
94 94 formatters = Dict()
95 95 def _formatters_default(self):
96 96 """Activate the default formatters."""
97 97 formatter_classes = [
98 98 PlainTextFormatter,
99 99 HTMLFormatter,
100 100 MarkdownFormatter,
101 101 SVGFormatter,
102 102 PNGFormatter,
103 103 PDFFormatter,
104 104 JPEGFormatter,
105 105 LatexFormatter,
106 106 JSONFormatter,
107 107 JavascriptFormatter
108 108 ]
109 109 d = {}
110 110 for cls in formatter_classes:
111 111 f = cls(parent=self)
112 112 d[f.format_type] = f
113 113 return d
114 114
115 115 def format(self, obj, include=None, exclude=None):
116 116 """Return a format data dict for an object.
117 117
118 118 By default all format types will be computed.
119 119
120 120 The following MIME types are currently implemented:
121 121
122 122 * text/plain
123 123 * text/html
124 124 * text/markdown
125 125 * text/latex
126 126 * application/json
127 127 * application/javascript
128 128 * application/pdf
129 129 * image/png
130 130 * image/jpeg
131 131 * image/svg+xml
132 132
133 133 Parameters
134 134 ----------
135 135 obj : object
136 136 The Python object whose format data will be computed.
137 137 include : list or tuple, optional
138 138 A list of format type strings (MIME types) to include in the
139 139 format data dict. If this is set *only* the format types included
140 140 in this list will be computed.
141 141 exclude : list or tuple, optional
142 142 A list of format type string (MIME types) to exclude in the format
143 143 data dict. If this is set all format types will be computed,
144 144 except for those included in this argument.
145 145
146 146 Returns
147 147 -------
148 148 (format_dict, metadata_dict) : tuple of two dicts
149 149
150 150 format_dict is a dictionary of key/value pairs, one of each format that was
151 151 generated for the object. The keys are the format types, which
152 152 will usually be MIME type strings and the values and JSON'able
153 153 data structure containing the raw data for the representation in
154 154 that format.
155 155
156 156 metadata_dict is a dictionary of metadata about each mime-type output.
157 157 Its keys will be a strict subset of the keys in format_dict.
158 158 """
159 159 format_dict = {}
160 160 md_dict = {}
161 161
162 162 if self.ipython_display_formatter(obj):
163 163 # object handled itself, don't proceed
164 164 return {}, {}
165 165
166 166 for format_type, formatter in self.formatters.items():
167 167 if include and format_type not in include:
168 168 continue
169 169 if exclude and format_type in exclude:
170 170 continue
171 171
172 172 md = None
173 173 try:
174 174 data = formatter(obj)
175 175 except:
176 176 # FIXME: log the exception
177 177 raise
178 178
179 179 # formatters can return raw data or (data, metadata)
180 180 if isinstance(data, tuple) and len(data) == 2:
181 181 data, md = data
182 182
183 183 if data is not None:
184 184 format_dict[format_type] = data
185 185 if md is not None:
186 186 md_dict[format_type] = md
187 187
188 188 return format_dict, md_dict
189 189
190 190 @property
191 191 def format_types(self):
192 192 """Return the format types (MIME types) of the active formatters."""
193 193 return list(self.formatters.keys())
194 194
195 195
196 196 #-----------------------------------------------------------------------------
197 197 # Formatters for specific format types (text, html, svg, etc.)
198 198 #-----------------------------------------------------------------------------
199 199
200 200
201 201 def _safe_repr(obj):
202 202 """Try to return a repr of an object
203 203
204 204 always returns a string, at least.
205 205 """
206 206 try:
207 207 return repr(obj)
208 208 except Exception as e:
209 209 return "un-repr-able object (%r)" % e
210 210
211 211
212 212 class FormatterWarning(UserWarning):
213 213 """Warning class for errors in formatters"""
214 214
215 215 @decorator
216 216 def catch_format_error(method, self, *args, **kwargs):
217 217 """show traceback on failed format call"""
218 218 try:
219 219 r = method(self, *args, **kwargs)
220 220 except NotImplementedError:
221 221 # don't warn on NotImplementedErrors
222 222 return None
223 223 except Exception:
224 224 exc_info = sys.exc_info()
225 225 ip = get_ipython()
226 226 if ip is not None:
227 227 ip.showtraceback(exc_info)
228 228 else:
229 229 traceback.print_exception(*exc_info)
230 230 return None
231 231 return self._check_return(r, args[0])
232 232
233 233
234 234 class FormatterABC(with_metaclass(abc.ABCMeta, object)):
235 235 """ Abstract base class for Formatters.
236 236
237 237 A formatter is a callable class that is responsible for computing the
238 238 raw format data for a particular format type (MIME type). For example,
239 239 an HTML formatter would have a format type of `text/html` and would return
240 240 the HTML representation of the object when called.
241 241 """
242 242
243 243 # The format type of the data returned, usually a MIME type.
244 244 format_type = 'text/plain'
245 245
246 246 # Is the formatter enabled...
247 247 enabled = True
248 248
249 249 @abc.abstractmethod
250 250 def __call__(self, obj):
251 251 """Return a JSON'able representation of the object.
252 252
253 253 If the object cannot be formatted by this formatter,
254 254 warn and return None.
255 255 """
256 256 return repr(obj)
257 257
258 258
259 259 def _mod_name_key(typ):
260 260 """Return a (__module__, __name__) tuple for a type.
261 261
262 262 Used as key in Formatter.deferred_printers.
263 263 """
264 264 module = getattr(typ, '__module__', None)
265 265 name = getattr(typ, '__name__', None)
266 266 return (module, name)
267 267
268 268
269 269 def _get_type(obj):
270 270 """Return the type of an instance (old and new-style)"""
271 271 return getattr(obj, '__class__', None) or type(obj)
272 272
273 273 _raise_key_error = object()
274 274
275 275
276 276 class BaseFormatter(Configurable):
277 277 """A base formatter class that is configurable.
278 278
279 279 This formatter should usually be used as the base class of all formatters.
280 280 It is a traited :class:`Configurable` class and includes an extensible
281 281 API for users to determine how their objects are formatted. The following
282 282 logic is used to find a function to format an given object.
283 283
284 284 1. The object is introspected to see if it has a method with the name
285 285 :attr:`print_method`. If is does, that object is passed to that method
286 286 for formatting.
287 287 2. If no print method is found, three internal dictionaries are consulted
288 288 to find print method: :attr:`singleton_printers`, :attr:`type_printers`
289 289 and :attr:`deferred_printers`.
290 290
291 291 Users should use these dictionaries to register functions that will be
292 292 used to compute the format data for their objects (if those objects don't
293 293 have the special print methods). The easiest way of using these
294 294 dictionaries is through the :meth:`for_type` and :meth:`for_type_by_name`
295 295 methods.
296 296
297 297 If no function/callable is found to compute the format data, ``None`` is
298 298 returned and this format type is not used.
299 299 """
300 300
301 301 format_type = Unicode('text/plain')
302 302 _return_type = string_types
303 303
304 304 enabled = Bool(True, config=True)
305 305
306 306 print_method = ObjectName('__repr__')
307 307
308 308 # The singleton printers.
309 309 # Maps the IDs of the builtin singleton objects to the format functions.
310 310 singleton_printers = Dict(config=True)
311 311
312 312 # The type-specific printers.
313 313 # Map type objects to the format functions.
314 314 type_printers = Dict(config=True)
315 315
316 316 # The deferred-import type-specific printers.
317 317 # Map (modulename, classname) pairs to the format functions.
318 318 deferred_printers = Dict(config=True)
319 319
320 320 @catch_format_error
321 321 def __call__(self, obj):
322 322 """Compute the format for an object."""
323 323 if self.enabled:
324 324 # lookup registered printer
325 325 try:
326 326 printer = self.lookup(obj)
327 327 except KeyError:
328 328 pass
329 329 else:
330 330 return printer(obj)
331 331 # Finally look for special method names
332 332 method = _safe_get_formatter_method(obj, self.print_method)
333 333 if method is not None:
334 334 return method()
335 335 return None
336 336 else:
337 337 return None
338 338
339 339 def __contains__(self, typ):
340 340 """map in to lookup_by_type"""
341 341 try:
342 342 self.lookup_by_type(typ)
343 343 except KeyError:
344 344 return False
345 345 else:
346 346 return True
347 347
348 348 def _check_return(self, r, obj):
349 349 """Check that a return value is appropriate
350 350
351 351 Return the value if so, None otherwise, warning if invalid.
352 352 """
353 353 if r is None or isinstance(r, self._return_type) or \
354 354 (isinstance(r, tuple) and r and isinstance(r[0], self._return_type)):
355 355 return r
356 356 else:
357 357 warnings.warn(
358 358 "%s formatter returned invalid type %s (expected %s) for object: %s" % \
359 359 (self.format_type, type(r), self._return_type, _safe_repr(obj)),
360 360 FormatterWarning
361 361 )
362 362
363 363 def lookup(self, obj):
364 364 """Look up the formatter for a given instance.
365 365
366 366 Parameters
367 367 ----------
368 368 obj : object instance
369 369
370 370 Returns
371 371 -------
372 372 f : callable
373 373 The registered formatting callable for the type.
374 374
375 375 Raises
376 376 ------
377 377 KeyError if the type has not been registered.
378 378 """
379 379 # look for singleton first
380 380 obj_id = id(obj)
381 381 if obj_id in self.singleton_printers:
382 382 return self.singleton_printers[obj_id]
383 383 # then lookup by type
384 384 return self.lookup_by_type(_get_type(obj))
385 385
386 386 def lookup_by_type(self, typ):
387 387 """Look up the registered formatter for a type.
388 388
389 389 Parameters
390 390 ----------
391 391 typ : type or '__module__.__name__' string for a type
392 392
393 393 Returns
394 394 -------
395 395 f : callable
396 396 The registered formatting callable for the type.
397 397
398 398 Raises
399 399 ------
400 400 KeyError if the type has not been registered.
401 401 """
402 402 if isinstance(typ, string_types):
403 403 typ_key = tuple(typ.rsplit('.',1))
404 404 if typ_key not in self.deferred_printers:
405 405 # We may have it cached in the type map. We will have to
406 406 # iterate over all of the types to check.
407 407 for cls in self.type_printers:
408 408 if _mod_name_key(cls) == typ_key:
409 409 return self.type_printers[cls]
410 410 else:
411 411 return self.deferred_printers[typ_key]
412 412 else:
413 413 for cls in pretty._get_mro(typ):
414 414 if cls in self.type_printers or self._in_deferred_types(cls):
415 415 return self.type_printers[cls]
416 416
417 417 # If we have reached here, the lookup failed.
418 418 raise KeyError("No registered printer for {0!r}".format(typ))
419 419
420 420 def for_type(self, typ, func=None):
421 421 """Add a format function for a given type.
422 422
423 423 Parameters
424 424 -----------
425 425 typ : type or '__module__.__name__' string for a type
426 426 The class of the object that will be formatted using `func`.
427 427 func : callable
428 428 A callable for computing the format data.
429 429 `func` will be called with the object to be formatted,
430 430 and will return the raw data in this formatter's format.
431 431 Subclasses may use a different call signature for the
432 432 `func` argument.
433 433
434 434 If `func` is None or not specified, there will be no change,
435 435 only returning the current value.
436 436
437 437 Returns
438 438 -------
439 439 oldfunc : callable
440 440 The currently registered callable.
441 441 If you are registering a new formatter,
442 442 this will be the previous value (to enable restoring later).
443 443 """
444 444 # if string given, interpret as 'pkg.module.class_name'
445 445 if isinstance(typ, string_types):
446 446 type_module, type_name = typ.rsplit('.', 1)
447 447 return self.for_type_by_name(type_module, type_name, func)
448 448
449 449 try:
450 450 oldfunc = self.lookup_by_type(typ)
451 451 except KeyError:
452 452 oldfunc = None
453 453
454 454 if func is not None:
455 455 self.type_printers[typ] = func
456 456
457 457 return oldfunc
458 458
459 459 def for_type_by_name(self, type_module, type_name, func=None):
460 460 """Add a format function for a type specified by the full dotted
461 461 module and name of the type, rather than the type of the object.
462 462
463 463 Parameters
464 464 ----------
465 465 type_module : str
466 466 The full dotted name of the module the type is defined in, like
467 467 ``numpy``.
468 468 type_name : str
469 469 The name of the type (the class name), like ``dtype``
470 470 func : callable
471 471 A callable for computing the format data.
472 472 `func` will be called with the object to be formatted,
473 473 and will return the raw data in this formatter's format.
474 474 Subclasses may use a different call signature for the
475 475 `func` argument.
476 476
477 477 If `func` is None or unspecified, there will be no change,
478 478 only returning the current value.
479 479
480 480 Returns
481 481 -------
482 482 oldfunc : callable
483 483 The currently registered callable.
484 484 If you are registering a new formatter,
485 485 this will be the previous value (to enable restoring later).
486 486 """
487 487 key = (type_module, type_name)
488 488
489 489 try:
490 490 oldfunc = self.lookup_by_type("%s.%s" % key)
491 491 except KeyError:
492 492 oldfunc = None
493 493
494 494 if func is not None:
495 495 self.deferred_printers[key] = func
496 496 return oldfunc
497 497
498 498 def pop(self, typ, default=_raise_key_error):
499 499 """Pop a formatter for the given type.
500 500
501 501 Parameters
502 502 ----------
503 503 typ : type or '__module__.__name__' string for a type
504 504 default : object
505 505 value to be returned if no formatter is registered for typ.
506 506
507 507 Returns
508 508 -------
509 509 obj : object
510 510 The last registered object for the type.
511 511
512 512 Raises
513 513 ------
514 514 KeyError if the type is not registered and default is not specified.
515 515 """
516 516
517 517 if isinstance(typ, string_types):
518 518 typ_key = tuple(typ.rsplit('.',1))
519 519 if typ_key not in self.deferred_printers:
520 520 # We may have it cached in the type map. We will have to
521 521 # iterate over all of the types to check.
522 522 for cls in self.type_printers:
523 523 if _mod_name_key(cls) == typ_key:
524 524 old = self.type_printers.pop(cls)
525 525 break
526 526 else:
527 527 old = default
528 528 else:
529 529 old = self.deferred_printers.pop(typ_key)
530 530 else:
531 531 if typ in self.type_printers:
532 532 old = self.type_printers.pop(typ)
533 533 else:
534 534 old = self.deferred_printers.pop(_mod_name_key(typ), default)
535 535 if old is _raise_key_error:
536 536 raise KeyError("No registered value for {0!r}".format(typ))
537 537 return old
538 538
539 539 def _in_deferred_types(self, cls):
540 540 """
541 541 Check if the given class is specified in the deferred type registry.
542 542
543 543 Successful matches will be moved to the regular type registry for future use.
544 544 """
545 545 mod = getattr(cls, '__module__', None)
546 546 name = getattr(cls, '__name__', None)
547 547 key = (mod, name)
548 548 if key in self.deferred_printers:
549 549 # Move the printer over to the regular registry.
550 550 printer = self.deferred_printers.pop(key)
551 551 self.type_printers[cls] = printer
552 552 return True
553 553 return False
554 554
555 555
556 556 class PlainTextFormatter(BaseFormatter):
557 557 """The default pretty-printer.
558 558
559 559 This uses :mod:`IPython.lib.pretty` to compute the format data of
560 560 the object. If the object cannot be pretty printed, :func:`repr` is used.
561 561 See the documentation of :mod:`IPython.lib.pretty` for details on
562 562 how to write pretty printers. Here is a simple example::
563 563
564 564 def dtype_pprinter(obj, p, cycle):
565 565 if cycle:
566 566 return p.text('dtype(...)')
567 567 if hasattr(obj, 'fields'):
568 568 if obj.fields is None:
569 569 p.text(repr(obj))
570 570 else:
571 571 p.begin_group(7, 'dtype([')
572 572 for i, field in enumerate(obj.descr):
573 573 if i > 0:
574 574 p.text(',')
575 575 p.breakable()
576 576 p.pretty(field)
577 577 p.end_group(7, '])')
578 578 """
579 579
580 580 # The format type of data returned.
581 581 format_type = Unicode('text/plain')
582 582
583 583 # This subclass ignores this attribute as it always need to return
584 584 # something.
585 585 enabled = Bool(True, config=False)
586 586
587 587 max_seq_length = Integer(pretty.MAX_SEQ_LENGTH, config=True,
588 588 help="""Truncate large collections (lists, dicts, tuples, sets) to this size.
589 589
590 590 Set to 0 to disable truncation.
591 591 """
592 592 )
593 593
594 594 # Look for a _repr_pretty_ methods to use for pretty printing.
595 595 print_method = ObjectName('_repr_pretty_')
596 596
597 597 # Whether to pretty-print or not.
598 598 pprint = Bool(True, config=True)
599 599
600 600 # Whether to be verbose or not.
601 601 verbose = Bool(False, config=True)
602 602
603 603 # The maximum width.
604 604 max_width = Integer(79, config=True)
605 605
606 606 # The newline character.
607 607 newline = Unicode('\n', config=True)
608 608
609 609 # format-string for pprinting floats
610 610 float_format = Unicode('%r')
611 611 # setter for float precision, either int or direct format-string
612 612 float_precision = CUnicode('', config=True)
613 613
614 614 def _float_precision_changed(self, name, old, new):
615 615 """float_precision changed, set float_format accordingly.
616 616
617 617 float_precision can be set by int or str.
618 618 This will set float_format, after interpreting input.
619 619 If numpy has been imported, numpy print precision will also be set.
620 620
621 621 integer `n` sets format to '%.nf', otherwise, format set directly.
622 622
623 623 An empty string returns to defaults (repr for float, 8 for numpy).
624 624
625 625 This parameter can be set via the '%precision' magic.
626 626 """
627 627
628 628 if '%' in new:
629 629 # got explicit format string
630 630 fmt = new
631 631 try:
632 632 fmt%3.14159
633 633 except Exception:
634 634 raise ValueError("Precision must be int or format string, not %r"%new)
635 635 elif new:
636 636 # otherwise, should be an int
637 637 try:
638 638 i = int(new)
639 639 assert i >= 0
640 640 except ValueError:
641 641 raise ValueError("Precision must be int or format string, not %r"%new)
642 642 except AssertionError:
643 643 raise ValueError("int precision must be non-negative, not %r"%i)
644 644
645 645 fmt = '%%.%if'%i
646 646 if 'numpy' in sys.modules:
647 647 # set numpy precision if it has been imported
648 648 import numpy
649 649 numpy.set_printoptions(precision=i)
650 650 else:
651 651 # default back to repr
652 652 fmt = '%r'
653 653 if 'numpy' in sys.modules:
654 654 import numpy
655 655 # numpy default is 8
656 656 numpy.set_printoptions(precision=8)
657 657 self.float_format = fmt
658 658
659 659 # Use the default pretty printers from IPython.lib.pretty.
660 660 def _singleton_printers_default(self):
661 661 return pretty._singleton_pprinters.copy()
662 662
663 663 def _type_printers_default(self):
664 664 d = pretty._type_pprinters.copy()
665 665 d[float] = lambda obj,p,cycle: p.text(self.float_format%obj)
666 666 return d
667 667
668 668 def _deferred_printers_default(self):
669 669 return pretty._deferred_type_pprinters.copy()
670 670
671 671 #### FormatterABC interface ####
672 672
673 673 @catch_format_error
674 674 def __call__(self, obj):
675 675 """Compute the pretty representation of the object."""
676 676 if not self.pprint:
677 677 return repr(obj)
678 678 else:
679 679 # handle str and unicode on Python 2
680 680 # io.StringIO only accepts unicode,
681 681 # cStringIO doesn't handle unicode on py2,
682 682 # StringIO allows str, unicode but only ascii str
683 683 stream = pretty.CUnicodeIO()
684 684 printer = pretty.RepresentationPrinter(stream, self.verbose,
685 685 self.max_width, self.newline,
686 686 max_seq_length=self.max_seq_length,
687 687 singleton_pprinters=self.singleton_printers,
688 688 type_pprinters=self.type_printers,
689 689 deferred_pprinters=self.deferred_printers)
690 690 printer.pretty(obj)
691 691 printer.flush()
692 692 return stream.getvalue()
693 693
694 694
695 695 class HTMLFormatter(BaseFormatter):
696 696 """An HTML formatter.
697 697
698 698 To define the callables that compute the HTML representation of your
699 699 objects, define a :meth:`_repr_html_` method or use the :meth:`for_type`
700 700 or :meth:`for_type_by_name` methods to register functions that handle
701 701 this.
702 702
703 703 The return value of this formatter should be a valid HTML snippet that
704 704 could be injected into an existing DOM. It should *not* include the
705 705 ```<html>`` or ```<body>`` tags.
706 706 """
707 707 format_type = Unicode('text/html')
708 708
709 709 print_method = ObjectName('_repr_html_')
710 710
711 711
712 712 class MarkdownFormatter(BaseFormatter):
713 713 """A Markdown formatter.
714 714
715 715 To define the callables that compute the Markdown representation of your
716 716 objects, define a :meth:`_repr_markdown_` method or use the :meth:`for_type`
717 717 or :meth:`for_type_by_name` methods to register functions that handle
718 718 this.
719 719
720 720 The return value of this formatter should be a valid Markdown.
721 721 """
722 722 format_type = Unicode('text/markdown')
723 723
724 724 print_method = ObjectName('_repr_markdown_')
725 725
726 726 class SVGFormatter(BaseFormatter):
727 727 """An SVG formatter.
728 728
729 729 To define the callables that compute the SVG representation of your
730 730 objects, define a :meth:`_repr_svg_` method or use the :meth:`for_type`
731 731 or :meth:`for_type_by_name` methods to register functions that handle
732 732 this.
733 733
734 734 The return value of this formatter should be valid SVG enclosed in
735 735 ```<svg>``` tags, that could be injected into an existing DOM. It should
736 736 *not* include the ```<html>`` or ```<body>`` tags.
737 737 """
738 738 format_type = Unicode('image/svg+xml')
739 739
740 740 print_method = ObjectName('_repr_svg_')
741 741
742 742
743 743 class PNGFormatter(BaseFormatter):
744 744 """A PNG formatter.
745 745
746 746 To define the callables that compute the PNG representation of your
747 747 objects, define a :meth:`_repr_png_` method or use the :meth:`for_type`
748 748 or :meth:`for_type_by_name` methods to register functions that handle
749 749 this.
750 750
751 751 The return value of this formatter should be raw PNG data, *not*
752 752 base64 encoded.
753 753 """
754 754 format_type = Unicode('image/png')
755 755
756 756 print_method = ObjectName('_repr_png_')
757 757
758 758 _return_type = (bytes, unicode_type)
759 759
760 760
761 761 class JPEGFormatter(BaseFormatter):
762 762 """A JPEG formatter.
763 763
764 764 To define the callables that compute the JPEG representation of your
765 765 objects, define a :meth:`_repr_jpeg_` method or use the :meth:`for_type`
766 766 or :meth:`for_type_by_name` methods to register functions that handle
767 767 this.
768 768
769 769 The return value of this formatter should be raw JPEG data, *not*
770 770 base64 encoded.
771 771 """
772 772 format_type = Unicode('image/jpeg')
773 773
774 774 print_method = ObjectName('_repr_jpeg_')
775 775
776 776 _return_type = (bytes, unicode_type)
777 777
778 778
779 779 class LatexFormatter(BaseFormatter):
780 780 """A LaTeX formatter.
781 781
782 782 To define the callables that compute the LaTeX representation of your
783 783 objects, define a :meth:`_repr_latex_` method or use the :meth:`for_type`
784 784 or :meth:`for_type_by_name` methods to register functions that handle
785 785 this.
786 786
787 787 The return value of this formatter should be a valid LaTeX equation,
788 788 enclosed in either ```$```, ```$$``` or another LaTeX equation
789 789 environment.
790 790 """
791 791 format_type = Unicode('text/latex')
792 792
793 793 print_method = ObjectName('_repr_latex_')
794 794
795 795
796 796 class JSONFormatter(BaseFormatter):
797 797 """A JSON string formatter.
798 798
799 799 To define the callables that compute the JSONable representation of
800 800 your objects, define a :meth:`_repr_json_` method or use the :meth:`for_type`
801 801 or :meth:`for_type_by_name` methods to register functions that handle
802 802 this.
803 803
804 804 The return value of this formatter should be a JSONable list or dict.
805 805 JSON scalars (None, number, string) are not allowed, only dict or list containers.
806 806 """
807 807 format_type = Unicode('application/json')
808 808 _return_type = (list, dict)
809 809
810 810 print_method = ObjectName('_repr_json_')
811 811
812 812 def _check_return(self, r, obj):
813 813 """Check that a return value is appropriate
814 814
815 815 Return the value if so, None otherwise, warning if invalid.
816 816 """
817 817 if r is None:
818 818 return
819 819 md = None
820 820 if isinstance(r, tuple):
821 821 # unpack data, metadata tuple for type checking on first element
822 822 r, md = r
823 823
824 824 # handle deprecated JSON-as-string form from IPython < 3
825 825 if isinstance(r, string_types):
826 826 warnings.warn("JSON expects JSONable list/dict containers, not JSON strings",
827 827 FormatterWarning)
828 828 r = json.loads(r)
829 829
830 830 if md is not None:
831 831 # put the tuple back together
832 832 r = (r, md)
833 833 return super(JSONFormatter, self)._check_return(r, obj)
834 834
835 835
836 836 class JavascriptFormatter(BaseFormatter):
837 837 """A Javascript formatter.
838 838
839 839 To define the callables that compute the Javascript representation of
840 840 your objects, define a :meth:`_repr_javascript_` method or use the
841 841 :meth:`for_type` or :meth:`for_type_by_name` methods to register functions
842 842 that handle this.
843 843
844 844 The return value of this formatter should be valid Javascript code and
845 845 should *not* be enclosed in ```<script>``` tags.
846 846 """
847 847 format_type = Unicode('application/javascript')
848 848
849 849 print_method = ObjectName('_repr_javascript_')
850 850
851 851
852 852 class PDFFormatter(BaseFormatter):
853 853 """A PDF formatter.
854 854
855 855 To define the callables that compute the PDF representation of your
856 856 objects, define a :meth:`_repr_pdf_` method or use the :meth:`for_type`
857 857 or :meth:`for_type_by_name` methods to register functions that handle
858 858 this.
859 859
860 860 The return value of this formatter should be raw PDF data, *not*
861 861 base64 encoded.
862 862 """
863 863 format_type = Unicode('application/pdf')
864 864
865 865 print_method = ObjectName('_repr_pdf_')
866 866
867 867 _return_type = (bytes, unicode_type)
868 868
869 869 class IPythonDisplayFormatter(BaseFormatter):
870 870 """A Formatter for objects that know how to display themselves.
871 871
872 872 To define the callables that compute the representation of your
873 873 objects, define a :meth:`_ipython_display_` method or use the :meth:`for_type`
874 874 or :meth:`for_type_by_name` methods to register functions that handle
875 875 this. Unlike mime-type displays, this method should not return anything,
876 876 instead calling any appropriate display methods itself.
877 877
878 878 This display formatter has highest priority.
879 879 If it fires, no other display formatter will be called.
880 880 """
881 881 print_method = ObjectName('_ipython_display_')
882 882 _return_type = (type(None), bool)
883 883
884 884
885 885 @catch_format_error
886 886 def __call__(self, obj):
887 887 """Compute the format for an object."""
888 888 if self.enabled:
889 889 # lookup registered printer
890 890 try:
891 891 printer = self.lookup(obj)
892 892 except KeyError:
893 893 pass
894 894 else:
895 895 printer(obj)
896 896 return True
897 897 # Finally look for special method names
898 898 method = _safe_get_formatter_method(obj, self.print_method)
899 899 if method is not None:
900 900 method()
901 901 return True
902 902
903 903
904 904 FormatterABC.register(BaseFormatter)
905 905 FormatterABC.register(PlainTextFormatter)
906 906 FormatterABC.register(HTMLFormatter)
907 907 FormatterABC.register(MarkdownFormatter)
908 908 FormatterABC.register(SVGFormatter)
909 909 FormatterABC.register(PNGFormatter)
910 910 FormatterABC.register(PDFFormatter)
911 911 FormatterABC.register(JPEGFormatter)
912 912 FormatterABC.register(LatexFormatter)
913 913 FormatterABC.register(JSONFormatter)
914 914 FormatterABC.register(JavascriptFormatter)
915 915 FormatterABC.register(IPythonDisplayFormatter)
916 916
917 917
918 918 def format_display_data(obj, include=None, exclude=None):
919 919 """Return a format data dict for an object.
920 920
921 921 By default all format types will be computed.
922 922
923 923 The following MIME types are currently implemented:
924 924
925 925 * text/plain
926 926 * text/html
927 927 * text/markdown
928 928 * text/latex
929 929 * application/json
930 930 * application/javascript
931 931 * application/pdf
932 932 * image/png
933 933 * image/jpeg
934 934 * image/svg+xml
935 935
936 936 Parameters
937 937 ----------
938 938 obj : object
939 939 The Python object whose format data will be computed.
940 940
941 941 Returns
942 942 -------
943 943 format_dict : dict
944 944 A dictionary of key/value pairs, one or each format that was
945 945 generated for the object. The keys are the format types, which
946 946 will usually be MIME type strings and the values and JSON'able
947 947 data structure containing the raw data for the representation in
948 948 that format.
949 949 include : list or tuple, optional
950 950 A list of format type strings (MIME types) to include in the
951 951 format data dict. If this is set *only* the format types included
952 952 in this list will be computed.
953 953 exclude : list or tuple, optional
954 954 A list of format type string (MIME types) to exclue in the format
955 955 data dict. If this is set all format types will be computed,
956 956 except for those included in this argument.
957 957 """
958 958 from IPython.core.interactiveshell import InteractiveShell
959 959
960 960 InteractiveShell.instance().display_formatter.format(
961 961 obj,
962 962 include,
963 963 exclude
964 964 )
965 965
@@ -1,870 +1,870 b''
1 1 """ History related magics and functionality """
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010-2011 The IPython Development Team.
4 4 #
5 5 # Distributed under the terms of the BSD License.
6 6 #
7 7 # The full license is in the file COPYING.txt, distributed with this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13 from __future__ import print_function
14 14
15 15 # Stdlib imports
16 16 import atexit
17 17 import datetime
18 18 import os
19 19 import re
20 20 try:
21 21 import sqlite3
22 22 except ImportError:
23 23 try:
24 24 from pysqlite2 import dbapi2 as sqlite3
25 25 except ImportError:
26 26 sqlite3 = None
27 27 import threading
28 28
29 29 # Our own packages
30 30 from IPython.config.configurable import Configurable
31 from IPython.external.decorator import decorator
31 from decorator import decorator
32 32 from IPython.utils.decorators import undoc
33 33 from IPython.utils.path import locate_profile
34 34 from IPython.utils import py3compat
35 35 from IPython.utils.traitlets import (
36 36 Any, Bool, Dict, Instance, Integer, List, Unicode, TraitError,
37 37 )
38 38 from IPython.utils.warn import warn
39 39
40 40 #-----------------------------------------------------------------------------
41 41 # Classes and functions
42 42 #-----------------------------------------------------------------------------
43 43
44 44 @undoc
45 45 class DummyDB(object):
46 46 """Dummy DB that will act as a black hole for history.
47 47
48 48 Only used in the absence of sqlite"""
49 49 def execute(*args, **kwargs):
50 50 return []
51 51
52 52 def commit(self, *args, **kwargs):
53 53 pass
54 54
55 55 def __enter__(self, *args, **kwargs):
56 56 pass
57 57
58 58 def __exit__(self, *args, **kwargs):
59 59 pass
60 60
61 61
62 62 @decorator
63 63 def needs_sqlite(f, self, *a, **kw):
64 64 """Decorator: return an empty list in the absence of sqlite."""
65 65 if sqlite3 is None or not self.enabled:
66 66 return []
67 67 else:
68 68 return f(self, *a, **kw)
69 69
70 70
71 71 if sqlite3 is not None:
72 72 DatabaseError = sqlite3.DatabaseError
73 73 else:
74 74 @undoc
75 75 class DatabaseError(Exception):
76 76 "Dummy exception when sqlite could not be imported. Should never occur."
77 77
78 78 @decorator
79 79 def catch_corrupt_db(f, self, *a, **kw):
80 80 """A decorator which wraps HistoryAccessor method calls to catch errors from
81 81 a corrupt SQLite database, move the old database out of the way, and create
82 82 a new one.
83 83 """
84 84 try:
85 85 return f(self, *a, **kw)
86 86 except DatabaseError:
87 87 if os.path.isfile(self.hist_file):
88 88 # Try to move the file out of the way
89 89 base,ext = os.path.splitext(self.hist_file)
90 90 newpath = base + '-corrupt' + ext
91 91 os.rename(self.hist_file, newpath)
92 92 self.init_db()
93 93 print("ERROR! History file wasn't a valid SQLite database.",
94 94 "It was moved to %s" % newpath, "and a new file created.")
95 95 return []
96 96
97 97 else:
98 98 # The hist_file is probably :memory: or something else.
99 99 raise
100 100
101 101 class HistoryAccessorBase(Configurable):
102 102 """An abstract class for History Accessors """
103 103
104 104 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
105 105 raise NotImplementedError
106 106
107 107 def search(self, pattern="*", raw=True, search_raw=True,
108 108 output=False, n=None, unique=False):
109 109 raise NotImplementedError
110 110
111 111 def get_range(self, session, start=1, stop=None, raw=True,output=False):
112 112 raise NotImplementedError
113 113
114 114 def get_range_by_str(self, rangestr, raw=True, output=False):
115 115 raise NotImplementedError
116 116
117 117
118 118 class HistoryAccessor(HistoryAccessorBase):
119 119 """Access the history database without adding to it.
120 120
121 121 This is intended for use by standalone history tools. IPython shells use
122 122 HistoryManager, below, which is a subclass of this."""
123 123
124 124 # String holding the path to the history file
125 125 hist_file = Unicode(config=True,
126 126 help="""Path to file to use for SQLite history database.
127 127
128 128 By default, IPython will put the history database in the IPython
129 129 profile directory. If you would rather share one history among
130 130 profiles, you can set this value in each, so that they are consistent.
131 131
132 132 Due to an issue with fcntl, SQLite is known to misbehave on some NFS
133 133 mounts. If you see IPython hanging, try setting this to something on a
134 134 local disk, e.g::
135 135
136 136 ipython --HistoryManager.hist_file=/tmp/ipython_hist.sqlite
137 137
138 138 """)
139 139
140 140 enabled = Bool(True, config=True,
141 141 help="""enable the SQLite history
142 142
143 143 set enabled=False to disable the SQLite history,
144 144 in which case there will be no stored history, no SQLite connection,
145 145 and no background saving thread. This may be necessary in some
146 146 threaded environments where IPython is embedded.
147 147 """
148 148 )
149 149
150 150 connection_options = Dict(config=True,
151 151 help="""Options for configuring the SQLite connection
152 152
153 153 These options are passed as keyword args to sqlite3.connect
154 154 when establishing database conenctions.
155 155 """
156 156 )
157 157
158 158 # The SQLite database
159 159 db = Any()
160 160 def _db_changed(self, name, old, new):
161 161 """validate the db, since it can be an Instance of two different types"""
162 162 connection_types = (DummyDB,)
163 163 if sqlite3 is not None:
164 164 connection_types = (DummyDB, sqlite3.Connection)
165 165 if not isinstance(new, connection_types):
166 166 msg = "%s.db must be sqlite3 Connection or DummyDB, not %r" % \
167 167 (self.__class__.__name__, new)
168 168 raise TraitError(msg)
169 169
170 170 def __init__(self, profile='default', hist_file=u'', **traits):
171 171 """Create a new history accessor.
172 172
173 173 Parameters
174 174 ----------
175 175 profile : str
176 176 The name of the profile from which to open history.
177 177 hist_file : str
178 178 Path to an SQLite history database stored by IPython. If specified,
179 179 hist_file overrides profile.
180 180 config : :class:`~IPython.config.loader.Config`
181 181 Config object. hist_file can also be set through this.
182 182 """
183 183 # We need a pointer back to the shell for various tasks.
184 184 super(HistoryAccessor, self).__init__(**traits)
185 185 # defer setting hist_file from kwarg until after init,
186 186 # otherwise the default kwarg value would clobber any value
187 187 # set by config
188 188 if hist_file:
189 189 self.hist_file = hist_file
190 190
191 191 if self.hist_file == u'':
192 192 # No one has set the hist_file, yet.
193 193 self.hist_file = self._get_hist_file_name(profile)
194 194
195 195 if sqlite3 is None and self.enabled:
196 196 warn("IPython History requires SQLite, your history will not be saved")
197 197 self.enabled = False
198 198
199 199 self.init_db()
200 200
201 201 def _get_hist_file_name(self, profile='default'):
202 202 """Find the history file for the given profile name.
203 203
204 204 This is overridden by the HistoryManager subclass, to use the shell's
205 205 active profile.
206 206
207 207 Parameters
208 208 ----------
209 209 profile : str
210 210 The name of a profile which has a history file.
211 211 """
212 212 return os.path.join(locate_profile(profile), 'history.sqlite')
213 213
214 214 @catch_corrupt_db
215 215 def init_db(self):
216 216 """Connect to the database, and create tables if necessary."""
217 217 if not self.enabled:
218 218 self.db = DummyDB()
219 219 return
220 220
221 221 # use detect_types so that timestamps return datetime objects
222 222 kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
223 223 kwargs.update(self.connection_options)
224 224 self.db = sqlite3.connect(self.hist_file, **kwargs)
225 225 self.db.execute("""CREATE TABLE IF NOT EXISTS sessions (session integer
226 226 primary key autoincrement, start timestamp,
227 227 end timestamp, num_cmds integer, remark text)""")
228 228 self.db.execute("""CREATE TABLE IF NOT EXISTS history
229 229 (session integer, line integer, source text, source_raw text,
230 230 PRIMARY KEY (session, line))""")
231 231 # Output history is optional, but ensure the table's there so it can be
232 232 # enabled later.
233 233 self.db.execute("""CREATE TABLE IF NOT EXISTS output_history
234 234 (session integer, line integer, output text,
235 235 PRIMARY KEY (session, line))""")
236 236 self.db.commit()
237 237
238 238 def writeout_cache(self):
239 239 """Overridden by HistoryManager to dump the cache before certain
240 240 database lookups."""
241 241 pass
242 242
243 243 ## -------------------------------
244 244 ## Methods for retrieving history:
245 245 ## -------------------------------
246 246 def _run_sql(self, sql, params, raw=True, output=False):
247 247 """Prepares and runs an SQL query for the history database.
248 248
249 249 Parameters
250 250 ----------
251 251 sql : str
252 252 Any filtering expressions to go after SELECT ... FROM ...
253 253 params : tuple
254 254 Parameters passed to the SQL query (to replace "?")
255 255 raw, output : bool
256 256 See :meth:`get_range`
257 257
258 258 Returns
259 259 -------
260 260 Tuples as :meth:`get_range`
261 261 """
262 262 toget = 'source_raw' if raw else 'source'
263 263 sqlfrom = "history"
264 264 if output:
265 265 sqlfrom = "history LEFT JOIN output_history USING (session, line)"
266 266 toget = "history.%s, output_history.output" % toget
267 267 cur = self.db.execute("SELECT session, line, %s FROM %s " %\
268 268 (toget, sqlfrom) + sql, params)
269 269 if output: # Regroup into 3-tuples, and parse JSON
270 270 return ((ses, lin, (inp, out)) for ses, lin, inp, out in cur)
271 271 return cur
272 272
273 273 @needs_sqlite
274 274 @catch_corrupt_db
275 275 def get_session_info(self, session):
276 276 """Get info about a session.
277 277
278 278 Parameters
279 279 ----------
280 280
281 281 session : int
282 282 Session number to retrieve.
283 283
284 284 Returns
285 285 -------
286 286
287 287 session_id : int
288 288 Session ID number
289 289 start : datetime
290 290 Timestamp for the start of the session.
291 291 end : datetime
292 292 Timestamp for the end of the session, or None if IPython crashed.
293 293 num_cmds : int
294 294 Number of commands run, or None if IPython crashed.
295 295 remark : unicode
296 296 A manually set description.
297 297 """
298 298 query = "SELECT * from sessions where session == ?"
299 299 return self.db.execute(query, (session,)).fetchone()
300 300
301 301 @catch_corrupt_db
302 302 def get_last_session_id(self):
303 303 """Get the last session ID currently in the database.
304 304
305 305 Within IPython, this should be the same as the value stored in
306 306 :attr:`HistoryManager.session_number`.
307 307 """
308 308 for record in self.get_tail(n=1, include_latest=True):
309 309 return record[0]
310 310
311 311 @catch_corrupt_db
312 312 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
313 313 """Get the last n lines from the history database.
314 314
315 315 Parameters
316 316 ----------
317 317 n : int
318 318 The number of lines to get
319 319 raw, output : bool
320 320 See :meth:`get_range`
321 321 include_latest : bool
322 322 If False (default), n+1 lines are fetched, and the latest one
323 323 is discarded. This is intended to be used where the function
324 324 is called by a user command, which it should not return.
325 325
326 326 Returns
327 327 -------
328 328 Tuples as :meth:`get_range`
329 329 """
330 330 self.writeout_cache()
331 331 if not include_latest:
332 332 n += 1
333 333 cur = self._run_sql("ORDER BY session DESC, line DESC LIMIT ?",
334 334 (n,), raw=raw, output=output)
335 335 if not include_latest:
336 336 return reversed(list(cur)[1:])
337 337 return reversed(list(cur))
338 338
339 339 @catch_corrupt_db
340 340 def search(self, pattern="*", raw=True, search_raw=True,
341 341 output=False, n=None, unique=False):
342 342 """Search the database using unix glob-style matching (wildcards
343 343 * and ?).
344 344
345 345 Parameters
346 346 ----------
347 347 pattern : str
348 348 The wildcarded pattern to match when searching
349 349 search_raw : bool
350 350 If True, search the raw input, otherwise, the parsed input
351 351 raw, output : bool
352 352 See :meth:`get_range`
353 353 n : None or int
354 354 If an integer is given, it defines the limit of
355 355 returned entries.
356 356 unique : bool
357 357 When it is true, return only unique entries.
358 358
359 359 Returns
360 360 -------
361 361 Tuples as :meth:`get_range`
362 362 """
363 363 tosearch = "source_raw" if search_raw else "source"
364 364 if output:
365 365 tosearch = "history." + tosearch
366 366 self.writeout_cache()
367 367 sqlform = "WHERE %s GLOB ?" % tosearch
368 368 params = (pattern,)
369 369 if unique:
370 370 sqlform += ' GROUP BY {0}'.format(tosearch)
371 371 if n is not None:
372 372 sqlform += " ORDER BY session DESC, line DESC LIMIT ?"
373 373 params += (n,)
374 374 elif unique:
375 375 sqlform += " ORDER BY session, line"
376 376 cur = self._run_sql(sqlform, params, raw=raw, output=output)
377 377 if n is not None:
378 378 return reversed(list(cur))
379 379 return cur
380 380
381 381 @catch_corrupt_db
382 382 def get_range(self, session, start=1, stop=None, raw=True,output=False):
383 383 """Retrieve input by session.
384 384
385 385 Parameters
386 386 ----------
387 387 session : int
388 388 Session number to retrieve.
389 389 start : int
390 390 First line to retrieve.
391 391 stop : int
392 392 End of line range (excluded from output itself). If None, retrieve
393 393 to the end of the session.
394 394 raw : bool
395 395 If True, return untranslated input
396 396 output : bool
397 397 If True, attempt to include output. This will be 'real' Python
398 398 objects for the current session, or text reprs from previous
399 399 sessions if db_log_output was enabled at the time. Where no output
400 400 is found, None is used.
401 401
402 402 Returns
403 403 -------
404 404 entries
405 405 An iterator over the desired lines. Each line is a 3-tuple, either
406 406 (session, line, input) if output is False, or
407 407 (session, line, (input, output)) if output is True.
408 408 """
409 409 if stop:
410 410 lineclause = "line >= ? AND line < ?"
411 411 params = (session, start, stop)
412 412 else:
413 413 lineclause = "line>=?"
414 414 params = (session, start)
415 415
416 416 return self._run_sql("WHERE session==? AND %s" % lineclause,
417 417 params, raw=raw, output=output)
418 418
419 419 def get_range_by_str(self, rangestr, raw=True, output=False):
420 420 """Get lines of history from a string of ranges, as used by magic
421 421 commands %hist, %save, %macro, etc.
422 422
423 423 Parameters
424 424 ----------
425 425 rangestr : str
426 426 A string specifying ranges, e.g. "5 ~2/1-4". See
427 427 :func:`magic_history` for full details.
428 428 raw, output : bool
429 429 As :meth:`get_range`
430 430
431 431 Returns
432 432 -------
433 433 Tuples as :meth:`get_range`
434 434 """
435 435 for sess, s, e in extract_hist_ranges(rangestr):
436 436 for line in self.get_range(sess, s, e, raw=raw, output=output):
437 437 yield line
438 438
439 439
440 440 class HistoryManager(HistoryAccessor):
441 441 """A class to organize all history-related functionality in one place.
442 442 """
443 443 # Public interface
444 444
445 445 # An instance of the IPython shell we are attached to
446 446 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
447 447 # Lists to hold processed and raw history. These start with a blank entry
448 448 # so that we can index them starting from 1
449 449 input_hist_parsed = List([""])
450 450 input_hist_raw = List([""])
451 451 # A list of directories visited during session
452 452 dir_hist = List()
453 453 def _dir_hist_default(self):
454 454 try:
455 455 return [py3compat.getcwd()]
456 456 except OSError:
457 457 return []
458 458
459 459 # A dict of output history, keyed with ints from the shell's
460 460 # execution count.
461 461 output_hist = Dict()
462 462 # The text/plain repr of outputs.
463 463 output_hist_reprs = Dict()
464 464
465 465 # The number of the current session in the history database
466 466 session_number = Integer()
467 467
468 468 db_log_output = Bool(False, config=True,
469 469 help="Should the history database include output? (default: no)"
470 470 )
471 471 db_cache_size = Integer(0, config=True,
472 472 help="Write to database every x commands (higher values save disk access & power).\n"
473 473 "Values of 1 or less effectively disable caching."
474 474 )
475 475 # The input and output caches
476 476 db_input_cache = List()
477 477 db_output_cache = List()
478 478
479 479 # History saving in separate thread
480 480 save_thread = Instance('IPython.core.history.HistorySavingThread')
481 481 try: # Event is a function returning an instance of _Event...
482 482 save_flag = Instance(threading._Event)
483 483 except AttributeError: # ...until Python 3.3, when it's a class.
484 484 save_flag = Instance(threading.Event)
485 485
486 486 # Private interface
487 487 # Variables used to store the three last inputs from the user. On each new
488 488 # history update, we populate the user's namespace with these, shifted as
489 489 # necessary.
490 490 _i00 = Unicode(u'')
491 491 _i = Unicode(u'')
492 492 _ii = Unicode(u'')
493 493 _iii = Unicode(u'')
494 494
495 495 # A regex matching all forms of the exit command, so that we don't store
496 496 # them in the history (it's annoying to rewind the first entry and land on
497 497 # an exit call).
498 498 _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")
499 499
500 500 def __init__(self, shell=None, config=None, **traits):
501 501 """Create a new history manager associated with a shell instance.
502 502 """
503 503 # We need a pointer back to the shell for various tasks.
504 504 super(HistoryManager, self).__init__(shell=shell, config=config,
505 505 **traits)
506 506 self.save_flag = threading.Event()
507 507 self.db_input_cache_lock = threading.Lock()
508 508 self.db_output_cache_lock = threading.Lock()
509 509 if self.enabled and self.hist_file != ':memory:':
510 510 self.save_thread = HistorySavingThread(self)
511 511 self.save_thread.start()
512 512
513 513 self.new_session()
514 514
515 515 def _get_hist_file_name(self, profile=None):
516 516 """Get default history file name based on the Shell's profile.
517 517
518 518 The profile parameter is ignored, but must exist for compatibility with
519 519 the parent class."""
520 520 profile_dir = self.shell.profile_dir.location
521 521 return os.path.join(profile_dir, 'history.sqlite')
522 522
523 523 @needs_sqlite
524 524 def new_session(self, conn=None):
525 525 """Get a new session number."""
526 526 if conn is None:
527 527 conn = self.db
528 528
529 529 with conn:
530 530 cur = conn.execute("""INSERT INTO sessions VALUES (NULL, ?, NULL,
531 531 NULL, "") """, (datetime.datetime.now(),))
532 532 self.session_number = cur.lastrowid
533 533
534 534 def end_session(self):
535 535 """Close the database session, filling in the end time and line count."""
536 536 self.writeout_cache()
537 537 with self.db:
538 538 self.db.execute("""UPDATE sessions SET end=?, num_cmds=? WHERE
539 539 session==?""", (datetime.datetime.now(),
540 540 len(self.input_hist_parsed)-1, self.session_number))
541 541 self.session_number = 0
542 542
543 543 def name_session(self, name):
544 544 """Give the current session a name in the history database."""
545 545 with self.db:
546 546 self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
547 547 (name, self.session_number))
548 548
549 549 def reset(self, new_session=True):
550 550 """Clear the session history, releasing all object references, and
551 551 optionally open a new session."""
552 552 self.output_hist.clear()
553 553 # The directory history can't be completely empty
554 554 self.dir_hist[:] = [py3compat.getcwd()]
555 555
556 556 if new_session:
557 557 if self.session_number:
558 558 self.end_session()
559 559 self.input_hist_parsed[:] = [""]
560 560 self.input_hist_raw[:] = [""]
561 561 self.new_session()
562 562
563 563 # ------------------------------
564 564 # Methods for retrieving history
565 565 # ------------------------------
566 566 def get_session_info(self, session=0):
567 567 """Get info about a session.
568 568
569 569 Parameters
570 570 ----------
571 571
572 572 session : int
573 573 Session number to retrieve. The current session is 0, and negative
574 574 numbers count back from current session, so -1 is the previous session.
575 575
576 576 Returns
577 577 -------
578 578
579 579 session_id : int
580 580 Session ID number
581 581 start : datetime
582 582 Timestamp for the start of the session.
583 583 end : datetime
584 584 Timestamp for the end of the session, or None if IPython crashed.
585 585 num_cmds : int
586 586 Number of commands run, or None if IPython crashed.
587 587 remark : unicode
588 588 A manually set description.
589 589 """
590 590 if session <= 0:
591 591 session += self.session_number
592 592
593 593 return super(HistoryManager, self).get_session_info(session=session)
594 594
595 595 def _get_range_session(self, start=1, stop=None, raw=True, output=False):
596 596 """Get input and output history from the current session. Called by
597 597 get_range, and takes similar parameters."""
598 598 input_hist = self.input_hist_raw if raw else self.input_hist_parsed
599 599
600 600 n = len(input_hist)
601 601 if start < 0:
602 602 start += n
603 603 if not stop or (stop > n):
604 604 stop = n
605 605 elif stop < 0:
606 606 stop += n
607 607
608 608 for i in range(start, stop):
609 609 if output:
610 610 line = (input_hist[i], self.output_hist_reprs.get(i))
611 611 else:
612 612 line = input_hist[i]
613 613 yield (0, i, line)
614 614
615 615 def get_range(self, session=0, start=1, stop=None, raw=True,output=False):
616 616 """Retrieve input by session.
617 617
618 618 Parameters
619 619 ----------
620 620 session : int
621 621 Session number to retrieve. The current session is 0, and negative
622 622 numbers count back from current session, so -1 is previous session.
623 623 start : int
624 624 First line to retrieve.
625 625 stop : int
626 626 End of line range (excluded from output itself). If None, retrieve
627 627 to the end of the session.
628 628 raw : bool
629 629 If True, return untranslated input
630 630 output : bool
631 631 If True, attempt to include output. This will be 'real' Python
632 632 objects for the current session, or text reprs from previous
633 633 sessions if db_log_output was enabled at the time. Where no output
634 634 is found, None is used.
635 635
636 636 Returns
637 637 -------
638 638 entries
639 639 An iterator over the desired lines. Each line is a 3-tuple, either
640 640 (session, line, input) if output is False, or
641 641 (session, line, (input, output)) if output is True.
642 642 """
643 643 if session <= 0:
644 644 session += self.session_number
645 645 if session==self.session_number: # Current session
646 646 return self._get_range_session(start, stop, raw, output)
647 647 return super(HistoryManager, self).get_range(session, start, stop, raw,
648 648 output)
649 649
650 650 ## ----------------------------
651 651 ## Methods for storing history:
652 652 ## ----------------------------
653 653 def store_inputs(self, line_num, source, source_raw=None):
654 654 """Store source and raw input in history and create input cache
655 655 variables ``_i*``.
656 656
657 657 Parameters
658 658 ----------
659 659 line_num : int
660 660 The prompt number of this input.
661 661
662 662 source : str
663 663 Python input.
664 664
665 665 source_raw : str, optional
666 666 If given, this is the raw input without any IPython transformations
667 667 applied to it. If not given, ``source`` is used.
668 668 """
669 669 if source_raw is None:
670 670 source_raw = source
671 671 source = source.rstrip('\n')
672 672 source_raw = source_raw.rstrip('\n')
673 673
674 674 # do not store exit/quit commands
675 675 if self._exit_re.match(source_raw.strip()):
676 676 return
677 677
678 678 self.input_hist_parsed.append(source)
679 679 self.input_hist_raw.append(source_raw)
680 680
681 681 with self.db_input_cache_lock:
682 682 self.db_input_cache.append((line_num, source, source_raw))
683 683 # Trigger to flush cache and write to DB.
684 684 if len(self.db_input_cache) >= self.db_cache_size:
685 685 self.save_flag.set()
686 686
687 687 # update the auto _i variables
688 688 self._iii = self._ii
689 689 self._ii = self._i
690 690 self._i = self._i00
691 691 self._i00 = source_raw
692 692
693 693 # hackish access to user namespace to create _i1,_i2... dynamically
694 694 new_i = '_i%s' % line_num
695 695 to_main = {'_i': self._i,
696 696 '_ii': self._ii,
697 697 '_iii': self._iii,
698 698 new_i : self._i00 }
699 699
700 700 if self.shell is not None:
701 701 self.shell.push(to_main, interactive=False)
702 702
703 703 def store_output(self, line_num):
704 704 """If database output logging is enabled, this saves all the
705 705 outputs from the indicated prompt number to the database. It's
706 706 called by run_cell after code has been executed.
707 707
708 708 Parameters
709 709 ----------
710 710 line_num : int
711 711 The line number from which to save outputs
712 712 """
713 713 if (not self.db_log_output) or (line_num not in self.output_hist_reprs):
714 714 return
715 715 output = self.output_hist_reprs[line_num]
716 716
717 717 with self.db_output_cache_lock:
718 718 self.db_output_cache.append((line_num, output))
719 719 if self.db_cache_size <= 1:
720 720 self.save_flag.set()
721 721
722 722 def _writeout_input_cache(self, conn):
723 723 with conn:
724 724 for line in self.db_input_cache:
725 725 conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
726 726 (self.session_number,)+line)
727 727
728 728 def _writeout_output_cache(self, conn):
729 729 with conn:
730 730 for line in self.db_output_cache:
731 731 conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
732 732 (self.session_number,)+line)
733 733
734 734 @needs_sqlite
735 735 def writeout_cache(self, conn=None):
736 736 """Write any entries in the cache to the database."""
737 737 if conn is None:
738 738 conn = self.db
739 739
740 740 with self.db_input_cache_lock:
741 741 try:
742 742 self._writeout_input_cache(conn)
743 743 except sqlite3.IntegrityError:
744 744 self.new_session(conn)
745 745 print("ERROR! Session/line number was not unique in",
746 746 "database. History logging moved to new session",
747 747 self.session_number)
748 748 try:
749 749 # Try writing to the new session. If this fails, don't
750 750 # recurse
751 751 self._writeout_input_cache(conn)
752 752 except sqlite3.IntegrityError:
753 753 pass
754 754 finally:
755 755 self.db_input_cache = []
756 756
757 757 with self.db_output_cache_lock:
758 758 try:
759 759 self._writeout_output_cache(conn)
760 760 except sqlite3.IntegrityError:
761 761 print("!! Session/line number for output was not unique",
762 762 "in database. Output will not be stored.")
763 763 finally:
764 764 self.db_output_cache = []
765 765
766 766
767 767 class HistorySavingThread(threading.Thread):
768 768 """This thread takes care of writing history to the database, so that
769 769 the UI isn't held up while that happens.
770 770
771 771 It waits for the HistoryManager's save_flag to be set, then writes out
772 772 the history cache. The main thread is responsible for setting the flag when
773 773 the cache size reaches a defined threshold."""
774 774 daemon = True
775 775 stop_now = False
776 776 enabled = True
777 777 def __init__(self, history_manager):
778 778 super(HistorySavingThread, self).__init__(name="IPythonHistorySavingThread")
779 779 self.history_manager = history_manager
780 780 self.enabled = history_manager.enabled
781 781 atexit.register(self.stop)
782 782
783 783 @needs_sqlite
784 784 def run(self):
785 785 # We need a separate db connection per thread:
786 786 try:
787 787 self.db = sqlite3.connect(self.history_manager.hist_file,
788 788 **self.history_manager.connection_options
789 789 )
790 790 while True:
791 791 self.history_manager.save_flag.wait()
792 792 if self.stop_now:
793 793 self.db.close()
794 794 return
795 795 self.history_manager.save_flag.clear()
796 796 self.history_manager.writeout_cache(self.db)
797 797 except Exception as e:
798 798 print(("The history saving thread hit an unexpected error (%s)."
799 799 "History will not be written to the database.") % repr(e))
800 800
801 801 def stop(self):
802 802 """This can be called from the main thread to safely stop this thread.
803 803
804 804 Note that it does not attempt to write out remaining history before
805 805 exiting. That should be done by calling the HistoryManager's
806 806 end_session method."""
807 807 self.stop_now = True
808 808 self.history_manager.save_flag.set()
809 809 self.join()
810 810
811 811
812 812 # To match, e.g. ~5/8-~2/3
813 813 range_re = re.compile(r"""
814 814 ((?P<startsess>~?\d+)/)?
815 815 (?P<start>\d+)?
816 816 ((?P<sep>[\-:])
817 817 ((?P<endsess>~?\d+)/)?
818 818 (?P<end>\d+))?
819 819 $""", re.VERBOSE)
820 820
821 821
822 822 def extract_hist_ranges(ranges_str):
823 823 """Turn a string of history ranges into 3-tuples of (session, start, stop).
824 824
825 825 Examples
826 826 --------
827 827 >>> list(extract_hist_ranges("~8/5-~7/4 2"))
828 828 [(-8, 5, None), (-7, 1, 5), (0, 2, 3)]
829 829 """
830 830 for range_str in ranges_str.split():
831 831 rmatch = range_re.match(range_str)
832 832 if not rmatch:
833 833 continue
834 834 start = rmatch.group("start")
835 835 if start:
836 836 start = int(start)
837 837 end = rmatch.group("end")
838 838 # If no end specified, get (a, a + 1)
839 839 end = int(end) if end else start + 1
840 840 else: # start not specified
841 841 if not rmatch.group('startsess'): # no startsess
842 842 continue
843 843 start = 1
844 844 end = None # provide the entire session hist
845 845
846 846 if rmatch.group("sep") == "-": # 1-3 == 1:4 --> [1, 2, 3]
847 847 end += 1
848 848 startsess = rmatch.group("startsess") or "0"
849 849 endsess = rmatch.group("endsess") or startsess
850 850 startsess = int(startsess.replace("~","-"))
851 851 endsess = int(endsess.replace("~","-"))
852 852 assert endsess >= startsess, "start session must be earlier than end session"
853 853
854 854 if endsess == startsess:
855 855 yield (startsess, start, end)
856 856 continue
857 857 # Multiple sessions in one range:
858 858 yield (startsess, start, None)
859 859 for sess in range(startsess+1, endsess):
860 860 yield (sess, 1, None)
861 861 yield (endsess, 1, end)
862 862
863 863
864 864 def _format_lineno(session, line):
865 865 """Helper function to format line numbers properly."""
866 866 if session == 0:
867 867 return str(line)
868 868 return "%s#%s" % (session, line)
869 869
870 870
@@ -1,702 +1,702 b''
1 1 # encoding: utf-8
2 2 """Magic functions for InteractiveShell.
3 3 """
4 4 from __future__ import print_function
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2001 Janko Hauser <jhauser@zscout.de> and
8 8 # Copyright (C) 2001 Fernando Perez <fperez@colorado.edu>
9 9 # Copyright (C) 2008 The IPython Development Team
10 10
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18 # Stdlib
19 19 import os
20 20 import re
21 21 import sys
22 22 import types
23 23 from getopt import getopt, GetoptError
24 24
25 25 # Our own
26 26 from IPython.config.configurable import Configurable
27 27 from IPython.core import oinspect
28 28 from IPython.core.error import UsageError
29 29 from IPython.core.inputsplitter import ESC_MAGIC, ESC_MAGIC2
30 from IPython.external.decorator import decorator
30 from decorator import decorator
31 31 from IPython.utils.ipstruct import Struct
32 32 from IPython.utils.process import arg_split
33 33 from IPython.utils.py3compat import string_types, iteritems
34 34 from IPython.utils.text import dedent
35 35 from IPython.utils.traitlets import Bool, Dict, Instance, MetaHasTraits
36 36 from IPython.utils.warn import error
37 37
38 38 #-----------------------------------------------------------------------------
39 39 # Globals
40 40 #-----------------------------------------------------------------------------
41 41
42 42 # A dict we'll use for each class that has magics, used as temporary storage to
43 43 # pass information between the @line/cell_magic method decorators and the
44 44 # @magics_class class decorator, because the method decorators have no
45 45 # access to the class when they run. See for more details:
46 46 # http://stackoverflow.com/questions/2366713/can-a-python-decorator-of-an-instance-method-access-the-class
47 47
48 48 magics = dict(line={}, cell={})
49 49
50 50 magic_kinds = ('line', 'cell')
51 51 magic_spec = ('line', 'cell', 'line_cell')
52 52 magic_escapes = dict(line=ESC_MAGIC, cell=ESC_MAGIC2)
53 53
54 54 #-----------------------------------------------------------------------------
55 55 # Utility classes and functions
56 56 #-----------------------------------------------------------------------------
57 57
58 58 class Bunch: pass
59 59
60 60
61 61 def on_off(tag):
62 62 """Return an ON/OFF string for a 1/0 input. Simple utility function."""
63 63 return ['OFF','ON'][tag]
64 64
65 65
66 66 def compress_dhist(dh):
67 67 """Compress a directory history into a new one with at most 20 entries.
68 68
69 69 Return a new list made from the first and last 10 elements of dhist after
70 70 removal of duplicates.
71 71 """
72 72 head, tail = dh[:-10], dh[-10:]
73 73
74 74 newhead = []
75 75 done = set()
76 76 for h in head:
77 77 if h in done:
78 78 continue
79 79 newhead.append(h)
80 80 done.add(h)
81 81
82 82 return newhead + tail
83 83
84 84
85 85 def needs_local_scope(func):
86 86 """Decorator to mark magic functions which need to local scope to run."""
87 87 func.needs_local_scope = True
88 88 return func
89 89
90 90 #-----------------------------------------------------------------------------
91 91 # Class and method decorators for registering magics
92 92 #-----------------------------------------------------------------------------
93 93
94 94 def magics_class(cls):
95 95 """Class decorator for all subclasses of the main Magics class.
96 96
97 97 Any class that subclasses Magics *must* also apply this decorator, to
98 98 ensure that all the methods that have been decorated as line/cell magics
99 99 get correctly registered in the class instance. This is necessary because
100 100 when method decorators run, the class does not exist yet, so they
101 101 temporarily store their information into a module global. Application of
102 102 this class decorator copies that global data to the class instance and
103 103 clears the global.
104 104
105 105 Obviously, this mechanism is not thread-safe, which means that the
106 106 *creation* of subclasses of Magic should only be done in a single-thread
107 107 context. Instantiation of the classes has no restrictions. Given that
108 108 these classes are typically created at IPython startup time and before user
109 109 application code becomes active, in practice this should not pose any
110 110 problems.
111 111 """
112 112 cls.registered = True
113 113 cls.magics = dict(line = magics['line'],
114 114 cell = magics['cell'])
115 115 magics['line'] = {}
116 116 magics['cell'] = {}
117 117 return cls
118 118
119 119
120 120 def record_magic(dct, magic_kind, magic_name, func):
121 121 """Utility function to store a function as a magic of a specific kind.
122 122
123 123 Parameters
124 124 ----------
125 125 dct : dict
126 126 A dictionary with 'line' and 'cell' subdicts.
127 127
128 128 magic_kind : str
129 129 Kind of magic to be stored.
130 130
131 131 magic_name : str
132 132 Key to store the magic as.
133 133
134 134 func : function
135 135 Callable object to store.
136 136 """
137 137 if magic_kind == 'line_cell':
138 138 dct['line'][magic_name] = dct['cell'][magic_name] = func
139 139 else:
140 140 dct[magic_kind][magic_name] = func
141 141
142 142
143 143 def validate_type(magic_kind):
144 144 """Ensure that the given magic_kind is valid.
145 145
146 146 Check that the given magic_kind is one of the accepted spec types (stored
147 147 in the global `magic_spec`), raise ValueError otherwise.
148 148 """
149 149 if magic_kind not in magic_spec:
150 150 raise ValueError('magic_kind must be one of %s, %s given' %
151 151 magic_kinds, magic_kind)
152 152
153 153
154 154 # The docstrings for the decorator below will be fairly similar for the two
155 155 # types (method and function), so we generate them here once and reuse the
156 156 # templates below.
157 157 _docstring_template = \
158 158 """Decorate the given {0} as {1} magic.
159 159
160 160 The decorator can be used with or without arguments, as follows.
161 161
162 162 i) without arguments: it will create a {1} magic named as the {0} being
163 163 decorated::
164 164
165 165 @deco
166 166 def foo(...)
167 167
168 168 will create a {1} magic named `foo`.
169 169
170 170 ii) with one string argument: which will be used as the actual name of the
171 171 resulting magic::
172 172
173 173 @deco('bar')
174 174 def foo(...)
175 175
176 176 will create a {1} magic named `bar`.
177 177 """
178 178
179 179 # These two are decorator factories. While they are conceptually very similar,
180 180 # there are enough differences in the details that it's simpler to have them
181 181 # written as completely standalone functions rather than trying to share code
182 182 # and make a single one with convoluted logic.
183 183
184 184 def _method_magic_marker(magic_kind):
185 185 """Decorator factory for methods in Magics subclasses.
186 186 """
187 187
188 188 validate_type(magic_kind)
189 189
190 190 # This is a closure to capture the magic_kind. We could also use a class,
191 191 # but it's overkill for just that one bit of state.
192 192 def magic_deco(arg):
193 193 call = lambda f, *a, **k: f(*a, **k)
194 194
195 195 if callable(arg):
196 196 # "Naked" decorator call (just @foo, no args)
197 197 func = arg
198 198 name = func.__name__
199 199 retval = decorator(call, func)
200 200 record_magic(magics, magic_kind, name, name)
201 201 elif isinstance(arg, string_types):
202 202 # Decorator called with arguments (@foo('bar'))
203 203 name = arg
204 204 def mark(func, *a, **kw):
205 205 record_magic(magics, magic_kind, name, func.__name__)
206 206 return decorator(call, func)
207 207 retval = mark
208 208 else:
209 209 raise TypeError("Decorator can only be called with "
210 210 "string or function")
211 211 return retval
212 212
213 213 # Ensure the resulting decorator has a usable docstring
214 214 magic_deco.__doc__ = _docstring_template.format('method', magic_kind)
215 215 return magic_deco
216 216
217 217
218 218 def _function_magic_marker(magic_kind):
219 219 """Decorator factory for standalone functions.
220 220 """
221 221 validate_type(magic_kind)
222 222
223 223 # This is a closure to capture the magic_kind. We could also use a class,
224 224 # but it's overkill for just that one bit of state.
225 225 def magic_deco(arg):
226 226 call = lambda f, *a, **k: f(*a, **k)
227 227
228 228 # Find get_ipython() in the caller's namespace
229 229 caller = sys._getframe(1)
230 230 for ns in ['f_locals', 'f_globals', 'f_builtins']:
231 231 get_ipython = getattr(caller, ns).get('get_ipython')
232 232 if get_ipython is not None:
233 233 break
234 234 else:
235 235 raise NameError('Decorator can only run in context where '
236 236 '`get_ipython` exists')
237 237
238 238 ip = get_ipython()
239 239
240 240 if callable(arg):
241 241 # "Naked" decorator call (just @foo, no args)
242 242 func = arg
243 243 name = func.__name__
244 244 ip.register_magic_function(func, magic_kind, name)
245 245 retval = decorator(call, func)
246 246 elif isinstance(arg, string_types):
247 247 # Decorator called with arguments (@foo('bar'))
248 248 name = arg
249 249 def mark(func, *a, **kw):
250 250 ip.register_magic_function(func, magic_kind, name)
251 251 return decorator(call, func)
252 252 retval = mark
253 253 else:
254 254 raise TypeError("Decorator can only be called with "
255 255 "string or function")
256 256 return retval
257 257
258 258 # Ensure the resulting decorator has a usable docstring
259 259 ds = _docstring_template.format('function', magic_kind)
260 260
261 261 ds += dedent("""
262 262 Note: this decorator can only be used in a context where IPython is already
263 263 active, so that the `get_ipython()` call succeeds. You can therefore use
264 264 it in your startup files loaded after IPython initializes, but *not* in the
265 265 IPython configuration file itself, which is executed before IPython is
266 266 fully up and running. Any file located in the `startup` subdirectory of
267 267 your configuration profile will be OK in this sense.
268 268 """)
269 269
270 270 magic_deco.__doc__ = ds
271 271 return magic_deco
272 272
273 273
274 274 # Create the actual decorators for public use
275 275
276 276 # These three are used to decorate methods in class definitions
277 277 line_magic = _method_magic_marker('line')
278 278 cell_magic = _method_magic_marker('cell')
279 279 line_cell_magic = _method_magic_marker('line_cell')
280 280
281 281 # These three decorate standalone functions and perform the decoration
282 282 # immediately. They can only run where get_ipython() works
283 283 register_line_magic = _function_magic_marker('line')
284 284 register_cell_magic = _function_magic_marker('cell')
285 285 register_line_cell_magic = _function_magic_marker('line_cell')
286 286
287 287 #-----------------------------------------------------------------------------
288 288 # Core Magic classes
289 289 #-----------------------------------------------------------------------------
290 290
291 291 class MagicsManager(Configurable):
292 292 """Object that handles all magic-related functionality for IPython.
293 293 """
294 294 # Non-configurable class attributes
295 295
296 296 # A two-level dict, first keyed by magic type, then by magic function, and
297 297 # holding the actual callable object as value. This is the dict used for
298 298 # magic function dispatch
299 299 magics = Dict
300 300
301 301 # A registry of the original objects that we've been given holding magics.
302 302 registry = Dict
303 303
304 304 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
305 305
306 306 auto_magic = Bool(True, config=True, help=
307 307 "Automatically call line magics without requiring explicit % prefix")
308 308
309 309 def _auto_magic_changed(self, name, value):
310 310 self.shell.automagic = value
311 311
312 312 _auto_status = [
313 313 'Automagic is OFF, % prefix IS needed for line magics.',
314 314 'Automagic is ON, % prefix IS NOT needed for line magics.']
315 315
316 316 user_magics = Instance('IPython.core.magics.UserMagics')
317 317
318 318 def __init__(self, shell=None, config=None, user_magics=None, **traits):
319 319
320 320 super(MagicsManager, self).__init__(shell=shell, config=config,
321 321 user_magics=user_magics, **traits)
322 322 self.magics = dict(line={}, cell={})
323 323 # Let's add the user_magics to the registry for uniformity, so *all*
324 324 # registered magic containers can be found there.
325 325 self.registry[user_magics.__class__.__name__] = user_magics
326 326
327 327 def auto_status(self):
328 328 """Return descriptive string with automagic status."""
329 329 return self._auto_status[self.auto_magic]
330 330
331 331 def lsmagic(self):
332 332 """Return a dict of currently available magic functions.
333 333
334 334 The return dict has the keys 'line' and 'cell', corresponding to the
335 335 two types of magics we support. Each value is a list of names.
336 336 """
337 337 return self.magics
338 338
339 339 def lsmagic_docs(self, brief=False, missing=''):
340 340 """Return dict of documentation of magic functions.
341 341
342 342 The return dict has the keys 'line' and 'cell', corresponding to the
343 343 two types of magics we support. Each value is a dict keyed by magic
344 344 name whose value is the function docstring. If a docstring is
345 345 unavailable, the value of `missing` is used instead.
346 346
347 347 If brief is True, only the first line of each docstring will be returned.
348 348 """
349 349 docs = {}
350 350 for m_type in self.magics:
351 351 m_docs = {}
352 352 for m_name, m_func in iteritems(self.magics[m_type]):
353 353 if m_func.__doc__:
354 354 if brief:
355 355 m_docs[m_name] = m_func.__doc__.split('\n', 1)[0]
356 356 else:
357 357 m_docs[m_name] = m_func.__doc__.rstrip()
358 358 else:
359 359 m_docs[m_name] = missing
360 360 docs[m_type] = m_docs
361 361 return docs
362 362
363 363 def register(self, *magic_objects):
364 364 """Register one or more instances of Magics.
365 365
366 366 Take one or more classes or instances of classes that subclass the main
367 367 `core.Magic` class, and register them with IPython to use the magic
368 368 functions they provide. The registration process will then ensure that
369 369 any methods that have decorated to provide line and/or cell magics will
370 370 be recognized with the `%x`/`%%x` syntax as a line/cell magic
371 371 respectively.
372 372
373 373 If classes are given, they will be instantiated with the default
374 374 constructor. If your classes need a custom constructor, you should
375 375 instanitate them first and pass the instance.
376 376
377 377 The provided arguments can be an arbitrary mix of classes and instances.
378 378
379 379 Parameters
380 380 ----------
381 381 magic_objects : one or more classes or instances
382 382 """
383 383 # Start by validating them to ensure they have all had their magic
384 384 # methods registered at the instance level
385 385 for m in magic_objects:
386 386 if not m.registered:
387 387 raise ValueError("Class of magics %r was constructed without "
388 388 "the @register_magics class decorator")
389 389 if type(m) in (type, MetaHasTraits):
390 390 # If we're given an uninstantiated class
391 391 m = m(shell=self.shell)
392 392
393 393 # Now that we have an instance, we can register it and update the
394 394 # table of callables
395 395 self.registry[m.__class__.__name__] = m
396 396 for mtype in magic_kinds:
397 397 self.magics[mtype].update(m.magics[mtype])
398 398
399 399 def register_function(self, func, magic_kind='line', magic_name=None):
400 400 """Expose a standalone function as magic function for IPython.
401 401
402 402 This will create an IPython magic (line, cell or both) from a
403 403 standalone function. The functions should have the following
404 404 signatures:
405 405
406 406 * For line magics: `def f(line)`
407 407 * For cell magics: `def f(line, cell)`
408 408 * For a function that does both: `def f(line, cell=None)`
409 409
410 410 In the latter case, the function will be called with `cell==None` when
411 411 invoked as `%f`, and with cell as a string when invoked as `%%f`.
412 412
413 413 Parameters
414 414 ----------
415 415 func : callable
416 416 Function to be registered as a magic.
417 417
418 418 magic_kind : str
419 419 Kind of magic, one of 'line', 'cell' or 'line_cell'
420 420
421 421 magic_name : optional str
422 422 If given, the name the magic will have in the IPython namespace. By
423 423 default, the name of the function itself is used.
424 424 """
425 425
426 426 # Create the new method in the user_magics and register it in the
427 427 # global table
428 428 validate_type(magic_kind)
429 429 magic_name = func.__name__ if magic_name is None else magic_name
430 430 setattr(self.user_magics, magic_name, func)
431 431 record_magic(self.magics, magic_kind, magic_name, func)
432 432
433 433 def define_magic(self, name, func):
434 434 """[Deprecated] Expose own function as magic function for IPython.
435 435
436 436 Example::
437 437
438 438 def foo_impl(self, parameter_s=''):
439 439 'My very own magic!. (Use docstrings, IPython reads them).'
440 440 print 'Magic function. Passed parameter is between < >:'
441 441 print '<%s>' % parameter_s
442 442 print 'The self object is:', self
443 443
444 444 ip.define_magic('foo',foo_impl)
445 445 """
446 446 meth = types.MethodType(func, self.user_magics)
447 447 setattr(self.user_magics, name, meth)
448 448 record_magic(self.magics, 'line', name, meth)
449 449
450 450 def register_alias(self, alias_name, magic_name, magic_kind='line'):
451 451 """Register an alias to a magic function.
452 452
453 453 The alias is an instance of :class:`MagicAlias`, which holds the
454 454 name and kind of the magic it should call. Binding is done at
455 455 call time, so if the underlying magic function is changed the alias
456 456 will call the new function.
457 457
458 458 Parameters
459 459 ----------
460 460 alias_name : str
461 461 The name of the magic to be registered.
462 462
463 463 magic_name : str
464 464 The name of an existing magic.
465 465
466 466 magic_kind : str
467 467 Kind of magic, one of 'line' or 'cell'
468 468 """
469 469
470 470 # `validate_type` is too permissive, as it allows 'line_cell'
471 471 # which we do not handle.
472 472 if magic_kind not in magic_kinds:
473 473 raise ValueError('magic_kind must be one of %s, %s given' %
474 474 magic_kinds, magic_kind)
475 475
476 476 alias = MagicAlias(self.shell, magic_name, magic_kind)
477 477 setattr(self.user_magics, alias_name, alias)
478 478 record_magic(self.magics, magic_kind, alias_name, alias)
479 479
480 480 # Key base class that provides the central functionality for magics.
481 481
482 482
483 483 class Magics(Configurable):
484 484 """Base class for implementing magic functions.
485 485
486 486 Shell functions which can be reached as %function_name. All magic
487 487 functions should accept a string, which they can parse for their own
488 488 needs. This can make some functions easier to type, eg `%cd ../`
489 489 vs. `%cd("../")`
490 490
491 491 Classes providing magic functions need to subclass this class, and they
492 492 MUST:
493 493
494 494 - Use the method decorators `@line_magic` and `@cell_magic` to decorate
495 495 individual methods as magic functions, AND
496 496
497 497 - Use the class decorator `@magics_class` to ensure that the magic
498 498 methods are properly registered at the instance level upon instance
499 499 initialization.
500 500
501 501 See :mod:`magic_functions` for examples of actual implementation classes.
502 502 """
503 503 # Dict holding all command-line options for each magic.
504 504 options_table = None
505 505 # Dict for the mapping of magic names to methods, set by class decorator
506 506 magics = None
507 507 # Flag to check that the class decorator was properly applied
508 508 registered = False
509 509 # Instance of IPython shell
510 510 shell = None
511 511
512 512 def __init__(self, shell=None, **kwargs):
513 513 if not(self.__class__.registered):
514 514 raise ValueError('Magics subclass without registration - '
515 515 'did you forget to apply @magics_class?')
516 516 if shell is not None:
517 517 if hasattr(shell, 'configurables'):
518 518 shell.configurables.append(self)
519 519 if hasattr(shell, 'config'):
520 520 kwargs.setdefault('parent', shell)
521 521 kwargs['shell'] = shell
522 522
523 523 self.shell = shell
524 524 self.options_table = {}
525 525 # The method decorators are run when the instance doesn't exist yet, so
526 526 # they can only record the names of the methods they are supposed to
527 527 # grab. Only now, that the instance exists, can we create the proper
528 528 # mapping to bound methods. So we read the info off the original names
529 529 # table and replace each method name by the actual bound method.
530 530 # But we mustn't clobber the *class* mapping, in case of multiple instances.
531 531 class_magics = self.magics
532 532 self.magics = {}
533 533 for mtype in magic_kinds:
534 534 tab = self.magics[mtype] = {}
535 535 cls_tab = class_magics[mtype]
536 536 for magic_name, meth_name in iteritems(cls_tab):
537 537 if isinstance(meth_name, string_types):
538 538 # it's a method name, grab it
539 539 tab[magic_name] = getattr(self, meth_name)
540 540 else:
541 541 # it's the real thing
542 542 tab[magic_name] = meth_name
543 543 # Configurable **needs** to be initiated at the end or the config
544 544 # magics get screwed up.
545 545 super(Magics, self).__init__(**kwargs)
546 546
547 547 def arg_err(self,func):
548 548 """Print docstring if incorrect arguments were passed"""
549 549 print('Error in arguments:')
550 550 print(oinspect.getdoc(func))
551 551
552 552 def format_latex(self, strng):
553 553 """Format a string for latex inclusion."""
554 554
555 555 # Characters that need to be escaped for latex:
556 556 escape_re = re.compile(r'(%|_|\$|#|&)',re.MULTILINE)
557 557 # Magic command names as headers:
558 558 cmd_name_re = re.compile(r'^(%s.*?):' % ESC_MAGIC,
559 559 re.MULTILINE)
560 560 # Magic commands
561 561 cmd_re = re.compile(r'(?P<cmd>%s.+?\b)(?!\}\}:)' % ESC_MAGIC,
562 562 re.MULTILINE)
563 563 # Paragraph continue
564 564 par_re = re.compile(r'\\$',re.MULTILINE)
565 565
566 566 # The "\n" symbol
567 567 newline_re = re.compile(r'\\n')
568 568
569 569 # Now build the string for output:
570 570 #strng = cmd_name_re.sub(r'\n\\texttt{\\textsl{\\large \1}}:',strng)
571 571 strng = cmd_name_re.sub(r'\n\\bigskip\n\\texttt{\\textbf{ \1}}:',
572 572 strng)
573 573 strng = cmd_re.sub(r'\\texttt{\g<cmd>}',strng)
574 574 strng = par_re.sub(r'\\\\',strng)
575 575 strng = escape_re.sub(r'\\\1',strng)
576 576 strng = newline_re.sub(r'\\textbackslash{}n',strng)
577 577 return strng
578 578
579 579 def parse_options(self, arg_str, opt_str, *long_opts, **kw):
580 580 """Parse options passed to an argument string.
581 581
582 582 The interface is similar to that of :func:`getopt.getopt`, but it
583 583 returns a :class:`~IPython.utils.struct.Struct` with the options as keys
584 584 and the stripped argument string still as a string.
585 585
586 586 arg_str is quoted as a true sys.argv vector by using shlex.split.
587 587 This allows us to easily expand variables, glob files, quote
588 588 arguments, etc.
589 589
590 590 Parameters
591 591 ----------
592 592
593 593 arg_str : str
594 594 The arguments to parse.
595 595
596 596 opt_str : str
597 597 The options specification.
598 598
599 599 mode : str, default 'string'
600 600 If given as 'list', the argument string is returned as a list (split
601 601 on whitespace) instead of a string.
602 602
603 603 list_all : bool, default False
604 604 Put all option values in lists. Normally only options
605 605 appearing more than once are put in a list.
606 606
607 607 posix : bool, default True
608 608 Whether to split the input line in POSIX mode or not, as per the
609 609 conventions outlined in the :mod:`shlex` module from the standard
610 610 library.
611 611 """
612 612
613 613 # inject default options at the beginning of the input line
614 614 caller = sys._getframe(1).f_code.co_name
615 615 arg_str = '%s %s' % (self.options_table.get(caller,''),arg_str)
616 616
617 617 mode = kw.get('mode','string')
618 618 if mode not in ['string','list']:
619 619 raise ValueError('incorrect mode given: %s' % mode)
620 620 # Get options
621 621 list_all = kw.get('list_all',0)
622 622 posix = kw.get('posix', os.name == 'posix')
623 623 strict = kw.get('strict', True)
624 624
625 625 # Check if we have more than one argument to warrant extra processing:
626 626 odict = {} # Dictionary with options
627 627 args = arg_str.split()
628 628 if len(args) >= 1:
629 629 # If the list of inputs only has 0 or 1 thing in it, there's no
630 630 # need to look for options
631 631 argv = arg_split(arg_str, posix, strict)
632 632 # Do regular option processing
633 633 try:
634 634 opts,args = getopt(argv, opt_str, long_opts)
635 635 except GetoptError as e:
636 636 raise UsageError('%s ( allowed: "%s" %s)' % (e.msg,opt_str,
637 637 " ".join(long_opts)))
638 638 for o,a in opts:
639 639 if o.startswith('--'):
640 640 o = o[2:]
641 641 else:
642 642 o = o[1:]
643 643 try:
644 644 odict[o].append(a)
645 645 except AttributeError:
646 646 odict[o] = [odict[o],a]
647 647 except KeyError:
648 648 if list_all:
649 649 odict[o] = [a]
650 650 else:
651 651 odict[o] = a
652 652
653 653 # Prepare opts,args for return
654 654 opts = Struct(odict)
655 655 if mode == 'string':
656 656 args = ' '.join(args)
657 657
658 658 return opts,args
659 659
660 660 def default_option(self, fn, optstr):
661 661 """Make an entry in the options_table for fn, with value optstr"""
662 662
663 663 if fn not in self.lsmagic():
664 664 error("%s is not a magic function" % fn)
665 665 self.options_table[fn] = optstr
666 666
667 667
668 668 class MagicAlias(object):
669 669 """An alias to another magic function.
670 670
671 671 An alias is determined by its magic name and magic kind. Lookup
672 672 is done at call time, so if the underlying magic changes the alias
673 673 will call the new function.
674 674
675 675 Use the :meth:`MagicsManager.register_alias` method or the
676 676 `%alias_magic` magic function to create and register a new alias.
677 677 """
678 678 def __init__(self, shell, magic_name, magic_kind):
679 679 self.shell = shell
680 680 self.magic_name = magic_name
681 681 self.magic_kind = magic_kind
682 682
683 683 self.pretty_target = '%s%s' % (magic_escapes[self.magic_kind], self.magic_name)
684 684 self.__doc__ = "Alias for `%s`." % self.pretty_target
685 685
686 686 self._in_call = False
687 687
688 688 def __call__(self, *args, **kwargs):
689 689 """Call the magic alias."""
690 690 fn = self.shell.find_magic(self.magic_name, self.magic_kind)
691 691 if fn is None:
692 692 raise UsageError("Magic `%s` not found." % self.pretty_target)
693 693
694 694 # Protect against infinite recursion.
695 695 if self._in_call:
696 696 raise UsageError("Infinite recursion detected; "
697 697 "magic aliases cannot call themselves.")
698 698 self._in_call = True
699 699 try:
700 700 return fn(*args, **kwargs)
701 701 finally:
702 702 self._in_call = False
@@ -1,382 +1,382 b''
1 1 """Tests for the object inspection functionality.
2 2 """
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (C) 2010-2011 The IPython Development Team.
5 5 #
6 6 # Distributed under the terms of the BSD License.
7 7 #
8 8 # The full license is in the file COPYING.txt, distributed with this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14 from __future__ import print_function
15 15
16 16 # Stdlib imports
17 17 import os
18 18 import re
19 19
20 20 # Third-party imports
21 21 import nose.tools as nt
22 22
23 23 # Our own imports
24 24 from .. import oinspect
25 25 from IPython.core.magic import (Magics, magics_class, line_magic,
26 26 cell_magic, line_cell_magic,
27 27 register_line_magic, register_cell_magic,
28 28 register_line_cell_magic)
29 from IPython.external.decorator import decorator
29 from decorator import decorator
30 30 from IPython.testing.decorators import skipif
31 31 from IPython.testing.tools import AssertPrints
32 32 from IPython.utils.path import compress_user
33 33 from IPython.utils import py3compat
34 34
35 35
36 36 #-----------------------------------------------------------------------------
37 37 # Globals and constants
38 38 #-----------------------------------------------------------------------------
39 39
40 40 inspector = oinspect.Inspector()
41 41 ip = get_ipython()
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Local utilities
45 45 #-----------------------------------------------------------------------------
46 46
47 47 # WARNING: since this test checks the line number where a function is
48 48 # defined, if any code is inserted above, the following line will need to be
49 49 # updated. Do NOT insert any whitespace between the next line and the function
50 50 # definition below.
51 51 THIS_LINE_NUMBER = 51 # Put here the actual number of this line
52 52 def test_find_source_lines():
53 53 nt.assert_equal(oinspect.find_source_lines(test_find_source_lines),
54 54 THIS_LINE_NUMBER+1)
55 55
56 56
57 57 # A couple of utilities to ensure these tests work the same from a source or a
58 58 # binary install
59 59 def pyfile(fname):
60 60 return os.path.normcase(re.sub('.py[co]$', '.py', fname))
61 61
62 62
63 63 def match_pyfiles(f1, f2):
64 64 nt.assert_equal(pyfile(f1), pyfile(f2))
65 65
66 66
67 67 def test_find_file():
68 68 match_pyfiles(oinspect.find_file(test_find_file), os.path.abspath(__file__))
69 69
70 70
71 71 def test_find_file_decorated1():
72 72
73 73 @decorator
74 74 def noop1(f):
75 75 def wrapper():
76 76 return f(*a, **kw)
77 77 return wrapper
78 78
79 79 @noop1
80 80 def f(x):
81 81 "My docstring"
82 82
83 83 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
84 84 nt.assert_equal(f.__doc__, "My docstring")
85 85
86 86
87 87 def test_find_file_decorated2():
88 88
89 89 @decorator
90 90 def noop2(f, *a, **kw):
91 91 return f(*a, **kw)
92 92
93 93 @noop2
94 94 def f(x):
95 95 "My docstring 2"
96 96
97 97 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
98 98 nt.assert_equal(f.__doc__, "My docstring 2")
99 99
100 100
101 101 def test_find_file_magic():
102 102 run = ip.find_line_magic('run')
103 103 nt.assert_not_equal(oinspect.find_file(run), None)
104 104
105 105
106 106 # A few generic objects we can then inspect in the tests below
107 107
108 108 class Call(object):
109 109 """This is the class docstring."""
110 110
111 111 def __init__(self, x, y=1):
112 112 """This is the constructor docstring."""
113 113
114 114 def __call__(self, *a, **kw):
115 115 """This is the call docstring."""
116 116
117 117 def method(self, x, z=2):
118 118 """Some method's docstring"""
119 119
120 120 class SimpleClass(object):
121 121 def method(self, x, z=2):
122 122 """Some method's docstring"""
123 123
124 124
125 125 class OldStyle:
126 126 """An old-style class for testing."""
127 127 pass
128 128
129 129
130 130 def f(x, y=2, *a, **kw):
131 131 """A simple function."""
132 132
133 133
134 134 def g(y, z=3, *a, **kw):
135 135 pass # no docstring
136 136
137 137
138 138 @register_line_magic
139 139 def lmagic(line):
140 140 "A line magic"
141 141
142 142
143 143 @register_cell_magic
144 144 def cmagic(line, cell):
145 145 "A cell magic"
146 146
147 147
148 148 @register_line_cell_magic
149 149 def lcmagic(line, cell=None):
150 150 "A line/cell magic"
151 151
152 152
153 153 @magics_class
154 154 class SimpleMagics(Magics):
155 155 @line_magic
156 156 def Clmagic(self, cline):
157 157 "A class-based line magic"
158 158
159 159 @cell_magic
160 160 def Ccmagic(self, cline, ccell):
161 161 "A class-based cell magic"
162 162
163 163 @line_cell_magic
164 164 def Clcmagic(self, cline, ccell=None):
165 165 "A class-based line/cell magic"
166 166
167 167
168 168 class Awkward(object):
169 169 def __getattr__(self, name):
170 170 raise Exception(name)
171 171
172 172
173 173 def check_calltip(obj, name, call, docstring):
174 174 """Generic check pattern all calltip tests will use"""
175 175 info = inspector.info(obj, name)
176 176 call_line, ds = oinspect.call_tip(info)
177 177 nt.assert_equal(call_line, call)
178 178 nt.assert_equal(ds, docstring)
179 179
180 180 #-----------------------------------------------------------------------------
181 181 # Tests
182 182 #-----------------------------------------------------------------------------
183 183
184 184 def test_calltip_class():
185 185 check_calltip(Call, 'Call', 'Call(x, y=1)', Call.__init__.__doc__)
186 186
187 187
188 188 def test_calltip_instance():
189 189 c = Call(1)
190 190 check_calltip(c, 'c', 'c(*a, **kw)', c.__call__.__doc__)
191 191
192 192
193 193 def test_calltip_method():
194 194 c = Call(1)
195 195 check_calltip(c.method, 'c.method', 'c.method(x, z=2)', c.method.__doc__)
196 196
197 197
198 198 def test_calltip_function():
199 199 check_calltip(f, 'f', 'f(x, y=2, *a, **kw)', f.__doc__)
200 200
201 201
202 202 def test_calltip_function2():
203 203 check_calltip(g, 'g', 'g(y, z=3, *a, **kw)', '<no docstring>')
204 204
205 205
206 206 def test_calltip_builtin():
207 207 check_calltip(sum, 'sum', None, sum.__doc__)
208 208
209 209
210 210 def test_calltip_line_magic():
211 211 check_calltip(lmagic, 'lmagic', 'lmagic(line)', "A line magic")
212 212
213 213
214 214 def test_calltip_cell_magic():
215 215 check_calltip(cmagic, 'cmagic', 'cmagic(line, cell)', "A cell magic")
216 216
217 217
218 218 def test_calltip_line_cell_magic():
219 219 check_calltip(lcmagic, 'lcmagic', 'lcmagic(line, cell=None)',
220 220 "A line/cell magic")
221 221
222 222
223 223 def test_class_magics():
224 224 cm = SimpleMagics(ip)
225 225 ip.register_magics(cm)
226 226 check_calltip(cm.Clmagic, 'Clmagic', 'Clmagic(cline)',
227 227 "A class-based line magic")
228 228 check_calltip(cm.Ccmagic, 'Ccmagic', 'Ccmagic(cline, ccell)',
229 229 "A class-based cell magic")
230 230 check_calltip(cm.Clcmagic, 'Clcmagic', 'Clcmagic(cline, ccell=None)',
231 231 "A class-based line/cell magic")
232 232
233 233
234 234 def test_info():
235 235 "Check that Inspector.info fills out various fields as expected."
236 236 i = inspector.info(Call, oname='Call')
237 237 nt.assert_equal(i['type_name'], 'type')
238 238 expted_class = str(type(type)) # <class 'type'> (Python 3) or <type 'type'>
239 239 nt.assert_equal(i['base_class'], expted_class)
240 240 nt.assert_equal(i['string_form'], "<class 'IPython.core.tests.test_oinspect.Call'>")
241 241 fname = __file__
242 242 if fname.endswith(".pyc"):
243 243 fname = fname[:-1]
244 244 # case-insensitive comparison needed on some filesystems
245 245 # e.g. Windows:
246 246 nt.assert_equal(i['file'].lower(), compress_user(fname.lower()))
247 247 nt.assert_equal(i['definition'], None)
248 248 nt.assert_equal(i['docstring'], Call.__doc__)
249 249 nt.assert_equal(i['source'], None)
250 250 nt.assert_true(i['isclass'])
251 251 nt.assert_equal(i['init_definition'], "Call(self, x, y=1)\n")
252 252 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
253 253
254 254 i = inspector.info(Call, detail_level=1)
255 255 nt.assert_not_equal(i['source'], None)
256 256 nt.assert_equal(i['docstring'], None)
257 257
258 258 c = Call(1)
259 259 c.__doc__ = "Modified instance docstring"
260 260 i = inspector.info(c)
261 261 nt.assert_equal(i['type_name'], 'Call')
262 262 nt.assert_equal(i['docstring'], "Modified instance docstring")
263 263 nt.assert_equal(i['class_docstring'], Call.__doc__)
264 264 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
265 265 nt.assert_equal(i['call_docstring'], Call.__call__.__doc__)
266 266
267 267 # Test old-style classes, which for example may not have an __init__ method.
268 268 if not py3compat.PY3:
269 269 i = inspector.info(OldStyle)
270 270 nt.assert_equal(i['type_name'], 'classobj')
271 271
272 272 i = inspector.info(OldStyle())
273 273 nt.assert_equal(i['type_name'], 'instance')
274 274 nt.assert_equal(i['docstring'], OldStyle.__doc__)
275 275
276 276 def test_info_awkward():
277 277 # Just test that this doesn't throw an error.
278 278 i = inspector.info(Awkward())
279 279
280 280 def test_calldef_none():
281 281 # We should ignore __call__ for all of these.
282 282 for obj in [f, SimpleClass().method, any, str.upper]:
283 283 print(obj)
284 284 i = inspector.info(obj)
285 285 nt.assert_is(i['call_def'], None)
286 286
287 287 if py3compat.PY3:
288 288 exec("def f_kwarg(pos, *, kwonly): pass")
289 289
290 290 @skipif(not py3compat.PY3)
291 291 def test_definition_kwonlyargs():
292 292 i = inspector.info(f_kwarg, oname='f_kwarg') # analysis:ignore
293 293 nt.assert_equal(i['definition'], "f_kwarg(pos, *, kwonly)\n")
294 294
295 295 def test_getdoc():
296 296 class A(object):
297 297 """standard docstring"""
298 298 pass
299 299
300 300 class B(object):
301 301 """standard docstring"""
302 302 def getdoc(self):
303 303 return "custom docstring"
304 304
305 305 class C(object):
306 306 """standard docstring"""
307 307 def getdoc(self):
308 308 return None
309 309
310 310 a = A()
311 311 b = B()
312 312 c = C()
313 313
314 314 nt.assert_equal(oinspect.getdoc(a), "standard docstring")
315 315 nt.assert_equal(oinspect.getdoc(b), "custom docstring")
316 316 nt.assert_equal(oinspect.getdoc(c), "standard docstring")
317 317
318 318
319 319 def test_empty_property_has_no_source():
320 320 i = inspector.info(property(), detail_level=1)
321 321 nt.assert_is(i['source'], None)
322 322
323 323
324 324 def test_property_sources():
325 325 import zlib
326 326
327 327 class A(object):
328 328 @property
329 329 def foo(self):
330 330 return 'bar'
331 331
332 332 foo = foo.setter(lambda self, v: setattr(self, 'bar', v))
333 333
334 334 id = property(id)
335 335 compress = property(zlib.compress)
336 336
337 337 i = inspector.info(A.foo, detail_level=1)
338 338 nt.assert_in('def foo(self):', i['source'])
339 339 nt.assert_in('lambda self, v:', i['source'])
340 340
341 341 i = inspector.info(A.id, detail_level=1)
342 342 nt.assert_in('fget = <function id>', i['source'])
343 343
344 344 i = inspector.info(A.compress, detail_level=1)
345 345 nt.assert_in('fget = <function zlib.compress>', i['source'])
346 346
347 347
348 348 def test_property_docstring_is_in_info_for_detail_level_0():
349 349 class A(object):
350 350 @property
351 351 def foobar():
352 352 """This is `foobar` property."""
353 353 pass
354 354
355 355 ip.user_ns['a_obj'] = A()
356 356 nt.assert_equals(
357 357 'This is `foobar` property.',
358 358 ip.object_inspect('a_obj.foobar', detail_level=0)['docstring'])
359 359
360 360 ip.user_ns['a_cls'] = A
361 361 nt.assert_equals(
362 362 'This is `foobar` property.',
363 363 ip.object_inspect('a_cls.foobar', detail_level=0)['docstring'])
364 364
365 365
366 366 def test_pdef():
367 367 # See gh-1914
368 368 def foo(): pass
369 369 inspector.pdef(foo, 'foo')
370 370
371 371 def test_pinfo_nonascii():
372 372 # See gh-1177
373 373 from . import nonascii2
374 374 ip.user_ns['nonascii2'] = nonascii2
375 375 ip._inspect('pinfo', 'nonascii2', detail_level=1)
376 376
377 377 def test_pinfo_magic():
378 378 with AssertPrints('Docstring:'):
379 379 ip._inspect('pinfo', 'lsmagic', detail_level=0)
380 380
381 381 with AssertPrints('Source:'):
382 382 ip._inspect('pinfo', 'lsmagic', detail_level=1)
@@ -1,703 +1,703 b''
1 1 """AsyncResult objects for the client"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import print_function
7 7
8 8 import sys
9 9 import time
10 10 from datetime import datetime
11 11
12 12 from zmq import MessageTracker
13 13
14 14 from IPython.core.display import clear_output, display, display_pretty
15 from IPython.external.decorator import decorator
15 from decorator import decorator
16 16 from IPython.parallel import error
17 17 from IPython.utils.py3compat import string_types
18 18
19 19
20 20 def _raw_text(s):
21 21 display_pretty(s, raw=True)
22 22
23 23
24 24 # global empty tracker that's always done:
25 25 finished_tracker = MessageTracker()
26 26
27 27 @decorator
28 28 def check_ready(f, self, *args, **kwargs):
29 29 """Call spin() to sync state prior to calling the method."""
30 30 self.wait(0)
31 31 if not self._ready:
32 32 raise error.TimeoutError("result not ready")
33 33 return f(self, *args, **kwargs)
34 34
35 35 class AsyncResult(object):
36 36 """Class for representing results of non-blocking calls.
37 37
38 38 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
39 39 """
40 40
41 41 msg_ids = None
42 42 _targets = None
43 43 _tracker = None
44 44 _single_result = False
45 45 owner = False,
46 46
47 47 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None,
48 48 owner=False,
49 49 ):
50 50 if isinstance(msg_ids, string_types):
51 51 # always a list
52 52 msg_ids = [msg_ids]
53 53 self._single_result = True
54 54 else:
55 55 self._single_result = False
56 56 if tracker is None:
57 57 # default to always done
58 58 tracker = finished_tracker
59 59 self._client = client
60 60 self.msg_ids = msg_ids
61 61 self._fname=fname
62 62 self._targets = targets
63 63 self._tracker = tracker
64 64 self.owner = owner
65 65
66 66 self._ready = False
67 67 self._outputs_ready = False
68 68 self._success = None
69 69 self._metadata = [self._client.metadata[id] for id in self.msg_ids]
70 70
71 71 def __repr__(self):
72 72 if self._ready:
73 73 return "<%s: finished>"%(self.__class__.__name__)
74 74 else:
75 75 return "<%s: %s>"%(self.__class__.__name__,self._fname)
76 76
77 77
78 78 def _reconstruct_result(self, res):
79 79 """Reconstruct our result from actual result list (always a list)
80 80
81 81 Override me in subclasses for turning a list of results
82 82 into the expected form.
83 83 """
84 84 if self._single_result:
85 85 return res[0]
86 86 else:
87 87 return res
88 88
89 89 def get(self, timeout=-1):
90 90 """Return the result when it arrives.
91 91
92 92 If `timeout` is not ``None`` and the result does not arrive within
93 93 `timeout` seconds then ``TimeoutError`` is raised. If the
94 94 remote call raised an exception then that exception will be reraised
95 95 by get() inside a `RemoteError`.
96 96 """
97 97 if not self.ready():
98 98 self.wait(timeout)
99 99
100 100 if self._ready:
101 101 if self._success:
102 102 return self._result
103 103 else:
104 104 raise self._exception
105 105 else:
106 106 raise error.TimeoutError("Result not ready.")
107 107
108 108 def _check_ready(self):
109 109 if not self.ready():
110 110 raise error.TimeoutError("Result not ready.")
111 111
112 112 def ready(self):
113 113 """Return whether the call has completed."""
114 114 if not self._ready:
115 115 self.wait(0)
116 116 elif not self._outputs_ready:
117 117 self._wait_for_outputs(0)
118 118
119 119 return self._ready
120 120
121 121 def wait(self, timeout=-1):
122 122 """Wait until the result is available or until `timeout` seconds pass.
123 123
124 124 This method always returns None.
125 125 """
126 126 if self._ready:
127 127 self._wait_for_outputs(timeout)
128 128 return
129 129 self._ready = self._client.wait(self.msg_ids, timeout)
130 130 if self._ready:
131 131 try:
132 132 results = list(map(self._client.results.get, self.msg_ids))
133 133 self._result = results
134 134 if self._single_result:
135 135 r = results[0]
136 136 if isinstance(r, Exception):
137 137 raise r
138 138 else:
139 139 results = error.collect_exceptions(results, self._fname)
140 140 self._result = self._reconstruct_result(results)
141 141 except Exception as e:
142 142 self._exception = e
143 143 self._success = False
144 144 else:
145 145 self._success = True
146 146 finally:
147 147 if timeout is None or timeout < 0:
148 148 # cutoff infinite wait at 10s
149 149 timeout = 10
150 150 self._wait_for_outputs(timeout)
151 151
152 152 if self.owner:
153 153
154 154 self._metadata = [self._client.metadata.pop(mid) for mid in self.msg_ids]
155 155 [self._client.results.pop(mid) for mid in self.msg_ids]
156 156
157 157
158 158
159 159 def successful(self):
160 160 """Return whether the call completed without raising an exception.
161 161
162 162 Will raise ``AssertionError`` if the result is not ready.
163 163 """
164 164 assert self.ready()
165 165 return self._success
166 166
167 167 #----------------------------------------------------------------
168 168 # Extra methods not in mp.pool.AsyncResult
169 169 #----------------------------------------------------------------
170 170
171 171 def get_dict(self, timeout=-1):
172 172 """Get the results as a dict, keyed by engine_id.
173 173
174 174 timeout behavior is described in `get()`.
175 175 """
176 176
177 177 results = self.get(timeout)
178 178 if self._single_result:
179 179 results = [results]
180 180 engine_ids = [ md['engine_id'] for md in self._metadata ]
181 181
182 182
183 183 rdict = {}
184 184 for engine_id, result in zip(engine_ids, results):
185 185 if engine_id in rdict:
186 186 raise ValueError("Cannot build dict, %i jobs ran on engine #%i" % (
187 187 engine_ids.count(engine_id), engine_id)
188 188 )
189 189 else:
190 190 rdict[engine_id] = result
191 191
192 192 return rdict
193 193
194 194 @property
195 195 def result(self):
196 196 """result property wrapper for `get(timeout=-1)`."""
197 197 return self.get()
198 198
199 199 # abbreviated alias:
200 200 r = result
201 201
202 202 @property
203 203 def metadata(self):
204 204 """property for accessing execution metadata."""
205 205 if self._single_result:
206 206 return self._metadata[0]
207 207 else:
208 208 return self._metadata
209 209
210 210 @property
211 211 def result_dict(self):
212 212 """result property as a dict."""
213 213 return self.get_dict()
214 214
215 215 def __dict__(self):
216 216 return self.get_dict(0)
217 217
218 218 def abort(self):
219 219 """abort my tasks."""
220 220 assert not self.ready(), "Can't abort, I am already done!"
221 221 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
222 222
223 223 @property
224 224 def sent(self):
225 225 """check whether my messages have been sent."""
226 226 return self._tracker.done
227 227
228 228 def wait_for_send(self, timeout=-1):
229 229 """wait for pyzmq send to complete.
230 230
231 231 This is necessary when sending arrays that you intend to edit in-place.
232 232 `timeout` is in seconds, and will raise TimeoutError if it is reached
233 233 before the send completes.
234 234 """
235 235 return self._tracker.wait(timeout)
236 236
237 237 #-------------------------------------
238 238 # dict-access
239 239 #-------------------------------------
240 240
241 241 def __getitem__(self, key):
242 242 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
243 243 """
244 244 if isinstance(key, int):
245 245 self._check_ready()
246 246 return error.collect_exceptions([self._result[key]], self._fname)[0]
247 247 elif isinstance(key, slice):
248 248 self._check_ready()
249 249 return error.collect_exceptions(self._result[key], self._fname)
250 250 elif isinstance(key, string_types):
251 251 # metadata proxy *does not* require that results are done
252 252 self.wait(0)
253 253 values = [ md[key] for md in self._metadata ]
254 254 if self._single_result:
255 255 return values[0]
256 256 else:
257 257 return values
258 258 else:
259 259 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
260 260
261 261 def __getattr__(self, key):
262 262 """getattr maps to getitem for convenient attr access to metadata."""
263 263 try:
264 264 return self.__getitem__(key)
265 265 except (error.TimeoutError, KeyError):
266 266 raise AttributeError("%r object has no attribute %r"%(
267 267 self.__class__.__name__, key))
268 268
269 269 # asynchronous iterator:
270 270 def __iter__(self):
271 271 if self._single_result:
272 272 raise TypeError("AsyncResults with a single result are not iterable.")
273 273 try:
274 274 rlist = self.get(0)
275 275 except error.TimeoutError:
276 276 # wait for each result individually
277 277 for msg_id in self.msg_ids:
278 278 ar = AsyncResult(self._client, msg_id, self._fname)
279 279 yield ar.get()
280 280 else:
281 281 # already done
282 282 for r in rlist:
283 283 yield r
284 284
285 285 def __len__(self):
286 286 return len(self.msg_ids)
287 287
288 288 #-------------------------------------
289 289 # Sugar methods and attributes
290 290 #-------------------------------------
291 291
292 292 def timedelta(self, start, end, start_key=min, end_key=max):
293 293 """compute the difference between two sets of timestamps
294 294
295 295 The default behavior is to use the earliest of the first
296 296 and the latest of the second list, but this can be changed
297 297 by passing a different
298 298
299 299 Parameters
300 300 ----------
301 301
302 302 start : one or more datetime objects (e.g. ar.submitted)
303 303 end : one or more datetime objects (e.g. ar.received)
304 304 start_key : callable
305 305 Function to call on `start` to extract the relevant
306 306 entry [defalt: min]
307 307 end_key : callable
308 308 Function to call on `end` to extract the relevant
309 309 entry [default: max]
310 310
311 311 Returns
312 312 -------
313 313
314 314 dt : float
315 315 The time elapsed (in seconds) between the two selected timestamps.
316 316 """
317 317 if not isinstance(start, datetime):
318 318 # handle single_result AsyncResults, where ar.stamp is single object,
319 319 # not a list
320 320 start = start_key(start)
321 321 if not isinstance(end, datetime):
322 322 # handle single_result AsyncResults, where ar.stamp is single object,
323 323 # not a list
324 324 end = end_key(end)
325 325 return (end - start).total_seconds()
326 326
327 327 @property
328 328 def progress(self):
329 329 """the number of tasks which have been completed at this point.
330 330
331 331 Fractional progress would be given by 1.0 * ar.progress / len(ar)
332 332 """
333 333 self.wait(0)
334 334 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
335 335
336 336 @property
337 337 def elapsed(self):
338 338 """elapsed time since initial submission"""
339 339 if self.ready():
340 340 return self.wall_time
341 341
342 342 now = submitted = datetime.now()
343 343 for msg_id in self.msg_ids:
344 344 if msg_id in self._client.metadata:
345 345 stamp = self._client.metadata[msg_id]['submitted']
346 346 if stamp and stamp < submitted:
347 347 submitted = stamp
348 348 return (now-submitted).total_seconds()
349 349
350 350 @property
351 351 @check_ready
352 352 def serial_time(self):
353 353 """serial computation time of a parallel calculation
354 354
355 355 Computed as the sum of (completed-started) of each task
356 356 """
357 357 t = 0
358 358 for md in self._metadata:
359 359 t += (md['completed'] - md['started']).total_seconds()
360 360 return t
361 361
362 362 @property
363 363 @check_ready
364 364 def wall_time(self):
365 365 """actual computation time of a parallel calculation
366 366
367 367 Computed as the time between the latest `received` stamp
368 368 and the earliest `submitted`.
369 369
370 370 Only reliable if Client was spinning/waiting when the task finished, because
371 371 the `received` timestamp is created when a result is pulled off of the zmq queue,
372 372 which happens as a result of `client.spin()`.
373 373
374 374 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
375 375
376 376 """
377 377 return self.timedelta(self.submitted, self.received)
378 378
379 379 def wait_interactive(self, interval=1., timeout=-1):
380 380 """interactive wait, printing progress at regular intervals"""
381 381 if timeout is None:
382 382 timeout = -1
383 383 N = len(self)
384 384 tic = time.time()
385 385 while not self.ready() and (timeout < 0 or time.time() - tic <= timeout):
386 386 self.wait(interval)
387 387 clear_output(wait=True)
388 388 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
389 389 sys.stdout.flush()
390 390 print()
391 391 print("done")
392 392
393 393 def _republish_displaypub(self, content, eid):
394 394 """republish individual displaypub content dicts"""
395 395 try:
396 396 ip = get_ipython()
397 397 except NameError:
398 398 # displaypub is meaningless outside IPython
399 399 return
400 400 md = content['metadata'] or {}
401 401 md['engine'] = eid
402 402 ip.display_pub.publish(data=content['data'], metadata=md)
403 403
404 404 def _display_stream(self, text, prefix='', file=None):
405 405 if not text:
406 406 # nothing to display
407 407 return
408 408 if file is None:
409 409 file = sys.stdout
410 410 end = '' if text.endswith('\n') else '\n'
411 411
412 412 multiline = text.count('\n') > int(text.endswith('\n'))
413 413 if prefix and multiline and not text.startswith('\n'):
414 414 prefix = prefix + '\n'
415 415 print("%s%s" % (prefix, text), file=file, end=end)
416 416
417 417
418 418 def _display_single_result(self):
419 419 self._display_stream(self.stdout)
420 420 self._display_stream(self.stderr, file=sys.stderr)
421 421
422 422 try:
423 423 get_ipython()
424 424 except NameError:
425 425 # displaypub is meaningless outside IPython
426 426 return
427 427
428 428 for output in self.outputs:
429 429 self._republish_displaypub(output, self.engine_id)
430 430
431 431 if self.execute_result is not None:
432 432 display(self.get())
433 433
434 434 def _wait_for_outputs(self, timeout=-1):
435 435 """wait for the 'status=idle' message that indicates we have all outputs
436 436 """
437 437 if self._outputs_ready or not self._success:
438 438 # don't wait on errors
439 439 return
440 440
441 441 # cast None to -1 for infinite timeout
442 442 if timeout is None:
443 443 timeout = -1
444 444
445 445 tic = time.time()
446 446 while True:
447 447 self._client._flush_iopub(self._client._iopub_socket)
448 448 self._outputs_ready = all(md['outputs_ready']
449 449 for md in self._metadata)
450 450 if self._outputs_ready or \
451 451 (timeout >= 0 and time.time() > tic + timeout):
452 452 break
453 453 time.sleep(0.01)
454 454
455 455 @check_ready
456 456 def display_outputs(self, groupby="type"):
457 457 """republish the outputs of the computation
458 458
459 459 Parameters
460 460 ----------
461 461
462 462 groupby : str [default: type]
463 463 if 'type':
464 464 Group outputs by type (show all stdout, then all stderr, etc.):
465 465
466 466 [stdout:1] foo
467 467 [stdout:2] foo
468 468 [stderr:1] bar
469 469 [stderr:2] bar
470 470 if 'engine':
471 471 Display outputs for each engine before moving on to the next:
472 472
473 473 [stdout:1] foo
474 474 [stderr:1] bar
475 475 [stdout:2] foo
476 476 [stderr:2] bar
477 477
478 478 if 'order':
479 479 Like 'type', but further collate individual displaypub
480 480 outputs. This is meant for cases of each command producing
481 481 several plots, and you would like to see all of the first
482 482 plots together, then all of the second plots, and so on.
483 483 """
484 484 if self._single_result:
485 485 self._display_single_result()
486 486 return
487 487
488 488 stdouts = self.stdout
489 489 stderrs = self.stderr
490 490 execute_results = self.execute_result
491 491 output_lists = self.outputs
492 492 results = self.get()
493 493
494 494 targets = self.engine_id
495 495
496 496 if groupby == "engine":
497 497 for eid,stdout,stderr,outputs,r,execute_result in zip(
498 498 targets, stdouts, stderrs, output_lists, results, execute_results
499 499 ):
500 500 self._display_stream(stdout, '[stdout:%i] ' % eid)
501 501 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
502 502
503 503 try:
504 504 get_ipython()
505 505 except NameError:
506 506 # displaypub is meaningless outside IPython
507 507 return
508 508
509 509 if outputs or execute_result is not None:
510 510 _raw_text('[output:%i]' % eid)
511 511
512 512 for output in outputs:
513 513 self._republish_displaypub(output, eid)
514 514
515 515 if execute_result is not None:
516 516 display(r)
517 517
518 518 elif groupby in ('type', 'order'):
519 519 # republish stdout:
520 520 for eid,stdout in zip(targets, stdouts):
521 521 self._display_stream(stdout, '[stdout:%i] ' % eid)
522 522
523 523 # republish stderr:
524 524 for eid,stderr in zip(targets, stderrs):
525 525 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
526 526
527 527 try:
528 528 get_ipython()
529 529 except NameError:
530 530 # displaypub is meaningless outside IPython
531 531 return
532 532
533 533 if groupby == 'order':
534 534 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
535 535 N = max(len(outputs) for outputs in output_lists)
536 536 for i in range(N):
537 537 for eid in targets:
538 538 outputs = output_dict[eid]
539 539 if len(outputs) >= N:
540 540 _raw_text('[output:%i]' % eid)
541 541 self._republish_displaypub(outputs[i], eid)
542 542 else:
543 543 # republish displaypub output
544 544 for eid,outputs in zip(targets, output_lists):
545 545 if outputs:
546 546 _raw_text('[output:%i]' % eid)
547 547 for output in outputs:
548 548 self._republish_displaypub(output, eid)
549 549
550 550 # finally, add execute_result:
551 551 for eid,r,execute_result in zip(targets, results, execute_results):
552 552 if execute_result is not None:
553 553 display(r)
554 554
555 555 else:
556 556 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
557 557
558 558
559 559
560 560
561 561 class AsyncMapResult(AsyncResult):
562 562 """Class for representing results of non-blocking gathers.
563 563
564 564 This will properly reconstruct the gather.
565 565
566 566 This class is iterable at any time, and will wait on results as they come.
567 567
568 568 If ordered=False, then the first results to arrive will come first, otherwise
569 569 results will be yielded in the order they were submitted.
570 570
571 571 """
572 572
573 573 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
574 574 AsyncResult.__init__(self, client, msg_ids, fname=fname)
575 575 self._mapObject = mapObject
576 576 self._single_result = False
577 577 self.ordered = ordered
578 578
579 579 def _reconstruct_result(self, res):
580 580 """Perform the gather on the actual results."""
581 581 return self._mapObject.joinPartitions(res)
582 582
583 583 # asynchronous iterator:
584 584 def __iter__(self):
585 585 it = self._ordered_iter if self.ordered else self._unordered_iter
586 586 for r in it():
587 587 yield r
588 588
589 589 # asynchronous ordered iterator:
590 590 def _ordered_iter(self):
591 591 """iterator for results *as they arrive*, preserving submission order."""
592 592 try:
593 593 rlist = self.get(0)
594 594 except error.TimeoutError:
595 595 # wait for each result individually
596 596 for msg_id in self.msg_ids:
597 597 ar = AsyncResult(self._client, msg_id, self._fname)
598 598 rlist = ar.get()
599 599 try:
600 600 for r in rlist:
601 601 yield r
602 602 except TypeError:
603 603 # flattened, not a list
604 604 # this could get broken by flattened data that returns iterables
605 605 # but most calls to map do not expose the `flatten` argument
606 606 yield rlist
607 607 else:
608 608 # already done
609 609 for r in rlist:
610 610 yield r
611 611
612 612 # asynchronous unordered iterator:
613 613 def _unordered_iter(self):
614 614 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
615 615 try:
616 616 rlist = self.get(0)
617 617 except error.TimeoutError:
618 618 pending = set(self.msg_ids)
619 619 while pending:
620 620 try:
621 621 self._client.wait(pending, 1e-3)
622 622 except error.TimeoutError:
623 623 # ignore timeout error, because that only means
624 624 # *some* jobs are outstanding
625 625 pass
626 626 # update ready set with those no longer outstanding:
627 627 ready = pending.difference(self._client.outstanding)
628 628 # update pending to exclude those that are finished
629 629 pending = pending.difference(ready)
630 630 while ready:
631 631 msg_id = ready.pop()
632 632 ar = AsyncResult(self._client, msg_id, self._fname)
633 633 rlist = ar.get()
634 634 try:
635 635 for r in rlist:
636 636 yield r
637 637 except TypeError:
638 638 # flattened, not a list
639 639 # this could get broken by flattened data that returns iterables
640 640 # but most calls to map do not expose the `flatten` argument
641 641 yield rlist
642 642 else:
643 643 # already done
644 644 for r in rlist:
645 645 yield r
646 646
647 647
648 648 class AsyncHubResult(AsyncResult):
649 649 """Class to wrap pending results that must be requested from the Hub.
650 650
651 651 Note that waiting/polling on these objects requires polling the Hubover the network,
652 652 so use `AsyncHubResult.wait()` sparingly.
653 653 """
654 654
655 655 def _wait_for_outputs(self, timeout=-1):
656 656 """no-op, because HubResults are never incomplete"""
657 657 self._outputs_ready = True
658 658
659 659 def wait(self, timeout=-1):
660 660 """wait for result to complete."""
661 661 start = time.time()
662 662 if self._ready:
663 663 return
664 664 local_ids = [m for m in self.msg_ids if m in self._client.outstanding]
665 665 local_ready = self._client.wait(local_ids, timeout)
666 666 if local_ready:
667 667 remote_ids = [m for m in self.msg_ids if m not in self._client.results]
668 668 if not remote_ids:
669 669 self._ready = True
670 670 else:
671 671 rdict = self._client.result_status(remote_ids, status_only=False)
672 672 pending = rdict['pending']
673 673 while pending and (timeout < 0 or time.time() < start+timeout):
674 674 rdict = self._client.result_status(remote_ids, status_only=False)
675 675 pending = rdict['pending']
676 676 if pending:
677 677 time.sleep(0.1)
678 678 if not pending:
679 679 self._ready = True
680 680 if self._ready:
681 681 try:
682 682 results = list(map(self._client.results.get, self.msg_ids))
683 683 self._result = results
684 684 if self._single_result:
685 685 r = results[0]
686 686 if isinstance(r, Exception):
687 687 raise r
688 688 else:
689 689 results = error.collect_exceptions(results, self._fname)
690 690 self._result = self._reconstruct_result(results)
691 691 except Exception as e:
692 692 self._exception = e
693 693 self._success = False
694 694 else:
695 695 self._success = True
696 696 finally:
697 697 self._metadata = [self._client.metadata[mid] for mid in self.msg_ids]
698 698 if self.owner:
699 699 [self._client.metadata.pop(mid) for mid in self.msg_ids]
700 700 [self._client.results.pop(mid) for mid in self.msg_ids]
701 701
702 702
703 703 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
@@ -1,1893 +1,1893 b''
1 1 """A semi-synchronous Client for IPython parallel"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import print_function
7 7
8 8 import os
9 9 import json
10 10 import sys
11 11 from threading import Thread, Event
12 12 import time
13 13 import warnings
14 14 from datetime import datetime
15 15 from getpass import getpass
16 16 from pprint import pprint
17 17
18 18 pjoin = os.path.join
19 19
20 20 import zmq
21 21
22 22 from IPython.config.configurable import MultipleInstanceError
23 23 from IPython.core.application import BaseIPythonApplication
24 24 from IPython.core.profiledir import ProfileDir, ProfileDirError
25 25
26 26 from IPython.utils.capture import RichOutput
27 27 from IPython.utils.coloransi import TermColors
28 28 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
29 29 from IPython.utils.localinterfaces import localhost, is_local_ip
30 30 from IPython.utils.path import get_ipython_dir, compress_user
31 31 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
32 32 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
33 33 Dict, List, Bool, Set, Any)
34 from IPython.external.decorator import decorator
34 from decorator import decorator
35 35
36 36 from IPython.parallel import Reference
37 37 from IPython.parallel import error
38 38 from IPython.parallel import util
39 39
40 40 from IPython.kernel.zmq.session import Session, Message
41 41 from IPython.kernel.zmq import serialize
42 42
43 43 from .asyncresult import AsyncResult, AsyncHubResult
44 44 from .view import DirectView, LoadBalancedView
45 45
46 46 #--------------------------------------------------------------------------
47 47 # Decorators for Client methods
48 48 #--------------------------------------------------------------------------
49 49
50 50
51 51 @decorator
52 52 def spin_first(f, self, *args, **kwargs):
53 53 """Call spin() to sync state prior to calling the method."""
54 54 self.spin()
55 55 return f(self, *args, **kwargs)
56 56
57 57
58 58 #--------------------------------------------------------------------------
59 59 # Classes
60 60 #--------------------------------------------------------------------------
61 61
62 62 _no_connection_file_msg = """
63 63 Failed to connect because no Controller could be found.
64 64 Please double-check your profile and ensure that a cluster is running.
65 65 """
66 66
67 67 class ExecuteReply(RichOutput):
68 68 """wrapper for finished Execute results"""
69 69 def __init__(self, msg_id, content, metadata):
70 70 self.msg_id = msg_id
71 71 self._content = content
72 72 self.execution_count = content['execution_count']
73 73 self.metadata = metadata
74 74
75 75 # RichOutput overrides
76 76
77 77 @property
78 78 def source(self):
79 79 execute_result = self.metadata['execute_result']
80 80 if execute_result:
81 81 return execute_result.get('source', '')
82 82
83 83 @property
84 84 def data(self):
85 85 execute_result = self.metadata['execute_result']
86 86 if execute_result:
87 87 return execute_result.get('data', {})
88 88
89 89 @property
90 90 def _metadata(self):
91 91 execute_result = self.metadata['execute_result']
92 92 if execute_result:
93 93 return execute_result.get('metadata', {})
94 94
95 95 def display(self):
96 96 from IPython.display import publish_display_data
97 97 publish_display_data(self.data, self.metadata)
98 98
99 99 def _repr_mime_(self, mime):
100 100 if mime not in self.data:
101 101 return
102 102 data = self.data[mime]
103 103 if mime in self._metadata:
104 104 return data, self._metadata[mime]
105 105 else:
106 106 return data
107 107
108 108 def __getitem__(self, key):
109 109 return self.metadata[key]
110 110
111 111 def __getattr__(self, key):
112 112 if key not in self.metadata:
113 113 raise AttributeError(key)
114 114 return self.metadata[key]
115 115
116 116 def __repr__(self):
117 117 execute_result = self.metadata['execute_result'] or {'data':{}}
118 118 text_out = execute_result['data'].get('text/plain', '')
119 119 if len(text_out) > 32:
120 120 text_out = text_out[:29] + '...'
121 121
122 122 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
123 123
124 124 def _repr_pretty_(self, p, cycle):
125 125 execute_result = self.metadata['execute_result'] or {'data':{}}
126 126 text_out = execute_result['data'].get('text/plain', '')
127 127
128 128 if not text_out:
129 129 return
130 130
131 131 try:
132 132 ip = get_ipython()
133 133 except NameError:
134 134 colors = "NoColor"
135 135 else:
136 136 colors = ip.colors
137 137
138 138 if colors == "NoColor":
139 139 out = normal = ""
140 140 else:
141 141 out = TermColors.Red
142 142 normal = TermColors.Normal
143 143
144 144 if '\n' in text_out and not text_out.startswith('\n'):
145 145 # add newline for multiline reprs
146 146 text_out = '\n' + text_out
147 147
148 148 p.text(
149 149 out + u'Out[%i:%i]: ' % (
150 150 self.metadata['engine_id'], self.execution_count
151 151 ) + normal + text_out
152 152 )
153 153
154 154
155 155 class Metadata(dict):
156 156 """Subclass of dict for initializing metadata values.
157 157
158 158 Attribute access works on keys.
159 159
160 160 These objects have a strict set of keys - errors will raise if you try
161 161 to add new keys.
162 162 """
163 163 def __init__(self, *args, **kwargs):
164 164 dict.__init__(self)
165 165 md = {'msg_id' : None,
166 166 'submitted' : None,
167 167 'started' : None,
168 168 'completed' : None,
169 169 'received' : None,
170 170 'engine_uuid' : None,
171 171 'engine_id' : None,
172 172 'follow' : None,
173 173 'after' : None,
174 174 'status' : None,
175 175
176 176 'execute_input' : None,
177 177 'execute_result' : None,
178 178 'error' : None,
179 179 'stdout' : '',
180 180 'stderr' : '',
181 181 'outputs' : [],
182 182 'data': {},
183 183 'outputs_ready' : False,
184 184 }
185 185 self.update(md)
186 186 self.update(dict(*args, **kwargs))
187 187
188 188 def __getattr__(self, key):
189 189 """getattr aliased to getitem"""
190 190 if key in self:
191 191 return self[key]
192 192 else:
193 193 raise AttributeError(key)
194 194
195 195 def __setattr__(self, key, value):
196 196 """setattr aliased to setitem, with strict"""
197 197 if key in self:
198 198 self[key] = value
199 199 else:
200 200 raise AttributeError(key)
201 201
202 202 def __setitem__(self, key, value):
203 203 """strict static key enforcement"""
204 204 if key in self:
205 205 dict.__setitem__(self, key, value)
206 206 else:
207 207 raise KeyError(key)
208 208
209 209
210 210 class Client(HasTraits):
211 211 """A semi-synchronous client to the IPython ZMQ cluster
212 212
213 213 Parameters
214 214 ----------
215 215
216 216 url_file : str/unicode; path to ipcontroller-client.json
217 217 This JSON file should contain all the information needed to connect to a cluster,
218 218 and is likely the only argument needed.
219 219 Connection information for the Hub's registration. If a json connector
220 220 file is given, then likely no further configuration is necessary.
221 221 [Default: use profile]
222 222 profile : bytes
223 223 The name of the Cluster profile to be used to find connector information.
224 224 If run from an IPython application, the default profile will be the same
225 225 as the running application, otherwise it will be 'default'.
226 226 cluster_id : str
227 227 String id to added to runtime files, to prevent name collisions when using
228 228 multiple clusters with a single profile simultaneously.
229 229 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
230 230 Since this is text inserted into filenames, typical recommendations apply:
231 231 Simple character strings are ideal, and spaces are not recommended (but
232 232 should generally work)
233 233 context : zmq.Context
234 234 Pass an existing zmq.Context instance, otherwise the client will create its own.
235 235 debug : bool
236 236 flag for lots of message printing for debug purposes
237 237 timeout : int/float
238 238 time (in seconds) to wait for connection replies from the Hub
239 239 [Default: 10]
240 240
241 241 #-------------- session related args ----------------
242 242
243 243 config : Config object
244 244 If specified, this will be relayed to the Session for configuration
245 245 username : str
246 246 set username for the session object
247 247
248 248 #-------------- ssh related args ----------------
249 249 # These are args for configuring the ssh tunnel to be used
250 250 # credentials are used to forward connections over ssh to the Controller
251 251 # Note that the ip given in `addr` needs to be relative to sshserver
252 252 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
253 253 # and set sshserver as the same machine the Controller is on. However,
254 254 # the only requirement is that sshserver is able to see the Controller
255 255 # (i.e. is within the same trusted network).
256 256
257 257 sshserver : str
258 258 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
259 259 If keyfile or password is specified, and this is not, it will default to
260 260 the ip given in addr.
261 261 sshkey : str; path to ssh private key file
262 262 This specifies a key to be used in ssh login, default None.
263 263 Regular default ssh keys will be used without specifying this argument.
264 264 password : str
265 265 Your ssh password to sshserver. Note that if this is left None,
266 266 you will be prompted for it if passwordless key based login is unavailable.
267 267 paramiko : bool
268 268 flag for whether to use paramiko instead of shell ssh for tunneling.
269 269 [default: True on win32, False else]
270 270
271 271
272 272 Attributes
273 273 ----------
274 274
275 275 ids : list of int engine IDs
276 276 requesting the ids attribute always synchronizes
277 277 the registration state. To request ids without synchronization,
278 278 use semi-private _ids attributes.
279 279
280 280 history : list of msg_ids
281 281 a list of msg_ids, keeping track of all the execution
282 282 messages you have submitted in order.
283 283
284 284 outstanding : set of msg_ids
285 285 a set of msg_ids that have been submitted, but whose
286 286 results have not yet been received.
287 287
288 288 results : dict
289 289 a dict of all our results, keyed by msg_id
290 290
291 291 block : bool
292 292 determines default behavior when block not specified
293 293 in execution methods
294 294
295 295 Methods
296 296 -------
297 297
298 298 spin
299 299 flushes incoming results and registration state changes
300 300 control methods spin, and requesting `ids` also ensures up to date
301 301
302 302 wait
303 303 wait on one or more msg_ids
304 304
305 305 execution methods
306 306 apply
307 307 legacy: execute, run
308 308
309 309 data movement
310 310 push, pull, scatter, gather
311 311
312 312 query methods
313 313 queue_status, get_result, purge, result_status
314 314
315 315 control methods
316 316 abort, shutdown
317 317
318 318 """
319 319
320 320
321 321 block = Bool(False)
322 322 outstanding = Set()
323 323 results = Instance('collections.defaultdict', (dict,))
324 324 metadata = Instance('collections.defaultdict', (Metadata,))
325 325 history = List()
326 326 debug = Bool(False)
327 327 _spin_thread = Any()
328 328 _stop_spinning = Any()
329 329
330 330 profile=Unicode()
331 331 def _profile_default(self):
332 332 if BaseIPythonApplication.initialized():
333 333 # an IPython app *might* be running, try to get its profile
334 334 try:
335 335 return BaseIPythonApplication.instance().profile
336 336 except (AttributeError, MultipleInstanceError):
337 337 # could be a *different* subclass of config.Application,
338 338 # which would raise one of these two errors.
339 339 return u'default'
340 340 else:
341 341 return u'default'
342 342
343 343
344 344 _outstanding_dict = Instance('collections.defaultdict', (set,))
345 345 _ids = List()
346 346 _connected=Bool(False)
347 347 _ssh=Bool(False)
348 348 _context = Instance('zmq.Context')
349 349 _config = Dict()
350 350 _engines=Instance(util.ReverseDict, (), {})
351 351 # _hub_socket=Instance('zmq.Socket')
352 352 _query_socket=Instance('zmq.Socket')
353 353 _control_socket=Instance('zmq.Socket')
354 354 _iopub_socket=Instance('zmq.Socket')
355 355 _notification_socket=Instance('zmq.Socket')
356 356 _mux_socket=Instance('zmq.Socket')
357 357 _task_socket=Instance('zmq.Socket')
358 358 _task_scheme=Unicode()
359 359 _closed = False
360 360 _ignored_control_replies=Integer(0)
361 361 _ignored_hub_replies=Integer(0)
362 362
363 363 def __new__(self, *args, **kw):
364 364 # don't raise on positional args
365 365 return HasTraits.__new__(self, **kw)
366 366
367 367 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
368 368 context=None, debug=False,
369 369 sshserver=None, sshkey=None, password=None, paramiko=None,
370 370 timeout=10, cluster_id=None, **extra_args
371 371 ):
372 372 if profile:
373 373 super(Client, self).__init__(debug=debug, profile=profile)
374 374 else:
375 375 super(Client, self).__init__(debug=debug)
376 376 if context is None:
377 377 context = zmq.Context.instance()
378 378 self._context = context
379 379 self._stop_spinning = Event()
380 380
381 381 if 'url_or_file' in extra_args:
382 382 url_file = extra_args['url_or_file']
383 383 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
384 384
385 385 if url_file and util.is_url(url_file):
386 386 raise ValueError("single urls cannot be specified, url-files must be used.")
387 387
388 388 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
389 389
390 390 no_file_msg = '\n'.join([
391 391 "You have attempted to connect to an IPython Cluster but no Controller could be found.",
392 392 "Please double-check your configuration and ensure that a cluster is running.",
393 393 ])
394 394
395 395 if self._cd is not None:
396 396 if url_file is None:
397 397 if not cluster_id:
398 398 client_json = 'ipcontroller-client.json'
399 399 else:
400 400 client_json = 'ipcontroller-%s-client.json' % cluster_id
401 401 url_file = pjoin(self._cd.security_dir, client_json)
402 402 if not os.path.exists(url_file):
403 403 msg = '\n'.join([
404 404 "Connection file %r not found." % compress_user(url_file),
405 405 no_file_msg,
406 406 ])
407 407 raise IOError(msg)
408 408 if url_file is None:
409 409 raise IOError(no_file_msg)
410 410
411 411 if not os.path.exists(url_file):
412 412 # Connection file explicitly specified, but not found
413 413 raise IOError("Connection file %r not found. Is a controller running?" % \
414 414 compress_user(url_file)
415 415 )
416 416
417 417 with open(url_file) as f:
418 418 cfg = json.load(f)
419 419
420 420 self._task_scheme = cfg['task_scheme']
421 421
422 422 # sync defaults from args, json:
423 423 if sshserver:
424 424 cfg['ssh'] = sshserver
425 425
426 426 location = cfg.setdefault('location', None)
427 427
428 428 proto,addr = cfg['interface'].split('://')
429 429 addr = util.disambiguate_ip_address(addr, location)
430 430 cfg['interface'] = "%s://%s" % (proto, addr)
431 431
432 432 # turn interface,port into full urls:
433 433 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
434 434 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
435 435
436 436 url = cfg['registration']
437 437
438 438 if location is not None and addr == localhost():
439 439 # location specified, and connection is expected to be local
440 440 if not is_local_ip(location) and not sshserver:
441 441 # load ssh from JSON *only* if the controller is not on
442 442 # this machine
443 443 sshserver=cfg['ssh']
444 444 if not is_local_ip(location) and not sshserver:
445 445 # warn if no ssh specified, but SSH is probably needed
446 446 # This is only a warning, because the most likely cause
447 447 # is a local Controller on a laptop whose IP is dynamic
448 448 warnings.warn("""
449 449 Controller appears to be listening on localhost, but not on this machine.
450 450 If this is true, you should specify Client(...,sshserver='you@%s')
451 451 or instruct your controller to listen on an external IP."""%location,
452 452 RuntimeWarning)
453 453 elif not sshserver:
454 454 # otherwise sync with cfg
455 455 sshserver = cfg['ssh']
456 456
457 457 self._config = cfg
458 458
459 459 self._ssh = bool(sshserver or sshkey or password)
460 460 if self._ssh and sshserver is None:
461 461 # default to ssh via localhost
462 462 sshserver = addr
463 463 if self._ssh and password is None:
464 464 from zmq.ssh import tunnel
465 465 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
466 466 password=False
467 467 else:
468 468 password = getpass("SSH Password for %s: "%sshserver)
469 469 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
470 470
471 471 # configure and construct the session
472 472 try:
473 473 extra_args['packer'] = cfg['pack']
474 474 extra_args['unpacker'] = cfg['unpack']
475 475 extra_args['key'] = cast_bytes(cfg['key'])
476 476 extra_args['signature_scheme'] = cfg['signature_scheme']
477 477 except KeyError as exc:
478 478 msg = '\n'.join([
479 479 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
480 480 "If you are reusing connection files, remove them and start ipcontroller again."
481 481 ])
482 482 raise ValueError(msg.format(exc.message))
483 483
484 484 self.session = Session(**extra_args)
485 485
486 486 self._query_socket = self._context.socket(zmq.DEALER)
487 487
488 488 if self._ssh:
489 489 from zmq.ssh import tunnel
490 490 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
491 491 else:
492 492 self._query_socket.connect(cfg['registration'])
493 493
494 494 self.session.debug = self.debug
495 495
496 496 self._notification_handlers = {'registration_notification' : self._register_engine,
497 497 'unregistration_notification' : self._unregister_engine,
498 498 'shutdown_notification' : lambda msg: self.close(),
499 499 }
500 500 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
501 501 'apply_reply' : self._handle_apply_reply}
502 502
503 503 try:
504 504 self._connect(sshserver, ssh_kwargs, timeout)
505 505 except:
506 506 self.close(linger=0)
507 507 raise
508 508
509 509 # last step: setup magics, if we are in IPython:
510 510
511 511 try:
512 512 ip = get_ipython()
513 513 except NameError:
514 514 return
515 515 else:
516 516 if 'px' not in ip.magics_manager.magics:
517 517 # in IPython but we are the first Client.
518 518 # activate a default view for parallel magics.
519 519 self.activate()
520 520
521 521 def __del__(self):
522 522 """cleanup sockets, but _not_ context."""
523 523 self.close()
524 524
525 525 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
526 526 if ipython_dir is None:
527 527 ipython_dir = get_ipython_dir()
528 528 if profile_dir is not None:
529 529 try:
530 530 self._cd = ProfileDir.find_profile_dir(profile_dir)
531 531 return
532 532 except ProfileDirError:
533 533 pass
534 534 elif profile is not None:
535 535 try:
536 536 self._cd = ProfileDir.find_profile_dir_by_name(
537 537 ipython_dir, profile)
538 538 return
539 539 except ProfileDirError:
540 540 pass
541 541 self._cd = None
542 542
543 543 def _update_engines(self, engines):
544 544 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
545 545 for k,v in iteritems(engines):
546 546 eid = int(k)
547 547 if eid not in self._engines:
548 548 self._ids.append(eid)
549 549 self._engines[eid] = v
550 550 self._ids = sorted(self._ids)
551 551 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
552 552 self._task_scheme == 'pure' and self._task_socket:
553 553 self._stop_scheduling_tasks()
554 554
555 555 def _stop_scheduling_tasks(self):
556 556 """Stop scheduling tasks because an engine has been unregistered
557 557 from a pure ZMQ scheduler.
558 558 """
559 559 self._task_socket.close()
560 560 self._task_socket = None
561 561 msg = "An engine has been unregistered, and we are using pure " +\
562 562 "ZMQ task scheduling. Task farming will be disabled."
563 563 if self.outstanding:
564 564 msg += " If you were running tasks when this happened, " +\
565 565 "some `outstanding` msg_ids may never resolve."
566 566 warnings.warn(msg, RuntimeWarning)
567 567
568 568 def _build_targets(self, targets):
569 569 """Turn valid target IDs or 'all' into two lists:
570 570 (int_ids, uuids).
571 571 """
572 572 if not self._ids:
573 573 # flush notification socket if no engines yet, just in case
574 574 if not self.ids:
575 575 raise error.NoEnginesRegistered("Can't build targets without any engines")
576 576
577 577 if targets is None:
578 578 targets = self._ids
579 579 elif isinstance(targets, string_types):
580 580 if targets.lower() == 'all':
581 581 targets = self._ids
582 582 else:
583 583 raise TypeError("%r not valid str target, must be 'all'"%(targets))
584 584 elif isinstance(targets, int):
585 585 if targets < 0:
586 586 targets = self.ids[targets]
587 587 if targets not in self._ids:
588 588 raise IndexError("No such engine: %i"%targets)
589 589 targets = [targets]
590 590
591 591 if isinstance(targets, slice):
592 592 indices = list(range(len(self._ids))[targets])
593 593 ids = self.ids
594 594 targets = [ ids[i] for i in indices ]
595 595
596 596 if not isinstance(targets, (tuple, list, xrange)):
597 597 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
598 598
599 599 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
600 600
601 601 def _connect(self, sshserver, ssh_kwargs, timeout):
602 602 """setup all our socket connections to the cluster. This is called from
603 603 __init__."""
604 604
605 605 # Maybe allow reconnecting?
606 606 if self._connected:
607 607 return
608 608 self._connected=True
609 609
610 610 def connect_socket(s, url):
611 611 if self._ssh:
612 612 from zmq.ssh import tunnel
613 613 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
614 614 else:
615 615 return s.connect(url)
616 616
617 617 self.session.send(self._query_socket, 'connection_request')
618 618 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
619 619 poller = zmq.Poller()
620 620 poller.register(self._query_socket, zmq.POLLIN)
621 621 # poll expects milliseconds, timeout is seconds
622 622 evts = poller.poll(timeout*1000)
623 623 if not evts:
624 624 raise error.TimeoutError("Hub connection request timed out")
625 625 idents,msg = self.session.recv(self._query_socket,mode=0)
626 626 if self.debug:
627 627 pprint(msg)
628 628 content = msg['content']
629 629 # self._config['registration'] = dict(content)
630 630 cfg = self._config
631 631 if content['status'] == 'ok':
632 632 self._mux_socket = self._context.socket(zmq.DEALER)
633 633 connect_socket(self._mux_socket, cfg['mux'])
634 634
635 635 self._task_socket = self._context.socket(zmq.DEALER)
636 636 connect_socket(self._task_socket, cfg['task'])
637 637
638 638 self._notification_socket = self._context.socket(zmq.SUB)
639 639 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
640 640 connect_socket(self._notification_socket, cfg['notification'])
641 641
642 642 self._control_socket = self._context.socket(zmq.DEALER)
643 643 connect_socket(self._control_socket, cfg['control'])
644 644
645 645 self._iopub_socket = self._context.socket(zmq.SUB)
646 646 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
647 647 connect_socket(self._iopub_socket, cfg['iopub'])
648 648
649 649 self._update_engines(dict(content['engines']))
650 650 else:
651 651 self._connected = False
652 652 raise Exception("Failed to connect!")
653 653
654 654 #--------------------------------------------------------------------------
655 655 # handlers and callbacks for incoming messages
656 656 #--------------------------------------------------------------------------
657 657
658 658 def _unwrap_exception(self, content):
659 659 """unwrap exception, and remap engine_id to int."""
660 660 e = error.unwrap_exception(content)
661 661 # print e.traceback
662 662 if e.engine_info:
663 663 e_uuid = e.engine_info['engine_uuid']
664 664 eid = self._engines[e_uuid]
665 665 e.engine_info['engine_id'] = eid
666 666 return e
667 667
668 668 def _extract_metadata(self, msg):
669 669 header = msg['header']
670 670 parent = msg['parent_header']
671 671 msg_meta = msg['metadata']
672 672 content = msg['content']
673 673 md = {'msg_id' : parent['msg_id'],
674 674 'received' : datetime.now(),
675 675 'engine_uuid' : msg_meta.get('engine', None),
676 676 'follow' : msg_meta.get('follow', []),
677 677 'after' : msg_meta.get('after', []),
678 678 'status' : content['status'],
679 679 }
680 680
681 681 if md['engine_uuid'] is not None:
682 682 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
683 683
684 684 if 'date' in parent:
685 685 md['submitted'] = parent['date']
686 686 if 'started' in msg_meta:
687 687 md['started'] = parse_date(msg_meta['started'])
688 688 if 'date' in header:
689 689 md['completed'] = header['date']
690 690 return md
691 691
692 692 def _register_engine(self, msg):
693 693 """Register a new engine, and update our connection info."""
694 694 content = msg['content']
695 695 eid = content['id']
696 696 d = {eid : content['uuid']}
697 697 self._update_engines(d)
698 698
699 699 def _unregister_engine(self, msg):
700 700 """Unregister an engine that has died."""
701 701 content = msg['content']
702 702 eid = int(content['id'])
703 703 if eid in self._ids:
704 704 self._ids.remove(eid)
705 705 uuid = self._engines.pop(eid)
706 706
707 707 self._handle_stranded_msgs(eid, uuid)
708 708
709 709 if self._task_socket and self._task_scheme == 'pure':
710 710 self._stop_scheduling_tasks()
711 711
712 712 def _handle_stranded_msgs(self, eid, uuid):
713 713 """Handle messages known to be on an engine when the engine unregisters.
714 714
715 715 It is possible that this will fire prematurely - that is, an engine will
716 716 go down after completing a result, and the client will be notified
717 717 of the unregistration and later receive the successful result.
718 718 """
719 719
720 720 outstanding = self._outstanding_dict[uuid]
721 721
722 722 for msg_id in list(outstanding):
723 723 if msg_id in self.results:
724 724 # we already
725 725 continue
726 726 try:
727 727 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
728 728 except:
729 729 content = error.wrap_exception()
730 730 # build a fake message:
731 731 msg = self.session.msg('apply_reply', content=content)
732 732 msg['parent_header']['msg_id'] = msg_id
733 733 msg['metadata']['engine'] = uuid
734 734 self._handle_apply_reply(msg)
735 735
736 736 def _handle_execute_reply(self, msg):
737 737 """Save the reply to an execute_request into our results.
738 738
739 739 execute messages are never actually used. apply is used instead.
740 740 """
741 741
742 742 parent = msg['parent_header']
743 743 msg_id = parent['msg_id']
744 744 if msg_id not in self.outstanding:
745 745 if msg_id in self.history:
746 746 print("got stale result: %s"%msg_id)
747 747 else:
748 748 print("got unknown result: %s"%msg_id)
749 749 else:
750 750 self.outstanding.remove(msg_id)
751 751
752 752 content = msg['content']
753 753 header = msg['header']
754 754
755 755 # construct metadata:
756 756 md = self.metadata[msg_id]
757 757 md.update(self._extract_metadata(msg))
758 758 # is this redundant?
759 759 self.metadata[msg_id] = md
760 760
761 761 e_outstanding = self._outstanding_dict[md['engine_uuid']]
762 762 if msg_id in e_outstanding:
763 763 e_outstanding.remove(msg_id)
764 764
765 765 # construct result:
766 766 if content['status'] == 'ok':
767 767 self.results[msg_id] = ExecuteReply(msg_id, content, md)
768 768 elif content['status'] == 'aborted':
769 769 self.results[msg_id] = error.TaskAborted(msg_id)
770 770 elif content['status'] == 'resubmitted':
771 771 # TODO: handle resubmission
772 772 pass
773 773 else:
774 774 self.results[msg_id] = self._unwrap_exception(content)
775 775
776 776 def _handle_apply_reply(self, msg):
777 777 """Save the reply to an apply_request into our results."""
778 778 parent = msg['parent_header']
779 779 msg_id = parent['msg_id']
780 780 if msg_id not in self.outstanding:
781 781 if msg_id in self.history:
782 782 print("got stale result: %s"%msg_id)
783 783 print(self.results[msg_id])
784 784 print(msg)
785 785 else:
786 786 print("got unknown result: %s"%msg_id)
787 787 else:
788 788 self.outstanding.remove(msg_id)
789 789 content = msg['content']
790 790 header = msg['header']
791 791
792 792 # construct metadata:
793 793 md = self.metadata[msg_id]
794 794 md.update(self._extract_metadata(msg))
795 795 # is this redundant?
796 796 self.metadata[msg_id] = md
797 797
798 798 e_outstanding = self._outstanding_dict[md['engine_uuid']]
799 799 if msg_id in e_outstanding:
800 800 e_outstanding.remove(msg_id)
801 801
802 802 # construct result:
803 803 if content['status'] == 'ok':
804 804 self.results[msg_id] = serialize.deserialize_object(msg['buffers'])[0]
805 805 elif content['status'] == 'aborted':
806 806 self.results[msg_id] = error.TaskAborted(msg_id)
807 807 elif content['status'] == 'resubmitted':
808 808 # TODO: handle resubmission
809 809 pass
810 810 else:
811 811 self.results[msg_id] = self._unwrap_exception(content)
812 812
813 813 def _flush_notifications(self):
814 814 """Flush notifications of engine registrations waiting
815 815 in ZMQ queue."""
816 816 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
817 817 while msg is not None:
818 818 if self.debug:
819 819 pprint(msg)
820 820 msg_type = msg['header']['msg_type']
821 821 handler = self._notification_handlers.get(msg_type, None)
822 822 if handler is None:
823 823 raise Exception("Unhandled message type: %s" % msg_type)
824 824 else:
825 825 handler(msg)
826 826 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
827 827
828 828 def _flush_results(self, sock):
829 829 """Flush task or queue results waiting in ZMQ queue."""
830 830 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
831 831 while msg is not None:
832 832 if self.debug:
833 833 pprint(msg)
834 834 msg_type = msg['header']['msg_type']
835 835 handler = self._queue_handlers.get(msg_type, None)
836 836 if handler is None:
837 837 raise Exception("Unhandled message type: %s" % msg_type)
838 838 else:
839 839 handler(msg)
840 840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
841 841
842 842 def _flush_control(self, sock):
843 843 """Flush replies from the control channel waiting
844 844 in the ZMQ queue.
845 845
846 846 Currently: ignore them."""
847 847 if self._ignored_control_replies <= 0:
848 848 return
849 849 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
850 850 while msg is not None:
851 851 self._ignored_control_replies -= 1
852 852 if self.debug:
853 853 pprint(msg)
854 854 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
855 855
856 856 def _flush_ignored_control(self):
857 857 """flush ignored control replies"""
858 858 while self._ignored_control_replies > 0:
859 859 self.session.recv(self._control_socket)
860 860 self._ignored_control_replies -= 1
861 861
862 862 def _flush_ignored_hub_replies(self):
863 863 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
864 864 while msg is not None:
865 865 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
866 866
867 867 def _flush_iopub(self, sock):
868 868 """Flush replies from the iopub channel waiting
869 869 in the ZMQ queue.
870 870 """
871 871 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
872 872 while msg is not None:
873 873 if self.debug:
874 874 pprint(msg)
875 875 parent = msg['parent_header']
876 876 if not parent or parent['session'] != self.session.session:
877 877 # ignore IOPub messages not from here
878 878 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
879 879 continue
880 880 msg_id = parent['msg_id']
881 881 content = msg['content']
882 882 header = msg['header']
883 883 msg_type = msg['header']['msg_type']
884 884
885 885 if msg_type == 'status' and msg_id not in self.metadata:
886 886 # ignore status messages if they aren't mine
887 887 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
888 888 continue
889 889
890 890 # init metadata:
891 891 md = self.metadata[msg_id]
892 892
893 893 if msg_type == 'stream':
894 894 name = content['name']
895 895 s = md[name] or ''
896 896 md[name] = s + content['text']
897 897 elif msg_type == 'error':
898 898 md.update({'error' : self._unwrap_exception(content)})
899 899 elif msg_type == 'execute_input':
900 900 md.update({'execute_input' : content['code']})
901 901 elif msg_type == 'display_data':
902 902 md['outputs'].append(content)
903 903 elif msg_type == 'execute_result':
904 904 md['execute_result'] = content
905 905 elif msg_type == 'data_message':
906 906 data, remainder = serialize.deserialize_object(msg['buffers'])
907 907 md['data'].update(data)
908 908 elif msg_type == 'status':
909 909 # idle message comes after all outputs
910 910 if content['execution_state'] == 'idle':
911 911 md['outputs_ready'] = True
912 912 else:
913 913 # unhandled msg_type (status, etc.)
914 914 pass
915 915
916 916 # reduntant?
917 917 self.metadata[msg_id] = md
918 918
919 919 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
920 920
921 921 #--------------------------------------------------------------------------
922 922 # len, getitem
923 923 #--------------------------------------------------------------------------
924 924
925 925 def __len__(self):
926 926 """len(client) returns # of engines."""
927 927 return len(self.ids)
928 928
929 929 def __getitem__(self, key):
930 930 """index access returns DirectView multiplexer objects
931 931
932 932 Must be int, slice, or list/tuple/xrange of ints"""
933 933 if not isinstance(key, (int, slice, tuple, list, xrange)):
934 934 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
935 935 else:
936 936 return self.direct_view(key)
937 937
938 938 def __iter__(self):
939 939 """Since we define getitem, Client is iterable
940 940
941 941 but unless we also define __iter__, it won't work correctly unless engine IDs
942 942 start at zero and are continuous.
943 943 """
944 944 for eid in self.ids:
945 945 yield self.direct_view(eid)
946 946
947 947 #--------------------------------------------------------------------------
948 948 # Begin public methods
949 949 #--------------------------------------------------------------------------
950 950
951 951 @property
952 952 def ids(self):
953 953 """Always up-to-date ids property."""
954 954 self._flush_notifications()
955 955 # always copy:
956 956 return list(self._ids)
957 957
958 958 def activate(self, targets='all', suffix=''):
959 959 """Create a DirectView and register it with IPython magics
960 960
961 961 Defines the magics `%px, %autopx, %pxresult, %%px`
962 962
963 963 Parameters
964 964 ----------
965 965
966 966 targets: int, list of ints, or 'all'
967 967 The engines on which the view's magics will run
968 968 suffix: str [default: '']
969 969 The suffix, if any, for the magics. This allows you to have
970 970 multiple views associated with parallel magics at the same time.
971 971
972 972 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
973 973 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
974 974 on engine 0.
975 975 """
976 976 view = self.direct_view(targets)
977 977 view.block = True
978 978 view.activate(suffix)
979 979 return view
980 980
981 981 def close(self, linger=None):
982 982 """Close my zmq Sockets
983 983
984 984 If `linger`, set the zmq LINGER socket option,
985 985 which allows discarding of messages.
986 986 """
987 987 if self._closed:
988 988 return
989 989 self.stop_spin_thread()
990 990 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
991 991 for name in snames:
992 992 socket = getattr(self, name)
993 993 if socket is not None and not socket.closed:
994 994 if linger is not None:
995 995 socket.close(linger=linger)
996 996 else:
997 997 socket.close()
998 998 self._closed = True
999 999
1000 1000 def _spin_every(self, interval=1):
1001 1001 """target func for use in spin_thread"""
1002 1002 while True:
1003 1003 if self._stop_spinning.is_set():
1004 1004 return
1005 1005 time.sleep(interval)
1006 1006 self.spin()
1007 1007
1008 1008 def spin_thread(self, interval=1):
1009 1009 """call Client.spin() in a background thread on some regular interval
1010 1010
1011 1011 This helps ensure that messages don't pile up too much in the zmq queue
1012 1012 while you are working on other things, or just leaving an idle terminal.
1013 1013
1014 1014 It also helps limit potential padding of the `received` timestamp
1015 1015 on AsyncResult objects, used for timings.
1016 1016
1017 1017 Parameters
1018 1018 ----------
1019 1019
1020 1020 interval : float, optional
1021 1021 The interval on which to spin the client in the background thread
1022 1022 (simply passed to time.sleep).
1023 1023
1024 1024 Notes
1025 1025 -----
1026 1026
1027 1027 For precision timing, you may want to use this method to put a bound
1028 1028 on the jitter (in seconds) in `received` timestamps used
1029 1029 in AsyncResult.wall_time.
1030 1030
1031 1031 """
1032 1032 if self._spin_thread is not None:
1033 1033 self.stop_spin_thread()
1034 1034 self._stop_spinning.clear()
1035 1035 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1036 1036 self._spin_thread.daemon = True
1037 1037 self._spin_thread.start()
1038 1038
1039 1039 def stop_spin_thread(self):
1040 1040 """stop background spin_thread, if any"""
1041 1041 if self._spin_thread is not None:
1042 1042 self._stop_spinning.set()
1043 1043 self._spin_thread.join()
1044 1044 self._spin_thread = None
1045 1045
1046 1046 def spin(self):
1047 1047 """Flush any registration notifications and execution results
1048 1048 waiting in the ZMQ queue.
1049 1049 """
1050 1050 if self._notification_socket:
1051 1051 self._flush_notifications()
1052 1052 if self._iopub_socket:
1053 1053 self._flush_iopub(self._iopub_socket)
1054 1054 if self._mux_socket:
1055 1055 self._flush_results(self._mux_socket)
1056 1056 if self._task_socket:
1057 1057 self._flush_results(self._task_socket)
1058 1058 if self._control_socket:
1059 1059 self._flush_control(self._control_socket)
1060 1060 if self._query_socket:
1061 1061 self._flush_ignored_hub_replies()
1062 1062
1063 1063 def wait(self, jobs=None, timeout=-1):
1064 1064 """waits on one or more `jobs`, for up to `timeout` seconds.
1065 1065
1066 1066 Parameters
1067 1067 ----------
1068 1068
1069 1069 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1070 1070 ints are indices to self.history
1071 1071 strs are msg_ids
1072 1072 default: wait on all outstanding messages
1073 1073 timeout : float
1074 1074 a time in seconds, after which to give up.
1075 1075 default is -1, which means no timeout
1076 1076
1077 1077 Returns
1078 1078 -------
1079 1079
1080 1080 True : when all msg_ids are done
1081 1081 False : timeout reached, some msg_ids still outstanding
1082 1082 """
1083 1083 tic = time.time()
1084 1084 if jobs is None:
1085 1085 theids = self.outstanding
1086 1086 else:
1087 1087 if isinstance(jobs, string_types + (int, AsyncResult)):
1088 1088 jobs = [jobs]
1089 1089 theids = set()
1090 1090 for job in jobs:
1091 1091 if isinstance(job, int):
1092 1092 # index access
1093 1093 job = self.history[job]
1094 1094 elif isinstance(job, AsyncResult):
1095 1095 theids.update(job.msg_ids)
1096 1096 continue
1097 1097 theids.add(job)
1098 1098 if not theids.intersection(self.outstanding):
1099 1099 return True
1100 1100 self.spin()
1101 1101 while theids.intersection(self.outstanding):
1102 1102 if timeout >= 0 and ( time.time()-tic ) > timeout:
1103 1103 break
1104 1104 time.sleep(1e-3)
1105 1105 self.spin()
1106 1106 return len(theids.intersection(self.outstanding)) == 0
1107 1107
1108 1108 #--------------------------------------------------------------------------
1109 1109 # Control methods
1110 1110 #--------------------------------------------------------------------------
1111 1111
1112 1112 @spin_first
1113 1113 def clear(self, targets=None, block=None):
1114 1114 """Clear the namespace in target(s)."""
1115 1115 block = self.block if block is None else block
1116 1116 targets = self._build_targets(targets)[0]
1117 1117 for t in targets:
1118 1118 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1119 1119 error = False
1120 1120 if block:
1121 1121 self._flush_ignored_control()
1122 1122 for i in range(len(targets)):
1123 1123 idents,msg = self.session.recv(self._control_socket,0)
1124 1124 if self.debug:
1125 1125 pprint(msg)
1126 1126 if msg['content']['status'] != 'ok':
1127 1127 error = self._unwrap_exception(msg['content'])
1128 1128 else:
1129 1129 self._ignored_control_replies += len(targets)
1130 1130 if error:
1131 1131 raise error
1132 1132
1133 1133
1134 1134 @spin_first
1135 1135 def abort(self, jobs=None, targets=None, block=None):
1136 1136 """Abort specific jobs from the execution queues of target(s).
1137 1137
1138 1138 This is a mechanism to prevent jobs that have already been submitted
1139 1139 from executing.
1140 1140
1141 1141 Parameters
1142 1142 ----------
1143 1143
1144 1144 jobs : msg_id, list of msg_ids, or AsyncResult
1145 1145 The jobs to be aborted
1146 1146
1147 1147 If unspecified/None: abort all outstanding jobs.
1148 1148
1149 1149 """
1150 1150 block = self.block if block is None else block
1151 1151 jobs = jobs if jobs is not None else list(self.outstanding)
1152 1152 targets = self._build_targets(targets)[0]
1153 1153
1154 1154 msg_ids = []
1155 1155 if isinstance(jobs, string_types + (AsyncResult,)):
1156 1156 jobs = [jobs]
1157 1157 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1158 1158 if bad_ids:
1159 1159 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1160 1160 for j in jobs:
1161 1161 if isinstance(j, AsyncResult):
1162 1162 msg_ids.extend(j.msg_ids)
1163 1163 else:
1164 1164 msg_ids.append(j)
1165 1165 content = dict(msg_ids=msg_ids)
1166 1166 for t in targets:
1167 1167 self.session.send(self._control_socket, 'abort_request',
1168 1168 content=content, ident=t)
1169 1169 error = False
1170 1170 if block:
1171 1171 self._flush_ignored_control()
1172 1172 for i in range(len(targets)):
1173 1173 idents,msg = self.session.recv(self._control_socket,0)
1174 1174 if self.debug:
1175 1175 pprint(msg)
1176 1176 if msg['content']['status'] != 'ok':
1177 1177 error = self._unwrap_exception(msg['content'])
1178 1178 else:
1179 1179 self._ignored_control_replies += len(targets)
1180 1180 if error:
1181 1181 raise error
1182 1182
1183 1183 @spin_first
1184 1184 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1185 1185 """Terminates one or more engine processes, optionally including the hub.
1186 1186
1187 1187 Parameters
1188 1188 ----------
1189 1189
1190 1190 targets: list of ints or 'all' [default: all]
1191 1191 Which engines to shutdown.
1192 1192 hub: bool [default: False]
1193 1193 Whether to include the Hub. hub=True implies targets='all'.
1194 1194 block: bool [default: self.block]
1195 1195 Whether to wait for clean shutdown replies or not.
1196 1196 restart: bool [default: False]
1197 1197 NOT IMPLEMENTED
1198 1198 whether to restart engines after shutting them down.
1199 1199 """
1200 1200 from IPython.parallel.error import NoEnginesRegistered
1201 1201 if restart:
1202 1202 raise NotImplementedError("Engine restart is not yet implemented")
1203 1203
1204 1204 block = self.block if block is None else block
1205 1205 if hub:
1206 1206 targets = 'all'
1207 1207 try:
1208 1208 targets = self._build_targets(targets)[0]
1209 1209 except NoEnginesRegistered:
1210 1210 targets = []
1211 1211 for t in targets:
1212 1212 self.session.send(self._control_socket, 'shutdown_request',
1213 1213 content={'restart':restart},ident=t)
1214 1214 error = False
1215 1215 if block or hub:
1216 1216 self._flush_ignored_control()
1217 1217 for i in range(len(targets)):
1218 1218 idents,msg = self.session.recv(self._control_socket, 0)
1219 1219 if self.debug:
1220 1220 pprint(msg)
1221 1221 if msg['content']['status'] != 'ok':
1222 1222 error = self._unwrap_exception(msg['content'])
1223 1223 else:
1224 1224 self._ignored_control_replies += len(targets)
1225 1225
1226 1226 if hub:
1227 1227 time.sleep(0.25)
1228 1228 self.session.send(self._query_socket, 'shutdown_request')
1229 1229 idents,msg = self.session.recv(self._query_socket, 0)
1230 1230 if self.debug:
1231 1231 pprint(msg)
1232 1232 if msg['content']['status'] != 'ok':
1233 1233 error = self._unwrap_exception(msg['content'])
1234 1234
1235 1235 if error:
1236 1236 raise error
1237 1237
1238 1238 #--------------------------------------------------------------------------
1239 1239 # Execution related methods
1240 1240 #--------------------------------------------------------------------------
1241 1241
1242 1242 def _maybe_raise(self, result):
1243 1243 """wrapper for maybe raising an exception if apply failed."""
1244 1244 if isinstance(result, error.RemoteError):
1245 1245 raise result
1246 1246
1247 1247 return result
1248 1248
1249 1249 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1250 1250 ident=None):
1251 1251 """construct and send an apply message via a socket.
1252 1252
1253 1253 This is the principal method with which all engine execution is performed by views.
1254 1254 """
1255 1255
1256 1256 if self._closed:
1257 1257 raise RuntimeError("Client cannot be used after its sockets have been closed")
1258 1258
1259 1259 # defaults:
1260 1260 args = args if args is not None else []
1261 1261 kwargs = kwargs if kwargs is not None else {}
1262 1262 metadata = metadata if metadata is not None else {}
1263 1263
1264 1264 # validate arguments
1265 1265 if not callable(f) and not isinstance(f, Reference):
1266 1266 raise TypeError("f must be callable, not %s"%type(f))
1267 1267 if not isinstance(args, (tuple, list)):
1268 1268 raise TypeError("args must be tuple or list, not %s"%type(args))
1269 1269 if not isinstance(kwargs, dict):
1270 1270 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1271 1271 if not isinstance(metadata, dict):
1272 1272 raise TypeError("metadata must be dict, not %s"%type(metadata))
1273 1273
1274 1274 bufs = serialize.pack_apply_message(f, args, kwargs,
1275 1275 buffer_threshold=self.session.buffer_threshold,
1276 1276 item_threshold=self.session.item_threshold,
1277 1277 )
1278 1278
1279 1279 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1280 1280 metadata=metadata, track=track)
1281 1281
1282 1282 msg_id = msg['header']['msg_id']
1283 1283 self.outstanding.add(msg_id)
1284 1284 if ident:
1285 1285 # possibly routed to a specific engine
1286 1286 if isinstance(ident, list):
1287 1287 ident = ident[-1]
1288 1288 if ident in self._engines.values():
1289 1289 # save for later, in case of engine death
1290 1290 self._outstanding_dict[ident].add(msg_id)
1291 1291 self.history.append(msg_id)
1292 1292 self.metadata[msg_id]['submitted'] = datetime.now()
1293 1293
1294 1294 return msg
1295 1295
1296 1296 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1297 1297 """construct and send an execute request via a socket.
1298 1298
1299 1299 """
1300 1300
1301 1301 if self._closed:
1302 1302 raise RuntimeError("Client cannot be used after its sockets have been closed")
1303 1303
1304 1304 # defaults:
1305 1305 metadata = metadata if metadata is not None else {}
1306 1306
1307 1307 # validate arguments
1308 1308 if not isinstance(code, string_types):
1309 1309 raise TypeError("code must be text, not %s" % type(code))
1310 1310 if not isinstance(metadata, dict):
1311 1311 raise TypeError("metadata must be dict, not %s" % type(metadata))
1312 1312
1313 1313 content = dict(code=code, silent=bool(silent), user_expressions={})
1314 1314
1315 1315
1316 1316 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1317 1317 metadata=metadata)
1318 1318
1319 1319 msg_id = msg['header']['msg_id']
1320 1320 self.outstanding.add(msg_id)
1321 1321 if ident:
1322 1322 # possibly routed to a specific engine
1323 1323 if isinstance(ident, list):
1324 1324 ident = ident[-1]
1325 1325 if ident in self._engines.values():
1326 1326 # save for later, in case of engine death
1327 1327 self._outstanding_dict[ident].add(msg_id)
1328 1328 self.history.append(msg_id)
1329 1329 self.metadata[msg_id]['submitted'] = datetime.now()
1330 1330
1331 1331 return msg
1332 1332
1333 1333 #--------------------------------------------------------------------------
1334 1334 # construct a View object
1335 1335 #--------------------------------------------------------------------------
1336 1336
1337 1337 def load_balanced_view(self, targets=None):
1338 1338 """construct a DirectView object.
1339 1339
1340 1340 If no arguments are specified, create a LoadBalancedView
1341 1341 using all engines.
1342 1342
1343 1343 Parameters
1344 1344 ----------
1345 1345
1346 1346 targets: list,slice,int,etc. [default: use all engines]
1347 1347 The subset of engines across which to load-balance
1348 1348 """
1349 1349 if targets == 'all':
1350 1350 targets = None
1351 1351 if targets is not None:
1352 1352 targets = self._build_targets(targets)[1]
1353 1353 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1354 1354
1355 1355 def direct_view(self, targets='all'):
1356 1356 """construct a DirectView object.
1357 1357
1358 1358 If no targets are specified, create a DirectView using all engines.
1359 1359
1360 1360 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1361 1361 evaluate the target engines at each execution, whereas rc[:] will connect to
1362 1362 all *current* engines, and that list will not change.
1363 1363
1364 1364 That is, 'all' will always use all engines, whereas rc[:] will not use
1365 1365 engines added after the DirectView is constructed.
1366 1366
1367 1367 Parameters
1368 1368 ----------
1369 1369
1370 1370 targets: list,slice,int,etc. [default: use all engines]
1371 1371 The engines to use for the View
1372 1372 """
1373 1373 single = isinstance(targets, int)
1374 1374 # allow 'all' to be lazily evaluated at each execution
1375 1375 if targets != 'all':
1376 1376 targets = self._build_targets(targets)[1]
1377 1377 if single:
1378 1378 targets = targets[0]
1379 1379 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1380 1380
1381 1381 #--------------------------------------------------------------------------
1382 1382 # Query methods
1383 1383 #--------------------------------------------------------------------------
1384 1384
1385 1385 @spin_first
1386 1386 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
1387 1387 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1388 1388
1389 1389 If the client already has the results, no request to the Hub will be made.
1390 1390
1391 1391 This is a convenient way to construct AsyncResult objects, which are wrappers
1392 1392 that include metadata about execution, and allow for awaiting results that
1393 1393 were not submitted by this Client.
1394 1394
1395 1395 It can also be a convenient way to retrieve the metadata associated with
1396 1396 blocking execution, since it always retrieves
1397 1397
1398 1398 Examples
1399 1399 --------
1400 1400 ::
1401 1401
1402 1402 In [10]: r = client.apply()
1403 1403
1404 1404 Parameters
1405 1405 ----------
1406 1406
1407 1407 indices_or_msg_ids : integer history index, str msg_id, or list of either
1408 1408 The indices or msg_ids of indices to be retrieved
1409 1409
1410 1410 block : bool
1411 1411 Whether to wait for the result to be done
1412 1412 owner : bool [default: True]
1413 1413 Whether this AsyncResult should own the result.
1414 1414 If so, calling `ar.get()` will remove data from the
1415 1415 client's result and metadata cache.
1416 1416 There should only be one owner of any given msg_id.
1417 1417
1418 1418 Returns
1419 1419 -------
1420 1420
1421 1421 AsyncResult
1422 1422 A single AsyncResult object will always be returned.
1423 1423
1424 1424 AsyncHubResult
1425 1425 A subclass of AsyncResult that retrieves results from the Hub
1426 1426
1427 1427 """
1428 1428 block = self.block if block is None else block
1429 1429 if indices_or_msg_ids is None:
1430 1430 indices_or_msg_ids = -1
1431 1431
1432 1432 single_result = False
1433 1433 if not isinstance(indices_or_msg_ids, (list,tuple)):
1434 1434 indices_or_msg_ids = [indices_or_msg_ids]
1435 1435 single_result = True
1436 1436
1437 1437 theids = []
1438 1438 for id in indices_or_msg_ids:
1439 1439 if isinstance(id, int):
1440 1440 id = self.history[id]
1441 1441 if not isinstance(id, string_types):
1442 1442 raise TypeError("indices must be str or int, not %r"%id)
1443 1443 theids.append(id)
1444 1444
1445 1445 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1446 1446 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1447 1447
1448 1448 # given single msg_id initially, get_result shot get the result itself,
1449 1449 # not a length-one list
1450 1450 if single_result:
1451 1451 theids = theids[0]
1452 1452
1453 1453 if remote_ids:
1454 1454 ar = AsyncHubResult(self, msg_ids=theids, owner=owner)
1455 1455 else:
1456 1456 ar = AsyncResult(self, msg_ids=theids, owner=owner)
1457 1457
1458 1458 if block:
1459 1459 ar.wait()
1460 1460
1461 1461 return ar
1462 1462
1463 1463 @spin_first
1464 1464 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1465 1465 """Resubmit one or more tasks.
1466 1466
1467 1467 in-flight tasks may not be resubmitted.
1468 1468
1469 1469 Parameters
1470 1470 ----------
1471 1471
1472 1472 indices_or_msg_ids : integer history index, str msg_id, or list of either
1473 1473 The indices or msg_ids of indices to be retrieved
1474 1474
1475 1475 block : bool
1476 1476 Whether to wait for the result to be done
1477 1477
1478 1478 Returns
1479 1479 -------
1480 1480
1481 1481 AsyncHubResult
1482 1482 A subclass of AsyncResult that retrieves results from the Hub
1483 1483
1484 1484 """
1485 1485 block = self.block if block is None else block
1486 1486 if indices_or_msg_ids is None:
1487 1487 indices_or_msg_ids = -1
1488 1488
1489 1489 if not isinstance(indices_or_msg_ids, (list,tuple)):
1490 1490 indices_or_msg_ids = [indices_or_msg_ids]
1491 1491
1492 1492 theids = []
1493 1493 for id in indices_or_msg_ids:
1494 1494 if isinstance(id, int):
1495 1495 id = self.history[id]
1496 1496 if not isinstance(id, string_types):
1497 1497 raise TypeError("indices must be str or int, not %r"%id)
1498 1498 theids.append(id)
1499 1499
1500 1500 content = dict(msg_ids = theids)
1501 1501
1502 1502 self.session.send(self._query_socket, 'resubmit_request', content)
1503 1503
1504 1504 zmq.select([self._query_socket], [], [])
1505 1505 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1506 1506 if self.debug:
1507 1507 pprint(msg)
1508 1508 content = msg['content']
1509 1509 if content['status'] != 'ok':
1510 1510 raise self._unwrap_exception(content)
1511 1511 mapping = content['resubmitted']
1512 1512 new_ids = [ mapping[msg_id] for msg_id in theids ]
1513 1513
1514 1514 ar = AsyncHubResult(self, msg_ids=new_ids)
1515 1515
1516 1516 if block:
1517 1517 ar.wait()
1518 1518
1519 1519 return ar
1520 1520
1521 1521 @spin_first
1522 1522 def result_status(self, msg_ids, status_only=True):
1523 1523 """Check on the status of the result(s) of the apply request with `msg_ids`.
1524 1524
1525 1525 If status_only is False, then the actual results will be retrieved, else
1526 1526 only the status of the results will be checked.
1527 1527
1528 1528 Parameters
1529 1529 ----------
1530 1530
1531 1531 msg_ids : list of msg_ids
1532 1532 if int:
1533 1533 Passed as index to self.history for convenience.
1534 1534 status_only : bool (default: True)
1535 1535 if False:
1536 1536 Retrieve the actual results of completed tasks.
1537 1537
1538 1538 Returns
1539 1539 -------
1540 1540
1541 1541 results : dict
1542 1542 There will always be the keys 'pending' and 'completed', which will
1543 1543 be lists of msg_ids that are incomplete or complete. If `status_only`
1544 1544 is False, then completed results will be keyed by their `msg_id`.
1545 1545 """
1546 1546 if not isinstance(msg_ids, (list,tuple)):
1547 1547 msg_ids = [msg_ids]
1548 1548
1549 1549 theids = []
1550 1550 for msg_id in msg_ids:
1551 1551 if isinstance(msg_id, int):
1552 1552 msg_id = self.history[msg_id]
1553 1553 if not isinstance(msg_id, string_types):
1554 1554 raise TypeError("msg_ids must be str, not %r"%msg_id)
1555 1555 theids.append(msg_id)
1556 1556
1557 1557 completed = []
1558 1558 local_results = {}
1559 1559
1560 1560 # comment this block out to temporarily disable local shortcut:
1561 1561 for msg_id in theids:
1562 1562 if msg_id in self.results:
1563 1563 completed.append(msg_id)
1564 1564 local_results[msg_id] = self.results[msg_id]
1565 1565 theids.remove(msg_id)
1566 1566
1567 1567 if theids: # some not locally cached
1568 1568 content = dict(msg_ids=theids, status_only=status_only)
1569 1569 msg = self.session.send(self._query_socket, "result_request", content=content)
1570 1570 zmq.select([self._query_socket], [], [])
1571 1571 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1572 1572 if self.debug:
1573 1573 pprint(msg)
1574 1574 content = msg['content']
1575 1575 if content['status'] != 'ok':
1576 1576 raise self._unwrap_exception(content)
1577 1577 buffers = msg['buffers']
1578 1578 else:
1579 1579 content = dict(completed=[],pending=[])
1580 1580
1581 1581 content['completed'].extend(completed)
1582 1582
1583 1583 if status_only:
1584 1584 return content
1585 1585
1586 1586 failures = []
1587 1587 # load cached results into result:
1588 1588 content.update(local_results)
1589 1589
1590 1590 # update cache with results:
1591 1591 for msg_id in sorted(theids):
1592 1592 if msg_id in content['completed']:
1593 1593 rec = content[msg_id]
1594 1594 parent = extract_dates(rec['header'])
1595 1595 header = extract_dates(rec['result_header'])
1596 1596 rcontent = rec['result_content']
1597 1597 iodict = rec['io']
1598 1598 if isinstance(rcontent, str):
1599 1599 rcontent = self.session.unpack(rcontent)
1600 1600
1601 1601 md = self.metadata[msg_id]
1602 1602 md_msg = dict(
1603 1603 content=rcontent,
1604 1604 parent_header=parent,
1605 1605 header=header,
1606 1606 metadata=rec['result_metadata'],
1607 1607 )
1608 1608 md.update(self._extract_metadata(md_msg))
1609 1609 if rec.get('received'):
1610 1610 md['received'] = parse_date(rec['received'])
1611 1611 md.update(iodict)
1612 1612
1613 1613 if rcontent['status'] == 'ok':
1614 1614 if header['msg_type'] == 'apply_reply':
1615 1615 res,buffers = serialize.deserialize_object(buffers)
1616 1616 elif header['msg_type'] == 'execute_reply':
1617 1617 res = ExecuteReply(msg_id, rcontent, md)
1618 1618 else:
1619 1619 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1620 1620 else:
1621 1621 res = self._unwrap_exception(rcontent)
1622 1622 failures.append(res)
1623 1623
1624 1624 self.results[msg_id] = res
1625 1625 content[msg_id] = res
1626 1626
1627 1627 if len(theids) == 1 and failures:
1628 1628 raise failures[0]
1629 1629
1630 1630 error.collect_exceptions(failures, "result_status")
1631 1631 return content
1632 1632
1633 1633 @spin_first
1634 1634 def queue_status(self, targets='all', verbose=False):
1635 1635 """Fetch the status of engine queues.
1636 1636
1637 1637 Parameters
1638 1638 ----------
1639 1639
1640 1640 targets : int/str/list of ints/strs
1641 1641 the engines whose states are to be queried.
1642 1642 default : all
1643 1643 verbose : bool
1644 1644 Whether to return lengths only, or lists of ids for each element
1645 1645 """
1646 1646 if targets == 'all':
1647 1647 # allow 'all' to be evaluated on the engine
1648 1648 engine_ids = None
1649 1649 else:
1650 1650 engine_ids = self._build_targets(targets)[1]
1651 1651 content = dict(targets=engine_ids, verbose=verbose)
1652 1652 self.session.send(self._query_socket, "queue_request", content=content)
1653 1653 idents,msg = self.session.recv(self._query_socket, 0)
1654 1654 if self.debug:
1655 1655 pprint(msg)
1656 1656 content = msg['content']
1657 1657 status = content.pop('status')
1658 1658 if status != 'ok':
1659 1659 raise self._unwrap_exception(content)
1660 1660 content = rekey(content)
1661 1661 if isinstance(targets, int):
1662 1662 return content[targets]
1663 1663 else:
1664 1664 return content
1665 1665
1666 1666 def _build_msgids_from_target(self, targets=None):
1667 1667 """Build a list of msg_ids from the list of engine targets"""
1668 1668 if not targets: # needed as _build_targets otherwise uses all engines
1669 1669 return []
1670 1670 target_ids = self._build_targets(targets)[0]
1671 1671 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1672 1672
1673 1673 def _build_msgids_from_jobs(self, jobs=None):
1674 1674 """Build a list of msg_ids from "jobs" """
1675 1675 if not jobs:
1676 1676 return []
1677 1677 msg_ids = []
1678 1678 if isinstance(jobs, string_types + (AsyncResult,)):
1679 1679 jobs = [jobs]
1680 1680 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1681 1681 if bad_ids:
1682 1682 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1683 1683 for j in jobs:
1684 1684 if isinstance(j, AsyncResult):
1685 1685 msg_ids.extend(j.msg_ids)
1686 1686 else:
1687 1687 msg_ids.append(j)
1688 1688 return msg_ids
1689 1689
1690 1690 def purge_local_results(self, jobs=[], targets=[]):
1691 1691 """Clears the client caches of results and their metadata.
1692 1692
1693 1693 Individual results can be purged by msg_id, or the entire
1694 1694 history of specific targets can be purged.
1695 1695
1696 1696 Use `purge_local_results('all')` to scrub everything from the Clients's
1697 1697 results and metadata caches.
1698 1698
1699 1699 After this call all `AsyncResults` are invalid and should be discarded.
1700 1700
1701 1701 If you must "reget" the results, you can still do so by using
1702 1702 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1703 1703 redownload the results from the hub if they are still available
1704 1704 (i.e `client.purge_hub_results(...)` has not been called.
1705 1705
1706 1706 Parameters
1707 1707 ----------
1708 1708
1709 1709 jobs : str or list of str or AsyncResult objects
1710 1710 the msg_ids whose results should be purged.
1711 1711 targets : int/list of ints
1712 1712 The engines, by integer ID, whose entire result histories are to be purged.
1713 1713
1714 1714 Raises
1715 1715 ------
1716 1716
1717 1717 RuntimeError : if any of the tasks to be purged are still outstanding.
1718 1718
1719 1719 """
1720 1720 if not targets and not jobs:
1721 1721 raise ValueError("Must specify at least one of `targets` and `jobs`")
1722 1722
1723 1723 if jobs == 'all':
1724 1724 if self.outstanding:
1725 1725 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1726 1726 self.results.clear()
1727 1727 self.metadata.clear()
1728 1728 else:
1729 1729 msg_ids = set()
1730 1730 msg_ids.update(self._build_msgids_from_target(targets))
1731 1731 msg_ids.update(self._build_msgids_from_jobs(jobs))
1732 1732 still_outstanding = self.outstanding.intersection(msg_ids)
1733 1733 if still_outstanding:
1734 1734 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1735 1735 for mid in msg_ids:
1736 1736 self.results.pop(mid, None)
1737 1737 self.metadata.pop(mid, None)
1738 1738
1739 1739
1740 1740 @spin_first
1741 1741 def purge_hub_results(self, jobs=[], targets=[]):
1742 1742 """Tell the Hub to forget results.
1743 1743
1744 1744 Individual results can be purged by msg_id, or the entire
1745 1745 history of specific targets can be purged.
1746 1746
1747 1747 Use `purge_results('all')` to scrub everything from the Hub's db.
1748 1748
1749 1749 Parameters
1750 1750 ----------
1751 1751
1752 1752 jobs : str or list of str or AsyncResult objects
1753 1753 the msg_ids whose results should be forgotten.
1754 1754 targets : int/str/list of ints/strs
1755 1755 The targets, by int_id, whose entire history is to be purged.
1756 1756
1757 1757 default : None
1758 1758 """
1759 1759 if not targets and not jobs:
1760 1760 raise ValueError("Must specify at least one of `targets` and `jobs`")
1761 1761 if targets:
1762 1762 targets = self._build_targets(targets)[1]
1763 1763
1764 1764 # construct msg_ids from jobs
1765 1765 if jobs == 'all':
1766 1766 msg_ids = jobs
1767 1767 else:
1768 1768 msg_ids = self._build_msgids_from_jobs(jobs)
1769 1769
1770 1770 content = dict(engine_ids=targets, msg_ids=msg_ids)
1771 1771 self.session.send(self._query_socket, "purge_request", content=content)
1772 1772 idents, msg = self.session.recv(self._query_socket, 0)
1773 1773 if self.debug:
1774 1774 pprint(msg)
1775 1775 content = msg['content']
1776 1776 if content['status'] != 'ok':
1777 1777 raise self._unwrap_exception(content)
1778 1778
1779 1779 def purge_results(self, jobs=[], targets=[]):
1780 1780 """Clears the cached results from both the hub and the local client
1781 1781
1782 1782 Individual results can be purged by msg_id, or the entire
1783 1783 history of specific targets can be purged.
1784 1784
1785 1785 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1786 1786 the Client's db.
1787 1787
1788 1788 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1789 1789 the same arguments.
1790 1790
1791 1791 Parameters
1792 1792 ----------
1793 1793
1794 1794 jobs : str or list of str or AsyncResult objects
1795 1795 the msg_ids whose results should be forgotten.
1796 1796 targets : int/str/list of ints/strs
1797 1797 The targets, by int_id, whose entire history is to be purged.
1798 1798
1799 1799 default : None
1800 1800 """
1801 1801 self.purge_local_results(jobs=jobs, targets=targets)
1802 1802 self.purge_hub_results(jobs=jobs, targets=targets)
1803 1803
1804 1804 def purge_everything(self):
1805 1805 """Clears all content from previous Tasks from both the hub and the local client
1806 1806
1807 1807 In addition to calling `purge_results("all")` it also deletes the history and
1808 1808 other bookkeeping lists.
1809 1809 """
1810 1810 self.purge_results("all")
1811 1811 self.history = []
1812 1812 self.session.digest_history.clear()
1813 1813
1814 1814 @spin_first
1815 1815 def hub_history(self):
1816 1816 """Get the Hub's history
1817 1817
1818 1818 Just like the Client, the Hub has a history, which is a list of msg_ids.
1819 1819 This will contain the history of all clients, and, depending on configuration,
1820 1820 may contain history across multiple cluster sessions.
1821 1821
1822 1822 Any msg_id returned here is a valid argument to `get_result`.
1823 1823
1824 1824 Returns
1825 1825 -------
1826 1826
1827 1827 msg_ids : list of strs
1828 1828 list of all msg_ids, ordered by task submission time.
1829 1829 """
1830 1830
1831 1831 self.session.send(self._query_socket, "history_request", content={})
1832 1832 idents, msg = self.session.recv(self._query_socket, 0)
1833 1833
1834 1834 if self.debug:
1835 1835 pprint(msg)
1836 1836 content = msg['content']
1837 1837 if content['status'] != 'ok':
1838 1838 raise self._unwrap_exception(content)
1839 1839 else:
1840 1840 return content['history']
1841 1841
1842 1842 @spin_first
1843 1843 def db_query(self, query, keys=None):
1844 1844 """Query the Hub's TaskRecord database
1845 1845
1846 1846 This will return a list of task record dicts that match `query`
1847 1847
1848 1848 Parameters
1849 1849 ----------
1850 1850
1851 1851 query : mongodb query dict
1852 1852 The search dict. See mongodb query docs for details.
1853 1853 keys : list of strs [optional]
1854 1854 The subset of keys to be returned. The default is to fetch everything but buffers.
1855 1855 'msg_id' will *always* be included.
1856 1856 """
1857 1857 if isinstance(keys, string_types):
1858 1858 keys = [keys]
1859 1859 content = dict(query=query, keys=keys)
1860 1860 self.session.send(self._query_socket, "db_request", content=content)
1861 1861 idents, msg = self.session.recv(self._query_socket, 0)
1862 1862 if self.debug:
1863 1863 pprint(msg)
1864 1864 content = msg['content']
1865 1865 if content['status'] != 'ok':
1866 1866 raise self._unwrap_exception(content)
1867 1867
1868 1868 records = content['records']
1869 1869
1870 1870 buffer_lens = content['buffer_lens']
1871 1871 result_buffer_lens = content['result_buffer_lens']
1872 1872 buffers = msg['buffers']
1873 1873 has_bufs = buffer_lens is not None
1874 1874 has_rbufs = result_buffer_lens is not None
1875 1875 for i,rec in enumerate(records):
1876 1876 # unpack datetime objects
1877 1877 for hkey in ('header', 'result_header'):
1878 1878 if hkey in rec:
1879 1879 rec[hkey] = extract_dates(rec[hkey])
1880 1880 for dtkey in ('submitted', 'started', 'completed', 'received'):
1881 1881 if dtkey in rec:
1882 1882 rec[dtkey] = parse_date(rec[dtkey])
1883 1883 # relink buffers
1884 1884 if has_bufs:
1885 1885 blen = buffer_lens[i]
1886 1886 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1887 1887 if has_rbufs:
1888 1888 blen = result_buffer_lens[i]
1889 1889 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1890 1890
1891 1891 return records
1892 1892
1893 1893 __all__ = [ 'Client' ]
@@ -1,276 +1,276 b''
1 1 """Remote Functions and decorators for Views."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import division
7 7
8 8 import sys
9 9 import warnings
10 10
11 from IPython.external.decorator import decorator
11 from decorator import decorator
12 12 from IPython.testing.skipdoctest import skip_doctest
13 13
14 14 from . import map as Map
15 15 from .asyncresult import AsyncMapResult
16 16
17 17 #-----------------------------------------------------------------------------
18 18 # Functions and Decorators
19 19 #-----------------------------------------------------------------------------
20 20
21 21 @skip_doctest
22 22 def remote(view, block=None, **flags):
23 23 """Turn a function into a remote function.
24 24
25 25 This method can be used for map:
26 26
27 27 In [1]: @remote(view,block=True)
28 28 ...: def func(a):
29 29 ...: pass
30 30 """
31 31
32 32 def remote_function(f):
33 33 return RemoteFunction(view, f, block=block, **flags)
34 34 return remote_function
35 35
36 36 @skip_doctest
37 37 def parallel(view, dist='b', block=None, ordered=True, **flags):
38 38 """Turn a function into a parallel remote function.
39 39
40 40 This method can be used for map:
41 41
42 42 In [1]: @parallel(view, block=True)
43 43 ...: def func(a):
44 44 ...: pass
45 45 """
46 46
47 47 def parallel_function(f):
48 48 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
49 49 return parallel_function
50 50
51 51 def getname(f):
52 52 """Get the name of an object.
53 53
54 54 For use in case of callables that are not functions, and
55 55 thus may not have __name__ defined.
56 56
57 57 Order: f.__name__ > f.name > str(f)
58 58 """
59 59 try:
60 60 return f.__name__
61 61 except:
62 62 pass
63 63 try:
64 64 return f.name
65 65 except:
66 66 pass
67 67
68 68 return str(f)
69 69
70 70 @decorator
71 71 def sync_view_results(f, self, *args, **kwargs):
72 72 """sync relevant results from self.client to our results attribute.
73 73
74 74 This is a clone of view.sync_results, but for remote functions
75 75 """
76 76 view = self.view
77 77 if view._in_sync_results:
78 78 return f(self, *args, **kwargs)
79 79 view._in_sync_results = True
80 80 try:
81 81 ret = f(self, *args, **kwargs)
82 82 finally:
83 83 view._in_sync_results = False
84 84 view._sync_results()
85 85 return ret
86 86
87 87 #--------------------------------------------------------------------------
88 88 # Classes
89 89 #--------------------------------------------------------------------------
90 90
91 91 class RemoteFunction(object):
92 92 """Turn an existing function into a remote function.
93 93
94 94 Parameters
95 95 ----------
96 96
97 97 view : View instance
98 98 The view to be used for execution
99 99 f : callable
100 100 The function to be wrapped into a remote function
101 101 block : bool [default: None]
102 102 Whether to wait for results or not. The default behavior is
103 103 to use the current `block` attribute of `view`
104 104
105 105 **flags : remaining kwargs are passed to View.temp_flags
106 106 """
107 107
108 108 view = None # the remote connection
109 109 func = None # the wrapped function
110 110 block = None # whether to block
111 111 flags = None # dict of extra kwargs for temp_flags
112 112
113 113 def __init__(self, view, f, block=None, **flags):
114 114 self.view = view
115 115 self.func = f
116 116 self.block=block
117 117 self.flags=flags
118 118
119 119 def __call__(self, *args, **kwargs):
120 120 block = self.view.block if self.block is None else self.block
121 121 with self.view.temp_flags(block=block, **self.flags):
122 122 return self.view.apply(self.func, *args, **kwargs)
123 123
124 124
125 125 class ParallelFunction(RemoteFunction):
126 126 """Class for mapping a function to sequences.
127 127
128 128 This will distribute the sequences according the a mapper, and call
129 129 the function on each sub-sequence. If called via map, then the function
130 130 will be called once on each element, rather that each sub-sequence.
131 131
132 132 Parameters
133 133 ----------
134 134
135 135 view : View instance
136 136 The view to be used for execution
137 137 f : callable
138 138 The function to be wrapped into a remote function
139 139 dist : str [default: 'b']
140 140 The key for which mapObject to use to distribute sequences
141 141 options are:
142 142
143 143 * 'b' : use contiguous chunks in order
144 144 * 'r' : use round-robin striping
145 145
146 146 block : bool [default: None]
147 147 Whether to wait for results or not. The default behavior is
148 148 to use the current `block` attribute of `view`
149 149 chunksize : int or None
150 150 The size of chunk to use when breaking up sequences in a load-balanced manner
151 151 ordered : bool [default: True]
152 152 Whether the result should be kept in order. If False,
153 153 results become available as they arrive, regardless of submission order.
154 154 **flags
155 155 remaining kwargs are passed to View.temp_flags
156 156 """
157 157
158 158 chunksize = None
159 159 ordered = None
160 160 mapObject = None
161 161 _mapping = False
162 162
163 163 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
164 164 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
165 165 self.chunksize = chunksize
166 166 self.ordered = ordered
167 167
168 168 mapClass = Map.dists[dist]
169 169 self.mapObject = mapClass()
170 170
171 171 @sync_view_results
172 172 def __call__(self, *sequences):
173 173 client = self.view.client
174 174
175 175 lens = []
176 176 maxlen = minlen = -1
177 177 for i, seq in enumerate(sequences):
178 178 try:
179 179 n = len(seq)
180 180 except Exception:
181 181 seq = list(seq)
182 182 if isinstance(sequences, tuple):
183 183 # can't alter a tuple
184 184 sequences = list(sequences)
185 185 sequences[i] = seq
186 186 n = len(seq)
187 187 if n > maxlen:
188 188 maxlen = n
189 189 if minlen == -1 or n < minlen:
190 190 minlen = n
191 191 lens.append(n)
192 192
193 193 if maxlen == 0:
194 194 # nothing to iterate over
195 195 return []
196 196
197 197 # check that the length of sequences match
198 198 if not self._mapping and minlen != maxlen:
199 199 msg = 'all sequences must have equal length, but have %s' % lens
200 200 raise ValueError(msg)
201 201
202 202 balanced = 'Balanced' in self.view.__class__.__name__
203 203 if balanced:
204 204 if self.chunksize:
205 205 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
206 206 else:
207 207 nparts = maxlen
208 208 targets = [None]*nparts
209 209 else:
210 210 if self.chunksize:
211 211 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
212 212 # multiplexed:
213 213 targets = self.view.targets
214 214 # 'all' is lazily evaluated at execution time, which is now:
215 215 if targets == 'all':
216 216 targets = client._build_targets(targets)[1]
217 217 elif isinstance(targets, int):
218 218 # single-engine view, targets must be iterable
219 219 targets = [targets]
220 220 nparts = len(targets)
221 221
222 222 msg_ids = []
223 223 for index, t in enumerate(targets):
224 224 args = []
225 225 for seq in sequences:
226 226 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
227 227 args.append(part)
228 228
229 229 if sum([len(arg) for arg in args]) == 0:
230 230 continue
231 231
232 232 if self._mapping:
233 233 if sys.version_info[0] >= 3:
234 234 f = lambda f, *sequences: list(map(f, *sequences))
235 235 else:
236 236 f = map
237 237 args = [self.func] + args
238 238 else:
239 239 f=self.func
240 240
241 241 view = self.view if balanced else client[t]
242 242 with view.temp_flags(block=False, **self.flags):
243 243 ar = view.apply(f, *args)
244 244
245 245 msg_ids.extend(ar.msg_ids)
246 246
247 247 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
248 248 fname=getname(self.func),
249 249 ordered=self.ordered
250 250 )
251 251
252 252 if self.block:
253 253 try:
254 254 return r.get()
255 255 except KeyboardInterrupt:
256 256 return r
257 257 else:
258 258 return r
259 259
260 260 def map(self, *sequences):
261 261 """call a function on each element of one or more sequence(s) remotely.
262 262 This should behave very much like the builtin map, but return an AsyncMapResult
263 263 if self.block is False.
264 264
265 265 That means it can take generators (will be cast to lists locally),
266 266 and mismatched sequence lengths will be padded with None.
267 267 """
268 268 # set _mapping as a flag for use inside self.__call__
269 269 self._mapping = True
270 270 try:
271 271 ret = self(*sequences)
272 272 finally:
273 273 self._mapping = False
274 274 return ret
275 275
276 276 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,1125 +1,1125 b''
1 1 """Views of remote engines."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import print_function
7 7
8 8 import imp
9 9 import sys
10 10 import warnings
11 11 from contextlib import contextmanager
12 12 from types import ModuleType
13 13
14 14 import zmq
15 15
16 16 from IPython.testing.skipdoctest import skip_doctest
17 17 from IPython.utils import pickleutil
18 18 from IPython.utils.traitlets import (
19 19 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
20 20 )
21 from IPython.external.decorator import decorator
21 from decorator import decorator
22 22
23 23 from IPython.parallel import util
24 24 from IPython.parallel.controller.dependency import Dependency, dependent
25 25 from IPython.utils.py3compat import string_types, iteritems, PY3
26 26
27 27 from . import map as Map
28 28 from .asyncresult import AsyncResult, AsyncMapResult
29 29 from .remotefunction import ParallelFunction, parallel, remote, getname
30 30
31 31 #-----------------------------------------------------------------------------
32 32 # Decorators
33 33 #-----------------------------------------------------------------------------
34 34
35 35 @decorator
36 36 def save_ids(f, self, *args, **kwargs):
37 37 """Keep our history and outstanding attributes up to date after a method call."""
38 38 n_previous = len(self.client.history)
39 39 try:
40 40 ret = f(self, *args, **kwargs)
41 41 finally:
42 42 nmsgs = len(self.client.history) - n_previous
43 43 msg_ids = self.client.history[-nmsgs:]
44 44 self.history.extend(msg_ids)
45 45 self.outstanding.update(msg_ids)
46 46 return ret
47 47
48 48 @decorator
49 49 def sync_results(f, self, *args, **kwargs):
50 50 """sync relevant results from self.client to our results attribute."""
51 51 if self._in_sync_results:
52 52 return f(self, *args, **kwargs)
53 53 self._in_sync_results = True
54 54 try:
55 55 ret = f(self, *args, **kwargs)
56 56 finally:
57 57 self._in_sync_results = False
58 58 self._sync_results()
59 59 return ret
60 60
61 61 @decorator
62 62 def spin_after(f, self, *args, **kwargs):
63 63 """call spin after the method."""
64 64 ret = f(self, *args, **kwargs)
65 65 self.spin()
66 66 return ret
67 67
68 68 #-----------------------------------------------------------------------------
69 69 # Classes
70 70 #-----------------------------------------------------------------------------
71 71
72 72 @skip_doctest
73 73 class View(HasTraits):
74 74 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
75 75
76 76 Don't use this class, use subclasses.
77 77
78 78 Methods
79 79 -------
80 80
81 81 spin
82 82 flushes incoming results and registration state changes
83 83 control methods spin, and requesting `ids` also ensures up to date
84 84
85 85 wait
86 86 wait on one or more msg_ids
87 87
88 88 execution methods
89 89 apply
90 90 legacy: execute, run
91 91
92 92 data movement
93 93 push, pull, scatter, gather
94 94
95 95 query methods
96 96 get_result, queue_status, purge_results, result_status
97 97
98 98 control methods
99 99 abort, shutdown
100 100
101 101 """
102 102 # flags
103 103 block=Bool(False)
104 104 track=Bool(True)
105 105 targets = Any()
106 106
107 107 history=List()
108 108 outstanding = Set()
109 109 results = Dict()
110 110 client = Instance('IPython.parallel.Client')
111 111
112 112 _socket = Instance('zmq.Socket')
113 113 _flag_names = List(['targets', 'block', 'track'])
114 114 _in_sync_results = Bool(False)
115 115 _targets = Any()
116 116 _idents = Any()
117 117
118 118 def __init__(self, client=None, socket=None, **flags):
119 119 super(View, self).__init__(client=client, _socket=socket)
120 120 self.results = client.results
121 121 self.block = client.block
122 122
123 123 self.set_flags(**flags)
124 124
125 125 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
126 126
127 127 def __repr__(self):
128 128 strtargets = str(self.targets)
129 129 if len(strtargets) > 16:
130 130 strtargets = strtargets[:12]+'...]'
131 131 return "<%s %s>"%(self.__class__.__name__, strtargets)
132 132
133 133 def __len__(self):
134 134 if isinstance(self.targets, list):
135 135 return len(self.targets)
136 136 elif isinstance(self.targets, int):
137 137 return 1
138 138 else:
139 139 return len(self.client)
140 140
141 141 def set_flags(self, **kwargs):
142 142 """set my attribute flags by keyword.
143 143
144 144 Views determine behavior with a few attributes (`block`, `track`, etc.).
145 145 These attributes can be set all at once by name with this method.
146 146
147 147 Parameters
148 148 ----------
149 149
150 150 block : bool
151 151 whether to wait for results
152 152 track : bool
153 153 whether to create a MessageTracker to allow the user to
154 154 safely edit after arrays and buffers during non-copying
155 155 sends.
156 156 """
157 157 for name, value in iteritems(kwargs):
158 158 if name not in self._flag_names:
159 159 raise KeyError("Invalid name: %r"%name)
160 160 else:
161 161 setattr(self, name, value)
162 162
163 163 @contextmanager
164 164 def temp_flags(self, **kwargs):
165 165 """temporarily set flags, for use in `with` statements.
166 166
167 167 See set_flags for permanent setting of flags
168 168
169 169 Examples
170 170 --------
171 171
172 172 >>> view.track=False
173 173 ...
174 174 >>> with view.temp_flags(track=True):
175 175 ... ar = view.apply(dostuff, my_big_array)
176 176 ... ar.tracker.wait() # wait for send to finish
177 177 >>> view.track
178 178 False
179 179
180 180 """
181 181 # preflight: save flags, and set temporaries
182 182 saved_flags = {}
183 183 for f in self._flag_names:
184 184 saved_flags[f] = getattr(self, f)
185 185 self.set_flags(**kwargs)
186 186 # yield to the with-statement block
187 187 try:
188 188 yield
189 189 finally:
190 190 # postflight: restore saved flags
191 191 self.set_flags(**saved_flags)
192 192
193 193
194 194 #----------------------------------------------------------------
195 195 # apply
196 196 #----------------------------------------------------------------
197 197
198 198 def _sync_results(self):
199 199 """to be called by @sync_results decorator
200 200
201 201 after submitting any tasks.
202 202 """
203 203 delta = self.outstanding.difference(self.client.outstanding)
204 204 completed = self.outstanding.intersection(delta)
205 205 self.outstanding = self.outstanding.difference(completed)
206 206
207 207 @sync_results
208 208 @save_ids
209 209 def _really_apply(self, f, args, kwargs, block=None, **options):
210 210 """wrapper for client.send_apply_request"""
211 211 raise NotImplementedError("Implement in subclasses")
212 212
213 213 def apply(self, f, *args, **kwargs):
214 214 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
215 215
216 216 This method sets all apply flags via this View's attributes.
217 217
218 218 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
219 219 instance if ``self.block`` is False, otherwise the return value of
220 220 ``f(*args, **kwargs)``.
221 221 """
222 222 return self._really_apply(f, args, kwargs)
223 223
224 224 def apply_async(self, f, *args, **kwargs):
225 225 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
226 226
227 227 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
228 228 """
229 229 return self._really_apply(f, args, kwargs, block=False)
230 230
231 231 @spin_after
232 232 def apply_sync(self, f, *args, **kwargs):
233 233 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
234 234 returning the result.
235 235 """
236 236 return self._really_apply(f, args, kwargs, block=True)
237 237
238 238 #----------------------------------------------------------------
239 239 # wrappers for client and control methods
240 240 #----------------------------------------------------------------
241 241 @sync_results
242 242 def spin(self):
243 243 """spin the client, and sync"""
244 244 self.client.spin()
245 245
246 246 @sync_results
247 247 def wait(self, jobs=None, timeout=-1):
248 248 """waits on one or more `jobs`, for up to `timeout` seconds.
249 249
250 250 Parameters
251 251 ----------
252 252
253 253 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
254 254 ints are indices to self.history
255 255 strs are msg_ids
256 256 default: wait on all outstanding messages
257 257 timeout : float
258 258 a time in seconds, after which to give up.
259 259 default is -1, which means no timeout
260 260
261 261 Returns
262 262 -------
263 263
264 264 True : when all msg_ids are done
265 265 False : timeout reached, some msg_ids still outstanding
266 266 """
267 267 if jobs is None:
268 268 jobs = self.history
269 269 return self.client.wait(jobs, timeout)
270 270
271 271 def abort(self, jobs=None, targets=None, block=None):
272 272 """Abort jobs on my engines.
273 273
274 274 Parameters
275 275 ----------
276 276
277 277 jobs : None, str, list of strs, optional
278 278 if None: abort all jobs.
279 279 else: abort specific msg_id(s).
280 280 """
281 281 block = block if block is not None else self.block
282 282 targets = targets if targets is not None else self.targets
283 283 jobs = jobs if jobs is not None else list(self.outstanding)
284 284
285 285 return self.client.abort(jobs=jobs, targets=targets, block=block)
286 286
287 287 def queue_status(self, targets=None, verbose=False):
288 288 """Fetch the Queue status of my engines"""
289 289 targets = targets if targets is not None else self.targets
290 290 return self.client.queue_status(targets=targets, verbose=verbose)
291 291
292 292 def purge_results(self, jobs=[], targets=[]):
293 293 """Instruct the controller to forget specific results."""
294 294 if targets is None or targets == 'all':
295 295 targets = self.targets
296 296 return self.client.purge_results(jobs=jobs, targets=targets)
297 297
298 298 def shutdown(self, targets=None, restart=False, hub=False, block=None):
299 299 """Terminates one or more engine processes, optionally including the hub.
300 300 """
301 301 block = self.block if block is None else block
302 302 if targets is None or targets == 'all':
303 303 targets = self.targets
304 304 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
305 305
306 306 @spin_after
307 307 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
308 308 """return one or more results, specified by history index or msg_id.
309 309
310 310 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
311 311 """
312 312
313 313 if indices_or_msg_ids is None:
314 314 indices_or_msg_ids = -1
315 315 if isinstance(indices_or_msg_ids, int):
316 316 indices_or_msg_ids = self.history[indices_or_msg_ids]
317 317 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
318 318 indices_or_msg_ids = list(indices_or_msg_ids)
319 319 for i,index in enumerate(indices_or_msg_ids):
320 320 if isinstance(index, int):
321 321 indices_or_msg_ids[i] = self.history[index]
322 322 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
323 323
324 324 #-------------------------------------------------------------------
325 325 # Map
326 326 #-------------------------------------------------------------------
327 327
328 328 @sync_results
329 329 def map(self, f, *sequences, **kwargs):
330 330 """override in subclasses"""
331 331 raise NotImplementedError
332 332
333 333 def map_async(self, f, *sequences, **kwargs):
334 334 """Parallel version of builtin :func:`python:map`, using this view's engines.
335 335
336 336 This is equivalent to ``map(...block=False)``.
337 337
338 338 See `self.map` for details.
339 339 """
340 340 if 'block' in kwargs:
341 341 raise TypeError("map_async doesn't take a `block` keyword argument.")
342 342 kwargs['block'] = False
343 343 return self.map(f,*sequences,**kwargs)
344 344
345 345 def map_sync(self, f, *sequences, **kwargs):
346 346 """Parallel version of builtin :func:`python:map`, using this view's engines.
347 347
348 348 This is equivalent to ``map(...block=True)``.
349 349
350 350 See `self.map` for details.
351 351 """
352 352 if 'block' in kwargs:
353 353 raise TypeError("map_sync doesn't take a `block` keyword argument.")
354 354 kwargs['block'] = True
355 355 return self.map(f,*sequences,**kwargs)
356 356
357 357 def imap(self, f, *sequences, **kwargs):
358 358 """Parallel version of :func:`itertools.imap`.
359 359
360 360 See `self.map` for details.
361 361
362 362 """
363 363
364 364 return iter(self.map_async(f,*sequences, **kwargs))
365 365
366 366 #-------------------------------------------------------------------
367 367 # Decorators
368 368 #-------------------------------------------------------------------
369 369
370 370 def remote(self, block=None, **flags):
371 371 """Decorator for making a RemoteFunction"""
372 372 block = self.block if block is None else block
373 373 return remote(self, block=block, **flags)
374 374
375 375 def parallel(self, dist='b', block=None, **flags):
376 376 """Decorator for making a ParallelFunction"""
377 377 block = self.block if block is None else block
378 378 return parallel(self, dist=dist, block=block, **flags)
379 379
380 380 @skip_doctest
381 381 class DirectView(View):
382 382 """Direct Multiplexer View of one or more engines.
383 383
384 384 These are created via indexed access to a client:
385 385
386 386 >>> dv_1 = client[1]
387 387 >>> dv_all = client[:]
388 388 >>> dv_even = client[::2]
389 389 >>> dv_some = client[1:3]
390 390
391 391 This object provides dictionary access to engine namespaces:
392 392
393 393 # push a=5:
394 394 >>> dv['a'] = 5
395 395 # pull 'foo':
396 396 >>> dv['foo']
397 397
398 398 """
399 399
400 400 def __init__(self, client=None, socket=None, targets=None):
401 401 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
402 402
403 403 @property
404 404 def importer(self):
405 405 """sync_imports(local=True) as a property.
406 406
407 407 See sync_imports for details.
408 408
409 409 """
410 410 return self.sync_imports(True)
411 411
412 412 @contextmanager
413 413 def sync_imports(self, local=True, quiet=False):
414 414 """Context Manager for performing simultaneous local and remote imports.
415 415
416 416 'import x as y' will *not* work. The 'as y' part will simply be ignored.
417 417
418 418 If `local=True`, then the package will also be imported locally.
419 419
420 420 If `quiet=True`, no output will be produced when attempting remote
421 421 imports.
422 422
423 423 Note that remote-only (`local=False`) imports have not been implemented.
424 424
425 425 >>> with view.sync_imports():
426 426 ... from numpy import recarray
427 427 importing recarray from numpy on engine(s)
428 428
429 429 """
430 430 from IPython.utils.py3compat import builtin_mod
431 431 local_import = builtin_mod.__import__
432 432 modules = set()
433 433 results = []
434 434 @util.interactive
435 435 def remote_import(name, fromlist, level):
436 436 """the function to be passed to apply, that actually performs the import
437 437 on the engine, and loads up the user namespace.
438 438 """
439 439 import sys
440 440 user_ns = globals()
441 441 mod = __import__(name, fromlist=fromlist, level=level)
442 442 if fromlist:
443 443 for key in fromlist:
444 444 user_ns[key] = getattr(mod, key)
445 445 else:
446 446 user_ns[name] = sys.modules[name]
447 447
448 448 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
449 449 """the drop-in replacement for __import__, that optionally imports
450 450 locally as well.
451 451 """
452 452 # don't override nested imports
453 453 save_import = builtin_mod.__import__
454 454 builtin_mod.__import__ = local_import
455 455
456 456 if imp.lock_held():
457 457 # this is a side-effect import, don't do it remotely, or even
458 458 # ignore the local effects
459 459 return local_import(name, globals, locals, fromlist, level)
460 460
461 461 imp.acquire_lock()
462 462 if local:
463 463 mod = local_import(name, globals, locals, fromlist, level)
464 464 else:
465 465 raise NotImplementedError("remote-only imports not yet implemented")
466 466 imp.release_lock()
467 467
468 468 key = name+':'+','.join(fromlist or [])
469 469 if level <= 0 and key not in modules:
470 470 modules.add(key)
471 471 if not quiet:
472 472 if fromlist:
473 473 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
474 474 else:
475 475 print("importing %s on engine(s)"%name)
476 476 results.append(self.apply_async(remote_import, name, fromlist, level))
477 477 # restore override
478 478 builtin_mod.__import__ = save_import
479 479
480 480 return mod
481 481
482 482 # override __import__
483 483 builtin_mod.__import__ = view_import
484 484 try:
485 485 # enter the block
486 486 yield
487 487 except ImportError:
488 488 if local:
489 489 raise
490 490 else:
491 491 # ignore import errors if not doing local imports
492 492 pass
493 493 finally:
494 494 # always restore __import__
495 495 builtin_mod.__import__ = local_import
496 496
497 497 for r in results:
498 498 # raise possible remote ImportErrors here
499 499 r.get()
500 500
501 501 def use_dill(self):
502 502 """Expand serialization support with dill
503 503
504 504 adds support for closures, etc.
505 505
506 506 This calls IPython.utils.pickleutil.use_dill() here and on each engine.
507 507 """
508 508 pickleutil.use_dill()
509 509 return self.apply(pickleutil.use_dill)
510 510
511 511 def use_cloudpickle(self):
512 512 """Expand serialization support with cloudpickle.
513 513 """
514 514 pickleutil.use_cloudpickle()
515 515 return self.apply(pickleutil.use_cloudpickle)
516 516
517 517
518 518 @sync_results
519 519 @save_ids
520 520 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
521 521 """calls f(*args, **kwargs) on remote engines, returning the result.
522 522
523 523 This method sets all of `apply`'s flags via this View's attributes.
524 524
525 525 Parameters
526 526 ----------
527 527
528 528 f : callable
529 529
530 530 args : list [default: empty]
531 531
532 532 kwargs : dict [default: empty]
533 533
534 534 targets : target list [default: self.targets]
535 535 where to run
536 536 block : bool [default: self.block]
537 537 whether to block
538 538 track : bool [default: self.track]
539 539 whether to ask zmq to track the message, for safe non-copying sends
540 540
541 541 Returns
542 542 -------
543 543
544 544 if self.block is False:
545 545 returns AsyncResult
546 546 else:
547 547 returns actual result of f(*args, **kwargs) on the engine(s)
548 548 This will be a list of self.targets is also a list (even length 1), or
549 549 the single result if self.targets is an integer engine id
550 550 """
551 551 args = [] if args is None else args
552 552 kwargs = {} if kwargs is None else kwargs
553 553 block = self.block if block is None else block
554 554 track = self.track if track is None else track
555 555 targets = self.targets if targets is None else targets
556 556
557 557 _idents, _targets = self.client._build_targets(targets)
558 558 msg_ids = []
559 559 trackers = []
560 560 for ident in _idents:
561 561 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
562 562 ident=ident)
563 563 if track:
564 564 trackers.append(msg['tracker'])
565 565 msg_ids.append(msg['header']['msg_id'])
566 566 if isinstance(targets, int):
567 567 msg_ids = msg_ids[0]
568 568 tracker = None if track is False else zmq.MessageTracker(*trackers)
569 569 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
570 570 tracker=tracker, owner=True,
571 571 )
572 572 if block:
573 573 try:
574 574 return ar.get()
575 575 except KeyboardInterrupt:
576 576 pass
577 577 return ar
578 578
579 579
580 580 @sync_results
581 581 def map(self, f, *sequences, **kwargs):
582 582 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
583 583
584 584 Parallel version of builtin `map`, using this View's `targets`.
585 585
586 586 There will be one task per target, so work will be chunked
587 587 if the sequences are longer than `targets`.
588 588
589 589 Results can be iterated as they are ready, but will become available in chunks.
590 590
591 591 Parameters
592 592 ----------
593 593
594 594 f : callable
595 595 function to be mapped
596 596 *sequences: one or more sequences of matching length
597 597 the sequences to be distributed and passed to `f`
598 598 block : bool
599 599 whether to wait for the result or not [default self.block]
600 600
601 601 Returns
602 602 -------
603 603
604 604
605 605 If block=False
606 606 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
607 607 An object like AsyncResult, but which reassembles the sequence of results
608 608 into a single list. AsyncMapResults can be iterated through before all
609 609 results are complete.
610 610 else
611 611 A list, the result of ``map(f,*sequences)``
612 612 """
613 613
614 614 block = kwargs.pop('block', self.block)
615 615 for k in kwargs.keys():
616 616 if k not in ['block', 'track']:
617 617 raise TypeError("invalid keyword arg, %r"%k)
618 618
619 619 assert len(sequences) > 0, "must have some sequences to map onto!"
620 620 pf = ParallelFunction(self, f, block=block, **kwargs)
621 621 return pf.map(*sequences)
622 622
623 623 @sync_results
624 624 @save_ids
625 625 def execute(self, code, silent=True, targets=None, block=None):
626 626 """Executes `code` on `targets` in blocking or nonblocking manner.
627 627
628 628 ``execute`` is always `bound` (affects engine namespace)
629 629
630 630 Parameters
631 631 ----------
632 632
633 633 code : str
634 634 the code string to be executed
635 635 block : bool
636 636 whether or not to wait until done to return
637 637 default: self.block
638 638 """
639 639 block = self.block if block is None else block
640 640 targets = self.targets if targets is None else targets
641 641
642 642 _idents, _targets = self.client._build_targets(targets)
643 643 msg_ids = []
644 644 trackers = []
645 645 for ident in _idents:
646 646 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
647 647 msg_ids.append(msg['header']['msg_id'])
648 648 if isinstance(targets, int):
649 649 msg_ids = msg_ids[0]
650 650 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
651 651 if block:
652 652 try:
653 653 ar.get()
654 654 except KeyboardInterrupt:
655 655 pass
656 656 return ar
657 657
658 658 def run(self, filename, targets=None, block=None):
659 659 """Execute contents of `filename` on my engine(s).
660 660
661 661 This simply reads the contents of the file and calls `execute`.
662 662
663 663 Parameters
664 664 ----------
665 665
666 666 filename : str
667 667 The path to the file
668 668 targets : int/str/list of ints/strs
669 669 the engines on which to execute
670 670 default : all
671 671 block : bool
672 672 whether or not to wait until done
673 673 default: self.block
674 674
675 675 """
676 676 with open(filename, 'r') as f:
677 677 # add newline in case of trailing indented whitespace
678 678 # which will cause SyntaxError
679 679 code = f.read()+'\n'
680 680 return self.execute(code, block=block, targets=targets)
681 681
682 682 def update(self, ns):
683 683 """update remote namespace with dict `ns`
684 684
685 685 See `push` for details.
686 686 """
687 687 return self.push(ns, block=self.block, track=self.track)
688 688
689 689 def push(self, ns, targets=None, block=None, track=None):
690 690 """update remote namespace with dict `ns`
691 691
692 692 Parameters
693 693 ----------
694 694
695 695 ns : dict
696 696 dict of keys with which to update engine namespace(s)
697 697 block : bool [default : self.block]
698 698 whether to wait to be notified of engine receipt
699 699
700 700 """
701 701
702 702 block = block if block is not None else self.block
703 703 track = track if track is not None else self.track
704 704 targets = targets if targets is not None else self.targets
705 705 # applier = self.apply_sync if block else self.apply_async
706 706 if not isinstance(ns, dict):
707 707 raise TypeError("Must be a dict, not %s"%type(ns))
708 708 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
709 709
710 710 def get(self, key_s):
711 711 """get object(s) by `key_s` from remote namespace
712 712
713 713 see `pull` for details.
714 714 """
715 715 # block = block if block is not None else self.block
716 716 return self.pull(key_s, block=True)
717 717
718 718 def pull(self, names, targets=None, block=None):
719 719 """get object(s) by `name` from remote namespace
720 720
721 721 will return one object if it is a key.
722 722 can also take a list of keys, in which case it will return a list of objects.
723 723 """
724 724 block = block if block is not None else self.block
725 725 targets = targets if targets is not None else self.targets
726 726 applier = self.apply_sync if block else self.apply_async
727 727 if isinstance(names, string_types):
728 728 pass
729 729 elif isinstance(names, (list,tuple,set)):
730 730 for key in names:
731 731 if not isinstance(key, string_types):
732 732 raise TypeError("keys must be str, not type %r"%type(key))
733 733 else:
734 734 raise TypeError("names must be strs, not %r"%names)
735 735 return self._really_apply(util._pull, (names,), block=block, targets=targets)
736 736
737 737 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
738 738 """
739 739 Partition a Python sequence and send the partitions to a set of engines.
740 740 """
741 741 block = block if block is not None else self.block
742 742 track = track if track is not None else self.track
743 743 targets = targets if targets is not None else self.targets
744 744
745 745 # construct integer ID list:
746 746 targets = self.client._build_targets(targets)[1]
747 747
748 748 mapObject = Map.dists[dist]()
749 749 nparts = len(targets)
750 750 msg_ids = []
751 751 trackers = []
752 752 for index, engineid in enumerate(targets):
753 753 partition = mapObject.getPartition(seq, index, nparts)
754 754 if flatten and len(partition) == 1:
755 755 ns = {key: partition[0]}
756 756 else:
757 757 ns = {key: partition}
758 758 r = self.push(ns, block=False, track=track, targets=engineid)
759 759 msg_ids.extend(r.msg_ids)
760 760 if track:
761 761 trackers.append(r._tracker)
762 762
763 763 if track:
764 764 tracker = zmq.MessageTracker(*trackers)
765 765 else:
766 766 tracker = None
767 767
768 768 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
769 769 tracker=tracker, owner=True,
770 770 )
771 771 if block:
772 772 r.wait()
773 773 else:
774 774 return r
775 775
776 776 @sync_results
777 777 @save_ids
778 778 def gather(self, key, dist='b', targets=None, block=None):
779 779 """
780 780 Gather a partitioned sequence on a set of engines as a single local seq.
781 781 """
782 782 block = block if block is not None else self.block
783 783 targets = targets if targets is not None else self.targets
784 784 mapObject = Map.dists[dist]()
785 785 msg_ids = []
786 786
787 787 # construct integer ID list:
788 788 targets = self.client._build_targets(targets)[1]
789 789
790 790 for index, engineid in enumerate(targets):
791 791 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
792 792
793 793 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
794 794
795 795 if block:
796 796 try:
797 797 return r.get()
798 798 except KeyboardInterrupt:
799 799 pass
800 800 return r
801 801
802 802 def __getitem__(self, key):
803 803 return self.get(key)
804 804
805 805 def __setitem__(self,key, value):
806 806 self.update({key:value})
807 807
808 808 def clear(self, targets=None, block=None):
809 809 """Clear the remote namespaces on my engines."""
810 810 block = block if block is not None else self.block
811 811 targets = targets if targets is not None else self.targets
812 812 return self.client.clear(targets=targets, block=block)
813 813
814 814 #----------------------------------------
815 815 # activate for %px, %autopx, etc. magics
816 816 #----------------------------------------
817 817
818 818 def activate(self, suffix=''):
819 819 """Activate IPython magics associated with this View
820 820
821 821 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
822 822
823 823 Parameters
824 824 ----------
825 825
826 826 suffix: str [default: '']
827 827 The suffix, if any, for the magics. This allows you to have
828 828 multiple views associated with parallel magics at the same time.
829 829
830 830 e.g. ``rc[::2].activate(suffix='_even')`` will give you
831 831 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
832 832 on the even engines.
833 833 """
834 834
835 835 from IPython.parallel.client.magics import ParallelMagics
836 836
837 837 try:
838 838 # This is injected into __builtins__.
839 839 ip = get_ipython()
840 840 except NameError:
841 841 print("The IPython parallel magics (%px, etc.) only work within IPython.")
842 842 return
843 843
844 844 M = ParallelMagics(ip, self, suffix)
845 845 ip.magics_manager.register(M)
846 846
847 847
848 848 @skip_doctest
849 849 class LoadBalancedView(View):
850 850 """An load-balancing View that only executes via the Task scheduler.
851 851
852 852 Load-balanced views can be created with the client's `view` method:
853 853
854 854 >>> v = client.load_balanced_view()
855 855
856 856 or targets can be specified, to restrict the potential destinations:
857 857
858 858 >>> v = client.load_balanced_view([1,3])
859 859
860 860 which would restrict loadbalancing to between engines 1 and 3.
861 861
862 862 """
863 863
864 864 follow=Any()
865 865 after=Any()
866 866 timeout=CFloat()
867 867 retries = Integer(0)
868 868
869 869 _task_scheme = Any()
870 870 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
871 871
872 872 def __init__(self, client=None, socket=None, **flags):
873 873 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
874 874 self._task_scheme=client._task_scheme
875 875
876 876 def _validate_dependency(self, dep):
877 877 """validate a dependency.
878 878
879 879 For use in `set_flags`.
880 880 """
881 881 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
882 882 return True
883 883 elif isinstance(dep, (list,set, tuple)):
884 884 for d in dep:
885 885 if not isinstance(d, string_types + (AsyncResult,)):
886 886 return False
887 887 elif isinstance(dep, dict):
888 888 if set(dep.keys()) != set(Dependency().as_dict().keys()):
889 889 return False
890 890 if not isinstance(dep['msg_ids'], list):
891 891 return False
892 892 for d in dep['msg_ids']:
893 893 if not isinstance(d, string_types):
894 894 return False
895 895 else:
896 896 return False
897 897
898 898 return True
899 899
900 900 def _render_dependency(self, dep):
901 901 """helper for building jsonable dependencies from various input forms."""
902 902 if isinstance(dep, Dependency):
903 903 return dep.as_dict()
904 904 elif isinstance(dep, AsyncResult):
905 905 return dep.msg_ids
906 906 elif dep is None:
907 907 return []
908 908 else:
909 909 # pass to Dependency constructor
910 910 return list(Dependency(dep))
911 911
912 912 def set_flags(self, **kwargs):
913 913 """set my attribute flags by keyword.
914 914
915 915 A View is a wrapper for the Client's apply method, but with attributes
916 916 that specify keyword arguments, those attributes can be set by keyword
917 917 argument with this method.
918 918
919 919 Parameters
920 920 ----------
921 921
922 922 block : bool
923 923 whether to wait for results
924 924 track : bool
925 925 whether to create a MessageTracker to allow the user to
926 926 safely edit after arrays and buffers during non-copying
927 927 sends.
928 928
929 929 after : Dependency or collection of msg_ids
930 930 Only for load-balanced execution (targets=None)
931 931 Specify a list of msg_ids as a time-based dependency.
932 932 This job will only be run *after* the dependencies
933 933 have been met.
934 934
935 935 follow : Dependency or collection of msg_ids
936 936 Only for load-balanced execution (targets=None)
937 937 Specify a list of msg_ids as a location-based dependency.
938 938 This job will only be run on an engine where this dependency
939 939 is met.
940 940
941 941 timeout : float/int or None
942 942 Only for load-balanced execution (targets=None)
943 943 Specify an amount of time (in seconds) for the scheduler to
944 944 wait for dependencies to be met before failing with a
945 945 DependencyTimeout.
946 946
947 947 retries : int
948 948 Number of times a task will be retried on failure.
949 949 """
950 950
951 951 super(LoadBalancedView, self).set_flags(**kwargs)
952 952 for name in ('follow', 'after'):
953 953 if name in kwargs:
954 954 value = kwargs[name]
955 955 if self._validate_dependency(value):
956 956 setattr(self, name, value)
957 957 else:
958 958 raise ValueError("Invalid dependency: %r"%value)
959 959 if 'timeout' in kwargs:
960 960 t = kwargs['timeout']
961 961 if not isinstance(t, (int, float, type(None))):
962 962 if (not PY3) and (not isinstance(t, long)):
963 963 raise TypeError("Invalid type for timeout: %r"%type(t))
964 964 if t is not None:
965 965 if t < 0:
966 966 raise ValueError("Invalid timeout: %s"%t)
967 967 self.timeout = t
968 968
969 969 @sync_results
970 970 @save_ids
971 971 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
972 972 after=None, follow=None, timeout=None,
973 973 targets=None, retries=None):
974 974 """calls f(*args, **kwargs) on a remote engine, returning the result.
975 975
976 976 This method temporarily sets all of `apply`'s flags for a single call.
977 977
978 978 Parameters
979 979 ----------
980 980
981 981 f : callable
982 982
983 983 args : list [default: empty]
984 984
985 985 kwargs : dict [default: empty]
986 986
987 987 block : bool [default: self.block]
988 988 whether to block
989 989 track : bool [default: self.track]
990 990 whether to ask zmq to track the message, for safe non-copying sends
991 991
992 992 !!!!!! TODO: THE REST HERE !!!!
993 993
994 994 Returns
995 995 -------
996 996
997 997 if self.block is False:
998 998 returns AsyncResult
999 999 else:
1000 1000 returns actual result of f(*args, **kwargs) on the engine(s)
1001 1001 This will be a list of self.targets is also a list (even length 1), or
1002 1002 the single result if self.targets is an integer engine id
1003 1003 """
1004 1004
1005 1005 # validate whether we can run
1006 1006 if self._socket.closed:
1007 1007 msg = "Task farming is disabled"
1008 1008 if self._task_scheme == 'pure':
1009 1009 msg += " because the pure ZMQ scheduler cannot handle"
1010 1010 msg += " disappearing engines."
1011 1011 raise RuntimeError(msg)
1012 1012
1013 1013 if self._task_scheme == 'pure':
1014 1014 # pure zmq scheme doesn't support extra features
1015 1015 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1016 1016 "follow, after, retries, targets, timeout"
1017 1017 if (follow or after or retries or targets or timeout):
1018 1018 # hard fail on Scheduler flags
1019 1019 raise RuntimeError(msg)
1020 1020 if isinstance(f, dependent):
1021 1021 # soft warn on functional dependencies
1022 1022 warnings.warn(msg, RuntimeWarning)
1023 1023
1024 1024 # build args
1025 1025 args = [] if args is None else args
1026 1026 kwargs = {} if kwargs is None else kwargs
1027 1027 block = self.block if block is None else block
1028 1028 track = self.track if track is None else track
1029 1029 after = self.after if after is None else after
1030 1030 retries = self.retries if retries is None else retries
1031 1031 follow = self.follow if follow is None else follow
1032 1032 timeout = self.timeout if timeout is None else timeout
1033 1033 targets = self.targets if targets is None else targets
1034 1034
1035 1035 if not isinstance(retries, int):
1036 1036 raise TypeError('retries must be int, not %r'%type(retries))
1037 1037
1038 1038 if targets is None:
1039 1039 idents = []
1040 1040 else:
1041 1041 idents = self.client._build_targets(targets)[0]
1042 1042 # ensure *not* bytes
1043 1043 idents = [ ident.decode() for ident in idents ]
1044 1044
1045 1045 after = self._render_dependency(after)
1046 1046 follow = self._render_dependency(follow)
1047 1047 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1048 1048
1049 1049 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1050 1050 metadata=metadata)
1051 1051 tracker = None if track is False else msg['tracker']
1052 1052
1053 1053 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1054 1054 targets=None, tracker=tracker, owner=True,
1055 1055 )
1056 1056 if block:
1057 1057 try:
1058 1058 return ar.get()
1059 1059 except KeyboardInterrupt:
1060 1060 pass
1061 1061 return ar
1062 1062
1063 1063 @sync_results
1064 1064 @save_ids
1065 1065 def map(self, f, *sequences, **kwargs):
1066 1066 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1067 1067
1068 1068 Parallel version of builtin `map`, load-balanced by this View.
1069 1069
1070 1070 `block`, and `chunksize` can be specified by keyword only.
1071 1071
1072 1072 Each `chunksize` elements will be a separate task, and will be
1073 1073 load-balanced. This lets individual elements be available for iteration
1074 1074 as soon as they arrive.
1075 1075
1076 1076 Parameters
1077 1077 ----------
1078 1078
1079 1079 f : callable
1080 1080 function to be mapped
1081 1081 *sequences: one or more sequences of matching length
1082 1082 the sequences to be distributed and passed to `f`
1083 1083 block : bool [default self.block]
1084 1084 whether to wait for the result or not
1085 1085 track : bool
1086 1086 whether to create a MessageTracker to allow the user to
1087 1087 safely edit after arrays and buffers during non-copying
1088 1088 sends.
1089 1089 chunksize : int [default 1]
1090 1090 how many elements should be in each task.
1091 1091 ordered : bool [default True]
1092 1092 Whether the results should be gathered as they arrive, or enforce
1093 1093 the order of submission.
1094 1094
1095 1095 Only applies when iterating through AsyncMapResult as results arrive.
1096 1096 Has no effect when block=True.
1097 1097
1098 1098 Returns
1099 1099 -------
1100 1100
1101 1101 if block=False
1102 1102 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
1103 1103 An object like AsyncResult, but which reassembles the sequence of results
1104 1104 into a single list. AsyncMapResults can be iterated through before all
1105 1105 results are complete.
1106 1106 else
1107 1107 A list, the result of ``map(f,*sequences)``
1108 1108 """
1109 1109
1110 1110 # default
1111 1111 block = kwargs.get('block', self.block)
1112 1112 chunksize = kwargs.get('chunksize', 1)
1113 1113 ordered = kwargs.get('ordered', True)
1114 1114
1115 1115 keyset = set(kwargs.keys())
1116 1116 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1117 1117 if extra_keys:
1118 1118 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1119 1119
1120 1120 assert len(sequences) > 0, "must have some sequences to map onto!"
1121 1121
1122 1122 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1123 1123 return pf.map(*sequences)
1124 1124
1125 1125 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,849 +1,849 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6 """
7 7
8 8 # Copyright (c) IPython Development Team.
9 9 # Distributed under the terms of the Modified BSD License.
10 10
11 11 import logging
12 12 import sys
13 13 import time
14 14
15 15 from collections import deque
16 16 from datetime import datetime
17 17 from random import randint, random
18 18 from types import FunctionType
19 19
20 20 try:
21 21 import numpy
22 22 except ImportError:
23 23 numpy = None
24 24
25 25 import zmq
26 26 from zmq.eventloop import ioloop, zmqstream
27 27
28 28 # local imports
29 from IPython.external.decorator import decorator
29 from decorator import decorator
30 30 from IPython.config.application import Application
31 31 from IPython.config.loader import Config
32 32 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
33 33 from IPython.utils.py3compat import cast_bytes
34 34
35 35 from IPython.parallel import error, util
36 36 from IPython.parallel.factory import SessionFactory
37 37 from IPython.parallel.util import connect_logger, local_logger
38 38
39 39 from .dependency import Dependency
40 40
41 41 @decorator
42 42 def logged(f,self,*args,**kwargs):
43 43 # print ("#--------------------")
44 44 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
45 45 # print ("#--")
46 46 return f(self,*args, **kwargs)
47 47
48 48 #----------------------------------------------------------------------
49 49 # Chooser functions
50 50 #----------------------------------------------------------------------
51 51
52 52 def plainrandom(loads):
53 53 """Plain random pick."""
54 54 n = len(loads)
55 55 return randint(0,n-1)
56 56
57 57 def lru(loads):
58 58 """Always pick the front of the line.
59 59
60 60 The content of `loads` is ignored.
61 61
62 62 Assumes LRU ordering of loads, with oldest first.
63 63 """
64 64 return 0
65 65
66 66 def twobin(loads):
67 67 """Pick two at random, use the LRU of the two.
68 68
69 69 The content of loads is ignored.
70 70
71 71 Assumes LRU ordering of loads, with oldest first.
72 72 """
73 73 n = len(loads)
74 74 a = randint(0,n-1)
75 75 b = randint(0,n-1)
76 76 return min(a,b)
77 77
78 78 def weighted(loads):
79 79 """Pick two at random using inverse load as weight.
80 80
81 81 Return the less loaded of the two.
82 82 """
83 83 # weight 0 a million times more than 1:
84 84 weights = 1./(1e-6+numpy.array(loads))
85 85 sums = weights.cumsum()
86 86 t = sums[-1]
87 87 x = random()*t
88 88 y = random()*t
89 89 idx = 0
90 90 idy = 0
91 91 while sums[idx] < x:
92 92 idx += 1
93 93 while sums[idy] < y:
94 94 idy += 1
95 95 if weights[idy] > weights[idx]:
96 96 return idy
97 97 else:
98 98 return idx
99 99
100 100 def leastload(loads):
101 101 """Always choose the lowest load.
102 102
103 103 If the lowest load occurs more than once, the first
104 104 occurance will be used. If loads has LRU ordering, this means
105 105 the LRU of those with the lowest load is chosen.
106 106 """
107 107 return loads.index(min(loads))
108 108
109 109 #---------------------------------------------------------------------
110 110 # Classes
111 111 #---------------------------------------------------------------------
112 112
113 113
114 114 # store empty default dependency:
115 115 MET = Dependency([])
116 116
117 117
118 118 class Job(object):
119 119 """Simple container for a job"""
120 120 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
121 121 targets, after, follow, timeout):
122 122 self.msg_id = msg_id
123 123 self.raw_msg = raw_msg
124 124 self.idents = idents
125 125 self.msg = msg
126 126 self.header = header
127 127 self.metadata = metadata
128 128 self.targets = targets
129 129 self.after = after
130 130 self.follow = follow
131 131 self.timeout = timeout
132 132
133 133 self.removed = False # used for lazy-delete from sorted queue
134 134 self.timestamp = time.time()
135 135 self.timeout_id = 0
136 136 self.blacklist = set()
137 137
138 138 def __lt__(self, other):
139 139 return self.timestamp < other.timestamp
140 140
141 141 def __cmp__(self, other):
142 142 return cmp(self.timestamp, other.timestamp)
143 143
144 144 @property
145 145 def dependents(self):
146 146 return self.follow.union(self.after)
147 147
148 148
149 149 class TaskScheduler(SessionFactory):
150 150 """Python TaskScheduler object.
151 151
152 152 This is the simplest object that supports msg_id based
153 153 DAG dependencies. *Only* task msg_ids are checked, not
154 154 msg_ids of jobs submitted via the MUX queue.
155 155
156 156 """
157 157
158 158 hwm = Integer(1, config=True,
159 159 help="""specify the High Water Mark (HWM) for the downstream
160 160 socket in the Task scheduler. This is the maximum number
161 161 of allowed outstanding tasks on each engine.
162 162
163 163 The default (1) means that only one task can be outstanding on each
164 164 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
165 165 engines continue to be assigned tasks while they are working,
166 166 effectively hiding network latency behind computation, but can result
167 167 in an imbalance of work when submitting many heterogenous tasks all at
168 168 once. Any positive value greater than one is a compromise between the
169 169 two.
170 170
171 171 """
172 172 )
173 173 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
174 174 'leastload', config=True,
175 175 help="""select the task scheduler scheme [default: Python LRU]
176 176 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
177 177 )
178 178 def _scheme_name_changed(self, old, new):
179 179 self.log.debug("Using scheme %r"%new)
180 180 self.scheme = globals()[new]
181 181
182 182 # input arguments:
183 183 scheme = Instance(FunctionType) # function for determining the destination
184 184 def _scheme_default(self):
185 185 return leastload
186 186 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
187 187 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
188 188 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
189 189 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
190 190 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
191 191
192 192 # internals:
193 193 queue = Instance(deque) # sorted list of Jobs
194 194 def _queue_default(self):
195 195 return deque()
196 196 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
197 197 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
198 198 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
199 199 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
200 200 pending = Dict() # dict by engine_uuid of submitted tasks
201 201 completed = Dict() # dict by engine_uuid of completed tasks
202 202 failed = Dict() # dict by engine_uuid of failed tasks
203 203 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
204 204 clients = Dict() # dict by msg_id for who submitted the task
205 205 targets = List() # list of target IDENTs
206 206 loads = List() # list of engine loads
207 207 # full = Set() # set of IDENTs that have HWM outstanding tasks
208 208 all_completed = Set() # set of all completed tasks
209 209 all_failed = Set() # set of all failed tasks
210 210 all_done = Set() # set of all finished tasks=union(completed,failed)
211 211 all_ids = Set() # set of all submitted task IDs
212 212
213 213 ident = CBytes() # ZMQ identity. This should just be self.session.session
214 214 # but ensure Bytes
215 215 def _ident_default(self):
216 216 return self.session.bsession
217 217
218 218 def start(self):
219 219 self.query_stream.on_recv(self.dispatch_query_reply)
220 220 self.session.send(self.query_stream, "connection_request", {})
221 221
222 222 self.engine_stream.on_recv(self.dispatch_result, copy=False)
223 223 self.client_stream.on_recv(self.dispatch_submission, copy=False)
224 224
225 225 self._notification_handlers = dict(
226 226 registration_notification = self._register_engine,
227 227 unregistration_notification = self._unregister_engine
228 228 )
229 229 self.notifier_stream.on_recv(self.dispatch_notification)
230 230 self.log.info("Scheduler started [%s]" % self.scheme_name)
231 231
232 232 def resume_receiving(self):
233 233 """Resume accepting jobs."""
234 234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
235 235
236 236 def stop_receiving(self):
237 237 """Stop accepting jobs while there are no engines.
238 238 Leave them in the ZMQ queue."""
239 239 self.client_stream.on_recv(None)
240 240
241 241 #-----------------------------------------------------------------------
242 242 # [Un]Registration Handling
243 243 #-----------------------------------------------------------------------
244 244
245 245
246 246 def dispatch_query_reply(self, msg):
247 247 """handle reply to our initial connection request"""
248 248 try:
249 249 idents,msg = self.session.feed_identities(msg)
250 250 except ValueError:
251 251 self.log.warn("task::Invalid Message: %r",msg)
252 252 return
253 253 try:
254 254 msg = self.session.deserialize(msg)
255 255 except ValueError:
256 256 self.log.warn("task::Unauthorized message from: %r"%idents)
257 257 return
258 258
259 259 content = msg['content']
260 260 for uuid in content.get('engines', {}).values():
261 261 self._register_engine(cast_bytes(uuid))
262 262
263 263
264 264 @util.log_errors
265 265 def dispatch_notification(self, msg):
266 266 """dispatch register/unregister events."""
267 267 try:
268 268 idents,msg = self.session.feed_identities(msg)
269 269 except ValueError:
270 270 self.log.warn("task::Invalid Message: %r",msg)
271 271 return
272 272 try:
273 273 msg = self.session.deserialize(msg)
274 274 except ValueError:
275 275 self.log.warn("task::Unauthorized message from: %r"%idents)
276 276 return
277 277
278 278 msg_type = msg['header']['msg_type']
279 279
280 280 handler = self._notification_handlers.get(msg_type, None)
281 281 if handler is None:
282 282 self.log.error("Unhandled message type: %r"%msg_type)
283 283 else:
284 284 try:
285 285 handler(cast_bytes(msg['content']['uuid']))
286 286 except Exception:
287 287 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
288 288
289 289 def _register_engine(self, uid):
290 290 """New engine with ident `uid` became available."""
291 291 # head of the line:
292 292 self.targets.insert(0,uid)
293 293 self.loads.insert(0,0)
294 294
295 295 # initialize sets
296 296 self.completed[uid] = set()
297 297 self.failed[uid] = set()
298 298 self.pending[uid] = {}
299 299
300 300 # rescan the graph:
301 301 self.update_graph(None)
302 302
303 303 def _unregister_engine(self, uid):
304 304 """Existing engine with ident `uid` became unavailable."""
305 305 if len(self.targets) == 1:
306 306 # this was our only engine
307 307 pass
308 308
309 309 # handle any potentially finished tasks:
310 310 self.engine_stream.flush()
311 311
312 312 # don't pop destinations, because they might be used later
313 313 # map(self.destinations.pop, self.completed.pop(uid))
314 314 # map(self.destinations.pop, self.failed.pop(uid))
315 315
316 316 # prevent this engine from receiving work
317 317 idx = self.targets.index(uid)
318 318 self.targets.pop(idx)
319 319 self.loads.pop(idx)
320 320
321 321 # wait 5 seconds before cleaning up pending jobs, since the results might
322 322 # still be incoming
323 323 if self.pending[uid]:
324 324 self.loop.add_timeout(self.loop.time() + 5,
325 325 lambda : self.handle_stranded_tasks(uid),
326 326 )
327 327 else:
328 328 self.completed.pop(uid)
329 329 self.failed.pop(uid)
330 330
331 331
332 332 def handle_stranded_tasks(self, engine):
333 333 """Deal with jobs resident in an engine that died."""
334 334 lost = self.pending[engine]
335 335 for msg_id in lost.keys():
336 336 if msg_id not in self.pending[engine]:
337 337 # prevent double-handling of messages
338 338 continue
339 339
340 340 raw_msg = lost[msg_id].raw_msg
341 341 idents,msg = self.session.feed_identities(raw_msg, copy=False)
342 342 parent = self.session.unpack(msg[1].bytes)
343 343 idents = [engine, idents[0]]
344 344
345 345 # build fake error reply
346 346 try:
347 347 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
348 348 except:
349 349 content = error.wrap_exception()
350 350 # build fake metadata
351 351 md = dict(
352 352 status=u'error',
353 353 engine=engine.decode('ascii'),
354 354 date=datetime.now(),
355 355 )
356 356 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
357 357 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
358 358 # and dispatch it
359 359 self.dispatch_result(raw_reply)
360 360
361 361 # finally scrub completed/failed lists
362 362 self.completed.pop(engine)
363 363 self.failed.pop(engine)
364 364
365 365
366 366 #-----------------------------------------------------------------------
367 367 # Job Submission
368 368 #-----------------------------------------------------------------------
369 369
370 370
371 371 @util.log_errors
372 372 def dispatch_submission(self, raw_msg):
373 373 """Dispatch job submission to appropriate handlers."""
374 374 # ensure targets up to date:
375 375 self.notifier_stream.flush()
376 376 try:
377 377 idents, msg = self.session.feed_identities(raw_msg, copy=False)
378 378 msg = self.session.deserialize(msg, content=False, copy=False)
379 379 except Exception:
380 380 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
381 381 return
382 382
383 383
384 384 # send to monitor
385 385 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
386 386
387 387 header = msg['header']
388 388 md = msg['metadata']
389 389 msg_id = header['msg_id']
390 390 self.all_ids.add(msg_id)
391 391
392 392 # get targets as a set of bytes objects
393 393 # from a list of unicode objects
394 394 targets = md.get('targets', [])
395 395 targets = set(map(cast_bytes, targets))
396 396
397 397 retries = md.get('retries', 0)
398 398 self.retries[msg_id] = retries
399 399
400 400 # time dependencies
401 401 after = md.get('after', None)
402 402 if after:
403 403 after = Dependency(after)
404 404 if after.all:
405 405 if after.success:
406 406 after = Dependency(after.difference(self.all_completed),
407 407 success=after.success,
408 408 failure=after.failure,
409 409 all=after.all,
410 410 )
411 411 if after.failure:
412 412 after = Dependency(after.difference(self.all_failed),
413 413 success=after.success,
414 414 failure=after.failure,
415 415 all=after.all,
416 416 )
417 417 if after.check(self.all_completed, self.all_failed):
418 418 # recast as empty set, if `after` already met,
419 419 # to prevent unnecessary set comparisons
420 420 after = MET
421 421 else:
422 422 after = MET
423 423
424 424 # location dependencies
425 425 follow = Dependency(md.get('follow', []))
426 426
427 427 timeout = md.get('timeout', None)
428 428 if timeout:
429 429 timeout = float(timeout)
430 430
431 431 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
432 432 header=header, targets=targets, after=after, follow=follow,
433 433 timeout=timeout, metadata=md,
434 434 )
435 435 # validate and reduce dependencies:
436 436 for dep in after,follow:
437 437 if not dep: # empty dependency
438 438 continue
439 439 # check valid:
440 440 if msg_id in dep or dep.difference(self.all_ids):
441 441 self.queue_map[msg_id] = job
442 442 return self.fail_unreachable(msg_id, error.InvalidDependency)
443 443 # check if unreachable:
444 444 if dep.unreachable(self.all_completed, self.all_failed):
445 445 self.queue_map[msg_id] = job
446 446 return self.fail_unreachable(msg_id)
447 447
448 448 if after.check(self.all_completed, self.all_failed):
449 449 # time deps already met, try to run
450 450 if not self.maybe_run(job):
451 451 # can't run yet
452 452 if msg_id not in self.all_failed:
453 453 # could have failed as unreachable
454 454 self.save_unmet(job)
455 455 else:
456 456 self.save_unmet(job)
457 457
458 458 def job_timeout(self, job, timeout_id):
459 459 """callback for a job's timeout.
460 460
461 461 The job may or may not have been run at this point.
462 462 """
463 463 if job.timeout_id != timeout_id:
464 464 # not the most recent call
465 465 return
466 466 now = time.time()
467 467 if job.timeout >= (now + 1):
468 468 self.log.warn("task %s timeout fired prematurely: %s > %s",
469 469 job.msg_id, job.timeout, now
470 470 )
471 471 if job.msg_id in self.queue_map:
472 472 # still waiting, but ran out of time
473 473 self.log.info("task %r timed out", job.msg_id)
474 474 self.fail_unreachable(job.msg_id, error.TaskTimeout)
475 475
476 476 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
477 477 """a task has become unreachable, send a reply with an ImpossibleDependency
478 478 error."""
479 479 if msg_id not in self.queue_map:
480 480 self.log.error("task %r already failed!", msg_id)
481 481 return
482 482 job = self.queue_map.pop(msg_id)
483 483 # lazy-delete from the queue
484 484 job.removed = True
485 485 for mid in job.dependents:
486 486 if mid in self.graph:
487 487 self.graph[mid].remove(msg_id)
488 488
489 489 try:
490 490 raise why()
491 491 except:
492 492 content = error.wrap_exception()
493 493 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
494 494
495 495 self.all_done.add(msg_id)
496 496 self.all_failed.add(msg_id)
497 497
498 498 msg = self.session.send(self.client_stream, 'apply_reply', content,
499 499 parent=job.header, ident=job.idents)
500 500 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
501 501
502 502 self.update_graph(msg_id, success=False)
503 503
504 504 def available_engines(self):
505 505 """return a list of available engine indices based on HWM"""
506 506 if not self.hwm:
507 507 return list(range(len(self.targets)))
508 508 available = []
509 509 for idx in range(len(self.targets)):
510 510 if self.loads[idx] < self.hwm:
511 511 available.append(idx)
512 512 return available
513 513
514 514 def maybe_run(self, job):
515 515 """check location dependencies, and run if they are met."""
516 516 msg_id = job.msg_id
517 517 self.log.debug("Attempting to assign task %s", msg_id)
518 518 available = self.available_engines()
519 519 if not available:
520 520 # no engines, definitely can't run
521 521 return False
522 522
523 523 if job.follow or job.targets or job.blacklist or self.hwm:
524 524 # we need a can_run filter
525 525 def can_run(idx):
526 526 # check hwm
527 527 if self.hwm and self.loads[idx] == self.hwm:
528 528 return False
529 529 target = self.targets[idx]
530 530 # check blacklist
531 531 if target in job.blacklist:
532 532 return False
533 533 # check targets
534 534 if job.targets and target not in job.targets:
535 535 return False
536 536 # check follow
537 537 return job.follow.check(self.completed[target], self.failed[target])
538 538
539 539 indices = list(filter(can_run, available))
540 540
541 541 if not indices:
542 542 # couldn't run
543 543 if job.follow.all:
544 544 # check follow for impossibility
545 545 dests = set()
546 546 relevant = set()
547 547 if job.follow.success:
548 548 relevant = self.all_completed
549 549 if job.follow.failure:
550 550 relevant = relevant.union(self.all_failed)
551 551 for m in job.follow.intersection(relevant):
552 552 dests.add(self.destinations[m])
553 553 if len(dests) > 1:
554 554 self.queue_map[msg_id] = job
555 555 self.fail_unreachable(msg_id)
556 556 return False
557 557 if job.targets:
558 558 # check blacklist+targets for impossibility
559 559 job.targets.difference_update(job.blacklist)
560 560 if not job.targets or not job.targets.intersection(self.targets):
561 561 self.queue_map[msg_id] = job
562 562 self.fail_unreachable(msg_id)
563 563 return False
564 564 return False
565 565 else:
566 566 indices = None
567 567
568 568 self.submit_task(job, indices)
569 569 return True
570 570
571 571 def save_unmet(self, job):
572 572 """Save a message for later submission when its dependencies are met."""
573 573 msg_id = job.msg_id
574 574 self.log.debug("Adding task %s to the queue", msg_id)
575 575 self.queue_map[msg_id] = job
576 576 self.queue.append(job)
577 577 # track the ids in follow or after, but not those already finished
578 578 for dep_id in job.after.union(job.follow).difference(self.all_done):
579 579 if dep_id not in self.graph:
580 580 self.graph[dep_id] = set()
581 581 self.graph[dep_id].add(msg_id)
582 582
583 583 # schedule timeout callback
584 584 if job.timeout:
585 585 timeout_id = job.timeout_id = job.timeout_id + 1
586 586 self.loop.add_timeout(time.time() + job.timeout,
587 587 lambda : self.job_timeout(job, timeout_id)
588 588 )
589 589
590 590
591 591 def submit_task(self, job, indices=None):
592 592 """Submit a task to any of a subset of our targets."""
593 593 if indices:
594 594 loads = [self.loads[i] for i in indices]
595 595 else:
596 596 loads = self.loads
597 597 idx = self.scheme(loads)
598 598 if indices:
599 599 idx = indices[idx]
600 600 target = self.targets[idx]
601 601 # print (target, map(str, msg[:3]))
602 602 # send job to the engine
603 603 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
604 604 self.engine_stream.send_multipart(job.raw_msg, copy=False)
605 605 # update load
606 606 self.add_job(idx)
607 607 self.pending[target][job.msg_id] = job
608 608 # notify Hub
609 609 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
610 610 self.session.send(self.mon_stream, 'task_destination', content=content,
611 611 ident=[b'tracktask',self.ident])
612 612
613 613
614 614 #-----------------------------------------------------------------------
615 615 # Result Handling
616 616 #-----------------------------------------------------------------------
617 617
618 618
619 619 @util.log_errors
620 620 def dispatch_result(self, raw_msg):
621 621 """dispatch method for result replies"""
622 622 try:
623 623 idents,msg = self.session.feed_identities(raw_msg, copy=False)
624 624 msg = self.session.deserialize(msg, content=False, copy=False)
625 625 engine = idents[0]
626 626 try:
627 627 idx = self.targets.index(engine)
628 628 except ValueError:
629 629 pass # skip load-update for dead engines
630 630 else:
631 631 self.finish_job(idx)
632 632 except Exception:
633 633 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
634 634 return
635 635
636 636 md = msg['metadata']
637 637 parent = msg['parent_header']
638 638 if md.get('dependencies_met', True):
639 639 success = (md['status'] == 'ok')
640 640 msg_id = parent['msg_id']
641 641 retries = self.retries[msg_id]
642 642 if not success and retries > 0:
643 643 # failed
644 644 self.retries[msg_id] = retries - 1
645 645 self.handle_unmet_dependency(idents, parent)
646 646 else:
647 647 del self.retries[msg_id]
648 648 # relay to client and update graph
649 649 self.handle_result(idents, parent, raw_msg, success)
650 650 # send to Hub monitor
651 651 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
652 652 else:
653 653 self.handle_unmet_dependency(idents, parent)
654 654
655 655 def handle_result(self, idents, parent, raw_msg, success=True):
656 656 """handle a real task result, either success or failure"""
657 657 # first, relay result to client
658 658 engine = idents[0]
659 659 client = idents[1]
660 660 # swap_ids for ROUTER-ROUTER mirror
661 661 raw_msg[:2] = [client,engine]
662 662 # print (map(str, raw_msg[:4]))
663 663 self.client_stream.send_multipart(raw_msg, copy=False)
664 664 # now, update our data structures
665 665 msg_id = parent['msg_id']
666 666 self.pending[engine].pop(msg_id)
667 667 if success:
668 668 self.completed[engine].add(msg_id)
669 669 self.all_completed.add(msg_id)
670 670 else:
671 671 self.failed[engine].add(msg_id)
672 672 self.all_failed.add(msg_id)
673 673 self.all_done.add(msg_id)
674 674 self.destinations[msg_id] = engine
675 675
676 676 self.update_graph(msg_id, success)
677 677
678 678 def handle_unmet_dependency(self, idents, parent):
679 679 """handle an unmet dependency"""
680 680 engine = idents[0]
681 681 msg_id = parent['msg_id']
682 682
683 683 job = self.pending[engine].pop(msg_id)
684 684 job.blacklist.add(engine)
685 685
686 686 if job.blacklist == job.targets:
687 687 self.queue_map[msg_id] = job
688 688 self.fail_unreachable(msg_id)
689 689 elif not self.maybe_run(job):
690 690 # resubmit failed
691 691 if msg_id not in self.all_failed:
692 692 # put it back in our dependency tree
693 693 self.save_unmet(job)
694 694
695 695 if self.hwm:
696 696 try:
697 697 idx = self.targets.index(engine)
698 698 except ValueError:
699 699 pass # skip load-update for dead engines
700 700 else:
701 701 if self.loads[idx] == self.hwm-1:
702 702 self.update_graph(None)
703 703
704 704 def update_graph(self, dep_id=None, success=True):
705 705 """dep_id just finished. Update our dependency
706 706 graph and submit any jobs that just became runnable.
707 707
708 708 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
709 709 """
710 710 # print ("\n\n***********")
711 711 # pprint (dep_id)
712 712 # pprint (self.graph)
713 713 # pprint (self.queue_map)
714 714 # pprint (self.all_completed)
715 715 # pprint (self.all_failed)
716 716 # print ("\n\n***********\n\n")
717 717 # update any jobs that depended on the dependency
718 718 msg_ids = self.graph.pop(dep_id, [])
719 719
720 720 # recheck *all* jobs if
721 721 # a) we have HWM and an engine just become no longer full
722 722 # or b) dep_id was given as None
723 723
724 724 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
725 725 jobs = self.queue
726 726 using_queue = True
727 727 else:
728 728 using_queue = False
729 729 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
730 730
731 731 to_restore = []
732 732 while jobs:
733 733 job = jobs.popleft()
734 734 if job.removed:
735 735 continue
736 736 msg_id = job.msg_id
737 737
738 738 put_it_back = True
739 739
740 740 if job.after.unreachable(self.all_completed, self.all_failed)\
741 741 or job.follow.unreachable(self.all_completed, self.all_failed):
742 742 self.fail_unreachable(msg_id)
743 743 put_it_back = False
744 744
745 745 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
746 746 if self.maybe_run(job):
747 747 put_it_back = False
748 748 self.queue_map.pop(msg_id)
749 749 for mid in job.dependents:
750 750 if mid in self.graph:
751 751 self.graph[mid].remove(msg_id)
752 752
753 753 # abort the loop if we just filled up all of our engines.
754 754 # avoids an O(N) operation in situation of full queue,
755 755 # where graph update is triggered as soon as an engine becomes
756 756 # non-full, and all tasks after the first are checked,
757 757 # even though they can't run.
758 758 if not self.available_engines():
759 759 break
760 760
761 761 if using_queue and put_it_back:
762 762 # popped a job from the queue but it neither ran nor failed,
763 763 # so we need to put it back when we are done
764 764 # make sure to_restore preserves the same ordering
765 765 to_restore.append(job)
766 766
767 767 # put back any tasks we popped but didn't run
768 768 if using_queue:
769 769 self.queue.extendleft(to_restore)
770 770
771 771 #----------------------------------------------------------------------
772 772 # methods to be overridden by subclasses
773 773 #----------------------------------------------------------------------
774 774
775 775 def add_job(self, idx):
776 776 """Called after self.targets[idx] just got the job with header.
777 777 Override with subclasses. The default ordering is simple LRU.
778 778 The default loads are the number of outstanding jobs."""
779 779 self.loads[idx] += 1
780 780 for lis in (self.targets, self.loads):
781 781 lis.append(lis.pop(idx))
782 782
783 783
784 784 def finish_job(self, idx):
785 785 """Called after self.targets[idx] just finished a job.
786 786 Override with subclasses."""
787 787 self.loads[idx] -= 1
788 788
789 789
790 790
791 791 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
792 792 logname='root', log_url=None, loglevel=logging.DEBUG,
793 793 identity=b'task', in_thread=False):
794 794
795 795 ZMQStream = zmqstream.ZMQStream
796 796
797 797 if config:
798 798 # unwrap dict back into Config
799 799 config = Config(config)
800 800
801 801 if in_thread:
802 802 # use instance() to get the same Context/Loop as our parent
803 803 ctx = zmq.Context.instance()
804 804 loop = ioloop.IOLoop.instance()
805 805 else:
806 806 # in a process, don't use instance()
807 807 # for safety with multiprocessing
808 808 ctx = zmq.Context()
809 809 loop = ioloop.IOLoop()
810 810 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
811 811 util.set_hwm(ins, 0)
812 812 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
813 813 ins.bind(in_addr)
814 814
815 815 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
816 816 util.set_hwm(outs, 0)
817 817 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
818 818 outs.bind(out_addr)
819 819 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
820 820 util.set_hwm(mons, 0)
821 821 mons.connect(mon_addr)
822 822 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
823 823 nots.setsockopt(zmq.SUBSCRIBE, b'')
824 824 nots.connect(not_addr)
825 825
826 826 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
827 827 querys.connect(reg_addr)
828 828
829 829 # setup logging.
830 830 if in_thread:
831 831 log = Application.instance().log
832 832 else:
833 833 if log_url:
834 834 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
835 835 else:
836 836 log = local_logger(logname, loglevel)
837 837
838 838 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
839 839 mon_stream=mons, notifier_stream=nots,
840 840 query_stream=querys,
841 841 loop=loop, log=log,
842 842 config=config)
843 843 scheduler.start()
844 844 if not in_thread:
845 845 try:
846 846 loop.start()
847 847 except KeyboardInterrupt:
848 848 scheduler.log.critical("Interrupted, exiting...")
849 849
@@ -1,192 +1,192 b''
1 1 """base class for parallel client tests
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14 from __future__ import print_function
15 15
16 16 import sys
17 17 import tempfile
18 18 import time
19 19
20 20 from nose import SkipTest
21 21
22 22 import zmq
23 23 from zmq.tests import BaseZMQTestCase
24 24
25 from IPython.external.decorator import decorator
25 from decorator import decorator
26 26
27 27 from IPython.parallel import error
28 28 from IPython.parallel import Client
29 29
30 30 from IPython.parallel.tests import launchers, add_engines
31 31
32 32 # simple tasks for use in apply tests
33 33
34 34 def segfault():
35 35 """this will segfault"""
36 36 import ctypes
37 37 ctypes.memset(-1,0,1)
38 38
39 39 def crash():
40 40 """from stdlib crashers in the test suite"""
41 41 import types
42 42 if sys.platform.startswith('win'):
43 43 import ctypes
44 44 ctypes.windll.kernel32.SetErrorMode(0x0002);
45 45 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
46 46 if sys.version_info[0] >= 3:
47 47 # Python3 adds 'kwonlyargcount' as the second argument to Code
48 48 args.insert(1, 0)
49 49
50 50 co = types.CodeType(*args)
51 51 exec(co)
52 52
53 53 def wait(n):
54 54 """sleep for a time"""
55 55 import time
56 56 time.sleep(n)
57 57 return n
58 58
59 59 def raiser(eclass):
60 60 """raise an exception"""
61 61 raise eclass()
62 62
63 63 def generate_output():
64 64 """function for testing output
65 65
66 66 publishes two outputs of each type, and returns
67 67 a rich displayable object.
68 68 """
69 69
70 70 import sys
71 71 from IPython.core.display import display, HTML, Math
72 72
73 73 print("stdout")
74 74 print("stderr", file=sys.stderr)
75 75
76 76 display(HTML("<b>HTML</b>"))
77 77
78 78 print("stdout2")
79 79 print("stderr2", file=sys.stderr)
80 80
81 81 display(Math(r"\alpha=\beta"))
82 82
83 83 return Math("42")
84 84
85 85 # test decorator for skipping tests when libraries are unavailable
86 86 def skip_without(*names):
87 87 """skip a test if some names are not importable"""
88 88 @decorator
89 89 def skip_without_names(f, *args, **kwargs):
90 90 """decorator to skip tests in the absence of numpy."""
91 91 for name in names:
92 92 try:
93 93 __import__(name)
94 94 except ImportError:
95 95 raise SkipTest
96 96 return f(*args, **kwargs)
97 97 return skip_without_names
98 98
99 99 #-------------------------------------------------------------------------------
100 100 # Classes
101 101 #-------------------------------------------------------------------------------
102 102
103 103
104 104 class ClusterTestCase(BaseZMQTestCase):
105 105 timeout = 10
106 106
107 107 def add_engines(self, n=1, block=True):
108 108 """add multiple engines to our cluster"""
109 109 self.engines.extend(add_engines(n))
110 110 if block:
111 111 self.wait_on_engines()
112 112
113 113 def minimum_engines(self, n=1, block=True):
114 114 """add engines until there are at least n connected"""
115 115 self.engines.extend(add_engines(n, total=True))
116 116 if block:
117 117 self.wait_on_engines()
118 118
119 119
120 120 def wait_on_engines(self, timeout=5):
121 121 """wait for our engines to connect."""
122 122 n = len(self.engines)+self.base_engine_count
123 123 tic = time.time()
124 124 while time.time()-tic < timeout and len(self.client.ids) < n:
125 125 time.sleep(0.1)
126 126
127 127 assert not len(self.client.ids) < n, "waiting for engines timed out"
128 128
129 129 def client_wait(self, client, jobs=None, timeout=-1):
130 130 """my wait wrapper, sets a default finite timeout to avoid hangs"""
131 131 if timeout < 0:
132 132 timeout = self.timeout
133 133 return Client.wait(client, jobs, timeout)
134 134
135 135 def connect_client(self):
136 136 """connect a client with my Context, and track its sockets for cleanup"""
137 137 c = Client(profile='iptest', context=self.context)
138 138 c.wait = lambda *a, **kw: self.client_wait(c, *a, **kw)
139 139
140 140 for name in filter(lambda n:n.endswith('socket'), dir(c)):
141 141 s = getattr(c, name)
142 142 s.setsockopt(zmq.LINGER, 0)
143 143 self.sockets.append(s)
144 144 return c
145 145
146 146 def assertRaisesRemote(self, etype, f, *args, **kwargs):
147 147 try:
148 148 try:
149 149 f(*args, **kwargs)
150 150 except error.CompositeError as e:
151 151 e.raise_exception()
152 152 except error.RemoteError as e:
153 153 self.assertEqual(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
154 154 else:
155 155 self.fail("should have raised a RemoteError")
156 156
157 157 def _wait_for(self, f, timeout=10):
158 158 """wait for a condition"""
159 159 tic = time.time()
160 160 while time.time() <= tic + timeout:
161 161 if f():
162 162 return
163 163 time.sleep(0.1)
164 164 self.client.spin()
165 165 if not f():
166 166 print("Warning: Awaited condition never arrived")
167 167
168 168 def setUp(self):
169 169 BaseZMQTestCase.setUp(self)
170 170 self.client = self.connect_client()
171 171 # start every test with clean engine namespaces:
172 172 self.client.clear(block=True)
173 173 self.base_engine_count=len(self.client.ids)
174 174 self.engines=[]
175 175
176 176 def tearDown(self):
177 177 # self.client.clear(block=True)
178 178 # close fds:
179 179 for e in filter(lambda e: e.poll() is not None, launchers):
180 180 launchers.remove(e)
181 181
182 182 # allow flushing of incoming messages to prevent crash on socket close
183 183 self.client.wait(timeout=2)
184 184 # time.sleep(2)
185 185 self.client.spin()
186 186 self.client.close()
187 187 BaseZMQTestCase.tearDown(self)
188 188 # this will be redundant when pyzmq merges PR #88
189 189 # self.context.term()
190 190 # print tempfile.TemporaryFile().fileno(),
191 191 # sys.stdout.flush()
192 192
@@ -1,389 +1,389 b''
1 1 """Some generic utilities for dealing with classes, urls, and serialization."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import logging
7 7 import os
8 8 import re
9 9 import stat
10 10 import socket
11 11 import sys
12 12 import warnings
13 13 from signal import signal, SIGINT, SIGABRT, SIGTERM
14 14 try:
15 15 from signal import SIGKILL
16 16 except ImportError:
17 17 SIGKILL=None
18 18 from types import FunctionType
19 19
20 20 try:
21 21 import cPickle
22 22 pickle = cPickle
23 23 except:
24 24 cPickle = None
25 25 import pickle
26 26
27 27 import zmq
28 28 from zmq.log import handlers
29 29
30 30 from IPython.utils.log import get_logger
31 from IPython.external.decorator import decorator
31 from decorator import decorator
32 32
33 33 from IPython.config.application import Application
34 34 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
35 35 from IPython.utils.py3compat import string_types, iteritems, itervalues
36 36 from IPython.kernel.zmq.log import EnginePUBHandler
37 37
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Classes
41 41 #-----------------------------------------------------------------------------
42 42
43 43 class Namespace(dict):
44 44 """Subclass of dict for attribute access to keys."""
45 45
46 46 def __getattr__(self, key):
47 47 """getattr aliased to getitem"""
48 48 if key in self:
49 49 return self[key]
50 50 else:
51 51 raise NameError(key)
52 52
53 53 def __setattr__(self, key, value):
54 54 """setattr aliased to setitem, with strict"""
55 55 if hasattr(dict, key):
56 56 raise KeyError("Cannot override dict keys %r"%key)
57 57 self[key] = value
58 58
59 59
60 60 class ReverseDict(dict):
61 61 """simple double-keyed subset of dict methods."""
62 62
63 63 def __init__(self, *args, **kwargs):
64 64 dict.__init__(self, *args, **kwargs)
65 65 self._reverse = dict()
66 66 for key, value in iteritems(self):
67 67 self._reverse[value] = key
68 68
69 69 def __getitem__(self, key):
70 70 try:
71 71 return dict.__getitem__(self, key)
72 72 except KeyError:
73 73 return self._reverse[key]
74 74
75 75 def __setitem__(self, key, value):
76 76 if key in self._reverse:
77 77 raise KeyError("Can't have key %r on both sides!"%key)
78 78 dict.__setitem__(self, key, value)
79 79 self._reverse[value] = key
80 80
81 81 def pop(self, key):
82 82 value = dict.pop(self, key)
83 83 self._reverse.pop(value)
84 84 return value
85 85
86 86 def get(self, key, default=None):
87 87 try:
88 88 return self[key]
89 89 except KeyError:
90 90 return default
91 91
92 92 #-----------------------------------------------------------------------------
93 93 # Functions
94 94 #-----------------------------------------------------------------------------
95 95
96 96 @decorator
97 97 def log_errors(f, self, *args, **kwargs):
98 98 """decorator to log unhandled exceptions raised in a method.
99 99
100 100 For use wrapping on_recv callbacks, so that exceptions
101 101 do not cause the stream to be closed.
102 102 """
103 103 try:
104 104 return f(self, *args, **kwargs)
105 105 except Exception:
106 106 self.log.error("Uncaught exception in %r" % f, exc_info=True)
107 107
108 108
109 109 def is_url(url):
110 110 """boolean check for whether a string is a zmq url"""
111 111 if '://' not in url:
112 112 return False
113 113 proto, addr = url.split('://', 1)
114 114 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
115 115 return False
116 116 return True
117 117
118 118 def validate_url(url):
119 119 """validate a url for zeromq"""
120 120 if not isinstance(url, string_types):
121 121 raise TypeError("url must be a string, not %r"%type(url))
122 122 url = url.lower()
123 123
124 124 proto_addr = url.split('://')
125 125 assert len(proto_addr) == 2, 'Invalid url: %r'%url
126 126 proto, addr = proto_addr
127 127 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
128 128
129 129 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
130 130 # author: Remi Sabourin
131 131 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
132 132
133 133 if proto == 'tcp':
134 134 lis = addr.split(':')
135 135 assert len(lis) == 2, 'Invalid url: %r'%url
136 136 addr,s_port = lis
137 137 try:
138 138 port = int(s_port)
139 139 except ValueError:
140 140 raise AssertionError("Invalid port %r in url: %r"%(port, url))
141 141
142 142 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
143 143
144 144 else:
145 145 # only validate tcp urls currently
146 146 pass
147 147
148 148 return True
149 149
150 150
151 151 def validate_url_container(container):
152 152 """validate a potentially nested collection of urls."""
153 153 if isinstance(container, string_types):
154 154 url = container
155 155 return validate_url(url)
156 156 elif isinstance(container, dict):
157 157 container = itervalues(container)
158 158
159 159 for element in container:
160 160 validate_url_container(element)
161 161
162 162
163 163 def split_url(url):
164 164 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
165 165 proto_addr = url.split('://')
166 166 assert len(proto_addr) == 2, 'Invalid url: %r'%url
167 167 proto, addr = proto_addr
168 168 lis = addr.split(':')
169 169 assert len(lis) == 2, 'Invalid url: %r'%url
170 170 addr,s_port = lis
171 171 return proto,addr,s_port
172 172
173 173
174 174 def disambiguate_ip_address(ip, location=None):
175 175 """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address
176 176
177 177 Explicit IP addresses are returned unmodified.
178 178
179 179 Parameters
180 180 ----------
181 181
182 182 ip : IP address
183 183 An IP address, or the special values 0.0.0.0, or *
184 184 location: IP address, optional
185 185 A public IP of the target machine.
186 186 If location is an IP of the current machine,
187 187 localhost will be returned,
188 188 otherwise location will be returned.
189 189 """
190 190 if ip in {'0.0.0.0', '*'}:
191 191 if not location:
192 192 # unspecified location, localhost is the only choice
193 193 ip = localhost()
194 194 elif is_public_ip(location):
195 195 # location is a public IP on this machine, use localhost
196 196 ip = localhost()
197 197 elif not public_ips():
198 198 # this machine's public IPs cannot be determined,
199 199 # assume `location` is not this machine
200 200 warnings.warn("IPython could not determine public IPs", RuntimeWarning)
201 201 ip = location
202 202 else:
203 203 # location is not this machine, do not use loopback
204 204 ip = location
205 205 return ip
206 206
207 207
208 208 def disambiguate_url(url, location=None):
209 209 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
210 210 ones, based on the location (default interpretation is localhost).
211 211
212 212 This is for zeromq urls, such as ``tcp://*:10101``.
213 213 """
214 214 try:
215 215 proto,ip,port = split_url(url)
216 216 except AssertionError:
217 217 # probably not tcp url; could be ipc, etc.
218 218 return url
219 219
220 220 ip = disambiguate_ip_address(ip,location)
221 221
222 222 return "%s://%s:%s"%(proto,ip,port)
223 223
224 224
225 225 #--------------------------------------------------------------------------
226 226 # helpers for implementing old MEC API via view.apply
227 227 #--------------------------------------------------------------------------
228 228
229 229 def interactive(f):
230 230 """decorator for making functions appear as interactively defined.
231 231 This results in the function being linked to the user_ns as globals()
232 232 instead of the module globals().
233 233 """
234 234
235 235 # build new FunctionType, so it can have the right globals
236 236 # interactive functions never have closures, that's kind of the point
237 237 if isinstance(f, FunctionType):
238 238 mainmod = __import__('__main__')
239 239 f = FunctionType(f.__code__, mainmod.__dict__,
240 240 f.__name__, f.__defaults__,
241 241 )
242 242 # associate with __main__ for uncanning
243 243 f.__module__ = '__main__'
244 244 return f
245 245
246 246 @interactive
247 247 def _push(**ns):
248 248 """helper method for implementing `client.push` via `client.apply`"""
249 249 user_ns = globals()
250 250 tmp = '_IP_PUSH_TMP_'
251 251 while tmp in user_ns:
252 252 tmp = tmp + '_'
253 253 try:
254 254 for name, value in ns.items():
255 255 user_ns[tmp] = value
256 256 exec("%s = %s" % (name, tmp), user_ns)
257 257 finally:
258 258 user_ns.pop(tmp, None)
259 259
260 260 @interactive
261 261 def _pull(keys):
262 262 """helper method for implementing `client.pull` via `client.apply`"""
263 263 if isinstance(keys, (list,tuple, set)):
264 264 return [eval(key, globals()) for key in keys]
265 265 else:
266 266 return eval(keys, globals())
267 267
268 268 @interactive
269 269 def _execute(code):
270 270 """helper method for implementing `client.execute` via `client.apply`"""
271 271 exec(code, globals())
272 272
273 273 #--------------------------------------------------------------------------
274 274 # extra process management utilities
275 275 #--------------------------------------------------------------------------
276 276
277 277 _random_ports = set()
278 278
279 279 def select_random_ports(n):
280 280 """Selects and return n random ports that are available."""
281 281 ports = []
282 282 for i in range(n):
283 283 sock = socket.socket()
284 284 sock.bind(('', 0))
285 285 while sock.getsockname()[1] in _random_ports:
286 286 sock.close()
287 287 sock = socket.socket()
288 288 sock.bind(('', 0))
289 289 ports.append(sock)
290 290 for i, sock in enumerate(ports):
291 291 port = sock.getsockname()[1]
292 292 sock.close()
293 293 ports[i] = port
294 294 _random_ports.add(port)
295 295 return ports
296 296
297 297 def signal_children(children):
298 298 """Relay interupt/term signals to children, for more solid process cleanup."""
299 299 def terminate_children(sig, frame):
300 300 log = get_logger()
301 301 log.critical("Got signal %i, terminating children..."%sig)
302 302 for child in children:
303 303 child.terminate()
304 304
305 305 sys.exit(sig != SIGINT)
306 306 # sys.exit(sig)
307 307 for sig in (SIGINT, SIGABRT, SIGTERM):
308 308 signal(sig, terminate_children)
309 309
310 310 def generate_exec_key(keyfile):
311 311 import uuid
312 312 newkey = str(uuid.uuid4())
313 313 with open(keyfile, 'w') as f:
314 314 # f.write('ipython-key ')
315 315 f.write(newkey+'\n')
316 316 # set user-only RW permissions (0600)
317 317 # this will have no effect on Windows
318 318 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
319 319
320 320
321 321 def integer_loglevel(loglevel):
322 322 try:
323 323 loglevel = int(loglevel)
324 324 except ValueError:
325 325 if isinstance(loglevel, str):
326 326 loglevel = getattr(logging, loglevel)
327 327 return loglevel
328 328
329 329 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
330 330 logger = logging.getLogger(logname)
331 331 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
332 332 # don't add a second PUBHandler
333 333 return
334 334 loglevel = integer_loglevel(loglevel)
335 335 lsock = context.socket(zmq.PUB)
336 336 lsock.connect(iface)
337 337 handler = handlers.PUBHandler(lsock)
338 338 handler.setLevel(loglevel)
339 339 handler.root_topic = root
340 340 logger.addHandler(handler)
341 341 logger.setLevel(loglevel)
342 342
343 343 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
344 344 logger = logging.getLogger()
345 345 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
346 346 # don't add a second PUBHandler
347 347 return
348 348 loglevel = integer_loglevel(loglevel)
349 349 lsock = context.socket(zmq.PUB)
350 350 lsock.connect(iface)
351 351 handler = EnginePUBHandler(engine, lsock)
352 352 handler.setLevel(loglevel)
353 353 logger.addHandler(handler)
354 354 logger.setLevel(loglevel)
355 355 return logger
356 356
357 357 def local_logger(logname, loglevel=logging.DEBUG):
358 358 loglevel = integer_loglevel(loglevel)
359 359 logger = logging.getLogger(logname)
360 360 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
361 361 # don't add a second StreamHandler
362 362 return
363 363 handler = logging.StreamHandler()
364 364 handler.setLevel(loglevel)
365 365 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
366 366 datefmt="%Y-%m-%d %H:%M:%S")
367 367 handler.setFormatter(formatter)
368 368
369 369 logger.addHandler(handler)
370 370 logger.setLevel(loglevel)
371 371 return logger
372 372
373 373 def set_hwm(sock, hwm=0):
374 374 """set zmq High Water Mark on a socket
375 375
376 376 in a way that always works for various pyzmq / libzmq versions.
377 377 """
378 378 import zmq
379 379
380 380 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
381 381 opt = getattr(zmq, key, None)
382 382 if opt is None:
383 383 continue
384 384 try:
385 385 sock.setsockopt(opt, hwm)
386 386 except zmq.ZMQError:
387 387 pass
388 388
389 389
@@ -1,400 +1,400 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Decorators for labeling test objects.
3 3
4 4 Decorators that merely return a modified version of the original function
5 5 object are straightforward. Decorators that return a new function object need
6 6 to use nose.tools.make_decorator(original_function)(decorator) in returning the
7 7 decorator, in order to preserve metadata such as function name, setup and
8 8 teardown functions and so on - see nose.tools for more information.
9 9
10 10 This module provides a set of useful decorators meant to be ready to use in
11 11 your own tests. See the bottom of the file for the ready-made ones, and if you
12 12 find yourself writing a new one that may be of generic use, add it here.
13 13
14 14 Included decorators:
15 15
16 16
17 17 Lightweight testing that remains unittest-compatible.
18 18
19 19 - An @as_unittest decorator can be used to tag any normal parameter-less
20 20 function as a unittest TestCase. Then, both nose and normal unittest will
21 21 recognize it as such. This will make it easier to migrate away from Nose if
22 22 we ever need/want to while maintaining very lightweight tests.
23 23
24 24 NOTE: This file contains IPython-specific decorators. Using the machinery in
25 25 IPython.external.decorators, we import either numpy.testing.decorators if numpy is
26 26 available, OR use equivalent code in IPython.external._decorators, which
27 27 we've copied verbatim from numpy.
28 28
29 29 Authors
30 30 -------
31 31
32 32 - Fernando Perez <Fernando.Perez@berkeley.edu>
33 33 """
34 34
35 35 #-----------------------------------------------------------------------------
36 36 # Copyright (C) 2009-2011 The IPython Development Team
37 37 #
38 38 # Distributed under the terms of the BSD License. The full license is in
39 39 # the file COPYING, distributed as part of this software.
40 40 #-----------------------------------------------------------------------------
41 41
42 42 #-----------------------------------------------------------------------------
43 43 # Imports
44 44 #-----------------------------------------------------------------------------
45 45
46 46 # Stdlib imports
47 47 import sys
48 48 import os
49 49 import tempfile
50 50 import unittest
51 51
52 52 # Third-party imports
53 53
54 54 # This is Michele Simionato's decorator module, kept verbatim.
55 from IPython.external.decorator import decorator
55 from decorator import decorator
56 56
57 57 # Expose the unittest-driven decorators
58 58 from .ipunittest import ipdoctest, ipdocstring
59 59
60 60 # Grab the numpy-specific decorators which we keep in a file that we
61 61 # occasionally update from upstream: decorators.py is a copy of
62 62 # numpy.testing.decorators, we expose all of it here.
63 63 from IPython.external.decorators import *
64 64
65 65 # For onlyif_cmd_exists decorator
66 66 from IPython.utils.process import is_cmd_found
67 67 from IPython.utils.py3compat import string_types
68 68
69 69 #-----------------------------------------------------------------------------
70 70 # Classes and functions
71 71 #-----------------------------------------------------------------------------
72 72
73 73 # Simple example of the basic idea
74 74 def as_unittest(func):
75 75 """Decorator to make a simple function into a normal test via unittest."""
76 76 class Tester(unittest.TestCase):
77 77 def test(self):
78 78 func()
79 79
80 80 Tester.__name__ = func.__name__
81 81
82 82 return Tester
83 83
84 84 # Utility functions
85 85
86 86 def apply_wrapper(wrapper,func):
87 87 """Apply a wrapper to a function for decoration.
88 88
89 89 This mixes Michele Simionato's decorator tool with nose's make_decorator,
90 90 to apply a wrapper in a decorator so that all nose attributes, as well as
91 91 function signature and other properties, survive the decoration cleanly.
92 92 This will ensure that wrapped functions can still be well introspected via
93 93 IPython, for example.
94 94 """
95 95 import nose.tools
96 96
97 97 return decorator(wrapper,nose.tools.make_decorator(func)(wrapper))
98 98
99 99
100 100 def make_label_dec(label,ds=None):
101 101 """Factory function to create a decorator that applies one or more labels.
102 102
103 103 Parameters
104 104 ----------
105 105 label : string or sequence
106 106 One or more labels that will be applied by the decorator to the functions
107 107 it decorates. Labels are attributes of the decorated function with their
108 108 value set to True.
109 109
110 110 ds : string
111 111 An optional docstring for the resulting decorator. If not given, a
112 112 default docstring is auto-generated.
113 113
114 114 Returns
115 115 -------
116 116 A decorator.
117 117
118 118 Examples
119 119 --------
120 120
121 121 A simple labeling decorator:
122 122
123 123 >>> slow = make_label_dec('slow')
124 124 >>> slow.__doc__
125 125 "Labels a test as 'slow'."
126 126
127 127 And one that uses multiple labels and a custom docstring:
128 128
129 129 >>> rare = make_label_dec(['slow','hard'],
130 130 ... "Mix labels 'slow' and 'hard' for rare tests.")
131 131 >>> rare.__doc__
132 132 "Mix labels 'slow' and 'hard' for rare tests."
133 133
134 134 Now, let's test using this one:
135 135 >>> @rare
136 136 ... def f(): pass
137 137 ...
138 138 >>>
139 139 >>> f.slow
140 140 True
141 141 >>> f.hard
142 142 True
143 143 """
144 144
145 145 if isinstance(label, string_types):
146 146 labels = [label]
147 147 else:
148 148 labels = label
149 149
150 150 # Validate that the given label(s) are OK for use in setattr() by doing a
151 151 # dry run on a dummy function.
152 152 tmp = lambda : None
153 153 for label in labels:
154 154 setattr(tmp,label,True)
155 155
156 156 # This is the actual decorator we'll return
157 157 def decor(f):
158 158 for label in labels:
159 159 setattr(f,label,True)
160 160 return f
161 161
162 162 # Apply the user's docstring, or autogenerate a basic one
163 163 if ds is None:
164 164 ds = "Labels a test as %r." % label
165 165 decor.__doc__ = ds
166 166
167 167 return decor
168 168
169 169
170 170 # Inspired by numpy's skipif, but uses the full apply_wrapper utility to
171 171 # preserve function metadata better and allows the skip condition to be a
172 172 # callable.
173 173 def skipif(skip_condition, msg=None):
174 174 ''' Make function raise SkipTest exception if skip_condition is true
175 175
176 176 Parameters
177 177 ----------
178 178
179 179 skip_condition : bool or callable
180 180 Flag to determine whether to skip test. If the condition is a
181 181 callable, it is used at runtime to dynamically make the decision. This
182 182 is useful for tests that may require costly imports, to delay the cost
183 183 until the test suite is actually executed.
184 184 msg : string
185 185 Message to give on raising a SkipTest exception.
186 186
187 187 Returns
188 188 -------
189 189 decorator : function
190 190 Decorator, which, when applied to a function, causes SkipTest
191 191 to be raised when the skip_condition was True, and the function
192 192 to be called normally otherwise.
193 193
194 194 Notes
195 195 -----
196 196 You will see from the code that we had to further decorate the
197 197 decorator with the nose.tools.make_decorator function in order to
198 198 transmit function name, and various other metadata.
199 199 '''
200 200
201 201 def skip_decorator(f):
202 202 # Local import to avoid a hard nose dependency and only incur the
203 203 # import time overhead at actual test-time.
204 204 import nose
205 205
206 206 # Allow for both boolean or callable skip conditions.
207 207 if callable(skip_condition):
208 208 skip_val = skip_condition
209 209 else:
210 210 skip_val = lambda : skip_condition
211 211
212 212 def get_msg(func,msg=None):
213 213 """Skip message with information about function being skipped."""
214 214 if msg is None: out = 'Test skipped due to test condition.'
215 215 else: out = msg
216 216 return "Skipping test: %s. %s" % (func.__name__,out)
217 217
218 218 # We need to define *two* skippers because Python doesn't allow both
219 219 # return with value and yield inside the same function.
220 220 def skipper_func(*args, **kwargs):
221 221 """Skipper for normal test functions."""
222 222 if skip_val():
223 223 raise nose.SkipTest(get_msg(f,msg))
224 224 else:
225 225 return f(*args, **kwargs)
226 226
227 227 def skipper_gen(*args, **kwargs):
228 228 """Skipper for test generators."""
229 229 if skip_val():
230 230 raise nose.SkipTest(get_msg(f,msg))
231 231 else:
232 232 for x in f(*args, **kwargs):
233 233 yield x
234 234
235 235 # Choose the right skipper to use when building the actual generator.
236 236 if nose.util.isgenerator(f):
237 237 skipper = skipper_gen
238 238 else:
239 239 skipper = skipper_func
240 240
241 241 return nose.tools.make_decorator(f)(skipper)
242 242
243 243 return skip_decorator
244 244
245 245 # A version with the condition set to true, common case just to attach a message
246 246 # to a skip decorator
247 247 def skip(msg=None):
248 248 """Decorator factory - mark a test function for skipping from test suite.
249 249
250 250 Parameters
251 251 ----------
252 252 msg : string
253 253 Optional message to be added.
254 254
255 255 Returns
256 256 -------
257 257 decorator : function
258 258 Decorator, which, when applied to a function, causes SkipTest
259 259 to be raised, with the optional message added.
260 260 """
261 261
262 262 return skipif(True,msg)
263 263
264 264
265 265 def onlyif(condition, msg):
266 266 """The reverse from skipif, see skipif for details."""
267 267
268 268 if callable(condition):
269 269 skip_condition = lambda : not condition()
270 270 else:
271 271 skip_condition = lambda : not condition
272 272
273 273 return skipif(skip_condition, msg)
274 274
275 275 #-----------------------------------------------------------------------------
276 276 # Utility functions for decorators
277 277 def module_not_available(module):
278 278 """Can module be imported? Returns true if module does NOT import.
279 279
280 280 This is used to make a decorator to skip tests that require module to be
281 281 available, but delay the 'import numpy' to test execution time.
282 282 """
283 283 try:
284 284 mod = __import__(module)
285 285 mod_not_avail = False
286 286 except ImportError:
287 287 mod_not_avail = True
288 288
289 289 return mod_not_avail
290 290
291 291
292 292 def decorated_dummy(dec, name):
293 293 """Return a dummy function decorated with dec, with the given name.
294 294
295 295 Examples
296 296 --------
297 297 import IPython.testing.decorators as dec
298 298 setup = dec.decorated_dummy(dec.skip_if_no_x11, __name__)
299 299 """
300 300 dummy = lambda: None
301 301 dummy.__name__ = name
302 302 return dec(dummy)
303 303
304 304 #-----------------------------------------------------------------------------
305 305 # Decorators for public use
306 306
307 307 # Decorators to skip certain tests on specific platforms.
308 308 skip_win32 = skipif(sys.platform == 'win32',
309 309 "This test does not run under Windows")
310 310 skip_linux = skipif(sys.platform.startswith('linux'),
311 311 "This test does not run under Linux")
312 312 skip_osx = skipif(sys.platform == 'darwin',"This test does not run under OS X")
313 313
314 314
315 315 # Decorators to skip tests if not on specific platforms.
316 316 skip_if_not_win32 = skipif(sys.platform != 'win32',
317 317 "This test only runs under Windows")
318 318 skip_if_not_linux = skipif(not sys.platform.startswith('linux'),
319 319 "This test only runs under Linux")
320 320 skip_if_not_osx = skipif(sys.platform != 'darwin',
321 321 "This test only runs under OSX")
322 322
323 323
324 324 _x11_skip_cond = (sys.platform not in ('darwin', 'win32') and
325 325 os.environ.get('DISPLAY', '') == '')
326 326 _x11_skip_msg = "Skipped under *nix when X11/XOrg not available"
327 327
328 328 skip_if_no_x11 = skipif(_x11_skip_cond, _x11_skip_msg)
329 329
330 330 # not a decorator itself, returns a dummy function to be used as setup
331 331 def skip_file_no_x11(name):
332 332 return decorated_dummy(skip_if_no_x11, name) if _x11_skip_cond else None
333 333
334 334 # Other skip decorators
335 335
336 336 # generic skip without module
337 337 skip_without = lambda mod: skipif(module_not_available(mod), "This test requires %s" % mod)
338 338
339 339 skipif_not_numpy = skip_without('numpy')
340 340
341 341 skipif_not_matplotlib = skip_without('matplotlib')
342 342
343 343 skipif_not_sympy = skip_without('sympy')
344 344
345 345 skip_known_failure = knownfailureif(True,'This test is known to fail')
346 346
347 347 known_failure_py3 = knownfailureif(sys.version_info[0] >= 3,
348 348 'This test is known to fail on Python 3.')
349 349
350 350 # A null 'decorator', useful to make more readable code that needs to pick
351 351 # between different decorators based on OS or other conditions
352 352 null_deco = lambda f: f
353 353
354 354 # Some tests only run where we can use unicode paths. Note that we can't just
355 355 # check os.path.supports_unicode_filenames, which is always False on Linux.
356 356 try:
357 357 f = tempfile.NamedTemporaryFile(prefix=u"tmp€")
358 358 except UnicodeEncodeError:
359 359 unicode_paths = False
360 360 else:
361 361 unicode_paths = True
362 362 f.close()
363 363
364 364 onlyif_unicode_paths = onlyif(unicode_paths, ("This test is only applicable "
365 365 "where we can use unicode in filenames."))
366 366
367 367
368 368 def onlyif_cmds_exist(*commands):
369 369 """
370 370 Decorator to skip test when at least one of `commands` is not found.
371 371 """
372 372 for cmd in commands:
373 373 try:
374 374 if not is_cmd_found(cmd):
375 375 return skip("This test runs only if command '{0}' "
376 376 "is installed".format(cmd))
377 377 except ImportError as e:
378 378 # is_cmd_found uses pywin32 on windows, which might not be available
379 379 if sys.platform == 'win32' and 'pywin32' in str(e):
380 380 return skip("This test runs only if pywin32 and command '{0}' "
381 381 "is installed".format(cmd))
382 382 raise e
383 383 return null_deco
384 384
385 385 def onlyif_any_cmd_exists(*commands):
386 386 """
387 387 Decorator to skip test unless at least one of `commands` is found.
388 388 """
389 389 for cmd in commands:
390 390 try:
391 391 if is_cmd_found(cmd):
392 392 return null_deco
393 393 except ImportError as e:
394 394 # is_cmd_found uses pywin32 on windows, which might not be available
395 395 if sys.platform == 'win32' and 'pywin32' in str(e):
396 396 return skip("This test runs only if pywin32 and commands '{0}' "
397 397 "are installed".format(commands))
398 398 raise e
399 399 return skip("This test runs only if one of the commands {0} "
400 400 "is installed".format(commands))
@@ -1,344 +1,345 b''
1 1 #!/usr/bin/env python
2 2 # -*- coding: utf-8 -*-
3 3 """Setup script for IPython.
4 4
5 5 Under Posix environments it works like a typical setup.py script.
6 6 Under Windows, the command sdist is not supported, since IPython
7 7 requires utilities which are not available under Windows."""
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (c) 2008-2011, IPython Development Team.
11 11 # Copyright (c) 2001-2007, Fernando Perez <fernando.perez@colorado.edu>
12 12 # Copyright (c) 2001, Janko Hauser <jhauser@zscout.de>
13 13 # Copyright (c) 2001, Nathaniel Gray <n8gray@caltech.edu>
14 14 #
15 15 # Distributed under the terms of the Modified BSD License.
16 16 #
17 17 # The full license is in the file COPYING.rst, distributed with this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Minimal Python version sanity check
22 22 #-----------------------------------------------------------------------------
23 23 from __future__ import print_function
24 24
25 25 import sys
26 26
27 27 # This check is also made in IPython/__init__, don't forget to update both when
28 28 # changing Python version requirements.
29 29 v = sys.version_info
30 30 if v[:2] < (2,7) or (v[0] >= 3 and v[:2] < (3,3)):
31 31 error = "ERROR: IPython requires Python version 2.7 or 3.3 or above."
32 32 print(error, file=sys.stderr)
33 33 sys.exit(1)
34 34
35 35 PY3 = (sys.version_info[0] >= 3)
36 36
37 37 # At least we're on the python version we need, move on.
38 38
39 39 #-------------------------------------------------------------------------------
40 40 # Imports
41 41 #-------------------------------------------------------------------------------
42 42
43 43 # Stdlib imports
44 44 import os
45 45 import shutil
46 46
47 47 from glob import glob
48 48
49 49 # BEFORE importing distutils, remove MANIFEST. distutils doesn't properly
50 50 # update it when the contents of directories change.
51 51 if os.path.exists('MANIFEST'): os.remove('MANIFEST')
52 52
53 53 from distutils.core import setup
54 54
55 55 # Our own imports
56 56 from setupbase import target_update
57 57
58 58 from setupbase import (
59 59 setup_args,
60 60 find_packages,
61 61 find_package_data,
62 62 check_package_data_first,
63 63 find_entry_points,
64 64 build_scripts_entrypt,
65 65 find_data_files,
66 66 check_for_dependencies,
67 67 git_prebuild,
68 68 check_submodule_status,
69 69 update_submodules,
70 70 require_submodules,
71 71 UpdateSubmodules,
72 72 get_bdist_wheel,
73 73 CompileCSS,
74 74 JavascriptVersion,
75 75 css_js_prerelease,
76 76 install_symlinked,
77 77 install_lib_symlink,
78 78 install_scripts_for_symlink,
79 79 unsymlink,
80 80 )
81 81 from setupext import setupext
82 82
83 83 isfile = os.path.isfile
84 84 pjoin = os.path.join
85 85
86 86 #-------------------------------------------------------------------------------
87 87 # Handle OS specific things
88 88 #-------------------------------------------------------------------------------
89 89
90 90 if os.name in ('nt','dos'):
91 91 os_name = 'windows'
92 92 else:
93 93 os_name = os.name
94 94
95 95 # Under Windows, 'sdist' has not been supported. Now that the docs build with
96 96 # Sphinx it might work, but let's not turn it on until someone confirms that it
97 97 # actually works.
98 98 if os_name == 'windows' and 'sdist' in sys.argv:
99 99 print('The sdist command is not available under Windows. Exiting.')
100 100 sys.exit(1)
101 101
102 102 #-------------------------------------------------------------------------------
103 103 # Make sure we aren't trying to run without submodules
104 104 #-------------------------------------------------------------------------------
105 105 here = os.path.abspath(os.path.dirname(__file__))
106 106
107 107 def require_clean_submodules():
108 108 """Check on git submodules before distutils can do anything
109 109
110 110 Since distutils cannot be trusted to update the tree
111 111 after everything has been set in motion,
112 112 this is not a distutils command.
113 113 """
114 114 # PACKAGERS: Add a return here to skip checks for git submodules
115 115
116 116 # don't do anything if nothing is actually supposed to happen
117 117 for do_nothing in ('-h', '--help', '--help-commands', 'clean', 'submodule'):
118 118 if do_nothing in sys.argv:
119 119 return
120 120
121 121 status = check_submodule_status(here)
122 122
123 123 if status == "missing":
124 124 print("checking out submodules for the first time")
125 125 update_submodules(here)
126 126 elif status == "unclean":
127 127 print('\n'.join([
128 128 "Cannot build / install IPython with unclean submodules",
129 129 "Please update submodules with",
130 130 " python setup.py submodule",
131 131 "or",
132 132 " git submodule update",
133 133 "or commit any submodule changes you have made."
134 134 ]))
135 135 sys.exit(1)
136 136
137 137 require_clean_submodules()
138 138
139 139 #-------------------------------------------------------------------------------
140 140 # Things related to the IPython documentation
141 141 #-------------------------------------------------------------------------------
142 142
143 143 # update the manuals when building a source dist
144 144 if len(sys.argv) >= 2 and sys.argv[1] in ('sdist','bdist_rpm'):
145 145
146 146 # List of things to be updated. Each entry is a triplet of args for
147 147 # target_update()
148 148 to_update = [
149 149 # FIXME - Disabled for now: we need to redo an automatic way
150 150 # of generating the magic info inside the rst.
151 151 #('docs/magic.tex',
152 152 #['IPython/Magic.py'],
153 153 #"cd doc && ./update_magic.sh" ),
154 154
155 155 ('docs/man/ipcluster.1.gz',
156 156 ['docs/man/ipcluster.1'],
157 157 'cd docs/man && gzip -9c ipcluster.1 > ipcluster.1.gz'),
158 158
159 159 ('docs/man/ipcontroller.1.gz',
160 160 ['docs/man/ipcontroller.1'],
161 161 'cd docs/man && gzip -9c ipcontroller.1 > ipcontroller.1.gz'),
162 162
163 163 ('docs/man/ipengine.1.gz',
164 164 ['docs/man/ipengine.1'],
165 165 'cd docs/man && gzip -9c ipengine.1 > ipengine.1.gz'),
166 166
167 167 ('docs/man/ipython.1.gz',
168 168 ['docs/man/ipython.1'],
169 169 'cd docs/man && gzip -9c ipython.1 > ipython.1.gz'),
170 170
171 171 ]
172 172
173 173
174 174 [ target_update(*t) for t in to_update ]
175 175
176 176 #---------------------------------------------------------------------------
177 177 # Find all the packages, package data, and data_files
178 178 #---------------------------------------------------------------------------
179 179
180 180 packages = find_packages()
181 181 package_data = find_package_data()
182 182
183 183 data_files = find_data_files()
184 184
185 185 setup_args['packages'] = packages
186 186 setup_args['package_data'] = package_data
187 187 setup_args['data_files'] = data_files
188 188
189 189 #---------------------------------------------------------------------------
190 190 # custom distutils commands
191 191 #---------------------------------------------------------------------------
192 192 # imports here, so they are after setuptools import if there was one
193 193 from distutils.command.sdist import sdist
194 194 from distutils.command.upload import upload
195 195
196 196 class UploadWindowsInstallers(upload):
197 197
198 198 description = "Upload Windows installers to PyPI (only used from tools/release_windows.py)"
199 199 user_options = upload.user_options + [
200 200 ('files=', 'f', 'exe file (or glob) to upload')
201 201 ]
202 202 def initialize_options(self):
203 203 upload.initialize_options(self)
204 204 meta = self.distribution.metadata
205 205 base = '{name}-{version}'.format(
206 206 name=meta.get_name(),
207 207 version=meta.get_version()
208 208 )
209 209 self.files = os.path.join('dist', '%s.*.exe' % base)
210 210
211 211 def run(self):
212 212 for dist_file in glob(self.files):
213 213 self.upload_file('bdist_wininst', 'any', dist_file)
214 214
215 215 setup_args['cmdclass'] = {
216 216 'build_py': css_js_prerelease(
217 217 check_package_data_first(git_prebuild('IPython'))),
218 218 'sdist' : css_js_prerelease(git_prebuild('IPython', sdist)),
219 219 'upload_wininst' : UploadWindowsInstallers,
220 220 'submodule' : UpdateSubmodules,
221 221 'css' : CompileCSS,
222 222 'symlink': install_symlinked,
223 223 'install_lib_symlink': install_lib_symlink,
224 224 'install_scripts_sym': install_scripts_for_symlink,
225 225 'unsymlink': unsymlink,
226 226 'jsversion' : JavascriptVersion,
227 227 }
228 228
229 229 #---------------------------------------------------------------------------
230 230 # Handle scripts, dependencies, and setuptools specific things
231 231 #---------------------------------------------------------------------------
232 232
233 233 # For some commands, use setuptools. Note that we do NOT list install here!
234 234 # If you want a setuptools-enhanced install, just run 'setupegg.py install'
235 235 needs_setuptools = set(('develop', 'release', 'bdist_egg', 'bdist_rpm',
236 236 'bdist', 'bdist_dumb', 'bdist_wininst', 'bdist_wheel',
237 237 'egg_info', 'easy_install', 'upload', 'install_egg_info',
238 238 ))
239 239
240 240 if len(needs_setuptools.intersection(sys.argv)) > 0:
241 241 import setuptools
242 242
243 243 # This dict is used for passing extra arguments that are setuptools
244 244 # specific to setup
245 245 setuptools_extra_args = {}
246 246
247 247 # setuptools requirements
248 248
249 249 pyzmq = 'pyzmq>=13'
250 250
251 251 extras_require = dict(
252 252 parallel = [pyzmq],
253 253 qtconsole = [pyzmq, 'pygments'],
254 254 doc = ['Sphinx>=1.1', 'numpydoc'],
255 255 test = ['nose>=0.10.1', 'requests'],
256 256 terminal = [],
257 257 nbformat = ['jsonschema>=2.0'],
258 258 notebook = ['tornado>=4.0', pyzmq, 'jinja2', 'pygments', 'mistune>=0.5'],
259 259 nbconvert = ['pygments', 'jinja2', 'mistune>=0.3.1']
260 260 )
261 261
262 262 if not sys.platform.startswith('win'):
263 263 extras_require['notebook'].append('terminado>=0.3.3')
264 264
265 265 if sys.version_info < (3, 3):
266 266 extras_require['test'].append('mock')
267 267
268 268 extras_require['notebook'].extend(extras_require['nbformat'])
269 269 extras_require['nbconvert'].extend(extras_require['nbformat'])
270 270
271 271 install_requires = [
272 'decorator',
272 273 'path.py', # required by pickleshare, remove when pickleshare is added here
273 274 ]
274 275
275 276 # add readline
276 277 if sys.platform == 'darwin':
277 278 if 'bdist_wheel' in sys.argv[1:] or not setupext.check_for_readline():
278 279 install_requires.append('gnureadline')
279 280 elif sys.platform.startswith('win'):
280 281 extras_require['terminal'].append('pyreadline>=2.0')
281 282
282 283 everything = set()
283 284 for deps in extras_require.values():
284 285 everything.update(deps)
285 286 extras_require['all'] = everything
286 287
287 288 if 'setuptools' in sys.modules:
288 289 # setup.py develop should check for submodules
289 290 from setuptools.command.develop import develop
290 291 setup_args['cmdclass']['develop'] = require_submodules(develop)
291 292 setup_args['cmdclass']['bdist_wheel'] = css_js_prerelease(get_bdist_wheel())
292 293
293 294 setuptools_extra_args['zip_safe'] = False
294 295 setuptools_extra_args['entry_points'] = {
295 296 'console_scripts': find_entry_points(),
296 297 'pygments.lexers': [
297 298 'ipythonconsole = IPython.lib.lexers:IPythonConsoleLexer',
298 299 'ipython = IPython.lib.lexers:IPythonLexer',
299 300 'ipython3 = IPython.lib.lexers:IPython3Lexer',
300 301 ],
301 302 }
302 303 setup_args['extras_require'] = extras_require
303 304 requires = setup_args['install_requires'] = install_requires
304 305
305 306 # Script to be run by the windows binary installer after the default setup
306 307 # routine, to add shortcuts and similar windows-only things. Windows
307 308 # post-install scripts MUST reside in the scripts/ dir, otherwise distutils
308 309 # doesn't find them.
309 310 if 'bdist_wininst' in sys.argv:
310 311 if len(sys.argv) > 2 and \
311 312 ('sdist' in sys.argv or 'bdist_rpm' in sys.argv):
312 313 print("ERROR: bdist_wininst must be run alone. Exiting.", file=sys.stderr)
313 314 sys.exit(1)
314 315 setup_args['data_files'].append(
315 316 ['Scripts', ('scripts/ipython.ico', 'scripts/ipython_nb.ico')])
316 317 setup_args['scripts'] = [pjoin('scripts','ipython_win_post_install.py')]
317 318 setup_args['options'] = {"bdist_wininst":
318 319 {"install_script":
319 320 "ipython_win_post_install.py"}}
320 321
321 322 else:
322 323 # If we are installing without setuptools, call this function which will
323 324 # check for dependencies an inform the user what is needed. This is
324 325 # just to make life easy for users.
325 326 for install_cmd in ('install', 'symlink'):
326 327 if install_cmd in sys.argv:
327 328 check_for_dependencies()
328 329 break
329 330 # scripts has to be a non-empty list, or install_scripts isn't called
330 331 setup_args['scripts'] = [e.split('=')[0].strip() for e in find_entry_points()]
331 332
332 333 setup_args['cmdclass']['build_scripts'] = build_scripts_entrypt
333 334
334 335 #---------------------------------------------------------------------------
335 336 # Do the actual setup now
336 337 #---------------------------------------------------------------------------
337 338
338 339 setup_args.update(setuptools_extra_args)
339 340
340 341 def main():
341 342 setup(**setup_args)
342 343
343 344 if __name__ == '__main__':
344 345 main()
1 NO CONTENT: file was removed
1 NO CONTENT: file was removed
General Comments 0
You need to be logged in to leave comments. Login now