##// END OF EJS Templates
remove decorator from external
MinRK -
Show More
@@ -1,621 +1,621 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """A base class for a configurable application."""
2 """A base class for a configurable application."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 from __future__ import print_function
7 from __future__ import print_function
8
8
9 import json
9 import json
10 import logging
10 import logging
11 import os
11 import os
12 import re
12 import re
13 import sys
13 import sys
14 from copy import deepcopy
14 from copy import deepcopy
15 from collections import defaultdict
15 from collections import defaultdict
16
16
17 from IPython.external.decorator import decorator
17 from decorator import decorator
18
18
19 from IPython.config.configurable import SingletonConfigurable
19 from IPython.config.configurable import SingletonConfigurable
20 from IPython.config.loader import (
20 from IPython.config.loader import (
21 KVArgParseConfigLoader, PyFileConfigLoader, Config, ArgumentError, ConfigFileNotFound, JSONFileConfigLoader
21 KVArgParseConfigLoader, PyFileConfigLoader, Config, ArgumentError, ConfigFileNotFound, JSONFileConfigLoader
22 )
22 )
23
23
24 from IPython.utils.traitlets import (
24 from IPython.utils.traitlets import (
25 Unicode, List, Enum, Dict, Instance, TraitError
25 Unicode, List, Enum, Dict, Instance, TraitError
26 )
26 )
27 from IPython.utils.importstring import import_item
27 from IPython.utils.importstring import import_item
28 from IPython.utils.text import indent, wrap_paragraphs, dedent
28 from IPython.utils.text import indent, wrap_paragraphs, dedent
29 from IPython.utils import py3compat
29 from IPython.utils import py3compat
30 from IPython.utils.py3compat import string_types, iteritems
30 from IPython.utils.py3compat import string_types, iteritems
31
31
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33 # Descriptions for the various sections
33 # Descriptions for the various sections
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35
35
36 # merge flags&aliases into options
36 # merge flags&aliases into options
37 option_description = """
37 option_description = """
38 Arguments that take values are actually convenience aliases to full
38 Arguments that take values are actually convenience aliases to full
39 Configurables, whose aliases are listed on the help line. For more information
39 Configurables, whose aliases are listed on the help line. For more information
40 on full configurables, see '--help-all'.
40 on full configurables, see '--help-all'.
41 """.strip() # trim newlines of front and back
41 """.strip() # trim newlines of front and back
42
42
43 keyvalue_description = """
43 keyvalue_description = """
44 Parameters are set from command-line arguments of the form:
44 Parameters are set from command-line arguments of the form:
45 `--Class.trait=value`.
45 `--Class.trait=value`.
46 This line is evaluated in Python, so simple expressions are allowed, e.g.::
46 This line is evaluated in Python, so simple expressions are allowed, e.g.::
47 `--C.a='range(3)'` For setting C.a=[0,1,2].
47 `--C.a='range(3)'` For setting C.a=[0,1,2].
48 """.strip() # trim newlines of front and back
48 """.strip() # trim newlines of front and back
49
49
50 # sys.argv can be missing, for example when python is embedded. See the docs
50 # sys.argv can be missing, for example when python is embedded. See the docs
51 # for details: http://docs.python.org/2/c-api/intro.html#embedding-python
51 # for details: http://docs.python.org/2/c-api/intro.html#embedding-python
52 if not hasattr(sys, "argv"):
52 if not hasattr(sys, "argv"):
53 sys.argv = [""]
53 sys.argv = [""]
54
54
55 subcommand_description = """
55 subcommand_description = """
56 Subcommands are launched as `{app} cmd [args]`. For information on using
56 Subcommands are launched as `{app} cmd [args]`. For information on using
57 subcommand 'cmd', do: `{app} cmd -h`.
57 subcommand 'cmd', do: `{app} cmd -h`.
58 """
58 """
59 # get running program name
59 # get running program name
60
60
61 #-----------------------------------------------------------------------------
61 #-----------------------------------------------------------------------------
62 # Application class
62 # Application class
63 #-----------------------------------------------------------------------------
63 #-----------------------------------------------------------------------------
64
64
65 @decorator
65 @decorator
66 def catch_config_error(method, app, *args, **kwargs):
66 def catch_config_error(method, app, *args, **kwargs):
67 """Method decorator for catching invalid config (Trait/ArgumentErrors) during init.
67 """Method decorator for catching invalid config (Trait/ArgumentErrors) during init.
68
68
69 On a TraitError (generally caused by bad config), this will print the trait's
69 On a TraitError (generally caused by bad config), this will print the trait's
70 message, and exit the app.
70 message, and exit the app.
71
71
72 For use on init methods, to prevent invoking excepthook on invalid input.
72 For use on init methods, to prevent invoking excepthook on invalid input.
73 """
73 """
74 try:
74 try:
75 return method(app, *args, **kwargs)
75 return method(app, *args, **kwargs)
76 except (TraitError, ArgumentError) as e:
76 except (TraitError, ArgumentError) as e:
77 app.print_help()
77 app.print_help()
78 app.log.fatal("Bad config encountered during initialization:")
78 app.log.fatal("Bad config encountered during initialization:")
79 app.log.fatal(str(e))
79 app.log.fatal(str(e))
80 app.log.debug("Config at the time: %s", app.config)
80 app.log.debug("Config at the time: %s", app.config)
81 app.exit(1)
81 app.exit(1)
82
82
83
83
84 class ApplicationError(Exception):
84 class ApplicationError(Exception):
85 pass
85 pass
86
86
87 class LevelFormatter(logging.Formatter):
87 class LevelFormatter(logging.Formatter):
88 """Formatter with additional `highlevel` record
88 """Formatter with additional `highlevel` record
89
89
90 This field is empty if log level is less than highlevel_limit,
90 This field is empty if log level is less than highlevel_limit,
91 otherwise it is formatted with self.highlevel_format.
91 otherwise it is formatted with self.highlevel_format.
92
92
93 Useful for adding 'WARNING' to warning messages,
93 Useful for adding 'WARNING' to warning messages,
94 without adding 'INFO' to info, etc.
94 without adding 'INFO' to info, etc.
95 """
95 """
96 highlevel_limit = logging.WARN
96 highlevel_limit = logging.WARN
97 highlevel_format = " %(levelname)s |"
97 highlevel_format = " %(levelname)s |"
98
98
99 def format(self, record):
99 def format(self, record):
100 if record.levelno >= self.highlevel_limit:
100 if record.levelno >= self.highlevel_limit:
101 record.highlevel = self.highlevel_format % record.__dict__
101 record.highlevel = self.highlevel_format % record.__dict__
102 else:
102 else:
103 record.highlevel = ""
103 record.highlevel = ""
104 return super(LevelFormatter, self).format(record)
104 return super(LevelFormatter, self).format(record)
105
105
106
106
107 class Application(SingletonConfigurable):
107 class Application(SingletonConfigurable):
108 """A singleton application with full configuration support."""
108 """A singleton application with full configuration support."""
109
109
110 # The name of the application, will usually match the name of the command
110 # The name of the application, will usually match the name of the command
111 # line application
111 # line application
112 name = Unicode(u'application')
112 name = Unicode(u'application')
113
113
114 # The description of the application that is printed at the beginning
114 # The description of the application that is printed at the beginning
115 # of the help.
115 # of the help.
116 description = Unicode(u'This is an application.')
116 description = Unicode(u'This is an application.')
117 # default section descriptions
117 # default section descriptions
118 option_description = Unicode(option_description)
118 option_description = Unicode(option_description)
119 keyvalue_description = Unicode(keyvalue_description)
119 keyvalue_description = Unicode(keyvalue_description)
120 subcommand_description = Unicode(subcommand_description)
120 subcommand_description = Unicode(subcommand_description)
121
121
122 # The usage and example string that goes at the end of the help string.
122 # The usage and example string that goes at the end of the help string.
123 examples = Unicode()
123 examples = Unicode()
124
124
125 # A sequence of Configurable subclasses whose config=True attributes will
125 # A sequence of Configurable subclasses whose config=True attributes will
126 # be exposed at the command line.
126 # be exposed at the command line.
127 classes = []
127 classes = []
128 @property
128 @property
129 def _help_classes(self):
129 def _help_classes(self):
130 """Define `App.help_classes` if CLI classes should differ from config file classes"""
130 """Define `App.help_classes` if CLI classes should differ from config file classes"""
131 return getattr(self, 'help_classes', self.classes)
131 return getattr(self, 'help_classes', self.classes)
132
132
133 @property
133 @property
134 def _config_classes(self):
134 def _config_classes(self):
135 """Define `App.config_classes` if config file classes should differ from CLI classes."""
135 """Define `App.config_classes` if config file classes should differ from CLI classes."""
136 return getattr(self, 'config_classes', self.classes)
136 return getattr(self, 'config_classes', self.classes)
137
137
138 # The version string of this application.
138 # The version string of this application.
139 version = Unicode(u'0.0')
139 version = Unicode(u'0.0')
140
140
141 # the argv used to initialize the application
141 # the argv used to initialize the application
142 argv = List()
142 argv = List()
143
143
144 # The log level for the application
144 # The log level for the application
145 log_level = Enum((0,10,20,30,40,50,'DEBUG','INFO','WARN','ERROR','CRITICAL'),
145 log_level = Enum((0,10,20,30,40,50,'DEBUG','INFO','WARN','ERROR','CRITICAL'),
146 default_value=logging.WARN,
146 default_value=logging.WARN,
147 config=True,
147 config=True,
148 help="Set the log level by value or name.")
148 help="Set the log level by value or name.")
149 def _log_level_changed(self, name, old, new):
149 def _log_level_changed(self, name, old, new):
150 """Adjust the log level when log_level is set."""
150 """Adjust the log level when log_level is set."""
151 if isinstance(new, string_types):
151 if isinstance(new, string_types):
152 new = getattr(logging, new)
152 new = getattr(logging, new)
153 self.log_level = new
153 self.log_level = new
154 self.log.setLevel(new)
154 self.log.setLevel(new)
155
155
156 _log_formatter_cls = LevelFormatter
156 _log_formatter_cls = LevelFormatter
157
157
158 log_datefmt = Unicode("%Y-%m-%d %H:%M:%S", config=True,
158 log_datefmt = Unicode("%Y-%m-%d %H:%M:%S", config=True,
159 help="The date format used by logging formatters for %(asctime)s"
159 help="The date format used by logging formatters for %(asctime)s"
160 )
160 )
161 def _log_datefmt_changed(self, name, old, new):
161 def _log_datefmt_changed(self, name, old, new):
162 self._log_format_changed('log_format', self.log_format, self.log_format)
162 self._log_format_changed('log_format', self.log_format, self.log_format)
163
163
164 log_format = Unicode("[%(name)s]%(highlevel)s %(message)s", config=True,
164 log_format = Unicode("[%(name)s]%(highlevel)s %(message)s", config=True,
165 help="The Logging format template",
165 help="The Logging format template",
166 )
166 )
167 def _log_format_changed(self, name, old, new):
167 def _log_format_changed(self, name, old, new):
168 """Change the log formatter when log_format is set."""
168 """Change the log formatter when log_format is set."""
169 _log_handler = self.log.handlers[0]
169 _log_handler = self.log.handlers[0]
170 _log_formatter = self._log_formatter_cls(fmt=new, datefmt=self.log_datefmt)
170 _log_formatter = self._log_formatter_cls(fmt=new, datefmt=self.log_datefmt)
171 _log_handler.setFormatter(_log_formatter)
171 _log_handler.setFormatter(_log_formatter)
172
172
173
173
174 log = Instance(logging.Logger)
174 log = Instance(logging.Logger)
175 def _log_default(self):
175 def _log_default(self):
176 """Start logging for this application.
176 """Start logging for this application.
177
177
178 The default is to log to stderr using a StreamHandler, if no default
178 The default is to log to stderr using a StreamHandler, if no default
179 handler already exists. The log level starts at logging.WARN, but this
179 handler already exists. The log level starts at logging.WARN, but this
180 can be adjusted by setting the ``log_level`` attribute.
180 can be adjusted by setting the ``log_level`` attribute.
181 """
181 """
182 log = logging.getLogger(self.__class__.__name__)
182 log = logging.getLogger(self.__class__.__name__)
183 log.setLevel(self.log_level)
183 log.setLevel(self.log_level)
184 log.propagate = False
184 log.propagate = False
185 _log = log # copied from Logger.hasHandlers() (new in Python 3.2)
185 _log = log # copied from Logger.hasHandlers() (new in Python 3.2)
186 while _log:
186 while _log:
187 if _log.handlers:
187 if _log.handlers:
188 return log
188 return log
189 if not _log.propagate:
189 if not _log.propagate:
190 break
190 break
191 else:
191 else:
192 _log = _log.parent
192 _log = _log.parent
193 if sys.executable.endswith('pythonw.exe'):
193 if sys.executable.endswith('pythonw.exe'):
194 # this should really go to a file, but file-logging is only
194 # this should really go to a file, but file-logging is only
195 # hooked up in parallel applications
195 # hooked up in parallel applications
196 _log_handler = logging.StreamHandler(open(os.devnull, 'w'))
196 _log_handler = logging.StreamHandler(open(os.devnull, 'w'))
197 else:
197 else:
198 _log_handler = logging.StreamHandler()
198 _log_handler = logging.StreamHandler()
199 _log_formatter = self._log_formatter_cls(fmt=self.log_format, datefmt=self.log_datefmt)
199 _log_formatter = self._log_formatter_cls(fmt=self.log_format, datefmt=self.log_datefmt)
200 _log_handler.setFormatter(_log_formatter)
200 _log_handler.setFormatter(_log_formatter)
201 log.addHandler(_log_handler)
201 log.addHandler(_log_handler)
202 return log
202 return log
203
203
204 # the alias map for configurables
204 # the alias map for configurables
205 aliases = Dict({'log-level' : 'Application.log_level'})
205 aliases = Dict({'log-level' : 'Application.log_level'})
206
206
207 # flags for loading Configurables or store_const style flags
207 # flags for loading Configurables or store_const style flags
208 # flags are loaded from this dict by '--key' flags
208 # flags are loaded from this dict by '--key' flags
209 # this must be a dict of two-tuples, the first element being the Config/dict
209 # this must be a dict of two-tuples, the first element being the Config/dict
210 # and the second being the help string for the flag
210 # and the second being the help string for the flag
211 flags = Dict()
211 flags = Dict()
212 def _flags_changed(self, name, old, new):
212 def _flags_changed(self, name, old, new):
213 """ensure flags dict is valid"""
213 """ensure flags dict is valid"""
214 for key,value in iteritems(new):
214 for key,value in iteritems(new):
215 assert len(value) == 2, "Bad flag: %r:%s"%(key,value)
215 assert len(value) == 2, "Bad flag: %r:%s"%(key,value)
216 assert isinstance(value[0], (dict, Config)), "Bad flag: %r:%s"%(key,value)
216 assert isinstance(value[0], (dict, Config)), "Bad flag: %r:%s"%(key,value)
217 assert isinstance(value[1], string_types), "Bad flag: %r:%s"%(key,value)
217 assert isinstance(value[1], string_types), "Bad flag: %r:%s"%(key,value)
218
218
219
219
220 # subcommands for launching other applications
220 # subcommands for launching other applications
221 # if this is not empty, this will be a parent Application
221 # if this is not empty, this will be a parent Application
222 # this must be a dict of two-tuples,
222 # this must be a dict of two-tuples,
223 # the first element being the application class/import string
223 # the first element being the application class/import string
224 # and the second being the help string for the subcommand
224 # and the second being the help string for the subcommand
225 subcommands = Dict()
225 subcommands = Dict()
226 # parse_command_line will initialize a subapp, if requested
226 # parse_command_line will initialize a subapp, if requested
227 subapp = Instance('IPython.config.application.Application', allow_none=True)
227 subapp = Instance('IPython.config.application.Application', allow_none=True)
228
228
229 # extra command-line arguments that don't set config values
229 # extra command-line arguments that don't set config values
230 extra_args = List(Unicode)
230 extra_args = List(Unicode)
231
231
232
232
233 def __init__(self, **kwargs):
233 def __init__(self, **kwargs):
234 SingletonConfigurable.__init__(self, **kwargs)
234 SingletonConfigurable.__init__(self, **kwargs)
235 # Ensure my class is in self.classes, so my attributes appear in command line
235 # Ensure my class is in self.classes, so my attributes appear in command line
236 # options and config files.
236 # options and config files.
237 if self.__class__ not in self.classes:
237 if self.__class__ not in self.classes:
238 self.classes.insert(0, self.__class__)
238 self.classes.insert(0, self.__class__)
239
239
240 def _config_changed(self, name, old, new):
240 def _config_changed(self, name, old, new):
241 SingletonConfigurable._config_changed(self, name, old, new)
241 SingletonConfigurable._config_changed(self, name, old, new)
242 self.log.debug('Config changed:')
242 self.log.debug('Config changed:')
243 self.log.debug(repr(new))
243 self.log.debug(repr(new))
244
244
245 @catch_config_error
245 @catch_config_error
246 def initialize(self, argv=None):
246 def initialize(self, argv=None):
247 """Do the basic steps to configure me.
247 """Do the basic steps to configure me.
248
248
249 Override in subclasses.
249 Override in subclasses.
250 """
250 """
251 self.parse_command_line(argv)
251 self.parse_command_line(argv)
252
252
253
253
254 def start(self):
254 def start(self):
255 """Start the app mainloop.
255 """Start the app mainloop.
256
256
257 Override in subclasses.
257 Override in subclasses.
258 """
258 """
259 if self.subapp is not None:
259 if self.subapp is not None:
260 return self.subapp.start()
260 return self.subapp.start()
261
261
262 def print_alias_help(self):
262 def print_alias_help(self):
263 """Print the alias part of the help."""
263 """Print the alias part of the help."""
264 if not self.aliases:
264 if not self.aliases:
265 return
265 return
266
266
267 lines = []
267 lines = []
268 classdict = {}
268 classdict = {}
269 for cls in self._help_classes:
269 for cls in self._help_classes:
270 # include all parents (up to, but excluding Configurable) in available names
270 # include all parents (up to, but excluding Configurable) in available names
271 for c in cls.mro()[:-3]:
271 for c in cls.mro()[:-3]:
272 classdict[c.__name__] = c
272 classdict[c.__name__] = c
273
273
274 for alias, longname in iteritems(self.aliases):
274 for alias, longname in iteritems(self.aliases):
275 classname, traitname = longname.split('.',1)
275 classname, traitname = longname.split('.',1)
276 cls = classdict[classname]
276 cls = classdict[classname]
277
277
278 trait = cls.class_traits(config=True)[traitname]
278 trait = cls.class_traits(config=True)[traitname]
279 help = cls.class_get_trait_help(trait).splitlines()
279 help = cls.class_get_trait_help(trait).splitlines()
280 # reformat first line
280 # reformat first line
281 help[0] = help[0].replace(longname, alias) + ' (%s)'%longname
281 help[0] = help[0].replace(longname, alias) + ' (%s)'%longname
282 if len(alias) == 1:
282 if len(alias) == 1:
283 help[0] = help[0].replace('--%s='%alias, '-%s '%alias)
283 help[0] = help[0].replace('--%s='%alias, '-%s '%alias)
284 lines.extend(help)
284 lines.extend(help)
285 # lines.append('')
285 # lines.append('')
286 print(os.linesep.join(lines))
286 print(os.linesep.join(lines))
287
287
288 def print_flag_help(self):
288 def print_flag_help(self):
289 """Print the flag part of the help."""
289 """Print the flag part of the help."""
290 if not self.flags:
290 if not self.flags:
291 return
291 return
292
292
293 lines = []
293 lines = []
294 for m, (cfg,help) in iteritems(self.flags):
294 for m, (cfg,help) in iteritems(self.flags):
295 prefix = '--' if len(m) > 1 else '-'
295 prefix = '--' if len(m) > 1 else '-'
296 lines.append(prefix+m)
296 lines.append(prefix+m)
297 lines.append(indent(dedent(help.strip())))
297 lines.append(indent(dedent(help.strip())))
298 # lines.append('')
298 # lines.append('')
299 print(os.linesep.join(lines))
299 print(os.linesep.join(lines))
300
300
301 def print_options(self):
301 def print_options(self):
302 if not self.flags and not self.aliases:
302 if not self.flags and not self.aliases:
303 return
303 return
304 lines = ['Options']
304 lines = ['Options']
305 lines.append('-'*len(lines[0]))
305 lines.append('-'*len(lines[0]))
306 lines.append('')
306 lines.append('')
307 for p in wrap_paragraphs(self.option_description):
307 for p in wrap_paragraphs(self.option_description):
308 lines.append(p)
308 lines.append(p)
309 lines.append('')
309 lines.append('')
310 print(os.linesep.join(lines))
310 print(os.linesep.join(lines))
311 self.print_flag_help()
311 self.print_flag_help()
312 self.print_alias_help()
312 self.print_alias_help()
313 print()
313 print()
314
314
315 def print_subcommands(self):
315 def print_subcommands(self):
316 """Print the subcommand part of the help."""
316 """Print the subcommand part of the help."""
317 if not self.subcommands:
317 if not self.subcommands:
318 return
318 return
319
319
320 lines = ["Subcommands"]
320 lines = ["Subcommands"]
321 lines.append('-'*len(lines[0]))
321 lines.append('-'*len(lines[0]))
322 lines.append('')
322 lines.append('')
323 for p in wrap_paragraphs(self.subcommand_description.format(
323 for p in wrap_paragraphs(self.subcommand_description.format(
324 app=self.name)):
324 app=self.name)):
325 lines.append(p)
325 lines.append(p)
326 lines.append('')
326 lines.append('')
327 for subc, (cls, help) in iteritems(self.subcommands):
327 for subc, (cls, help) in iteritems(self.subcommands):
328 lines.append(subc)
328 lines.append(subc)
329 if help:
329 if help:
330 lines.append(indent(dedent(help.strip())))
330 lines.append(indent(dedent(help.strip())))
331 lines.append('')
331 lines.append('')
332 print(os.linesep.join(lines))
332 print(os.linesep.join(lines))
333
333
334 def print_help(self, classes=False):
334 def print_help(self, classes=False):
335 """Print the help for each Configurable class in self.classes.
335 """Print the help for each Configurable class in self.classes.
336
336
337 If classes=False (the default), only flags and aliases are printed.
337 If classes=False (the default), only flags and aliases are printed.
338 """
338 """
339 self.print_description()
339 self.print_description()
340 self.print_subcommands()
340 self.print_subcommands()
341 self.print_options()
341 self.print_options()
342
342
343 if classes:
343 if classes:
344 help_classes = self._help_classes
344 help_classes = self._help_classes
345 if help_classes:
345 if help_classes:
346 print("Class parameters")
346 print("Class parameters")
347 print("----------------")
347 print("----------------")
348 print()
348 print()
349 for p in wrap_paragraphs(self.keyvalue_description):
349 for p in wrap_paragraphs(self.keyvalue_description):
350 print(p)
350 print(p)
351 print()
351 print()
352
352
353 for cls in help_classes:
353 for cls in help_classes:
354 cls.class_print_help()
354 cls.class_print_help()
355 print()
355 print()
356 else:
356 else:
357 print("To see all available configurables, use `--help-all`")
357 print("To see all available configurables, use `--help-all`")
358 print()
358 print()
359
359
360 self.print_examples()
360 self.print_examples()
361
361
362
362
363 def print_description(self):
363 def print_description(self):
364 """Print the application description."""
364 """Print the application description."""
365 for p in wrap_paragraphs(self.description):
365 for p in wrap_paragraphs(self.description):
366 print(p)
366 print(p)
367 print()
367 print()
368
368
369 def print_examples(self):
369 def print_examples(self):
370 """Print usage and examples.
370 """Print usage and examples.
371
371
372 This usage string goes at the end of the command line help string
372 This usage string goes at the end of the command line help string
373 and should contain examples of the application's usage.
373 and should contain examples of the application's usage.
374 """
374 """
375 if self.examples:
375 if self.examples:
376 print("Examples")
376 print("Examples")
377 print("--------")
377 print("--------")
378 print()
378 print()
379 print(indent(dedent(self.examples.strip())))
379 print(indent(dedent(self.examples.strip())))
380 print()
380 print()
381
381
382 def print_version(self):
382 def print_version(self):
383 """Print the version string."""
383 """Print the version string."""
384 print(self.version)
384 print(self.version)
385
385
386 def update_config(self, config):
386 def update_config(self, config):
387 """Fire the traits events when the config is updated."""
387 """Fire the traits events when the config is updated."""
388 # Save a copy of the current config.
388 # Save a copy of the current config.
389 newconfig = deepcopy(self.config)
389 newconfig = deepcopy(self.config)
390 # Merge the new config into the current one.
390 # Merge the new config into the current one.
391 newconfig.merge(config)
391 newconfig.merge(config)
392 # Save the combined config as self.config, which triggers the traits
392 # Save the combined config as self.config, which triggers the traits
393 # events.
393 # events.
394 self.config = newconfig
394 self.config = newconfig
395
395
396 @catch_config_error
396 @catch_config_error
397 def initialize_subcommand(self, subc, argv=None):
397 def initialize_subcommand(self, subc, argv=None):
398 """Initialize a subcommand with argv."""
398 """Initialize a subcommand with argv."""
399 subapp,help = self.subcommands.get(subc)
399 subapp,help = self.subcommands.get(subc)
400
400
401 if isinstance(subapp, string_types):
401 if isinstance(subapp, string_types):
402 subapp = import_item(subapp)
402 subapp = import_item(subapp)
403
403
404 # clear existing instances
404 # clear existing instances
405 self.__class__.clear_instance()
405 self.__class__.clear_instance()
406 # instantiate
406 # instantiate
407 self.subapp = subapp.instance(config=self.config)
407 self.subapp = subapp.instance(config=self.config)
408 # and initialize subapp
408 # and initialize subapp
409 self.subapp.initialize(argv)
409 self.subapp.initialize(argv)
410
410
411 def flatten_flags(self):
411 def flatten_flags(self):
412 """flatten flags and aliases, so cl-args override as expected.
412 """flatten flags and aliases, so cl-args override as expected.
413
413
414 This prevents issues such as an alias pointing to InteractiveShell,
414 This prevents issues such as an alias pointing to InteractiveShell,
415 but a config file setting the same trait in TerminalInteraciveShell
415 but a config file setting the same trait in TerminalInteraciveShell
416 getting inappropriate priority over the command-line arg.
416 getting inappropriate priority over the command-line arg.
417
417
418 Only aliases with exactly one descendent in the class list
418 Only aliases with exactly one descendent in the class list
419 will be promoted.
419 will be promoted.
420
420
421 """
421 """
422 # build a tree of classes in our list that inherit from a particular
422 # build a tree of classes in our list that inherit from a particular
423 # it will be a dict by parent classname of classes in our list
423 # it will be a dict by parent classname of classes in our list
424 # that are descendents
424 # that are descendents
425 mro_tree = defaultdict(list)
425 mro_tree = defaultdict(list)
426 for cls in self._help_classes:
426 for cls in self._help_classes:
427 clsname = cls.__name__
427 clsname = cls.__name__
428 for parent in cls.mro()[1:-3]:
428 for parent in cls.mro()[1:-3]:
429 # exclude cls itself and Configurable,HasTraits,object
429 # exclude cls itself and Configurable,HasTraits,object
430 mro_tree[parent.__name__].append(clsname)
430 mro_tree[parent.__name__].append(clsname)
431 # flatten aliases, which have the form:
431 # flatten aliases, which have the form:
432 # { 'alias' : 'Class.trait' }
432 # { 'alias' : 'Class.trait' }
433 aliases = {}
433 aliases = {}
434 for alias, cls_trait in iteritems(self.aliases):
434 for alias, cls_trait in iteritems(self.aliases):
435 cls,trait = cls_trait.split('.',1)
435 cls,trait = cls_trait.split('.',1)
436 children = mro_tree[cls]
436 children = mro_tree[cls]
437 if len(children) == 1:
437 if len(children) == 1:
438 # exactly one descendent, promote alias
438 # exactly one descendent, promote alias
439 cls = children[0]
439 cls = children[0]
440 aliases[alias] = '.'.join([cls,trait])
440 aliases[alias] = '.'.join([cls,trait])
441
441
442 # flatten flags, which are of the form:
442 # flatten flags, which are of the form:
443 # { 'key' : ({'Cls' : {'trait' : value}}, 'help')}
443 # { 'key' : ({'Cls' : {'trait' : value}}, 'help')}
444 flags = {}
444 flags = {}
445 for key, (flagdict, help) in iteritems(self.flags):
445 for key, (flagdict, help) in iteritems(self.flags):
446 newflag = {}
446 newflag = {}
447 for cls, subdict in iteritems(flagdict):
447 for cls, subdict in iteritems(flagdict):
448 children = mro_tree[cls]
448 children = mro_tree[cls]
449 # exactly one descendent, promote flag section
449 # exactly one descendent, promote flag section
450 if len(children) == 1:
450 if len(children) == 1:
451 cls = children[0]
451 cls = children[0]
452 newflag[cls] = subdict
452 newflag[cls] = subdict
453 flags[key] = (newflag, help)
453 flags[key] = (newflag, help)
454 return flags, aliases
454 return flags, aliases
455
455
456 @catch_config_error
456 @catch_config_error
457 def parse_command_line(self, argv=None):
457 def parse_command_line(self, argv=None):
458 """Parse the command line arguments."""
458 """Parse the command line arguments."""
459 argv = sys.argv[1:] if argv is None else argv
459 argv = sys.argv[1:] if argv is None else argv
460 self.argv = [ py3compat.cast_unicode(arg) for arg in argv ]
460 self.argv = [ py3compat.cast_unicode(arg) for arg in argv ]
461
461
462 if argv and argv[0] == 'help':
462 if argv and argv[0] == 'help':
463 # turn `ipython help notebook` into `ipython notebook -h`
463 # turn `ipython help notebook` into `ipython notebook -h`
464 argv = argv[1:] + ['-h']
464 argv = argv[1:] + ['-h']
465
465
466 if self.subcommands and len(argv) > 0:
466 if self.subcommands and len(argv) > 0:
467 # we have subcommands, and one may have been specified
467 # we have subcommands, and one may have been specified
468 subc, subargv = argv[0], argv[1:]
468 subc, subargv = argv[0], argv[1:]
469 if re.match(r'^\w(\-?\w)*$', subc) and subc in self.subcommands:
469 if re.match(r'^\w(\-?\w)*$', subc) and subc in self.subcommands:
470 # it's a subcommand, and *not* a flag or class parameter
470 # it's a subcommand, and *not* a flag or class parameter
471 return self.initialize_subcommand(subc, subargv)
471 return self.initialize_subcommand(subc, subargv)
472
472
473 # Arguments after a '--' argument are for the script IPython may be
473 # Arguments after a '--' argument are for the script IPython may be
474 # about to run, not IPython iteslf. For arguments parsed here (help and
474 # about to run, not IPython iteslf. For arguments parsed here (help and
475 # version), we want to only search the arguments up to the first
475 # version), we want to only search the arguments up to the first
476 # occurrence of '--', which we're calling interpreted_argv.
476 # occurrence of '--', which we're calling interpreted_argv.
477 try:
477 try:
478 interpreted_argv = argv[:argv.index('--')]
478 interpreted_argv = argv[:argv.index('--')]
479 except ValueError:
479 except ValueError:
480 interpreted_argv = argv
480 interpreted_argv = argv
481
481
482 if any(x in interpreted_argv for x in ('-h', '--help-all', '--help')):
482 if any(x in interpreted_argv for x in ('-h', '--help-all', '--help')):
483 self.print_help('--help-all' in interpreted_argv)
483 self.print_help('--help-all' in interpreted_argv)
484 self.exit(0)
484 self.exit(0)
485
485
486 if '--version' in interpreted_argv or '-V' in interpreted_argv:
486 if '--version' in interpreted_argv or '-V' in interpreted_argv:
487 self.print_version()
487 self.print_version()
488 self.exit(0)
488 self.exit(0)
489
489
490 # flatten flags&aliases, so cl-args get appropriate priority:
490 # flatten flags&aliases, so cl-args get appropriate priority:
491 flags,aliases = self.flatten_flags()
491 flags,aliases = self.flatten_flags()
492 loader = KVArgParseConfigLoader(argv=argv, aliases=aliases,
492 loader = KVArgParseConfigLoader(argv=argv, aliases=aliases,
493 flags=flags, log=self.log)
493 flags=flags, log=self.log)
494 config = loader.load_config()
494 config = loader.load_config()
495 self.update_config(config)
495 self.update_config(config)
496 # store unparsed args in extra_args
496 # store unparsed args in extra_args
497 self.extra_args = loader.extra_args
497 self.extra_args = loader.extra_args
498
498
499 @classmethod
499 @classmethod
500 def _load_config_files(cls, basefilename, path=None, log=None):
500 def _load_config_files(cls, basefilename, path=None, log=None):
501 """Load config files (py,json) by filename and path.
501 """Load config files (py,json) by filename and path.
502
502
503 yield each config object in turn.
503 yield each config object in turn.
504 """
504 """
505
505
506 if not isinstance(path, list):
506 if not isinstance(path, list):
507 path = [path]
507 path = [path]
508 for path in path[::-1]:
508 for path in path[::-1]:
509 # path list is in descending priority order, so load files backwards:
509 # path list is in descending priority order, so load files backwards:
510 pyloader = PyFileConfigLoader(basefilename+'.py', path=path, log=log)
510 pyloader = PyFileConfigLoader(basefilename+'.py', path=path, log=log)
511 jsonloader = JSONFileConfigLoader(basefilename+'.json', path=path, log=log)
511 jsonloader = JSONFileConfigLoader(basefilename+'.json', path=path, log=log)
512 config = None
512 config = None
513 for loader in [pyloader, jsonloader]:
513 for loader in [pyloader, jsonloader]:
514 try:
514 try:
515 config = loader.load_config()
515 config = loader.load_config()
516 except ConfigFileNotFound:
516 except ConfigFileNotFound:
517 pass
517 pass
518 except Exception:
518 except Exception:
519 # try to get the full filename, but it will be empty in the
519 # try to get the full filename, but it will be empty in the
520 # unlikely event that the error raised before filefind finished
520 # unlikely event that the error raised before filefind finished
521 filename = loader.full_filename or basefilename
521 filename = loader.full_filename or basefilename
522 # problem while running the file
522 # problem while running the file
523 if log:
523 if log:
524 log.error("Exception while loading config file %s",
524 log.error("Exception while loading config file %s",
525 filename, exc_info=True)
525 filename, exc_info=True)
526 else:
526 else:
527 if log:
527 if log:
528 log.debug("Loaded config file: %s", loader.full_filename)
528 log.debug("Loaded config file: %s", loader.full_filename)
529 if config:
529 if config:
530 yield config
530 yield config
531
531
532 raise StopIteration
532 raise StopIteration
533
533
534
534
535 @catch_config_error
535 @catch_config_error
536 def load_config_file(self, filename, path=None):
536 def load_config_file(self, filename, path=None):
537 """Load config files by filename and path."""
537 """Load config files by filename and path."""
538 filename, ext = os.path.splitext(filename)
538 filename, ext = os.path.splitext(filename)
539 loaded = []
539 loaded = []
540 for config in self._load_config_files(filename, path=path, log=self.log):
540 for config in self._load_config_files(filename, path=path, log=self.log):
541 loaded.append(config)
541 loaded.append(config)
542 self.update_config(config)
542 self.update_config(config)
543 if len(loaded) > 1:
543 if len(loaded) > 1:
544 collisions = loaded[0].collisions(loaded[1])
544 collisions = loaded[0].collisions(loaded[1])
545 if collisions:
545 if collisions:
546 self.log.warn("Collisions detected in {0}.py and {0}.json config files."
546 self.log.warn("Collisions detected in {0}.py and {0}.json config files."
547 " {0}.json has higher priority: {1}".format(
547 " {0}.json has higher priority: {1}".format(
548 filename, json.dumps(collisions, indent=2),
548 filename, json.dumps(collisions, indent=2),
549 ))
549 ))
550
550
551
551
552 def generate_config_file(self):
552 def generate_config_file(self):
553 """generate default config file from Configurables"""
553 """generate default config file from Configurables"""
554 lines = ["# Configuration file for %s."%self.name]
554 lines = ["# Configuration file for %s."%self.name]
555 lines.append('')
555 lines.append('')
556 lines.append('c = get_config()')
556 lines.append('c = get_config()')
557 lines.append('')
557 lines.append('')
558 for cls in self._config_classes:
558 for cls in self._config_classes:
559 lines.append(cls.class_config_section())
559 lines.append(cls.class_config_section())
560 return '\n'.join(lines)
560 return '\n'.join(lines)
561
561
562 def exit(self, exit_status=0):
562 def exit(self, exit_status=0):
563 self.log.debug("Exiting application: %s" % self.name)
563 self.log.debug("Exiting application: %s" % self.name)
564 sys.exit(exit_status)
564 sys.exit(exit_status)
565
565
566 @classmethod
566 @classmethod
567 def launch_instance(cls, argv=None, **kwargs):
567 def launch_instance(cls, argv=None, **kwargs):
568 """Launch a global instance of this Application
568 """Launch a global instance of this Application
569
569
570 If a global instance already exists, this reinitializes and starts it
570 If a global instance already exists, this reinitializes and starts it
571 """
571 """
572 app = cls.instance(**kwargs)
572 app = cls.instance(**kwargs)
573 app.initialize(argv)
573 app.initialize(argv)
574 app.start()
574 app.start()
575
575
576 #-----------------------------------------------------------------------------
576 #-----------------------------------------------------------------------------
577 # utility functions, for convenience
577 # utility functions, for convenience
578 #-----------------------------------------------------------------------------
578 #-----------------------------------------------------------------------------
579
579
580 def boolean_flag(name, configurable, set_help='', unset_help=''):
580 def boolean_flag(name, configurable, set_help='', unset_help=''):
581 """Helper for building basic --trait, --no-trait flags.
581 """Helper for building basic --trait, --no-trait flags.
582
582
583 Parameters
583 Parameters
584 ----------
584 ----------
585
585
586 name : str
586 name : str
587 The name of the flag.
587 The name of the flag.
588 configurable : str
588 configurable : str
589 The 'Class.trait' string of the trait to be set/unset with the flag
589 The 'Class.trait' string of the trait to be set/unset with the flag
590 set_help : unicode
590 set_help : unicode
591 help string for --name flag
591 help string for --name flag
592 unset_help : unicode
592 unset_help : unicode
593 help string for --no-name flag
593 help string for --no-name flag
594
594
595 Returns
595 Returns
596 -------
596 -------
597
597
598 cfg : dict
598 cfg : dict
599 A dict with two keys: 'name', and 'no-name', for setting and unsetting
599 A dict with two keys: 'name', and 'no-name', for setting and unsetting
600 the trait, respectively.
600 the trait, respectively.
601 """
601 """
602 # default helpstrings
602 # default helpstrings
603 set_help = set_help or "set %s=True"%configurable
603 set_help = set_help or "set %s=True"%configurable
604 unset_help = unset_help or "set %s=False"%configurable
604 unset_help = unset_help or "set %s=False"%configurable
605
605
606 cls,trait = configurable.split('.')
606 cls,trait = configurable.split('.')
607
607
608 setter = {cls : {trait : True}}
608 setter = {cls : {trait : True}}
609 unsetter = {cls : {trait : False}}
609 unsetter = {cls : {trait : False}}
610 return {name : (setter, set_help), 'no-'+name : (unsetter, unset_help)}
610 return {name : (setter, set_help), 'no-'+name : (unsetter, unset_help)}
611
611
612
612
613 def get_config():
613 def get_config():
614 """Get the config object for the global Application instance, if there is one
614 """Get the config object for the global Application instance, if there is one
615
615
616 otherwise return an empty config object
616 otherwise return an empty config object
617 """
617 """
618 if Application.initialized():
618 if Application.initialized():
619 return Application.instance().config
619 return Application.instance().config
620 else:
620 else:
621 return Config()
621 return Config()
@@ -1,965 +1,965 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Display formatters.
2 """Display formatters.
3
3
4 Inheritance diagram:
4 Inheritance diagram:
5
5
6 .. inheritance-diagram:: IPython.core.formatters
6 .. inheritance-diagram:: IPython.core.formatters
7 :parts: 3
7 :parts: 3
8 """
8 """
9
9
10 # Copyright (c) IPython Development Team.
10 # Copyright (c) IPython Development Team.
11 # Distributed under the terms of the Modified BSD License.
11 # Distributed under the terms of the Modified BSD License.
12
12
13 import abc
13 import abc
14 import inspect
14 import inspect
15 import json
15 import json
16 import sys
16 import sys
17 import traceback
17 import traceback
18 import warnings
18 import warnings
19
19
20 from IPython.external.decorator import decorator
20 from decorator import decorator
21
21
22 from IPython.config.configurable import Configurable
22 from IPython.config.configurable import Configurable
23 from IPython.core.getipython import get_ipython
23 from IPython.core.getipython import get_ipython
24 from IPython.lib import pretty
24 from IPython.lib import pretty
25 from IPython.utils.traitlets import (
25 from IPython.utils.traitlets import (
26 Bool, Dict, Integer, Unicode, CUnicode, ObjectName, List,
26 Bool, Dict, Integer, Unicode, CUnicode, ObjectName, List,
27 ForwardDeclaredInstance,
27 ForwardDeclaredInstance,
28 )
28 )
29 from IPython.utils.py3compat import (
29 from IPython.utils.py3compat import (
30 with_metaclass, string_types, unicode_type,
30 with_metaclass, string_types, unicode_type,
31 )
31 )
32
32
33
33
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35 # The main DisplayFormatter class
35 # The main DisplayFormatter class
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37
37
38
38
39 def _safe_get_formatter_method(obj, name):
39 def _safe_get_formatter_method(obj, name):
40 """Safely get a formatter method
40 """Safely get a formatter method
41
41
42 - Classes cannot have formatter methods, only instance
42 - Classes cannot have formatter methods, only instance
43 - protect against proxy objects that claim to have everything
43 - protect against proxy objects that claim to have everything
44 """
44 """
45 if inspect.isclass(obj):
45 if inspect.isclass(obj):
46 # repr methods only make sense on instances, not classes
46 # repr methods only make sense on instances, not classes
47 return None
47 return None
48 method = pretty._safe_getattr(obj, name, None)
48 method = pretty._safe_getattr(obj, name, None)
49 if callable(method):
49 if callable(method):
50 # obj claims to have repr method...
50 # obj claims to have repr method...
51 if callable(pretty._safe_getattr(obj, '_ipython_canary_method_should_not_exist_', None)):
51 if callable(pretty._safe_getattr(obj, '_ipython_canary_method_should_not_exist_', None)):
52 # ...but don't trust proxy objects that claim to have everything
52 # ...but don't trust proxy objects that claim to have everything
53 return None
53 return None
54 return method
54 return method
55
55
56
56
57 class DisplayFormatter(Configurable):
57 class DisplayFormatter(Configurable):
58
58
59 # When set to true only the default plain text formatter will be used.
59 # When set to true only the default plain text formatter will be used.
60 plain_text_only = Bool(False, config=True)
60 plain_text_only = Bool(False, config=True)
61 def _plain_text_only_changed(self, name, old, new):
61 def _plain_text_only_changed(self, name, old, new):
62 warnings.warn("""DisplayFormatter.plain_text_only is deprecated.
62 warnings.warn("""DisplayFormatter.plain_text_only is deprecated.
63
63
64 Use DisplayFormatter.active_types = ['text/plain']
64 Use DisplayFormatter.active_types = ['text/plain']
65 for the same effect.
65 for the same effect.
66 """, DeprecationWarning)
66 """, DeprecationWarning)
67 if new:
67 if new:
68 self.active_types = ['text/plain']
68 self.active_types = ['text/plain']
69 else:
69 else:
70 self.active_types = self.format_types
70 self.active_types = self.format_types
71
71
72 active_types = List(Unicode, config=True,
72 active_types = List(Unicode, config=True,
73 help="""List of currently active mime-types to display.
73 help="""List of currently active mime-types to display.
74 You can use this to set a white-list for formats to display.
74 You can use this to set a white-list for formats to display.
75
75
76 Most users will not need to change this value.
76 Most users will not need to change this value.
77 """)
77 """)
78 def _active_types_default(self):
78 def _active_types_default(self):
79 return self.format_types
79 return self.format_types
80
80
81 def _active_types_changed(self, name, old, new):
81 def _active_types_changed(self, name, old, new):
82 for key, formatter in self.formatters.items():
82 for key, formatter in self.formatters.items():
83 if key in new:
83 if key in new:
84 formatter.enabled = True
84 formatter.enabled = True
85 else:
85 else:
86 formatter.enabled = False
86 formatter.enabled = False
87
87
88 ipython_display_formatter = ForwardDeclaredInstance('FormatterABC')
88 ipython_display_formatter = ForwardDeclaredInstance('FormatterABC')
89 def _ipython_display_formatter_default(self):
89 def _ipython_display_formatter_default(self):
90 return IPythonDisplayFormatter(parent=self)
90 return IPythonDisplayFormatter(parent=self)
91
91
92 # A dict of formatter whose keys are format types (MIME types) and whose
92 # A dict of formatter whose keys are format types (MIME types) and whose
93 # values are subclasses of BaseFormatter.
93 # values are subclasses of BaseFormatter.
94 formatters = Dict()
94 formatters = Dict()
95 def _formatters_default(self):
95 def _formatters_default(self):
96 """Activate the default formatters."""
96 """Activate the default formatters."""
97 formatter_classes = [
97 formatter_classes = [
98 PlainTextFormatter,
98 PlainTextFormatter,
99 HTMLFormatter,
99 HTMLFormatter,
100 MarkdownFormatter,
100 MarkdownFormatter,
101 SVGFormatter,
101 SVGFormatter,
102 PNGFormatter,
102 PNGFormatter,
103 PDFFormatter,
103 PDFFormatter,
104 JPEGFormatter,
104 JPEGFormatter,
105 LatexFormatter,
105 LatexFormatter,
106 JSONFormatter,
106 JSONFormatter,
107 JavascriptFormatter
107 JavascriptFormatter
108 ]
108 ]
109 d = {}
109 d = {}
110 for cls in formatter_classes:
110 for cls in formatter_classes:
111 f = cls(parent=self)
111 f = cls(parent=self)
112 d[f.format_type] = f
112 d[f.format_type] = f
113 return d
113 return d
114
114
115 def format(self, obj, include=None, exclude=None):
115 def format(self, obj, include=None, exclude=None):
116 """Return a format data dict for an object.
116 """Return a format data dict for an object.
117
117
118 By default all format types will be computed.
118 By default all format types will be computed.
119
119
120 The following MIME types are currently implemented:
120 The following MIME types are currently implemented:
121
121
122 * text/plain
122 * text/plain
123 * text/html
123 * text/html
124 * text/markdown
124 * text/markdown
125 * text/latex
125 * text/latex
126 * application/json
126 * application/json
127 * application/javascript
127 * application/javascript
128 * application/pdf
128 * application/pdf
129 * image/png
129 * image/png
130 * image/jpeg
130 * image/jpeg
131 * image/svg+xml
131 * image/svg+xml
132
132
133 Parameters
133 Parameters
134 ----------
134 ----------
135 obj : object
135 obj : object
136 The Python object whose format data will be computed.
136 The Python object whose format data will be computed.
137 include : list or tuple, optional
137 include : list or tuple, optional
138 A list of format type strings (MIME types) to include in the
138 A list of format type strings (MIME types) to include in the
139 format data dict. If this is set *only* the format types included
139 format data dict. If this is set *only* the format types included
140 in this list will be computed.
140 in this list will be computed.
141 exclude : list or tuple, optional
141 exclude : list or tuple, optional
142 A list of format type string (MIME types) to exclude in the format
142 A list of format type string (MIME types) to exclude in the format
143 data dict. If this is set all format types will be computed,
143 data dict. If this is set all format types will be computed,
144 except for those included in this argument.
144 except for those included in this argument.
145
145
146 Returns
146 Returns
147 -------
147 -------
148 (format_dict, metadata_dict) : tuple of two dicts
148 (format_dict, metadata_dict) : tuple of two dicts
149
149
150 format_dict is a dictionary of key/value pairs, one of each format that was
150 format_dict is a dictionary of key/value pairs, one of each format that was
151 generated for the object. The keys are the format types, which
151 generated for the object. The keys are the format types, which
152 will usually be MIME type strings and the values and JSON'able
152 will usually be MIME type strings and the values and JSON'able
153 data structure containing the raw data for the representation in
153 data structure containing the raw data for the representation in
154 that format.
154 that format.
155
155
156 metadata_dict is a dictionary of metadata about each mime-type output.
156 metadata_dict is a dictionary of metadata about each mime-type output.
157 Its keys will be a strict subset of the keys in format_dict.
157 Its keys will be a strict subset of the keys in format_dict.
158 """
158 """
159 format_dict = {}
159 format_dict = {}
160 md_dict = {}
160 md_dict = {}
161
161
162 if self.ipython_display_formatter(obj):
162 if self.ipython_display_formatter(obj):
163 # object handled itself, don't proceed
163 # object handled itself, don't proceed
164 return {}, {}
164 return {}, {}
165
165
166 for format_type, formatter in self.formatters.items():
166 for format_type, formatter in self.formatters.items():
167 if include and format_type not in include:
167 if include and format_type not in include:
168 continue
168 continue
169 if exclude and format_type in exclude:
169 if exclude and format_type in exclude:
170 continue
170 continue
171
171
172 md = None
172 md = None
173 try:
173 try:
174 data = formatter(obj)
174 data = formatter(obj)
175 except:
175 except:
176 # FIXME: log the exception
176 # FIXME: log the exception
177 raise
177 raise
178
178
179 # formatters can return raw data or (data, metadata)
179 # formatters can return raw data or (data, metadata)
180 if isinstance(data, tuple) and len(data) == 2:
180 if isinstance(data, tuple) and len(data) == 2:
181 data, md = data
181 data, md = data
182
182
183 if data is not None:
183 if data is not None:
184 format_dict[format_type] = data
184 format_dict[format_type] = data
185 if md is not None:
185 if md is not None:
186 md_dict[format_type] = md
186 md_dict[format_type] = md
187
187
188 return format_dict, md_dict
188 return format_dict, md_dict
189
189
190 @property
190 @property
191 def format_types(self):
191 def format_types(self):
192 """Return the format types (MIME types) of the active formatters."""
192 """Return the format types (MIME types) of the active formatters."""
193 return list(self.formatters.keys())
193 return list(self.formatters.keys())
194
194
195
195
196 #-----------------------------------------------------------------------------
196 #-----------------------------------------------------------------------------
197 # Formatters for specific format types (text, html, svg, etc.)
197 # Formatters for specific format types (text, html, svg, etc.)
198 #-----------------------------------------------------------------------------
198 #-----------------------------------------------------------------------------
199
199
200
200
201 def _safe_repr(obj):
201 def _safe_repr(obj):
202 """Try to return a repr of an object
202 """Try to return a repr of an object
203
203
204 always returns a string, at least.
204 always returns a string, at least.
205 """
205 """
206 try:
206 try:
207 return repr(obj)
207 return repr(obj)
208 except Exception as e:
208 except Exception as e:
209 return "un-repr-able object (%r)" % e
209 return "un-repr-able object (%r)" % e
210
210
211
211
212 class FormatterWarning(UserWarning):
212 class FormatterWarning(UserWarning):
213 """Warning class for errors in formatters"""
213 """Warning class for errors in formatters"""
214
214
215 @decorator
215 @decorator
216 def catch_format_error(method, self, *args, **kwargs):
216 def catch_format_error(method, self, *args, **kwargs):
217 """show traceback on failed format call"""
217 """show traceback on failed format call"""
218 try:
218 try:
219 r = method(self, *args, **kwargs)
219 r = method(self, *args, **kwargs)
220 except NotImplementedError:
220 except NotImplementedError:
221 # don't warn on NotImplementedErrors
221 # don't warn on NotImplementedErrors
222 return None
222 return None
223 except Exception:
223 except Exception:
224 exc_info = sys.exc_info()
224 exc_info = sys.exc_info()
225 ip = get_ipython()
225 ip = get_ipython()
226 if ip is not None:
226 if ip is not None:
227 ip.showtraceback(exc_info)
227 ip.showtraceback(exc_info)
228 else:
228 else:
229 traceback.print_exception(*exc_info)
229 traceback.print_exception(*exc_info)
230 return None
230 return None
231 return self._check_return(r, args[0])
231 return self._check_return(r, args[0])
232
232
233
233
234 class FormatterABC(with_metaclass(abc.ABCMeta, object)):
234 class FormatterABC(with_metaclass(abc.ABCMeta, object)):
235 """ Abstract base class for Formatters.
235 """ Abstract base class for Formatters.
236
236
237 A formatter is a callable class that is responsible for computing the
237 A formatter is a callable class that is responsible for computing the
238 raw format data for a particular format type (MIME type). For example,
238 raw format data for a particular format type (MIME type). For example,
239 an HTML formatter would have a format type of `text/html` and would return
239 an HTML formatter would have a format type of `text/html` and would return
240 the HTML representation of the object when called.
240 the HTML representation of the object when called.
241 """
241 """
242
242
243 # The format type of the data returned, usually a MIME type.
243 # The format type of the data returned, usually a MIME type.
244 format_type = 'text/plain'
244 format_type = 'text/plain'
245
245
246 # Is the formatter enabled...
246 # Is the formatter enabled...
247 enabled = True
247 enabled = True
248
248
249 @abc.abstractmethod
249 @abc.abstractmethod
250 def __call__(self, obj):
250 def __call__(self, obj):
251 """Return a JSON'able representation of the object.
251 """Return a JSON'able representation of the object.
252
252
253 If the object cannot be formatted by this formatter,
253 If the object cannot be formatted by this formatter,
254 warn and return None.
254 warn and return None.
255 """
255 """
256 return repr(obj)
256 return repr(obj)
257
257
258
258
259 def _mod_name_key(typ):
259 def _mod_name_key(typ):
260 """Return a (__module__, __name__) tuple for a type.
260 """Return a (__module__, __name__) tuple for a type.
261
261
262 Used as key in Formatter.deferred_printers.
262 Used as key in Formatter.deferred_printers.
263 """
263 """
264 module = getattr(typ, '__module__', None)
264 module = getattr(typ, '__module__', None)
265 name = getattr(typ, '__name__', None)
265 name = getattr(typ, '__name__', None)
266 return (module, name)
266 return (module, name)
267
267
268
268
269 def _get_type(obj):
269 def _get_type(obj):
270 """Return the type of an instance (old and new-style)"""
270 """Return the type of an instance (old and new-style)"""
271 return getattr(obj, '__class__', None) or type(obj)
271 return getattr(obj, '__class__', None) or type(obj)
272
272
273 _raise_key_error = object()
273 _raise_key_error = object()
274
274
275
275
276 class BaseFormatter(Configurable):
276 class BaseFormatter(Configurable):
277 """A base formatter class that is configurable.
277 """A base formatter class that is configurable.
278
278
279 This formatter should usually be used as the base class of all formatters.
279 This formatter should usually be used as the base class of all formatters.
280 It is a traited :class:`Configurable` class and includes an extensible
280 It is a traited :class:`Configurable` class and includes an extensible
281 API for users to determine how their objects are formatted. The following
281 API for users to determine how their objects are formatted. The following
282 logic is used to find a function to format an given object.
282 logic is used to find a function to format an given object.
283
283
284 1. The object is introspected to see if it has a method with the name
284 1. The object is introspected to see if it has a method with the name
285 :attr:`print_method`. If is does, that object is passed to that method
285 :attr:`print_method`. If is does, that object is passed to that method
286 for formatting.
286 for formatting.
287 2. If no print method is found, three internal dictionaries are consulted
287 2. If no print method is found, three internal dictionaries are consulted
288 to find print method: :attr:`singleton_printers`, :attr:`type_printers`
288 to find print method: :attr:`singleton_printers`, :attr:`type_printers`
289 and :attr:`deferred_printers`.
289 and :attr:`deferred_printers`.
290
290
291 Users should use these dictionaries to register functions that will be
291 Users should use these dictionaries to register functions that will be
292 used to compute the format data for their objects (if those objects don't
292 used to compute the format data for their objects (if those objects don't
293 have the special print methods). The easiest way of using these
293 have the special print methods). The easiest way of using these
294 dictionaries is through the :meth:`for_type` and :meth:`for_type_by_name`
294 dictionaries is through the :meth:`for_type` and :meth:`for_type_by_name`
295 methods.
295 methods.
296
296
297 If no function/callable is found to compute the format data, ``None`` is
297 If no function/callable is found to compute the format data, ``None`` is
298 returned and this format type is not used.
298 returned and this format type is not used.
299 """
299 """
300
300
301 format_type = Unicode('text/plain')
301 format_type = Unicode('text/plain')
302 _return_type = string_types
302 _return_type = string_types
303
303
304 enabled = Bool(True, config=True)
304 enabled = Bool(True, config=True)
305
305
306 print_method = ObjectName('__repr__')
306 print_method = ObjectName('__repr__')
307
307
308 # The singleton printers.
308 # The singleton printers.
309 # Maps the IDs of the builtin singleton objects to the format functions.
309 # Maps the IDs of the builtin singleton objects to the format functions.
310 singleton_printers = Dict(config=True)
310 singleton_printers = Dict(config=True)
311
311
312 # The type-specific printers.
312 # The type-specific printers.
313 # Map type objects to the format functions.
313 # Map type objects to the format functions.
314 type_printers = Dict(config=True)
314 type_printers = Dict(config=True)
315
315
316 # The deferred-import type-specific printers.
316 # The deferred-import type-specific printers.
317 # Map (modulename, classname) pairs to the format functions.
317 # Map (modulename, classname) pairs to the format functions.
318 deferred_printers = Dict(config=True)
318 deferred_printers = Dict(config=True)
319
319
320 @catch_format_error
320 @catch_format_error
321 def __call__(self, obj):
321 def __call__(self, obj):
322 """Compute the format for an object."""
322 """Compute the format for an object."""
323 if self.enabled:
323 if self.enabled:
324 # lookup registered printer
324 # lookup registered printer
325 try:
325 try:
326 printer = self.lookup(obj)
326 printer = self.lookup(obj)
327 except KeyError:
327 except KeyError:
328 pass
328 pass
329 else:
329 else:
330 return printer(obj)
330 return printer(obj)
331 # Finally look for special method names
331 # Finally look for special method names
332 method = _safe_get_formatter_method(obj, self.print_method)
332 method = _safe_get_formatter_method(obj, self.print_method)
333 if method is not None:
333 if method is not None:
334 return method()
334 return method()
335 return None
335 return None
336 else:
336 else:
337 return None
337 return None
338
338
339 def __contains__(self, typ):
339 def __contains__(self, typ):
340 """map in to lookup_by_type"""
340 """map in to lookup_by_type"""
341 try:
341 try:
342 self.lookup_by_type(typ)
342 self.lookup_by_type(typ)
343 except KeyError:
343 except KeyError:
344 return False
344 return False
345 else:
345 else:
346 return True
346 return True
347
347
348 def _check_return(self, r, obj):
348 def _check_return(self, r, obj):
349 """Check that a return value is appropriate
349 """Check that a return value is appropriate
350
350
351 Return the value if so, None otherwise, warning if invalid.
351 Return the value if so, None otherwise, warning if invalid.
352 """
352 """
353 if r is None or isinstance(r, self._return_type) or \
353 if r is None or isinstance(r, self._return_type) or \
354 (isinstance(r, tuple) and r and isinstance(r[0], self._return_type)):
354 (isinstance(r, tuple) and r and isinstance(r[0], self._return_type)):
355 return r
355 return r
356 else:
356 else:
357 warnings.warn(
357 warnings.warn(
358 "%s formatter returned invalid type %s (expected %s) for object: %s" % \
358 "%s formatter returned invalid type %s (expected %s) for object: %s" % \
359 (self.format_type, type(r), self._return_type, _safe_repr(obj)),
359 (self.format_type, type(r), self._return_type, _safe_repr(obj)),
360 FormatterWarning
360 FormatterWarning
361 )
361 )
362
362
363 def lookup(self, obj):
363 def lookup(self, obj):
364 """Look up the formatter for a given instance.
364 """Look up the formatter for a given instance.
365
365
366 Parameters
366 Parameters
367 ----------
367 ----------
368 obj : object instance
368 obj : object instance
369
369
370 Returns
370 Returns
371 -------
371 -------
372 f : callable
372 f : callable
373 The registered formatting callable for the type.
373 The registered formatting callable for the type.
374
374
375 Raises
375 Raises
376 ------
376 ------
377 KeyError if the type has not been registered.
377 KeyError if the type has not been registered.
378 """
378 """
379 # look for singleton first
379 # look for singleton first
380 obj_id = id(obj)
380 obj_id = id(obj)
381 if obj_id in self.singleton_printers:
381 if obj_id in self.singleton_printers:
382 return self.singleton_printers[obj_id]
382 return self.singleton_printers[obj_id]
383 # then lookup by type
383 # then lookup by type
384 return self.lookup_by_type(_get_type(obj))
384 return self.lookup_by_type(_get_type(obj))
385
385
386 def lookup_by_type(self, typ):
386 def lookup_by_type(self, typ):
387 """Look up the registered formatter for a type.
387 """Look up the registered formatter for a type.
388
388
389 Parameters
389 Parameters
390 ----------
390 ----------
391 typ : type or '__module__.__name__' string for a type
391 typ : type or '__module__.__name__' string for a type
392
392
393 Returns
393 Returns
394 -------
394 -------
395 f : callable
395 f : callable
396 The registered formatting callable for the type.
396 The registered formatting callable for the type.
397
397
398 Raises
398 Raises
399 ------
399 ------
400 KeyError if the type has not been registered.
400 KeyError if the type has not been registered.
401 """
401 """
402 if isinstance(typ, string_types):
402 if isinstance(typ, string_types):
403 typ_key = tuple(typ.rsplit('.',1))
403 typ_key = tuple(typ.rsplit('.',1))
404 if typ_key not in self.deferred_printers:
404 if typ_key not in self.deferred_printers:
405 # We may have it cached in the type map. We will have to
405 # We may have it cached in the type map. We will have to
406 # iterate over all of the types to check.
406 # iterate over all of the types to check.
407 for cls in self.type_printers:
407 for cls in self.type_printers:
408 if _mod_name_key(cls) == typ_key:
408 if _mod_name_key(cls) == typ_key:
409 return self.type_printers[cls]
409 return self.type_printers[cls]
410 else:
410 else:
411 return self.deferred_printers[typ_key]
411 return self.deferred_printers[typ_key]
412 else:
412 else:
413 for cls in pretty._get_mro(typ):
413 for cls in pretty._get_mro(typ):
414 if cls in self.type_printers or self._in_deferred_types(cls):
414 if cls in self.type_printers or self._in_deferred_types(cls):
415 return self.type_printers[cls]
415 return self.type_printers[cls]
416
416
417 # If we have reached here, the lookup failed.
417 # If we have reached here, the lookup failed.
418 raise KeyError("No registered printer for {0!r}".format(typ))
418 raise KeyError("No registered printer for {0!r}".format(typ))
419
419
420 def for_type(self, typ, func=None):
420 def for_type(self, typ, func=None):
421 """Add a format function for a given type.
421 """Add a format function for a given type.
422
422
423 Parameters
423 Parameters
424 -----------
424 -----------
425 typ : type or '__module__.__name__' string for a type
425 typ : type or '__module__.__name__' string for a type
426 The class of the object that will be formatted using `func`.
426 The class of the object that will be formatted using `func`.
427 func : callable
427 func : callable
428 A callable for computing the format data.
428 A callable for computing the format data.
429 `func` will be called with the object to be formatted,
429 `func` will be called with the object to be formatted,
430 and will return the raw data in this formatter's format.
430 and will return the raw data in this formatter's format.
431 Subclasses may use a different call signature for the
431 Subclasses may use a different call signature for the
432 `func` argument.
432 `func` argument.
433
433
434 If `func` is None or not specified, there will be no change,
434 If `func` is None or not specified, there will be no change,
435 only returning the current value.
435 only returning the current value.
436
436
437 Returns
437 Returns
438 -------
438 -------
439 oldfunc : callable
439 oldfunc : callable
440 The currently registered callable.
440 The currently registered callable.
441 If you are registering a new formatter,
441 If you are registering a new formatter,
442 this will be the previous value (to enable restoring later).
442 this will be the previous value (to enable restoring later).
443 """
443 """
444 # if string given, interpret as 'pkg.module.class_name'
444 # if string given, interpret as 'pkg.module.class_name'
445 if isinstance(typ, string_types):
445 if isinstance(typ, string_types):
446 type_module, type_name = typ.rsplit('.', 1)
446 type_module, type_name = typ.rsplit('.', 1)
447 return self.for_type_by_name(type_module, type_name, func)
447 return self.for_type_by_name(type_module, type_name, func)
448
448
449 try:
449 try:
450 oldfunc = self.lookup_by_type(typ)
450 oldfunc = self.lookup_by_type(typ)
451 except KeyError:
451 except KeyError:
452 oldfunc = None
452 oldfunc = None
453
453
454 if func is not None:
454 if func is not None:
455 self.type_printers[typ] = func
455 self.type_printers[typ] = func
456
456
457 return oldfunc
457 return oldfunc
458
458
459 def for_type_by_name(self, type_module, type_name, func=None):
459 def for_type_by_name(self, type_module, type_name, func=None):
460 """Add a format function for a type specified by the full dotted
460 """Add a format function for a type specified by the full dotted
461 module and name of the type, rather than the type of the object.
461 module and name of the type, rather than the type of the object.
462
462
463 Parameters
463 Parameters
464 ----------
464 ----------
465 type_module : str
465 type_module : str
466 The full dotted name of the module the type is defined in, like
466 The full dotted name of the module the type is defined in, like
467 ``numpy``.
467 ``numpy``.
468 type_name : str
468 type_name : str
469 The name of the type (the class name), like ``dtype``
469 The name of the type (the class name), like ``dtype``
470 func : callable
470 func : callable
471 A callable for computing the format data.
471 A callable for computing the format data.
472 `func` will be called with the object to be formatted,
472 `func` will be called with the object to be formatted,
473 and will return the raw data in this formatter's format.
473 and will return the raw data in this formatter's format.
474 Subclasses may use a different call signature for the
474 Subclasses may use a different call signature for the
475 `func` argument.
475 `func` argument.
476
476
477 If `func` is None or unspecified, there will be no change,
477 If `func` is None or unspecified, there will be no change,
478 only returning the current value.
478 only returning the current value.
479
479
480 Returns
480 Returns
481 -------
481 -------
482 oldfunc : callable
482 oldfunc : callable
483 The currently registered callable.
483 The currently registered callable.
484 If you are registering a new formatter,
484 If you are registering a new formatter,
485 this will be the previous value (to enable restoring later).
485 this will be the previous value (to enable restoring later).
486 """
486 """
487 key = (type_module, type_name)
487 key = (type_module, type_name)
488
488
489 try:
489 try:
490 oldfunc = self.lookup_by_type("%s.%s" % key)
490 oldfunc = self.lookup_by_type("%s.%s" % key)
491 except KeyError:
491 except KeyError:
492 oldfunc = None
492 oldfunc = None
493
493
494 if func is not None:
494 if func is not None:
495 self.deferred_printers[key] = func
495 self.deferred_printers[key] = func
496 return oldfunc
496 return oldfunc
497
497
498 def pop(self, typ, default=_raise_key_error):
498 def pop(self, typ, default=_raise_key_error):
499 """Pop a formatter for the given type.
499 """Pop a formatter for the given type.
500
500
501 Parameters
501 Parameters
502 ----------
502 ----------
503 typ : type or '__module__.__name__' string for a type
503 typ : type or '__module__.__name__' string for a type
504 default : object
504 default : object
505 value to be returned if no formatter is registered for typ.
505 value to be returned if no formatter is registered for typ.
506
506
507 Returns
507 Returns
508 -------
508 -------
509 obj : object
509 obj : object
510 The last registered object for the type.
510 The last registered object for the type.
511
511
512 Raises
512 Raises
513 ------
513 ------
514 KeyError if the type is not registered and default is not specified.
514 KeyError if the type is not registered and default is not specified.
515 """
515 """
516
516
517 if isinstance(typ, string_types):
517 if isinstance(typ, string_types):
518 typ_key = tuple(typ.rsplit('.',1))
518 typ_key = tuple(typ.rsplit('.',1))
519 if typ_key not in self.deferred_printers:
519 if typ_key not in self.deferred_printers:
520 # We may have it cached in the type map. We will have to
520 # We may have it cached in the type map. We will have to
521 # iterate over all of the types to check.
521 # iterate over all of the types to check.
522 for cls in self.type_printers:
522 for cls in self.type_printers:
523 if _mod_name_key(cls) == typ_key:
523 if _mod_name_key(cls) == typ_key:
524 old = self.type_printers.pop(cls)
524 old = self.type_printers.pop(cls)
525 break
525 break
526 else:
526 else:
527 old = default
527 old = default
528 else:
528 else:
529 old = self.deferred_printers.pop(typ_key)
529 old = self.deferred_printers.pop(typ_key)
530 else:
530 else:
531 if typ in self.type_printers:
531 if typ in self.type_printers:
532 old = self.type_printers.pop(typ)
532 old = self.type_printers.pop(typ)
533 else:
533 else:
534 old = self.deferred_printers.pop(_mod_name_key(typ), default)
534 old = self.deferred_printers.pop(_mod_name_key(typ), default)
535 if old is _raise_key_error:
535 if old is _raise_key_error:
536 raise KeyError("No registered value for {0!r}".format(typ))
536 raise KeyError("No registered value for {0!r}".format(typ))
537 return old
537 return old
538
538
539 def _in_deferred_types(self, cls):
539 def _in_deferred_types(self, cls):
540 """
540 """
541 Check if the given class is specified in the deferred type registry.
541 Check if the given class is specified in the deferred type registry.
542
542
543 Successful matches will be moved to the regular type registry for future use.
543 Successful matches will be moved to the regular type registry for future use.
544 """
544 """
545 mod = getattr(cls, '__module__', None)
545 mod = getattr(cls, '__module__', None)
546 name = getattr(cls, '__name__', None)
546 name = getattr(cls, '__name__', None)
547 key = (mod, name)
547 key = (mod, name)
548 if key in self.deferred_printers:
548 if key in self.deferred_printers:
549 # Move the printer over to the regular registry.
549 # Move the printer over to the regular registry.
550 printer = self.deferred_printers.pop(key)
550 printer = self.deferred_printers.pop(key)
551 self.type_printers[cls] = printer
551 self.type_printers[cls] = printer
552 return True
552 return True
553 return False
553 return False
554
554
555
555
556 class PlainTextFormatter(BaseFormatter):
556 class PlainTextFormatter(BaseFormatter):
557 """The default pretty-printer.
557 """The default pretty-printer.
558
558
559 This uses :mod:`IPython.lib.pretty` to compute the format data of
559 This uses :mod:`IPython.lib.pretty` to compute the format data of
560 the object. If the object cannot be pretty printed, :func:`repr` is used.
560 the object. If the object cannot be pretty printed, :func:`repr` is used.
561 See the documentation of :mod:`IPython.lib.pretty` for details on
561 See the documentation of :mod:`IPython.lib.pretty` for details on
562 how to write pretty printers. Here is a simple example::
562 how to write pretty printers. Here is a simple example::
563
563
564 def dtype_pprinter(obj, p, cycle):
564 def dtype_pprinter(obj, p, cycle):
565 if cycle:
565 if cycle:
566 return p.text('dtype(...)')
566 return p.text('dtype(...)')
567 if hasattr(obj, 'fields'):
567 if hasattr(obj, 'fields'):
568 if obj.fields is None:
568 if obj.fields is None:
569 p.text(repr(obj))
569 p.text(repr(obj))
570 else:
570 else:
571 p.begin_group(7, 'dtype([')
571 p.begin_group(7, 'dtype([')
572 for i, field in enumerate(obj.descr):
572 for i, field in enumerate(obj.descr):
573 if i > 0:
573 if i > 0:
574 p.text(',')
574 p.text(',')
575 p.breakable()
575 p.breakable()
576 p.pretty(field)
576 p.pretty(field)
577 p.end_group(7, '])')
577 p.end_group(7, '])')
578 """
578 """
579
579
580 # The format type of data returned.
580 # The format type of data returned.
581 format_type = Unicode('text/plain')
581 format_type = Unicode('text/plain')
582
582
583 # This subclass ignores this attribute as it always need to return
583 # This subclass ignores this attribute as it always need to return
584 # something.
584 # something.
585 enabled = Bool(True, config=False)
585 enabled = Bool(True, config=False)
586
586
587 max_seq_length = Integer(pretty.MAX_SEQ_LENGTH, config=True,
587 max_seq_length = Integer(pretty.MAX_SEQ_LENGTH, config=True,
588 help="""Truncate large collections (lists, dicts, tuples, sets) to this size.
588 help="""Truncate large collections (lists, dicts, tuples, sets) to this size.
589
589
590 Set to 0 to disable truncation.
590 Set to 0 to disable truncation.
591 """
591 """
592 )
592 )
593
593
594 # Look for a _repr_pretty_ methods to use for pretty printing.
594 # Look for a _repr_pretty_ methods to use for pretty printing.
595 print_method = ObjectName('_repr_pretty_')
595 print_method = ObjectName('_repr_pretty_')
596
596
597 # Whether to pretty-print or not.
597 # Whether to pretty-print or not.
598 pprint = Bool(True, config=True)
598 pprint = Bool(True, config=True)
599
599
600 # Whether to be verbose or not.
600 # Whether to be verbose or not.
601 verbose = Bool(False, config=True)
601 verbose = Bool(False, config=True)
602
602
603 # The maximum width.
603 # The maximum width.
604 max_width = Integer(79, config=True)
604 max_width = Integer(79, config=True)
605
605
606 # The newline character.
606 # The newline character.
607 newline = Unicode('\n', config=True)
607 newline = Unicode('\n', config=True)
608
608
609 # format-string for pprinting floats
609 # format-string for pprinting floats
610 float_format = Unicode('%r')
610 float_format = Unicode('%r')
611 # setter for float precision, either int or direct format-string
611 # setter for float precision, either int or direct format-string
612 float_precision = CUnicode('', config=True)
612 float_precision = CUnicode('', config=True)
613
613
614 def _float_precision_changed(self, name, old, new):
614 def _float_precision_changed(self, name, old, new):
615 """float_precision changed, set float_format accordingly.
615 """float_precision changed, set float_format accordingly.
616
616
617 float_precision can be set by int or str.
617 float_precision can be set by int or str.
618 This will set float_format, after interpreting input.
618 This will set float_format, after interpreting input.
619 If numpy has been imported, numpy print precision will also be set.
619 If numpy has been imported, numpy print precision will also be set.
620
620
621 integer `n` sets format to '%.nf', otherwise, format set directly.
621 integer `n` sets format to '%.nf', otherwise, format set directly.
622
622
623 An empty string returns to defaults (repr for float, 8 for numpy).
623 An empty string returns to defaults (repr for float, 8 for numpy).
624
624
625 This parameter can be set via the '%precision' magic.
625 This parameter can be set via the '%precision' magic.
626 """
626 """
627
627
628 if '%' in new:
628 if '%' in new:
629 # got explicit format string
629 # got explicit format string
630 fmt = new
630 fmt = new
631 try:
631 try:
632 fmt%3.14159
632 fmt%3.14159
633 except Exception:
633 except Exception:
634 raise ValueError("Precision must be int or format string, not %r"%new)
634 raise ValueError("Precision must be int or format string, not %r"%new)
635 elif new:
635 elif new:
636 # otherwise, should be an int
636 # otherwise, should be an int
637 try:
637 try:
638 i = int(new)
638 i = int(new)
639 assert i >= 0
639 assert i >= 0
640 except ValueError:
640 except ValueError:
641 raise ValueError("Precision must be int or format string, not %r"%new)
641 raise ValueError("Precision must be int or format string, not %r"%new)
642 except AssertionError:
642 except AssertionError:
643 raise ValueError("int precision must be non-negative, not %r"%i)
643 raise ValueError("int precision must be non-negative, not %r"%i)
644
644
645 fmt = '%%.%if'%i
645 fmt = '%%.%if'%i
646 if 'numpy' in sys.modules:
646 if 'numpy' in sys.modules:
647 # set numpy precision if it has been imported
647 # set numpy precision if it has been imported
648 import numpy
648 import numpy
649 numpy.set_printoptions(precision=i)
649 numpy.set_printoptions(precision=i)
650 else:
650 else:
651 # default back to repr
651 # default back to repr
652 fmt = '%r'
652 fmt = '%r'
653 if 'numpy' in sys.modules:
653 if 'numpy' in sys.modules:
654 import numpy
654 import numpy
655 # numpy default is 8
655 # numpy default is 8
656 numpy.set_printoptions(precision=8)
656 numpy.set_printoptions(precision=8)
657 self.float_format = fmt
657 self.float_format = fmt
658
658
659 # Use the default pretty printers from IPython.lib.pretty.
659 # Use the default pretty printers from IPython.lib.pretty.
660 def _singleton_printers_default(self):
660 def _singleton_printers_default(self):
661 return pretty._singleton_pprinters.copy()
661 return pretty._singleton_pprinters.copy()
662
662
663 def _type_printers_default(self):
663 def _type_printers_default(self):
664 d = pretty._type_pprinters.copy()
664 d = pretty._type_pprinters.copy()
665 d[float] = lambda obj,p,cycle: p.text(self.float_format%obj)
665 d[float] = lambda obj,p,cycle: p.text(self.float_format%obj)
666 return d
666 return d
667
667
668 def _deferred_printers_default(self):
668 def _deferred_printers_default(self):
669 return pretty._deferred_type_pprinters.copy()
669 return pretty._deferred_type_pprinters.copy()
670
670
671 #### FormatterABC interface ####
671 #### FormatterABC interface ####
672
672
673 @catch_format_error
673 @catch_format_error
674 def __call__(self, obj):
674 def __call__(self, obj):
675 """Compute the pretty representation of the object."""
675 """Compute the pretty representation of the object."""
676 if not self.pprint:
676 if not self.pprint:
677 return repr(obj)
677 return repr(obj)
678 else:
678 else:
679 # handle str and unicode on Python 2
679 # handle str and unicode on Python 2
680 # io.StringIO only accepts unicode,
680 # io.StringIO only accepts unicode,
681 # cStringIO doesn't handle unicode on py2,
681 # cStringIO doesn't handle unicode on py2,
682 # StringIO allows str, unicode but only ascii str
682 # StringIO allows str, unicode but only ascii str
683 stream = pretty.CUnicodeIO()
683 stream = pretty.CUnicodeIO()
684 printer = pretty.RepresentationPrinter(stream, self.verbose,
684 printer = pretty.RepresentationPrinter(stream, self.verbose,
685 self.max_width, self.newline,
685 self.max_width, self.newline,
686 max_seq_length=self.max_seq_length,
686 max_seq_length=self.max_seq_length,
687 singleton_pprinters=self.singleton_printers,
687 singleton_pprinters=self.singleton_printers,
688 type_pprinters=self.type_printers,
688 type_pprinters=self.type_printers,
689 deferred_pprinters=self.deferred_printers)
689 deferred_pprinters=self.deferred_printers)
690 printer.pretty(obj)
690 printer.pretty(obj)
691 printer.flush()
691 printer.flush()
692 return stream.getvalue()
692 return stream.getvalue()
693
693
694
694
695 class HTMLFormatter(BaseFormatter):
695 class HTMLFormatter(BaseFormatter):
696 """An HTML formatter.
696 """An HTML formatter.
697
697
698 To define the callables that compute the HTML representation of your
698 To define the callables that compute the HTML representation of your
699 objects, define a :meth:`_repr_html_` method or use the :meth:`for_type`
699 objects, define a :meth:`_repr_html_` method or use the :meth:`for_type`
700 or :meth:`for_type_by_name` methods to register functions that handle
700 or :meth:`for_type_by_name` methods to register functions that handle
701 this.
701 this.
702
702
703 The return value of this formatter should be a valid HTML snippet that
703 The return value of this formatter should be a valid HTML snippet that
704 could be injected into an existing DOM. It should *not* include the
704 could be injected into an existing DOM. It should *not* include the
705 ```<html>`` or ```<body>`` tags.
705 ```<html>`` or ```<body>`` tags.
706 """
706 """
707 format_type = Unicode('text/html')
707 format_type = Unicode('text/html')
708
708
709 print_method = ObjectName('_repr_html_')
709 print_method = ObjectName('_repr_html_')
710
710
711
711
712 class MarkdownFormatter(BaseFormatter):
712 class MarkdownFormatter(BaseFormatter):
713 """A Markdown formatter.
713 """A Markdown formatter.
714
714
715 To define the callables that compute the Markdown representation of your
715 To define the callables that compute the Markdown representation of your
716 objects, define a :meth:`_repr_markdown_` method or use the :meth:`for_type`
716 objects, define a :meth:`_repr_markdown_` method or use the :meth:`for_type`
717 or :meth:`for_type_by_name` methods to register functions that handle
717 or :meth:`for_type_by_name` methods to register functions that handle
718 this.
718 this.
719
719
720 The return value of this formatter should be a valid Markdown.
720 The return value of this formatter should be a valid Markdown.
721 """
721 """
722 format_type = Unicode('text/markdown')
722 format_type = Unicode('text/markdown')
723
723
724 print_method = ObjectName('_repr_markdown_')
724 print_method = ObjectName('_repr_markdown_')
725
725
726 class SVGFormatter(BaseFormatter):
726 class SVGFormatter(BaseFormatter):
727 """An SVG formatter.
727 """An SVG formatter.
728
728
729 To define the callables that compute the SVG representation of your
729 To define the callables that compute the SVG representation of your
730 objects, define a :meth:`_repr_svg_` method or use the :meth:`for_type`
730 objects, define a :meth:`_repr_svg_` method or use the :meth:`for_type`
731 or :meth:`for_type_by_name` methods to register functions that handle
731 or :meth:`for_type_by_name` methods to register functions that handle
732 this.
732 this.
733
733
734 The return value of this formatter should be valid SVG enclosed in
734 The return value of this formatter should be valid SVG enclosed in
735 ```<svg>``` tags, that could be injected into an existing DOM. It should
735 ```<svg>``` tags, that could be injected into an existing DOM. It should
736 *not* include the ```<html>`` or ```<body>`` tags.
736 *not* include the ```<html>`` or ```<body>`` tags.
737 """
737 """
738 format_type = Unicode('image/svg+xml')
738 format_type = Unicode('image/svg+xml')
739
739
740 print_method = ObjectName('_repr_svg_')
740 print_method = ObjectName('_repr_svg_')
741
741
742
742
743 class PNGFormatter(BaseFormatter):
743 class PNGFormatter(BaseFormatter):
744 """A PNG formatter.
744 """A PNG formatter.
745
745
746 To define the callables that compute the PNG representation of your
746 To define the callables that compute the PNG representation of your
747 objects, define a :meth:`_repr_png_` method or use the :meth:`for_type`
747 objects, define a :meth:`_repr_png_` method or use the :meth:`for_type`
748 or :meth:`for_type_by_name` methods to register functions that handle
748 or :meth:`for_type_by_name` methods to register functions that handle
749 this.
749 this.
750
750
751 The return value of this formatter should be raw PNG data, *not*
751 The return value of this formatter should be raw PNG data, *not*
752 base64 encoded.
752 base64 encoded.
753 """
753 """
754 format_type = Unicode('image/png')
754 format_type = Unicode('image/png')
755
755
756 print_method = ObjectName('_repr_png_')
756 print_method = ObjectName('_repr_png_')
757
757
758 _return_type = (bytes, unicode_type)
758 _return_type = (bytes, unicode_type)
759
759
760
760
761 class JPEGFormatter(BaseFormatter):
761 class JPEGFormatter(BaseFormatter):
762 """A JPEG formatter.
762 """A JPEG formatter.
763
763
764 To define the callables that compute the JPEG representation of your
764 To define the callables that compute the JPEG representation of your
765 objects, define a :meth:`_repr_jpeg_` method or use the :meth:`for_type`
765 objects, define a :meth:`_repr_jpeg_` method or use the :meth:`for_type`
766 or :meth:`for_type_by_name` methods to register functions that handle
766 or :meth:`for_type_by_name` methods to register functions that handle
767 this.
767 this.
768
768
769 The return value of this formatter should be raw JPEG data, *not*
769 The return value of this formatter should be raw JPEG data, *not*
770 base64 encoded.
770 base64 encoded.
771 """
771 """
772 format_type = Unicode('image/jpeg')
772 format_type = Unicode('image/jpeg')
773
773
774 print_method = ObjectName('_repr_jpeg_')
774 print_method = ObjectName('_repr_jpeg_')
775
775
776 _return_type = (bytes, unicode_type)
776 _return_type = (bytes, unicode_type)
777
777
778
778
779 class LatexFormatter(BaseFormatter):
779 class LatexFormatter(BaseFormatter):
780 """A LaTeX formatter.
780 """A LaTeX formatter.
781
781
782 To define the callables that compute the LaTeX representation of your
782 To define the callables that compute the LaTeX representation of your
783 objects, define a :meth:`_repr_latex_` method or use the :meth:`for_type`
783 objects, define a :meth:`_repr_latex_` method or use the :meth:`for_type`
784 or :meth:`for_type_by_name` methods to register functions that handle
784 or :meth:`for_type_by_name` methods to register functions that handle
785 this.
785 this.
786
786
787 The return value of this formatter should be a valid LaTeX equation,
787 The return value of this formatter should be a valid LaTeX equation,
788 enclosed in either ```$```, ```$$``` or another LaTeX equation
788 enclosed in either ```$```, ```$$``` or another LaTeX equation
789 environment.
789 environment.
790 """
790 """
791 format_type = Unicode('text/latex')
791 format_type = Unicode('text/latex')
792
792
793 print_method = ObjectName('_repr_latex_')
793 print_method = ObjectName('_repr_latex_')
794
794
795
795
796 class JSONFormatter(BaseFormatter):
796 class JSONFormatter(BaseFormatter):
797 """A JSON string formatter.
797 """A JSON string formatter.
798
798
799 To define the callables that compute the JSONable representation of
799 To define the callables that compute the JSONable representation of
800 your objects, define a :meth:`_repr_json_` method or use the :meth:`for_type`
800 your objects, define a :meth:`_repr_json_` method or use the :meth:`for_type`
801 or :meth:`for_type_by_name` methods to register functions that handle
801 or :meth:`for_type_by_name` methods to register functions that handle
802 this.
802 this.
803
803
804 The return value of this formatter should be a JSONable list or dict.
804 The return value of this formatter should be a JSONable list or dict.
805 JSON scalars (None, number, string) are not allowed, only dict or list containers.
805 JSON scalars (None, number, string) are not allowed, only dict or list containers.
806 """
806 """
807 format_type = Unicode('application/json')
807 format_type = Unicode('application/json')
808 _return_type = (list, dict)
808 _return_type = (list, dict)
809
809
810 print_method = ObjectName('_repr_json_')
810 print_method = ObjectName('_repr_json_')
811
811
812 def _check_return(self, r, obj):
812 def _check_return(self, r, obj):
813 """Check that a return value is appropriate
813 """Check that a return value is appropriate
814
814
815 Return the value if so, None otherwise, warning if invalid.
815 Return the value if so, None otherwise, warning if invalid.
816 """
816 """
817 if r is None:
817 if r is None:
818 return
818 return
819 md = None
819 md = None
820 if isinstance(r, tuple):
820 if isinstance(r, tuple):
821 # unpack data, metadata tuple for type checking on first element
821 # unpack data, metadata tuple for type checking on first element
822 r, md = r
822 r, md = r
823
823
824 # handle deprecated JSON-as-string form from IPython < 3
824 # handle deprecated JSON-as-string form from IPython < 3
825 if isinstance(r, string_types):
825 if isinstance(r, string_types):
826 warnings.warn("JSON expects JSONable list/dict containers, not JSON strings",
826 warnings.warn("JSON expects JSONable list/dict containers, not JSON strings",
827 FormatterWarning)
827 FormatterWarning)
828 r = json.loads(r)
828 r = json.loads(r)
829
829
830 if md is not None:
830 if md is not None:
831 # put the tuple back together
831 # put the tuple back together
832 r = (r, md)
832 r = (r, md)
833 return super(JSONFormatter, self)._check_return(r, obj)
833 return super(JSONFormatter, self)._check_return(r, obj)
834
834
835
835
836 class JavascriptFormatter(BaseFormatter):
836 class JavascriptFormatter(BaseFormatter):
837 """A Javascript formatter.
837 """A Javascript formatter.
838
838
839 To define the callables that compute the Javascript representation of
839 To define the callables that compute the Javascript representation of
840 your objects, define a :meth:`_repr_javascript_` method or use the
840 your objects, define a :meth:`_repr_javascript_` method or use the
841 :meth:`for_type` or :meth:`for_type_by_name` methods to register functions
841 :meth:`for_type` or :meth:`for_type_by_name` methods to register functions
842 that handle this.
842 that handle this.
843
843
844 The return value of this formatter should be valid Javascript code and
844 The return value of this formatter should be valid Javascript code and
845 should *not* be enclosed in ```<script>``` tags.
845 should *not* be enclosed in ```<script>``` tags.
846 """
846 """
847 format_type = Unicode('application/javascript')
847 format_type = Unicode('application/javascript')
848
848
849 print_method = ObjectName('_repr_javascript_')
849 print_method = ObjectName('_repr_javascript_')
850
850
851
851
852 class PDFFormatter(BaseFormatter):
852 class PDFFormatter(BaseFormatter):
853 """A PDF formatter.
853 """A PDF formatter.
854
854
855 To define the callables that compute the PDF representation of your
855 To define the callables that compute the PDF representation of your
856 objects, define a :meth:`_repr_pdf_` method or use the :meth:`for_type`
856 objects, define a :meth:`_repr_pdf_` method or use the :meth:`for_type`
857 or :meth:`for_type_by_name` methods to register functions that handle
857 or :meth:`for_type_by_name` methods to register functions that handle
858 this.
858 this.
859
859
860 The return value of this formatter should be raw PDF data, *not*
860 The return value of this formatter should be raw PDF data, *not*
861 base64 encoded.
861 base64 encoded.
862 """
862 """
863 format_type = Unicode('application/pdf')
863 format_type = Unicode('application/pdf')
864
864
865 print_method = ObjectName('_repr_pdf_')
865 print_method = ObjectName('_repr_pdf_')
866
866
867 _return_type = (bytes, unicode_type)
867 _return_type = (bytes, unicode_type)
868
868
869 class IPythonDisplayFormatter(BaseFormatter):
869 class IPythonDisplayFormatter(BaseFormatter):
870 """A Formatter for objects that know how to display themselves.
870 """A Formatter for objects that know how to display themselves.
871
871
872 To define the callables that compute the representation of your
872 To define the callables that compute the representation of your
873 objects, define a :meth:`_ipython_display_` method or use the :meth:`for_type`
873 objects, define a :meth:`_ipython_display_` method or use the :meth:`for_type`
874 or :meth:`for_type_by_name` methods to register functions that handle
874 or :meth:`for_type_by_name` methods to register functions that handle
875 this. Unlike mime-type displays, this method should not return anything,
875 this. Unlike mime-type displays, this method should not return anything,
876 instead calling any appropriate display methods itself.
876 instead calling any appropriate display methods itself.
877
877
878 This display formatter has highest priority.
878 This display formatter has highest priority.
879 If it fires, no other display formatter will be called.
879 If it fires, no other display formatter will be called.
880 """
880 """
881 print_method = ObjectName('_ipython_display_')
881 print_method = ObjectName('_ipython_display_')
882 _return_type = (type(None), bool)
882 _return_type = (type(None), bool)
883
883
884
884
885 @catch_format_error
885 @catch_format_error
886 def __call__(self, obj):
886 def __call__(self, obj):
887 """Compute the format for an object."""
887 """Compute the format for an object."""
888 if self.enabled:
888 if self.enabled:
889 # lookup registered printer
889 # lookup registered printer
890 try:
890 try:
891 printer = self.lookup(obj)
891 printer = self.lookup(obj)
892 except KeyError:
892 except KeyError:
893 pass
893 pass
894 else:
894 else:
895 printer(obj)
895 printer(obj)
896 return True
896 return True
897 # Finally look for special method names
897 # Finally look for special method names
898 method = _safe_get_formatter_method(obj, self.print_method)
898 method = _safe_get_formatter_method(obj, self.print_method)
899 if method is not None:
899 if method is not None:
900 method()
900 method()
901 return True
901 return True
902
902
903
903
904 FormatterABC.register(BaseFormatter)
904 FormatterABC.register(BaseFormatter)
905 FormatterABC.register(PlainTextFormatter)
905 FormatterABC.register(PlainTextFormatter)
906 FormatterABC.register(HTMLFormatter)
906 FormatterABC.register(HTMLFormatter)
907 FormatterABC.register(MarkdownFormatter)
907 FormatterABC.register(MarkdownFormatter)
908 FormatterABC.register(SVGFormatter)
908 FormatterABC.register(SVGFormatter)
909 FormatterABC.register(PNGFormatter)
909 FormatterABC.register(PNGFormatter)
910 FormatterABC.register(PDFFormatter)
910 FormatterABC.register(PDFFormatter)
911 FormatterABC.register(JPEGFormatter)
911 FormatterABC.register(JPEGFormatter)
912 FormatterABC.register(LatexFormatter)
912 FormatterABC.register(LatexFormatter)
913 FormatterABC.register(JSONFormatter)
913 FormatterABC.register(JSONFormatter)
914 FormatterABC.register(JavascriptFormatter)
914 FormatterABC.register(JavascriptFormatter)
915 FormatterABC.register(IPythonDisplayFormatter)
915 FormatterABC.register(IPythonDisplayFormatter)
916
916
917
917
918 def format_display_data(obj, include=None, exclude=None):
918 def format_display_data(obj, include=None, exclude=None):
919 """Return a format data dict for an object.
919 """Return a format data dict for an object.
920
920
921 By default all format types will be computed.
921 By default all format types will be computed.
922
922
923 The following MIME types are currently implemented:
923 The following MIME types are currently implemented:
924
924
925 * text/plain
925 * text/plain
926 * text/html
926 * text/html
927 * text/markdown
927 * text/markdown
928 * text/latex
928 * text/latex
929 * application/json
929 * application/json
930 * application/javascript
930 * application/javascript
931 * application/pdf
931 * application/pdf
932 * image/png
932 * image/png
933 * image/jpeg
933 * image/jpeg
934 * image/svg+xml
934 * image/svg+xml
935
935
936 Parameters
936 Parameters
937 ----------
937 ----------
938 obj : object
938 obj : object
939 The Python object whose format data will be computed.
939 The Python object whose format data will be computed.
940
940
941 Returns
941 Returns
942 -------
942 -------
943 format_dict : dict
943 format_dict : dict
944 A dictionary of key/value pairs, one or each format that was
944 A dictionary of key/value pairs, one or each format that was
945 generated for the object. The keys are the format types, which
945 generated for the object. The keys are the format types, which
946 will usually be MIME type strings and the values and JSON'able
946 will usually be MIME type strings and the values and JSON'able
947 data structure containing the raw data for the representation in
947 data structure containing the raw data for the representation in
948 that format.
948 that format.
949 include : list or tuple, optional
949 include : list or tuple, optional
950 A list of format type strings (MIME types) to include in the
950 A list of format type strings (MIME types) to include in the
951 format data dict. If this is set *only* the format types included
951 format data dict. If this is set *only* the format types included
952 in this list will be computed.
952 in this list will be computed.
953 exclude : list or tuple, optional
953 exclude : list or tuple, optional
954 A list of format type string (MIME types) to exclue in the format
954 A list of format type string (MIME types) to exclue in the format
955 data dict. If this is set all format types will be computed,
955 data dict. If this is set all format types will be computed,
956 except for those included in this argument.
956 except for those included in this argument.
957 """
957 """
958 from IPython.core.interactiveshell import InteractiveShell
958 from IPython.core.interactiveshell import InteractiveShell
959
959
960 InteractiveShell.instance().display_formatter.format(
960 InteractiveShell.instance().display_formatter.format(
961 obj,
961 obj,
962 include,
962 include,
963 exclude
963 exclude
964 )
964 )
965
965
@@ -1,870 +1,870 b''
1 """ History related magics and functionality """
1 """ History related magics and functionality """
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010-2011 The IPython Development Team.
3 # Copyright (C) 2010-2011 The IPython Development Team.
4 #
4 #
5 # Distributed under the terms of the BSD License.
5 # Distributed under the terms of the BSD License.
6 #
6 #
7 # The full license is in the file COPYING.txt, distributed with this software.
7 # The full license is in the file COPYING.txt, distributed with this software.
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 from __future__ import print_function
13 from __future__ import print_function
14
14
15 # Stdlib imports
15 # Stdlib imports
16 import atexit
16 import atexit
17 import datetime
17 import datetime
18 import os
18 import os
19 import re
19 import re
20 try:
20 try:
21 import sqlite3
21 import sqlite3
22 except ImportError:
22 except ImportError:
23 try:
23 try:
24 from pysqlite2 import dbapi2 as sqlite3
24 from pysqlite2 import dbapi2 as sqlite3
25 except ImportError:
25 except ImportError:
26 sqlite3 = None
26 sqlite3 = None
27 import threading
27 import threading
28
28
29 # Our own packages
29 # Our own packages
30 from IPython.config.configurable import Configurable
30 from IPython.config.configurable import Configurable
31 from IPython.external.decorator import decorator
31 from decorator import decorator
32 from IPython.utils.decorators import undoc
32 from IPython.utils.decorators import undoc
33 from IPython.utils.path import locate_profile
33 from IPython.utils.path import locate_profile
34 from IPython.utils import py3compat
34 from IPython.utils import py3compat
35 from IPython.utils.traitlets import (
35 from IPython.utils.traitlets import (
36 Any, Bool, Dict, Instance, Integer, List, Unicode, TraitError,
36 Any, Bool, Dict, Instance, Integer, List, Unicode, TraitError,
37 )
37 )
38 from IPython.utils.warn import warn
38 from IPython.utils.warn import warn
39
39
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41 # Classes and functions
41 # Classes and functions
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43
43
44 @undoc
44 @undoc
45 class DummyDB(object):
45 class DummyDB(object):
46 """Dummy DB that will act as a black hole for history.
46 """Dummy DB that will act as a black hole for history.
47
47
48 Only used in the absence of sqlite"""
48 Only used in the absence of sqlite"""
49 def execute(*args, **kwargs):
49 def execute(*args, **kwargs):
50 return []
50 return []
51
51
52 def commit(self, *args, **kwargs):
52 def commit(self, *args, **kwargs):
53 pass
53 pass
54
54
55 def __enter__(self, *args, **kwargs):
55 def __enter__(self, *args, **kwargs):
56 pass
56 pass
57
57
58 def __exit__(self, *args, **kwargs):
58 def __exit__(self, *args, **kwargs):
59 pass
59 pass
60
60
61
61
62 @decorator
62 @decorator
63 def needs_sqlite(f, self, *a, **kw):
63 def needs_sqlite(f, self, *a, **kw):
64 """Decorator: return an empty list in the absence of sqlite."""
64 """Decorator: return an empty list in the absence of sqlite."""
65 if sqlite3 is None or not self.enabled:
65 if sqlite3 is None or not self.enabled:
66 return []
66 return []
67 else:
67 else:
68 return f(self, *a, **kw)
68 return f(self, *a, **kw)
69
69
70
70
71 if sqlite3 is not None:
71 if sqlite3 is not None:
72 DatabaseError = sqlite3.DatabaseError
72 DatabaseError = sqlite3.DatabaseError
73 else:
73 else:
74 @undoc
74 @undoc
75 class DatabaseError(Exception):
75 class DatabaseError(Exception):
76 "Dummy exception when sqlite could not be imported. Should never occur."
76 "Dummy exception when sqlite could not be imported. Should never occur."
77
77
78 @decorator
78 @decorator
79 def catch_corrupt_db(f, self, *a, **kw):
79 def catch_corrupt_db(f, self, *a, **kw):
80 """A decorator which wraps HistoryAccessor method calls to catch errors from
80 """A decorator which wraps HistoryAccessor method calls to catch errors from
81 a corrupt SQLite database, move the old database out of the way, and create
81 a corrupt SQLite database, move the old database out of the way, and create
82 a new one.
82 a new one.
83 """
83 """
84 try:
84 try:
85 return f(self, *a, **kw)
85 return f(self, *a, **kw)
86 except DatabaseError:
86 except DatabaseError:
87 if os.path.isfile(self.hist_file):
87 if os.path.isfile(self.hist_file):
88 # Try to move the file out of the way
88 # Try to move the file out of the way
89 base,ext = os.path.splitext(self.hist_file)
89 base,ext = os.path.splitext(self.hist_file)
90 newpath = base + '-corrupt' + ext
90 newpath = base + '-corrupt' + ext
91 os.rename(self.hist_file, newpath)
91 os.rename(self.hist_file, newpath)
92 self.init_db()
92 self.init_db()
93 print("ERROR! History file wasn't a valid SQLite database.",
93 print("ERROR! History file wasn't a valid SQLite database.",
94 "It was moved to %s" % newpath, "and a new file created.")
94 "It was moved to %s" % newpath, "and a new file created.")
95 return []
95 return []
96
96
97 else:
97 else:
98 # The hist_file is probably :memory: or something else.
98 # The hist_file is probably :memory: or something else.
99 raise
99 raise
100
100
101 class HistoryAccessorBase(Configurable):
101 class HistoryAccessorBase(Configurable):
102 """An abstract class for History Accessors """
102 """An abstract class for History Accessors """
103
103
104 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
104 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
105 raise NotImplementedError
105 raise NotImplementedError
106
106
107 def search(self, pattern="*", raw=True, search_raw=True,
107 def search(self, pattern="*", raw=True, search_raw=True,
108 output=False, n=None, unique=False):
108 output=False, n=None, unique=False):
109 raise NotImplementedError
109 raise NotImplementedError
110
110
111 def get_range(self, session, start=1, stop=None, raw=True,output=False):
111 def get_range(self, session, start=1, stop=None, raw=True,output=False):
112 raise NotImplementedError
112 raise NotImplementedError
113
113
114 def get_range_by_str(self, rangestr, raw=True, output=False):
114 def get_range_by_str(self, rangestr, raw=True, output=False):
115 raise NotImplementedError
115 raise NotImplementedError
116
116
117
117
118 class HistoryAccessor(HistoryAccessorBase):
118 class HistoryAccessor(HistoryAccessorBase):
119 """Access the history database without adding to it.
119 """Access the history database without adding to it.
120
120
121 This is intended for use by standalone history tools. IPython shells use
121 This is intended for use by standalone history tools. IPython shells use
122 HistoryManager, below, which is a subclass of this."""
122 HistoryManager, below, which is a subclass of this."""
123
123
124 # String holding the path to the history file
124 # String holding the path to the history file
125 hist_file = Unicode(config=True,
125 hist_file = Unicode(config=True,
126 help="""Path to file to use for SQLite history database.
126 help="""Path to file to use for SQLite history database.
127
127
128 By default, IPython will put the history database in the IPython
128 By default, IPython will put the history database in the IPython
129 profile directory. If you would rather share one history among
129 profile directory. If you would rather share one history among
130 profiles, you can set this value in each, so that they are consistent.
130 profiles, you can set this value in each, so that they are consistent.
131
131
132 Due to an issue with fcntl, SQLite is known to misbehave on some NFS
132 Due to an issue with fcntl, SQLite is known to misbehave on some NFS
133 mounts. If you see IPython hanging, try setting this to something on a
133 mounts. If you see IPython hanging, try setting this to something on a
134 local disk, e.g::
134 local disk, e.g::
135
135
136 ipython --HistoryManager.hist_file=/tmp/ipython_hist.sqlite
136 ipython --HistoryManager.hist_file=/tmp/ipython_hist.sqlite
137
137
138 """)
138 """)
139
139
140 enabled = Bool(True, config=True,
140 enabled = Bool(True, config=True,
141 help="""enable the SQLite history
141 help="""enable the SQLite history
142
142
143 set enabled=False to disable the SQLite history,
143 set enabled=False to disable the SQLite history,
144 in which case there will be no stored history, no SQLite connection,
144 in which case there will be no stored history, no SQLite connection,
145 and no background saving thread. This may be necessary in some
145 and no background saving thread. This may be necessary in some
146 threaded environments where IPython is embedded.
146 threaded environments where IPython is embedded.
147 """
147 """
148 )
148 )
149
149
150 connection_options = Dict(config=True,
150 connection_options = Dict(config=True,
151 help="""Options for configuring the SQLite connection
151 help="""Options for configuring the SQLite connection
152
152
153 These options are passed as keyword args to sqlite3.connect
153 These options are passed as keyword args to sqlite3.connect
154 when establishing database conenctions.
154 when establishing database conenctions.
155 """
155 """
156 )
156 )
157
157
158 # The SQLite database
158 # The SQLite database
159 db = Any()
159 db = Any()
160 def _db_changed(self, name, old, new):
160 def _db_changed(self, name, old, new):
161 """validate the db, since it can be an Instance of two different types"""
161 """validate the db, since it can be an Instance of two different types"""
162 connection_types = (DummyDB,)
162 connection_types = (DummyDB,)
163 if sqlite3 is not None:
163 if sqlite3 is not None:
164 connection_types = (DummyDB, sqlite3.Connection)
164 connection_types = (DummyDB, sqlite3.Connection)
165 if not isinstance(new, connection_types):
165 if not isinstance(new, connection_types):
166 msg = "%s.db must be sqlite3 Connection or DummyDB, not %r" % \
166 msg = "%s.db must be sqlite3 Connection or DummyDB, not %r" % \
167 (self.__class__.__name__, new)
167 (self.__class__.__name__, new)
168 raise TraitError(msg)
168 raise TraitError(msg)
169
169
170 def __init__(self, profile='default', hist_file=u'', **traits):
170 def __init__(self, profile='default', hist_file=u'', **traits):
171 """Create a new history accessor.
171 """Create a new history accessor.
172
172
173 Parameters
173 Parameters
174 ----------
174 ----------
175 profile : str
175 profile : str
176 The name of the profile from which to open history.
176 The name of the profile from which to open history.
177 hist_file : str
177 hist_file : str
178 Path to an SQLite history database stored by IPython. If specified,
178 Path to an SQLite history database stored by IPython. If specified,
179 hist_file overrides profile.
179 hist_file overrides profile.
180 config : :class:`~IPython.config.loader.Config`
180 config : :class:`~IPython.config.loader.Config`
181 Config object. hist_file can also be set through this.
181 Config object. hist_file can also be set through this.
182 """
182 """
183 # We need a pointer back to the shell for various tasks.
183 # We need a pointer back to the shell for various tasks.
184 super(HistoryAccessor, self).__init__(**traits)
184 super(HistoryAccessor, self).__init__(**traits)
185 # defer setting hist_file from kwarg until after init,
185 # defer setting hist_file from kwarg until after init,
186 # otherwise the default kwarg value would clobber any value
186 # otherwise the default kwarg value would clobber any value
187 # set by config
187 # set by config
188 if hist_file:
188 if hist_file:
189 self.hist_file = hist_file
189 self.hist_file = hist_file
190
190
191 if self.hist_file == u'':
191 if self.hist_file == u'':
192 # No one has set the hist_file, yet.
192 # No one has set the hist_file, yet.
193 self.hist_file = self._get_hist_file_name(profile)
193 self.hist_file = self._get_hist_file_name(profile)
194
194
195 if sqlite3 is None and self.enabled:
195 if sqlite3 is None and self.enabled:
196 warn("IPython History requires SQLite, your history will not be saved")
196 warn("IPython History requires SQLite, your history will not be saved")
197 self.enabled = False
197 self.enabled = False
198
198
199 self.init_db()
199 self.init_db()
200
200
201 def _get_hist_file_name(self, profile='default'):
201 def _get_hist_file_name(self, profile='default'):
202 """Find the history file for the given profile name.
202 """Find the history file for the given profile name.
203
203
204 This is overridden by the HistoryManager subclass, to use the shell's
204 This is overridden by the HistoryManager subclass, to use the shell's
205 active profile.
205 active profile.
206
206
207 Parameters
207 Parameters
208 ----------
208 ----------
209 profile : str
209 profile : str
210 The name of a profile which has a history file.
210 The name of a profile which has a history file.
211 """
211 """
212 return os.path.join(locate_profile(profile), 'history.sqlite')
212 return os.path.join(locate_profile(profile), 'history.sqlite')
213
213
214 @catch_corrupt_db
214 @catch_corrupt_db
215 def init_db(self):
215 def init_db(self):
216 """Connect to the database, and create tables if necessary."""
216 """Connect to the database, and create tables if necessary."""
217 if not self.enabled:
217 if not self.enabled:
218 self.db = DummyDB()
218 self.db = DummyDB()
219 return
219 return
220
220
221 # use detect_types so that timestamps return datetime objects
221 # use detect_types so that timestamps return datetime objects
222 kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
222 kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
223 kwargs.update(self.connection_options)
223 kwargs.update(self.connection_options)
224 self.db = sqlite3.connect(self.hist_file, **kwargs)
224 self.db = sqlite3.connect(self.hist_file, **kwargs)
225 self.db.execute("""CREATE TABLE IF NOT EXISTS sessions (session integer
225 self.db.execute("""CREATE TABLE IF NOT EXISTS sessions (session integer
226 primary key autoincrement, start timestamp,
226 primary key autoincrement, start timestamp,
227 end timestamp, num_cmds integer, remark text)""")
227 end timestamp, num_cmds integer, remark text)""")
228 self.db.execute("""CREATE TABLE IF NOT EXISTS history
228 self.db.execute("""CREATE TABLE IF NOT EXISTS history
229 (session integer, line integer, source text, source_raw text,
229 (session integer, line integer, source text, source_raw text,
230 PRIMARY KEY (session, line))""")
230 PRIMARY KEY (session, line))""")
231 # Output history is optional, but ensure the table's there so it can be
231 # Output history is optional, but ensure the table's there so it can be
232 # enabled later.
232 # enabled later.
233 self.db.execute("""CREATE TABLE IF NOT EXISTS output_history
233 self.db.execute("""CREATE TABLE IF NOT EXISTS output_history
234 (session integer, line integer, output text,
234 (session integer, line integer, output text,
235 PRIMARY KEY (session, line))""")
235 PRIMARY KEY (session, line))""")
236 self.db.commit()
236 self.db.commit()
237
237
238 def writeout_cache(self):
238 def writeout_cache(self):
239 """Overridden by HistoryManager to dump the cache before certain
239 """Overridden by HistoryManager to dump the cache before certain
240 database lookups."""
240 database lookups."""
241 pass
241 pass
242
242
243 ## -------------------------------
243 ## -------------------------------
244 ## Methods for retrieving history:
244 ## Methods for retrieving history:
245 ## -------------------------------
245 ## -------------------------------
246 def _run_sql(self, sql, params, raw=True, output=False):
246 def _run_sql(self, sql, params, raw=True, output=False):
247 """Prepares and runs an SQL query for the history database.
247 """Prepares and runs an SQL query for the history database.
248
248
249 Parameters
249 Parameters
250 ----------
250 ----------
251 sql : str
251 sql : str
252 Any filtering expressions to go after SELECT ... FROM ...
252 Any filtering expressions to go after SELECT ... FROM ...
253 params : tuple
253 params : tuple
254 Parameters passed to the SQL query (to replace "?")
254 Parameters passed to the SQL query (to replace "?")
255 raw, output : bool
255 raw, output : bool
256 See :meth:`get_range`
256 See :meth:`get_range`
257
257
258 Returns
258 Returns
259 -------
259 -------
260 Tuples as :meth:`get_range`
260 Tuples as :meth:`get_range`
261 """
261 """
262 toget = 'source_raw' if raw else 'source'
262 toget = 'source_raw' if raw else 'source'
263 sqlfrom = "history"
263 sqlfrom = "history"
264 if output:
264 if output:
265 sqlfrom = "history LEFT JOIN output_history USING (session, line)"
265 sqlfrom = "history LEFT JOIN output_history USING (session, line)"
266 toget = "history.%s, output_history.output" % toget
266 toget = "history.%s, output_history.output" % toget
267 cur = self.db.execute("SELECT session, line, %s FROM %s " %\
267 cur = self.db.execute("SELECT session, line, %s FROM %s " %\
268 (toget, sqlfrom) + sql, params)
268 (toget, sqlfrom) + sql, params)
269 if output: # Regroup into 3-tuples, and parse JSON
269 if output: # Regroup into 3-tuples, and parse JSON
270 return ((ses, lin, (inp, out)) for ses, lin, inp, out in cur)
270 return ((ses, lin, (inp, out)) for ses, lin, inp, out in cur)
271 return cur
271 return cur
272
272
273 @needs_sqlite
273 @needs_sqlite
274 @catch_corrupt_db
274 @catch_corrupt_db
275 def get_session_info(self, session):
275 def get_session_info(self, session):
276 """Get info about a session.
276 """Get info about a session.
277
277
278 Parameters
278 Parameters
279 ----------
279 ----------
280
280
281 session : int
281 session : int
282 Session number to retrieve.
282 Session number to retrieve.
283
283
284 Returns
284 Returns
285 -------
285 -------
286
286
287 session_id : int
287 session_id : int
288 Session ID number
288 Session ID number
289 start : datetime
289 start : datetime
290 Timestamp for the start of the session.
290 Timestamp for the start of the session.
291 end : datetime
291 end : datetime
292 Timestamp for the end of the session, or None if IPython crashed.
292 Timestamp for the end of the session, or None if IPython crashed.
293 num_cmds : int
293 num_cmds : int
294 Number of commands run, or None if IPython crashed.
294 Number of commands run, or None if IPython crashed.
295 remark : unicode
295 remark : unicode
296 A manually set description.
296 A manually set description.
297 """
297 """
298 query = "SELECT * from sessions where session == ?"
298 query = "SELECT * from sessions where session == ?"
299 return self.db.execute(query, (session,)).fetchone()
299 return self.db.execute(query, (session,)).fetchone()
300
300
301 @catch_corrupt_db
301 @catch_corrupt_db
302 def get_last_session_id(self):
302 def get_last_session_id(self):
303 """Get the last session ID currently in the database.
303 """Get the last session ID currently in the database.
304
304
305 Within IPython, this should be the same as the value stored in
305 Within IPython, this should be the same as the value stored in
306 :attr:`HistoryManager.session_number`.
306 :attr:`HistoryManager.session_number`.
307 """
307 """
308 for record in self.get_tail(n=1, include_latest=True):
308 for record in self.get_tail(n=1, include_latest=True):
309 return record[0]
309 return record[0]
310
310
311 @catch_corrupt_db
311 @catch_corrupt_db
312 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
312 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
313 """Get the last n lines from the history database.
313 """Get the last n lines from the history database.
314
314
315 Parameters
315 Parameters
316 ----------
316 ----------
317 n : int
317 n : int
318 The number of lines to get
318 The number of lines to get
319 raw, output : bool
319 raw, output : bool
320 See :meth:`get_range`
320 See :meth:`get_range`
321 include_latest : bool
321 include_latest : bool
322 If False (default), n+1 lines are fetched, and the latest one
322 If False (default), n+1 lines are fetched, and the latest one
323 is discarded. This is intended to be used where the function
323 is discarded. This is intended to be used where the function
324 is called by a user command, which it should not return.
324 is called by a user command, which it should not return.
325
325
326 Returns
326 Returns
327 -------
327 -------
328 Tuples as :meth:`get_range`
328 Tuples as :meth:`get_range`
329 """
329 """
330 self.writeout_cache()
330 self.writeout_cache()
331 if not include_latest:
331 if not include_latest:
332 n += 1
332 n += 1
333 cur = self._run_sql("ORDER BY session DESC, line DESC LIMIT ?",
333 cur = self._run_sql("ORDER BY session DESC, line DESC LIMIT ?",
334 (n,), raw=raw, output=output)
334 (n,), raw=raw, output=output)
335 if not include_latest:
335 if not include_latest:
336 return reversed(list(cur)[1:])
336 return reversed(list(cur)[1:])
337 return reversed(list(cur))
337 return reversed(list(cur))
338
338
339 @catch_corrupt_db
339 @catch_corrupt_db
340 def search(self, pattern="*", raw=True, search_raw=True,
340 def search(self, pattern="*", raw=True, search_raw=True,
341 output=False, n=None, unique=False):
341 output=False, n=None, unique=False):
342 """Search the database using unix glob-style matching (wildcards
342 """Search the database using unix glob-style matching (wildcards
343 * and ?).
343 * and ?).
344
344
345 Parameters
345 Parameters
346 ----------
346 ----------
347 pattern : str
347 pattern : str
348 The wildcarded pattern to match when searching
348 The wildcarded pattern to match when searching
349 search_raw : bool
349 search_raw : bool
350 If True, search the raw input, otherwise, the parsed input
350 If True, search the raw input, otherwise, the parsed input
351 raw, output : bool
351 raw, output : bool
352 See :meth:`get_range`
352 See :meth:`get_range`
353 n : None or int
353 n : None or int
354 If an integer is given, it defines the limit of
354 If an integer is given, it defines the limit of
355 returned entries.
355 returned entries.
356 unique : bool
356 unique : bool
357 When it is true, return only unique entries.
357 When it is true, return only unique entries.
358
358
359 Returns
359 Returns
360 -------
360 -------
361 Tuples as :meth:`get_range`
361 Tuples as :meth:`get_range`
362 """
362 """
363 tosearch = "source_raw" if search_raw else "source"
363 tosearch = "source_raw" if search_raw else "source"
364 if output:
364 if output:
365 tosearch = "history." + tosearch
365 tosearch = "history." + tosearch
366 self.writeout_cache()
366 self.writeout_cache()
367 sqlform = "WHERE %s GLOB ?" % tosearch
367 sqlform = "WHERE %s GLOB ?" % tosearch
368 params = (pattern,)
368 params = (pattern,)
369 if unique:
369 if unique:
370 sqlform += ' GROUP BY {0}'.format(tosearch)
370 sqlform += ' GROUP BY {0}'.format(tosearch)
371 if n is not None:
371 if n is not None:
372 sqlform += " ORDER BY session DESC, line DESC LIMIT ?"
372 sqlform += " ORDER BY session DESC, line DESC LIMIT ?"
373 params += (n,)
373 params += (n,)
374 elif unique:
374 elif unique:
375 sqlform += " ORDER BY session, line"
375 sqlform += " ORDER BY session, line"
376 cur = self._run_sql(sqlform, params, raw=raw, output=output)
376 cur = self._run_sql(sqlform, params, raw=raw, output=output)
377 if n is not None:
377 if n is not None:
378 return reversed(list(cur))
378 return reversed(list(cur))
379 return cur
379 return cur
380
380
381 @catch_corrupt_db
381 @catch_corrupt_db
382 def get_range(self, session, start=1, stop=None, raw=True,output=False):
382 def get_range(self, session, start=1, stop=None, raw=True,output=False):
383 """Retrieve input by session.
383 """Retrieve input by session.
384
384
385 Parameters
385 Parameters
386 ----------
386 ----------
387 session : int
387 session : int
388 Session number to retrieve.
388 Session number to retrieve.
389 start : int
389 start : int
390 First line to retrieve.
390 First line to retrieve.
391 stop : int
391 stop : int
392 End of line range (excluded from output itself). If None, retrieve
392 End of line range (excluded from output itself). If None, retrieve
393 to the end of the session.
393 to the end of the session.
394 raw : bool
394 raw : bool
395 If True, return untranslated input
395 If True, return untranslated input
396 output : bool
396 output : bool
397 If True, attempt to include output. This will be 'real' Python
397 If True, attempt to include output. This will be 'real' Python
398 objects for the current session, or text reprs from previous
398 objects for the current session, or text reprs from previous
399 sessions if db_log_output was enabled at the time. Where no output
399 sessions if db_log_output was enabled at the time. Where no output
400 is found, None is used.
400 is found, None is used.
401
401
402 Returns
402 Returns
403 -------
403 -------
404 entries
404 entries
405 An iterator over the desired lines. Each line is a 3-tuple, either
405 An iterator over the desired lines. Each line is a 3-tuple, either
406 (session, line, input) if output is False, or
406 (session, line, input) if output is False, or
407 (session, line, (input, output)) if output is True.
407 (session, line, (input, output)) if output is True.
408 """
408 """
409 if stop:
409 if stop:
410 lineclause = "line >= ? AND line < ?"
410 lineclause = "line >= ? AND line < ?"
411 params = (session, start, stop)
411 params = (session, start, stop)
412 else:
412 else:
413 lineclause = "line>=?"
413 lineclause = "line>=?"
414 params = (session, start)
414 params = (session, start)
415
415
416 return self._run_sql("WHERE session==? AND %s" % lineclause,
416 return self._run_sql("WHERE session==? AND %s" % lineclause,
417 params, raw=raw, output=output)
417 params, raw=raw, output=output)
418
418
419 def get_range_by_str(self, rangestr, raw=True, output=False):
419 def get_range_by_str(self, rangestr, raw=True, output=False):
420 """Get lines of history from a string of ranges, as used by magic
420 """Get lines of history from a string of ranges, as used by magic
421 commands %hist, %save, %macro, etc.
421 commands %hist, %save, %macro, etc.
422
422
423 Parameters
423 Parameters
424 ----------
424 ----------
425 rangestr : str
425 rangestr : str
426 A string specifying ranges, e.g. "5 ~2/1-4". See
426 A string specifying ranges, e.g. "5 ~2/1-4". See
427 :func:`magic_history` for full details.
427 :func:`magic_history` for full details.
428 raw, output : bool
428 raw, output : bool
429 As :meth:`get_range`
429 As :meth:`get_range`
430
430
431 Returns
431 Returns
432 -------
432 -------
433 Tuples as :meth:`get_range`
433 Tuples as :meth:`get_range`
434 """
434 """
435 for sess, s, e in extract_hist_ranges(rangestr):
435 for sess, s, e in extract_hist_ranges(rangestr):
436 for line in self.get_range(sess, s, e, raw=raw, output=output):
436 for line in self.get_range(sess, s, e, raw=raw, output=output):
437 yield line
437 yield line
438
438
439
439
440 class HistoryManager(HistoryAccessor):
440 class HistoryManager(HistoryAccessor):
441 """A class to organize all history-related functionality in one place.
441 """A class to organize all history-related functionality in one place.
442 """
442 """
443 # Public interface
443 # Public interface
444
444
445 # An instance of the IPython shell we are attached to
445 # An instance of the IPython shell we are attached to
446 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
446 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
447 # Lists to hold processed and raw history. These start with a blank entry
447 # Lists to hold processed and raw history. These start with a blank entry
448 # so that we can index them starting from 1
448 # so that we can index them starting from 1
449 input_hist_parsed = List([""])
449 input_hist_parsed = List([""])
450 input_hist_raw = List([""])
450 input_hist_raw = List([""])
451 # A list of directories visited during session
451 # A list of directories visited during session
452 dir_hist = List()
452 dir_hist = List()
453 def _dir_hist_default(self):
453 def _dir_hist_default(self):
454 try:
454 try:
455 return [py3compat.getcwd()]
455 return [py3compat.getcwd()]
456 except OSError:
456 except OSError:
457 return []
457 return []
458
458
459 # A dict of output history, keyed with ints from the shell's
459 # A dict of output history, keyed with ints from the shell's
460 # execution count.
460 # execution count.
461 output_hist = Dict()
461 output_hist = Dict()
462 # The text/plain repr of outputs.
462 # The text/plain repr of outputs.
463 output_hist_reprs = Dict()
463 output_hist_reprs = Dict()
464
464
465 # The number of the current session in the history database
465 # The number of the current session in the history database
466 session_number = Integer()
466 session_number = Integer()
467
467
468 db_log_output = Bool(False, config=True,
468 db_log_output = Bool(False, config=True,
469 help="Should the history database include output? (default: no)"
469 help="Should the history database include output? (default: no)"
470 )
470 )
471 db_cache_size = Integer(0, config=True,
471 db_cache_size = Integer(0, config=True,
472 help="Write to database every x commands (higher values save disk access & power).\n"
472 help="Write to database every x commands (higher values save disk access & power).\n"
473 "Values of 1 or less effectively disable caching."
473 "Values of 1 or less effectively disable caching."
474 )
474 )
475 # The input and output caches
475 # The input and output caches
476 db_input_cache = List()
476 db_input_cache = List()
477 db_output_cache = List()
477 db_output_cache = List()
478
478
479 # History saving in separate thread
479 # History saving in separate thread
480 save_thread = Instance('IPython.core.history.HistorySavingThread')
480 save_thread = Instance('IPython.core.history.HistorySavingThread')
481 try: # Event is a function returning an instance of _Event...
481 try: # Event is a function returning an instance of _Event...
482 save_flag = Instance(threading._Event)
482 save_flag = Instance(threading._Event)
483 except AttributeError: # ...until Python 3.3, when it's a class.
483 except AttributeError: # ...until Python 3.3, when it's a class.
484 save_flag = Instance(threading.Event)
484 save_flag = Instance(threading.Event)
485
485
486 # Private interface
486 # Private interface
487 # Variables used to store the three last inputs from the user. On each new
487 # Variables used to store the three last inputs from the user. On each new
488 # history update, we populate the user's namespace with these, shifted as
488 # history update, we populate the user's namespace with these, shifted as
489 # necessary.
489 # necessary.
490 _i00 = Unicode(u'')
490 _i00 = Unicode(u'')
491 _i = Unicode(u'')
491 _i = Unicode(u'')
492 _ii = Unicode(u'')
492 _ii = Unicode(u'')
493 _iii = Unicode(u'')
493 _iii = Unicode(u'')
494
494
495 # A regex matching all forms of the exit command, so that we don't store
495 # A regex matching all forms of the exit command, so that we don't store
496 # them in the history (it's annoying to rewind the first entry and land on
496 # them in the history (it's annoying to rewind the first entry and land on
497 # an exit call).
497 # an exit call).
498 _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")
498 _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")
499
499
500 def __init__(self, shell=None, config=None, **traits):
500 def __init__(self, shell=None, config=None, **traits):
501 """Create a new history manager associated with a shell instance.
501 """Create a new history manager associated with a shell instance.
502 """
502 """
503 # We need a pointer back to the shell for various tasks.
503 # We need a pointer back to the shell for various tasks.
504 super(HistoryManager, self).__init__(shell=shell, config=config,
504 super(HistoryManager, self).__init__(shell=shell, config=config,
505 **traits)
505 **traits)
506 self.save_flag = threading.Event()
506 self.save_flag = threading.Event()
507 self.db_input_cache_lock = threading.Lock()
507 self.db_input_cache_lock = threading.Lock()
508 self.db_output_cache_lock = threading.Lock()
508 self.db_output_cache_lock = threading.Lock()
509 if self.enabled and self.hist_file != ':memory:':
509 if self.enabled and self.hist_file != ':memory:':
510 self.save_thread = HistorySavingThread(self)
510 self.save_thread = HistorySavingThread(self)
511 self.save_thread.start()
511 self.save_thread.start()
512
512
513 self.new_session()
513 self.new_session()
514
514
515 def _get_hist_file_name(self, profile=None):
515 def _get_hist_file_name(self, profile=None):
516 """Get default history file name based on the Shell's profile.
516 """Get default history file name based on the Shell's profile.
517
517
518 The profile parameter is ignored, but must exist for compatibility with
518 The profile parameter is ignored, but must exist for compatibility with
519 the parent class."""
519 the parent class."""
520 profile_dir = self.shell.profile_dir.location
520 profile_dir = self.shell.profile_dir.location
521 return os.path.join(profile_dir, 'history.sqlite')
521 return os.path.join(profile_dir, 'history.sqlite')
522
522
523 @needs_sqlite
523 @needs_sqlite
524 def new_session(self, conn=None):
524 def new_session(self, conn=None):
525 """Get a new session number."""
525 """Get a new session number."""
526 if conn is None:
526 if conn is None:
527 conn = self.db
527 conn = self.db
528
528
529 with conn:
529 with conn:
530 cur = conn.execute("""INSERT INTO sessions VALUES (NULL, ?, NULL,
530 cur = conn.execute("""INSERT INTO sessions VALUES (NULL, ?, NULL,
531 NULL, "") """, (datetime.datetime.now(),))
531 NULL, "") """, (datetime.datetime.now(),))
532 self.session_number = cur.lastrowid
532 self.session_number = cur.lastrowid
533
533
534 def end_session(self):
534 def end_session(self):
535 """Close the database session, filling in the end time and line count."""
535 """Close the database session, filling in the end time and line count."""
536 self.writeout_cache()
536 self.writeout_cache()
537 with self.db:
537 with self.db:
538 self.db.execute("""UPDATE sessions SET end=?, num_cmds=? WHERE
538 self.db.execute("""UPDATE sessions SET end=?, num_cmds=? WHERE
539 session==?""", (datetime.datetime.now(),
539 session==?""", (datetime.datetime.now(),
540 len(self.input_hist_parsed)-1, self.session_number))
540 len(self.input_hist_parsed)-1, self.session_number))
541 self.session_number = 0
541 self.session_number = 0
542
542
543 def name_session(self, name):
543 def name_session(self, name):
544 """Give the current session a name in the history database."""
544 """Give the current session a name in the history database."""
545 with self.db:
545 with self.db:
546 self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
546 self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
547 (name, self.session_number))
547 (name, self.session_number))
548
548
549 def reset(self, new_session=True):
549 def reset(self, new_session=True):
550 """Clear the session history, releasing all object references, and
550 """Clear the session history, releasing all object references, and
551 optionally open a new session."""
551 optionally open a new session."""
552 self.output_hist.clear()
552 self.output_hist.clear()
553 # The directory history can't be completely empty
553 # The directory history can't be completely empty
554 self.dir_hist[:] = [py3compat.getcwd()]
554 self.dir_hist[:] = [py3compat.getcwd()]
555
555
556 if new_session:
556 if new_session:
557 if self.session_number:
557 if self.session_number:
558 self.end_session()
558 self.end_session()
559 self.input_hist_parsed[:] = [""]
559 self.input_hist_parsed[:] = [""]
560 self.input_hist_raw[:] = [""]
560 self.input_hist_raw[:] = [""]
561 self.new_session()
561 self.new_session()
562
562
563 # ------------------------------
563 # ------------------------------
564 # Methods for retrieving history
564 # Methods for retrieving history
565 # ------------------------------
565 # ------------------------------
566 def get_session_info(self, session=0):
566 def get_session_info(self, session=0):
567 """Get info about a session.
567 """Get info about a session.
568
568
569 Parameters
569 Parameters
570 ----------
570 ----------
571
571
572 session : int
572 session : int
573 Session number to retrieve. The current session is 0, and negative
573 Session number to retrieve. The current session is 0, and negative
574 numbers count back from current session, so -1 is the previous session.
574 numbers count back from current session, so -1 is the previous session.
575
575
576 Returns
576 Returns
577 -------
577 -------
578
578
579 session_id : int
579 session_id : int
580 Session ID number
580 Session ID number
581 start : datetime
581 start : datetime
582 Timestamp for the start of the session.
582 Timestamp for the start of the session.
583 end : datetime
583 end : datetime
584 Timestamp for the end of the session, or None if IPython crashed.
584 Timestamp for the end of the session, or None if IPython crashed.
585 num_cmds : int
585 num_cmds : int
586 Number of commands run, or None if IPython crashed.
586 Number of commands run, or None if IPython crashed.
587 remark : unicode
587 remark : unicode
588 A manually set description.
588 A manually set description.
589 """
589 """
590 if session <= 0:
590 if session <= 0:
591 session += self.session_number
591 session += self.session_number
592
592
593 return super(HistoryManager, self).get_session_info(session=session)
593 return super(HistoryManager, self).get_session_info(session=session)
594
594
595 def _get_range_session(self, start=1, stop=None, raw=True, output=False):
595 def _get_range_session(self, start=1, stop=None, raw=True, output=False):
596 """Get input and output history from the current session. Called by
596 """Get input and output history from the current session. Called by
597 get_range, and takes similar parameters."""
597 get_range, and takes similar parameters."""
598 input_hist = self.input_hist_raw if raw else self.input_hist_parsed
598 input_hist = self.input_hist_raw if raw else self.input_hist_parsed
599
599
600 n = len(input_hist)
600 n = len(input_hist)
601 if start < 0:
601 if start < 0:
602 start += n
602 start += n
603 if not stop or (stop > n):
603 if not stop or (stop > n):
604 stop = n
604 stop = n
605 elif stop < 0:
605 elif stop < 0:
606 stop += n
606 stop += n
607
607
608 for i in range(start, stop):
608 for i in range(start, stop):
609 if output:
609 if output:
610 line = (input_hist[i], self.output_hist_reprs.get(i))
610 line = (input_hist[i], self.output_hist_reprs.get(i))
611 else:
611 else:
612 line = input_hist[i]
612 line = input_hist[i]
613 yield (0, i, line)
613 yield (0, i, line)
614
614
615 def get_range(self, session=0, start=1, stop=None, raw=True,output=False):
615 def get_range(self, session=0, start=1, stop=None, raw=True,output=False):
616 """Retrieve input by session.
616 """Retrieve input by session.
617
617
618 Parameters
618 Parameters
619 ----------
619 ----------
620 session : int
620 session : int
621 Session number to retrieve. The current session is 0, and negative
621 Session number to retrieve. The current session is 0, and negative
622 numbers count back from current session, so -1 is previous session.
622 numbers count back from current session, so -1 is previous session.
623 start : int
623 start : int
624 First line to retrieve.
624 First line to retrieve.
625 stop : int
625 stop : int
626 End of line range (excluded from output itself). If None, retrieve
626 End of line range (excluded from output itself). If None, retrieve
627 to the end of the session.
627 to the end of the session.
628 raw : bool
628 raw : bool
629 If True, return untranslated input
629 If True, return untranslated input
630 output : bool
630 output : bool
631 If True, attempt to include output. This will be 'real' Python
631 If True, attempt to include output. This will be 'real' Python
632 objects for the current session, or text reprs from previous
632 objects for the current session, or text reprs from previous
633 sessions if db_log_output was enabled at the time. Where no output
633 sessions if db_log_output was enabled at the time. Where no output
634 is found, None is used.
634 is found, None is used.
635
635
636 Returns
636 Returns
637 -------
637 -------
638 entries
638 entries
639 An iterator over the desired lines. Each line is a 3-tuple, either
639 An iterator over the desired lines. Each line is a 3-tuple, either
640 (session, line, input) if output is False, or
640 (session, line, input) if output is False, or
641 (session, line, (input, output)) if output is True.
641 (session, line, (input, output)) if output is True.
642 """
642 """
643 if session <= 0:
643 if session <= 0:
644 session += self.session_number
644 session += self.session_number
645 if session==self.session_number: # Current session
645 if session==self.session_number: # Current session
646 return self._get_range_session(start, stop, raw, output)
646 return self._get_range_session(start, stop, raw, output)
647 return super(HistoryManager, self).get_range(session, start, stop, raw,
647 return super(HistoryManager, self).get_range(session, start, stop, raw,
648 output)
648 output)
649
649
650 ## ----------------------------
650 ## ----------------------------
651 ## Methods for storing history:
651 ## Methods for storing history:
652 ## ----------------------------
652 ## ----------------------------
653 def store_inputs(self, line_num, source, source_raw=None):
653 def store_inputs(self, line_num, source, source_raw=None):
654 """Store source and raw input in history and create input cache
654 """Store source and raw input in history and create input cache
655 variables ``_i*``.
655 variables ``_i*``.
656
656
657 Parameters
657 Parameters
658 ----------
658 ----------
659 line_num : int
659 line_num : int
660 The prompt number of this input.
660 The prompt number of this input.
661
661
662 source : str
662 source : str
663 Python input.
663 Python input.
664
664
665 source_raw : str, optional
665 source_raw : str, optional
666 If given, this is the raw input without any IPython transformations
666 If given, this is the raw input without any IPython transformations
667 applied to it. If not given, ``source`` is used.
667 applied to it. If not given, ``source`` is used.
668 """
668 """
669 if source_raw is None:
669 if source_raw is None:
670 source_raw = source
670 source_raw = source
671 source = source.rstrip('\n')
671 source = source.rstrip('\n')
672 source_raw = source_raw.rstrip('\n')
672 source_raw = source_raw.rstrip('\n')
673
673
674 # do not store exit/quit commands
674 # do not store exit/quit commands
675 if self._exit_re.match(source_raw.strip()):
675 if self._exit_re.match(source_raw.strip()):
676 return
676 return
677
677
678 self.input_hist_parsed.append(source)
678 self.input_hist_parsed.append(source)
679 self.input_hist_raw.append(source_raw)
679 self.input_hist_raw.append(source_raw)
680
680
681 with self.db_input_cache_lock:
681 with self.db_input_cache_lock:
682 self.db_input_cache.append((line_num, source, source_raw))
682 self.db_input_cache.append((line_num, source, source_raw))
683 # Trigger to flush cache and write to DB.
683 # Trigger to flush cache and write to DB.
684 if len(self.db_input_cache) >= self.db_cache_size:
684 if len(self.db_input_cache) >= self.db_cache_size:
685 self.save_flag.set()
685 self.save_flag.set()
686
686
687 # update the auto _i variables
687 # update the auto _i variables
688 self._iii = self._ii
688 self._iii = self._ii
689 self._ii = self._i
689 self._ii = self._i
690 self._i = self._i00
690 self._i = self._i00
691 self._i00 = source_raw
691 self._i00 = source_raw
692
692
693 # hackish access to user namespace to create _i1,_i2... dynamically
693 # hackish access to user namespace to create _i1,_i2... dynamically
694 new_i = '_i%s' % line_num
694 new_i = '_i%s' % line_num
695 to_main = {'_i': self._i,
695 to_main = {'_i': self._i,
696 '_ii': self._ii,
696 '_ii': self._ii,
697 '_iii': self._iii,
697 '_iii': self._iii,
698 new_i : self._i00 }
698 new_i : self._i00 }
699
699
700 if self.shell is not None:
700 if self.shell is not None:
701 self.shell.push(to_main, interactive=False)
701 self.shell.push(to_main, interactive=False)
702
702
703 def store_output(self, line_num):
703 def store_output(self, line_num):
704 """If database output logging is enabled, this saves all the
704 """If database output logging is enabled, this saves all the
705 outputs from the indicated prompt number to the database. It's
705 outputs from the indicated prompt number to the database. It's
706 called by run_cell after code has been executed.
706 called by run_cell after code has been executed.
707
707
708 Parameters
708 Parameters
709 ----------
709 ----------
710 line_num : int
710 line_num : int
711 The line number from which to save outputs
711 The line number from which to save outputs
712 """
712 """
713 if (not self.db_log_output) or (line_num not in self.output_hist_reprs):
713 if (not self.db_log_output) or (line_num not in self.output_hist_reprs):
714 return
714 return
715 output = self.output_hist_reprs[line_num]
715 output = self.output_hist_reprs[line_num]
716
716
717 with self.db_output_cache_lock:
717 with self.db_output_cache_lock:
718 self.db_output_cache.append((line_num, output))
718 self.db_output_cache.append((line_num, output))
719 if self.db_cache_size <= 1:
719 if self.db_cache_size <= 1:
720 self.save_flag.set()
720 self.save_flag.set()
721
721
722 def _writeout_input_cache(self, conn):
722 def _writeout_input_cache(self, conn):
723 with conn:
723 with conn:
724 for line in self.db_input_cache:
724 for line in self.db_input_cache:
725 conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
725 conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
726 (self.session_number,)+line)
726 (self.session_number,)+line)
727
727
728 def _writeout_output_cache(self, conn):
728 def _writeout_output_cache(self, conn):
729 with conn:
729 with conn:
730 for line in self.db_output_cache:
730 for line in self.db_output_cache:
731 conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
731 conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
732 (self.session_number,)+line)
732 (self.session_number,)+line)
733
733
734 @needs_sqlite
734 @needs_sqlite
735 def writeout_cache(self, conn=None):
735 def writeout_cache(self, conn=None):
736 """Write any entries in the cache to the database."""
736 """Write any entries in the cache to the database."""
737 if conn is None:
737 if conn is None:
738 conn = self.db
738 conn = self.db
739
739
740 with self.db_input_cache_lock:
740 with self.db_input_cache_lock:
741 try:
741 try:
742 self._writeout_input_cache(conn)
742 self._writeout_input_cache(conn)
743 except sqlite3.IntegrityError:
743 except sqlite3.IntegrityError:
744 self.new_session(conn)
744 self.new_session(conn)
745 print("ERROR! Session/line number was not unique in",
745 print("ERROR! Session/line number was not unique in",
746 "database. History logging moved to new session",
746 "database. History logging moved to new session",
747 self.session_number)
747 self.session_number)
748 try:
748 try:
749 # Try writing to the new session. If this fails, don't
749 # Try writing to the new session. If this fails, don't
750 # recurse
750 # recurse
751 self._writeout_input_cache(conn)
751 self._writeout_input_cache(conn)
752 except sqlite3.IntegrityError:
752 except sqlite3.IntegrityError:
753 pass
753 pass
754 finally:
754 finally:
755 self.db_input_cache = []
755 self.db_input_cache = []
756
756
757 with self.db_output_cache_lock:
757 with self.db_output_cache_lock:
758 try:
758 try:
759 self._writeout_output_cache(conn)
759 self._writeout_output_cache(conn)
760 except sqlite3.IntegrityError:
760 except sqlite3.IntegrityError:
761 print("!! Session/line number for output was not unique",
761 print("!! Session/line number for output was not unique",
762 "in database. Output will not be stored.")
762 "in database. Output will not be stored.")
763 finally:
763 finally:
764 self.db_output_cache = []
764 self.db_output_cache = []
765
765
766
766
767 class HistorySavingThread(threading.Thread):
767 class HistorySavingThread(threading.Thread):
768 """This thread takes care of writing history to the database, so that
768 """This thread takes care of writing history to the database, so that
769 the UI isn't held up while that happens.
769 the UI isn't held up while that happens.
770
770
771 It waits for the HistoryManager's save_flag to be set, then writes out
771 It waits for the HistoryManager's save_flag to be set, then writes out
772 the history cache. The main thread is responsible for setting the flag when
772 the history cache. The main thread is responsible for setting the flag when
773 the cache size reaches a defined threshold."""
773 the cache size reaches a defined threshold."""
774 daemon = True
774 daemon = True
775 stop_now = False
775 stop_now = False
776 enabled = True
776 enabled = True
777 def __init__(self, history_manager):
777 def __init__(self, history_manager):
778 super(HistorySavingThread, self).__init__(name="IPythonHistorySavingThread")
778 super(HistorySavingThread, self).__init__(name="IPythonHistorySavingThread")
779 self.history_manager = history_manager
779 self.history_manager = history_manager
780 self.enabled = history_manager.enabled
780 self.enabled = history_manager.enabled
781 atexit.register(self.stop)
781 atexit.register(self.stop)
782
782
783 @needs_sqlite
783 @needs_sqlite
784 def run(self):
784 def run(self):
785 # We need a separate db connection per thread:
785 # We need a separate db connection per thread:
786 try:
786 try:
787 self.db = sqlite3.connect(self.history_manager.hist_file,
787 self.db = sqlite3.connect(self.history_manager.hist_file,
788 **self.history_manager.connection_options
788 **self.history_manager.connection_options
789 )
789 )
790 while True:
790 while True:
791 self.history_manager.save_flag.wait()
791 self.history_manager.save_flag.wait()
792 if self.stop_now:
792 if self.stop_now:
793 self.db.close()
793 self.db.close()
794 return
794 return
795 self.history_manager.save_flag.clear()
795 self.history_manager.save_flag.clear()
796 self.history_manager.writeout_cache(self.db)
796 self.history_manager.writeout_cache(self.db)
797 except Exception as e:
797 except Exception as e:
798 print(("The history saving thread hit an unexpected error (%s)."
798 print(("The history saving thread hit an unexpected error (%s)."
799 "History will not be written to the database.") % repr(e))
799 "History will not be written to the database.") % repr(e))
800
800
801 def stop(self):
801 def stop(self):
802 """This can be called from the main thread to safely stop this thread.
802 """This can be called from the main thread to safely stop this thread.
803
803
804 Note that it does not attempt to write out remaining history before
804 Note that it does not attempt to write out remaining history before
805 exiting. That should be done by calling the HistoryManager's
805 exiting. That should be done by calling the HistoryManager's
806 end_session method."""
806 end_session method."""
807 self.stop_now = True
807 self.stop_now = True
808 self.history_manager.save_flag.set()
808 self.history_manager.save_flag.set()
809 self.join()
809 self.join()
810
810
811
811
812 # To match, e.g. ~5/8-~2/3
812 # To match, e.g. ~5/8-~2/3
813 range_re = re.compile(r"""
813 range_re = re.compile(r"""
814 ((?P<startsess>~?\d+)/)?
814 ((?P<startsess>~?\d+)/)?
815 (?P<start>\d+)?
815 (?P<start>\d+)?
816 ((?P<sep>[\-:])
816 ((?P<sep>[\-:])
817 ((?P<endsess>~?\d+)/)?
817 ((?P<endsess>~?\d+)/)?
818 (?P<end>\d+))?
818 (?P<end>\d+))?
819 $""", re.VERBOSE)
819 $""", re.VERBOSE)
820
820
821
821
822 def extract_hist_ranges(ranges_str):
822 def extract_hist_ranges(ranges_str):
823 """Turn a string of history ranges into 3-tuples of (session, start, stop).
823 """Turn a string of history ranges into 3-tuples of (session, start, stop).
824
824
825 Examples
825 Examples
826 --------
826 --------
827 >>> list(extract_hist_ranges("~8/5-~7/4 2"))
827 >>> list(extract_hist_ranges("~8/5-~7/4 2"))
828 [(-8, 5, None), (-7, 1, 5), (0, 2, 3)]
828 [(-8, 5, None), (-7, 1, 5), (0, 2, 3)]
829 """
829 """
830 for range_str in ranges_str.split():
830 for range_str in ranges_str.split():
831 rmatch = range_re.match(range_str)
831 rmatch = range_re.match(range_str)
832 if not rmatch:
832 if not rmatch:
833 continue
833 continue
834 start = rmatch.group("start")
834 start = rmatch.group("start")
835 if start:
835 if start:
836 start = int(start)
836 start = int(start)
837 end = rmatch.group("end")
837 end = rmatch.group("end")
838 # If no end specified, get (a, a + 1)
838 # If no end specified, get (a, a + 1)
839 end = int(end) if end else start + 1
839 end = int(end) if end else start + 1
840 else: # start not specified
840 else: # start not specified
841 if not rmatch.group('startsess'): # no startsess
841 if not rmatch.group('startsess'): # no startsess
842 continue
842 continue
843 start = 1
843 start = 1
844 end = None # provide the entire session hist
844 end = None # provide the entire session hist
845
845
846 if rmatch.group("sep") == "-": # 1-3 == 1:4 --> [1, 2, 3]
846 if rmatch.group("sep") == "-": # 1-3 == 1:4 --> [1, 2, 3]
847 end += 1
847 end += 1
848 startsess = rmatch.group("startsess") or "0"
848 startsess = rmatch.group("startsess") or "0"
849 endsess = rmatch.group("endsess") or startsess
849 endsess = rmatch.group("endsess") or startsess
850 startsess = int(startsess.replace("~","-"))
850 startsess = int(startsess.replace("~","-"))
851 endsess = int(endsess.replace("~","-"))
851 endsess = int(endsess.replace("~","-"))
852 assert endsess >= startsess, "start session must be earlier than end session"
852 assert endsess >= startsess, "start session must be earlier than end session"
853
853
854 if endsess == startsess:
854 if endsess == startsess:
855 yield (startsess, start, end)
855 yield (startsess, start, end)
856 continue
856 continue
857 # Multiple sessions in one range:
857 # Multiple sessions in one range:
858 yield (startsess, start, None)
858 yield (startsess, start, None)
859 for sess in range(startsess+1, endsess):
859 for sess in range(startsess+1, endsess):
860 yield (sess, 1, None)
860 yield (sess, 1, None)
861 yield (endsess, 1, end)
861 yield (endsess, 1, end)
862
862
863
863
864 def _format_lineno(session, line):
864 def _format_lineno(session, line):
865 """Helper function to format line numbers properly."""
865 """Helper function to format line numbers properly."""
866 if session == 0:
866 if session == 0:
867 return str(line)
867 return str(line)
868 return "%s#%s" % (session, line)
868 return "%s#%s" % (session, line)
869
869
870
870
@@ -1,702 +1,702 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Magic functions for InteractiveShell.
2 """Magic functions for InteractiveShell.
3 """
3 """
4 from __future__ import print_function
4 from __future__ import print_function
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2001 Janko Hauser <jhauser@zscout.de> and
7 # Copyright (C) 2001 Janko Hauser <jhauser@zscout.de> and
8 # Copyright (C) 2001 Fernando Perez <fperez@colorado.edu>
8 # Copyright (C) 2001 Fernando Perez <fperez@colorado.edu>
9 # Copyright (C) 2008 The IPython Development Team
9 # Copyright (C) 2008 The IPython Development Team
10
10
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 # Stdlib
18 # Stdlib
19 import os
19 import os
20 import re
20 import re
21 import sys
21 import sys
22 import types
22 import types
23 from getopt import getopt, GetoptError
23 from getopt import getopt, GetoptError
24
24
25 # Our own
25 # Our own
26 from IPython.config.configurable import Configurable
26 from IPython.config.configurable import Configurable
27 from IPython.core import oinspect
27 from IPython.core import oinspect
28 from IPython.core.error import UsageError
28 from IPython.core.error import UsageError
29 from IPython.core.inputsplitter import ESC_MAGIC, ESC_MAGIC2
29 from IPython.core.inputsplitter import ESC_MAGIC, ESC_MAGIC2
30 from IPython.external.decorator import decorator
30 from decorator import decorator
31 from IPython.utils.ipstruct import Struct
31 from IPython.utils.ipstruct import Struct
32 from IPython.utils.process import arg_split
32 from IPython.utils.process import arg_split
33 from IPython.utils.py3compat import string_types, iteritems
33 from IPython.utils.py3compat import string_types, iteritems
34 from IPython.utils.text import dedent
34 from IPython.utils.text import dedent
35 from IPython.utils.traitlets import Bool, Dict, Instance, MetaHasTraits
35 from IPython.utils.traitlets import Bool, Dict, Instance, MetaHasTraits
36 from IPython.utils.warn import error
36 from IPython.utils.warn import error
37
37
38 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
39 # Globals
39 # Globals
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41
41
42 # A dict we'll use for each class that has magics, used as temporary storage to
42 # A dict we'll use for each class that has magics, used as temporary storage to
43 # pass information between the @line/cell_magic method decorators and the
43 # pass information between the @line/cell_magic method decorators and the
44 # @magics_class class decorator, because the method decorators have no
44 # @magics_class class decorator, because the method decorators have no
45 # access to the class when they run. See for more details:
45 # access to the class when they run. See for more details:
46 # http://stackoverflow.com/questions/2366713/can-a-python-decorator-of-an-instance-method-access-the-class
46 # http://stackoverflow.com/questions/2366713/can-a-python-decorator-of-an-instance-method-access-the-class
47
47
48 magics = dict(line={}, cell={})
48 magics = dict(line={}, cell={})
49
49
50 magic_kinds = ('line', 'cell')
50 magic_kinds = ('line', 'cell')
51 magic_spec = ('line', 'cell', 'line_cell')
51 magic_spec = ('line', 'cell', 'line_cell')
52 magic_escapes = dict(line=ESC_MAGIC, cell=ESC_MAGIC2)
52 magic_escapes = dict(line=ESC_MAGIC, cell=ESC_MAGIC2)
53
53
54 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
55 # Utility classes and functions
55 # Utility classes and functions
56 #-----------------------------------------------------------------------------
56 #-----------------------------------------------------------------------------
57
57
58 class Bunch: pass
58 class Bunch: pass
59
59
60
60
61 def on_off(tag):
61 def on_off(tag):
62 """Return an ON/OFF string for a 1/0 input. Simple utility function."""
62 """Return an ON/OFF string for a 1/0 input. Simple utility function."""
63 return ['OFF','ON'][tag]
63 return ['OFF','ON'][tag]
64
64
65
65
66 def compress_dhist(dh):
66 def compress_dhist(dh):
67 """Compress a directory history into a new one with at most 20 entries.
67 """Compress a directory history into a new one with at most 20 entries.
68
68
69 Return a new list made from the first and last 10 elements of dhist after
69 Return a new list made from the first and last 10 elements of dhist after
70 removal of duplicates.
70 removal of duplicates.
71 """
71 """
72 head, tail = dh[:-10], dh[-10:]
72 head, tail = dh[:-10], dh[-10:]
73
73
74 newhead = []
74 newhead = []
75 done = set()
75 done = set()
76 for h in head:
76 for h in head:
77 if h in done:
77 if h in done:
78 continue
78 continue
79 newhead.append(h)
79 newhead.append(h)
80 done.add(h)
80 done.add(h)
81
81
82 return newhead + tail
82 return newhead + tail
83
83
84
84
85 def needs_local_scope(func):
85 def needs_local_scope(func):
86 """Decorator to mark magic functions which need to local scope to run."""
86 """Decorator to mark magic functions which need to local scope to run."""
87 func.needs_local_scope = True
87 func.needs_local_scope = True
88 return func
88 return func
89
89
90 #-----------------------------------------------------------------------------
90 #-----------------------------------------------------------------------------
91 # Class and method decorators for registering magics
91 # Class and method decorators for registering magics
92 #-----------------------------------------------------------------------------
92 #-----------------------------------------------------------------------------
93
93
94 def magics_class(cls):
94 def magics_class(cls):
95 """Class decorator for all subclasses of the main Magics class.
95 """Class decorator for all subclasses of the main Magics class.
96
96
97 Any class that subclasses Magics *must* also apply this decorator, to
97 Any class that subclasses Magics *must* also apply this decorator, to
98 ensure that all the methods that have been decorated as line/cell magics
98 ensure that all the methods that have been decorated as line/cell magics
99 get correctly registered in the class instance. This is necessary because
99 get correctly registered in the class instance. This is necessary because
100 when method decorators run, the class does not exist yet, so they
100 when method decorators run, the class does not exist yet, so they
101 temporarily store their information into a module global. Application of
101 temporarily store their information into a module global. Application of
102 this class decorator copies that global data to the class instance and
102 this class decorator copies that global data to the class instance and
103 clears the global.
103 clears the global.
104
104
105 Obviously, this mechanism is not thread-safe, which means that the
105 Obviously, this mechanism is not thread-safe, which means that the
106 *creation* of subclasses of Magic should only be done in a single-thread
106 *creation* of subclasses of Magic should only be done in a single-thread
107 context. Instantiation of the classes has no restrictions. Given that
107 context. Instantiation of the classes has no restrictions. Given that
108 these classes are typically created at IPython startup time and before user
108 these classes are typically created at IPython startup time and before user
109 application code becomes active, in practice this should not pose any
109 application code becomes active, in practice this should not pose any
110 problems.
110 problems.
111 """
111 """
112 cls.registered = True
112 cls.registered = True
113 cls.magics = dict(line = magics['line'],
113 cls.magics = dict(line = magics['line'],
114 cell = magics['cell'])
114 cell = magics['cell'])
115 magics['line'] = {}
115 magics['line'] = {}
116 magics['cell'] = {}
116 magics['cell'] = {}
117 return cls
117 return cls
118
118
119
119
120 def record_magic(dct, magic_kind, magic_name, func):
120 def record_magic(dct, magic_kind, magic_name, func):
121 """Utility function to store a function as a magic of a specific kind.
121 """Utility function to store a function as a magic of a specific kind.
122
122
123 Parameters
123 Parameters
124 ----------
124 ----------
125 dct : dict
125 dct : dict
126 A dictionary with 'line' and 'cell' subdicts.
126 A dictionary with 'line' and 'cell' subdicts.
127
127
128 magic_kind : str
128 magic_kind : str
129 Kind of magic to be stored.
129 Kind of magic to be stored.
130
130
131 magic_name : str
131 magic_name : str
132 Key to store the magic as.
132 Key to store the magic as.
133
133
134 func : function
134 func : function
135 Callable object to store.
135 Callable object to store.
136 """
136 """
137 if magic_kind == 'line_cell':
137 if magic_kind == 'line_cell':
138 dct['line'][magic_name] = dct['cell'][magic_name] = func
138 dct['line'][magic_name] = dct['cell'][magic_name] = func
139 else:
139 else:
140 dct[magic_kind][magic_name] = func
140 dct[magic_kind][magic_name] = func
141
141
142
142
143 def validate_type(magic_kind):
143 def validate_type(magic_kind):
144 """Ensure that the given magic_kind is valid.
144 """Ensure that the given magic_kind is valid.
145
145
146 Check that the given magic_kind is one of the accepted spec types (stored
146 Check that the given magic_kind is one of the accepted spec types (stored
147 in the global `magic_spec`), raise ValueError otherwise.
147 in the global `magic_spec`), raise ValueError otherwise.
148 """
148 """
149 if magic_kind not in magic_spec:
149 if magic_kind not in magic_spec:
150 raise ValueError('magic_kind must be one of %s, %s given' %
150 raise ValueError('magic_kind must be one of %s, %s given' %
151 magic_kinds, magic_kind)
151 magic_kinds, magic_kind)
152
152
153
153
154 # The docstrings for the decorator below will be fairly similar for the two
154 # The docstrings for the decorator below will be fairly similar for the two
155 # types (method and function), so we generate them here once and reuse the
155 # types (method and function), so we generate them here once and reuse the
156 # templates below.
156 # templates below.
157 _docstring_template = \
157 _docstring_template = \
158 """Decorate the given {0} as {1} magic.
158 """Decorate the given {0} as {1} magic.
159
159
160 The decorator can be used with or without arguments, as follows.
160 The decorator can be used with or without arguments, as follows.
161
161
162 i) without arguments: it will create a {1} magic named as the {0} being
162 i) without arguments: it will create a {1} magic named as the {0} being
163 decorated::
163 decorated::
164
164
165 @deco
165 @deco
166 def foo(...)
166 def foo(...)
167
167
168 will create a {1} magic named `foo`.
168 will create a {1} magic named `foo`.
169
169
170 ii) with one string argument: which will be used as the actual name of the
170 ii) with one string argument: which will be used as the actual name of the
171 resulting magic::
171 resulting magic::
172
172
173 @deco('bar')
173 @deco('bar')
174 def foo(...)
174 def foo(...)
175
175
176 will create a {1} magic named `bar`.
176 will create a {1} magic named `bar`.
177 """
177 """
178
178
179 # These two are decorator factories. While they are conceptually very similar,
179 # These two are decorator factories. While they are conceptually very similar,
180 # there are enough differences in the details that it's simpler to have them
180 # there are enough differences in the details that it's simpler to have them
181 # written as completely standalone functions rather than trying to share code
181 # written as completely standalone functions rather than trying to share code
182 # and make a single one with convoluted logic.
182 # and make a single one with convoluted logic.
183
183
184 def _method_magic_marker(magic_kind):
184 def _method_magic_marker(magic_kind):
185 """Decorator factory for methods in Magics subclasses.
185 """Decorator factory for methods in Magics subclasses.
186 """
186 """
187
187
188 validate_type(magic_kind)
188 validate_type(magic_kind)
189
189
190 # This is a closure to capture the magic_kind. We could also use a class,
190 # This is a closure to capture the magic_kind. We could also use a class,
191 # but it's overkill for just that one bit of state.
191 # but it's overkill for just that one bit of state.
192 def magic_deco(arg):
192 def magic_deco(arg):
193 call = lambda f, *a, **k: f(*a, **k)
193 call = lambda f, *a, **k: f(*a, **k)
194
194
195 if callable(arg):
195 if callable(arg):
196 # "Naked" decorator call (just @foo, no args)
196 # "Naked" decorator call (just @foo, no args)
197 func = arg
197 func = arg
198 name = func.__name__
198 name = func.__name__
199 retval = decorator(call, func)
199 retval = decorator(call, func)
200 record_magic(magics, magic_kind, name, name)
200 record_magic(magics, magic_kind, name, name)
201 elif isinstance(arg, string_types):
201 elif isinstance(arg, string_types):
202 # Decorator called with arguments (@foo('bar'))
202 # Decorator called with arguments (@foo('bar'))
203 name = arg
203 name = arg
204 def mark(func, *a, **kw):
204 def mark(func, *a, **kw):
205 record_magic(magics, magic_kind, name, func.__name__)
205 record_magic(magics, magic_kind, name, func.__name__)
206 return decorator(call, func)
206 return decorator(call, func)
207 retval = mark
207 retval = mark
208 else:
208 else:
209 raise TypeError("Decorator can only be called with "
209 raise TypeError("Decorator can only be called with "
210 "string or function")
210 "string or function")
211 return retval
211 return retval
212
212
213 # Ensure the resulting decorator has a usable docstring
213 # Ensure the resulting decorator has a usable docstring
214 magic_deco.__doc__ = _docstring_template.format('method', magic_kind)
214 magic_deco.__doc__ = _docstring_template.format('method', magic_kind)
215 return magic_deco
215 return magic_deco
216
216
217
217
218 def _function_magic_marker(magic_kind):
218 def _function_magic_marker(magic_kind):
219 """Decorator factory for standalone functions.
219 """Decorator factory for standalone functions.
220 """
220 """
221 validate_type(magic_kind)
221 validate_type(magic_kind)
222
222
223 # This is a closure to capture the magic_kind. We could also use a class,
223 # This is a closure to capture the magic_kind. We could also use a class,
224 # but it's overkill for just that one bit of state.
224 # but it's overkill for just that one bit of state.
225 def magic_deco(arg):
225 def magic_deco(arg):
226 call = lambda f, *a, **k: f(*a, **k)
226 call = lambda f, *a, **k: f(*a, **k)
227
227
228 # Find get_ipython() in the caller's namespace
228 # Find get_ipython() in the caller's namespace
229 caller = sys._getframe(1)
229 caller = sys._getframe(1)
230 for ns in ['f_locals', 'f_globals', 'f_builtins']:
230 for ns in ['f_locals', 'f_globals', 'f_builtins']:
231 get_ipython = getattr(caller, ns).get('get_ipython')
231 get_ipython = getattr(caller, ns).get('get_ipython')
232 if get_ipython is not None:
232 if get_ipython is not None:
233 break
233 break
234 else:
234 else:
235 raise NameError('Decorator can only run in context where '
235 raise NameError('Decorator can only run in context where '
236 '`get_ipython` exists')
236 '`get_ipython` exists')
237
237
238 ip = get_ipython()
238 ip = get_ipython()
239
239
240 if callable(arg):
240 if callable(arg):
241 # "Naked" decorator call (just @foo, no args)
241 # "Naked" decorator call (just @foo, no args)
242 func = arg
242 func = arg
243 name = func.__name__
243 name = func.__name__
244 ip.register_magic_function(func, magic_kind, name)
244 ip.register_magic_function(func, magic_kind, name)
245 retval = decorator(call, func)
245 retval = decorator(call, func)
246 elif isinstance(arg, string_types):
246 elif isinstance(arg, string_types):
247 # Decorator called with arguments (@foo('bar'))
247 # Decorator called with arguments (@foo('bar'))
248 name = arg
248 name = arg
249 def mark(func, *a, **kw):
249 def mark(func, *a, **kw):
250 ip.register_magic_function(func, magic_kind, name)
250 ip.register_magic_function(func, magic_kind, name)
251 return decorator(call, func)
251 return decorator(call, func)
252 retval = mark
252 retval = mark
253 else:
253 else:
254 raise TypeError("Decorator can only be called with "
254 raise TypeError("Decorator can only be called with "
255 "string or function")
255 "string or function")
256 return retval
256 return retval
257
257
258 # Ensure the resulting decorator has a usable docstring
258 # Ensure the resulting decorator has a usable docstring
259 ds = _docstring_template.format('function', magic_kind)
259 ds = _docstring_template.format('function', magic_kind)
260
260
261 ds += dedent("""
261 ds += dedent("""
262 Note: this decorator can only be used in a context where IPython is already
262 Note: this decorator can only be used in a context where IPython is already
263 active, so that the `get_ipython()` call succeeds. You can therefore use
263 active, so that the `get_ipython()` call succeeds. You can therefore use
264 it in your startup files loaded after IPython initializes, but *not* in the
264 it in your startup files loaded after IPython initializes, but *not* in the
265 IPython configuration file itself, which is executed before IPython is
265 IPython configuration file itself, which is executed before IPython is
266 fully up and running. Any file located in the `startup` subdirectory of
266 fully up and running. Any file located in the `startup` subdirectory of
267 your configuration profile will be OK in this sense.
267 your configuration profile will be OK in this sense.
268 """)
268 """)
269
269
270 magic_deco.__doc__ = ds
270 magic_deco.__doc__ = ds
271 return magic_deco
271 return magic_deco
272
272
273
273
274 # Create the actual decorators for public use
274 # Create the actual decorators for public use
275
275
276 # These three are used to decorate methods in class definitions
276 # These three are used to decorate methods in class definitions
277 line_magic = _method_magic_marker('line')
277 line_magic = _method_magic_marker('line')
278 cell_magic = _method_magic_marker('cell')
278 cell_magic = _method_magic_marker('cell')
279 line_cell_magic = _method_magic_marker('line_cell')
279 line_cell_magic = _method_magic_marker('line_cell')
280
280
281 # These three decorate standalone functions and perform the decoration
281 # These three decorate standalone functions and perform the decoration
282 # immediately. They can only run where get_ipython() works
282 # immediately. They can only run where get_ipython() works
283 register_line_magic = _function_magic_marker('line')
283 register_line_magic = _function_magic_marker('line')
284 register_cell_magic = _function_magic_marker('cell')
284 register_cell_magic = _function_magic_marker('cell')
285 register_line_cell_magic = _function_magic_marker('line_cell')
285 register_line_cell_magic = _function_magic_marker('line_cell')
286
286
287 #-----------------------------------------------------------------------------
287 #-----------------------------------------------------------------------------
288 # Core Magic classes
288 # Core Magic classes
289 #-----------------------------------------------------------------------------
289 #-----------------------------------------------------------------------------
290
290
291 class MagicsManager(Configurable):
291 class MagicsManager(Configurable):
292 """Object that handles all magic-related functionality for IPython.
292 """Object that handles all magic-related functionality for IPython.
293 """
293 """
294 # Non-configurable class attributes
294 # Non-configurable class attributes
295
295
296 # A two-level dict, first keyed by magic type, then by magic function, and
296 # A two-level dict, first keyed by magic type, then by magic function, and
297 # holding the actual callable object as value. This is the dict used for
297 # holding the actual callable object as value. This is the dict used for
298 # magic function dispatch
298 # magic function dispatch
299 magics = Dict
299 magics = Dict
300
300
301 # A registry of the original objects that we've been given holding magics.
301 # A registry of the original objects that we've been given holding magics.
302 registry = Dict
302 registry = Dict
303
303
304 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
304 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
305
305
306 auto_magic = Bool(True, config=True, help=
306 auto_magic = Bool(True, config=True, help=
307 "Automatically call line magics without requiring explicit % prefix")
307 "Automatically call line magics without requiring explicit % prefix")
308
308
309 def _auto_magic_changed(self, name, value):
309 def _auto_magic_changed(self, name, value):
310 self.shell.automagic = value
310 self.shell.automagic = value
311
311
312 _auto_status = [
312 _auto_status = [
313 'Automagic is OFF, % prefix IS needed for line magics.',
313 'Automagic is OFF, % prefix IS needed for line magics.',
314 'Automagic is ON, % prefix IS NOT needed for line magics.']
314 'Automagic is ON, % prefix IS NOT needed for line magics.']
315
315
316 user_magics = Instance('IPython.core.magics.UserMagics')
316 user_magics = Instance('IPython.core.magics.UserMagics')
317
317
318 def __init__(self, shell=None, config=None, user_magics=None, **traits):
318 def __init__(self, shell=None, config=None, user_magics=None, **traits):
319
319
320 super(MagicsManager, self).__init__(shell=shell, config=config,
320 super(MagicsManager, self).__init__(shell=shell, config=config,
321 user_magics=user_magics, **traits)
321 user_magics=user_magics, **traits)
322 self.magics = dict(line={}, cell={})
322 self.magics = dict(line={}, cell={})
323 # Let's add the user_magics to the registry for uniformity, so *all*
323 # Let's add the user_magics to the registry for uniformity, so *all*
324 # registered magic containers can be found there.
324 # registered magic containers can be found there.
325 self.registry[user_magics.__class__.__name__] = user_magics
325 self.registry[user_magics.__class__.__name__] = user_magics
326
326
327 def auto_status(self):
327 def auto_status(self):
328 """Return descriptive string with automagic status."""
328 """Return descriptive string with automagic status."""
329 return self._auto_status[self.auto_magic]
329 return self._auto_status[self.auto_magic]
330
330
331 def lsmagic(self):
331 def lsmagic(self):
332 """Return a dict of currently available magic functions.
332 """Return a dict of currently available magic functions.
333
333
334 The return dict has the keys 'line' and 'cell', corresponding to the
334 The return dict has the keys 'line' and 'cell', corresponding to the
335 two types of magics we support. Each value is a list of names.
335 two types of magics we support. Each value is a list of names.
336 """
336 """
337 return self.magics
337 return self.magics
338
338
339 def lsmagic_docs(self, brief=False, missing=''):
339 def lsmagic_docs(self, brief=False, missing=''):
340 """Return dict of documentation of magic functions.
340 """Return dict of documentation of magic functions.
341
341
342 The return dict has the keys 'line' and 'cell', corresponding to the
342 The return dict has the keys 'line' and 'cell', corresponding to the
343 two types of magics we support. Each value is a dict keyed by magic
343 two types of magics we support. Each value is a dict keyed by magic
344 name whose value is the function docstring. If a docstring is
344 name whose value is the function docstring. If a docstring is
345 unavailable, the value of `missing` is used instead.
345 unavailable, the value of `missing` is used instead.
346
346
347 If brief is True, only the first line of each docstring will be returned.
347 If brief is True, only the first line of each docstring will be returned.
348 """
348 """
349 docs = {}
349 docs = {}
350 for m_type in self.magics:
350 for m_type in self.magics:
351 m_docs = {}
351 m_docs = {}
352 for m_name, m_func in iteritems(self.magics[m_type]):
352 for m_name, m_func in iteritems(self.magics[m_type]):
353 if m_func.__doc__:
353 if m_func.__doc__:
354 if brief:
354 if brief:
355 m_docs[m_name] = m_func.__doc__.split('\n', 1)[0]
355 m_docs[m_name] = m_func.__doc__.split('\n', 1)[0]
356 else:
356 else:
357 m_docs[m_name] = m_func.__doc__.rstrip()
357 m_docs[m_name] = m_func.__doc__.rstrip()
358 else:
358 else:
359 m_docs[m_name] = missing
359 m_docs[m_name] = missing
360 docs[m_type] = m_docs
360 docs[m_type] = m_docs
361 return docs
361 return docs
362
362
363 def register(self, *magic_objects):
363 def register(self, *magic_objects):
364 """Register one or more instances of Magics.
364 """Register one or more instances of Magics.
365
365
366 Take one or more classes or instances of classes that subclass the main
366 Take one or more classes or instances of classes that subclass the main
367 `core.Magic` class, and register them with IPython to use the magic
367 `core.Magic` class, and register them with IPython to use the magic
368 functions they provide. The registration process will then ensure that
368 functions they provide. The registration process will then ensure that
369 any methods that have decorated to provide line and/or cell magics will
369 any methods that have decorated to provide line and/or cell magics will
370 be recognized with the `%x`/`%%x` syntax as a line/cell magic
370 be recognized with the `%x`/`%%x` syntax as a line/cell magic
371 respectively.
371 respectively.
372
372
373 If classes are given, they will be instantiated with the default
373 If classes are given, they will be instantiated with the default
374 constructor. If your classes need a custom constructor, you should
374 constructor. If your classes need a custom constructor, you should
375 instanitate them first and pass the instance.
375 instanitate them first and pass the instance.
376
376
377 The provided arguments can be an arbitrary mix of classes and instances.
377 The provided arguments can be an arbitrary mix of classes and instances.
378
378
379 Parameters
379 Parameters
380 ----------
380 ----------
381 magic_objects : one or more classes or instances
381 magic_objects : one or more classes or instances
382 """
382 """
383 # Start by validating them to ensure they have all had their magic
383 # Start by validating them to ensure they have all had their magic
384 # methods registered at the instance level
384 # methods registered at the instance level
385 for m in magic_objects:
385 for m in magic_objects:
386 if not m.registered:
386 if not m.registered:
387 raise ValueError("Class of magics %r was constructed without "
387 raise ValueError("Class of magics %r was constructed without "
388 "the @register_magics class decorator")
388 "the @register_magics class decorator")
389 if type(m) in (type, MetaHasTraits):
389 if type(m) in (type, MetaHasTraits):
390 # If we're given an uninstantiated class
390 # If we're given an uninstantiated class
391 m = m(shell=self.shell)
391 m = m(shell=self.shell)
392
392
393 # Now that we have an instance, we can register it and update the
393 # Now that we have an instance, we can register it and update the
394 # table of callables
394 # table of callables
395 self.registry[m.__class__.__name__] = m
395 self.registry[m.__class__.__name__] = m
396 for mtype in magic_kinds:
396 for mtype in magic_kinds:
397 self.magics[mtype].update(m.magics[mtype])
397 self.magics[mtype].update(m.magics[mtype])
398
398
399 def register_function(self, func, magic_kind='line', magic_name=None):
399 def register_function(self, func, magic_kind='line', magic_name=None):
400 """Expose a standalone function as magic function for IPython.
400 """Expose a standalone function as magic function for IPython.
401
401
402 This will create an IPython magic (line, cell or both) from a
402 This will create an IPython magic (line, cell or both) from a
403 standalone function. The functions should have the following
403 standalone function. The functions should have the following
404 signatures:
404 signatures:
405
405
406 * For line magics: `def f(line)`
406 * For line magics: `def f(line)`
407 * For cell magics: `def f(line, cell)`
407 * For cell magics: `def f(line, cell)`
408 * For a function that does both: `def f(line, cell=None)`
408 * For a function that does both: `def f(line, cell=None)`
409
409
410 In the latter case, the function will be called with `cell==None` when
410 In the latter case, the function will be called with `cell==None` when
411 invoked as `%f`, and with cell as a string when invoked as `%%f`.
411 invoked as `%f`, and with cell as a string when invoked as `%%f`.
412
412
413 Parameters
413 Parameters
414 ----------
414 ----------
415 func : callable
415 func : callable
416 Function to be registered as a magic.
416 Function to be registered as a magic.
417
417
418 magic_kind : str
418 magic_kind : str
419 Kind of magic, one of 'line', 'cell' or 'line_cell'
419 Kind of magic, one of 'line', 'cell' or 'line_cell'
420
420
421 magic_name : optional str
421 magic_name : optional str
422 If given, the name the magic will have in the IPython namespace. By
422 If given, the name the magic will have in the IPython namespace. By
423 default, the name of the function itself is used.
423 default, the name of the function itself is used.
424 """
424 """
425
425
426 # Create the new method in the user_magics and register it in the
426 # Create the new method in the user_magics and register it in the
427 # global table
427 # global table
428 validate_type(magic_kind)
428 validate_type(magic_kind)
429 magic_name = func.__name__ if magic_name is None else magic_name
429 magic_name = func.__name__ if magic_name is None else magic_name
430 setattr(self.user_magics, magic_name, func)
430 setattr(self.user_magics, magic_name, func)
431 record_magic(self.magics, magic_kind, magic_name, func)
431 record_magic(self.magics, magic_kind, magic_name, func)
432
432
433 def define_magic(self, name, func):
433 def define_magic(self, name, func):
434 """[Deprecated] Expose own function as magic function for IPython.
434 """[Deprecated] Expose own function as magic function for IPython.
435
435
436 Example::
436 Example::
437
437
438 def foo_impl(self, parameter_s=''):
438 def foo_impl(self, parameter_s=''):
439 'My very own magic!. (Use docstrings, IPython reads them).'
439 'My very own magic!. (Use docstrings, IPython reads them).'
440 print 'Magic function. Passed parameter is between < >:'
440 print 'Magic function. Passed parameter is between < >:'
441 print '<%s>' % parameter_s
441 print '<%s>' % parameter_s
442 print 'The self object is:', self
442 print 'The self object is:', self
443
443
444 ip.define_magic('foo',foo_impl)
444 ip.define_magic('foo',foo_impl)
445 """
445 """
446 meth = types.MethodType(func, self.user_magics)
446 meth = types.MethodType(func, self.user_magics)
447 setattr(self.user_magics, name, meth)
447 setattr(self.user_magics, name, meth)
448 record_magic(self.magics, 'line', name, meth)
448 record_magic(self.magics, 'line', name, meth)
449
449
450 def register_alias(self, alias_name, magic_name, magic_kind='line'):
450 def register_alias(self, alias_name, magic_name, magic_kind='line'):
451 """Register an alias to a magic function.
451 """Register an alias to a magic function.
452
452
453 The alias is an instance of :class:`MagicAlias`, which holds the
453 The alias is an instance of :class:`MagicAlias`, which holds the
454 name and kind of the magic it should call. Binding is done at
454 name and kind of the magic it should call. Binding is done at
455 call time, so if the underlying magic function is changed the alias
455 call time, so if the underlying magic function is changed the alias
456 will call the new function.
456 will call the new function.
457
457
458 Parameters
458 Parameters
459 ----------
459 ----------
460 alias_name : str
460 alias_name : str
461 The name of the magic to be registered.
461 The name of the magic to be registered.
462
462
463 magic_name : str
463 magic_name : str
464 The name of an existing magic.
464 The name of an existing magic.
465
465
466 magic_kind : str
466 magic_kind : str
467 Kind of magic, one of 'line' or 'cell'
467 Kind of magic, one of 'line' or 'cell'
468 """
468 """
469
469
470 # `validate_type` is too permissive, as it allows 'line_cell'
470 # `validate_type` is too permissive, as it allows 'line_cell'
471 # which we do not handle.
471 # which we do not handle.
472 if magic_kind not in magic_kinds:
472 if magic_kind not in magic_kinds:
473 raise ValueError('magic_kind must be one of %s, %s given' %
473 raise ValueError('magic_kind must be one of %s, %s given' %
474 magic_kinds, magic_kind)
474 magic_kinds, magic_kind)
475
475
476 alias = MagicAlias(self.shell, magic_name, magic_kind)
476 alias = MagicAlias(self.shell, magic_name, magic_kind)
477 setattr(self.user_magics, alias_name, alias)
477 setattr(self.user_magics, alias_name, alias)
478 record_magic(self.magics, magic_kind, alias_name, alias)
478 record_magic(self.magics, magic_kind, alias_name, alias)
479
479
480 # Key base class that provides the central functionality for magics.
480 # Key base class that provides the central functionality for magics.
481
481
482
482
483 class Magics(Configurable):
483 class Magics(Configurable):
484 """Base class for implementing magic functions.
484 """Base class for implementing magic functions.
485
485
486 Shell functions which can be reached as %function_name. All magic
486 Shell functions which can be reached as %function_name. All magic
487 functions should accept a string, which they can parse for their own
487 functions should accept a string, which they can parse for their own
488 needs. This can make some functions easier to type, eg `%cd ../`
488 needs. This can make some functions easier to type, eg `%cd ../`
489 vs. `%cd("../")`
489 vs. `%cd("../")`
490
490
491 Classes providing magic functions need to subclass this class, and they
491 Classes providing magic functions need to subclass this class, and they
492 MUST:
492 MUST:
493
493
494 - Use the method decorators `@line_magic` and `@cell_magic` to decorate
494 - Use the method decorators `@line_magic` and `@cell_magic` to decorate
495 individual methods as magic functions, AND
495 individual methods as magic functions, AND
496
496
497 - Use the class decorator `@magics_class` to ensure that the magic
497 - Use the class decorator `@magics_class` to ensure that the magic
498 methods are properly registered at the instance level upon instance
498 methods are properly registered at the instance level upon instance
499 initialization.
499 initialization.
500
500
501 See :mod:`magic_functions` for examples of actual implementation classes.
501 See :mod:`magic_functions` for examples of actual implementation classes.
502 """
502 """
503 # Dict holding all command-line options for each magic.
503 # Dict holding all command-line options for each magic.
504 options_table = None
504 options_table = None
505 # Dict for the mapping of magic names to methods, set by class decorator
505 # Dict for the mapping of magic names to methods, set by class decorator
506 magics = None
506 magics = None
507 # Flag to check that the class decorator was properly applied
507 # Flag to check that the class decorator was properly applied
508 registered = False
508 registered = False
509 # Instance of IPython shell
509 # Instance of IPython shell
510 shell = None
510 shell = None
511
511
512 def __init__(self, shell=None, **kwargs):
512 def __init__(self, shell=None, **kwargs):
513 if not(self.__class__.registered):
513 if not(self.__class__.registered):
514 raise ValueError('Magics subclass without registration - '
514 raise ValueError('Magics subclass without registration - '
515 'did you forget to apply @magics_class?')
515 'did you forget to apply @magics_class?')
516 if shell is not None:
516 if shell is not None:
517 if hasattr(shell, 'configurables'):
517 if hasattr(shell, 'configurables'):
518 shell.configurables.append(self)
518 shell.configurables.append(self)
519 if hasattr(shell, 'config'):
519 if hasattr(shell, 'config'):
520 kwargs.setdefault('parent', shell)
520 kwargs.setdefault('parent', shell)
521 kwargs['shell'] = shell
521 kwargs['shell'] = shell
522
522
523 self.shell = shell
523 self.shell = shell
524 self.options_table = {}
524 self.options_table = {}
525 # The method decorators are run when the instance doesn't exist yet, so
525 # The method decorators are run when the instance doesn't exist yet, so
526 # they can only record the names of the methods they are supposed to
526 # they can only record the names of the methods they are supposed to
527 # grab. Only now, that the instance exists, can we create the proper
527 # grab. Only now, that the instance exists, can we create the proper
528 # mapping to bound methods. So we read the info off the original names
528 # mapping to bound methods. So we read the info off the original names
529 # table and replace each method name by the actual bound method.
529 # table and replace each method name by the actual bound method.
530 # But we mustn't clobber the *class* mapping, in case of multiple instances.
530 # But we mustn't clobber the *class* mapping, in case of multiple instances.
531 class_magics = self.magics
531 class_magics = self.magics
532 self.magics = {}
532 self.magics = {}
533 for mtype in magic_kinds:
533 for mtype in magic_kinds:
534 tab = self.magics[mtype] = {}
534 tab = self.magics[mtype] = {}
535 cls_tab = class_magics[mtype]
535 cls_tab = class_magics[mtype]
536 for magic_name, meth_name in iteritems(cls_tab):
536 for magic_name, meth_name in iteritems(cls_tab):
537 if isinstance(meth_name, string_types):
537 if isinstance(meth_name, string_types):
538 # it's a method name, grab it
538 # it's a method name, grab it
539 tab[magic_name] = getattr(self, meth_name)
539 tab[magic_name] = getattr(self, meth_name)
540 else:
540 else:
541 # it's the real thing
541 # it's the real thing
542 tab[magic_name] = meth_name
542 tab[magic_name] = meth_name
543 # Configurable **needs** to be initiated at the end or the config
543 # Configurable **needs** to be initiated at the end or the config
544 # magics get screwed up.
544 # magics get screwed up.
545 super(Magics, self).__init__(**kwargs)
545 super(Magics, self).__init__(**kwargs)
546
546
547 def arg_err(self,func):
547 def arg_err(self,func):
548 """Print docstring if incorrect arguments were passed"""
548 """Print docstring if incorrect arguments were passed"""
549 print('Error in arguments:')
549 print('Error in arguments:')
550 print(oinspect.getdoc(func))
550 print(oinspect.getdoc(func))
551
551
552 def format_latex(self, strng):
552 def format_latex(self, strng):
553 """Format a string for latex inclusion."""
553 """Format a string for latex inclusion."""
554
554
555 # Characters that need to be escaped for latex:
555 # Characters that need to be escaped for latex:
556 escape_re = re.compile(r'(%|_|\$|#|&)',re.MULTILINE)
556 escape_re = re.compile(r'(%|_|\$|#|&)',re.MULTILINE)
557 # Magic command names as headers:
557 # Magic command names as headers:
558 cmd_name_re = re.compile(r'^(%s.*?):' % ESC_MAGIC,
558 cmd_name_re = re.compile(r'^(%s.*?):' % ESC_MAGIC,
559 re.MULTILINE)
559 re.MULTILINE)
560 # Magic commands
560 # Magic commands
561 cmd_re = re.compile(r'(?P<cmd>%s.+?\b)(?!\}\}:)' % ESC_MAGIC,
561 cmd_re = re.compile(r'(?P<cmd>%s.+?\b)(?!\}\}:)' % ESC_MAGIC,
562 re.MULTILINE)
562 re.MULTILINE)
563 # Paragraph continue
563 # Paragraph continue
564 par_re = re.compile(r'\\$',re.MULTILINE)
564 par_re = re.compile(r'\\$',re.MULTILINE)
565
565
566 # The "\n" symbol
566 # The "\n" symbol
567 newline_re = re.compile(r'\\n')
567 newline_re = re.compile(r'\\n')
568
568
569 # Now build the string for output:
569 # Now build the string for output:
570 #strng = cmd_name_re.sub(r'\n\\texttt{\\textsl{\\large \1}}:',strng)
570 #strng = cmd_name_re.sub(r'\n\\texttt{\\textsl{\\large \1}}:',strng)
571 strng = cmd_name_re.sub(r'\n\\bigskip\n\\texttt{\\textbf{ \1}}:',
571 strng = cmd_name_re.sub(r'\n\\bigskip\n\\texttt{\\textbf{ \1}}:',
572 strng)
572 strng)
573 strng = cmd_re.sub(r'\\texttt{\g<cmd>}',strng)
573 strng = cmd_re.sub(r'\\texttt{\g<cmd>}',strng)
574 strng = par_re.sub(r'\\\\',strng)
574 strng = par_re.sub(r'\\\\',strng)
575 strng = escape_re.sub(r'\\\1',strng)
575 strng = escape_re.sub(r'\\\1',strng)
576 strng = newline_re.sub(r'\\textbackslash{}n',strng)
576 strng = newline_re.sub(r'\\textbackslash{}n',strng)
577 return strng
577 return strng
578
578
579 def parse_options(self, arg_str, opt_str, *long_opts, **kw):
579 def parse_options(self, arg_str, opt_str, *long_opts, **kw):
580 """Parse options passed to an argument string.
580 """Parse options passed to an argument string.
581
581
582 The interface is similar to that of :func:`getopt.getopt`, but it
582 The interface is similar to that of :func:`getopt.getopt`, but it
583 returns a :class:`~IPython.utils.struct.Struct` with the options as keys
583 returns a :class:`~IPython.utils.struct.Struct` with the options as keys
584 and the stripped argument string still as a string.
584 and the stripped argument string still as a string.
585
585
586 arg_str is quoted as a true sys.argv vector by using shlex.split.
586 arg_str is quoted as a true sys.argv vector by using shlex.split.
587 This allows us to easily expand variables, glob files, quote
587 This allows us to easily expand variables, glob files, quote
588 arguments, etc.
588 arguments, etc.
589
589
590 Parameters
590 Parameters
591 ----------
591 ----------
592
592
593 arg_str : str
593 arg_str : str
594 The arguments to parse.
594 The arguments to parse.
595
595
596 opt_str : str
596 opt_str : str
597 The options specification.
597 The options specification.
598
598
599 mode : str, default 'string'
599 mode : str, default 'string'
600 If given as 'list', the argument string is returned as a list (split
600 If given as 'list', the argument string is returned as a list (split
601 on whitespace) instead of a string.
601 on whitespace) instead of a string.
602
602
603 list_all : bool, default False
603 list_all : bool, default False
604 Put all option values in lists. Normally only options
604 Put all option values in lists. Normally only options
605 appearing more than once are put in a list.
605 appearing more than once are put in a list.
606
606
607 posix : bool, default True
607 posix : bool, default True
608 Whether to split the input line in POSIX mode or not, as per the
608 Whether to split the input line in POSIX mode or not, as per the
609 conventions outlined in the :mod:`shlex` module from the standard
609 conventions outlined in the :mod:`shlex` module from the standard
610 library.
610 library.
611 """
611 """
612
612
613 # inject default options at the beginning of the input line
613 # inject default options at the beginning of the input line
614 caller = sys._getframe(1).f_code.co_name
614 caller = sys._getframe(1).f_code.co_name
615 arg_str = '%s %s' % (self.options_table.get(caller,''),arg_str)
615 arg_str = '%s %s' % (self.options_table.get(caller,''),arg_str)
616
616
617 mode = kw.get('mode','string')
617 mode = kw.get('mode','string')
618 if mode not in ['string','list']:
618 if mode not in ['string','list']:
619 raise ValueError('incorrect mode given: %s' % mode)
619 raise ValueError('incorrect mode given: %s' % mode)
620 # Get options
620 # Get options
621 list_all = kw.get('list_all',0)
621 list_all = kw.get('list_all',0)
622 posix = kw.get('posix', os.name == 'posix')
622 posix = kw.get('posix', os.name == 'posix')
623 strict = kw.get('strict', True)
623 strict = kw.get('strict', True)
624
624
625 # Check if we have more than one argument to warrant extra processing:
625 # Check if we have more than one argument to warrant extra processing:
626 odict = {} # Dictionary with options
626 odict = {} # Dictionary with options
627 args = arg_str.split()
627 args = arg_str.split()
628 if len(args) >= 1:
628 if len(args) >= 1:
629 # If the list of inputs only has 0 or 1 thing in it, there's no
629 # If the list of inputs only has 0 or 1 thing in it, there's no
630 # need to look for options
630 # need to look for options
631 argv = arg_split(arg_str, posix, strict)
631 argv = arg_split(arg_str, posix, strict)
632 # Do regular option processing
632 # Do regular option processing
633 try:
633 try:
634 opts,args = getopt(argv, opt_str, long_opts)
634 opts,args = getopt(argv, opt_str, long_opts)
635 except GetoptError as e:
635 except GetoptError as e:
636 raise UsageError('%s ( allowed: "%s" %s)' % (e.msg,opt_str,
636 raise UsageError('%s ( allowed: "%s" %s)' % (e.msg,opt_str,
637 " ".join(long_opts)))
637 " ".join(long_opts)))
638 for o,a in opts:
638 for o,a in opts:
639 if o.startswith('--'):
639 if o.startswith('--'):
640 o = o[2:]
640 o = o[2:]
641 else:
641 else:
642 o = o[1:]
642 o = o[1:]
643 try:
643 try:
644 odict[o].append(a)
644 odict[o].append(a)
645 except AttributeError:
645 except AttributeError:
646 odict[o] = [odict[o],a]
646 odict[o] = [odict[o],a]
647 except KeyError:
647 except KeyError:
648 if list_all:
648 if list_all:
649 odict[o] = [a]
649 odict[o] = [a]
650 else:
650 else:
651 odict[o] = a
651 odict[o] = a
652
652
653 # Prepare opts,args for return
653 # Prepare opts,args for return
654 opts = Struct(odict)
654 opts = Struct(odict)
655 if mode == 'string':
655 if mode == 'string':
656 args = ' '.join(args)
656 args = ' '.join(args)
657
657
658 return opts,args
658 return opts,args
659
659
660 def default_option(self, fn, optstr):
660 def default_option(self, fn, optstr):
661 """Make an entry in the options_table for fn, with value optstr"""
661 """Make an entry in the options_table for fn, with value optstr"""
662
662
663 if fn not in self.lsmagic():
663 if fn not in self.lsmagic():
664 error("%s is not a magic function" % fn)
664 error("%s is not a magic function" % fn)
665 self.options_table[fn] = optstr
665 self.options_table[fn] = optstr
666
666
667
667
668 class MagicAlias(object):
668 class MagicAlias(object):
669 """An alias to another magic function.
669 """An alias to another magic function.
670
670
671 An alias is determined by its magic name and magic kind. Lookup
671 An alias is determined by its magic name and magic kind. Lookup
672 is done at call time, so if the underlying magic changes the alias
672 is done at call time, so if the underlying magic changes the alias
673 will call the new function.
673 will call the new function.
674
674
675 Use the :meth:`MagicsManager.register_alias` method or the
675 Use the :meth:`MagicsManager.register_alias` method or the
676 `%alias_magic` magic function to create and register a new alias.
676 `%alias_magic` magic function to create and register a new alias.
677 """
677 """
678 def __init__(self, shell, magic_name, magic_kind):
678 def __init__(self, shell, magic_name, magic_kind):
679 self.shell = shell
679 self.shell = shell
680 self.magic_name = magic_name
680 self.magic_name = magic_name
681 self.magic_kind = magic_kind
681 self.magic_kind = magic_kind
682
682
683 self.pretty_target = '%s%s' % (magic_escapes[self.magic_kind], self.magic_name)
683 self.pretty_target = '%s%s' % (magic_escapes[self.magic_kind], self.magic_name)
684 self.__doc__ = "Alias for `%s`." % self.pretty_target
684 self.__doc__ = "Alias for `%s`." % self.pretty_target
685
685
686 self._in_call = False
686 self._in_call = False
687
687
688 def __call__(self, *args, **kwargs):
688 def __call__(self, *args, **kwargs):
689 """Call the magic alias."""
689 """Call the magic alias."""
690 fn = self.shell.find_magic(self.magic_name, self.magic_kind)
690 fn = self.shell.find_magic(self.magic_name, self.magic_kind)
691 if fn is None:
691 if fn is None:
692 raise UsageError("Magic `%s` not found." % self.pretty_target)
692 raise UsageError("Magic `%s` not found." % self.pretty_target)
693
693
694 # Protect against infinite recursion.
694 # Protect against infinite recursion.
695 if self._in_call:
695 if self._in_call:
696 raise UsageError("Infinite recursion detected; "
696 raise UsageError("Infinite recursion detected; "
697 "magic aliases cannot call themselves.")
697 "magic aliases cannot call themselves.")
698 self._in_call = True
698 self._in_call = True
699 try:
699 try:
700 return fn(*args, **kwargs)
700 return fn(*args, **kwargs)
701 finally:
701 finally:
702 self._in_call = False
702 self._in_call = False
@@ -1,382 +1,382 b''
1 """Tests for the object inspection functionality.
1 """Tests for the object inspection functionality.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2010-2011 The IPython Development Team.
4 # Copyright (C) 2010-2011 The IPython Development Team.
5 #
5 #
6 # Distributed under the terms of the BSD License.
6 # Distributed under the terms of the BSD License.
7 #
7 #
8 # The full license is in the file COPYING.txt, distributed with this software.
8 # The full license is in the file COPYING.txt, distributed with this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 from __future__ import print_function
14 from __future__ import print_function
15
15
16 # Stdlib imports
16 # Stdlib imports
17 import os
17 import os
18 import re
18 import re
19
19
20 # Third-party imports
20 # Third-party imports
21 import nose.tools as nt
21 import nose.tools as nt
22
22
23 # Our own imports
23 # Our own imports
24 from .. import oinspect
24 from .. import oinspect
25 from IPython.core.magic import (Magics, magics_class, line_magic,
25 from IPython.core.magic import (Magics, magics_class, line_magic,
26 cell_magic, line_cell_magic,
26 cell_magic, line_cell_magic,
27 register_line_magic, register_cell_magic,
27 register_line_magic, register_cell_magic,
28 register_line_cell_magic)
28 register_line_cell_magic)
29 from IPython.external.decorator import decorator
29 from decorator import decorator
30 from IPython.testing.decorators import skipif
30 from IPython.testing.decorators import skipif
31 from IPython.testing.tools import AssertPrints
31 from IPython.testing.tools import AssertPrints
32 from IPython.utils.path import compress_user
32 from IPython.utils.path import compress_user
33 from IPython.utils import py3compat
33 from IPython.utils import py3compat
34
34
35
35
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37 # Globals and constants
37 # Globals and constants
38 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
39
39
40 inspector = oinspect.Inspector()
40 inspector = oinspect.Inspector()
41 ip = get_ipython()
41 ip = get_ipython()
42
42
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44 # Local utilities
44 # Local utilities
45 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
46
46
47 # WARNING: since this test checks the line number where a function is
47 # WARNING: since this test checks the line number where a function is
48 # defined, if any code is inserted above, the following line will need to be
48 # defined, if any code is inserted above, the following line will need to be
49 # updated. Do NOT insert any whitespace between the next line and the function
49 # updated. Do NOT insert any whitespace between the next line and the function
50 # definition below.
50 # definition below.
51 THIS_LINE_NUMBER = 51 # Put here the actual number of this line
51 THIS_LINE_NUMBER = 51 # Put here the actual number of this line
52 def test_find_source_lines():
52 def test_find_source_lines():
53 nt.assert_equal(oinspect.find_source_lines(test_find_source_lines),
53 nt.assert_equal(oinspect.find_source_lines(test_find_source_lines),
54 THIS_LINE_NUMBER+1)
54 THIS_LINE_NUMBER+1)
55
55
56
56
57 # A couple of utilities to ensure these tests work the same from a source or a
57 # A couple of utilities to ensure these tests work the same from a source or a
58 # binary install
58 # binary install
59 def pyfile(fname):
59 def pyfile(fname):
60 return os.path.normcase(re.sub('.py[co]$', '.py', fname))
60 return os.path.normcase(re.sub('.py[co]$', '.py', fname))
61
61
62
62
63 def match_pyfiles(f1, f2):
63 def match_pyfiles(f1, f2):
64 nt.assert_equal(pyfile(f1), pyfile(f2))
64 nt.assert_equal(pyfile(f1), pyfile(f2))
65
65
66
66
67 def test_find_file():
67 def test_find_file():
68 match_pyfiles(oinspect.find_file(test_find_file), os.path.abspath(__file__))
68 match_pyfiles(oinspect.find_file(test_find_file), os.path.abspath(__file__))
69
69
70
70
71 def test_find_file_decorated1():
71 def test_find_file_decorated1():
72
72
73 @decorator
73 @decorator
74 def noop1(f):
74 def noop1(f):
75 def wrapper():
75 def wrapper():
76 return f(*a, **kw)
76 return f(*a, **kw)
77 return wrapper
77 return wrapper
78
78
79 @noop1
79 @noop1
80 def f(x):
80 def f(x):
81 "My docstring"
81 "My docstring"
82
82
83 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
83 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
84 nt.assert_equal(f.__doc__, "My docstring")
84 nt.assert_equal(f.__doc__, "My docstring")
85
85
86
86
87 def test_find_file_decorated2():
87 def test_find_file_decorated2():
88
88
89 @decorator
89 @decorator
90 def noop2(f, *a, **kw):
90 def noop2(f, *a, **kw):
91 return f(*a, **kw)
91 return f(*a, **kw)
92
92
93 @noop2
93 @noop2
94 def f(x):
94 def f(x):
95 "My docstring 2"
95 "My docstring 2"
96
96
97 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
97 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
98 nt.assert_equal(f.__doc__, "My docstring 2")
98 nt.assert_equal(f.__doc__, "My docstring 2")
99
99
100
100
101 def test_find_file_magic():
101 def test_find_file_magic():
102 run = ip.find_line_magic('run')
102 run = ip.find_line_magic('run')
103 nt.assert_not_equal(oinspect.find_file(run), None)
103 nt.assert_not_equal(oinspect.find_file(run), None)
104
104
105
105
106 # A few generic objects we can then inspect in the tests below
106 # A few generic objects we can then inspect in the tests below
107
107
108 class Call(object):
108 class Call(object):
109 """This is the class docstring."""
109 """This is the class docstring."""
110
110
111 def __init__(self, x, y=1):
111 def __init__(self, x, y=1):
112 """This is the constructor docstring."""
112 """This is the constructor docstring."""
113
113
114 def __call__(self, *a, **kw):
114 def __call__(self, *a, **kw):
115 """This is the call docstring."""
115 """This is the call docstring."""
116
116
117 def method(self, x, z=2):
117 def method(self, x, z=2):
118 """Some method's docstring"""
118 """Some method's docstring"""
119
119
120 class SimpleClass(object):
120 class SimpleClass(object):
121 def method(self, x, z=2):
121 def method(self, x, z=2):
122 """Some method's docstring"""
122 """Some method's docstring"""
123
123
124
124
125 class OldStyle:
125 class OldStyle:
126 """An old-style class for testing."""
126 """An old-style class for testing."""
127 pass
127 pass
128
128
129
129
130 def f(x, y=2, *a, **kw):
130 def f(x, y=2, *a, **kw):
131 """A simple function."""
131 """A simple function."""
132
132
133
133
134 def g(y, z=3, *a, **kw):
134 def g(y, z=3, *a, **kw):
135 pass # no docstring
135 pass # no docstring
136
136
137
137
138 @register_line_magic
138 @register_line_magic
139 def lmagic(line):
139 def lmagic(line):
140 "A line magic"
140 "A line magic"
141
141
142
142
143 @register_cell_magic
143 @register_cell_magic
144 def cmagic(line, cell):
144 def cmagic(line, cell):
145 "A cell magic"
145 "A cell magic"
146
146
147
147
148 @register_line_cell_magic
148 @register_line_cell_magic
149 def lcmagic(line, cell=None):
149 def lcmagic(line, cell=None):
150 "A line/cell magic"
150 "A line/cell magic"
151
151
152
152
153 @magics_class
153 @magics_class
154 class SimpleMagics(Magics):
154 class SimpleMagics(Magics):
155 @line_magic
155 @line_magic
156 def Clmagic(self, cline):
156 def Clmagic(self, cline):
157 "A class-based line magic"
157 "A class-based line magic"
158
158
159 @cell_magic
159 @cell_magic
160 def Ccmagic(self, cline, ccell):
160 def Ccmagic(self, cline, ccell):
161 "A class-based cell magic"
161 "A class-based cell magic"
162
162
163 @line_cell_magic
163 @line_cell_magic
164 def Clcmagic(self, cline, ccell=None):
164 def Clcmagic(self, cline, ccell=None):
165 "A class-based line/cell magic"
165 "A class-based line/cell magic"
166
166
167
167
168 class Awkward(object):
168 class Awkward(object):
169 def __getattr__(self, name):
169 def __getattr__(self, name):
170 raise Exception(name)
170 raise Exception(name)
171
171
172
172
173 def check_calltip(obj, name, call, docstring):
173 def check_calltip(obj, name, call, docstring):
174 """Generic check pattern all calltip tests will use"""
174 """Generic check pattern all calltip tests will use"""
175 info = inspector.info(obj, name)
175 info = inspector.info(obj, name)
176 call_line, ds = oinspect.call_tip(info)
176 call_line, ds = oinspect.call_tip(info)
177 nt.assert_equal(call_line, call)
177 nt.assert_equal(call_line, call)
178 nt.assert_equal(ds, docstring)
178 nt.assert_equal(ds, docstring)
179
179
180 #-----------------------------------------------------------------------------
180 #-----------------------------------------------------------------------------
181 # Tests
181 # Tests
182 #-----------------------------------------------------------------------------
182 #-----------------------------------------------------------------------------
183
183
184 def test_calltip_class():
184 def test_calltip_class():
185 check_calltip(Call, 'Call', 'Call(x, y=1)', Call.__init__.__doc__)
185 check_calltip(Call, 'Call', 'Call(x, y=1)', Call.__init__.__doc__)
186
186
187
187
188 def test_calltip_instance():
188 def test_calltip_instance():
189 c = Call(1)
189 c = Call(1)
190 check_calltip(c, 'c', 'c(*a, **kw)', c.__call__.__doc__)
190 check_calltip(c, 'c', 'c(*a, **kw)', c.__call__.__doc__)
191
191
192
192
193 def test_calltip_method():
193 def test_calltip_method():
194 c = Call(1)
194 c = Call(1)
195 check_calltip(c.method, 'c.method', 'c.method(x, z=2)', c.method.__doc__)
195 check_calltip(c.method, 'c.method', 'c.method(x, z=2)', c.method.__doc__)
196
196
197
197
198 def test_calltip_function():
198 def test_calltip_function():
199 check_calltip(f, 'f', 'f(x, y=2, *a, **kw)', f.__doc__)
199 check_calltip(f, 'f', 'f(x, y=2, *a, **kw)', f.__doc__)
200
200
201
201
202 def test_calltip_function2():
202 def test_calltip_function2():
203 check_calltip(g, 'g', 'g(y, z=3, *a, **kw)', '<no docstring>')
203 check_calltip(g, 'g', 'g(y, z=3, *a, **kw)', '<no docstring>')
204
204
205
205
206 def test_calltip_builtin():
206 def test_calltip_builtin():
207 check_calltip(sum, 'sum', None, sum.__doc__)
207 check_calltip(sum, 'sum', None, sum.__doc__)
208
208
209
209
210 def test_calltip_line_magic():
210 def test_calltip_line_magic():
211 check_calltip(lmagic, 'lmagic', 'lmagic(line)', "A line magic")
211 check_calltip(lmagic, 'lmagic', 'lmagic(line)', "A line magic")
212
212
213
213
214 def test_calltip_cell_magic():
214 def test_calltip_cell_magic():
215 check_calltip(cmagic, 'cmagic', 'cmagic(line, cell)', "A cell magic")
215 check_calltip(cmagic, 'cmagic', 'cmagic(line, cell)', "A cell magic")
216
216
217
217
218 def test_calltip_line_cell_magic():
218 def test_calltip_line_cell_magic():
219 check_calltip(lcmagic, 'lcmagic', 'lcmagic(line, cell=None)',
219 check_calltip(lcmagic, 'lcmagic', 'lcmagic(line, cell=None)',
220 "A line/cell magic")
220 "A line/cell magic")
221
221
222
222
223 def test_class_magics():
223 def test_class_magics():
224 cm = SimpleMagics(ip)
224 cm = SimpleMagics(ip)
225 ip.register_magics(cm)
225 ip.register_magics(cm)
226 check_calltip(cm.Clmagic, 'Clmagic', 'Clmagic(cline)',
226 check_calltip(cm.Clmagic, 'Clmagic', 'Clmagic(cline)',
227 "A class-based line magic")
227 "A class-based line magic")
228 check_calltip(cm.Ccmagic, 'Ccmagic', 'Ccmagic(cline, ccell)',
228 check_calltip(cm.Ccmagic, 'Ccmagic', 'Ccmagic(cline, ccell)',
229 "A class-based cell magic")
229 "A class-based cell magic")
230 check_calltip(cm.Clcmagic, 'Clcmagic', 'Clcmagic(cline, ccell=None)',
230 check_calltip(cm.Clcmagic, 'Clcmagic', 'Clcmagic(cline, ccell=None)',
231 "A class-based line/cell magic")
231 "A class-based line/cell magic")
232
232
233
233
234 def test_info():
234 def test_info():
235 "Check that Inspector.info fills out various fields as expected."
235 "Check that Inspector.info fills out various fields as expected."
236 i = inspector.info(Call, oname='Call')
236 i = inspector.info(Call, oname='Call')
237 nt.assert_equal(i['type_name'], 'type')
237 nt.assert_equal(i['type_name'], 'type')
238 expted_class = str(type(type)) # <class 'type'> (Python 3) or <type 'type'>
238 expted_class = str(type(type)) # <class 'type'> (Python 3) or <type 'type'>
239 nt.assert_equal(i['base_class'], expted_class)
239 nt.assert_equal(i['base_class'], expted_class)
240 nt.assert_equal(i['string_form'], "<class 'IPython.core.tests.test_oinspect.Call'>")
240 nt.assert_equal(i['string_form'], "<class 'IPython.core.tests.test_oinspect.Call'>")
241 fname = __file__
241 fname = __file__
242 if fname.endswith(".pyc"):
242 if fname.endswith(".pyc"):
243 fname = fname[:-1]
243 fname = fname[:-1]
244 # case-insensitive comparison needed on some filesystems
244 # case-insensitive comparison needed on some filesystems
245 # e.g. Windows:
245 # e.g. Windows:
246 nt.assert_equal(i['file'].lower(), compress_user(fname.lower()))
246 nt.assert_equal(i['file'].lower(), compress_user(fname.lower()))
247 nt.assert_equal(i['definition'], None)
247 nt.assert_equal(i['definition'], None)
248 nt.assert_equal(i['docstring'], Call.__doc__)
248 nt.assert_equal(i['docstring'], Call.__doc__)
249 nt.assert_equal(i['source'], None)
249 nt.assert_equal(i['source'], None)
250 nt.assert_true(i['isclass'])
250 nt.assert_true(i['isclass'])
251 nt.assert_equal(i['init_definition'], "Call(self, x, y=1)\n")
251 nt.assert_equal(i['init_definition'], "Call(self, x, y=1)\n")
252 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
252 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
253
253
254 i = inspector.info(Call, detail_level=1)
254 i = inspector.info(Call, detail_level=1)
255 nt.assert_not_equal(i['source'], None)
255 nt.assert_not_equal(i['source'], None)
256 nt.assert_equal(i['docstring'], None)
256 nt.assert_equal(i['docstring'], None)
257
257
258 c = Call(1)
258 c = Call(1)
259 c.__doc__ = "Modified instance docstring"
259 c.__doc__ = "Modified instance docstring"
260 i = inspector.info(c)
260 i = inspector.info(c)
261 nt.assert_equal(i['type_name'], 'Call')
261 nt.assert_equal(i['type_name'], 'Call')
262 nt.assert_equal(i['docstring'], "Modified instance docstring")
262 nt.assert_equal(i['docstring'], "Modified instance docstring")
263 nt.assert_equal(i['class_docstring'], Call.__doc__)
263 nt.assert_equal(i['class_docstring'], Call.__doc__)
264 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
264 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
265 nt.assert_equal(i['call_docstring'], Call.__call__.__doc__)
265 nt.assert_equal(i['call_docstring'], Call.__call__.__doc__)
266
266
267 # Test old-style classes, which for example may not have an __init__ method.
267 # Test old-style classes, which for example may not have an __init__ method.
268 if not py3compat.PY3:
268 if not py3compat.PY3:
269 i = inspector.info(OldStyle)
269 i = inspector.info(OldStyle)
270 nt.assert_equal(i['type_name'], 'classobj')
270 nt.assert_equal(i['type_name'], 'classobj')
271
271
272 i = inspector.info(OldStyle())
272 i = inspector.info(OldStyle())
273 nt.assert_equal(i['type_name'], 'instance')
273 nt.assert_equal(i['type_name'], 'instance')
274 nt.assert_equal(i['docstring'], OldStyle.__doc__)
274 nt.assert_equal(i['docstring'], OldStyle.__doc__)
275
275
276 def test_info_awkward():
276 def test_info_awkward():
277 # Just test that this doesn't throw an error.
277 # Just test that this doesn't throw an error.
278 i = inspector.info(Awkward())
278 i = inspector.info(Awkward())
279
279
280 def test_calldef_none():
280 def test_calldef_none():
281 # We should ignore __call__ for all of these.
281 # We should ignore __call__ for all of these.
282 for obj in [f, SimpleClass().method, any, str.upper]:
282 for obj in [f, SimpleClass().method, any, str.upper]:
283 print(obj)
283 print(obj)
284 i = inspector.info(obj)
284 i = inspector.info(obj)
285 nt.assert_is(i['call_def'], None)
285 nt.assert_is(i['call_def'], None)
286
286
287 if py3compat.PY3:
287 if py3compat.PY3:
288 exec("def f_kwarg(pos, *, kwonly): pass")
288 exec("def f_kwarg(pos, *, kwonly): pass")
289
289
290 @skipif(not py3compat.PY3)
290 @skipif(not py3compat.PY3)
291 def test_definition_kwonlyargs():
291 def test_definition_kwonlyargs():
292 i = inspector.info(f_kwarg, oname='f_kwarg') # analysis:ignore
292 i = inspector.info(f_kwarg, oname='f_kwarg') # analysis:ignore
293 nt.assert_equal(i['definition'], "f_kwarg(pos, *, kwonly)\n")
293 nt.assert_equal(i['definition'], "f_kwarg(pos, *, kwonly)\n")
294
294
295 def test_getdoc():
295 def test_getdoc():
296 class A(object):
296 class A(object):
297 """standard docstring"""
297 """standard docstring"""
298 pass
298 pass
299
299
300 class B(object):
300 class B(object):
301 """standard docstring"""
301 """standard docstring"""
302 def getdoc(self):
302 def getdoc(self):
303 return "custom docstring"
303 return "custom docstring"
304
304
305 class C(object):
305 class C(object):
306 """standard docstring"""
306 """standard docstring"""
307 def getdoc(self):
307 def getdoc(self):
308 return None
308 return None
309
309
310 a = A()
310 a = A()
311 b = B()
311 b = B()
312 c = C()
312 c = C()
313
313
314 nt.assert_equal(oinspect.getdoc(a), "standard docstring")
314 nt.assert_equal(oinspect.getdoc(a), "standard docstring")
315 nt.assert_equal(oinspect.getdoc(b), "custom docstring")
315 nt.assert_equal(oinspect.getdoc(b), "custom docstring")
316 nt.assert_equal(oinspect.getdoc(c), "standard docstring")
316 nt.assert_equal(oinspect.getdoc(c), "standard docstring")
317
317
318
318
319 def test_empty_property_has_no_source():
319 def test_empty_property_has_no_source():
320 i = inspector.info(property(), detail_level=1)
320 i = inspector.info(property(), detail_level=1)
321 nt.assert_is(i['source'], None)
321 nt.assert_is(i['source'], None)
322
322
323
323
324 def test_property_sources():
324 def test_property_sources():
325 import zlib
325 import zlib
326
326
327 class A(object):
327 class A(object):
328 @property
328 @property
329 def foo(self):
329 def foo(self):
330 return 'bar'
330 return 'bar'
331
331
332 foo = foo.setter(lambda self, v: setattr(self, 'bar', v))
332 foo = foo.setter(lambda self, v: setattr(self, 'bar', v))
333
333
334 id = property(id)
334 id = property(id)
335 compress = property(zlib.compress)
335 compress = property(zlib.compress)
336
336
337 i = inspector.info(A.foo, detail_level=1)
337 i = inspector.info(A.foo, detail_level=1)
338 nt.assert_in('def foo(self):', i['source'])
338 nt.assert_in('def foo(self):', i['source'])
339 nt.assert_in('lambda self, v:', i['source'])
339 nt.assert_in('lambda self, v:', i['source'])
340
340
341 i = inspector.info(A.id, detail_level=1)
341 i = inspector.info(A.id, detail_level=1)
342 nt.assert_in('fget = <function id>', i['source'])
342 nt.assert_in('fget = <function id>', i['source'])
343
343
344 i = inspector.info(A.compress, detail_level=1)
344 i = inspector.info(A.compress, detail_level=1)
345 nt.assert_in('fget = <function zlib.compress>', i['source'])
345 nt.assert_in('fget = <function zlib.compress>', i['source'])
346
346
347
347
348 def test_property_docstring_is_in_info_for_detail_level_0():
348 def test_property_docstring_is_in_info_for_detail_level_0():
349 class A(object):
349 class A(object):
350 @property
350 @property
351 def foobar():
351 def foobar():
352 """This is `foobar` property."""
352 """This is `foobar` property."""
353 pass
353 pass
354
354
355 ip.user_ns['a_obj'] = A()
355 ip.user_ns['a_obj'] = A()
356 nt.assert_equals(
356 nt.assert_equals(
357 'This is `foobar` property.',
357 'This is `foobar` property.',
358 ip.object_inspect('a_obj.foobar', detail_level=0)['docstring'])
358 ip.object_inspect('a_obj.foobar', detail_level=0)['docstring'])
359
359
360 ip.user_ns['a_cls'] = A
360 ip.user_ns['a_cls'] = A
361 nt.assert_equals(
361 nt.assert_equals(
362 'This is `foobar` property.',
362 'This is `foobar` property.',
363 ip.object_inspect('a_cls.foobar', detail_level=0)['docstring'])
363 ip.object_inspect('a_cls.foobar', detail_level=0)['docstring'])
364
364
365
365
366 def test_pdef():
366 def test_pdef():
367 # See gh-1914
367 # See gh-1914
368 def foo(): pass
368 def foo(): pass
369 inspector.pdef(foo, 'foo')
369 inspector.pdef(foo, 'foo')
370
370
371 def test_pinfo_nonascii():
371 def test_pinfo_nonascii():
372 # See gh-1177
372 # See gh-1177
373 from . import nonascii2
373 from . import nonascii2
374 ip.user_ns['nonascii2'] = nonascii2
374 ip.user_ns['nonascii2'] = nonascii2
375 ip._inspect('pinfo', 'nonascii2', detail_level=1)
375 ip._inspect('pinfo', 'nonascii2', detail_level=1)
376
376
377 def test_pinfo_magic():
377 def test_pinfo_magic():
378 with AssertPrints('Docstring:'):
378 with AssertPrints('Docstring:'):
379 ip._inspect('pinfo', 'lsmagic', detail_level=0)
379 ip._inspect('pinfo', 'lsmagic', detail_level=0)
380
380
381 with AssertPrints('Source:'):
381 with AssertPrints('Source:'):
382 ip._inspect('pinfo', 'lsmagic', detail_level=1)
382 ip._inspect('pinfo', 'lsmagic', detail_level=1)
@@ -1,703 +1,703 b''
1 """AsyncResult objects for the client"""
1 """AsyncResult objects for the client"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import print_function
6 from __future__ import print_function
7
7
8 import sys
8 import sys
9 import time
9 import time
10 from datetime import datetime
10 from datetime import datetime
11
11
12 from zmq import MessageTracker
12 from zmq import MessageTracker
13
13
14 from IPython.core.display import clear_output, display, display_pretty
14 from IPython.core.display import clear_output, display, display_pretty
15 from IPython.external.decorator import decorator
15 from decorator import decorator
16 from IPython.parallel import error
16 from IPython.parallel import error
17 from IPython.utils.py3compat import string_types
17 from IPython.utils.py3compat import string_types
18
18
19
19
20 def _raw_text(s):
20 def _raw_text(s):
21 display_pretty(s, raw=True)
21 display_pretty(s, raw=True)
22
22
23
23
24 # global empty tracker that's always done:
24 # global empty tracker that's always done:
25 finished_tracker = MessageTracker()
25 finished_tracker = MessageTracker()
26
26
27 @decorator
27 @decorator
28 def check_ready(f, self, *args, **kwargs):
28 def check_ready(f, self, *args, **kwargs):
29 """Call spin() to sync state prior to calling the method."""
29 """Call spin() to sync state prior to calling the method."""
30 self.wait(0)
30 self.wait(0)
31 if not self._ready:
31 if not self._ready:
32 raise error.TimeoutError("result not ready")
32 raise error.TimeoutError("result not ready")
33 return f(self, *args, **kwargs)
33 return f(self, *args, **kwargs)
34
34
35 class AsyncResult(object):
35 class AsyncResult(object):
36 """Class for representing results of non-blocking calls.
36 """Class for representing results of non-blocking calls.
37
37
38 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
38 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
39 """
39 """
40
40
41 msg_ids = None
41 msg_ids = None
42 _targets = None
42 _targets = None
43 _tracker = None
43 _tracker = None
44 _single_result = False
44 _single_result = False
45 owner = False,
45 owner = False,
46
46
47 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None,
47 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None,
48 owner=False,
48 owner=False,
49 ):
49 ):
50 if isinstance(msg_ids, string_types):
50 if isinstance(msg_ids, string_types):
51 # always a list
51 # always a list
52 msg_ids = [msg_ids]
52 msg_ids = [msg_ids]
53 self._single_result = True
53 self._single_result = True
54 else:
54 else:
55 self._single_result = False
55 self._single_result = False
56 if tracker is None:
56 if tracker is None:
57 # default to always done
57 # default to always done
58 tracker = finished_tracker
58 tracker = finished_tracker
59 self._client = client
59 self._client = client
60 self.msg_ids = msg_ids
60 self.msg_ids = msg_ids
61 self._fname=fname
61 self._fname=fname
62 self._targets = targets
62 self._targets = targets
63 self._tracker = tracker
63 self._tracker = tracker
64 self.owner = owner
64 self.owner = owner
65
65
66 self._ready = False
66 self._ready = False
67 self._outputs_ready = False
67 self._outputs_ready = False
68 self._success = None
68 self._success = None
69 self._metadata = [self._client.metadata[id] for id in self.msg_ids]
69 self._metadata = [self._client.metadata[id] for id in self.msg_ids]
70
70
71 def __repr__(self):
71 def __repr__(self):
72 if self._ready:
72 if self._ready:
73 return "<%s: finished>"%(self.__class__.__name__)
73 return "<%s: finished>"%(self.__class__.__name__)
74 else:
74 else:
75 return "<%s: %s>"%(self.__class__.__name__,self._fname)
75 return "<%s: %s>"%(self.__class__.__name__,self._fname)
76
76
77
77
78 def _reconstruct_result(self, res):
78 def _reconstruct_result(self, res):
79 """Reconstruct our result from actual result list (always a list)
79 """Reconstruct our result from actual result list (always a list)
80
80
81 Override me in subclasses for turning a list of results
81 Override me in subclasses for turning a list of results
82 into the expected form.
82 into the expected form.
83 """
83 """
84 if self._single_result:
84 if self._single_result:
85 return res[0]
85 return res[0]
86 else:
86 else:
87 return res
87 return res
88
88
89 def get(self, timeout=-1):
89 def get(self, timeout=-1):
90 """Return the result when it arrives.
90 """Return the result when it arrives.
91
91
92 If `timeout` is not ``None`` and the result does not arrive within
92 If `timeout` is not ``None`` and the result does not arrive within
93 `timeout` seconds then ``TimeoutError`` is raised. If the
93 `timeout` seconds then ``TimeoutError`` is raised. If the
94 remote call raised an exception then that exception will be reraised
94 remote call raised an exception then that exception will be reraised
95 by get() inside a `RemoteError`.
95 by get() inside a `RemoteError`.
96 """
96 """
97 if not self.ready():
97 if not self.ready():
98 self.wait(timeout)
98 self.wait(timeout)
99
99
100 if self._ready:
100 if self._ready:
101 if self._success:
101 if self._success:
102 return self._result
102 return self._result
103 else:
103 else:
104 raise self._exception
104 raise self._exception
105 else:
105 else:
106 raise error.TimeoutError("Result not ready.")
106 raise error.TimeoutError("Result not ready.")
107
107
108 def _check_ready(self):
108 def _check_ready(self):
109 if not self.ready():
109 if not self.ready():
110 raise error.TimeoutError("Result not ready.")
110 raise error.TimeoutError("Result not ready.")
111
111
112 def ready(self):
112 def ready(self):
113 """Return whether the call has completed."""
113 """Return whether the call has completed."""
114 if not self._ready:
114 if not self._ready:
115 self.wait(0)
115 self.wait(0)
116 elif not self._outputs_ready:
116 elif not self._outputs_ready:
117 self._wait_for_outputs(0)
117 self._wait_for_outputs(0)
118
118
119 return self._ready
119 return self._ready
120
120
121 def wait(self, timeout=-1):
121 def wait(self, timeout=-1):
122 """Wait until the result is available or until `timeout` seconds pass.
122 """Wait until the result is available or until `timeout` seconds pass.
123
123
124 This method always returns None.
124 This method always returns None.
125 """
125 """
126 if self._ready:
126 if self._ready:
127 self._wait_for_outputs(timeout)
127 self._wait_for_outputs(timeout)
128 return
128 return
129 self._ready = self._client.wait(self.msg_ids, timeout)
129 self._ready = self._client.wait(self.msg_ids, timeout)
130 if self._ready:
130 if self._ready:
131 try:
131 try:
132 results = list(map(self._client.results.get, self.msg_ids))
132 results = list(map(self._client.results.get, self.msg_ids))
133 self._result = results
133 self._result = results
134 if self._single_result:
134 if self._single_result:
135 r = results[0]
135 r = results[0]
136 if isinstance(r, Exception):
136 if isinstance(r, Exception):
137 raise r
137 raise r
138 else:
138 else:
139 results = error.collect_exceptions(results, self._fname)
139 results = error.collect_exceptions(results, self._fname)
140 self._result = self._reconstruct_result(results)
140 self._result = self._reconstruct_result(results)
141 except Exception as e:
141 except Exception as e:
142 self._exception = e
142 self._exception = e
143 self._success = False
143 self._success = False
144 else:
144 else:
145 self._success = True
145 self._success = True
146 finally:
146 finally:
147 if timeout is None or timeout < 0:
147 if timeout is None or timeout < 0:
148 # cutoff infinite wait at 10s
148 # cutoff infinite wait at 10s
149 timeout = 10
149 timeout = 10
150 self._wait_for_outputs(timeout)
150 self._wait_for_outputs(timeout)
151
151
152 if self.owner:
152 if self.owner:
153
153
154 self._metadata = [self._client.metadata.pop(mid) for mid in self.msg_ids]
154 self._metadata = [self._client.metadata.pop(mid) for mid in self.msg_ids]
155 [self._client.results.pop(mid) for mid in self.msg_ids]
155 [self._client.results.pop(mid) for mid in self.msg_ids]
156
156
157
157
158
158
159 def successful(self):
159 def successful(self):
160 """Return whether the call completed without raising an exception.
160 """Return whether the call completed without raising an exception.
161
161
162 Will raise ``AssertionError`` if the result is not ready.
162 Will raise ``AssertionError`` if the result is not ready.
163 """
163 """
164 assert self.ready()
164 assert self.ready()
165 return self._success
165 return self._success
166
166
167 #----------------------------------------------------------------
167 #----------------------------------------------------------------
168 # Extra methods not in mp.pool.AsyncResult
168 # Extra methods not in mp.pool.AsyncResult
169 #----------------------------------------------------------------
169 #----------------------------------------------------------------
170
170
171 def get_dict(self, timeout=-1):
171 def get_dict(self, timeout=-1):
172 """Get the results as a dict, keyed by engine_id.
172 """Get the results as a dict, keyed by engine_id.
173
173
174 timeout behavior is described in `get()`.
174 timeout behavior is described in `get()`.
175 """
175 """
176
176
177 results = self.get(timeout)
177 results = self.get(timeout)
178 if self._single_result:
178 if self._single_result:
179 results = [results]
179 results = [results]
180 engine_ids = [ md['engine_id'] for md in self._metadata ]
180 engine_ids = [ md['engine_id'] for md in self._metadata ]
181
181
182
182
183 rdict = {}
183 rdict = {}
184 for engine_id, result in zip(engine_ids, results):
184 for engine_id, result in zip(engine_ids, results):
185 if engine_id in rdict:
185 if engine_id in rdict:
186 raise ValueError("Cannot build dict, %i jobs ran on engine #%i" % (
186 raise ValueError("Cannot build dict, %i jobs ran on engine #%i" % (
187 engine_ids.count(engine_id), engine_id)
187 engine_ids.count(engine_id), engine_id)
188 )
188 )
189 else:
189 else:
190 rdict[engine_id] = result
190 rdict[engine_id] = result
191
191
192 return rdict
192 return rdict
193
193
194 @property
194 @property
195 def result(self):
195 def result(self):
196 """result property wrapper for `get(timeout=-1)`."""
196 """result property wrapper for `get(timeout=-1)`."""
197 return self.get()
197 return self.get()
198
198
199 # abbreviated alias:
199 # abbreviated alias:
200 r = result
200 r = result
201
201
202 @property
202 @property
203 def metadata(self):
203 def metadata(self):
204 """property for accessing execution metadata."""
204 """property for accessing execution metadata."""
205 if self._single_result:
205 if self._single_result:
206 return self._metadata[0]
206 return self._metadata[0]
207 else:
207 else:
208 return self._metadata
208 return self._metadata
209
209
210 @property
210 @property
211 def result_dict(self):
211 def result_dict(self):
212 """result property as a dict."""
212 """result property as a dict."""
213 return self.get_dict()
213 return self.get_dict()
214
214
215 def __dict__(self):
215 def __dict__(self):
216 return self.get_dict(0)
216 return self.get_dict(0)
217
217
218 def abort(self):
218 def abort(self):
219 """abort my tasks."""
219 """abort my tasks."""
220 assert not self.ready(), "Can't abort, I am already done!"
220 assert not self.ready(), "Can't abort, I am already done!"
221 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
221 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
222
222
223 @property
223 @property
224 def sent(self):
224 def sent(self):
225 """check whether my messages have been sent."""
225 """check whether my messages have been sent."""
226 return self._tracker.done
226 return self._tracker.done
227
227
228 def wait_for_send(self, timeout=-1):
228 def wait_for_send(self, timeout=-1):
229 """wait for pyzmq send to complete.
229 """wait for pyzmq send to complete.
230
230
231 This is necessary when sending arrays that you intend to edit in-place.
231 This is necessary when sending arrays that you intend to edit in-place.
232 `timeout` is in seconds, and will raise TimeoutError if it is reached
232 `timeout` is in seconds, and will raise TimeoutError if it is reached
233 before the send completes.
233 before the send completes.
234 """
234 """
235 return self._tracker.wait(timeout)
235 return self._tracker.wait(timeout)
236
236
237 #-------------------------------------
237 #-------------------------------------
238 # dict-access
238 # dict-access
239 #-------------------------------------
239 #-------------------------------------
240
240
241 def __getitem__(self, key):
241 def __getitem__(self, key):
242 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
242 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
243 """
243 """
244 if isinstance(key, int):
244 if isinstance(key, int):
245 self._check_ready()
245 self._check_ready()
246 return error.collect_exceptions([self._result[key]], self._fname)[0]
246 return error.collect_exceptions([self._result[key]], self._fname)[0]
247 elif isinstance(key, slice):
247 elif isinstance(key, slice):
248 self._check_ready()
248 self._check_ready()
249 return error.collect_exceptions(self._result[key], self._fname)
249 return error.collect_exceptions(self._result[key], self._fname)
250 elif isinstance(key, string_types):
250 elif isinstance(key, string_types):
251 # metadata proxy *does not* require that results are done
251 # metadata proxy *does not* require that results are done
252 self.wait(0)
252 self.wait(0)
253 values = [ md[key] for md in self._metadata ]
253 values = [ md[key] for md in self._metadata ]
254 if self._single_result:
254 if self._single_result:
255 return values[0]
255 return values[0]
256 else:
256 else:
257 return values
257 return values
258 else:
258 else:
259 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
259 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
260
260
261 def __getattr__(self, key):
261 def __getattr__(self, key):
262 """getattr maps to getitem for convenient attr access to metadata."""
262 """getattr maps to getitem for convenient attr access to metadata."""
263 try:
263 try:
264 return self.__getitem__(key)
264 return self.__getitem__(key)
265 except (error.TimeoutError, KeyError):
265 except (error.TimeoutError, KeyError):
266 raise AttributeError("%r object has no attribute %r"%(
266 raise AttributeError("%r object has no attribute %r"%(
267 self.__class__.__name__, key))
267 self.__class__.__name__, key))
268
268
269 # asynchronous iterator:
269 # asynchronous iterator:
270 def __iter__(self):
270 def __iter__(self):
271 if self._single_result:
271 if self._single_result:
272 raise TypeError("AsyncResults with a single result are not iterable.")
272 raise TypeError("AsyncResults with a single result are not iterable.")
273 try:
273 try:
274 rlist = self.get(0)
274 rlist = self.get(0)
275 except error.TimeoutError:
275 except error.TimeoutError:
276 # wait for each result individually
276 # wait for each result individually
277 for msg_id in self.msg_ids:
277 for msg_id in self.msg_ids:
278 ar = AsyncResult(self._client, msg_id, self._fname)
278 ar = AsyncResult(self._client, msg_id, self._fname)
279 yield ar.get()
279 yield ar.get()
280 else:
280 else:
281 # already done
281 # already done
282 for r in rlist:
282 for r in rlist:
283 yield r
283 yield r
284
284
285 def __len__(self):
285 def __len__(self):
286 return len(self.msg_ids)
286 return len(self.msg_ids)
287
287
288 #-------------------------------------
288 #-------------------------------------
289 # Sugar methods and attributes
289 # Sugar methods and attributes
290 #-------------------------------------
290 #-------------------------------------
291
291
292 def timedelta(self, start, end, start_key=min, end_key=max):
292 def timedelta(self, start, end, start_key=min, end_key=max):
293 """compute the difference between two sets of timestamps
293 """compute the difference between two sets of timestamps
294
294
295 The default behavior is to use the earliest of the first
295 The default behavior is to use the earliest of the first
296 and the latest of the second list, but this can be changed
296 and the latest of the second list, but this can be changed
297 by passing a different
297 by passing a different
298
298
299 Parameters
299 Parameters
300 ----------
300 ----------
301
301
302 start : one or more datetime objects (e.g. ar.submitted)
302 start : one or more datetime objects (e.g. ar.submitted)
303 end : one or more datetime objects (e.g. ar.received)
303 end : one or more datetime objects (e.g. ar.received)
304 start_key : callable
304 start_key : callable
305 Function to call on `start` to extract the relevant
305 Function to call on `start` to extract the relevant
306 entry [defalt: min]
306 entry [defalt: min]
307 end_key : callable
307 end_key : callable
308 Function to call on `end` to extract the relevant
308 Function to call on `end` to extract the relevant
309 entry [default: max]
309 entry [default: max]
310
310
311 Returns
311 Returns
312 -------
312 -------
313
313
314 dt : float
314 dt : float
315 The time elapsed (in seconds) between the two selected timestamps.
315 The time elapsed (in seconds) between the two selected timestamps.
316 """
316 """
317 if not isinstance(start, datetime):
317 if not isinstance(start, datetime):
318 # handle single_result AsyncResults, where ar.stamp is single object,
318 # handle single_result AsyncResults, where ar.stamp is single object,
319 # not a list
319 # not a list
320 start = start_key(start)
320 start = start_key(start)
321 if not isinstance(end, datetime):
321 if not isinstance(end, datetime):
322 # handle single_result AsyncResults, where ar.stamp is single object,
322 # handle single_result AsyncResults, where ar.stamp is single object,
323 # not a list
323 # not a list
324 end = end_key(end)
324 end = end_key(end)
325 return (end - start).total_seconds()
325 return (end - start).total_seconds()
326
326
327 @property
327 @property
328 def progress(self):
328 def progress(self):
329 """the number of tasks which have been completed at this point.
329 """the number of tasks which have been completed at this point.
330
330
331 Fractional progress would be given by 1.0 * ar.progress / len(ar)
331 Fractional progress would be given by 1.0 * ar.progress / len(ar)
332 """
332 """
333 self.wait(0)
333 self.wait(0)
334 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
334 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
335
335
336 @property
336 @property
337 def elapsed(self):
337 def elapsed(self):
338 """elapsed time since initial submission"""
338 """elapsed time since initial submission"""
339 if self.ready():
339 if self.ready():
340 return self.wall_time
340 return self.wall_time
341
341
342 now = submitted = datetime.now()
342 now = submitted = datetime.now()
343 for msg_id in self.msg_ids:
343 for msg_id in self.msg_ids:
344 if msg_id in self._client.metadata:
344 if msg_id in self._client.metadata:
345 stamp = self._client.metadata[msg_id]['submitted']
345 stamp = self._client.metadata[msg_id]['submitted']
346 if stamp and stamp < submitted:
346 if stamp and stamp < submitted:
347 submitted = stamp
347 submitted = stamp
348 return (now-submitted).total_seconds()
348 return (now-submitted).total_seconds()
349
349
350 @property
350 @property
351 @check_ready
351 @check_ready
352 def serial_time(self):
352 def serial_time(self):
353 """serial computation time of a parallel calculation
353 """serial computation time of a parallel calculation
354
354
355 Computed as the sum of (completed-started) of each task
355 Computed as the sum of (completed-started) of each task
356 """
356 """
357 t = 0
357 t = 0
358 for md in self._metadata:
358 for md in self._metadata:
359 t += (md['completed'] - md['started']).total_seconds()
359 t += (md['completed'] - md['started']).total_seconds()
360 return t
360 return t
361
361
362 @property
362 @property
363 @check_ready
363 @check_ready
364 def wall_time(self):
364 def wall_time(self):
365 """actual computation time of a parallel calculation
365 """actual computation time of a parallel calculation
366
366
367 Computed as the time between the latest `received` stamp
367 Computed as the time between the latest `received` stamp
368 and the earliest `submitted`.
368 and the earliest `submitted`.
369
369
370 Only reliable if Client was spinning/waiting when the task finished, because
370 Only reliable if Client was spinning/waiting when the task finished, because
371 the `received` timestamp is created when a result is pulled off of the zmq queue,
371 the `received` timestamp is created when a result is pulled off of the zmq queue,
372 which happens as a result of `client.spin()`.
372 which happens as a result of `client.spin()`.
373
373
374 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
374 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
375
375
376 """
376 """
377 return self.timedelta(self.submitted, self.received)
377 return self.timedelta(self.submitted, self.received)
378
378
379 def wait_interactive(self, interval=1., timeout=-1):
379 def wait_interactive(self, interval=1., timeout=-1):
380 """interactive wait, printing progress at regular intervals"""
380 """interactive wait, printing progress at regular intervals"""
381 if timeout is None:
381 if timeout is None:
382 timeout = -1
382 timeout = -1
383 N = len(self)
383 N = len(self)
384 tic = time.time()
384 tic = time.time()
385 while not self.ready() and (timeout < 0 or time.time() - tic <= timeout):
385 while not self.ready() and (timeout < 0 or time.time() - tic <= timeout):
386 self.wait(interval)
386 self.wait(interval)
387 clear_output(wait=True)
387 clear_output(wait=True)
388 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
388 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
389 sys.stdout.flush()
389 sys.stdout.flush()
390 print()
390 print()
391 print("done")
391 print("done")
392
392
393 def _republish_displaypub(self, content, eid):
393 def _republish_displaypub(self, content, eid):
394 """republish individual displaypub content dicts"""
394 """republish individual displaypub content dicts"""
395 try:
395 try:
396 ip = get_ipython()
396 ip = get_ipython()
397 except NameError:
397 except NameError:
398 # displaypub is meaningless outside IPython
398 # displaypub is meaningless outside IPython
399 return
399 return
400 md = content['metadata'] or {}
400 md = content['metadata'] or {}
401 md['engine'] = eid
401 md['engine'] = eid
402 ip.display_pub.publish(data=content['data'], metadata=md)
402 ip.display_pub.publish(data=content['data'], metadata=md)
403
403
404 def _display_stream(self, text, prefix='', file=None):
404 def _display_stream(self, text, prefix='', file=None):
405 if not text:
405 if not text:
406 # nothing to display
406 # nothing to display
407 return
407 return
408 if file is None:
408 if file is None:
409 file = sys.stdout
409 file = sys.stdout
410 end = '' if text.endswith('\n') else '\n'
410 end = '' if text.endswith('\n') else '\n'
411
411
412 multiline = text.count('\n') > int(text.endswith('\n'))
412 multiline = text.count('\n') > int(text.endswith('\n'))
413 if prefix and multiline and not text.startswith('\n'):
413 if prefix and multiline and not text.startswith('\n'):
414 prefix = prefix + '\n'
414 prefix = prefix + '\n'
415 print("%s%s" % (prefix, text), file=file, end=end)
415 print("%s%s" % (prefix, text), file=file, end=end)
416
416
417
417
418 def _display_single_result(self):
418 def _display_single_result(self):
419 self._display_stream(self.stdout)
419 self._display_stream(self.stdout)
420 self._display_stream(self.stderr, file=sys.stderr)
420 self._display_stream(self.stderr, file=sys.stderr)
421
421
422 try:
422 try:
423 get_ipython()
423 get_ipython()
424 except NameError:
424 except NameError:
425 # displaypub is meaningless outside IPython
425 # displaypub is meaningless outside IPython
426 return
426 return
427
427
428 for output in self.outputs:
428 for output in self.outputs:
429 self._republish_displaypub(output, self.engine_id)
429 self._republish_displaypub(output, self.engine_id)
430
430
431 if self.execute_result is not None:
431 if self.execute_result is not None:
432 display(self.get())
432 display(self.get())
433
433
434 def _wait_for_outputs(self, timeout=-1):
434 def _wait_for_outputs(self, timeout=-1):
435 """wait for the 'status=idle' message that indicates we have all outputs
435 """wait for the 'status=idle' message that indicates we have all outputs
436 """
436 """
437 if self._outputs_ready or not self._success:
437 if self._outputs_ready or not self._success:
438 # don't wait on errors
438 # don't wait on errors
439 return
439 return
440
440
441 # cast None to -1 for infinite timeout
441 # cast None to -1 for infinite timeout
442 if timeout is None:
442 if timeout is None:
443 timeout = -1
443 timeout = -1
444
444
445 tic = time.time()
445 tic = time.time()
446 while True:
446 while True:
447 self._client._flush_iopub(self._client._iopub_socket)
447 self._client._flush_iopub(self._client._iopub_socket)
448 self._outputs_ready = all(md['outputs_ready']
448 self._outputs_ready = all(md['outputs_ready']
449 for md in self._metadata)
449 for md in self._metadata)
450 if self._outputs_ready or \
450 if self._outputs_ready or \
451 (timeout >= 0 and time.time() > tic + timeout):
451 (timeout >= 0 and time.time() > tic + timeout):
452 break
452 break
453 time.sleep(0.01)
453 time.sleep(0.01)
454
454
455 @check_ready
455 @check_ready
456 def display_outputs(self, groupby="type"):
456 def display_outputs(self, groupby="type"):
457 """republish the outputs of the computation
457 """republish the outputs of the computation
458
458
459 Parameters
459 Parameters
460 ----------
460 ----------
461
461
462 groupby : str [default: type]
462 groupby : str [default: type]
463 if 'type':
463 if 'type':
464 Group outputs by type (show all stdout, then all stderr, etc.):
464 Group outputs by type (show all stdout, then all stderr, etc.):
465
465
466 [stdout:1] foo
466 [stdout:1] foo
467 [stdout:2] foo
467 [stdout:2] foo
468 [stderr:1] bar
468 [stderr:1] bar
469 [stderr:2] bar
469 [stderr:2] bar
470 if 'engine':
470 if 'engine':
471 Display outputs for each engine before moving on to the next:
471 Display outputs for each engine before moving on to the next:
472
472
473 [stdout:1] foo
473 [stdout:1] foo
474 [stderr:1] bar
474 [stderr:1] bar
475 [stdout:2] foo
475 [stdout:2] foo
476 [stderr:2] bar
476 [stderr:2] bar
477
477
478 if 'order':
478 if 'order':
479 Like 'type', but further collate individual displaypub
479 Like 'type', but further collate individual displaypub
480 outputs. This is meant for cases of each command producing
480 outputs. This is meant for cases of each command producing
481 several plots, and you would like to see all of the first
481 several plots, and you would like to see all of the first
482 plots together, then all of the second plots, and so on.
482 plots together, then all of the second plots, and so on.
483 """
483 """
484 if self._single_result:
484 if self._single_result:
485 self._display_single_result()
485 self._display_single_result()
486 return
486 return
487
487
488 stdouts = self.stdout
488 stdouts = self.stdout
489 stderrs = self.stderr
489 stderrs = self.stderr
490 execute_results = self.execute_result
490 execute_results = self.execute_result
491 output_lists = self.outputs
491 output_lists = self.outputs
492 results = self.get()
492 results = self.get()
493
493
494 targets = self.engine_id
494 targets = self.engine_id
495
495
496 if groupby == "engine":
496 if groupby == "engine":
497 for eid,stdout,stderr,outputs,r,execute_result in zip(
497 for eid,stdout,stderr,outputs,r,execute_result in zip(
498 targets, stdouts, stderrs, output_lists, results, execute_results
498 targets, stdouts, stderrs, output_lists, results, execute_results
499 ):
499 ):
500 self._display_stream(stdout, '[stdout:%i] ' % eid)
500 self._display_stream(stdout, '[stdout:%i] ' % eid)
501 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
501 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
502
502
503 try:
503 try:
504 get_ipython()
504 get_ipython()
505 except NameError:
505 except NameError:
506 # displaypub is meaningless outside IPython
506 # displaypub is meaningless outside IPython
507 return
507 return
508
508
509 if outputs or execute_result is not None:
509 if outputs or execute_result is not None:
510 _raw_text('[output:%i]' % eid)
510 _raw_text('[output:%i]' % eid)
511
511
512 for output in outputs:
512 for output in outputs:
513 self._republish_displaypub(output, eid)
513 self._republish_displaypub(output, eid)
514
514
515 if execute_result is not None:
515 if execute_result is not None:
516 display(r)
516 display(r)
517
517
518 elif groupby in ('type', 'order'):
518 elif groupby in ('type', 'order'):
519 # republish stdout:
519 # republish stdout:
520 for eid,stdout in zip(targets, stdouts):
520 for eid,stdout in zip(targets, stdouts):
521 self._display_stream(stdout, '[stdout:%i] ' % eid)
521 self._display_stream(stdout, '[stdout:%i] ' % eid)
522
522
523 # republish stderr:
523 # republish stderr:
524 for eid,stderr in zip(targets, stderrs):
524 for eid,stderr in zip(targets, stderrs):
525 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
525 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
526
526
527 try:
527 try:
528 get_ipython()
528 get_ipython()
529 except NameError:
529 except NameError:
530 # displaypub is meaningless outside IPython
530 # displaypub is meaningless outside IPython
531 return
531 return
532
532
533 if groupby == 'order':
533 if groupby == 'order':
534 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
534 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
535 N = max(len(outputs) for outputs in output_lists)
535 N = max(len(outputs) for outputs in output_lists)
536 for i in range(N):
536 for i in range(N):
537 for eid in targets:
537 for eid in targets:
538 outputs = output_dict[eid]
538 outputs = output_dict[eid]
539 if len(outputs) >= N:
539 if len(outputs) >= N:
540 _raw_text('[output:%i]' % eid)
540 _raw_text('[output:%i]' % eid)
541 self._republish_displaypub(outputs[i], eid)
541 self._republish_displaypub(outputs[i], eid)
542 else:
542 else:
543 # republish displaypub output
543 # republish displaypub output
544 for eid,outputs in zip(targets, output_lists):
544 for eid,outputs in zip(targets, output_lists):
545 if outputs:
545 if outputs:
546 _raw_text('[output:%i]' % eid)
546 _raw_text('[output:%i]' % eid)
547 for output in outputs:
547 for output in outputs:
548 self._republish_displaypub(output, eid)
548 self._republish_displaypub(output, eid)
549
549
550 # finally, add execute_result:
550 # finally, add execute_result:
551 for eid,r,execute_result in zip(targets, results, execute_results):
551 for eid,r,execute_result in zip(targets, results, execute_results):
552 if execute_result is not None:
552 if execute_result is not None:
553 display(r)
553 display(r)
554
554
555 else:
555 else:
556 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
556 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
557
557
558
558
559
559
560
560
561 class AsyncMapResult(AsyncResult):
561 class AsyncMapResult(AsyncResult):
562 """Class for representing results of non-blocking gathers.
562 """Class for representing results of non-blocking gathers.
563
563
564 This will properly reconstruct the gather.
564 This will properly reconstruct the gather.
565
565
566 This class is iterable at any time, and will wait on results as they come.
566 This class is iterable at any time, and will wait on results as they come.
567
567
568 If ordered=False, then the first results to arrive will come first, otherwise
568 If ordered=False, then the first results to arrive will come first, otherwise
569 results will be yielded in the order they were submitted.
569 results will be yielded in the order they were submitted.
570
570
571 """
571 """
572
572
573 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
573 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
574 AsyncResult.__init__(self, client, msg_ids, fname=fname)
574 AsyncResult.__init__(self, client, msg_ids, fname=fname)
575 self._mapObject = mapObject
575 self._mapObject = mapObject
576 self._single_result = False
576 self._single_result = False
577 self.ordered = ordered
577 self.ordered = ordered
578
578
579 def _reconstruct_result(self, res):
579 def _reconstruct_result(self, res):
580 """Perform the gather on the actual results."""
580 """Perform the gather on the actual results."""
581 return self._mapObject.joinPartitions(res)
581 return self._mapObject.joinPartitions(res)
582
582
583 # asynchronous iterator:
583 # asynchronous iterator:
584 def __iter__(self):
584 def __iter__(self):
585 it = self._ordered_iter if self.ordered else self._unordered_iter
585 it = self._ordered_iter if self.ordered else self._unordered_iter
586 for r in it():
586 for r in it():
587 yield r
587 yield r
588
588
589 # asynchronous ordered iterator:
589 # asynchronous ordered iterator:
590 def _ordered_iter(self):
590 def _ordered_iter(self):
591 """iterator for results *as they arrive*, preserving submission order."""
591 """iterator for results *as they arrive*, preserving submission order."""
592 try:
592 try:
593 rlist = self.get(0)
593 rlist = self.get(0)
594 except error.TimeoutError:
594 except error.TimeoutError:
595 # wait for each result individually
595 # wait for each result individually
596 for msg_id in self.msg_ids:
596 for msg_id in self.msg_ids:
597 ar = AsyncResult(self._client, msg_id, self._fname)
597 ar = AsyncResult(self._client, msg_id, self._fname)
598 rlist = ar.get()
598 rlist = ar.get()
599 try:
599 try:
600 for r in rlist:
600 for r in rlist:
601 yield r
601 yield r
602 except TypeError:
602 except TypeError:
603 # flattened, not a list
603 # flattened, not a list
604 # this could get broken by flattened data that returns iterables
604 # this could get broken by flattened data that returns iterables
605 # but most calls to map do not expose the `flatten` argument
605 # but most calls to map do not expose the `flatten` argument
606 yield rlist
606 yield rlist
607 else:
607 else:
608 # already done
608 # already done
609 for r in rlist:
609 for r in rlist:
610 yield r
610 yield r
611
611
612 # asynchronous unordered iterator:
612 # asynchronous unordered iterator:
613 def _unordered_iter(self):
613 def _unordered_iter(self):
614 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
614 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
615 try:
615 try:
616 rlist = self.get(0)
616 rlist = self.get(0)
617 except error.TimeoutError:
617 except error.TimeoutError:
618 pending = set(self.msg_ids)
618 pending = set(self.msg_ids)
619 while pending:
619 while pending:
620 try:
620 try:
621 self._client.wait(pending, 1e-3)
621 self._client.wait(pending, 1e-3)
622 except error.TimeoutError:
622 except error.TimeoutError:
623 # ignore timeout error, because that only means
623 # ignore timeout error, because that only means
624 # *some* jobs are outstanding
624 # *some* jobs are outstanding
625 pass
625 pass
626 # update ready set with those no longer outstanding:
626 # update ready set with those no longer outstanding:
627 ready = pending.difference(self._client.outstanding)
627 ready = pending.difference(self._client.outstanding)
628 # update pending to exclude those that are finished
628 # update pending to exclude those that are finished
629 pending = pending.difference(ready)
629 pending = pending.difference(ready)
630 while ready:
630 while ready:
631 msg_id = ready.pop()
631 msg_id = ready.pop()
632 ar = AsyncResult(self._client, msg_id, self._fname)
632 ar = AsyncResult(self._client, msg_id, self._fname)
633 rlist = ar.get()
633 rlist = ar.get()
634 try:
634 try:
635 for r in rlist:
635 for r in rlist:
636 yield r
636 yield r
637 except TypeError:
637 except TypeError:
638 # flattened, not a list
638 # flattened, not a list
639 # this could get broken by flattened data that returns iterables
639 # this could get broken by flattened data that returns iterables
640 # but most calls to map do not expose the `flatten` argument
640 # but most calls to map do not expose the `flatten` argument
641 yield rlist
641 yield rlist
642 else:
642 else:
643 # already done
643 # already done
644 for r in rlist:
644 for r in rlist:
645 yield r
645 yield r
646
646
647
647
648 class AsyncHubResult(AsyncResult):
648 class AsyncHubResult(AsyncResult):
649 """Class to wrap pending results that must be requested from the Hub.
649 """Class to wrap pending results that must be requested from the Hub.
650
650
651 Note that waiting/polling on these objects requires polling the Hubover the network,
651 Note that waiting/polling on these objects requires polling the Hubover the network,
652 so use `AsyncHubResult.wait()` sparingly.
652 so use `AsyncHubResult.wait()` sparingly.
653 """
653 """
654
654
655 def _wait_for_outputs(self, timeout=-1):
655 def _wait_for_outputs(self, timeout=-1):
656 """no-op, because HubResults are never incomplete"""
656 """no-op, because HubResults are never incomplete"""
657 self._outputs_ready = True
657 self._outputs_ready = True
658
658
659 def wait(self, timeout=-1):
659 def wait(self, timeout=-1):
660 """wait for result to complete."""
660 """wait for result to complete."""
661 start = time.time()
661 start = time.time()
662 if self._ready:
662 if self._ready:
663 return
663 return
664 local_ids = [m for m in self.msg_ids if m in self._client.outstanding]
664 local_ids = [m for m in self.msg_ids if m in self._client.outstanding]
665 local_ready = self._client.wait(local_ids, timeout)
665 local_ready = self._client.wait(local_ids, timeout)
666 if local_ready:
666 if local_ready:
667 remote_ids = [m for m in self.msg_ids if m not in self._client.results]
667 remote_ids = [m for m in self.msg_ids if m not in self._client.results]
668 if not remote_ids:
668 if not remote_ids:
669 self._ready = True
669 self._ready = True
670 else:
670 else:
671 rdict = self._client.result_status(remote_ids, status_only=False)
671 rdict = self._client.result_status(remote_ids, status_only=False)
672 pending = rdict['pending']
672 pending = rdict['pending']
673 while pending and (timeout < 0 or time.time() < start+timeout):
673 while pending and (timeout < 0 or time.time() < start+timeout):
674 rdict = self._client.result_status(remote_ids, status_only=False)
674 rdict = self._client.result_status(remote_ids, status_only=False)
675 pending = rdict['pending']
675 pending = rdict['pending']
676 if pending:
676 if pending:
677 time.sleep(0.1)
677 time.sleep(0.1)
678 if not pending:
678 if not pending:
679 self._ready = True
679 self._ready = True
680 if self._ready:
680 if self._ready:
681 try:
681 try:
682 results = list(map(self._client.results.get, self.msg_ids))
682 results = list(map(self._client.results.get, self.msg_ids))
683 self._result = results
683 self._result = results
684 if self._single_result:
684 if self._single_result:
685 r = results[0]
685 r = results[0]
686 if isinstance(r, Exception):
686 if isinstance(r, Exception):
687 raise r
687 raise r
688 else:
688 else:
689 results = error.collect_exceptions(results, self._fname)
689 results = error.collect_exceptions(results, self._fname)
690 self._result = self._reconstruct_result(results)
690 self._result = self._reconstruct_result(results)
691 except Exception as e:
691 except Exception as e:
692 self._exception = e
692 self._exception = e
693 self._success = False
693 self._success = False
694 else:
694 else:
695 self._success = True
695 self._success = True
696 finally:
696 finally:
697 self._metadata = [self._client.metadata[mid] for mid in self.msg_ids]
697 self._metadata = [self._client.metadata[mid] for mid in self.msg_ids]
698 if self.owner:
698 if self.owner:
699 [self._client.metadata.pop(mid) for mid in self.msg_ids]
699 [self._client.metadata.pop(mid) for mid in self.msg_ids]
700 [self._client.results.pop(mid) for mid in self.msg_ids]
700 [self._client.results.pop(mid) for mid in self.msg_ids]
701
701
702
702
703 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
703 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
@@ -1,1893 +1,1893 b''
1 """A semi-synchronous Client for IPython parallel"""
1 """A semi-synchronous Client for IPython parallel"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import print_function
6 from __future__ import print_function
7
7
8 import os
8 import os
9 import json
9 import json
10 import sys
10 import sys
11 from threading import Thread, Event
11 from threading import Thread, Event
12 import time
12 import time
13 import warnings
13 import warnings
14 from datetime import datetime
14 from datetime import datetime
15 from getpass import getpass
15 from getpass import getpass
16 from pprint import pprint
16 from pprint import pprint
17
17
18 pjoin = os.path.join
18 pjoin = os.path.join
19
19
20 import zmq
20 import zmq
21
21
22 from IPython.config.configurable import MultipleInstanceError
22 from IPython.config.configurable import MultipleInstanceError
23 from IPython.core.application import BaseIPythonApplication
23 from IPython.core.application import BaseIPythonApplication
24 from IPython.core.profiledir import ProfileDir, ProfileDirError
24 from IPython.core.profiledir import ProfileDir, ProfileDirError
25
25
26 from IPython.utils.capture import RichOutput
26 from IPython.utils.capture import RichOutput
27 from IPython.utils.coloransi import TermColors
27 from IPython.utils.coloransi import TermColors
28 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
28 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
29 from IPython.utils.localinterfaces import localhost, is_local_ip
29 from IPython.utils.localinterfaces import localhost, is_local_ip
30 from IPython.utils.path import get_ipython_dir, compress_user
30 from IPython.utils.path import get_ipython_dir, compress_user
31 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
31 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
32 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
32 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
33 Dict, List, Bool, Set, Any)
33 Dict, List, Bool, Set, Any)
34 from IPython.external.decorator import decorator
34 from decorator import decorator
35
35
36 from IPython.parallel import Reference
36 from IPython.parallel import Reference
37 from IPython.parallel import error
37 from IPython.parallel import error
38 from IPython.parallel import util
38 from IPython.parallel import util
39
39
40 from IPython.kernel.zmq.session import Session, Message
40 from IPython.kernel.zmq.session import Session, Message
41 from IPython.kernel.zmq import serialize
41 from IPython.kernel.zmq import serialize
42
42
43 from .asyncresult import AsyncResult, AsyncHubResult
43 from .asyncresult import AsyncResult, AsyncHubResult
44 from .view import DirectView, LoadBalancedView
44 from .view import DirectView, LoadBalancedView
45
45
46 #--------------------------------------------------------------------------
46 #--------------------------------------------------------------------------
47 # Decorators for Client methods
47 # Decorators for Client methods
48 #--------------------------------------------------------------------------
48 #--------------------------------------------------------------------------
49
49
50
50
51 @decorator
51 @decorator
52 def spin_first(f, self, *args, **kwargs):
52 def spin_first(f, self, *args, **kwargs):
53 """Call spin() to sync state prior to calling the method."""
53 """Call spin() to sync state prior to calling the method."""
54 self.spin()
54 self.spin()
55 return f(self, *args, **kwargs)
55 return f(self, *args, **kwargs)
56
56
57
57
58 #--------------------------------------------------------------------------
58 #--------------------------------------------------------------------------
59 # Classes
59 # Classes
60 #--------------------------------------------------------------------------
60 #--------------------------------------------------------------------------
61
61
62 _no_connection_file_msg = """
62 _no_connection_file_msg = """
63 Failed to connect because no Controller could be found.
63 Failed to connect because no Controller could be found.
64 Please double-check your profile and ensure that a cluster is running.
64 Please double-check your profile and ensure that a cluster is running.
65 """
65 """
66
66
67 class ExecuteReply(RichOutput):
67 class ExecuteReply(RichOutput):
68 """wrapper for finished Execute results"""
68 """wrapper for finished Execute results"""
69 def __init__(self, msg_id, content, metadata):
69 def __init__(self, msg_id, content, metadata):
70 self.msg_id = msg_id
70 self.msg_id = msg_id
71 self._content = content
71 self._content = content
72 self.execution_count = content['execution_count']
72 self.execution_count = content['execution_count']
73 self.metadata = metadata
73 self.metadata = metadata
74
74
75 # RichOutput overrides
75 # RichOutput overrides
76
76
77 @property
77 @property
78 def source(self):
78 def source(self):
79 execute_result = self.metadata['execute_result']
79 execute_result = self.metadata['execute_result']
80 if execute_result:
80 if execute_result:
81 return execute_result.get('source', '')
81 return execute_result.get('source', '')
82
82
83 @property
83 @property
84 def data(self):
84 def data(self):
85 execute_result = self.metadata['execute_result']
85 execute_result = self.metadata['execute_result']
86 if execute_result:
86 if execute_result:
87 return execute_result.get('data', {})
87 return execute_result.get('data', {})
88
88
89 @property
89 @property
90 def _metadata(self):
90 def _metadata(self):
91 execute_result = self.metadata['execute_result']
91 execute_result = self.metadata['execute_result']
92 if execute_result:
92 if execute_result:
93 return execute_result.get('metadata', {})
93 return execute_result.get('metadata', {})
94
94
95 def display(self):
95 def display(self):
96 from IPython.display import publish_display_data
96 from IPython.display import publish_display_data
97 publish_display_data(self.data, self.metadata)
97 publish_display_data(self.data, self.metadata)
98
98
99 def _repr_mime_(self, mime):
99 def _repr_mime_(self, mime):
100 if mime not in self.data:
100 if mime not in self.data:
101 return
101 return
102 data = self.data[mime]
102 data = self.data[mime]
103 if mime in self._metadata:
103 if mime in self._metadata:
104 return data, self._metadata[mime]
104 return data, self._metadata[mime]
105 else:
105 else:
106 return data
106 return data
107
107
108 def __getitem__(self, key):
108 def __getitem__(self, key):
109 return self.metadata[key]
109 return self.metadata[key]
110
110
111 def __getattr__(self, key):
111 def __getattr__(self, key):
112 if key not in self.metadata:
112 if key not in self.metadata:
113 raise AttributeError(key)
113 raise AttributeError(key)
114 return self.metadata[key]
114 return self.metadata[key]
115
115
116 def __repr__(self):
116 def __repr__(self):
117 execute_result = self.metadata['execute_result'] or {'data':{}}
117 execute_result = self.metadata['execute_result'] or {'data':{}}
118 text_out = execute_result['data'].get('text/plain', '')
118 text_out = execute_result['data'].get('text/plain', '')
119 if len(text_out) > 32:
119 if len(text_out) > 32:
120 text_out = text_out[:29] + '...'
120 text_out = text_out[:29] + '...'
121
121
122 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
122 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
123
123
124 def _repr_pretty_(self, p, cycle):
124 def _repr_pretty_(self, p, cycle):
125 execute_result = self.metadata['execute_result'] or {'data':{}}
125 execute_result = self.metadata['execute_result'] or {'data':{}}
126 text_out = execute_result['data'].get('text/plain', '')
126 text_out = execute_result['data'].get('text/plain', '')
127
127
128 if not text_out:
128 if not text_out:
129 return
129 return
130
130
131 try:
131 try:
132 ip = get_ipython()
132 ip = get_ipython()
133 except NameError:
133 except NameError:
134 colors = "NoColor"
134 colors = "NoColor"
135 else:
135 else:
136 colors = ip.colors
136 colors = ip.colors
137
137
138 if colors == "NoColor":
138 if colors == "NoColor":
139 out = normal = ""
139 out = normal = ""
140 else:
140 else:
141 out = TermColors.Red
141 out = TermColors.Red
142 normal = TermColors.Normal
142 normal = TermColors.Normal
143
143
144 if '\n' in text_out and not text_out.startswith('\n'):
144 if '\n' in text_out and not text_out.startswith('\n'):
145 # add newline for multiline reprs
145 # add newline for multiline reprs
146 text_out = '\n' + text_out
146 text_out = '\n' + text_out
147
147
148 p.text(
148 p.text(
149 out + u'Out[%i:%i]: ' % (
149 out + u'Out[%i:%i]: ' % (
150 self.metadata['engine_id'], self.execution_count
150 self.metadata['engine_id'], self.execution_count
151 ) + normal + text_out
151 ) + normal + text_out
152 )
152 )
153
153
154
154
155 class Metadata(dict):
155 class Metadata(dict):
156 """Subclass of dict for initializing metadata values.
156 """Subclass of dict for initializing metadata values.
157
157
158 Attribute access works on keys.
158 Attribute access works on keys.
159
159
160 These objects have a strict set of keys - errors will raise if you try
160 These objects have a strict set of keys - errors will raise if you try
161 to add new keys.
161 to add new keys.
162 """
162 """
163 def __init__(self, *args, **kwargs):
163 def __init__(self, *args, **kwargs):
164 dict.__init__(self)
164 dict.__init__(self)
165 md = {'msg_id' : None,
165 md = {'msg_id' : None,
166 'submitted' : None,
166 'submitted' : None,
167 'started' : None,
167 'started' : None,
168 'completed' : None,
168 'completed' : None,
169 'received' : None,
169 'received' : None,
170 'engine_uuid' : None,
170 'engine_uuid' : None,
171 'engine_id' : None,
171 'engine_id' : None,
172 'follow' : None,
172 'follow' : None,
173 'after' : None,
173 'after' : None,
174 'status' : None,
174 'status' : None,
175
175
176 'execute_input' : None,
176 'execute_input' : None,
177 'execute_result' : None,
177 'execute_result' : None,
178 'error' : None,
178 'error' : None,
179 'stdout' : '',
179 'stdout' : '',
180 'stderr' : '',
180 'stderr' : '',
181 'outputs' : [],
181 'outputs' : [],
182 'data': {},
182 'data': {},
183 'outputs_ready' : False,
183 'outputs_ready' : False,
184 }
184 }
185 self.update(md)
185 self.update(md)
186 self.update(dict(*args, **kwargs))
186 self.update(dict(*args, **kwargs))
187
187
188 def __getattr__(self, key):
188 def __getattr__(self, key):
189 """getattr aliased to getitem"""
189 """getattr aliased to getitem"""
190 if key in self:
190 if key in self:
191 return self[key]
191 return self[key]
192 else:
192 else:
193 raise AttributeError(key)
193 raise AttributeError(key)
194
194
195 def __setattr__(self, key, value):
195 def __setattr__(self, key, value):
196 """setattr aliased to setitem, with strict"""
196 """setattr aliased to setitem, with strict"""
197 if key in self:
197 if key in self:
198 self[key] = value
198 self[key] = value
199 else:
199 else:
200 raise AttributeError(key)
200 raise AttributeError(key)
201
201
202 def __setitem__(self, key, value):
202 def __setitem__(self, key, value):
203 """strict static key enforcement"""
203 """strict static key enforcement"""
204 if key in self:
204 if key in self:
205 dict.__setitem__(self, key, value)
205 dict.__setitem__(self, key, value)
206 else:
206 else:
207 raise KeyError(key)
207 raise KeyError(key)
208
208
209
209
210 class Client(HasTraits):
210 class Client(HasTraits):
211 """A semi-synchronous client to the IPython ZMQ cluster
211 """A semi-synchronous client to the IPython ZMQ cluster
212
212
213 Parameters
213 Parameters
214 ----------
214 ----------
215
215
216 url_file : str/unicode; path to ipcontroller-client.json
216 url_file : str/unicode; path to ipcontroller-client.json
217 This JSON file should contain all the information needed to connect to a cluster,
217 This JSON file should contain all the information needed to connect to a cluster,
218 and is likely the only argument needed.
218 and is likely the only argument needed.
219 Connection information for the Hub's registration. If a json connector
219 Connection information for the Hub's registration. If a json connector
220 file is given, then likely no further configuration is necessary.
220 file is given, then likely no further configuration is necessary.
221 [Default: use profile]
221 [Default: use profile]
222 profile : bytes
222 profile : bytes
223 The name of the Cluster profile to be used to find connector information.
223 The name of the Cluster profile to be used to find connector information.
224 If run from an IPython application, the default profile will be the same
224 If run from an IPython application, the default profile will be the same
225 as the running application, otherwise it will be 'default'.
225 as the running application, otherwise it will be 'default'.
226 cluster_id : str
226 cluster_id : str
227 String id to added to runtime files, to prevent name collisions when using
227 String id to added to runtime files, to prevent name collisions when using
228 multiple clusters with a single profile simultaneously.
228 multiple clusters with a single profile simultaneously.
229 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
229 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
230 Since this is text inserted into filenames, typical recommendations apply:
230 Since this is text inserted into filenames, typical recommendations apply:
231 Simple character strings are ideal, and spaces are not recommended (but
231 Simple character strings are ideal, and spaces are not recommended (but
232 should generally work)
232 should generally work)
233 context : zmq.Context
233 context : zmq.Context
234 Pass an existing zmq.Context instance, otherwise the client will create its own.
234 Pass an existing zmq.Context instance, otherwise the client will create its own.
235 debug : bool
235 debug : bool
236 flag for lots of message printing for debug purposes
236 flag for lots of message printing for debug purposes
237 timeout : int/float
237 timeout : int/float
238 time (in seconds) to wait for connection replies from the Hub
238 time (in seconds) to wait for connection replies from the Hub
239 [Default: 10]
239 [Default: 10]
240
240
241 #-------------- session related args ----------------
241 #-------------- session related args ----------------
242
242
243 config : Config object
243 config : Config object
244 If specified, this will be relayed to the Session for configuration
244 If specified, this will be relayed to the Session for configuration
245 username : str
245 username : str
246 set username for the session object
246 set username for the session object
247
247
248 #-------------- ssh related args ----------------
248 #-------------- ssh related args ----------------
249 # These are args for configuring the ssh tunnel to be used
249 # These are args for configuring the ssh tunnel to be used
250 # credentials are used to forward connections over ssh to the Controller
250 # credentials are used to forward connections over ssh to the Controller
251 # Note that the ip given in `addr` needs to be relative to sshserver
251 # Note that the ip given in `addr` needs to be relative to sshserver
252 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
252 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
253 # and set sshserver as the same machine the Controller is on. However,
253 # and set sshserver as the same machine the Controller is on. However,
254 # the only requirement is that sshserver is able to see the Controller
254 # the only requirement is that sshserver is able to see the Controller
255 # (i.e. is within the same trusted network).
255 # (i.e. is within the same trusted network).
256
256
257 sshserver : str
257 sshserver : str
258 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
258 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
259 If keyfile or password is specified, and this is not, it will default to
259 If keyfile or password is specified, and this is not, it will default to
260 the ip given in addr.
260 the ip given in addr.
261 sshkey : str; path to ssh private key file
261 sshkey : str; path to ssh private key file
262 This specifies a key to be used in ssh login, default None.
262 This specifies a key to be used in ssh login, default None.
263 Regular default ssh keys will be used without specifying this argument.
263 Regular default ssh keys will be used without specifying this argument.
264 password : str
264 password : str
265 Your ssh password to sshserver. Note that if this is left None,
265 Your ssh password to sshserver. Note that if this is left None,
266 you will be prompted for it if passwordless key based login is unavailable.
266 you will be prompted for it if passwordless key based login is unavailable.
267 paramiko : bool
267 paramiko : bool
268 flag for whether to use paramiko instead of shell ssh for tunneling.
268 flag for whether to use paramiko instead of shell ssh for tunneling.
269 [default: True on win32, False else]
269 [default: True on win32, False else]
270
270
271
271
272 Attributes
272 Attributes
273 ----------
273 ----------
274
274
275 ids : list of int engine IDs
275 ids : list of int engine IDs
276 requesting the ids attribute always synchronizes
276 requesting the ids attribute always synchronizes
277 the registration state. To request ids without synchronization,
277 the registration state. To request ids without synchronization,
278 use semi-private _ids attributes.
278 use semi-private _ids attributes.
279
279
280 history : list of msg_ids
280 history : list of msg_ids
281 a list of msg_ids, keeping track of all the execution
281 a list of msg_ids, keeping track of all the execution
282 messages you have submitted in order.
282 messages you have submitted in order.
283
283
284 outstanding : set of msg_ids
284 outstanding : set of msg_ids
285 a set of msg_ids that have been submitted, but whose
285 a set of msg_ids that have been submitted, but whose
286 results have not yet been received.
286 results have not yet been received.
287
287
288 results : dict
288 results : dict
289 a dict of all our results, keyed by msg_id
289 a dict of all our results, keyed by msg_id
290
290
291 block : bool
291 block : bool
292 determines default behavior when block not specified
292 determines default behavior when block not specified
293 in execution methods
293 in execution methods
294
294
295 Methods
295 Methods
296 -------
296 -------
297
297
298 spin
298 spin
299 flushes incoming results and registration state changes
299 flushes incoming results and registration state changes
300 control methods spin, and requesting `ids` also ensures up to date
300 control methods spin, and requesting `ids` also ensures up to date
301
301
302 wait
302 wait
303 wait on one or more msg_ids
303 wait on one or more msg_ids
304
304
305 execution methods
305 execution methods
306 apply
306 apply
307 legacy: execute, run
307 legacy: execute, run
308
308
309 data movement
309 data movement
310 push, pull, scatter, gather
310 push, pull, scatter, gather
311
311
312 query methods
312 query methods
313 queue_status, get_result, purge, result_status
313 queue_status, get_result, purge, result_status
314
314
315 control methods
315 control methods
316 abort, shutdown
316 abort, shutdown
317
317
318 """
318 """
319
319
320
320
321 block = Bool(False)
321 block = Bool(False)
322 outstanding = Set()
322 outstanding = Set()
323 results = Instance('collections.defaultdict', (dict,))
323 results = Instance('collections.defaultdict', (dict,))
324 metadata = Instance('collections.defaultdict', (Metadata,))
324 metadata = Instance('collections.defaultdict', (Metadata,))
325 history = List()
325 history = List()
326 debug = Bool(False)
326 debug = Bool(False)
327 _spin_thread = Any()
327 _spin_thread = Any()
328 _stop_spinning = Any()
328 _stop_spinning = Any()
329
329
330 profile=Unicode()
330 profile=Unicode()
331 def _profile_default(self):
331 def _profile_default(self):
332 if BaseIPythonApplication.initialized():
332 if BaseIPythonApplication.initialized():
333 # an IPython app *might* be running, try to get its profile
333 # an IPython app *might* be running, try to get its profile
334 try:
334 try:
335 return BaseIPythonApplication.instance().profile
335 return BaseIPythonApplication.instance().profile
336 except (AttributeError, MultipleInstanceError):
336 except (AttributeError, MultipleInstanceError):
337 # could be a *different* subclass of config.Application,
337 # could be a *different* subclass of config.Application,
338 # which would raise one of these two errors.
338 # which would raise one of these two errors.
339 return u'default'
339 return u'default'
340 else:
340 else:
341 return u'default'
341 return u'default'
342
342
343
343
344 _outstanding_dict = Instance('collections.defaultdict', (set,))
344 _outstanding_dict = Instance('collections.defaultdict', (set,))
345 _ids = List()
345 _ids = List()
346 _connected=Bool(False)
346 _connected=Bool(False)
347 _ssh=Bool(False)
347 _ssh=Bool(False)
348 _context = Instance('zmq.Context')
348 _context = Instance('zmq.Context')
349 _config = Dict()
349 _config = Dict()
350 _engines=Instance(util.ReverseDict, (), {})
350 _engines=Instance(util.ReverseDict, (), {})
351 # _hub_socket=Instance('zmq.Socket')
351 # _hub_socket=Instance('zmq.Socket')
352 _query_socket=Instance('zmq.Socket')
352 _query_socket=Instance('zmq.Socket')
353 _control_socket=Instance('zmq.Socket')
353 _control_socket=Instance('zmq.Socket')
354 _iopub_socket=Instance('zmq.Socket')
354 _iopub_socket=Instance('zmq.Socket')
355 _notification_socket=Instance('zmq.Socket')
355 _notification_socket=Instance('zmq.Socket')
356 _mux_socket=Instance('zmq.Socket')
356 _mux_socket=Instance('zmq.Socket')
357 _task_socket=Instance('zmq.Socket')
357 _task_socket=Instance('zmq.Socket')
358 _task_scheme=Unicode()
358 _task_scheme=Unicode()
359 _closed = False
359 _closed = False
360 _ignored_control_replies=Integer(0)
360 _ignored_control_replies=Integer(0)
361 _ignored_hub_replies=Integer(0)
361 _ignored_hub_replies=Integer(0)
362
362
363 def __new__(self, *args, **kw):
363 def __new__(self, *args, **kw):
364 # don't raise on positional args
364 # don't raise on positional args
365 return HasTraits.__new__(self, **kw)
365 return HasTraits.__new__(self, **kw)
366
366
367 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
367 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
368 context=None, debug=False,
368 context=None, debug=False,
369 sshserver=None, sshkey=None, password=None, paramiko=None,
369 sshserver=None, sshkey=None, password=None, paramiko=None,
370 timeout=10, cluster_id=None, **extra_args
370 timeout=10, cluster_id=None, **extra_args
371 ):
371 ):
372 if profile:
372 if profile:
373 super(Client, self).__init__(debug=debug, profile=profile)
373 super(Client, self).__init__(debug=debug, profile=profile)
374 else:
374 else:
375 super(Client, self).__init__(debug=debug)
375 super(Client, self).__init__(debug=debug)
376 if context is None:
376 if context is None:
377 context = zmq.Context.instance()
377 context = zmq.Context.instance()
378 self._context = context
378 self._context = context
379 self._stop_spinning = Event()
379 self._stop_spinning = Event()
380
380
381 if 'url_or_file' in extra_args:
381 if 'url_or_file' in extra_args:
382 url_file = extra_args['url_or_file']
382 url_file = extra_args['url_or_file']
383 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
383 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
384
384
385 if url_file and util.is_url(url_file):
385 if url_file and util.is_url(url_file):
386 raise ValueError("single urls cannot be specified, url-files must be used.")
386 raise ValueError("single urls cannot be specified, url-files must be used.")
387
387
388 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
388 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
389
389
390 no_file_msg = '\n'.join([
390 no_file_msg = '\n'.join([
391 "You have attempted to connect to an IPython Cluster but no Controller could be found.",
391 "You have attempted to connect to an IPython Cluster but no Controller could be found.",
392 "Please double-check your configuration and ensure that a cluster is running.",
392 "Please double-check your configuration and ensure that a cluster is running.",
393 ])
393 ])
394
394
395 if self._cd is not None:
395 if self._cd is not None:
396 if url_file is None:
396 if url_file is None:
397 if not cluster_id:
397 if not cluster_id:
398 client_json = 'ipcontroller-client.json'
398 client_json = 'ipcontroller-client.json'
399 else:
399 else:
400 client_json = 'ipcontroller-%s-client.json' % cluster_id
400 client_json = 'ipcontroller-%s-client.json' % cluster_id
401 url_file = pjoin(self._cd.security_dir, client_json)
401 url_file = pjoin(self._cd.security_dir, client_json)
402 if not os.path.exists(url_file):
402 if not os.path.exists(url_file):
403 msg = '\n'.join([
403 msg = '\n'.join([
404 "Connection file %r not found." % compress_user(url_file),
404 "Connection file %r not found." % compress_user(url_file),
405 no_file_msg,
405 no_file_msg,
406 ])
406 ])
407 raise IOError(msg)
407 raise IOError(msg)
408 if url_file is None:
408 if url_file is None:
409 raise IOError(no_file_msg)
409 raise IOError(no_file_msg)
410
410
411 if not os.path.exists(url_file):
411 if not os.path.exists(url_file):
412 # Connection file explicitly specified, but not found
412 # Connection file explicitly specified, but not found
413 raise IOError("Connection file %r not found. Is a controller running?" % \
413 raise IOError("Connection file %r not found. Is a controller running?" % \
414 compress_user(url_file)
414 compress_user(url_file)
415 )
415 )
416
416
417 with open(url_file) as f:
417 with open(url_file) as f:
418 cfg = json.load(f)
418 cfg = json.load(f)
419
419
420 self._task_scheme = cfg['task_scheme']
420 self._task_scheme = cfg['task_scheme']
421
421
422 # sync defaults from args, json:
422 # sync defaults from args, json:
423 if sshserver:
423 if sshserver:
424 cfg['ssh'] = sshserver
424 cfg['ssh'] = sshserver
425
425
426 location = cfg.setdefault('location', None)
426 location = cfg.setdefault('location', None)
427
427
428 proto,addr = cfg['interface'].split('://')
428 proto,addr = cfg['interface'].split('://')
429 addr = util.disambiguate_ip_address(addr, location)
429 addr = util.disambiguate_ip_address(addr, location)
430 cfg['interface'] = "%s://%s" % (proto, addr)
430 cfg['interface'] = "%s://%s" % (proto, addr)
431
431
432 # turn interface,port into full urls:
432 # turn interface,port into full urls:
433 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
433 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
434 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
434 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
435
435
436 url = cfg['registration']
436 url = cfg['registration']
437
437
438 if location is not None and addr == localhost():
438 if location is not None and addr == localhost():
439 # location specified, and connection is expected to be local
439 # location specified, and connection is expected to be local
440 if not is_local_ip(location) and not sshserver:
440 if not is_local_ip(location) and not sshserver:
441 # load ssh from JSON *only* if the controller is not on
441 # load ssh from JSON *only* if the controller is not on
442 # this machine
442 # this machine
443 sshserver=cfg['ssh']
443 sshserver=cfg['ssh']
444 if not is_local_ip(location) and not sshserver:
444 if not is_local_ip(location) and not sshserver:
445 # warn if no ssh specified, but SSH is probably needed
445 # warn if no ssh specified, but SSH is probably needed
446 # This is only a warning, because the most likely cause
446 # This is only a warning, because the most likely cause
447 # is a local Controller on a laptop whose IP is dynamic
447 # is a local Controller on a laptop whose IP is dynamic
448 warnings.warn("""
448 warnings.warn("""
449 Controller appears to be listening on localhost, but not on this machine.
449 Controller appears to be listening on localhost, but not on this machine.
450 If this is true, you should specify Client(...,sshserver='you@%s')
450 If this is true, you should specify Client(...,sshserver='you@%s')
451 or instruct your controller to listen on an external IP."""%location,
451 or instruct your controller to listen on an external IP."""%location,
452 RuntimeWarning)
452 RuntimeWarning)
453 elif not sshserver:
453 elif not sshserver:
454 # otherwise sync with cfg
454 # otherwise sync with cfg
455 sshserver = cfg['ssh']
455 sshserver = cfg['ssh']
456
456
457 self._config = cfg
457 self._config = cfg
458
458
459 self._ssh = bool(sshserver or sshkey or password)
459 self._ssh = bool(sshserver or sshkey or password)
460 if self._ssh and sshserver is None:
460 if self._ssh and sshserver is None:
461 # default to ssh via localhost
461 # default to ssh via localhost
462 sshserver = addr
462 sshserver = addr
463 if self._ssh and password is None:
463 if self._ssh and password is None:
464 from zmq.ssh import tunnel
464 from zmq.ssh import tunnel
465 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
465 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
466 password=False
466 password=False
467 else:
467 else:
468 password = getpass("SSH Password for %s: "%sshserver)
468 password = getpass("SSH Password for %s: "%sshserver)
469 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
469 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
470
470
471 # configure and construct the session
471 # configure and construct the session
472 try:
472 try:
473 extra_args['packer'] = cfg['pack']
473 extra_args['packer'] = cfg['pack']
474 extra_args['unpacker'] = cfg['unpack']
474 extra_args['unpacker'] = cfg['unpack']
475 extra_args['key'] = cast_bytes(cfg['key'])
475 extra_args['key'] = cast_bytes(cfg['key'])
476 extra_args['signature_scheme'] = cfg['signature_scheme']
476 extra_args['signature_scheme'] = cfg['signature_scheme']
477 except KeyError as exc:
477 except KeyError as exc:
478 msg = '\n'.join([
478 msg = '\n'.join([
479 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
479 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
480 "If you are reusing connection files, remove them and start ipcontroller again."
480 "If you are reusing connection files, remove them and start ipcontroller again."
481 ])
481 ])
482 raise ValueError(msg.format(exc.message))
482 raise ValueError(msg.format(exc.message))
483
483
484 self.session = Session(**extra_args)
484 self.session = Session(**extra_args)
485
485
486 self._query_socket = self._context.socket(zmq.DEALER)
486 self._query_socket = self._context.socket(zmq.DEALER)
487
487
488 if self._ssh:
488 if self._ssh:
489 from zmq.ssh import tunnel
489 from zmq.ssh import tunnel
490 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
490 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
491 else:
491 else:
492 self._query_socket.connect(cfg['registration'])
492 self._query_socket.connect(cfg['registration'])
493
493
494 self.session.debug = self.debug
494 self.session.debug = self.debug
495
495
496 self._notification_handlers = {'registration_notification' : self._register_engine,
496 self._notification_handlers = {'registration_notification' : self._register_engine,
497 'unregistration_notification' : self._unregister_engine,
497 'unregistration_notification' : self._unregister_engine,
498 'shutdown_notification' : lambda msg: self.close(),
498 'shutdown_notification' : lambda msg: self.close(),
499 }
499 }
500 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
500 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
501 'apply_reply' : self._handle_apply_reply}
501 'apply_reply' : self._handle_apply_reply}
502
502
503 try:
503 try:
504 self._connect(sshserver, ssh_kwargs, timeout)
504 self._connect(sshserver, ssh_kwargs, timeout)
505 except:
505 except:
506 self.close(linger=0)
506 self.close(linger=0)
507 raise
507 raise
508
508
509 # last step: setup magics, if we are in IPython:
509 # last step: setup magics, if we are in IPython:
510
510
511 try:
511 try:
512 ip = get_ipython()
512 ip = get_ipython()
513 except NameError:
513 except NameError:
514 return
514 return
515 else:
515 else:
516 if 'px' not in ip.magics_manager.magics:
516 if 'px' not in ip.magics_manager.magics:
517 # in IPython but we are the first Client.
517 # in IPython but we are the first Client.
518 # activate a default view for parallel magics.
518 # activate a default view for parallel magics.
519 self.activate()
519 self.activate()
520
520
521 def __del__(self):
521 def __del__(self):
522 """cleanup sockets, but _not_ context."""
522 """cleanup sockets, but _not_ context."""
523 self.close()
523 self.close()
524
524
525 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
525 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
526 if ipython_dir is None:
526 if ipython_dir is None:
527 ipython_dir = get_ipython_dir()
527 ipython_dir = get_ipython_dir()
528 if profile_dir is not None:
528 if profile_dir is not None:
529 try:
529 try:
530 self._cd = ProfileDir.find_profile_dir(profile_dir)
530 self._cd = ProfileDir.find_profile_dir(profile_dir)
531 return
531 return
532 except ProfileDirError:
532 except ProfileDirError:
533 pass
533 pass
534 elif profile is not None:
534 elif profile is not None:
535 try:
535 try:
536 self._cd = ProfileDir.find_profile_dir_by_name(
536 self._cd = ProfileDir.find_profile_dir_by_name(
537 ipython_dir, profile)
537 ipython_dir, profile)
538 return
538 return
539 except ProfileDirError:
539 except ProfileDirError:
540 pass
540 pass
541 self._cd = None
541 self._cd = None
542
542
543 def _update_engines(self, engines):
543 def _update_engines(self, engines):
544 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
544 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
545 for k,v in iteritems(engines):
545 for k,v in iteritems(engines):
546 eid = int(k)
546 eid = int(k)
547 if eid not in self._engines:
547 if eid not in self._engines:
548 self._ids.append(eid)
548 self._ids.append(eid)
549 self._engines[eid] = v
549 self._engines[eid] = v
550 self._ids = sorted(self._ids)
550 self._ids = sorted(self._ids)
551 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
551 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
552 self._task_scheme == 'pure' and self._task_socket:
552 self._task_scheme == 'pure' and self._task_socket:
553 self._stop_scheduling_tasks()
553 self._stop_scheduling_tasks()
554
554
555 def _stop_scheduling_tasks(self):
555 def _stop_scheduling_tasks(self):
556 """Stop scheduling tasks because an engine has been unregistered
556 """Stop scheduling tasks because an engine has been unregistered
557 from a pure ZMQ scheduler.
557 from a pure ZMQ scheduler.
558 """
558 """
559 self._task_socket.close()
559 self._task_socket.close()
560 self._task_socket = None
560 self._task_socket = None
561 msg = "An engine has been unregistered, and we are using pure " +\
561 msg = "An engine has been unregistered, and we are using pure " +\
562 "ZMQ task scheduling. Task farming will be disabled."
562 "ZMQ task scheduling. Task farming will be disabled."
563 if self.outstanding:
563 if self.outstanding:
564 msg += " If you were running tasks when this happened, " +\
564 msg += " If you were running tasks when this happened, " +\
565 "some `outstanding` msg_ids may never resolve."
565 "some `outstanding` msg_ids may never resolve."
566 warnings.warn(msg, RuntimeWarning)
566 warnings.warn(msg, RuntimeWarning)
567
567
568 def _build_targets(self, targets):
568 def _build_targets(self, targets):
569 """Turn valid target IDs or 'all' into two lists:
569 """Turn valid target IDs or 'all' into two lists:
570 (int_ids, uuids).
570 (int_ids, uuids).
571 """
571 """
572 if not self._ids:
572 if not self._ids:
573 # flush notification socket if no engines yet, just in case
573 # flush notification socket if no engines yet, just in case
574 if not self.ids:
574 if not self.ids:
575 raise error.NoEnginesRegistered("Can't build targets without any engines")
575 raise error.NoEnginesRegistered("Can't build targets without any engines")
576
576
577 if targets is None:
577 if targets is None:
578 targets = self._ids
578 targets = self._ids
579 elif isinstance(targets, string_types):
579 elif isinstance(targets, string_types):
580 if targets.lower() == 'all':
580 if targets.lower() == 'all':
581 targets = self._ids
581 targets = self._ids
582 else:
582 else:
583 raise TypeError("%r not valid str target, must be 'all'"%(targets))
583 raise TypeError("%r not valid str target, must be 'all'"%(targets))
584 elif isinstance(targets, int):
584 elif isinstance(targets, int):
585 if targets < 0:
585 if targets < 0:
586 targets = self.ids[targets]
586 targets = self.ids[targets]
587 if targets not in self._ids:
587 if targets not in self._ids:
588 raise IndexError("No such engine: %i"%targets)
588 raise IndexError("No such engine: %i"%targets)
589 targets = [targets]
589 targets = [targets]
590
590
591 if isinstance(targets, slice):
591 if isinstance(targets, slice):
592 indices = list(range(len(self._ids))[targets])
592 indices = list(range(len(self._ids))[targets])
593 ids = self.ids
593 ids = self.ids
594 targets = [ ids[i] for i in indices ]
594 targets = [ ids[i] for i in indices ]
595
595
596 if not isinstance(targets, (tuple, list, xrange)):
596 if not isinstance(targets, (tuple, list, xrange)):
597 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
597 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
598
598
599 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
599 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
600
600
601 def _connect(self, sshserver, ssh_kwargs, timeout):
601 def _connect(self, sshserver, ssh_kwargs, timeout):
602 """setup all our socket connections to the cluster. This is called from
602 """setup all our socket connections to the cluster. This is called from
603 __init__."""
603 __init__."""
604
604
605 # Maybe allow reconnecting?
605 # Maybe allow reconnecting?
606 if self._connected:
606 if self._connected:
607 return
607 return
608 self._connected=True
608 self._connected=True
609
609
610 def connect_socket(s, url):
610 def connect_socket(s, url):
611 if self._ssh:
611 if self._ssh:
612 from zmq.ssh import tunnel
612 from zmq.ssh import tunnel
613 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
613 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
614 else:
614 else:
615 return s.connect(url)
615 return s.connect(url)
616
616
617 self.session.send(self._query_socket, 'connection_request')
617 self.session.send(self._query_socket, 'connection_request')
618 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
618 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
619 poller = zmq.Poller()
619 poller = zmq.Poller()
620 poller.register(self._query_socket, zmq.POLLIN)
620 poller.register(self._query_socket, zmq.POLLIN)
621 # poll expects milliseconds, timeout is seconds
621 # poll expects milliseconds, timeout is seconds
622 evts = poller.poll(timeout*1000)
622 evts = poller.poll(timeout*1000)
623 if not evts:
623 if not evts:
624 raise error.TimeoutError("Hub connection request timed out")
624 raise error.TimeoutError("Hub connection request timed out")
625 idents,msg = self.session.recv(self._query_socket,mode=0)
625 idents,msg = self.session.recv(self._query_socket,mode=0)
626 if self.debug:
626 if self.debug:
627 pprint(msg)
627 pprint(msg)
628 content = msg['content']
628 content = msg['content']
629 # self._config['registration'] = dict(content)
629 # self._config['registration'] = dict(content)
630 cfg = self._config
630 cfg = self._config
631 if content['status'] == 'ok':
631 if content['status'] == 'ok':
632 self._mux_socket = self._context.socket(zmq.DEALER)
632 self._mux_socket = self._context.socket(zmq.DEALER)
633 connect_socket(self._mux_socket, cfg['mux'])
633 connect_socket(self._mux_socket, cfg['mux'])
634
634
635 self._task_socket = self._context.socket(zmq.DEALER)
635 self._task_socket = self._context.socket(zmq.DEALER)
636 connect_socket(self._task_socket, cfg['task'])
636 connect_socket(self._task_socket, cfg['task'])
637
637
638 self._notification_socket = self._context.socket(zmq.SUB)
638 self._notification_socket = self._context.socket(zmq.SUB)
639 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
639 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
640 connect_socket(self._notification_socket, cfg['notification'])
640 connect_socket(self._notification_socket, cfg['notification'])
641
641
642 self._control_socket = self._context.socket(zmq.DEALER)
642 self._control_socket = self._context.socket(zmq.DEALER)
643 connect_socket(self._control_socket, cfg['control'])
643 connect_socket(self._control_socket, cfg['control'])
644
644
645 self._iopub_socket = self._context.socket(zmq.SUB)
645 self._iopub_socket = self._context.socket(zmq.SUB)
646 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
646 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
647 connect_socket(self._iopub_socket, cfg['iopub'])
647 connect_socket(self._iopub_socket, cfg['iopub'])
648
648
649 self._update_engines(dict(content['engines']))
649 self._update_engines(dict(content['engines']))
650 else:
650 else:
651 self._connected = False
651 self._connected = False
652 raise Exception("Failed to connect!")
652 raise Exception("Failed to connect!")
653
653
654 #--------------------------------------------------------------------------
654 #--------------------------------------------------------------------------
655 # handlers and callbacks for incoming messages
655 # handlers and callbacks for incoming messages
656 #--------------------------------------------------------------------------
656 #--------------------------------------------------------------------------
657
657
658 def _unwrap_exception(self, content):
658 def _unwrap_exception(self, content):
659 """unwrap exception, and remap engine_id to int."""
659 """unwrap exception, and remap engine_id to int."""
660 e = error.unwrap_exception(content)
660 e = error.unwrap_exception(content)
661 # print e.traceback
661 # print e.traceback
662 if e.engine_info:
662 if e.engine_info:
663 e_uuid = e.engine_info['engine_uuid']
663 e_uuid = e.engine_info['engine_uuid']
664 eid = self._engines[e_uuid]
664 eid = self._engines[e_uuid]
665 e.engine_info['engine_id'] = eid
665 e.engine_info['engine_id'] = eid
666 return e
666 return e
667
667
668 def _extract_metadata(self, msg):
668 def _extract_metadata(self, msg):
669 header = msg['header']
669 header = msg['header']
670 parent = msg['parent_header']
670 parent = msg['parent_header']
671 msg_meta = msg['metadata']
671 msg_meta = msg['metadata']
672 content = msg['content']
672 content = msg['content']
673 md = {'msg_id' : parent['msg_id'],
673 md = {'msg_id' : parent['msg_id'],
674 'received' : datetime.now(),
674 'received' : datetime.now(),
675 'engine_uuid' : msg_meta.get('engine', None),
675 'engine_uuid' : msg_meta.get('engine', None),
676 'follow' : msg_meta.get('follow', []),
676 'follow' : msg_meta.get('follow', []),
677 'after' : msg_meta.get('after', []),
677 'after' : msg_meta.get('after', []),
678 'status' : content['status'],
678 'status' : content['status'],
679 }
679 }
680
680
681 if md['engine_uuid'] is not None:
681 if md['engine_uuid'] is not None:
682 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
682 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
683
683
684 if 'date' in parent:
684 if 'date' in parent:
685 md['submitted'] = parent['date']
685 md['submitted'] = parent['date']
686 if 'started' in msg_meta:
686 if 'started' in msg_meta:
687 md['started'] = parse_date(msg_meta['started'])
687 md['started'] = parse_date(msg_meta['started'])
688 if 'date' in header:
688 if 'date' in header:
689 md['completed'] = header['date']
689 md['completed'] = header['date']
690 return md
690 return md
691
691
692 def _register_engine(self, msg):
692 def _register_engine(self, msg):
693 """Register a new engine, and update our connection info."""
693 """Register a new engine, and update our connection info."""
694 content = msg['content']
694 content = msg['content']
695 eid = content['id']
695 eid = content['id']
696 d = {eid : content['uuid']}
696 d = {eid : content['uuid']}
697 self._update_engines(d)
697 self._update_engines(d)
698
698
699 def _unregister_engine(self, msg):
699 def _unregister_engine(self, msg):
700 """Unregister an engine that has died."""
700 """Unregister an engine that has died."""
701 content = msg['content']
701 content = msg['content']
702 eid = int(content['id'])
702 eid = int(content['id'])
703 if eid in self._ids:
703 if eid in self._ids:
704 self._ids.remove(eid)
704 self._ids.remove(eid)
705 uuid = self._engines.pop(eid)
705 uuid = self._engines.pop(eid)
706
706
707 self._handle_stranded_msgs(eid, uuid)
707 self._handle_stranded_msgs(eid, uuid)
708
708
709 if self._task_socket and self._task_scheme == 'pure':
709 if self._task_socket and self._task_scheme == 'pure':
710 self._stop_scheduling_tasks()
710 self._stop_scheduling_tasks()
711
711
712 def _handle_stranded_msgs(self, eid, uuid):
712 def _handle_stranded_msgs(self, eid, uuid):
713 """Handle messages known to be on an engine when the engine unregisters.
713 """Handle messages known to be on an engine when the engine unregisters.
714
714
715 It is possible that this will fire prematurely - that is, an engine will
715 It is possible that this will fire prematurely - that is, an engine will
716 go down after completing a result, and the client will be notified
716 go down after completing a result, and the client will be notified
717 of the unregistration and later receive the successful result.
717 of the unregistration and later receive the successful result.
718 """
718 """
719
719
720 outstanding = self._outstanding_dict[uuid]
720 outstanding = self._outstanding_dict[uuid]
721
721
722 for msg_id in list(outstanding):
722 for msg_id in list(outstanding):
723 if msg_id in self.results:
723 if msg_id in self.results:
724 # we already
724 # we already
725 continue
725 continue
726 try:
726 try:
727 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
727 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
728 except:
728 except:
729 content = error.wrap_exception()
729 content = error.wrap_exception()
730 # build a fake message:
730 # build a fake message:
731 msg = self.session.msg('apply_reply', content=content)
731 msg = self.session.msg('apply_reply', content=content)
732 msg['parent_header']['msg_id'] = msg_id
732 msg['parent_header']['msg_id'] = msg_id
733 msg['metadata']['engine'] = uuid
733 msg['metadata']['engine'] = uuid
734 self._handle_apply_reply(msg)
734 self._handle_apply_reply(msg)
735
735
736 def _handle_execute_reply(self, msg):
736 def _handle_execute_reply(self, msg):
737 """Save the reply to an execute_request into our results.
737 """Save the reply to an execute_request into our results.
738
738
739 execute messages are never actually used. apply is used instead.
739 execute messages are never actually used. apply is used instead.
740 """
740 """
741
741
742 parent = msg['parent_header']
742 parent = msg['parent_header']
743 msg_id = parent['msg_id']
743 msg_id = parent['msg_id']
744 if msg_id not in self.outstanding:
744 if msg_id not in self.outstanding:
745 if msg_id in self.history:
745 if msg_id in self.history:
746 print("got stale result: %s"%msg_id)
746 print("got stale result: %s"%msg_id)
747 else:
747 else:
748 print("got unknown result: %s"%msg_id)
748 print("got unknown result: %s"%msg_id)
749 else:
749 else:
750 self.outstanding.remove(msg_id)
750 self.outstanding.remove(msg_id)
751
751
752 content = msg['content']
752 content = msg['content']
753 header = msg['header']
753 header = msg['header']
754
754
755 # construct metadata:
755 # construct metadata:
756 md = self.metadata[msg_id]
756 md = self.metadata[msg_id]
757 md.update(self._extract_metadata(msg))
757 md.update(self._extract_metadata(msg))
758 # is this redundant?
758 # is this redundant?
759 self.metadata[msg_id] = md
759 self.metadata[msg_id] = md
760
760
761 e_outstanding = self._outstanding_dict[md['engine_uuid']]
761 e_outstanding = self._outstanding_dict[md['engine_uuid']]
762 if msg_id in e_outstanding:
762 if msg_id in e_outstanding:
763 e_outstanding.remove(msg_id)
763 e_outstanding.remove(msg_id)
764
764
765 # construct result:
765 # construct result:
766 if content['status'] == 'ok':
766 if content['status'] == 'ok':
767 self.results[msg_id] = ExecuteReply(msg_id, content, md)
767 self.results[msg_id] = ExecuteReply(msg_id, content, md)
768 elif content['status'] == 'aborted':
768 elif content['status'] == 'aborted':
769 self.results[msg_id] = error.TaskAborted(msg_id)
769 self.results[msg_id] = error.TaskAborted(msg_id)
770 elif content['status'] == 'resubmitted':
770 elif content['status'] == 'resubmitted':
771 # TODO: handle resubmission
771 # TODO: handle resubmission
772 pass
772 pass
773 else:
773 else:
774 self.results[msg_id] = self._unwrap_exception(content)
774 self.results[msg_id] = self._unwrap_exception(content)
775
775
776 def _handle_apply_reply(self, msg):
776 def _handle_apply_reply(self, msg):
777 """Save the reply to an apply_request into our results."""
777 """Save the reply to an apply_request into our results."""
778 parent = msg['parent_header']
778 parent = msg['parent_header']
779 msg_id = parent['msg_id']
779 msg_id = parent['msg_id']
780 if msg_id not in self.outstanding:
780 if msg_id not in self.outstanding:
781 if msg_id in self.history:
781 if msg_id in self.history:
782 print("got stale result: %s"%msg_id)
782 print("got stale result: %s"%msg_id)
783 print(self.results[msg_id])
783 print(self.results[msg_id])
784 print(msg)
784 print(msg)
785 else:
785 else:
786 print("got unknown result: %s"%msg_id)
786 print("got unknown result: %s"%msg_id)
787 else:
787 else:
788 self.outstanding.remove(msg_id)
788 self.outstanding.remove(msg_id)
789 content = msg['content']
789 content = msg['content']
790 header = msg['header']
790 header = msg['header']
791
791
792 # construct metadata:
792 # construct metadata:
793 md = self.metadata[msg_id]
793 md = self.metadata[msg_id]
794 md.update(self._extract_metadata(msg))
794 md.update(self._extract_metadata(msg))
795 # is this redundant?
795 # is this redundant?
796 self.metadata[msg_id] = md
796 self.metadata[msg_id] = md
797
797
798 e_outstanding = self._outstanding_dict[md['engine_uuid']]
798 e_outstanding = self._outstanding_dict[md['engine_uuid']]
799 if msg_id in e_outstanding:
799 if msg_id in e_outstanding:
800 e_outstanding.remove(msg_id)
800 e_outstanding.remove(msg_id)
801
801
802 # construct result:
802 # construct result:
803 if content['status'] == 'ok':
803 if content['status'] == 'ok':
804 self.results[msg_id] = serialize.deserialize_object(msg['buffers'])[0]
804 self.results[msg_id] = serialize.deserialize_object(msg['buffers'])[0]
805 elif content['status'] == 'aborted':
805 elif content['status'] == 'aborted':
806 self.results[msg_id] = error.TaskAborted(msg_id)
806 self.results[msg_id] = error.TaskAborted(msg_id)
807 elif content['status'] == 'resubmitted':
807 elif content['status'] == 'resubmitted':
808 # TODO: handle resubmission
808 # TODO: handle resubmission
809 pass
809 pass
810 else:
810 else:
811 self.results[msg_id] = self._unwrap_exception(content)
811 self.results[msg_id] = self._unwrap_exception(content)
812
812
813 def _flush_notifications(self):
813 def _flush_notifications(self):
814 """Flush notifications of engine registrations waiting
814 """Flush notifications of engine registrations waiting
815 in ZMQ queue."""
815 in ZMQ queue."""
816 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
816 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
817 while msg is not None:
817 while msg is not None:
818 if self.debug:
818 if self.debug:
819 pprint(msg)
819 pprint(msg)
820 msg_type = msg['header']['msg_type']
820 msg_type = msg['header']['msg_type']
821 handler = self._notification_handlers.get(msg_type, None)
821 handler = self._notification_handlers.get(msg_type, None)
822 if handler is None:
822 if handler is None:
823 raise Exception("Unhandled message type: %s" % msg_type)
823 raise Exception("Unhandled message type: %s" % msg_type)
824 else:
824 else:
825 handler(msg)
825 handler(msg)
826 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
826 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
827
827
828 def _flush_results(self, sock):
828 def _flush_results(self, sock):
829 """Flush task or queue results waiting in ZMQ queue."""
829 """Flush task or queue results waiting in ZMQ queue."""
830 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
830 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
831 while msg is not None:
831 while msg is not None:
832 if self.debug:
832 if self.debug:
833 pprint(msg)
833 pprint(msg)
834 msg_type = msg['header']['msg_type']
834 msg_type = msg['header']['msg_type']
835 handler = self._queue_handlers.get(msg_type, None)
835 handler = self._queue_handlers.get(msg_type, None)
836 if handler is None:
836 if handler is None:
837 raise Exception("Unhandled message type: %s" % msg_type)
837 raise Exception("Unhandled message type: %s" % msg_type)
838 else:
838 else:
839 handler(msg)
839 handler(msg)
840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
841
841
842 def _flush_control(self, sock):
842 def _flush_control(self, sock):
843 """Flush replies from the control channel waiting
843 """Flush replies from the control channel waiting
844 in the ZMQ queue.
844 in the ZMQ queue.
845
845
846 Currently: ignore them."""
846 Currently: ignore them."""
847 if self._ignored_control_replies <= 0:
847 if self._ignored_control_replies <= 0:
848 return
848 return
849 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
849 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
850 while msg is not None:
850 while msg is not None:
851 self._ignored_control_replies -= 1
851 self._ignored_control_replies -= 1
852 if self.debug:
852 if self.debug:
853 pprint(msg)
853 pprint(msg)
854 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
854 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
855
855
856 def _flush_ignored_control(self):
856 def _flush_ignored_control(self):
857 """flush ignored control replies"""
857 """flush ignored control replies"""
858 while self._ignored_control_replies > 0:
858 while self._ignored_control_replies > 0:
859 self.session.recv(self._control_socket)
859 self.session.recv(self._control_socket)
860 self._ignored_control_replies -= 1
860 self._ignored_control_replies -= 1
861
861
862 def _flush_ignored_hub_replies(self):
862 def _flush_ignored_hub_replies(self):
863 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
863 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
864 while msg is not None:
864 while msg is not None:
865 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
865 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
866
866
867 def _flush_iopub(self, sock):
867 def _flush_iopub(self, sock):
868 """Flush replies from the iopub channel waiting
868 """Flush replies from the iopub channel waiting
869 in the ZMQ queue.
869 in the ZMQ queue.
870 """
870 """
871 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
871 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
872 while msg is not None:
872 while msg is not None:
873 if self.debug:
873 if self.debug:
874 pprint(msg)
874 pprint(msg)
875 parent = msg['parent_header']
875 parent = msg['parent_header']
876 if not parent or parent['session'] != self.session.session:
876 if not parent or parent['session'] != self.session.session:
877 # ignore IOPub messages not from here
877 # ignore IOPub messages not from here
878 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
878 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
879 continue
879 continue
880 msg_id = parent['msg_id']
880 msg_id = parent['msg_id']
881 content = msg['content']
881 content = msg['content']
882 header = msg['header']
882 header = msg['header']
883 msg_type = msg['header']['msg_type']
883 msg_type = msg['header']['msg_type']
884
884
885 if msg_type == 'status' and msg_id not in self.metadata:
885 if msg_type == 'status' and msg_id not in self.metadata:
886 # ignore status messages if they aren't mine
886 # ignore status messages if they aren't mine
887 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
887 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
888 continue
888 continue
889
889
890 # init metadata:
890 # init metadata:
891 md = self.metadata[msg_id]
891 md = self.metadata[msg_id]
892
892
893 if msg_type == 'stream':
893 if msg_type == 'stream':
894 name = content['name']
894 name = content['name']
895 s = md[name] or ''
895 s = md[name] or ''
896 md[name] = s + content['text']
896 md[name] = s + content['text']
897 elif msg_type == 'error':
897 elif msg_type == 'error':
898 md.update({'error' : self._unwrap_exception(content)})
898 md.update({'error' : self._unwrap_exception(content)})
899 elif msg_type == 'execute_input':
899 elif msg_type == 'execute_input':
900 md.update({'execute_input' : content['code']})
900 md.update({'execute_input' : content['code']})
901 elif msg_type == 'display_data':
901 elif msg_type == 'display_data':
902 md['outputs'].append(content)
902 md['outputs'].append(content)
903 elif msg_type == 'execute_result':
903 elif msg_type == 'execute_result':
904 md['execute_result'] = content
904 md['execute_result'] = content
905 elif msg_type == 'data_message':
905 elif msg_type == 'data_message':
906 data, remainder = serialize.deserialize_object(msg['buffers'])
906 data, remainder = serialize.deserialize_object(msg['buffers'])
907 md['data'].update(data)
907 md['data'].update(data)
908 elif msg_type == 'status':
908 elif msg_type == 'status':
909 # idle message comes after all outputs
909 # idle message comes after all outputs
910 if content['execution_state'] == 'idle':
910 if content['execution_state'] == 'idle':
911 md['outputs_ready'] = True
911 md['outputs_ready'] = True
912 else:
912 else:
913 # unhandled msg_type (status, etc.)
913 # unhandled msg_type (status, etc.)
914 pass
914 pass
915
915
916 # reduntant?
916 # reduntant?
917 self.metadata[msg_id] = md
917 self.metadata[msg_id] = md
918
918
919 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
919 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
920
920
921 #--------------------------------------------------------------------------
921 #--------------------------------------------------------------------------
922 # len, getitem
922 # len, getitem
923 #--------------------------------------------------------------------------
923 #--------------------------------------------------------------------------
924
924
925 def __len__(self):
925 def __len__(self):
926 """len(client) returns # of engines."""
926 """len(client) returns # of engines."""
927 return len(self.ids)
927 return len(self.ids)
928
928
929 def __getitem__(self, key):
929 def __getitem__(self, key):
930 """index access returns DirectView multiplexer objects
930 """index access returns DirectView multiplexer objects
931
931
932 Must be int, slice, or list/tuple/xrange of ints"""
932 Must be int, slice, or list/tuple/xrange of ints"""
933 if not isinstance(key, (int, slice, tuple, list, xrange)):
933 if not isinstance(key, (int, slice, tuple, list, xrange)):
934 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
934 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
935 else:
935 else:
936 return self.direct_view(key)
936 return self.direct_view(key)
937
937
938 def __iter__(self):
938 def __iter__(self):
939 """Since we define getitem, Client is iterable
939 """Since we define getitem, Client is iterable
940
940
941 but unless we also define __iter__, it won't work correctly unless engine IDs
941 but unless we also define __iter__, it won't work correctly unless engine IDs
942 start at zero and are continuous.
942 start at zero and are continuous.
943 """
943 """
944 for eid in self.ids:
944 for eid in self.ids:
945 yield self.direct_view(eid)
945 yield self.direct_view(eid)
946
946
947 #--------------------------------------------------------------------------
947 #--------------------------------------------------------------------------
948 # Begin public methods
948 # Begin public methods
949 #--------------------------------------------------------------------------
949 #--------------------------------------------------------------------------
950
950
951 @property
951 @property
952 def ids(self):
952 def ids(self):
953 """Always up-to-date ids property."""
953 """Always up-to-date ids property."""
954 self._flush_notifications()
954 self._flush_notifications()
955 # always copy:
955 # always copy:
956 return list(self._ids)
956 return list(self._ids)
957
957
958 def activate(self, targets='all', suffix=''):
958 def activate(self, targets='all', suffix=''):
959 """Create a DirectView and register it with IPython magics
959 """Create a DirectView and register it with IPython magics
960
960
961 Defines the magics `%px, %autopx, %pxresult, %%px`
961 Defines the magics `%px, %autopx, %pxresult, %%px`
962
962
963 Parameters
963 Parameters
964 ----------
964 ----------
965
965
966 targets: int, list of ints, or 'all'
966 targets: int, list of ints, or 'all'
967 The engines on which the view's magics will run
967 The engines on which the view's magics will run
968 suffix: str [default: '']
968 suffix: str [default: '']
969 The suffix, if any, for the magics. This allows you to have
969 The suffix, if any, for the magics. This allows you to have
970 multiple views associated with parallel magics at the same time.
970 multiple views associated with parallel magics at the same time.
971
971
972 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
972 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
973 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
973 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
974 on engine 0.
974 on engine 0.
975 """
975 """
976 view = self.direct_view(targets)
976 view = self.direct_view(targets)
977 view.block = True
977 view.block = True
978 view.activate(suffix)
978 view.activate(suffix)
979 return view
979 return view
980
980
981 def close(self, linger=None):
981 def close(self, linger=None):
982 """Close my zmq Sockets
982 """Close my zmq Sockets
983
983
984 If `linger`, set the zmq LINGER socket option,
984 If `linger`, set the zmq LINGER socket option,
985 which allows discarding of messages.
985 which allows discarding of messages.
986 """
986 """
987 if self._closed:
987 if self._closed:
988 return
988 return
989 self.stop_spin_thread()
989 self.stop_spin_thread()
990 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
990 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
991 for name in snames:
991 for name in snames:
992 socket = getattr(self, name)
992 socket = getattr(self, name)
993 if socket is not None and not socket.closed:
993 if socket is not None and not socket.closed:
994 if linger is not None:
994 if linger is not None:
995 socket.close(linger=linger)
995 socket.close(linger=linger)
996 else:
996 else:
997 socket.close()
997 socket.close()
998 self._closed = True
998 self._closed = True
999
999
1000 def _spin_every(self, interval=1):
1000 def _spin_every(self, interval=1):
1001 """target func for use in spin_thread"""
1001 """target func for use in spin_thread"""
1002 while True:
1002 while True:
1003 if self._stop_spinning.is_set():
1003 if self._stop_spinning.is_set():
1004 return
1004 return
1005 time.sleep(interval)
1005 time.sleep(interval)
1006 self.spin()
1006 self.spin()
1007
1007
1008 def spin_thread(self, interval=1):
1008 def spin_thread(self, interval=1):
1009 """call Client.spin() in a background thread on some regular interval
1009 """call Client.spin() in a background thread on some regular interval
1010
1010
1011 This helps ensure that messages don't pile up too much in the zmq queue
1011 This helps ensure that messages don't pile up too much in the zmq queue
1012 while you are working on other things, or just leaving an idle terminal.
1012 while you are working on other things, or just leaving an idle terminal.
1013
1013
1014 It also helps limit potential padding of the `received` timestamp
1014 It also helps limit potential padding of the `received` timestamp
1015 on AsyncResult objects, used for timings.
1015 on AsyncResult objects, used for timings.
1016
1016
1017 Parameters
1017 Parameters
1018 ----------
1018 ----------
1019
1019
1020 interval : float, optional
1020 interval : float, optional
1021 The interval on which to spin the client in the background thread
1021 The interval on which to spin the client in the background thread
1022 (simply passed to time.sleep).
1022 (simply passed to time.sleep).
1023
1023
1024 Notes
1024 Notes
1025 -----
1025 -----
1026
1026
1027 For precision timing, you may want to use this method to put a bound
1027 For precision timing, you may want to use this method to put a bound
1028 on the jitter (in seconds) in `received` timestamps used
1028 on the jitter (in seconds) in `received` timestamps used
1029 in AsyncResult.wall_time.
1029 in AsyncResult.wall_time.
1030
1030
1031 """
1031 """
1032 if self._spin_thread is not None:
1032 if self._spin_thread is not None:
1033 self.stop_spin_thread()
1033 self.stop_spin_thread()
1034 self._stop_spinning.clear()
1034 self._stop_spinning.clear()
1035 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1035 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1036 self._spin_thread.daemon = True
1036 self._spin_thread.daemon = True
1037 self._spin_thread.start()
1037 self._spin_thread.start()
1038
1038
1039 def stop_spin_thread(self):
1039 def stop_spin_thread(self):
1040 """stop background spin_thread, if any"""
1040 """stop background spin_thread, if any"""
1041 if self._spin_thread is not None:
1041 if self._spin_thread is not None:
1042 self._stop_spinning.set()
1042 self._stop_spinning.set()
1043 self._spin_thread.join()
1043 self._spin_thread.join()
1044 self._spin_thread = None
1044 self._spin_thread = None
1045
1045
1046 def spin(self):
1046 def spin(self):
1047 """Flush any registration notifications and execution results
1047 """Flush any registration notifications and execution results
1048 waiting in the ZMQ queue.
1048 waiting in the ZMQ queue.
1049 """
1049 """
1050 if self._notification_socket:
1050 if self._notification_socket:
1051 self._flush_notifications()
1051 self._flush_notifications()
1052 if self._iopub_socket:
1052 if self._iopub_socket:
1053 self._flush_iopub(self._iopub_socket)
1053 self._flush_iopub(self._iopub_socket)
1054 if self._mux_socket:
1054 if self._mux_socket:
1055 self._flush_results(self._mux_socket)
1055 self._flush_results(self._mux_socket)
1056 if self._task_socket:
1056 if self._task_socket:
1057 self._flush_results(self._task_socket)
1057 self._flush_results(self._task_socket)
1058 if self._control_socket:
1058 if self._control_socket:
1059 self._flush_control(self._control_socket)
1059 self._flush_control(self._control_socket)
1060 if self._query_socket:
1060 if self._query_socket:
1061 self._flush_ignored_hub_replies()
1061 self._flush_ignored_hub_replies()
1062
1062
1063 def wait(self, jobs=None, timeout=-1):
1063 def wait(self, jobs=None, timeout=-1):
1064 """waits on one or more `jobs`, for up to `timeout` seconds.
1064 """waits on one or more `jobs`, for up to `timeout` seconds.
1065
1065
1066 Parameters
1066 Parameters
1067 ----------
1067 ----------
1068
1068
1069 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1069 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1070 ints are indices to self.history
1070 ints are indices to self.history
1071 strs are msg_ids
1071 strs are msg_ids
1072 default: wait on all outstanding messages
1072 default: wait on all outstanding messages
1073 timeout : float
1073 timeout : float
1074 a time in seconds, after which to give up.
1074 a time in seconds, after which to give up.
1075 default is -1, which means no timeout
1075 default is -1, which means no timeout
1076
1076
1077 Returns
1077 Returns
1078 -------
1078 -------
1079
1079
1080 True : when all msg_ids are done
1080 True : when all msg_ids are done
1081 False : timeout reached, some msg_ids still outstanding
1081 False : timeout reached, some msg_ids still outstanding
1082 """
1082 """
1083 tic = time.time()
1083 tic = time.time()
1084 if jobs is None:
1084 if jobs is None:
1085 theids = self.outstanding
1085 theids = self.outstanding
1086 else:
1086 else:
1087 if isinstance(jobs, string_types + (int, AsyncResult)):
1087 if isinstance(jobs, string_types + (int, AsyncResult)):
1088 jobs = [jobs]
1088 jobs = [jobs]
1089 theids = set()
1089 theids = set()
1090 for job in jobs:
1090 for job in jobs:
1091 if isinstance(job, int):
1091 if isinstance(job, int):
1092 # index access
1092 # index access
1093 job = self.history[job]
1093 job = self.history[job]
1094 elif isinstance(job, AsyncResult):
1094 elif isinstance(job, AsyncResult):
1095 theids.update(job.msg_ids)
1095 theids.update(job.msg_ids)
1096 continue
1096 continue
1097 theids.add(job)
1097 theids.add(job)
1098 if not theids.intersection(self.outstanding):
1098 if not theids.intersection(self.outstanding):
1099 return True
1099 return True
1100 self.spin()
1100 self.spin()
1101 while theids.intersection(self.outstanding):
1101 while theids.intersection(self.outstanding):
1102 if timeout >= 0 and ( time.time()-tic ) > timeout:
1102 if timeout >= 0 and ( time.time()-tic ) > timeout:
1103 break
1103 break
1104 time.sleep(1e-3)
1104 time.sleep(1e-3)
1105 self.spin()
1105 self.spin()
1106 return len(theids.intersection(self.outstanding)) == 0
1106 return len(theids.intersection(self.outstanding)) == 0
1107
1107
1108 #--------------------------------------------------------------------------
1108 #--------------------------------------------------------------------------
1109 # Control methods
1109 # Control methods
1110 #--------------------------------------------------------------------------
1110 #--------------------------------------------------------------------------
1111
1111
1112 @spin_first
1112 @spin_first
1113 def clear(self, targets=None, block=None):
1113 def clear(self, targets=None, block=None):
1114 """Clear the namespace in target(s)."""
1114 """Clear the namespace in target(s)."""
1115 block = self.block if block is None else block
1115 block = self.block if block is None else block
1116 targets = self._build_targets(targets)[0]
1116 targets = self._build_targets(targets)[0]
1117 for t in targets:
1117 for t in targets:
1118 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1118 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1119 error = False
1119 error = False
1120 if block:
1120 if block:
1121 self._flush_ignored_control()
1121 self._flush_ignored_control()
1122 for i in range(len(targets)):
1122 for i in range(len(targets)):
1123 idents,msg = self.session.recv(self._control_socket,0)
1123 idents,msg = self.session.recv(self._control_socket,0)
1124 if self.debug:
1124 if self.debug:
1125 pprint(msg)
1125 pprint(msg)
1126 if msg['content']['status'] != 'ok':
1126 if msg['content']['status'] != 'ok':
1127 error = self._unwrap_exception(msg['content'])
1127 error = self._unwrap_exception(msg['content'])
1128 else:
1128 else:
1129 self._ignored_control_replies += len(targets)
1129 self._ignored_control_replies += len(targets)
1130 if error:
1130 if error:
1131 raise error
1131 raise error
1132
1132
1133
1133
1134 @spin_first
1134 @spin_first
1135 def abort(self, jobs=None, targets=None, block=None):
1135 def abort(self, jobs=None, targets=None, block=None):
1136 """Abort specific jobs from the execution queues of target(s).
1136 """Abort specific jobs from the execution queues of target(s).
1137
1137
1138 This is a mechanism to prevent jobs that have already been submitted
1138 This is a mechanism to prevent jobs that have already been submitted
1139 from executing.
1139 from executing.
1140
1140
1141 Parameters
1141 Parameters
1142 ----------
1142 ----------
1143
1143
1144 jobs : msg_id, list of msg_ids, or AsyncResult
1144 jobs : msg_id, list of msg_ids, or AsyncResult
1145 The jobs to be aborted
1145 The jobs to be aborted
1146
1146
1147 If unspecified/None: abort all outstanding jobs.
1147 If unspecified/None: abort all outstanding jobs.
1148
1148
1149 """
1149 """
1150 block = self.block if block is None else block
1150 block = self.block if block is None else block
1151 jobs = jobs if jobs is not None else list(self.outstanding)
1151 jobs = jobs if jobs is not None else list(self.outstanding)
1152 targets = self._build_targets(targets)[0]
1152 targets = self._build_targets(targets)[0]
1153
1153
1154 msg_ids = []
1154 msg_ids = []
1155 if isinstance(jobs, string_types + (AsyncResult,)):
1155 if isinstance(jobs, string_types + (AsyncResult,)):
1156 jobs = [jobs]
1156 jobs = [jobs]
1157 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1157 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1158 if bad_ids:
1158 if bad_ids:
1159 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1159 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1160 for j in jobs:
1160 for j in jobs:
1161 if isinstance(j, AsyncResult):
1161 if isinstance(j, AsyncResult):
1162 msg_ids.extend(j.msg_ids)
1162 msg_ids.extend(j.msg_ids)
1163 else:
1163 else:
1164 msg_ids.append(j)
1164 msg_ids.append(j)
1165 content = dict(msg_ids=msg_ids)
1165 content = dict(msg_ids=msg_ids)
1166 for t in targets:
1166 for t in targets:
1167 self.session.send(self._control_socket, 'abort_request',
1167 self.session.send(self._control_socket, 'abort_request',
1168 content=content, ident=t)
1168 content=content, ident=t)
1169 error = False
1169 error = False
1170 if block:
1170 if block:
1171 self._flush_ignored_control()
1171 self._flush_ignored_control()
1172 for i in range(len(targets)):
1172 for i in range(len(targets)):
1173 idents,msg = self.session.recv(self._control_socket,0)
1173 idents,msg = self.session.recv(self._control_socket,0)
1174 if self.debug:
1174 if self.debug:
1175 pprint(msg)
1175 pprint(msg)
1176 if msg['content']['status'] != 'ok':
1176 if msg['content']['status'] != 'ok':
1177 error = self._unwrap_exception(msg['content'])
1177 error = self._unwrap_exception(msg['content'])
1178 else:
1178 else:
1179 self._ignored_control_replies += len(targets)
1179 self._ignored_control_replies += len(targets)
1180 if error:
1180 if error:
1181 raise error
1181 raise error
1182
1182
1183 @spin_first
1183 @spin_first
1184 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1184 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1185 """Terminates one or more engine processes, optionally including the hub.
1185 """Terminates one or more engine processes, optionally including the hub.
1186
1186
1187 Parameters
1187 Parameters
1188 ----------
1188 ----------
1189
1189
1190 targets: list of ints or 'all' [default: all]
1190 targets: list of ints or 'all' [default: all]
1191 Which engines to shutdown.
1191 Which engines to shutdown.
1192 hub: bool [default: False]
1192 hub: bool [default: False]
1193 Whether to include the Hub. hub=True implies targets='all'.
1193 Whether to include the Hub. hub=True implies targets='all'.
1194 block: bool [default: self.block]
1194 block: bool [default: self.block]
1195 Whether to wait for clean shutdown replies or not.
1195 Whether to wait for clean shutdown replies or not.
1196 restart: bool [default: False]
1196 restart: bool [default: False]
1197 NOT IMPLEMENTED
1197 NOT IMPLEMENTED
1198 whether to restart engines after shutting them down.
1198 whether to restart engines after shutting them down.
1199 """
1199 """
1200 from IPython.parallel.error import NoEnginesRegistered
1200 from IPython.parallel.error import NoEnginesRegistered
1201 if restart:
1201 if restart:
1202 raise NotImplementedError("Engine restart is not yet implemented")
1202 raise NotImplementedError("Engine restart is not yet implemented")
1203
1203
1204 block = self.block if block is None else block
1204 block = self.block if block is None else block
1205 if hub:
1205 if hub:
1206 targets = 'all'
1206 targets = 'all'
1207 try:
1207 try:
1208 targets = self._build_targets(targets)[0]
1208 targets = self._build_targets(targets)[0]
1209 except NoEnginesRegistered:
1209 except NoEnginesRegistered:
1210 targets = []
1210 targets = []
1211 for t in targets:
1211 for t in targets:
1212 self.session.send(self._control_socket, 'shutdown_request',
1212 self.session.send(self._control_socket, 'shutdown_request',
1213 content={'restart':restart},ident=t)
1213 content={'restart':restart},ident=t)
1214 error = False
1214 error = False
1215 if block or hub:
1215 if block or hub:
1216 self._flush_ignored_control()
1216 self._flush_ignored_control()
1217 for i in range(len(targets)):
1217 for i in range(len(targets)):
1218 idents,msg = self.session.recv(self._control_socket, 0)
1218 idents,msg = self.session.recv(self._control_socket, 0)
1219 if self.debug:
1219 if self.debug:
1220 pprint(msg)
1220 pprint(msg)
1221 if msg['content']['status'] != 'ok':
1221 if msg['content']['status'] != 'ok':
1222 error = self._unwrap_exception(msg['content'])
1222 error = self._unwrap_exception(msg['content'])
1223 else:
1223 else:
1224 self._ignored_control_replies += len(targets)
1224 self._ignored_control_replies += len(targets)
1225
1225
1226 if hub:
1226 if hub:
1227 time.sleep(0.25)
1227 time.sleep(0.25)
1228 self.session.send(self._query_socket, 'shutdown_request')
1228 self.session.send(self._query_socket, 'shutdown_request')
1229 idents,msg = self.session.recv(self._query_socket, 0)
1229 idents,msg = self.session.recv(self._query_socket, 0)
1230 if self.debug:
1230 if self.debug:
1231 pprint(msg)
1231 pprint(msg)
1232 if msg['content']['status'] != 'ok':
1232 if msg['content']['status'] != 'ok':
1233 error = self._unwrap_exception(msg['content'])
1233 error = self._unwrap_exception(msg['content'])
1234
1234
1235 if error:
1235 if error:
1236 raise error
1236 raise error
1237
1237
1238 #--------------------------------------------------------------------------
1238 #--------------------------------------------------------------------------
1239 # Execution related methods
1239 # Execution related methods
1240 #--------------------------------------------------------------------------
1240 #--------------------------------------------------------------------------
1241
1241
1242 def _maybe_raise(self, result):
1242 def _maybe_raise(self, result):
1243 """wrapper for maybe raising an exception if apply failed."""
1243 """wrapper for maybe raising an exception if apply failed."""
1244 if isinstance(result, error.RemoteError):
1244 if isinstance(result, error.RemoteError):
1245 raise result
1245 raise result
1246
1246
1247 return result
1247 return result
1248
1248
1249 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1249 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1250 ident=None):
1250 ident=None):
1251 """construct and send an apply message via a socket.
1251 """construct and send an apply message via a socket.
1252
1252
1253 This is the principal method with which all engine execution is performed by views.
1253 This is the principal method with which all engine execution is performed by views.
1254 """
1254 """
1255
1255
1256 if self._closed:
1256 if self._closed:
1257 raise RuntimeError("Client cannot be used after its sockets have been closed")
1257 raise RuntimeError("Client cannot be used after its sockets have been closed")
1258
1258
1259 # defaults:
1259 # defaults:
1260 args = args if args is not None else []
1260 args = args if args is not None else []
1261 kwargs = kwargs if kwargs is not None else {}
1261 kwargs = kwargs if kwargs is not None else {}
1262 metadata = metadata if metadata is not None else {}
1262 metadata = metadata if metadata is not None else {}
1263
1263
1264 # validate arguments
1264 # validate arguments
1265 if not callable(f) and not isinstance(f, Reference):
1265 if not callable(f) and not isinstance(f, Reference):
1266 raise TypeError("f must be callable, not %s"%type(f))
1266 raise TypeError("f must be callable, not %s"%type(f))
1267 if not isinstance(args, (tuple, list)):
1267 if not isinstance(args, (tuple, list)):
1268 raise TypeError("args must be tuple or list, not %s"%type(args))
1268 raise TypeError("args must be tuple or list, not %s"%type(args))
1269 if not isinstance(kwargs, dict):
1269 if not isinstance(kwargs, dict):
1270 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1270 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1271 if not isinstance(metadata, dict):
1271 if not isinstance(metadata, dict):
1272 raise TypeError("metadata must be dict, not %s"%type(metadata))
1272 raise TypeError("metadata must be dict, not %s"%type(metadata))
1273
1273
1274 bufs = serialize.pack_apply_message(f, args, kwargs,
1274 bufs = serialize.pack_apply_message(f, args, kwargs,
1275 buffer_threshold=self.session.buffer_threshold,
1275 buffer_threshold=self.session.buffer_threshold,
1276 item_threshold=self.session.item_threshold,
1276 item_threshold=self.session.item_threshold,
1277 )
1277 )
1278
1278
1279 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1279 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1280 metadata=metadata, track=track)
1280 metadata=metadata, track=track)
1281
1281
1282 msg_id = msg['header']['msg_id']
1282 msg_id = msg['header']['msg_id']
1283 self.outstanding.add(msg_id)
1283 self.outstanding.add(msg_id)
1284 if ident:
1284 if ident:
1285 # possibly routed to a specific engine
1285 # possibly routed to a specific engine
1286 if isinstance(ident, list):
1286 if isinstance(ident, list):
1287 ident = ident[-1]
1287 ident = ident[-1]
1288 if ident in self._engines.values():
1288 if ident in self._engines.values():
1289 # save for later, in case of engine death
1289 # save for later, in case of engine death
1290 self._outstanding_dict[ident].add(msg_id)
1290 self._outstanding_dict[ident].add(msg_id)
1291 self.history.append(msg_id)
1291 self.history.append(msg_id)
1292 self.metadata[msg_id]['submitted'] = datetime.now()
1292 self.metadata[msg_id]['submitted'] = datetime.now()
1293
1293
1294 return msg
1294 return msg
1295
1295
1296 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1296 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1297 """construct and send an execute request via a socket.
1297 """construct and send an execute request via a socket.
1298
1298
1299 """
1299 """
1300
1300
1301 if self._closed:
1301 if self._closed:
1302 raise RuntimeError("Client cannot be used after its sockets have been closed")
1302 raise RuntimeError("Client cannot be used after its sockets have been closed")
1303
1303
1304 # defaults:
1304 # defaults:
1305 metadata = metadata if metadata is not None else {}
1305 metadata = metadata if metadata is not None else {}
1306
1306
1307 # validate arguments
1307 # validate arguments
1308 if not isinstance(code, string_types):
1308 if not isinstance(code, string_types):
1309 raise TypeError("code must be text, not %s" % type(code))
1309 raise TypeError("code must be text, not %s" % type(code))
1310 if not isinstance(metadata, dict):
1310 if not isinstance(metadata, dict):
1311 raise TypeError("metadata must be dict, not %s" % type(metadata))
1311 raise TypeError("metadata must be dict, not %s" % type(metadata))
1312
1312
1313 content = dict(code=code, silent=bool(silent), user_expressions={})
1313 content = dict(code=code, silent=bool(silent), user_expressions={})
1314
1314
1315
1315
1316 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1316 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1317 metadata=metadata)
1317 metadata=metadata)
1318
1318
1319 msg_id = msg['header']['msg_id']
1319 msg_id = msg['header']['msg_id']
1320 self.outstanding.add(msg_id)
1320 self.outstanding.add(msg_id)
1321 if ident:
1321 if ident:
1322 # possibly routed to a specific engine
1322 # possibly routed to a specific engine
1323 if isinstance(ident, list):
1323 if isinstance(ident, list):
1324 ident = ident[-1]
1324 ident = ident[-1]
1325 if ident in self._engines.values():
1325 if ident in self._engines.values():
1326 # save for later, in case of engine death
1326 # save for later, in case of engine death
1327 self._outstanding_dict[ident].add(msg_id)
1327 self._outstanding_dict[ident].add(msg_id)
1328 self.history.append(msg_id)
1328 self.history.append(msg_id)
1329 self.metadata[msg_id]['submitted'] = datetime.now()
1329 self.metadata[msg_id]['submitted'] = datetime.now()
1330
1330
1331 return msg
1331 return msg
1332
1332
1333 #--------------------------------------------------------------------------
1333 #--------------------------------------------------------------------------
1334 # construct a View object
1334 # construct a View object
1335 #--------------------------------------------------------------------------
1335 #--------------------------------------------------------------------------
1336
1336
1337 def load_balanced_view(self, targets=None):
1337 def load_balanced_view(self, targets=None):
1338 """construct a DirectView object.
1338 """construct a DirectView object.
1339
1339
1340 If no arguments are specified, create a LoadBalancedView
1340 If no arguments are specified, create a LoadBalancedView
1341 using all engines.
1341 using all engines.
1342
1342
1343 Parameters
1343 Parameters
1344 ----------
1344 ----------
1345
1345
1346 targets: list,slice,int,etc. [default: use all engines]
1346 targets: list,slice,int,etc. [default: use all engines]
1347 The subset of engines across which to load-balance
1347 The subset of engines across which to load-balance
1348 """
1348 """
1349 if targets == 'all':
1349 if targets == 'all':
1350 targets = None
1350 targets = None
1351 if targets is not None:
1351 if targets is not None:
1352 targets = self._build_targets(targets)[1]
1352 targets = self._build_targets(targets)[1]
1353 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1353 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1354
1354
1355 def direct_view(self, targets='all'):
1355 def direct_view(self, targets='all'):
1356 """construct a DirectView object.
1356 """construct a DirectView object.
1357
1357
1358 If no targets are specified, create a DirectView using all engines.
1358 If no targets are specified, create a DirectView using all engines.
1359
1359
1360 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1360 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1361 evaluate the target engines at each execution, whereas rc[:] will connect to
1361 evaluate the target engines at each execution, whereas rc[:] will connect to
1362 all *current* engines, and that list will not change.
1362 all *current* engines, and that list will not change.
1363
1363
1364 That is, 'all' will always use all engines, whereas rc[:] will not use
1364 That is, 'all' will always use all engines, whereas rc[:] will not use
1365 engines added after the DirectView is constructed.
1365 engines added after the DirectView is constructed.
1366
1366
1367 Parameters
1367 Parameters
1368 ----------
1368 ----------
1369
1369
1370 targets: list,slice,int,etc. [default: use all engines]
1370 targets: list,slice,int,etc. [default: use all engines]
1371 The engines to use for the View
1371 The engines to use for the View
1372 """
1372 """
1373 single = isinstance(targets, int)
1373 single = isinstance(targets, int)
1374 # allow 'all' to be lazily evaluated at each execution
1374 # allow 'all' to be lazily evaluated at each execution
1375 if targets != 'all':
1375 if targets != 'all':
1376 targets = self._build_targets(targets)[1]
1376 targets = self._build_targets(targets)[1]
1377 if single:
1377 if single:
1378 targets = targets[0]
1378 targets = targets[0]
1379 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1379 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1380
1380
1381 #--------------------------------------------------------------------------
1381 #--------------------------------------------------------------------------
1382 # Query methods
1382 # Query methods
1383 #--------------------------------------------------------------------------
1383 #--------------------------------------------------------------------------
1384
1384
1385 @spin_first
1385 @spin_first
1386 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
1386 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
1387 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1387 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1388
1388
1389 If the client already has the results, no request to the Hub will be made.
1389 If the client already has the results, no request to the Hub will be made.
1390
1390
1391 This is a convenient way to construct AsyncResult objects, which are wrappers
1391 This is a convenient way to construct AsyncResult objects, which are wrappers
1392 that include metadata about execution, and allow for awaiting results that
1392 that include metadata about execution, and allow for awaiting results that
1393 were not submitted by this Client.
1393 were not submitted by this Client.
1394
1394
1395 It can also be a convenient way to retrieve the metadata associated with
1395 It can also be a convenient way to retrieve the metadata associated with
1396 blocking execution, since it always retrieves
1396 blocking execution, since it always retrieves
1397
1397
1398 Examples
1398 Examples
1399 --------
1399 --------
1400 ::
1400 ::
1401
1401
1402 In [10]: r = client.apply()
1402 In [10]: r = client.apply()
1403
1403
1404 Parameters
1404 Parameters
1405 ----------
1405 ----------
1406
1406
1407 indices_or_msg_ids : integer history index, str msg_id, or list of either
1407 indices_or_msg_ids : integer history index, str msg_id, or list of either
1408 The indices or msg_ids of indices to be retrieved
1408 The indices or msg_ids of indices to be retrieved
1409
1409
1410 block : bool
1410 block : bool
1411 Whether to wait for the result to be done
1411 Whether to wait for the result to be done
1412 owner : bool [default: True]
1412 owner : bool [default: True]
1413 Whether this AsyncResult should own the result.
1413 Whether this AsyncResult should own the result.
1414 If so, calling `ar.get()` will remove data from the
1414 If so, calling `ar.get()` will remove data from the
1415 client's result and metadata cache.
1415 client's result and metadata cache.
1416 There should only be one owner of any given msg_id.
1416 There should only be one owner of any given msg_id.
1417
1417
1418 Returns
1418 Returns
1419 -------
1419 -------
1420
1420
1421 AsyncResult
1421 AsyncResult
1422 A single AsyncResult object will always be returned.
1422 A single AsyncResult object will always be returned.
1423
1423
1424 AsyncHubResult
1424 AsyncHubResult
1425 A subclass of AsyncResult that retrieves results from the Hub
1425 A subclass of AsyncResult that retrieves results from the Hub
1426
1426
1427 """
1427 """
1428 block = self.block if block is None else block
1428 block = self.block if block is None else block
1429 if indices_or_msg_ids is None:
1429 if indices_or_msg_ids is None:
1430 indices_or_msg_ids = -1
1430 indices_or_msg_ids = -1
1431
1431
1432 single_result = False
1432 single_result = False
1433 if not isinstance(indices_or_msg_ids, (list,tuple)):
1433 if not isinstance(indices_or_msg_ids, (list,tuple)):
1434 indices_or_msg_ids = [indices_or_msg_ids]
1434 indices_or_msg_ids = [indices_or_msg_ids]
1435 single_result = True
1435 single_result = True
1436
1436
1437 theids = []
1437 theids = []
1438 for id in indices_or_msg_ids:
1438 for id in indices_or_msg_ids:
1439 if isinstance(id, int):
1439 if isinstance(id, int):
1440 id = self.history[id]
1440 id = self.history[id]
1441 if not isinstance(id, string_types):
1441 if not isinstance(id, string_types):
1442 raise TypeError("indices must be str or int, not %r"%id)
1442 raise TypeError("indices must be str or int, not %r"%id)
1443 theids.append(id)
1443 theids.append(id)
1444
1444
1445 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1445 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1446 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1446 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1447
1447
1448 # given single msg_id initially, get_result shot get the result itself,
1448 # given single msg_id initially, get_result shot get the result itself,
1449 # not a length-one list
1449 # not a length-one list
1450 if single_result:
1450 if single_result:
1451 theids = theids[0]
1451 theids = theids[0]
1452
1452
1453 if remote_ids:
1453 if remote_ids:
1454 ar = AsyncHubResult(self, msg_ids=theids, owner=owner)
1454 ar = AsyncHubResult(self, msg_ids=theids, owner=owner)
1455 else:
1455 else:
1456 ar = AsyncResult(self, msg_ids=theids, owner=owner)
1456 ar = AsyncResult(self, msg_ids=theids, owner=owner)
1457
1457
1458 if block:
1458 if block:
1459 ar.wait()
1459 ar.wait()
1460
1460
1461 return ar
1461 return ar
1462
1462
1463 @spin_first
1463 @spin_first
1464 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1464 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1465 """Resubmit one or more tasks.
1465 """Resubmit one or more tasks.
1466
1466
1467 in-flight tasks may not be resubmitted.
1467 in-flight tasks may not be resubmitted.
1468
1468
1469 Parameters
1469 Parameters
1470 ----------
1470 ----------
1471
1471
1472 indices_or_msg_ids : integer history index, str msg_id, or list of either
1472 indices_or_msg_ids : integer history index, str msg_id, or list of either
1473 The indices or msg_ids of indices to be retrieved
1473 The indices or msg_ids of indices to be retrieved
1474
1474
1475 block : bool
1475 block : bool
1476 Whether to wait for the result to be done
1476 Whether to wait for the result to be done
1477
1477
1478 Returns
1478 Returns
1479 -------
1479 -------
1480
1480
1481 AsyncHubResult
1481 AsyncHubResult
1482 A subclass of AsyncResult that retrieves results from the Hub
1482 A subclass of AsyncResult that retrieves results from the Hub
1483
1483
1484 """
1484 """
1485 block = self.block if block is None else block
1485 block = self.block if block is None else block
1486 if indices_or_msg_ids is None:
1486 if indices_or_msg_ids is None:
1487 indices_or_msg_ids = -1
1487 indices_or_msg_ids = -1
1488
1488
1489 if not isinstance(indices_or_msg_ids, (list,tuple)):
1489 if not isinstance(indices_or_msg_ids, (list,tuple)):
1490 indices_or_msg_ids = [indices_or_msg_ids]
1490 indices_or_msg_ids = [indices_or_msg_ids]
1491
1491
1492 theids = []
1492 theids = []
1493 for id in indices_or_msg_ids:
1493 for id in indices_or_msg_ids:
1494 if isinstance(id, int):
1494 if isinstance(id, int):
1495 id = self.history[id]
1495 id = self.history[id]
1496 if not isinstance(id, string_types):
1496 if not isinstance(id, string_types):
1497 raise TypeError("indices must be str or int, not %r"%id)
1497 raise TypeError("indices must be str or int, not %r"%id)
1498 theids.append(id)
1498 theids.append(id)
1499
1499
1500 content = dict(msg_ids = theids)
1500 content = dict(msg_ids = theids)
1501
1501
1502 self.session.send(self._query_socket, 'resubmit_request', content)
1502 self.session.send(self._query_socket, 'resubmit_request', content)
1503
1503
1504 zmq.select([self._query_socket], [], [])
1504 zmq.select([self._query_socket], [], [])
1505 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1505 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1506 if self.debug:
1506 if self.debug:
1507 pprint(msg)
1507 pprint(msg)
1508 content = msg['content']
1508 content = msg['content']
1509 if content['status'] != 'ok':
1509 if content['status'] != 'ok':
1510 raise self._unwrap_exception(content)
1510 raise self._unwrap_exception(content)
1511 mapping = content['resubmitted']
1511 mapping = content['resubmitted']
1512 new_ids = [ mapping[msg_id] for msg_id in theids ]
1512 new_ids = [ mapping[msg_id] for msg_id in theids ]
1513
1513
1514 ar = AsyncHubResult(self, msg_ids=new_ids)
1514 ar = AsyncHubResult(self, msg_ids=new_ids)
1515
1515
1516 if block:
1516 if block:
1517 ar.wait()
1517 ar.wait()
1518
1518
1519 return ar
1519 return ar
1520
1520
1521 @spin_first
1521 @spin_first
1522 def result_status(self, msg_ids, status_only=True):
1522 def result_status(self, msg_ids, status_only=True):
1523 """Check on the status of the result(s) of the apply request with `msg_ids`.
1523 """Check on the status of the result(s) of the apply request with `msg_ids`.
1524
1524
1525 If status_only is False, then the actual results will be retrieved, else
1525 If status_only is False, then the actual results will be retrieved, else
1526 only the status of the results will be checked.
1526 only the status of the results will be checked.
1527
1527
1528 Parameters
1528 Parameters
1529 ----------
1529 ----------
1530
1530
1531 msg_ids : list of msg_ids
1531 msg_ids : list of msg_ids
1532 if int:
1532 if int:
1533 Passed as index to self.history for convenience.
1533 Passed as index to self.history for convenience.
1534 status_only : bool (default: True)
1534 status_only : bool (default: True)
1535 if False:
1535 if False:
1536 Retrieve the actual results of completed tasks.
1536 Retrieve the actual results of completed tasks.
1537
1537
1538 Returns
1538 Returns
1539 -------
1539 -------
1540
1540
1541 results : dict
1541 results : dict
1542 There will always be the keys 'pending' and 'completed', which will
1542 There will always be the keys 'pending' and 'completed', which will
1543 be lists of msg_ids that are incomplete or complete. If `status_only`
1543 be lists of msg_ids that are incomplete or complete. If `status_only`
1544 is False, then completed results will be keyed by their `msg_id`.
1544 is False, then completed results will be keyed by their `msg_id`.
1545 """
1545 """
1546 if not isinstance(msg_ids, (list,tuple)):
1546 if not isinstance(msg_ids, (list,tuple)):
1547 msg_ids = [msg_ids]
1547 msg_ids = [msg_ids]
1548
1548
1549 theids = []
1549 theids = []
1550 for msg_id in msg_ids:
1550 for msg_id in msg_ids:
1551 if isinstance(msg_id, int):
1551 if isinstance(msg_id, int):
1552 msg_id = self.history[msg_id]
1552 msg_id = self.history[msg_id]
1553 if not isinstance(msg_id, string_types):
1553 if not isinstance(msg_id, string_types):
1554 raise TypeError("msg_ids must be str, not %r"%msg_id)
1554 raise TypeError("msg_ids must be str, not %r"%msg_id)
1555 theids.append(msg_id)
1555 theids.append(msg_id)
1556
1556
1557 completed = []
1557 completed = []
1558 local_results = {}
1558 local_results = {}
1559
1559
1560 # comment this block out to temporarily disable local shortcut:
1560 # comment this block out to temporarily disable local shortcut:
1561 for msg_id in theids:
1561 for msg_id in theids:
1562 if msg_id in self.results:
1562 if msg_id in self.results:
1563 completed.append(msg_id)
1563 completed.append(msg_id)
1564 local_results[msg_id] = self.results[msg_id]
1564 local_results[msg_id] = self.results[msg_id]
1565 theids.remove(msg_id)
1565 theids.remove(msg_id)
1566
1566
1567 if theids: # some not locally cached
1567 if theids: # some not locally cached
1568 content = dict(msg_ids=theids, status_only=status_only)
1568 content = dict(msg_ids=theids, status_only=status_only)
1569 msg = self.session.send(self._query_socket, "result_request", content=content)
1569 msg = self.session.send(self._query_socket, "result_request", content=content)
1570 zmq.select([self._query_socket], [], [])
1570 zmq.select([self._query_socket], [], [])
1571 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1571 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1572 if self.debug:
1572 if self.debug:
1573 pprint(msg)
1573 pprint(msg)
1574 content = msg['content']
1574 content = msg['content']
1575 if content['status'] != 'ok':
1575 if content['status'] != 'ok':
1576 raise self._unwrap_exception(content)
1576 raise self._unwrap_exception(content)
1577 buffers = msg['buffers']
1577 buffers = msg['buffers']
1578 else:
1578 else:
1579 content = dict(completed=[],pending=[])
1579 content = dict(completed=[],pending=[])
1580
1580
1581 content['completed'].extend(completed)
1581 content['completed'].extend(completed)
1582
1582
1583 if status_only:
1583 if status_only:
1584 return content
1584 return content
1585
1585
1586 failures = []
1586 failures = []
1587 # load cached results into result:
1587 # load cached results into result:
1588 content.update(local_results)
1588 content.update(local_results)
1589
1589
1590 # update cache with results:
1590 # update cache with results:
1591 for msg_id in sorted(theids):
1591 for msg_id in sorted(theids):
1592 if msg_id in content['completed']:
1592 if msg_id in content['completed']:
1593 rec = content[msg_id]
1593 rec = content[msg_id]
1594 parent = extract_dates(rec['header'])
1594 parent = extract_dates(rec['header'])
1595 header = extract_dates(rec['result_header'])
1595 header = extract_dates(rec['result_header'])
1596 rcontent = rec['result_content']
1596 rcontent = rec['result_content']
1597 iodict = rec['io']
1597 iodict = rec['io']
1598 if isinstance(rcontent, str):
1598 if isinstance(rcontent, str):
1599 rcontent = self.session.unpack(rcontent)
1599 rcontent = self.session.unpack(rcontent)
1600
1600
1601 md = self.metadata[msg_id]
1601 md = self.metadata[msg_id]
1602 md_msg = dict(
1602 md_msg = dict(
1603 content=rcontent,
1603 content=rcontent,
1604 parent_header=parent,
1604 parent_header=parent,
1605 header=header,
1605 header=header,
1606 metadata=rec['result_metadata'],
1606 metadata=rec['result_metadata'],
1607 )
1607 )
1608 md.update(self._extract_metadata(md_msg))
1608 md.update(self._extract_metadata(md_msg))
1609 if rec.get('received'):
1609 if rec.get('received'):
1610 md['received'] = parse_date(rec['received'])
1610 md['received'] = parse_date(rec['received'])
1611 md.update(iodict)
1611 md.update(iodict)
1612
1612
1613 if rcontent['status'] == 'ok':
1613 if rcontent['status'] == 'ok':
1614 if header['msg_type'] == 'apply_reply':
1614 if header['msg_type'] == 'apply_reply':
1615 res,buffers = serialize.deserialize_object(buffers)
1615 res,buffers = serialize.deserialize_object(buffers)
1616 elif header['msg_type'] == 'execute_reply':
1616 elif header['msg_type'] == 'execute_reply':
1617 res = ExecuteReply(msg_id, rcontent, md)
1617 res = ExecuteReply(msg_id, rcontent, md)
1618 else:
1618 else:
1619 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1619 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1620 else:
1620 else:
1621 res = self._unwrap_exception(rcontent)
1621 res = self._unwrap_exception(rcontent)
1622 failures.append(res)
1622 failures.append(res)
1623
1623
1624 self.results[msg_id] = res
1624 self.results[msg_id] = res
1625 content[msg_id] = res
1625 content[msg_id] = res
1626
1626
1627 if len(theids) == 1 and failures:
1627 if len(theids) == 1 and failures:
1628 raise failures[0]
1628 raise failures[0]
1629
1629
1630 error.collect_exceptions(failures, "result_status")
1630 error.collect_exceptions(failures, "result_status")
1631 return content
1631 return content
1632
1632
1633 @spin_first
1633 @spin_first
1634 def queue_status(self, targets='all', verbose=False):
1634 def queue_status(self, targets='all', verbose=False):
1635 """Fetch the status of engine queues.
1635 """Fetch the status of engine queues.
1636
1636
1637 Parameters
1637 Parameters
1638 ----------
1638 ----------
1639
1639
1640 targets : int/str/list of ints/strs
1640 targets : int/str/list of ints/strs
1641 the engines whose states are to be queried.
1641 the engines whose states are to be queried.
1642 default : all
1642 default : all
1643 verbose : bool
1643 verbose : bool
1644 Whether to return lengths only, or lists of ids for each element
1644 Whether to return lengths only, or lists of ids for each element
1645 """
1645 """
1646 if targets == 'all':
1646 if targets == 'all':
1647 # allow 'all' to be evaluated on the engine
1647 # allow 'all' to be evaluated on the engine
1648 engine_ids = None
1648 engine_ids = None
1649 else:
1649 else:
1650 engine_ids = self._build_targets(targets)[1]
1650 engine_ids = self._build_targets(targets)[1]
1651 content = dict(targets=engine_ids, verbose=verbose)
1651 content = dict(targets=engine_ids, verbose=verbose)
1652 self.session.send(self._query_socket, "queue_request", content=content)
1652 self.session.send(self._query_socket, "queue_request", content=content)
1653 idents,msg = self.session.recv(self._query_socket, 0)
1653 idents,msg = self.session.recv(self._query_socket, 0)
1654 if self.debug:
1654 if self.debug:
1655 pprint(msg)
1655 pprint(msg)
1656 content = msg['content']
1656 content = msg['content']
1657 status = content.pop('status')
1657 status = content.pop('status')
1658 if status != 'ok':
1658 if status != 'ok':
1659 raise self._unwrap_exception(content)
1659 raise self._unwrap_exception(content)
1660 content = rekey(content)
1660 content = rekey(content)
1661 if isinstance(targets, int):
1661 if isinstance(targets, int):
1662 return content[targets]
1662 return content[targets]
1663 else:
1663 else:
1664 return content
1664 return content
1665
1665
1666 def _build_msgids_from_target(self, targets=None):
1666 def _build_msgids_from_target(self, targets=None):
1667 """Build a list of msg_ids from the list of engine targets"""
1667 """Build a list of msg_ids from the list of engine targets"""
1668 if not targets: # needed as _build_targets otherwise uses all engines
1668 if not targets: # needed as _build_targets otherwise uses all engines
1669 return []
1669 return []
1670 target_ids = self._build_targets(targets)[0]
1670 target_ids = self._build_targets(targets)[0]
1671 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1671 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1672
1672
1673 def _build_msgids_from_jobs(self, jobs=None):
1673 def _build_msgids_from_jobs(self, jobs=None):
1674 """Build a list of msg_ids from "jobs" """
1674 """Build a list of msg_ids from "jobs" """
1675 if not jobs:
1675 if not jobs:
1676 return []
1676 return []
1677 msg_ids = []
1677 msg_ids = []
1678 if isinstance(jobs, string_types + (AsyncResult,)):
1678 if isinstance(jobs, string_types + (AsyncResult,)):
1679 jobs = [jobs]
1679 jobs = [jobs]
1680 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1680 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1681 if bad_ids:
1681 if bad_ids:
1682 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1682 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1683 for j in jobs:
1683 for j in jobs:
1684 if isinstance(j, AsyncResult):
1684 if isinstance(j, AsyncResult):
1685 msg_ids.extend(j.msg_ids)
1685 msg_ids.extend(j.msg_ids)
1686 else:
1686 else:
1687 msg_ids.append(j)
1687 msg_ids.append(j)
1688 return msg_ids
1688 return msg_ids
1689
1689
1690 def purge_local_results(self, jobs=[], targets=[]):
1690 def purge_local_results(self, jobs=[], targets=[]):
1691 """Clears the client caches of results and their metadata.
1691 """Clears the client caches of results and their metadata.
1692
1692
1693 Individual results can be purged by msg_id, or the entire
1693 Individual results can be purged by msg_id, or the entire
1694 history of specific targets can be purged.
1694 history of specific targets can be purged.
1695
1695
1696 Use `purge_local_results('all')` to scrub everything from the Clients's
1696 Use `purge_local_results('all')` to scrub everything from the Clients's
1697 results and metadata caches.
1697 results and metadata caches.
1698
1698
1699 After this call all `AsyncResults` are invalid and should be discarded.
1699 After this call all `AsyncResults` are invalid and should be discarded.
1700
1700
1701 If you must "reget" the results, you can still do so by using
1701 If you must "reget" the results, you can still do so by using
1702 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1702 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1703 redownload the results from the hub if they are still available
1703 redownload the results from the hub if they are still available
1704 (i.e `client.purge_hub_results(...)` has not been called.
1704 (i.e `client.purge_hub_results(...)` has not been called.
1705
1705
1706 Parameters
1706 Parameters
1707 ----------
1707 ----------
1708
1708
1709 jobs : str or list of str or AsyncResult objects
1709 jobs : str or list of str or AsyncResult objects
1710 the msg_ids whose results should be purged.
1710 the msg_ids whose results should be purged.
1711 targets : int/list of ints
1711 targets : int/list of ints
1712 The engines, by integer ID, whose entire result histories are to be purged.
1712 The engines, by integer ID, whose entire result histories are to be purged.
1713
1713
1714 Raises
1714 Raises
1715 ------
1715 ------
1716
1716
1717 RuntimeError : if any of the tasks to be purged are still outstanding.
1717 RuntimeError : if any of the tasks to be purged are still outstanding.
1718
1718
1719 """
1719 """
1720 if not targets and not jobs:
1720 if not targets and not jobs:
1721 raise ValueError("Must specify at least one of `targets` and `jobs`")
1721 raise ValueError("Must specify at least one of `targets` and `jobs`")
1722
1722
1723 if jobs == 'all':
1723 if jobs == 'all':
1724 if self.outstanding:
1724 if self.outstanding:
1725 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1725 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1726 self.results.clear()
1726 self.results.clear()
1727 self.metadata.clear()
1727 self.metadata.clear()
1728 else:
1728 else:
1729 msg_ids = set()
1729 msg_ids = set()
1730 msg_ids.update(self._build_msgids_from_target(targets))
1730 msg_ids.update(self._build_msgids_from_target(targets))
1731 msg_ids.update(self._build_msgids_from_jobs(jobs))
1731 msg_ids.update(self._build_msgids_from_jobs(jobs))
1732 still_outstanding = self.outstanding.intersection(msg_ids)
1732 still_outstanding = self.outstanding.intersection(msg_ids)
1733 if still_outstanding:
1733 if still_outstanding:
1734 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1734 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1735 for mid in msg_ids:
1735 for mid in msg_ids:
1736 self.results.pop(mid, None)
1736 self.results.pop(mid, None)
1737 self.metadata.pop(mid, None)
1737 self.metadata.pop(mid, None)
1738
1738
1739
1739
1740 @spin_first
1740 @spin_first
1741 def purge_hub_results(self, jobs=[], targets=[]):
1741 def purge_hub_results(self, jobs=[], targets=[]):
1742 """Tell the Hub to forget results.
1742 """Tell the Hub to forget results.
1743
1743
1744 Individual results can be purged by msg_id, or the entire
1744 Individual results can be purged by msg_id, or the entire
1745 history of specific targets can be purged.
1745 history of specific targets can be purged.
1746
1746
1747 Use `purge_results('all')` to scrub everything from the Hub's db.
1747 Use `purge_results('all')` to scrub everything from the Hub's db.
1748
1748
1749 Parameters
1749 Parameters
1750 ----------
1750 ----------
1751
1751
1752 jobs : str or list of str or AsyncResult objects
1752 jobs : str or list of str or AsyncResult objects
1753 the msg_ids whose results should be forgotten.
1753 the msg_ids whose results should be forgotten.
1754 targets : int/str/list of ints/strs
1754 targets : int/str/list of ints/strs
1755 The targets, by int_id, whose entire history is to be purged.
1755 The targets, by int_id, whose entire history is to be purged.
1756
1756
1757 default : None
1757 default : None
1758 """
1758 """
1759 if not targets and not jobs:
1759 if not targets and not jobs:
1760 raise ValueError("Must specify at least one of `targets` and `jobs`")
1760 raise ValueError("Must specify at least one of `targets` and `jobs`")
1761 if targets:
1761 if targets:
1762 targets = self._build_targets(targets)[1]
1762 targets = self._build_targets(targets)[1]
1763
1763
1764 # construct msg_ids from jobs
1764 # construct msg_ids from jobs
1765 if jobs == 'all':
1765 if jobs == 'all':
1766 msg_ids = jobs
1766 msg_ids = jobs
1767 else:
1767 else:
1768 msg_ids = self._build_msgids_from_jobs(jobs)
1768 msg_ids = self._build_msgids_from_jobs(jobs)
1769
1769
1770 content = dict(engine_ids=targets, msg_ids=msg_ids)
1770 content = dict(engine_ids=targets, msg_ids=msg_ids)
1771 self.session.send(self._query_socket, "purge_request", content=content)
1771 self.session.send(self._query_socket, "purge_request", content=content)
1772 idents, msg = self.session.recv(self._query_socket, 0)
1772 idents, msg = self.session.recv(self._query_socket, 0)
1773 if self.debug:
1773 if self.debug:
1774 pprint(msg)
1774 pprint(msg)
1775 content = msg['content']
1775 content = msg['content']
1776 if content['status'] != 'ok':
1776 if content['status'] != 'ok':
1777 raise self._unwrap_exception(content)
1777 raise self._unwrap_exception(content)
1778
1778
1779 def purge_results(self, jobs=[], targets=[]):
1779 def purge_results(self, jobs=[], targets=[]):
1780 """Clears the cached results from both the hub and the local client
1780 """Clears the cached results from both the hub and the local client
1781
1781
1782 Individual results can be purged by msg_id, or the entire
1782 Individual results can be purged by msg_id, or the entire
1783 history of specific targets can be purged.
1783 history of specific targets can be purged.
1784
1784
1785 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1785 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1786 the Client's db.
1786 the Client's db.
1787
1787
1788 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1788 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1789 the same arguments.
1789 the same arguments.
1790
1790
1791 Parameters
1791 Parameters
1792 ----------
1792 ----------
1793
1793
1794 jobs : str or list of str or AsyncResult objects
1794 jobs : str or list of str or AsyncResult objects
1795 the msg_ids whose results should be forgotten.
1795 the msg_ids whose results should be forgotten.
1796 targets : int/str/list of ints/strs
1796 targets : int/str/list of ints/strs
1797 The targets, by int_id, whose entire history is to be purged.
1797 The targets, by int_id, whose entire history is to be purged.
1798
1798
1799 default : None
1799 default : None
1800 """
1800 """
1801 self.purge_local_results(jobs=jobs, targets=targets)
1801 self.purge_local_results(jobs=jobs, targets=targets)
1802 self.purge_hub_results(jobs=jobs, targets=targets)
1802 self.purge_hub_results(jobs=jobs, targets=targets)
1803
1803
1804 def purge_everything(self):
1804 def purge_everything(self):
1805 """Clears all content from previous Tasks from both the hub and the local client
1805 """Clears all content from previous Tasks from both the hub and the local client
1806
1806
1807 In addition to calling `purge_results("all")` it also deletes the history and
1807 In addition to calling `purge_results("all")` it also deletes the history and
1808 other bookkeeping lists.
1808 other bookkeeping lists.
1809 """
1809 """
1810 self.purge_results("all")
1810 self.purge_results("all")
1811 self.history = []
1811 self.history = []
1812 self.session.digest_history.clear()
1812 self.session.digest_history.clear()
1813
1813
1814 @spin_first
1814 @spin_first
1815 def hub_history(self):
1815 def hub_history(self):
1816 """Get the Hub's history
1816 """Get the Hub's history
1817
1817
1818 Just like the Client, the Hub has a history, which is a list of msg_ids.
1818 Just like the Client, the Hub has a history, which is a list of msg_ids.
1819 This will contain the history of all clients, and, depending on configuration,
1819 This will contain the history of all clients, and, depending on configuration,
1820 may contain history across multiple cluster sessions.
1820 may contain history across multiple cluster sessions.
1821
1821
1822 Any msg_id returned here is a valid argument to `get_result`.
1822 Any msg_id returned here is a valid argument to `get_result`.
1823
1823
1824 Returns
1824 Returns
1825 -------
1825 -------
1826
1826
1827 msg_ids : list of strs
1827 msg_ids : list of strs
1828 list of all msg_ids, ordered by task submission time.
1828 list of all msg_ids, ordered by task submission time.
1829 """
1829 """
1830
1830
1831 self.session.send(self._query_socket, "history_request", content={})
1831 self.session.send(self._query_socket, "history_request", content={})
1832 idents, msg = self.session.recv(self._query_socket, 0)
1832 idents, msg = self.session.recv(self._query_socket, 0)
1833
1833
1834 if self.debug:
1834 if self.debug:
1835 pprint(msg)
1835 pprint(msg)
1836 content = msg['content']
1836 content = msg['content']
1837 if content['status'] != 'ok':
1837 if content['status'] != 'ok':
1838 raise self._unwrap_exception(content)
1838 raise self._unwrap_exception(content)
1839 else:
1839 else:
1840 return content['history']
1840 return content['history']
1841
1841
1842 @spin_first
1842 @spin_first
1843 def db_query(self, query, keys=None):
1843 def db_query(self, query, keys=None):
1844 """Query the Hub's TaskRecord database
1844 """Query the Hub's TaskRecord database
1845
1845
1846 This will return a list of task record dicts that match `query`
1846 This will return a list of task record dicts that match `query`
1847
1847
1848 Parameters
1848 Parameters
1849 ----------
1849 ----------
1850
1850
1851 query : mongodb query dict
1851 query : mongodb query dict
1852 The search dict. See mongodb query docs for details.
1852 The search dict. See mongodb query docs for details.
1853 keys : list of strs [optional]
1853 keys : list of strs [optional]
1854 The subset of keys to be returned. The default is to fetch everything but buffers.
1854 The subset of keys to be returned. The default is to fetch everything but buffers.
1855 'msg_id' will *always* be included.
1855 'msg_id' will *always* be included.
1856 """
1856 """
1857 if isinstance(keys, string_types):
1857 if isinstance(keys, string_types):
1858 keys = [keys]
1858 keys = [keys]
1859 content = dict(query=query, keys=keys)
1859 content = dict(query=query, keys=keys)
1860 self.session.send(self._query_socket, "db_request", content=content)
1860 self.session.send(self._query_socket, "db_request", content=content)
1861 idents, msg = self.session.recv(self._query_socket, 0)
1861 idents, msg = self.session.recv(self._query_socket, 0)
1862 if self.debug:
1862 if self.debug:
1863 pprint(msg)
1863 pprint(msg)
1864 content = msg['content']
1864 content = msg['content']
1865 if content['status'] != 'ok':
1865 if content['status'] != 'ok':
1866 raise self._unwrap_exception(content)
1866 raise self._unwrap_exception(content)
1867
1867
1868 records = content['records']
1868 records = content['records']
1869
1869
1870 buffer_lens = content['buffer_lens']
1870 buffer_lens = content['buffer_lens']
1871 result_buffer_lens = content['result_buffer_lens']
1871 result_buffer_lens = content['result_buffer_lens']
1872 buffers = msg['buffers']
1872 buffers = msg['buffers']
1873 has_bufs = buffer_lens is not None
1873 has_bufs = buffer_lens is not None
1874 has_rbufs = result_buffer_lens is not None
1874 has_rbufs = result_buffer_lens is not None
1875 for i,rec in enumerate(records):
1875 for i,rec in enumerate(records):
1876 # unpack datetime objects
1876 # unpack datetime objects
1877 for hkey in ('header', 'result_header'):
1877 for hkey in ('header', 'result_header'):
1878 if hkey in rec:
1878 if hkey in rec:
1879 rec[hkey] = extract_dates(rec[hkey])
1879 rec[hkey] = extract_dates(rec[hkey])
1880 for dtkey in ('submitted', 'started', 'completed', 'received'):
1880 for dtkey in ('submitted', 'started', 'completed', 'received'):
1881 if dtkey in rec:
1881 if dtkey in rec:
1882 rec[dtkey] = parse_date(rec[dtkey])
1882 rec[dtkey] = parse_date(rec[dtkey])
1883 # relink buffers
1883 # relink buffers
1884 if has_bufs:
1884 if has_bufs:
1885 blen = buffer_lens[i]
1885 blen = buffer_lens[i]
1886 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1886 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1887 if has_rbufs:
1887 if has_rbufs:
1888 blen = result_buffer_lens[i]
1888 blen = result_buffer_lens[i]
1889 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1889 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1890
1890
1891 return records
1891 return records
1892
1892
1893 __all__ = [ 'Client' ]
1893 __all__ = [ 'Client' ]
@@ -1,276 +1,276 b''
1 """Remote Functions and decorators for Views."""
1 """Remote Functions and decorators for Views."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import division
6 from __future__ import division
7
7
8 import sys
8 import sys
9 import warnings
9 import warnings
10
10
11 from IPython.external.decorator import decorator
11 from decorator import decorator
12 from IPython.testing.skipdoctest import skip_doctest
12 from IPython.testing.skipdoctest import skip_doctest
13
13
14 from . import map as Map
14 from . import map as Map
15 from .asyncresult import AsyncMapResult
15 from .asyncresult import AsyncMapResult
16
16
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 # Functions and Decorators
18 # Functions and Decorators
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20
20
21 @skip_doctest
21 @skip_doctest
22 def remote(view, block=None, **flags):
22 def remote(view, block=None, **flags):
23 """Turn a function into a remote function.
23 """Turn a function into a remote function.
24
24
25 This method can be used for map:
25 This method can be used for map:
26
26
27 In [1]: @remote(view,block=True)
27 In [1]: @remote(view,block=True)
28 ...: def func(a):
28 ...: def func(a):
29 ...: pass
29 ...: pass
30 """
30 """
31
31
32 def remote_function(f):
32 def remote_function(f):
33 return RemoteFunction(view, f, block=block, **flags)
33 return RemoteFunction(view, f, block=block, **flags)
34 return remote_function
34 return remote_function
35
35
36 @skip_doctest
36 @skip_doctest
37 def parallel(view, dist='b', block=None, ordered=True, **flags):
37 def parallel(view, dist='b', block=None, ordered=True, **flags):
38 """Turn a function into a parallel remote function.
38 """Turn a function into a parallel remote function.
39
39
40 This method can be used for map:
40 This method can be used for map:
41
41
42 In [1]: @parallel(view, block=True)
42 In [1]: @parallel(view, block=True)
43 ...: def func(a):
43 ...: def func(a):
44 ...: pass
44 ...: pass
45 """
45 """
46
46
47 def parallel_function(f):
47 def parallel_function(f):
48 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
48 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
49 return parallel_function
49 return parallel_function
50
50
51 def getname(f):
51 def getname(f):
52 """Get the name of an object.
52 """Get the name of an object.
53
53
54 For use in case of callables that are not functions, and
54 For use in case of callables that are not functions, and
55 thus may not have __name__ defined.
55 thus may not have __name__ defined.
56
56
57 Order: f.__name__ > f.name > str(f)
57 Order: f.__name__ > f.name > str(f)
58 """
58 """
59 try:
59 try:
60 return f.__name__
60 return f.__name__
61 except:
61 except:
62 pass
62 pass
63 try:
63 try:
64 return f.name
64 return f.name
65 except:
65 except:
66 pass
66 pass
67
67
68 return str(f)
68 return str(f)
69
69
70 @decorator
70 @decorator
71 def sync_view_results(f, self, *args, **kwargs):
71 def sync_view_results(f, self, *args, **kwargs):
72 """sync relevant results from self.client to our results attribute.
72 """sync relevant results from self.client to our results attribute.
73
73
74 This is a clone of view.sync_results, but for remote functions
74 This is a clone of view.sync_results, but for remote functions
75 """
75 """
76 view = self.view
76 view = self.view
77 if view._in_sync_results:
77 if view._in_sync_results:
78 return f(self, *args, **kwargs)
78 return f(self, *args, **kwargs)
79 view._in_sync_results = True
79 view._in_sync_results = True
80 try:
80 try:
81 ret = f(self, *args, **kwargs)
81 ret = f(self, *args, **kwargs)
82 finally:
82 finally:
83 view._in_sync_results = False
83 view._in_sync_results = False
84 view._sync_results()
84 view._sync_results()
85 return ret
85 return ret
86
86
87 #--------------------------------------------------------------------------
87 #--------------------------------------------------------------------------
88 # Classes
88 # Classes
89 #--------------------------------------------------------------------------
89 #--------------------------------------------------------------------------
90
90
91 class RemoteFunction(object):
91 class RemoteFunction(object):
92 """Turn an existing function into a remote function.
92 """Turn an existing function into a remote function.
93
93
94 Parameters
94 Parameters
95 ----------
95 ----------
96
96
97 view : View instance
97 view : View instance
98 The view to be used for execution
98 The view to be used for execution
99 f : callable
99 f : callable
100 The function to be wrapped into a remote function
100 The function to be wrapped into a remote function
101 block : bool [default: None]
101 block : bool [default: None]
102 Whether to wait for results or not. The default behavior is
102 Whether to wait for results or not. The default behavior is
103 to use the current `block` attribute of `view`
103 to use the current `block` attribute of `view`
104
104
105 **flags : remaining kwargs are passed to View.temp_flags
105 **flags : remaining kwargs are passed to View.temp_flags
106 """
106 """
107
107
108 view = None # the remote connection
108 view = None # the remote connection
109 func = None # the wrapped function
109 func = None # the wrapped function
110 block = None # whether to block
110 block = None # whether to block
111 flags = None # dict of extra kwargs for temp_flags
111 flags = None # dict of extra kwargs for temp_flags
112
112
113 def __init__(self, view, f, block=None, **flags):
113 def __init__(self, view, f, block=None, **flags):
114 self.view = view
114 self.view = view
115 self.func = f
115 self.func = f
116 self.block=block
116 self.block=block
117 self.flags=flags
117 self.flags=flags
118
118
119 def __call__(self, *args, **kwargs):
119 def __call__(self, *args, **kwargs):
120 block = self.view.block if self.block is None else self.block
120 block = self.view.block if self.block is None else self.block
121 with self.view.temp_flags(block=block, **self.flags):
121 with self.view.temp_flags(block=block, **self.flags):
122 return self.view.apply(self.func, *args, **kwargs)
122 return self.view.apply(self.func, *args, **kwargs)
123
123
124
124
125 class ParallelFunction(RemoteFunction):
125 class ParallelFunction(RemoteFunction):
126 """Class for mapping a function to sequences.
126 """Class for mapping a function to sequences.
127
127
128 This will distribute the sequences according the a mapper, and call
128 This will distribute the sequences according the a mapper, and call
129 the function on each sub-sequence. If called via map, then the function
129 the function on each sub-sequence. If called via map, then the function
130 will be called once on each element, rather that each sub-sequence.
130 will be called once on each element, rather that each sub-sequence.
131
131
132 Parameters
132 Parameters
133 ----------
133 ----------
134
134
135 view : View instance
135 view : View instance
136 The view to be used for execution
136 The view to be used for execution
137 f : callable
137 f : callable
138 The function to be wrapped into a remote function
138 The function to be wrapped into a remote function
139 dist : str [default: 'b']
139 dist : str [default: 'b']
140 The key for which mapObject to use to distribute sequences
140 The key for which mapObject to use to distribute sequences
141 options are:
141 options are:
142
142
143 * 'b' : use contiguous chunks in order
143 * 'b' : use contiguous chunks in order
144 * 'r' : use round-robin striping
144 * 'r' : use round-robin striping
145
145
146 block : bool [default: None]
146 block : bool [default: None]
147 Whether to wait for results or not. The default behavior is
147 Whether to wait for results or not. The default behavior is
148 to use the current `block` attribute of `view`
148 to use the current `block` attribute of `view`
149 chunksize : int or None
149 chunksize : int or None
150 The size of chunk to use when breaking up sequences in a load-balanced manner
150 The size of chunk to use when breaking up sequences in a load-balanced manner
151 ordered : bool [default: True]
151 ordered : bool [default: True]
152 Whether the result should be kept in order. If False,
152 Whether the result should be kept in order. If False,
153 results become available as they arrive, regardless of submission order.
153 results become available as they arrive, regardless of submission order.
154 **flags
154 **flags
155 remaining kwargs are passed to View.temp_flags
155 remaining kwargs are passed to View.temp_flags
156 """
156 """
157
157
158 chunksize = None
158 chunksize = None
159 ordered = None
159 ordered = None
160 mapObject = None
160 mapObject = None
161 _mapping = False
161 _mapping = False
162
162
163 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
163 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
164 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
164 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
165 self.chunksize = chunksize
165 self.chunksize = chunksize
166 self.ordered = ordered
166 self.ordered = ordered
167
167
168 mapClass = Map.dists[dist]
168 mapClass = Map.dists[dist]
169 self.mapObject = mapClass()
169 self.mapObject = mapClass()
170
170
171 @sync_view_results
171 @sync_view_results
172 def __call__(self, *sequences):
172 def __call__(self, *sequences):
173 client = self.view.client
173 client = self.view.client
174
174
175 lens = []
175 lens = []
176 maxlen = minlen = -1
176 maxlen = minlen = -1
177 for i, seq in enumerate(sequences):
177 for i, seq in enumerate(sequences):
178 try:
178 try:
179 n = len(seq)
179 n = len(seq)
180 except Exception:
180 except Exception:
181 seq = list(seq)
181 seq = list(seq)
182 if isinstance(sequences, tuple):
182 if isinstance(sequences, tuple):
183 # can't alter a tuple
183 # can't alter a tuple
184 sequences = list(sequences)
184 sequences = list(sequences)
185 sequences[i] = seq
185 sequences[i] = seq
186 n = len(seq)
186 n = len(seq)
187 if n > maxlen:
187 if n > maxlen:
188 maxlen = n
188 maxlen = n
189 if minlen == -1 or n < minlen:
189 if minlen == -1 or n < minlen:
190 minlen = n
190 minlen = n
191 lens.append(n)
191 lens.append(n)
192
192
193 if maxlen == 0:
193 if maxlen == 0:
194 # nothing to iterate over
194 # nothing to iterate over
195 return []
195 return []
196
196
197 # check that the length of sequences match
197 # check that the length of sequences match
198 if not self._mapping and minlen != maxlen:
198 if not self._mapping and minlen != maxlen:
199 msg = 'all sequences must have equal length, but have %s' % lens
199 msg = 'all sequences must have equal length, but have %s' % lens
200 raise ValueError(msg)
200 raise ValueError(msg)
201
201
202 balanced = 'Balanced' in self.view.__class__.__name__
202 balanced = 'Balanced' in self.view.__class__.__name__
203 if balanced:
203 if balanced:
204 if self.chunksize:
204 if self.chunksize:
205 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
205 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
206 else:
206 else:
207 nparts = maxlen
207 nparts = maxlen
208 targets = [None]*nparts
208 targets = [None]*nparts
209 else:
209 else:
210 if self.chunksize:
210 if self.chunksize:
211 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
211 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
212 # multiplexed:
212 # multiplexed:
213 targets = self.view.targets
213 targets = self.view.targets
214 # 'all' is lazily evaluated at execution time, which is now:
214 # 'all' is lazily evaluated at execution time, which is now:
215 if targets == 'all':
215 if targets == 'all':
216 targets = client._build_targets(targets)[1]
216 targets = client._build_targets(targets)[1]
217 elif isinstance(targets, int):
217 elif isinstance(targets, int):
218 # single-engine view, targets must be iterable
218 # single-engine view, targets must be iterable
219 targets = [targets]
219 targets = [targets]
220 nparts = len(targets)
220 nparts = len(targets)
221
221
222 msg_ids = []
222 msg_ids = []
223 for index, t in enumerate(targets):
223 for index, t in enumerate(targets):
224 args = []
224 args = []
225 for seq in sequences:
225 for seq in sequences:
226 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
226 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
227 args.append(part)
227 args.append(part)
228
228
229 if sum([len(arg) for arg in args]) == 0:
229 if sum([len(arg) for arg in args]) == 0:
230 continue
230 continue
231
231
232 if self._mapping:
232 if self._mapping:
233 if sys.version_info[0] >= 3:
233 if sys.version_info[0] >= 3:
234 f = lambda f, *sequences: list(map(f, *sequences))
234 f = lambda f, *sequences: list(map(f, *sequences))
235 else:
235 else:
236 f = map
236 f = map
237 args = [self.func] + args
237 args = [self.func] + args
238 else:
238 else:
239 f=self.func
239 f=self.func
240
240
241 view = self.view if balanced else client[t]
241 view = self.view if balanced else client[t]
242 with view.temp_flags(block=False, **self.flags):
242 with view.temp_flags(block=False, **self.flags):
243 ar = view.apply(f, *args)
243 ar = view.apply(f, *args)
244
244
245 msg_ids.extend(ar.msg_ids)
245 msg_ids.extend(ar.msg_ids)
246
246
247 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
247 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
248 fname=getname(self.func),
248 fname=getname(self.func),
249 ordered=self.ordered
249 ordered=self.ordered
250 )
250 )
251
251
252 if self.block:
252 if self.block:
253 try:
253 try:
254 return r.get()
254 return r.get()
255 except KeyboardInterrupt:
255 except KeyboardInterrupt:
256 return r
256 return r
257 else:
257 else:
258 return r
258 return r
259
259
260 def map(self, *sequences):
260 def map(self, *sequences):
261 """call a function on each element of one or more sequence(s) remotely.
261 """call a function on each element of one or more sequence(s) remotely.
262 This should behave very much like the builtin map, but return an AsyncMapResult
262 This should behave very much like the builtin map, but return an AsyncMapResult
263 if self.block is False.
263 if self.block is False.
264
264
265 That means it can take generators (will be cast to lists locally),
265 That means it can take generators (will be cast to lists locally),
266 and mismatched sequence lengths will be padded with None.
266 and mismatched sequence lengths will be padded with None.
267 """
267 """
268 # set _mapping as a flag for use inside self.__call__
268 # set _mapping as a flag for use inside self.__call__
269 self._mapping = True
269 self._mapping = True
270 try:
270 try:
271 ret = self(*sequences)
271 ret = self(*sequences)
272 finally:
272 finally:
273 self._mapping = False
273 self._mapping = False
274 return ret
274 return ret
275
275
276 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
276 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,1125 +1,1125 b''
1 """Views of remote engines."""
1 """Views of remote engines."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import print_function
6 from __future__ import print_function
7
7
8 import imp
8 import imp
9 import sys
9 import sys
10 import warnings
10 import warnings
11 from contextlib import contextmanager
11 from contextlib import contextmanager
12 from types import ModuleType
12 from types import ModuleType
13
13
14 import zmq
14 import zmq
15
15
16 from IPython.testing.skipdoctest import skip_doctest
16 from IPython.testing.skipdoctest import skip_doctest
17 from IPython.utils import pickleutil
17 from IPython.utils import pickleutil
18 from IPython.utils.traitlets import (
18 from IPython.utils.traitlets import (
19 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
19 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
20 )
20 )
21 from IPython.external.decorator import decorator
21 from decorator import decorator
22
22
23 from IPython.parallel import util
23 from IPython.parallel import util
24 from IPython.parallel.controller.dependency import Dependency, dependent
24 from IPython.parallel.controller.dependency import Dependency, dependent
25 from IPython.utils.py3compat import string_types, iteritems, PY3
25 from IPython.utils.py3compat import string_types, iteritems, PY3
26
26
27 from . import map as Map
27 from . import map as Map
28 from .asyncresult import AsyncResult, AsyncMapResult
28 from .asyncresult import AsyncResult, AsyncMapResult
29 from .remotefunction import ParallelFunction, parallel, remote, getname
29 from .remotefunction import ParallelFunction, parallel, remote, getname
30
30
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32 # Decorators
32 # Decorators
33 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
34
34
35 @decorator
35 @decorator
36 def save_ids(f, self, *args, **kwargs):
36 def save_ids(f, self, *args, **kwargs):
37 """Keep our history and outstanding attributes up to date after a method call."""
37 """Keep our history and outstanding attributes up to date after a method call."""
38 n_previous = len(self.client.history)
38 n_previous = len(self.client.history)
39 try:
39 try:
40 ret = f(self, *args, **kwargs)
40 ret = f(self, *args, **kwargs)
41 finally:
41 finally:
42 nmsgs = len(self.client.history) - n_previous
42 nmsgs = len(self.client.history) - n_previous
43 msg_ids = self.client.history[-nmsgs:]
43 msg_ids = self.client.history[-nmsgs:]
44 self.history.extend(msg_ids)
44 self.history.extend(msg_ids)
45 self.outstanding.update(msg_ids)
45 self.outstanding.update(msg_ids)
46 return ret
46 return ret
47
47
48 @decorator
48 @decorator
49 def sync_results(f, self, *args, **kwargs):
49 def sync_results(f, self, *args, **kwargs):
50 """sync relevant results from self.client to our results attribute."""
50 """sync relevant results from self.client to our results attribute."""
51 if self._in_sync_results:
51 if self._in_sync_results:
52 return f(self, *args, **kwargs)
52 return f(self, *args, **kwargs)
53 self._in_sync_results = True
53 self._in_sync_results = True
54 try:
54 try:
55 ret = f(self, *args, **kwargs)
55 ret = f(self, *args, **kwargs)
56 finally:
56 finally:
57 self._in_sync_results = False
57 self._in_sync_results = False
58 self._sync_results()
58 self._sync_results()
59 return ret
59 return ret
60
60
61 @decorator
61 @decorator
62 def spin_after(f, self, *args, **kwargs):
62 def spin_after(f, self, *args, **kwargs):
63 """call spin after the method."""
63 """call spin after the method."""
64 ret = f(self, *args, **kwargs)
64 ret = f(self, *args, **kwargs)
65 self.spin()
65 self.spin()
66 return ret
66 return ret
67
67
68 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
69 # Classes
69 # Classes
70 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
71
71
72 @skip_doctest
72 @skip_doctest
73 class View(HasTraits):
73 class View(HasTraits):
74 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
74 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
75
75
76 Don't use this class, use subclasses.
76 Don't use this class, use subclasses.
77
77
78 Methods
78 Methods
79 -------
79 -------
80
80
81 spin
81 spin
82 flushes incoming results and registration state changes
82 flushes incoming results and registration state changes
83 control methods spin, and requesting `ids` also ensures up to date
83 control methods spin, and requesting `ids` also ensures up to date
84
84
85 wait
85 wait
86 wait on one or more msg_ids
86 wait on one or more msg_ids
87
87
88 execution methods
88 execution methods
89 apply
89 apply
90 legacy: execute, run
90 legacy: execute, run
91
91
92 data movement
92 data movement
93 push, pull, scatter, gather
93 push, pull, scatter, gather
94
94
95 query methods
95 query methods
96 get_result, queue_status, purge_results, result_status
96 get_result, queue_status, purge_results, result_status
97
97
98 control methods
98 control methods
99 abort, shutdown
99 abort, shutdown
100
100
101 """
101 """
102 # flags
102 # flags
103 block=Bool(False)
103 block=Bool(False)
104 track=Bool(True)
104 track=Bool(True)
105 targets = Any()
105 targets = Any()
106
106
107 history=List()
107 history=List()
108 outstanding = Set()
108 outstanding = Set()
109 results = Dict()
109 results = Dict()
110 client = Instance('IPython.parallel.Client')
110 client = Instance('IPython.parallel.Client')
111
111
112 _socket = Instance('zmq.Socket')
112 _socket = Instance('zmq.Socket')
113 _flag_names = List(['targets', 'block', 'track'])
113 _flag_names = List(['targets', 'block', 'track'])
114 _in_sync_results = Bool(False)
114 _in_sync_results = Bool(False)
115 _targets = Any()
115 _targets = Any()
116 _idents = Any()
116 _idents = Any()
117
117
118 def __init__(self, client=None, socket=None, **flags):
118 def __init__(self, client=None, socket=None, **flags):
119 super(View, self).__init__(client=client, _socket=socket)
119 super(View, self).__init__(client=client, _socket=socket)
120 self.results = client.results
120 self.results = client.results
121 self.block = client.block
121 self.block = client.block
122
122
123 self.set_flags(**flags)
123 self.set_flags(**flags)
124
124
125 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
125 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
126
126
127 def __repr__(self):
127 def __repr__(self):
128 strtargets = str(self.targets)
128 strtargets = str(self.targets)
129 if len(strtargets) > 16:
129 if len(strtargets) > 16:
130 strtargets = strtargets[:12]+'...]'
130 strtargets = strtargets[:12]+'...]'
131 return "<%s %s>"%(self.__class__.__name__, strtargets)
131 return "<%s %s>"%(self.__class__.__name__, strtargets)
132
132
133 def __len__(self):
133 def __len__(self):
134 if isinstance(self.targets, list):
134 if isinstance(self.targets, list):
135 return len(self.targets)
135 return len(self.targets)
136 elif isinstance(self.targets, int):
136 elif isinstance(self.targets, int):
137 return 1
137 return 1
138 else:
138 else:
139 return len(self.client)
139 return len(self.client)
140
140
141 def set_flags(self, **kwargs):
141 def set_flags(self, **kwargs):
142 """set my attribute flags by keyword.
142 """set my attribute flags by keyword.
143
143
144 Views determine behavior with a few attributes (`block`, `track`, etc.).
144 Views determine behavior with a few attributes (`block`, `track`, etc.).
145 These attributes can be set all at once by name with this method.
145 These attributes can be set all at once by name with this method.
146
146
147 Parameters
147 Parameters
148 ----------
148 ----------
149
149
150 block : bool
150 block : bool
151 whether to wait for results
151 whether to wait for results
152 track : bool
152 track : bool
153 whether to create a MessageTracker to allow the user to
153 whether to create a MessageTracker to allow the user to
154 safely edit after arrays and buffers during non-copying
154 safely edit after arrays and buffers during non-copying
155 sends.
155 sends.
156 """
156 """
157 for name, value in iteritems(kwargs):
157 for name, value in iteritems(kwargs):
158 if name not in self._flag_names:
158 if name not in self._flag_names:
159 raise KeyError("Invalid name: %r"%name)
159 raise KeyError("Invalid name: %r"%name)
160 else:
160 else:
161 setattr(self, name, value)
161 setattr(self, name, value)
162
162
163 @contextmanager
163 @contextmanager
164 def temp_flags(self, **kwargs):
164 def temp_flags(self, **kwargs):
165 """temporarily set flags, for use in `with` statements.
165 """temporarily set flags, for use in `with` statements.
166
166
167 See set_flags for permanent setting of flags
167 See set_flags for permanent setting of flags
168
168
169 Examples
169 Examples
170 --------
170 --------
171
171
172 >>> view.track=False
172 >>> view.track=False
173 ...
173 ...
174 >>> with view.temp_flags(track=True):
174 >>> with view.temp_flags(track=True):
175 ... ar = view.apply(dostuff, my_big_array)
175 ... ar = view.apply(dostuff, my_big_array)
176 ... ar.tracker.wait() # wait for send to finish
176 ... ar.tracker.wait() # wait for send to finish
177 >>> view.track
177 >>> view.track
178 False
178 False
179
179
180 """
180 """
181 # preflight: save flags, and set temporaries
181 # preflight: save flags, and set temporaries
182 saved_flags = {}
182 saved_flags = {}
183 for f in self._flag_names:
183 for f in self._flag_names:
184 saved_flags[f] = getattr(self, f)
184 saved_flags[f] = getattr(self, f)
185 self.set_flags(**kwargs)
185 self.set_flags(**kwargs)
186 # yield to the with-statement block
186 # yield to the with-statement block
187 try:
187 try:
188 yield
188 yield
189 finally:
189 finally:
190 # postflight: restore saved flags
190 # postflight: restore saved flags
191 self.set_flags(**saved_flags)
191 self.set_flags(**saved_flags)
192
192
193
193
194 #----------------------------------------------------------------
194 #----------------------------------------------------------------
195 # apply
195 # apply
196 #----------------------------------------------------------------
196 #----------------------------------------------------------------
197
197
198 def _sync_results(self):
198 def _sync_results(self):
199 """to be called by @sync_results decorator
199 """to be called by @sync_results decorator
200
200
201 after submitting any tasks.
201 after submitting any tasks.
202 """
202 """
203 delta = self.outstanding.difference(self.client.outstanding)
203 delta = self.outstanding.difference(self.client.outstanding)
204 completed = self.outstanding.intersection(delta)
204 completed = self.outstanding.intersection(delta)
205 self.outstanding = self.outstanding.difference(completed)
205 self.outstanding = self.outstanding.difference(completed)
206
206
207 @sync_results
207 @sync_results
208 @save_ids
208 @save_ids
209 def _really_apply(self, f, args, kwargs, block=None, **options):
209 def _really_apply(self, f, args, kwargs, block=None, **options):
210 """wrapper for client.send_apply_request"""
210 """wrapper for client.send_apply_request"""
211 raise NotImplementedError("Implement in subclasses")
211 raise NotImplementedError("Implement in subclasses")
212
212
213 def apply(self, f, *args, **kwargs):
213 def apply(self, f, *args, **kwargs):
214 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
214 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
215
215
216 This method sets all apply flags via this View's attributes.
216 This method sets all apply flags via this View's attributes.
217
217
218 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
218 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
219 instance if ``self.block`` is False, otherwise the return value of
219 instance if ``self.block`` is False, otherwise the return value of
220 ``f(*args, **kwargs)``.
220 ``f(*args, **kwargs)``.
221 """
221 """
222 return self._really_apply(f, args, kwargs)
222 return self._really_apply(f, args, kwargs)
223
223
224 def apply_async(self, f, *args, **kwargs):
224 def apply_async(self, f, *args, **kwargs):
225 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
225 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
226
226
227 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
227 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
228 """
228 """
229 return self._really_apply(f, args, kwargs, block=False)
229 return self._really_apply(f, args, kwargs, block=False)
230
230
231 @spin_after
231 @spin_after
232 def apply_sync(self, f, *args, **kwargs):
232 def apply_sync(self, f, *args, **kwargs):
233 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
233 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
234 returning the result.
234 returning the result.
235 """
235 """
236 return self._really_apply(f, args, kwargs, block=True)
236 return self._really_apply(f, args, kwargs, block=True)
237
237
238 #----------------------------------------------------------------
238 #----------------------------------------------------------------
239 # wrappers for client and control methods
239 # wrappers for client and control methods
240 #----------------------------------------------------------------
240 #----------------------------------------------------------------
241 @sync_results
241 @sync_results
242 def spin(self):
242 def spin(self):
243 """spin the client, and sync"""
243 """spin the client, and sync"""
244 self.client.spin()
244 self.client.spin()
245
245
246 @sync_results
246 @sync_results
247 def wait(self, jobs=None, timeout=-1):
247 def wait(self, jobs=None, timeout=-1):
248 """waits on one or more `jobs`, for up to `timeout` seconds.
248 """waits on one or more `jobs`, for up to `timeout` seconds.
249
249
250 Parameters
250 Parameters
251 ----------
251 ----------
252
252
253 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
253 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
254 ints are indices to self.history
254 ints are indices to self.history
255 strs are msg_ids
255 strs are msg_ids
256 default: wait on all outstanding messages
256 default: wait on all outstanding messages
257 timeout : float
257 timeout : float
258 a time in seconds, after which to give up.
258 a time in seconds, after which to give up.
259 default is -1, which means no timeout
259 default is -1, which means no timeout
260
260
261 Returns
261 Returns
262 -------
262 -------
263
263
264 True : when all msg_ids are done
264 True : when all msg_ids are done
265 False : timeout reached, some msg_ids still outstanding
265 False : timeout reached, some msg_ids still outstanding
266 """
266 """
267 if jobs is None:
267 if jobs is None:
268 jobs = self.history
268 jobs = self.history
269 return self.client.wait(jobs, timeout)
269 return self.client.wait(jobs, timeout)
270
270
271 def abort(self, jobs=None, targets=None, block=None):
271 def abort(self, jobs=None, targets=None, block=None):
272 """Abort jobs on my engines.
272 """Abort jobs on my engines.
273
273
274 Parameters
274 Parameters
275 ----------
275 ----------
276
276
277 jobs : None, str, list of strs, optional
277 jobs : None, str, list of strs, optional
278 if None: abort all jobs.
278 if None: abort all jobs.
279 else: abort specific msg_id(s).
279 else: abort specific msg_id(s).
280 """
280 """
281 block = block if block is not None else self.block
281 block = block if block is not None else self.block
282 targets = targets if targets is not None else self.targets
282 targets = targets if targets is not None else self.targets
283 jobs = jobs if jobs is not None else list(self.outstanding)
283 jobs = jobs if jobs is not None else list(self.outstanding)
284
284
285 return self.client.abort(jobs=jobs, targets=targets, block=block)
285 return self.client.abort(jobs=jobs, targets=targets, block=block)
286
286
287 def queue_status(self, targets=None, verbose=False):
287 def queue_status(self, targets=None, verbose=False):
288 """Fetch the Queue status of my engines"""
288 """Fetch the Queue status of my engines"""
289 targets = targets if targets is not None else self.targets
289 targets = targets if targets is not None else self.targets
290 return self.client.queue_status(targets=targets, verbose=verbose)
290 return self.client.queue_status(targets=targets, verbose=verbose)
291
291
292 def purge_results(self, jobs=[], targets=[]):
292 def purge_results(self, jobs=[], targets=[]):
293 """Instruct the controller to forget specific results."""
293 """Instruct the controller to forget specific results."""
294 if targets is None or targets == 'all':
294 if targets is None or targets == 'all':
295 targets = self.targets
295 targets = self.targets
296 return self.client.purge_results(jobs=jobs, targets=targets)
296 return self.client.purge_results(jobs=jobs, targets=targets)
297
297
298 def shutdown(self, targets=None, restart=False, hub=False, block=None):
298 def shutdown(self, targets=None, restart=False, hub=False, block=None):
299 """Terminates one or more engine processes, optionally including the hub.
299 """Terminates one or more engine processes, optionally including the hub.
300 """
300 """
301 block = self.block if block is None else block
301 block = self.block if block is None else block
302 if targets is None or targets == 'all':
302 if targets is None or targets == 'all':
303 targets = self.targets
303 targets = self.targets
304 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
304 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
305
305
306 @spin_after
306 @spin_after
307 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
307 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
308 """return one or more results, specified by history index or msg_id.
308 """return one or more results, specified by history index or msg_id.
309
309
310 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
310 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
311 """
311 """
312
312
313 if indices_or_msg_ids is None:
313 if indices_or_msg_ids is None:
314 indices_or_msg_ids = -1
314 indices_or_msg_ids = -1
315 if isinstance(indices_or_msg_ids, int):
315 if isinstance(indices_or_msg_ids, int):
316 indices_or_msg_ids = self.history[indices_or_msg_ids]
316 indices_or_msg_ids = self.history[indices_or_msg_ids]
317 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
317 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
318 indices_or_msg_ids = list(indices_or_msg_ids)
318 indices_or_msg_ids = list(indices_or_msg_ids)
319 for i,index in enumerate(indices_or_msg_ids):
319 for i,index in enumerate(indices_or_msg_ids):
320 if isinstance(index, int):
320 if isinstance(index, int):
321 indices_or_msg_ids[i] = self.history[index]
321 indices_or_msg_ids[i] = self.history[index]
322 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
322 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
323
323
324 #-------------------------------------------------------------------
324 #-------------------------------------------------------------------
325 # Map
325 # Map
326 #-------------------------------------------------------------------
326 #-------------------------------------------------------------------
327
327
328 @sync_results
328 @sync_results
329 def map(self, f, *sequences, **kwargs):
329 def map(self, f, *sequences, **kwargs):
330 """override in subclasses"""
330 """override in subclasses"""
331 raise NotImplementedError
331 raise NotImplementedError
332
332
333 def map_async(self, f, *sequences, **kwargs):
333 def map_async(self, f, *sequences, **kwargs):
334 """Parallel version of builtin :func:`python:map`, using this view's engines.
334 """Parallel version of builtin :func:`python:map`, using this view's engines.
335
335
336 This is equivalent to ``map(...block=False)``.
336 This is equivalent to ``map(...block=False)``.
337
337
338 See `self.map` for details.
338 See `self.map` for details.
339 """
339 """
340 if 'block' in kwargs:
340 if 'block' in kwargs:
341 raise TypeError("map_async doesn't take a `block` keyword argument.")
341 raise TypeError("map_async doesn't take a `block` keyword argument.")
342 kwargs['block'] = False
342 kwargs['block'] = False
343 return self.map(f,*sequences,**kwargs)
343 return self.map(f,*sequences,**kwargs)
344
344
345 def map_sync(self, f, *sequences, **kwargs):
345 def map_sync(self, f, *sequences, **kwargs):
346 """Parallel version of builtin :func:`python:map`, using this view's engines.
346 """Parallel version of builtin :func:`python:map`, using this view's engines.
347
347
348 This is equivalent to ``map(...block=True)``.
348 This is equivalent to ``map(...block=True)``.
349
349
350 See `self.map` for details.
350 See `self.map` for details.
351 """
351 """
352 if 'block' in kwargs:
352 if 'block' in kwargs:
353 raise TypeError("map_sync doesn't take a `block` keyword argument.")
353 raise TypeError("map_sync doesn't take a `block` keyword argument.")
354 kwargs['block'] = True
354 kwargs['block'] = True
355 return self.map(f,*sequences,**kwargs)
355 return self.map(f,*sequences,**kwargs)
356
356
357 def imap(self, f, *sequences, **kwargs):
357 def imap(self, f, *sequences, **kwargs):
358 """Parallel version of :func:`itertools.imap`.
358 """Parallel version of :func:`itertools.imap`.
359
359
360 See `self.map` for details.
360 See `self.map` for details.
361
361
362 """
362 """
363
363
364 return iter(self.map_async(f,*sequences, **kwargs))
364 return iter(self.map_async(f,*sequences, **kwargs))
365
365
366 #-------------------------------------------------------------------
366 #-------------------------------------------------------------------
367 # Decorators
367 # Decorators
368 #-------------------------------------------------------------------
368 #-------------------------------------------------------------------
369
369
370 def remote(self, block=None, **flags):
370 def remote(self, block=None, **flags):
371 """Decorator for making a RemoteFunction"""
371 """Decorator for making a RemoteFunction"""
372 block = self.block if block is None else block
372 block = self.block if block is None else block
373 return remote(self, block=block, **flags)
373 return remote(self, block=block, **flags)
374
374
375 def parallel(self, dist='b', block=None, **flags):
375 def parallel(self, dist='b', block=None, **flags):
376 """Decorator for making a ParallelFunction"""
376 """Decorator for making a ParallelFunction"""
377 block = self.block if block is None else block
377 block = self.block if block is None else block
378 return parallel(self, dist=dist, block=block, **flags)
378 return parallel(self, dist=dist, block=block, **flags)
379
379
380 @skip_doctest
380 @skip_doctest
381 class DirectView(View):
381 class DirectView(View):
382 """Direct Multiplexer View of one or more engines.
382 """Direct Multiplexer View of one or more engines.
383
383
384 These are created via indexed access to a client:
384 These are created via indexed access to a client:
385
385
386 >>> dv_1 = client[1]
386 >>> dv_1 = client[1]
387 >>> dv_all = client[:]
387 >>> dv_all = client[:]
388 >>> dv_even = client[::2]
388 >>> dv_even = client[::2]
389 >>> dv_some = client[1:3]
389 >>> dv_some = client[1:3]
390
390
391 This object provides dictionary access to engine namespaces:
391 This object provides dictionary access to engine namespaces:
392
392
393 # push a=5:
393 # push a=5:
394 >>> dv['a'] = 5
394 >>> dv['a'] = 5
395 # pull 'foo':
395 # pull 'foo':
396 >>> dv['foo']
396 >>> dv['foo']
397
397
398 """
398 """
399
399
400 def __init__(self, client=None, socket=None, targets=None):
400 def __init__(self, client=None, socket=None, targets=None):
401 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
401 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
402
402
403 @property
403 @property
404 def importer(self):
404 def importer(self):
405 """sync_imports(local=True) as a property.
405 """sync_imports(local=True) as a property.
406
406
407 See sync_imports for details.
407 See sync_imports for details.
408
408
409 """
409 """
410 return self.sync_imports(True)
410 return self.sync_imports(True)
411
411
412 @contextmanager
412 @contextmanager
413 def sync_imports(self, local=True, quiet=False):
413 def sync_imports(self, local=True, quiet=False):
414 """Context Manager for performing simultaneous local and remote imports.
414 """Context Manager for performing simultaneous local and remote imports.
415
415
416 'import x as y' will *not* work. The 'as y' part will simply be ignored.
416 'import x as y' will *not* work. The 'as y' part will simply be ignored.
417
417
418 If `local=True`, then the package will also be imported locally.
418 If `local=True`, then the package will also be imported locally.
419
419
420 If `quiet=True`, no output will be produced when attempting remote
420 If `quiet=True`, no output will be produced when attempting remote
421 imports.
421 imports.
422
422
423 Note that remote-only (`local=False`) imports have not been implemented.
423 Note that remote-only (`local=False`) imports have not been implemented.
424
424
425 >>> with view.sync_imports():
425 >>> with view.sync_imports():
426 ... from numpy import recarray
426 ... from numpy import recarray
427 importing recarray from numpy on engine(s)
427 importing recarray from numpy on engine(s)
428
428
429 """
429 """
430 from IPython.utils.py3compat import builtin_mod
430 from IPython.utils.py3compat import builtin_mod
431 local_import = builtin_mod.__import__
431 local_import = builtin_mod.__import__
432 modules = set()
432 modules = set()
433 results = []
433 results = []
434 @util.interactive
434 @util.interactive
435 def remote_import(name, fromlist, level):
435 def remote_import(name, fromlist, level):
436 """the function to be passed to apply, that actually performs the import
436 """the function to be passed to apply, that actually performs the import
437 on the engine, and loads up the user namespace.
437 on the engine, and loads up the user namespace.
438 """
438 """
439 import sys
439 import sys
440 user_ns = globals()
440 user_ns = globals()
441 mod = __import__(name, fromlist=fromlist, level=level)
441 mod = __import__(name, fromlist=fromlist, level=level)
442 if fromlist:
442 if fromlist:
443 for key in fromlist:
443 for key in fromlist:
444 user_ns[key] = getattr(mod, key)
444 user_ns[key] = getattr(mod, key)
445 else:
445 else:
446 user_ns[name] = sys.modules[name]
446 user_ns[name] = sys.modules[name]
447
447
448 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
448 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
449 """the drop-in replacement for __import__, that optionally imports
449 """the drop-in replacement for __import__, that optionally imports
450 locally as well.
450 locally as well.
451 """
451 """
452 # don't override nested imports
452 # don't override nested imports
453 save_import = builtin_mod.__import__
453 save_import = builtin_mod.__import__
454 builtin_mod.__import__ = local_import
454 builtin_mod.__import__ = local_import
455
455
456 if imp.lock_held():
456 if imp.lock_held():
457 # this is a side-effect import, don't do it remotely, or even
457 # this is a side-effect import, don't do it remotely, or even
458 # ignore the local effects
458 # ignore the local effects
459 return local_import(name, globals, locals, fromlist, level)
459 return local_import(name, globals, locals, fromlist, level)
460
460
461 imp.acquire_lock()
461 imp.acquire_lock()
462 if local:
462 if local:
463 mod = local_import(name, globals, locals, fromlist, level)
463 mod = local_import(name, globals, locals, fromlist, level)
464 else:
464 else:
465 raise NotImplementedError("remote-only imports not yet implemented")
465 raise NotImplementedError("remote-only imports not yet implemented")
466 imp.release_lock()
466 imp.release_lock()
467
467
468 key = name+':'+','.join(fromlist or [])
468 key = name+':'+','.join(fromlist or [])
469 if level <= 0 and key not in modules:
469 if level <= 0 and key not in modules:
470 modules.add(key)
470 modules.add(key)
471 if not quiet:
471 if not quiet:
472 if fromlist:
472 if fromlist:
473 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
473 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
474 else:
474 else:
475 print("importing %s on engine(s)"%name)
475 print("importing %s on engine(s)"%name)
476 results.append(self.apply_async(remote_import, name, fromlist, level))
476 results.append(self.apply_async(remote_import, name, fromlist, level))
477 # restore override
477 # restore override
478 builtin_mod.__import__ = save_import
478 builtin_mod.__import__ = save_import
479
479
480 return mod
480 return mod
481
481
482 # override __import__
482 # override __import__
483 builtin_mod.__import__ = view_import
483 builtin_mod.__import__ = view_import
484 try:
484 try:
485 # enter the block
485 # enter the block
486 yield
486 yield
487 except ImportError:
487 except ImportError:
488 if local:
488 if local:
489 raise
489 raise
490 else:
490 else:
491 # ignore import errors if not doing local imports
491 # ignore import errors if not doing local imports
492 pass
492 pass
493 finally:
493 finally:
494 # always restore __import__
494 # always restore __import__
495 builtin_mod.__import__ = local_import
495 builtin_mod.__import__ = local_import
496
496
497 for r in results:
497 for r in results:
498 # raise possible remote ImportErrors here
498 # raise possible remote ImportErrors here
499 r.get()
499 r.get()
500
500
501 def use_dill(self):
501 def use_dill(self):
502 """Expand serialization support with dill
502 """Expand serialization support with dill
503
503
504 adds support for closures, etc.
504 adds support for closures, etc.
505
505
506 This calls IPython.utils.pickleutil.use_dill() here and on each engine.
506 This calls IPython.utils.pickleutil.use_dill() here and on each engine.
507 """
507 """
508 pickleutil.use_dill()
508 pickleutil.use_dill()
509 return self.apply(pickleutil.use_dill)
509 return self.apply(pickleutil.use_dill)
510
510
511 def use_cloudpickle(self):
511 def use_cloudpickle(self):
512 """Expand serialization support with cloudpickle.
512 """Expand serialization support with cloudpickle.
513 """
513 """
514 pickleutil.use_cloudpickle()
514 pickleutil.use_cloudpickle()
515 return self.apply(pickleutil.use_cloudpickle)
515 return self.apply(pickleutil.use_cloudpickle)
516
516
517
517
518 @sync_results
518 @sync_results
519 @save_ids
519 @save_ids
520 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
520 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
521 """calls f(*args, **kwargs) on remote engines, returning the result.
521 """calls f(*args, **kwargs) on remote engines, returning the result.
522
522
523 This method sets all of `apply`'s flags via this View's attributes.
523 This method sets all of `apply`'s flags via this View's attributes.
524
524
525 Parameters
525 Parameters
526 ----------
526 ----------
527
527
528 f : callable
528 f : callable
529
529
530 args : list [default: empty]
530 args : list [default: empty]
531
531
532 kwargs : dict [default: empty]
532 kwargs : dict [default: empty]
533
533
534 targets : target list [default: self.targets]
534 targets : target list [default: self.targets]
535 where to run
535 where to run
536 block : bool [default: self.block]
536 block : bool [default: self.block]
537 whether to block
537 whether to block
538 track : bool [default: self.track]
538 track : bool [default: self.track]
539 whether to ask zmq to track the message, for safe non-copying sends
539 whether to ask zmq to track the message, for safe non-copying sends
540
540
541 Returns
541 Returns
542 -------
542 -------
543
543
544 if self.block is False:
544 if self.block is False:
545 returns AsyncResult
545 returns AsyncResult
546 else:
546 else:
547 returns actual result of f(*args, **kwargs) on the engine(s)
547 returns actual result of f(*args, **kwargs) on the engine(s)
548 This will be a list of self.targets is also a list (even length 1), or
548 This will be a list of self.targets is also a list (even length 1), or
549 the single result if self.targets is an integer engine id
549 the single result if self.targets is an integer engine id
550 """
550 """
551 args = [] if args is None else args
551 args = [] if args is None else args
552 kwargs = {} if kwargs is None else kwargs
552 kwargs = {} if kwargs is None else kwargs
553 block = self.block if block is None else block
553 block = self.block if block is None else block
554 track = self.track if track is None else track
554 track = self.track if track is None else track
555 targets = self.targets if targets is None else targets
555 targets = self.targets if targets is None else targets
556
556
557 _idents, _targets = self.client._build_targets(targets)
557 _idents, _targets = self.client._build_targets(targets)
558 msg_ids = []
558 msg_ids = []
559 trackers = []
559 trackers = []
560 for ident in _idents:
560 for ident in _idents:
561 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
561 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
562 ident=ident)
562 ident=ident)
563 if track:
563 if track:
564 trackers.append(msg['tracker'])
564 trackers.append(msg['tracker'])
565 msg_ids.append(msg['header']['msg_id'])
565 msg_ids.append(msg['header']['msg_id'])
566 if isinstance(targets, int):
566 if isinstance(targets, int):
567 msg_ids = msg_ids[0]
567 msg_ids = msg_ids[0]
568 tracker = None if track is False else zmq.MessageTracker(*trackers)
568 tracker = None if track is False else zmq.MessageTracker(*trackers)
569 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
569 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
570 tracker=tracker, owner=True,
570 tracker=tracker, owner=True,
571 )
571 )
572 if block:
572 if block:
573 try:
573 try:
574 return ar.get()
574 return ar.get()
575 except KeyboardInterrupt:
575 except KeyboardInterrupt:
576 pass
576 pass
577 return ar
577 return ar
578
578
579
579
580 @sync_results
580 @sync_results
581 def map(self, f, *sequences, **kwargs):
581 def map(self, f, *sequences, **kwargs):
582 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
582 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
583
583
584 Parallel version of builtin `map`, using this View's `targets`.
584 Parallel version of builtin `map`, using this View's `targets`.
585
585
586 There will be one task per target, so work will be chunked
586 There will be one task per target, so work will be chunked
587 if the sequences are longer than `targets`.
587 if the sequences are longer than `targets`.
588
588
589 Results can be iterated as they are ready, but will become available in chunks.
589 Results can be iterated as they are ready, but will become available in chunks.
590
590
591 Parameters
591 Parameters
592 ----------
592 ----------
593
593
594 f : callable
594 f : callable
595 function to be mapped
595 function to be mapped
596 *sequences: one or more sequences of matching length
596 *sequences: one or more sequences of matching length
597 the sequences to be distributed and passed to `f`
597 the sequences to be distributed and passed to `f`
598 block : bool
598 block : bool
599 whether to wait for the result or not [default self.block]
599 whether to wait for the result or not [default self.block]
600
600
601 Returns
601 Returns
602 -------
602 -------
603
603
604
604
605 If block=False
605 If block=False
606 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
606 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
607 An object like AsyncResult, but which reassembles the sequence of results
607 An object like AsyncResult, but which reassembles the sequence of results
608 into a single list. AsyncMapResults can be iterated through before all
608 into a single list. AsyncMapResults can be iterated through before all
609 results are complete.
609 results are complete.
610 else
610 else
611 A list, the result of ``map(f,*sequences)``
611 A list, the result of ``map(f,*sequences)``
612 """
612 """
613
613
614 block = kwargs.pop('block', self.block)
614 block = kwargs.pop('block', self.block)
615 for k in kwargs.keys():
615 for k in kwargs.keys():
616 if k not in ['block', 'track']:
616 if k not in ['block', 'track']:
617 raise TypeError("invalid keyword arg, %r"%k)
617 raise TypeError("invalid keyword arg, %r"%k)
618
618
619 assert len(sequences) > 0, "must have some sequences to map onto!"
619 assert len(sequences) > 0, "must have some sequences to map onto!"
620 pf = ParallelFunction(self, f, block=block, **kwargs)
620 pf = ParallelFunction(self, f, block=block, **kwargs)
621 return pf.map(*sequences)
621 return pf.map(*sequences)
622
622
623 @sync_results
623 @sync_results
624 @save_ids
624 @save_ids
625 def execute(self, code, silent=True, targets=None, block=None):
625 def execute(self, code, silent=True, targets=None, block=None):
626 """Executes `code` on `targets` in blocking or nonblocking manner.
626 """Executes `code` on `targets` in blocking or nonblocking manner.
627
627
628 ``execute`` is always `bound` (affects engine namespace)
628 ``execute`` is always `bound` (affects engine namespace)
629
629
630 Parameters
630 Parameters
631 ----------
631 ----------
632
632
633 code : str
633 code : str
634 the code string to be executed
634 the code string to be executed
635 block : bool
635 block : bool
636 whether or not to wait until done to return
636 whether or not to wait until done to return
637 default: self.block
637 default: self.block
638 """
638 """
639 block = self.block if block is None else block
639 block = self.block if block is None else block
640 targets = self.targets if targets is None else targets
640 targets = self.targets if targets is None else targets
641
641
642 _idents, _targets = self.client._build_targets(targets)
642 _idents, _targets = self.client._build_targets(targets)
643 msg_ids = []
643 msg_ids = []
644 trackers = []
644 trackers = []
645 for ident in _idents:
645 for ident in _idents:
646 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
646 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
647 msg_ids.append(msg['header']['msg_id'])
647 msg_ids.append(msg['header']['msg_id'])
648 if isinstance(targets, int):
648 if isinstance(targets, int):
649 msg_ids = msg_ids[0]
649 msg_ids = msg_ids[0]
650 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
650 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
651 if block:
651 if block:
652 try:
652 try:
653 ar.get()
653 ar.get()
654 except KeyboardInterrupt:
654 except KeyboardInterrupt:
655 pass
655 pass
656 return ar
656 return ar
657
657
658 def run(self, filename, targets=None, block=None):
658 def run(self, filename, targets=None, block=None):
659 """Execute contents of `filename` on my engine(s).
659 """Execute contents of `filename` on my engine(s).
660
660
661 This simply reads the contents of the file and calls `execute`.
661 This simply reads the contents of the file and calls `execute`.
662
662
663 Parameters
663 Parameters
664 ----------
664 ----------
665
665
666 filename : str
666 filename : str
667 The path to the file
667 The path to the file
668 targets : int/str/list of ints/strs
668 targets : int/str/list of ints/strs
669 the engines on which to execute
669 the engines on which to execute
670 default : all
670 default : all
671 block : bool
671 block : bool
672 whether or not to wait until done
672 whether or not to wait until done
673 default: self.block
673 default: self.block
674
674
675 """
675 """
676 with open(filename, 'r') as f:
676 with open(filename, 'r') as f:
677 # add newline in case of trailing indented whitespace
677 # add newline in case of trailing indented whitespace
678 # which will cause SyntaxError
678 # which will cause SyntaxError
679 code = f.read()+'\n'
679 code = f.read()+'\n'
680 return self.execute(code, block=block, targets=targets)
680 return self.execute(code, block=block, targets=targets)
681
681
682 def update(self, ns):
682 def update(self, ns):
683 """update remote namespace with dict `ns`
683 """update remote namespace with dict `ns`
684
684
685 See `push` for details.
685 See `push` for details.
686 """
686 """
687 return self.push(ns, block=self.block, track=self.track)
687 return self.push(ns, block=self.block, track=self.track)
688
688
689 def push(self, ns, targets=None, block=None, track=None):
689 def push(self, ns, targets=None, block=None, track=None):
690 """update remote namespace with dict `ns`
690 """update remote namespace with dict `ns`
691
691
692 Parameters
692 Parameters
693 ----------
693 ----------
694
694
695 ns : dict
695 ns : dict
696 dict of keys with which to update engine namespace(s)
696 dict of keys with which to update engine namespace(s)
697 block : bool [default : self.block]
697 block : bool [default : self.block]
698 whether to wait to be notified of engine receipt
698 whether to wait to be notified of engine receipt
699
699
700 """
700 """
701
701
702 block = block if block is not None else self.block
702 block = block if block is not None else self.block
703 track = track if track is not None else self.track
703 track = track if track is not None else self.track
704 targets = targets if targets is not None else self.targets
704 targets = targets if targets is not None else self.targets
705 # applier = self.apply_sync if block else self.apply_async
705 # applier = self.apply_sync if block else self.apply_async
706 if not isinstance(ns, dict):
706 if not isinstance(ns, dict):
707 raise TypeError("Must be a dict, not %s"%type(ns))
707 raise TypeError("Must be a dict, not %s"%type(ns))
708 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
708 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
709
709
710 def get(self, key_s):
710 def get(self, key_s):
711 """get object(s) by `key_s` from remote namespace
711 """get object(s) by `key_s` from remote namespace
712
712
713 see `pull` for details.
713 see `pull` for details.
714 """
714 """
715 # block = block if block is not None else self.block
715 # block = block if block is not None else self.block
716 return self.pull(key_s, block=True)
716 return self.pull(key_s, block=True)
717
717
718 def pull(self, names, targets=None, block=None):
718 def pull(self, names, targets=None, block=None):
719 """get object(s) by `name` from remote namespace
719 """get object(s) by `name` from remote namespace
720
720
721 will return one object if it is a key.
721 will return one object if it is a key.
722 can also take a list of keys, in which case it will return a list of objects.
722 can also take a list of keys, in which case it will return a list of objects.
723 """
723 """
724 block = block if block is not None else self.block
724 block = block if block is not None else self.block
725 targets = targets if targets is not None else self.targets
725 targets = targets if targets is not None else self.targets
726 applier = self.apply_sync if block else self.apply_async
726 applier = self.apply_sync if block else self.apply_async
727 if isinstance(names, string_types):
727 if isinstance(names, string_types):
728 pass
728 pass
729 elif isinstance(names, (list,tuple,set)):
729 elif isinstance(names, (list,tuple,set)):
730 for key in names:
730 for key in names:
731 if not isinstance(key, string_types):
731 if not isinstance(key, string_types):
732 raise TypeError("keys must be str, not type %r"%type(key))
732 raise TypeError("keys must be str, not type %r"%type(key))
733 else:
733 else:
734 raise TypeError("names must be strs, not %r"%names)
734 raise TypeError("names must be strs, not %r"%names)
735 return self._really_apply(util._pull, (names,), block=block, targets=targets)
735 return self._really_apply(util._pull, (names,), block=block, targets=targets)
736
736
737 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
737 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
738 """
738 """
739 Partition a Python sequence and send the partitions to a set of engines.
739 Partition a Python sequence and send the partitions to a set of engines.
740 """
740 """
741 block = block if block is not None else self.block
741 block = block if block is not None else self.block
742 track = track if track is not None else self.track
742 track = track if track is not None else self.track
743 targets = targets if targets is not None else self.targets
743 targets = targets if targets is not None else self.targets
744
744
745 # construct integer ID list:
745 # construct integer ID list:
746 targets = self.client._build_targets(targets)[1]
746 targets = self.client._build_targets(targets)[1]
747
747
748 mapObject = Map.dists[dist]()
748 mapObject = Map.dists[dist]()
749 nparts = len(targets)
749 nparts = len(targets)
750 msg_ids = []
750 msg_ids = []
751 trackers = []
751 trackers = []
752 for index, engineid in enumerate(targets):
752 for index, engineid in enumerate(targets):
753 partition = mapObject.getPartition(seq, index, nparts)
753 partition = mapObject.getPartition(seq, index, nparts)
754 if flatten and len(partition) == 1:
754 if flatten and len(partition) == 1:
755 ns = {key: partition[0]}
755 ns = {key: partition[0]}
756 else:
756 else:
757 ns = {key: partition}
757 ns = {key: partition}
758 r = self.push(ns, block=False, track=track, targets=engineid)
758 r = self.push(ns, block=False, track=track, targets=engineid)
759 msg_ids.extend(r.msg_ids)
759 msg_ids.extend(r.msg_ids)
760 if track:
760 if track:
761 trackers.append(r._tracker)
761 trackers.append(r._tracker)
762
762
763 if track:
763 if track:
764 tracker = zmq.MessageTracker(*trackers)
764 tracker = zmq.MessageTracker(*trackers)
765 else:
765 else:
766 tracker = None
766 tracker = None
767
767
768 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
768 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
769 tracker=tracker, owner=True,
769 tracker=tracker, owner=True,
770 )
770 )
771 if block:
771 if block:
772 r.wait()
772 r.wait()
773 else:
773 else:
774 return r
774 return r
775
775
776 @sync_results
776 @sync_results
777 @save_ids
777 @save_ids
778 def gather(self, key, dist='b', targets=None, block=None):
778 def gather(self, key, dist='b', targets=None, block=None):
779 """
779 """
780 Gather a partitioned sequence on a set of engines as a single local seq.
780 Gather a partitioned sequence on a set of engines as a single local seq.
781 """
781 """
782 block = block if block is not None else self.block
782 block = block if block is not None else self.block
783 targets = targets if targets is not None else self.targets
783 targets = targets if targets is not None else self.targets
784 mapObject = Map.dists[dist]()
784 mapObject = Map.dists[dist]()
785 msg_ids = []
785 msg_ids = []
786
786
787 # construct integer ID list:
787 # construct integer ID list:
788 targets = self.client._build_targets(targets)[1]
788 targets = self.client._build_targets(targets)[1]
789
789
790 for index, engineid in enumerate(targets):
790 for index, engineid in enumerate(targets):
791 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
791 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
792
792
793 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
793 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
794
794
795 if block:
795 if block:
796 try:
796 try:
797 return r.get()
797 return r.get()
798 except KeyboardInterrupt:
798 except KeyboardInterrupt:
799 pass
799 pass
800 return r
800 return r
801
801
802 def __getitem__(self, key):
802 def __getitem__(self, key):
803 return self.get(key)
803 return self.get(key)
804
804
805 def __setitem__(self,key, value):
805 def __setitem__(self,key, value):
806 self.update({key:value})
806 self.update({key:value})
807
807
808 def clear(self, targets=None, block=None):
808 def clear(self, targets=None, block=None):
809 """Clear the remote namespaces on my engines."""
809 """Clear the remote namespaces on my engines."""
810 block = block if block is not None else self.block
810 block = block if block is not None else self.block
811 targets = targets if targets is not None else self.targets
811 targets = targets if targets is not None else self.targets
812 return self.client.clear(targets=targets, block=block)
812 return self.client.clear(targets=targets, block=block)
813
813
814 #----------------------------------------
814 #----------------------------------------
815 # activate for %px, %autopx, etc. magics
815 # activate for %px, %autopx, etc. magics
816 #----------------------------------------
816 #----------------------------------------
817
817
818 def activate(self, suffix=''):
818 def activate(self, suffix=''):
819 """Activate IPython magics associated with this View
819 """Activate IPython magics associated with this View
820
820
821 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
821 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
822
822
823 Parameters
823 Parameters
824 ----------
824 ----------
825
825
826 suffix: str [default: '']
826 suffix: str [default: '']
827 The suffix, if any, for the magics. This allows you to have
827 The suffix, if any, for the magics. This allows you to have
828 multiple views associated with parallel magics at the same time.
828 multiple views associated with parallel magics at the same time.
829
829
830 e.g. ``rc[::2].activate(suffix='_even')`` will give you
830 e.g. ``rc[::2].activate(suffix='_even')`` will give you
831 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
831 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
832 on the even engines.
832 on the even engines.
833 """
833 """
834
834
835 from IPython.parallel.client.magics import ParallelMagics
835 from IPython.parallel.client.magics import ParallelMagics
836
836
837 try:
837 try:
838 # This is injected into __builtins__.
838 # This is injected into __builtins__.
839 ip = get_ipython()
839 ip = get_ipython()
840 except NameError:
840 except NameError:
841 print("The IPython parallel magics (%px, etc.) only work within IPython.")
841 print("The IPython parallel magics (%px, etc.) only work within IPython.")
842 return
842 return
843
843
844 M = ParallelMagics(ip, self, suffix)
844 M = ParallelMagics(ip, self, suffix)
845 ip.magics_manager.register(M)
845 ip.magics_manager.register(M)
846
846
847
847
848 @skip_doctest
848 @skip_doctest
849 class LoadBalancedView(View):
849 class LoadBalancedView(View):
850 """An load-balancing View that only executes via the Task scheduler.
850 """An load-balancing View that only executes via the Task scheduler.
851
851
852 Load-balanced views can be created with the client's `view` method:
852 Load-balanced views can be created with the client's `view` method:
853
853
854 >>> v = client.load_balanced_view()
854 >>> v = client.load_balanced_view()
855
855
856 or targets can be specified, to restrict the potential destinations:
856 or targets can be specified, to restrict the potential destinations:
857
857
858 >>> v = client.load_balanced_view([1,3])
858 >>> v = client.load_balanced_view([1,3])
859
859
860 which would restrict loadbalancing to between engines 1 and 3.
860 which would restrict loadbalancing to between engines 1 and 3.
861
861
862 """
862 """
863
863
864 follow=Any()
864 follow=Any()
865 after=Any()
865 after=Any()
866 timeout=CFloat()
866 timeout=CFloat()
867 retries = Integer(0)
867 retries = Integer(0)
868
868
869 _task_scheme = Any()
869 _task_scheme = Any()
870 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
870 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
871
871
872 def __init__(self, client=None, socket=None, **flags):
872 def __init__(self, client=None, socket=None, **flags):
873 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
873 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
874 self._task_scheme=client._task_scheme
874 self._task_scheme=client._task_scheme
875
875
876 def _validate_dependency(self, dep):
876 def _validate_dependency(self, dep):
877 """validate a dependency.
877 """validate a dependency.
878
878
879 For use in `set_flags`.
879 For use in `set_flags`.
880 """
880 """
881 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
881 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
882 return True
882 return True
883 elif isinstance(dep, (list,set, tuple)):
883 elif isinstance(dep, (list,set, tuple)):
884 for d in dep:
884 for d in dep:
885 if not isinstance(d, string_types + (AsyncResult,)):
885 if not isinstance(d, string_types + (AsyncResult,)):
886 return False
886 return False
887 elif isinstance(dep, dict):
887 elif isinstance(dep, dict):
888 if set(dep.keys()) != set(Dependency().as_dict().keys()):
888 if set(dep.keys()) != set(Dependency().as_dict().keys()):
889 return False
889 return False
890 if not isinstance(dep['msg_ids'], list):
890 if not isinstance(dep['msg_ids'], list):
891 return False
891 return False
892 for d in dep['msg_ids']:
892 for d in dep['msg_ids']:
893 if not isinstance(d, string_types):
893 if not isinstance(d, string_types):
894 return False
894 return False
895 else:
895 else:
896 return False
896 return False
897
897
898 return True
898 return True
899
899
900 def _render_dependency(self, dep):
900 def _render_dependency(self, dep):
901 """helper for building jsonable dependencies from various input forms."""
901 """helper for building jsonable dependencies from various input forms."""
902 if isinstance(dep, Dependency):
902 if isinstance(dep, Dependency):
903 return dep.as_dict()
903 return dep.as_dict()
904 elif isinstance(dep, AsyncResult):
904 elif isinstance(dep, AsyncResult):
905 return dep.msg_ids
905 return dep.msg_ids
906 elif dep is None:
906 elif dep is None:
907 return []
907 return []
908 else:
908 else:
909 # pass to Dependency constructor
909 # pass to Dependency constructor
910 return list(Dependency(dep))
910 return list(Dependency(dep))
911
911
912 def set_flags(self, **kwargs):
912 def set_flags(self, **kwargs):
913 """set my attribute flags by keyword.
913 """set my attribute flags by keyword.
914
914
915 A View is a wrapper for the Client's apply method, but with attributes
915 A View is a wrapper for the Client's apply method, but with attributes
916 that specify keyword arguments, those attributes can be set by keyword
916 that specify keyword arguments, those attributes can be set by keyword
917 argument with this method.
917 argument with this method.
918
918
919 Parameters
919 Parameters
920 ----------
920 ----------
921
921
922 block : bool
922 block : bool
923 whether to wait for results
923 whether to wait for results
924 track : bool
924 track : bool
925 whether to create a MessageTracker to allow the user to
925 whether to create a MessageTracker to allow the user to
926 safely edit after arrays and buffers during non-copying
926 safely edit after arrays and buffers during non-copying
927 sends.
927 sends.
928
928
929 after : Dependency or collection of msg_ids
929 after : Dependency or collection of msg_ids
930 Only for load-balanced execution (targets=None)
930 Only for load-balanced execution (targets=None)
931 Specify a list of msg_ids as a time-based dependency.
931 Specify a list of msg_ids as a time-based dependency.
932 This job will only be run *after* the dependencies
932 This job will only be run *after* the dependencies
933 have been met.
933 have been met.
934
934
935 follow : Dependency or collection of msg_ids
935 follow : Dependency or collection of msg_ids
936 Only for load-balanced execution (targets=None)
936 Only for load-balanced execution (targets=None)
937 Specify a list of msg_ids as a location-based dependency.
937 Specify a list of msg_ids as a location-based dependency.
938 This job will only be run on an engine where this dependency
938 This job will only be run on an engine where this dependency
939 is met.
939 is met.
940
940
941 timeout : float/int or None
941 timeout : float/int or None
942 Only for load-balanced execution (targets=None)
942 Only for load-balanced execution (targets=None)
943 Specify an amount of time (in seconds) for the scheduler to
943 Specify an amount of time (in seconds) for the scheduler to
944 wait for dependencies to be met before failing with a
944 wait for dependencies to be met before failing with a
945 DependencyTimeout.
945 DependencyTimeout.
946
946
947 retries : int
947 retries : int
948 Number of times a task will be retried on failure.
948 Number of times a task will be retried on failure.
949 """
949 """
950
950
951 super(LoadBalancedView, self).set_flags(**kwargs)
951 super(LoadBalancedView, self).set_flags(**kwargs)
952 for name in ('follow', 'after'):
952 for name in ('follow', 'after'):
953 if name in kwargs:
953 if name in kwargs:
954 value = kwargs[name]
954 value = kwargs[name]
955 if self._validate_dependency(value):
955 if self._validate_dependency(value):
956 setattr(self, name, value)
956 setattr(self, name, value)
957 else:
957 else:
958 raise ValueError("Invalid dependency: %r"%value)
958 raise ValueError("Invalid dependency: %r"%value)
959 if 'timeout' in kwargs:
959 if 'timeout' in kwargs:
960 t = kwargs['timeout']
960 t = kwargs['timeout']
961 if not isinstance(t, (int, float, type(None))):
961 if not isinstance(t, (int, float, type(None))):
962 if (not PY3) and (not isinstance(t, long)):
962 if (not PY3) and (not isinstance(t, long)):
963 raise TypeError("Invalid type for timeout: %r"%type(t))
963 raise TypeError("Invalid type for timeout: %r"%type(t))
964 if t is not None:
964 if t is not None:
965 if t < 0:
965 if t < 0:
966 raise ValueError("Invalid timeout: %s"%t)
966 raise ValueError("Invalid timeout: %s"%t)
967 self.timeout = t
967 self.timeout = t
968
968
969 @sync_results
969 @sync_results
970 @save_ids
970 @save_ids
971 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
971 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
972 after=None, follow=None, timeout=None,
972 after=None, follow=None, timeout=None,
973 targets=None, retries=None):
973 targets=None, retries=None):
974 """calls f(*args, **kwargs) on a remote engine, returning the result.
974 """calls f(*args, **kwargs) on a remote engine, returning the result.
975
975
976 This method temporarily sets all of `apply`'s flags for a single call.
976 This method temporarily sets all of `apply`'s flags for a single call.
977
977
978 Parameters
978 Parameters
979 ----------
979 ----------
980
980
981 f : callable
981 f : callable
982
982
983 args : list [default: empty]
983 args : list [default: empty]
984
984
985 kwargs : dict [default: empty]
985 kwargs : dict [default: empty]
986
986
987 block : bool [default: self.block]
987 block : bool [default: self.block]
988 whether to block
988 whether to block
989 track : bool [default: self.track]
989 track : bool [default: self.track]
990 whether to ask zmq to track the message, for safe non-copying sends
990 whether to ask zmq to track the message, for safe non-copying sends
991
991
992 !!!!!! TODO: THE REST HERE !!!!
992 !!!!!! TODO: THE REST HERE !!!!
993
993
994 Returns
994 Returns
995 -------
995 -------
996
996
997 if self.block is False:
997 if self.block is False:
998 returns AsyncResult
998 returns AsyncResult
999 else:
999 else:
1000 returns actual result of f(*args, **kwargs) on the engine(s)
1000 returns actual result of f(*args, **kwargs) on the engine(s)
1001 This will be a list of self.targets is also a list (even length 1), or
1001 This will be a list of self.targets is also a list (even length 1), or
1002 the single result if self.targets is an integer engine id
1002 the single result if self.targets is an integer engine id
1003 """
1003 """
1004
1004
1005 # validate whether we can run
1005 # validate whether we can run
1006 if self._socket.closed:
1006 if self._socket.closed:
1007 msg = "Task farming is disabled"
1007 msg = "Task farming is disabled"
1008 if self._task_scheme == 'pure':
1008 if self._task_scheme == 'pure':
1009 msg += " because the pure ZMQ scheduler cannot handle"
1009 msg += " because the pure ZMQ scheduler cannot handle"
1010 msg += " disappearing engines."
1010 msg += " disappearing engines."
1011 raise RuntimeError(msg)
1011 raise RuntimeError(msg)
1012
1012
1013 if self._task_scheme == 'pure':
1013 if self._task_scheme == 'pure':
1014 # pure zmq scheme doesn't support extra features
1014 # pure zmq scheme doesn't support extra features
1015 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1015 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1016 "follow, after, retries, targets, timeout"
1016 "follow, after, retries, targets, timeout"
1017 if (follow or after or retries or targets or timeout):
1017 if (follow or after or retries or targets or timeout):
1018 # hard fail on Scheduler flags
1018 # hard fail on Scheduler flags
1019 raise RuntimeError(msg)
1019 raise RuntimeError(msg)
1020 if isinstance(f, dependent):
1020 if isinstance(f, dependent):
1021 # soft warn on functional dependencies
1021 # soft warn on functional dependencies
1022 warnings.warn(msg, RuntimeWarning)
1022 warnings.warn(msg, RuntimeWarning)
1023
1023
1024 # build args
1024 # build args
1025 args = [] if args is None else args
1025 args = [] if args is None else args
1026 kwargs = {} if kwargs is None else kwargs
1026 kwargs = {} if kwargs is None else kwargs
1027 block = self.block if block is None else block
1027 block = self.block if block is None else block
1028 track = self.track if track is None else track
1028 track = self.track if track is None else track
1029 after = self.after if after is None else after
1029 after = self.after if after is None else after
1030 retries = self.retries if retries is None else retries
1030 retries = self.retries if retries is None else retries
1031 follow = self.follow if follow is None else follow
1031 follow = self.follow if follow is None else follow
1032 timeout = self.timeout if timeout is None else timeout
1032 timeout = self.timeout if timeout is None else timeout
1033 targets = self.targets if targets is None else targets
1033 targets = self.targets if targets is None else targets
1034
1034
1035 if not isinstance(retries, int):
1035 if not isinstance(retries, int):
1036 raise TypeError('retries must be int, not %r'%type(retries))
1036 raise TypeError('retries must be int, not %r'%type(retries))
1037
1037
1038 if targets is None:
1038 if targets is None:
1039 idents = []
1039 idents = []
1040 else:
1040 else:
1041 idents = self.client._build_targets(targets)[0]
1041 idents = self.client._build_targets(targets)[0]
1042 # ensure *not* bytes
1042 # ensure *not* bytes
1043 idents = [ ident.decode() for ident in idents ]
1043 idents = [ ident.decode() for ident in idents ]
1044
1044
1045 after = self._render_dependency(after)
1045 after = self._render_dependency(after)
1046 follow = self._render_dependency(follow)
1046 follow = self._render_dependency(follow)
1047 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1047 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1048
1048
1049 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1049 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1050 metadata=metadata)
1050 metadata=metadata)
1051 tracker = None if track is False else msg['tracker']
1051 tracker = None if track is False else msg['tracker']
1052
1052
1053 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1053 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1054 targets=None, tracker=tracker, owner=True,
1054 targets=None, tracker=tracker, owner=True,
1055 )
1055 )
1056 if block:
1056 if block:
1057 try:
1057 try:
1058 return ar.get()
1058 return ar.get()
1059 except KeyboardInterrupt:
1059 except KeyboardInterrupt:
1060 pass
1060 pass
1061 return ar
1061 return ar
1062
1062
1063 @sync_results
1063 @sync_results
1064 @save_ids
1064 @save_ids
1065 def map(self, f, *sequences, **kwargs):
1065 def map(self, f, *sequences, **kwargs):
1066 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1066 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1067
1067
1068 Parallel version of builtin `map`, load-balanced by this View.
1068 Parallel version of builtin `map`, load-balanced by this View.
1069
1069
1070 `block`, and `chunksize` can be specified by keyword only.
1070 `block`, and `chunksize` can be specified by keyword only.
1071
1071
1072 Each `chunksize` elements will be a separate task, and will be
1072 Each `chunksize` elements will be a separate task, and will be
1073 load-balanced. This lets individual elements be available for iteration
1073 load-balanced. This lets individual elements be available for iteration
1074 as soon as they arrive.
1074 as soon as they arrive.
1075
1075
1076 Parameters
1076 Parameters
1077 ----------
1077 ----------
1078
1078
1079 f : callable
1079 f : callable
1080 function to be mapped
1080 function to be mapped
1081 *sequences: one or more sequences of matching length
1081 *sequences: one or more sequences of matching length
1082 the sequences to be distributed and passed to `f`
1082 the sequences to be distributed and passed to `f`
1083 block : bool [default self.block]
1083 block : bool [default self.block]
1084 whether to wait for the result or not
1084 whether to wait for the result or not
1085 track : bool
1085 track : bool
1086 whether to create a MessageTracker to allow the user to
1086 whether to create a MessageTracker to allow the user to
1087 safely edit after arrays and buffers during non-copying
1087 safely edit after arrays and buffers during non-copying
1088 sends.
1088 sends.
1089 chunksize : int [default 1]
1089 chunksize : int [default 1]
1090 how many elements should be in each task.
1090 how many elements should be in each task.
1091 ordered : bool [default True]
1091 ordered : bool [default True]
1092 Whether the results should be gathered as they arrive, or enforce
1092 Whether the results should be gathered as they arrive, or enforce
1093 the order of submission.
1093 the order of submission.
1094
1094
1095 Only applies when iterating through AsyncMapResult as results arrive.
1095 Only applies when iterating through AsyncMapResult as results arrive.
1096 Has no effect when block=True.
1096 Has no effect when block=True.
1097
1097
1098 Returns
1098 Returns
1099 -------
1099 -------
1100
1100
1101 if block=False
1101 if block=False
1102 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
1102 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
1103 An object like AsyncResult, but which reassembles the sequence of results
1103 An object like AsyncResult, but which reassembles the sequence of results
1104 into a single list. AsyncMapResults can be iterated through before all
1104 into a single list. AsyncMapResults can be iterated through before all
1105 results are complete.
1105 results are complete.
1106 else
1106 else
1107 A list, the result of ``map(f,*sequences)``
1107 A list, the result of ``map(f,*sequences)``
1108 """
1108 """
1109
1109
1110 # default
1110 # default
1111 block = kwargs.get('block', self.block)
1111 block = kwargs.get('block', self.block)
1112 chunksize = kwargs.get('chunksize', 1)
1112 chunksize = kwargs.get('chunksize', 1)
1113 ordered = kwargs.get('ordered', True)
1113 ordered = kwargs.get('ordered', True)
1114
1114
1115 keyset = set(kwargs.keys())
1115 keyset = set(kwargs.keys())
1116 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1116 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1117 if extra_keys:
1117 if extra_keys:
1118 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1118 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1119
1119
1120 assert len(sequences) > 0, "must have some sequences to map onto!"
1120 assert len(sequences) > 0, "must have some sequences to map onto!"
1121
1121
1122 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1122 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1123 return pf.map(*sequences)
1123 return pf.map(*sequences)
1124
1124
1125 __all__ = ['LoadBalancedView', 'DirectView']
1125 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,849 +1,849 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6 """
6 """
7
7
8 # Copyright (c) IPython Development Team.
8 # Copyright (c) IPython Development Team.
9 # Distributed under the terms of the Modified BSD License.
9 # Distributed under the terms of the Modified BSD License.
10
10
11 import logging
11 import logging
12 import sys
12 import sys
13 import time
13 import time
14
14
15 from collections import deque
15 from collections import deque
16 from datetime import datetime
16 from datetime import datetime
17 from random import randint, random
17 from random import randint, random
18 from types import FunctionType
18 from types import FunctionType
19
19
20 try:
20 try:
21 import numpy
21 import numpy
22 except ImportError:
22 except ImportError:
23 numpy = None
23 numpy = None
24
24
25 import zmq
25 import zmq
26 from zmq.eventloop import ioloop, zmqstream
26 from zmq.eventloop import ioloop, zmqstream
27
27
28 # local imports
28 # local imports
29 from IPython.external.decorator import decorator
29 from decorator import decorator
30 from IPython.config.application import Application
30 from IPython.config.application import Application
31 from IPython.config.loader import Config
31 from IPython.config.loader import Config
32 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
32 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
33 from IPython.utils.py3compat import cast_bytes
33 from IPython.utils.py3compat import cast_bytes
34
34
35 from IPython.parallel import error, util
35 from IPython.parallel import error, util
36 from IPython.parallel.factory import SessionFactory
36 from IPython.parallel.factory import SessionFactory
37 from IPython.parallel.util import connect_logger, local_logger
37 from IPython.parallel.util import connect_logger, local_logger
38
38
39 from .dependency import Dependency
39 from .dependency import Dependency
40
40
41 @decorator
41 @decorator
42 def logged(f,self,*args,**kwargs):
42 def logged(f,self,*args,**kwargs):
43 # print ("#--------------------")
43 # print ("#--------------------")
44 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
44 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
45 # print ("#--")
45 # print ("#--")
46 return f(self,*args, **kwargs)
46 return f(self,*args, **kwargs)
47
47
48 #----------------------------------------------------------------------
48 #----------------------------------------------------------------------
49 # Chooser functions
49 # Chooser functions
50 #----------------------------------------------------------------------
50 #----------------------------------------------------------------------
51
51
52 def plainrandom(loads):
52 def plainrandom(loads):
53 """Plain random pick."""
53 """Plain random pick."""
54 n = len(loads)
54 n = len(loads)
55 return randint(0,n-1)
55 return randint(0,n-1)
56
56
57 def lru(loads):
57 def lru(loads):
58 """Always pick the front of the line.
58 """Always pick the front of the line.
59
59
60 The content of `loads` is ignored.
60 The content of `loads` is ignored.
61
61
62 Assumes LRU ordering of loads, with oldest first.
62 Assumes LRU ordering of loads, with oldest first.
63 """
63 """
64 return 0
64 return 0
65
65
66 def twobin(loads):
66 def twobin(loads):
67 """Pick two at random, use the LRU of the two.
67 """Pick two at random, use the LRU of the two.
68
68
69 The content of loads is ignored.
69 The content of loads is ignored.
70
70
71 Assumes LRU ordering of loads, with oldest first.
71 Assumes LRU ordering of loads, with oldest first.
72 """
72 """
73 n = len(loads)
73 n = len(loads)
74 a = randint(0,n-1)
74 a = randint(0,n-1)
75 b = randint(0,n-1)
75 b = randint(0,n-1)
76 return min(a,b)
76 return min(a,b)
77
77
78 def weighted(loads):
78 def weighted(loads):
79 """Pick two at random using inverse load as weight.
79 """Pick two at random using inverse load as weight.
80
80
81 Return the less loaded of the two.
81 Return the less loaded of the two.
82 """
82 """
83 # weight 0 a million times more than 1:
83 # weight 0 a million times more than 1:
84 weights = 1./(1e-6+numpy.array(loads))
84 weights = 1./(1e-6+numpy.array(loads))
85 sums = weights.cumsum()
85 sums = weights.cumsum()
86 t = sums[-1]
86 t = sums[-1]
87 x = random()*t
87 x = random()*t
88 y = random()*t
88 y = random()*t
89 idx = 0
89 idx = 0
90 idy = 0
90 idy = 0
91 while sums[idx] < x:
91 while sums[idx] < x:
92 idx += 1
92 idx += 1
93 while sums[idy] < y:
93 while sums[idy] < y:
94 idy += 1
94 idy += 1
95 if weights[idy] > weights[idx]:
95 if weights[idy] > weights[idx]:
96 return idy
96 return idy
97 else:
97 else:
98 return idx
98 return idx
99
99
100 def leastload(loads):
100 def leastload(loads):
101 """Always choose the lowest load.
101 """Always choose the lowest load.
102
102
103 If the lowest load occurs more than once, the first
103 If the lowest load occurs more than once, the first
104 occurance will be used. If loads has LRU ordering, this means
104 occurance will be used. If loads has LRU ordering, this means
105 the LRU of those with the lowest load is chosen.
105 the LRU of those with the lowest load is chosen.
106 """
106 """
107 return loads.index(min(loads))
107 return loads.index(min(loads))
108
108
109 #---------------------------------------------------------------------
109 #---------------------------------------------------------------------
110 # Classes
110 # Classes
111 #---------------------------------------------------------------------
111 #---------------------------------------------------------------------
112
112
113
113
114 # store empty default dependency:
114 # store empty default dependency:
115 MET = Dependency([])
115 MET = Dependency([])
116
116
117
117
118 class Job(object):
118 class Job(object):
119 """Simple container for a job"""
119 """Simple container for a job"""
120 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
120 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
121 targets, after, follow, timeout):
121 targets, after, follow, timeout):
122 self.msg_id = msg_id
122 self.msg_id = msg_id
123 self.raw_msg = raw_msg
123 self.raw_msg = raw_msg
124 self.idents = idents
124 self.idents = idents
125 self.msg = msg
125 self.msg = msg
126 self.header = header
126 self.header = header
127 self.metadata = metadata
127 self.metadata = metadata
128 self.targets = targets
128 self.targets = targets
129 self.after = after
129 self.after = after
130 self.follow = follow
130 self.follow = follow
131 self.timeout = timeout
131 self.timeout = timeout
132
132
133 self.removed = False # used for lazy-delete from sorted queue
133 self.removed = False # used for lazy-delete from sorted queue
134 self.timestamp = time.time()
134 self.timestamp = time.time()
135 self.timeout_id = 0
135 self.timeout_id = 0
136 self.blacklist = set()
136 self.blacklist = set()
137
137
138 def __lt__(self, other):
138 def __lt__(self, other):
139 return self.timestamp < other.timestamp
139 return self.timestamp < other.timestamp
140
140
141 def __cmp__(self, other):
141 def __cmp__(self, other):
142 return cmp(self.timestamp, other.timestamp)
142 return cmp(self.timestamp, other.timestamp)
143
143
144 @property
144 @property
145 def dependents(self):
145 def dependents(self):
146 return self.follow.union(self.after)
146 return self.follow.union(self.after)
147
147
148
148
149 class TaskScheduler(SessionFactory):
149 class TaskScheduler(SessionFactory):
150 """Python TaskScheduler object.
150 """Python TaskScheduler object.
151
151
152 This is the simplest object that supports msg_id based
152 This is the simplest object that supports msg_id based
153 DAG dependencies. *Only* task msg_ids are checked, not
153 DAG dependencies. *Only* task msg_ids are checked, not
154 msg_ids of jobs submitted via the MUX queue.
154 msg_ids of jobs submitted via the MUX queue.
155
155
156 """
156 """
157
157
158 hwm = Integer(1, config=True,
158 hwm = Integer(1, config=True,
159 help="""specify the High Water Mark (HWM) for the downstream
159 help="""specify the High Water Mark (HWM) for the downstream
160 socket in the Task scheduler. This is the maximum number
160 socket in the Task scheduler. This is the maximum number
161 of allowed outstanding tasks on each engine.
161 of allowed outstanding tasks on each engine.
162
162
163 The default (1) means that only one task can be outstanding on each
163 The default (1) means that only one task can be outstanding on each
164 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
164 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
165 engines continue to be assigned tasks while they are working,
165 engines continue to be assigned tasks while they are working,
166 effectively hiding network latency behind computation, but can result
166 effectively hiding network latency behind computation, but can result
167 in an imbalance of work when submitting many heterogenous tasks all at
167 in an imbalance of work when submitting many heterogenous tasks all at
168 once. Any positive value greater than one is a compromise between the
168 once. Any positive value greater than one is a compromise between the
169 two.
169 two.
170
170
171 """
171 """
172 )
172 )
173 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
173 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
174 'leastload', config=True,
174 'leastload', config=True,
175 help="""select the task scheduler scheme [default: Python LRU]
175 help="""select the task scheduler scheme [default: Python LRU]
176 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
176 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
177 )
177 )
178 def _scheme_name_changed(self, old, new):
178 def _scheme_name_changed(self, old, new):
179 self.log.debug("Using scheme %r"%new)
179 self.log.debug("Using scheme %r"%new)
180 self.scheme = globals()[new]
180 self.scheme = globals()[new]
181
181
182 # input arguments:
182 # input arguments:
183 scheme = Instance(FunctionType) # function for determining the destination
183 scheme = Instance(FunctionType) # function for determining the destination
184 def _scheme_default(self):
184 def _scheme_default(self):
185 return leastload
185 return leastload
186 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
186 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
187 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
187 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
188 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
188 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
189 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
189 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
190 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
190 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
191
191
192 # internals:
192 # internals:
193 queue = Instance(deque) # sorted list of Jobs
193 queue = Instance(deque) # sorted list of Jobs
194 def _queue_default(self):
194 def _queue_default(self):
195 return deque()
195 return deque()
196 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
196 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
197 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
197 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
198 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
198 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
199 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
199 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
200 pending = Dict() # dict by engine_uuid of submitted tasks
200 pending = Dict() # dict by engine_uuid of submitted tasks
201 completed = Dict() # dict by engine_uuid of completed tasks
201 completed = Dict() # dict by engine_uuid of completed tasks
202 failed = Dict() # dict by engine_uuid of failed tasks
202 failed = Dict() # dict by engine_uuid of failed tasks
203 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
203 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
204 clients = Dict() # dict by msg_id for who submitted the task
204 clients = Dict() # dict by msg_id for who submitted the task
205 targets = List() # list of target IDENTs
205 targets = List() # list of target IDENTs
206 loads = List() # list of engine loads
206 loads = List() # list of engine loads
207 # full = Set() # set of IDENTs that have HWM outstanding tasks
207 # full = Set() # set of IDENTs that have HWM outstanding tasks
208 all_completed = Set() # set of all completed tasks
208 all_completed = Set() # set of all completed tasks
209 all_failed = Set() # set of all failed tasks
209 all_failed = Set() # set of all failed tasks
210 all_done = Set() # set of all finished tasks=union(completed,failed)
210 all_done = Set() # set of all finished tasks=union(completed,failed)
211 all_ids = Set() # set of all submitted task IDs
211 all_ids = Set() # set of all submitted task IDs
212
212
213 ident = CBytes() # ZMQ identity. This should just be self.session.session
213 ident = CBytes() # ZMQ identity. This should just be self.session.session
214 # but ensure Bytes
214 # but ensure Bytes
215 def _ident_default(self):
215 def _ident_default(self):
216 return self.session.bsession
216 return self.session.bsession
217
217
218 def start(self):
218 def start(self):
219 self.query_stream.on_recv(self.dispatch_query_reply)
219 self.query_stream.on_recv(self.dispatch_query_reply)
220 self.session.send(self.query_stream, "connection_request", {})
220 self.session.send(self.query_stream, "connection_request", {})
221
221
222 self.engine_stream.on_recv(self.dispatch_result, copy=False)
222 self.engine_stream.on_recv(self.dispatch_result, copy=False)
223 self.client_stream.on_recv(self.dispatch_submission, copy=False)
223 self.client_stream.on_recv(self.dispatch_submission, copy=False)
224
224
225 self._notification_handlers = dict(
225 self._notification_handlers = dict(
226 registration_notification = self._register_engine,
226 registration_notification = self._register_engine,
227 unregistration_notification = self._unregister_engine
227 unregistration_notification = self._unregister_engine
228 )
228 )
229 self.notifier_stream.on_recv(self.dispatch_notification)
229 self.notifier_stream.on_recv(self.dispatch_notification)
230 self.log.info("Scheduler started [%s]" % self.scheme_name)
230 self.log.info("Scheduler started [%s]" % self.scheme_name)
231
231
232 def resume_receiving(self):
232 def resume_receiving(self):
233 """Resume accepting jobs."""
233 """Resume accepting jobs."""
234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
235
235
236 def stop_receiving(self):
236 def stop_receiving(self):
237 """Stop accepting jobs while there are no engines.
237 """Stop accepting jobs while there are no engines.
238 Leave them in the ZMQ queue."""
238 Leave them in the ZMQ queue."""
239 self.client_stream.on_recv(None)
239 self.client_stream.on_recv(None)
240
240
241 #-----------------------------------------------------------------------
241 #-----------------------------------------------------------------------
242 # [Un]Registration Handling
242 # [Un]Registration Handling
243 #-----------------------------------------------------------------------
243 #-----------------------------------------------------------------------
244
244
245
245
246 def dispatch_query_reply(self, msg):
246 def dispatch_query_reply(self, msg):
247 """handle reply to our initial connection request"""
247 """handle reply to our initial connection request"""
248 try:
248 try:
249 idents,msg = self.session.feed_identities(msg)
249 idents,msg = self.session.feed_identities(msg)
250 except ValueError:
250 except ValueError:
251 self.log.warn("task::Invalid Message: %r",msg)
251 self.log.warn("task::Invalid Message: %r",msg)
252 return
252 return
253 try:
253 try:
254 msg = self.session.deserialize(msg)
254 msg = self.session.deserialize(msg)
255 except ValueError:
255 except ValueError:
256 self.log.warn("task::Unauthorized message from: %r"%idents)
256 self.log.warn("task::Unauthorized message from: %r"%idents)
257 return
257 return
258
258
259 content = msg['content']
259 content = msg['content']
260 for uuid in content.get('engines', {}).values():
260 for uuid in content.get('engines', {}).values():
261 self._register_engine(cast_bytes(uuid))
261 self._register_engine(cast_bytes(uuid))
262
262
263
263
264 @util.log_errors
264 @util.log_errors
265 def dispatch_notification(self, msg):
265 def dispatch_notification(self, msg):
266 """dispatch register/unregister events."""
266 """dispatch register/unregister events."""
267 try:
267 try:
268 idents,msg = self.session.feed_identities(msg)
268 idents,msg = self.session.feed_identities(msg)
269 except ValueError:
269 except ValueError:
270 self.log.warn("task::Invalid Message: %r",msg)
270 self.log.warn("task::Invalid Message: %r",msg)
271 return
271 return
272 try:
272 try:
273 msg = self.session.deserialize(msg)
273 msg = self.session.deserialize(msg)
274 except ValueError:
274 except ValueError:
275 self.log.warn("task::Unauthorized message from: %r"%idents)
275 self.log.warn("task::Unauthorized message from: %r"%idents)
276 return
276 return
277
277
278 msg_type = msg['header']['msg_type']
278 msg_type = msg['header']['msg_type']
279
279
280 handler = self._notification_handlers.get(msg_type, None)
280 handler = self._notification_handlers.get(msg_type, None)
281 if handler is None:
281 if handler is None:
282 self.log.error("Unhandled message type: %r"%msg_type)
282 self.log.error("Unhandled message type: %r"%msg_type)
283 else:
283 else:
284 try:
284 try:
285 handler(cast_bytes(msg['content']['uuid']))
285 handler(cast_bytes(msg['content']['uuid']))
286 except Exception:
286 except Exception:
287 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
287 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
288
288
289 def _register_engine(self, uid):
289 def _register_engine(self, uid):
290 """New engine with ident `uid` became available."""
290 """New engine with ident `uid` became available."""
291 # head of the line:
291 # head of the line:
292 self.targets.insert(0,uid)
292 self.targets.insert(0,uid)
293 self.loads.insert(0,0)
293 self.loads.insert(0,0)
294
294
295 # initialize sets
295 # initialize sets
296 self.completed[uid] = set()
296 self.completed[uid] = set()
297 self.failed[uid] = set()
297 self.failed[uid] = set()
298 self.pending[uid] = {}
298 self.pending[uid] = {}
299
299
300 # rescan the graph:
300 # rescan the graph:
301 self.update_graph(None)
301 self.update_graph(None)
302
302
303 def _unregister_engine(self, uid):
303 def _unregister_engine(self, uid):
304 """Existing engine with ident `uid` became unavailable."""
304 """Existing engine with ident `uid` became unavailable."""
305 if len(self.targets) == 1:
305 if len(self.targets) == 1:
306 # this was our only engine
306 # this was our only engine
307 pass
307 pass
308
308
309 # handle any potentially finished tasks:
309 # handle any potentially finished tasks:
310 self.engine_stream.flush()
310 self.engine_stream.flush()
311
311
312 # don't pop destinations, because they might be used later
312 # don't pop destinations, because they might be used later
313 # map(self.destinations.pop, self.completed.pop(uid))
313 # map(self.destinations.pop, self.completed.pop(uid))
314 # map(self.destinations.pop, self.failed.pop(uid))
314 # map(self.destinations.pop, self.failed.pop(uid))
315
315
316 # prevent this engine from receiving work
316 # prevent this engine from receiving work
317 idx = self.targets.index(uid)
317 idx = self.targets.index(uid)
318 self.targets.pop(idx)
318 self.targets.pop(idx)
319 self.loads.pop(idx)
319 self.loads.pop(idx)
320
320
321 # wait 5 seconds before cleaning up pending jobs, since the results might
321 # wait 5 seconds before cleaning up pending jobs, since the results might
322 # still be incoming
322 # still be incoming
323 if self.pending[uid]:
323 if self.pending[uid]:
324 self.loop.add_timeout(self.loop.time() + 5,
324 self.loop.add_timeout(self.loop.time() + 5,
325 lambda : self.handle_stranded_tasks(uid),
325 lambda : self.handle_stranded_tasks(uid),
326 )
326 )
327 else:
327 else:
328 self.completed.pop(uid)
328 self.completed.pop(uid)
329 self.failed.pop(uid)
329 self.failed.pop(uid)
330
330
331
331
332 def handle_stranded_tasks(self, engine):
332 def handle_stranded_tasks(self, engine):
333 """Deal with jobs resident in an engine that died."""
333 """Deal with jobs resident in an engine that died."""
334 lost = self.pending[engine]
334 lost = self.pending[engine]
335 for msg_id in lost.keys():
335 for msg_id in lost.keys():
336 if msg_id not in self.pending[engine]:
336 if msg_id not in self.pending[engine]:
337 # prevent double-handling of messages
337 # prevent double-handling of messages
338 continue
338 continue
339
339
340 raw_msg = lost[msg_id].raw_msg
340 raw_msg = lost[msg_id].raw_msg
341 idents,msg = self.session.feed_identities(raw_msg, copy=False)
341 idents,msg = self.session.feed_identities(raw_msg, copy=False)
342 parent = self.session.unpack(msg[1].bytes)
342 parent = self.session.unpack(msg[1].bytes)
343 idents = [engine, idents[0]]
343 idents = [engine, idents[0]]
344
344
345 # build fake error reply
345 # build fake error reply
346 try:
346 try:
347 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
347 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
348 except:
348 except:
349 content = error.wrap_exception()
349 content = error.wrap_exception()
350 # build fake metadata
350 # build fake metadata
351 md = dict(
351 md = dict(
352 status=u'error',
352 status=u'error',
353 engine=engine.decode('ascii'),
353 engine=engine.decode('ascii'),
354 date=datetime.now(),
354 date=datetime.now(),
355 )
355 )
356 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
356 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
357 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
357 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
358 # and dispatch it
358 # and dispatch it
359 self.dispatch_result(raw_reply)
359 self.dispatch_result(raw_reply)
360
360
361 # finally scrub completed/failed lists
361 # finally scrub completed/failed lists
362 self.completed.pop(engine)
362 self.completed.pop(engine)
363 self.failed.pop(engine)
363 self.failed.pop(engine)
364
364
365
365
366 #-----------------------------------------------------------------------
366 #-----------------------------------------------------------------------
367 # Job Submission
367 # Job Submission
368 #-----------------------------------------------------------------------
368 #-----------------------------------------------------------------------
369
369
370
370
371 @util.log_errors
371 @util.log_errors
372 def dispatch_submission(self, raw_msg):
372 def dispatch_submission(self, raw_msg):
373 """Dispatch job submission to appropriate handlers."""
373 """Dispatch job submission to appropriate handlers."""
374 # ensure targets up to date:
374 # ensure targets up to date:
375 self.notifier_stream.flush()
375 self.notifier_stream.flush()
376 try:
376 try:
377 idents, msg = self.session.feed_identities(raw_msg, copy=False)
377 idents, msg = self.session.feed_identities(raw_msg, copy=False)
378 msg = self.session.deserialize(msg, content=False, copy=False)
378 msg = self.session.deserialize(msg, content=False, copy=False)
379 except Exception:
379 except Exception:
380 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
380 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
381 return
381 return
382
382
383
383
384 # send to monitor
384 # send to monitor
385 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
385 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
386
386
387 header = msg['header']
387 header = msg['header']
388 md = msg['metadata']
388 md = msg['metadata']
389 msg_id = header['msg_id']
389 msg_id = header['msg_id']
390 self.all_ids.add(msg_id)
390 self.all_ids.add(msg_id)
391
391
392 # get targets as a set of bytes objects
392 # get targets as a set of bytes objects
393 # from a list of unicode objects
393 # from a list of unicode objects
394 targets = md.get('targets', [])
394 targets = md.get('targets', [])
395 targets = set(map(cast_bytes, targets))
395 targets = set(map(cast_bytes, targets))
396
396
397 retries = md.get('retries', 0)
397 retries = md.get('retries', 0)
398 self.retries[msg_id] = retries
398 self.retries[msg_id] = retries
399
399
400 # time dependencies
400 # time dependencies
401 after = md.get('after', None)
401 after = md.get('after', None)
402 if after:
402 if after:
403 after = Dependency(after)
403 after = Dependency(after)
404 if after.all:
404 if after.all:
405 if after.success:
405 if after.success:
406 after = Dependency(after.difference(self.all_completed),
406 after = Dependency(after.difference(self.all_completed),
407 success=after.success,
407 success=after.success,
408 failure=after.failure,
408 failure=after.failure,
409 all=after.all,
409 all=after.all,
410 )
410 )
411 if after.failure:
411 if after.failure:
412 after = Dependency(after.difference(self.all_failed),
412 after = Dependency(after.difference(self.all_failed),
413 success=after.success,
413 success=after.success,
414 failure=after.failure,
414 failure=after.failure,
415 all=after.all,
415 all=after.all,
416 )
416 )
417 if after.check(self.all_completed, self.all_failed):
417 if after.check(self.all_completed, self.all_failed):
418 # recast as empty set, if `after` already met,
418 # recast as empty set, if `after` already met,
419 # to prevent unnecessary set comparisons
419 # to prevent unnecessary set comparisons
420 after = MET
420 after = MET
421 else:
421 else:
422 after = MET
422 after = MET
423
423
424 # location dependencies
424 # location dependencies
425 follow = Dependency(md.get('follow', []))
425 follow = Dependency(md.get('follow', []))
426
426
427 timeout = md.get('timeout', None)
427 timeout = md.get('timeout', None)
428 if timeout:
428 if timeout:
429 timeout = float(timeout)
429 timeout = float(timeout)
430
430
431 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
431 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
432 header=header, targets=targets, after=after, follow=follow,
432 header=header, targets=targets, after=after, follow=follow,
433 timeout=timeout, metadata=md,
433 timeout=timeout, metadata=md,
434 )
434 )
435 # validate and reduce dependencies:
435 # validate and reduce dependencies:
436 for dep in after,follow:
436 for dep in after,follow:
437 if not dep: # empty dependency
437 if not dep: # empty dependency
438 continue
438 continue
439 # check valid:
439 # check valid:
440 if msg_id in dep or dep.difference(self.all_ids):
440 if msg_id in dep or dep.difference(self.all_ids):
441 self.queue_map[msg_id] = job
441 self.queue_map[msg_id] = job
442 return self.fail_unreachable(msg_id, error.InvalidDependency)
442 return self.fail_unreachable(msg_id, error.InvalidDependency)
443 # check if unreachable:
443 # check if unreachable:
444 if dep.unreachable(self.all_completed, self.all_failed):
444 if dep.unreachable(self.all_completed, self.all_failed):
445 self.queue_map[msg_id] = job
445 self.queue_map[msg_id] = job
446 return self.fail_unreachable(msg_id)
446 return self.fail_unreachable(msg_id)
447
447
448 if after.check(self.all_completed, self.all_failed):
448 if after.check(self.all_completed, self.all_failed):
449 # time deps already met, try to run
449 # time deps already met, try to run
450 if not self.maybe_run(job):
450 if not self.maybe_run(job):
451 # can't run yet
451 # can't run yet
452 if msg_id not in self.all_failed:
452 if msg_id not in self.all_failed:
453 # could have failed as unreachable
453 # could have failed as unreachable
454 self.save_unmet(job)
454 self.save_unmet(job)
455 else:
455 else:
456 self.save_unmet(job)
456 self.save_unmet(job)
457
457
458 def job_timeout(self, job, timeout_id):
458 def job_timeout(self, job, timeout_id):
459 """callback for a job's timeout.
459 """callback for a job's timeout.
460
460
461 The job may or may not have been run at this point.
461 The job may or may not have been run at this point.
462 """
462 """
463 if job.timeout_id != timeout_id:
463 if job.timeout_id != timeout_id:
464 # not the most recent call
464 # not the most recent call
465 return
465 return
466 now = time.time()
466 now = time.time()
467 if job.timeout >= (now + 1):
467 if job.timeout >= (now + 1):
468 self.log.warn("task %s timeout fired prematurely: %s > %s",
468 self.log.warn("task %s timeout fired prematurely: %s > %s",
469 job.msg_id, job.timeout, now
469 job.msg_id, job.timeout, now
470 )
470 )
471 if job.msg_id in self.queue_map:
471 if job.msg_id in self.queue_map:
472 # still waiting, but ran out of time
472 # still waiting, but ran out of time
473 self.log.info("task %r timed out", job.msg_id)
473 self.log.info("task %r timed out", job.msg_id)
474 self.fail_unreachable(job.msg_id, error.TaskTimeout)
474 self.fail_unreachable(job.msg_id, error.TaskTimeout)
475
475
476 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
476 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
477 """a task has become unreachable, send a reply with an ImpossibleDependency
477 """a task has become unreachable, send a reply with an ImpossibleDependency
478 error."""
478 error."""
479 if msg_id not in self.queue_map:
479 if msg_id not in self.queue_map:
480 self.log.error("task %r already failed!", msg_id)
480 self.log.error("task %r already failed!", msg_id)
481 return
481 return
482 job = self.queue_map.pop(msg_id)
482 job = self.queue_map.pop(msg_id)
483 # lazy-delete from the queue
483 # lazy-delete from the queue
484 job.removed = True
484 job.removed = True
485 for mid in job.dependents:
485 for mid in job.dependents:
486 if mid in self.graph:
486 if mid in self.graph:
487 self.graph[mid].remove(msg_id)
487 self.graph[mid].remove(msg_id)
488
488
489 try:
489 try:
490 raise why()
490 raise why()
491 except:
491 except:
492 content = error.wrap_exception()
492 content = error.wrap_exception()
493 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
493 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
494
494
495 self.all_done.add(msg_id)
495 self.all_done.add(msg_id)
496 self.all_failed.add(msg_id)
496 self.all_failed.add(msg_id)
497
497
498 msg = self.session.send(self.client_stream, 'apply_reply', content,
498 msg = self.session.send(self.client_stream, 'apply_reply', content,
499 parent=job.header, ident=job.idents)
499 parent=job.header, ident=job.idents)
500 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
500 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
501
501
502 self.update_graph(msg_id, success=False)
502 self.update_graph(msg_id, success=False)
503
503
504 def available_engines(self):
504 def available_engines(self):
505 """return a list of available engine indices based on HWM"""
505 """return a list of available engine indices based on HWM"""
506 if not self.hwm:
506 if not self.hwm:
507 return list(range(len(self.targets)))
507 return list(range(len(self.targets)))
508 available = []
508 available = []
509 for idx in range(len(self.targets)):
509 for idx in range(len(self.targets)):
510 if self.loads[idx] < self.hwm:
510 if self.loads[idx] < self.hwm:
511 available.append(idx)
511 available.append(idx)
512 return available
512 return available
513
513
514 def maybe_run(self, job):
514 def maybe_run(self, job):
515 """check location dependencies, and run if they are met."""
515 """check location dependencies, and run if they are met."""
516 msg_id = job.msg_id
516 msg_id = job.msg_id
517 self.log.debug("Attempting to assign task %s", msg_id)
517 self.log.debug("Attempting to assign task %s", msg_id)
518 available = self.available_engines()
518 available = self.available_engines()
519 if not available:
519 if not available:
520 # no engines, definitely can't run
520 # no engines, definitely can't run
521 return False
521 return False
522
522
523 if job.follow or job.targets or job.blacklist or self.hwm:
523 if job.follow or job.targets or job.blacklist or self.hwm:
524 # we need a can_run filter
524 # we need a can_run filter
525 def can_run(idx):
525 def can_run(idx):
526 # check hwm
526 # check hwm
527 if self.hwm and self.loads[idx] == self.hwm:
527 if self.hwm and self.loads[idx] == self.hwm:
528 return False
528 return False
529 target = self.targets[idx]
529 target = self.targets[idx]
530 # check blacklist
530 # check blacklist
531 if target in job.blacklist:
531 if target in job.blacklist:
532 return False
532 return False
533 # check targets
533 # check targets
534 if job.targets and target not in job.targets:
534 if job.targets and target not in job.targets:
535 return False
535 return False
536 # check follow
536 # check follow
537 return job.follow.check(self.completed[target], self.failed[target])
537 return job.follow.check(self.completed[target], self.failed[target])
538
538
539 indices = list(filter(can_run, available))
539 indices = list(filter(can_run, available))
540
540
541 if not indices:
541 if not indices:
542 # couldn't run
542 # couldn't run
543 if job.follow.all:
543 if job.follow.all:
544 # check follow for impossibility
544 # check follow for impossibility
545 dests = set()
545 dests = set()
546 relevant = set()
546 relevant = set()
547 if job.follow.success:
547 if job.follow.success:
548 relevant = self.all_completed
548 relevant = self.all_completed
549 if job.follow.failure:
549 if job.follow.failure:
550 relevant = relevant.union(self.all_failed)
550 relevant = relevant.union(self.all_failed)
551 for m in job.follow.intersection(relevant):
551 for m in job.follow.intersection(relevant):
552 dests.add(self.destinations[m])
552 dests.add(self.destinations[m])
553 if len(dests) > 1:
553 if len(dests) > 1:
554 self.queue_map[msg_id] = job
554 self.queue_map[msg_id] = job
555 self.fail_unreachable(msg_id)
555 self.fail_unreachable(msg_id)
556 return False
556 return False
557 if job.targets:
557 if job.targets:
558 # check blacklist+targets for impossibility
558 # check blacklist+targets for impossibility
559 job.targets.difference_update(job.blacklist)
559 job.targets.difference_update(job.blacklist)
560 if not job.targets or not job.targets.intersection(self.targets):
560 if not job.targets or not job.targets.intersection(self.targets):
561 self.queue_map[msg_id] = job
561 self.queue_map[msg_id] = job
562 self.fail_unreachable(msg_id)
562 self.fail_unreachable(msg_id)
563 return False
563 return False
564 return False
564 return False
565 else:
565 else:
566 indices = None
566 indices = None
567
567
568 self.submit_task(job, indices)
568 self.submit_task(job, indices)
569 return True
569 return True
570
570
571 def save_unmet(self, job):
571 def save_unmet(self, job):
572 """Save a message for later submission when its dependencies are met."""
572 """Save a message for later submission when its dependencies are met."""
573 msg_id = job.msg_id
573 msg_id = job.msg_id
574 self.log.debug("Adding task %s to the queue", msg_id)
574 self.log.debug("Adding task %s to the queue", msg_id)
575 self.queue_map[msg_id] = job
575 self.queue_map[msg_id] = job
576 self.queue.append(job)
576 self.queue.append(job)
577 # track the ids in follow or after, but not those already finished
577 # track the ids in follow or after, but not those already finished
578 for dep_id in job.after.union(job.follow).difference(self.all_done):
578 for dep_id in job.after.union(job.follow).difference(self.all_done):
579 if dep_id not in self.graph:
579 if dep_id not in self.graph:
580 self.graph[dep_id] = set()
580 self.graph[dep_id] = set()
581 self.graph[dep_id].add(msg_id)
581 self.graph[dep_id].add(msg_id)
582
582
583 # schedule timeout callback
583 # schedule timeout callback
584 if job.timeout:
584 if job.timeout:
585 timeout_id = job.timeout_id = job.timeout_id + 1
585 timeout_id = job.timeout_id = job.timeout_id + 1
586 self.loop.add_timeout(time.time() + job.timeout,
586 self.loop.add_timeout(time.time() + job.timeout,
587 lambda : self.job_timeout(job, timeout_id)
587 lambda : self.job_timeout(job, timeout_id)
588 )
588 )
589
589
590
590
591 def submit_task(self, job, indices=None):
591 def submit_task(self, job, indices=None):
592 """Submit a task to any of a subset of our targets."""
592 """Submit a task to any of a subset of our targets."""
593 if indices:
593 if indices:
594 loads = [self.loads[i] for i in indices]
594 loads = [self.loads[i] for i in indices]
595 else:
595 else:
596 loads = self.loads
596 loads = self.loads
597 idx = self.scheme(loads)
597 idx = self.scheme(loads)
598 if indices:
598 if indices:
599 idx = indices[idx]
599 idx = indices[idx]
600 target = self.targets[idx]
600 target = self.targets[idx]
601 # print (target, map(str, msg[:3]))
601 # print (target, map(str, msg[:3]))
602 # send job to the engine
602 # send job to the engine
603 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
603 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
604 self.engine_stream.send_multipart(job.raw_msg, copy=False)
604 self.engine_stream.send_multipart(job.raw_msg, copy=False)
605 # update load
605 # update load
606 self.add_job(idx)
606 self.add_job(idx)
607 self.pending[target][job.msg_id] = job
607 self.pending[target][job.msg_id] = job
608 # notify Hub
608 # notify Hub
609 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
609 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
610 self.session.send(self.mon_stream, 'task_destination', content=content,
610 self.session.send(self.mon_stream, 'task_destination', content=content,
611 ident=[b'tracktask',self.ident])
611 ident=[b'tracktask',self.ident])
612
612
613
613
614 #-----------------------------------------------------------------------
614 #-----------------------------------------------------------------------
615 # Result Handling
615 # Result Handling
616 #-----------------------------------------------------------------------
616 #-----------------------------------------------------------------------
617
617
618
618
619 @util.log_errors
619 @util.log_errors
620 def dispatch_result(self, raw_msg):
620 def dispatch_result(self, raw_msg):
621 """dispatch method for result replies"""
621 """dispatch method for result replies"""
622 try:
622 try:
623 idents,msg = self.session.feed_identities(raw_msg, copy=False)
623 idents,msg = self.session.feed_identities(raw_msg, copy=False)
624 msg = self.session.deserialize(msg, content=False, copy=False)
624 msg = self.session.deserialize(msg, content=False, copy=False)
625 engine = idents[0]
625 engine = idents[0]
626 try:
626 try:
627 idx = self.targets.index(engine)
627 idx = self.targets.index(engine)
628 except ValueError:
628 except ValueError:
629 pass # skip load-update for dead engines
629 pass # skip load-update for dead engines
630 else:
630 else:
631 self.finish_job(idx)
631 self.finish_job(idx)
632 except Exception:
632 except Exception:
633 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
633 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
634 return
634 return
635
635
636 md = msg['metadata']
636 md = msg['metadata']
637 parent = msg['parent_header']
637 parent = msg['parent_header']
638 if md.get('dependencies_met', True):
638 if md.get('dependencies_met', True):
639 success = (md['status'] == 'ok')
639 success = (md['status'] == 'ok')
640 msg_id = parent['msg_id']
640 msg_id = parent['msg_id']
641 retries = self.retries[msg_id]
641 retries = self.retries[msg_id]
642 if not success and retries > 0:
642 if not success and retries > 0:
643 # failed
643 # failed
644 self.retries[msg_id] = retries - 1
644 self.retries[msg_id] = retries - 1
645 self.handle_unmet_dependency(idents, parent)
645 self.handle_unmet_dependency(idents, parent)
646 else:
646 else:
647 del self.retries[msg_id]
647 del self.retries[msg_id]
648 # relay to client and update graph
648 # relay to client and update graph
649 self.handle_result(idents, parent, raw_msg, success)
649 self.handle_result(idents, parent, raw_msg, success)
650 # send to Hub monitor
650 # send to Hub monitor
651 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
651 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
652 else:
652 else:
653 self.handle_unmet_dependency(idents, parent)
653 self.handle_unmet_dependency(idents, parent)
654
654
655 def handle_result(self, idents, parent, raw_msg, success=True):
655 def handle_result(self, idents, parent, raw_msg, success=True):
656 """handle a real task result, either success or failure"""
656 """handle a real task result, either success or failure"""
657 # first, relay result to client
657 # first, relay result to client
658 engine = idents[0]
658 engine = idents[0]
659 client = idents[1]
659 client = idents[1]
660 # swap_ids for ROUTER-ROUTER mirror
660 # swap_ids for ROUTER-ROUTER mirror
661 raw_msg[:2] = [client,engine]
661 raw_msg[:2] = [client,engine]
662 # print (map(str, raw_msg[:4]))
662 # print (map(str, raw_msg[:4]))
663 self.client_stream.send_multipart(raw_msg, copy=False)
663 self.client_stream.send_multipart(raw_msg, copy=False)
664 # now, update our data structures
664 # now, update our data structures
665 msg_id = parent['msg_id']
665 msg_id = parent['msg_id']
666 self.pending[engine].pop(msg_id)
666 self.pending[engine].pop(msg_id)
667 if success:
667 if success:
668 self.completed[engine].add(msg_id)
668 self.completed[engine].add(msg_id)
669 self.all_completed.add(msg_id)
669 self.all_completed.add(msg_id)
670 else:
670 else:
671 self.failed[engine].add(msg_id)
671 self.failed[engine].add(msg_id)
672 self.all_failed.add(msg_id)
672 self.all_failed.add(msg_id)
673 self.all_done.add(msg_id)
673 self.all_done.add(msg_id)
674 self.destinations[msg_id] = engine
674 self.destinations[msg_id] = engine
675
675
676 self.update_graph(msg_id, success)
676 self.update_graph(msg_id, success)
677
677
678 def handle_unmet_dependency(self, idents, parent):
678 def handle_unmet_dependency(self, idents, parent):
679 """handle an unmet dependency"""
679 """handle an unmet dependency"""
680 engine = idents[0]
680 engine = idents[0]
681 msg_id = parent['msg_id']
681 msg_id = parent['msg_id']
682
682
683 job = self.pending[engine].pop(msg_id)
683 job = self.pending[engine].pop(msg_id)
684 job.blacklist.add(engine)
684 job.blacklist.add(engine)
685
685
686 if job.blacklist == job.targets:
686 if job.blacklist == job.targets:
687 self.queue_map[msg_id] = job
687 self.queue_map[msg_id] = job
688 self.fail_unreachable(msg_id)
688 self.fail_unreachable(msg_id)
689 elif not self.maybe_run(job):
689 elif not self.maybe_run(job):
690 # resubmit failed
690 # resubmit failed
691 if msg_id not in self.all_failed:
691 if msg_id not in self.all_failed:
692 # put it back in our dependency tree
692 # put it back in our dependency tree
693 self.save_unmet(job)
693 self.save_unmet(job)
694
694
695 if self.hwm:
695 if self.hwm:
696 try:
696 try:
697 idx = self.targets.index(engine)
697 idx = self.targets.index(engine)
698 except ValueError:
698 except ValueError:
699 pass # skip load-update for dead engines
699 pass # skip load-update for dead engines
700 else:
700 else:
701 if self.loads[idx] == self.hwm-1:
701 if self.loads[idx] == self.hwm-1:
702 self.update_graph(None)
702 self.update_graph(None)
703
703
704 def update_graph(self, dep_id=None, success=True):
704 def update_graph(self, dep_id=None, success=True):
705 """dep_id just finished. Update our dependency
705 """dep_id just finished. Update our dependency
706 graph and submit any jobs that just became runnable.
706 graph and submit any jobs that just became runnable.
707
707
708 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
708 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
709 """
709 """
710 # print ("\n\n***********")
710 # print ("\n\n***********")
711 # pprint (dep_id)
711 # pprint (dep_id)
712 # pprint (self.graph)
712 # pprint (self.graph)
713 # pprint (self.queue_map)
713 # pprint (self.queue_map)
714 # pprint (self.all_completed)
714 # pprint (self.all_completed)
715 # pprint (self.all_failed)
715 # pprint (self.all_failed)
716 # print ("\n\n***********\n\n")
716 # print ("\n\n***********\n\n")
717 # update any jobs that depended on the dependency
717 # update any jobs that depended on the dependency
718 msg_ids = self.graph.pop(dep_id, [])
718 msg_ids = self.graph.pop(dep_id, [])
719
719
720 # recheck *all* jobs if
720 # recheck *all* jobs if
721 # a) we have HWM and an engine just become no longer full
721 # a) we have HWM and an engine just become no longer full
722 # or b) dep_id was given as None
722 # or b) dep_id was given as None
723
723
724 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
724 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
725 jobs = self.queue
725 jobs = self.queue
726 using_queue = True
726 using_queue = True
727 else:
727 else:
728 using_queue = False
728 using_queue = False
729 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
729 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
730
730
731 to_restore = []
731 to_restore = []
732 while jobs:
732 while jobs:
733 job = jobs.popleft()
733 job = jobs.popleft()
734 if job.removed:
734 if job.removed:
735 continue
735 continue
736 msg_id = job.msg_id
736 msg_id = job.msg_id
737
737
738 put_it_back = True
738 put_it_back = True
739
739
740 if job.after.unreachable(self.all_completed, self.all_failed)\
740 if job.after.unreachable(self.all_completed, self.all_failed)\
741 or job.follow.unreachable(self.all_completed, self.all_failed):
741 or job.follow.unreachable(self.all_completed, self.all_failed):
742 self.fail_unreachable(msg_id)
742 self.fail_unreachable(msg_id)
743 put_it_back = False
743 put_it_back = False
744
744
745 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
745 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
746 if self.maybe_run(job):
746 if self.maybe_run(job):
747 put_it_back = False
747 put_it_back = False
748 self.queue_map.pop(msg_id)
748 self.queue_map.pop(msg_id)
749 for mid in job.dependents:
749 for mid in job.dependents:
750 if mid in self.graph:
750 if mid in self.graph:
751 self.graph[mid].remove(msg_id)
751 self.graph[mid].remove(msg_id)
752
752
753 # abort the loop if we just filled up all of our engines.
753 # abort the loop if we just filled up all of our engines.
754 # avoids an O(N) operation in situation of full queue,
754 # avoids an O(N) operation in situation of full queue,
755 # where graph update is triggered as soon as an engine becomes
755 # where graph update is triggered as soon as an engine becomes
756 # non-full, and all tasks after the first are checked,
756 # non-full, and all tasks after the first are checked,
757 # even though they can't run.
757 # even though they can't run.
758 if not self.available_engines():
758 if not self.available_engines():
759 break
759 break
760
760
761 if using_queue and put_it_back:
761 if using_queue and put_it_back:
762 # popped a job from the queue but it neither ran nor failed,
762 # popped a job from the queue but it neither ran nor failed,
763 # so we need to put it back when we are done
763 # so we need to put it back when we are done
764 # make sure to_restore preserves the same ordering
764 # make sure to_restore preserves the same ordering
765 to_restore.append(job)
765 to_restore.append(job)
766
766
767 # put back any tasks we popped but didn't run
767 # put back any tasks we popped but didn't run
768 if using_queue:
768 if using_queue:
769 self.queue.extendleft(to_restore)
769 self.queue.extendleft(to_restore)
770
770
771 #----------------------------------------------------------------------
771 #----------------------------------------------------------------------
772 # methods to be overridden by subclasses
772 # methods to be overridden by subclasses
773 #----------------------------------------------------------------------
773 #----------------------------------------------------------------------
774
774
775 def add_job(self, idx):
775 def add_job(self, idx):
776 """Called after self.targets[idx] just got the job with header.
776 """Called after self.targets[idx] just got the job with header.
777 Override with subclasses. The default ordering is simple LRU.
777 Override with subclasses. The default ordering is simple LRU.
778 The default loads are the number of outstanding jobs."""
778 The default loads are the number of outstanding jobs."""
779 self.loads[idx] += 1
779 self.loads[idx] += 1
780 for lis in (self.targets, self.loads):
780 for lis in (self.targets, self.loads):
781 lis.append(lis.pop(idx))
781 lis.append(lis.pop(idx))
782
782
783
783
784 def finish_job(self, idx):
784 def finish_job(self, idx):
785 """Called after self.targets[idx] just finished a job.
785 """Called after self.targets[idx] just finished a job.
786 Override with subclasses."""
786 Override with subclasses."""
787 self.loads[idx] -= 1
787 self.loads[idx] -= 1
788
788
789
789
790
790
791 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
791 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
792 logname='root', log_url=None, loglevel=logging.DEBUG,
792 logname='root', log_url=None, loglevel=logging.DEBUG,
793 identity=b'task', in_thread=False):
793 identity=b'task', in_thread=False):
794
794
795 ZMQStream = zmqstream.ZMQStream
795 ZMQStream = zmqstream.ZMQStream
796
796
797 if config:
797 if config:
798 # unwrap dict back into Config
798 # unwrap dict back into Config
799 config = Config(config)
799 config = Config(config)
800
800
801 if in_thread:
801 if in_thread:
802 # use instance() to get the same Context/Loop as our parent
802 # use instance() to get the same Context/Loop as our parent
803 ctx = zmq.Context.instance()
803 ctx = zmq.Context.instance()
804 loop = ioloop.IOLoop.instance()
804 loop = ioloop.IOLoop.instance()
805 else:
805 else:
806 # in a process, don't use instance()
806 # in a process, don't use instance()
807 # for safety with multiprocessing
807 # for safety with multiprocessing
808 ctx = zmq.Context()
808 ctx = zmq.Context()
809 loop = ioloop.IOLoop()
809 loop = ioloop.IOLoop()
810 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
810 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
811 util.set_hwm(ins, 0)
811 util.set_hwm(ins, 0)
812 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
812 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
813 ins.bind(in_addr)
813 ins.bind(in_addr)
814
814
815 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
815 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
816 util.set_hwm(outs, 0)
816 util.set_hwm(outs, 0)
817 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
817 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
818 outs.bind(out_addr)
818 outs.bind(out_addr)
819 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
819 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
820 util.set_hwm(mons, 0)
820 util.set_hwm(mons, 0)
821 mons.connect(mon_addr)
821 mons.connect(mon_addr)
822 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
822 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
823 nots.setsockopt(zmq.SUBSCRIBE, b'')
823 nots.setsockopt(zmq.SUBSCRIBE, b'')
824 nots.connect(not_addr)
824 nots.connect(not_addr)
825
825
826 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
826 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
827 querys.connect(reg_addr)
827 querys.connect(reg_addr)
828
828
829 # setup logging.
829 # setup logging.
830 if in_thread:
830 if in_thread:
831 log = Application.instance().log
831 log = Application.instance().log
832 else:
832 else:
833 if log_url:
833 if log_url:
834 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
834 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
835 else:
835 else:
836 log = local_logger(logname, loglevel)
836 log = local_logger(logname, loglevel)
837
837
838 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
838 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
839 mon_stream=mons, notifier_stream=nots,
839 mon_stream=mons, notifier_stream=nots,
840 query_stream=querys,
840 query_stream=querys,
841 loop=loop, log=log,
841 loop=loop, log=log,
842 config=config)
842 config=config)
843 scheduler.start()
843 scheduler.start()
844 if not in_thread:
844 if not in_thread:
845 try:
845 try:
846 loop.start()
846 loop.start()
847 except KeyboardInterrupt:
847 except KeyboardInterrupt:
848 scheduler.log.critical("Interrupted, exiting...")
848 scheduler.log.critical("Interrupted, exiting...")
849
849
@@ -1,192 +1,192 b''
1 """base class for parallel client tests
1 """base class for parallel client tests
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14 from __future__ import print_function
14 from __future__ import print_function
15
15
16 import sys
16 import sys
17 import tempfile
17 import tempfile
18 import time
18 import time
19
19
20 from nose import SkipTest
20 from nose import SkipTest
21
21
22 import zmq
22 import zmq
23 from zmq.tests import BaseZMQTestCase
23 from zmq.tests import BaseZMQTestCase
24
24
25 from IPython.external.decorator import decorator
25 from decorator import decorator
26
26
27 from IPython.parallel import error
27 from IPython.parallel import error
28 from IPython.parallel import Client
28 from IPython.parallel import Client
29
29
30 from IPython.parallel.tests import launchers, add_engines
30 from IPython.parallel.tests import launchers, add_engines
31
31
32 # simple tasks for use in apply tests
32 # simple tasks for use in apply tests
33
33
34 def segfault():
34 def segfault():
35 """this will segfault"""
35 """this will segfault"""
36 import ctypes
36 import ctypes
37 ctypes.memset(-1,0,1)
37 ctypes.memset(-1,0,1)
38
38
39 def crash():
39 def crash():
40 """from stdlib crashers in the test suite"""
40 """from stdlib crashers in the test suite"""
41 import types
41 import types
42 if sys.platform.startswith('win'):
42 if sys.platform.startswith('win'):
43 import ctypes
43 import ctypes
44 ctypes.windll.kernel32.SetErrorMode(0x0002);
44 ctypes.windll.kernel32.SetErrorMode(0x0002);
45 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
45 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
46 if sys.version_info[0] >= 3:
46 if sys.version_info[0] >= 3:
47 # Python3 adds 'kwonlyargcount' as the second argument to Code
47 # Python3 adds 'kwonlyargcount' as the second argument to Code
48 args.insert(1, 0)
48 args.insert(1, 0)
49
49
50 co = types.CodeType(*args)
50 co = types.CodeType(*args)
51 exec(co)
51 exec(co)
52
52
53 def wait(n):
53 def wait(n):
54 """sleep for a time"""
54 """sleep for a time"""
55 import time
55 import time
56 time.sleep(n)
56 time.sleep(n)
57 return n
57 return n
58
58
59 def raiser(eclass):
59 def raiser(eclass):
60 """raise an exception"""
60 """raise an exception"""
61 raise eclass()
61 raise eclass()
62
62
63 def generate_output():
63 def generate_output():
64 """function for testing output
64 """function for testing output
65
65
66 publishes two outputs of each type, and returns
66 publishes two outputs of each type, and returns
67 a rich displayable object.
67 a rich displayable object.
68 """
68 """
69
69
70 import sys
70 import sys
71 from IPython.core.display import display, HTML, Math
71 from IPython.core.display import display, HTML, Math
72
72
73 print("stdout")
73 print("stdout")
74 print("stderr", file=sys.stderr)
74 print("stderr", file=sys.stderr)
75
75
76 display(HTML("<b>HTML</b>"))
76 display(HTML("<b>HTML</b>"))
77
77
78 print("stdout2")
78 print("stdout2")
79 print("stderr2", file=sys.stderr)
79 print("stderr2", file=sys.stderr)
80
80
81 display(Math(r"\alpha=\beta"))
81 display(Math(r"\alpha=\beta"))
82
82
83 return Math("42")
83 return Math("42")
84
84
85 # test decorator for skipping tests when libraries are unavailable
85 # test decorator for skipping tests when libraries are unavailable
86 def skip_without(*names):
86 def skip_without(*names):
87 """skip a test if some names are not importable"""
87 """skip a test if some names are not importable"""
88 @decorator
88 @decorator
89 def skip_without_names(f, *args, **kwargs):
89 def skip_without_names(f, *args, **kwargs):
90 """decorator to skip tests in the absence of numpy."""
90 """decorator to skip tests in the absence of numpy."""
91 for name in names:
91 for name in names:
92 try:
92 try:
93 __import__(name)
93 __import__(name)
94 except ImportError:
94 except ImportError:
95 raise SkipTest
95 raise SkipTest
96 return f(*args, **kwargs)
96 return f(*args, **kwargs)
97 return skip_without_names
97 return skip_without_names
98
98
99 #-------------------------------------------------------------------------------
99 #-------------------------------------------------------------------------------
100 # Classes
100 # Classes
101 #-------------------------------------------------------------------------------
101 #-------------------------------------------------------------------------------
102
102
103
103
104 class ClusterTestCase(BaseZMQTestCase):
104 class ClusterTestCase(BaseZMQTestCase):
105 timeout = 10
105 timeout = 10
106
106
107 def add_engines(self, n=1, block=True):
107 def add_engines(self, n=1, block=True):
108 """add multiple engines to our cluster"""
108 """add multiple engines to our cluster"""
109 self.engines.extend(add_engines(n))
109 self.engines.extend(add_engines(n))
110 if block:
110 if block:
111 self.wait_on_engines()
111 self.wait_on_engines()
112
112
113 def minimum_engines(self, n=1, block=True):
113 def minimum_engines(self, n=1, block=True):
114 """add engines until there are at least n connected"""
114 """add engines until there are at least n connected"""
115 self.engines.extend(add_engines(n, total=True))
115 self.engines.extend(add_engines(n, total=True))
116 if block:
116 if block:
117 self.wait_on_engines()
117 self.wait_on_engines()
118
118
119
119
120 def wait_on_engines(self, timeout=5):
120 def wait_on_engines(self, timeout=5):
121 """wait for our engines to connect."""
121 """wait for our engines to connect."""
122 n = len(self.engines)+self.base_engine_count
122 n = len(self.engines)+self.base_engine_count
123 tic = time.time()
123 tic = time.time()
124 while time.time()-tic < timeout and len(self.client.ids) < n:
124 while time.time()-tic < timeout and len(self.client.ids) < n:
125 time.sleep(0.1)
125 time.sleep(0.1)
126
126
127 assert not len(self.client.ids) < n, "waiting for engines timed out"
127 assert not len(self.client.ids) < n, "waiting for engines timed out"
128
128
129 def client_wait(self, client, jobs=None, timeout=-1):
129 def client_wait(self, client, jobs=None, timeout=-1):
130 """my wait wrapper, sets a default finite timeout to avoid hangs"""
130 """my wait wrapper, sets a default finite timeout to avoid hangs"""
131 if timeout < 0:
131 if timeout < 0:
132 timeout = self.timeout
132 timeout = self.timeout
133 return Client.wait(client, jobs, timeout)
133 return Client.wait(client, jobs, timeout)
134
134
135 def connect_client(self):
135 def connect_client(self):
136 """connect a client with my Context, and track its sockets for cleanup"""
136 """connect a client with my Context, and track its sockets for cleanup"""
137 c = Client(profile='iptest', context=self.context)
137 c = Client(profile='iptest', context=self.context)
138 c.wait = lambda *a, **kw: self.client_wait(c, *a, **kw)
138 c.wait = lambda *a, **kw: self.client_wait(c, *a, **kw)
139
139
140 for name in filter(lambda n:n.endswith('socket'), dir(c)):
140 for name in filter(lambda n:n.endswith('socket'), dir(c)):
141 s = getattr(c, name)
141 s = getattr(c, name)
142 s.setsockopt(zmq.LINGER, 0)
142 s.setsockopt(zmq.LINGER, 0)
143 self.sockets.append(s)
143 self.sockets.append(s)
144 return c
144 return c
145
145
146 def assertRaisesRemote(self, etype, f, *args, **kwargs):
146 def assertRaisesRemote(self, etype, f, *args, **kwargs):
147 try:
147 try:
148 try:
148 try:
149 f(*args, **kwargs)
149 f(*args, **kwargs)
150 except error.CompositeError as e:
150 except error.CompositeError as e:
151 e.raise_exception()
151 e.raise_exception()
152 except error.RemoteError as e:
152 except error.RemoteError as e:
153 self.assertEqual(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
153 self.assertEqual(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
154 else:
154 else:
155 self.fail("should have raised a RemoteError")
155 self.fail("should have raised a RemoteError")
156
156
157 def _wait_for(self, f, timeout=10):
157 def _wait_for(self, f, timeout=10):
158 """wait for a condition"""
158 """wait for a condition"""
159 tic = time.time()
159 tic = time.time()
160 while time.time() <= tic + timeout:
160 while time.time() <= tic + timeout:
161 if f():
161 if f():
162 return
162 return
163 time.sleep(0.1)
163 time.sleep(0.1)
164 self.client.spin()
164 self.client.spin()
165 if not f():
165 if not f():
166 print("Warning: Awaited condition never arrived")
166 print("Warning: Awaited condition never arrived")
167
167
168 def setUp(self):
168 def setUp(self):
169 BaseZMQTestCase.setUp(self)
169 BaseZMQTestCase.setUp(self)
170 self.client = self.connect_client()
170 self.client = self.connect_client()
171 # start every test with clean engine namespaces:
171 # start every test with clean engine namespaces:
172 self.client.clear(block=True)
172 self.client.clear(block=True)
173 self.base_engine_count=len(self.client.ids)
173 self.base_engine_count=len(self.client.ids)
174 self.engines=[]
174 self.engines=[]
175
175
176 def tearDown(self):
176 def tearDown(self):
177 # self.client.clear(block=True)
177 # self.client.clear(block=True)
178 # close fds:
178 # close fds:
179 for e in filter(lambda e: e.poll() is not None, launchers):
179 for e in filter(lambda e: e.poll() is not None, launchers):
180 launchers.remove(e)
180 launchers.remove(e)
181
181
182 # allow flushing of incoming messages to prevent crash on socket close
182 # allow flushing of incoming messages to prevent crash on socket close
183 self.client.wait(timeout=2)
183 self.client.wait(timeout=2)
184 # time.sleep(2)
184 # time.sleep(2)
185 self.client.spin()
185 self.client.spin()
186 self.client.close()
186 self.client.close()
187 BaseZMQTestCase.tearDown(self)
187 BaseZMQTestCase.tearDown(self)
188 # this will be redundant when pyzmq merges PR #88
188 # this will be redundant when pyzmq merges PR #88
189 # self.context.term()
189 # self.context.term()
190 # print tempfile.TemporaryFile().fileno(),
190 # print tempfile.TemporaryFile().fileno(),
191 # sys.stdout.flush()
191 # sys.stdout.flush()
192
192
@@ -1,389 +1,389 b''
1 """Some generic utilities for dealing with classes, urls, and serialization."""
1 """Some generic utilities for dealing with classes, urls, and serialization."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import logging
6 import logging
7 import os
7 import os
8 import re
8 import re
9 import stat
9 import stat
10 import socket
10 import socket
11 import sys
11 import sys
12 import warnings
12 import warnings
13 from signal import signal, SIGINT, SIGABRT, SIGTERM
13 from signal import signal, SIGINT, SIGABRT, SIGTERM
14 try:
14 try:
15 from signal import SIGKILL
15 from signal import SIGKILL
16 except ImportError:
16 except ImportError:
17 SIGKILL=None
17 SIGKILL=None
18 from types import FunctionType
18 from types import FunctionType
19
19
20 try:
20 try:
21 import cPickle
21 import cPickle
22 pickle = cPickle
22 pickle = cPickle
23 except:
23 except:
24 cPickle = None
24 cPickle = None
25 import pickle
25 import pickle
26
26
27 import zmq
27 import zmq
28 from zmq.log import handlers
28 from zmq.log import handlers
29
29
30 from IPython.utils.log import get_logger
30 from IPython.utils.log import get_logger
31 from IPython.external.decorator import decorator
31 from decorator import decorator
32
32
33 from IPython.config.application import Application
33 from IPython.config.application import Application
34 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
34 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
35 from IPython.utils.py3compat import string_types, iteritems, itervalues
35 from IPython.utils.py3compat import string_types, iteritems, itervalues
36 from IPython.kernel.zmq.log import EnginePUBHandler
36 from IPython.kernel.zmq.log import EnginePUBHandler
37
37
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Classes
40 # Classes
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42
42
43 class Namespace(dict):
43 class Namespace(dict):
44 """Subclass of dict for attribute access to keys."""
44 """Subclass of dict for attribute access to keys."""
45
45
46 def __getattr__(self, key):
46 def __getattr__(self, key):
47 """getattr aliased to getitem"""
47 """getattr aliased to getitem"""
48 if key in self:
48 if key in self:
49 return self[key]
49 return self[key]
50 else:
50 else:
51 raise NameError(key)
51 raise NameError(key)
52
52
53 def __setattr__(self, key, value):
53 def __setattr__(self, key, value):
54 """setattr aliased to setitem, with strict"""
54 """setattr aliased to setitem, with strict"""
55 if hasattr(dict, key):
55 if hasattr(dict, key):
56 raise KeyError("Cannot override dict keys %r"%key)
56 raise KeyError("Cannot override dict keys %r"%key)
57 self[key] = value
57 self[key] = value
58
58
59
59
60 class ReverseDict(dict):
60 class ReverseDict(dict):
61 """simple double-keyed subset of dict methods."""
61 """simple double-keyed subset of dict methods."""
62
62
63 def __init__(self, *args, **kwargs):
63 def __init__(self, *args, **kwargs):
64 dict.__init__(self, *args, **kwargs)
64 dict.__init__(self, *args, **kwargs)
65 self._reverse = dict()
65 self._reverse = dict()
66 for key, value in iteritems(self):
66 for key, value in iteritems(self):
67 self._reverse[value] = key
67 self._reverse[value] = key
68
68
69 def __getitem__(self, key):
69 def __getitem__(self, key):
70 try:
70 try:
71 return dict.__getitem__(self, key)
71 return dict.__getitem__(self, key)
72 except KeyError:
72 except KeyError:
73 return self._reverse[key]
73 return self._reverse[key]
74
74
75 def __setitem__(self, key, value):
75 def __setitem__(self, key, value):
76 if key in self._reverse:
76 if key in self._reverse:
77 raise KeyError("Can't have key %r on both sides!"%key)
77 raise KeyError("Can't have key %r on both sides!"%key)
78 dict.__setitem__(self, key, value)
78 dict.__setitem__(self, key, value)
79 self._reverse[value] = key
79 self._reverse[value] = key
80
80
81 def pop(self, key):
81 def pop(self, key):
82 value = dict.pop(self, key)
82 value = dict.pop(self, key)
83 self._reverse.pop(value)
83 self._reverse.pop(value)
84 return value
84 return value
85
85
86 def get(self, key, default=None):
86 def get(self, key, default=None):
87 try:
87 try:
88 return self[key]
88 return self[key]
89 except KeyError:
89 except KeyError:
90 return default
90 return default
91
91
92 #-----------------------------------------------------------------------------
92 #-----------------------------------------------------------------------------
93 # Functions
93 # Functions
94 #-----------------------------------------------------------------------------
94 #-----------------------------------------------------------------------------
95
95
96 @decorator
96 @decorator
97 def log_errors(f, self, *args, **kwargs):
97 def log_errors(f, self, *args, **kwargs):
98 """decorator to log unhandled exceptions raised in a method.
98 """decorator to log unhandled exceptions raised in a method.
99
99
100 For use wrapping on_recv callbacks, so that exceptions
100 For use wrapping on_recv callbacks, so that exceptions
101 do not cause the stream to be closed.
101 do not cause the stream to be closed.
102 """
102 """
103 try:
103 try:
104 return f(self, *args, **kwargs)
104 return f(self, *args, **kwargs)
105 except Exception:
105 except Exception:
106 self.log.error("Uncaught exception in %r" % f, exc_info=True)
106 self.log.error("Uncaught exception in %r" % f, exc_info=True)
107
107
108
108
109 def is_url(url):
109 def is_url(url):
110 """boolean check for whether a string is a zmq url"""
110 """boolean check for whether a string is a zmq url"""
111 if '://' not in url:
111 if '://' not in url:
112 return False
112 return False
113 proto, addr = url.split('://', 1)
113 proto, addr = url.split('://', 1)
114 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
114 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
115 return False
115 return False
116 return True
116 return True
117
117
118 def validate_url(url):
118 def validate_url(url):
119 """validate a url for zeromq"""
119 """validate a url for zeromq"""
120 if not isinstance(url, string_types):
120 if not isinstance(url, string_types):
121 raise TypeError("url must be a string, not %r"%type(url))
121 raise TypeError("url must be a string, not %r"%type(url))
122 url = url.lower()
122 url = url.lower()
123
123
124 proto_addr = url.split('://')
124 proto_addr = url.split('://')
125 assert len(proto_addr) == 2, 'Invalid url: %r'%url
125 assert len(proto_addr) == 2, 'Invalid url: %r'%url
126 proto, addr = proto_addr
126 proto, addr = proto_addr
127 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
127 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
128
128
129 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
129 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
130 # author: Remi Sabourin
130 # author: Remi Sabourin
131 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
131 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
132
132
133 if proto == 'tcp':
133 if proto == 'tcp':
134 lis = addr.split(':')
134 lis = addr.split(':')
135 assert len(lis) == 2, 'Invalid url: %r'%url
135 assert len(lis) == 2, 'Invalid url: %r'%url
136 addr,s_port = lis
136 addr,s_port = lis
137 try:
137 try:
138 port = int(s_port)
138 port = int(s_port)
139 except ValueError:
139 except ValueError:
140 raise AssertionError("Invalid port %r in url: %r"%(port, url))
140 raise AssertionError("Invalid port %r in url: %r"%(port, url))
141
141
142 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
142 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
143
143
144 else:
144 else:
145 # only validate tcp urls currently
145 # only validate tcp urls currently
146 pass
146 pass
147
147
148 return True
148 return True
149
149
150
150
151 def validate_url_container(container):
151 def validate_url_container(container):
152 """validate a potentially nested collection of urls."""
152 """validate a potentially nested collection of urls."""
153 if isinstance(container, string_types):
153 if isinstance(container, string_types):
154 url = container
154 url = container
155 return validate_url(url)
155 return validate_url(url)
156 elif isinstance(container, dict):
156 elif isinstance(container, dict):
157 container = itervalues(container)
157 container = itervalues(container)
158
158
159 for element in container:
159 for element in container:
160 validate_url_container(element)
160 validate_url_container(element)
161
161
162
162
163 def split_url(url):
163 def split_url(url):
164 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
164 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
165 proto_addr = url.split('://')
165 proto_addr = url.split('://')
166 assert len(proto_addr) == 2, 'Invalid url: %r'%url
166 assert len(proto_addr) == 2, 'Invalid url: %r'%url
167 proto, addr = proto_addr
167 proto, addr = proto_addr
168 lis = addr.split(':')
168 lis = addr.split(':')
169 assert len(lis) == 2, 'Invalid url: %r'%url
169 assert len(lis) == 2, 'Invalid url: %r'%url
170 addr,s_port = lis
170 addr,s_port = lis
171 return proto,addr,s_port
171 return proto,addr,s_port
172
172
173
173
174 def disambiguate_ip_address(ip, location=None):
174 def disambiguate_ip_address(ip, location=None):
175 """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address
175 """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address
176
176
177 Explicit IP addresses are returned unmodified.
177 Explicit IP addresses are returned unmodified.
178
178
179 Parameters
179 Parameters
180 ----------
180 ----------
181
181
182 ip : IP address
182 ip : IP address
183 An IP address, or the special values 0.0.0.0, or *
183 An IP address, or the special values 0.0.0.0, or *
184 location: IP address, optional
184 location: IP address, optional
185 A public IP of the target machine.
185 A public IP of the target machine.
186 If location is an IP of the current machine,
186 If location is an IP of the current machine,
187 localhost will be returned,
187 localhost will be returned,
188 otherwise location will be returned.
188 otherwise location will be returned.
189 """
189 """
190 if ip in {'0.0.0.0', '*'}:
190 if ip in {'0.0.0.0', '*'}:
191 if not location:
191 if not location:
192 # unspecified location, localhost is the only choice
192 # unspecified location, localhost is the only choice
193 ip = localhost()
193 ip = localhost()
194 elif is_public_ip(location):
194 elif is_public_ip(location):
195 # location is a public IP on this machine, use localhost
195 # location is a public IP on this machine, use localhost
196 ip = localhost()
196 ip = localhost()
197 elif not public_ips():
197 elif not public_ips():
198 # this machine's public IPs cannot be determined,
198 # this machine's public IPs cannot be determined,
199 # assume `location` is not this machine
199 # assume `location` is not this machine
200 warnings.warn("IPython could not determine public IPs", RuntimeWarning)
200 warnings.warn("IPython could not determine public IPs", RuntimeWarning)
201 ip = location
201 ip = location
202 else:
202 else:
203 # location is not this machine, do not use loopback
203 # location is not this machine, do not use loopback
204 ip = location
204 ip = location
205 return ip
205 return ip
206
206
207
207
208 def disambiguate_url(url, location=None):
208 def disambiguate_url(url, location=None):
209 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
209 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
210 ones, based on the location (default interpretation is localhost).
210 ones, based on the location (default interpretation is localhost).
211
211
212 This is for zeromq urls, such as ``tcp://*:10101``.
212 This is for zeromq urls, such as ``tcp://*:10101``.
213 """
213 """
214 try:
214 try:
215 proto,ip,port = split_url(url)
215 proto,ip,port = split_url(url)
216 except AssertionError:
216 except AssertionError:
217 # probably not tcp url; could be ipc, etc.
217 # probably not tcp url; could be ipc, etc.
218 return url
218 return url
219
219
220 ip = disambiguate_ip_address(ip,location)
220 ip = disambiguate_ip_address(ip,location)
221
221
222 return "%s://%s:%s"%(proto,ip,port)
222 return "%s://%s:%s"%(proto,ip,port)
223
223
224
224
225 #--------------------------------------------------------------------------
225 #--------------------------------------------------------------------------
226 # helpers for implementing old MEC API via view.apply
226 # helpers for implementing old MEC API via view.apply
227 #--------------------------------------------------------------------------
227 #--------------------------------------------------------------------------
228
228
229 def interactive(f):
229 def interactive(f):
230 """decorator for making functions appear as interactively defined.
230 """decorator for making functions appear as interactively defined.
231 This results in the function being linked to the user_ns as globals()
231 This results in the function being linked to the user_ns as globals()
232 instead of the module globals().
232 instead of the module globals().
233 """
233 """
234
234
235 # build new FunctionType, so it can have the right globals
235 # build new FunctionType, so it can have the right globals
236 # interactive functions never have closures, that's kind of the point
236 # interactive functions never have closures, that's kind of the point
237 if isinstance(f, FunctionType):
237 if isinstance(f, FunctionType):
238 mainmod = __import__('__main__')
238 mainmod = __import__('__main__')
239 f = FunctionType(f.__code__, mainmod.__dict__,
239 f = FunctionType(f.__code__, mainmod.__dict__,
240 f.__name__, f.__defaults__,
240 f.__name__, f.__defaults__,
241 )
241 )
242 # associate with __main__ for uncanning
242 # associate with __main__ for uncanning
243 f.__module__ = '__main__'
243 f.__module__ = '__main__'
244 return f
244 return f
245
245
246 @interactive
246 @interactive
247 def _push(**ns):
247 def _push(**ns):
248 """helper method for implementing `client.push` via `client.apply`"""
248 """helper method for implementing `client.push` via `client.apply`"""
249 user_ns = globals()
249 user_ns = globals()
250 tmp = '_IP_PUSH_TMP_'
250 tmp = '_IP_PUSH_TMP_'
251 while tmp in user_ns:
251 while tmp in user_ns:
252 tmp = tmp + '_'
252 tmp = tmp + '_'
253 try:
253 try:
254 for name, value in ns.items():
254 for name, value in ns.items():
255 user_ns[tmp] = value
255 user_ns[tmp] = value
256 exec("%s = %s" % (name, tmp), user_ns)
256 exec("%s = %s" % (name, tmp), user_ns)
257 finally:
257 finally:
258 user_ns.pop(tmp, None)
258 user_ns.pop(tmp, None)
259
259
260 @interactive
260 @interactive
261 def _pull(keys):
261 def _pull(keys):
262 """helper method for implementing `client.pull` via `client.apply`"""
262 """helper method for implementing `client.pull` via `client.apply`"""
263 if isinstance(keys, (list,tuple, set)):
263 if isinstance(keys, (list,tuple, set)):
264 return [eval(key, globals()) for key in keys]
264 return [eval(key, globals()) for key in keys]
265 else:
265 else:
266 return eval(keys, globals())
266 return eval(keys, globals())
267
267
268 @interactive
268 @interactive
269 def _execute(code):
269 def _execute(code):
270 """helper method for implementing `client.execute` via `client.apply`"""
270 """helper method for implementing `client.execute` via `client.apply`"""
271 exec(code, globals())
271 exec(code, globals())
272
272
273 #--------------------------------------------------------------------------
273 #--------------------------------------------------------------------------
274 # extra process management utilities
274 # extra process management utilities
275 #--------------------------------------------------------------------------
275 #--------------------------------------------------------------------------
276
276
277 _random_ports = set()
277 _random_ports = set()
278
278
279 def select_random_ports(n):
279 def select_random_ports(n):
280 """Selects and return n random ports that are available."""
280 """Selects and return n random ports that are available."""
281 ports = []
281 ports = []
282 for i in range(n):
282 for i in range(n):
283 sock = socket.socket()
283 sock = socket.socket()
284 sock.bind(('', 0))
284 sock.bind(('', 0))
285 while sock.getsockname()[1] in _random_ports:
285 while sock.getsockname()[1] in _random_ports:
286 sock.close()
286 sock.close()
287 sock = socket.socket()
287 sock = socket.socket()
288 sock.bind(('', 0))
288 sock.bind(('', 0))
289 ports.append(sock)
289 ports.append(sock)
290 for i, sock in enumerate(ports):
290 for i, sock in enumerate(ports):
291 port = sock.getsockname()[1]
291 port = sock.getsockname()[1]
292 sock.close()
292 sock.close()
293 ports[i] = port
293 ports[i] = port
294 _random_ports.add(port)
294 _random_ports.add(port)
295 return ports
295 return ports
296
296
297 def signal_children(children):
297 def signal_children(children):
298 """Relay interupt/term signals to children, for more solid process cleanup."""
298 """Relay interupt/term signals to children, for more solid process cleanup."""
299 def terminate_children(sig, frame):
299 def terminate_children(sig, frame):
300 log = get_logger()
300 log = get_logger()
301 log.critical("Got signal %i, terminating children..."%sig)
301 log.critical("Got signal %i, terminating children..."%sig)
302 for child in children:
302 for child in children:
303 child.terminate()
303 child.terminate()
304
304
305 sys.exit(sig != SIGINT)
305 sys.exit(sig != SIGINT)
306 # sys.exit(sig)
306 # sys.exit(sig)
307 for sig in (SIGINT, SIGABRT, SIGTERM):
307 for sig in (SIGINT, SIGABRT, SIGTERM):
308 signal(sig, terminate_children)
308 signal(sig, terminate_children)
309
309
310 def generate_exec_key(keyfile):
310 def generate_exec_key(keyfile):
311 import uuid
311 import uuid
312 newkey = str(uuid.uuid4())
312 newkey = str(uuid.uuid4())
313 with open(keyfile, 'w') as f:
313 with open(keyfile, 'w') as f:
314 # f.write('ipython-key ')
314 # f.write('ipython-key ')
315 f.write(newkey+'\n')
315 f.write(newkey+'\n')
316 # set user-only RW permissions (0600)
316 # set user-only RW permissions (0600)
317 # this will have no effect on Windows
317 # this will have no effect on Windows
318 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
318 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
319
319
320
320
321 def integer_loglevel(loglevel):
321 def integer_loglevel(loglevel):
322 try:
322 try:
323 loglevel = int(loglevel)
323 loglevel = int(loglevel)
324 except ValueError:
324 except ValueError:
325 if isinstance(loglevel, str):
325 if isinstance(loglevel, str):
326 loglevel = getattr(logging, loglevel)
326 loglevel = getattr(logging, loglevel)
327 return loglevel
327 return loglevel
328
328
329 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
329 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
330 logger = logging.getLogger(logname)
330 logger = logging.getLogger(logname)
331 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
331 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
332 # don't add a second PUBHandler
332 # don't add a second PUBHandler
333 return
333 return
334 loglevel = integer_loglevel(loglevel)
334 loglevel = integer_loglevel(loglevel)
335 lsock = context.socket(zmq.PUB)
335 lsock = context.socket(zmq.PUB)
336 lsock.connect(iface)
336 lsock.connect(iface)
337 handler = handlers.PUBHandler(lsock)
337 handler = handlers.PUBHandler(lsock)
338 handler.setLevel(loglevel)
338 handler.setLevel(loglevel)
339 handler.root_topic = root
339 handler.root_topic = root
340 logger.addHandler(handler)
340 logger.addHandler(handler)
341 logger.setLevel(loglevel)
341 logger.setLevel(loglevel)
342
342
343 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
343 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
344 logger = logging.getLogger()
344 logger = logging.getLogger()
345 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
345 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
346 # don't add a second PUBHandler
346 # don't add a second PUBHandler
347 return
347 return
348 loglevel = integer_loglevel(loglevel)
348 loglevel = integer_loglevel(loglevel)
349 lsock = context.socket(zmq.PUB)
349 lsock = context.socket(zmq.PUB)
350 lsock.connect(iface)
350 lsock.connect(iface)
351 handler = EnginePUBHandler(engine, lsock)
351 handler = EnginePUBHandler(engine, lsock)
352 handler.setLevel(loglevel)
352 handler.setLevel(loglevel)
353 logger.addHandler(handler)
353 logger.addHandler(handler)
354 logger.setLevel(loglevel)
354 logger.setLevel(loglevel)
355 return logger
355 return logger
356
356
357 def local_logger(logname, loglevel=logging.DEBUG):
357 def local_logger(logname, loglevel=logging.DEBUG):
358 loglevel = integer_loglevel(loglevel)
358 loglevel = integer_loglevel(loglevel)
359 logger = logging.getLogger(logname)
359 logger = logging.getLogger(logname)
360 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
360 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
361 # don't add a second StreamHandler
361 # don't add a second StreamHandler
362 return
362 return
363 handler = logging.StreamHandler()
363 handler = logging.StreamHandler()
364 handler.setLevel(loglevel)
364 handler.setLevel(loglevel)
365 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
365 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
366 datefmt="%Y-%m-%d %H:%M:%S")
366 datefmt="%Y-%m-%d %H:%M:%S")
367 handler.setFormatter(formatter)
367 handler.setFormatter(formatter)
368
368
369 logger.addHandler(handler)
369 logger.addHandler(handler)
370 logger.setLevel(loglevel)
370 logger.setLevel(loglevel)
371 return logger
371 return logger
372
372
373 def set_hwm(sock, hwm=0):
373 def set_hwm(sock, hwm=0):
374 """set zmq High Water Mark on a socket
374 """set zmq High Water Mark on a socket
375
375
376 in a way that always works for various pyzmq / libzmq versions.
376 in a way that always works for various pyzmq / libzmq versions.
377 """
377 """
378 import zmq
378 import zmq
379
379
380 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
380 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
381 opt = getattr(zmq, key, None)
381 opt = getattr(zmq, key, None)
382 if opt is None:
382 if opt is None:
383 continue
383 continue
384 try:
384 try:
385 sock.setsockopt(opt, hwm)
385 sock.setsockopt(opt, hwm)
386 except zmq.ZMQError:
386 except zmq.ZMQError:
387 pass
387 pass
388
388
389
389
@@ -1,400 +1,400 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Decorators for labeling test objects.
2 """Decorators for labeling test objects.
3
3
4 Decorators that merely return a modified version of the original function
4 Decorators that merely return a modified version of the original function
5 object are straightforward. Decorators that return a new function object need
5 object are straightforward. Decorators that return a new function object need
6 to use nose.tools.make_decorator(original_function)(decorator) in returning the
6 to use nose.tools.make_decorator(original_function)(decorator) in returning the
7 decorator, in order to preserve metadata such as function name, setup and
7 decorator, in order to preserve metadata such as function name, setup and
8 teardown functions and so on - see nose.tools for more information.
8 teardown functions and so on - see nose.tools for more information.
9
9
10 This module provides a set of useful decorators meant to be ready to use in
10 This module provides a set of useful decorators meant to be ready to use in
11 your own tests. See the bottom of the file for the ready-made ones, and if you
11 your own tests. See the bottom of the file for the ready-made ones, and if you
12 find yourself writing a new one that may be of generic use, add it here.
12 find yourself writing a new one that may be of generic use, add it here.
13
13
14 Included decorators:
14 Included decorators:
15
15
16
16
17 Lightweight testing that remains unittest-compatible.
17 Lightweight testing that remains unittest-compatible.
18
18
19 - An @as_unittest decorator can be used to tag any normal parameter-less
19 - An @as_unittest decorator can be used to tag any normal parameter-less
20 function as a unittest TestCase. Then, both nose and normal unittest will
20 function as a unittest TestCase. Then, both nose and normal unittest will
21 recognize it as such. This will make it easier to migrate away from Nose if
21 recognize it as such. This will make it easier to migrate away from Nose if
22 we ever need/want to while maintaining very lightweight tests.
22 we ever need/want to while maintaining very lightweight tests.
23
23
24 NOTE: This file contains IPython-specific decorators. Using the machinery in
24 NOTE: This file contains IPython-specific decorators. Using the machinery in
25 IPython.external.decorators, we import either numpy.testing.decorators if numpy is
25 IPython.external.decorators, we import either numpy.testing.decorators if numpy is
26 available, OR use equivalent code in IPython.external._decorators, which
26 available, OR use equivalent code in IPython.external._decorators, which
27 we've copied verbatim from numpy.
27 we've copied verbatim from numpy.
28
28
29 Authors
29 Authors
30 -------
30 -------
31
31
32 - Fernando Perez <Fernando.Perez@berkeley.edu>
32 - Fernando Perez <Fernando.Perez@berkeley.edu>
33 """
33 """
34
34
35 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
36 # Copyright (C) 2009-2011 The IPython Development Team
36 # Copyright (C) 2009-2011 The IPython Development Team
37 #
37 #
38 # Distributed under the terms of the BSD License. The full license is in
38 # Distributed under the terms of the BSD License. The full license is in
39 # the file COPYING, distributed as part of this software.
39 # the file COPYING, distributed as part of this software.
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41
41
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43 # Imports
43 # Imports
44 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
45
45
46 # Stdlib imports
46 # Stdlib imports
47 import sys
47 import sys
48 import os
48 import os
49 import tempfile
49 import tempfile
50 import unittest
50 import unittest
51
51
52 # Third-party imports
52 # Third-party imports
53
53
54 # This is Michele Simionato's decorator module, kept verbatim.
54 # This is Michele Simionato's decorator module, kept verbatim.
55 from IPython.external.decorator import decorator
55 from decorator import decorator
56
56
57 # Expose the unittest-driven decorators
57 # Expose the unittest-driven decorators
58 from .ipunittest import ipdoctest, ipdocstring
58 from .ipunittest import ipdoctest, ipdocstring
59
59
60 # Grab the numpy-specific decorators which we keep in a file that we
60 # Grab the numpy-specific decorators which we keep in a file that we
61 # occasionally update from upstream: decorators.py is a copy of
61 # occasionally update from upstream: decorators.py is a copy of
62 # numpy.testing.decorators, we expose all of it here.
62 # numpy.testing.decorators, we expose all of it here.
63 from IPython.external.decorators import *
63 from IPython.external.decorators import *
64
64
65 # For onlyif_cmd_exists decorator
65 # For onlyif_cmd_exists decorator
66 from IPython.utils.process import is_cmd_found
66 from IPython.utils.process import is_cmd_found
67 from IPython.utils.py3compat import string_types
67 from IPython.utils.py3compat import string_types
68
68
69 #-----------------------------------------------------------------------------
69 #-----------------------------------------------------------------------------
70 # Classes and functions
70 # Classes and functions
71 #-----------------------------------------------------------------------------
71 #-----------------------------------------------------------------------------
72
72
73 # Simple example of the basic idea
73 # Simple example of the basic idea
74 def as_unittest(func):
74 def as_unittest(func):
75 """Decorator to make a simple function into a normal test via unittest."""
75 """Decorator to make a simple function into a normal test via unittest."""
76 class Tester(unittest.TestCase):
76 class Tester(unittest.TestCase):
77 def test(self):
77 def test(self):
78 func()
78 func()
79
79
80 Tester.__name__ = func.__name__
80 Tester.__name__ = func.__name__
81
81
82 return Tester
82 return Tester
83
83
84 # Utility functions
84 # Utility functions
85
85
86 def apply_wrapper(wrapper,func):
86 def apply_wrapper(wrapper,func):
87 """Apply a wrapper to a function for decoration.
87 """Apply a wrapper to a function for decoration.
88
88
89 This mixes Michele Simionato's decorator tool with nose's make_decorator,
89 This mixes Michele Simionato's decorator tool with nose's make_decorator,
90 to apply a wrapper in a decorator so that all nose attributes, as well as
90 to apply a wrapper in a decorator so that all nose attributes, as well as
91 function signature and other properties, survive the decoration cleanly.
91 function signature and other properties, survive the decoration cleanly.
92 This will ensure that wrapped functions can still be well introspected via
92 This will ensure that wrapped functions can still be well introspected via
93 IPython, for example.
93 IPython, for example.
94 """
94 """
95 import nose.tools
95 import nose.tools
96
96
97 return decorator(wrapper,nose.tools.make_decorator(func)(wrapper))
97 return decorator(wrapper,nose.tools.make_decorator(func)(wrapper))
98
98
99
99
100 def make_label_dec(label,ds=None):
100 def make_label_dec(label,ds=None):
101 """Factory function to create a decorator that applies one or more labels.
101 """Factory function to create a decorator that applies one or more labels.
102
102
103 Parameters
103 Parameters
104 ----------
104 ----------
105 label : string or sequence
105 label : string or sequence
106 One or more labels that will be applied by the decorator to the functions
106 One or more labels that will be applied by the decorator to the functions
107 it decorates. Labels are attributes of the decorated function with their
107 it decorates. Labels are attributes of the decorated function with their
108 value set to True.
108 value set to True.
109
109
110 ds : string
110 ds : string
111 An optional docstring for the resulting decorator. If not given, a
111 An optional docstring for the resulting decorator. If not given, a
112 default docstring is auto-generated.
112 default docstring is auto-generated.
113
113
114 Returns
114 Returns
115 -------
115 -------
116 A decorator.
116 A decorator.
117
117
118 Examples
118 Examples
119 --------
119 --------
120
120
121 A simple labeling decorator:
121 A simple labeling decorator:
122
122
123 >>> slow = make_label_dec('slow')
123 >>> slow = make_label_dec('slow')
124 >>> slow.__doc__
124 >>> slow.__doc__
125 "Labels a test as 'slow'."
125 "Labels a test as 'slow'."
126
126
127 And one that uses multiple labels and a custom docstring:
127 And one that uses multiple labels and a custom docstring:
128
128
129 >>> rare = make_label_dec(['slow','hard'],
129 >>> rare = make_label_dec(['slow','hard'],
130 ... "Mix labels 'slow' and 'hard' for rare tests.")
130 ... "Mix labels 'slow' and 'hard' for rare tests.")
131 >>> rare.__doc__
131 >>> rare.__doc__
132 "Mix labels 'slow' and 'hard' for rare tests."
132 "Mix labels 'slow' and 'hard' for rare tests."
133
133
134 Now, let's test using this one:
134 Now, let's test using this one:
135 >>> @rare
135 >>> @rare
136 ... def f(): pass
136 ... def f(): pass
137 ...
137 ...
138 >>>
138 >>>
139 >>> f.slow
139 >>> f.slow
140 True
140 True
141 >>> f.hard
141 >>> f.hard
142 True
142 True
143 """
143 """
144
144
145 if isinstance(label, string_types):
145 if isinstance(label, string_types):
146 labels = [label]
146 labels = [label]
147 else:
147 else:
148 labels = label
148 labels = label
149
149
150 # Validate that the given label(s) are OK for use in setattr() by doing a
150 # Validate that the given label(s) are OK for use in setattr() by doing a
151 # dry run on a dummy function.
151 # dry run on a dummy function.
152 tmp = lambda : None
152 tmp = lambda : None
153 for label in labels:
153 for label in labels:
154 setattr(tmp,label,True)
154 setattr(tmp,label,True)
155
155
156 # This is the actual decorator we'll return
156 # This is the actual decorator we'll return
157 def decor(f):
157 def decor(f):
158 for label in labels:
158 for label in labels:
159 setattr(f,label,True)
159 setattr(f,label,True)
160 return f
160 return f
161
161
162 # Apply the user's docstring, or autogenerate a basic one
162 # Apply the user's docstring, or autogenerate a basic one
163 if ds is None:
163 if ds is None:
164 ds = "Labels a test as %r." % label
164 ds = "Labels a test as %r." % label
165 decor.__doc__ = ds
165 decor.__doc__ = ds
166
166
167 return decor
167 return decor
168
168
169
169
170 # Inspired by numpy's skipif, but uses the full apply_wrapper utility to
170 # Inspired by numpy's skipif, but uses the full apply_wrapper utility to
171 # preserve function metadata better and allows the skip condition to be a
171 # preserve function metadata better and allows the skip condition to be a
172 # callable.
172 # callable.
173 def skipif(skip_condition, msg=None):
173 def skipif(skip_condition, msg=None):
174 ''' Make function raise SkipTest exception if skip_condition is true
174 ''' Make function raise SkipTest exception if skip_condition is true
175
175
176 Parameters
176 Parameters
177 ----------
177 ----------
178
178
179 skip_condition : bool or callable
179 skip_condition : bool or callable
180 Flag to determine whether to skip test. If the condition is a
180 Flag to determine whether to skip test. If the condition is a
181 callable, it is used at runtime to dynamically make the decision. This
181 callable, it is used at runtime to dynamically make the decision. This
182 is useful for tests that may require costly imports, to delay the cost
182 is useful for tests that may require costly imports, to delay the cost
183 until the test suite is actually executed.
183 until the test suite is actually executed.
184 msg : string
184 msg : string
185 Message to give on raising a SkipTest exception.
185 Message to give on raising a SkipTest exception.
186
186
187 Returns
187 Returns
188 -------
188 -------
189 decorator : function
189 decorator : function
190 Decorator, which, when applied to a function, causes SkipTest
190 Decorator, which, when applied to a function, causes SkipTest
191 to be raised when the skip_condition was True, and the function
191 to be raised when the skip_condition was True, and the function
192 to be called normally otherwise.
192 to be called normally otherwise.
193
193
194 Notes
194 Notes
195 -----
195 -----
196 You will see from the code that we had to further decorate the
196 You will see from the code that we had to further decorate the
197 decorator with the nose.tools.make_decorator function in order to
197 decorator with the nose.tools.make_decorator function in order to
198 transmit function name, and various other metadata.
198 transmit function name, and various other metadata.
199 '''
199 '''
200
200
201 def skip_decorator(f):
201 def skip_decorator(f):
202 # Local import to avoid a hard nose dependency and only incur the
202 # Local import to avoid a hard nose dependency and only incur the
203 # import time overhead at actual test-time.
203 # import time overhead at actual test-time.
204 import nose
204 import nose
205
205
206 # Allow for both boolean or callable skip conditions.
206 # Allow for both boolean or callable skip conditions.
207 if callable(skip_condition):
207 if callable(skip_condition):
208 skip_val = skip_condition
208 skip_val = skip_condition
209 else:
209 else:
210 skip_val = lambda : skip_condition
210 skip_val = lambda : skip_condition
211
211
212 def get_msg(func,msg=None):
212 def get_msg(func,msg=None):
213 """Skip message with information about function being skipped."""
213 """Skip message with information about function being skipped."""
214 if msg is None: out = 'Test skipped due to test condition.'
214 if msg is None: out = 'Test skipped due to test condition.'
215 else: out = msg
215 else: out = msg
216 return "Skipping test: %s. %s" % (func.__name__,out)
216 return "Skipping test: %s. %s" % (func.__name__,out)
217
217
218 # We need to define *two* skippers because Python doesn't allow both
218 # We need to define *two* skippers because Python doesn't allow both
219 # return with value and yield inside the same function.
219 # return with value and yield inside the same function.
220 def skipper_func(*args, **kwargs):
220 def skipper_func(*args, **kwargs):
221 """Skipper for normal test functions."""
221 """Skipper for normal test functions."""
222 if skip_val():
222 if skip_val():
223 raise nose.SkipTest(get_msg(f,msg))
223 raise nose.SkipTest(get_msg(f,msg))
224 else:
224 else:
225 return f(*args, **kwargs)
225 return f(*args, **kwargs)
226
226
227 def skipper_gen(*args, **kwargs):
227 def skipper_gen(*args, **kwargs):
228 """Skipper for test generators."""
228 """Skipper for test generators."""
229 if skip_val():
229 if skip_val():
230 raise nose.SkipTest(get_msg(f,msg))
230 raise nose.SkipTest(get_msg(f,msg))
231 else:
231 else:
232 for x in f(*args, **kwargs):
232 for x in f(*args, **kwargs):
233 yield x
233 yield x
234
234
235 # Choose the right skipper to use when building the actual generator.
235 # Choose the right skipper to use when building the actual generator.
236 if nose.util.isgenerator(f):
236 if nose.util.isgenerator(f):
237 skipper = skipper_gen
237 skipper = skipper_gen
238 else:
238 else:
239 skipper = skipper_func
239 skipper = skipper_func
240
240
241 return nose.tools.make_decorator(f)(skipper)
241 return nose.tools.make_decorator(f)(skipper)
242
242
243 return skip_decorator
243 return skip_decorator
244
244
245 # A version with the condition set to true, common case just to attach a message
245 # A version with the condition set to true, common case just to attach a message
246 # to a skip decorator
246 # to a skip decorator
247 def skip(msg=None):
247 def skip(msg=None):
248 """Decorator factory - mark a test function for skipping from test suite.
248 """Decorator factory - mark a test function for skipping from test suite.
249
249
250 Parameters
250 Parameters
251 ----------
251 ----------
252 msg : string
252 msg : string
253 Optional message to be added.
253 Optional message to be added.
254
254
255 Returns
255 Returns
256 -------
256 -------
257 decorator : function
257 decorator : function
258 Decorator, which, when applied to a function, causes SkipTest
258 Decorator, which, when applied to a function, causes SkipTest
259 to be raised, with the optional message added.
259 to be raised, with the optional message added.
260 """
260 """
261
261
262 return skipif(True,msg)
262 return skipif(True,msg)
263
263
264
264
265 def onlyif(condition, msg):
265 def onlyif(condition, msg):
266 """The reverse from skipif, see skipif for details."""
266 """The reverse from skipif, see skipif for details."""
267
267
268 if callable(condition):
268 if callable(condition):
269 skip_condition = lambda : not condition()
269 skip_condition = lambda : not condition()
270 else:
270 else:
271 skip_condition = lambda : not condition
271 skip_condition = lambda : not condition
272
272
273 return skipif(skip_condition, msg)
273 return skipif(skip_condition, msg)
274
274
275 #-----------------------------------------------------------------------------
275 #-----------------------------------------------------------------------------
276 # Utility functions for decorators
276 # Utility functions for decorators
277 def module_not_available(module):
277 def module_not_available(module):
278 """Can module be imported? Returns true if module does NOT import.
278 """Can module be imported? Returns true if module does NOT import.
279
279
280 This is used to make a decorator to skip tests that require module to be
280 This is used to make a decorator to skip tests that require module to be
281 available, but delay the 'import numpy' to test execution time.
281 available, but delay the 'import numpy' to test execution time.
282 """
282 """
283 try:
283 try:
284 mod = __import__(module)
284 mod = __import__(module)
285 mod_not_avail = False
285 mod_not_avail = False
286 except ImportError:
286 except ImportError:
287 mod_not_avail = True
287 mod_not_avail = True
288
288
289 return mod_not_avail
289 return mod_not_avail
290
290
291
291
292 def decorated_dummy(dec, name):
292 def decorated_dummy(dec, name):
293 """Return a dummy function decorated with dec, with the given name.
293 """Return a dummy function decorated with dec, with the given name.
294
294
295 Examples
295 Examples
296 --------
296 --------
297 import IPython.testing.decorators as dec
297 import IPython.testing.decorators as dec
298 setup = dec.decorated_dummy(dec.skip_if_no_x11, __name__)
298 setup = dec.decorated_dummy(dec.skip_if_no_x11, __name__)
299 """
299 """
300 dummy = lambda: None
300 dummy = lambda: None
301 dummy.__name__ = name
301 dummy.__name__ = name
302 return dec(dummy)
302 return dec(dummy)
303
303
304 #-----------------------------------------------------------------------------
304 #-----------------------------------------------------------------------------
305 # Decorators for public use
305 # Decorators for public use
306
306
307 # Decorators to skip certain tests on specific platforms.
307 # Decorators to skip certain tests on specific platforms.
308 skip_win32 = skipif(sys.platform == 'win32',
308 skip_win32 = skipif(sys.platform == 'win32',
309 "This test does not run under Windows")
309 "This test does not run under Windows")
310 skip_linux = skipif(sys.platform.startswith('linux'),
310 skip_linux = skipif(sys.platform.startswith('linux'),
311 "This test does not run under Linux")
311 "This test does not run under Linux")
312 skip_osx = skipif(sys.platform == 'darwin',"This test does not run under OS X")
312 skip_osx = skipif(sys.platform == 'darwin',"This test does not run under OS X")
313
313
314
314
315 # Decorators to skip tests if not on specific platforms.
315 # Decorators to skip tests if not on specific platforms.
316 skip_if_not_win32 = skipif(sys.platform != 'win32',
316 skip_if_not_win32 = skipif(sys.platform != 'win32',
317 "This test only runs under Windows")
317 "This test only runs under Windows")
318 skip_if_not_linux = skipif(not sys.platform.startswith('linux'),
318 skip_if_not_linux = skipif(not sys.platform.startswith('linux'),
319 "This test only runs under Linux")
319 "This test only runs under Linux")
320 skip_if_not_osx = skipif(sys.platform != 'darwin',
320 skip_if_not_osx = skipif(sys.platform != 'darwin',
321 "This test only runs under OSX")
321 "This test only runs under OSX")
322
322
323
323
324 _x11_skip_cond = (sys.platform not in ('darwin', 'win32') and
324 _x11_skip_cond = (sys.platform not in ('darwin', 'win32') and
325 os.environ.get('DISPLAY', '') == '')
325 os.environ.get('DISPLAY', '') == '')
326 _x11_skip_msg = "Skipped under *nix when X11/XOrg not available"
326 _x11_skip_msg = "Skipped under *nix when X11/XOrg not available"
327
327
328 skip_if_no_x11 = skipif(_x11_skip_cond, _x11_skip_msg)
328 skip_if_no_x11 = skipif(_x11_skip_cond, _x11_skip_msg)
329
329
330 # not a decorator itself, returns a dummy function to be used as setup
330 # not a decorator itself, returns a dummy function to be used as setup
331 def skip_file_no_x11(name):
331 def skip_file_no_x11(name):
332 return decorated_dummy(skip_if_no_x11, name) if _x11_skip_cond else None
332 return decorated_dummy(skip_if_no_x11, name) if _x11_skip_cond else None
333
333
334 # Other skip decorators
334 # Other skip decorators
335
335
336 # generic skip without module
336 # generic skip without module
337 skip_without = lambda mod: skipif(module_not_available(mod), "This test requires %s" % mod)
337 skip_without = lambda mod: skipif(module_not_available(mod), "This test requires %s" % mod)
338
338
339 skipif_not_numpy = skip_without('numpy')
339 skipif_not_numpy = skip_without('numpy')
340
340
341 skipif_not_matplotlib = skip_without('matplotlib')
341 skipif_not_matplotlib = skip_without('matplotlib')
342
342
343 skipif_not_sympy = skip_without('sympy')
343 skipif_not_sympy = skip_without('sympy')
344
344
345 skip_known_failure = knownfailureif(True,'This test is known to fail')
345 skip_known_failure = knownfailureif(True,'This test is known to fail')
346
346
347 known_failure_py3 = knownfailureif(sys.version_info[0] >= 3,
347 known_failure_py3 = knownfailureif(sys.version_info[0] >= 3,
348 'This test is known to fail on Python 3.')
348 'This test is known to fail on Python 3.')
349
349
350 # A null 'decorator', useful to make more readable code that needs to pick
350 # A null 'decorator', useful to make more readable code that needs to pick
351 # between different decorators based on OS or other conditions
351 # between different decorators based on OS or other conditions
352 null_deco = lambda f: f
352 null_deco = lambda f: f
353
353
354 # Some tests only run where we can use unicode paths. Note that we can't just
354 # Some tests only run where we can use unicode paths. Note that we can't just
355 # check os.path.supports_unicode_filenames, which is always False on Linux.
355 # check os.path.supports_unicode_filenames, which is always False on Linux.
356 try:
356 try:
357 f = tempfile.NamedTemporaryFile(prefix=u"tmp€")
357 f = tempfile.NamedTemporaryFile(prefix=u"tmp€")
358 except UnicodeEncodeError:
358 except UnicodeEncodeError:
359 unicode_paths = False
359 unicode_paths = False
360 else:
360 else:
361 unicode_paths = True
361 unicode_paths = True
362 f.close()
362 f.close()
363
363
364 onlyif_unicode_paths = onlyif(unicode_paths, ("This test is only applicable "
364 onlyif_unicode_paths = onlyif(unicode_paths, ("This test is only applicable "
365 "where we can use unicode in filenames."))
365 "where we can use unicode in filenames."))
366
366
367
367
368 def onlyif_cmds_exist(*commands):
368 def onlyif_cmds_exist(*commands):
369 """
369 """
370 Decorator to skip test when at least one of `commands` is not found.
370 Decorator to skip test when at least one of `commands` is not found.
371 """
371 """
372 for cmd in commands:
372 for cmd in commands:
373 try:
373 try:
374 if not is_cmd_found(cmd):
374 if not is_cmd_found(cmd):
375 return skip("This test runs only if command '{0}' "
375 return skip("This test runs only if command '{0}' "
376 "is installed".format(cmd))
376 "is installed".format(cmd))
377 except ImportError as e:
377 except ImportError as e:
378 # is_cmd_found uses pywin32 on windows, which might not be available
378 # is_cmd_found uses pywin32 on windows, which might not be available
379 if sys.platform == 'win32' and 'pywin32' in str(e):
379 if sys.platform == 'win32' and 'pywin32' in str(e):
380 return skip("This test runs only if pywin32 and command '{0}' "
380 return skip("This test runs only if pywin32 and command '{0}' "
381 "is installed".format(cmd))
381 "is installed".format(cmd))
382 raise e
382 raise e
383 return null_deco
383 return null_deco
384
384
385 def onlyif_any_cmd_exists(*commands):
385 def onlyif_any_cmd_exists(*commands):
386 """
386 """
387 Decorator to skip test unless at least one of `commands` is found.
387 Decorator to skip test unless at least one of `commands` is found.
388 """
388 """
389 for cmd in commands:
389 for cmd in commands:
390 try:
390 try:
391 if is_cmd_found(cmd):
391 if is_cmd_found(cmd):
392 return null_deco
392 return null_deco
393 except ImportError as e:
393 except ImportError as e:
394 # is_cmd_found uses pywin32 on windows, which might not be available
394 # is_cmd_found uses pywin32 on windows, which might not be available
395 if sys.platform == 'win32' and 'pywin32' in str(e):
395 if sys.platform == 'win32' and 'pywin32' in str(e):
396 return skip("This test runs only if pywin32 and commands '{0}' "
396 return skip("This test runs only if pywin32 and commands '{0}' "
397 "are installed".format(commands))
397 "are installed".format(commands))
398 raise e
398 raise e
399 return skip("This test runs only if one of the commands {0} "
399 return skip("This test runs only if one of the commands {0} "
400 "is installed".format(commands))
400 "is installed".format(commands))
@@ -1,344 +1,345 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3 """Setup script for IPython.
3 """Setup script for IPython.
4
4
5 Under Posix environments it works like a typical setup.py script.
5 Under Posix environments it works like a typical setup.py script.
6 Under Windows, the command sdist is not supported, since IPython
6 Under Windows, the command sdist is not supported, since IPython
7 requires utilities which are not available under Windows."""
7 requires utilities which are not available under Windows."""
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Copyright (c) 2008-2011, IPython Development Team.
10 # Copyright (c) 2008-2011, IPython Development Team.
11 # Copyright (c) 2001-2007, Fernando Perez <fernando.perez@colorado.edu>
11 # Copyright (c) 2001-2007, Fernando Perez <fernando.perez@colorado.edu>
12 # Copyright (c) 2001, Janko Hauser <jhauser@zscout.de>
12 # Copyright (c) 2001, Janko Hauser <jhauser@zscout.de>
13 # Copyright (c) 2001, Nathaniel Gray <n8gray@caltech.edu>
13 # Copyright (c) 2001, Nathaniel Gray <n8gray@caltech.edu>
14 #
14 #
15 # Distributed under the terms of the Modified BSD License.
15 # Distributed under the terms of the Modified BSD License.
16 #
16 #
17 # The full license is in the file COPYING.rst, distributed with this software.
17 # The full license is in the file COPYING.rst, distributed with this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Minimal Python version sanity check
21 # Minimal Python version sanity check
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23 from __future__ import print_function
23 from __future__ import print_function
24
24
25 import sys
25 import sys
26
26
27 # This check is also made in IPython/__init__, don't forget to update both when
27 # This check is also made in IPython/__init__, don't forget to update both when
28 # changing Python version requirements.
28 # changing Python version requirements.
29 v = sys.version_info
29 v = sys.version_info
30 if v[:2] < (2,7) or (v[0] >= 3 and v[:2] < (3,3)):
30 if v[:2] < (2,7) or (v[0] >= 3 and v[:2] < (3,3)):
31 error = "ERROR: IPython requires Python version 2.7 or 3.3 or above."
31 error = "ERROR: IPython requires Python version 2.7 or 3.3 or above."
32 print(error, file=sys.stderr)
32 print(error, file=sys.stderr)
33 sys.exit(1)
33 sys.exit(1)
34
34
35 PY3 = (sys.version_info[0] >= 3)
35 PY3 = (sys.version_info[0] >= 3)
36
36
37 # At least we're on the python version we need, move on.
37 # At least we're on the python version we need, move on.
38
38
39 #-------------------------------------------------------------------------------
39 #-------------------------------------------------------------------------------
40 # Imports
40 # Imports
41 #-------------------------------------------------------------------------------
41 #-------------------------------------------------------------------------------
42
42
43 # Stdlib imports
43 # Stdlib imports
44 import os
44 import os
45 import shutil
45 import shutil
46
46
47 from glob import glob
47 from glob import glob
48
48
49 # BEFORE importing distutils, remove MANIFEST. distutils doesn't properly
49 # BEFORE importing distutils, remove MANIFEST. distutils doesn't properly
50 # update it when the contents of directories change.
50 # update it when the contents of directories change.
51 if os.path.exists('MANIFEST'): os.remove('MANIFEST')
51 if os.path.exists('MANIFEST'): os.remove('MANIFEST')
52
52
53 from distutils.core import setup
53 from distutils.core import setup
54
54
55 # Our own imports
55 # Our own imports
56 from setupbase import target_update
56 from setupbase import target_update
57
57
58 from setupbase import (
58 from setupbase import (
59 setup_args,
59 setup_args,
60 find_packages,
60 find_packages,
61 find_package_data,
61 find_package_data,
62 check_package_data_first,
62 check_package_data_first,
63 find_entry_points,
63 find_entry_points,
64 build_scripts_entrypt,
64 build_scripts_entrypt,
65 find_data_files,
65 find_data_files,
66 check_for_dependencies,
66 check_for_dependencies,
67 git_prebuild,
67 git_prebuild,
68 check_submodule_status,
68 check_submodule_status,
69 update_submodules,
69 update_submodules,
70 require_submodules,
70 require_submodules,
71 UpdateSubmodules,
71 UpdateSubmodules,
72 get_bdist_wheel,
72 get_bdist_wheel,
73 CompileCSS,
73 CompileCSS,
74 JavascriptVersion,
74 JavascriptVersion,
75 css_js_prerelease,
75 css_js_prerelease,
76 install_symlinked,
76 install_symlinked,
77 install_lib_symlink,
77 install_lib_symlink,
78 install_scripts_for_symlink,
78 install_scripts_for_symlink,
79 unsymlink,
79 unsymlink,
80 )
80 )
81 from setupext import setupext
81 from setupext import setupext
82
82
83 isfile = os.path.isfile
83 isfile = os.path.isfile
84 pjoin = os.path.join
84 pjoin = os.path.join
85
85
86 #-------------------------------------------------------------------------------
86 #-------------------------------------------------------------------------------
87 # Handle OS specific things
87 # Handle OS specific things
88 #-------------------------------------------------------------------------------
88 #-------------------------------------------------------------------------------
89
89
90 if os.name in ('nt','dos'):
90 if os.name in ('nt','dos'):
91 os_name = 'windows'
91 os_name = 'windows'
92 else:
92 else:
93 os_name = os.name
93 os_name = os.name
94
94
95 # Under Windows, 'sdist' has not been supported. Now that the docs build with
95 # Under Windows, 'sdist' has not been supported. Now that the docs build with
96 # Sphinx it might work, but let's not turn it on until someone confirms that it
96 # Sphinx it might work, but let's not turn it on until someone confirms that it
97 # actually works.
97 # actually works.
98 if os_name == 'windows' and 'sdist' in sys.argv:
98 if os_name == 'windows' and 'sdist' in sys.argv:
99 print('The sdist command is not available under Windows. Exiting.')
99 print('The sdist command is not available under Windows. Exiting.')
100 sys.exit(1)
100 sys.exit(1)
101
101
102 #-------------------------------------------------------------------------------
102 #-------------------------------------------------------------------------------
103 # Make sure we aren't trying to run without submodules
103 # Make sure we aren't trying to run without submodules
104 #-------------------------------------------------------------------------------
104 #-------------------------------------------------------------------------------
105 here = os.path.abspath(os.path.dirname(__file__))
105 here = os.path.abspath(os.path.dirname(__file__))
106
106
107 def require_clean_submodules():
107 def require_clean_submodules():
108 """Check on git submodules before distutils can do anything
108 """Check on git submodules before distutils can do anything
109
109
110 Since distutils cannot be trusted to update the tree
110 Since distutils cannot be trusted to update the tree
111 after everything has been set in motion,
111 after everything has been set in motion,
112 this is not a distutils command.
112 this is not a distutils command.
113 """
113 """
114 # PACKAGERS: Add a return here to skip checks for git submodules
114 # PACKAGERS: Add a return here to skip checks for git submodules
115
115
116 # don't do anything if nothing is actually supposed to happen
116 # don't do anything if nothing is actually supposed to happen
117 for do_nothing in ('-h', '--help', '--help-commands', 'clean', 'submodule'):
117 for do_nothing in ('-h', '--help', '--help-commands', 'clean', 'submodule'):
118 if do_nothing in sys.argv:
118 if do_nothing in sys.argv:
119 return
119 return
120
120
121 status = check_submodule_status(here)
121 status = check_submodule_status(here)
122
122
123 if status == "missing":
123 if status == "missing":
124 print("checking out submodules for the first time")
124 print("checking out submodules for the first time")
125 update_submodules(here)
125 update_submodules(here)
126 elif status == "unclean":
126 elif status == "unclean":
127 print('\n'.join([
127 print('\n'.join([
128 "Cannot build / install IPython with unclean submodules",
128 "Cannot build / install IPython with unclean submodules",
129 "Please update submodules with",
129 "Please update submodules with",
130 " python setup.py submodule",
130 " python setup.py submodule",
131 "or",
131 "or",
132 " git submodule update",
132 " git submodule update",
133 "or commit any submodule changes you have made."
133 "or commit any submodule changes you have made."
134 ]))
134 ]))
135 sys.exit(1)
135 sys.exit(1)
136
136
137 require_clean_submodules()
137 require_clean_submodules()
138
138
139 #-------------------------------------------------------------------------------
139 #-------------------------------------------------------------------------------
140 # Things related to the IPython documentation
140 # Things related to the IPython documentation
141 #-------------------------------------------------------------------------------
141 #-------------------------------------------------------------------------------
142
142
143 # update the manuals when building a source dist
143 # update the manuals when building a source dist
144 if len(sys.argv) >= 2 and sys.argv[1] in ('sdist','bdist_rpm'):
144 if len(sys.argv) >= 2 and sys.argv[1] in ('sdist','bdist_rpm'):
145
145
146 # List of things to be updated. Each entry is a triplet of args for
146 # List of things to be updated. Each entry is a triplet of args for
147 # target_update()
147 # target_update()
148 to_update = [
148 to_update = [
149 # FIXME - Disabled for now: we need to redo an automatic way
149 # FIXME - Disabled for now: we need to redo an automatic way
150 # of generating the magic info inside the rst.
150 # of generating the magic info inside the rst.
151 #('docs/magic.tex',
151 #('docs/magic.tex',
152 #['IPython/Magic.py'],
152 #['IPython/Magic.py'],
153 #"cd doc && ./update_magic.sh" ),
153 #"cd doc && ./update_magic.sh" ),
154
154
155 ('docs/man/ipcluster.1.gz',
155 ('docs/man/ipcluster.1.gz',
156 ['docs/man/ipcluster.1'],
156 ['docs/man/ipcluster.1'],
157 'cd docs/man && gzip -9c ipcluster.1 > ipcluster.1.gz'),
157 'cd docs/man && gzip -9c ipcluster.1 > ipcluster.1.gz'),
158
158
159 ('docs/man/ipcontroller.1.gz',
159 ('docs/man/ipcontroller.1.gz',
160 ['docs/man/ipcontroller.1'],
160 ['docs/man/ipcontroller.1'],
161 'cd docs/man && gzip -9c ipcontroller.1 > ipcontroller.1.gz'),
161 'cd docs/man && gzip -9c ipcontroller.1 > ipcontroller.1.gz'),
162
162
163 ('docs/man/ipengine.1.gz',
163 ('docs/man/ipengine.1.gz',
164 ['docs/man/ipengine.1'],
164 ['docs/man/ipengine.1'],
165 'cd docs/man && gzip -9c ipengine.1 > ipengine.1.gz'),
165 'cd docs/man && gzip -9c ipengine.1 > ipengine.1.gz'),
166
166
167 ('docs/man/ipython.1.gz',
167 ('docs/man/ipython.1.gz',
168 ['docs/man/ipython.1'],
168 ['docs/man/ipython.1'],
169 'cd docs/man && gzip -9c ipython.1 > ipython.1.gz'),
169 'cd docs/man && gzip -9c ipython.1 > ipython.1.gz'),
170
170
171 ]
171 ]
172
172
173
173
174 [ target_update(*t) for t in to_update ]
174 [ target_update(*t) for t in to_update ]
175
175
176 #---------------------------------------------------------------------------
176 #---------------------------------------------------------------------------
177 # Find all the packages, package data, and data_files
177 # Find all the packages, package data, and data_files
178 #---------------------------------------------------------------------------
178 #---------------------------------------------------------------------------
179
179
180 packages = find_packages()
180 packages = find_packages()
181 package_data = find_package_data()
181 package_data = find_package_data()
182
182
183 data_files = find_data_files()
183 data_files = find_data_files()
184
184
185 setup_args['packages'] = packages
185 setup_args['packages'] = packages
186 setup_args['package_data'] = package_data
186 setup_args['package_data'] = package_data
187 setup_args['data_files'] = data_files
187 setup_args['data_files'] = data_files
188
188
189 #---------------------------------------------------------------------------
189 #---------------------------------------------------------------------------
190 # custom distutils commands
190 # custom distutils commands
191 #---------------------------------------------------------------------------
191 #---------------------------------------------------------------------------
192 # imports here, so they are after setuptools import if there was one
192 # imports here, so they are after setuptools import if there was one
193 from distutils.command.sdist import sdist
193 from distutils.command.sdist import sdist
194 from distutils.command.upload import upload
194 from distutils.command.upload import upload
195
195
196 class UploadWindowsInstallers(upload):
196 class UploadWindowsInstallers(upload):
197
197
198 description = "Upload Windows installers to PyPI (only used from tools/release_windows.py)"
198 description = "Upload Windows installers to PyPI (only used from tools/release_windows.py)"
199 user_options = upload.user_options + [
199 user_options = upload.user_options + [
200 ('files=', 'f', 'exe file (or glob) to upload')
200 ('files=', 'f', 'exe file (or glob) to upload')
201 ]
201 ]
202 def initialize_options(self):
202 def initialize_options(self):
203 upload.initialize_options(self)
203 upload.initialize_options(self)
204 meta = self.distribution.metadata
204 meta = self.distribution.metadata
205 base = '{name}-{version}'.format(
205 base = '{name}-{version}'.format(
206 name=meta.get_name(),
206 name=meta.get_name(),
207 version=meta.get_version()
207 version=meta.get_version()
208 )
208 )
209 self.files = os.path.join('dist', '%s.*.exe' % base)
209 self.files = os.path.join('dist', '%s.*.exe' % base)
210
210
211 def run(self):
211 def run(self):
212 for dist_file in glob(self.files):
212 for dist_file in glob(self.files):
213 self.upload_file('bdist_wininst', 'any', dist_file)
213 self.upload_file('bdist_wininst', 'any', dist_file)
214
214
215 setup_args['cmdclass'] = {
215 setup_args['cmdclass'] = {
216 'build_py': css_js_prerelease(
216 'build_py': css_js_prerelease(
217 check_package_data_first(git_prebuild('IPython'))),
217 check_package_data_first(git_prebuild('IPython'))),
218 'sdist' : css_js_prerelease(git_prebuild('IPython', sdist)),
218 'sdist' : css_js_prerelease(git_prebuild('IPython', sdist)),
219 'upload_wininst' : UploadWindowsInstallers,
219 'upload_wininst' : UploadWindowsInstallers,
220 'submodule' : UpdateSubmodules,
220 'submodule' : UpdateSubmodules,
221 'css' : CompileCSS,
221 'css' : CompileCSS,
222 'symlink': install_symlinked,
222 'symlink': install_symlinked,
223 'install_lib_symlink': install_lib_symlink,
223 'install_lib_symlink': install_lib_symlink,
224 'install_scripts_sym': install_scripts_for_symlink,
224 'install_scripts_sym': install_scripts_for_symlink,
225 'unsymlink': unsymlink,
225 'unsymlink': unsymlink,
226 'jsversion' : JavascriptVersion,
226 'jsversion' : JavascriptVersion,
227 }
227 }
228
228
229 #---------------------------------------------------------------------------
229 #---------------------------------------------------------------------------
230 # Handle scripts, dependencies, and setuptools specific things
230 # Handle scripts, dependencies, and setuptools specific things
231 #---------------------------------------------------------------------------
231 #---------------------------------------------------------------------------
232
232
233 # For some commands, use setuptools. Note that we do NOT list install here!
233 # For some commands, use setuptools. Note that we do NOT list install here!
234 # If you want a setuptools-enhanced install, just run 'setupegg.py install'
234 # If you want a setuptools-enhanced install, just run 'setupegg.py install'
235 needs_setuptools = set(('develop', 'release', 'bdist_egg', 'bdist_rpm',
235 needs_setuptools = set(('develop', 'release', 'bdist_egg', 'bdist_rpm',
236 'bdist', 'bdist_dumb', 'bdist_wininst', 'bdist_wheel',
236 'bdist', 'bdist_dumb', 'bdist_wininst', 'bdist_wheel',
237 'egg_info', 'easy_install', 'upload', 'install_egg_info',
237 'egg_info', 'easy_install', 'upload', 'install_egg_info',
238 ))
238 ))
239
239
240 if len(needs_setuptools.intersection(sys.argv)) > 0:
240 if len(needs_setuptools.intersection(sys.argv)) > 0:
241 import setuptools
241 import setuptools
242
242
243 # This dict is used for passing extra arguments that are setuptools
243 # This dict is used for passing extra arguments that are setuptools
244 # specific to setup
244 # specific to setup
245 setuptools_extra_args = {}
245 setuptools_extra_args = {}
246
246
247 # setuptools requirements
247 # setuptools requirements
248
248
249 pyzmq = 'pyzmq>=13'
249 pyzmq = 'pyzmq>=13'
250
250
251 extras_require = dict(
251 extras_require = dict(
252 parallel = [pyzmq],
252 parallel = [pyzmq],
253 qtconsole = [pyzmq, 'pygments'],
253 qtconsole = [pyzmq, 'pygments'],
254 doc = ['Sphinx>=1.1', 'numpydoc'],
254 doc = ['Sphinx>=1.1', 'numpydoc'],
255 test = ['nose>=0.10.1', 'requests'],
255 test = ['nose>=0.10.1', 'requests'],
256 terminal = [],
256 terminal = [],
257 nbformat = ['jsonschema>=2.0'],
257 nbformat = ['jsonschema>=2.0'],
258 notebook = ['tornado>=4.0', pyzmq, 'jinja2', 'pygments', 'mistune>=0.5'],
258 notebook = ['tornado>=4.0', pyzmq, 'jinja2', 'pygments', 'mistune>=0.5'],
259 nbconvert = ['pygments', 'jinja2', 'mistune>=0.3.1']
259 nbconvert = ['pygments', 'jinja2', 'mistune>=0.3.1']
260 )
260 )
261
261
262 if not sys.platform.startswith('win'):
262 if not sys.platform.startswith('win'):
263 extras_require['notebook'].append('terminado>=0.3.3')
263 extras_require['notebook'].append('terminado>=0.3.3')
264
264
265 if sys.version_info < (3, 3):
265 if sys.version_info < (3, 3):
266 extras_require['test'].append('mock')
266 extras_require['test'].append('mock')
267
267
268 extras_require['notebook'].extend(extras_require['nbformat'])
268 extras_require['notebook'].extend(extras_require['nbformat'])
269 extras_require['nbconvert'].extend(extras_require['nbformat'])
269 extras_require['nbconvert'].extend(extras_require['nbformat'])
270
270
271 install_requires = [
271 install_requires = [
272 'decorator',
272 'path.py', # required by pickleshare, remove when pickleshare is added here
273 'path.py', # required by pickleshare, remove when pickleshare is added here
273 ]
274 ]
274
275
275 # add readline
276 # add readline
276 if sys.platform == 'darwin':
277 if sys.platform == 'darwin':
277 if 'bdist_wheel' in sys.argv[1:] or not setupext.check_for_readline():
278 if 'bdist_wheel' in sys.argv[1:] or not setupext.check_for_readline():
278 install_requires.append('gnureadline')
279 install_requires.append('gnureadline')
279 elif sys.platform.startswith('win'):
280 elif sys.platform.startswith('win'):
280 extras_require['terminal'].append('pyreadline>=2.0')
281 extras_require['terminal'].append('pyreadline>=2.0')
281
282
282 everything = set()
283 everything = set()
283 for deps in extras_require.values():
284 for deps in extras_require.values():
284 everything.update(deps)
285 everything.update(deps)
285 extras_require['all'] = everything
286 extras_require['all'] = everything
286
287
287 if 'setuptools' in sys.modules:
288 if 'setuptools' in sys.modules:
288 # setup.py develop should check for submodules
289 # setup.py develop should check for submodules
289 from setuptools.command.develop import develop
290 from setuptools.command.develop import develop
290 setup_args['cmdclass']['develop'] = require_submodules(develop)
291 setup_args['cmdclass']['develop'] = require_submodules(develop)
291 setup_args['cmdclass']['bdist_wheel'] = css_js_prerelease(get_bdist_wheel())
292 setup_args['cmdclass']['bdist_wheel'] = css_js_prerelease(get_bdist_wheel())
292
293
293 setuptools_extra_args['zip_safe'] = False
294 setuptools_extra_args['zip_safe'] = False
294 setuptools_extra_args['entry_points'] = {
295 setuptools_extra_args['entry_points'] = {
295 'console_scripts': find_entry_points(),
296 'console_scripts': find_entry_points(),
296 'pygments.lexers': [
297 'pygments.lexers': [
297 'ipythonconsole = IPython.lib.lexers:IPythonConsoleLexer',
298 'ipythonconsole = IPython.lib.lexers:IPythonConsoleLexer',
298 'ipython = IPython.lib.lexers:IPythonLexer',
299 'ipython = IPython.lib.lexers:IPythonLexer',
299 'ipython3 = IPython.lib.lexers:IPython3Lexer',
300 'ipython3 = IPython.lib.lexers:IPython3Lexer',
300 ],
301 ],
301 }
302 }
302 setup_args['extras_require'] = extras_require
303 setup_args['extras_require'] = extras_require
303 requires = setup_args['install_requires'] = install_requires
304 requires = setup_args['install_requires'] = install_requires
304
305
305 # Script to be run by the windows binary installer after the default setup
306 # Script to be run by the windows binary installer after the default setup
306 # routine, to add shortcuts and similar windows-only things. Windows
307 # routine, to add shortcuts and similar windows-only things. Windows
307 # post-install scripts MUST reside in the scripts/ dir, otherwise distutils
308 # post-install scripts MUST reside in the scripts/ dir, otherwise distutils
308 # doesn't find them.
309 # doesn't find them.
309 if 'bdist_wininst' in sys.argv:
310 if 'bdist_wininst' in sys.argv:
310 if len(sys.argv) > 2 and \
311 if len(sys.argv) > 2 and \
311 ('sdist' in sys.argv or 'bdist_rpm' in sys.argv):
312 ('sdist' in sys.argv or 'bdist_rpm' in sys.argv):
312 print("ERROR: bdist_wininst must be run alone. Exiting.", file=sys.stderr)
313 print("ERROR: bdist_wininst must be run alone. Exiting.", file=sys.stderr)
313 sys.exit(1)
314 sys.exit(1)
314 setup_args['data_files'].append(
315 setup_args['data_files'].append(
315 ['Scripts', ('scripts/ipython.ico', 'scripts/ipython_nb.ico')])
316 ['Scripts', ('scripts/ipython.ico', 'scripts/ipython_nb.ico')])
316 setup_args['scripts'] = [pjoin('scripts','ipython_win_post_install.py')]
317 setup_args['scripts'] = [pjoin('scripts','ipython_win_post_install.py')]
317 setup_args['options'] = {"bdist_wininst":
318 setup_args['options'] = {"bdist_wininst":
318 {"install_script":
319 {"install_script":
319 "ipython_win_post_install.py"}}
320 "ipython_win_post_install.py"}}
320
321
321 else:
322 else:
322 # If we are installing without setuptools, call this function which will
323 # If we are installing without setuptools, call this function which will
323 # check for dependencies an inform the user what is needed. This is
324 # check for dependencies an inform the user what is needed. This is
324 # just to make life easy for users.
325 # just to make life easy for users.
325 for install_cmd in ('install', 'symlink'):
326 for install_cmd in ('install', 'symlink'):
326 if install_cmd in sys.argv:
327 if install_cmd in sys.argv:
327 check_for_dependencies()
328 check_for_dependencies()
328 break
329 break
329 # scripts has to be a non-empty list, or install_scripts isn't called
330 # scripts has to be a non-empty list, or install_scripts isn't called
330 setup_args['scripts'] = [e.split('=')[0].strip() for e in find_entry_points()]
331 setup_args['scripts'] = [e.split('=')[0].strip() for e in find_entry_points()]
331
332
332 setup_args['cmdclass']['build_scripts'] = build_scripts_entrypt
333 setup_args['cmdclass']['build_scripts'] = build_scripts_entrypt
333
334
334 #---------------------------------------------------------------------------
335 #---------------------------------------------------------------------------
335 # Do the actual setup now
336 # Do the actual setup now
336 #---------------------------------------------------------------------------
337 #---------------------------------------------------------------------------
337
338
338 setup_args.update(setuptools_extra_args)
339 setup_args.update(setuptools_extra_args)
339
340
340 def main():
341 def main():
341 setup(**setup_args)
342 setup(**setup_args)
342
343
343 if __name__ == '__main__':
344 if __name__ == '__main__':
344 main()
345 main()
1 NO CONTENT: file was removed
NO CONTENT: file was removed
1 NO CONTENT: file was removed
NO CONTENT: file was removed
General Comments 0
You need to be logged in to leave comments. Login now