##// END OF EJS Templates
update recently changed modules with Authors in docstring
MinRK -
Show More

The requested changes are too big and content was truncated. Show full diff

@@ -1,361 +1,362 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 A base class for a configurable application.
3 A base class for a configurable application.
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Min RK
8 """
9 """
9
10
10 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
11 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008-2011 The IPython Development Team
12 #
13 #
13 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
16
17
17 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
18 # Imports
19 # Imports
19 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
20
21
21 from copy import deepcopy
22 from copy import deepcopy
22 import logging
23 import logging
23 import re
24 import re
24 import sys
25 import sys
25
26
26 from IPython.config.configurable import SingletonConfigurable
27 from IPython.config.configurable import SingletonConfigurable
27 from IPython.config.loader import (
28 from IPython.config.loader import (
28 KeyValueConfigLoader, PyFileConfigLoader, Config, ArgumentError
29 KeyValueConfigLoader, PyFileConfigLoader, Config, ArgumentError
29 )
30 )
30
31
31 from IPython.utils.traitlets import (
32 from IPython.utils.traitlets import (
32 Unicode, List, Int, Enum, Dict, Instance
33 Unicode, List, Int, Enum, Dict, Instance
33 )
34 )
34 from IPython.utils.importstring import import_item
35 from IPython.utils.importstring import import_item
35 from IPython.utils.text import indent
36 from IPython.utils.text import indent
36
37
37 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
38 # Descriptions for the various sections
39 # Descriptions for the various sections
39 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
40
41
41 flag_description = """
42 flag_description = """
42 Flags are command-line arguments passed as '--<flag>'.
43 Flags are command-line arguments passed as '--<flag>'.
43 These take no parameters, unlike regular key-value arguments.
44 These take no parameters, unlike regular key-value arguments.
44 They are typically used for setting boolean flags, or enabling
45 They are typically used for setting boolean flags, or enabling
45 modes that involve setting multiple options together.
46 modes that involve setting multiple options together.
46 """.strip() # trim newlines of front and back
47 """.strip() # trim newlines of front and back
47
48
48 alias_description = """
49 alias_description = """
49 These are commonly set parameters, given abbreviated aliases for convenience.
50 These are commonly set parameters, given abbreviated aliases for convenience.
50 They are set in the same `name=value` way as class parameters, where
51 They are set in the same `name=value` way as class parameters, where
51 <name> is replaced by the real parameter for which it is an alias.
52 <name> is replaced by the real parameter for which it is an alias.
52 """.strip() # trim newlines of front and back
53 """.strip() # trim newlines of front and back
53
54
54 keyvalue_description = """
55 keyvalue_description = """
55 Parameters are set from command-line arguments of the form:
56 Parameters are set from command-line arguments of the form:
56 `Class.trait=value`. Parameters will *never* be prefixed with '-'.
57 `Class.trait=value`. Parameters will *never* be prefixed with '-'.
57 This line is evaluated in Python, so simple expressions are allowed, e.g.
58 This line is evaluated in Python, so simple expressions are allowed, e.g.
58 `C.a='range(3)'` For setting C.a=[0,1,2]
59 `C.a='range(3)'` For setting C.a=[0,1,2]
59 """.strip() # trim newlines of front and back
60 """.strip() # trim newlines of front and back
60
61
61 #-----------------------------------------------------------------------------
62 #-----------------------------------------------------------------------------
62 # Application class
63 # Application class
63 #-----------------------------------------------------------------------------
64 #-----------------------------------------------------------------------------
64
65
65
66
66 class ApplicationError(Exception):
67 class ApplicationError(Exception):
67 pass
68 pass
68
69
69
70
70 class Application(SingletonConfigurable):
71 class Application(SingletonConfigurable):
71 """A singleton application with full configuration support."""
72 """A singleton application with full configuration support."""
72
73
73 # The name of the application, will usually match the name of the command
74 # The name of the application, will usually match the name of the command
74 # line application
75 # line application
75 name = Unicode(u'application')
76 name = Unicode(u'application')
76
77
77 # The description of the application that is printed at the beginning
78 # The description of the application that is printed at the beginning
78 # of the help.
79 # of the help.
79 description = Unicode(u'This is an application.')
80 description = Unicode(u'This is an application.')
80 # default section descriptions
81 # default section descriptions
81 flag_description = Unicode(flag_description)
82 flag_description = Unicode(flag_description)
82 alias_description = Unicode(alias_description)
83 alias_description = Unicode(alias_description)
83 keyvalue_description = Unicode(keyvalue_description)
84 keyvalue_description = Unicode(keyvalue_description)
84
85
85
86
86 # A sequence of Configurable subclasses whose config=True attributes will
87 # A sequence of Configurable subclasses whose config=True attributes will
87 # be exposed at the command line.
88 # be exposed at the command line.
88 classes = List([])
89 classes = List([])
89
90
90 # The version string of this application.
91 # The version string of this application.
91 version = Unicode(u'0.0')
92 version = Unicode(u'0.0')
92
93
93 # The log level for the application
94 # The log level for the application
94 log_level = Enum((0,10,20,30,40,50), default_value=logging.WARN,
95 log_level = Enum((0,10,20,30,40,50), default_value=logging.WARN,
95 config=True,
96 config=True,
96 help="Set the log level.")
97 help="Set the log level.")
97
98
98 # the alias map for configurables
99 # the alias map for configurables
99 aliases = Dict(dict(log_level='Application.log_level'))
100 aliases = Dict(dict(log_level='Application.log_level'))
100
101
101 # flags for loading Configurables or store_const style flags
102 # flags for loading Configurables or store_const style flags
102 # flags are loaded from this dict by '--key' flags
103 # flags are loaded from this dict by '--key' flags
103 # this must be a dict of two-tuples, the first element being the Config/dict
104 # this must be a dict of two-tuples, the first element being the Config/dict
104 # and the second being the help string for the flag
105 # and the second being the help string for the flag
105 flags = Dict()
106 flags = Dict()
106
107
107 # subcommands for launching other applications
108 # subcommands for launching other applications
108 # if this is not empty, this will be a parent Application
109 # if this is not empty, this will be a parent Application
109 # this must be a dict of two-tuples, the first element being the application class/import string
110 # this must be a dict of two-tuples, the first element being the application class/import string
110 # and the second being the help string for the subcommand
111 # and the second being the help string for the subcommand
111 subcommands = Dict()
112 subcommands = Dict()
112 # parse_command_line will initialize a subapp, if requested
113 # parse_command_line will initialize a subapp, if requested
113 subapp = Instance('IPython.config.application.Application', allow_none=True)
114 subapp = Instance('IPython.config.application.Application', allow_none=True)
114
115
115 # extra command-line arguments that don't set config values
116 # extra command-line arguments that don't set config values
116 extra_args = List(Unicode)
117 extra_args = List(Unicode)
117
118
118
119
119 def __init__(self, **kwargs):
120 def __init__(self, **kwargs):
120 SingletonConfigurable.__init__(self, **kwargs)
121 SingletonConfigurable.__init__(self, **kwargs)
121 # Add my class to self.classes so my attributes appear in command line
122 # Add my class to self.classes so my attributes appear in command line
122 # options.
123 # options.
123 self.classes.insert(0, self.__class__)
124 self.classes.insert(0, self.__class__)
124
125
125 # ensure self.flags dict is valid
126 # ensure self.flags dict is valid
126 for key,value in self.flags.iteritems():
127 for key,value in self.flags.iteritems():
127 assert len(value) == 2, "Bad flag: %r:%s"%(key,value)
128 assert len(value) == 2, "Bad flag: %r:%s"%(key,value)
128 assert isinstance(value[0], (dict, Config)), "Bad flag: %r:%s"%(key,value)
129 assert isinstance(value[0], (dict, Config)), "Bad flag: %r:%s"%(key,value)
129 assert isinstance(value[1], basestring), "Bad flag: %r:%s"%(key,value)
130 assert isinstance(value[1], basestring), "Bad flag: %r:%s"%(key,value)
130 self.init_logging()
131 self.init_logging()
131
132
132 def _config_changed(self, name, old, new):
133 def _config_changed(self, name, old, new):
133 SingletonConfigurable._config_changed(self, name, old, new)
134 SingletonConfigurable._config_changed(self, name, old, new)
134 self.log.debug('Config changed:')
135 self.log.debug('Config changed:')
135 self.log.debug(repr(new))
136 self.log.debug(repr(new))
136
137
137 def init_logging(self):
138 def init_logging(self):
138 """Start logging for this application.
139 """Start logging for this application.
139
140
140 The default is to log to stdout using a StreaHandler. The log level
141 The default is to log to stdout using a StreaHandler. The log level
141 starts at loggin.WARN, but this can be adjusted by setting the
142 starts at loggin.WARN, but this can be adjusted by setting the
142 ``log_level`` attribute.
143 ``log_level`` attribute.
143 """
144 """
144 self.log = logging.getLogger(self.__class__.__name__)
145 self.log = logging.getLogger(self.__class__.__name__)
145 self.log.setLevel(self.log_level)
146 self.log.setLevel(self.log_level)
146 self._log_handler = logging.StreamHandler()
147 self._log_handler = logging.StreamHandler()
147 self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
148 self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
148 self._log_handler.setFormatter(self._log_formatter)
149 self._log_handler.setFormatter(self._log_formatter)
149 self.log.addHandler(self._log_handler)
150 self.log.addHandler(self._log_handler)
150
151
151 def initialize(self, argv=None):
152 def initialize(self, argv=None):
152 """Do the basic steps to configure me.
153 """Do the basic steps to configure me.
153
154
154 Override in subclasses.
155 Override in subclasses.
155 """
156 """
156 self.parse_command_line(argv)
157 self.parse_command_line(argv)
157
158
158
159
159 def start(self):
160 def start(self):
160 """Start the app mainloop.
161 """Start the app mainloop.
161
162
162 Override in subclasses.
163 Override in subclasses.
163 """
164 """
164 if self.subapp is not None:
165 if self.subapp is not None:
165 return self.subapp.start()
166 return self.subapp.start()
166
167
167 def _log_level_changed(self, name, old, new):
168 def _log_level_changed(self, name, old, new):
168 """Adjust the log level when log_level is set."""
169 """Adjust the log level when log_level is set."""
169 self.log.setLevel(new)
170 self.log.setLevel(new)
170
171
171 def print_alias_help(self):
172 def print_alias_help(self):
172 """print the alias part of the help"""
173 """print the alias part of the help"""
173 if not self.aliases:
174 if not self.aliases:
174 return
175 return
175
176
176 lines = ['Aliases']
177 lines = ['Aliases']
177 lines.append('-'*len(lines[0]))
178 lines.append('-'*len(lines[0]))
178 lines.append(self.alias_description)
179 lines.append(self.alias_description)
179 lines.append('')
180 lines.append('')
180
181
181 classdict = {}
182 classdict = {}
182 for cls in self.classes:
183 for cls in self.classes:
183 # include all parents (up to, but excluding Configurable) in available names
184 # include all parents (up to, but excluding Configurable) in available names
184 for c in cls.mro()[:-3]:
185 for c in cls.mro()[:-3]:
185 classdict[c.__name__] = c
186 classdict[c.__name__] = c
186
187
187 for alias, longname in self.aliases.iteritems():
188 for alias, longname in self.aliases.iteritems():
188 classname, traitname = longname.split('.',1)
189 classname, traitname = longname.split('.',1)
189 cls = classdict[classname]
190 cls = classdict[classname]
190
191
191 trait = cls.class_traits(config=True)[traitname]
192 trait = cls.class_traits(config=True)[traitname]
192 help = cls.class_get_trait_help(trait)
193 help = cls.class_get_trait_help(trait)
193 help = help.replace(longname, "%s (%s)"%(alias, longname), 1)
194 help = help.replace(longname, "%s (%s)"%(alias, longname), 1)
194 lines.append(help)
195 lines.append(help)
195 lines.append('')
196 lines.append('')
196 print '\n'.join(lines)
197 print '\n'.join(lines)
197
198
198 def print_flag_help(self):
199 def print_flag_help(self):
199 """print the flag part of the help"""
200 """print the flag part of the help"""
200 if not self.flags:
201 if not self.flags:
201 return
202 return
202
203
203 lines = ['Flags']
204 lines = ['Flags']
204 lines.append('-'*len(lines[0]))
205 lines.append('-'*len(lines[0]))
205 lines.append(self.flag_description)
206 lines.append(self.flag_description)
206 lines.append('')
207 lines.append('')
207
208
208 for m, (cfg,help) in self.flags.iteritems():
209 for m, (cfg,help) in self.flags.iteritems():
209 lines.append('--'+m)
210 lines.append('--'+m)
210 lines.append(indent(help.strip(), flatten=True))
211 lines.append(indent(help.strip(), flatten=True))
211 lines.append('')
212 lines.append('')
212 print '\n'.join(lines)
213 print '\n'.join(lines)
213
214
214 def print_subcommands(self):
215 def print_subcommands(self):
215 """print the subcommand part of the help"""
216 """print the subcommand part of the help"""
216 if not self.subcommands:
217 if not self.subcommands:
217 return
218 return
218
219
219 lines = ["Subcommands"]
220 lines = ["Subcommands"]
220 lines.append('-'*len(lines[0]))
221 lines.append('-'*len(lines[0]))
221 for subc, (cls,help) in self.subcommands.iteritems():
222 for subc, (cls,help) in self.subcommands.iteritems():
222 lines.append("%s : %s"%(subc, cls))
223 lines.append("%s : %s"%(subc, cls))
223 if help:
224 if help:
224 lines.append(indent(help.strip(), flatten=True))
225 lines.append(indent(help.strip(), flatten=True))
225 lines.append('')
226 lines.append('')
226 print '\n'.join(lines)
227 print '\n'.join(lines)
227
228
228 def print_help(self, classes=False):
229 def print_help(self, classes=False):
229 """Print the help for each Configurable class in self.classes.
230 """Print the help for each Configurable class in self.classes.
230
231
231 If classes=False (the default), only flags and aliases are printed
232 If classes=False (the default), only flags and aliases are printed
232 """
233 """
233 self.print_subcommands()
234 self.print_subcommands()
234 self.print_flag_help()
235 self.print_flag_help()
235 self.print_alias_help()
236 self.print_alias_help()
236
237
237 if classes:
238 if classes:
238 if self.classes:
239 if self.classes:
239 print "Class parameters"
240 print "Class parameters"
240 print "----------------"
241 print "----------------"
241 print self.keyvalue_description
242 print self.keyvalue_description
242 print
243 print
243
244
244 for cls in self.classes:
245 for cls in self.classes:
245 cls.class_print_help()
246 cls.class_print_help()
246 print
247 print
247 else:
248 else:
248 print "To see all available configurables, use `--help-all`"
249 print "To see all available configurables, use `--help-all`"
249 print
250 print
250
251
251 def print_description(self):
252 def print_description(self):
252 """Print the application description."""
253 """Print the application description."""
253 print self.description
254 print self.description
254 print
255 print
255
256
256 def print_version(self):
257 def print_version(self):
257 """Print the version string."""
258 """Print the version string."""
258 print self.version
259 print self.version
259
260
260 def update_config(self, config):
261 def update_config(self, config):
261 """Fire the traits events when the config is updated."""
262 """Fire the traits events when the config is updated."""
262 # Save a copy of the current config.
263 # Save a copy of the current config.
263 newconfig = deepcopy(self.config)
264 newconfig = deepcopy(self.config)
264 # Merge the new config into the current one.
265 # Merge the new config into the current one.
265 newconfig._merge(config)
266 newconfig._merge(config)
266 # Save the combined config as self.config, which triggers the traits
267 # Save the combined config as self.config, which triggers the traits
267 # events.
268 # events.
268 self.config = newconfig
269 self.config = newconfig
269
270
270 def initialize_subcommand(self, subc, argv=None):
271 def initialize_subcommand(self, subc, argv=None):
271 """Initialize a subcommand with argv"""
272 """Initialize a subcommand with argv"""
272 subapp,help = self.subcommands.get(subc)
273 subapp,help = self.subcommands.get(subc)
273
274
274 if isinstance(subapp, basestring):
275 if isinstance(subapp, basestring):
275 subapp = import_item(subapp)
276 subapp = import_item(subapp)
276
277
277 # clear existing instances
278 # clear existing instances
278 self.__class__.clear_instance()
279 self.__class__.clear_instance()
279 # instantiate
280 # instantiate
280 self.subapp = subapp.instance()
281 self.subapp = subapp.instance()
281 # and initialize subapp
282 # and initialize subapp
282 self.subapp.initialize(argv)
283 self.subapp.initialize(argv)
283
284
284 def parse_command_line(self, argv=None):
285 def parse_command_line(self, argv=None):
285 """Parse the command line arguments."""
286 """Parse the command line arguments."""
286 argv = sys.argv[1:] if argv is None else argv
287 argv = sys.argv[1:] if argv is None else argv
287
288
288 if self.subcommands and len(argv) > 0:
289 if self.subcommands and len(argv) > 0:
289 # we have subcommands, and one may have been specified
290 # we have subcommands, and one may have been specified
290 subc, subargv = argv[0], argv[1:]
291 subc, subargv = argv[0], argv[1:]
291 if re.match(r'^\w(\-?\w)*$', subc) and subc in self.subcommands:
292 if re.match(r'^\w(\-?\w)*$', subc) and subc in self.subcommands:
292 # it's a subcommand, and *not* a flag or class parameter
293 # it's a subcommand, and *not* a flag or class parameter
293 return self.initialize_subcommand(subc, subargv)
294 return self.initialize_subcommand(subc, subargv)
294
295
295 if '-h' in argv or '--help' in argv or '--help-all' in argv:
296 if '-h' in argv or '--help' in argv or '--help-all' in argv:
296 self.print_description()
297 self.print_description()
297 self.print_help('--help-all' in argv)
298 self.print_help('--help-all' in argv)
298 self.exit(0)
299 self.exit(0)
299
300
300 if '--version' in argv:
301 if '--version' in argv:
301 self.print_version()
302 self.print_version()
302 self.exit(0)
303 self.exit(0)
303
304
304 loader = KeyValueConfigLoader(argv=argv, aliases=self.aliases,
305 loader = KeyValueConfigLoader(argv=argv, aliases=self.aliases,
305 flags=self.flags)
306 flags=self.flags)
306 try:
307 try:
307 config = loader.load_config()
308 config = loader.load_config()
308 except ArgumentError as e:
309 except ArgumentError as e:
309 self.log.fatal(str(e))
310 self.log.fatal(str(e))
310 self.print_description()
311 self.print_description()
311 self.print_help()
312 self.print_help()
312 self.exit(1)
313 self.exit(1)
313 self.update_config(config)
314 self.update_config(config)
314 # store unparsed args in extra_args
315 # store unparsed args in extra_args
315 self.extra_args = loader.extra_args
316 self.extra_args = loader.extra_args
316
317
317 def load_config_file(self, filename, path=None):
318 def load_config_file(self, filename, path=None):
318 """Load a .py based config file by filename and path."""
319 """Load a .py based config file by filename and path."""
319 loader = PyFileConfigLoader(filename, path=path)
320 loader = PyFileConfigLoader(filename, path=path)
320 config = loader.load_config()
321 config = loader.load_config()
321 self.update_config(config)
322 self.update_config(config)
322
323
323 def exit(self, exit_status=0):
324 def exit(self, exit_status=0):
324 self.log.debug("Exiting application: %s" % self.name)
325 self.log.debug("Exiting application: %s" % self.name)
325 sys.exit(exit_status)
326 sys.exit(exit_status)
326
327
327 #-----------------------------------------------------------------------------
328 #-----------------------------------------------------------------------------
328 # utility functions, for convenience
329 # utility functions, for convenience
329 #-----------------------------------------------------------------------------
330 #-----------------------------------------------------------------------------
330
331
331 def boolean_flag(name, configurable, set_help='', unset_help=''):
332 def boolean_flag(name, configurable, set_help='', unset_help=''):
332 """helper for building basic --trait, --no-trait flags
333 """helper for building basic --trait, --no-trait flags
333
334
334 Parameters
335 Parameters
335 ----------
336 ----------
336
337
337 name : str
338 name : str
338 The name of the flag.
339 The name of the flag.
339 configurable : str
340 configurable : str
340 The 'Class.trait' string of the trait to be set/unset with the flag
341 The 'Class.trait' string of the trait to be set/unset with the flag
341 set_help : unicode
342 set_help : unicode
342 help string for --name flag
343 help string for --name flag
343 unset_help : unicode
344 unset_help : unicode
344 help string for --no-name flag
345 help string for --no-name flag
345
346
346 Returns
347 Returns
347 -------
348 -------
348
349
349 cfg : dict
350 cfg : dict
350 A dict with two keys: 'name', and 'no-name', for setting and unsetting
351 A dict with two keys: 'name', and 'no-name', for setting and unsetting
351 the trait, respectively.
352 the trait, respectively.
352 """
353 """
353 # default helpstrings
354 # default helpstrings
354 set_help = set_help or "set %s=True"%configurable
355 set_help = set_help or "set %s=True"%configurable
355 unset_help = unset_help or "set %s=False"%configurable
356 unset_help = unset_help or "set %s=False"%configurable
356
357
357 cls,trait = configurable.split('.')
358 cls,trait = configurable.split('.')
358
359
359 setter = {cls : {trait : True}}
360 setter = {cls : {trait : True}}
360 unsetter = {cls : {trait : False}}
361 unsetter = {cls : {trait : False}}
361 return {name : (setter, set_help), 'no-'+name : (unsetter, unset_help)}
362 return {name : (setter, set_help), 'no-'+name : (unsetter, unset_help)}
@@ -1,278 +1,279 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 A base class for objects that are configurable.
4 A base class for objects that are configurable.
5
5
6 Authors:
6 Authors:
7
7
8 * Brian Granger
8 * Brian Granger
9 * Fernando Perez
9 * Fernando Perez
10 * Min RK
10 """
11 """
11
12
12 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
13 # Copyright (C) 2008-2010 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
14 #
15 #
15 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
18
19
19 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
20 # Imports
21 # Imports
21 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
22
23
23 from copy import deepcopy
24 from copy import deepcopy
24 import datetime
25 import datetime
25
26
26 from loader import Config
27 from loader import Config
27 from IPython.utils.traitlets import HasTraits, Instance
28 from IPython.utils.traitlets import HasTraits, Instance
28 from IPython.utils.text import indent
29 from IPython.utils.text import indent
29
30
30
31
31 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
32 # Helper classes for Configurables
33 # Helper classes for Configurables
33 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
34
35
35
36
36 class ConfigurableError(Exception):
37 class ConfigurableError(Exception):
37 pass
38 pass
38
39
39
40
40 class MultipleInstanceError(ConfigurableError):
41 class MultipleInstanceError(ConfigurableError):
41 pass
42 pass
42
43
43 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
44 # Configurable implementation
45 # Configurable implementation
45 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
46
47
47 class Configurable(HasTraits):
48 class Configurable(HasTraits):
48
49
49 config = Instance(Config,(),{})
50 config = Instance(Config,(),{})
50 created = None
51 created = None
51
52
52 def __init__(self, **kwargs):
53 def __init__(self, **kwargs):
53 """Create a conigurable given a config config.
54 """Create a conigurable given a config config.
54
55
55 Parameters
56 Parameters
56 ----------
57 ----------
57 config : Config
58 config : Config
58 If this is empty, default values are used. If config is a
59 If this is empty, default values are used. If config is a
59 :class:`Config` instance, it will be used to configure the
60 :class:`Config` instance, it will be used to configure the
60 instance.
61 instance.
61
62
62 Notes
63 Notes
63 -----
64 -----
64 Subclasses of Configurable must call the :meth:`__init__` method of
65 Subclasses of Configurable must call the :meth:`__init__` method of
65 :class:`Configurable` *before* doing anything else and using
66 :class:`Configurable` *before* doing anything else and using
66 :func:`super`::
67 :func:`super`::
67
68
68 class MyConfigurable(Configurable):
69 class MyConfigurable(Configurable):
69 def __init__(self, config=None):
70 def __init__(self, config=None):
70 super(MyConfigurable, self).__init__(config)
71 super(MyConfigurable, self).__init__(config)
71 # Then any other code you need to finish initialization.
72 # Then any other code you need to finish initialization.
72
73
73 This ensures that instances will be configured properly.
74 This ensures that instances will be configured properly.
74 """
75 """
75 config = kwargs.pop('config', None)
76 config = kwargs.pop('config', None)
76 if config is not None:
77 if config is not None:
77 # We used to deepcopy, but for now we are trying to just save
78 # We used to deepcopy, but for now we are trying to just save
78 # by reference. This *could* have side effects as all components
79 # by reference. This *could* have side effects as all components
79 # will share config. In fact, I did find such a side effect in
80 # will share config. In fact, I did find such a side effect in
80 # _config_changed below. If a config attribute value was a mutable type
81 # _config_changed below. If a config attribute value was a mutable type
81 # all instances of a component were getting the same copy, effectively
82 # all instances of a component were getting the same copy, effectively
82 # making that a class attribute.
83 # making that a class attribute.
83 # self.config = deepcopy(config)
84 # self.config = deepcopy(config)
84 self.config = config
85 self.config = config
85 # This should go second so individual keyword arguments override
86 # This should go second so individual keyword arguments override
86 # the values in config.
87 # the values in config.
87 super(Configurable, self).__init__(**kwargs)
88 super(Configurable, self).__init__(**kwargs)
88 self.created = datetime.datetime.now()
89 self.created = datetime.datetime.now()
89
90
90 #-------------------------------------------------------------------------
91 #-------------------------------------------------------------------------
91 # Static trait notifiations
92 # Static trait notifiations
92 #-------------------------------------------------------------------------
93 #-------------------------------------------------------------------------
93
94
94 def _config_changed(self, name, old, new):
95 def _config_changed(self, name, old, new):
95 """Update all the class traits having ``config=True`` as metadata.
96 """Update all the class traits having ``config=True`` as metadata.
96
97
97 For any class trait with a ``config`` metadata attribute that is
98 For any class trait with a ``config`` metadata attribute that is
98 ``True``, we update the trait with the value of the corresponding
99 ``True``, we update the trait with the value of the corresponding
99 config entry.
100 config entry.
100 """
101 """
101 # Get all traits with a config metadata entry that is True
102 # Get all traits with a config metadata entry that is True
102 traits = self.traits(config=True)
103 traits = self.traits(config=True)
103
104
104 # We auto-load config section for this class as well as any parent
105 # We auto-load config section for this class as well as any parent
105 # classes that are Configurable subclasses. This starts with Configurable
106 # classes that are Configurable subclasses. This starts with Configurable
106 # and works down the mro loading the config for each section.
107 # and works down the mro loading the config for each section.
107 section_names = [cls.__name__ for cls in \
108 section_names = [cls.__name__ for cls in \
108 reversed(self.__class__.__mro__) if
109 reversed(self.__class__.__mro__) if
109 issubclass(cls, Configurable) and issubclass(self.__class__, cls)]
110 issubclass(cls, Configurable) and issubclass(self.__class__, cls)]
110
111
111 for sname in section_names:
112 for sname in section_names:
112 # Don't do a blind getattr as that would cause the config to
113 # Don't do a blind getattr as that would cause the config to
113 # dynamically create the section with name self.__class__.__name__.
114 # dynamically create the section with name self.__class__.__name__.
114 if new._has_section(sname):
115 if new._has_section(sname):
115 my_config = new[sname]
116 my_config = new[sname]
116 for k, v in traits.iteritems():
117 for k, v in traits.iteritems():
117 # Don't allow traitlets with config=True to start with
118 # Don't allow traitlets with config=True to start with
118 # uppercase. Otherwise, they are confused with Config
119 # uppercase. Otherwise, they are confused with Config
119 # subsections. But, developers shouldn't have uppercase
120 # subsections. But, developers shouldn't have uppercase
120 # attributes anyways! (PEP 6)
121 # attributes anyways! (PEP 6)
121 if k[0].upper()==k[0] and not k.startswith('_'):
122 if k[0].upper()==k[0] and not k.startswith('_'):
122 raise ConfigurableError('Configurable traitlets with '
123 raise ConfigurableError('Configurable traitlets with '
123 'config=True must start with a lowercase so they are '
124 'config=True must start with a lowercase so they are '
124 'not confused with Config subsections: %s.%s' % \
125 'not confused with Config subsections: %s.%s' % \
125 (self.__class__.__name__, k))
126 (self.__class__.__name__, k))
126 try:
127 try:
127 # Here we grab the value from the config
128 # Here we grab the value from the config
128 # If k has the naming convention of a config
129 # If k has the naming convention of a config
129 # section, it will be auto created.
130 # section, it will be auto created.
130 config_value = my_config[k]
131 config_value = my_config[k]
131 except KeyError:
132 except KeyError:
132 pass
133 pass
133 else:
134 else:
134 # print "Setting %s.%s from %s.%s=%r" % \
135 # print "Setting %s.%s from %s.%s=%r" % \
135 # (self.__class__.__name__,k,sname,k,config_value)
136 # (self.__class__.__name__,k,sname,k,config_value)
136 # We have to do a deepcopy here if we don't deepcopy the entire
137 # We have to do a deepcopy here if we don't deepcopy the entire
137 # config object. If we don't, a mutable config_value will be
138 # config object. If we don't, a mutable config_value will be
138 # shared by all instances, effectively making it a class attribute.
139 # shared by all instances, effectively making it a class attribute.
139 setattr(self, k, deepcopy(config_value))
140 setattr(self, k, deepcopy(config_value))
140
141
141 @classmethod
142 @classmethod
142 def class_get_help(cls):
143 def class_get_help(cls):
143 """Get the help string for this class in ReST format."""
144 """Get the help string for this class in ReST format."""
144 cls_traits = cls.class_traits(config=True)
145 cls_traits = cls.class_traits(config=True)
145 final_help = []
146 final_help = []
146 final_help.append(u'%s options' % cls.__name__)
147 final_help.append(u'%s options' % cls.__name__)
147 final_help.append(len(final_help[0])*u'-')
148 final_help.append(len(final_help[0])*u'-')
148 for k,v in cls.class_traits(config=True).iteritems():
149 for k,v in cls.class_traits(config=True).iteritems():
149 help = cls.class_get_trait_help(v)
150 help = cls.class_get_trait_help(v)
150 final_help.append(help)
151 final_help.append(help)
151 return '\n'.join(final_help)
152 return '\n'.join(final_help)
152
153
153 @classmethod
154 @classmethod
154 def class_get_trait_help(cls, trait):
155 def class_get_trait_help(cls, trait):
155 """Get the help string for a single """
156 """Get the help string for a single """
156 lines = []
157 lines = []
157 header = "%s.%s : %s" % (cls.__name__, trait.name, trait.__class__.__name__)
158 header = "%s.%s : %s" % (cls.__name__, trait.name, trait.__class__.__name__)
158 try:
159 try:
159 dvr = repr(trait.get_default_value())
160 dvr = repr(trait.get_default_value())
160 except Exception:
161 except Exception:
161 dvr = None # ignore defaults we can't construct
162 dvr = None # ignore defaults we can't construct
162 if dvr is not None:
163 if dvr is not None:
163 header += ' [default: %s]'%dvr
164 header += ' [default: %s]'%dvr
164 lines.append(header)
165 lines.append(header)
165
166
166 help = trait.get_metadata('help')
167 help = trait.get_metadata('help')
167 if help is not None:
168 if help is not None:
168 lines.append(indent(help.strip(), flatten=True))
169 lines.append(indent(help.strip(), flatten=True))
169 if 'Enum' in trait.__class__.__name__:
170 if 'Enum' in trait.__class__.__name__:
170 # include Enum choices
171 # include Enum choices
171 lines.append(indent('Choices: %r'%(trait.values,), flatten=True))
172 lines.append(indent('Choices: %r'%(trait.values,), flatten=True))
172 return '\n'.join(lines)
173 return '\n'.join(lines)
173
174
174 @classmethod
175 @classmethod
175 def class_print_help(cls):
176 def class_print_help(cls):
176 print cls.class_get_help()
177 print cls.class_get_help()
177
178
178
179
179 class SingletonConfigurable(Configurable):
180 class SingletonConfigurable(Configurable):
180 """A configurable that only allows one instance.
181 """A configurable that only allows one instance.
181
182
182 This class is for classes that should only have one instance of itself
183 This class is for classes that should only have one instance of itself
183 or *any* subclass. To create and retrieve such a class use the
184 or *any* subclass. To create and retrieve such a class use the
184 :meth:`SingletonConfigurable.instance` method.
185 :meth:`SingletonConfigurable.instance` method.
185 """
186 """
186
187
187 _instance = None
188 _instance = None
188
189
189 @classmethod
190 @classmethod
190 def _walk_mro(cls):
191 def _walk_mro(cls):
191 """Walk the cls.mro() for parent classes that are also singletons
192 """Walk the cls.mro() for parent classes that are also singletons
192
193
193 For use in instance()
194 For use in instance()
194 """
195 """
195
196
196 for subclass in cls.mro():
197 for subclass in cls.mro():
197 if issubclass(cls, subclass) and \
198 if issubclass(cls, subclass) and \
198 issubclass(subclass, SingletonConfigurable) and \
199 issubclass(subclass, SingletonConfigurable) and \
199 subclass != SingletonConfigurable:
200 subclass != SingletonConfigurable:
200 yield subclass
201 yield subclass
201
202
202 @classmethod
203 @classmethod
203 def clear_instance(cls):
204 def clear_instance(cls):
204 """unset _instance for this class and singleton parents.
205 """unset _instance for this class and singleton parents.
205 """
206 """
206 if not cls.initialized():
207 if not cls.initialized():
207 return
208 return
208 for subclass in cls._walk_mro():
209 for subclass in cls._walk_mro():
209 if isinstance(subclass._instance, cls):
210 if isinstance(subclass._instance, cls):
210 # only clear instances that are instances
211 # only clear instances that are instances
211 # of the calling class
212 # of the calling class
212 subclass._instance = None
213 subclass._instance = None
213
214
214 @classmethod
215 @classmethod
215 def instance(cls, *args, **kwargs):
216 def instance(cls, *args, **kwargs):
216 """Returns a global instance of this class.
217 """Returns a global instance of this class.
217
218
218 This method create a new instance if none have previously been created
219 This method create a new instance if none have previously been created
219 and returns a previously created instance is one already exists.
220 and returns a previously created instance is one already exists.
220
221
221 The arguments and keyword arguments passed to this method are passed
222 The arguments and keyword arguments passed to this method are passed
222 on to the :meth:`__init__` method of the class upon instantiation.
223 on to the :meth:`__init__` method of the class upon instantiation.
223
224
224 Examples
225 Examples
225 --------
226 --------
226
227
227 Create a singleton class using instance, and retrieve it::
228 Create a singleton class using instance, and retrieve it::
228
229
229 >>> from IPython.config.configurable import SingletonConfigurable
230 >>> from IPython.config.configurable import SingletonConfigurable
230 >>> class Foo(SingletonConfigurable): pass
231 >>> class Foo(SingletonConfigurable): pass
231 >>> foo = Foo.instance()
232 >>> foo = Foo.instance()
232 >>> foo == Foo.instance()
233 >>> foo == Foo.instance()
233 True
234 True
234
235
235 Create a subclass that is retrived using the base class instance::
236 Create a subclass that is retrived using the base class instance::
236
237
237 >>> class Bar(SingletonConfigurable): pass
238 >>> class Bar(SingletonConfigurable): pass
238 >>> class Bam(Bar): pass
239 >>> class Bam(Bar): pass
239 >>> bam = Bam.instance()
240 >>> bam = Bam.instance()
240 >>> bam == Bar.instance()
241 >>> bam == Bar.instance()
241 True
242 True
242 """
243 """
243 # Create and save the instance
244 # Create and save the instance
244 if cls._instance is None:
245 if cls._instance is None:
245 inst = cls(*args, **kwargs)
246 inst = cls(*args, **kwargs)
246 # Now make sure that the instance will also be returned by
247 # Now make sure that the instance will also be returned by
247 # parent classes' _instance attribute.
248 # parent classes' _instance attribute.
248 for subclass in cls._walk_mro():
249 for subclass in cls._walk_mro():
249 subclass._instance = inst
250 subclass._instance = inst
250
251
251 if isinstance(cls._instance, cls):
252 if isinstance(cls._instance, cls):
252 return cls._instance
253 return cls._instance
253 else:
254 else:
254 raise MultipleInstanceError(
255 raise MultipleInstanceError(
255 'Multiple incompatible subclass instances of '
256 'Multiple incompatible subclass instances of '
256 '%s are being created.' % cls.__name__
257 '%s are being created.' % cls.__name__
257 )
258 )
258
259
259 @classmethod
260 @classmethod
260 def initialized(cls):
261 def initialized(cls):
261 """Has an instance been created?"""
262 """Has an instance been created?"""
262 return hasattr(cls, "_instance") and cls._instance is not None
263 return hasattr(cls, "_instance") and cls._instance is not None
263
264
264
265
265 class LoggingConfigurable(Configurable):
266 class LoggingConfigurable(Configurable):
266 """A parent class for Configurables that log.
267 """A parent class for Configurables that log.
267
268
268 Subclasses have a log trait, and the default behavior
269 Subclasses have a log trait, and the default behavior
269 is to get the logger from the currently running Application
270 is to get the logger from the currently running Application
270 via Application.instance().log.
271 via Application.instance().log.
271 """
272 """
272
273
273 log = Instance('logging.Logger')
274 log = Instance('logging.Logger')
274 def _log_default(self):
275 def _log_default(self):
275 from IPython.config.application import Application
276 from IPython.config.application import Application
276 return Application.instance().log
277 return Application.instance().log
277
278
278 No newline at end of file
279
@@ -1,513 +1,514 b''
1 """A simple configuration system.
1 """A simple configuration system.
2
2
3 Authors
3 Authors
4 -------
4 -------
5 * Brian Granger
5 * Brian Granger
6 * Fernando Perez
6 * Fernando Perez
7 * Min RK
7 """
8 """
8
9
9 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
10 # Copyright (C) 2008-2009 The IPython Development Team
11 # Copyright (C) 2008-2011 The IPython Development Team
11 #
12 #
12 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
15
16
16 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
17 # Imports
18 # Imports
18 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
19
20
20 import __builtin__
21 import __builtin__
21 import re
22 import re
22 import sys
23 import sys
23
24
24 from IPython.external import argparse
25 from IPython.external import argparse
25 from IPython.utils.path import filefind
26 from IPython.utils.path import filefind
26
27
27 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
28 # Exceptions
29 # Exceptions
29 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
30
31
31
32
32 class ConfigError(Exception):
33 class ConfigError(Exception):
33 pass
34 pass
34
35
35
36
36 class ConfigLoaderError(ConfigError):
37 class ConfigLoaderError(ConfigError):
37 pass
38 pass
38
39
39 class ArgumentError(ConfigLoaderError):
40 class ArgumentError(ConfigLoaderError):
40 pass
41 pass
41
42
42 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
43 # Argparse fix
44 # Argparse fix
44 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
45
46
46 # Unfortunately argparse by default prints help messages to stderr instead of
47 # Unfortunately argparse by default prints help messages to stderr instead of
47 # stdout. This makes it annoying to capture long help screens at the command
48 # stdout. This makes it annoying to capture long help screens at the command
48 # line, since one must know how to pipe stderr, which many users don't know how
49 # line, since one must know how to pipe stderr, which many users don't know how
49 # to do. So we override the print_help method with one that defaults to
50 # to do. So we override the print_help method with one that defaults to
50 # stdout and use our class instead.
51 # stdout and use our class instead.
51
52
52 class ArgumentParser(argparse.ArgumentParser):
53 class ArgumentParser(argparse.ArgumentParser):
53 """Simple argparse subclass that prints help to stdout by default."""
54 """Simple argparse subclass that prints help to stdout by default."""
54
55
55 def print_help(self, file=None):
56 def print_help(self, file=None):
56 if file is None:
57 if file is None:
57 file = sys.stdout
58 file = sys.stdout
58 return super(ArgumentParser, self).print_help(file)
59 return super(ArgumentParser, self).print_help(file)
59
60
60 print_help.__doc__ = argparse.ArgumentParser.print_help.__doc__
61 print_help.__doc__ = argparse.ArgumentParser.print_help.__doc__
61
62
62 #-----------------------------------------------------------------------------
63 #-----------------------------------------------------------------------------
63 # Config class for holding config information
64 # Config class for holding config information
64 #-----------------------------------------------------------------------------
65 #-----------------------------------------------------------------------------
65
66
66
67
67 class Config(dict):
68 class Config(dict):
68 """An attribute based dict that can do smart merges."""
69 """An attribute based dict that can do smart merges."""
69
70
70 def __init__(self, *args, **kwds):
71 def __init__(self, *args, **kwds):
71 dict.__init__(self, *args, **kwds)
72 dict.__init__(self, *args, **kwds)
72 # This sets self.__dict__ = self, but it has to be done this way
73 # This sets self.__dict__ = self, but it has to be done this way
73 # because we are also overriding __setattr__.
74 # because we are also overriding __setattr__.
74 dict.__setattr__(self, '__dict__', self)
75 dict.__setattr__(self, '__dict__', self)
75
76
76 def _merge(self, other):
77 def _merge(self, other):
77 to_update = {}
78 to_update = {}
78 for k, v in other.iteritems():
79 for k, v in other.iteritems():
79 if not self.has_key(k):
80 if not self.has_key(k):
80 to_update[k] = v
81 to_update[k] = v
81 else: # I have this key
82 else: # I have this key
82 if isinstance(v, Config):
83 if isinstance(v, Config):
83 # Recursively merge common sub Configs
84 # Recursively merge common sub Configs
84 self[k]._merge(v)
85 self[k]._merge(v)
85 else:
86 else:
86 # Plain updates for non-Configs
87 # Plain updates for non-Configs
87 to_update[k] = v
88 to_update[k] = v
88
89
89 self.update(to_update)
90 self.update(to_update)
90
91
91 def _is_section_key(self, key):
92 def _is_section_key(self, key):
92 if key[0].upper()==key[0] and not key.startswith('_'):
93 if key[0].upper()==key[0] and not key.startswith('_'):
93 return True
94 return True
94 else:
95 else:
95 return False
96 return False
96
97
97 def __contains__(self, key):
98 def __contains__(self, key):
98 if self._is_section_key(key):
99 if self._is_section_key(key):
99 return True
100 return True
100 else:
101 else:
101 return super(Config, self).__contains__(key)
102 return super(Config, self).__contains__(key)
102 # .has_key is deprecated for dictionaries.
103 # .has_key is deprecated for dictionaries.
103 has_key = __contains__
104 has_key = __contains__
104
105
105 def _has_section(self, key):
106 def _has_section(self, key):
106 if self._is_section_key(key):
107 if self._is_section_key(key):
107 if super(Config, self).__contains__(key):
108 if super(Config, self).__contains__(key):
108 return True
109 return True
109 return False
110 return False
110
111
111 def copy(self):
112 def copy(self):
112 return type(self)(dict.copy(self))
113 return type(self)(dict.copy(self))
113
114
114 def __copy__(self):
115 def __copy__(self):
115 return self.copy()
116 return self.copy()
116
117
117 def __deepcopy__(self, memo):
118 def __deepcopy__(self, memo):
118 import copy
119 import copy
119 return type(self)(copy.deepcopy(self.items()))
120 return type(self)(copy.deepcopy(self.items()))
120
121
121 def __getitem__(self, key):
122 def __getitem__(self, key):
122 # We cannot use directly self._is_section_key, because it triggers
123 # We cannot use directly self._is_section_key, because it triggers
123 # infinite recursion on top of PyPy. Instead, we manually fish the
124 # infinite recursion on top of PyPy. Instead, we manually fish the
124 # bound method.
125 # bound method.
125 is_section_key = self.__class__._is_section_key.__get__(self)
126 is_section_key = self.__class__._is_section_key.__get__(self)
126
127
127 # Because we use this for an exec namespace, we need to delegate
128 # Because we use this for an exec namespace, we need to delegate
128 # the lookup of names in __builtin__ to itself. This means
129 # the lookup of names in __builtin__ to itself. This means
129 # that you can't have section or attribute names that are
130 # that you can't have section or attribute names that are
130 # builtins.
131 # builtins.
131 try:
132 try:
132 return getattr(__builtin__, key)
133 return getattr(__builtin__, key)
133 except AttributeError:
134 except AttributeError:
134 pass
135 pass
135 if is_section_key(key):
136 if is_section_key(key):
136 try:
137 try:
137 return dict.__getitem__(self, key)
138 return dict.__getitem__(self, key)
138 except KeyError:
139 except KeyError:
139 c = Config()
140 c = Config()
140 dict.__setitem__(self, key, c)
141 dict.__setitem__(self, key, c)
141 return c
142 return c
142 else:
143 else:
143 return dict.__getitem__(self, key)
144 return dict.__getitem__(self, key)
144
145
145 def __setitem__(self, key, value):
146 def __setitem__(self, key, value):
146 # Don't allow names in __builtin__ to be modified.
147 # Don't allow names in __builtin__ to be modified.
147 if hasattr(__builtin__, key):
148 if hasattr(__builtin__, key):
148 raise ConfigError('Config variable names cannot have the same name '
149 raise ConfigError('Config variable names cannot have the same name '
149 'as a Python builtin: %s' % key)
150 'as a Python builtin: %s' % key)
150 if self._is_section_key(key):
151 if self._is_section_key(key):
151 if not isinstance(value, Config):
152 if not isinstance(value, Config):
152 raise ValueError('values whose keys begin with an uppercase '
153 raise ValueError('values whose keys begin with an uppercase '
153 'char must be Config instances: %r, %r' % (key, value))
154 'char must be Config instances: %r, %r' % (key, value))
154 else:
155 else:
155 dict.__setitem__(self, key, value)
156 dict.__setitem__(self, key, value)
156
157
157 def __getattr__(self, key):
158 def __getattr__(self, key):
158 try:
159 try:
159 return self.__getitem__(key)
160 return self.__getitem__(key)
160 except KeyError, e:
161 except KeyError, e:
161 raise AttributeError(e)
162 raise AttributeError(e)
162
163
163 def __setattr__(self, key, value):
164 def __setattr__(self, key, value):
164 try:
165 try:
165 self.__setitem__(key, value)
166 self.__setitem__(key, value)
166 except KeyError, e:
167 except KeyError, e:
167 raise AttributeError(e)
168 raise AttributeError(e)
168
169
169 def __delattr__(self, key):
170 def __delattr__(self, key):
170 try:
171 try:
171 dict.__delitem__(self, key)
172 dict.__delitem__(self, key)
172 except KeyError, e:
173 except KeyError, e:
173 raise AttributeError(e)
174 raise AttributeError(e)
174
175
175
176
176 #-----------------------------------------------------------------------------
177 #-----------------------------------------------------------------------------
177 # Config loading classes
178 # Config loading classes
178 #-----------------------------------------------------------------------------
179 #-----------------------------------------------------------------------------
179
180
180
181
181 class ConfigLoader(object):
182 class ConfigLoader(object):
182 """A object for loading configurations from just about anywhere.
183 """A object for loading configurations from just about anywhere.
183
184
184 The resulting configuration is packaged as a :class:`Struct`.
185 The resulting configuration is packaged as a :class:`Struct`.
185
186
186 Notes
187 Notes
187 -----
188 -----
188 A :class:`ConfigLoader` does one thing: load a config from a source
189 A :class:`ConfigLoader` does one thing: load a config from a source
189 (file, command line arguments) and returns the data as a :class:`Struct`.
190 (file, command line arguments) and returns the data as a :class:`Struct`.
190 There are lots of things that :class:`ConfigLoader` does not do. It does
191 There are lots of things that :class:`ConfigLoader` does not do. It does
191 not implement complex logic for finding config files. It does not handle
192 not implement complex logic for finding config files. It does not handle
192 default values or merge multiple configs. These things need to be
193 default values or merge multiple configs. These things need to be
193 handled elsewhere.
194 handled elsewhere.
194 """
195 """
195
196
196 def __init__(self):
197 def __init__(self):
197 """A base class for config loaders.
198 """A base class for config loaders.
198
199
199 Examples
200 Examples
200 --------
201 --------
201
202
202 >>> cl = ConfigLoader()
203 >>> cl = ConfigLoader()
203 >>> config = cl.load_config()
204 >>> config = cl.load_config()
204 >>> config
205 >>> config
205 {}
206 {}
206 """
207 """
207 self.clear()
208 self.clear()
208
209
209 def clear(self):
210 def clear(self):
210 self.config = Config()
211 self.config = Config()
211
212
212 def load_config(self):
213 def load_config(self):
213 """Load a config from somewhere, return a :class:`Config` instance.
214 """Load a config from somewhere, return a :class:`Config` instance.
214
215
215 Usually, this will cause self.config to be set and then returned.
216 Usually, this will cause self.config to be set and then returned.
216 However, in most cases, :meth:`ConfigLoader.clear` should be called
217 However, in most cases, :meth:`ConfigLoader.clear` should be called
217 to erase any previous state.
218 to erase any previous state.
218 """
219 """
219 self.clear()
220 self.clear()
220 return self.config
221 return self.config
221
222
222
223
223 class FileConfigLoader(ConfigLoader):
224 class FileConfigLoader(ConfigLoader):
224 """A base class for file based configurations.
225 """A base class for file based configurations.
225
226
226 As we add more file based config loaders, the common logic should go
227 As we add more file based config loaders, the common logic should go
227 here.
228 here.
228 """
229 """
229 pass
230 pass
230
231
231
232
232 class PyFileConfigLoader(FileConfigLoader):
233 class PyFileConfigLoader(FileConfigLoader):
233 """A config loader for pure python files.
234 """A config loader for pure python files.
234
235
235 This calls execfile on a plain python file and looks for attributes
236 This calls execfile on a plain python file and looks for attributes
236 that are all caps. These attribute are added to the config Struct.
237 that are all caps. These attribute are added to the config Struct.
237 """
238 """
238
239
239 def __init__(self, filename, path=None):
240 def __init__(self, filename, path=None):
240 """Build a config loader for a filename and path.
241 """Build a config loader for a filename and path.
241
242
242 Parameters
243 Parameters
243 ----------
244 ----------
244 filename : str
245 filename : str
245 The file name of the config file.
246 The file name of the config file.
246 path : str, list, tuple
247 path : str, list, tuple
247 The path to search for the config file on, or a sequence of
248 The path to search for the config file on, or a sequence of
248 paths to try in order.
249 paths to try in order.
249 """
250 """
250 super(PyFileConfigLoader, self).__init__()
251 super(PyFileConfigLoader, self).__init__()
251 self.filename = filename
252 self.filename = filename
252 self.path = path
253 self.path = path
253 self.full_filename = ''
254 self.full_filename = ''
254 self.data = None
255 self.data = None
255
256
256 def load_config(self):
257 def load_config(self):
257 """Load the config from a file and return it as a Struct."""
258 """Load the config from a file and return it as a Struct."""
258 self.clear()
259 self.clear()
259 self._find_file()
260 self._find_file()
260 self._read_file_as_dict()
261 self._read_file_as_dict()
261 self._convert_to_config()
262 self._convert_to_config()
262 return self.config
263 return self.config
263
264
264 def _find_file(self):
265 def _find_file(self):
265 """Try to find the file by searching the paths."""
266 """Try to find the file by searching the paths."""
266 self.full_filename = filefind(self.filename, self.path)
267 self.full_filename = filefind(self.filename, self.path)
267
268
268 def _read_file_as_dict(self):
269 def _read_file_as_dict(self):
269 """Load the config file into self.config, with recursive loading."""
270 """Load the config file into self.config, with recursive loading."""
270 # This closure is made available in the namespace that is used
271 # This closure is made available in the namespace that is used
271 # to exec the config file. This allows users to call
272 # to exec the config file. This allows users to call
272 # load_subconfig('myconfig.py') to load config files recursively.
273 # load_subconfig('myconfig.py') to load config files recursively.
273 # It needs to be a closure because it has references to self.path
274 # It needs to be a closure because it has references to self.path
274 # and self.config. The sub-config is loaded with the same path
275 # and self.config. The sub-config is loaded with the same path
275 # as the parent, but it uses an empty config which is then merged
276 # as the parent, but it uses an empty config which is then merged
276 # with the parents.
277 # with the parents.
277 def load_subconfig(fname):
278 def load_subconfig(fname):
278 loader = PyFileConfigLoader(fname, self.path)
279 loader = PyFileConfigLoader(fname, self.path)
279 try:
280 try:
280 sub_config = loader.load_config()
281 sub_config = loader.load_config()
281 except IOError:
282 except IOError:
282 # Pass silently if the sub config is not there. This happens
283 # Pass silently if the sub config is not there. This happens
283 # when a user us using a profile, but not the default config.
284 # when a user us using a profile, but not the default config.
284 pass
285 pass
285 else:
286 else:
286 self.config._merge(sub_config)
287 self.config._merge(sub_config)
287
288
288 # Again, this needs to be a closure and should be used in config
289 # Again, this needs to be a closure and should be used in config
289 # files to get the config being loaded.
290 # files to get the config being loaded.
290 def get_config():
291 def get_config():
291 return self.config
292 return self.config
292
293
293 namespace = dict(load_subconfig=load_subconfig, get_config=get_config)
294 namespace = dict(load_subconfig=load_subconfig, get_config=get_config)
294 fs_encoding = sys.getfilesystemencoding() or 'ascii'
295 fs_encoding = sys.getfilesystemencoding() or 'ascii'
295 conf_filename = self.full_filename.encode(fs_encoding)
296 conf_filename = self.full_filename.encode(fs_encoding)
296 execfile(conf_filename, namespace)
297 execfile(conf_filename, namespace)
297
298
298 def _convert_to_config(self):
299 def _convert_to_config(self):
299 if self.data is None:
300 if self.data is None:
300 ConfigLoaderError('self.data does not exist')
301 ConfigLoaderError('self.data does not exist')
301
302
302
303
303 class CommandLineConfigLoader(ConfigLoader):
304 class CommandLineConfigLoader(ConfigLoader):
304 """A config loader for command line arguments.
305 """A config loader for command line arguments.
305
306
306 As we add more command line based loaders, the common logic should go
307 As we add more command line based loaders, the common logic should go
307 here.
308 here.
308 """
309 """
309
310
310 kv_pattern = re.compile(r'[A-Za-z]\w*(\.\w+)*\=.*')
311 kv_pattern = re.compile(r'[A-Za-z]\w*(\.\w+)*\=.*')
311 flag_pattern = re.compile(r'\-\-\w+(\-\w)*')
312 flag_pattern = re.compile(r'\-\-\w+(\-\w)*')
312
313
313 class KeyValueConfigLoader(CommandLineConfigLoader):
314 class KeyValueConfigLoader(CommandLineConfigLoader):
314 """A config loader that loads key value pairs from the command line.
315 """A config loader that loads key value pairs from the command line.
315
316
316 This allows command line options to be gives in the following form::
317 This allows command line options to be gives in the following form::
317
318
318 ipython Global.profile="foo" InteractiveShell.autocall=False
319 ipython Global.profile="foo" InteractiveShell.autocall=False
319 """
320 """
320
321
321 def __init__(self, argv=None, aliases=None, flags=None):
322 def __init__(self, argv=None, aliases=None, flags=None):
322 """Create a key value pair config loader.
323 """Create a key value pair config loader.
323
324
324 Parameters
325 Parameters
325 ----------
326 ----------
326 argv : list
327 argv : list
327 A list that has the form of sys.argv[1:] which has unicode
328 A list that has the form of sys.argv[1:] which has unicode
328 elements of the form u"key=value". If this is None (default),
329 elements of the form u"key=value". If this is None (default),
329 then sys.argv[1:] will be used.
330 then sys.argv[1:] will be used.
330 aliases : dict
331 aliases : dict
331 A dict of aliases for configurable traits.
332 A dict of aliases for configurable traits.
332 Keys are the short aliases, Values are the resolved trait.
333 Keys are the short aliases, Values are the resolved trait.
333 Of the form: `{'alias' : 'Configurable.trait'}`
334 Of the form: `{'alias' : 'Configurable.trait'}`
334 flags : dict
335 flags : dict
335 A dict of flags, keyed by str name. Vaues can be Config objects,
336 A dict of flags, keyed by str name. Vaues can be Config objects,
336 dicts, or "key=value" strings. If Config or dict, when the flag
337 dicts, or "key=value" strings. If Config or dict, when the flag
337 is triggered, The flag is loaded as `self.config.update(m)`.
338 is triggered, The flag is loaded as `self.config.update(m)`.
338
339
339 Returns
340 Returns
340 -------
341 -------
341 config : Config
342 config : Config
342 The resulting Config object.
343 The resulting Config object.
343
344
344 Examples
345 Examples
345 --------
346 --------
346
347
347 >>> from IPython.config.loader import KeyValueConfigLoader
348 >>> from IPython.config.loader import KeyValueConfigLoader
348 >>> cl = KeyValueConfigLoader()
349 >>> cl = KeyValueConfigLoader()
349 >>> cl.load_config(["foo='bar'","A.name='brian'","B.number=0"])
350 >>> cl.load_config(["foo='bar'","A.name='brian'","B.number=0"])
350 {'A': {'name': 'brian'}, 'B': {'number': 0}, 'foo': 'bar'}
351 {'A': {'name': 'brian'}, 'B': {'number': 0}, 'foo': 'bar'}
351 """
352 """
352 if argv is None:
353 if argv is None:
353 argv = sys.argv[1:]
354 argv = sys.argv[1:]
354 self.argv = argv
355 self.argv = argv
355 self.aliases = aliases or {}
356 self.aliases = aliases or {}
356 self.flags = flags or {}
357 self.flags = flags or {}
357
358
358 def load_config(self, argv=None, aliases=None, flags=None):
359 def load_config(self, argv=None, aliases=None, flags=None):
359 """Parse the configuration and generate the Config object.
360 """Parse the configuration and generate the Config object.
360
361
361 Parameters
362 Parameters
362 ----------
363 ----------
363 argv : list, optional
364 argv : list, optional
364 A list that has the form of sys.argv[1:] which has unicode
365 A list that has the form of sys.argv[1:] which has unicode
365 elements of the form u"key=value". If this is None (default),
366 elements of the form u"key=value". If this is None (default),
366 then self.argv will be used.
367 then self.argv will be used.
367 aliases : dict
368 aliases : dict
368 A dict of aliases for configurable traits.
369 A dict of aliases for configurable traits.
369 Keys are the short aliases, Values are the resolved trait.
370 Keys are the short aliases, Values are the resolved trait.
370 Of the form: `{'alias' : 'Configurable.trait'}`
371 Of the form: `{'alias' : 'Configurable.trait'}`
371 flags : dict
372 flags : dict
372 A dict of flags, keyed by str name. Values can be Config objects
373 A dict of flags, keyed by str name. Values can be Config objects
373 or dicts. When the flag is triggered, The config is loaded as
374 or dicts. When the flag is triggered, The config is loaded as
374 `self.config.update(cfg)`.
375 `self.config.update(cfg)`.
375 """
376 """
376 from IPython.config.configurable import Configurable
377 from IPython.config.configurable import Configurable
377
378
378 self.clear()
379 self.clear()
379 if argv is None:
380 if argv is None:
380 argv = self.argv
381 argv = self.argv
381 if aliases is None:
382 if aliases is None:
382 aliases = self.aliases
383 aliases = self.aliases
383 if flags is None:
384 if flags is None:
384 flags = self.flags
385 flags = self.flags
385
386
386 self.extra_args = []
387 self.extra_args = []
387
388
388 for item in argv:
389 for item in argv:
389 if kv_pattern.match(item):
390 if kv_pattern.match(item):
390 lhs,rhs = item.split('=',1)
391 lhs,rhs = item.split('=',1)
391 # Substitute longnames for aliases.
392 # Substitute longnames for aliases.
392 if lhs in aliases:
393 if lhs in aliases:
393 lhs = aliases[lhs]
394 lhs = aliases[lhs]
394 exec_str = 'self.config.' + lhs + '=' + rhs
395 exec_str = 'self.config.' + lhs + '=' + rhs
395 try:
396 try:
396 # Try to see if regular Python syntax will work. This
397 # Try to see if regular Python syntax will work. This
397 # won't handle strings as the quote marks are removed
398 # won't handle strings as the quote marks are removed
398 # by the system shell.
399 # by the system shell.
399 exec exec_str in locals(), globals()
400 exec exec_str in locals(), globals()
400 except (NameError, SyntaxError):
401 except (NameError, SyntaxError):
401 # This case happens if the rhs is a string but without
402 # This case happens if the rhs is a string but without
402 # the quote marks. We add the quote marks and see if
403 # the quote marks. We add the quote marks and see if
403 # it succeeds. If it still fails, we let it raise.
404 # it succeeds. If it still fails, we let it raise.
404 exec_str = 'self.config.' + lhs + '="' + rhs + '"'
405 exec_str = 'self.config.' + lhs + '="' + rhs + '"'
405 exec exec_str in locals(), globals()
406 exec exec_str in locals(), globals()
406 elif flag_pattern.match(item):
407 elif flag_pattern.match(item):
407 # trim leading '--'
408 # trim leading '--'
408 m = item[2:]
409 m = item[2:]
409 cfg,_ = flags.get(m, (None,None))
410 cfg,_ = flags.get(m, (None,None))
410 if cfg is None:
411 if cfg is None:
411 raise ArgumentError("Unrecognized flag: %r"%item)
412 raise ArgumentError("Unrecognized flag: %r"%item)
412 elif isinstance(cfg, (dict, Config)):
413 elif isinstance(cfg, (dict, Config)):
413 # don't clobber whole config sections, update
414 # don't clobber whole config sections, update
414 # each section from config:
415 # each section from config:
415 for sec,c in cfg.iteritems():
416 for sec,c in cfg.iteritems():
416 self.config[sec].update(c)
417 self.config[sec].update(c)
417 else:
418 else:
418 raise ValueError("Invalid flag: %r"%flag)
419 raise ValueError("Invalid flag: %r"%flag)
419 elif item.startswith('-'):
420 elif item.startswith('-'):
420 # this shouldn't ever be valid
421 # this shouldn't ever be valid
421 raise ArgumentError("Invalid argument: %r"%item)
422 raise ArgumentError("Invalid argument: %r"%item)
422 else:
423 else:
423 # keep all args that aren't valid in a list,
424 # keep all args that aren't valid in a list,
424 # in case our parent knows what to do with them.
425 # in case our parent knows what to do with them.
425 self.extra_args.append(item)
426 self.extra_args.append(item)
426 return self.config
427 return self.config
427
428
428 class ArgParseConfigLoader(CommandLineConfigLoader):
429 class ArgParseConfigLoader(CommandLineConfigLoader):
429 """A loader that uses the argparse module to load from the command line."""
430 """A loader that uses the argparse module to load from the command line."""
430
431
431 def __init__(self, argv=None, *parser_args, **parser_kw):
432 def __init__(self, argv=None, *parser_args, **parser_kw):
432 """Create a config loader for use with argparse.
433 """Create a config loader for use with argparse.
433
434
434 Parameters
435 Parameters
435 ----------
436 ----------
436
437
437 argv : optional, list
438 argv : optional, list
438 If given, used to read command-line arguments from, otherwise
439 If given, used to read command-line arguments from, otherwise
439 sys.argv[1:] is used.
440 sys.argv[1:] is used.
440
441
441 parser_args : tuple
442 parser_args : tuple
442 A tuple of positional arguments that will be passed to the
443 A tuple of positional arguments that will be passed to the
443 constructor of :class:`argparse.ArgumentParser`.
444 constructor of :class:`argparse.ArgumentParser`.
444
445
445 parser_kw : dict
446 parser_kw : dict
446 A tuple of keyword arguments that will be passed to the
447 A tuple of keyword arguments that will be passed to the
447 constructor of :class:`argparse.ArgumentParser`.
448 constructor of :class:`argparse.ArgumentParser`.
448
449
449 Returns
450 Returns
450 -------
451 -------
451 config : Config
452 config : Config
452 The resulting Config object.
453 The resulting Config object.
453 """
454 """
454 super(CommandLineConfigLoader, self).__init__()
455 super(CommandLineConfigLoader, self).__init__()
455 if argv == None:
456 if argv == None:
456 argv = sys.argv[1:]
457 argv = sys.argv[1:]
457 self.argv = argv
458 self.argv = argv
458 self.parser_args = parser_args
459 self.parser_args = parser_args
459 self.version = parser_kw.pop("version", None)
460 self.version = parser_kw.pop("version", None)
460 kwargs = dict(argument_default=argparse.SUPPRESS)
461 kwargs = dict(argument_default=argparse.SUPPRESS)
461 kwargs.update(parser_kw)
462 kwargs.update(parser_kw)
462 self.parser_kw = kwargs
463 self.parser_kw = kwargs
463
464
464 def load_config(self, argv=None):
465 def load_config(self, argv=None):
465 """Parse command line arguments and return as a Config object.
466 """Parse command line arguments and return as a Config object.
466
467
467 Parameters
468 Parameters
468 ----------
469 ----------
469
470
470 args : optional, list
471 args : optional, list
471 If given, a list with the structure of sys.argv[1:] to parse
472 If given, a list with the structure of sys.argv[1:] to parse
472 arguments from. If not given, the instance's self.argv attribute
473 arguments from. If not given, the instance's self.argv attribute
473 (given at construction time) is used."""
474 (given at construction time) is used."""
474 self.clear()
475 self.clear()
475 if argv is None:
476 if argv is None:
476 argv = self.argv
477 argv = self.argv
477 self._create_parser()
478 self._create_parser()
478 self._parse_args(argv)
479 self._parse_args(argv)
479 self._convert_to_config()
480 self._convert_to_config()
480 return self.config
481 return self.config
481
482
482 def get_extra_args(self):
483 def get_extra_args(self):
483 if hasattr(self, 'extra_args'):
484 if hasattr(self, 'extra_args'):
484 return self.extra_args
485 return self.extra_args
485 else:
486 else:
486 return []
487 return []
487
488
488 def _create_parser(self):
489 def _create_parser(self):
489 self.parser = ArgumentParser(*self.parser_args, **self.parser_kw)
490 self.parser = ArgumentParser(*self.parser_args, **self.parser_kw)
490 self._add_arguments()
491 self._add_arguments()
491
492
492 def _add_arguments(self):
493 def _add_arguments(self):
493 raise NotImplementedError("subclasses must implement _add_arguments")
494 raise NotImplementedError("subclasses must implement _add_arguments")
494
495
495 def _parse_args(self, args):
496 def _parse_args(self, args):
496 """self.parser->self.parsed_data"""
497 """self.parser->self.parsed_data"""
497 # decode sys.argv to support unicode command-line options
498 # decode sys.argv to support unicode command-line options
498 uargs = []
499 uargs = []
499 for a in args:
500 for a in args:
500 if isinstance(a, str):
501 if isinstance(a, str):
501 # don't decode if we already got unicode
502 # don't decode if we already got unicode
502 a = a.decode(sys.stdin.encoding or
503 a = a.decode(sys.stdin.encoding or
503 sys.getdefaultencoding())
504 sys.getdefaultencoding())
504 uargs.append(a)
505 uargs.append(a)
505 self.parsed_data, self.extra_args = self.parser.parse_known_args(uargs)
506 self.parsed_data, self.extra_args = self.parser.parse_known_args(uargs)
506
507
507 def _convert_to_config(self):
508 def _convert_to_config(self):
508 """self.parsed_data->self.config"""
509 """self.parsed_data->self.config"""
509 for k, v in vars(self.parsed_data).iteritems():
510 for k, v in vars(self.parsed_data).iteritems():
510 exec_str = 'self.config.' + k + '= v'
511 exec_str = 'self.config.' + k + '= v'
511 exec exec_str in locals(), globals()
512 exec exec_str in locals(), globals()
512
513
513
514
@@ -1,435 +1,436 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 An application for IPython.
3 An application for IPython.
4
4
5 All top-level applications should use the classes in this module for
5 All top-level applications should use the classes in this module for
6 handling configuration and creating componenets.
6 handling configuration and creating componenets.
7
7
8 The job of an :class:`Application` is to create the master configuration
8 The job of an :class:`Application` is to create the master configuration
9 object and then create the configurable objects, passing the config to them.
9 object and then create the configurable objects, passing the config to them.
10
10
11 Authors:
11 Authors:
12
12
13 * Brian Granger
13 * Brian Granger
14 * Fernando Perez
14 * Fernando Perez
15 * Min RK
15
16
16 Notes
17 Notes
17 -----
18 -----
18 """
19 """
19
20
20 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
21 # Copyright (C) 2008-2009 The IPython Development Team
22 # Copyright (C) 2008-2009 The IPython Development Team
22 #
23 #
23 # Distributed under the terms of the BSD License. The full license is in
24 # Distributed under the terms of the BSD License. The full license is in
24 # the file COPYING, distributed as part of this software.
25 # the file COPYING, distributed as part of this software.
25 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
26
27
27 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
28 # Imports
29 # Imports
29 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
30
31
31 import logging
32 import logging
32 import os
33 import os
33 import shutil
34 import shutil
34 import sys
35 import sys
35
36
36 from IPython.config.application import Application
37 from IPython.config.application import Application
37 from IPython.config.configurable import Configurable
38 from IPython.config.configurable import Configurable
38 from IPython.config.loader import Config
39 from IPython.config.loader import Config
39 from IPython.core import release, crashhandler
40 from IPython.core import release, crashhandler
40 from IPython.utils.path import get_ipython_dir, get_ipython_package_dir, expand_path
41 from IPython.utils.path import get_ipython_dir, get_ipython_package_dir, expand_path
41 from IPython.utils.traitlets import List, Unicode, Type, Bool, Dict
42 from IPython.utils.traitlets import List, Unicode, Type, Bool, Dict
42
43
43 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
44 # Classes and functions
45 # Classes and functions
45 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
46
47
47
48
48 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
49 # Module errors
50 # Module errors
50 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
51
52
52 class ProfileDirError(Exception):
53 class ProfileDirError(Exception):
53 pass
54 pass
54
55
55
56
56 #-----------------------------------------------------------------------------
57 #-----------------------------------------------------------------------------
57 # Class for managing profile directories
58 # Class for managing profile directories
58 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
59
60
60 class ProfileDir(Configurable):
61 class ProfileDir(Configurable):
61 """An object to manage the profile directory and its resources.
62 """An object to manage the profile directory and its resources.
62
63
63 The profile directory is used by all IPython applications, to manage
64 The profile directory is used by all IPython applications, to manage
64 configuration, logging and security.
65 configuration, logging and security.
65
66
66 This object knows how to find, create and manage these directories. This
67 This object knows how to find, create and manage these directories. This
67 should be used by any code that wants to handle profiles.
68 should be used by any code that wants to handle profiles.
68 """
69 """
69
70
70 security_dir_name = Unicode('security')
71 security_dir_name = Unicode('security')
71 log_dir_name = Unicode('log')
72 log_dir_name = Unicode('log')
72 pid_dir_name = Unicode('pid')
73 pid_dir_name = Unicode('pid')
73 security_dir = Unicode(u'')
74 security_dir = Unicode(u'')
74 log_dir = Unicode(u'')
75 log_dir = Unicode(u'')
75 pid_dir = Unicode(u'')
76 pid_dir = Unicode(u'')
76
77
77 location = Unicode(u'', config=True,
78 location = Unicode(u'', config=True,
78 help="""Set the profile location directly. This overrides the logic used by the
79 help="""Set the profile location directly. This overrides the logic used by the
79 `profile` option.""",
80 `profile` option.""",
80 )
81 )
81
82
82 _location_isset = Bool(False) # flag for detecting multiply set location
83 _location_isset = Bool(False) # flag for detecting multiply set location
83
84
84 def _location_changed(self, name, old, new):
85 def _location_changed(self, name, old, new):
85 if self._location_isset:
86 if self._location_isset:
86 raise RuntimeError("Cannot set profile location more than once.")
87 raise RuntimeError("Cannot set profile location more than once.")
87 self._location_isset = True
88 self._location_isset = True
88 if not os.path.isdir(new):
89 if not os.path.isdir(new):
89 os.makedirs(new)
90 os.makedirs(new)
90
91
91 # ensure config files exist:
92 # ensure config files exist:
92 self.security_dir = os.path.join(new, self.security_dir_name)
93 self.security_dir = os.path.join(new, self.security_dir_name)
93 self.log_dir = os.path.join(new, self.log_dir_name)
94 self.log_dir = os.path.join(new, self.log_dir_name)
94 self.pid_dir = os.path.join(new, self.pid_dir_name)
95 self.pid_dir = os.path.join(new, self.pid_dir_name)
95 self.check_dirs()
96 self.check_dirs()
96
97
97 def _log_dir_changed(self, name, old, new):
98 def _log_dir_changed(self, name, old, new):
98 self.check_log_dir()
99 self.check_log_dir()
99
100
100 def check_log_dir(self):
101 def check_log_dir(self):
101 if not os.path.isdir(self.log_dir):
102 if not os.path.isdir(self.log_dir):
102 os.mkdir(self.log_dir)
103 os.mkdir(self.log_dir)
103
104
104 def _security_dir_changed(self, name, old, new):
105 def _security_dir_changed(self, name, old, new):
105 self.check_security_dir()
106 self.check_security_dir()
106
107
107 def check_security_dir(self):
108 def check_security_dir(self):
108 if not os.path.isdir(self.security_dir):
109 if not os.path.isdir(self.security_dir):
109 os.mkdir(self.security_dir, 0700)
110 os.mkdir(self.security_dir, 0700)
110 else:
111 else:
111 os.chmod(self.security_dir, 0700)
112 os.chmod(self.security_dir, 0700)
112
113
113 def _pid_dir_changed(self, name, old, new):
114 def _pid_dir_changed(self, name, old, new):
114 self.check_pid_dir()
115 self.check_pid_dir()
115
116
116 def check_pid_dir(self):
117 def check_pid_dir(self):
117 if not os.path.isdir(self.pid_dir):
118 if not os.path.isdir(self.pid_dir):
118 os.mkdir(self.pid_dir, 0700)
119 os.mkdir(self.pid_dir, 0700)
119 else:
120 else:
120 os.chmod(self.pid_dir, 0700)
121 os.chmod(self.pid_dir, 0700)
121
122
122 def check_dirs(self):
123 def check_dirs(self):
123 self.check_security_dir()
124 self.check_security_dir()
124 self.check_log_dir()
125 self.check_log_dir()
125 self.check_pid_dir()
126 self.check_pid_dir()
126
127
127 def copy_config_file(self, config_file, path=None, overwrite=False):
128 def copy_config_file(self, config_file, path=None, overwrite=False):
128 """Copy a default config file into the active profile directory.
129 """Copy a default config file into the active profile directory.
129
130
130 Default configuration files are kept in :mod:`IPython.config.default`.
131 Default configuration files are kept in :mod:`IPython.config.default`.
131 This function moves these from that location to the working profile
132 This function moves these from that location to the working profile
132 directory.
133 directory.
133 """
134 """
134 dst = os.path.join(self.location, config_file)
135 dst = os.path.join(self.location, config_file)
135 if os.path.isfile(dst) and not overwrite:
136 if os.path.isfile(dst) and not overwrite:
136 return
137 return
137 if path is None:
138 if path is None:
138 path = os.path.join(get_ipython_package_dir(), u'config', u'profile', u'default')
139 path = os.path.join(get_ipython_package_dir(), u'config', u'profile', u'default')
139 src = os.path.join(path, config_file)
140 src = os.path.join(path, config_file)
140 shutil.copy(src, dst)
141 shutil.copy(src, dst)
141
142
142 @classmethod
143 @classmethod
143 def create_profile_dir(cls, profile_dir, config=None):
144 def create_profile_dir(cls, profile_dir, config=None):
144 """Create a new profile directory given a full path.
145 """Create a new profile directory given a full path.
145
146
146 Parameters
147 Parameters
147 ----------
148 ----------
148 profile_dir : str
149 profile_dir : str
149 The full path to the profile directory. If it does exist, it will
150 The full path to the profile directory. If it does exist, it will
150 be used. If not, it will be created.
151 be used. If not, it will be created.
151 """
152 """
152 return cls(location=profile_dir, config=config)
153 return cls(location=profile_dir, config=config)
153
154
154 @classmethod
155 @classmethod
155 def create_profile_dir_by_name(cls, path, name=u'default', config=None):
156 def create_profile_dir_by_name(cls, path, name=u'default', config=None):
156 """Create a profile dir by profile name and path.
157 """Create a profile dir by profile name and path.
157
158
158 Parameters
159 Parameters
159 ----------
160 ----------
160 path : unicode
161 path : unicode
161 The path (directory) to put the profile directory in.
162 The path (directory) to put the profile directory in.
162 name : unicode
163 name : unicode
163 The name of the profile. The name of the profile directory will
164 The name of the profile. The name of the profile directory will
164 be "profile_<profile>".
165 be "profile_<profile>".
165 """
166 """
166 if not os.path.isdir(path):
167 if not os.path.isdir(path):
167 raise ProfileDirError('Directory not found: %s' % path)
168 raise ProfileDirError('Directory not found: %s' % path)
168 profile_dir = os.path.join(path, u'profile_' + name)
169 profile_dir = os.path.join(path, u'profile_' + name)
169 return cls(location=profile_dir, config=config)
170 return cls(location=profile_dir, config=config)
170
171
171 @classmethod
172 @classmethod
172 def find_profile_dir_by_name(cls, ipython_dir, name=u'default', config=None):
173 def find_profile_dir_by_name(cls, ipython_dir, name=u'default', config=None):
173 """Find an existing profile dir by profile name, return its ProfileDir.
174 """Find an existing profile dir by profile name, return its ProfileDir.
174
175
175 This searches through a sequence of paths for a profile dir. If it
176 This searches through a sequence of paths for a profile dir. If it
176 is not found, a :class:`ProfileDirError` exception will be raised.
177 is not found, a :class:`ProfileDirError` exception will be raised.
177
178
178 The search path algorithm is:
179 The search path algorithm is:
179 1. ``os.getcwd()``
180 1. ``os.getcwd()``
180 2. ``ipython_dir``
181 2. ``ipython_dir``
181 3. The directories found in the ":" separated
182 3. The directories found in the ":" separated
182 :env:`IPCLUSTER_DIR_PATH` environment variable.
183 :env:`IPCLUSTER_DIR_PATH` environment variable.
183
184
184 Parameters
185 Parameters
185 ----------
186 ----------
186 ipython_dir : unicode or str
187 ipython_dir : unicode or str
187 The IPython directory to use.
188 The IPython directory to use.
188 name : unicode or str
189 name : unicode or str
189 The name of the profile. The name of the profile directory
190 The name of the profile. The name of the profile directory
190 will be "profile_<profile>".
191 will be "profile_<profile>".
191 """
192 """
192 dirname = u'profile_' + name
193 dirname = u'profile_' + name
193 profile_dir_paths = os.environ.get('IPYTHON_PROFILE_PATH','')
194 profile_dir_paths = os.environ.get('IPYTHON_PROFILE_PATH','')
194 if profile_dir_paths:
195 if profile_dir_paths:
195 profile_dir_paths = profile_dir_paths.split(os.pathsep)
196 profile_dir_paths = profile_dir_paths.split(os.pathsep)
196 else:
197 else:
197 profile_dir_paths = []
198 profile_dir_paths = []
198 paths = [os.getcwd(), ipython_dir] + profile_dir_paths
199 paths = [os.getcwd(), ipython_dir] + profile_dir_paths
199 for p in paths:
200 for p in paths:
200 profile_dir = os.path.join(p, dirname)
201 profile_dir = os.path.join(p, dirname)
201 if os.path.isdir(profile_dir):
202 if os.path.isdir(profile_dir):
202 return cls(location=profile_dir, config=config)
203 return cls(location=profile_dir, config=config)
203 else:
204 else:
204 raise ProfileDirError('Profile directory not found in paths: %s' % dirname)
205 raise ProfileDirError('Profile directory not found in paths: %s' % dirname)
205
206
206 @classmethod
207 @classmethod
207 def find_profile_dir(cls, profile_dir, config=None):
208 def find_profile_dir(cls, profile_dir, config=None):
208 """Find/create a profile dir and return its ProfileDir.
209 """Find/create a profile dir and return its ProfileDir.
209
210
210 This will create the profile directory if it doesn't exist.
211 This will create the profile directory if it doesn't exist.
211
212
212 Parameters
213 Parameters
213 ----------
214 ----------
214 profile_dir : unicode or str
215 profile_dir : unicode or str
215 The path of the profile directory. This is expanded using
216 The path of the profile directory. This is expanded using
216 :func:`IPython.utils.genutils.expand_path`.
217 :func:`IPython.utils.genutils.expand_path`.
217 """
218 """
218 profile_dir = expand_path(profile_dir)
219 profile_dir = expand_path(profile_dir)
219 if not os.path.isdir(profile_dir):
220 if not os.path.isdir(profile_dir):
220 raise ProfileDirError('Profile directory not found: %s' % profile_dir)
221 raise ProfileDirError('Profile directory not found: %s' % profile_dir)
221 return cls(location=profile_dir, config=config)
222 return cls(location=profile_dir, config=config)
222
223
223
224
224 #-----------------------------------------------------------------------------
225 #-----------------------------------------------------------------------------
225 # Base Application Class
226 # Base Application Class
226 #-----------------------------------------------------------------------------
227 #-----------------------------------------------------------------------------
227
228
228 # aliases and flags
229 # aliases and flags
229
230
230 base_aliases = dict(
231 base_aliases = dict(
231 profile='BaseIPythonApplication.profile',
232 profile='BaseIPythonApplication.profile',
232 ipython_dir='BaseIPythonApplication.ipython_dir',
233 ipython_dir='BaseIPythonApplication.ipython_dir',
233 )
234 )
234
235
235 base_flags = dict(
236 base_flags = dict(
236 debug = ({'Application' : {'log_level' : logging.DEBUG}},
237 debug = ({'Application' : {'log_level' : logging.DEBUG}},
237 "set log level to logging.DEBUG (maximize logging output)"),
238 "set log level to logging.DEBUG (maximize logging output)"),
238 quiet = ({'Application' : {'log_level' : logging.CRITICAL}},
239 quiet = ({'Application' : {'log_level' : logging.CRITICAL}},
239 "set log level to logging.CRITICAL (minimize logging output)"),
240 "set log level to logging.CRITICAL (minimize logging output)"),
240 init = ({'BaseIPythonApplication' : {
241 init = ({'BaseIPythonApplication' : {
241 'copy_config_files' : True,
242 'copy_config_files' : True,
242 'auto_create' : True}
243 'auto_create' : True}
243 }, "Initialize profile with default config files")
244 }, "Initialize profile with default config files")
244 )
245 )
245
246
246
247
247 class BaseIPythonApplication(Application):
248 class BaseIPythonApplication(Application):
248
249
249 name = Unicode(u'ipython')
250 name = Unicode(u'ipython')
250 description = Unicode(u'IPython: an enhanced interactive Python shell.')
251 description = Unicode(u'IPython: an enhanced interactive Python shell.')
251 version = Unicode(release.version)
252 version = Unicode(release.version)
252
253
253 aliases = Dict(base_aliases)
254 aliases = Dict(base_aliases)
254 flags = Dict(base_flags)
255 flags = Dict(base_flags)
255
256
256 # Track whether the config_file has changed,
257 # Track whether the config_file has changed,
257 # because some logic happens only if we aren't using the default.
258 # because some logic happens only if we aren't using the default.
258 config_file_specified = Bool(False)
259 config_file_specified = Bool(False)
259
260
260 config_file_name = Unicode(u'ipython_config.py')
261 config_file_name = Unicode(u'ipython_config.py')
261 def _config_file_name_changed(self, name, old, new):
262 def _config_file_name_changed(self, name, old, new):
262 if new != old:
263 if new != old:
263 self.config_file_specified = True
264 self.config_file_specified = True
264
265
265 # The directory that contains IPython's builtin profiles.
266 # The directory that contains IPython's builtin profiles.
266 builtin_profile_dir = Unicode(
267 builtin_profile_dir = Unicode(
267 os.path.join(get_ipython_package_dir(), u'config', u'profile', u'default')
268 os.path.join(get_ipython_package_dir(), u'config', u'profile', u'default')
268 )
269 )
269
270
270 config_file_paths = List(Unicode)
271 config_file_paths = List(Unicode)
271 def _config_file_paths_default(self):
272 def _config_file_paths_default(self):
272 return [os.getcwdu()]
273 return [os.getcwdu()]
273
274
274 profile = Unicode(u'default', config=True,
275 profile = Unicode(u'default', config=True,
275 help="""The IPython profile to use."""
276 help="""The IPython profile to use."""
276 )
277 )
277 def _profile_changed(self, name, old, new):
278 def _profile_changed(self, name, old, new):
278 self.builtin_profile_dir = os.path.join(
279 self.builtin_profile_dir = os.path.join(
279 get_ipython_package_dir(), u'config', u'profile', new
280 get_ipython_package_dir(), u'config', u'profile', new
280 )
281 )
281
282
282
283
283 ipython_dir = Unicode(get_ipython_dir(), config=True,
284 ipython_dir = Unicode(get_ipython_dir(), config=True,
284 help="""
285 help="""
285 The name of the IPython directory. This directory is used for logging
286 The name of the IPython directory. This directory is used for logging
286 configuration (through profiles), history storage, etc. The default
287 configuration (through profiles), history storage, etc. The default
287 is usually $HOME/.ipython. This options can also be specified through
288 is usually $HOME/.ipython. This options can also be specified through
288 the environment variable IPYTHON_DIR.
289 the environment variable IPYTHON_DIR.
289 """
290 """
290 )
291 )
291
292
292 overwrite = Bool(False, config=True,
293 overwrite = Bool(False, config=True,
293 help="""Whether to overwrite existing config files when copying""")
294 help="""Whether to overwrite existing config files when copying""")
294 auto_create = Bool(False, config=True,
295 auto_create = Bool(False, config=True,
295 help="""Whether to create profile dir if it doesn't exist""")
296 help="""Whether to create profile dir if it doesn't exist""")
296
297
297 config_files = List(Unicode)
298 config_files = List(Unicode)
298 def _config_files_default(self):
299 def _config_files_default(self):
299 return [u'ipython_config.py']
300 return [u'ipython_config.py']
300
301
301 copy_config_files = Bool(False, config=True,
302 copy_config_files = Bool(False, config=True,
302 help="""Whether to copy the default config files into the profile dir.""")
303 help="""Whether to copy the default config files into the profile dir.""")
303
304
304 # The class to use as the crash handler.
305 # The class to use as the crash handler.
305 crash_handler_class = Type(crashhandler.CrashHandler)
306 crash_handler_class = Type(crashhandler.CrashHandler)
306
307
307 def __init__(self, **kwargs):
308 def __init__(self, **kwargs):
308 super(BaseIPythonApplication, self).__init__(**kwargs)
309 super(BaseIPythonApplication, self).__init__(**kwargs)
309 # ensure even default IPYTHON_DIR exists
310 # ensure even default IPYTHON_DIR exists
310 if not os.path.exists(self.ipython_dir):
311 if not os.path.exists(self.ipython_dir):
311 self._ipython_dir_changed('ipython_dir', self.ipython_dir, self.ipython_dir)
312 self._ipython_dir_changed('ipython_dir', self.ipython_dir, self.ipython_dir)
312
313
313 #-------------------------------------------------------------------------
314 #-------------------------------------------------------------------------
314 # Various stages of Application creation
315 # Various stages of Application creation
315 #-------------------------------------------------------------------------
316 #-------------------------------------------------------------------------
316
317
317 def init_crash_handler(self):
318 def init_crash_handler(self):
318 """Create a crash handler, typically setting sys.excepthook to it."""
319 """Create a crash handler, typically setting sys.excepthook to it."""
319 self.crash_handler = self.crash_handler_class(self)
320 self.crash_handler = self.crash_handler_class(self)
320 sys.excepthook = self.crash_handler
321 sys.excepthook = self.crash_handler
321
322
322 def _ipython_dir_changed(self, name, old, new):
323 def _ipython_dir_changed(self, name, old, new):
323 if old in sys.path:
324 if old in sys.path:
324 sys.path.remove(old)
325 sys.path.remove(old)
325 sys.path.append(os.path.abspath(new))
326 sys.path.append(os.path.abspath(new))
326 if not os.path.isdir(new):
327 if not os.path.isdir(new):
327 os.makedirs(new, mode=0777)
328 os.makedirs(new, mode=0777)
328 readme = os.path.join(new, 'README')
329 readme = os.path.join(new, 'README')
329 if not os.path.exists(readme):
330 if not os.path.exists(readme):
330 path = os.path.join(get_ipython_package_dir(), u'config', u'profile')
331 path = os.path.join(get_ipython_package_dir(), u'config', u'profile')
331 shutil.copy(os.path.join(path, 'README'), readme)
332 shutil.copy(os.path.join(path, 'README'), readme)
332 self.log.debug("IPYTHON_DIR set to: %s" % new)
333 self.log.debug("IPYTHON_DIR set to: %s" % new)
333
334
334 def load_config_file(self, suppress_errors=True):
335 def load_config_file(self, suppress_errors=True):
335 """Load the config file.
336 """Load the config file.
336
337
337 By default, errors in loading config are handled, and a warning
338 By default, errors in loading config are handled, and a warning
338 printed on screen. For testing, the suppress_errors option is set
339 printed on screen. For testing, the suppress_errors option is set
339 to False, so errors will make tests fail.
340 to False, so errors will make tests fail.
340 """
341 """
341 self.log.debug("Attempting to load config file: %s" %
342 self.log.debug("Attempting to load config file: %s" %
342 self.config_file_name)
343 self.config_file_name)
343 try:
344 try:
344 Application.load_config_file(
345 Application.load_config_file(
345 self,
346 self,
346 self.config_file_name,
347 self.config_file_name,
347 path=self.config_file_paths
348 path=self.config_file_paths
348 )
349 )
349 except IOError:
350 except IOError:
350 # Only warn if the default config file was NOT being used.
351 # Only warn if the default config file was NOT being used.
351 if self.config_file_specified:
352 if self.config_file_specified:
352 self.log.warn("Config file not found, skipping: %s" %
353 self.log.warn("Config file not found, skipping: %s" %
353 self.config_file_name)
354 self.config_file_name)
354 except:
355 except:
355 # For testing purposes.
356 # For testing purposes.
356 if not suppress_errors:
357 if not suppress_errors:
357 raise
358 raise
358 self.log.warn("Error loading config file: %s" %
359 self.log.warn("Error loading config file: %s" %
359 self.config_file_name, exc_info=True)
360 self.config_file_name, exc_info=True)
360
361
361 def init_profile_dir(self):
362 def init_profile_dir(self):
362 """initialize the profile dir"""
363 """initialize the profile dir"""
363 try:
364 try:
364 # location explicitly specified:
365 # location explicitly specified:
365 location = self.config.ProfileDir.location
366 location = self.config.ProfileDir.location
366 except AttributeError:
367 except AttributeError:
367 # location not specified, find by profile name
368 # location not specified, find by profile name
368 try:
369 try:
369 p = ProfileDir.find_profile_dir_by_name(self.ipython_dir, self.profile, self.config)
370 p = ProfileDir.find_profile_dir_by_name(self.ipython_dir, self.profile, self.config)
370 except ProfileDirError:
371 except ProfileDirError:
371 # not found, maybe create it (always create default profile)
372 # not found, maybe create it (always create default profile)
372 if self.auto_create or self.profile=='default':
373 if self.auto_create or self.profile=='default':
373 try:
374 try:
374 p = ProfileDir.create_profile_dir_by_name(self.ipython_dir, self.profile, self.config)
375 p = ProfileDir.create_profile_dir_by_name(self.ipython_dir, self.profile, self.config)
375 except ProfileDirError:
376 except ProfileDirError:
376 self.log.fatal("Could not create profile: %r"%self.profile)
377 self.log.fatal("Could not create profile: %r"%self.profile)
377 self.exit(1)
378 self.exit(1)
378 else:
379 else:
379 self.log.info("Created profile dir: %r"%p.location)
380 self.log.info("Created profile dir: %r"%p.location)
380 else:
381 else:
381 self.log.fatal("Profile %r not found."%self.profile)
382 self.log.fatal("Profile %r not found."%self.profile)
382 self.exit(1)
383 self.exit(1)
383 else:
384 else:
384 self.log.info("Using existing profile dir: %r"%p.location)
385 self.log.info("Using existing profile dir: %r"%p.location)
385 else:
386 else:
386 # location is fully specified
387 # location is fully specified
387 try:
388 try:
388 p = ProfileDir.find_profile_dir(location, self.config)
389 p = ProfileDir.find_profile_dir(location, self.config)
389 except ProfileDirError:
390 except ProfileDirError:
390 # not found, maybe create it
391 # not found, maybe create it
391 if self.auto_create:
392 if self.auto_create:
392 try:
393 try:
393 p = ProfileDir.create_profile_dir(location, self.config)
394 p = ProfileDir.create_profile_dir(location, self.config)
394 except ProfileDirError:
395 except ProfileDirError:
395 self.log.fatal("Could not create profile directory: %r"%location)
396 self.log.fatal("Could not create profile directory: %r"%location)
396 self.exit(1)
397 self.exit(1)
397 else:
398 else:
398 self.log.info("Creating new profile dir: %r"%location)
399 self.log.info("Creating new profile dir: %r"%location)
399 else:
400 else:
400 self.log.fatal("Profile directory %r not found."%location)
401 self.log.fatal("Profile directory %r not found."%location)
401 self.exit(1)
402 self.exit(1)
402 else:
403 else:
403 self.log.info("Using existing profile dir: %r"%location)
404 self.log.info("Using existing profile dir: %r"%location)
404
405
405 self.profile_dir = p
406 self.profile_dir = p
406 self.config_file_paths.append(p.location)
407 self.config_file_paths.append(p.location)
407
408
408 def init_config_files(self):
409 def init_config_files(self):
409 """[optionally] copy default config files into profile dir."""
410 """[optionally] copy default config files into profile dir."""
410 # copy config files
411 # copy config files
411 if self.copy_config_files:
412 if self.copy_config_files:
412 path = self.builtin_profile_dir
413 path = self.builtin_profile_dir
413 src = self.profile
414 src = self.profile
414 if not os.path.exists(path):
415 if not os.path.exists(path):
415 # use default if new profile doesn't have a preset
416 # use default if new profile doesn't have a preset
416 path = None
417 path = None
417 src = 'default'
418 src = 'default'
418
419
419 self.log.debug("Staging %s config files into %r [overwrite=%s]"%(
420 self.log.debug("Staging %s config files into %r [overwrite=%s]"%(
420 src, self.profile_dir.location, self.overwrite)
421 src, self.profile_dir.location, self.overwrite)
421 )
422 )
422
423
423 for cfg in self.config_files:
424 for cfg in self.config_files:
424 self.profile_dir.copy_config_file(cfg, path=path, overwrite=self.overwrite)
425 self.profile_dir.copy_config_file(cfg, path=path, overwrite=self.overwrite)
425
426
426 def initialize(self, argv=None):
427 def initialize(self, argv=None):
427 self.init_crash_handler()
428 self.init_crash_handler()
428 self.parse_command_line(argv)
429 self.parse_command_line(argv)
429 cl_config = self.config
430 cl_config = self.config
430 self.init_profile_dir()
431 self.init_profile_dir()
431 self.init_config_files()
432 self.init_config_files()
432 self.load_config_file()
433 self.load_config_file()
433 # enforce cl-opts override configfile opts:
434 # enforce cl-opts override configfile opts:
434 self.update_config(cl_config)
435 self.update_config(cl_config)
435
436
@@ -1,32 +1,37 b''
1 """The IPython ZMQ-based parallel computing interface."""
1 """The IPython ZMQ-based parallel computing interface.
2
3 Authors:
4
5 * MinRK
6 """
2 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
3 # Copyright (C) 2011 The IPython Development Team
8 # Copyright (C) 2011 The IPython Development Team
4 #
9 #
5 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
8
13
9 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
10 # Imports
15 # Imports
11 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
12
17
13 import os
18 import os
14 import zmq
19 import zmq
15
20
16
21
17 if os.name == 'nt':
22 if os.name == 'nt':
18 if zmq.__version__ < '2.1.7':
23 if zmq.__version__ < '2.1.7':
19 raise ImportError("IPython.parallel requires pyzmq/0MQ >= 2.1.7 on Windows, "
24 raise ImportError("IPython.parallel requires pyzmq/0MQ >= 2.1.7 on Windows, "
20 "and you appear to have %s"%zmq.__version__)
25 "and you appear to have %s"%zmq.__version__)
21 elif zmq.__version__ < '2.1.4':
26 elif zmq.__version__ < '2.1.4':
22 raise ImportError("IPython.parallel requires pyzmq/0MQ >= 2.1.4, you appear to have %s"%zmq.__version__)
27 raise ImportError("IPython.parallel requires pyzmq/0MQ >= 2.1.4, you appear to have %s"%zmq.__version__)
23
28
24 from IPython.utils.pickleutil import Reference
29 from IPython.utils.pickleutil import Reference
25
30
26 from .client.asyncresult import *
31 from .client.asyncresult import *
27 from .client.client import Client
32 from .client.client import Client
28 from .client.remotefunction import *
33 from .client.remotefunction import *
29 from .client.view import *
34 from .client.view import *
30 from .controller.dependency import *
35 from .controller.dependency import *
31
36
32
37
@@ -1,257 +1,263 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython cluster directory
4 The Base Application class for IPython.parallel apps
5
6 Authors:
7
8 * Brian Granger
9 * Min RK
10
5 """
11 """
6
12
7 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
9 #
15 #
10 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
13
19
14 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
15 # Imports
21 # Imports
16 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
17
23
18 from __future__ import with_statement
24 from __future__ import with_statement
19
25
20 import os
26 import os
21 import logging
27 import logging
22 import re
28 import re
23 import sys
29 import sys
24
30
25 from subprocess import Popen, PIPE
31 from subprocess import Popen, PIPE
26
32
27 from IPython.core import release
33 from IPython.core import release
28 from IPython.core.crashhandler import CrashHandler
34 from IPython.core.crashhandler import CrashHandler
29 from IPython.core.newapplication import (
35 from IPython.core.newapplication import (
30 BaseIPythonApplication,
36 BaseIPythonApplication,
31 base_aliases as base_ip_aliases,
37 base_aliases as base_ip_aliases,
32 base_flags as base_ip_flags
38 base_flags as base_ip_flags
33 )
39 )
34 from IPython.utils.path import expand_path
40 from IPython.utils.path import expand_path
35
41
36 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict, List
42 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict, List
37
43
38 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
39 # Module errors
45 # Module errors
40 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
41
47
42 class PIDFileError(Exception):
48 class PIDFileError(Exception):
43 pass
49 pass
44
50
45
51
46 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
47 # Crash handler for this application
53 # Crash handler for this application
48 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
49
55
50
56
51 _message_template = """\
57 _message_template = """\
52 Oops, $self.app_name crashed. We do our best to make it stable, but...
58 Oops, $self.app_name crashed. We do our best to make it stable, but...
53
59
54 A crash report was automatically generated with the following information:
60 A crash report was automatically generated with the following information:
55 - A verbatim copy of the crash traceback.
61 - A verbatim copy of the crash traceback.
56 - Data on your current $self.app_name configuration.
62 - Data on your current $self.app_name configuration.
57
63
58 It was left in the file named:
64 It was left in the file named:
59 \t'$self.crash_report_fname'
65 \t'$self.crash_report_fname'
60 If you can email this file to the developers, the information in it will help
66 If you can email this file to the developers, the information in it will help
61 them in understanding and correcting the problem.
67 them in understanding and correcting the problem.
62
68
63 You can mail it to: $self.contact_name at $self.contact_email
69 You can mail it to: $self.contact_name at $self.contact_email
64 with the subject '$self.app_name Crash Report'.
70 with the subject '$self.app_name Crash Report'.
65
71
66 If you want to do it now, the following command will work (under Unix):
72 If you want to do it now, the following command will work (under Unix):
67 mail -s '$self.app_name Crash Report' $self.contact_email < $self.crash_report_fname
73 mail -s '$self.app_name Crash Report' $self.contact_email < $self.crash_report_fname
68
74
69 To ensure accurate tracking of this issue, please file a report about it at:
75 To ensure accurate tracking of this issue, please file a report about it at:
70 $self.bug_tracker
76 $self.bug_tracker
71 """
77 """
72
78
73 class ParallelCrashHandler(CrashHandler):
79 class ParallelCrashHandler(CrashHandler):
74 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
80 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
75
81
76 message_template = _message_template
82 message_template = _message_template
77
83
78 def __init__(self, app):
84 def __init__(self, app):
79 contact_name = release.authors['Min'][0]
85 contact_name = release.authors['Min'][0]
80 contact_email = release.authors['Min'][1]
86 contact_email = release.authors['Min'][1]
81 bug_tracker = 'http://github.com/ipython/ipython/issues'
87 bug_tracker = 'http://github.com/ipython/ipython/issues'
82 super(ParallelCrashHandler,self).__init__(
88 super(ParallelCrashHandler,self).__init__(
83 app, contact_name, contact_email, bug_tracker
89 app, contact_name, contact_email, bug_tracker
84 )
90 )
85
91
86
92
87 #-----------------------------------------------------------------------------
93 #-----------------------------------------------------------------------------
88 # Main application
94 # Main application
89 #-----------------------------------------------------------------------------
95 #-----------------------------------------------------------------------------
90 base_aliases = {}
96 base_aliases = {}
91 base_aliases.update(base_ip_aliases)
97 base_aliases.update(base_ip_aliases)
92 base_aliases.update({
98 base_aliases.update({
93 'profile_dir' : 'ProfileDir.location',
99 'profile_dir' : 'ProfileDir.location',
94 'log_level' : 'BaseParallelApplication.log_level',
100 'log_level' : 'BaseParallelApplication.log_level',
95 'work_dir' : 'BaseParallelApplication.work_dir',
101 'work_dir' : 'BaseParallelApplication.work_dir',
96 'log_to_file' : 'BaseParallelApplication.log_to_file',
102 'log_to_file' : 'BaseParallelApplication.log_to_file',
97 'clean_logs' : 'BaseParallelApplication.clean_logs',
103 'clean_logs' : 'BaseParallelApplication.clean_logs',
98 'log_url' : 'BaseParallelApplication.log_url',
104 'log_url' : 'BaseParallelApplication.log_url',
99 })
105 })
100
106
101 base_flags = {
107 base_flags = {
102 'log-to-file' : (
108 'log-to-file' : (
103 {'BaseParallelApplication' : {'log_to_file' : True}},
109 {'BaseParallelApplication' : {'log_to_file' : True}},
104 "send log output to a file"
110 "send log output to a file"
105 )
111 )
106 }
112 }
107 base_flags.update(base_ip_flags)
113 base_flags.update(base_ip_flags)
108
114
109 class BaseParallelApplication(BaseIPythonApplication):
115 class BaseParallelApplication(BaseIPythonApplication):
110 """The base Application for IPython.parallel apps
116 """The base Application for IPython.parallel apps
111
117
112 Principle extensions to BaseIPyythonApplication:
118 Principle extensions to BaseIPyythonApplication:
113
119
114 * work_dir
120 * work_dir
115 * remote logging via pyzmq
121 * remote logging via pyzmq
116 * IOLoop instance
122 * IOLoop instance
117 """
123 """
118
124
119 crash_handler_class = ParallelCrashHandler
125 crash_handler_class = ParallelCrashHandler
120
126
121 def _log_level_default(self):
127 def _log_level_default(self):
122 # temporarily override default_log_level to INFO
128 # temporarily override default_log_level to INFO
123 return logging.INFO
129 return logging.INFO
124
130
125 work_dir = Unicode(os.getcwdu(), config=True,
131 work_dir = Unicode(os.getcwdu(), config=True,
126 help='Set the working dir for the process.'
132 help='Set the working dir for the process.'
127 )
133 )
128 def _work_dir_changed(self, name, old, new):
134 def _work_dir_changed(self, name, old, new):
129 self.work_dir = unicode(expand_path(new))
135 self.work_dir = unicode(expand_path(new))
130
136
131 log_to_file = Bool(config=True,
137 log_to_file = Bool(config=True,
132 help="whether to log to a file")
138 help="whether to log to a file")
133
139
134 clean_logs = Bool(False, config=True,
140 clean_logs = Bool(False, config=True,
135 help="whether to cleanup old logfiles before starting")
141 help="whether to cleanup old logfiles before starting")
136
142
137 log_url = Unicode('', config=True,
143 log_url = Unicode('', config=True,
138 help="The ZMQ URL of the iplogger to aggregate logging.")
144 help="The ZMQ URL of the iplogger to aggregate logging.")
139
145
140 def _config_files_default(self):
146 def _config_files_default(self):
141 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
147 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
142
148
143 loop = Instance('zmq.eventloop.ioloop.IOLoop')
149 loop = Instance('zmq.eventloop.ioloop.IOLoop')
144 def _loop_default(self):
150 def _loop_default(self):
145 from zmq.eventloop.ioloop import IOLoop
151 from zmq.eventloop.ioloop import IOLoop
146 return IOLoop.instance()
152 return IOLoop.instance()
147
153
148 aliases = Dict(base_aliases)
154 aliases = Dict(base_aliases)
149 flags = Dict(base_flags)
155 flags = Dict(base_flags)
150
156
151 def initialize(self, argv=None):
157 def initialize(self, argv=None):
152 """initialize the app"""
158 """initialize the app"""
153 super(BaseParallelApplication, self).initialize(argv)
159 super(BaseParallelApplication, self).initialize(argv)
154 self.to_work_dir()
160 self.to_work_dir()
155 self.reinit_logging()
161 self.reinit_logging()
156
162
157 def to_work_dir(self):
163 def to_work_dir(self):
158 wd = self.work_dir
164 wd = self.work_dir
159 if unicode(wd) != os.getcwdu():
165 if unicode(wd) != os.getcwdu():
160 os.chdir(wd)
166 os.chdir(wd)
161 self.log.info("Changing to working dir: %s" % wd)
167 self.log.info("Changing to working dir: %s" % wd)
162 # This is the working dir by now.
168 # This is the working dir by now.
163 sys.path.insert(0, '')
169 sys.path.insert(0, '')
164
170
165 def reinit_logging(self):
171 def reinit_logging(self):
166 # Remove old log files
172 # Remove old log files
167 log_dir = self.profile_dir.log_dir
173 log_dir = self.profile_dir.log_dir
168 if self.clean_logs:
174 if self.clean_logs:
169 for f in os.listdir(log_dir):
175 for f in os.listdir(log_dir):
170 if re.match(r'%s-\d+\.(log|err|out)'%self.name,f):
176 if re.match(r'%s-\d+\.(log|err|out)'%self.name,f):
171 os.remove(os.path.join(log_dir, f))
177 os.remove(os.path.join(log_dir, f))
172 if self.log_to_file:
178 if self.log_to_file:
173 # Start logging to the new log file
179 # Start logging to the new log file
174 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
180 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
175 logfile = os.path.join(log_dir, log_filename)
181 logfile = os.path.join(log_dir, log_filename)
176 open_log_file = open(logfile, 'w')
182 open_log_file = open(logfile, 'w')
177 else:
183 else:
178 open_log_file = None
184 open_log_file = None
179 if open_log_file is not None:
185 if open_log_file is not None:
180 self.log.removeHandler(self._log_handler)
186 self.log.removeHandler(self._log_handler)
181 self._log_handler = logging.StreamHandler(open_log_file)
187 self._log_handler = logging.StreamHandler(open_log_file)
182 self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
188 self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
183 self._log_handler.setFormatter(self._log_formatter)
189 self._log_handler.setFormatter(self._log_formatter)
184 self.log.addHandler(self._log_handler)
190 self.log.addHandler(self._log_handler)
185
191
186 def write_pid_file(self, overwrite=False):
192 def write_pid_file(self, overwrite=False):
187 """Create a .pid file in the pid_dir with my pid.
193 """Create a .pid file in the pid_dir with my pid.
188
194
189 This must be called after pre_construct, which sets `self.pid_dir`.
195 This must be called after pre_construct, which sets `self.pid_dir`.
190 This raises :exc:`PIDFileError` if the pid file exists already.
196 This raises :exc:`PIDFileError` if the pid file exists already.
191 """
197 """
192 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
198 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
193 if os.path.isfile(pid_file):
199 if os.path.isfile(pid_file):
194 pid = self.get_pid_from_file()
200 pid = self.get_pid_from_file()
195 if not overwrite:
201 if not overwrite:
196 raise PIDFileError(
202 raise PIDFileError(
197 'The pid file [%s] already exists. \nThis could mean that this '
203 'The pid file [%s] already exists. \nThis could mean that this '
198 'server is already running with [pid=%s].' % (pid_file, pid)
204 'server is already running with [pid=%s].' % (pid_file, pid)
199 )
205 )
200 with open(pid_file, 'w') as f:
206 with open(pid_file, 'w') as f:
201 self.log.info("Creating pid file: %s" % pid_file)
207 self.log.info("Creating pid file: %s" % pid_file)
202 f.write(repr(os.getpid())+'\n')
208 f.write(repr(os.getpid())+'\n')
203
209
204 def remove_pid_file(self):
210 def remove_pid_file(self):
205 """Remove the pid file.
211 """Remove the pid file.
206
212
207 This should be called at shutdown by registering a callback with
213 This should be called at shutdown by registering a callback with
208 :func:`reactor.addSystemEventTrigger`. This needs to return
214 :func:`reactor.addSystemEventTrigger`. This needs to return
209 ``None``.
215 ``None``.
210 """
216 """
211 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
217 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
212 if os.path.isfile(pid_file):
218 if os.path.isfile(pid_file):
213 try:
219 try:
214 self.log.info("Removing pid file: %s" % pid_file)
220 self.log.info("Removing pid file: %s" % pid_file)
215 os.remove(pid_file)
221 os.remove(pid_file)
216 except:
222 except:
217 self.log.warn("Error removing the pid file: %s" % pid_file)
223 self.log.warn("Error removing the pid file: %s" % pid_file)
218
224
219 def get_pid_from_file(self):
225 def get_pid_from_file(self):
220 """Get the pid from the pid file.
226 """Get the pid from the pid file.
221
227
222 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
228 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
223 """
229 """
224 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
230 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
225 if os.path.isfile(pid_file):
231 if os.path.isfile(pid_file):
226 with open(pid_file, 'r') as f:
232 with open(pid_file, 'r') as f:
227 pid = int(f.read().strip())
233 pid = int(f.read().strip())
228 return pid
234 return pid
229 else:
235 else:
230 raise PIDFileError('pid file not found: %s' % pid_file)
236 raise PIDFileError('pid file not found: %s' % pid_file)
231
237
232 def check_pid(self, pid):
238 def check_pid(self, pid):
233 if os.name == 'nt':
239 if os.name == 'nt':
234 try:
240 try:
235 import ctypes
241 import ctypes
236 # returns 0 if no such process (of ours) exists
242 # returns 0 if no such process (of ours) exists
237 # positive int otherwise
243 # positive int otherwise
238 p = ctypes.windll.kernel32.OpenProcess(1,0,pid)
244 p = ctypes.windll.kernel32.OpenProcess(1,0,pid)
239 except Exception:
245 except Exception:
240 self.log.warn(
246 self.log.warn(
241 "Could not determine whether pid %i is running via `OpenProcess`. "
247 "Could not determine whether pid %i is running via `OpenProcess`. "
242 " Making the likely assumption that it is."%pid
248 " Making the likely assumption that it is."%pid
243 )
249 )
244 return True
250 return True
245 return bool(p)
251 return bool(p)
246 else:
252 else:
247 try:
253 try:
248 p = Popen(['ps','x'], stdout=PIPE, stderr=PIPE)
254 p = Popen(['ps','x'], stdout=PIPE, stderr=PIPE)
249 output,_ = p.communicate()
255 output,_ = p.communicate()
250 except OSError:
256 except OSError:
251 self.log.warn(
257 self.log.warn(
252 "Could not determine whether pid %i is running via `ps x`. "
258 "Could not determine whether pid %i is running via `ps x`. "
253 " Making the likely assumption that it is."%pid
259 " Making the likely assumption that it is."%pid
254 )
260 )
255 return True
261 return True
256 pids = map(int, re.findall(r'^\W*\d+', output, re.MULTILINE))
262 pids = map(int, re.findall(r'^\W*\d+', output, re.MULTILINE))
257 return pid in pids
263 return pid in pids
@@ -1,521 +1,527 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The ipcluster application.
4 The ipcluster application.
5
6 Authors:
7
8 * Brian Granger
9 * MinRK
10
5 """
11 """
6
12
7 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
9 #
15 #
10 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
13
19
14 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
15 # Imports
21 # Imports
16 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
17
23
18 import errno
24 import errno
19 import logging
25 import logging
20 import os
26 import os
21 import re
27 import re
22 import signal
28 import signal
23
29
24 from subprocess import check_call, CalledProcessError, PIPE
30 from subprocess import check_call, CalledProcessError, PIPE
25 import zmq
31 import zmq
26 from zmq.eventloop import ioloop
32 from zmq.eventloop import ioloop
27
33
28 from IPython.config.application import Application, boolean_flag
34 from IPython.config.application import Application, boolean_flag
29 from IPython.config.loader import Config
35 from IPython.config.loader import Config
30 from IPython.core.newapplication import BaseIPythonApplication, ProfileDir
36 from IPython.core.newapplication import BaseIPythonApplication, ProfileDir
31 from IPython.utils.importstring import import_item
37 from IPython.utils.importstring import import_item
32 from IPython.utils.traitlets import Int, Unicode, Bool, CFloat, Dict, List
38 from IPython.utils.traitlets import Int, Unicode, Bool, CFloat, Dict, List
33
39
34 from IPython.parallel.apps.baseapp import (
40 from IPython.parallel.apps.baseapp import (
35 BaseParallelApplication,
41 BaseParallelApplication,
36 PIDFileError,
42 PIDFileError,
37 base_flags, base_aliases
43 base_flags, base_aliases
38 )
44 )
39
45
40
46
41 #-----------------------------------------------------------------------------
47 #-----------------------------------------------------------------------------
42 # Module level variables
48 # Module level variables
43 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
44
50
45
51
46 default_config_file_name = u'ipcluster_config.py'
52 default_config_file_name = u'ipcluster_config.py'
47
53
48
54
49 _description = """Start an IPython cluster for parallel computing.
55 _description = """Start an IPython cluster for parallel computing.
50
56
51 An IPython cluster consists of 1 controller and 1 or more engines.
57 An IPython cluster consists of 1 controller and 1 or more engines.
52 This command automates the startup of these processes using a wide
58 This command automates the startup of these processes using a wide
53 range of startup methods (SSH, local processes, PBS, mpiexec,
59 range of startup methods (SSH, local processes, PBS, mpiexec,
54 Windows HPC Server 2008). To start a cluster with 4 engines on your
60 Windows HPC Server 2008). To start a cluster with 4 engines on your
55 local host simply do 'ipcluster start n=4'. For more complex usage
61 local host simply do 'ipcluster start n=4'. For more complex usage
56 you will typically do 'ipcluster create profile=mycluster', then edit
62 you will typically do 'ipcluster create profile=mycluster', then edit
57 configuration files, followed by 'ipcluster start profile=mycluster n=4'.
63 configuration files, followed by 'ipcluster start profile=mycluster n=4'.
58 """
64 """
59
65
60
66
61 # Exit codes for ipcluster
67 # Exit codes for ipcluster
62
68
63 # This will be the exit code if the ipcluster appears to be running because
69 # This will be the exit code if the ipcluster appears to be running because
64 # a .pid file exists
70 # a .pid file exists
65 ALREADY_STARTED = 10
71 ALREADY_STARTED = 10
66
72
67
73
68 # This will be the exit code if ipcluster stop is run, but there is not .pid
74 # This will be the exit code if ipcluster stop is run, but there is not .pid
69 # file to be found.
75 # file to be found.
70 ALREADY_STOPPED = 11
76 ALREADY_STOPPED = 11
71
77
72 # This will be the exit code if ipcluster engines is run, but there is not .pid
78 # This will be the exit code if ipcluster engines is run, but there is not .pid
73 # file to be found.
79 # file to be found.
74 NO_CLUSTER = 12
80 NO_CLUSTER = 12
75
81
76
82
77 #-----------------------------------------------------------------------------
83 #-----------------------------------------------------------------------------
78 # Main application
84 # Main application
79 #-----------------------------------------------------------------------------
85 #-----------------------------------------------------------------------------
80 start_help = """Start an IPython cluster for parallel computing
86 start_help = """Start an IPython cluster for parallel computing
81
87
82 Start an ipython cluster by its profile name or cluster
88 Start an ipython cluster by its profile name or cluster
83 directory. Cluster directories contain configuration, log and
89 directory. Cluster directories contain configuration, log and
84 security related files and are named using the convention
90 security related files and are named using the convention
85 'cluster_<profile>' and should be creating using the 'start'
91 'cluster_<profile>' and should be creating using the 'start'
86 subcommand of 'ipcluster'. If your cluster directory is in
92 subcommand of 'ipcluster'. If your cluster directory is in
87 the cwd or the ipython directory, you can simply refer to it
93 the cwd or the ipython directory, you can simply refer to it
88 using its profile name, 'ipcluster start n=4 profile=<profile>`,
94 using its profile name, 'ipcluster start n=4 profile=<profile>`,
89 otherwise use the 'profile_dir' option.
95 otherwise use the 'profile_dir' option.
90 """
96 """
91 stop_help = """Stop a running IPython cluster
97 stop_help = """Stop a running IPython cluster
92
98
93 Stop a running ipython cluster by its profile name or cluster
99 Stop a running ipython cluster by its profile name or cluster
94 directory. Cluster directories are named using the convention
100 directory. Cluster directories are named using the convention
95 'cluster_<profile>'. If your cluster directory is in
101 'cluster_<profile>'. If your cluster directory is in
96 the cwd or the ipython directory, you can simply refer to it
102 the cwd or the ipython directory, you can simply refer to it
97 using its profile name, 'ipcluster stop profile=<profile>`, otherwise
103 using its profile name, 'ipcluster stop profile=<profile>`, otherwise
98 use the 'profile_dir' option.
104 use the 'profile_dir' option.
99 """
105 """
100 engines_help = """Start engines connected to an existing IPython cluster
106 engines_help = """Start engines connected to an existing IPython cluster
101
107
102 Start one or more engines to connect to an existing Cluster
108 Start one or more engines to connect to an existing Cluster
103 by profile name or cluster directory.
109 by profile name or cluster directory.
104 Cluster directories contain configuration, log and
110 Cluster directories contain configuration, log and
105 security related files and are named using the convention
111 security related files and are named using the convention
106 'cluster_<profile>' and should be creating using the 'start'
112 'cluster_<profile>' and should be creating using the 'start'
107 subcommand of 'ipcluster'. If your cluster directory is in
113 subcommand of 'ipcluster'. If your cluster directory is in
108 the cwd or the ipython directory, you can simply refer to it
114 the cwd or the ipython directory, you can simply refer to it
109 using its profile name, 'ipcluster engines n=4 profile=<profile>`,
115 using its profile name, 'ipcluster engines n=4 profile=<profile>`,
110 otherwise use the 'profile_dir' option.
116 otherwise use the 'profile_dir' option.
111 """
117 """
112 create_help = """Create an ipcluster profile by name
118 create_help = """Create an ipcluster profile by name
113
119
114 Create an ipython cluster directory by its profile name or
120 Create an ipython cluster directory by its profile name or
115 cluster directory path. Cluster directories contain
121 cluster directory path. Cluster directories contain
116 configuration, log and security related files and are named
122 configuration, log and security related files and are named
117 using the convention 'cluster_<profile>'. By default they are
123 using the convention 'cluster_<profile>'. By default they are
118 located in your ipython directory. Once created, you will
124 located in your ipython directory. Once created, you will
119 probably need to edit the configuration files in the cluster
125 probably need to edit the configuration files in the cluster
120 directory to configure your cluster. Most users will create a
126 directory to configure your cluster. Most users will create a
121 cluster directory by profile name,
127 cluster directory by profile name,
122 `ipcluster create profile=mycluster`, which will put the directory
128 `ipcluster create profile=mycluster`, which will put the directory
123 in `<ipython_dir>/cluster_mycluster`.
129 in `<ipython_dir>/cluster_mycluster`.
124 """
130 """
125 list_help = """List available cluster profiles
131 list_help = """List available cluster profiles
126
132
127 List all available clusters, by cluster directory, that can
133 List all available clusters, by cluster directory, that can
128 be found in the current working directly or in the ipython
134 be found in the current working directly or in the ipython
129 directory. Cluster directories are named using the convention
135 directory. Cluster directories are named using the convention
130 'cluster_<profile>'.
136 'cluster_<profile>'.
131 """
137 """
132
138
133
139
134 class IPClusterList(BaseIPythonApplication):
140 class IPClusterList(BaseIPythonApplication):
135 name = u'ipcluster-list'
141 name = u'ipcluster-list'
136 description = list_help
142 description = list_help
137
143
138 # empty aliases
144 # empty aliases
139 aliases=Dict()
145 aliases=Dict()
140 flags = Dict(base_flags)
146 flags = Dict(base_flags)
141
147
142 def _log_level_default(self):
148 def _log_level_default(self):
143 return 20
149 return 20
144
150
145 def list_profile_dirs(self):
151 def list_profile_dirs(self):
146 # Find the search paths
152 # Find the search paths
147 profile_dir_paths = os.environ.get('IPYTHON_PROFILE_PATH','')
153 profile_dir_paths = os.environ.get('IPYTHON_PROFILE_PATH','')
148 if profile_dir_paths:
154 if profile_dir_paths:
149 profile_dir_paths = profile_dir_paths.split(':')
155 profile_dir_paths = profile_dir_paths.split(':')
150 else:
156 else:
151 profile_dir_paths = []
157 profile_dir_paths = []
152
158
153 ipython_dir = self.ipython_dir
159 ipython_dir = self.ipython_dir
154
160
155 paths = [os.getcwd(), ipython_dir] + profile_dir_paths
161 paths = [os.getcwd(), ipython_dir] + profile_dir_paths
156 paths = list(set(paths))
162 paths = list(set(paths))
157
163
158 self.log.info('Searching for cluster profiles in paths: %r' % paths)
164 self.log.info('Searching for cluster profiles in paths: %r' % paths)
159 for path in paths:
165 for path in paths:
160 files = os.listdir(path)
166 files = os.listdir(path)
161 for f in files:
167 for f in files:
162 full_path = os.path.join(path, f)
168 full_path = os.path.join(path, f)
163 if os.path.isdir(full_path) and f.startswith('profile_') and \
169 if os.path.isdir(full_path) and f.startswith('profile_') and \
164 os.path.isfile(os.path.join(full_path, 'ipcontroller_config.py')):
170 os.path.isfile(os.path.join(full_path, 'ipcontroller_config.py')):
165 profile = f.split('_')[-1]
171 profile = f.split('_')[-1]
166 start_cmd = 'ipcluster start profile=%s n=4' % profile
172 start_cmd = 'ipcluster start profile=%s n=4' % profile
167 print start_cmd + " ==> " + full_path
173 print start_cmd + " ==> " + full_path
168
174
169 def start(self):
175 def start(self):
170 self.list_profile_dirs()
176 self.list_profile_dirs()
171
177
172
178
173 # `ipcluster create` will be deprecated when `ipython profile create` or equivalent exists
179 # `ipcluster create` will be deprecated when `ipython profile create` or equivalent exists
174
180
175 create_flags = {}
181 create_flags = {}
176 create_flags.update(base_flags)
182 create_flags.update(base_flags)
177 create_flags.update(boolean_flag('reset', 'IPClusterCreate.overwrite',
183 create_flags.update(boolean_flag('reset', 'IPClusterCreate.overwrite',
178 "reset config files to defaults", "leave existing config files"))
184 "reset config files to defaults", "leave existing config files"))
179
185
180 class IPClusterCreate(BaseParallelApplication):
186 class IPClusterCreate(BaseParallelApplication):
181 name = u'ipcluster-create'
187 name = u'ipcluster-create'
182 description = create_help
188 description = create_help
183 auto_create = Bool(True)
189 auto_create = Bool(True)
184 config_file_name = Unicode(default_config_file_name)
190 config_file_name = Unicode(default_config_file_name)
185
191
186 flags = Dict(create_flags)
192 flags = Dict(create_flags)
187
193
188 aliases = Dict(dict(profile='BaseIPythonApplication.profile'))
194 aliases = Dict(dict(profile='BaseIPythonApplication.profile'))
189
195
190 classes = [ProfileDir]
196 classes = [ProfileDir]
191
197
192
198
193 stop_aliases = dict(
199 stop_aliases = dict(
194 signal='IPClusterStop.signal',
200 signal='IPClusterStop.signal',
195 profile='BaseIPythonApplication.profile',
201 profile='BaseIPythonApplication.profile',
196 profile_dir='ProfileDir.location',
202 profile_dir='ProfileDir.location',
197 )
203 )
198
204
199 class IPClusterStop(BaseParallelApplication):
205 class IPClusterStop(BaseParallelApplication):
200 name = u'ipcluster'
206 name = u'ipcluster'
201 description = stop_help
207 description = stop_help
202 config_file_name = Unicode(default_config_file_name)
208 config_file_name = Unicode(default_config_file_name)
203
209
204 signal = Int(signal.SIGINT, config=True,
210 signal = Int(signal.SIGINT, config=True,
205 help="signal to use for stopping processes.")
211 help="signal to use for stopping processes.")
206
212
207 aliases = Dict(stop_aliases)
213 aliases = Dict(stop_aliases)
208
214
209 def start(self):
215 def start(self):
210 """Start the app for the stop subcommand."""
216 """Start the app for the stop subcommand."""
211 try:
217 try:
212 pid = self.get_pid_from_file()
218 pid = self.get_pid_from_file()
213 except PIDFileError:
219 except PIDFileError:
214 self.log.critical(
220 self.log.critical(
215 'Could not read pid file, cluster is probably not running.'
221 'Could not read pid file, cluster is probably not running.'
216 )
222 )
217 # Here I exit with a unusual exit status that other processes
223 # Here I exit with a unusual exit status that other processes
218 # can watch for to learn how I existed.
224 # can watch for to learn how I existed.
219 self.remove_pid_file()
225 self.remove_pid_file()
220 self.exit(ALREADY_STOPPED)
226 self.exit(ALREADY_STOPPED)
221
227
222 if not self.check_pid(pid):
228 if not self.check_pid(pid):
223 self.log.critical(
229 self.log.critical(
224 'Cluster [pid=%r] is not running.' % pid
230 'Cluster [pid=%r] is not running.' % pid
225 )
231 )
226 self.remove_pid_file()
232 self.remove_pid_file()
227 # Here I exit with a unusual exit status that other processes
233 # Here I exit with a unusual exit status that other processes
228 # can watch for to learn how I existed.
234 # can watch for to learn how I existed.
229 self.exit(ALREADY_STOPPED)
235 self.exit(ALREADY_STOPPED)
230
236
231 elif os.name=='posix':
237 elif os.name=='posix':
232 sig = self.signal
238 sig = self.signal
233 self.log.info(
239 self.log.info(
234 "Stopping cluster [pid=%r] with [signal=%r]" % (pid, sig)
240 "Stopping cluster [pid=%r] with [signal=%r]" % (pid, sig)
235 )
241 )
236 try:
242 try:
237 os.kill(pid, sig)
243 os.kill(pid, sig)
238 except OSError:
244 except OSError:
239 self.log.error("Stopping cluster failed, assuming already dead.",
245 self.log.error("Stopping cluster failed, assuming already dead.",
240 exc_info=True)
246 exc_info=True)
241 self.remove_pid_file()
247 self.remove_pid_file()
242 elif os.name=='nt':
248 elif os.name=='nt':
243 try:
249 try:
244 # kill the whole tree
250 # kill the whole tree
245 p = check_call(['taskkill', '-pid', str(pid), '-t', '-f'], stdout=PIPE,stderr=PIPE)
251 p = check_call(['taskkill', '-pid', str(pid), '-t', '-f'], stdout=PIPE,stderr=PIPE)
246 except (CalledProcessError, OSError):
252 except (CalledProcessError, OSError):
247 self.log.error("Stopping cluster failed, assuming already dead.",
253 self.log.error("Stopping cluster failed, assuming already dead.",
248 exc_info=True)
254 exc_info=True)
249 self.remove_pid_file()
255 self.remove_pid_file()
250
256
251 engine_aliases = {}
257 engine_aliases = {}
252 engine_aliases.update(base_aliases)
258 engine_aliases.update(base_aliases)
253 engine_aliases.update(dict(
259 engine_aliases.update(dict(
254 n='IPClusterEngines.n',
260 n='IPClusterEngines.n',
255 elauncher = 'IPClusterEngines.engine_launcher_class',
261 elauncher = 'IPClusterEngines.engine_launcher_class',
256 ))
262 ))
257 class IPClusterEngines(BaseParallelApplication):
263 class IPClusterEngines(BaseParallelApplication):
258
264
259 name = u'ipcluster'
265 name = u'ipcluster'
260 description = engines_help
266 description = engines_help
261 usage = None
267 usage = None
262 config_file_name = Unicode(default_config_file_name)
268 config_file_name = Unicode(default_config_file_name)
263 default_log_level = logging.INFO
269 default_log_level = logging.INFO
264 classes = List()
270 classes = List()
265 def _classes_default(self):
271 def _classes_default(self):
266 from IPython.parallel.apps import launcher
272 from IPython.parallel.apps import launcher
267 launchers = launcher.all_launchers
273 launchers = launcher.all_launchers
268 eslaunchers = [ l for l in launchers if 'EngineSet' in l.__name__]
274 eslaunchers = [ l for l in launchers if 'EngineSet' in l.__name__]
269 return [ProfileDir]+eslaunchers
275 return [ProfileDir]+eslaunchers
270
276
271 n = Int(2, config=True,
277 n = Int(2, config=True,
272 help="The number of engines to start.")
278 help="The number of engines to start.")
273
279
274 engine_launcher_class = Unicode('LocalEngineSetLauncher',
280 engine_launcher_class = Unicode('LocalEngineSetLauncher',
275 config=True,
281 config=True,
276 help="The class for launching a set of Engines."
282 help="The class for launching a set of Engines."
277 )
283 )
278 daemonize = Bool(False, config=True,
284 daemonize = Bool(False, config=True,
279 help='Daemonize the ipcluster program. This implies --log-to-file')
285 help='Daemonize the ipcluster program. This implies --log-to-file')
280
286
281 def _daemonize_changed(self, name, old, new):
287 def _daemonize_changed(self, name, old, new):
282 if new:
288 if new:
283 self.log_to_file = True
289 self.log_to_file = True
284
290
285 aliases = Dict(engine_aliases)
291 aliases = Dict(engine_aliases)
286 # flags = Dict(flags)
292 # flags = Dict(flags)
287 _stopping = False
293 _stopping = False
288
294
289 def initialize(self, argv=None):
295 def initialize(self, argv=None):
290 super(IPClusterEngines, self).initialize(argv)
296 super(IPClusterEngines, self).initialize(argv)
291 self.init_signal()
297 self.init_signal()
292 self.init_launchers()
298 self.init_launchers()
293
299
294 def init_launchers(self):
300 def init_launchers(self):
295 self.engine_launcher = self.build_launcher(self.engine_launcher_class)
301 self.engine_launcher = self.build_launcher(self.engine_launcher_class)
296 self.engine_launcher.on_stop(lambda r: self.loop.stop())
302 self.engine_launcher.on_stop(lambda r: self.loop.stop())
297
303
298 def init_signal(self):
304 def init_signal(self):
299 # Setup signals
305 # Setup signals
300 signal.signal(signal.SIGINT, self.sigint_handler)
306 signal.signal(signal.SIGINT, self.sigint_handler)
301
307
302 def build_launcher(self, clsname):
308 def build_launcher(self, clsname):
303 """import and instantiate a Launcher based on importstring"""
309 """import and instantiate a Launcher based on importstring"""
304 if '.' not in clsname:
310 if '.' not in clsname:
305 # not a module, presume it's the raw name in apps.launcher
311 # not a module, presume it's the raw name in apps.launcher
306 clsname = 'IPython.parallel.apps.launcher.'+clsname
312 clsname = 'IPython.parallel.apps.launcher.'+clsname
307 # print repr(clsname)
313 # print repr(clsname)
308 klass = import_item(clsname)
314 klass = import_item(clsname)
309
315
310 launcher = klass(
316 launcher = klass(
311 work_dir=self.profile_dir.location, config=self.config, log=self.log
317 work_dir=self.profile_dir.location, config=self.config, log=self.log
312 )
318 )
313 return launcher
319 return launcher
314
320
315 def start_engines(self):
321 def start_engines(self):
316 self.log.info("Starting %i engines"%self.n)
322 self.log.info("Starting %i engines"%self.n)
317 self.engine_launcher.start(
323 self.engine_launcher.start(
318 self.n,
324 self.n,
319 self.profile_dir.location
325 self.profile_dir.location
320 )
326 )
321
327
322 def stop_engines(self):
328 def stop_engines(self):
323 self.log.info("Stopping Engines...")
329 self.log.info("Stopping Engines...")
324 if self.engine_launcher.running:
330 if self.engine_launcher.running:
325 d = self.engine_launcher.stop()
331 d = self.engine_launcher.stop()
326 return d
332 return d
327 else:
333 else:
328 return None
334 return None
329
335
330 def stop_launchers(self, r=None):
336 def stop_launchers(self, r=None):
331 if not self._stopping:
337 if not self._stopping:
332 self._stopping = True
338 self._stopping = True
333 self.log.error("IPython cluster: stopping")
339 self.log.error("IPython cluster: stopping")
334 self.stop_engines()
340 self.stop_engines()
335 # Wait a few seconds to let things shut down.
341 # Wait a few seconds to let things shut down.
336 dc = ioloop.DelayedCallback(self.loop.stop, 4000, self.loop)
342 dc = ioloop.DelayedCallback(self.loop.stop, 4000, self.loop)
337 dc.start()
343 dc.start()
338
344
339 def sigint_handler(self, signum, frame):
345 def sigint_handler(self, signum, frame):
340 self.log.debug("SIGINT received, stopping launchers...")
346 self.log.debug("SIGINT received, stopping launchers...")
341 self.stop_launchers()
347 self.stop_launchers()
342
348
343 def start_logging(self):
349 def start_logging(self):
344 # Remove old log files of the controller and engine
350 # Remove old log files of the controller and engine
345 if self.clean_logs:
351 if self.clean_logs:
346 log_dir = self.profile_dir.log_dir
352 log_dir = self.profile_dir.log_dir
347 for f in os.listdir(log_dir):
353 for f in os.listdir(log_dir):
348 if re.match(r'ip(engine|controller)z-\d+\.(log|err|out)',f):
354 if re.match(r'ip(engine|controller)z-\d+\.(log|err|out)',f):
349 os.remove(os.path.join(log_dir, f))
355 os.remove(os.path.join(log_dir, f))
350 # This will remove old log files for ipcluster itself
356 # This will remove old log files for ipcluster itself
351 # super(IPBaseParallelApplication, self).start_logging()
357 # super(IPBaseParallelApplication, self).start_logging()
352
358
353 def start(self):
359 def start(self):
354 """Start the app for the engines subcommand."""
360 """Start the app for the engines subcommand."""
355 self.log.info("IPython cluster: started")
361 self.log.info("IPython cluster: started")
356 # First see if the cluster is already running
362 # First see if the cluster is already running
357
363
358 # Now log and daemonize
364 # Now log and daemonize
359 self.log.info(
365 self.log.info(
360 'Starting engines with [daemon=%r]' % self.daemonize
366 'Starting engines with [daemon=%r]' % self.daemonize
361 )
367 )
362 # TODO: Get daemonize working on Windows or as a Windows Server.
368 # TODO: Get daemonize working on Windows or as a Windows Server.
363 if self.daemonize:
369 if self.daemonize:
364 if os.name=='posix':
370 if os.name=='posix':
365 from twisted.scripts._twistd_unix import daemonize
371 from twisted.scripts._twistd_unix import daemonize
366 daemonize()
372 daemonize()
367
373
368 dc = ioloop.DelayedCallback(self.start_engines, 0, self.loop)
374 dc = ioloop.DelayedCallback(self.start_engines, 0, self.loop)
369 dc.start()
375 dc.start()
370 # Now write the new pid file AFTER our new forked pid is active.
376 # Now write the new pid file AFTER our new forked pid is active.
371 # self.write_pid_file()
377 # self.write_pid_file()
372 try:
378 try:
373 self.loop.start()
379 self.loop.start()
374 except KeyboardInterrupt:
380 except KeyboardInterrupt:
375 pass
381 pass
376 except zmq.ZMQError as e:
382 except zmq.ZMQError as e:
377 if e.errno == errno.EINTR:
383 if e.errno == errno.EINTR:
378 pass
384 pass
379 else:
385 else:
380 raise
386 raise
381
387
382 start_aliases = {}
388 start_aliases = {}
383 start_aliases.update(engine_aliases)
389 start_aliases.update(engine_aliases)
384 start_aliases.update(dict(
390 start_aliases.update(dict(
385 delay='IPClusterStart.delay',
391 delay='IPClusterStart.delay',
386 clean_logs='IPClusterStart.clean_logs',
392 clean_logs='IPClusterStart.clean_logs',
387 ))
393 ))
388
394
389 class IPClusterStart(IPClusterEngines):
395 class IPClusterStart(IPClusterEngines):
390
396
391 name = u'ipcluster'
397 name = u'ipcluster'
392 description = start_help
398 description = start_help
393 default_log_level = logging.INFO
399 default_log_level = logging.INFO
394 auto_create = Bool(True, config=True,
400 auto_create = Bool(True, config=True,
395 help="whether to create the profile_dir if it doesn't exist")
401 help="whether to create the profile_dir if it doesn't exist")
396 classes = List()
402 classes = List()
397 def _classes_default(self,):
403 def _classes_default(self,):
398 from IPython.parallel.apps import launcher
404 from IPython.parallel.apps import launcher
399 return [ProfileDir]+launcher.all_launchers
405 return [ProfileDir]+launcher.all_launchers
400
406
401 clean_logs = Bool(True, config=True,
407 clean_logs = Bool(True, config=True,
402 help="whether to cleanup old logs before starting")
408 help="whether to cleanup old logs before starting")
403
409
404 delay = CFloat(1., config=True,
410 delay = CFloat(1., config=True,
405 help="delay (in s) between starting the controller and the engines")
411 help="delay (in s) between starting the controller and the engines")
406
412
407 controller_launcher_class = Unicode('LocalControllerLauncher',
413 controller_launcher_class = Unicode('LocalControllerLauncher',
408 config=True,
414 config=True,
409 help="The class for launching a Controller."
415 help="The class for launching a Controller."
410 )
416 )
411 reset = Bool(False, config=True,
417 reset = Bool(False, config=True,
412 help="Whether to reset config files as part of '--create'."
418 help="Whether to reset config files as part of '--create'."
413 )
419 )
414
420
415 # flags = Dict(flags)
421 # flags = Dict(flags)
416 aliases = Dict(start_aliases)
422 aliases = Dict(start_aliases)
417
423
418 def init_launchers(self):
424 def init_launchers(self):
419 self.controller_launcher = self.build_launcher(self.controller_launcher_class)
425 self.controller_launcher = self.build_launcher(self.controller_launcher_class)
420 self.engine_launcher = self.build_launcher(self.engine_launcher_class)
426 self.engine_launcher = self.build_launcher(self.engine_launcher_class)
421 self.controller_launcher.on_stop(self.stop_launchers)
427 self.controller_launcher.on_stop(self.stop_launchers)
422
428
423 def start_controller(self):
429 def start_controller(self):
424 self.controller_launcher.start(
430 self.controller_launcher.start(
425 self.profile_dir.location
431 self.profile_dir.location
426 )
432 )
427
433
428 def stop_controller(self):
434 def stop_controller(self):
429 # self.log.info("In stop_controller")
435 # self.log.info("In stop_controller")
430 if self.controller_launcher and self.controller_launcher.running:
436 if self.controller_launcher and self.controller_launcher.running:
431 return self.controller_launcher.stop()
437 return self.controller_launcher.stop()
432
438
433 def stop_launchers(self, r=None):
439 def stop_launchers(self, r=None):
434 if not self._stopping:
440 if not self._stopping:
435 self.stop_controller()
441 self.stop_controller()
436 super(IPClusterStart, self).stop_launchers()
442 super(IPClusterStart, self).stop_launchers()
437
443
438 def start(self):
444 def start(self):
439 """Start the app for the start subcommand."""
445 """Start the app for the start subcommand."""
440 # First see if the cluster is already running
446 # First see if the cluster is already running
441 try:
447 try:
442 pid = self.get_pid_from_file()
448 pid = self.get_pid_from_file()
443 except PIDFileError:
449 except PIDFileError:
444 pass
450 pass
445 else:
451 else:
446 if self.check_pid(pid):
452 if self.check_pid(pid):
447 self.log.critical(
453 self.log.critical(
448 'Cluster is already running with [pid=%s]. '
454 'Cluster is already running with [pid=%s]. '
449 'use "ipcluster stop" to stop the cluster.' % pid
455 'use "ipcluster stop" to stop the cluster.' % pid
450 )
456 )
451 # Here I exit with a unusual exit status that other processes
457 # Here I exit with a unusual exit status that other processes
452 # can watch for to learn how I existed.
458 # can watch for to learn how I existed.
453 self.exit(ALREADY_STARTED)
459 self.exit(ALREADY_STARTED)
454 else:
460 else:
455 self.remove_pid_file()
461 self.remove_pid_file()
456
462
457
463
458 # Now log and daemonize
464 # Now log and daemonize
459 self.log.info(
465 self.log.info(
460 'Starting ipcluster with [daemon=%r]' % self.daemonize
466 'Starting ipcluster with [daemon=%r]' % self.daemonize
461 )
467 )
462 # TODO: Get daemonize working on Windows or as a Windows Server.
468 # TODO: Get daemonize working on Windows or as a Windows Server.
463 if self.daemonize:
469 if self.daemonize:
464 if os.name=='posix':
470 if os.name=='posix':
465 from twisted.scripts._twistd_unix import daemonize
471 from twisted.scripts._twistd_unix import daemonize
466 daemonize()
472 daemonize()
467
473
468 dc = ioloop.DelayedCallback(self.start_controller, 0, self.loop)
474 dc = ioloop.DelayedCallback(self.start_controller, 0, self.loop)
469 dc.start()
475 dc.start()
470 dc = ioloop.DelayedCallback(self.start_engines, 1000*self.delay, self.loop)
476 dc = ioloop.DelayedCallback(self.start_engines, 1000*self.delay, self.loop)
471 dc.start()
477 dc.start()
472 # Now write the new pid file AFTER our new forked pid is active.
478 # Now write the new pid file AFTER our new forked pid is active.
473 self.write_pid_file()
479 self.write_pid_file()
474 try:
480 try:
475 self.loop.start()
481 self.loop.start()
476 except KeyboardInterrupt:
482 except KeyboardInterrupt:
477 pass
483 pass
478 except zmq.ZMQError as e:
484 except zmq.ZMQError as e:
479 if e.errno == errno.EINTR:
485 if e.errno == errno.EINTR:
480 pass
486 pass
481 else:
487 else:
482 raise
488 raise
483 finally:
489 finally:
484 self.remove_pid_file()
490 self.remove_pid_file()
485
491
486 base='IPython.parallel.apps.ipclusterapp.IPCluster'
492 base='IPython.parallel.apps.ipclusterapp.IPCluster'
487
493
488 class IPBaseParallelApplication(Application):
494 class IPBaseParallelApplication(Application):
489 name = u'ipcluster'
495 name = u'ipcluster'
490 description = _description
496 description = _description
491
497
492 subcommands = {'create' : (base+'Create', create_help),
498 subcommands = {'create' : (base+'Create', create_help),
493 'list' : (base+'List', list_help),
499 'list' : (base+'List', list_help),
494 'start' : (base+'Start', start_help),
500 'start' : (base+'Start', start_help),
495 'stop' : (base+'Stop', stop_help),
501 'stop' : (base+'Stop', stop_help),
496 'engines' : (base+'Engines', engines_help),
502 'engines' : (base+'Engines', engines_help),
497 }
503 }
498
504
499 # no aliases or flags for parent App
505 # no aliases or flags for parent App
500 aliases = Dict()
506 aliases = Dict()
501 flags = Dict()
507 flags = Dict()
502
508
503 def start(self):
509 def start(self):
504 if self.subapp is None:
510 if self.subapp is None:
505 print "No subcommand specified! Must specify one of: %s"%(self.subcommands.keys())
511 print "No subcommand specified! Must specify one of: %s"%(self.subcommands.keys())
506 print
512 print
507 self.print_subcommands()
513 self.print_subcommands()
508 self.exit(1)
514 self.exit(1)
509 else:
515 else:
510 return self.subapp.start()
516 return self.subapp.start()
511
517
512 def launch_new_instance():
518 def launch_new_instance():
513 """Create and run the IPython cluster."""
519 """Create and run the IPython cluster."""
514 app = IPBaseParallelApplication.instance()
520 app = IPBaseParallelApplication.instance()
515 app.initialize()
521 app.initialize()
516 app.start()
522 app.start()
517
523
518
524
519 if __name__ == '__main__':
525 if __name__ == '__main__':
520 launch_new_instance()
526 launch_new_instance()
521
527
@@ -1,402 +1,408 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython controller application.
4 The IPython controller application.
5
6 Authors:
7
8 * Brian Granger
9 * MinRK
10
5 """
11 """
6
12
7 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
9 #
15 #
10 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
13
19
14 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
15 # Imports
21 # Imports
16 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
17
23
18 from __future__ import with_statement
24 from __future__ import with_statement
19
25
20 import os
26 import os
21 import socket
27 import socket
22 import stat
28 import stat
23 import sys
29 import sys
24 import uuid
30 import uuid
25
31
26 from multiprocessing import Process
32 from multiprocessing import Process
27
33
28 import zmq
34 import zmq
29 from zmq.devices import ProcessMonitoredQueue
35 from zmq.devices import ProcessMonitoredQueue
30 from zmq.log.handlers import PUBHandler
36 from zmq.log.handlers import PUBHandler
31 from zmq.utils import jsonapi as json
37 from zmq.utils import jsonapi as json
32
38
33 from IPython.config.application import boolean_flag
39 from IPython.config.application import boolean_flag
34 from IPython.core.newapplication import ProfileDir
40 from IPython.core.newapplication import ProfileDir
35
41
36 from IPython.parallel.apps.baseapp import (
42 from IPython.parallel.apps.baseapp import (
37 BaseParallelApplication,
43 BaseParallelApplication,
38 base_flags
44 base_flags
39 )
45 )
40 from IPython.utils.importstring import import_item
46 from IPython.utils.importstring import import_item
41 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
47 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
42
48
43 # from IPython.parallel.controller.controller import ControllerFactory
49 # from IPython.parallel.controller.controller import ControllerFactory
44 from IPython.zmq.session import Session
50 from IPython.zmq.session import Session
45 from IPython.parallel.controller.heartmonitor import HeartMonitor
51 from IPython.parallel.controller.heartmonitor import HeartMonitor
46 from IPython.parallel.controller.hub import HubFactory
52 from IPython.parallel.controller.hub import HubFactory
47 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
53 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
48 from IPython.parallel.controller.sqlitedb import SQLiteDB
54 from IPython.parallel.controller.sqlitedb import SQLiteDB
49
55
50 from IPython.parallel.util import signal_children, split_url
56 from IPython.parallel.util import signal_children, split_url
51
57
52 # conditional import of MongoDB backend class
58 # conditional import of MongoDB backend class
53
59
54 try:
60 try:
55 from IPython.parallel.controller.mongodb import MongoDB
61 from IPython.parallel.controller.mongodb import MongoDB
56 except ImportError:
62 except ImportError:
57 maybe_mongo = []
63 maybe_mongo = []
58 else:
64 else:
59 maybe_mongo = [MongoDB]
65 maybe_mongo = [MongoDB]
60
66
61
67
62 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
63 # Module level variables
69 # Module level variables
64 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
65
71
66
72
67 #: The default config file name for this application
73 #: The default config file name for this application
68 default_config_file_name = u'ipcontroller_config.py'
74 default_config_file_name = u'ipcontroller_config.py'
69
75
70
76
71 _description = """Start the IPython controller for parallel computing.
77 _description = """Start the IPython controller for parallel computing.
72
78
73 The IPython controller provides a gateway between the IPython engines and
79 The IPython controller provides a gateway between the IPython engines and
74 clients. The controller needs to be started before the engines and can be
80 clients. The controller needs to be started before the engines and can be
75 configured using command line options or using a cluster directory. Cluster
81 configured using command line options or using a cluster directory. Cluster
76 directories contain config, log and security files and are usually located in
82 directories contain config, log and security files and are usually located in
77 your ipython directory and named as "cluster_<profile>". See the `profile`
83 your ipython directory and named as "cluster_<profile>". See the `profile`
78 and `profile_dir` options for details.
84 and `profile_dir` options for details.
79 """
85 """
80
86
81
87
82
88
83
89
84 #-----------------------------------------------------------------------------
90 #-----------------------------------------------------------------------------
85 # The main application
91 # The main application
86 #-----------------------------------------------------------------------------
92 #-----------------------------------------------------------------------------
87 flags = {}
93 flags = {}
88 flags.update(base_flags)
94 flags.update(base_flags)
89 flags.update({
95 flags.update({
90 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
96 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
91 'Use threads instead of processes for the schedulers'),
97 'Use threads instead of processes for the schedulers'),
92 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
98 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
93 'use the SQLiteDB backend'),
99 'use the SQLiteDB backend'),
94 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
100 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
95 'use the MongoDB backend'),
101 'use the MongoDB backend'),
96 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
102 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
97 'use the in-memory DictDB backend'),
103 'use the in-memory DictDB backend'),
98 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
104 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
99 'reuse existing json connection files')
105 'reuse existing json connection files')
100 })
106 })
101
107
102 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
108 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
103 "Use HMAC digests for authentication of messages.",
109 "Use HMAC digests for authentication of messages.",
104 "Don't authenticate messages."
110 "Don't authenticate messages."
105 ))
111 ))
106
112
107 class IPControllerApp(BaseParallelApplication):
113 class IPControllerApp(BaseParallelApplication):
108
114
109 name = u'ipcontroller'
115 name = u'ipcontroller'
110 description = _description
116 description = _description
111 config_file_name = Unicode(default_config_file_name)
117 config_file_name = Unicode(default_config_file_name)
112 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
118 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
113
119
114 # change default to True
120 # change default to True
115 auto_create = Bool(True, config=True,
121 auto_create = Bool(True, config=True,
116 help="""Whether to create profile dir if it doesn't exist.""")
122 help="""Whether to create profile dir if it doesn't exist.""")
117
123
118 reuse_files = Bool(False, config=True,
124 reuse_files = Bool(False, config=True,
119 help='Whether to reuse existing json connection files.'
125 help='Whether to reuse existing json connection files.'
120 )
126 )
121 secure = Bool(True, config=True,
127 secure = Bool(True, config=True,
122 help='Whether to use HMAC digests for extra message authentication.'
128 help='Whether to use HMAC digests for extra message authentication.'
123 )
129 )
124 ssh_server = Unicode(u'', config=True,
130 ssh_server = Unicode(u'', config=True,
125 help="""ssh url for clients to use when connecting to the Controller
131 help="""ssh url for clients to use when connecting to the Controller
126 processes. It should be of the form: [user@]server[:port]. The
132 processes. It should be of the form: [user@]server[:port]. The
127 Controller's listening addresses must be accessible from the ssh server""",
133 Controller's listening addresses must be accessible from the ssh server""",
128 )
134 )
129 location = Unicode(u'', config=True,
135 location = Unicode(u'', config=True,
130 help="""The external IP or domain name of the Controller, used for disambiguating
136 help="""The external IP or domain name of the Controller, used for disambiguating
131 engine and client connections.""",
137 engine and client connections.""",
132 )
138 )
133 import_statements = List([], config=True,
139 import_statements = List([], config=True,
134 help="import statements to be run at startup. Necessary in some environments"
140 help="import statements to be run at startup. Necessary in some environments"
135 )
141 )
136
142
137 use_threads = Bool(False, config=True,
143 use_threads = Bool(False, config=True,
138 help='Use threads instead of processes for the schedulers',
144 help='Use threads instead of processes for the schedulers',
139 )
145 )
140
146
141 # internal
147 # internal
142 children = List()
148 children = List()
143 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
149 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
144
150
145 def _use_threads_changed(self, name, old, new):
151 def _use_threads_changed(self, name, old, new):
146 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
152 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
147
153
148 aliases = Dict(dict(
154 aliases = Dict(dict(
149 log_level = 'IPControllerApp.log_level',
155 log_level = 'IPControllerApp.log_level',
150 log_url = 'IPControllerApp.log_url',
156 log_url = 'IPControllerApp.log_url',
151 reuse_files = 'IPControllerApp.reuse_files',
157 reuse_files = 'IPControllerApp.reuse_files',
152 secure = 'IPControllerApp.secure',
158 secure = 'IPControllerApp.secure',
153 ssh = 'IPControllerApp.ssh_server',
159 ssh = 'IPControllerApp.ssh_server',
154 use_threads = 'IPControllerApp.use_threads',
160 use_threads = 'IPControllerApp.use_threads',
155 import_statements = 'IPControllerApp.import_statements',
161 import_statements = 'IPControllerApp.import_statements',
156 location = 'IPControllerApp.location',
162 location = 'IPControllerApp.location',
157
163
158 ident = 'Session.session',
164 ident = 'Session.session',
159 user = 'Session.username',
165 user = 'Session.username',
160 exec_key = 'Session.keyfile',
166 exec_key = 'Session.keyfile',
161
167
162 url = 'HubFactory.url',
168 url = 'HubFactory.url',
163 ip = 'HubFactory.ip',
169 ip = 'HubFactory.ip',
164 transport = 'HubFactory.transport',
170 transport = 'HubFactory.transport',
165 port = 'HubFactory.regport',
171 port = 'HubFactory.regport',
166
172
167 ping = 'HeartMonitor.period',
173 ping = 'HeartMonitor.period',
168
174
169 scheme = 'TaskScheduler.scheme_name',
175 scheme = 'TaskScheduler.scheme_name',
170 hwm = 'TaskScheduler.hwm',
176 hwm = 'TaskScheduler.hwm',
171
177
172
178
173 profile = "BaseIPythonApplication.profile",
179 profile = "BaseIPythonApplication.profile",
174 profile_dir = 'ProfileDir.location',
180 profile_dir = 'ProfileDir.location',
175
181
176 ))
182 ))
177 flags = Dict(flags)
183 flags = Dict(flags)
178
184
179
185
180 def save_connection_dict(self, fname, cdict):
186 def save_connection_dict(self, fname, cdict):
181 """save a connection dict to json file."""
187 """save a connection dict to json file."""
182 c = self.config
188 c = self.config
183 url = cdict['url']
189 url = cdict['url']
184 location = cdict['location']
190 location = cdict['location']
185 if not location:
191 if not location:
186 try:
192 try:
187 proto,ip,port = split_url(url)
193 proto,ip,port = split_url(url)
188 except AssertionError:
194 except AssertionError:
189 pass
195 pass
190 else:
196 else:
191 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
197 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
192 cdict['location'] = location
198 cdict['location'] = location
193 fname = os.path.join(self.profile_dir.security_dir, fname)
199 fname = os.path.join(self.profile_dir.security_dir, fname)
194 with open(fname, 'w') as f:
200 with open(fname, 'w') as f:
195 f.write(json.dumps(cdict, indent=2))
201 f.write(json.dumps(cdict, indent=2))
196 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
202 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
197
203
198 def load_config_from_json(self):
204 def load_config_from_json(self):
199 """load config from existing json connector files."""
205 """load config from existing json connector files."""
200 c = self.config
206 c = self.config
201 # load from engine config
207 # load from engine config
202 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
208 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
203 cfg = json.loads(f.read())
209 cfg = json.loads(f.read())
204 key = c.Session.key = cfg['exec_key']
210 key = c.Session.key = cfg['exec_key']
205 xport,addr = cfg['url'].split('://')
211 xport,addr = cfg['url'].split('://')
206 c.HubFactory.engine_transport = xport
212 c.HubFactory.engine_transport = xport
207 ip,ports = addr.split(':')
213 ip,ports = addr.split(':')
208 c.HubFactory.engine_ip = ip
214 c.HubFactory.engine_ip = ip
209 c.HubFactory.regport = int(ports)
215 c.HubFactory.regport = int(ports)
210 self.location = cfg['location']
216 self.location = cfg['location']
211
217
212 # load client config
218 # load client config
213 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
219 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
214 cfg = json.loads(f.read())
220 cfg = json.loads(f.read())
215 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
221 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
216 xport,addr = cfg['url'].split('://')
222 xport,addr = cfg['url'].split('://')
217 c.HubFactory.client_transport = xport
223 c.HubFactory.client_transport = xport
218 ip,ports = addr.split(':')
224 ip,ports = addr.split(':')
219 c.HubFactory.client_ip = ip
225 c.HubFactory.client_ip = ip
220 self.ssh_server = cfg['ssh']
226 self.ssh_server = cfg['ssh']
221 assert int(ports) == c.HubFactory.regport, "regport mismatch"
227 assert int(ports) == c.HubFactory.regport, "regport mismatch"
222
228
223 def init_hub(self):
229 def init_hub(self):
224 c = self.config
230 c = self.config
225
231
226 self.do_import_statements()
232 self.do_import_statements()
227 reusing = self.reuse_files
233 reusing = self.reuse_files
228 if reusing:
234 if reusing:
229 try:
235 try:
230 self.load_config_from_json()
236 self.load_config_from_json()
231 except (AssertionError,IOError):
237 except (AssertionError,IOError):
232 reusing=False
238 reusing=False
233 # check again, because reusing may have failed:
239 # check again, because reusing may have failed:
234 if reusing:
240 if reusing:
235 pass
241 pass
236 elif self.secure:
242 elif self.secure:
237 key = str(uuid.uuid4())
243 key = str(uuid.uuid4())
238 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
244 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
239 # with open(keyfile, 'w') as f:
245 # with open(keyfile, 'w') as f:
240 # f.write(key)
246 # f.write(key)
241 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
247 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
242 c.Session.key = key
248 c.Session.key = key
243 else:
249 else:
244 key = c.Session.key = ''
250 key = c.Session.key = ''
245
251
246 try:
252 try:
247 self.factory = HubFactory(config=c, log=self.log)
253 self.factory = HubFactory(config=c, log=self.log)
248 # self.start_logging()
254 # self.start_logging()
249 self.factory.init_hub()
255 self.factory.init_hub()
250 except:
256 except:
251 self.log.error("Couldn't construct the Controller", exc_info=True)
257 self.log.error("Couldn't construct the Controller", exc_info=True)
252 self.exit(1)
258 self.exit(1)
253
259
254 if not reusing:
260 if not reusing:
255 # save to new json config files
261 # save to new json config files
256 f = self.factory
262 f = self.factory
257 cdict = {'exec_key' : key,
263 cdict = {'exec_key' : key,
258 'ssh' : self.ssh_server,
264 'ssh' : self.ssh_server,
259 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
265 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
260 'location' : self.location
266 'location' : self.location
261 }
267 }
262 self.save_connection_dict('ipcontroller-client.json', cdict)
268 self.save_connection_dict('ipcontroller-client.json', cdict)
263 edict = cdict
269 edict = cdict
264 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
270 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
265 self.save_connection_dict('ipcontroller-engine.json', edict)
271 self.save_connection_dict('ipcontroller-engine.json', edict)
266
272
267 #
273 #
268 def init_schedulers(self):
274 def init_schedulers(self):
269 children = self.children
275 children = self.children
270 mq = import_item(str(self.mq_class))
276 mq = import_item(str(self.mq_class))
271
277
272 hub = self.factory
278 hub = self.factory
273 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
279 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
274 # IOPub relay (in a Process)
280 # IOPub relay (in a Process)
275 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
281 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
276 q.bind_in(hub.client_info['iopub'])
282 q.bind_in(hub.client_info['iopub'])
277 q.bind_out(hub.engine_info['iopub'])
283 q.bind_out(hub.engine_info['iopub'])
278 q.setsockopt_out(zmq.SUBSCRIBE, '')
284 q.setsockopt_out(zmq.SUBSCRIBE, '')
279 q.connect_mon(hub.monitor_url)
285 q.connect_mon(hub.monitor_url)
280 q.daemon=True
286 q.daemon=True
281 children.append(q)
287 children.append(q)
282
288
283 # Multiplexer Queue (in a Process)
289 # Multiplexer Queue (in a Process)
284 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
290 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
285 q.bind_in(hub.client_info['mux'])
291 q.bind_in(hub.client_info['mux'])
286 q.setsockopt_in(zmq.IDENTITY, 'mux')
292 q.setsockopt_in(zmq.IDENTITY, 'mux')
287 q.bind_out(hub.engine_info['mux'])
293 q.bind_out(hub.engine_info['mux'])
288 q.connect_mon(hub.monitor_url)
294 q.connect_mon(hub.monitor_url)
289 q.daemon=True
295 q.daemon=True
290 children.append(q)
296 children.append(q)
291
297
292 # Control Queue (in a Process)
298 # Control Queue (in a Process)
293 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
299 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
294 q.bind_in(hub.client_info['control'])
300 q.bind_in(hub.client_info['control'])
295 q.setsockopt_in(zmq.IDENTITY, 'control')
301 q.setsockopt_in(zmq.IDENTITY, 'control')
296 q.bind_out(hub.engine_info['control'])
302 q.bind_out(hub.engine_info['control'])
297 q.connect_mon(hub.monitor_url)
303 q.connect_mon(hub.monitor_url)
298 q.daemon=True
304 q.daemon=True
299 children.append(q)
305 children.append(q)
300 try:
306 try:
301 scheme = self.config.TaskScheduler.scheme_name
307 scheme = self.config.TaskScheduler.scheme_name
302 except AttributeError:
308 except AttributeError:
303 scheme = TaskScheduler.scheme_name.get_default_value()
309 scheme = TaskScheduler.scheme_name.get_default_value()
304 # Task Queue (in a Process)
310 # Task Queue (in a Process)
305 if scheme == 'pure':
311 if scheme == 'pure':
306 self.log.warn("task::using pure XREQ Task scheduler")
312 self.log.warn("task::using pure XREQ Task scheduler")
307 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
313 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
308 # q.setsockopt_out(zmq.HWM, hub.hwm)
314 # q.setsockopt_out(zmq.HWM, hub.hwm)
309 q.bind_in(hub.client_info['task'][1])
315 q.bind_in(hub.client_info['task'][1])
310 q.setsockopt_in(zmq.IDENTITY, 'task')
316 q.setsockopt_in(zmq.IDENTITY, 'task')
311 q.bind_out(hub.engine_info['task'])
317 q.bind_out(hub.engine_info['task'])
312 q.connect_mon(hub.monitor_url)
318 q.connect_mon(hub.monitor_url)
313 q.daemon=True
319 q.daemon=True
314 children.append(q)
320 children.append(q)
315 elif scheme == 'none':
321 elif scheme == 'none':
316 self.log.warn("task::using no Task scheduler")
322 self.log.warn("task::using no Task scheduler")
317
323
318 else:
324 else:
319 self.log.info("task::using Python %s Task scheduler"%scheme)
325 self.log.info("task::using Python %s Task scheduler"%scheme)
320 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
326 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
321 hub.monitor_url, hub.client_info['notification'])
327 hub.monitor_url, hub.client_info['notification'])
322 kwargs = dict(logname='scheduler', loglevel=self.log_level,
328 kwargs = dict(logname='scheduler', loglevel=self.log_level,
323 log_url = self.log_url, config=dict(self.config))
329 log_url = self.log_url, config=dict(self.config))
324 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
330 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
325 q.daemon=True
331 q.daemon=True
326 children.append(q)
332 children.append(q)
327
333
328
334
329 def save_urls(self):
335 def save_urls(self):
330 """save the registration urls to files."""
336 """save the registration urls to files."""
331 c = self.config
337 c = self.config
332
338
333 sec_dir = self.profile_dir.security_dir
339 sec_dir = self.profile_dir.security_dir
334 cf = self.factory
340 cf = self.factory
335
341
336 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
342 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
337 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
343 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
338
344
339 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
345 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
340 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
346 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
341
347
342
348
343 def do_import_statements(self):
349 def do_import_statements(self):
344 statements = self.import_statements
350 statements = self.import_statements
345 for s in statements:
351 for s in statements:
346 try:
352 try:
347 self.log.msg("Executing statement: '%s'" % s)
353 self.log.msg("Executing statement: '%s'" % s)
348 exec s in globals(), locals()
354 exec s in globals(), locals()
349 except:
355 except:
350 self.log.msg("Error running statement: %s" % s)
356 self.log.msg("Error running statement: %s" % s)
351
357
352 def forward_logging(self):
358 def forward_logging(self):
353 if self.log_url:
359 if self.log_url:
354 self.log.info("Forwarding logging to %s"%self.log_url)
360 self.log.info("Forwarding logging to %s"%self.log_url)
355 context = zmq.Context.instance()
361 context = zmq.Context.instance()
356 lsock = context.socket(zmq.PUB)
362 lsock = context.socket(zmq.PUB)
357 lsock.connect(self.log_url)
363 lsock.connect(self.log_url)
358 handler = PUBHandler(lsock)
364 handler = PUBHandler(lsock)
359 self.log.removeHandler(self._log_handler)
365 self.log.removeHandler(self._log_handler)
360 handler.root_topic = 'controller'
366 handler.root_topic = 'controller'
361 handler.setLevel(self.log_level)
367 handler.setLevel(self.log_level)
362 self.log.addHandler(handler)
368 self.log.addHandler(handler)
363 self._log_handler = handler
369 self._log_handler = handler
364 # #
370 # #
365
371
366 def initialize(self, argv=None):
372 def initialize(self, argv=None):
367 super(IPControllerApp, self).initialize(argv)
373 super(IPControllerApp, self).initialize(argv)
368 self.forward_logging()
374 self.forward_logging()
369 self.init_hub()
375 self.init_hub()
370 self.init_schedulers()
376 self.init_schedulers()
371
377
372 def start(self):
378 def start(self):
373 # Start the subprocesses:
379 # Start the subprocesses:
374 self.factory.start()
380 self.factory.start()
375 child_procs = []
381 child_procs = []
376 for child in self.children:
382 for child in self.children:
377 child.start()
383 child.start()
378 if isinstance(child, ProcessMonitoredQueue):
384 if isinstance(child, ProcessMonitoredQueue):
379 child_procs.append(child.launcher)
385 child_procs.append(child.launcher)
380 elif isinstance(child, Process):
386 elif isinstance(child, Process):
381 child_procs.append(child)
387 child_procs.append(child)
382 if child_procs:
388 if child_procs:
383 signal_children(child_procs)
389 signal_children(child_procs)
384
390
385 self.write_pid_file(overwrite=True)
391 self.write_pid_file(overwrite=True)
386
392
387 try:
393 try:
388 self.factory.loop.start()
394 self.factory.loop.start()
389 except KeyboardInterrupt:
395 except KeyboardInterrupt:
390 self.log.critical("Interrupted, Exiting...\n")
396 self.log.critical("Interrupted, Exiting...\n")
391
397
392
398
393
399
394 def launch_new_instance():
400 def launch_new_instance():
395 """Create and run the IPython controller"""
401 """Create and run the IPython controller"""
396 app = IPControllerApp.instance()
402 app = IPControllerApp.instance()
397 app.initialize()
403 app.initialize()
398 app.start()
404 app.start()
399
405
400
406
401 if __name__ == '__main__':
407 if __name__ == '__main__':
402 launch_new_instance()
408 launch_new_instance()
@@ -1,270 +1,276 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython engine application
4 The IPython engine application
5
6 Authors:
7
8 * Brian Granger
9 * MinRK
10
5 """
11 """
6
12
7 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
9 #
15 #
10 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
13
19
14 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
15 # Imports
21 # Imports
16 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
17
23
18 import json
24 import json
19 import os
25 import os
20 import sys
26 import sys
21
27
22 import zmq
28 import zmq
23 from zmq.eventloop import ioloop
29 from zmq.eventloop import ioloop
24
30
25 from IPython.core.newapplication import ProfileDir
31 from IPython.core.newapplication import ProfileDir
26 from IPython.parallel.apps.baseapp import BaseParallelApplication
32 from IPython.parallel.apps.baseapp import BaseParallelApplication
27 from IPython.zmq.log import EnginePUBHandler
33 from IPython.zmq.log import EnginePUBHandler
28
34
29 from IPython.config.configurable import Configurable
35 from IPython.config.configurable import Configurable
30 from IPython.zmq.session import Session
36 from IPython.zmq.session import Session
31 from IPython.parallel.engine.engine import EngineFactory
37 from IPython.parallel.engine.engine import EngineFactory
32 from IPython.parallel.engine.streamkernel import Kernel
38 from IPython.parallel.engine.streamkernel import Kernel
33 from IPython.parallel.util import disambiguate_url
39 from IPython.parallel.util import disambiguate_url
34
40
35 from IPython.utils.importstring import import_item
41 from IPython.utils.importstring import import_item
36 from IPython.utils.traitlets import Bool, Unicode, Dict, List
42 from IPython.utils.traitlets import Bool, Unicode, Dict, List
37
43
38
44
39 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
40 # Module level variables
46 # Module level variables
41 #-----------------------------------------------------------------------------
47 #-----------------------------------------------------------------------------
42
48
43 #: The default config file name for this application
49 #: The default config file name for this application
44 default_config_file_name = u'ipengine_config.py'
50 default_config_file_name = u'ipengine_config.py'
45
51
46 _description = """Start an IPython engine for parallel computing.
52 _description = """Start an IPython engine for parallel computing.
47
53
48 IPython engines run in parallel and perform computations on behalf of a client
54 IPython engines run in parallel and perform computations on behalf of a client
49 and controller. A controller needs to be started before the engines. The
55 and controller. A controller needs to be started before the engines. The
50 engine can be configured using command line options or using a cluster
56 engine can be configured using command line options or using a cluster
51 directory. Cluster directories contain config, log and security files and are
57 directory. Cluster directories contain config, log and security files and are
52 usually located in your ipython directory and named as "cluster_<profile>".
58 usually located in your ipython directory and named as "cluster_<profile>".
53 See the `profile` and `profile_dir` options for details.
59 See the `profile` and `profile_dir` options for details.
54 """
60 """
55
61
56
62
57 #-----------------------------------------------------------------------------
63 #-----------------------------------------------------------------------------
58 # MPI configuration
64 # MPI configuration
59 #-----------------------------------------------------------------------------
65 #-----------------------------------------------------------------------------
60
66
61 mpi4py_init = """from mpi4py import MPI as mpi
67 mpi4py_init = """from mpi4py import MPI as mpi
62 mpi.size = mpi.COMM_WORLD.Get_size()
68 mpi.size = mpi.COMM_WORLD.Get_size()
63 mpi.rank = mpi.COMM_WORLD.Get_rank()
69 mpi.rank = mpi.COMM_WORLD.Get_rank()
64 """
70 """
65
71
66
72
67 pytrilinos_init = """from PyTrilinos import Epetra
73 pytrilinos_init = """from PyTrilinos import Epetra
68 class SimpleStruct:
74 class SimpleStruct:
69 pass
75 pass
70 mpi = SimpleStruct()
76 mpi = SimpleStruct()
71 mpi.rank = 0
77 mpi.rank = 0
72 mpi.size = 0
78 mpi.size = 0
73 """
79 """
74
80
75 class MPI(Configurable):
81 class MPI(Configurable):
76 """Configurable for MPI initialization"""
82 """Configurable for MPI initialization"""
77 use = Unicode('', config=True,
83 use = Unicode('', config=True,
78 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
84 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
79 )
85 )
80
86
81 def _on_use_changed(self, old, new):
87 def _on_use_changed(self, old, new):
82 # load default init script if it's not set
88 # load default init script if it's not set
83 if not self.init_script:
89 if not self.init_script:
84 self.init_script = self.default_inits.get(new, '')
90 self.init_script = self.default_inits.get(new, '')
85
91
86 init_script = Unicode('', config=True,
92 init_script = Unicode('', config=True,
87 help="Initialization code for MPI")
93 help="Initialization code for MPI")
88
94
89 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
95 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
90 config=True)
96 config=True)
91
97
92
98
93 #-----------------------------------------------------------------------------
99 #-----------------------------------------------------------------------------
94 # Main application
100 # Main application
95 #-----------------------------------------------------------------------------
101 #-----------------------------------------------------------------------------
96
102
97
103
98 class IPEngineApp(BaseParallelApplication):
104 class IPEngineApp(BaseParallelApplication):
99
105
100 app_name = Unicode(u'ipengine')
106 app_name = Unicode(u'ipengine')
101 description = Unicode(_description)
107 description = Unicode(_description)
102 config_file_name = Unicode(default_config_file_name)
108 config_file_name = Unicode(default_config_file_name)
103 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
109 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
104
110
105 startup_script = Unicode(u'', config=True,
111 startup_script = Unicode(u'', config=True,
106 help='specify a script to be run at startup')
112 help='specify a script to be run at startup')
107 startup_command = Unicode('', config=True,
113 startup_command = Unicode('', config=True,
108 help='specify a command to be run at startup')
114 help='specify a command to be run at startup')
109
115
110 url_file = Unicode(u'', config=True,
116 url_file = Unicode(u'', config=True,
111 help="""The full location of the file containing the connection information for
117 help="""The full location of the file containing the connection information for
112 the controller. If this is not given, the file must be in the
118 the controller. If this is not given, the file must be in the
113 security directory of the cluster directory. This location is
119 security directory of the cluster directory. This location is
114 resolved using the `profile` or `profile_dir` options.""",
120 resolved using the `profile` or `profile_dir` options.""",
115 )
121 )
116
122
117 url_file_name = Unicode(u'ipcontroller-engine.json')
123 url_file_name = Unicode(u'ipcontroller-engine.json')
118 log_url = Unicode('', config=True,
124 log_url = Unicode('', config=True,
119 help="""The URL for the iploggerapp instance, for forwarding
125 help="""The URL for the iploggerapp instance, for forwarding
120 logging to a central location.""")
126 logging to a central location.""")
121
127
122 aliases = Dict(dict(
128 aliases = Dict(dict(
123 file = 'IPEngineApp.url_file',
129 file = 'IPEngineApp.url_file',
124 c = 'IPEngineApp.startup_command',
130 c = 'IPEngineApp.startup_command',
125 s = 'IPEngineApp.startup_script',
131 s = 'IPEngineApp.startup_script',
126
132
127 ident = 'Session.session',
133 ident = 'Session.session',
128 user = 'Session.username',
134 user = 'Session.username',
129 exec_key = 'Session.keyfile',
135 exec_key = 'Session.keyfile',
130
136
131 url = 'EngineFactory.url',
137 url = 'EngineFactory.url',
132 ip = 'EngineFactory.ip',
138 ip = 'EngineFactory.ip',
133 transport = 'EngineFactory.transport',
139 transport = 'EngineFactory.transport',
134 port = 'EngineFactory.regport',
140 port = 'EngineFactory.regport',
135 location = 'EngineFactory.location',
141 location = 'EngineFactory.location',
136
142
137 timeout = 'EngineFactory.timeout',
143 timeout = 'EngineFactory.timeout',
138
144
139 profile = "IPEngineApp.profile",
145 profile = "IPEngineApp.profile",
140 profile_dir = 'ProfileDir.location',
146 profile_dir = 'ProfileDir.location',
141
147
142 mpi = 'MPI.use',
148 mpi = 'MPI.use',
143
149
144 log_level = 'IPEngineApp.log_level',
150 log_level = 'IPEngineApp.log_level',
145 log_url = 'IPEngineApp.log_url'
151 log_url = 'IPEngineApp.log_url'
146 ))
152 ))
147
153
148 # def find_key_file(self):
154 # def find_key_file(self):
149 # """Set the key file.
155 # """Set the key file.
150 #
156 #
151 # Here we don't try to actually see if it exists for is valid as that
157 # Here we don't try to actually see if it exists for is valid as that
152 # is hadled by the connection logic.
158 # is hadled by the connection logic.
153 # """
159 # """
154 # config = self.master_config
160 # config = self.master_config
155 # # Find the actual controller key file
161 # # Find the actual controller key file
156 # if not config.Global.key_file:
162 # if not config.Global.key_file:
157 # try_this = os.path.join(
163 # try_this = os.path.join(
158 # config.Global.profile_dir,
164 # config.Global.profile_dir,
159 # config.Global.security_dir,
165 # config.Global.security_dir,
160 # config.Global.key_file_name
166 # config.Global.key_file_name
161 # )
167 # )
162 # config.Global.key_file = try_this
168 # config.Global.key_file = try_this
163
169
164 def find_url_file(self):
170 def find_url_file(self):
165 """Set the key file.
171 """Set the key file.
166
172
167 Here we don't try to actually see if it exists for is valid as that
173 Here we don't try to actually see if it exists for is valid as that
168 is hadled by the connection logic.
174 is hadled by the connection logic.
169 """
175 """
170 config = self.config
176 config = self.config
171 # Find the actual controller key file
177 # Find the actual controller key file
172 if not self.url_file:
178 if not self.url_file:
173 self.url_file = os.path.join(
179 self.url_file = os.path.join(
174 self.profile_dir.security_dir,
180 self.profile_dir.security_dir,
175 self.url_file_name
181 self.url_file_name
176 )
182 )
177 def init_engine(self):
183 def init_engine(self):
178 # This is the working dir by now.
184 # This is the working dir by now.
179 sys.path.insert(0, '')
185 sys.path.insert(0, '')
180 config = self.config
186 config = self.config
181 # print config
187 # print config
182 self.find_url_file()
188 self.find_url_file()
183
189
184 # if os.path.exists(config.Global.key_file) and config.Global.secure:
190 # if os.path.exists(config.Global.key_file) and config.Global.secure:
185 # config.SessionFactory.exec_key = config.Global.key_file
191 # config.SessionFactory.exec_key = config.Global.key_file
186 if os.path.exists(self.url_file):
192 if os.path.exists(self.url_file):
187 with open(self.url_file) as f:
193 with open(self.url_file) as f:
188 d = json.loads(f.read())
194 d = json.loads(f.read())
189 for k,v in d.iteritems():
195 for k,v in d.iteritems():
190 if isinstance(v, unicode):
196 if isinstance(v, unicode):
191 d[k] = v.encode()
197 d[k] = v.encode()
192 if d['exec_key']:
198 if d['exec_key']:
193 config.Session.key = d['exec_key']
199 config.Session.key = d['exec_key']
194 d['url'] = disambiguate_url(d['url'], d['location'])
200 d['url'] = disambiguate_url(d['url'], d['location'])
195 config.EngineFactory.url = d['url']
201 config.EngineFactory.url = d['url']
196 config.EngineFactory.location = d['location']
202 config.EngineFactory.location = d['location']
197
203
198 try:
204 try:
199 exec_lines = config.Kernel.exec_lines
205 exec_lines = config.Kernel.exec_lines
200 except AttributeError:
206 except AttributeError:
201 config.Kernel.exec_lines = []
207 config.Kernel.exec_lines = []
202 exec_lines = config.Kernel.exec_lines
208 exec_lines = config.Kernel.exec_lines
203
209
204 if self.startup_script:
210 if self.startup_script:
205 enc = sys.getfilesystemencoding() or 'utf8'
211 enc = sys.getfilesystemencoding() or 'utf8'
206 cmd="execfile(%r)"%self.startup_script.encode(enc)
212 cmd="execfile(%r)"%self.startup_script.encode(enc)
207 exec_lines.append(cmd)
213 exec_lines.append(cmd)
208 if self.startup_command:
214 if self.startup_command:
209 exec_lines.append(self.startup_command)
215 exec_lines.append(self.startup_command)
210
216
211 # Create the underlying shell class and Engine
217 # Create the underlying shell class and Engine
212 # shell_class = import_item(self.master_config.Global.shell_class)
218 # shell_class = import_item(self.master_config.Global.shell_class)
213 # print self.config
219 # print self.config
214 try:
220 try:
215 self.engine = EngineFactory(config=config, log=self.log)
221 self.engine = EngineFactory(config=config, log=self.log)
216 except:
222 except:
217 self.log.error("Couldn't start the Engine", exc_info=True)
223 self.log.error("Couldn't start the Engine", exc_info=True)
218 self.exit(1)
224 self.exit(1)
219
225
220 def forward_logging(self):
226 def forward_logging(self):
221 if self.log_url:
227 if self.log_url:
222 self.log.info("Forwarding logging to %s"%self.log_url)
228 self.log.info("Forwarding logging to %s"%self.log_url)
223 context = self.engine.context
229 context = self.engine.context
224 lsock = context.socket(zmq.PUB)
230 lsock = context.socket(zmq.PUB)
225 lsock.connect(self.log_url)
231 lsock.connect(self.log_url)
226 self.log.removeHandler(self._log_handler)
232 self.log.removeHandler(self._log_handler)
227 handler = EnginePUBHandler(self.engine, lsock)
233 handler = EnginePUBHandler(self.engine, lsock)
228 handler.setLevel(self.log_level)
234 handler.setLevel(self.log_level)
229 self.log.addHandler(handler)
235 self.log.addHandler(handler)
230 self._log_handler = handler
236 self._log_handler = handler
231 #
237 #
232 def init_mpi(self):
238 def init_mpi(self):
233 global mpi
239 global mpi
234 self.mpi = MPI(config=self.config)
240 self.mpi = MPI(config=self.config)
235
241
236 mpi_import_statement = self.mpi.init_script
242 mpi_import_statement = self.mpi.init_script
237 if mpi_import_statement:
243 if mpi_import_statement:
238 try:
244 try:
239 self.log.info("Initializing MPI:")
245 self.log.info("Initializing MPI:")
240 self.log.info(mpi_import_statement)
246 self.log.info(mpi_import_statement)
241 exec mpi_import_statement in globals()
247 exec mpi_import_statement in globals()
242 except:
248 except:
243 mpi = None
249 mpi = None
244 else:
250 else:
245 mpi = None
251 mpi = None
246
252
247 def initialize(self, argv=None):
253 def initialize(self, argv=None):
248 super(IPEngineApp, self).initialize(argv)
254 super(IPEngineApp, self).initialize(argv)
249 self.init_mpi()
255 self.init_mpi()
250 self.init_engine()
256 self.init_engine()
251 self.forward_logging()
257 self.forward_logging()
252
258
253 def start(self):
259 def start(self):
254 self.engine.start()
260 self.engine.start()
255 try:
261 try:
256 self.engine.loop.start()
262 self.engine.loop.start()
257 except KeyboardInterrupt:
263 except KeyboardInterrupt:
258 self.log.critical("Engine Interrupted, shutting down...\n")
264 self.log.critical("Engine Interrupted, shutting down...\n")
259
265
260
266
261 def launch_new_instance():
267 def launch_new_instance():
262 """Create and run the IPython engine"""
268 """Create and run the IPython engine"""
263 app = IPEngineApp.instance()
269 app = IPEngineApp.instance()
264 app.initialize()
270 app.initialize()
265 app.start()
271 app.start()
266
272
267
273
268 if __name__ == '__main__':
274 if __name__ == '__main__':
269 launch_new_instance()
275 launch_new_instance()
270
276
@@ -1,96 +1,101 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 A simple IPython logger application
4 A simple IPython logger application
5
6 Authors:
7
8 * MinRK
9
5 """
10 """
6
11
7 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
8 # Copyright (C) 2011 The IPython Development Team
13 # Copyright (C) 2011 The IPython Development Team
9 #
14 #
10 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
13
18
14 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
15 # Imports
20 # Imports
16 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
17
22
18 import os
23 import os
19 import sys
24 import sys
20
25
21 import zmq
26 import zmq
22
27
23 from IPython.core.newapplication import ProfileDir
28 from IPython.core.newapplication import ProfileDir
24 from IPython.utils.traitlets import Bool, Dict, Unicode
29 from IPython.utils.traitlets import Bool, Dict, Unicode
25
30
26 from IPython.parallel.apps.baseapp import (
31 from IPython.parallel.apps.baseapp import (
27 BaseParallelApplication,
32 BaseParallelApplication,
28 base_aliases
33 base_aliases
29 )
34 )
30 from IPython.parallel.apps.logwatcher import LogWatcher
35 from IPython.parallel.apps.logwatcher import LogWatcher
31
36
32 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
33 # Module level variables
38 # Module level variables
34 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
35
40
36 #: The default config file name for this application
41 #: The default config file name for this application
37 default_config_file_name = u'iplogger_config.py'
42 default_config_file_name = u'iplogger_config.py'
38
43
39 _description = """Start an IPython logger for parallel computing.
44 _description = """Start an IPython logger for parallel computing.
40
45
41 IPython controllers and engines (and your own processes) can broadcast log messages
46 IPython controllers and engines (and your own processes) can broadcast log messages
42 by registering a `zmq.log.handlers.PUBHandler` with the `logging` module. The
47 by registering a `zmq.log.handlers.PUBHandler` with the `logging` module. The
43 logger can be configured using command line options or using a cluster
48 logger can be configured using command line options or using a cluster
44 directory. Cluster directories contain config, log and security files and are
49 directory. Cluster directories contain config, log and security files and are
45 usually located in your ipython directory and named as "cluster_<profile>".
50 usually located in your ipython directory and named as "cluster_<profile>".
46 See the `profile` and `profile_dir` options for details.
51 See the `profile` and `profile_dir` options for details.
47 """
52 """
48
53
49
54
50 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
51 # Main application
56 # Main application
52 #-----------------------------------------------------------------------------
57 #-----------------------------------------------------------------------------
53 aliases = {}
58 aliases = {}
54 aliases.update(base_aliases)
59 aliases.update(base_aliases)
55 aliases.update(dict(url='LogWatcher.url', topics='LogWatcher.topics'))
60 aliases.update(dict(url='LogWatcher.url', topics='LogWatcher.topics'))
56
61
57 class IPLoggerApp(BaseParallelApplication):
62 class IPLoggerApp(BaseParallelApplication):
58
63
59 name = u'iploggerz'
64 name = u'iploggerz'
60 description = _description
65 description = _description
61 config_file_name = Unicode(default_config_file_name)
66 config_file_name = Unicode(default_config_file_name)
62
67
63 classes = [LogWatcher, ProfileDir]
68 classes = [LogWatcher, ProfileDir]
64 aliases = Dict(aliases)
69 aliases = Dict(aliases)
65
70
66 def initialize(self, argv=None):
71 def initialize(self, argv=None):
67 super(IPLoggerApp, self).initialize(argv)
72 super(IPLoggerApp, self).initialize(argv)
68 self.init_watcher()
73 self.init_watcher()
69
74
70 def init_watcher(self):
75 def init_watcher(self):
71 try:
76 try:
72 self.watcher = LogWatcher(config=self.config, log=self.log)
77 self.watcher = LogWatcher(config=self.config, log=self.log)
73 except:
78 except:
74 self.log.error("Couldn't start the LogWatcher", exc_info=True)
79 self.log.error("Couldn't start the LogWatcher", exc_info=True)
75 self.exit(1)
80 self.exit(1)
76 self.log.info("Listening for log messages on %r"%self.watcher.url)
81 self.log.info("Listening for log messages on %r"%self.watcher.url)
77
82
78
83
79 def start(self):
84 def start(self):
80 self.watcher.start()
85 self.watcher.start()
81 try:
86 try:
82 self.watcher.loop.start()
87 self.watcher.loop.start()
83 except KeyboardInterrupt:
88 except KeyboardInterrupt:
84 self.log.critical("Logging Interrupted, shutting down...\n")
89 self.log.critical("Logging Interrupted, shutting down...\n")
85
90
86
91
87 def launch_new_instance():
92 def launch_new_instance():
88 """Create and run the IPython LogWatcher"""
93 """Create and run the IPython LogWatcher"""
89 app = IPLoggerApp.instance()
94 app = IPLoggerApp.instance()
90 app.initialize()
95 app.initialize()
91 app.start()
96 app.start()
92
97
93
98
94 if __name__ == '__main__':
99 if __name__ == '__main__':
95 launch_new_instance()
100 launch_new_instance()
96
101
@@ -1,1069 +1,1074 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 Facilities for launching IPython processes asynchronously.
4 Facilities for launching IPython processes asynchronously.
5
6 Authors:
7
8 * Brian Granger
9 * MinRK
5 """
10 """
6
11
7 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
13 # Copyright (C) 2008-2011 The IPython Development Team
9 #
14 #
10 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
13
18
14 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
15 # Imports
20 # Imports
16 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
17
22
18 import copy
23 import copy
19 import logging
24 import logging
20 import os
25 import os
21 import re
26 import re
22 import stat
27 import stat
23
28
24 # signal imports, handling various platforms, versions
29 # signal imports, handling various platforms, versions
25
30
26 from signal import SIGINT, SIGTERM
31 from signal import SIGINT, SIGTERM
27 try:
32 try:
28 from signal import SIGKILL
33 from signal import SIGKILL
29 except ImportError:
34 except ImportError:
30 # Windows
35 # Windows
31 SIGKILL=SIGTERM
36 SIGKILL=SIGTERM
32
37
33 try:
38 try:
34 # Windows >= 2.7, 3.2
39 # Windows >= 2.7, 3.2
35 from signal import CTRL_C_EVENT as SIGINT
40 from signal import CTRL_C_EVENT as SIGINT
36 except ImportError:
41 except ImportError:
37 pass
42 pass
38
43
39 from subprocess import Popen, PIPE, STDOUT
44 from subprocess import Popen, PIPE, STDOUT
40 try:
45 try:
41 from subprocess import check_output
46 from subprocess import check_output
42 except ImportError:
47 except ImportError:
43 # pre-2.7, define check_output with Popen
48 # pre-2.7, define check_output with Popen
44 def check_output(*args, **kwargs):
49 def check_output(*args, **kwargs):
45 kwargs.update(dict(stdout=PIPE))
50 kwargs.update(dict(stdout=PIPE))
46 p = Popen(*args, **kwargs)
51 p = Popen(*args, **kwargs)
47 out,err = p.communicate()
52 out,err = p.communicate()
48 return out
53 return out
49
54
50 from zmq.eventloop import ioloop
55 from zmq.eventloop import ioloop
51
56
52 from IPython.config.application import Application
57 from IPython.config.application import Application
53 from IPython.config.configurable import LoggingConfigurable
58 from IPython.config.configurable import LoggingConfigurable
54 from IPython.utils.text import EvalFormatter
59 from IPython.utils.text import EvalFormatter
55 from IPython.utils.traitlets import Any, Int, List, Unicode, Dict, Instance
60 from IPython.utils.traitlets import Any, Int, List, Unicode, Dict, Instance
56 from IPython.utils.path import get_ipython_module_path
61 from IPython.utils.path import get_ipython_module_path
57 from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError
62 from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError
58
63
59 from .win32support import forward_read_events
64 from .win32support import forward_read_events
60
65
61 from .winhpcjob import IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob
66 from .winhpcjob import IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob
62
67
63 WINDOWS = os.name == 'nt'
68 WINDOWS = os.name == 'nt'
64
69
65 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
66 # Paths to the kernel apps
71 # Paths to the kernel apps
67 #-----------------------------------------------------------------------------
72 #-----------------------------------------------------------------------------
68
73
69
74
70 ipcluster_cmd_argv = pycmd2argv(get_ipython_module_path(
75 ipcluster_cmd_argv = pycmd2argv(get_ipython_module_path(
71 'IPython.parallel.apps.ipclusterapp'
76 'IPython.parallel.apps.ipclusterapp'
72 ))
77 ))
73
78
74 ipengine_cmd_argv = pycmd2argv(get_ipython_module_path(
79 ipengine_cmd_argv = pycmd2argv(get_ipython_module_path(
75 'IPython.parallel.apps.ipengineapp'
80 'IPython.parallel.apps.ipengineapp'
76 ))
81 ))
77
82
78 ipcontroller_cmd_argv = pycmd2argv(get_ipython_module_path(
83 ipcontroller_cmd_argv = pycmd2argv(get_ipython_module_path(
79 'IPython.parallel.apps.ipcontrollerapp'
84 'IPython.parallel.apps.ipcontrollerapp'
80 ))
85 ))
81
86
82 #-----------------------------------------------------------------------------
87 #-----------------------------------------------------------------------------
83 # Base launchers and errors
88 # Base launchers and errors
84 #-----------------------------------------------------------------------------
89 #-----------------------------------------------------------------------------
85
90
86
91
87 class LauncherError(Exception):
92 class LauncherError(Exception):
88 pass
93 pass
89
94
90
95
91 class ProcessStateError(LauncherError):
96 class ProcessStateError(LauncherError):
92 pass
97 pass
93
98
94
99
95 class UnknownStatus(LauncherError):
100 class UnknownStatus(LauncherError):
96 pass
101 pass
97
102
98
103
99 class BaseLauncher(LoggingConfigurable):
104 class BaseLauncher(LoggingConfigurable):
100 """An asbtraction for starting, stopping and signaling a process."""
105 """An asbtraction for starting, stopping and signaling a process."""
101
106
102 # In all of the launchers, the work_dir is where child processes will be
107 # In all of the launchers, the work_dir is where child processes will be
103 # run. This will usually be the profile_dir, but may not be. any work_dir
108 # run. This will usually be the profile_dir, but may not be. any work_dir
104 # passed into the __init__ method will override the config value.
109 # passed into the __init__ method will override the config value.
105 # This should not be used to set the work_dir for the actual engine
110 # This should not be used to set the work_dir for the actual engine
106 # and controller. Instead, use their own config files or the
111 # and controller. Instead, use their own config files or the
107 # controller_args, engine_args attributes of the launchers to add
112 # controller_args, engine_args attributes of the launchers to add
108 # the work_dir option.
113 # the work_dir option.
109 work_dir = Unicode(u'.')
114 work_dir = Unicode(u'.')
110 loop = Instance('zmq.eventloop.ioloop.IOLoop')
115 loop = Instance('zmq.eventloop.ioloop.IOLoop')
111
116
112 start_data = Any()
117 start_data = Any()
113 stop_data = Any()
118 stop_data = Any()
114
119
115 def _loop_default(self):
120 def _loop_default(self):
116 return ioloop.IOLoop.instance()
121 return ioloop.IOLoop.instance()
117
122
118 def __init__(self, work_dir=u'.', config=None, **kwargs):
123 def __init__(self, work_dir=u'.', config=None, **kwargs):
119 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
124 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
120 self.state = 'before' # can be before, running, after
125 self.state = 'before' # can be before, running, after
121 self.stop_callbacks = []
126 self.stop_callbacks = []
122 self.start_data = None
127 self.start_data = None
123 self.stop_data = None
128 self.stop_data = None
124
129
125 @property
130 @property
126 def args(self):
131 def args(self):
127 """A list of cmd and args that will be used to start the process.
132 """A list of cmd and args that will be used to start the process.
128
133
129 This is what is passed to :func:`spawnProcess` and the first element
134 This is what is passed to :func:`spawnProcess` and the first element
130 will be the process name.
135 will be the process name.
131 """
136 """
132 return self.find_args()
137 return self.find_args()
133
138
134 def find_args(self):
139 def find_args(self):
135 """The ``.args`` property calls this to find the args list.
140 """The ``.args`` property calls this to find the args list.
136
141
137 Subcommand should implement this to construct the cmd and args.
142 Subcommand should implement this to construct the cmd and args.
138 """
143 """
139 raise NotImplementedError('find_args must be implemented in a subclass')
144 raise NotImplementedError('find_args must be implemented in a subclass')
140
145
141 @property
146 @property
142 def arg_str(self):
147 def arg_str(self):
143 """The string form of the program arguments."""
148 """The string form of the program arguments."""
144 return ' '.join(self.args)
149 return ' '.join(self.args)
145
150
146 @property
151 @property
147 def running(self):
152 def running(self):
148 """Am I running."""
153 """Am I running."""
149 if self.state == 'running':
154 if self.state == 'running':
150 return True
155 return True
151 else:
156 else:
152 return False
157 return False
153
158
154 def start(self):
159 def start(self):
155 """Start the process.
160 """Start the process.
156
161
157 This must return a deferred that fires with information about the
162 This must return a deferred that fires with information about the
158 process starting (like a pid, job id, etc.).
163 process starting (like a pid, job id, etc.).
159 """
164 """
160 raise NotImplementedError('start must be implemented in a subclass')
165 raise NotImplementedError('start must be implemented in a subclass')
161
166
162 def stop(self):
167 def stop(self):
163 """Stop the process and notify observers of stopping.
168 """Stop the process and notify observers of stopping.
164
169
165 This must return a deferred that fires with information about the
170 This must return a deferred that fires with information about the
166 processing stopping, like errors that occur while the process is
171 processing stopping, like errors that occur while the process is
167 attempting to be shut down. This deferred won't fire when the process
172 attempting to be shut down. This deferred won't fire when the process
168 actually stops. To observe the actual process stopping, see
173 actually stops. To observe the actual process stopping, see
169 :func:`observe_stop`.
174 :func:`observe_stop`.
170 """
175 """
171 raise NotImplementedError('stop must be implemented in a subclass')
176 raise NotImplementedError('stop must be implemented in a subclass')
172
177
173 def on_stop(self, f):
178 def on_stop(self, f):
174 """Get a deferred that will fire when the process stops.
179 """Get a deferred that will fire when the process stops.
175
180
176 The deferred will fire with data that contains information about
181 The deferred will fire with data that contains information about
177 the exit status of the process.
182 the exit status of the process.
178 """
183 """
179 if self.state=='after':
184 if self.state=='after':
180 return f(self.stop_data)
185 return f(self.stop_data)
181 else:
186 else:
182 self.stop_callbacks.append(f)
187 self.stop_callbacks.append(f)
183
188
184 def notify_start(self, data):
189 def notify_start(self, data):
185 """Call this to trigger startup actions.
190 """Call this to trigger startup actions.
186
191
187 This logs the process startup and sets the state to 'running'. It is
192 This logs the process startup and sets the state to 'running'. It is
188 a pass-through so it can be used as a callback.
193 a pass-through so it can be used as a callback.
189 """
194 """
190
195
191 self.log.info('Process %r started: %r' % (self.args[0], data))
196 self.log.info('Process %r started: %r' % (self.args[0], data))
192 self.start_data = data
197 self.start_data = data
193 self.state = 'running'
198 self.state = 'running'
194 return data
199 return data
195
200
196 def notify_stop(self, data):
201 def notify_stop(self, data):
197 """Call this to trigger process stop actions.
202 """Call this to trigger process stop actions.
198
203
199 This logs the process stopping and sets the state to 'after'. Call
204 This logs the process stopping and sets the state to 'after'. Call
200 this to trigger all the deferreds from :func:`observe_stop`."""
205 this to trigger all the deferreds from :func:`observe_stop`."""
201
206
202 self.log.info('Process %r stopped: %r' % (self.args[0], data))
207 self.log.info('Process %r stopped: %r' % (self.args[0], data))
203 self.stop_data = data
208 self.stop_data = data
204 self.state = 'after'
209 self.state = 'after'
205 for i in range(len(self.stop_callbacks)):
210 for i in range(len(self.stop_callbacks)):
206 d = self.stop_callbacks.pop()
211 d = self.stop_callbacks.pop()
207 d(data)
212 d(data)
208 return data
213 return data
209
214
210 def signal(self, sig):
215 def signal(self, sig):
211 """Signal the process.
216 """Signal the process.
212
217
213 Return a semi-meaningless deferred after signaling the process.
218 Return a semi-meaningless deferred after signaling the process.
214
219
215 Parameters
220 Parameters
216 ----------
221 ----------
217 sig : str or int
222 sig : str or int
218 'KILL', 'INT', etc., or any signal number
223 'KILL', 'INT', etc., or any signal number
219 """
224 """
220 raise NotImplementedError('signal must be implemented in a subclass')
225 raise NotImplementedError('signal must be implemented in a subclass')
221
226
222
227
223 #-----------------------------------------------------------------------------
228 #-----------------------------------------------------------------------------
224 # Local process launchers
229 # Local process launchers
225 #-----------------------------------------------------------------------------
230 #-----------------------------------------------------------------------------
226
231
227
232
228 class LocalProcessLauncher(BaseLauncher):
233 class LocalProcessLauncher(BaseLauncher):
229 """Start and stop an external process in an asynchronous manner.
234 """Start and stop an external process in an asynchronous manner.
230
235
231 This will launch the external process with a working directory of
236 This will launch the external process with a working directory of
232 ``self.work_dir``.
237 ``self.work_dir``.
233 """
238 """
234
239
235 # This is used to to construct self.args, which is passed to
240 # This is used to to construct self.args, which is passed to
236 # spawnProcess.
241 # spawnProcess.
237 cmd_and_args = List([])
242 cmd_and_args = List([])
238 poll_frequency = Int(100) # in ms
243 poll_frequency = Int(100) # in ms
239
244
240 def __init__(self, work_dir=u'.', config=None, **kwargs):
245 def __init__(self, work_dir=u'.', config=None, **kwargs):
241 super(LocalProcessLauncher, self).__init__(
246 super(LocalProcessLauncher, self).__init__(
242 work_dir=work_dir, config=config, **kwargs
247 work_dir=work_dir, config=config, **kwargs
243 )
248 )
244 self.process = None
249 self.process = None
245 self.start_deferred = None
250 self.start_deferred = None
246 self.poller = None
251 self.poller = None
247
252
248 def find_args(self):
253 def find_args(self):
249 return self.cmd_and_args
254 return self.cmd_and_args
250
255
251 def start(self):
256 def start(self):
252 if self.state == 'before':
257 if self.state == 'before':
253 self.process = Popen(self.args,
258 self.process = Popen(self.args,
254 stdout=PIPE,stderr=PIPE,stdin=PIPE,
259 stdout=PIPE,stderr=PIPE,stdin=PIPE,
255 env=os.environ,
260 env=os.environ,
256 cwd=self.work_dir
261 cwd=self.work_dir
257 )
262 )
258 if WINDOWS:
263 if WINDOWS:
259 self.stdout = forward_read_events(self.process.stdout)
264 self.stdout = forward_read_events(self.process.stdout)
260 self.stderr = forward_read_events(self.process.stderr)
265 self.stderr = forward_read_events(self.process.stderr)
261 else:
266 else:
262 self.stdout = self.process.stdout.fileno()
267 self.stdout = self.process.stdout.fileno()
263 self.stderr = self.process.stderr.fileno()
268 self.stderr = self.process.stderr.fileno()
264 self.loop.add_handler(self.stdout, self.handle_stdout, self.loop.READ)
269 self.loop.add_handler(self.stdout, self.handle_stdout, self.loop.READ)
265 self.loop.add_handler(self.stderr, self.handle_stderr, self.loop.READ)
270 self.loop.add_handler(self.stderr, self.handle_stderr, self.loop.READ)
266 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
271 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
267 self.poller.start()
272 self.poller.start()
268 self.notify_start(self.process.pid)
273 self.notify_start(self.process.pid)
269 else:
274 else:
270 s = 'The process was already started and has state: %r' % self.state
275 s = 'The process was already started and has state: %r' % self.state
271 raise ProcessStateError(s)
276 raise ProcessStateError(s)
272
277
273 def stop(self):
278 def stop(self):
274 return self.interrupt_then_kill()
279 return self.interrupt_then_kill()
275
280
276 def signal(self, sig):
281 def signal(self, sig):
277 if self.state == 'running':
282 if self.state == 'running':
278 if WINDOWS and sig != SIGINT:
283 if WINDOWS and sig != SIGINT:
279 # use Windows tree-kill for better child cleanup
284 # use Windows tree-kill for better child cleanup
280 check_output(['taskkill', '-pid', str(self.process.pid), '-t', '-f'])
285 check_output(['taskkill', '-pid', str(self.process.pid), '-t', '-f'])
281 else:
286 else:
282 self.process.send_signal(sig)
287 self.process.send_signal(sig)
283
288
284 def interrupt_then_kill(self, delay=2.0):
289 def interrupt_then_kill(self, delay=2.0):
285 """Send INT, wait a delay and then send KILL."""
290 """Send INT, wait a delay and then send KILL."""
286 try:
291 try:
287 self.signal(SIGINT)
292 self.signal(SIGINT)
288 except Exception:
293 except Exception:
289 self.log.debug("interrupt failed")
294 self.log.debug("interrupt failed")
290 pass
295 pass
291 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
296 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
292 self.killer.start()
297 self.killer.start()
293
298
294 # callbacks, etc:
299 # callbacks, etc:
295
300
296 def handle_stdout(self, fd, events):
301 def handle_stdout(self, fd, events):
297 if WINDOWS:
302 if WINDOWS:
298 line = self.stdout.recv()
303 line = self.stdout.recv()
299 else:
304 else:
300 line = self.process.stdout.readline()
305 line = self.process.stdout.readline()
301 # a stopped process will be readable but return empty strings
306 # a stopped process will be readable but return empty strings
302 if line:
307 if line:
303 self.log.info(line[:-1])
308 self.log.info(line[:-1])
304 else:
309 else:
305 self.poll()
310 self.poll()
306
311
307 def handle_stderr(self, fd, events):
312 def handle_stderr(self, fd, events):
308 if WINDOWS:
313 if WINDOWS:
309 line = self.stderr.recv()
314 line = self.stderr.recv()
310 else:
315 else:
311 line = self.process.stderr.readline()
316 line = self.process.stderr.readline()
312 # a stopped process will be readable but return empty strings
317 # a stopped process will be readable but return empty strings
313 if line:
318 if line:
314 self.log.error(line[:-1])
319 self.log.error(line[:-1])
315 else:
320 else:
316 self.poll()
321 self.poll()
317
322
318 def poll(self):
323 def poll(self):
319 status = self.process.poll()
324 status = self.process.poll()
320 if status is not None:
325 if status is not None:
321 self.poller.stop()
326 self.poller.stop()
322 self.loop.remove_handler(self.stdout)
327 self.loop.remove_handler(self.stdout)
323 self.loop.remove_handler(self.stderr)
328 self.loop.remove_handler(self.stderr)
324 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
329 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
325 return status
330 return status
326
331
327 class LocalControllerLauncher(LocalProcessLauncher):
332 class LocalControllerLauncher(LocalProcessLauncher):
328 """Launch a controller as a regular external process."""
333 """Launch a controller as a regular external process."""
329
334
330 controller_cmd = List(ipcontroller_cmd_argv, config=True,
335 controller_cmd = List(ipcontroller_cmd_argv, config=True,
331 help="""Popen command to launch ipcontroller.""")
336 help="""Popen command to launch ipcontroller.""")
332 # Command line arguments to ipcontroller.
337 # Command line arguments to ipcontroller.
333 controller_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
338 controller_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
334 help="""command-line args to pass to ipcontroller""")
339 help="""command-line args to pass to ipcontroller""")
335
340
336 def find_args(self):
341 def find_args(self):
337 return self.controller_cmd + self.controller_args
342 return self.controller_cmd + self.controller_args
338
343
339 def start(self, profile_dir):
344 def start(self, profile_dir):
340 """Start the controller by profile_dir."""
345 """Start the controller by profile_dir."""
341 self.controller_args.extend(['profile_dir=%s'%profile_dir])
346 self.controller_args.extend(['profile_dir=%s'%profile_dir])
342 self.profile_dir = unicode(profile_dir)
347 self.profile_dir = unicode(profile_dir)
343 self.log.info("Starting LocalControllerLauncher: %r" % self.args)
348 self.log.info("Starting LocalControllerLauncher: %r" % self.args)
344 return super(LocalControllerLauncher, self).start()
349 return super(LocalControllerLauncher, self).start()
345
350
346
351
347 class LocalEngineLauncher(LocalProcessLauncher):
352 class LocalEngineLauncher(LocalProcessLauncher):
348 """Launch a single engine as a regular externall process."""
353 """Launch a single engine as a regular externall process."""
349
354
350 engine_cmd = List(ipengine_cmd_argv, config=True,
355 engine_cmd = List(ipengine_cmd_argv, config=True,
351 help="""command to launch the Engine.""")
356 help="""command to launch the Engine.""")
352 # Command line arguments for ipengine.
357 # Command line arguments for ipengine.
353 engine_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
358 engine_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
354 help="command-line arguments to pass to ipengine"
359 help="command-line arguments to pass to ipengine"
355 )
360 )
356
361
357 def find_args(self):
362 def find_args(self):
358 return self.engine_cmd + self.engine_args
363 return self.engine_cmd + self.engine_args
359
364
360 def start(self, profile_dir):
365 def start(self, profile_dir):
361 """Start the engine by profile_dir."""
366 """Start the engine by profile_dir."""
362 self.engine_args.extend(['profile_dir=%s'%profile_dir])
367 self.engine_args.extend(['profile_dir=%s'%profile_dir])
363 self.profile_dir = unicode(profile_dir)
368 self.profile_dir = unicode(profile_dir)
364 return super(LocalEngineLauncher, self).start()
369 return super(LocalEngineLauncher, self).start()
365
370
366
371
367 class LocalEngineSetLauncher(BaseLauncher):
372 class LocalEngineSetLauncher(BaseLauncher):
368 """Launch a set of engines as regular external processes."""
373 """Launch a set of engines as regular external processes."""
369
374
370 # Command line arguments for ipengine.
375 # Command line arguments for ipengine.
371 engine_args = List(
376 engine_args = List(
372 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
377 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
373 help="command-line arguments to pass to ipengine"
378 help="command-line arguments to pass to ipengine"
374 )
379 )
375 # launcher class
380 # launcher class
376 launcher_class = LocalEngineLauncher
381 launcher_class = LocalEngineLauncher
377
382
378 launchers = Dict()
383 launchers = Dict()
379 stop_data = Dict()
384 stop_data = Dict()
380
385
381 def __init__(self, work_dir=u'.', config=None, **kwargs):
386 def __init__(self, work_dir=u'.', config=None, **kwargs):
382 super(LocalEngineSetLauncher, self).__init__(
387 super(LocalEngineSetLauncher, self).__init__(
383 work_dir=work_dir, config=config, **kwargs
388 work_dir=work_dir, config=config, **kwargs
384 )
389 )
385 self.stop_data = {}
390 self.stop_data = {}
386
391
387 def start(self, n, profile_dir):
392 def start(self, n, profile_dir):
388 """Start n engines by profile or profile_dir."""
393 """Start n engines by profile or profile_dir."""
389 self.profile_dir = unicode(profile_dir)
394 self.profile_dir = unicode(profile_dir)
390 dlist = []
395 dlist = []
391 for i in range(n):
396 for i in range(n):
392 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
397 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
393 # Copy the engine args over to each engine launcher.
398 # Copy the engine args over to each engine launcher.
394 el.engine_args = copy.deepcopy(self.engine_args)
399 el.engine_args = copy.deepcopy(self.engine_args)
395 el.on_stop(self._notice_engine_stopped)
400 el.on_stop(self._notice_engine_stopped)
396 d = el.start(profile_dir)
401 d = el.start(profile_dir)
397 if i==0:
402 if i==0:
398 self.log.info("Starting LocalEngineSetLauncher: %r" % el.args)
403 self.log.info("Starting LocalEngineSetLauncher: %r" % el.args)
399 self.launchers[i] = el
404 self.launchers[i] = el
400 dlist.append(d)
405 dlist.append(d)
401 self.notify_start(dlist)
406 self.notify_start(dlist)
402 # The consumeErrors here could be dangerous
407 # The consumeErrors here could be dangerous
403 # dfinal = gatherBoth(dlist, consumeErrors=True)
408 # dfinal = gatherBoth(dlist, consumeErrors=True)
404 # dfinal.addCallback(self.notify_start)
409 # dfinal.addCallback(self.notify_start)
405 return dlist
410 return dlist
406
411
407 def find_args(self):
412 def find_args(self):
408 return ['engine set']
413 return ['engine set']
409
414
410 def signal(self, sig):
415 def signal(self, sig):
411 dlist = []
416 dlist = []
412 for el in self.launchers.itervalues():
417 for el in self.launchers.itervalues():
413 d = el.signal(sig)
418 d = el.signal(sig)
414 dlist.append(d)
419 dlist.append(d)
415 # dfinal = gatherBoth(dlist, consumeErrors=True)
420 # dfinal = gatherBoth(dlist, consumeErrors=True)
416 return dlist
421 return dlist
417
422
418 def interrupt_then_kill(self, delay=1.0):
423 def interrupt_then_kill(self, delay=1.0):
419 dlist = []
424 dlist = []
420 for el in self.launchers.itervalues():
425 for el in self.launchers.itervalues():
421 d = el.interrupt_then_kill(delay)
426 d = el.interrupt_then_kill(delay)
422 dlist.append(d)
427 dlist.append(d)
423 # dfinal = gatherBoth(dlist, consumeErrors=True)
428 # dfinal = gatherBoth(dlist, consumeErrors=True)
424 return dlist
429 return dlist
425
430
426 def stop(self):
431 def stop(self):
427 return self.interrupt_then_kill()
432 return self.interrupt_then_kill()
428
433
429 def _notice_engine_stopped(self, data):
434 def _notice_engine_stopped(self, data):
430 pid = data['pid']
435 pid = data['pid']
431 for idx,el in self.launchers.iteritems():
436 for idx,el in self.launchers.iteritems():
432 if el.process.pid == pid:
437 if el.process.pid == pid:
433 break
438 break
434 self.launchers.pop(idx)
439 self.launchers.pop(idx)
435 self.stop_data[idx] = data
440 self.stop_data[idx] = data
436 if not self.launchers:
441 if not self.launchers:
437 self.notify_stop(self.stop_data)
442 self.notify_stop(self.stop_data)
438
443
439
444
440 #-----------------------------------------------------------------------------
445 #-----------------------------------------------------------------------------
441 # MPIExec launchers
446 # MPIExec launchers
442 #-----------------------------------------------------------------------------
447 #-----------------------------------------------------------------------------
443
448
444
449
445 class MPIExecLauncher(LocalProcessLauncher):
450 class MPIExecLauncher(LocalProcessLauncher):
446 """Launch an external process using mpiexec."""
451 """Launch an external process using mpiexec."""
447
452
448 mpi_cmd = List(['mpiexec'], config=True,
453 mpi_cmd = List(['mpiexec'], config=True,
449 help="The mpiexec command to use in starting the process."
454 help="The mpiexec command to use in starting the process."
450 )
455 )
451 mpi_args = List([], config=True,
456 mpi_args = List([], config=True,
452 help="The command line arguments to pass to mpiexec."
457 help="The command line arguments to pass to mpiexec."
453 )
458 )
454 program = List(['date'], config=True,
459 program = List(['date'], config=True,
455 help="The program to start via mpiexec.")
460 help="The program to start via mpiexec.")
456 program_args = List([], config=True,
461 program_args = List([], config=True,
457 help="The command line argument to the program."
462 help="The command line argument to the program."
458 )
463 )
459 n = Int(1)
464 n = Int(1)
460
465
461 def find_args(self):
466 def find_args(self):
462 """Build self.args using all the fields."""
467 """Build self.args using all the fields."""
463 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
468 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
464 self.program + self.program_args
469 self.program + self.program_args
465
470
466 def start(self, n):
471 def start(self, n):
467 """Start n instances of the program using mpiexec."""
472 """Start n instances of the program using mpiexec."""
468 self.n = n
473 self.n = n
469 return super(MPIExecLauncher, self).start()
474 return super(MPIExecLauncher, self).start()
470
475
471
476
472 class MPIExecControllerLauncher(MPIExecLauncher):
477 class MPIExecControllerLauncher(MPIExecLauncher):
473 """Launch a controller using mpiexec."""
478 """Launch a controller using mpiexec."""
474
479
475 controller_cmd = List(ipcontroller_cmd_argv, config=True,
480 controller_cmd = List(ipcontroller_cmd_argv, config=True,
476 help="Popen command to launch the Contropper"
481 help="Popen command to launch the Contropper"
477 )
482 )
478 controller_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
483 controller_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
479 help="Command line arguments to pass to ipcontroller."
484 help="Command line arguments to pass to ipcontroller."
480 )
485 )
481 n = Int(1)
486 n = Int(1)
482
487
483 def start(self, profile_dir):
488 def start(self, profile_dir):
484 """Start the controller by profile_dir."""
489 """Start the controller by profile_dir."""
485 self.controller_args.extend(['profile_dir=%s'%profile_dir])
490 self.controller_args.extend(['profile_dir=%s'%profile_dir])
486 self.profile_dir = unicode(profile_dir)
491 self.profile_dir = unicode(profile_dir)
487 self.log.info("Starting MPIExecControllerLauncher: %r" % self.args)
492 self.log.info("Starting MPIExecControllerLauncher: %r" % self.args)
488 return super(MPIExecControllerLauncher, self).start(1)
493 return super(MPIExecControllerLauncher, self).start(1)
489
494
490 def find_args(self):
495 def find_args(self):
491 return self.mpi_cmd + ['-n', self.n] + self.mpi_args + \
496 return self.mpi_cmd + ['-n', self.n] + self.mpi_args + \
492 self.controller_cmd + self.controller_args
497 self.controller_cmd + self.controller_args
493
498
494
499
495 class MPIExecEngineSetLauncher(MPIExecLauncher):
500 class MPIExecEngineSetLauncher(MPIExecLauncher):
496
501
497 program = List(ipengine_cmd_argv, config=True,
502 program = List(ipengine_cmd_argv, config=True,
498 help="Popen command for ipengine"
503 help="Popen command for ipengine"
499 )
504 )
500 program_args = List(
505 program_args = List(
501 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
506 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
502 help="Command line arguments for ipengine."
507 help="Command line arguments for ipengine."
503 )
508 )
504 n = Int(1)
509 n = Int(1)
505
510
506 def start(self, n, profile_dir):
511 def start(self, n, profile_dir):
507 """Start n engines by profile or profile_dir."""
512 """Start n engines by profile or profile_dir."""
508 self.program_args.extend(['profile_dir=%s'%profile_dir])
513 self.program_args.extend(['profile_dir=%s'%profile_dir])
509 self.profile_dir = unicode(profile_dir)
514 self.profile_dir = unicode(profile_dir)
510 self.n = n
515 self.n = n
511 self.log.info('Starting MPIExecEngineSetLauncher: %r' % self.args)
516 self.log.info('Starting MPIExecEngineSetLauncher: %r' % self.args)
512 return super(MPIExecEngineSetLauncher, self).start(n)
517 return super(MPIExecEngineSetLauncher, self).start(n)
513
518
514 #-----------------------------------------------------------------------------
519 #-----------------------------------------------------------------------------
515 # SSH launchers
520 # SSH launchers
516 #-----------------------------------------------------------------------------
521 #-----------------------------------------------------------------------------
517
522
518 # TODO: Get SSH Launcher working again.
523 # TODO: Get SSH Launcher working again.
519
524
520 class SSHLauncher(LocalProcessLauncher):
525 class SSHLauncher(LocalProcessLauncher):
521 """A minimal launcher for ssh.
526 """A minimal launcher for ssh.
522
527
523 To be useful this will probably have to be extended to use the ``sshx``
528 To be useful this will probably have to be extended to use the ``sshx``
524 idea for environment variables. There could be other things this needs
529 idea for environment variables. There could be other things this needs
525 as well.
530 as well.
526 """
531 """
527
532
528 ssh_cmd = List(['ssh'], config=True,
533 ssh_cmd = List(['ssh'], config=True,
529 help="command for starting ssh")
534 help="command for starting ssh")
530 ssh_args = List(['-tt'], config=True,
535 ssh_args = List(['-tt'], config=True,
531 help="args to pass to ssh")
536 help="args to pass to ssh")
532 program = List(['date'], config=True,
537 program = List(['date'], config=True,
533 help="Program to launch via ssh")
538 help="Program to launch via ssh")
534 program_args = List([], config=True,
539 program_args = List([], config=True,
535 help="args to pass to remote program")
540 help="args to pass to remote program")
536 hostname = Unicode('', config=True,
541 hostname = Unicode('', config=True,
537 help="hostname on which to launch the program")
542 help="hostname on which to launch the program")
538 user = Unicode('', config=True,
543 user = Unicode('', config=True,
539 help="username for ssh")
544 help="username for ssh")
540 location = Unicode('', config=True,
545 location = Unicode('', config=True,
541 help="user@hostname location for ssh in one setting")
546 help="user@hostname location for ssh in one setting")
542
547
543 def _hostname_changed(self, name, old, new):
548 def _hostname_changed(self, name, old, new):
544 if self.user:
549 if self.user:
545 self.location = u'%s@%s' % (self.user, new)
550 self.location = u'%s@%s' % (self.user, new)
546 else:
551 else:
547 self.location = new
552 self.location = new
548
553
549 def _user_changed(self, name, old, new):
554 def _user_changed(self, name, old, new):
550 self.location = u'%s@%s' % (new, self.hostname)
555 self.location = u'%s@%s' % (new, self.hostname)
551
556
552 def find_args(self):
557 def find_args(self):
553 return self.ssh_cmd + self.ssh_args + [self.location] + \
558 return self.ssh_cmd + self.ssh_args + [self.location] + \
554 self.program + self.program_args
559 self.program + self.program_args
555
560
556 def start(self, profile_dir, hostname=None, user=None):
561 def start(self, profile_dir, hostname=None, user=None):
557 self.profile_dir = unicode(profile_dir)
562 self.profile_dir = unicode(profile_dir)
558 if hostname is not None:
563 if hostname is not None:
559 self.hostname = hostname
564 self.hostname = hostname
560 if user is not None:
565 if user is not None:
561 self.user = user
566 self.user = user
562
567
563 return super(SSHLauncher, self).start()
568 return super(SSHLauncher, self).start()
564
569
565 def signal(self, sig):
570 def signal(self, sig):
566 if self.state == 'running':
571 if self.state == 'running':
567 # send escaped ssh connection-closer
572 # send escaped ssh connection-closer
568 self.process.stdin.write('~.')
573 self.process.stdin.write('~.')
569 self.process.stdin.flush()
574 self.process.stdin.flush()
570
575
571
576
572
577
573 class SSHControllerLauncher(SSHLauncher):
578 class SSHControllerLauncher(SSHLauncher):
574
579
575 program = List(ipcontroller_cmd_argv, config=True,
580 program = List(ipcontroller_cmd_argv, config=True,
576 help="remote ipcontroller command.")
581 help="remote ipcontroller command.")
577 program_args = List(['--reuse-files', '--log-to-file','log_level=%i'%logging.INFO], config=True,
582 program_args = List(['--reuse-files', '--log-to-file','log_level=%i'%logging.INFO], config=True,
578 help="Command line arguments to ipcontroller.")
583 help="Command line arguments to ipcontroller.")
579
584
580
585
581 class SSHEngineLauncher(SSHLauncher):
586 class SSHEngineLauncher(SSHLauncher):
582 program = List(ipengine_cmd_argv, config=True,
587 program = List(ipengine_cmd_argv, config=True,
583 help="remote ipengine command.")
588 help="remote ipengine command.")
584 # Command line arguments for ipengine.
589 # Command line arguments for ipengine.
585 program_args = List(
590 program_args = List(
586 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
591 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
587 help="Command line arguments to ipengine."
592 help="Command line arguments to ipengine."
588 )
593 )
589
594
590 class SSHEngineSetLauncher(LocalEngineSetLauncher):
595 class SSHEngineSetLauncher(LocalEngineSetLauncher):
591 launcher_class = SSHEngineLauncher
596 launcher_class = SSHEngineLauncher
592 engines = Dict(config=True,
597 engines = Dict(config=True,
593 help="""dict of engines to launch. This is a dict by hostname of ints,
598 help="""dict of engines to launch. This is a dict by hostname of ints,
594 corresponding to the number of engines to start on that host.""")
599 corresponding to the number of engines to start on that host.""")
595
600
596 def start(self, n, profile_dir):
601 def start(self, n, profile_dir):
597 """Start engines by profile or profile_dir.
602 """Start engines by profile or profile_dir.
598 `n` is ignored, and the `engines` config property is used instead.
603 `n` is ignored, and the `engines` config property is used instead.
599 """
604 """
600
605
601 self.profile_dir = unicode(profile_dir)
606 self.profile_dir = unicode(profile_dir)
602 dlist = []
607 dlist = []
603 for host, n in self.engines.iteritems():
608 for host, n in self.engines.iteritems():
604 if isinstance(n, (tuple, list)):
609 if isinstance(n, (tuple, list)):
605 n, args = n
610 n, args = n
606 else:
611 else:
607 args = copy.deepcopy(self.engine_args)
612 args = copy.deepcopy(self.engine_args)
608
613
609 if '@' in host:
614 if '@' in host:
610 user,host = host.split('@',1)
615 user,host = host.split('@',1)
611 else:
616 else:
612 user=None
617 user=None
613 for i in range(n):
618 for i in range(n):
614 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
619 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
615
620
616 # Copy the engine args over to each engine launcher.
621 # Copy the engine args over to each engine launcher.
617 i
622 i
618 el.program_args = args
623 el.program_args = args
619 el.on_stop(self._notice_engine_stopped)
624 el.on_stop(self._notice_engine_stopped)
620 d = el.start(profile_dir, user=user, hostname=host)
625 d = el.start(profile_dir, user=user, hostname=host)
621 if i==0:
626 if i==0:
622 self.log.info("Starting SSHEngineSetLauncher: %r" % el.args)
627 self.log.info("Starting SSHEngineSetLauncher: %r" % el.args)
623 self.launchers[host+str(i)] = el
628 self.launchers[host+str(i)] = el
624 dlist.append(d)
629 dlist.append(d)
625 self.notify_start(dlist)
630 self.notify_start(dlist)
626 return dlist
631 return dlist
627
632
628
633
629
634
630 #-----------------------------------------------------------------------------
635 #-----------------------------------------------------------------------------
631 # Windows HPC Server 2008 scheduler launchers
636 # Windows HPC Server 2008 scheduler launchers
632 #-----------------------------------------------------------------------------
637 #-----------------------------------------------------------------------------
633
638
634
639
635 # This is only used on Windows.
640 # This is only used on Windows.
636 def find_job_cmd():
641 def find_job_cmd():
637 if WINDOWS:
642 if WINDOWS:
638 try:
643 try:
639 return find_cmd('job')
644 return find_cmd('job')
640 except (FindCmdError, ImportError):
645 except (FindCmdError, ImportError):
641 # ImportError will be raised if win32api is not installed
646 # ImportError will be raised if win32api is not installed
642 return 'job'
647 return 'job'
643 else:
648 else:
644 return 'job'
649 return 'job'
645
650
646
651
647 class WindowsHPCLauncher(BaseLauncher):
652 class WindowsHPCLauncher(BaseLauncher):
648
653
649 job_id_regexp = Unicode(r'\d+', config=True,
654 job_id_regexp = Unicode(r'\d+', config=True,
650 help="""A regular expression used to get the job id from the output of the
655 help="""A regular expression used to get the job id from the output of the
651 submit_command. """
656 submit_command. """
652 )
657 )
653 job_file_name = Unicode(u'ipython_job.xml', config=True,
658 job_file_name = Unicode(u'ipython_job.xml', config=True,
654 help="The filename of the instantiated job script.")
659 help="The filename of the instantiated job script.")
655 # The full path to the instantiated job script. This gets made dynamically
660 # The full path to the instantiated job script. This gets made dynamically
656 # by combining the work_dir with the job_file_name.
661 # by combining the work_dir with the job_file_name.
657 job_file = Unicode(u'')
662 job_file = Unicode(u'')
658 scheduler = Unicode('', config=True,
663 scheduler = Unicode('', config=True,
659 help="The hostname of the scheduler to submit the job to.")
664 help="The hostname of the scheduler to submit the job to.")
660 job_cmd = Unicode(find_job_cmd(), config=True,
665 job_cmd = Unicode(find_job_cmd(), config=True,
661 help="The command for submitting jobs.")
666 help="The command for submitting jobs.")
662
667
663 def __init__(self, work_dir=u'.', config=None, **kwargs):
668 def __init__(self, work_dir=u'.', config=None, **kwargs):
664 super(WindowsHPCLauncher, self).__init__(
669 super(WindowsHPCLauncher, self).__init__(
665 work_dir=work_dir, config=config, **kwargs
670 work_dir=work_dir, config=config, **kwargs
666 )
671 )
667
672
668 @property
673 @property
669 def job_file(self):
674 def job_file(self):
670 return os.path.join(self.work_dir, self.job_file_name)
675 return os.path.join(self.work_dir, self.job_file_name)
671
676
672 def write_job_file(self, n):
677 def write_job_file(self, n):
673 raise NotImplementedError("Implement write_job_file in a subclass.")
678 raise NotImplementedError("Implement write_job_file in a subclass.")
674
679
675 def find_args(self):
680 def find_args(self):
676 return [u'job.exe']
681 return [u'job.exe']
677
682
678 def parse_job_id(self, output):
683 def parse_job_id(self, output):
679 """Take the output of the submit command and return the job id."""
684 """Take the output of the submit command and return the job id."""
680 m = re.search(self.job_id_regexp, output)
685 m = re.search(self.job_id_regexp, output)
681 if m is not None:
686 if m is not None:
682 job_id = m.group()
687 job_id = m.group()
683 else:
688 else:
684 raise LauncherError("Job id couldn't be determined: %s" % output)
689 raise LauncherError("Job id couldn't be determined: %s" % output)
685 self.job_id = job_id
690 self.job_id = job_id
686 self.log.info('Job started with job id: %r' % job_id)
691 self.log.info('Job started with job id: %r' % job_id)
687 return job_id
692 return job_id
688
693
689 def start(self, n):
694 def start(self, n):
690 """Start n copies of the process using the Win HPC job scheduler."""
695 """Start n copies of the process using the Win HPC job scheduler."""
691 self.write_job_file(n)
696 self.write_job_file(n)
692 args = [
697 args = [
693 'submit',
698 'submit',
694 '/jobfile:%s' % self.job_file,
699 '/jobfile:%s' % self.job_file,
695 '/scheduler:%s' % self.scheduler
700 '/scheduler:%s' % self.scheduler
696 ]
701 ]
697 self.log.info("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
702 self.log.info("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
698 # Twisted will raise DeprecationWarnings if we try to pass unicode to this
703 # Twisted will raise DeprecationWarnings if we try to pass unicode to this
699 output = check_output([self.job_cmd]+args,
704 output = check_output([self.job_cmd]+args,
700 env=os.environ,
705 env=os.environ,
701 cwd=self.work_dir,
706 cwd=self.work_dir,
702 stderr=STDOUT
707 stderr=STDOUT
703 )
708 )
704 job_id = self.parse_job_id(output)
709 job_id = self.parse_job_id(output)
705 self.notify_start(job_id)
710 self.notify_start(job_id)
706 return job_id
711 return job_id
707
712
708 def stop(self):
713 def stop(self):
709 args = [
714 args = [
710 'cancel',
715 'cancel',
711 self.job_id,
716 self.job_id,
712 '/scheduler:%s' % self.scheduler
717 '/scheduler:%s' % self.scheduler
713 ]
718 ]
714 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
719 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
715 try:
720 try:
716 output = check_output([self.job_cmd]+args,
721 output = check_output([self.job_cmd]+args,
717 env=os.environ,
722 env=os.environ,
718 cwd=self.work_dir,
723 cwd=self.work_dir,
719 stderr=STDOUT
724 stderr=STDOUT
720 )
725 )
721 except:
726 except:
722 output = 'The job already appears to be stoppped: %r' % self.job_id
727 output = 'The job already appears to be stoppped: %r' % self.job_id
723 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
728 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
724 return output
729 return output
725
730
726
731
727 class WindowsHPCControllerLauncher(WindowsHPCLauncher):
732 class WindowsHPCControllerLauncher(WindowsHPCLauncher):
728
733
729 job_file_name = Unicode(u'ipcontroller_job.xml', config=True,
734 job_file_name = Unicode(u'ipcontroller_job.xml', config=True,
730 help="WinHPC xml job file.")
735 help="WinHPC xml job file.")
731 extra_args = List([], config=False,
736 extra_args = List([], config=False,
732 help="extra args to pass to ipcontroller")
737 help="extra args to pass to ipcontroller")
733
738
734 def write_job_file(self, n):
739 def write_job_file(self, n):
735 job = IPControllerJob(config=self.config)
740 job = IPControllerJob(config=self.config)
736
741
737 t = IPControllerTask(config=self.config)
742 t = IPControllerTask(config=self.config)
738 # The tasks work directory is *not* the actual work directory of
743 # The tasks work directory is *not* the actual work directory of
739 # the controller. It is used as the base path for the stdout/stderr
744 # the controller. It is used as the base path for the stdout/stderr
740 # files that the scheduler redirects to.
745 # files that the scheduler redirects to.
741 t.work_directory = self.profile_dir
746 t.work_directory = self.profile_dir
742 # Add the profile_dir and from self.start().
747 # Add the profile_dir and from self.start().
743 t.controller_args.extend(self.extra_args)
748 t.controller_args.extend(self.extra_args)
744 job.add_task(t)
749 job.add_task(t)
745
750
746 self.log.info("Writing job description file: %s" % self.job_file)
751 self.log.info("Writing job description file: %s" % self.job_file)
747 job.write(self.job_file)
752 job.write(self.job_file)
748
753
749 @property
754 @property
750 def job_file(self):
755 def job_file(self):
751 return os.path.join(self.profile_dir, self.job_file_name)
756 return os.path.join(self.profile_dir, self.job_file_name)
752
757
753 def start(self, profile_dir):
758 def start(self, profile_dir):
754 """Start the controller by profile_dir."""
759 """Start the controller by profile_dir."""
755 self.extra_args = ['profile_dir=%s'%profile_dir]
760 self.extra_args = ['profile_dir=%s'%profile_dir]
756 self.profile_dir = unicode(profile_dir)
761 self.profile_dir = unicode(profile_dir)
757 return super(WindowsHPCControllerLauncher, self).start(1)
762 return super(WindowsHPCControllerLauncher, self).start(1)
758
763
759
764
760 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher):
765 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher):
761
766
762 job_file_name = Unicode(u'ipengineset_job.xml', config=True,
767 job_file_name = Unicode(u'ipengineset_job.xml', config=True,
763 help="jobfile for ipengines job")
768 help="jobfile for ipengines job")
764 extra_args = List([], config=False,
769 extra_args = List([], config=False,
765 help="extra args to pas to ipengine")
770 help="extra args to pas to ipengine")
766
771
767 def write_job_file(self, n):
772 def write_job_file(self, n):
768 job = IPEngineSetJob(config=self.config)
773 job = IPEngineSetJob(config=self.config)
769
774
770 for i in range(n):
775 for i in range(n):
771 t = IPEngineTask(config=self.config)
776 t = IPEngineTask(config=self.config)
772 # The tasks work directory is *not* the actual work directory of
777 # The tasks work directory is *not* the actual work directory of
773 # the engine. It is used as the base path for the stdout/stderr
778 # the engine. It is used as the base path for the stdout/stderr
774 # files that the scheduler redirects to.
779 # files that the scheduler redirects to.
775 t.work_directory = self.profile_dir
780 t.work_directory = self.profile_dir
776 # Add the profile_dir and from self.start().
781 # Add the profile_dir and from self.start().
777 t.engine_args.extend(self.extra_args)
782 t.engine_args.extend(self.extra_args)
778 job.add_task(t)
783 job.add_task(t)
779
784
780 self.log.info("Writing job description file: %s" % self.job_file)
785 self.log.info("Writing job description file: %s" % self.job_file)
781 job.write(self.job_file)
786 job.write(self.job_file)
782
787
783 @property
788 @property
784 def job_file(self):
789 def job_file(self):
785 return os.path.join(self.profile_dir, self.job_file_name)
790 return os.path.join(self.profile_dir, self.job_file_name)
786
791
787 def start(self, n, profile_dir):
792 def start(self, n, profile_dir):
788 """Start the controller by profile_dir."""
793 """Start the controller by profile_dir."""
789 self.extra_args = ['profile_dir=%s'%profile_dir]
794 self.extra_args = ['profile_dir=%s'%profile_dir]
790 self.profile_dir = unicode(profile_dir)
795 self.profile_dir = unicode(profile_dir)
791 return super(WindowsHPCEngineSetLauncher, self).start(n)
796 return super(WindowsHPCEngineSetLauncher, self).start(n)
792
797
793
798
794 #-----------------------------------------------------------------------------
799 #-----------------------------------------------------------------------------
795 # Batch (PBS) system launchers
800 # Batch (PBS) system launchers
796 #-----------------------------------------------------------------------------
801 #-----------------------------------------------------------------------------
797
802
798 class BatchSystemLauncher(BaseLauncher):
803 class BatchSystemLauncher(BaseLauncher):
799 """Launch an external process using a batch system.
804 """Launch an external process using a batch system.
800
805
801 This class is designed to work with UNIX batch systems like PBS, LSF,
806 This class is designed to work with UNIX batch systems like PBS, LSF,
802 GridEngine, etc. The overall model is that there are different commands
807 GridEngine, etc. The overall model is that there are different commands
803 like qsub, qdel, etc. that handle the starting and stopping of the process.
808 like qsub, qdel, etc. that handle the starting and stopping of the process.
804
809
805 This class also has the notion of a batch script. The ``batch_template``
810 This class also has the notion of a batch script. The ``batch_template``
806 attribute can be set to a string that is a template for the batch script.
811 attribute can be set to a string that is a template for the batch script.
807 This template is instantiated using string formatting. Thus the template can
812 This template is instantiated using string formatting. Thus the template can
808 use {n} fot the number of instances. Subclasses can add additional variables
813 use {n} fot the number of instances. Subclasses can add additional variables
809 to the template dict.
814 to the template dict.
810 """
815 """
811
816
812 # Subclasses must fill these in. See PBSEngineSet
817 # Subclasses must fill these in. See PBSEngineSet
813 submit_command = List([''], config=True,
818 submit_command = List([''], config=True,
814 help="The name of the command line program used to submit jobs.")
819 help="The name of the command line program used to submit jobs.")
815 delete_command = List([''], config=True,
820 delete_command = List([''], config=True,
816 help="The name of the command line program used to delete jobs.")
821 help="The name of the command line program used to delete jobs.")
817 job_id_regexp = Unicode('', config=True,
822 job_id_regexp = Unicode('', config=True,
818 help="""A regular expression used to get the job id from the output of the
823 help="""A regular expression used to get the job id from the output of the
819 submit_command.""")
824 submit_command.""")
820 batch_template = Unicode('', config=True,
825 batch_template = Unicode('', config=True,
821 help="The string that is the batch script template itself.")
826 help="The string that is the batch script template itself.")
822 batch_template_file = Unicode(u'', config=True,
827 batch_template_file = Unicode(u'', config=True,
823 help="The file that contains the batch template.")
828 help="The file that contains the batch template.")
824 batch_file_name = Unicode(u'batch_script', config=True,
829 batch_file_name = Unicode(u'batch_script', config=True,
825 help="The filename of the instantiated batch script.")
830 help="The filename of the instantiated batch script.")
826 queue = Unicode(u'', config=True,
831 queue = Unicode(u'', config=True,
827 help="The PBS Queue.")
832 help="The PBS Queue.")
828
833
829 # not configurable, override in subclasses
834 # not configurable, override in subclasses
830 # PBS Job Array regex
835 # PBS Job Array regex
831 job_array_regexp = Unicode('')
836 job_array_regexp = Unicode('')
832 job_array_template = Unicode('')
837 job_array_template = Unicode('')
833 # PBS Queue regex
838 # PBS Queue regex
834 queue_regexp = Unicode('')
839 queue_regexp = Unicode('')
835 queue_template = Unicode('')
840 queue_template = Unicode('')
836 # The default batch template, override in subclasses
841 # The default batch template, override in subclasses
837 default_template = Unicode('')
842 default_template = Unicode('')
838 # The full path to the instantiated batch script.
843 # The full path to the instantiated batch script.
839 batch_file = Unicode(u'')
844 batch_file = Unicode(u'')
840 # the format dict used with batch_template:
845 # the format dict used with batch_template:
841 context = Dict()
846 context = Dict()
842 # the Formatter instance for rendering the templates:
847 # the Formatter instance for rendering the templates:
843 formatter = Instance(EvalFormatter, (), {})
848 formatter = Instance(EvalFormatter, (), {})
844
849
845
850
846 def find_args(self):
851 def find_args(self):
847 return self.submit_command + [self.batch_file]
852 return self.submit_command + [self.batch_file]
848
853
849 def __init__(self, work_dir=u'.', config=None, **kwargs):
854 def __init__(self, work_dir=u'.', config=None, **kwargs):
850 super(BatchSystemLauncher, self).__init__(
855 super(BatchSystemLauncher, self).__init__(
851 work_dir=work_dir, config=config, **kwargs
856 work_dir=work_dir, config=config, **kwargs
852 )
857 )
853 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
858 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
854
859
855 def parse_job_id(self, output):
860 def parse_job_id(self, output):
856 """Take the output of the submit command and return the job id."""
861 """Take the output of the submit command and return the job id."""
857 m = re.search(self.job_id_regexp, output)
862 m = re.search(self.job_id_regexp, output)
858 if m is not None:
863 if m is not None:
859 job_id = m.group()
864 job_id = m.group()
860 else:
865 else:
861 raise LauncherError("Job id couldn't be determined: %s" % output)
866 raise LauncherError("Job id couldn't be determined: %s" % output)
862 self.job_id = job_id
867 self.job_id = job_id
863 self.log.info('Job submitted with job id: %r' % job_id)
868 self.log.info('Job submitted with job id: %r' % job_id)
864 return job_id
869 return job_id
865
870
866 def write_batch_script(self, n):
871 def write_batch_script(self, n):
867 """Instantiate and write the batch script to the work_dir."""
872 """Instantiate and write the batch script to the work_dir."""
868 self.context['n'] = n
873 self.context['n'] = n
869 self.context['queue'] = self.queue
874 self.context['queue'] = self.queue
870 # first priority is batch_template if set
875 # first priority is batch_template if set
871 if self.batch_template_file and not self.batch_template:
876 if self.batch_template_file and not self.batch_template:
872 # second priority is batch_template_file
877 # second priority is batch_template_file
873 with open(self.batch_template_file) as f:
878 with open(self.batch_template_file) as f:
874 self.batch_template = f.read()
879 self.batch_template = f.read()
875 if not self.batch_template:
880 if not self.batch_template:
876 # third (last) priority is default_template
881 # third (last) priority is default_template
877 self.batch_template = self.default_template
882 self.batch_template = self.default_template
878
883
879 regex = re.compile(self.job_array_regexp)
884 regex = re.compile(self.job_array_regexp)
880 # print regex.search(self.batch_template)
885 # print regex.search(self.batch_template)
881 if not regex.search(self.batch_template):
886 if not regex.search(self.batch_template):
882 self.log.info("adding job array settings to batch script")
887 self.log.info("adding job array settings to batch script")
883 firstline, rest = self.batch_template.split('\n',1)
888 firstline, rest = self.batch_template.split('\n',1)
884 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
889 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
885
890
886 regex = re.compile(self.queue_regexp)
891 regex = re.compile(self.queue_regexp)
887 # print regex.search(self.batch_template)
892 # print regex.search(self.batch_template)
888 if self.queue and not regex.search(self.batch_template):
893 if self.queue and not regex.search(self.batch_template):
889 self.log.info("adding PBS queue settings to batch script")
894 self.log.info("adding PBS queue settings to batch script")
890 firstline, rest = self.batch_template.split('\n',1)
895 firstline, rest = self.batch_template.split('\n',1)
891 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
896 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
892
897
893 script_as_string = self.formatter.format(self.batch_template, **self.context)
898 script_as_string = self.formatter.format(self.batch_template, **self.context)
894 self.log.info('Writing instantiated batch script: %s' % self.batch_file)
899 self.log.info('Writing instantiated batch script: %s' % self.batch_file)
895
900
896 with open(self.batch_file, 'w') as f:
901 with open(self.batch_file, 'w') as f:
897 f.write(script_as_string)
902 f.write(script_as_string)
898 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
903 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
899
904
900 def start(self, n, profile_dir):
905 def start(self, n, profile_dir):
901 """Start n copies of the process using a batch system."""
906 """Start n copies of the process using a batch system."""
902 # Here we save profile_dir in the context so they
907 # Here we save profile_dir in the context so they
903 # can be used in the batch script template as {profile_dir}
908 # can be used in the batch script template as {profile_dir}
904 self.context['profile_dir'] = profile_dir
909 self.context['profile_dir'] = profile_dir
905 self.profile_dir = unicode(profile_dir)
910 self.profile_dir = unicode(profile_dir)
906 self.write_batch_script(n)
911 self.write_batch_script(n)
907 output = check_output(self.args, env=os.environ)
912 output = check_output(self.args, env=os.environ)
908
913
909 job_id = self.parse_job_id(output)
914 job_id = self.parse_job_id(output)
910 self.notify_start(job_id)
915 self.notify_start(job_id)
911 return job_id
916 return job_id
912
917
913 def stop(self):
918 def stop(self):
914 output = check_output(self.delete_command+[self.job_id], env=os.environ)
919 output = check_output(self.delete_command+[self.job_id], env=os.environ)
915 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
920 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
916 return output
921 return output
917
922
918
923
919 class PBSLauncher(BatchSystemLauncher):
924 class PBSLauncher(BatchSystemLauncher):
920 """A BatchSystemLauncher subclass for PBS."""
925 """A BatchSystemLauncher subclass for PBS."""
921
926
922 submit_command = List(['qsub'], config=True,
927 submit_command = List(['qsub'], config=True,
923 help="The PBS submit command ['qsub']")
928 help="The PBS submit command ['qsub']")
924 delete_command = List(['qdel'], config=True,
929 delete_command = List(['qdel'], config=True,
925 help="The PBS delete command ['qsub']")
930 help="The PBS delete command ['qsub']")
926 job_id_regexp = Unicode(r'\d+', config=True,
931 job_id_regexp = Unicode(r'\d+', config=True,
927 help="Regular expresion for identifying the job ID [r'\d+']")
932 help="Regular expresion for identifying the job ID [r'\d+']")
928
933
929 batch_file = Unicode(u'')
934 batch_file = Unicode(u'')
930 job_array_regexp = Unicode('#PBS\W+-t\W+[\w\d\-\$]+')
935 job_array_regexp = Unicode('#PBS\W+-t\W+[\w\d\-\$]+')
931 job_array_template = Unicode('#PBS -t 1-{n}')
936 job_array_template = Unicode('#PBS -t 1-{n}')
932 queue_regexp = Unicode('#PBS\W+-q\W+\$?\w+')
937 queue_regexp = Unicode('#PBS\W+-q\W+\$?\w+')
933 queue_template = Unicode('#PBS -q {queue}')
938 queue_template = Unicode('#PBS -q {queue}')
934
939
935
940
936 class PBSControllerLauncher(PBSLauncher):
941 class PBSControllerLauncher(PBSLauncher):
937 """Launch a controller using PBS."""
942 """Launch a controller using PBS."""
938
943
939 batch_file_name = Unicode(u'pbs_controller', config=True,
944 batch_file_name = Unicode(u'pbs_controller', config=True,
940 help="batch file name for the controller job.")
945 help="batch file name for the controller job.")
941 default_template= Unicode("""#!/bin/sh
946 default_template= Unicode("""#!/bin/sh
942 #PBS -V
947 #PBS -V
943 #PBS -N ipcontroller
948 #PBS -N ipcontroller
944 %s --log-to-file profile_dir={profile_dir}
949 %s --log-to-file profile_dir={profile_dir}
945 """%(' '.join(ipcontroller_cmd_argv)))
950 """%(' '.join(ipcontroller_cmd_argv)))
946
951
947 def start(self, profile_dir):
952 def start(self, profile_dir):
948 """Start the controller by profile or profile_dir."""
953 """Start the controller by profile or profile_dir."""
949 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
954 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
950 return super(PBSControllerLauncher, self).start(1, profile_dir)
955 return super(PBSControllerLauncher, self).start(1, profile_dir)
951
956
952
957
953 class PBSEngineSetLauncher(PBSLauncher):
958 class PBSEngineSetLauncher(PBSLauncher):
954 """Launch Engines using PBS"""
959 """Launch Engines using PBS"""
955 batch_file_name = Unicode(u'pbs_engines', config=True,
960 batch_file_name = Unicode(u'pbs_engines', config=True,
956 help="batch file name for the engine(s) job.")
961 help="batch file name for the engine(s) job.")
957 default_template= Unicode(u"""#!/bin/sh
962 default_template= Unicode(u"""#!/bin/sh
958 #PBS -V
963 #PBS -V
959 #PBS -N ipengine
964 #PBS -N ipengine
960 %s profile_dir={profile_dir}
965 %s profile_dir={profile_dir}
961 """%(' '.join(ipengine_cmd_argv)))
966 """%(' '.join(ipengine_cmd_argv)))
962
967
963 def start(self, n, profile_dir):
968 def start(self, n, profile_dir):
964 """Start n engines by profile or profile_dir."""
969 """Start n engines by profile or profile_dir."""
965 self.log.info('Starting %i engines with PBSEngineSetLauncher: %r' % (n, self.args))
970 self.log.info('Starting %i engines with PBSEngineSetLauncher: %r' % (n, self.args))
966 return super(PBSEngineSetLauncher, self).start(n, profile_dir)
971 return super(PBSEngineSetLauncher, self).start(n, profile_dir)
967
972
968 #SGE is very similar to PBS
973 #SGE is very similar to PBS
969
974
970 class SGELauncher(PBSLauncher):
975 class SGELauncher(PBSLauncher):
971 """Sun GridEngine is a PBS clone with slightly different syntax"""
976 """Sun GridEngine is a PBS clone with slightly different syntax"""
972 job_array_regexp = Unicode('#\$\W+\-t')
977 job_array_regexp = Unicode('#\$\W+\-t')
973 job_array_template = Unicode('#$ -t 1-{n}')
978 job_array_template = Unicode('#$ -t 1-{n}')
974 queue_regexp = Unicode('#\$\W+-q\W+\$?\w+')
979 queue_regexp = Unicode('#\$\W+-q\W+\$?\w+')
975 queue_template = Unicode('#$ -q $queue')
980 queue_template = Unicode('#$ -q $queue')
976
981
977 class SGEControllerLauncher(SGELauncher):
982 class SGEControllerLauncher(SGELauncher):
978 """Launch a controller using SGE."""
983 """Launch a controller using SGE."""
979
984
980 batch_file_name = Unicode(u'sge_controller', config=True,
985 batch_file_name = Unicode(u'sge_controller', config=True,
981 help="batch file name for the ipontroller job.")
986 help="batch file name for the ipontroller job.")
982 default_template= Unicode(u"""#$ -V
987 default_template= Unicode(u"""#$ -V
983 #$ -S /bin/sh
988 #$ -S /bin/sh
984 #$ -N ipcontroller
989 #$ -N ipcontroller
985 %s --log-to-file profile_dir={profile_dir}
990 %s --log-to-file profile_dir={profile_dir}
986 """%(' '.join(ipcontroller_cmd_argv)))
991 """%(' '.join(ipcontroller_cmd_argv)))
987
992
988 def start(self, profile_dir):
993 def start(self, profile_dir):
989 """Start the controller by profile or profile_dir."""
994 """Start the controller by profile or profile_dir."""
990 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
995 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
991 return super(SGEControllerLauncher, self).start(1, profile_dir)
996 return super(SGEControllerLauncher, self).start(1, profile_dir)
992
997
993 class SGEEngineSetLauncher(SGELauncher):
998 class SGEEngineSetLauncher(SGELauncher):
994 """Launch Engines with SGE"""
999 """Launch Engines with SGE"""
995 batch_file_name = Unicode(u'sge_engines', config=True,
1000 batch_file_name = Unicode(u'sge_engines', config=True,
996 help="batch file name for the engine(s) job.")
1001 help="batch file name for the engine(s) job.")
997 default_template = Unicode("""#$ -V
1002 default_template = Unicode("""#$ -V
998 #$ -S /bin/sh
1003 #$ -S /bin/sh
999 #$ -N ipengine
1004 #$ -N ipengine
1000 %s profile_dir={profile_dir}
1005 %s profile_dir={profile_dir}
1001 """%(' '.join(ipengine_cmd_argv)))
1006 """%(' '.join(ipengine_cmd_argv)))
1002
1007
1003 def start(self, n, profile_dir):
1008 def start(self, n, profile_dir):
1004 """Start n engines by profile or profile_dir."""
1009 """Start n engines by profile or profile_dir."""
1005 self.log.info('Starting %i engines with SGEEngineSetLauncher: %r' % (n, self.args))
1010 self.log.info('Starting %i engines with SGEEngineSetLauncher: %r' % (n, self.args))
1006 return super(SGEEngineSetLauncher, self).start(n, profile_dir)
1011 return super(SGEEngineSetLauncher, self).start(n, profile_dir)
1007
1012
1008
1013
1009 #-----------------------------------------------------------------------------
1014 #-----------------------------------------------------------------------------
1010 # A launcher for ipcluster itself!
1015 # A launcher for ipcluster itself!
1011 #-----------------------------------------------------------------------------
1016 #-----------------------------------------------------------------------------
1012
1017
1013
1018
1014 class IPClusterLauncher(LocalProcessLauncher):
1019 class IPClusterLauncher(LocalProcessLauncher):
1015 """Launch the ipcluster program in an external process."""
1020 """Launch the ipcluster program in an external process."""
1016
1021
1017 ipcluster_cmd = List(ipcluster_cmd_argv, config=True,
1022 ipcluster_cmd = List(ipcluster_cmd_argv, config=True,
1018 help="Popen command for ipcluster")
1023 help="Popen command for ipcluster")
1019 ipcluster_args = List(
1024 ipcluster_args = List(
1020 ['--clean-logs', '--log-to-file', 'log_level=%i'%logging.INFO], config=True,
1025 ['--clean-logs', '--log-to-file', 'log_level=%i'%logging.INFO], config=True,
1021 help="Command line arguments to pass to ipcluster.")
1026 help="Command line arguments to pass to ipcluster.")
1022 ipcluster_subcommand = Unicode('start')
1027 ipcluster_subcommand = Unicode('start')
1023 ipcluster_n = Int(2)
1028 ipcluster_n = Int(2)
1024
1029
1025 def find_args(self):
1030 def find_args(self):
1026 return self.ipcluster_cmd + ['--'+self.ipcluster_subcommand] + \
1031 return self.ipcluster_cmd + ['--'+self.ipcluster_subcommand] + \
1027 ['n=%i'%self.ipcluster_n] + self.ipcluster_args
1032 ['n=%i'%self.ipcluster_n] + self.ipcluster_args
1028
1033
1029 def start(self):
1034 def start(self):
1030 self.log.info("Starting ipcluster: %r" % self.args)
1035 self.log.info("Starting ipcluster: %r" % self.args)
1031 return super(IPClusterLauncher, self).start()
1036 return super(IPClusterLauncher, self).start()
1032
1037
1033 #-----------------------------------------------------------------------------
1038 #-----------------------------------------------------------------------------
1034 # Collections of launchers
1039 # Collections of launchers
1035 #-----------------------------------------------------------------------------
1040 #-----------------------------------------------------------------------------
1036
1041
1037 local_launchers = [
1042 local_launchers = [
1038 LocalControllerLauncher,
1043 LocalControllerLauncher,
1039 LocalEngineLauncher,
1044 LocalEngineLauncher,
1040 LocalEngineSetLauncher,
1045 LocalEngineSetLauncher,
1041 ]
1046 ]
1042 mpi_launchers = [
1047 mpi_launchers = [
1043 MPIExecLauncher,
1048 MPIExecLauncher,
1044 MPIExecControllerLauncher,
1049 MPIExecControllerLauncher,
1045 MPIExecEngineSetLauncher,
1050 MPIExecEngineSetLauncher,
1046 ]
1051 ]
1047 ssh_launchers = [
1052 ssh_launchers = [
1048 SSHLauncher,
1053 SSHLauncher,
1049 SSHControllerLauncher,
1054 SSHControllerLauncher,
1050 SSHEngineLauncher,
1055 SSHEngineLauncher,
1051 SSHEngineSetLauncher,
1056 SSHEngineSetLauncher,
1052 ]
1057 ]
1053 winhpc_launchers = [
1058 winhpc_launchers = [
1054 WindowsHPCLauncher,
1059 WindowsHPCLauncher,
1055 WindowsHPCControllerLauncher,
1060 WindowsHPCControllerLauncher,
1056 WindowsHPCEngineSetLauncher,
1061 WindowsHPCEngineSetLauncher,
1057 ]
1062 ]
1058 pbs_launchers = [
1063 pbs_launchers = [
1059 PBSLauncher,
1064 PBSLauncher,
1060 PBSControllerLauncher,
1065 PBSControllerLauncher,
1061 PBSEngineSetLauncher,
1066 PBSEngineSetLauncher,
1062 ]
1067 ]
1063 sge_launchers = [
1068 sge_launchers = [
1064 SGELauncher,
1069 SGELauncher,
1065 SGEControllerLauncher,
1070 SGEControllerLauncher,
1066 SGEEngineSetLauncher,
1071 SGEEngineSetLauncher,
1067 ]
1072 ]
1068 all_launchers = local_launchers + mpi_launchers + ssh_launchers + winhpc_launchers\
1073 all_launchers = local_launchers + mpi_launchers + ssh_launchers + winhpc_launchers\
1069 + pbs_launchers + sge_launchers
1074 + pbs_launchers + sge_launchers
@@ -1,108 +1,115 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """A simple logger object that consolidates messages incoming from ipcluster processes."""
2 """
3 A simple logger object that consolidates messages incoming from ipcluster processes.
4
5 Authors:
6
7 * MinRK
8
9 """
3
10
4 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
5 # Copyright (C) 2011 The IPython Development Team
12 # Copyright (C) 2011 The IPython Development Team
6 #
13 #
7 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
10
17
11 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
12 # Imports
19 # Imports
13 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
14
21
15
22
16 import logging
23 import logging
17 import sys
24 import sys
18
25
19 import zmq
26 import zmq
20 from zmq.eventloop import ioloop, zmqstream
27 from zmq.eventloop import ioloop, zmqstream
21
28
22 from IPython.config.configurable import LoggingConfigurable
29 from IPython.config.configurable import LoggingConfigurable
23 from IPython.utils.traitlets import Int, Unicode, Instance, List
30 from IPython.utils.traitlets import Int, Unicode, Instance, List
24
31
25 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
26 # Classes
33 # Classes
27 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
28
35
29
36
30 class LogWatcher(LoggingConfigurable):
37 class LogWatcher(LoggingConfigurable):
31 """A simple class that receives messages on a SUB socket, as published
38 """A simple class that receives messages on a SUB socket, as published
32 by subclasses of `zmq.log.handlers.PUBHandler`, and logs them itself.
39 by subclasses of `zmq.log.handlers.PUBHandler`, and logs them itself.
33
40
34 This can subscribe to multiple topics, but defaults to all topics.
41 This can subscribe to multiple topics, but defaults to all topics.
35 """
42 """
36
43
37 # configurables
44 # configurables
38 topics = List([''], config=True,
45 topics = List([''], config=True,
39 help="The ZMQ topics to subscribe to. Default is to subscribe to all messages")
46 help="The ZMQ topics to subscribe to. Default is to subscribe to all messages")
40 url = Unicode('tcp://127.0.0.1:20202', config=True,
47 url = Unicode('tcp://127.0.0.1:20202', config=True,
41 help="ZMQ url on which to listen for log messages")
48 help="ZMQ url on which to listen for log messages")
42
49
43 # internals
50 # internals
44 stream = Instance('zmq.eventloop.zmqstream.ZMQStream')
51 stream = Instance('zmq.eventloop.zmqstream.ZMQStream')
45
52
46 context = Instance(zmq.Context)
53 context = Instance(zmq.Context)
47 def _context_default(self):
54 def _context_default(self):
48 return zmq.Context.instance()
55 return zmq.Context.instance()
49
56
50 loop = Instance(zmq.eventloop.ioloop.IOLoop)
57 loop = Instance(zmq.eventloop.ioloop.IOLoop)
51 def _loop_default(self):
58 def _loop_default(self):
52 return ioloop.IOLoop.instance()
59 return ioloop.IOLoop.instance()
53
60
54 def __init__(self, **kwargs):
61 def __init__(self, **kwargs):
55 super(LogWatcher, self).__init__(**kwargs)
62 super(LogWatcher, self).__init__(**kwargs)
56 s = self.context.socket(zmq.SUB)
63 s = self.context.socket(zmq.SUB)
57 s.bind(self.url)
64 s.bind(self.url)
58 self.stream = zmqstream.ZMQStream(s, self.loop)
65 self.stream = zmqstream.ZMQStream(s, self.loop)
59 self.subscribe()
66 self.subscribe()
60 self.on_trait_change(self.subscribe, 'topics')
67 self.on_trait_change(self.subscribe, 'topics')
61
68
62 def start(self):
69 def start(self):
63 self.stream.on_recv(self.log_message)
70 self.stream.on_recv(self.log_message)
64
71
65 def stop(self):
72 def stop(self):
66 self.stream.stop_on_recv()
73 self.stream.stop_on_recv()
67
74
68 def subscribe(self):
75 def subscribe(self):
69 """Update our SUB socket's subscriptions."""
76 """Update our SUB socket's subscriptions."""
70 self.stream.setsockopt(zmq.UNSUBSCRIBE, '')
77 self.stream.setsockopt(zmq.UNSUBSCRIBE, '')
71 if '' in self.topics:
78 if '' in self.topics:
72 self.log.debug("Subscribing to: everything")
79 self.log.debug("Subscribing to: everything")
73 self.stream.setsockopt(zmq.SUBSCRIBE, '')
80 self.stream.setsockopt(zmq.SUBSCRIBE, '')
74 else:
81 else:
75 for topic in self.topics:
82 for topic in self.topics:
76 self.log.debug("Subscribing to: %r"%(topic))
83 self.log.debug("Subscribing to: %r"%(topic))
77 self.stream.setsockopt(zmq.SUBSCRIBE, topic)
84 self.stream.setsockopt(zmq.SUBSCRIBE, topic)
78
85
79 def _extract_level(self, topic_str):
86 def _extract_level(self, topic_str):
80 """Turn 'engine.0.INFO.extra' into (logging.INFO, 'engine.0.extra')"""
87 """Turn 'engine.0.INFO.extra' into (logging.INFO, 'engine.0.extra')"""
81 topics = topic_str.split('.')
88 topics = topic_str.split('.')
82 for idx,t in enumerate(topics):
89 for idx,t in enumerate(topics):
83 level = getattr(logging, t, None)
90 level = getattr(logging, t, None)
84 if level is not None:
91 if level is not None:
85 break
92 break
86
93
87 if level is None:
94 if level is None:
88 level = logging.INFO
95 level = logging.INFO
89 else:
96 else:
90 topics.pop(idx)
97 topics.pop(idx)
91
98
92 return level, '.'.join(topics)
99 return level, '.'.join(topics)
93
100
94
101
95 def log_message(self, raw):
102 def log_message(self, raw):
96 """receive and parse a message, then log it."""
103 """receive and parse a message, then log it."""
97 if len(raw) != 2 or '.' not in raw[0]:
104 if len(raw) != 2 or '.' not in raw[0]:
98 self.log.error("Invalid log message: %s"%raw)
105 self.log.error("Invalid log message: %s"%raw)
99 return
106 return
100 else:
107 else:
101 topic, msg = raw
108 topic, msg = raw
102 # don't newline, since log messages always newline:
109 # don't newline, since log messages always newline:
103 topic,level_name = topic.rsplit('.',1)
110 topic,level_name = topic.rsplit('.',1)
104 level,topic = self._extract_level(topic)
111 level,topic = self._extract_level(topic)
105 if msg[-1] == '\n':
112 if msg[-1] == '\n':
106 msg = msg[:-1]
113 msg = msg[:-1]
107 self.log.log(level, "[%s] %s" % (topic, msg))
114 self.log.log(level, "[%s] %s" % (topic, msg))
108
115
@@ -1,67 +1,73 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """Utility for forwarding file read events over a zmq socket.
2 """Utility for forwarding file read events over a zmq socket.
3
3
4 This is necessary because select on Windows only supports"""
4 This is necessary because select on Windows only supports
5
6 Authors:
7
8 * MinRK
9
10 """
5
11
6 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
7 # Copyright (C) 2011 The IPython Development Team
13 # Copyright (C) 2011 The IPython Development Team
8 #
14 #
9 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
12
18
13 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
14 # Imports
20 # Imports
15 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
16
22
17 import uuid
23 import uuid
18 import zmq
24 import zmq
19
25
20 from threading import Thread
26 from threading import Thread
21
27
22 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
23 # Code
29 # Code
24 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
25
31
26 class ForwarderThread(Thread):
32 class ForwarderThread(Thread):
27 def __init__(self, sock, fd):
33 def __init__(self, sock, fd):
28 Thread.__init__(self)
34 Thread.__init__(self)
29 self.daemon=True
35 self.daemon=True
30 self.sock = sock
36 self.sock = sock
31 self.fd = fd
37 self.fd = fd
32
38
33 def run(self):
39 def run(self):
34 """loop through lines in self.fd, and send them over self.sock"""
40 """loop through lines in self.fd, and send them over self.sock"""
35 line = self.fd.readline()
41 line = self.fd.readline()
36 # allow for files opened in unicode mode
42 # allow for files opened in unicode mode
37 if isinstance(line, unicode):
43 if isinstance(line, unicode):
38 send = self.sock.send_unicode
44 send = self.sock.send_unicode
39 else:
45 else:
40 send = self.sock.send
46 send = self.sock.send
41 while line:
47 while line:
42 send(line)
48 send(line)
43 line = self.fd.readline()
49 line = self.fd.readline()
44 # line == '' means EOF
50 # line == '' means EOF
45 self.fd.close()
51 self.fd.close()
46 self.sock.close()
52 self.sock.close()
47
53
48 def forward_read_events(fd, context=None):
54 def forward_read_events(fd, context=None):
49 """forward read events from an FD over a socket.
55 """forward read events from an FD over a socket.
50
56
51 This method wraps a file in a socket pair, so it can
57 This method wraps a file in a socket pair, so it can
52 be polled for read events by select (specifically zmq.eventloop.ioloop)
58 be polled for read events by select (specifically zmq.eventloop.ioloop)
53 """
59 """
54 if context is None:
60 if context is None:
55 context = zmq.Context.instance()
61 context = zmq.Context.instance()
56 push = context.socket(zmq.PUSH)
62 push = context.socket(zmq.PUSH)
57 push.setsockopt(zmq.LINGER, -1)
63 push.setsockopt(zmq.LINGER, -1)
58 pull = context.socket(zmq.PULL)
64 pull = context.socket(zmq.PULL)
59 addr='inproc://%s'%uuid.uuid4()
65 addr='inproc://%s'%uuid.uuid4()
60 push.bind(addr)
66 push.bind(addr)
61 pull.connect(addr)
67 pull.connect(addr)
62 forwarder = ForwarderThread(push, fd)
68 forwarder = ForwarderThread(push, fd)
63 forwarder.start()
69 forwarder.start()
64 return pull
70 return pull
65
71
66
72
67 __all__ = ['forward_read_events'] No newline at end of file
73 __all__ = ['forward_read_events']
@@ -1,314 +1,320 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 Job and task components for writing .xml files that the Windows HPC Server
4 Job and task components for writing .xml files that the Windows HPC Server
5 2008 can use to start jobs.
5 2008 can use to start jobs.
6
7 Authors:
8
9 * Brian Granger
10 * MinRK
11
6 """
12 """
7
13
8 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
9 # Copyright (C) 2008-2009 The IPython Development Team
15 # Copyright (C) 2008-2011 The IPython Development Team
10 #
16 #
11 # Distributed under the terms of the BSD License. The full license is in
17 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
18 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
14
20
15 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
16 # Imports
22 # Imports
17 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
18
24
19 import os
25 import os
20 import re
26 import re
21 import uuid
27 import uuid
22
28
23 from xml.etree import ElementTree as ET
29 from xml.etree import ElementTree as ET
24
30
25 from IPython.config.configurable import Configurable
31 from IPython.config.configurable import Configurable
26 from IPython.utils.traitlets import (
32 from IPython.utils.traitlets import (
27 Unicode, Int, List, Instance,
33 Unicode, Int, List, Instance,
28 Enum, Bool
34 Enum, Bool
29 )
35 )
30
36
31 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
32 # Job and Task classes
38 # Job and Task classes
33 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
34
40
35
41
36 def as_str(value):
42 def as_str(value):
37 if isinstance(value, str):
43 if isinstance(value, str):
38 return value
44 return value
39 elif isinstance(value, bool):
45 elif isinstance(value, bool):
40 if value:
46 if value:
41 return 'true'
47 return 'true'
42 else:
48 else:
43 return 'false'
49 return 'false'
44 elif isinstance(value, (int, float)):
50 elif isinstance(value, (int, float)):
45 return repr(value)
51 return repr(value)
46 else:
52 else:
47 return value
53 return value
48
54
49
55
50 def indent(elem, level=0):
56 def indent(elem, level=0):
51 i = "\n" + level*" "
57 i = "\n" + level*" "
52 if len(elem):
58 if len(elem):
53 if not elem.text or not elem.text.strip():
59 if not elem.text or not elem.text.strip():
54 elem.text = i + " "
60 elem.text = i + " "
55 if not elem.tail or not elem.tail.strip():
61 if not elem.tail or not elem.tail.strip():
56 elem.tail = i
62 elem.tail = i
57 for elem in elem:
63 for elem in elem:
58 indent(elem, level+1)
64 indent(elem, level+1)
59 if not elem.tail or not elem.tail.strip():
65 if not elem.tail or not elem.tail.strip():
60 elem.tail = i
66 elem.tail = i
61 else:
67 else:
62 if level and (not elem.tail or not elem.tail.strip()):
68 if level and (not elem.tail or not elem.tail.strip()):
63 elem.tail = i
69 elem.tail = i
64
70
65
71
66 def find_username():
72 def find_username():
67 domain = os.environ.get('USERDOMAIN')
73 domain = os.environ.get('USERDOMAIN')
68 username = os.environ.get('USERNAME','')
74 username = os.environ.get('USERNAME','')
69 if domain is None:
75 if domain is None:
70 return username
76 return username
71 else:
77 else:
72 return '%s\\%s' % (domain, username)
78 return '%s\\%s' % (domain, username)
73
79
74
80
75 class WinHPCJob(Configurable):
81 class WinHPCJob(Configurable):
76
82
77 job_id = Unicode('')
83 job_id = Unicode('')
78 job_name = Unicode('MyJob', config=True)
84 job_name = Unicode('MyJob', config=True)
79 min_cores = Int(1, config=True)
85 min_cores = Int(1, config=True)
80 max_cores = Int(1, config=True)
86 max_cores = Int(1, config=True)
81 min_sockets = Int(1, config=True)
87 min_sockets = Int(1, config=True)
82 max_sockets = Int(1, config=True)
88 max_sockets = Int(1, config=True)
83 min_nodes = Int(1, config=True)
89 min_nodes = Int(1, config=True)
84 max_nodes = Int(1, config=True)
90 max_nodes = Int(1, config=True)
85 unit_type = Unicode("Core", config=True)
91 unit_type = Unicode("Core", config=True)
86 auto_calculate_min = Bool(True, config=True)
92 auto_calculate_min = Bool(True, config=True)
87 auto_calculate_max = Bool(True, config=True)
93 auto_calculate_max = Bool(True, config=True)
88 run_until_canceled = Bool(False, config=True)
94 run_until_canceled = Bool(False, config=True)
89 is_exclusive = Bool(False, config=True)
95 is_exclusive = Bool(False, config=True)
90 username = Unicode(find_username(), config=True)
96 username = Unicode(find_username(), config=True)
91 job_type = Unicode('Batch', config=True)
97 job_type = Unicode('Batch', config=True)
92 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
98 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
93 default_value='Highest', config=True)
99 default_value='Highest', config=True)
94 requested_nodes = Unicode('', config=True)
100 requested_nodes = Unicode('', config=True)
95 project = Unicode('IPython', config=True)
101 project = Unicode('IPython', config=True)
96 xmlns = Unicode('http://schemas.microsoft.com/HPCS2008/scheduler/')
102 xmlns = Unicode('http://schemas.microsoft.com/HPCS2008/scheduler/')
97 version = Unicode("2.000")
103 version = Unicode("2.000")
98 tasks = List([])
104 tasks = List([])
99
105
100 @property
106 @property
101 def owner(self):
107 def owner(self):
102 return self.username
108 return self.username
103
109
104 def _write_attr(self, root, attr, key):
110 def _write_attr(self, root, attr, key):
105 s = as_str(getattr(self, attr, ''))
111 s = as_str(getattr(self, attr, ''))
106 if s:
112 if s:
107 root.set(key, s)
113 root.set(key, s)
108
114
109 def as_element(self):
115 def as_element(self):
110 # We have to add _A_ type things to get the right order than
116 # We have to add _A_ type things to get the right order than
111 # the MSFT XML parser expects.
117 # the MSFT XML parser expects.
112 root = ET.Element('Job')
118 root = ET.Element('Job')
113 self._write_attr(root, 'version', '_A_Version')
119 self._write_attr(root, 'version', '_A_Version')
114 self._write_attr(root, 'job_name', '_B_Name')
120 self._write_attr(root, 'job_name', '_B_Name')
115 self._write_attr(root, 'unit_type', '_C_UnitType')
121 self._write_attr(root, 'unit_type', '_C_UnitType')
116 self._write_attr(root, 'min_cores', '_D_MinCores')
122 self._write_attr(root, 'min_cores', '_D_MinCores')
117 self._write_attr(root, 'max_cores', '_E_MaxCores')
123 self._write_attr(root, 'max_cores', '_E_MaxCores')
118 self._write_attr(root, 'min_sockets', '_F_MinSockets')
124 self._write_attr(root, 'min_sockets', '_F_MinSockets')
119 self._write_attr(root, 'max_sockets', '_G_MaxSockets')
125 self._write_attr(root, 'max_sockets', '_G_MaxSockets')
120 self._write_attr(root, 'min_nodes', '_H_MinNodes')
126 self._write_attr(root, 'min_nodes', '_H_MinNodes')
121 self._write_attr(root, 'max_nodes', '_I_MaxNodes')
127 self._write_attr(root, 'max_nodes', '_I_MaxNodes')
122 self._write_attr(root, 'run_until_canceled', '_J_RunUntilCanceled')
128 self._write_attr(root, 'run_until_canceled', '_J_RunUntilCanceled')
123 self._write_attr(root, 'is_exclusive', '_K_IsExclusive')
129 self._write_attr(root, 'is_exclusive', '_K_IsExclusive')
124 self._write_attr(root, 'username', '_L_UserName')
130 self._write_attr(root, 'username', '_L_UserName')
125 self._write_attr(root, 'job_type', '_M_JobType')
131 self._write_attr(root, 'job_type', '_M_JobType')
126 self._write_attr(root, 'priority', '_N_Priority')
132 self._write_attr(root, 'priority', '_N_Priority')
127 self._write_attr(root, 'requested_nodes', '_O_RequestedNodes')
133 self._write_attr(root, 'requested_nodes', '_O_RequestedNodes')
128 self._write_attr(root, 'auto_calculate_max', '_P_AutoCalculateMax')
134 self._write_attr(root, 'auto_calculate_max', '_P_AutoCalculateMax')
129 self._write_attr(root, 'auto_calculate_min', '_Q_AutoCalculateMin')
135 self._write_attr(root, 'auto_calculate_min', '_Q_AutoCalculateMin')
130 self._write_attr(root, 'project', '_R_Project')
136 self._write_attr(root, 'project', '_R_Project')
131 self._write_attr(root, 'owner', '_S_Owner')
137 self._write_attr(root, 'owner', '_S_Owner')
132 self._write_attr(root, 'xmlns', '_T_xmlns')
138 self._write_attr(root, 'xmlns', '_T_xmlns')
133 dependencies = ET.SubElement(root, "Dependencies")
139 dependencies = ET.SubElement(root, "Dependencies")
134 etasks = ET.SubElement(root, "Tasks")
140 etasks = ET.SubElement(root, "Tasks")
135 for t in self.tasks:
141 for t in self.tasks:
136 etasks.append(t.as_element())
142 etasks.append(t.as_element())
137 return root
143 return root
138
144
139 def tostring(self):
145 def tostring(self):
140 """Return the string representation of the job description XML."""
146 """Return the string representation of the job description XML."""
141 root = self.as_element()
147 root = self.as_element()
142 indent(root)
148 indent(root)
143 txt = ET.tostring(root, encoding="utf-8")
149 txt = ET.tostring(root, encoding="utf-8")
144 # Now remove the tokens used to order the attributes.
150 # Now remove the tokens used to order the attributes.
145 txt = re.sub(r'_[A-Z]_','',txt)
151 txt = re.sub(r'_[A-Z]_','',txt)
146 txt = '<?xml version="1.0" encoding="utf-8"?>\n' + txt
152 txt = '<?xml version="1.0" encoding="utf-8"?>\n' + txt
147 return txt
153 return txt
148
154
149 def write(self, filename):
155 def write(self, filename):
150 """Write the XML job description to a file."""
156 """Write the XML job description to a file."""
151 txt = self.tostring()
157 txt = self.tostring()
152 with open(filename, 'w') as f:
158 with open(filename, 'w') as f:
153 f.write(txt)
159 f.write(txt)
154
160
155 def add_task(self, task):
161 def add_task(self, task):
156 """Add a task to the job.
162 """Add a task to the job.
157
163
158 Parameters
164 Parameters
159 ----------
165 ----------
160 task : :class:`WinHPCTask`
166 task : :class:`WinHPCTask`
161 The task object to add.
167 The task object to add.
162 """
168 """
163 self.tasks.append(task)
169 self.tasks.append(task)
164
170
165
171
166 class WinHPCTask(Configurable):
172 class WinHPCTask(Configurable):
167
173
168 task_id = Unicode('')
174 task_id = Unicode('')
169 task_name = Unicode('')
175 task_name = Unicode('')
170 version = Unicode("2.000")
176 version = Unicode("2.000")
171 min_cores = Int(1, config=True)
177 min_cores = Int(1, config=True)
172 max_cores = Int(1, config=True)
178 max_cores = Int(1, config=True)
173 min_sockets = Int(1, config=True)
179 min_sockets = Int(1, config=True)
174 max_sockets = Int(1, config=True)
180 max_sockets = Int(1, config=True)
175 min_nodes = Int(1, config=True)
181 min_nodes = Int(1, config=True)
176 max_nodes = Int(1, config=True)
182 max_nodes = Int(1, config=True)
177 unit_type = Unicode("Core", config=True)
183 unit_type = Unicode("Core", config=True)
178 command_line = Unicode('', config=True)
184 command_line = Unicode('', config=True)
179 work_directory = Unicode('', config=True)
185 work_directory = Unicode('', config=True)
180 is_rerunnaable = Bool(True, config=True)
186 is_rerunnaable = Bool(True, config=True)
181 std_out_file_path = Unicode('', config=True)
187 std_out_file_path = Unicode('', config=True)
182 std_err_file_path = Unicode('', config=True)
188 std_err_file_path = Unicode('', config=True)
183 is_parametric = Bool(False, config=True)
189 is_parametric = Bool(False, config=True)
184 environment_variables = Instance(dict, args=(), config=True)
190 environment_variables = Instance(dict, args=(), config=True)
185
191
186 def _write_attr(self, root, attr, key):
192 def _write_attr(self, root, attr, key):
187 s = as_str(getattr(self, attr, ''))
193 s = as_str(getattr(self, attr, ''))
188 if s:
194 if s:
189 root.set(key, s)
195 root.set(key, s)
190
196
191 def as_element(self):
197 def as_element(self):
192 root = ET.Element('Task')
198 root = ET.Element('Task')
193 self._write_attr(root, 'version', '_A_Version')
199 self._write_attr(root, 'version', '_A_Version')
194 self._write_attr(root, 'task_name', '_B_Name')
200 self._write_attr(root, 'task_name', '_B_Name')
195 self._write_attr(root, 'min_cores', '_C_MinCores')
201 self._write_attr(root, 'min_cores', '_C_MinCores')
196 self._write_attr(root, 'max_cores', '_D_MaxCores')
202 self._write_attr(root, 'max_cores', '_D_MaxCores')
197 self._write_attr(root, 'min_sockets', '_E_MinSockets')
203 self._write_attr(root, 'min_sockets', '_E_MinSockets')
198 self._write_attr(root, 'max_sockets', '_F_MaxSockets')
204 self._write_attr(root, 'max_sockets', '_F_MaxSockets')
199 self._write_attr(root, 'min_nodes', '_G_MinNodes')
205 self._write_attr(root, 'min_nodes', '_G_MinNodes')
200 self._write_attr(root, 'max_nodes', '_H_MaxNodes')
206 self._write_attr(root, 'max_nodes', '_H_MaxNodes')
201 self._write_attr(root, 'command_line', '_I_CommandLine')
207 self._write_attr(root, 'command_line', '_I_CommandLine')
202 self._write_attr(root, 'work_directory', '_J_WorkDirectory')
208 self._write_attr(root, 'work_directory', '_J_WorkDirectory')
203 self._write_attr(root, 'is_rerunnaable', '_K_IsRerunnable')
209 self._write_attr(root, 'is_rerunnaable', '_K_IsRerunnable')
204 self._write_attr(root, 'std_out_file_path', '_L_StdOutFilePath')
210 self._write_attr(root, 'std_out_file_path', '_L_StdOutFilePath')
205 self._write_attr(root, 'std_err_file_path', '_M_StdErrFilePath')
211 self._write_attr(root, 'std_err_file_path', '_M_StdErrFilePath')
206 self._write_attr(root, 'is_parametric', '_N_IsParametric')
212 self._write_attr(root, 'is_parametric', '_N_IsParametric')
207 self._write_attr(root, 'unit_type', '_O_UnitType')
213 self._write_attr(root, 'unit_type', '_O_UnitType')
208 root.append(self.get_env_vars())
214 root.append(self.get_env_vars())
209 return root
215 return root
210
216
211 def get_env_vars(self):
217 def get_env_vars(self):
212 env_vars = ET.Element('EnvironmentVariables')
218 env_vars = ET.Element('EnvironmentVariables')
213 for k, v in self.environment_variables.iteritems():
219 for k, v in self.environment_variables.iteritems():
214 variable = ET.SubElement(env_vars, "Variable")
220 variable = ET.SubElement(env_vars, "Variable")
215 name = ET.SubElement(variable, "Name")
221 name = ET.SubElement(variable, "Name")
216 name.text = k
222 name.text = k
217 value = ET.SubElement(variable, "Value")
223 value = ET.SubElement(variable, "Value")
218 value.text = v
224 value.text = v
219 return env_vars
225 return env_vars
220
226
221
227
222
228
223 # By declaring these, we can configure the controller and engine separately!
229 # By declaring these, we can configure the controller and engine separately!
224
230
225 class IPControllerJob(WinHPCJob):
231 class IPControllerJob(WinHPCJob):
226 job_name = Unicode('IPController', config=False)
232 job_name = Unicode('IPController', config=False)
227 is_exclusive = Bool(False, config=True)
233 is_exclusive = Bool(False, config=True)
228 username = Unicode(find_username(), config=True)
234 username = Unicode(find_username(), config=True)
229 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
235 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
230 default_value='Highest', config=True)
236 default_value='Highest', config=True)
231 requested_nodes = Unicode('', config=True)
237 requested_nodes = Unicode('', config=True)
232 project = Unicode('IPython', config=True)
238 project = Unicode('IPython', config=True)
233
239
234
240
235 class IPEngineSetJob(WinHPCJob):
241 class IPEngineSetJob(WinHPCJob):
236 job_name = Unicode('IPEngineSet', config=False)
242 job_name = Unicode('IPEngineSet', config=False)
237 is_exclusive = Bool(False, config=True)
243 is_exclusive = Bool(False, config=True)
238 username = Unicode(find_username(), config=True)
244 username = Unicode(find_username(), config=True)
239 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
245 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
240 default_value='Highest', config=True)
246 default_value='Highest', config=True)
241 requested_nodes = Unicode('', config=True)
247 requested_nodes = Unicode('', config=True)
242 project = Unicode('IPython', config=True)
248 project = Unicode('IPython', config=True)
243
249
244
250
245 class IPControllerTask(WinHPCTask):
251 class IPControllerTask(WinHPCTask):
246
252
247 task_name = Unicode('IPController', config=True)
253 task_name = Unicode('IPController', config=True)
248 controller_cmd = List(['ipcontroller.exe'], config=True)
254 controller_cmd = List(['ipcontroller.exe'], config=True)
249 controller_args = List(['--log-to-file', '--log-level', '40'], config=True)
255 controller_args = List(['--log-to-file', '--log-level', '40'], config=True)
250 # I don't want these to be configurable
256 # I don't want these to be configurable
251 std_out_file_path = Unicode('', config=False)
257 std_out_file_path = Unicode('', config=False)
252 std_err_file_path = Unicode('', config=False)
258 std_err_file_path = Unicode('', config=False)
253 min_cores = Int(1, config=False)
259 min_cores = Int(1, config=False)
254 max_cores = Int(1, config=False)
260 max_cores = Int(1, config=False)
255 min_sockets = Int(1, config=False)
261 min_sockets = Int(1, config=False)
256 max_sockets = Int(1, config=False)
262 max_sockets = Int(1, config=False)
257 min_nodes = Int(1, config=False)
263 min_nodes = Int(1, config=False)
258 max_nodes = Int(1, config=False)
264 max_nodes = Int(1, config=False)
259 unit_type = Unicode("Core", config=False)
265 unit_type = Unicode("Core", config=False)
260 work_directory = Unicode('', config=False)
266 work_directory = Unicode('', config=False)
261
267
262 def __init__(self, config=None):
268 def __init__(self, config=None):
263 super(IPControllerTask, self).__init__(config=config)
269 super(IPControllerTask, self).__init__(config=config)
264 the_uuid = uuid.uuid1()
270 the_uuid = uuid.uuid1()
265 self.std_out_file_path = os.path.join('log','ipcontroller-%s.out' % the_uuid)
271 self.std_out_file_path = os.path.join('log','ipcontroller-%s.out' % the_uuid)
266 self.std_err_file_path = os.path.join('log','ipcontroller-%s.err' % the_uuid)
272 self.std_err_file_path = os.path.join('log','ipcontroller-%s.err' % the_uuid)
267
273
268 @property
274 @property
269 def command_line(self):
275 def command_line(self):
270 return ' '.join(self.controller_cmd + self.controller_args)
276 return ' '.join(self.controller_cmd + self.controller_args)
271
277
272
278
273 class IPEngineTask(WinHPCTask):
279 class IPEngineTask(WinHPCTask):
274
280
275 task_name = Unicode('IPEngine', config=True)
281 task_name = Unicode('IPEngine', config=True)
276 engine_cmd = List(['ipengine.exe'], config=True)
282 engine_cmd = List(['ipengine.exe'], config=True)
277 engine_args = List(['--log-to-file', '--log-level', '40'], config=True)
283 engine_args = List(['--log-to-file', '--log-level', '40'], config=True)
278 # I don't want these to be configurable
284 # I don't want these to be configurable
279 std_out_file_path = Unicode('', config=False)
285 std_out_file_path = Unicode('', config=False)
280 std_err_file_path = Unicode('', config=False)
286 std_err_file_path = Unicode('', config=False)
281 min_cores = Int(1, config=False)
287 min_cores = Int(1, config=False)
282 max_cores = Int(1, config=False)
288 max_cores = Int(1, config=False)
283 min_sockets = Int(1, config=False)
289 min_sockets = Int(1, config=False)
284 max_sockets = Int(1, config=False)
290 max_sockets = Int(1, config=False)
285 min_nodes = Int(1, config=False)
291 min_nodes = Int(1, config=False)
286 max_nodes = Int(1, config=False)
292 max_nodes = Int(1, config=False)
287 unit_type = Unicode("Core", config=False)
293 unit_type = Unicode("Core", config=False)
288 work_directory = Unicode('', config=False)
294 work_directory = Unicode('', config=False)
289
295
290 def __init__(self, config=None):
296 def __init__(self, config=None):
291 super(IPEngineTask,self).__init__(config=config)
297 super(IPEngineTask,self).__init__(config=config)
292 the_uuid = uuid.uuid1()
298 the_uuid = uuid.uuid1()
293 self.std_out_file_path = os.path.join('log','ipengine-%s.out' % the_uuid)
299 self.std_out_file_path = os.path.join('log','ipengine-%s.out' % the_uuid)
294 self.std_err_file_path = os.path.join('log','ipengine-%s.err' % the_uuid)
300 self.std_err_file_path = os.path.join('log','ipengine-%s.err' % the_uuid)
295
301
296 @property
302 @property
297 def command_line(self):
303 def command_line(self):
298 return ' '.join(self.engine_cmd + self.engine_args)
304 return ' '.join(self.engine_cmd + self.engine_args)
299
305
300
306
301 # j = WinHPCJob(None)
307 # j = WinHPCJob(None)
302 # j.job_name = 'IPCluster'
308 # j.job_name = 'IPCluster'
303 # j.username = 'GNET\\bgranger'
309 # j.username = 'GNET\\bgranger'
304 # j.requested_nodes = 'GREEN'
310 # j.requested_nodes = 'GREEN'
305 #
311 #
306 # t = WinHPCTask(None)
312 # t = WinHPCTask(None)
307 # t.task_name = 'Controller'
313 # t.task_name = 'Controller'
308 # t.command_line = r"\\blue\domainusers$\bgranger\Python\Python25\Scripts\ipcontroller.exe --log-to-file -p default --log-level 10"
314 # t.command_line = r"\\blue\domainusers$\bgranger\Python\Python25\Scripts\ipcontroller.exe --log-to-file -p default --log-level 10"
309 # t.work_directory = r"\\blue\domainusers$\bgranger\.ipython\cluster_default"
315 # t.work_directory = r"\\blue\domainusers$\bgranger\.ipython\cluster_default"
310 # t.std_out_file_path = 'controller-out.txt'
316 # t.std_out_file_path = 'controller-out.txt'
311 # t.std_err_file_path = 'controller-err.txt'
317 # t.std_err_file_path = 'controller-err.txt'
312 # t.environment_variables['PYTHONPATH'] = r"\\blue\domainusers$\bgranger\Python\Python25\Lib\site-packages"
318 # t.environment_variables['PYTHONPATH'] = r"\\blue\domainusers$\bgranger\Python\Python25\Lib\site-packages"
313 # j.add_task(t)
319 # j.add_task(t)
314
320
@@ -1,340 +1,345 b''
1 """AsyncResult objects for the client"""
1 """AsyncResult objects for the client
2
3 Authors:
4
5 * MinRK
6 """
2 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
4 #
9 #
5 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
8
13
9 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
10 # Imports
15 # Imports
11 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
12
17
13 import time
18 import time
14
19
15 from zmq import MessageTracker
20 from zmq import MessageTracker
16
21
17 from IPython.external.decorator import decorator
22 from IPython.external.decorator import decorator
18 from IPython.parallel import error
23 from IPython.parallel import error
19
24
20 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
21 # Classes
26 # Classes
22 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
23
28
24 # global empty tracker that's always done:
29 # global empty tracker that's always done:
25 finished_tracker = MessageTracker()
30 finished_tracker = MessageTracker()
26
31
27 @decorator
32 @decorator
28 def check_ready(f, self, *args, **kwargs):
33 def check_ready(f, self, *args, **kwargs):
29 """Call spin() to sync state prior to calling the method."""
34 """Call spin() to sync state prior to calling the method."""
30 self.wait(0)
35 self.wait(0)
31 if not self._ready:
36 if not self._ready:
32 raise error.TimeoutError("result not ready")
37 raise error.TimeoutError("result not ready")
33 return f(self, *args, **kwargs)
38 return f(self, *args, **kwargs)
34
39
35 class AsyncResult(object):
40 class AsyncResult(object):
36 """Class for representing results of non-blocking calls.
41 """Class for representing results of non-blocking calls.
37
42
38 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
43 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
39 """
44 """
40
45
41 msg_ids = None
46 msg_ids = None
42 _targets = None
47 _targets = None
43 _tracker = None
48 _tracker = None
44 _single_result = False
49 _single_result = False
45
50
46 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
51 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
47 if isinstance(msg_ids, basestring):
52 if isinstance(msg_ids, basestring):
48 # always a list
53 # always a list
49 msg_ids = [msg_ids]
54 msg_ids = [msg_ids]
50 if tracker is None:
55 if tracker is None:
51 # default to always done
56 # default to always done
52 tracker = finished_tracker
57 tracker = finished_tracker
53 self._client = client
58 self._client = client
54 self.msg_ids = msg_ids
59 self.msg_ids = msg_ids
55 self._fname=fname
60 self._fname=fname
56 self._targets = targets
61 self._targets = targets
57 self._tracker = tracker
62 self._tracker = tracker
58 self._ready = False
63 self._ready = False
59 self._success = None
64 self._success = None
60 if len(msg_ids) == 1:
65 if len(msg_ids) == 1:
61 self._single_result = not isinstance(targets, (list, tuple))
66 self._single_result = not isinstance(targets, (list, tuple))
62 else:
67 else:
63 self._single_result = False
68 self._single_result = False
64
69
65 def __repr__(self):
70 def __repr__(self):
66 if self._ready:
71 if self._ready:
67 return "<%s: finished>"%(self.__class__.__name__)
72 return "<%s: finished>"%(self.__class__.__name__)
68 else:
73 else:
69 return "<%s: %s>"%(self.__class__.__name__,self._fname)
74 return "<%s: %s>"%(self.__class__.__name__,self._fname)
70
75
71
76
72 def _reconstruct_result(self, res):
77 def _reconstruct_result(self, res):
73 """Reconstruct our result from actual result list (always a list)
78 """Reconstruct our result from actual result list (always a list)
74
79
75 Override me in subclasses for turning a list of results
80 Override me in subclasses for turning a list of results
76 into the expected form.
81 into the expected form.
77 """
82 """
78 if self._single_result:
83 if self._single_result:
79 return res[0]
84 return res[0]
80 else:
85 else:
81 return res
86 return res
82
87
83 def get(self, timeout=-1):
88 def get(self, timeout=-1):
84 """Return the result when it arrives.
89 """Return the result when it arrives.
85
90
86 If `timeout` is not ``None`` and the result does not arrive within
91 If `timeout` is not ``None`` and the result does not arrive within
87 `timeout` seconds then ``TimeoutError`` is raised. If the
92 `timeout` seconds then ``TimeoutError`` is raised. If the
88 remote call raised an exception then that exception will be reraised
93 remote call raised an exception then that exception will be reraised
89 by get() inside a `RemoteError`.
94 by get() inside a `RemoteError`.
90 """
95 """
91 if not self.ready():
96 if not self.ready():
92 self.wait(timeout)
97 self.wait(timeout)
93
98
94 if self._ready:
99 if self._ready:
95 if self._success:
100 if self._success:
96 return self._result
101 return self._result
97 else:
102 else:
98 raise self._exception
103 raise self._exception
99 else:
104 else:
100 raise error.TimeoutError("Result not ready.")
105 raise error.TimeoutError("Result not ready.")
101
106
102 def ready(self):
107 def ready(self):
103 """Return whether the call has completed."""
108 """Return whether the call has completed."""
104 if not self._ready:
109 if not self._ready:
105 self.wait(0)
110 self.wait(0)
106 return self._ready
111 return self._ready
107
112
108 def wait(self, timeout=-1):
113 def wait(self, timeout=-1):
109 """Wait until the result is available or until `timeout` seconds pass.
114 """Wait until the result is available or until `timeout` seconds pass.
110
115
111 This method always returns None.
116 This method always returns None.
112 """
117 """
113 if self._ready:
118 if self._ready:
114 return
119 return
115 self._ready = self._client.wait(self.msg_ids, timeout)
120 self._ready = self._client.wait(self.msg_ids, timeout)
116 if self._ready:
121 if self._ready:
117 try:
122 try:
118 results = map(self._client.results.get, self.msg_ids)
123 results = map(self._client.results.get, self.msg_ids)
119 self._result = results
124 self._result = results
120 if self._single_result:
125 if self._single_result:
121 r = results[0]
126 r = results[0]
122 if isinstance(r, Exception):
127 if isinstance(r, Exception):
123 raise r
128 raise r
124 else:
129 else:
125 results = error.collect_exceptions(results, self._fname)
130 results = error.collect_exceptions(results, self._fname)
126 self._result = self._reconstruct_result(results)
131 self._result = self._reconstruct_result(results)
127 except Exception, e:
132 except Exception, e:
128 self._exception = e
133 self._exception = e
129 self._success = False
134 self._success = False
130 else:
135 else:
131 self._success = True
136 self._success = True
132 finally:
137 finally:
133 self._metadata = map(self._client.metadata.get, self.msg_ids)
138 self._metadata = map(self._client.metadata.get, self.msg_ids)
134
139
135
140
136 def successful(self):
141 def successful(self):
137 """Return whether the call completed without raising an exception.
142 """Return whether the call completed without raising an exception.
138
143
139 Will raise ``AssertionError`` if the result is not ready.
144 Will raise ``AssertionError`` if the result is not ready.
140 """
145 """
141 assert self.ready()
146 assert self.ready()
142 return self._success
147 return self._success
143
148
144 #----------------------------------------------------------------
149 #----------------------------------------------------------------
145 # Extra methods not in mp.pool.AsyncResult
150 # Extra methods not in mp.pool.AsyncResult
146 #----------------------------------------------------------------
151 #----------------------------------------------------------------
147
152
148 def get_dict(self, timeout=-1):
153 def get_dict(self, timeout=-1):
149 """Get the results as a dict, keyed by engine_id.
154 """Get the results as a dict, keyed by engine_id.
150
155
151 timeout behavior is described in `get()`.
156 timeout behavior is described in `get()`.
152 """
157 """
153
158
154 results = self.get(timeout)
159 results = self.get(timeout)
155 engine_ids = [ md['engine_id'] for md in self._metadata ]
160 engine_ids = [ md['engine_id'] for md in self._metadata ]
156 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
161 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
157 maxcount = bycount.count(bycount[-1])
162 maxcount = bycount.count(bycount[-1])
158 if maxcount > 1:
163 if maxcount > 1:
159 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
164 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
160 maxcount, bycount[-1]))
165 maxcount, bycount[-1]))
161
166
162 return dict(zip(engine_ids,results))
167 return dict(zip(engine_ids,results))
163
168
164 @property
169 @property
165 def result(self):
170 def result(self):
166 """result property wrapper for `get(timeout=0)`."""
171 """result property wrapper for `get(timeout=0)`."""
167 return self.get()
172 return self.get()
168
173
169 # abbreviated alias:
174 # abbreviated alias:
170 r = result
175 r = result
171
176
172 @property
177 @property
173 @check_ready
178 @check_ready
174 def metadata(self):
179 def metadata(self):
175 """property for accessing execution metadata."""
180 """property for accessing execution metadata."""
176 if self._single_result:
181 if self._single_result:
177 return self._metadata[0]
182 return self._metadata[0]
178 else:
183 else:
179 return self._metadata
184 return self._metadata
180
185
181 @property
186 @property
182 def result_dict(self):
187 def result_dict(self):
183 """result property as a dict."""
188 """result property as a dict."""
184 return self.get_dict()
189 return self.get_dict()
185
190
186 def __dict__(self):
191 def __dict__(self):
187 return self.get_dict(0)
192 return self.get_dict(0)
188
193
189 def abort(self):
194 def abort(self):
190 """abort my tasks."""
195 """abort my tasks."""
191 assert not self.ready(), "Can't abort, I am already done!"
196 assert not self.ready(), "Can't abort, I am already done!"
192 return self.client.abort(self.msg_ids, targets=self._targets, block=True)
197 return self.client.abort(self.msg_ids, targets=self._targets, block=True)
193
198
194 @property
199 @property
195 def sent(self):
200 def sent(self):
196 """check whether my messages have been sent."""
201 """check whether my messages have been sent."""
197 return self._tracker.done
202 return self._tracker.done
198
203
199 def wait_for_send(self, timeout=-1):
204 def wait_for_send(self, timeout=-1):
200 """wait for pyzmq send to complete.
205 """wait for pyzmq send to complete.
201
206
202 This is necessary when sending arrays that you intend to edit in-place.
207 This is necessary when sending arrays that you intend to edit in-place.
203 `timeout` is in seconds, and will raise TimeoutError if it is reached
208 `timeout` is in seconds, and will raise TimeoutError if it is reached
204 before the send completes.
209 before the send completes.
205 """
210 """
206 return self._tracker.wait(timeout)
211 return self._tracker.wait(timeout)
207
212
208 #-------------------------------------
213 #-------------------------------------
209 # dict-access
214 # dict-access
210 #-------------------------------------
215 #-------------------------------------
211
216
212 @check_ready
217 @check_ready
213 def __getitem__(self, key):
218 def __getitem__(self, key):
214 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
219 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
215 """
220 """
216 if isinstance(key, int):
221 if isinstance(key, int):
217 return error.collect_exceptions([self._result[key]], self._fname)[0]
222 return error.collect_exceptions([self._result[key]], self._fname)[0]
218 elif isinstance(key, slice):
223 elif isinstance(key, slice):
219 return error.collect_exceptions(self._result[key], self._fname)
224 return error.collect_exceptions(self._result[key], self._fname)
220 elif isinstance(key, basestring):
225 elif isinstance(key, basestring):
221 values = [ md[key] for md in self._metadata ]
226 values = [ md[key] for md in self._metadata ]
222 if self._single_result:
227 if self._single_result:
223 return values[0]
228 return values[0]
224 else:
229 else:
225 return values
230 return values
226 else:
231 else:
227 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
232 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
228
233
229 @check_ready
234 @check_ready
230 def __getattr__(self, key):
235 def __getattr__(self, key):
231 """getattr maps to getitem for convenient attr access to metadata."""
236 """getattr maps to getitem for convenient attr access to metadata."""
232 if key not in self._metadata[0].keys():
237 if key not in self._metadata[0].keys():
233 raise AttributeError("%r object has no attribute %r"%(
238 raise AttributeError("%r object has no attribute %r"%(
234 self.__class__.__name__, key))
239 self.__class__.__name__, key))
235 return self.__getitem__(key)
240 return self.__getitem__(key)
236
241
237 # asynchronous iterator:
242 # asynchronous iterator:
238 def __iter__(self):
243 def __iter__(self):
239 if self._single_result:
244 if self._single_result:
240 raise TypeError("AsyncResults with a single result are not iterable.")
245 raise TypeError("AsyncResults with a single result are not iterable.")
241 try:
246 try:
242 rlist = self.get(0)
247 rlist = self.get(0)
243 except error.TimeoutError:
248 except error.TimeoutError:
244 # wait for each result individually
249 # wait for each result individually
245 for msg_id in self.msg_ids:
250 for msg_id in self.msg_ids:
246 ar = AsyncResult(self._client, msg_id, self._fname)
251 ar = AsyncResult(self._client, msg_id, self._fname)
247 yield ar.get()
252 yield ar.get()
248 else:
253 else:
249 # already done
254 # already done
250 for r in rlist:
255 for r in rlist:
251 yield r
256 yield r
252
257
253
258
254
259
255 class AsyncMapResult(AsyncResult):
260 class AsyncMapResult(AsyncResult):
256 """Class for representing results of non-blocking gathers.
261 """Class for representing results of non-blocking gathers.
257
262
258 This will properly reconstruct the gather.
263 This will properly reconstruct the gather.
259 """
264 """
260
265
261 def __init__(self, client, msg_ids, mapObject, fname=''):
266 def __init__(self, client, msg_ids, mapObject, fname=''):
262 AsyncResult.__init__(self, client, msg_ids, fname=fname)
267 AsyncResult.__init__(self, client, msg_ids, fname=fname)
263 self._mapObject = mapObject
268 self._mapObject = mapObject
264 self._single_result = False
269 self._single_result = False
265
270
266 def _reconstruct_result(self, res):
271 def _reconstruct_result(self, res):
267 """Perform the gather on the actual results."""
272 """Perform the gather on the actual results."""
268 return self._mapObject.joinPartitions(res)
273 return self._mapObject.joinPartitions(res)
269
274
270 # asynchronous iterator:
275 # asynchronous iterator:
271 def __iter__(self):
276 def __iter__(self):
272 try:
277 try:
273 rlist = self.get(0)
278 rlist = self.get(0)
274 except error.TimeoutError:
279 except error.TimeoutError:
275 # wait for each result individually
280 # wait for each result individually
276 for msg_id in self.msg_ids:
281 for msg_id in self.msg_ids:
277 ar = AsyncResult(self._client, msg_id, self._fname)
282 ar = AsyncResult(self._client, msg_id, self._fname)
278 rlist = ar.get()
283 rlist = ar.get()
279 try:
284 try:
280 for r in rlist:
285 for r in rlist:
281 yield r
286 yield r
282 except TypeError:
287 except TypeError:
283 # flattened, not a list
288 # flattened, not a list
284 # this could get broken by flattened data that returns iterables
289 # this could get broken by flattened data that returns iterables
285 # but most calls to map do not expose the `flatten` argument
290 # but most calls to map do not expose the `flatten` argument
286 yield rlist
291 yield rlist
287 else:
292 else:
288 # already done
293 # already done
289 for r in rlist:
294 for r in rlist:
290 yield r
295 yield r
291
296
292
297
293 class AsyncHubResult(AsyncResult):
298 class AsyncHubResult(AsyncResult):
294 """Class to wrap pending results that must be requested from the Hub.
299 """Class to wrap pending results that must be requested from the Hub.
295
300
296 Note that waiting/polling on these objects requires polling the Hubover the network,
301 Note that waiting/polling on these objects requires polling the Hubover the network,
297 so use `AsyncHubResult.wait()` sparingly.
302 so use `AsyncHubResult.wait()` sparingly.
298 """
303 """
299
304
300 def wait(self, timeout=-1):
305 def wait(self, timeout=-1):
301 """wait for result to complete."""
306 """wait for result to complete."""
302 start = time.time()
307 start = time.time()
303 if self._ready:
308 if self._ready:
304 return
309 return
305 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
310 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
306 local_ready = self._client.wait(local_ids, timeout)
311 local_ready = self._client.wait(local_ids, timeout)
307 if local_ready:
312 if local_ready:
308 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
313 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
309 if not remote_ids:
314 if not remote_ids:
310 self._ready = True
315 self._ready = True
311 else:
316 else:
312 rdict = self._client.result_status(remote_ids, status_only=False)
317 rdict = self._client.result_status(remote_ids, status_only=False)
313 pending = rdict['pending']
318 pending = rdict['pending']
314 while pending and (timeout < 0 or time.time() < start+timeout):
319 while pending and (timeout < 0 or time.time() < start+timeout):
315 rdict = self._client.result_status(remote_ids, status_only=False)
320 rdict = self._client.result_status(remote_ids, status_only=False)
316 pending = rdict['pending']
321 pending = rdict['pending']
317 if pending:
322 if pending:
318 time.sleep(0.1)
323 time.sleep(0.1)
319 if not pending:
324 if not pending:
320 self._ready = True
325 self._ready = True
321 if self._ready:
326 if self._ready:
322 try:
327 try:
323 results = map(self._client.results.get, self.msg_ids)
328 results = map(self._client.results.get, self.msg_ids)
324 self._result = results
329 self._result = results
325 if self._single_result:
330 if self._single_result:
326 r = results[0]
331 r = results[0]
327 if isinstance(r, Exception):
332 if isinstance(r, Exception):
328 raise r
333 raise r
329 else:
334 else:
330 results = error.collect_exceptions(results, self._fname)
335 results = error.collect_exceptions(results, self._fname)
331 self._result = self._reconstruct_result(results)
336 self._result = self._reconstruct_result(results)
332 except Exception, e:
337 except Exception, e:
333 self._exception = e
338 self._exception = e
334 self._success = False
339 self._success = False
335 else:
340 else:
336 self._success = True
341 self._success = True
337 finally:
342 finally:
338 self._metadata = map(self._client.metadata.get, self.msg_ids)
343 self._metadata = map(self._client.metadata.get, self.msg_ids)
339
344
340 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult'] No newline at end of file
345 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
@@ -1,1368 +1,1373 b''
1 """A semi-synchronous Client for the ZMQ cluster"""
1 """A semi-synchronous Client for the ZMQ cluster
2
3 Authors:
4
5 * MinRK
6 """
2 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
4 #
9 #
5 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
8
13
9 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
10 # Imports
15 # Imports
11 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
12
17
13 import os
18 import os
14 import json
19 import json
15 import time
20 import time
16 import warnings
21 import warnings
17 from datetime import datetime
22 from datetime import datetime
18 from getpass import getpass
23 from getpass import getpass
19 from pprint import pprint
24 from pprint import pprint
20
25
21 pjoin = os.path.join
26 pjoin = os.path.join
22
27
23 import zmq
28 import zmq
24 # from zmq.eventloop import ioloop, zmqstream
29 # from zmq.eventloop import ioloop, zmqstream
25
30
26 from IPython.utils.path import get_ipython_dir
31 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
32 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
28 Dict, List, Bool, Set)
33 Dict, List, Bool, Set)
29 from IPython.external.decorator import decorator
34 from IPython.external.decorator import decorator
30 from IPython.external.ssh import tunnel
35 from IPython.external.ssh import tunnel
31
36
32 from IPython.parallel import error
37 from IPython.parallel import error
33 from IPython.parallel import util
38 from IPython.parallel import util
34
39
35 from IPython.zmq.session import Session, Message
40 from IPython.zmq.session import Session, Message
36
41
37 from .asyncresult import AsyncResult, AsyncHubResult
42 from .asyncresult import AsyncResult, AsyncHubResult
38 from IPython.core.newapplication import ProfileDir, ProfileDirError
43 from IPython.core.newapplication import ProfileDir, ProfileDirError
39 from .view import DirectView, LoadBalancedView
44 from .view import DirectView, LoadBalancedView
40
45
41 #--------------------------------------------------------------------------
46 #--------------------------------------------------------------------------
42 # Decorators for Client methods
47 # Decorators for Client methods
43 #--------------------------------------------------------------------------
48 #--------------------------------------------------------------------------
44
49
45 @decorator
50 @decorator
46 def spin_first(f, self, *args, **kwargs):
51 def spin_first(f, self, *args, **kwargs):
47 """Call spin() to sync state prior to calling the method."""
52 """Call spin() to sync state prior to calling the method."""
48 self.spin()
53 self.spin()
49 return f(self, *args, **kwargs)
54 return f(self, *args, **kwargs)
50
55
51
56
52 #--------------------------------------------------------------------------
57 #--------------------------------------------------------------------------
53 # Classes
58 # Classes
54 #--------------------------------------------------------------------------
59 #--------------------------------------------------------------------------
55
60
56 class Metadata(dict):
61 class Metadata(dict):
57 """Subclass of dict for initializing metadata values.
62 """Subclass of dict for initializing metadata values.
58
63
59 Attribute access works on keys.
64 Attribute access works on keys.
60
65
61 These objects have a strict set of keys - errors will raise if you try
66 These objects have a strict set of keys - errors will raise if you try
62 to add new keys.
67 to add new keys.
63 """
68 """
64 def __init__(self, *args, **kwargs):
69 def __init__(self, *args, **kwargs):
65 dict.__init__(self)
70 dict.__init__(self)
66 md = {'msg_id' : None,
71 md = {'msg_id' : None,
67 'submitted' : None,
72 'submitted' : None,
68 'started' : None,
73 'started' : None,
69 'completed' : None,
74 'completed' : None,
70 'received' : None,
75 'received' : None,
71 'engine_uuid' : None,
76 'engine_uuid' : None,
72 'engine_id' : None,
77 'engine_id' : None,
73 'follow' : None,
78 'follow' : None,
74 'after' : None,
79 'after' : None,
75 'status' : None,
80 'status' : None,
76
81
77 'pyin' : None,
82 'pyin' : None,
78 'pyout' : None,
83 'pyout' : None,
79 'pyerr' : None,
84 'pyerr' : None,
80 'stdout' : '',
85 'stdout' : '',
81 'stderr' : '',
86 'stderr' : '',
82 }
87 }
83 self.update(md)
88 self.update(md)
84 self.update(dict(*args, **kwargs))
89 self.update(dict(*args, **kwargs))
85
90
86 def __getattr__(self, key):
91 def __getattr__(self, key):
87 """getattr aliased to getitem"""
92 """getattr aliased to getitem"""
88 if key in self.iterkeys():
93 if key in self.iterkeys():
89 return self[key]
94 return self[key]
90 else:
95 else:
91 raise AttributeError(key)
96 raise AttributeError(key)
92
97
93 def __setattr__(self, key, value):
98 def __setattr__(self, key, value):
94 """setattr aliased to setitem, with strict"""
99 """setattr aliased to setitem, with strict"""
95 if key in self.iterkeys():
100 if key in self.iterkeys():
96 self[key] = value
101 self[key] = value
97 else:
102 else:
98 raise AttributeError(key)
103 raise AttributeError(key)
99
104
100 def __setitem__(self, key, value):
105 def __setitem__(self, key, value):
101 """strict static key enforcement"""
106 """strict static key enforcement"""
102 if key in self.iterkeys():
107 if key in self.iterkeys():
103 dict.__setitem__(self, key, value)
108 dict.__setitem__(self, key, value)
104 else:
109 else:
105 raise KeyError(key)
110 raise KeyError(key)
106
111
107
112
108 class Client(HasTraits):
113 class Client(HasTraits):
109 """A semi-synchronous client to the IPython ZMQ cluster
114 """A semi-synchronous client to the IPython ZMQ cluster
110
115
111 Parameters
116 Parameters
112 ----------
117 ----------
113
118
114 url_or_file : bytes; zmq url or path to ipcontroller-client.json
119 url_or_file : bytes; zmq url or path to ipcontroller-client.json
115 Connection information for the Hub's registration. If a json connector
120 Connection information for the Hub's registration. If a json connector
116 file is given, then likely no further configuration is necessary.
121 file is given, then likely no further configuration is necessary.
117 [Default: use profile]
122 [Default: use profile]
118 profile : bytes
123 profile : bytes
119 The name of the Cluster profile to be used to find connector information.
124 The name of the Cluster profile to be used to find connector information.
120 [Default: 'default']
125 [Default: 'default']
121 context : zmq.Context
126 context : zmq.Context
122 Pass an existing zmq.Context instance, otherwise the client will create its own.
127 Pass an existing zmq.Context instance, otherwise the client will create its own.
123 debug : bool
128 debug : bool
124 flag for lots of message printing for debug purposes
129 flag for lots of message printing for debug purposes
125 timeout : int/float
130 timeout : int/float
126 time (in seconds) to wait for connection replies from the Hub
131 time (in seconds) to wait for connection replies from the Hub
127 [Default: 10]
132 [Default: 10]
128
133
129 #-------------- session related args ----------------
134 #-------------- session related args ----------------
130
135
131 config : Config object
136 config : Config object
132 If specified, this will be relayed to the Session for configuration
137 If specified, this will be relayed to the Session for configuration
133 username : str
138 username : str
134 set username for the session object
139 set username for the session object
135 packer : str (import_string) or callable
140 packer : str (import_string) or callable
136 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
141 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
137 function to serialize messages. Must support same input as
142 function to serialize messages. Must support same input as
138 JSON, and output must be bytes.
143 JSON, and output must be bytes.
139 You can pass a callable directly as `pack`
144 You can pass a callable directly as `pack`
140 unpacker : str (import_string) or callable
145 unpacker : str (import_string) or callable
141 The inverse of packer. Only necessary if packer is specified as *not* one
146 The inverse of packer. Only necessary if packer is specified as *not* one
142 of 'json' or 'pickle'.
147 of 'json' or 'pickle'.
143
148
144 #-------------- ssh related args ----------------
149 #-------------- ssh related args ----------------
145 # These are args for configuring the ssh tunnel to be used
150 # These are args for configuring the ssh tunnel to be used
146 # credentials are used to forward connections over ssh to the Controller
151 # credentials are used to forward connections over ssh to the Controller
147 # Note that the ip given in `addr` needs to be relative to sshserver
152 # Note that the ip given in `addr` needs to be relative to sshserver
148 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
153 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
149 # and set sshserver as the same machine the Controller is on. However,
154 # and set sshserver as the same machine the Controller is on. However,
150 # the only requirement is that sshserver is able to see the Controller
155 # the only requirement is that sshserver is able to see the Controller
151 # (i.e. is within the same trusted network).
156 # (i.e. is within the same trusted network).
152
157
153 sshserver : str
158 sshserver : str
154 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
159 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
155 If keyfile or password is specified, and this is not, it will default to
160 If keyfile or password is specified, and this is not, it will default to
156 the ip given in addr.
161 the ip given in addr.
157 sshkey : str; path to public ssh key file
162 sshkey : str; path to public ssh key file
158 This specifies a key to be used in ssh login, default None.
163 This specifies a key to be used in ssh login, default None.
159 Regular default ssh keys will be used without specifying this argument.
164 Regular default ssh keys will be used without specifying this argument.
160 password : str
165 password : str
161 Your ssh password to sshserver. Note that if this is left None,
166 Your ssh password to sshserver. Note that if this is left None,
162 you will be prompted for it if passwordless key based login is unavailable.
167 you will be prompted for it if passwordless key based login is unavailable.
163 paramiko : bool
168 paramiko : bool
164 flag for whether to use paramiko instead of shell ssh for tunneling.
169 flag for whether to use paramiko instead of shell ssh for tunneling.
165 [default: True on win32, False else]
170 [default: True on win32, False else]
166
171
167 ------- exec authentication args -------
172 ------- exec authentication args -------
168 If even localhost is untrusted, you can have some protection against
173 If even localhost is untrusted, you can have some protection against
169 unauthorized execution by signing messages with HMAC digests.
174 unauthorized execution by signing messages with HMAC digests.
170 Messages are still sent as cleartext, so if someone can snoop your
175 Messages are still sent as cleartext, so if someone can snoop your
171 loopback traffic this will not protect your privacy, but will prevent
176 loopback traffic this will not protect your privacy, but will prevent
172 unauthorized execution.
177 unauthorized execution.
173
178
174 exec_key : str
179 exec_key : str
175 an authentication key or file containing a key
180 an authentication key or file containing a key
176 default: None
181 default: None
177
182
178
183
179 Attributes
184 Attributes
180 ----------
185 ----------
181
186
182 ids : list of int engine IDs
187 ids : list of int engine IDs
183 requesting the ids attribute always synchronizes
188 requesting the ids attribute always synchronizes
184 the registration state. To request ids without synchronization,
189 the registration state. To request ids without synchronization,
185 use semi-private _ids attributes.
190 use semi-private _ids attributes.
186
191
187 history : list of msg_ids
192 history : list of msg_ids
188 a list of msg_ids, keeping track of all the execution
193 a list of msg_ids, keeping track of all the execution
189 messages you have submitted in order.
194 messages you have submitted in order.
190
195
191 outstanding : set of msg_ids
196 outstanding : set of msg_ids
192 a set of msg_ids that have been submitted, but whose
197 a set of msg_ids that have been submitted, but whose
193 results have not yet been received.
198 results have not yet been received.
194
199
195 results : dict
200 results : dict
196 a dict of all our results, keyed by msg_id
201 a dict of all our results, keyed by msg_id
197
202
198 block : bool
203 block : bool
199 determines default behavior when block not specified
204 determines default behavior when block not specified
200 in execution methods
205 in execution methods
201
206
202 Methods
207 Methods
203 -------
208 -------
204
209
205 spin
210 spin
206 flushes incoming results and registration state changes
211 flushes incoming results and registration state changes
207 control methods spin, and requesting `ids` also ensures up to date
212 control methods spin, and requesting `ids` also ensures up to date
208
213
209 wait
214 wait
210 wait on one or more msg_ids
215 wait on one or more msg_ids
211
216
212 execution methods
217 execution methods
213 apply
218 apply
214 legacy: execute, run
219 legacy: execute, run
215
220
216 data movement
221 data movement
217 push, pull, scatter, gather
222 push, pull, scatter, gather
218
223
219 query methods
224 query methods
220 queue_status, get_result, purge, result_status
225 queue_status, get_result, purge, result_status
221
226
222 control methods
227 control methods
223 abort, shutdown
228 abort, shutdown
224
229
225 """
230 """
226
231
227
232
228 block = Bool(False)
233 block = Bool(False)
229 outstanding = Set()
234 outstanding = Set()
230 results = Instance('collections.defaultdict', (dict,))
235 results = Instance('collections.defaultdict', (dict,))
231 metadata = Instance('collections.defaultdict', (Metadata,))
236 metadata = Instance('collections.defaultdict', (Metadata,))
232 history = List()
237 history = List()
233 debug = Bool(False)
238 debug = Bool(False)
234 profile=Unicode('default')
239 profile=Unicode('default')
235
240
236 _outstanding_dict = Instance('collections.defaultdict', (set,))
241 _outstanding_dict = Instance('collections.defaultdict', (set,))
237 _ids = List()
242 _ids = List()
238 _connected=Bool(False)
243 _connected=Bool(False)
239 _ssh=Bool(False)
244 _ssh=Bool(False)
240 _context = Instance('zmq.Context')
245 _context = Instance('zmq.Context')
241 _config = Dict()
246 _config = Dict()
242 _engines=Instance(util.ReverseDict, (), {})
247 _engines=Instance(util.ReverseDict, (), {})
243 # _hub_socket=Instance('zmq.Socket')
248 # _hub_socket=Instance('zmq.Socket')
244 _query_socket=Instance('zmq.Socket')
249 _query_socket=Instance('zmq.Socket')
245 _control_socket=Instance('zmq.Socket')
250 _control_socket=Instance('zmq.Socket')
246 _iopub_socket=Instance('zmq.Socket')
251 _iopub_socket=Instance('zmq.Socket')
247 _notification_socket=Instance('zmq.Socket')
252 _notification_socket=Instance('zmq.Socket')
248 _mux_socket=Instance('zmq.Socket')
253 _mux_socket=Instance('zmq.Socket')
249 _task_socket=Instance('zmq.Socket')
254 _task_socket=Instance('zmq.Socket')
250 _task_scheme=Unicode()
255 _task_scheme=Unicode()
251 _closed = False
256 _closed = False
252 _ignored_control_replies=Int(0)
257 _ignored_control_replies=Int(0)
253 _ignored_hub_replies=Int(0)
258 _ignored_hub_replies=Int(0)
254
259
255 def __init__(self, url_or_file=None, profile='default', profile_dir=None, ipython_dir=None,
260 def __init__(self, url_or_file=None, profile='default', profile_dir=None, ipython_dir=None,
256 context=None, debug=False, exec_key=None,
261 context=None, debug=False, exec_key=None,
257 sshserver=None, sshkey=None, password=None, paramiko=None,
262 sshserver=None, sshkey=None, password=None, paramiko=None,
258 timeout=10, **extra_args
263 timeout=10, **extra_args
259 ):
264 ):
260 super(Client, self).__init__(debug=debug, profile=profile)
265 super(Client, self).__init__(debug=debug, profile=profile)
261 if context is None:
266 if context is None:
262 context = zmq.Context.instance()
267 context = zmq.Context.instance()
263 self._context = context
268 self._context = context
264
269
265
270
266 self._setup_profile_dir(profile, profile_dir, ipython_dir)
271 self._setup_profile_dir(profile, profile_dir, ipython_dir)
267 if self._cd is not None:
272 if self._cd is not None:
268 if url_or_file is None:
273 if url_or_file is None:
269 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
274 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
270 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
275 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
271 " Please specify at least one of url_or_file or profile."
276 " Please specify at least one of url_or_file or profile."
272
277
273 try:
278 try:
274 util.validate_url(url_or_file)
279 util.validate_url(url_or_file)
275 except AssertionError:
280 except AssertionError:
276 if not os.path.exists(url_or_file):
281 if not os.path.exists(url_or_file):
277 if self._cd:
282 if self._cd:
278 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
283 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
279 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
284 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
280 with open(url_or_file) as f:
285 with open(url_or_file) as f:
281 cfg = json.loads(f.read())
286 cfg = json.loads(f.read())
282 else:
287 else:
283 cfg = {'url':url_or_file}
288 cfg = {'url':url_or_file}
284
289
285 # sync defaults from args, json:
290 # sync defaults from args, json:
286 if sshserver:
291 if sshserver:
287 cfg['ssh'] = sshserver
292 cfg['ssh'] = sshserver
288 if exec_key:
293 if exec_key:
289 cfg['exec_key'] = exec_key
294 cfg['exec_key'] = exec_key
290 exec_key = cfg['exec_key']
295 exec_key = cfg['exec_key']
291 sshserver=cfg['ssh']
296 sshserver=cfg['ssh']
292 url = cfg['url']
297 url = cfg['url']
293 location = cfg.setdefault('location', None)
298 location = cfg.setdefault('location', None)
294 cfg['url'] = util.disambiguate_url(cfg['url'], location)
299 cfg['url'] = util.disambiguate_url(cfg['url'], location)
295 url = cfg['url']
300 url = cfg['url']
296
301
297 self._config = cfg
302 self._config = cfg
298
303
299 self._ssh = bool(sshserver or sshkey or password)
304 self._ssh = bool(sshserver or sshkey or password)
300 if self._ssh and sshserver is None:
305 if self._ssh and sshserver is None:
301 # default to ssh via localhost
306 # default to ssh via localhost
302 sshserver = url.split('://')[1].split(':')[0]
307 sshserver = url.split('://')[1].split(':')[0]
303 if self._ssh and password is None:
308 if self._ssh and password is None:
304 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
309 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
305 password=False
310 password=False
306 else:
311 else:
307 password = getpass("SSH Password for %s: "%sshserver)
312 password = getpass("SSH Password for %s: "%sshserver)
308 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
313 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
309
314
310 # configure and construct the session
315 # configure and construct the session
311 if exec_key is not None:
316 if exec_key is not None:
312 if os.path.isfile(exec_key):
317 if os.path.isfile(exec_key):
313 extra_args['keyfile'] = exec_key
318 extra_args['keyfile'] = exec_key
314 else:
319 else:
315 extra_args['key'] = exec_key
320 extra_args['key'] = exec_key
316 self.session = Session(**extra_args)
321 self.session = Session(**extra_args)
317
322
318 self._query_socket = self._context.socket(zmq.XREQ)
323 self._query_socket = self._context.socket(zmq.XREQ)
319 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
324 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
320 if self._ssh:
325 if self._ssh:
321 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
326 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
322 else:
327 else:
323 self._query_socket.connect(url)
328 self._query_socket.connect(url)
324
329
325 self.session.debug = self.debug
330 self.session.debug = self.debug
326
331
327 self._notification_handlers = {'registration_notification' : self._register_engine,
332 self._notification_handlers = {'registration_notification' : self._register_engine,
328 'unregistration_notification' : self._unregister_engine,
333 'unregistration_notification' : self._unregister_engine,
329 'shutdown_notification' : lambda msg: self.close(),
334 'shutdown_notification' : lambda msg: self.close(),
330 }
335 }
331 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
336 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
332 'apply_reply' : self._handle_apply_reply}
337 'apply_reply' : self._handle_apply_reply}
333 self._connect(sshserver, ssh_kwargs, timeout)
338 self._connect(sshserver, ssh_kwargs, timeout)
334
339
335 def __del__(self):
340 def __del__(self):
336 """cleanup sockets, but _not_ context."""
341 """cleanup sockets, but _not_ context."""
337 self.close()
342 self.close()
338
343
339 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
344 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
340 if ipython_dir is None:
345 if ipython_dir is None:
341 ipython_dir = get_ipython_dir()
346 ipython_dir = get_ipython_dir()
342 if profile_dir is not None:
347 if profile_dir is not None:
343 try:
348 try:
344 self._cd = ProfileDir.find_profile_dir(profile_dir)
349 self._cd = ProfileDir.find_profile_dir(profile_dir)
345 return
350 return
346 except ProfileDirError:
351 except ProfileDirError:
347 pass
352 pass
348 elif profile is not None:
353 elif profile is not None:
349 try:
354 try:
350 self._cd = ProfileDir.find_profile_dir_by_name(
355 self._cd = ProfileDir.find_profile_dir_by_name(
351 ipython_dir, profile)
356 ipython_dir, profile)
352 return
357 return
353 except ProfileDirError:
358 except ProfileDirError:
354 pass
359 pass
355 self._cd = None
360 self._cd = None
356
361
357 def _update_engines(self, engines):
362 def _update_engines(self, engines):
358 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
363 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
359 for k,v in engines.iteritems():
364 for k,v in engines.iteritems():
360 eid = int(k)
365 eid = int(k)
361 self._engines[eid] = bytes(v) # force not unicode
366 self._engines[eid] = bytes(v) # force not unicode
362 self._ids.append(eid)
367 self._ids.append(eid)
363 self._ids = sorted(self._ids)
368 self._ids = sorted(self._ids)
364 if sorted(self._engines.keys()) != range(len(self._engines)) and \
369 if sorted(self._engines.keys()) != range(len(self._engines)) and \
365 self._task_scheme == 'pure' and self._task_socket:
370 self._task_scheme == 'pure' and self._task_socket:
366 self._stop_scheduling_tasks()
371 self._stop_scheduling_tasks()
367
372
368 def _stop_scheduling_tasks(self):
373 def _stop_scheduling_tasks(self):
369 """Stop scheduling tasks because an engine has been unregistered
374 """Stop scheduling tasks because an engine has been unregistered
370 from a pure ZMQ scheduler.
375 from a pure ZMQ scheduler.
371 """
376 """
372 self._task_socket.close()
377 self._task_socket.close()
373 self._task_socket = None
378 self._task_socket = None
374 msg = "An engine has been unregistered, and we are using pure " +\
379 msg = "An engine has been unregistered, and we are using pure " +\
375 "ZMQ task scheduling. Task farming will be disabled."
380 "ZMQ task scheduling. Task farming will be disabled."
376 if self.outstanding:
381 if self.outstanding:
377 msg += " If you were running tasks when this happened, " +\
382 msg += " If you were running tasks when this happened, " +\
378 "some `outstanding` msg_ids may never resolve."
383 "some `outstanding` msg_ids may never resolve."
379 warnings.warn(msg, RuntimeWarning)
384 warnings.warn(msg, RuntimeWarning)
380
385
381 def _build_targets(self, targets):
386 def _build_targets(self, targets):
382 """Turn valid target IDs or 'all' into two lists:
387 """Turn valid target IDs or 'all' into two lists:
383 (int_ids, uuids).
388 (int_ids, uuids).
384 """
389 """
385 if not self._ids:
390 if not self._ids:
386 # flush notification socket if no engines yet, just in case
391 # flush notification socket if no engines yet, just in case
387 if not self.ids:
392 if not self.ids:
388 raise error.NoEnginesRegistered("Can't build targets without any engines")
393 raise error.NoEnginesRegistered("Can't build targets without any engines")
389
394
390 if targets is None:
395 if targets is None:
391 targets = self._ids
396 targets = self._ids
392 elif isinstance(targets, str):
397 elif isinstance(targets, str):
393 if targets.lower() == 'all':
398 if targets.lower() == 'all':
394 targets = self._ids
399 targets = self._ids
395 else:
400 else:
396 raise TypeError("%r not valid str target, must be 'all'"%(targets))
401 raise TypeError("%r not valid str target, must be 'all'"%(targets))
397 elif isinstance(targets, int):
402 elif isinstance(targets, int):
398 if targets < 0:
403 if targets < 0:
399 targets = self.ids[targets]
404 targets = self.ids[targets]
400 if targets not in self._ids:
405 if targets not in self._ids:
401 raise IndexError("No such engine: %i"%targets)
406 raise IndexError("No such engine: %i"%targets)
402 targets = [targets]
407 targets = [targets]
403
408
404 if isinstance(targets, slice):
409 if isinstance(targets, slice):
405 indices = range(len(self._ids))[targets]
410 indices = range(len(self._ids))[targets]
406 ids = self.ids
411 ids = self.ids
407 targets = [ ids[i] for i in indices ]
412 targets = [ ids[i] for i in indices ]
408
413
409 if not isinstance(targets, (tuple, list, xrange)):
414 if not isinstance(targets, (tuple, list, xrange)):
410 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
415 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
411
416
412 return [self._engines[t] for t in targets], list(targets)
417 return [self._engines[t] for t in targets], list(targets)
413
418
414 def _connect(self, sshserver, ssh_kwargs, timeout):
419 def _connect(self, sshserver, ssh_kwargs, timeout):
415 """setup all our socket connections to the cluster. This is called from
420 """setup all our socket connections to the cluster. This is called from
416 __init__."""
421 __init__."""
417
422
418 # Maybe allow reconnecting?
423 # Maybe allow reconnecting?
419 if self._connected:
424 if self._connected:
420 return
425 return
421 self._connected=True
426 self._connected=True
422
427
423 def connect_socket(s, url):
428 def connect_socket(s, url):
424 url = util.disambiguate_url(url, self._config['location'])
429 url = util.disambiguate_url(url, self._config['location'])
425 if self._ssh:
430 if self._ssh:
426 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
431 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
427 else:
432 else:
428 return s.connect(url)
433 return s.connect(url)
429
434
430 self.session.send(self._query_socket, 'connection_request')
435 self.session.send(self._query_socket, 'connection_request')
431 r,w,x = zmq.select([self._query_socket],[],[], timeout)
436 r,w,x = zmq.select([self._query_socket],[],[], timeout)
432 if not r:
437 if not r:
433 raise error.TimeoutError("Hub connection request timed out")
438 raise error.TimeoutError("Hub connection request timed out")
434 idents,msg = self.session.recv(self._query_socket,mode=0)
439 idents,msg = self.session.recv(self._query_socket,mode=0)
435 if self.debug:
440 if self.debug:
436 pprint(msg)
441 pprint(msg)
437 msg = Message(msg)
442 msg = Message(msg)
438 content = msg.content
443 content = msg.content
439 self._config['registration'] = dict(content)
444 self._config['registration'] = dict(content)
440 if content.status == 'ok':
445 if content.status == 'ok':
441 if content.mux:
446 if content.mux:
442 self._mux_socket = self._context.socket(zmq.XREQ)
447 self._mux_socket = self._context.socket(zmq.XREQ)
443 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
448 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
444 connect_socket(self._mux_socket, content.mux)
449 connect_socket(self._mux_socket, content.mux)
445 if content.task:
450 if content.task:
446 self._task_scheme, task_addr = content.task
451 self._task_scheme, task_addr = content.task
447 self._task_socket = self._context.socket(zmq.XREQ)
452 self._task_socket = self._context.socket(zmq.XREQ)
448 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
453 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
449 connect_socket(self._task_socket, task_addr)
454 connect_socket(self._task_socket, task_addr)
450 if content.notification:
455 if content.notification:
451 self._notification_socket = self._context.socket(zmq.SUB)
456 self._notification_socket = self._context.socket(zmq.SUB)
452 connect_socket(self._notification_socket, content.notification)
457 connect_socket(self._notification_socket, content.notification)
453 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
458 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
454 # if content.query:
459 # if content.query:
455 # self._query_socket = self._context.socket(zmq.XREQ)
460 # self._query_socket = self._context.socket(zmq.XREQ)
456 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
461 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
457 # connect_socket(self._query_socket, content.query)
462 # connect_socket(self._query_socket, content.query)
458 if content.control:
463 if content.control:
459 self._control_socket = self._context.socket(zmq.XREQ)
464 self._control_socket = self._context.socket(zmq.XREQ)
460 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
465 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
461 connect_socket(self._control_socket, content.control)
466 connect_socket(self._control_socket, content.control)
462 if content.iopub:
467 if content.iopub:
463 self._iopub_socket = self._context.socket(zmq.SUB)
468 self._iopub_socket = self._context.socket(zmq.SUB)
464 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
469 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
465 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
470 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
466 connect_socket(self._iopub_socket, content.iopub)
471 connect_socket(self._iopub_socket, content.iopub)
467 self._update_engines(dict(content.engines))
472 self._update_engines(dict(content.engines))
468 else:
473 else:
469 self._connected = False
474 self._connected = False
470 raise Exception("Failed to connect!")
475 raise Exception("Failed to connect!")
471
476
472 #--------------------------------------------------------------------------
477 #--------------------------------------------------------------------------
473 # handlers and callbacks for incoming messages
478 # handlers and callbacks for incoming messages
474 #--------------------------------------------------------------------------
479 #--------------------------------------------------------------------------
475
480
476 def _unwrap_exception(self, content):
481 def _unwrap_exception(self, content):
477 """unwrap exception, and remap engine_id to int."""
482 """unwrap exception, and remap engine_id to int."""
478 e = error.unwrap_exception(content)
483 e = error.unwrap_exception(content)
479 # print e.traceback
484 # print e.traceback
480 if e.engine_info:
485 if e.engine_info:
481 e_uuid = e.engine_info['engine_uuid']
486 e_uuid = e.engine_info['engine_uuid']
482 eid = self._engines[e_uuid]
487 eid = self._engines[e_uuid]
483 e.engine_info['engine_id'] = eid
488 e.engine_info['engine_id'] = eid
484 return e
489 return e
485
490
486 def _extract_metadata(self, header, parent, content):
491 def _extract_metadata(self, header, parent, content):
487 md = {'msg_id' : parent['msg_id'],
492 md = {'msg_id' : parent['msg_id'],
488 'received' : datetime.now(),
493 'received' : datetime.now(),
489 'engine_uuid' : header.get('engine', None),
494 'engine_uuid' : header.get('engine', None),
490 'follow' : parent.get('follow', []),
495 'follow' : parent.get('follow', []),
491 'after' : parent.get('after', []),
496 'after' : parent.get('after', []),
492 'status' : content['status'],
497 'status' : content['status'],
493 }
498 }
494
499
495 if md['engine_uuid'] is not None:
500 if md['engine_uuid'] is not None:
496 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
501 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
497
502
498 if 'date' in parent:
503 if 'date' in parent:
499 md['submitted'] = parent['date']
504 md['submitted'] = parent['date']
500 if 'started' in header:
505 if 'started' in header:
501 md['started'] = header['started']
506 md['started'] = header['started']
502 if 'date' in header:
507 if 'date' in header:
503 md['completed'] = header['date']
508 md['completed'] = header['date']
504 return md
509 return md
505
510
506 def _register_engine(self, msg):
511 def _register_engine(self, msg):
507 """Register a new engine, and update our connection info."""
512 """Register a new engine, and update our connection info."""
508 content = msg['content']
513 content = msg['content']
509 eid = content['id']
514 eid = content['id']
510 d = {eid : content['queue']}
515 d = {eid : content['queue']}
511 self._update_engines(d)
516 self._update_engines(d)
512
517
513 def _unregister_engine(self, msg):
518 def _unregister_engine(self, msg):
514 """Unregister an engine that has died."""
519 """Unregister an engine that has died."""
515 content = msg['content']
520 content = msg['content']
516 eid = int(content['id'])
521 eid = int(content['id'])
517 if eid in self._ids:
522 if eid in self._ids:
518 self._ids.remove(eid)
523 self._ids.remove(eid)
519 uuid = self._engines.pop(eid)
524 uuid = self._engines.pop(eid)
520
525
521 self._handle_stranded_msgs(eid, uuid)
526 self._handle_stranded_msgs(eid, uuid)
522
527
523 if self._task_socket and self._task_scheme == 'pure':
528 if self._task_socket and self._task_scheme == 'pure':
524 self._stop_scheduling_tasks()
529 self._stop_scheduling_tasks()
525
530
526 def _handle_stranded_msgs(self, eid, uuid):
531 def _handle_stranded_msgs(self, eid, uuid):
527 """Handle messages known to be on an engine when the engine unregisters.
532 """Handle messages known to be on an engine when the engine unregisters.
528
533
529 It is possible that this will fire prematurely - that is, an engine will
534 It is possible that this will fire prematurely - that is, an engine will
530 go down after completing a result, and the client will be notified
535 go down after completing a result, and the client will be notified
531 of the unregistration and later receive the successful result.
536 of the unregistration and later receive the successful result.
532 """
537 """
533
538
534 outstanding = self._outstanding_dict[uuid]
539 outstanding = self._outstanding_dict[uuid]
535
540
536 for msg_id in list(outstanding):
541 for msg_id in list(outstanding):
537 if msg_id in self.results:
542 if msg_id in self.results:
538 # we already
543 # we already
539 continue
544 continue
540 try:
545 try:
541 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
546 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
542 except:
547 except:
543 content = error.wrap_exception()
548 content = error.wrap_exception()
544 # build a fake message:
549 # build a fake message:
545 parent = {}
550 parent = {}
546 header = {}
551 header = {}
547 parent['msg_id'] = msg_id
552 parent['msg_id'] = msg_id
548 header['engine'] = uuid
553 header['engine'] = uuid
549 header['date'] = datetime.now()
554 header['date'] = datetime.now()
550 msg = dict(parent_header=parent, header=header, content=content)
555 msg = dict(parent_header=parent, header=header, content=content)
551 self._handle_apply_reply(msg)
556 self._handle_apply_reply(msg)
552
557
553 def _handle_execute_reply(self, msg):
558 def _handle_execute_reply(self, msg):
554 """Save the reply to an execute_request into our results.
559 """Save the reply to an execute_request into our results.
555
560
556 execute messages are never actually used. apply is used instead.
561 execute messages are never actually used. apply is used instead.
557 """
562 """
558
563
559 parent = msg['parent_header']
564 parent = msg['parent_header']
560 msg_id = parent['msg_id']
565 msg_id = parent['msg_id']
561 if msg_id not in self.outstanding:
566 if msg_id not in self.outstanding:
562 if msg_id in self.history:
567 if msg_id in self.history:
563 print ("got stale result: %s"%msg_id)
568 print ("got stale result: %s"%msg_id)
564 else:
569 else:
565 print ("got unknown result: %s"%msg_id)
570 print ("got unknown result: %s"%msg_id)
566 else:
571 else:
567 self.outstanding.remove(msg_id)
572 self.outstanding.remove(msg_id)
568 self.results[msg_id] = self._unwrap_exception(msg['content'])
573 self.results[msg_id] = self._unwrap_exception(msg['content'])
569
574
570 def _handle_apply_reply(self, msg):
575 def _handle_apply_reply(self, msg):
571 """Save the reply to an apply_request into our results."""
576 """Save the reply to an apply_request into our results."""
572 parent = msg['parent_header']
577 parent = msg['parent_header']
573 msg_id = parent['msg_id']
578 msg_id = parent['msg_id']
574 if msg_id not in self.outstanding:
579 if msg_id not in self.outstanding:
575 if msg_id in self.history:
580 if msg_id in self.history:
576 print ("got stale result: %s"%msg_id)
581 print ("got stale result: %s"%msg_id)
577 print self.results[msg_id]
582 print self.results[msg_id]
578 print msg
583 print msg
579 else:
584 else:
580 print ("got unknown result: %s"%msg_id)
585 print ("got unknown result: %s"%msg_id)
581 else:
586 else:
582 self.outstanding.remove(msg_id)
587 self.outstanding.remove(msg_id)
583 content = msg['content']
588 content = msg['content']
584 header = msg['header']
589 header = msg['header']
585
590
586 # construct metadata:
591 # construct metadata:
587 md = self.metadata[msg_id]
592 md = self.metadata[msg_id]
588 md.update(self._extract_metadata(header, parent, content))
593 md.update(self._extract_metadata(header, parent, content))
589 # is this redundant?
594 # is this redundant?
590 self.metadata[msg_id] = md
595 self.metadata[msg_id] = md
591
596
592 e_outstanding = self._outstanding_dict[md['engine_uuid']]
597 e_outstanding = self._outstanding_dict[md['engine_uuid']]
593 if msg_id in e_outstanding:
598 if msg_id in e_outstanding:
594 e_outstanding.remove(msg_id)
599 e_outstanding.remove(msg_id)
595
600
596 # construct result:
601 # construct result:
597 if content['status'] == 'ok':
602 if content['status'] == 'ok':
598 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
603 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
599 elif content['status'] == 'aborted':
604 elif content['status'] == 'aborted':
600 self.results[msg_id] = error.TaskAborted(msg_id)
605 self.results[msg_id] = error.TaskAborted(msg_id)
601 elif content['status'] == 'resubmitted':
606 elif content['status'] == 'resubmitted':
602 # TODO: handle resubmission
607 # TODO: handle resubmission
603 pass
608 pass
604 else:
609 else:
605 self.results[msg_id] = self._unwrap_exception(content)
610 self.results[msg_id] = self._unwrap_exception(content)
606
611
607 def _flush_notifications(self):
612 def _flush_notifications(self):
608 """Flush notifications of engine registrations waiting
613 """Flush notifications of engine registrations waiting
609 in ZMQ queue."""
614 in ZMQ queue."""
610 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
615 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
611 while msg is not None:
616 while msg is not None:
612 if self.debug:
617 if self.debug:
613 pprint(msg)
618 pprint(msg)
614 msg_type = msg['msg_type']
619 msg_type = msg['msg_type']
615 handler = self._notification_handlers.get(msg_type, None)
620 handler = self._notification_handlers.get(msg_type, None)
616 if handler is None:
621 if handler is None:
617 raise Exception("Unhandled message type: %s"%msg.msg_type)
622 raise Exception("Unhandled message type: %s"%msg.msg_type)
618 else:
623 else:
619 handler(msg)
624 handler(msg)
620 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
625 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
621
626
622 def _flush_results(self, sock):
627 def _flush_results(self, sock):
623 """Flush task or queue results waiting in ZMQ queue."""
628 """Flush task or queue results waiting in ZMQ queue."""
624 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
629 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
625 while msg is not None:
630 while msg is not None:
626 if self.debug:
631 if self.debug:
627 pprint(msg)
632 pprint(msg)
628 msg_type = msg['msg_type']
633 msg_type = msg['msg_type']
629 handler = self._queue_handlers.get(msg_type, None)
634 handler = self._queue_handlers.get(msg_type, None)
630 if handler is None:
635 if handler is None:
631 raise Exception("Unhandled message type: %s"%msg.msg_type)
636 raise Exception("Unhandled message type: %s"%msg.msg_type)
632 else:
637 else:
633 handler(msg)
638 handler(msg)
634 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
639 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
635
640
636 def _flush_control(self, sock):
641 def _flush_control(self, sock):
637 """Flush replies from the control channel waiting
642 """Flush replies from the control channel waiting
638 in the ZMQ queue.
643 in the ZMQ queue.
639
644
640 Currently: ignore them."""
645 Currently: ignore them."""
641 if self._ignored_control_replies <= 0:
646 if self._ignored_control_replies <= 0:
642 return
647 return
643 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
648 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
644 while msg is not None:
649 while msg is not None:
645 self._ignored_control_replies -= 1
650 self._ignored_control_replies -= 1
646 if self.debug:
651 if self.debug:
647 pprint(msg)
652 pprint(msg)
648 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
653 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
649
654
650 def _flush_ignored_control(self):
655 def _flush_ignored_control(self):
651 """flush ignored control replies"""
656 """flush ignored control replies"""
652 while self._ignored_control_replies > 0:
657 while self._ignored_control_replies > 0:
653 self.session.recv(self._control_socket)
658 self.session.recv(self._control_socket)
654 self._ignored_control_replies -= 1
659 self._ignored_control_replies -= 1
655
660
656 def _flush_ignored_hub_replies(self):
661 def _flush_ignored_hub_replies(self):
657 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
662 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
658 while msg is not None:
663 while msg is not None:
659 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
664 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
660
665
661 def _flush_iopub(self, sock):
666 def _flush_iopub(self, sock):
662 """Flush replies from the iopub channel waiting
667 """Flush replies from the iopub channel waiting
663 in the ZMQ queue.
668 in the ZMQ queue.
664 """
669 """
665 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
670 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
666 while msg is not None:
671 while msg is not None:
667 if self.debug:
672 if self.debug:
668 pprint(msg)
673 pprint(msg)
669 parent = msg['parent_header']
674 parent = msg['parent_header']
670 msg_id = parent['msg_id']
675 msg_id = parent['msg_id']
671 content = msg['content']
676 content = msg['content']
672 header = msg['header']
677 header = msg['header']
673 msg_type = msg['msg_type']
678 msg_type = msg['msg_type']
674
679
675 # init metadata:
680 # init metadata:
676 md = self.metadata[msg_id]
681 md = self.metadata[msg_id]
677
682
678 if msg_type == 'stream':
683 if msg_type == 'stream':
679 name = content['name']
684 name = content['name']
680 s = md[name] or ''
685 s = md[name] or ''
681 md[name] = s + content['data']
686 md[name] = s + content['data']
682 elif msg_type == 'pyerr':
687 elif msg_type == 'pyerr':
683 md.update({'pyerr' : self._unwrap_exception(content)})
688 md.update({'pyerr' : self._unwrap_exception(content)})
684 elif msg_type == 'pyin':
689 elif msg_type == 'pyin':
685 md.update({'pyin' : content['code']})
690 md.update({'pyin' : content['code']})
686 else:
691 else:
687 md.update({msg_type : content.get('data', '')})
692 md.update({msg_type : content.get('data', '')})
688
693
689 # reduntant?
694 # reduntant?
690 self.metadata[msg_id] = md
695 self.metadata[msg_id] = md
691
696
692 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
697 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
693
698
694 #--------------------------------------------------------------------------
699 #--------------------------------------------------------------------------
695 # len, getitem
700 # len, getitem
696 #--------------------------------------------------------------------------
701 #--------------------------------------------------------------------------
697
702
698 def __len__(self):
703 def __len__(self):
699 """len(client) returns # of engines."""
704 """len(client) returns # of engines."""
700 return len(self.ids)
705 return len(self.ids)
701
706
702 def __getitem__(self, key):
707 def __getitem__(self, key):
703 """index access returns DirectView multiplexer objects
708 """index access returns DirectView multiplexer objects
704
709
705 Must be int, slice, or list/tuple/xrange of ints"""
710 Must be int, slice, or list/tuple/xrange of ints"""
706 if not isinstance(key, (int, slice, tuple, list, xrange)):
711 if not isinstance(key, (int, slice, tuple, list, xrange)):
707 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
712 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
708 else:
713 else:
709 return self.direct_view(key)
714 return self.direct_view(key)
710
715
711 #--------------------------------------------------------------------------
716 #--------------------------------------------------------------------------
712 # Begin public methods
717 # Begin public methods
713 #--------------------------------------------------------------------------
718 #--------------------------------------------------------------------------
714
719
715 @property
720 @property
716 def ids(self):
721 def ids(self):
717 """Always up-to-date ids property."""
722 """Always up-to-date ids property."""
718 self._flush_notifications()
723 self._flush_notifications()
719 # always copy:
724 # always copy:
720 return list(self._ids)
725 return list(self._ids)
721
726
722 def close(self):
727 def close(self):
723 if self._closed:
728 if self._closed:
724 return
729 return
725 snames = filter(lambda n: n.endswith('socket'), dir(self))
730 snames = filter(lambda n: n.endswith('socket'), dir(self))
726 for socket in map(lambda name: getattr(self, name), snames):
731 for socket in map(lambda name: getattr(self, name), snames):
727 if isinstance(socket, zmq.Socket) and not socket.closed:
732 if isinstance(socket, zmq.Socket) and not socket.closed:
728 socket.close()
733 socket.close()
729 self._closed = True
734 self._closed = True
730
735
731 def spin(self):
736 def spin(self):
732 """Flush any registration notifications and execution results
737 """Flush any registration notifications and execution results
733 waiting in the ZMQ queue.
738 waiting in the ZMQ queue.
734 """
739 """
735 if self._notification_socket:
740 if self._notification_socket:
736 self._flush_notifications()
741 self._flush_notifications()
737 if self._mux_socket:
742 if self._mux_socket:
738 self._flush_results(self._mux_socket)
743 self._flush_results(self._mux_socket)
739 if self._task_socket:
744 if self._task_socket:
740 self._flush_results(self._task_socket)
745 self._flush_results(self._task_socket)
741 if self._control_socket:
746 if self._control_socket:
742 self._flush_control(self._control_socket)
747 self._flush_control(self._control_socket)
743 if self._iopub_socket:
748 if self._iopub_socket:
744 self._flush_iopub(self._iopub_socket)
749 self._flush_iopub(self._iopub_socket)
745 if self._query_socket:
750 if self._query_socket:
746 self._flush_ignored_hub_replies()
751 self._flush_ignored_hub_replies()
747
752
748 def wait(self, jobs=None, timeout=-1):
753 def wait(self, jobs=None, timeout=-1):
749 """waits on one or more `jobs`, for up to `timeout` seconds.
754 """waits on one or more `jobs`, for up to `timeout` seconds.
750
755
751 Parameters
756 Parameters
752 ----------
757 ----------
753
758
754 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
759 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
755 ints are indices to self.history
760 ints are indices to self.history
756 strs are msg_ids
761 strs are msg_ids
757 default: wait on all outstanding messages
762 default: wait on all outstanding messages
758 timeout : float
763 timeout : float
759 a time in seconds, after which to give up.
764 a time in seconds, after which to give up.
760 default is -1, which means no timeout
765 default is -1, which means no timeout
761
766
762 Returns
767 Returns
763 -------
768 -------
764
769
765 True : when all msg_ids are done
770 True : when all msg_ids are done
766 False : timeout reached, some msg_ids still outstanding
771 False : timeout reached, some msg_ids still outstanding
767 """
772 """
768 tic = time.time()
773 tic = time.time()
769 if jobs is None:
774 if jobs is None:
770 theids = self.outstanding
775 theids = self.outstanding
771 else:
776 else:
772 if isinstance(jobs, (int, str, AsyncResult)):
777 if isinstance(jobs, (int, str, AsyncResult)):
773 jobs = [jobs]
778 jobs = [jobs]
774 theids = set()
779 theids = set()
775 for job in jobs:
780 for job in jobs:
776 if isinstance(job, int):
781 if isinstance(job, int):
777 # index access
782 # index access
778 job = self.history[job]
783 job = self.history[job]
779 elif isinstance(job, AsyncResult):
784 elif isinstance(job, AsyncResult):
780 map(theids.add, job.msg_ids)
785 map(theids.add, job.msg_ids)
781 continue
786 continue
782 theids.add(job)
787 theids.add(job)
783 if not theids.intersection(self.outstanding):
788 if not theids.intersection(self.outstanding):
784 return True
789 return True
785 self.spin()
790 self.spin()
786 while theids.intersection(self.outstanding):
791 while theids.intersection(self.outstanding):
787 if timeout >= 0 and ( time.time()-tic ) > timeout:
792 if timeout >= 0 and ( time.time()-tic ) > timeout:
788 break
793 break
789 time.sleep(1e-3)
794 time.sleep(1e-3)
790 self.spin()
795 self.spin()
791 return len(theids.intersection(self.outstanding)) == 0
796 return len(theids.intersection(self.outstanding)) == 0
792
797
793 #--------------------------------------------------------------------------
798 #--------------------------------------------------------------------------
794 # Control methods
799 # Control methods
795 #--------------------------------------------------------------------------
800 #--------------------------------------------------------------------------
796
801
797 @spin_first
802 @spin_first
798 def clear(self, targets=None, block=None):
803 def clear(self, targets=None, block=None):
799 """Clear the namespace in target(s)."""
804 """Clear the namespace in target(s)."""
800 block = self.block if block is None else block
805 block = self.block if block is None else block
801 targets = self._build_targets(targets)[0]
806 targets = self._build_targets(targets)[0]
802 for t in targets:
807 for t in targets:
803 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
808 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
804 error = False
809 error = False
805 if block:
810 if block:
806 self._flush_ignored_control()
811 self._flush_ignored_control()
807 for i in range(len(targets)):
812 for i in range(len(targets)):
808 idents,msg = self.session.recv(self._control_socket,0)
813 idents,msg = self.session.recv(self._control_socket,0)
809 if self.debug:
814 if self.debug:
810 pprint(msg)
815 pprint(msg)
811 if msg['content']['status'] != 'ok':
816 if msg['content']['status'] != 'ok':
812 error = self._unwrap_exception(msg['content'])
817 error = self._unwrap_exception(msg['content'])
813 else:
818 else:
814 self._ignored_control_replies += len(targets)
819 self._ignored_control_replies += len(targets)
815 if error:
820 if error:
816 raise error
821 raise error
817
822
818
823
819 @spin_first
824 @spin_first
820 def abort(self, jobs=None, targets=None, block=None):
825 def abort(self, jobs=None, targets=None, block=None):
821 """Abort specific jobs from the execution queues of target(s).
826 """Abort specific jobs from the execution queues of target(s).
822
827
823 This is a mechanism to prevent jobs that have already been submitted
828 This is a mechanism to prevent jobs that have already been submitted
824 from executing.
829 from executing.
825
830
826 Parameters
831 Parameters
827 ----------
832 ----------
828
833
829 jobs : msg_id, list of msg_ids, or AsyncResult
834 jobs : msg_id, list of msg_ids, or AsyncResult
830 The jobs to be aborted
835 The jobs to be aborted
831
836
832
837
833 """
838 """
834 block = self.block if block is None else block
839 block = self.block if block is None else block
835 targets = self._build_targets(targets)[0]
840 targets = self._build_targets(targets)[0]
836 msg_ids = []
841 msg_ids = []
837 if isinstance(jobs, (basestring,AsyncResult)):
842 if isinstance(jobs, (basestring,AsyncResult)):
838 jobs = [jobs]
843 jobs = [jobs]
839 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
844 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
840 if bad_ids:
845 if bad_ids:
841 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
846 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
842 for j in jobs:
847 for j in jobs:
843 if isinstance(j, AsyncResult):
848 if isinstance(j, AsyncResult):
844 msg_ids.extend(j.msg_ids)
849 msg_ids.extend(j.msg_ids)
845 else:
850 else:
846 msg_ids.append(j)
851 msg_ids.append(j)
847 content = dict(msg_ids=msg_ids)
852 content = dict(msg_ids=msg_ids)
848 for t in targets:
853 for t in targets:
849 self.session.send(self._control_socket, 'abort_request',
854 self.session.send(self._control_socket, 'abort_request',
850 content=content, ident=t)
855 content=content, ident=t)
851 error = False
856 error = False
852 if block:
857 if block:
853 self._flush_ignored_control()
858 self._flush_ignored_control()
854 for i in range(len(targets)):
859 for i in range(len(targets)):
855 idents,msg = self.session.recv(self._control_socket,0)
860 idents,msg = self.session.recv(self._control_socket,0)
856 if self.debug:
861 if self.debug:
857 pprint(msg)
862 pprint(msg)
858 if msg['content']['status'] != 'ok':
863 if msg['content']['status'] != 'ok':
859 error = self._unwrap_exception(msg['content'])
864 error = self._unwrap_exception(msg['content'])
860 else:
865 else:
861 self._ignored_control_replies += len(targets)
866 self._ignored_control_replies += len(targets)
862 if error:
867 if error:
863 raise error
868 raise error
864
869
865 @spin_first
870 @spin_first
866 def shutdown(self, targets=None, restart=False, hub=False, block=None):
871 def shutdown(self, targets=None, restart=False, hub=False, block=None):
867 """Terminates one or more engine processes, optionally including the hub."""
872 """Terminates one or more engine processes, optionally including the hub."""
868 block = self.block if block is None else block
873 block = self.block if block is None else block
869 if hub:
874 if hub:
870 targets = 'all'
875 targets = 'all'
871 targets = self._build_targets(targets)[0]
876 targets = self._build_targets(targets)[0]
872 for t in targets:
877 for t in targets:
873 self.session.send(self._control_socket, 'shutdown_request',
878 self.session.send(self._control_socket, 'shutdown_request',
874 content={'restart':restart},ident=t)
879 content={'restart':restart},ident=t)
875 error = False
880 error = False
876 if block or hub:
881 if block or hub:
877 self._flush_ignored_control()
882 self._flush_ignored_control()
878 for i in range(len(targets)):
883 for i in range(len(targets)):
879 idents,msg = self.session.recv(self._control_socket, 0)
884 idents,msg = self.session.recv(self._control_socket, 0)
880 if self.debug:
885 if self.debug:
881 pprint(msg)
886 pprint(msg)
882 if msg['content']['status'] != 'ok':
887 if msg['content']['status'] != 'ok':
883 error = self._unwrap_exception(msg['content'])
888 error = self._unwrap_exception(msg['content'])
884 else:
889 else:
885 self._ignored_control_replies += len(targets)
890 self._ignored_control_replies += len(targets)
886
891
887 if hub:
892 if hub:
888 time.sleep(0.25)
893 time.sleep(0.25)
889 self.session.send(self._query_socket, 'shutdown_request')
894 self.session.send(self._query_socket, 'shutdown_request')
890 idents,msg = self.session.recv(self._query_socket, 0)
895 idents,msg = self.session.recv(self._query_socket, 0)
891 if self.debug:
896 if self.debug:
892 pprint(msg)
897 pprint(msg)
893 if msg['content']['status'] != 'ok':
898 if msg['content']['status'] != 'ok':
894 error = self._unwrap_exception(msg['content'])
899 error = self._unwrap_exception(msg['content'])
895
900
896 if error:
901 if error:
897 raise error
902 raise error
898
903
899 #--------------------------------------------------------------------------
904 #--------------------------------------------------------------------------
900 # Execution related methods
905 # Execution related methods
901 #--------------------------------------------------------------------------
906 #--------------------------------------------------------------------------
902
907
903 def _maybe_raise(self, result):
908 def _maybe_raise(self, result):
904 """wrapper for maybe raising an exception if apply failed."""
909 """wrapper for maybe raising an exception if apply failed."""
905 if isinstance(result, error.RemoteError):
910 if isinstance(result, error.RemoteError):
906 raise result
911 raise result
907
912
908 return result
913 return result
909
914
910 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
915 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
911 ident=None):
916 ident=None):
912 """construct and send an apply message via a socket.
917 """construct and send an apply message via a socket.
913
918
914 This is the principal method with which all engine execution is performed by views.
919 This is the principal method with which all engine execution is performed by views.
915 """
920 """
916
921
917 assert not self._closed, "cannot use me anymore, I'm closed!"
922 assert not self._closed, "cannot use me anymore, I'm closed!"
918 # defaults:
923 # defaults:
919 args = args if args is not None else []
924 args = args if args is not None else []
920 kwargs = kwargs if kwargs is not None else {}
925 kwargs = kwargs if kwargs is not None else {}
921 subheader = subheader if subheader is not None else {}
926 subheader = subheader if subheader is not None else {}
922
927
923 # validate arguments
928 # validate arguments
924 if not callable(f):
929 if not callable(f):
925 raise TypeError("f must be callable, not %s"%type(f))
930 raise TypeError("f must be callable, not %s"%type(f))
926 if not isinstance(args, (tuple, list)):
931 if not isinstance(args, (tuple, list)):
927 raise TypeError("args must be tuple or list, not %s"%type(args))
932 raise TypeError("args must be tuple or list, not %s"%type(args))
928 if not isinstance(kwargs, dict):
933 if not isinstance(kwargs, dict):
929 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
934 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
930 if not isinstance(subheader, dict):
935 if not isinstance(subheader, dict):
931 raise TypeError("subheader must be dict, not %s"%type(subheader))
936 raise TypeError("subheader must be dict, not %s"%type(subheader))
932
937
933 bufs = util.pack_apply_message(f,args,kwargs)
938 bufs = util.pack_apply_message(f,args,kwargs)
934
939
935 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
940 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
936 subheader=subheader, track=track)
941 subheader=subheader, track=track)
937
942
938 msg_id = msg['msg_id']
943 msg_id = msg['msg_id']
939 self.outstanding.add(msg_id)
944 self.outstanding.add(msg_id)
940 if ident:
945 if ident:
941 # possibly routed to a specific engine
946 # possibly routed to a specific engine
942 if isinstance(ident, list):
947 if isinstance(ident, list):
943 ident = ident[-1]
948 ident = ident[-1]
944 if ident in self._engines.values():
949 if ident in self._engines.values():
945 # save for later, in case of engine death
950 # save for later, in case of engine death
946 self._outstanding_dict[ident].add(msg_id)
951 self._outstanding_dict[ident].add(msg_id)
947 self.history.append(msg_id)
952 self.history.append(msg_id)
948 self.metadata[msg_id]['submitted'] = datetime.now()
953 self.metadata[msg_id]['submitted'] = datetime.now()
949
954
950 return msg
955 return msg
951
956
952 #--------------------------------------------------------------------------
957 #--------------------------------------------------------------------------
953 # construct a View object
958 # construct a View object
954 #--------------------------------------------------------------------------
959 #--------------------------------------------------------------------------
955
960
956 def load_balanced_view(self, targets=None):
961 def load_balanced_view(self, targets=None):
957 """construct a DirectView object.
962 """construct a DirectView object.
958
963
959 If no arguments are specified, create a LoadBalancedView
964 If no arguments are specified, create a LoadBalancedView
960 using all engines.
965 using all engines.
961
966
962 Parameters
967 Parameters
963 ----------
968 ----------
964
969
965 targets: list,slice,int,etc. [default: use all engines]
970 targets: list,slice,int,etc. [default: use all engines]
966 The subset of engines across which to load-balance
971 The subset of engines across which to load-balance
967 """
972 """
968 if targets is not None:
973 if targets is not None:
969 targets = self._build_targets(targets)[1]
974 targets = self._build_targets(targets)[1]
970 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
975 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
971
976
972 def direct_view(self, targets='all'):
977 def direct_view(self, targets='all'):
973 """construct a DirectView object.
978 """construct a DirectView object.
974
979
975 If no targets are specified, create a DirectView
980 If no targets are specified, create a DirectView
976 using all engines.
981 using all engines.
977
982
978 Parameters
983 Parameters
979 ----------
984 ----------
980
985
981 targets: list,slice,int,etc. [default: use all engines]
986 targets: list,slice,int,etc. [default: use all engines]
982 The engines to use for the View
987 The engines to use for the View
983 """
988 """
984 single = isinstance(targets, int)
989 single = isinstance(targets, int)
985 targets = self._build_targets(targets)[1]
990 targets = self._build_targets(targets)[1]
986 if single:
991 if single:
987 targets = targets[0]
992 targets = targets[0]
988 return DirectView(client=self, socket=self._mux_socket, targets=targets)
993 return DirectView(client=self, socket=self._mux_socket, targets=targets)
989
994
990 #--------------------------------------------------------------------------
995 #--------------------------------------------------------------------------
991 # Query methods
996 # Query methods
992 #--------------------------------------------------------------------------
997 #--------------------------------------------------------------------------
993
998
994 @spin_first
999 @spin_first
995 def get_result(self, indices_or_msg_ids=None, block=None):
1000 def get_result(self, indices_or_msg_ids=None, block=None):
996 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1001 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
997
1002
998 If the client already has the results, no request to the Hub will be made.
1003 If the client already has the results, no request to the Hub will be made.
999
1004
1000 This is a convenient way to construct AsyncResult objects, which are wrappers
1005 This is a convenient way to construct AsyncResult objects, which are wrappers
1001 that include metadata about execution, and allow for awaiting results that
1006 that include metadata about execution, and allow for awaiting results that
1002 were not submitted by this Client.
1007 were not submitted by this Client.
1003
1008
1004 It can also be a convenient way to retrieve the metadata associated with
1009 It can also be a convenient way to retrieve the metadata associated with
1005 blocking execution, since it always retrieves
1010 blocking execution, since it always retrieves
1006
1011
1007 Examples
1012 Examples
1008 --------
1013 --------
1009 ::
1014 ::
1010
1015
1011 In [10]: r = client.apply()
1016 In [10]: r = client.apply()
1012
1017
1013 Parameters
1018 Parameters
1014 ----------
1019 ----------
1015
1020
1016 indices_or_msg_ids : integer history index, str msg_id, or list of either
1021 indices_or_msg_ids : integer history index, str msg_id, or list of either
1017 The indices or msg_ids of indices to be retrieved
1022 The indices or msg_ids of indices to be retrieved
1018
1023
1019 block : bool
1024 block : bool
1020 Whether to wait for the result to be done
1025 Whether to wait for the result to be done
1021
1026
1022 Returns
1027 Returns
1023 -------
1028 -------
1024
1029
1025 AsyncResult
1030 AsyncResult
1026 A single AsyncResult object will always be returned.
1031 A single AsyncResult object will always be returned.
1027
1032
1028 AsyncHubResult
1033 AsyncHubResult
1029 A subclass of AsyncResult that retrieves results from the Hub
1034 A subclass of AsyncResult that retrieves results from the Hub
1030
1035
1031 """
1036 """
1032 block = self.block if block is None else block
1037 block = self.block if block is None else block
1033 if indices_or_msg_ids is None:
1038 if indices_or_msg_ids is None:
1034 indices_or_msg_ids = -1
1039 indices_or_msg_ids = -1
1035
1040
1036 if not isinstance(indices_or_msg_ids, (list,tuple)):
1041 if not isinstance(indices_or_msg_ids, (list,tuple)):
1037 indices_or_msg_ids = [indices_or_msg_ids]
1042 indices_or_msg_ids = [indices_or_msg_ids]
1038
1043
1039 theids = []
1044 theids = []
1040 for id in indices_or_msg_ids:
1045 for id in indices_or_msg_ids:
1041 if isinstance(id, int):
1046 if isinstance(id, int):
1042 id = self.history[id]
1047 id = self.history[id]
1043 if not isinstance(id, str):
1048 if not isinstance(id, str):
1044 raise TypeError("indices must be str or int, not %r"%id)
1049 raise TypeError("indices must be str or int, not %r"%id)
1045 theids.append(id)
1050 theids.append(id)
1046
1051
1047 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1052 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1048 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1053 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1049
1054
1050 if remote_ids:
1055 if remote_ids:
1051 ar = AsyncHubResult(self, msg_ids=theids)
1056 ar = AsyncHubResult(self, msg_ids=theids)
1052 else:
1057 else:
1053 ar = AsyncResult(self, msg_ids=theids)
1058 ar = AsyncResult(self, msg_ids=theids)
1054
1059
1055 if block:
1060 if block:
1056 ar.wait()
1061 ar.wait()
1057
1062
1058 return ar
1063 return ar
1059
1064
1060 @spin_first
1065 @spin_first
1061 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1066 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1062 """Resubmit one or more tasks.
1067 """Resubmit one or more tasks.
1063
1068
1064 in-flight tasks may not be resubmitted.
1069 in-flight tasks may not be resubmitted.
1065
1070
1066 Parameters
1071 Parameters
1067 ----------
1072 ----------
1068
1073
1069 indices_or_msg_ids : integer history index, str msg_id, or list of either
1074 indices_or_msg_ids : integer history index, str msg_id, or list of either
1070 The indices or msg_ids of indices to be retrieved
1075 The indices or msg_ids of indices to be retrieved
1071
1076
1072 block : bool
1077 block : bool
1073 Whether to wait for the result to be done
1078 Whether to wait for the result to be done
1074
1079
1075 Returns
1080 Returns
1076 -------
1081 -------
1077
1082
1078 AsyncHubResult
1083 AsyncHubResult
1079 A subclass of AsyncResult that retrieves results from the Hub
1084 A subclass of AsyncResult that retrieves results from the Hub
1080
1085
1081 """
1086 """
1082 block = self.block if block is None else block
1087 block = self.block if block is None else block
1083 if indices_or_msg_ids is None:
1088 if indices_or_msg_ids is None:
1084 indices_or_msg_ids = -1
1089 indices_or_msg_ids = -1
1085
1090
1086 if not isinstance(indices_or_msg_ids, (list,tuple)):
1091 if not isinstance(indices_or_msg_ids, (list,tuple)):
1087 indices_or_msg_ids = [indices_or_msg_ids]
1092 indices_or_msg_ids = [indices_or_msg_ids]
1088
1093
1089 theids = []
1094 theids = []
1090 for id in indices_or_msg_ids:
1095 for id in indices_or_msg_ids:
1091 if isinstance(id, int):
1096 if isinstance(id, int):
1092 id = self.history[id]
1097 id = self.history[id]
1093 if not isinstance(id, str):
1098 if not isinstance(id, str):
1094 raise TypeError("indices must be str or int, not %r"%id)
1099 raise TypeError("indices must be str or int, not %r"%id)
1095 theids.append(id)
1100 theids.append(id)
1096
1101
1097 for msg_id in theids:
1102 for msg_id in theids:
1098 self.outstanding.discard(msg_id)
1103 self.outstanding.discard(msg_id)
1099 if msg_id in self.history:
1104 if msg_id in self.history:
1100 self.history.remove(msg_id)
1105 self.history.remove(msg_id)
1101 self.results.pop(msg_id, None)
1106 self.results.pop(msg_id, None)
1102 self.metadata.pop(msg_id, None)
1107 self.metadata.pop(msg_id, None)
1103 content = dict(msg_ids = theids)
1108 content = dict(msg_ids = theids)
1104
1109
1105 self.session.send(self._query_socket, 'resubmit_request', content)
1110 self.session.send(self._query_socket, 'resubmit_request', content)
1106
1111
1107 zmq.select([self._query_socket], [], [])
1112 zmq.select([self._query_socket], [], [])
1108 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1113 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1109 if self.debug:
1114 if self.debug:
1110 pprint(msg)
1115 pprint(msg)
1111 content = msg['content']
1116 content = msg['content']
1112 if content['status'] != 'ok':
1117 if content['status'] != 'ok':
1113 raise self._unwrap_exception(content)
1118 raise self._unwrap_exception(content)
1114
1119
1115 ar = AsyncHubResult(self, msg_ids=theids)
1120 ar = AsyncHubResult(self, msg_ids=theids)
1116
1121
1117 if block:
1122 if block:
1118 ar.wait()
1123 ar.wait()
1119
1124
1120 return ar
1125 return ar
1121
1126
1122 @spin_first
1127 @spin_first
1123 def result_status(self, msg_ids, status_only=True):
1128 def result_status(self, msg_ids, status_only=True):
1124 """Check on the status of the result(s) of the apply request with `msg_ids`.
1129 """Check on the status of the result(s) of the apply request with `msg_ids`.
1125
1130
1126 If status_only is False, then the actual results will be retrieved, else
1131 If status_only is False, then the actual results will be retrieved, else
1127 only the status of the results will be checked.
1132 only the status of the results will be checked.
1128
1133
1129 Parameters
1134 Parameters
1130 ----------
1135 ----------
1131
1136
1132 msg_ids : list of msg_ids
1137 msg_ids : list of msg_ids
1133 if int:
1138 if int:
1134 Passed as index to self.history for convenience.
1139 Passed as index to self.history for convenience.
1135 status_only : bool (default: True)
1140 status_only : bool (default: True)
1136 if False:
1141 if False:
1137 Retrieve the actual results of completed tasks.
1142 Retrieve the actual results of completed tasks.
1138
1143
1139 Returns
1144 Returns
1140 -------
1145 -------
1141
1146
1142 results : dict
1147 results : dict
1143 There will always be the keys 'pending' and 'completed', which will
1148 There will always be the keys 'pending' and 'completed', which will
1144 be lists of msg_ids that are incomplete or complete. If `status_only`
1149 be lists of msg_ids that are incomplete or complete. If `status_only`
1145 is False, then completed results will be keyed by their `msg_id`.
1150 is False, then completed results will be keyed by their `msg_id`.
1146 """
1151 """
1147 if not isinstance(msg_ids, (list,tuple)):
1152 if not isinstance(msg_ids, (list,tuple)):
1148 msg_ids = [msg_ids]
1153 msg_ids = [msg_ids]
1149
1154
1150 theids = []
1155 theids = []
1151 for msg_id in msg_ids:
1156 for msg_id in msg_ids:
1152 if isinstance(msg_id, int):
1157 if isinstance(msg_id, int):
1153 msg_id = self.history[msg_id]
1158 msg_id = self.history[msg_id]
1154 if not isinstance(msg_id, basestring):
1159 if not isinstance(msg_id, basestring):
1155 raise TypeError("msg_ids must be str, not %r"%msg_id)
1160 raise TypeError("msg_ids must be str, not %r"%msg_id)
1156 theids.append(msg_id)
1161 theids.append(msg_id)
1157
1162
1158 completed = []
1163 completed = []
1159 local_results = {}
1164 local_results = {}
1160
1165
1161 # comment this block out to temporarily disable local shortcut:
1166 # comment this block out to temporarily disable local shortcut:
1162 for msg_id in theids:
1167 for msg_id in theids:
1163 if msg_id in self.results:
1168 if msg_id in self.results:
1164 completed.append(msg_id)
1169 completed.append(msg_id)
1165 local_results[msg_id] = self.results[msg_id]
1170 local_results[msg_id] = self.results[msg_id]
1166 theids.remove(msg_id)
1171 theids.remove(msg_id)
1167
1172
1168 if theids: # some not locally cached
1173 if theids: # some not locally cached
1169 content = dict(msg_ids=theids, status_only=status_only)
1174 content = dict(msg_ids=theids, status_only=status_only)
1170 msg = self.session.send(self._query_socket, "result_request", content=content)
1175 msg = self.session.send(self._query_socket, "result_request", content=content)
1171 zmq.select([self._query_socket], [], [])
1176 zmq.select([self._query_socket], [], [])
1172 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1177 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1173 if self.debug:
1178 if self.debug:
1174 pprint(msg)
1179 pprint(msg)
1175 content = msg['content']
1180 content = msg['content']
1176 if content['status'] != 'ok':
1181 if content['status'] != 'ok':
1177 raise self._unwrap_exception(content)
1182 raise self._unwrap_exception(content)
1178 buffers = msg['buffers']
1183 buffers = msg['buffers']
1179 else:
1184 else:
1180 content = dict(completed=[],pending=[])
1185 content = dict(completed=[],pending=[])
1181
1186
1182 content['completed'].extend(completed)
1187 content['completed'].extend(completed)
1183
1188
1184 if status_only:
1189 if status_only:
1185 return content
1190 return content
1186
1191
1187 failures = []
1192 failures = []
1188 # load cached results into result:
1193 # load cached results into result:
1189 content.update(local_results)
1194 content.update(local_results)
1190
1195
1191 # update cache with results:
1196 # update cache with results:
1192 for msg_id in sorted(theids):
1197 for msg_id in sorted(theids):
1193 if msg_id in content['completed']:
1198 if msg_id in content['completed']:
1194 rec = content[msg_id]
1199 rec = content[msg_id]
1195 parent = rec['header']
1200 parent = rec['header']
1196 header = rec['result_header']
1201 header = rec['result_header']
1197 rcontent = rec['result_content']
1202 rcontent = rec['result_content']
1198 iodict = rec['io']
1203 iodict = rec['io']
1199 if isinstance(rcontent, str):
1204 if isinstance(rcontent, str):
1200 rcontent = self.session.unpack(rcontent)
1205 rcontent = self.session.unpack(rcontent)
1201
1206
1202 md = self.metadata[msg_id]
1207 md = self.metadata[msg_id]
1203 md.update(self._extract_metadata(header, parent, rcontent))
1208 md.update(self._extract_metadata(header, parent, rcontent))
1204 md.update(iodict)
1209 md.update(iodict)
1205
1210
1206 if rcontent['status'] == 'ok':
1211 if rcontent['status'] == 'ok':
1207 res,buffers = util.unserialize_object(buffers)
1212 res,buffers = util.unserialize_object(buffers)
1208 else:
1213 else:
1209 print rcontent
1214 print rcontent
1210 res = self._unwrap_exception(rcontent)
1215 res = self._unwrap_exception(rcontent)
1211 failures.append(res)
1216 failures.append(res)
1212
1217
1213 self.results[msg_id] = res
1218 self.results[msg_id] = res
1214 content[msg_id] = res
1219 content[msg_id] = res
1215
1220
1216 if len(theids) == 1 and failures:
1221 if len(theids) == 1 and failures:
1217 raise failures[0]
1222 raise failures[0]
1218
1223
1219 error.collect_exceptions(failures, "result_status")
1224 error.collect_exceptions(failures, "result_status")
1220 return content
1225 return content
1221
1226
1222 @spin_first
1227 @spin_first
1223 def queue_status(self, targets='all', verbose=False):
1228 def queue_status(self, targets='all', verbose=False):
1224 """Fetch the status of engine queues.
1229 """Fetch the status of engine queues.
1225
1230
1226 Parameters
1231 Parameters
1227 ----------
1232 ----------
1228
1233
1229 targets : int/str/list of ints/strs
1234 targets : int/str/list of ints/strs
1230 the engines whose states are to be queried.
1235 the engines whose states are to be queried.
1231 default : all
1236 default : all
1232 verbose : bool
1237 verbose : bool
1233 Whether to return lengths only, or lists of ids for each element
1238 Whether to return lengths only, or lists of ids for each element
1234 """
1239 """
1235 engine_ids = self._build_targets(targets)[1]
1240 engine_ids = self._build_targets(targets)[1]
1236 content = dict(targets=engine_ids, verbose=verbose)
1241 content = dict(targets=engine_ids, verbose=verbose)
1237 self.session.send(self._query_socket, "queue_request", content=content)
1242 self.session.send(self._query_socket, "queue_request", content=content)
1238 idents,msg = self.session.recv(self._query_socket, 0)
1243 idents,msg = self.session.recv(self._query_socket, 0)
1239 if self.debug:
1244 if self.debug:
1240 pprint(msg)
1245 pprint(msg)
1241 content = msg['content']
1246 content = msg['content']
1242 status = content.pop('status')
1247 status = content.pop('status')
1243 if status != 'ok':
1248 if status != 'ok':
1244 raise self._unwrap_exception(content)
1249 raise self._unwrap_exception(content)
1245 content = util.rekey(content)
1250 content = util.rekey(content)
1246 if isinstance(targets, int):
1251 if isinstance(targets, int):
1247 return content[targets]
1252 return content[targets]
1248 else:
1253 else:
1249 return content
1254 return content
1250
1255
1251 @spin_first
1256 @spin_first
1252 def purge_results(self, jobs=[], targets=[]):
1257 def purge_results(self, jobs=[], targets=[]):
1253 """Tell the Hub to forget results.
1258 """Tell the Hub to forget results.
1254
1259
1255 Individual results can be purged by msg_id, or the entire
1260 Individual results can be purged by msg_id, or the entire
1256 history of specific targets can be purged.
1261 history of specific targets can be purged.
1257
1262
1258 Parameters
1263 Parameters
1259 ----------
1264 ----------
1260
1265
1261 jobs : str or list of str or AsyncResult objects
1266 jobs : str or list of str or AsyncResult objects
1262 the msg_ids whose results should be forgotten.
1267 the msg_ids whose results should be forgotten.
1263 targets : int/str/list of ints/strs
1268 targets : int/str/list of ints/strs
1264 The targets, by uuid or int_id, whose entire history is to be purged.
1269 The targets, by uuid or int_id, whose entire history is to be purged.
1265 Use `targets='all'` to scrub everything from the Hub's memory.
1270 Use `targets='all'` to scrub everything from the Hub's memory.
1266
1271
1267 default : None
1272 default : None
1268 """
1273 """
1269 if not targets and not jobs:
1274 if not targets and not jobs:
1270 raise ValueError("Must specify at least one of `targets` and `jobs`")
1275 raise ValueError("Must specify at least one of `targets` and `jobs`")
1271 if targets:
1276 if targets:
1272 targets = self._build_targets(targets)[1]
1277 targets = self._build_targets(targets)[1]
1273
1278
1274 # construct msg_ids from jobs
1279 # construct msg_ids from jobs
1275 msg_ids = []
1280 msg_ids = []
1276 if isinstance(jobs, (basestring,AsyncResult)):
1281 if isinstance(jobs, (basestring,AsyncResult)):
1277 jobs = [jobs]
1282 jobs = [jobs]
1278 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1283 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1279 if bad_ids:
1284 if bad_ids:
1280 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1285 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1281 for j in jobs:
1286 for j in jobs:
1282 if isinstance(j, AsyncResult):
1287 if isinstance(j, AsyncResult):
1283 msg_ids.extend(j.msg_ids)
1288 msg_ids.extend(j.msg_ids)
1284 else:
1289 else:
1285 msg_ids.append(j)
1290 msg_ids.append(j)
1286
1291
1287 content = dict(targets=targets, msg_ids=msg_ids)
1292 content = dict(targets=targets, msg_ids=msg_ids)
1288 self.session.send(self._query_socket, "purge_request", content=content)
1293 self.session.send(self._query_socket, "purge_request", content=content)
1289 idents, msg = self.session.recv(self._query_socket, 0)
1294 idents, msg = self.session.recv(self._query_socket, 0)
1290 if self.debug:
1295 if self.debug:
1291 pprint(msg)
1296 pprint(msg)
1292 content = msg['content']
1297 content = msg['content']
1293 if content['status'] != 'ok':
1298 if content['status'] != 'ok':
1294 raise self._unwrap_exception(content)
1299 raise self._unwrap_exception(content)
1295
1300
1296 @spin_first
1301 @spin_first
1297 def hub_history(self):
1302 def hub_history(self):
1298 """Get the Hub's history
1303 """Get the Hub's history
1299
1304
1300 Just like the Client, the Hub has a history, which is a list of msg_ids.
1305 Just like the Client, the Hub has a history, which is a list of msg_ids.
1301 This will contain the history of all clients, and, depending on configuration,
1306 This will contain the history of all clients, and, depending on configuration,
1302 may contain history across multiple cluster sessions.
1307 may contain history across multiple cluster sessions.
1303
1308
1304 Any msg_id returned here is a valid argument to `get_result`.
1309 Any msg_id returned here is a valid argument to `get_result`.
1305
1310
1306 Returns
1311 Returns
1307 -------
1312 -------
1308
1313
1309 msg_ids : list of strs
1314 msg_ids : list of strs
1310 list of all msg_ids, ordered by task submission time.
1315 list of all msg_ids, ordered by task submission time.
1311 """
1316 """
1312
1317
1313 self.session.send(self._query_socket, "history_request", content={})
1318 self.session.send(self._query_socket, "history_request", content={})
1314 idents, msg = self.session.recv(self._query_socket, 0)
1319 idents, msg = self.session.recv(self._query_socket, 0)
1315
1320
1316 if self.debug:
1321 if self.debug:
1317 pprint(msg)
1322 pprint(msg)
1318 content = msg['content']
1323 content = msg['content']
1319 if content['status'] != 'ok':
1324 if content['status'] != 'ok':
1320 raise self._unwrap_exception(content)
1325 raise self._unwrap_exception(content)
1321 else:
1326 else:
1322 return content['history']
1327 return content['history']
1323
1328
1324 @spin_first
1329 @spin_first
1325 def db_query(self, query, keys=None):
1330 def db_query(self, query, keys=None):
1326 """Query the Hub's TaskRecord database
1331 """Query the Hub's TaskRecord database
1327
1332
1328 This will return a list of task record dicts that match `query`
1333 This will return a list of task record dicts that match `query`
1329
1334
1330 Parameters
1335 Parameters
1331 ----------
1336 ----------
1332
1337
1333 query : mongodb query dict
1338 query : mongodb query dict
1334 The search dict. See mongodb query docs for details.
1339 The search dict. See mongodb query docs for details.
1335 keys : list of strs [optional]
1340 keys : list of strs [optional]
1336 The subset of keys to be returned. The default is to fetch everything but buffers.
1341 The subset of keys to be returned. The default is to fetch everything but buffers.
1337 'msg_id' will *always* be included.
1342 'msg_id' will *always* be included.
1338 """
1343 """
1339 if isinstance(keys, basestring):
1344 if isinstance(keys, basestring):
1340 keys = [keys]
1345 keys = [keys]
1341 content = dict(query=query, keys=keys)
1346 content = dict(query=query, keys=keys)
1342 self.session.send(self._query_socket, "db_request", content=content)
1347 self.session.send(self._query_socket, "db_request", content=content)
1343 idents, msg = self.session.recv(self._query_socket, 0)
1348 idents, msg = self.session.recv(self._query_socket, 0)
1344 if self.debug:
1349 if self.debug:
1345 pprint(msg)
1350 pprint(msg)
1346 content = msg['content']
1351 content = msg['content']
1347 if content['status'] != 'ok':
1352 if content['status'] != 'ok':
1348 raise self._unwrap_exception(content)
1353 raise self._unwrap_exception(content)
1349
1354
1350 records = content['records']
1355 records = content['records']
1351
1356
1352 buffer_lens = content['buffer_lens']
1357 buffer_lens = content['buffer_lens']
1353 result_buffer_lens = content['result_buffer_lens']
1358 result_buffer_lens = content['result_buffer_lens']
1354 buffers = msg['buffers']
1359 buffers = msg['buffers']
1355 has_bufs = buffer_lens is not None
1360 has_bufs = buffer_lens is not None
1356 has_rbufs = result_buffer_lens is not None
1361 has_rbufs = result_buffer_lens is not None
1357 for i,rec in enumerate(records):
1362 for i,rec in enumerate(records):
1358 # relink buffers
1363 # relink buffers
1359 if has_bufs:
1364 if has_bufs:
1360 blen = buffer_lens[i]
1365 blen = buffer_lens[i]
1361 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1366 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1362 if has_rbufs:
1367 if has_rbufs:
1363 blen = result_buffer_lens[i]
1368 blen = result_buffer_lens[i]
1364 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1369 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1365
1370
1366 return records
1371 return records
1367
1372
1368 __all__ = [ 'Client' ]
1373 __all__ = [ 'Client' ]
@@ -1,158 +1,165 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Classes used in scattering and gathering sequences.
3 """Classes used in scattering and gathering sequences.
4
4
5 Scattering consists of partitioning a sequence and sending the various
5 Scattering consists of partitioning a sequence and sending the various
6 pieces to individual nodes in a cluster.
6 pieces to individual nodes in a cluster.
7
8
9 Authors:
10
11 * Brian Granger
12 * MinRK
13
7 """
14 """
8
15
9 __docformat__ = "restructuredtext en"
16 __docformat__ = "restructuredtext en"
10
17
11 #-------------------------------------------------------------------------------
18 #-------------------------------------------------------------------------------
12 # Copyright (C) 2008 The IPython Development Team
19 # Copyright (C) 2008-2011 The IPython Development Team
13 #
20 #
14 # Distributed under the terms of the BSD License. The full license is in
21 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
22 # the file COPYING, distributed as part of this software.
16 #-------------------------------------------------------------------------------
23 #-------------------------------------------------------------------------------
17
24
18 #-------------------------------------------------------------------------------
25 #-------------------------------------------------------------------------------
19 # Imports
26 # Imports
20 #-------------------------------------------------------------------------------
27 #-------------------------------------------------------------------------------
21
28
22 import types
29 import types
23
30
24 from IPython.utils.data import flatten as utils_flatten
31 from IPython.utils.data import flatten as utils_flatten
25
32
26 #-------------------------------------------------------------------------------
33 #-------------------------------------------------------------------------------
27 # Figure out which array packages are present and their array types
34 # Figure out which array packages are present and their array types
28 #-------------------------------------------------------------------------------
35 #-------------------------------------------------------------------------------
29
36
30 arrayModules = []
37 arrayModules = []
31 try:
38 try:
32 import Numeric
39 import Numeric
33 except ImportError:
40 except ImportError:
34 pass
41 pass
35 else:
42 else:
36 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
43 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
37 try:
44 try:
38 import numpy
45 import numpy
39 except ImportError:
46 except ImportError:
40 pass
47 pass
41 else:
48 else:
42 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
49 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
43 try:
50 try:
44 import numarray
51 import numarray
45 except ImportError:
52 except ImportError:
46 pass
53 pass
47 else:
54 else:
48 arrayModules.append({'module':numarray,
55 arrayModules.append({'module':numarray,
49 'type':numarray.numarraycore.NumArray})
56 'type':numarray.numarraycore.NumArray})
50
57
51 class Map:
58 class Map:
52 """A class for partitioning a sequence using a map."""
59 """A class for partitioning a sequence using a map."""
53
60
54 def getPartition(self, seq, p, q):
61 def getPartition(self, seq, p, q):
55 """Returns the pth partition of q partitions of seq."""
62 """Returns the pth partition of q partitions of seq."""
56
63
57 # Test for error conditions here
64 # Test for error conditions here
58 if p<0 or p>=q:
65 if p<0 or p>=q:
59 print "No partition exists."
66 print "No partition exists."
60 return
67 return
61
68
62 remainder = len(seq)%q
69 remainder = len(seq)%q
63 basesize = len(seq)/q
70 basesize = len(seq)/q
64 hi = []
71 hi = []
65 lo = []
72 lo = []
66 for n in range(q):
73 for n in range(q):
67 if n < remainder:
74 if n < remainder:
68 lo.append(n * (basesize + 1))
75 lo.append(n * (basesize + 1))
69 hi.append(lo[-1] + basesize + 1)
76 hi.append(lo[-1] + basesize + 1)
70 else:
77 else:
71 lo.append(n*basesize + remainder)
78 lo.append(n*basesize + remainder)
72 hi.append(lo[-1] + basesize)
79 hi.append(lo[-1] + basesize)
73
80
74
81
75 result = seq[lo[p]:hi[p]]
82 result = seq[lo[p]:hi[p]]
76 return result
83 return result
77
84
78 def joinPartitions(self, listOfPartitions):
85 def joinPartitions(self, listOfPartitions):
79 return self.concatenate(listOfPartitions)
86 return self.concatenate(listOfPartitions)
80
87
81 def concatenate(self, listOfPartitions):
88 def concatenate(self, listOfPartitions):
82 testObject = listOfPartitions[0]
89 testObject = listOfPartitions[0]
83 # First see if we have a known array type
90 # First see if we have a known array type
84 for m in arrayModules:
91 for m in arrayModules:
85 #print m
92 #print m
86 if isinstance(testObject, m['type']):
93 if isinstance(testObject, m['type']):
87 return m['module'].concatenate(listOfPartitions)
94 return m['module'].concatenate(listOfPartitions)
88 # Next try for Python sequence types
95 # Next try for Python sequence types
89 if isinstance(testObject, (types.ListType, types.TupleType)):
96 if isinstance(testObject, (types.ListType, types.TupleType)):
90 return utils_flatten(listOfPartitions)
97 return utils_flatten(listOfPartitions)
91 # If we have scalars, just return listOfPartitions
98 # If we have scalars, just return listOfPartitions
92 return listOfPartitions
99 return listOfPartitions
93
100
94 class RoundRobinMap(Map):
101 class RoundRobinMap(Map):
95 """Partitions a sequence in a roun robin fashion.
102 """Partitions a sequence in a roun robin fashion.
96
103
97 This currently does not work!
104 This currently does not work!
98 """
105 """
99
106
100 def getPartition(self, seq, p, q):
107 def getPartition(self, seq, p, q):
101 # if not isinstance(seq,(list,tuple)):
108 # if not isinstance(seq,(list,tuple)):
102 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
109 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
103 return seq[p:len(seq):q]
110 return seq[p:len(seq):q]
104 #result = []
111 #result = []
105 #for i in range(p,len(seq),q):
112 #for i in range(p,len(seq),q):
106 # result.append(seq[i])
113 # result.append(seq[i])
107 #return result
114 #return result
108
115
109 def joinPartitions(self, listOfPartitions):
116 def joinPartitions(self, listOfPartitions):
110 testObject = listOfPartitions[0]
117 testObject = listOfPartitions[0]
111 # First see if we have a known array type
118 # First see if we have a known array type
112 for m in arrayModules:
119 for m in arrayModules:
113 #print m
120 #print m
114 if isinstance(testObject, m['type']):
121 if isinstance(testObject, m['type']):
115 return self.flatten_array(m['type'], listOfPartitions)
122 return self.flatten_array(m['type'], listOfPartitions)
116 if isinstance(testObject, (types.ListType, types.TupleType)):
123 if isinstance(testObject, (types.ListType, types.TupleType)):
117 return self.flatten_list(listOfPartitions)
124 return self.flatten_list(listOfPartitions)
118 return listOfPartitions
125 return listOfPartitions
119
126
120 def flatten_array(self, klass, listOfPartitions):
127 def flatten_array(self, klass, listOfPartitions):
121 test = listOfPartitions[0]
128 test = listOfPartitions[0]
122 shape = list(test.shape)
129 shape = list(test.shape)
123 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
130 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
124 A = klass(shape)
131 A = klass(shape)
125 N = shape[0]
132 N = shape[0]
126 q = len(listOfPartitions)
133 q = len(listOfPartitions)
127 for p,part in enumerate(listOfPartitions):
134 for p,part in enumerate(listOfPartitions):
128 A[p:N:q] = part
135 A[p:N:q] = part
129 return A
136 return A
130
137
131 def flatten_list(self, listOfPartitions):
138 def flatten_list(self, listOfPartitions):
132 flat = []
139 flat = []
133 for i in range(len(listOfPartitions[0])):
140 for i in range(len(listOfPartitions[0])):
134 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
141 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
135 return flat
142 return flat
136 #lengths = [len(x) for x in listOfPartitions]
143 #lengths = [len(x) for x in listOfPartitions]
137 #maxPartitionLength = len(listOfPartitions[0])
144 #maxPartitionLength = len(listOfPartitions[0])
138 #numberOfPartitions = len(listOfPartitions)
145 #numberOfPartitions = len(listOfPartitions)
139 #concat = self.concatenate(listOfPartitions)
146 #concat = self.concatenate(listOfPartitions)
140 #totalLength = len(concat)
147 #totalLength = len(concat)
141 #result = []
148 #result = []
142 #for i in range(maxPartitionLength):
149 #for i in range(maxPartitionLength):
143 # result.append(concat[i:totalLength:maxPartitionLength])
150 # result.append(concat[i:totalLength:maxPartitionLength])
144 # return self.concatenate(listOfPartitions)
151 # return self.concatenate(listOfPartitions)
145
152
146 def mappable(obj):
153 def mappable(obj):
147 """return whether an object is mappable or not."""
154 """return whether an object is mappable or not."""
148 if isinstance(obj, (tuple,list)):
155 if isinstance(obj, (tuple,list)):
149 return True
156 return True
150 for m in arrayModules:
157 for m in arrayModules:
151 if isinstance(obj,m['type']):
158 if isinstance(obj,m['type']):
152 return True
159 return True
153 return False
160 return False
154
161
155 dists = {'b':Map,'r':RoundRobinMap}
162 dists = {'b':Map,'r':RoundRobinMap}
156
163
157
164
158
165
@@ -1,200 +1,206 b''
1 """Remote Functions and decorators for Views."""
1 """Remote Functions and decorators for Views.
2
3 Authors:
4
5 * Brian Granger
6 * Min RK
7 """
2 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
4 #
10 #
5 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
8
14
9 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
10 # Imports
16 # Imports
11 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
12
18
13 import warnings
19 import warnings
14
20
15 from IPython.testing.skipdoctest import skip_doctest
21 from IPython.testing.skipdoctest import skip_doctest
16
22
17 from . import map as Map
23 from . import map as Map
18 from .asyncresult import AsyncMapResult
24 from .asyncresult import AsyncMapResult
19
25
20 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
21 # Decorators
27 # Decorators
22 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
23
29
24 @skip_doctest
30 @skip_doctest
25 def remote(view, block=None, **flags):
31 def remote(view, block=None, **flags):
26 """Turn a function into a remote function.
32 """Turn a function into a remote function.
27
33
28 This method can be used for map:
34 This method can be used for map:
29
35
30 In [1]: @remote(view,block=True)
36 In [1]: @remote(view,block=True)
31 ...: def func(a):
37 ...: def func(a):
32 ...: pass
38 ...: pass
33 """
39 """
34
40
35 def remote_function(f):
41 def remote_function(f):
36 return RemoteFunction(view, f, block=block, **flags)
42 return RemoteFunction(view, f, block=block, **flags)
37 return remote_function
43 return remote_function
38
44
39 @skip_doctest
45 @skip_doctest
40 def parallel(view, dist='b', block=None, **flags):
46 def parallel(view, dist='b', block=None, **flags):
41 """Turn a function into a parallel remote function.
47 """Turn a function into a parallel remote function.
42
48
43 This method can be used for map:
49 This method can be used for map:
44
50
45 In [1]: @parallel(view, block=True)
51 In [1]: @parallel(view, block=True)
46 ...: def func(a):
52 ...: def func(a):
47 ...: pass
53 ...: pass
48 """
54 """
49
55
50 def parallel_function(f):
56 def parallel_function(f):
51 return ParallelFunction(view, f, dist=dist, block=block, **flags)
57 return ParallelFunction(view, f, dist=dist, block=block, **flags)
52 return parallel_function
58 return parallel_function
53
59
54 #--------------------------------------------------------------------------
60 #--------------------------------------------------------------------------
55 # Classes
61 # Classes
56 #--------------------------------------------------------------------------
62 #--------------------------------------------------------------------------
57
63
58 class RemoteFunction(object):
64 class RemoteFunction(object):
59 """Turn an existing function into a remote function.
65 """Turn an existing function into a remote function.
60
66
61 Parameters
67 Parameters
62 ----------
68 ----------
63
69
64 view : View instance
70 view : View instance
65 The view to be used for execution
71 The view to be used for execution
66 f : callable
72 f : callable
67 The function to be wrapped into a remote function
73 The function to be wrapped into a remote function
68 block : bool [default: None]
74 block : bool [default: None]
69 Whether to wait for results or not. The default behavior is
75 Whether to wait for results or not. The default behavior is
70 to use the current `block` attribute of `view`
76 to use the current `block` attribute of `view`
71
77
72 **flags : remaining kwargs are passed to View.temp_flags
78 **flags : remaining kwargs are passed to View.temp_flags
73 """
79 """
74
80
75 view = None # the remote connection
81 view = None # the remote connection
76 func = None # the wrapped function
82 func = None # the wrapped function
77 block = None # whether to block
83 block = None # whether to block
78 flags = None # dict of extra kwargs for temp_flags
84 flags = None # dict of extra kwargs for temp_flags
79
85
80 def __init__(self, view, f, block=None, **flags):
86 def __init__(self, view, f, block=None, **flags):
81 self.view = view
87 self.view = view
82 self.func = f
88 self.func = f
83 self.block=block
89 self.block=block
84 self.flags=flags
90 self.flags=flags
85
91
86 def __call__(self, *args, **kwargs):
92 def __call__(self, *args, **kwargs):
87 block = self.view.block if self.block is None else self.block
93 block = self.view.block if self.block is None else self.block
88 with self.view.temp_flags(block=block, **self.flags):
94 with self.view.temp_flags(block=block, **self.flags):
89 return self.view.apply(self.func, *args, **kwargs)
95 return self.view.apply(self.func, *args, **kwargs)
90
96
91
97
92 class ParallelFunction(RemoteFunction):
98 class ParallelFunction(RemoteFunction):
93 """Class for mapping a function to sequences.
99 """Class for mapping a function to sequences.
94
100
95 This will distribute the sequences according the a mapper, and call
101 This will distribute the sequences according the a mapper, and call
96 the function on each sub-sequence. If called via map, then the function
102 the function on each sub-sequence. If called via map, then the function
97 will be called once on each element, rather that each sub-sequence.
103 will be called once on each element, rather that each sub-sequence.
98
104
99 Parameters
105 Parameters
100 ----------
106 ----------
101
107
102 view : View instance
108 view : View instance
103 The view to be used for execution
109 The view to be used for execution
104 f : callable
110 f : callable
105 The function to be wrapped into a remote function
111 The function to be wrapped into a remote function
106 dist : str [default: 'b']
112 dist : str [default: 'b']
107 The key for which mapObject to use to distribute sequences
113 The key for which mapObject to use to distribute sequences
108 options are:
114 options are:
109 * 'b' : use contiguous chunks in order
115 * 'b' : use contiguous chunks in order
110 * 'r' : use round-robin striping
116 * 'r' : use round-robin striping
111 block : bool [default: None]
117 block : bool [default: None]
112 Whether to wait for results or not. The default behavior is
118 Whether to wait for results or not. The default behavior is
113 to use the current `block` attribute of `view`
119 to use the current `block` attribute of `view`
114 chunksize : int or None
120 chunksize : int or None
115 The size of chunk to use when breaking up sequences in a load-balanced manner
121 The size of chunk to use when breaking up sequences in a load-balanced manner
116 **flags : remaining kwargs are passed to View.temp_flags
122 **flags : remaining kwargs are passed to View.temp_flags
117 """
123 """
118
124
119 chunksize=None
125 chunksize=None
120 mapObject=None
126 mapObject=None
121
127
122 def __init__(self, view, f, dist='b', block=None, chunksize=None, **flags):
128 def __init__(self, view, f, dist='b', block=None, chunksize=None, **flags):
123 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
129 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
124 self.chunksize = chunksize
130 self.chunksize = chunksize
125
131
126 mapClass = Map.dists[dist]
132 mapClass = Map.dists[dist]
127 self.mapObject = mapClass()
133 self.mapObject = mapClass()
128
134
129 def __call__(self, *sequences):
135 def __call__(self, *sequences):
130 # check that the length of sequences match
136 # check that the length of sequences match
131 len_0 = len(sequences[0])
137 len_0 = len(sequences[0])
132 for s in sequences:
138 for s in sequences:
133 if len(s)!=len_0:
139 if len(s)!=len_0:
134 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
140 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
135 raise ValueError(msg)
141 raise ValueError(msg)
136 balanced = 'Balanced' in self.view.__class__.__name__
142 balanced = 'Balanced' in self.view.__class__.__name__
137 if balanced:
143 if balanced:
138 if self.chunksize:
144 if self.chunksize:
139 nparts = len_0/self.chunksize + int(len_0%self.chunksize > 0)
145 nparts = len_0/self.chunksize + int(len_0%self.chunksize > 0)
140 else:
146 else:
141 nparts = len_0
147 nparts = len_0
142 targets = [None]*nparts
148 targets = [None]*nparts
143 else:
149 else:
144 if self.chunksize:
150 if self.chunksize:
145 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
151 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
146 # multiplexed:
152 # multiplexed:
147 targets = self.view.targets
153 targets = self.view.targets
148 nparts = len(targets)
154 nparts = len(targets)
149
155
150 msg_ids = []
156 msg_ids = []
151 # my_f = lambda *a: map(self.func, *a)
157 # my_f = lambda *a: map(self.func, *a)
152 client = self.view.client
158 client = self.view.client
153 for index, t in enumerate(targets):
159 for index, t in enumerate(targets):
154 args = []
160 args = []
155 for seq in sequences:
161 for seq in sequences:
156 part = self.mapObject.getPartition(seq, index, nparts)
162 part = self.mapObject.getPartition(seq, index, nparts)
157 if len(part) == 0:
163 if len(part) == 0:
158 continue
164 continue
159 else:
165 else:
160 args.append(part)
166 args.append(part)
161 if not args:
167 if not args:
162 continue
168 continue
163
169
164 # print (args)
170 # print (args)
165 if hasattr(self, '_map'):
171 if hasattr(self, '_map'):
166 f = map
172 f = map
167 args = [self.func]+args
173 args = [self.func]+args
168 else:
174 else:
169 f=self.func
175 f=self.func
170
176
171 view = self.view if balanced else client[t]
177 view = self.view if balanced else client[t]
172 with view.temp_flags(block=False, **self.flags):
178 with view.temp_flags(block=False, **self.flags):
173 ar = view.apply(f, *args)
179 ar = view.apply(f, *args)
174
180
175 msg_ids.append(ar.msg_ids[0])
181 msg_ids.append(ar.msg_ids[0])
176
182
177 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject, fname=self.func.__name__)
183 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject, fname=self.func.__name__)
178
184
179 if self.block:
185 if self.block:
180 try:
186 try:
181 return r.get()
187 return r.get()
182 except KeyboardInterrupt:
188 except KeyboardInterrupt:
183 return r
189 return r
184 else:
190 else:
185 return r
191 return r
186
192
187 def map(self, *sequences):
193 def map(self, *sequences):
188 """call a function on each element of a sequence remotely.
194 """call a function on each element of a sequence remotely.
189 This should behave very much like the builtin map, but return an AsyncMapResult
195 This should behave very much like the builtin map, but return an AsyncMapResult
190 if self.block is False.
196 if self.block is False.
191 """
197 """
192 # set _map as a flag for use inside self.__call__
198 # set _map as a flag for use inside self.__call__
193 self._map = True
199 self._map = True
194 try:
200 try:
195 ret = self.__call__(*sequences)
201 ret = self.__call__(*sequences)
196 finally:
202 finally:
197 del self._map
203 del self._map
198 return ret
204 return ret
199
205
200 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
206 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,1041 +1,1046 b''
1 """Views of remote engines."""
1 """Views of remote engines.
2
3 Authors:
4
5 * Min RK
6 """
2 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
4 #
9 #
5 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
8
13
9 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
10 # Imports
15 # Imports
11 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
12
17
13 import imp
18 import imp
14 import sys
19 import sys
15 import warnings
20 import warnings
16 from contextlib import contextmanager
21 from contextlib import contextmanager
17 from types import ModuleType
22 from types import ModuleType
18
23
19 import zmq
24 import zmq
20
25
21 from IPython.testing.skipdoctest import skip_doctest
26 from IPython.testing.skipdoctest import skip_doctest
22 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat, CInt
27 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat, CInt
23 from IPython.external.decorator import decorator
28 from IPython.external.decorator import decorator
24
29
25 from IPython.parallel import util
30 from IPython.parallel import util
26 from IPython.parallel.controller.dependency import Dependency, dependent
31 from IPython.parallel.controller.dependency import Dependency, dependent
27
32
28 from . import map as Map
33 from . import map as Map
29 from .asyncresult import AsyncResult, AsyncMapResult
34 from .asyncresult import AsyncResult, AsyncMapResult
30 from .remotefunction import ParallelFunction, parallel, remote
35 from .remotefunction import ParallelFunction, parallel, remote
31
36
32 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
33 # Decorators
38 # Decorators
34 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
35
40
36 @decorator
41 @decorator
37 def save_ids(f, self, *args, **kwargs):
42 def save_ids(f, self, *args, **kwargs):
38 """Keep our history and outstanding attributes up to date after a method call."""
43 """Keep our history and outstanding attributes up to date after a method call."""
39 n_previous = len(self.client.history)
44 n_previous = len(self.client.history)
40 try:
45 try:
41 ret = f(self, *args, **kwargs)
46 ret = f(self, *args, **kwargs)
42 finally:
47 finally:
43 nmsgs = len(self.client.history) - n_previous
48 nmsgs = len(self.client.history) - n_previous
44 msg_ids = self.client.history[-nmsgs:]
49 msg_ids = self.client.history[-nmsgs:]
45 self.history.extend(msg_ids)
50 self.history.extend(msg_ids)
46 map(self.outstanding.add, msg_ids)
51 map(self.outstanding.add, msg_ids)
47 return ret
52 return ret
48
53
49 @decorator
54 @decorator
50 def sync_results(f, self, *args, **kwargs):
55 def sync_results(f, self, *args, **kwargs):
51 """sync relevant results from self.client to our results attribute."""
56 """sync relevant results from self.client to our results attribute."""
52 ret = f(self, *args, **kwargs)
57 ret = f(self, *args, **kwargs)
53 delta = self.outstanding.difference(self.client.outstanding)
58 delta = self.outstanding.difference(self.client.outstanding)
54 completed = self.outstanding.intersection(delta)
59 completed = self.outstanding.intersection(delta)
55 self.outstanding = self.outstanding.difference(completed)
60 self.outstanding = self.outstanding.difference(completed)
56 for msg_id in completed:
61 for msg_id in completed:
57 self.results[msg_id] = self.client.results[msg_id]
62 self.results[msg_id] = self.client.results[msg_id]
58 return ret
63 return ret
59
64
60 @decorator
65 @decorator
61 def spin_after(f, self, *args, **kwargs):
66 def spin_after(f, self, *args, **kwargs):
62 """call spin after the method."""
67 """call spin after the method."""
63 ret = f(self, *args, **kwargs)
68 ret = f(self, *args, **kwargs)
64 self.spin()
69 self.spin()
65 return ret
70 return ret
66
71
67 #-----------------------------------------------------------------------------
72 #-----------------------------------------------------------------------------
68 # Classes
73 # Classes
69 #-----------------------------------------------------------------------------
74 #-----------------------------------------------------------------------------
70
75
71 @skip_doctest
76 @skip_doctest
72 class View(HasTraits):
77 class View(HasTraits):
73 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
78 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
74
79
75 Don't use this class, use subclasses.
80 Don't use this class, use subclasses.
76
81
77 Methods
82 Methods
78 -------
83 -------
79
84
80 spin
85 spin
81 flushes incoming results and registration state changes
86 flushes incoming results and registration state changes
82 control methods spin, and requesting `ids` also ensures up to date
87 control methods spin, and requesting `ids` also ensures up to date
83
88
84 wait
89 wait
85 wait on one or more msg_ids
90 wait on one or more msg_ids
86
91
87 execution methods
92 execution methods
88 apply
93 apply
89 legacy: execute, run
94 legacy: execute, run
90
95
91 data movement
96 data movement
92 push, pull, scatter, gather
97 push, pull, scatter, gather
93
98
94 query methods
99 query methods
95 get_result, queue_status, purge_results, result_status
100 get_result, queue_status, purge_results, result_status
96
101
97 control methods
102 control methods
98 abort, shutdown
103 abort, shutdown
99
104
100 """
105 """
101 # flags
106 # flags
102 block=Bool(False)
107 block=Bool(False)
103 track=Bool(True)
108 track=Bool(True)
104 targets = Any()
109 targets = Any()
105
110
106 history=List()
111 history=List()
107 outstanding = Set()
112 outstanding = Set()
108 results = Dict()
113 results = Dict()
109 client = Instance('IPython.parallel.Client')
114 client = Instance('IPython.parallel.Client')
110
115
111 _socket = Instance('zmq.Socket')
116 _socket = Instance('zmq.Socket')
112 _flag_names = List(['targets', 'block', 'track'])
117 _flag_names = List(['targets', 'block', 'track'])
113 _targets = Any()
118 _targets = Any()
114 _idents = Any()
119 _idents = Any()
115
120
116 def __init__(self, client=None, socket=None, **flags):
121 def __init__(self, client=None, socket=None, **flags):
117 super(View, self).__init__(client=client, _socket=socket)
122 super(View, self).__init__(client=client, _socket=socket)
118 self.block = client.block
123 self.block = client.block
119
124
120 self.set_flags(**flags)
125 self.set_flags(**flags)
121
126
122 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
127 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
123
128
124
129
125 def __repr__(self):
130 def __repr__(self):
126 strtargets = str(self.targets)
131 strtargets = str(self.targets)
127 if len(strtargets) > 16:
132 if len(strtargets) > 16:
128 strtargets = strtargets[:12]+'...]'
133 strtargets = strtargets[:12]+'...]'
129 return "<%s %s>"%(self.__class__.__name__, strtargets)
134 return "<%s %s>"%(self.__class__.__name__, strtargets)
130
135
131 def set_flags(self, **kwargs):
136 def set_flags(self, **kwargs):
132 """set my attribute flags by keyword.
137 """set my attribute flags by keyword.
133
138
134 Views determine behavior with a few attributes (`block`, `track`, etc.).
139 Views determine behavior with a few attributes (`block`, `track`, etc.).
135 These attributes can be set all at once by name with this method.
140 These attributes can be set all at once by name with this method.
136
141
137 Parameters
142 Parameters
138 ----------
143 ----------
139
144
140 block : bool
145 block : bool
141 whether to wait for results
146 whether to wait for results
142 track : bool
147 track : bool
143 whether to create a MessageTracker to allow the user to
148 whether to create a MessageTracker to allow the user to
144 safely edit after arrays and buffers during non-copying
149 safely edit after arrays and buffers during non-copying
145 sends.
150 sends.
146 """
151 """
147 for name, value in kwargs.iteritems():
152 for name, value in kwargs.iteritems():
148 if name not in self._flag_names:
153 if name not in self._flag_names:
149 raise KeyError("Invalid name: %r"%name)
154 raise KeyError("Invalid name: %r"%name)
150 else:
155 else:
151 setattr(self, name, value)
156 setattr(self, name, value)
152
157
153 @contextmanager
158 @contextmanager
154 def temp_flags(self, **kwargs):
159 def temp_flags(self, **kwargs):
155 """temporarily set flags, for use in `with` statements.
160 """temporarily set flags, for use in `with` statements.
156
161
157 See set_flags for permanent setting of flags
162 See set_flags for permanent setting of flags
158
163
159 Examples
164 Examples
160 --------
165 --------
161
166
162 >>> view.track=False
167 >>> view.track=False
163 ...
168 ...
164 >>> with view.temp_flags(track=True):
169 >>> with view.temp_flags(track=True):
165 ... ar = view.apply(dostuff, my_big_array)
170 ... ar = view.apply(dostuff, my_big_array)
166 ... ar.tracker.wait() # wait for send to finish
171 ... ar.tracker.wait() # wait for send to finish
167 >>> view.track
172 >>> view.track
168 False
173 False
169
174
170 """
175 """
171 # preflight: save flags, and set temporaries
176 # preflight: save flags, and set temporaries
172 saved_flags = {}
177 saved_flags = {}
173 for f in self._flag_names:
178 for f in self._flag_names:
174 saved_flags[f] = getattr(self, f)
179 saved_flags[f] = getattr(self, f)
175 self.set_flags(**kwargs)
180 self.set_flags(**kwargs)
176 # yield to the with-statement block
181 # yield to the with-statement block
177 try:
182 try:
178 yield
183 yield
179 finally:
184 finally:
180 # postflight: restore saved flags
185 # postflight: restore saved flags
181 self.set_flags(**saved_flags)
186 self.set_flags(**saved_flags)
182
187
183
188
184 #----------------------------------------------------------------
189 #----------------------------------------------------------------
185 # apply
190 # apply
186 #----------------------------------------------------------------
191 #----------------------------------------------------------------
187
192
188 @sync_results
193 @sync_results
189 @save_ids
194 @save_ids
190 def _really_apply(self, f, args, kwargs, block=None, **options):
195 def _really_apply(self, f, args, kwargs, block=None, **options):
191 """wrapper for client.send_apply_message"""
196 """wrapper for client.send_apply_message"""
192 raise NotImplementedError("Implement in subclasses")
197 raise NotImplementedError("Implement in subclasses")
193
198
194 def apply(self, f, *args, **kwargs):
199 def apply(self, f, *args, **kwargs):
195 """calls f(*args, **kwargs) on remote engines, returning the result.
200 """calls f(*args, **kwargs) on remote engines, returning the result.
196
201
197 This method sets all apply flags via this View's attributes.
202 This method sets all apply flags via this View's attributes.
198
203
199 if self.block is False:
204 if self.block is False:
200 returns AsyncResult
205 returns AsyncResult
201 else:
206 else:
202 returns actual result of f(*args, **kwargs)
207 returns actual result of f(*args, **kwargs)
203 """
208 """
204 return self._really_apply(f, args, kwargs)
209 return self._really_apply(f, args, kwargs)
205
210
206 def apply_async(self, f, *args, **kwargs):
211 def apply_async(self, f, *args, **kwargs):
207 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
212 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
208
213
209 returns AsyncResult
214 returns AsyncResult
210 """
215 """
211 return self._really_apply(f, args, kwargs, block=False)
216 return self._really_apply(f, args, kwargs, block=False)
212
217
213 @spin_after
218 @spin_after
214 def apply_sync(self, f, *args, **kwargs):
219 def apply_sync(self, f, *args, **kwargs):
215 """calls f(*args, **kwargs) on remote engines in a blocking manner,
220 """calls f(*args, **kwargs) on remote engines in a blocking manner,
216 returning the result.
221 returning the result.
217
222
218 returns: actual result of f(*args, **kwargs)
223 returns: actual result of f(*args, **kwargs)
219 """
224 """
220 return self._really_apply(f, args, kwargs, block=True)
225 return self._really_apply(f, args, kwargs, block=True)
221
226
222 #----------------------------------------------------------------
227 #----------------------------------------------------------------
223 # wrappers for client and control methods
228 # wrappers for client and control methods
224 #----------------------------------------------------------------
229 #----------------------------------------------------------------
225 @sync_results
230 @sync_results
226 def spin(self):
231 def spin(self):
227 """spin the client, and sync"""
232 """spin the client, and sync"""
228 self.client.spin()
233 self.client.spin()
229
234
230 @sync_results
235 @sync_results
231 def wait(self, jobs=None, timeout=-1):
236 def wait(self, jobs=None, timeout=-1):
232 """waits on one or more `jobs`, for up to `timeout` seconds.
237 """waits on one or more `jobs`, for up to `timeout` seconds.
233
238
234 Parameters
239 Parameters
235 ----------
240 ----------
236
241
237 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
242 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
238 ints are indices to self.history
243 ints are indices to self.history
239 strs are msg_ids
244 strs are msg_ids
240 default: wait on all outstanding messages
245 default: wait on all outstanding messages
241 timeout : float
246 timeout : float
242 a time in seconds, after which to give up.
247 a time in seconds, after which to give up.
243 default is -1, which means no timeout
248 default is -1, which means no timeout
244
249
245 Returns
250 Returns
246 -------
251 -------
247
252
248 True : when all msg_ids are done
253 True : when all msg_ids are done
249 False : timeout reached, some msg_ids still outstanding
254 False : timeout reached, some msg_ids still outstanding
250 """
255 """
251 if jobs is None:
256 if jobs is None:
252 jobs = self.history
257 jobs = self.history
253 return self.client.wait(jobs, timeout)
258 return self.client.wait(jobs, timeout)
254
259
255 def abort(self, jobs=None, targets=None, block=None):
260 def abort(self, jobs=None, targets=None, block=None):
256 """Abort jobs on my engines.
261 """Abort jobs on my engines.
257
262
258 Parameters
263 Parameters
259 ----------
264 ----------
260
265
261 jobs : None, str, list of strs, optional
266 jobs : None, str, list of strs, optional
262 if None: abort all jobs.
267 if None: abort all jobs.
263 else: abort specific msg_id(s).
268 else: abort specific msg_id(s).
264 """
269 """
265 block = block if block is not None else self.block
270 block = block if block is not None else self.block
266 targets = targets if targets is not None else self.targets
271 targets = targets if targets is not None else self.targets
267 return self.client.abort(jobs=jobs, targets=targets, block=block)
272 return self.client.abort(jobs=jobs, targets=targets, block=block)
268
273
269 def queue_status(self, targets=None, verbose=False):
274 def queue_status(self, targets=None, verbose=False):
270 """Fetch the Queue status of my engines"""
275 """Fetch the Queue status of my engines"""
271 targets = targets if targets is not None else self.targets
276 targets = targets if targets is not None else self.targets
272 return self.client.queue_status(targets=targets, verbose=verbose)
277 return self.client.queue_status(targets=targets, verbose=verbose)
273
278
274 def purge_results(self, jobs=[], targets=[]):
279 def purge_results(self, jobs=[], targets=[]):
275 """Instruct the controller to forget specific results."""
280 """Instruct the controller to forget specific results."""
276 if targets is None or targets == 'all':
281 if targets is None or targets == 'all':
277 targets = self.targets
282 targets = self.targets
278 return self.client.purge_results(jobs=jobs, targets=targets)
283 return self.client.purge_results(jobs=jobs, targets=targets)
279
284
280 def shutdown(self, targets=None, restart=False, hub=False, block=None):
285 def shutdown(self, targets=None, restart=False, hub=False, block=None):
281 """Terminates one or more engine processes, optionally including the hub.
286 """Terminates one or more engine processes, optionally including the hub.
282 """
287 """
283 block = self.block if block is None else block
288 block = self.block if block is None else block
284 if targets is None or targets == 'all':
289 if targets is None or targets == 'all':
285 targets = self.targets
290 targets = self.targets
286 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
291 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
287
292
288 @spin_after
293 @spin_after
289 def get_result(self, indices_or_msg_ids=None):
294 def get_result(self, indices_or_msg_ids=None):
290 """return one or more results, specified by history index or msg_id.
295 """return one or more results, specified by history index or msg_id.
291
296
292 See client.get_result for details.
297 See client.get_result for details.
293
298
294 """
299 """
295
300
296 if indices_or_msg_ids is None:
301 if indices_or_msg_ids is None:
297 indices_or_msg_ids = -1
302 indices_or_msg_ids = -1
298 if isinstance(indices_or_msg_ids, int):
303 if isinstance(indices_or_msg_ids, int):
299 indices_or_msg_ids = self.history[indices_or_msg_ids]
304 indices_or_msg_ids = self.history[indices_or_msg_ids]
300 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
305 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
301 indices_or_msg_ids = list(indices_or_msg_ids)
306 indices_or_msg_ids = list(indices_or_msg_ids)
302 for i,index in enumerate(indices_or_msg_ids):
307 for i,index in enumerate(indices_or_msg_ids):
303 if isinstance(index, int):
308 if isinstance(index, int):
304 indices_or_msg_ids[i] = self.history[index]
309 indices_or_msg_ids[i] = self.history[index]
305 return self.client.get_result(indices_or_msg_ids)
310 return self.client.get_result(indices_or_msg_ids)
306
311
307 #-------------------------------------------------------------------
312 #-------------------------------------------------------------------
308 # Map
313 # Map
309 #-------------------------------------------------------------------
314 #-------------------------------------------------------------------
310
315
311 def map(self, f, *sequences, **kwargs):
316 def map(self, f, *sequences, **kwargs):
312 """override in subclasses"""
317 """override in subclasses"""
313 raise NotImplementedError
318 raise NotImplementedError
314
319
315 def map_async(self, f, *sequences, **kwargs):
320 def map_async(self, f, *sequences, **kwargs):
316 """Parallel version of builtin `map`, using this view's engines.
321 """Parallel version of builtin `map`, using this view's engines.
317
322
318 This is equivalent to map(...block=False)
323 This is equivalent to map(...block=False)
319
324
320 See `self.map` for details.
325 See `self.map` for details.
321 """
326 """
322 if 'block' in kwargs:
327 if 'block' in kwargs:
323 raise TypeError("map_async doesn't take a `block` keyword argument.")
328 raise TypeError("map_async doesn't take a `block` keyword argument.")
324 kwargs['block'] = False
329 kwargs['block'] = False
325 return self.map(f,*sequences,**kwargs)
330 return self.map(f,*sequences,**kwargs)
326
331
327 def map_sync(self, f, *sequences, **kwargs):
332 def map_sync(self, f, *sequences, **kwargs):
328 """Parallel version of builtin `map`, using this view's engines.
333 """Parallel version of builtin `map`, using this view's engines.
329
334
330 This is equivalent to map(...block=True)
335 This is equivalent to map(...block=True)
331
336
332 See `self.map` for details.
337 See `self.map` for details.
333 """
338 """
334 if 'block' in kwargs:
339 if 'block' in kwargs:
335 raise TypeError("map_sync doesn't take a `block` keyword argument.")
340 raise TypeError("map_sync doesn't take a `block` keyword argument.")
336 kwargs['block'] = True
341 kwargs['block'] = True
337 return self.map(f,*sequences,**kwargs)
342 return self.map(f,*sequences,**kwargs)
338
343
339 def imap(self, f, *sequences, **kwargs):
344 def imap(self, f, *sequences, **kwargs):
340 """Parallel version of `itertools.imap`.
345 """Parallel version of `itertools.imap`.
341
346
342 See `self.map` for details.
347 See `self.map` for details.
343
348
344 """
349 """
345
350
346 return iter(self.map_async(f,*sequences, **kwargs))
351 return iter(self.map_async(f,*sequences, **kwargs))
347
352
348 #-------------------------------------------------------------------
353 #-------------------------------------------------------------------
349 # Decorators
354 # Decorators
350 #-------------------------------------------------------------------
355 #-------------------------------------------------------------------
351
356
352 def remote(self, block=True, **flags):
357 def remote(self, block=True, **flags):
353 """Decorator for making a RemoteFunction"""
358 """Decorator for making a RemoteFunction"""
354 block = self.block if block is None else block
359 block = self.block if block is None else block
355 return remote(self, block=block, **flags)
360 return remote(self, block=block, **flags)
356
361
357 def parallel(self, dist='b', block=None, **flags):
362 def parallel(self, dist='b', block=None, **flags):
358 """Decorator for making a ParallelFunction"""
363 """Decorator for making a ParallelFunction"""
359 block = self.block if block is None else block
364 block = self.block if block is None else block
360 return parallel(self, dist=dist, block=block, **flags)
365 return parallel(self, dist=dist, block=block, **flags)
361
366
362 @skip_doctest
367 @skip_doctest
363 class DirectView(View):
368 class DirectView(View):
364 """Direct Multiplexer View of one or more engines.
369 """Direct Multiplexer View of one or more engines.
365
370
366 These are created via indexed access to a client:
371 These are created via indexed access to a client:
367
372
368 >>> dv_1 = client[1]
373 >>> dv_1 = client[1]
369 >>> dv_all = client[:]
374 >>> dv_all = client[:]
370 >>> dv_even = client[::2]
375 >>> dv_even = client[::2]
371 >>> dv_some = client[1:3]
376 >>> dv_some = client[1:3]
372
377
373 This object provides dictionary access to engine namespaces:
378 This object provides dictionary access to engine namespaces:
374
379
375 # push a=5:
380 # push a=5:
376 >>> dv['a'] = 5
381 >>> dv['a'] = 5
377 # pull 'foo':
382 # pull 'foo':
378 >>> db['foo']
383 >>> db['foo']
379
384
380 """
385 """
381
386
382 def __init__(self, client=None, socket=None, targets=None):
387 def __init__(self, client=None, socket=None, targets=None):
383 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
388 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
384
389
385 @property
390 @property
386 def importer(self):
391 def importer(self):
387 """sync_imports(local=True) as a property.
392 """sync_imports(local=True) as a property.
388
393
389 See sync_imports for details.
394 See sync_imports for details.
390
395
391 """
396 """
392 return self.sync_imports(True)
397 return self.sync_imports(True)
393
398
394 @contextmanager
399 @contextmanager
395 def sync_imports(self, local=True):
400 def sync_imports(self, local=True):
396 """Context Manager for performing simultaneous local and remote imports.
401 """Context Manager for performing simultaneous local and remote imports.
397
402
398 'import x as y' will *not* work. The 'as y' part will simply be ignored.
403 'import x as y' will *not* work. The 'as y' part will simply be ignored.
399
404
400 >>> with view.sync_imports():
405 >>> with view.sync_imports():
401 ... from numpy import recarray
406 ... from numpy import recarray
402 importing recarray from numpy on engine(s)
407 importing recarray from numpy on engine(s)
403
408
404 """
409 """
405 import __builtin__
410 import __builtin__
406 local_import = __builtin__.__import__
411 local_import = __builtin__.__import__
407 modules = set()
412 modules = set()
408 results = []
413 results = []
409 @util.interactive
414 @util.interactive
410 def remote_import(name, fromlist, level):
415 def remote_import(name, fromlist, level):
411 """the function to be passed to apply, that actually performs the import
416 """the function to be passed to apply, that actually performs the import
412 on the engine, and loads up the user namespace.
417 on the engine, and loads up the user namespace.
413 """
418 """
414 import sys
419 import sys
415 user_ns = globals()
420 user_ns = globals()
416 mod = __import__(name, fromlist=fromlist, level=level)
421 mod = __import__(name, fromlist=fromlist, level=level)
417 if fromlist:
422 if fromlist:
418 for key in fromlist:
423 for key in fromlist:
419 user_ns[key] = getattr(mod, key)
424 user_ns[key] = getattr(mod, key)
420 else:
425 else:
421 user_ns[name] = sys.modules[name]
426 user_ns[name] = sys.modules[name]
422
427
423 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
428 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
424 """the drop-in replacement for __import__, that optionally imports
429 """the drop-in replacement for __import__, that optionally imports
425 locally as well.
430 locally as well.
426 """
431 """
427 # don't override nested imports
432 # don't override nested imports
428 save_import = __builtin__.__import__
433 save_import = __builtin__.__import__
429 __builtin__.__import__ = local_import
434 __builtin__.__import__ = local_import
430
435
431 if imp.lock_held():
436 if imp.lock_held():
432 # this is a side-effect import, don't do it remotely, or even
437 # this is a side-effect import, don't do it remotely, or even
433 # ignore the local effects
438 # ignore the local effects
434 return local_import(name, globals, locals, fromlist, level)
439 return local_import(name, globals, locals, fromlist, level)
435
440
436 imp.acquire_lock()
441 imp.acquire_lock()
437 if local:
442 if local:
438 mod = local_import(name, globals, locals, fromlist, level)
443 mod = local_import(name, globals, locals, fromlist, level)
439 else:
444 else:
440 raise NotImplementedError("remote-only imports not yet implemented")
445 raise NotImplementedError("remote-only imports not yet implemented")
441 imp.release_lock()
446 imp.release_lock()
442
447
443 key = name+':'+','.join(fromlist or [])
448 key = name+':'+','.join(fromlist or [])
444 if level == -1 and key not in modules:
449 if level == -1 and key not in modules:
445 modules.add(key)
450 modules.add(key)
446 if fromlist:
451 if fromlist:
447 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
452 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
448 else:
453 else:
449 print "importing %s on engine(s)"%name
454 print "importing %s on engine(s)"%name
450 results.append(self.apply_async(remote_import, name, fromlist, level))
455 results.append(self.apply_async(remote_import, name, fromlist, level))
451 # restore override
456 # restore override
452 __builtin__.__import__ = save_import
457 __builtin__.__import__ = save_import
453
458
454 return mod
459 return mod
455
460
456 # override __import__
461 # override __import__
457 __builtin__.__import__ = view_import
462 __builtin__.__import__ = view_import
458 try:
463 try:
459 # enter the block
464 # enter the block
460 yield
465 yield
461 except ImportError:
466 except ImportError:
462 if not local:
467 if not local:
463 # ignore import errors if not doing local imports
468 # ignore import errors if not doing local imports
464 pass
469 pass
465 finally:
470 finally:
466 # always restore __import__
471 # always restore __import__
467 __builtin__.__import__ = local_import
472 __builtin__.__import__ = local_import
468
473
469 for r in results:
474 for r in results:
470 # raise possible remote ImportErrors here
475 # raise possible remote ImportErrors here
471 r.get()
476 r.get()
472
477
473
478
474 @sync_results
479 @sync_results
475 @save_ids
480 @save_ids
476 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
481 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
477 """calls f(*args, **kwargs) on remote engines, returning the result.
482 """calls f(*args, **kwargs) on remote engines, returning the result.
478
483
479 This method sets all of `apply`'s flags via this View's attributes.
484 This method sets all of `apply`'s flags via this View's attributes.
480
485
481 Parameters
486 Parameters
482 ----------
487 ----------
483
488
484 f : callable
489 f : callable
485
490
486 args : list [default: empty]
491 args : list [default: empty]
487
492
488 kwargs : dict [default: empty]
493 kwargs : dict [default: empty]
489
494
490 targets : target list [default: self.targets]
495 targets : target list [default: self.targets]
491 where to run
496 where to run
492 block : bool [default: self.block]
497 block : bool [default: self.block]
493 whether to block
498 whether to block
494 track : bool [default: self.track]
499 track : bool [default: self.track]
495 whether to ask zmq to track the message, for safe non-copying sends
500 whether to ask zmq to track the message, for safe non-copying sends
496
501
497 Returns
502 Returns
498 -------
503 -------
499
504
500 if self.block is False:
505 if self.block is False:
501 returns AsyncResult
506 returns AsyncResult
502 else:
507 else:
503 returns actual result of f(*args, **kwargs) on the engine(s)
508 returns actual result of f(*args, **kwargs) on the engine(s)
504 This will be a list of self.targets is also a list (even length 1), or
509 This will be a list of self.targets is also a list (even length 1), or
505 the single result if self.targets is an integer engine id
510 the single result if self.targets is an integer engine id
506 """
511 """
507 args = [] if args is None else args
512 args = [] if args is None else args
508 kwargs = {} if kwargs is None else kwargs
513 kwargs = {} if kwargs is None else kwargs
509 block = self.block if block is None else block
514 block = self.block if block is None else block
510 track = self.track if track is None else track
515 track = self.track if track is None else track
511 targets = self.targets if targets is None else targets
516 targets = self.targets if targets is None else targets
512
517
513 _idents = self.client._build_targets(targets)[0]
518 _idents = self.client._build_targets(targets)[0]
514 msg_ids = []
519 msg_ids = []
515 trackers = []
520 trackers = []
516 for ident in _idents:
521 for ident in _idents:
517 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
522 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
518 ident=ident)
523 ident=ident)
519 if track:
524 if track:
520 trackers.append(msg['tracker'])
525 trackers.append(msg['tracker'])
521 msg_ids.append(msg['msg_id'])
526 msg_ids.append(msg['msg_id'])
522 tracker = None if track is False else zmq.MessageTracker(*trackers)
527 tracker = None if track is False else zmq.MessageTracker(*trackers)
523 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
528 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
524 if block:
529 if block:
525 try:
530 try:
526 return ar.get()
531 return ar.get()
527 except KeyboardInterrupt:
532 except KeyboardInterrupt:
528 pass
533 pass
529 return ar
534 return ar
530
535
531 @spin_after
536 @spin_after
532 def map(self, f, *sequences, **kwargs):
537 def map(self, f, *sequences, **kwargs):
533 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
538 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
534
539
535 Parallel version of builtin `map`, using this View's `targets`.
540 Parallel version of builtin `map`, using this View's `targets`.
536
541
537 There will be one task per target, so work will be chunked
542 There will be one task per target, so work will be chunked
538 if the sequences are longer than `targets`.
543 if the sequences are longer than `targets`.
539
544
540 Results can be iterated as they are ready, but will become available in chunks.
545 Results can be iterated as they are ready, but will become available in chunks.
541
546
542 Parameters
547 Parameters
543 ----------
548 ----------
544
549
545 f : callable
550 f : callable
546 function to be mapped
551 function to be mapped
547 *sequences: one or more sequences of matching length
552 *sequences: one or more sequences of matching length
548 the sequences to be distributed and passed to `f`
553 the sequences to be distributed and passed to `f`
549 block : bool
554 block : bool
550 whether to wait for the result or not [default self.block]
555 whether to wait for the result or not [default self.block]
551
556
552 Returns
557 Returns
553 -------
558 -------
554
559
555 if block=False:
560 if block=False:
556 AsyncMapResult
561 AsyncMapResult
557 An object like AsyncResult, but which reassembles the sequence of results
562 An object like AsyncResult, but which reassembles the sequence of results
558 into a single list. AsyncMapResults can be iterated through before all
563 into a single list. AsyncMapResults can be iterated through before all
559 results are complete.
564 results are complete.
560 else:
565 else:
561 list
566 list
562 the result of map(f,*sequences)
567 the result of map(f,*sequences)
563 """
568 """
564
569
565 block = kwargs.pop('block', self.block)
570 block = kwargs.pop('block', self.block)
566 for k in kwargs.keys():
571 for k in kwargs.keys():
567 if k not in ['block', 'track']:
572 if k not in ['block', 'track']:
568 raise TypeError("invalid keyword arg, %r"%k)
573 raise TypeError("invalid keyword arg, %r"%k)
569
574
570 assert len(sequences) > 0, "must have some sequences to map onto!"
575 assert len(sequences) > 0, "must have some sequences to map onto!"
571 pf = ParallelFunction(self, f, block=block, **kwargs)
576 pf = ParallelFunction(self, f, block=block, **kwargs)
572 return pf.map(*sequences)
577 return pf.map(*sequences)
573
578
574 def execute(self, code, targets=None, block=None):
579 def execute(self, code, targets=None, block=None):
575 """Executes `code` on `targets` in blocking or nonblocking manner.
580 """Executes `code` on `targets` in blocking or nonblocking manner.
576
581
577 ``execute`` is always `bound` (affects engine namespace)
582 ``execute`` is always `bound` (affects engine namespace)
578
583
579 Parameters
584 Parameters
580 ----------
585 ----------
581
586
582 code : str
587 code : str
583 the code string to be executed
588 the code string to be executed
584 block : bool
589 block : bool
585 whether or not to wait until done to return
590 whether or not to wait until done to return
586 default: self.block
591 default: self.block
587 """
592 """
588 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
593 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
589
594
590 def run(self, filename, targets=None, block=None):
595 def run(self, filename, targets=None, block=None):
591 """Execute contents of `filename` on my engine(s).
596 """Execute contents of `filename` on my engine(s).
592
597
593 This simply reads the contents of the file and calls `execute`.
598 This simply reads the contents of the file and calls `execute`.
594
599
595 Parameters
600 Parameters
596 ----------
601 ----------
597
602
598 filename : str
603 filename : str
599 The path to the file
604 The path to the file
600 targets : int/str/list of ints/strs
605 targets : int/str/list of ints/strs
601 the engines on which to execute
606 the engines on which to execute
602 default : all
607 default : all
603 block : bool
608 block : bool
604 whether or not to wait until done
609 whether or not to wait until done
605 default: self.block
610 default: self.block
606
611
607 """
612 """
608 with open(filename, 'r') as f:
613 with open(filename, 'r') as f:
609 # add newline in case of trailing indented whitespace
614 # add newline in case of trailing indented whitespace
610 # which will cause SyntaxError
615 # which will cause SyntaxError
611 code = f.read()+'\n'
616 code = f.read()+'\n'
612 return self.execute(code, block=block, targets=targets)
617 return self.execute(code, block=block, targets=targets)
613
618
614 def update(self, ns):
619 def update(self, ns):
615 """update remote namespace with dict `ns`
620 """update remote namespace with dict `ns`
616
621
617 See `push` for details.
622 See `push` for details.
618 """
623 """
619 return self.push(ns, block=self.block, track=self.track)
624 return self.push(ns, block=self.block, track=self.track)
620
625
621 def push(self, ns, targets=None, block=None, track=None):
626 def push(self, ns, targets=None, block=None, track=None):
622 """update remote namespace with dict `ns`
627 """update remote namespace with dict `ns`
623
628
624 Parameters
629 Parameters
625 ----------
630 ----------
626
631
627 ns : dict
632 ns : dict
628 dict of keys with which to update engine namespace(s)
633 dict of keys with which to update engine namespace(s)
629 block : bool [default : self.block]
634 block : bool [default : self.block]
630 whether to wait to be notified of engine receipt
635 whether to wait to be notified of engine receipt
631
636
632 """
637 """
633
638
634 block = block if block is not None else self.block
639 block = block if block is not None else self.block
635 track = track if track is not None else self.track
640 track = track if track is not None else self.track
636 targets = targets if targets is not None else self.targets
641 targets = targets if targets is not None else self.targets
637 # applier = self.apply_sync if block else self.apply_async
642 # applier = self.apply_sync if block else self.apply_async
638 if not isinstance(ns, dict):
643 if not isinstance(ns, dict):
639 raise TypeError("Must be a dict, not %s"%type(ns))
644 raise TypeError("Must be a dict, not %s"%type(ns))
640 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
645 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
641
646
642 def get(self, key_s):
647 def get(self, key_s):
643 """get object(s) by `key_s` from remote namespace
648 """get object(s) by `key_s` from remote namespace
644
649
645 see `pull` for details.
650 see `pull` for details.
646 """
651 """
647 # block = block if block is not None else self.block
652 # block = block if block is not None else self.block
648 return self.pull(key_s, block=True)
653 return self.pull(key_s, block=True)
649
654
650 def pull(self, names, targets=None, block=None):
655 def pull(self, names, targets=None, block=None):
651 """get object(s) by `name` from remote namespace
656 """get object(s) by `name` from remote namespace
652
657
653 will return one object if it is a key.
658 will return one object if it is a key.
654 can also take a list of keys, in which case it will return a list of objects.
659 can also take a list of keys, in which case it will return a list of objects.
655 """
660 """
656 block = block if block is not None else self.block
661 block = block if block is not None else self.block
657 targets = targets if targets is not None else self.targets
662 targets = targets if targets is not None else self.targets
658 applier = self.apply_sync if block else self.apply_async
663 applier = self.apply_sync if block else self.apply_async
659 if isinstance(names, basestring):
664 if isinstance(names, basestring):
660 pass
665 pass
661 elif isinstance(names, (list,tuple,set)):
666 elif isinstance(names, (list,tuple,set)):
662 for key in names:
667 for key in names:
663 if not isinstance(key, basestring):
668 if not isinstance(key, basestring):
664 raise TypeError("keys must be str, not type %r"%type(key))
669 raise TypeError("keys must be str, not type %r"%type(key))
665 else:
670 else:
666 raise TypeError("names must be strs, not %r"%names)
671 raise TypeError("names must be strs, not %r"%names)
667 return self._really_apply(util._pull, (names,), block=block, targets=targets)
672 return self._really_apply(util._pull, (names,), block=block, targets=targets)
668
673
669 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
674 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
670 """
675 """
671 Partition a Python sequence and send the partitions to a set of engines.
676 Partition a Python sequence and send the partitions to a set of engines.
672 """
677 """
673 block = block if block is not None else self.block
678 block = block if block is not None else self.block
674 track = track if track is not None else self.track
679 track = track if track is not None else self.track
675 targets = targets if targets is not None else self.targets
680 targets = targets if targets is not None else self.targets
676
681
677 mapObject = Map.dists[dist]()
682 mapObject = Map.dists[dist]()
678 nparts = len(targets)
683 nparts = len(targets)
679 msg_ids = []
684 msg_ids = []
680 trackers = []
685 trackers = []
681 for index, engineid in enumerate(targets):
686 for index, engineid in enumerate(targets):
682 partition = mapObject.getPartition(seq, index, nparts)
687 partition = mapObject.getPartition(seq, index, nparts)
683 if flatten and len(partition) == 1:
688 if flatten and len(partition) == 1:
684 ns = {key: partition[0]}
689 ns = {key: partition[0]}
685 else:
690 else:
686 ns = {key: partition}
691 ns = {key: partition}
687 r = self.push(ns, block=False, track=track, targets=engineid)
692 r = self.push(ns, block=False, track=track, targets=engineid)
688 msg_ids.extend(r.msg_ids)
693 msg_ids.extend(r.msg_ids)
689 if track:
694 if track:
690 trackers.append(r._tracker)
695 trackers.append(r._tracker)
691
696
692 if track:
697 if track:
693 tracker = zmq.MessageTracker(*trackers)
698 tracker = zmq.MessageTracker(*trackers)
694 else:
699 else:
695 tracker = None
700 tracker = None
696
701
697 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
702 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
698 if block:
703 if block:
699 r.wait()
704 r.wait()
700 else:
705 else:
701 return r
706 return r
702
707
703 @sync_results
708 @sync_results
704 @save_ids
709 @save_ids
705 def gather(self, key, dist='b', targets=None, block=None):
710 def gather(self, key, dist='b', targets=None, block=None):
706 """
711 """
707 Gather a partitioned sequence on a set of engines as a single local seq.
712 Gather a partitioned sequence on a set of engines as a single local seq.
708 """
713 """
709 block = block if block is not None else self.block
714 block = block if block is not None else self.block
710 targets = targets if targets is not None else self.targets
715 targets = targets if targets is not None else self.targets
711 mapObject = Map.dists[dist]()
716 mapObject = Map.dists[dist]()
712 msg_ids = []
717 msg_ids = []
713
718
714 for index, engineid in enumerate(targets):
719 for index, engineid in enumerate(targets):
715 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
720 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
716
721
717 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
722 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
718
723
719 if block:
724 if block:
720 try:
725 try:
721 return r.get()
726 return r.get()
722 except KeyboardInterrupt:
727 except KeyboardInterrupt:
723 pass
728 pass
724 return r
729 return r
725
730
726 def __getitem__(self, key):
731 def __getitem__(self, key):
727 return self.get(key)
732 return self.get(key)
728
733
729 def __setitem__(self,key, value):
734 def __setitem__(self,key, value):
730 self.update({key:value})
735 self.update({key:value})
731
736
732 def clear(self, targets=None, block=False):
737 def clear(self, targets=None, block=False):
733 """Clear the remote namespaces on my engines."""
738 """Clear the remote namespaces on my engines."""
734 block = block if block is not None else self.block
739 block = block if block is not None else self.block
735 targets = targets if targets is not None else self.targets
740 targets = targets if targets is not None else self.targets
736 return self.client.clear(targets=targets, block=block)
741 return self.client.clear(targets=targets, block=block)
737
742
738 def kill(self, targets=None, block=True):
743 def kill(self, targets=None, block=True):
739 """Kill my engines."""
744 """Kill my engines."""
740 block = block if block is not None else self.block
745 block = block if block is not None else self.block
741 targets = targets if targets is not None else self.targets
746 targets = targets if targets is not None else self.targets
742 return self.client.kill(targets=targets, block=block)
747 return self.client.kill(targets=targets, block=block)
743
748
744 #----------------------------------------
749 #----------------------------------------
745 # activate for %px,%autopx magics
750 # activate for %px,%autopx magics
746 #----------------------------------------
751 #----------------------------------------
747 def activate(self):
752 def activate(self):
748 """Make this `View` active for parallel magic commands.
753 """Make this `View` active for parallel magic commands.
749
754
750 IPython has a magic command syntax to work with `MultiEngineClient` objects.
755 IPython has a magic command syntax to work with `MultiEngineClient` objects.
751 In a given IPython session there is a single active one. While
756 In a given IPython session there is a single active one. While
752 there can be many `Views` created and used by the user,
757 there can be many `Views` created and used by the user,
753 there is only one active one. The active `View` is used whenever
758 there is only one active one. The active `View` is used whenever
754 the magic commands %px and %autopx are used.
759 the magic commands %px and %autopx are used.
755
760
756 The activate() method is called on a given `View` to make it
761 The activate() method is called on a given `View` to make it
757 active. Once this has been done, the magic commands can be used.
762 active. Once this has been done, the magic commands can be used.
758 """
763 """
759
764
760 try:
765 try:
761 # This is injected into __builtins__.
766 # This is injected into __builtins__.
762 ip = get_ipython()
767 ip = get_ipython()
763 except NameError:
768 except NameError:
764 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
769 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
765 else:
770 else:
766 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
771 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
767 if pmagic is None:
772 if pmagic is None:
768 ip.magic_load_ext('parallelmagic')
773 ip.magic_load_ext('parallelmagic')
769 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
774 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
770
775
771 pmagic.active_view = self
776 pmagic.active_view = self
772
777
773
778
774 @skip_doctest
779 @skip_doctest
775 class LoadBalancedView(View):
780 class LoadBalancedView(View):
776 """An load-balancing View that only executes via the Task scheduler.
781 """An load-balancing View that only executes via the Task scheduler.
777
782
778 Load-balanced views can be created with the client's `view` method:
783 Load-balanced views can be created with the client's `view` method:
779
784
780 >>> v = client.load_balanced_view()
785 >>> v = client.load_balanced_view()
781
786
782 or targets can be specified, to restrict the potential destinations:
787 or targets can be specified, to restrict the potential destinations:
783
788
784 >>> v = client.client.load_balanced_view(([1,3])
789 >>> v = client.client.load_balanced_view(([1,3])
785
790
786 which would restrict loadbalancing to between engines 1 and 3.
791 which would restrict loadbalancing to between engines 1 and 3.
787
792
788 """
793 """
789
794
790 follow=Any()
795 follow=Any()
791 after=Any()
796 after=Any()
792 timeout=CFloat()
797 timeout=CFloat()
793 retries = CInt(0)
798 retries = CInt(0)
794
799
795 _task_scheme = Any()
800 _task_scheme = Any()
796 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
801 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
797
802
798 def __init__(self, client=None, socket=None, **flags):
803 def __init__(self, client=None, socket=None, **flags):
799 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
804 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
800 self._task_scheme=client._task_scheme
805 self._task_scheme=client._task_scheme
801
806
802 def _validate_dependency(self, dep):
807 def _validate_dependency(self, dep):
803 """validate a dependency.
808 """validate a dependency.
804
809
805 For use in `set_flags`.
810 For use in `set_flags`.
806 """
811 """
807 if dep is None or isinstance(dep, (str, AsyncResult, Dependency)):
812 if dep is None or isinstance(dep, (str, AsyncResult, Dependency)):
808 return True
813 return True
809 elif isinstance(dep, (list,set, tuple)):
814 elif isinstance(dep, (list,set, tuple)):
810 for d in dep:
815 for d in dep:
811 if not isinstance(d, (str, AsyncResult)):
816 if not isinstance(d, (str, AsyncResult)):
812 return False
817 return False
813 elif isinstance(dep, dict):
818 elif isinstance(dep, dict):
814 if set(dep.keys()) != set(Dependency().as_dict().keys()):
819 if set(dep.keys()) != set(Dependency().as_dict().keys()):
815 return False
820 return False
816 if not isinstance(dep['msg_ids'], list):
821 if not isinstance(dep['msg_ids'], list):
817 return False
822 return False
818 for d in dep['msg_ids']:
823 for d in dep['msg_ids']:
819 if not isinstance(d, str):
824 if not isinstance(d, str):
820 return False
825 return False
821 else:
826 else:
822 return False
827 return False
823
828
824 return True
829 return True
825
830
826 def _render_dependency(self, dep):
831 def _render_dependency(self, dep):
827 """helper for building jsonable dependencies from various input forms."""
832 """helper for building jsonable dependencies from various input forms."""
828 if isinstance(dep, Dependency):
833 if isinstance(dep, Dependency):
829 return dep.as_dict()
834 return dep.as_dict()
830 elif isinstance(dep, AsyncResult):
835 elif isinstance(dep, AsyncResult):
831 return dep.msg_ids
836 return dep.msg_ids
832 elif dep is None:
837 elif dep is None:
833 return []
838 return []
834 else:
839 else:
835 # pass to Dependency constructor
840 # pass to Dependency constructor
836 return list(Dependency(dep))
841 return list(Dependency(dep))
837
842
838 def set_flags(self, **kwargs):
843 def set_flags(self, **kwargs):
839 """set my attribute flags by keyword.
844 """set my attribute flags by keyword.
840
845
841 A View is a wrapper for the Client's apply method, but with attributes
846 A View is a wrapper for the Client's apply method, but with attributes
842 that specify keyword arguments, those attributes can be set by keyword
847 that specify keyword arguments, those attributes can be set by keyword
843 argument with this method.
848 argument with this method.
844
849
845 Parameters
850 Parameters
846 ----------
851 ----------
847
852
848 block : bool
853 block : bool
849 whether to wait for results
854 whether to wait for results
850 track : bool
855 track : bool
851 whether to create a MessageTracker to allow the user to
856 whether to create a MessageTracker to allow the user to
852 safely edit after arrays and buffers during non-copying
857 safely edit after arrays and buffers during non-copying
853 sends.
858 sends.
854
859
855 after : Dependency or collection of msg_ids
860 after : Dependency or collection of msg_ids
856 Only for load-balanced execution (targets=None)
861 Only for load-balanced execution (targets=None)
857 Specify a list of msg_ids as a time-based dependency.
862 Specify a list of msg_ids as a time-based dependency.
858 This job will only be run *after* the dependencies
863 This job will only be run *after* the dependencies
859 have been met.
864 have been met.
860
865
861 follow : Dependency or collection of msg_ids
866 follow : Dependency or collection of msg_ids
862 Only for load-balanced execution (targets=None)
867 Only for load-balanced execution (targets=None)
863 Specify a list of msg_ids as a location-based dependency.
868 Specify a list of msg_ids as a location-based dependency.
864 This job will only be run on an engine where this dependency
869 This job will only be run on an engine where this dependency
865 is met.
870 is met.
866
871
867 timeout : float/int or None
872 timeout : float/int or None
868 Only for load-balanced execution (targets=None)
873 Only for load-balanced execution (targets=None)
869 Specify an amount of time (in seconds) for the scheduler to
874 Specify an amount of time (in seconds) for the scheduler to
870 wait for dependencies to be met before failing with a
875 wait for dependencies to be met before failing with a
871 DependencyTimeout.
876 DependencyTimeout.
872
877
873 retries : int
878 retries : int
874 Number of times a task will be retried on failure.
879 Number of times a task will be retried on failure.
875 """
880 """
876
881
877 super(LoadBalancedView, self).set_flags(**kwargs)
882 super(LoadBalancedView, self).set_flags(**kwargs)
878 for name in ('follow', 'after'):
883 for name in ('follow', 'after'):
879 if name in kwargs:
884 if name in kwargs:
880 value = kwargs[name]
885 value = kwargs[name]
881 if self._validate_dependency(value):
886 if self._validate_dependency(value):
882 setattr(self, name, value)
887 setattr(self, name, value)
883 else:
888 else:
884 raise ValueError("Invalid dependency: %r"%value)
889 raise ValueError("Invalid dependency: %r"%value)
885 if 'timeout' in kwargs:
890 if 'timeout' in kwargs:
886 t = kwargs['timeout']
891 t = kwargs['timeout']
887 if not isinstance(t, (int, long, float, type(None))):
892 if not isinstance(t, (int, long, float, type(None))):
888 raise TypeError("Invalid type for timeout: %r"%type(t))
893 raise TypeError("Invalid type for timeout: %r"%type(t))
889 if t is not None:
894 if t is not None:
890 if t < 0:
895 if t < 0:
891 raise ValueError("Invalid timeout: %s"%t)
896 raise ValueError("Invalid timeout: %s"%t)
892 self.timeout = t
897 self.timeout = t
893
898
894 @sync_results
899 @sync_results
895 @save_ids
900 @save_ids
896 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
901 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
897 after=None, follow=None, timeout=None,
902 after=None, follow=None, timeout=None,
898 targets=None, retries=None):
903 targets=None, retries=None):
899 """calls f(*args, **kwargs) on a remote engine, returning the result.
904 """calls f(*args, **kwargs) on a remote engine, returning the result.
900
905
901 This method temporarily sets all of `apply`'s flags for a single call.
906 This method temporarily sets all of `apply`'s flags for a single call.
902
907
903 Parameters
908 Parameters
904 ----------
909 ----------
905
910
906 f : callable
911 f : callable
907
912
908 args : list [default: empty]
913 args : list [default: empty]
909
914
910 kwargs : dict [default: empty]
915 kwargs : dict [default: empty]
911
916
912 block : bool [default: self.block]
917 block : bool [default: self.block]
913 whether to block
918 whether to block
914 track : bool [default: self.track]
919 track : bool [default: self.track]
915 whether to ask zmq to track the message, for safe non-copying sends
920 whether to ask zmq to track the message, for safe non-copying sends
916
921
917 !!!!!! TODO: THE REST HERE !!!!
922 !!!!!! TODO: THE REST HERE !!!!
918
923
919 Returns
924 Returns
920 -------
925 -------
921
926
922 if self.block is False:
927 if self.block is False:
923 returns AsyncResult
928 returns AsyncResult
924 else:
929 else:
925 returns actual result of f(*args, **kwargs) on the engine(s)
930 returns actual result of f(*args, **kwargs) on the engine(s)
926 This will be a list of self.targets is also a list (even length 1), or
931 This will be a list of self.targets is also a list (even length 1), or
927 the single result if self.targets is an integer engine id
932 the single result if self.targets is an integer engine id
928 """
933 """
929
934
930 # validate whether we can run
935 # validate whether we can run
931 if self._socket.closed:
936 if self._socket.closed:
932 msg = "Task farming is disabled"
937 msg = "Task farming is disabled"
933 if self._task_scheme == 'pure':
938 if self._task_scheme == 'pure':
934 msg += " because the pure ZMQ scheduler cannot handle"
939 msg += " because the pure ZMQ scheduler cannot handle"
935 msg += " disappearing engines."
940 msg += " disappearing engines."
936 raise RuntimeError(msg)
941 raise RuntimeError(msg)
937
942
938 if self._task_scheme == 'pure':
943 if self._task_scheme == 'pure':
939 # pure zmq scheme doesn't support extra features
944 # pure zmq scheme doesn't support extra features
940 msg = "Pure ZMQ scheduler doesn't support the following flags:"
945 msg = "Pure ZMQ scheduler doesn't support the following flags:"
941 "follow, after, retries, targets, timeout"
946 "follow, after, retries, targets, timeout"
942 if (follow or after or retries or targets or timeout):
947 if (follow or after or retries or targets or timeout):
943 # hard fail on Scheduler flags
948 # hard fail on Scheduler flags
944 raise RuntimeError(msg)
949 raise RuntimeError(msg)
945 if isinstance(f, dependent):
950 if isinstance(f, dependent):
946 # soft warn on functional dependencies
951 # soft warn on functional dependencies
947 warnings.warn(msg, RuntimeWarning)
952 warnings.warn(msg, RuntimeWarning)
948
953
949 # build args
954 # build args
950 args = [] if args is None else args
955 args = [] if args is None else args
951 kwargs = {} if kwargs is None else kwargs
956 kwargs = {} if kwargs is None else kwargs
952 block = self.block if block is None else block
957 block = self.block if block is None else block
953 track = self.track if track is None else track
958 track = self.track if track is None else track
954 after = self.after if after is None else after
959 after = self.after if after is None else after
955 retries = self.retries if retries is None else retries
960 retries = self.retries if retries is None else retries
956 follow = self.follow if follow is None else follow
961 follow = self.follow if follow is None else follow
957 timeout = self.timeout if timeout is None else timeout
962 timeout = self.timeout if timeout is None else timeout
958 targets = self.targets if targets is None else targets
963 targets = self.targets if targets is None else targets
959
964
960 if not isinstance(retries, int):
965 if not isinstance(retries, int):
961 raise TypeError('retries must be int, not %r'%type(retries))
966 raise TypeError('retries must be int, not %r'%type(retries))
962
967
963 if targets is None:
968 if targets is None:
964 idents = []
969 idents = []
965 else:
970 else:
966 idents = self.client._build_targets(targets)[0]
971 idents = self.client._build_targets(targets)[0]
967
972
968 after = self._render_dependency(after)
973 after = self._render_dependency(after)
969 follow = self._render_dependency(follow)
974 follow = self._render_dependency(follow)
970 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
975 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
971
976
972 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
977 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
973 subheader=subheader)
978 subheader=subheader)
974 tracker = None if track is False else msg['tracker']
979 tracker = None if track is False else msg['tracker']
975
980
976 ar = AsyncResult(self.client, msg['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
981 ar = AsyncResult(self.client, msg['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
977
982
978 if block:
983 if block:
979 try:
984 try:
980 return ar.get()
985 return ar.get()
981 except KeyboardInterrupt:
986 except KeyboardInterrupt:
982 pass
987 pass
983 return ar
988 return ar
984
989
985 @spin_after
990 @spin_after
986 @save_ids
991 @save_ids
987 def map(self, f, *sequences, **kwargs):
992 def map(self, f, *sequences, **kwargs):
988 """view.map(f, *sequences, block=self.block, chunksize=1) => list|AsyncMapResult
993 """view.map(f, *sequences, block=self.block, chunksize=1) => list|AsyncMapResult
989
994
990 Parallel version of builtin `map`, load-balanced by this View.
995 Parallel version of builtin `map`, load-balanced by this View.
991
996
992 `block`, and `chunksize` can be specified by keyword only.
997 `block`, and `chunksize` can be specified by keyword only.
993
998
994 Each `chunksize` elements will be a separate task, and will be
999 Each `chunksize` elements will be a separate task, and will be
995 load-balanced. This lets individual elements be available for iteration
1000 load-balanced. This lets individual elements be available for iteration
996 as soon as they arrive.
1001 as soon as they arrive.
997
1002
998 Parameters
1003 Parameters
999 ----------
1004 ----------
1000
1005
1001 f : callable
1006 f : callable
1002 function to be mapped
1007 function to be mapped
1003 *sequences: one or more sequences of matching length
1008 *sequences: one or more sequences of matching length
1004 the sequences to be distributed and passed to `f`
1009 the sequences to be distributed and passed to `f`
1005 block : bool
1010 block : bool
1006 whether to wait for the result or not [default self.block]
1011 whether to wait for the result or not [default self.block]
1007 track : bool
1012 track : bool
1008 whether to create a MessageTracker to allow the user to
1013 whether to create a MessageTracker to allow the user to
1009 safely edit after arrays and buffers during non-copying
1014 safely edit after arrays and buffers during non-copying
1010 sends.
1015 sends.
1011 chunksize : int
1016 chunksize : int
1012 how many elements should be in each task [default 1]
1017 how many elements should be in each task [default 1]
1013
1018
1014 Returns
1019 Returns
1015 -------
1020 -------
1016
1021
1017 if block=False:
1022 if block=False:
1018 AsyncMapResult
1023 AsyncMapResult
1019 An object like AsyncResult, but which reassembles the sequence of results
1024 An object like AsyncResult, but which reassembles the sequence of results
1020 into a single list. AsyncMapResults can be iterated through before all
1025 into a single list. AsyncMapResults can be iterated through before all
1021 results are complete.
1026 results are complete.
1022 else:
1027 else:
1023 the result of map(f,*sequences)
1028 the result of map(f,*sequences)
1024
1029
1025 """
1030 """
1026
1031
1027 # default
1032 # default
1028 block = kwargs.get('block', self.block)
1033 block = kwargs.get('block', self.block)
1029 chunksize = kwargs.get('chunksize', 1)
1034 chunksize = kwargs.get('chunksize', 1)
1030
1035
1031 keyset = set(kwargs.keys())
1036 keyset = set(kwargs.keys())
1032 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1037 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1033 if extra_keys:
1038 if extra_keys:
1034 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1039 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1035
1040
1036 assert len(sequences) > 0, "must have some sequences to map onto!"
1041 assert len(sequences) > 0, "must have some sequences to map onto!"
1037
1042
1038 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
1043 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
1039 return pf.map(*sequences)
1044 return pf.map(*sequences)
1040
1045
1041 __all__ = ['LoadBalancedView', 'DirectView']
1046 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,196 +1,201 b''
1 """Dependency utilities"""
1 """Dependency utilities
2
3 Authors:
4
5 * Min RK
6 """
2 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
4 #
9 #
5 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
8
13
9 from types import ModuleType
14 from types import ModuleType
10
15
11 from IPython.parallel.client.asyncresult import AsyncResult
16 from IPython.parallel.client.asyncresult import AsyncResult
12 from IPython.parallel.error import UnmetDependency
17 from IPython.parallel.error import UnmetDependency
13 from IPython.parallel.util import interactive
18 from IPython.parallel.util import interactive
14
19
15 class depend(object):
20 class depend(object):
16 """Dependency decorator, for use with tasks.
21 """Dependency decorator, for use with tasks.
17
22
18 `@depend` lets you define a function for engine dependencies
23 `@depend` lets you define a function for engine dependencies
19 just like you use `apply` for tasks.
24 just like you use `apply` for tasks.
20
25
21
26
22 Examples
27 Examples
23 --------
28 --------
24 ::
29 ::
25
30
26 @depend(df, a,b, c=5)
31 @depend(df, a,b, c=5)
27 def f(m,n,p)
32 def f(m,n,p)
28
33
29 view.apply(f, 1,2,3)
34 view.apply(f, 1,2,3)
30
35
31 will call df(a,b,c=5) on the engine, and if it returns False or
36 will call df(a,b,c=5) on the engine, and if it returns False or
32 raises an UnmetDependency error, then the task will not be run
37 raises an UnmetDependency error, then the task will not be run
33 and another engine will be tried.
38 and another engine will be tried.
34 """
39 """
35 def __init__(self, f, *args, **kwargs):
40 def __init__(self, f, *args, **kwargs):
36 self.f = f
41 self.f = f
37 self.args = args
42 self.args = args
38 self.kwargs = kwargs
43 self.kwargs = kwargs
39
44
40 def __call__(self, f):
45 def __call__(self, f):
41 return dependent(f, self.f, *self.args, **self.kwargs)
46 return dependent(f, self.f, *self.args, **self.kwargs)
42
47
43 class dependent(object):
48 class dependent(object):
44 """A function that depends on another function.
49 """A function that depends on another function.
45 This is an object to prevent the closure used
50 This is an object to prevent the closure used
46 in traditional decorators, which are not picklable.
51 in traditional decorators, which are not picklable.
47 """
52 """
48
53
49 def __init__(self, f, df, *dargs, **dkwargs):
54 def __init__(self, f, df, *dargs, **dkwargs):
50 self.f = f
55 self.f = f
51 self.func_name = getattr(f, '__name__', 'f')
56 self.func_name = getattr(f, '__name__', 'f')
52 self.df = df
57 self.df = df
53 self.dargs = dargs
58 self.dargs = dargs
54 self.dkwargs = dkwargs
59 self.dkwargs = dkwargs
55
60
56 def __call__(self, *args, **kwargs):
61 def __call__(self, *args, **kwargs):
57 # if hasattr(self.f, 'func_globals') and hasattr(self.df, 'func_globals'):
62 # if hasattr(self.f, 'func_globals') and hasattr(self.df, 'func_globals'):
58 # self.df.func_globals = self.f.func_globals
63 # self.df.func_globals = self.f.func_globals
59 if self.df(*self.dargs, **self.dkwargs) is False:
64 if self.df(*self.dargs, **self.dkwargs) is False:
60 raise UnmetDependency()
65 raise UnmetDependency()
61 return self.f(*args, **kwargs)
66 return self.f(*args, **kwargs)
62
67
63 @property
68 @property
64 def __name__(self):
69 def __name__(self):
65 return self.func_name
70 return self.func_name
66
71
67 @interactive
72 @interactive
68 def _require(*names):
73 def _require(*names):
69 """Helper for @require decorator."""
74 """Helper for @require decorator."""
70 from IPython.parallel.error import UnmetDependency
75 from IPython.parallel.error import UnmetDependency
71 user_ns = globals()
76 user_ns = globals()
72 for name in names:
77 for name in names:
73 if name in user_ns:
78 if name in user_ns:
74 continue
79 continue
75 try:
80 try:
76 exec 'import %s'%name in user_ns
81 exec 'import %s'%name in user_ns
77 except ImportError:
82 except ImportError:
78 raise UnmetDependency(name)
83 raise UnmetDependency(name)
79 return True
84 return True
80
85
81 def require(*mods):
86 def require(*mods):
82 """Simple decorator for requiring names to be importable.
87 """Simple decorator for requiring names to be importable.
83
88
84 Examples
89 Examples
85 --------
90 --------
86
91
87 In [1]: @require('numpy')
92 In [1]: @require('numpy')
88 ...: def norm(a):
93 ...: def norm(a):
89 ...: import numpy
94 ...: import numpy
90 ...: return numpy.linalg.norm(a,2)
95 ...: return numpy.linalg.norm(a,2)
91 """
96 """
92 names = []
97 names = []
93 for mod in mods:
98 for mod in mods:
94 if isinstance(mod, ModuleType):
99 if isinstance(mod, ModuleType):
95 mod = mod.__name__
100 mod = mod.__name__
96
101
97 if isinstance(mod, basestring):
102 if isinstance(mod, basestring):
98 names.append(mod)
103 names.append(mod)
99 else:
104 else:
100 raise TypeError("names must be modules or module names, not %s"%type(mod))
105 raise TypeError("names must be modules or module names, not %s"%type(mod))
101
106
102 return depend(_require, *names)
107 return depend(_require, *names)
103
108
104 class Dependency(set):
109 class Dependency(set):
105 """An object for representing a set of msg_id dependencies.
110 """An object for representing a set of msg_id dependencies.
106
111
107 Subclassed from set().
112 Subclassed from set().
108
113
109 Parameters
114 Parameters
110 ----------
115 ----------
111 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
116 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
112 The msg_ids to depend on
117 The msg_ids to depend on
113 all : bool [default True]
118 all : bool [default True]
114 Whether the dependency should be considered met when *all* depending tasks have completed
119 Whether the dependency should be considered met when *all* depending tasks have completed
115 or only when *any* have been completed.
120 or only when *any* have been completed.
116 success : bool [default True]
121 success : bool [default True]
117 Whether to consider successes as fulfilling dependencies.
122 Whether to consider successes as fulfilling dependencies.
118 failure : bool [default False]
123 failure : bool [default False]
119 Whether to consider failures as fulfilling dependencies.
124 Whether to consider failures as fulfilling dependencies.
120
125
121 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
126 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
122 as soon as the first depended-upon task fails.
127 as soon as the first depended-upon task fails.
123 """
128 """
124
129
125 all=True
130 all=True
126 success=True
131 success=True
127 failure=True
132 failure=True
128
133
129 def __init__(self, dependencies=[], all=True, success=True, failure=False):
134 def __init__(self, dependencies=[], all=True, success=True, failure=False):
130 if isinstance(dependencies, dict):
135 if isinstance(dependencies, dict):
131 # load from dict
136 # load from dict
132 all = dependencies.get('all', True)
137 all = dependencies.get('all', True)
133 success = dependencies.get('success', success)
138 success = dependencies.get('success', success)
134 failure = dependencies.get('failure', failure)
139 failure = dependencies.get('failure', failure)
135 dependencies = dependencies.get('dependencies', [])
140 dependencies = dependencies.get('dependencies', [])
136 ids = []
141 ids = []
137
142
138 # extract ids from various sources:
143 # extract ids from various sources:
139 if isinstance(dependencies, (basestring, AsyncResult)):
144 if isinstance(dependencies, (basestring, AsyncResult)):
140 dependencies = [dependencies]
145 dependencies = [dependencies]
141 for d in dependencies:
146 for d in dependencies:
142 if isinstance(d, basestring):
147 if isinstance(d, basestring):
143 ids.append(d)
148 ids.append(d)
144 elif isinstance(d, AsyncResult):
149 elif isinstance(d, AsyncResult):
145 ids.extend(d.msg_ids)
150 ids.extend(d.msg_ids)
146 else:
151 else:
147 raise TypeError("invalid dependency type: %r"%type(d))
152 raise TypeError("invalid dependency type: %r"%type(d))
148
153
149 set.__init__(self, ids)
154 set.__init__(self, ids)
150 self.all = all
155 self.all = all
151 if not (success or failure):
156 if not (success or failure):
152 raise ValueError("Must depend on at least one of successes or failures!")
157 raise ValueError("Must depend on at least one of successes or failures!")
153 self.success=success
158 self.success=success
154 self.failure = failure
159 self.failure = failure
155
160
156 def check(self, completed, failed=None):
161 def check(self, completed, failed=None):
157 """check whether our dependencies have been met."""
162 """check whether our dependencies have been met."""
158 if len(self) == 0:
163 if len(self) == 0:
159 return True
164 return True
160 against = set()
165 against = set()
161 if self.success:
166 if self.success:
162 against = completed
167 against = completed
163 if failed is not None and self.failure:
168 if failed is not None and self.failure:
164 against = against.union(failed)
169 against = against.union(failed)
165 if self.all:
170 if self.all:
166 return self.issubset(against)
171 return self.issubset(against)
167 else:
172 else:
168 return not self.isdisjoint(against)
173 return not self.isdisjoint(against)
169
174
170 def unreachable(self, completed, failed=None):
175 def unreachable(self, completed, failed=None):
171 """return whether this dependency has become impossible."""
176 """return whether this dependency has become impossible."""
172 if len(self) == 0:
177 if len(self) == 0:
173 return False
178 return False
174 against = set()
179 against = set()
175 if not self.success:
180 if not self.success:
176 against = completed
181 against = completed
177 if failed is not None and not self.failure:
182 if failed is not None and not self.failure:
178 against = against.union(failed)
183 against = against.union(failed)
179 if self.all:
184 if self.all:
180 return not self.isdisjoint(against)
185 return not self.isdisjoint(against)
181 else:
186 else:
182 return self.issubset(against)
187 return self.issubset(against)
183
188
184
189
185 def as_dict(self):
190 def as_dict(self):
186 """Represent this dependency as a dict. For json compatibility."""
191 """Represent this dependency as a dict. For json compatibility."""
187 return dict(
192 return dict(
188 dependencies=list(self),
193 dependencies=list(self),
189 all=self.all,
194 all=self.all,
190 success=self.success,
195 success=self.success,
191 failure=self.failure
196 failure=self.failure
192 )
197 )
193
198
194
199
195 __all__ = ['depend', 'require', 'dependent', 'Dependency']
200 __all__ = ['depend', 'require', 'dependent', 'Dependency']
196
201
@@ -1,180 +1,185 b''
1 """A Task logger that presents our DB interface,
1 """A Task logger that presents our DB interface,
2 but exists entirely in memory and implemented with dicts.
2 but exists entirely in memory and implemented with dicts.
3
3
4 Authors:
5
6 * Min RK
7
8
4 TaskRecords are dicts of the form:
9 TaskRecords are dicts of the form:
5 {
10 {
6 'msg_id' : str(uuid),
11 'msg_id' : str(uuid),
7 'client_uuid' : str(uuid),
12 'client_uuid' : str(uuid),
8 'engine_uuid' : str(uuid) or None,
13 'engine_uuid' : str(uuid) or None,
9 'header' : dict(header),
14 'header' : dict(header),
10 'content': dict(content),
15 'content': dict(content),
11 'buffers': list(buffers),
16 'buffers': list(buffers),
12 'submitted': datetime,
17 'submitted': datetime,
13 'started': datetime or None,
18 'started': datetime or None,
14 'completed': datetime or None,
19 'completed': datetime or None,
15 'resubmitted': datetime or None,
20 'resubmitted': datetime or None,
16 'result_header' : dict(header) or None,
21 'result_header' : dict(header) or None,
17 'result_content' : dict(content) or None,
22 'result_content' : dict(content) or None,
18 'result_buffers' : list(buffers) or None,
23 'result_buffers' : list(buffers) or None,
19 }
24 }
20 With this info, many of the special categories of tasks can be defined by query:
25 With this info, many of the special categories of tasks can be defined by query:
21
26
22 pending: completed is None
27 pending: completed is None
23 client's outstanding: client_uuid = uuid && completed is None
28 client's outstanding: client_uuid = uuid && completed is None
24 MIA: arrived is None (and completed is None)
29 MIA: arrived is None (and completed is None)
25 etc.
30 etc.
26
31
27 EngineRecords are dicts of the form:
32 EngineRecords are dicts of the form:
28 {
33 {
29 'eid' : int(id),
34 'eid' : int(id),
30 'uuid': str(uuid)
35 'uuid': str(uuid)
31 }
36 }
32 This may be extended, but is currently.
37 This may be extended, but is currently.
33
38
34 We support a subset of mongodb operators:
39 We support a subset of mongodb operators:
35 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
40 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
36 """
41 """
37 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
38 # Copyright (C) 2010 The IPython Development Team
43 # Copyright (C) 2010-2011 The IPython Development Team
39 #
44 #
40 # Distributed under the terms of the BSD License. The full license is in
45 # Distributed under the terms of the BSD License. The full license is in
41 # the file COPYING, distributed as part of this software.
46 # the file COPYING, distributed as part of this software.
42 #-----------------------------------------------------------------------------
47 #-----------------------------------------------------------------------------
43
48
44
49
45 from datetime import datetime
50 from datetime import datetime
46
51
47 from IPython.config.configurable import LoggingConfigurable
52 from IPython.config.configurable import LoggingConfigurable
48
53
49 from IPython.utils.traitlets import Dict, Unicode, Instance
54 from IPython.utils.traitlets import Dict, Unicode, Instance
50
55
51 filters = {
56 filters = {
52 '$lt' : lambda a,b: a < b,
57 '$lt' : lambda a,b: a < b,
53 '$gt' : lambda a,b: b > a,
58 '$gt' : lambda a,b: b > a,
54 '$eq' : lambda a,b: a == b,
59 '$eq' : lambda a,b: a == b,
55 '$ne' : lambda a,b: a != b,
60 '$ne' : lambda a,b: a != b,
56 '$lte': lambda a,b: a <= b,
61 '$lte': lambda a,b: a <= b,
57 '$gte': lambda a,b: a >= b,
62 '$gte': lambda a,b: a >= b,
58 '$in' : lambda a,b: a in b,
63 '$in' : lambda a,b: a in b,
59 '$nin': lambda a,b: a not in b,
64 '$nin': lambda a,b: a not in b,
60 '$all': lambda a,b: all([ a in bb for bb in b ]),
65 '$all': lambda a,b: all([ a in bb for bb in b ]),
61 '$mod': lambda a,b: a%b[0] == b[1],
66 '$mod': lambda a,b: a%b[0] == b[1],
62 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
67 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
63 }
68 }
64
69
65
70
66 class CompositeFilter(object):
71 class CompositeFilter(object):
67 """Composite filter for matching multiple properties."""
72 """Composite filter for matching multiple properties."""
68
73
69 def __init__(self, dikt):
74 def __init__(self, dikt):
70 self.tests = []
75 self.tests = []
71 self.values = []
76 self.values = []
72 for key, value in dikt.iteritems():
77 for key, value in dikt.iteritems():
73 self.tests.append(filters[key])
78 self.tests.append(filters[key])
74 self.values.append(value)
79 self.values.append(value)
75
80
76 def __call__(self, value):
81 def __call__(self, value):
77 for test,check in zip(self.tests, self.values):
82 for test,check in zip(self.tests, self.values):
78 if not test(value, check):
83 if not test(value, check):
79 return False
84 return False
80 return True
85 return True
81
86
82 class BaseDB(LoggingConfigurable):
87 class BaseDB(LoggingConfigurable):
83 """Empty Parent class so traitlets work on DB."""
88 """Empty Parent class so traitlets work on DB."""
84 # base configurable traits:
89 # base configurable traits:
85 session = Unicode("")
90 session = Unicode("")
86
91
87 class DictDB(BaseDB):
92 class DictDB(BaseDB):
88 """Basic in-memory dict-based object for saving Task Records.
93 """Basic in-memory dict-based object for saving Task Records.
89
94
90 This is the first object to present the DB interface
95 This is the first object to present the DB interface
91 for logging tasks out of memory.
96 for logging tasks out of memory.
92
97
93 The interface is based on MongoDB, so adding a MongoDB
98 The interface is based on MongoDB, so adding a MongoDB
94 backend should be straightforward.
99 backend should be straightforward.
95 """
100 """
96
101
97 _records = Dict()
102 _records = Dict()
98
103
99 def _match_one(self, rec, tests):
104 def _match_one(self, rec, tests):
100 """Check if a specific record matches tests."""
105 """Check if a specific record matches tests."""
101 for key,test in tests.iteritems():
106 for key,test in tests.iteritems():
102 if not test(rec.get(key, None)):
107 if not test(rec.get(key, None)):
103 return False
108 return False
104 return True
109 return True
105
110
106 def _match(self, check):
111 def _match(self, check):
107 """Find all the matches for a check dict."""
112 """Find all the matches for a check dict."""
108 matches = []
113 matches = []
109 tests = {}
114 tests = {}
110 for k,v in check.iteritems():
115 for k,v in check.iteritems():
111 if isinstance(v, dict):
116 if isinstance(v, dict):
112 tests[k] = CompositeFilter(v)
117 tests[k] = CompositeFilter(v)
113 else:
118 else:
114 tests[k] = lambda o: o==v
119 tests[k] = lambda o: o==v
115
120
116 for rec in self._records.itervalues():
121 for rec in self._records.itervalues():
117 if self._match_one(rec, tests):
122 if self._match_one(rec, tests):
118 matches.append(rec)
123 matches.append(rec)
119 return matches
124 return matches
120
125
121 def _extract_subdict(self, rec, keys):
126 def _extract_subdict(self, rec, keys):
122 """extract subdict of keys"""
127 """extract subdict of keys"""
123 d = {}
128 d = {}
124 d['msg_id'] = rec['msg_id']
129 d['msg_id'] = rec['msg_id']
125 for key in keys:
130 for key in keys:
126 d[key] = rec[key]
131 d[key] = rec[key]
127 return d
132 return d
128
133
129 def add_record(self, msg_id, rec):
134 def add_record(self, msg_id, rec):
130 """Add a new Task Record, by msg_id."""
135 """Add a new Task Record, by msg_id."""
131 if self._records.has_key(msg_id):
136 if self._records.has_key(msg_id):
132 raise KeyError("Already have msg_id %r"%(msg_id))
137 raise KeyError("Already have msg_id %r"%(msg_id))
133 self._records[msg_id] = rec
138 self._records[msg_id] = rec
134
139
135 def get_record(self, msg_id):
140 def get_record(self, msg_id):
136 """Get a specific Task Record, by msg_id."""
141 """Get a specific Task Record, by msg_id."""
137 if not self._records.has_key(msg_id):
142 if not self._records.has_key(msg_id):
138 raise KeyError("No such msg_id %r"%(msg_id))
143 raise KeyError("No such msg_id %r"%(msg_id))
139 return self._records[msg_id]
144 return self._records[msg_id]
140
145
141 def update_record(self, msg_id, rec):
146 def update_record(self, msg_id, rec):
142 """Update the data in an existing record."""
147 """Update the data in an existing record."""
143 self._records[msg_id].update(rec)
148 self._records[msg_id].update(rec)
144
149
145 def drop_matching_records(self, check):
150 def drop_matching_records(self, check):
146 """Remove a record from the DB."""
151 """Remove a record from the DB."""
147 matches = self._match(check)
152 matches = self._match(check)
148 for m in matches:
153 for m in matches:
149 del self._records[m['msg_id']]
154 del self._records[m['msg_id']]
150
155
151 def drop_record(self, msg_id):
156 def drop_record(self, msg_id):
152 """Remove a record from the DB."""
157 """Remove a record from the DB."""
153 del self._records[msg_id]
158 del self._records[msg_id]
154
159
155
160
156 def find_records(self, check, keys=None):
161 def find_records(self, check, keys=None):
157 """Find records matching a query dict, optionally extracting subset of keys.
162 """Find records matching a query dict, optionally extracting subset of keys.
158
163
159 Returns dict keyed by msg_id of matching records.
164 Returns dict keyed by msg_id of matching records.
160
165
161 Parameters
166 Parameters
162 ----------
167 ----------
163
168
164 check: dict
169 check: dict
165 mongodb-style query argument
170 mongodb-style query argument
166 keys: list of strs [optional]
171 keys: list of strs [optional]
167 if specified, the subset of keys to extract. msg_id will *always* be
172 if specified, the subset of keys to extract. msg_id will *always* be
168 included.
173 included.
169 """
174 """
170 matches = self._match(check)
175 matches = self._match(check)
171 if keys:
176 if keys:
172 return [ self._extract_subdict(rec, keys) for rec in matches ]
177 return [ self._extract_subdict(rec, keys) for rec in matches ]
173 else:
178 else:
174 return matches
179 return matches
175
180
176
181
177 def get_history(self):
182 def get_history(self):
178 """get all msg_ids, ordered by time submitted."""
183 """get all msg_ids, ordered by time submitted."""
179 msg_ids = self._records.keys()
184 msg_ids = self._records.keys()
180 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
185 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
@@ -1,165 +1,169 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """
2 """
3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
4 and hearts are tracked based on their XREQ identities.
4 and hearts are tracked based on their XREQ identities.
5
6 Authors:
7
8 * Min RK
5 """
9 """
6 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010-2011 The IPython Development Team
11 # Copyright (C) 2010-2011 The IPython Development Team
8 #
12 #
9 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
12
16
13 from __future__ import print_function
17 from __future__ import print_function
14 import time
18 import time
15 import uuid
19 import uuid
16
20
17 import zmq
21 import zmq
18 from zmq.devices import ThreadDevice
22 from zmq.devices import ThreadDevice
19 from zmq.eventloop import ioloop, zmqstream
23 from zmq.eventloop import ioloop, zmqstream
20
24
21 from IPython.config.configurable import LoggingConfigurable
25 from IPython.config.configurable import LoggingConfigurable
22 from IPython.utils.traitlets import Set, Instance, CFloat
26 from IPython.utils.traitlets import Set, Instance, CFloat
23
27
24 class Heart(object):
28 class Heart(object):
25 """A basic heart object for responding to a HeartMonitor.
29 """A basic heart object for responding to a HeartMonitor.
26 This is a simple wrapper with defaults for the most common
30 This is a simple wrapper with defaults for the most common
27 Device model for responding to heartbeats.
31 Device model for responding to heartbeats.
28
32
29 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
33 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
30 SUB/XREQ for in/out.
34 SUB/XREQ for in/out.
31
35
32 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
36 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
33 device=None
37 device=None
34 id=None
38 id=None
35 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.XREQ, heart_id=None):
39 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.XREQ, heart_id=None):
36 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
40 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
37 self.device.daemon=True
41 self.device.daemon=True
38 self.device.connect_in(in_addr)
42 self.device.connect_in(in_addr)
39 self.device.connect_out(out_addr)
43 self.device.connect_out(out_addr)
40 if in_type == zmq.SUB:
44 if in_type == zmq.SUB:
41 self.device.setsockopt_in(zmq.SUBSCRIBE, "")
45 self.device.setsockopt_in(zmq.SUBSCRIBE, "")
42 if heart_id is None:
46 if heart_id is None:
43 heart_id = str(uuid.uuid4())
47 heart_id = str(uuid.uuid4())
44 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
48 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
45 self.id = heart_id
49 self.id = heart_id
46
50
47 def start(self):
51 def start(self):
48 return self.device.start()
52 return self.device.start()
49
53
50 class HeartMonitor(LoggingConfigurable):
54 class HeartMonitor(LoggingConfigurable):
51 """A basic HeartMonitor class
55 """A basic HeartMonitor class
52 pingstream: a PUB stream
56 pingstream: a PUB stream
53 pongstream: an XREP stream
57 pongstream: an XREP stream
54 period: the period of the heartbeat in milliseconds"""
58 period: the period of the heartbeat in milliseconds"""
55
59
56 period=CFloat(1000, config=True,
60 period=CFloat(1000, config=True,
57 help='The frequency at which the Hub pings the engines for heartbeats '
61 help='The frequency at which the Hub pings the engines for heartbeats '
58 ' (in ms) [default: 100]',
62 ' (in ms) [default: 100]',
59 )
63 )
60
64
61 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
65 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
62 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
66 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
63 loop = Instance('zmq.eventloop.ioloop.IOLoop')
67 loop = Instance('zmq.eventloop.ioloop.IOLoop')
64 def _loop_default(self):
68 def _loop_default(self):
65 return ioloop.IOLoop.instance()
69 return ioloop.IOLoop.instance()
66
70
67 # not settable:
71 # not settable:
68 hearts=Set()
72 hearts=Set()
69 responses=Set()
73 responses=Set()
70 on_probation=Set()
74 on_probation=Set()
71 last_ping=CFloat(0)
75 last_ping=CFloat(0)
72 _new_handlers = Set()
76 _new_handlers = Set()
73 _failure_handlers = Set()
77 _failure_handlers = Set()
74 lifetime = CFloat(0)
78 lifetime = CFloat(0)
75 tic = CFloat(0)
79 tic = CFloat(0)
76
80
77 def __init__(self, **kwargs):
81 def __init__(self, **kwargs):
78 super(HeartMonitor, self).__init__(**kwargs)
82 super(HeartMonitor, self).__init__(**kwargs)
79
83
80 self.pongstream.on_recv(self.handle_pong)
84 self.pongstream.on_recv(self.handle_pong)
81
85
82 def start(self):
86 def start(self):
83 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
87 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
84 self.caller.start()
88 self.caller.start()
85
89
86 def add_new_heart_handler(self, handler):
90 def add_new_heart_handler(self, handler):
87 """add a new handler for new hearts"""
91 """add a new handler for new hearts"""
88 self.log.debug("heartbeat::new_heart_handler: %s"%handler)
92 self.log.debug("heartbeat::new_heart_handler: %s"%handler)
89 self._new_handlers.add(handler)
93 self._new_handlers.add(handler)
90
94
91 def add_heart_failure_handler(self, handler):
95 def add_heart_failure_handler(self, handler):
92 """add a new handler for heart failure"""
96 """add a new handler for heart failure"""
93 self.log.debug("heartbeat::new heart failure handler: %s"%handler)
97 self.log.debug("heartbeat::new heart failure handler: %s"%handler)
94 self._failure_handlers.add(handler)
98 self._failure_handlers.add(handler)
95
99
96 def beat(self):
100 def beat(self):
97 self.pongstream.flush()
101 self.pongstream.flush()
98 self.last_ping = self.lifetime
102 self.last_ping = self.lifetime
99
103
100 toc = time.time()
104 toc = time.time()
101 self.lifetime += toc-self.tic
105 self.lifetime += toc-self.tic
102 self.tic = toc
106 self.tic = toc
103 # self.log.debug("heartbeat::%s"%self.lifetime)
107 # self.log.debug("heartbeat::%s"%self.lifetime)
104 goodhearts = self.hearts.intersection(self.responses)
108 goodhearts = self.hearts.intersection(self.responses)
105 missed_beats = self.hearts.difference(goodhearts)
109 missed_beats = self.hearts.difference(goodhearts)
106 heartfailures = self.on_probation.intersection(missed_beats)
110 heartfailures = self.on_probation.intersection(missed_beats)
107 newhearts = self.responses.difference(goodhearts)
111 newhearts = self.responses.difference(goodhearts)
108 map(self.handle_new_heart, newhearts)
112 map(self.handle_new_heart, newhearts)
109 map(self.handle_heart_failure, heartfailures)
113 map(self.handle_heart_failure, heartfailures)
110 self.on_probation = missed_beats.intersection(self.hearts)
114 self.on_probation = missed_beats.intersection(self.hearts)
111 self.responses = set()
115 self.responses = set()
112 # print self.on_probation, self.hearts
116 # print self.on_probation, self.hearts
113 # self.log.debug("heartbeat::beat %.3f, %i beating hearts"%(self.lifetime, len(self.hearts)))
117 # self.log.debug("heartbeat::beat %.3f, %i beating hearts"%(self.lifetime, len(self.hearts)))
114 self.pingstream.send(str(self.lifetime))
118 self.pingstream.send(str(self.lifetime))
115
119
116 def handle_new_heart(self, heart):
120 def handle_new_heart(self, heart):
117 if self._new_handlers:
121 if self._new_handlers:
118 for handler in self._new_handlers:
122 for handler in self._new_handlers:
119 handler(heart)
123 handler(heart)
120 else:
124 else:
121 self.log.info("heartbeat::yay, got new heart %s!"%heart)
125 self.log.info("heartbeat::yay, got new heart %s!"%heart)
122 self.hearts.add(heart)
126 self.hearts.add(heart)
123
127
124 def handle_heart_failure(self, heart):
128 def handle_heart_failure(self, heart):
125 if self._failure_handlers:
129 if self._failure_handlers:
126 for handler in self._failure_handlers:
130 for handler in self._failure_handlers:
127 try:
131 try:
128 handler(heart)
132 handler(heart)
129 except Exception as e:
133 except Exception as e:
130 self.log.error("heartbeat::Bad Handler! %s"%handler, exc_info=True)
134 self.log.error("heartbeat::Bad Handler! %s"%handler, exc_info=True)
131 pass
135 pass
132 else:
136 else:
133 self.log.info("heartbeat::Heart %s failed :("%heart)
137 self.log.info("heartbeat::Heart %s failed :("%heart)
134 self.hearts.remove(heart)
138 self.hearts.remove(heart)
135
139
136
140
137 def handle_pong(self, msg):
141 def handle_pong(self, msg):
138 "a heart just beat"
142 "a heart just beat"
139 if msg[1] == str(self.lifetime):
143 if msg[1] == str(self.lifetime):
140 delta = time.time()-self.tic
144 delta = time.time()-self.tic
141 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
145 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
142 self.responses.add(msg[0])
146 self.responses.add(msg[0])
143 elif msg[1] == str(self.last_ping):
147 elif msg[1] == str(self.last_ping):
144 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
148 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
145 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond"%(msg[0], 1000*delta))
149 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond"%(msg[0], 1000*delta))
146 self.responses.add(msg[0])
150 self.responses.add(msg[0])
147 else:
151 else:
148 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)"%
152 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)"%
149 (msg[1],self.lifetime))
153 (msg[1],self.lifetime))
150
154
151
155
152 if __name__ == '__main__':
156 if __name__ == '__main__':
153 loop = ioloop.IOLoop.instance()
157 loop = ioloop.IOLoop.instance()
154 context = zmq.Context()
158 context = zmq.Context()
155 pub = context.socket(zmq.PUB)
159 pub = context.socket(zmq.PUB)
156 pub.bind('tcp://127.0.0.1:5555')
160 pub.bind('tcp://127.0.0.1:5555')
157 xrep = context.socket(zmq.XREP)
161 xrep = context.socket(zmq.XREP)
158 xrep.bind('tcp://127.0.0.1:5556')
162 xrep.bind('tcp://127.0.0.1:5556')
159
163
160 outstream = zmqstream.ZMQStream(pub, loop)
164 outstream = zmqstream.ZMQStream(pub, loop)
161 instream = zmqstream.ZMQStream(xrep, loop)
165 instream = zmqstream.ZMQStream(xrep, loop)
162
166
163 hb = HeartMonitor(loop, outstream, instream)
167 hb = HeartMonitor(loop, outstream, instream)
164
168
165 loop.start()
169 loop.start()
@@ -1,1284 +1,1288 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """The IPython Controller Hub with 0MQ
2 """The IPython Controller Hub with 0MQ
3 This is the master object that handles connections from engines and clients,
3 This is the master object that handles connections from engines and clients,
4 and monitors traffic through the various queues.
4 and monitors traffic through the various queues.
5
6 Authors:
7
8 * Min RK
5 """
9 """
6 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010 The IPython Development Team
11 # Copyright (C) 2010 The IPython Development Team
8 #
12 #
9 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
12
16
13 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
14 # Imports
18 # Imports
15 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
16 from __future__ import print_function
20 from __future__ import print_function
17
21
18 import sys
22 import sys
19 import time
23 import time
20 from datetime import datetime
24 from datetime import datetime
21
25
22 import zmq
26 import zmq
23 from zmq.eventloop import ioloop
27 from zmq.eventloop import ioloop
24 from zmq.eventloop.zmqstream import ZMQStream
28 from zmq.eventloop.zmqstream import ZMQStream
25
29
26 # internal:
30 # internal:
27 from IPython.utils.importstring import import_item
31 from IPython.utils.importstring import import_item
28 from IPython.utils.traitlets import (
32 from IPython.utils.traitlets import (
29 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CStr
33 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CStr
30 )
34 )
31
35
32 from IPython.parallel import error, util
36 from IPython.parallel import error, util
33 from IPython.parallel.factory import RegistrationFactory
37 from IPython.parallel.factory import RegistrationFactory
34
38
35 from IPython.zmq.session import SessionFactory
39 from IPython.zmq.session import SessionFactory
36
40
37 from .heartmonitor import HeartMonitor
41 from .heartmonitor import HeartMonitor
38
42
39 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
40 # Code
44 # Code
41 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
42
46
43 def _passer(*args, **kwargs):
47 def _passer(*args, **kwargs):
44 return
48 return
45
49
46 def _printer(*args, **kwargs):
50 def _printer(*args, **kwargs):
47 print (args)
51 print (args)
48 print (kwargs)
52 print (kwargs)
49
53
50 def empty_record():
54 def empty_record():
51 """Return an empty dict with all record keys."""
55 """Return an empty dict with all record keys."""
52 return {
56 return {
53 'msg_id' : None,
57 'msg_id' : None,
54 'header' : None,
58 'header' : None,
55 'content': None,
59 'content': None,
56 'buffers': None,
60 'buffers': None,
57 'submitted': None,
61 'submitted': None,
58 'client_uuid' : None,
62 'client_uuid' : None,
59 'engine_uuid' : None,
63 'engine_uuid' : None,
60 'started': None,
64 'started': None,
61 'completed': None,
65 'completed': None,
62 'resubmitted': None,
66 'resubmitted': None,
63 'result_header' : None,
67 'result_header' : None,
64 'result_content' : None,
68 'result_content' : None,
65 'result_buffers' : None,
69 'result_buffers' : None,
66 'queue' : None,
70 'queue' : None,
67 'pyin' : None,
71 'pyin' : None,
68 'pyout': None,
72 'pyout': None,
69 'pyerr': None,
73 'pyerr': None,
70 'stdout': '',
74 'stdout': '',
71 'stderr': '',
75 'stderr': '',
72 }
76 }
73
77
74 def init_record(msg):
78 def init_record(msg):
75 """Initialize a TaskRecord based on a request."""
79 """Initialize a TaskRecord based on a request."""
76 header = msg['header']
80 header = msg['header']
77 return {
81 return {
78 'msg_id' : header['msg_id'],
82 'msg_id' : header['msg_id'],
79 'header' : header,
83 'header' : header,
80 'content': msg['content'],
84 'content': msg['content'],
81 'buffers': msg['buffers'],
85 'buffers': msg['buffers'],
82 'submitted': header['date'],
86 'submitted': header['date'],
83 'client_uuid' : None,
87 'client_uuid' : None,
84 'engine_uuid' : None,
88 'engine_uuid' : None,
85 'started': None,
89 'started': None,
86 'completed': None,
90 'completed': None,
87 'resubmitted': None,
91 'resubmitted': None,
88 'result_header' : None,
92 'result_header' : None,
89 'result_content' : None,
93 'result_content' : None,
90 'result_buffers' : None,
94 'result_buffers' : None,
91 'queue' : None,
95 'queue' : None,
92 'pyin' : None,
96 'pyin' : None,
93 'pyout': None,
97 'pyout': None,
94 'pyerr': None,
98 'pyerr': None,
95 'stdout': '',
99 'stdout': '',
96 'stderr': '',
100 'stderr': '',
97 }
101 }
98
102
99
103
100 class EngineConnector(HasTraits):
104 class EngineConnector(HasTraits):
101 """A simple object for accessing the various zmq connections of an object.
105 """A simple object for accessing the various zmq connections of an object.
102 Attributes are:
106 Attributes are:
103 id (int): engine ID
107 id (int): engine ID
104 uuid (str): uuid (unused?)
108 uuid (str): uuid (unused?)
105 queue (str): identity of queue's XREQ socket
109 queue (str): identity of queue's XREQ socket
106 registration (str): identity of registration XREQ socket
110 registration (str): identity of registration XREQ socket
107 heartbeat (str): identity of heartbeat XREQ socket
111 heartbeat (str): identity of heartbeat XREQ socket
108 """
112 """
109 id=Int(0)
113 id=Int(0)
110 queue=CStr()
114 queue=CStr()
111 control=CStr()
115 control=CStr()
112 registration=CStr()
116 registration=CStr()
113 heartbeat=CStr()
117 heartbeat=CStr()
114 pending=Set()
118 pending=Set()
115
119
116 class HubFactory(RegistrationFactory):
120 class HubFactory(RegistrationFactory):
117 """The Configurable for setting up a Hub."""
121 """The Configurable for setting up a Hub."""
118
122
119 # port-pairs for monitoredqueues:
123 # port-pairs for monitoredqueues:
120 hb = Tuple(Int,Int,config=True,
124 hb = Tuple(Int,Int,config=True,
121 help="""XREQ/SUB Port pair for Engine heartbeats""")
125 help="""XREQ/SUB Port pair for Engine heartbeats""")
122 def _hb_default(self):
126 def _hb_default(self):
123 return tuple(util.select_random_ports(2))
127 return tuple(util.select_random_ports(2))
124
128
125 mux = Tuple(Int,Int,config=True,
129 mux = Tuple(Int,Int,config=True,
126 help="""Engine/Client Port pair for MUX queue""")
130 help="""Engine/Client Port pair for MUX queue""")
127
131
128 def _mux_default(self):
132 def _mux_default(self):
129 return tuple(util.select_random_ports(2))
133 return tuple(util.select_random_ports(2))
130
134
131 task = Tuple(Int,Int,config=True,
135 task = Tuple(Int,Int,config=True,
132 help="""Engine/Client Port pair for Task queue""")
136 help="""Engine/Client Port pair for Task queue""")
133 def _task_default(self):
137 def _task_default(self):
134 return tuple(util.select_random_ports(2))
138 return tuple(util.select_random_ports(2))
135
139
136 control = Tuple(Int,Int,config=True,
140 control = Tuple(Int,Int,config=True,
137 help="""Engine/Client Port pair for Control queue""")
141 help="""Engine/Client Port pair for Control queue""")
138
142
139 def _control_default(self):
143 def _control_default(self):
140 return tuple(util.select_random_ports(2))
144 return tuple(util.select_random_ports(2))
141
145
142 iopub = Tuple(Int,Int,config=True,
146 iopub = Tuple(Int,Int,config=True,
143 help="""Engine/Client Port pair for IOPub relay""")
147 help="""Engine/Client Port pair for IOPub relay""")
144
148
145 def _iopub_default(self):
149 def _iopub_default(self):
146 return tuple(util.select_random_ports(2))
150 return tuple(util.select_random_ports(2))
147
151
148 # single ports:
152 # single ports:
149 mon_port = Int(config=True,
153 mon_port = Int(config=True,
150 help="""Monitor (SUB) port for queue traffic""")
154 help="""Monitor (SUB) port for queue traffic""")
151
155
152 def _mon_port_default(self):
156 def _mon_port_default(self):
153 return util.select_random_ports(1)[0]
157 return util.select_random_ports(1)[0]
154
158
155 notifier_port = Int(config=True,
159 notifier_port = Int(config=True,
156 help="""PUB port for sending engine status notifications""")
160 help="""PUB port for sending engine status notifications""")
157
161
158 def _notifier_port_default(self):
162 def _notifier_port_default(self):
159 return util.select_random_ports(1)[0]
163 return util.select_random_ports(1)[0]
160
164
161 engine_ip = Unicode('127.0.0.1', config=True,
165 engine_ip = Unicode('127.0.0.1', config=True,
162 help="IP on which to listen for engine connections. [default: loopback]")
166 help="IP on which to listen for engine connections. [default: loopback]")
163 engine_transport = Unicode('tcp', config=True,
167 engine_transport = Unicode('tcp', config=True,
164 help="0MQ transport for engine connections. [default: tcp]")
168 help="0MQ transport for engine connections. [default: tcp]")
165
169
166 client_ip = Unicode('127.0.0.1', config=True,
170 client_ip = Unicode('127.0.0.1', config=True,
167 help="IP on which to listen for client connections. [default: loopback]")
171 help="IP on which to listen for client connections. [default: loopback]")
168 client_transport = Unicode('tcp', config=True,
172 client_transport = Unicode('tcp', config=True,
169 help="0MQ transport for client connections. [default : tcp]")
173 help="0MQ transport for client connections. [default : tcp]")
170
174
171 monitor_ip = Unicode('127.0.0.1', config=True,
175 monitor_ip = Unicode('127.0.0.1', config=True,
172 help="IP on which to listen for monitor messages. [default: loopback]")
176 help="IP on which to listen for monitor messages. [default: loopback]")
173 monitor_transport = Unicode('tcp', config=True,
177 monitor_transport = Unicode('tcp', config=True,
174 help="0MQ transport for monitor messages. [default : tcp]")
178 help="0MQ transport for monitor messages. [default : tcp]")
175
179
176 monitor_url = Unicode('')
180 monitor_url = Unicode('')
177
181
178 db_class = Unicode('IPython.parallel.controller.dictdb.DictDB', config=True,
182 db_class = Unicode('IPython.parallel.controller.dictdb.DictDB', config=True,
179 help="""The class to use for the DB backend""")
183 help="""The class to use for the DB backend""")
180
184
181 # not configurable
185 # not configurable
182 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
186 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
183 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
187 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
184
188
185 def _ip_changed(self, name, old, new):
189 def _ip_changed(self, name, old, new):
186 self.engine_ip = new
190 self.engine_ip = new
187 self.client_ip = new
191 self.client_ip = new
188 self.monitor_ip = new
192 self.monitor_ip = new
189 self._update_monitor_url()
193 self._update_monitor_url()
190
194
191 def _update_monitor_url(self):
195 def _update_monitor_url(self):
192 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
196 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
193
197
194 def _transport_changed(self, name, old, new):
198 def _transport_changed(self, name, old, new):
195 self.engine_transport = new
199 self.engine_transport = new
196 self.client_transport = new
200 self.client_transport = new
197 self.monitor_transport = new
201 self.monitor_transport = new
198 self._update_monitor_url()
202 self._update_monitor_url()
199
203
200 def __init__(self, **kwargs):
204 def __init__(self, **kwargs):
201 super(HubFactory, self).__init__(**kwargs)
205 super(HubFactory, self).__init__(**kwargs)
202 self._update_monitor_url()
206 self._update_monitor_url()
203
207
204
208
205 def construct(self):
209 def construct(self):
206 self.init_hub()
210 self.init_hub()
207
211
208 def start(self):
212 def start(self):
209 self.heartmonitor.start()
213 self.heartmonitor.start()
210 self.log.info("Heartmonitor started")
214 self.log.info("Heartmonitor started")
211
215
212 def init_hub(self):
216 def init_hub(self):
213 """construct"""
217 """construct"""
214 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
218 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
215 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
219 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
216
220
217 ctx = self.context
221 ctx = self.context
218 loop = self.loop
222 loop = self.loop
219
223
220 # Registrar socket
224 # Registrar socket
221 q = ZMQStream(ctx.socket(zmq.XREP), loop)
225 q = ZMQStream(ctx.socket(zmq.XREP), loop)
222 q.bind(client_iface % self.regport)
226 q.bind(client_iface % self.regport)
223 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
227 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
224 if self.client_ip != self.engine_ip:
228 if self.client_ip != self.engine_ip:
225 q.bind(engine_iface % self.regport)
229 q.bind(engine_iface % self.regport)
226 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
230 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
227
231
228 ### Engine connections ###
232 ### Engine connections ###
229
233
230 # heartbeat
234 # heartbeat
231 hpub = ctx.socket(zmq.PUB)
235 hpub = ctx.socket(zmq.PUB)
232 hpub.bind(engine_iface % self.hb[0])
236 hpub.bind(engine_iface % self.hb[0])
233 hrep = ctx.socket(zmq.XREP)
237 hrep = ctx.socket(zmq.XREP)
234 hrep.bind(engine_iface % self.hb[1])
238 hrep.bind(engine_iface % self.hb[1])
235 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
239 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
236 pingstream=ZMQStream(hpub,loop),
240 pingstream=ZMQStream(hpub,loop),
237 pongstream=ZMQStream(hrep,loop)
241 pongstream=ZMQStream(hrep,loop)
238 )
242 )
239
243
240 ### Client connections ###
244 ### Client connections ###
241 # Notifier socket
245 # Notifier socket
242 n = ZMQStream(ctx.socket(zmq.PUB), loop)
246 n = ZMQStream(ctx.socket(zmq.PUB), loop)
243 n.bind(client_iface%self.notifier_port)
247 n.bind(client_iface%self.notifier_port)
244
248
245 ### build and launch the queues ###
249 ### build and launch the queues ###
246
250
247 # monitor socket
251 # monitor socket
248 sub = ctx.socket(zmq.SUB)
252 sub = ctx.socket(zmq.SUB)
249 sub.setsockopt(zmq.SUBSCRIBE, "")
253 sub.setsockopt(zmq.SUBSCRIBE, "")
250 sub.bind(self.monitor_url)
254 sub.bind(self.monitor_url)
251 sub.bind('inproc://monitor')
255 sub.bind('inproc://monitor')
252 sub = ZMQStream(sub, loop)
256 sub = ZMQStream(sub, loop)
253
257
254 # connect the db
258 # connect the db
255 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
259 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
256 # cdir = self.config.Global.cluster_dir
260 # cdir = self.config.Global.cluster_dir
257 self.db = import_item(str(self.db_class))(session=self.session.session,
261 self.db = import_item(str(self.db_class))(session=self.session.session,
258 config=self.config, log=self.log)
262 config=self.config, log=self.log)
259 time.sleep(.25)
263 time.sleep(.25)
260 try:
264 try:
261 scheme = self.config.TaskScheduler.scheme_name
265 scheme = self.config.TaskScheduler.scheme_name
262 except AttributeError:
266 except AttributeError:
263 from .scheduler import TaskScheduler
267 from .scheduler import TaskScheduler
264 scheme = TaskScheduler.scheme_name.get_default_value()
268 scheme = TaskScheduler.scheme_name.get_default_value()
265 # build connection dicts
269 # build connection dicts
266 self.engine_info = {
270 self.engine_info = {
267 'control' : engine_iface%self.control[1],
271 'control' : engine_iface%self.control[1],
268 'mux': engine_iface%self.mux[1],
272 'mux': engine_iface%self.mux[1],
269 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
273 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
270 'task' : engine_iface%self.task[1],
274 'task' : engine_iface%self.task[1],
271 'iopub' : engine_iface%self.iopub[1],
275 'iopub' : engine_iface%self.iopub[1],
272 # 'monitor' : engine_iface%self.mon_port,
276 # 'monitor' : engine_iface%self.mon_port,
273 }
277 }
274
278
275 self.client_info = {
279 self.client_info = {
276 'control' : client_iface%self.control[0],
280 'control' : client_iface%self.control[0],
277 'mux': client_iface%self.mux[0],
281 'mux': client_iface%self.mux[0],
278 'task' : (scheme, client_iface%self.task[0]),
282 'task' : (scheme, client_iface%self.task[0]),
279 'iopub' : client_iface%self.iopub[0],
283 'iopub' : client_iface%self.iopub[0],
280 'notification': client_iface%self.notifier_port
284 'notification': client_iface%self.notifier_port
281 }
285 }
282 self.log.debug("Hub engine addrs: %s"%self.engine_info)
286 self.log.debug("Hub engine addrs: %s"%self.engine_info)
283 self.log.debug("Hub client addrs: %s"%self.client_info)
287 self.log.debug("Hub client addrs: %s"%self.client_info)
284
288
285 # resubmit stream
289 # resubmit stream
286 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
290 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
287 url = util.disambiguate_url(self.client_info['task'][-1])
291 url = util.disambiguate_url(self.client_info['task'][-1])
288 r.setsockopt(zmq.IDENTITY, self.session.session)
292 r.setsockopt(zmq.IDENTITY, self.session.session)
289 r.connect(url)
293 r.connect(url)
290
294
291 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
295 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
292 query=q, notifier=n, resubmit=r, db=self.db,
296 query=q, notifier=n, resubmit=r, db=self.db,
293 engine_info=self.engine_info, client_info=self.client_info,
297 engine_info=self.engine_info, client_info=self.client_info,
294 log=self.log)
298 log=self.log)
295
299
296
300
297 class Hub(SessionFactory):
301 class Hub(SessionFactory):
298 """The IPython Controller Hub with 0MQ connections
302 """The IPython Controller Hub with 0MQ connections
299
303
300 Parameters
304 Parameters
301 ==========
305 ==========
302 loop: zmq IOLoop instance
306 loop: zmq IOLoop instance
303 session: Session object
307 session: Session object
304 <removed> context: zmq context for creating new connections (?)
308 <removed> context: zmq context for creating new connections (?)
305 queue: ZMQStream for monitoring the command queue (SUB)
309 queue: ZMQStream for monitoring the command queue (SUB)
306 query: ZMQStream for engine registration and client queries requests (XREP)
310 query: ZMQStream for engine registration and client queries requests (XREP)
307 heartbeat: HeartMonitor object checking the pulse of the engines
311 heartbeat: HeartMonitor object checking the pulse of the engines
308 notifier: ZMQStream for broadcasting engine registration changes (PUB)
312 notifier: ZMQStream for broadcasting engine registration changes (PUB)
309 db: connection to db for out of memory logging of commands
313 db: connection to db for out of memory logging of commands
310 NotImplemented
314 NotImplemented
311 engine_info: dict of zmq connection information for engines to connect
315 engine_info: dict of zmq connection information for engines to connect
312 to the queues.
316 to the queues.
313 client_info: dict of zmq connection information for engines to connect
317 client_info: dict of zmq connection information for engines to connect
314 to the queues.
318 to the queues.
315 """
319 """
316 # internal data structures:
320 # internal data structures:
317 ids=Set() # engine IDs
321 ids=Set() # engine IDs
318 keytable=Dict()
322 keytable=Dict()
319 by_ident=Dict()
323 by_ident=Dict()
320 engines=Dict()
324 engines=Dict()
321 clients=Dict()
325 clients=Dict()
322 hearts=Dict()
326 hearts=Dict()
323 pending=Set()
327 pending=Set()
324 queues=Dict() # pending msg_ids keyed by engine_id
328 queues=Dict() # pending msg_ids keyed by engine_id
325 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
329 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
326 completed=Dict() # completed msg_ids keyed by engine_id
330 completed=Dict() # completed msg_ids keyed by engine_id
327 all_completed=Set() # completed msg_ids keyed by engine_id
331 all_completed=Set() # completed msg_ids keyed by engine_id
328 dead_engines=Set() # completed msg_ids keyed by engine_id
332 dead_engines=Set() # completed msg_ids keyed by engine_id
329 unassigned=Set() # set of task msg_ds not yet assigned a destination
333 unassigned=Set() # set of task msg_ds not yet assigned a destination
330 incoming_registrations=Dict()
334 incoming_registrations=Dict()
331 registration_timeout=Int()
335 registration_timeout=Int()
332 _idcounter=Int(0)
336 _idcounter=Int(0)
333
337
334 # objects from constructor:
338 # objects from constructor:
335 query=Instance(ZMQStream)
339 query=Instance(ZMQStream)
336 monitor=Instance(ZMQStream)
340 monitor=Instance(ZMQStream)
337 notifier=Instance(ZMQStream)
341 notifier=Instance(ZMQStream)
338 resubmit=Instance(ZMQStream)
342 resubmit=Instance(ZMQStream)
339 heartmonitor=Instance(HeartMonitor)
343 heartmonitor=Instance(HeartMonitor)
340 db=Instance(object)
344 db=Instance(object)
341 client_info=Dict()
345 client_info=Dict()
342 engine_info=Dict()
346 engine_info=Dict()
343
347
344
348
345 def __init__(self, **kwargs):
349 def __init__(self, **kwargs):
346 """
350 """
347 # universal:
351 # universal:
348 loop: IOLoop for creating future connections
352 loop: IOLoop for creating future connections
349 session: streamsession for sending serialized data
353 session: streamsession for sending serialized data
350 # engine:
354 # engine:
351 queue: ZMQStream for monitoring queue messages
355 queue: ZMQStream for monitoring queue messages
352 query: ZMQStream for engine+client registration and client requests
356 query: ZMQStream for engine+client registration and client requests
353 heartbeat: HeartMonitor object for tracking engines
357 heartbeat: HeartMonitor object for tracking engines
354 # extra:
358 # extra:
355 db: ZMQStream for db connection (NotImplemented)
359 db: ZMQStream for db connection (NotImplemented)
356 engine_info: zmq address/protocol dict for engine connections
360 engine_info: zmq address/protocol dict for engine connections
357 client_info: zmq address/protocol dict for client connections
361 client_info: zmq address/protocol dict for client connections
358 """
362 """
359
363
360 super(Hub, self).__init__(**kwargs)
364 super(Hub, self).__init__(**kwargs)
361 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
365 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
362
366
363 # validate connection dicts:
367 # validate connection dicts:
364 for k,v in self.client_info.iteritems():
368 for k,v in self.client_info.iteritems():
365 if k == 'task':
369 if k == 'task':
366 util.validate_url_container(v[1])
370 util.validate_url_container(v[1])
367 else:
371 else:
368 util.validate_url_container(v)
372 util.validate_url_container(v)
369 # util.validate_url_container(self.client_info)
373 # util.validate_url_container(self.client_info)
370 util.validate_url_container(self.engine_info)
374 util.validate_url_container(self.engine_info)
371
375
372 # register our callbacks
376 # register our callbacks
373 self.query.on_recv(self.dispatch_query)
377 self.query.on_recv(self.dispatch_query)
374 self.monitor.on_recv(self.dispatch_monitor_traffic)
378 self.monitor.on_recv(self.dispatch_monitor_traffic)
375
379
376 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
380 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
377 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
381 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
378
382
379 self.monitor_handlers = { 'in' : self.save_queue_request,
383 self.monitor_handlers = { 'in' : self.save_queue_request,
380 'out': self.save_queue_result,
384 'out': self.save_queue_result,
381 'intask': self.save_task_request,
385 'intask': self.save_task_request,
382 'outtask': self.save_task_result,
386 'outtask': self.save_task_result,
383 'tracktask': self.save_task_destination,
387 'tracktask': self.save_task_destination,
384 'incontrol': _passer,
388 'incontrol': _passer,
385 'outcontrol': _passer,
389 'outcontrol': _passer,
386 'iopub': self.save_iopub_message,
390 'iopub': self.save_iopub_message,
387 }
391 }
388
392
389 self.query_handlers = {'queue_request': self.queue_status,
393 self.query_handlers = {'queue_request': self.queue_status,
390 'result_request': self.get_results,
394 'result_request': self.get_results,
391 'history_request': self.get_history,
395 'history_request': self.get_history,
392 'db_request': self.db_query,
396 'db_request': self.db_query,
393 'purge_request': self.purge_results,
397 'purge_request': self.purge_results,
394 'load_request': self.check_load,
398 'load_request': self.check_load,
395 'resubmit_request': self.resubmit_task,
399 'resubmit_request': self.resubmit_task,
396 'shutdown_request': self.shutdown_request,
400 'shutdown_request': self.shutdown_request,
397 'registration_request' : self.register_engine,
401 'registration_request' : self.register_engine,
398 'unregistration_request' : self.unregister_engine,
402 'unregistration_request' : self.unregister_engine,
399 'connection_request': self.connection_request,
403 'connection_request': self.connection_request,
400 }
404 }
401
405
402 # ignore resubmit replies
406 # ignore resubmit replies
403 self.resubmit.on_recv(lambda msg: None, copy=False)
407 self.resubmit.on_recv(lambda msg: None, copy=False)
404
408
405 self.log.info("hub::created hub")
409 self.log.info("hub::created hub")
406
410
407 @property
411 @property
408 def _next_id(self):
412 def _next_id(self):
409 """gemerate a new ID.
413 """gemerate a new ID.
410
414
411 No longer reuse old ids, just count from 0."""
415 No longer reuse old ids, just count from 0."""
412 newid = self._idcounter
416 newid = self._idcounter
413 self._idcounter += 1
417 self._idcounter += 1
414 return newid
418 return newid
415 # newid = 0
419 # newid = 0
416 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
420 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
417 # # print newid, self.ids, self.incoming_registrations
421 # # print newid, self.ids, self.incoming_registrations
418 # while newid in self.ids or newid in incoming:
422 # while newid in self.ids or newid in incoming:
419 # newid += 1
423 # newid += 1
420 # return newid
424 # return newid
421
425
422 #-----------------------------------------------------------------------------
426 #-----------------------------------------------------------------------------
423 # message validation
427 # message validation
424 #-----------------------------------------------------------------------------
428 #-----------------------------------------------------------------------------
425
429
426 def _validate_targets(self, targets):
430 def _validate_targets(self, targets):
427 """turn any valid targets argument into a list of integer ids"""
431 """turn any valid targets argument into a list of integer ids"""
428 if targets is None:
432 if targets is None:
429 # default to all
433 # default to all
430 targets = self.ids
434 targets = self.ids
431
435
432 if isinstance(targets, (int,str,unicode)):
436 if isinstance(targets, (int,str,unicode)):
433 # only one target specified
437 # only one target specified
434 targets = [targets]
438 targets = [targets]
435 _targets = []
439 _targets = []
436 for t in targets:
440 for t in targets:
437 # map raw identities to ids
441 # map raw identities to ids
438 if isinstance(t, (str,unicode)):
442 if isinstance(t, (str,unicode)):
439 t = self.by_ident.get(t, t)
443 t = self.by_ident.get(t, t)
440 _targets.append(t)
444 _targets.append(t)
441 targets = _targets
445 targets = _targets
442 bad_targets = [ t for t in targets if t not in self.ids ]
446 bad_targets = [ t for t in targets if t not in self.ids ]
443 if bad_targets:
447 if bad_targets:
444 raise IndexError("No Such Engine: %r"%bad_targets)
448 raise IndexError("No Such Engine: %r"%bad_targets)
445 if not targets:
449 if not targets:
446 raise IndexError("No Engines Registered")
450 raise IndexError("No Engines Registered")
447 return targets
451 return targets
448
452
449 #-----------------------------------------------------------------------------
453 #-----------------------------------------------------------------------------
450 # dispatch methods (1 per stream)
454 # dispatch methods (1 per stream)
451 #-----------------------------------------------------------------------------
455 #-----------------------------------------------------------------------------
452
456
453
457
454 def dispatch_monitor_traffic(self, msg):
458 def dispatch_monitor_traffic(self, msg):
455 """all ME and Task queue messages come through here, as well as
459 """all ME and Task queue messages come through here, as well as
456 IOPub traffic."""
460 IOPub traffic."""
457 self.log.debug("monitor traffic: %r"%msg[:2])
461 self.log.debug("monitor traffic: %r"%msg[:2])
458 switch = msg[0]
462 switch = msg[0]
459 try:
463 try:
460 idents, msg = self.session.feed_identities(msg[1:])
464 idents, msg = self.session.feed_identities(msg[1:])
461 except ValueError:
465 except ValueError:
462 idents=[]
466 idents=[]
463 if not idents:
467 if not idents:
464 self.log.error("Bad Monitor Message: %r"%msg)
468 self.log.error("Bad Monitor Message: %r"%msg)
465 return
469 return
466 handler = self.monitor_handlers.get(switch, None)
470 handler = self.monitor_handlers.get(switch, None)
467 if handler is not None:
471 if handler is not None:
468 handler(idents, msg)
472 handler(idents, msg)
469 else:
473 else:
470 self.log.error("Invalid monitor topic: %r"%switch)
474 self.log.error("Invalid monitor topic: %r"%switch)
471
475
472
476
473 def dispatch_query(self, msg):
477 def dispatch_query(self, msg):
474 """Route registration requests and queries from clients."""
478 """Route registration requests and queries from clients."""
475 try:
479 try:
476 idents, msg = self.session.feed_identities(msg)
480 idents, msg = self.session.feed_identities(msg)
477 except ValueError:
481 except ValueError:
478 idents = []
482 idents = []
479 if not idents:
483 if not idents:
480 self.log.error("Bad Query Message: %r"%msg)
484 self.log.error("Bad Query Message: %r"%msg)
481 return
485 return
482 client_id = idents[0]
486 client_id = idents[0]
483 try:
487 try:
484 msg = self.session.unpack_message(msg, content=True)
488 msg = self.session.unpack_message(msg, content=True)
485 except Exception:
489 except Exception:
486 content = error.wrap_exception()
490 content = error.wrap_exception()
487 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
488 self.session.send(self.query, "hub_error", ident=client_id,
492 self.session.send(self.query, "hub_error", ident=client_id,
489 content=content)
493 content=content)
490 return
494 return
491 # print client_id, header, parent, content
495 # print client_id, header, parent, content
492 #switch on message type:
496 #switch on message type:
493 msg_type = msg['msg_type']
497 msg_type = msg['msg_type']
494 self.log.info("client::client %r requested %r"%(client_id, msg_type))
498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
495 handler = self.query_handlers.get(msg_type, None)
499 handler = self.query_handlers.get(msg_type, None)
496 try:
500 try:
497 assert handler is not None, "Bad Message Type: %r"%msg_type
501 assert handler is not None, "Bad Message Type: %r"%msg_type
498 except:
502 except:
499 content = error.wrap_exception()
503 content = error.wrap_exception()
500 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
501 self.session.send(self.query, "hub_error", ident=client_id,
505 self.session.send(self.query, "hub_error", ident=client_id,
502 content=content)
506 content=content)
503 return
507 return
504
508
505 else:
509 else:
506 handler(idents, msg)
510 handler(idents, msg)
507
511
508 def dispatch_db(self, msg):
512 def dispatch_db(self, msg):
509 """"""
513 """"""
510 raise NotImplementedError
514 raise NotImplementedError
511
515
512 #---------------------------------------------------------------------------
516 #---------------------------------------------------------------------------
513 # handler methods (1 per event)
517 # handler methods (1 per event)
514 #---------------------------------------------------------------------------
518 #---------------------------------------------------------------------------
515
519
516 #----------------------- Heartbeat --------------------------------------
520 #----------------------- Heartbeat --------------------------------------
517
521
518 def handle_new_heart(self, heart):
522 def handle_new_heart(self, heart):
519 """handler to attach to heartbeater.
523 """handler to attach to heartbeater.
520 Called when a new heart starts to beat.
524 Called when a new heart starts to beat.
521 Triggers completion of registration."""
525 Triggers completion of registration."""
522 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
523 if heart not in self.incoming_registrations:
527 if heart not in self.incoming_registrations:
524 self.log.info("heartbeat::ignoring new heart: %r"%heart)
528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
525 else:
529 else:
526 self.finish_registration(heart)
530 self.finish_registration(heart)
527
531
528
532
529 def handle_heart_failure(self, heart):
533 def handle_heart_failure(self, heart):
530 """handler to attach to heartbeater.
534 """handler to attach to heartbeater.
531 called when a previously registered heart fails to respond to beat request.
535 called when a previously registered heart fails to respond to beat request.
532 triggers unregistration"""
536 triggers unregistration"""
533 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
534 eid = self.hearts.get(heart, None)
538 eid = self.hearts.get(heart, None)
535 queue = self.engines[eid].queue
539 queue = self.engines[eid].queue
536 if eid is None:
540 if eid is None:
537 self.log.info("heartbeat::ignoring heart failure %r"%heart)
541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
538 else:
542 else:
539 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
540
544
541 #----------------------- MUX Queue Traffic ------------------------------
545 #----------------------- MUX Queue Traffic ------------------------------
542
546
543 def save_queue_request(self, idents, msg):
547 def save_queue_request(self, idents, msg):
544 if len(idents) < 2:
548 if len(idents) < 2:
545 self.log.error("invalid identity prefix: %r"%idents)
549 self.log.error("invalid identity prefix: %r"%idents)
546 return
550 return
547 queue_id, client_id = idents[:2]
551 queue_id, client_id = idents[:2]
548 try:
552 try:
549 msg = self.session.unpack_message(msg)
553 msg = self.session.unpack_message(msg)
550 except Exception:
554 except Exception:
551 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
552 return
556 return
553
557
554 eid = self.by_ident.get(queue_id, None)
558 eid = self.by_ident.get(queue_id, None)
555 if eid is None:
559 if eid is None:
556 self.log.error("queue::target %r not registered"%queue_id)
560 self.log.error("queue::target %r not registered"%queue_id)
557 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
561 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
558 return
562 return
559 record = init_record(msg)
563 record = init_record(msg)
560 msg_id = record['msg_id']
564 msg_id = record['msg_id']
561 record['engine_uuid'] = queue_id
565 record['engine_uuid'] = queue_id
562 record['client_uuid'] = client_id
566 record['client_uuid'] = client_id
563 record['queue'] = 'mux'
567 record['queue'] = 'mux'
564
568
565 try:
569 try:
566 # it's posible iopub arrived first:
570 # it's posible iopub arrived first:
567 existing = self.db.get_record(msg_id)
571 existing = self.db.get_record(msg_id)
568 for key,evalue in existing.iteritems():
572 for key,evalue in existing.iteritems():
569 rvalue = record.get(key, None)
573 rvalue = record.get(key, None)
570 if evalue and rvalue and evalue != rvalue:
574 if evalue and rvalue and evalue != rvalue:
571 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
575 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
572 elif evalue and not rvalue:
576 elif evalue and not rvalue:
573 record[key] = evalue
577 record[key] = evalue
574 try:
578 try:
575 self.db.update_record(msg_id, record)
579 self.db.update_record(msg_id, record)
576 except Exception:
580 except Exception:
577 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
581 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
578 except KeyError:
582 except KeyError:
579 try:
583 try:
580 self.db.add_record(msg_id, record)
584 self.db.add_record(msg_id, record)
581 except Exception:
585 except Exception:
582 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
586 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
583
587
584
588
585 self.pending.add(msg_id)
589 self.pending.add(msg_id)
586 self.queues[eid].append(msg_id)
590 self.queues[eid].append(msg_id)
587
591
588 def save_queue_result(self, idents, msg):
592 def save_queue_result(self, idents, msg):
589 if len(idents) < 2:
593 if len(idents) < 2:
590 self.log.error("invalid identity prefix: %r"%idents)
594 self.log.error("invalid identity prefix: %r"%idents)
591 return
595 return
592
596
593 client_id, queue_id = idents[:2]
597 client_id, queue_id = idents[:2]
594 try:
598 try:
595 msg = self.session.unpack_message(msg)
599 msg = self.session.unpack_message(msg)
596 except Exception:
600 except Exception:
597 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
601 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
598 queue_id,client_id, msg), exc_info=True)
602 queue_id,client_id, msg), exc_info=True)
599 return
603 return
600
604
601 eid = self.by_ident.get(queue_id, None)
605 eid = self.by_ident.get(queue_id, None)
602 if eid is None:
606 if eid is None:
603 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
607 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
604 return
608 return
605
609
606 parent = msg['parent_header']
610 parent = msg['parent_header']
607 if not parent:
611 if not parent:
608 return
612 return
609 msg_id = parent['msg_id']
613 msg_id = parent['msg_id']
610 if msg_id in self.pending:
614 if msg_id in self.pending:
611 self.pending.remove(msg_id)
615 self.pending.remove(msg_id)
612 self.all_completed.add(msg_id)
616 self.all_completed.add(msg_id)
613 self.queues[eid].remove(msg_id)
617 self.queues[eid].remove(msg_id)
614 self.completed[eid].append(msg_id)
618 self.completed[eid].append(msg_id)
615 elif msg_id not in self.all_completed:
619 elif msg_id not in self.all_completed:
616 # it could be a result from a dead engine that died before delivering the
620 # it could be a result from a dead engine that died before delivering the
617 # result
621 # result
618 self.log.warn("queue:: unknown msg finished %r"%msg_id)
622 self.log.warn("queue:: unknown msg finished %r"%msg_id)
619 return
623 return
620 # update record anyway, because the unregistration could have been premature
624 # update record anyway, because the unregistration could have been premature
621 rheader = msg['header']
625 rheader = msg['header']
622 completed = rheader['date']
626 completed = rheader['date']
623 started = rheader.get('started', None)
627 started = rheader.get('started', None)
624 result = {
628 result = {
625 'result_header' : rheader,
629 'result_header' : rheader,
626 'result_content': msg['content'],
630 'result_content': msg['content'],
627 'started' : started,
631 'started' : started,
628 'completed' : completed
632 'completed' : completed
629 }
633 }
630
634
631 result['result_buffers'] = msg['buffers']
635 result['result_buffers'] = msg['buffers']
632 try:
636 try:
633 self.db.update_record(msg_id, result)
637 self.db.update_record(msg_id, result)
634 except Exception:
638 except Exception:
635 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
639 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
636
640
637
641
638 #--------------------- Task Queue Traffic ------------------------------
642 #--------------------- Task Queue Traffic ------------------------------
639
643
640 def save_task_request(self, idents, msg):
644 def save_task_request(self, idents, msg):
641 """Save the submission of a task."""
645 """Save the submission of a task."""
642 client_id = idents[0]
646 client_id = idents[0]
643
647
644 try:
648 try:
645 msg = self.session.unpack_message(msg)
649 msg = self.session.unpack_message(msg)
646 except Exception:
650 except Exception:
647 self.log.error("task::client %r sent invalid task message: %r"%(
651 self.log.error("task::client %r sent invalid task message: %r"%(
648 client_id, msg), exc_info=True)
652 client_id, msg), exc_info=True)
649 return
653 return
650 record = init_record(msg)
654 record = init_record(msg)
651
655
652 record['client_uuid'] = client_id
656 record['client_uuid'] = client_id
653 record['queue'] = 'task'
657 record['queue'] = 'task'
654 header = msg['header']
658 header = msg['header']
655 msg_id = header['msg_id']
659 msg_id = header['msg_id']
656 self.pending.add(msg_id)
660 self.pending.add(msg_id)
657 self.unassigned.add(msg_id)
661 self.unassigned.add(msg_id)
658 try:
662 try:
659 # it's posible iopub arrived first:
663 # it's posible iopub arrived first:
660 existing = self.db.get_record(msg_id)
664 existing = self.db.get_record(msg_id)
661 if existing['resubmitted']:
665 if existing['resubmitted']:
662 for key in ('submitted', 'client_uuid', 'buffers'):
666 for key in ('submitted', 'client_uuid', 'buffers'):
663 # don't clobber these keys on resubmit
667 # don't clobber these keys on resubmit
664 # submitted and client_uuid should be different
668 # submitted and client_uuid should be different
665 # and buffers might be big, and shouldn't have changed
669 # and buffers might be big, and shouldn't have changed
666 record.pop(key)
670 record.pop(key)
667 # still check content,header which should not change
671 # still check content,header which should not change
668 # but are not expensive to compare as buffers
672 # but are not expensive to compare as buffers
669
673
670 for key,evalue in existing.iteritems():
674 for key,evalue in existing.iteritems():
671 if key.endswith('buffers'):
675 if key.endswith('buffers'):
672 # don't compare buffers
676 # don't compare buffers
673 continue
677 continue
674 rvalue = record.get(key, None)
678 rvalue = record.get(key, None)
675 if evalue and rvalue and evalue != rvalue:
679 if evalue and rvalue and evalue != rvalue:
676 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
680 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
677 elif evalue and not rvalue:
681 elif evalue and not rvalue:
678 record[key] = evalue
682 record[key] = evalue
679 try:
683 try:
680 self.db.update_record(msg_id, record)
684 self.db.update_record(msg_id, record)
681 except Exception:
685 except Exception:
682 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
686 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
683 except KeyError:
687 except KeyError:
684 try:
688 try:
685 self.db.add_record(msg_id, record)
689 self.db.add_record(msg_id, record)
686 except Exception:
690 except Exception:
687 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
691 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
688 except Exception:
692 except Exception:
689 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
693 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
690
694
691 def save_task_result(self, idents, msg):
695 def save_task_result(self, idents, msg):
692 """save the result of a completed task."""
696 """save the result of a completed task."""
693 client_id = idents[0]
697 client_id = idents[0]
694 try:
698 try:
695 msg = self.session.unpack_message(msg)
699 msg = self.session.unpack_message(msg)
696 except Exception:
700 except Exception:
697 self.log.error("task::invalid task result message send to %r: %r"%(
701 self.log.error("task::invalid task result message send to %r: %r"%(
698 client_id, msg), exc_info=True)
702 client_id, msg), exc_info=True)
699 return
703 return
700
704
701 parent = msg['parent_header']
705 parent = msg['parent_header']
702 if not parent:
706 if not parent:
703 # print msg
707 # print msg
704 self.log.warn("Task %r had no parent!"%msg)
708 self.log.warn("Task %r had no parent!"%msg)
705 return
709 return
706 msg_id = parent['msg_id']
710 msg_id = parent['msg_id']
707 if msg_id in self.unassigned:
711 if msg_id in self.unassigned:
708 self.unassigned.remove(msg_id)
712 self.unassigned.remove(msg_id)
709
713
710 header = msg['header']
714 header = msg['header']
711 engine_uuid = header.get('engine', None)
715 engine_uuid = header.get('engine', None)
712 eid = self.by_ident.get(engine_uuid, None)
716 eid = self.by_ident.get(engine_uuid, None)
713
717
714 if msg_id in self.pending:
718 if msg_id in self.pending:
715 self.pending.remove(msg_id)
719 self.pending.remove(msg_id)
716 self.all_completed.add(msg_id)
720 self.all_completed.add(msg_id)
717 if eid is not None:
721 if eid is not None:
718 self.completed[eid].append(msg_id)
722 self.completed[eid].append(msg_id)
719 if msg_id in self.tasks[eid]:
723 if msg_id in self.tasks[eid]:
720 self.tasks[eid].remove(msg_id)
724 self.tasks[eid].remove(msg_id)
721 completed = header['date']
725 completed = header['date']
722 started = header.get('started', None)
726 started = header.get('started', None)
723 result = {
727 result = {
724 'result_header' : header,
728 'result_header' : header,
725 'result_content': msg['content'],
729 'result_content': msg['content'],
726 'started' : started,
730 'started' : started,
727 'completed' : completed,
731 'completed' : completed,
728 'engine_uuid': engine_uuid
732 'engine_uuid': engine_uuid
729 }
733 }
730
734
731 result['result_buffers'] = msg['buffers']
735 result['result_buffers'] = msg['buffers']
732 try:
736 try:
733 self.db.update_record(msg_id, result)
737 self.db.update_record(msg_id, result)
734 except Exception:
738 except Exception:
735 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
739 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
736
740
737 else:
741 else:
738 self.log.debug("task::unknown task %r finished"%msg_id)
742 self.log.debug("task::unknown task %r finished"%msg_id)
739
743
740 def save_task_destination(self, idents, msg):
744 def save_task_destination(self, idents, msg):
741 try:
745 try:
742 msg = self.session.unpack_message(msg, content=True)
746 msg = self.session.unpack_message(msg, content=True)
743 except Exception:
747 except Exception:
744 self.log.error("task::invalid task tracking message", exc_info=True)
748 self.log.error("task::invalid task tracking message", exc_info=True)
745 return
749 return
746 content = msg['content']
750 content = msg['content']
747 # print (content)
751 # print (content)
748 msg_id = content['msg_id']
752 msg_id = content['msg_id']
749 engine_uuid = content['engine_id']
753 engine_uuid = content['engine_id']
750 eid = self.by_ident[engine_uuid]
754 eid = self.by_ident[engine_uuid]
751
755
752 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
756 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
753 if msg_id in self.unassigned:
757 if msg_id in self.unassigned:
754 self.unassigned.remove(msg_id)
758 self.unassigned.remove(msg_id)
755 # else:
759 # else:
756 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
760 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
757
761
758 self.tasks[eid].append(msg_id)
762 self.tasks[eid].append(msg_id)
759 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
763 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
760 try:
764 try:
761 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
765 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
762 except Exception:
766 except Exception:
763 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
767 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
764
768
765
769
766 def mia_task_request(self, idents, msg):
770 def mia_task_request(self, idents, msg):
767 raise NotImplementedError
771 raise NotImplementedError
768 client_id = idents[0]
772 client_id = idents[0]
769 # content = dict(mia=self.mia,status='ok')
773 # content = dict(mia=self.mia,status='ok')
770 # self.session.send('mia_reply', content=content, idents=client_id)
774 # self.session.send('mia_reply', content=content, idents=client_id)
771
775
772
776
773 #--------------------- IOPub Traffic ------------------------------
777 #--------------------- IOPub Traffic ------------------------------
774
778
775 def save_iopub_message(self, topics, msg):
779 def save_iopub_message(self, topics, msg):
776 """save an iopub message into the db"""
780 """save an iopub message into the db"""
777 # print (topics)
781 # print (topics)
778 try:
782 try:
779 msg = self.session.unpack_message(msg, content=True)
783 msg = self.session.unpack_message(msg, content=True)
780 except Exception:
784 except Exception:
781 self.log.error("iopub::invalid IOPub message", exc_info=True)
785 self.log.error("iopub::invalid IOPub message", exc_info=True)
782 return
786 return
783
787
784 parent = msg['parent_header']
788 parent = msg['parent_header']
785 if not parent:
789 if not parent:
786 self.log.error("iopub::invalid IOPub message: %r"%msg)
790 self.log.error("iopub::invalid IOPub message: %r"%msg)
787 return
791 return
788 msg_id = parent['msg_id']
792 msg_id = parent['msg_id']
789 msg_type = msg['msg_type']
793 msg_type = msg['msg_type']
790 content = msg['content']
794 content = msg['content']
791
795
792 # ensure msg_id is in db
796 # ensure msg_id is in db
793 try:
797 try:
794 rec = self.db.get_record(msg_id)
798 rec = self.db.get_record(msg_id)
795 except KeyError:
799 except KeyError:
796 rec = empty_record()
800 rec = empty_record()
797 rec['msg_id'] = msg_id
801 rec['msg_id'] = msg_id
798 self.db.add_record(msg_id, rec)
802 self.db.add_record(msg_id, rec)
799 # stream
803 # stream
800 d = {}
804 d = {}
801 if msg_type == 'stream':
805 if msg_type == 'stream':
802 name = content['name']
806 name = content['name']
803 s = rec[name] or ''
807 s = rec[name] or ''
804 d[name] = s + content['data']
808 d[name] = s + content['data']
805
809
806 elif msg_type == 'pyerr':
810 elif msg_type == 'pyerr':
807 d['pyerr'] = content
811 d['pyerr'] = content
808 elif msg_type == 'pyin':
812 elif msg_type == 'pyin':
809 d['pyin'] = content['code']
813 d['pyin'] = content['code']
810 else:
814 else:
811 d[msg_type] = content.get('data', '')
815 d[msg_type] = content.get('data', '')
812
816
813 try:
817 try:
814 self.db.update_record(msg_id, d)
818 self.db.update_record(msg_id, d)
815 except Exception:
819 except Exception:
816 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
820 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
817
821
818
822
819
823
820 #-------------------------------------------------------------------------
824 #-------------------------------------------------------------------------
821 # Registration requests
825 # Registration requests
822 #-------------------------------------------------------------------------
826 #-------------------------------------------------------------------------
823
827
824 def connection_request(self, client_id, msg):
828 def connection_request(self, client_id, msg):
825 """Reply with connection addresses for clients."""
829 """Reply with connection addresses for clients."""
826 self.log.info("client::client %r connected"%client_id)
830 self.log.info("client::client %r connected"%client_id)
827 content = dict(status='ok')
831 content = dict(status='ok')
828 content.update(self.client_info)
832 content.update(self.client_info)
829 jsonable = {}
833 jsonable = {}
830 for k,v in self.keytable.iteritems():
834 for k,v in self.keytable.iteritems():
831 if v not in self.dead_engines:
835 if v not in self.dead_engines:
832 jsonable[str(k)] = v
836 jsonable[str(k)] = v
833 content['engines'] = jsonable
837 content['engines'] = jsonable
834 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
838 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
835
839
836 def register_engine(self, reg, msg):
840 def register_engine(self, reg, msg):
837 """Register a new engine."""
841 """Register a new engine."""
838 content = msg['content']
842 content = msg['content']
839 try:
843 try:
840 queue = content['queue']
844 queue = content['queue']
841 except KeyError:
845 except KeyError:
842 self.log.error("registration::queue not specified", exc_info=True)
846 self.log.error("registration::queue not specified", exc_info=True)
843 return
847 return
844 heart = content.get('heartbeat', None)
848 heart = content.get('heartbeat', None)
845 """register a new engine, and create the socket(s) necessary"""
849 """register a new engine, and create the socket(s) necessary"""
846 eid = self._next_id
850 eid = self._next_id
847 # print (eid, queue, reg, heart)
851 # print (eid, queue, reg, heart)
848
852
849 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
853 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
850
854
851 content = dict(id=eid,status='ok')
855 content = dict(id=eid,status='ok')
852 content.update(self.engine_info)
856 content.update(self.engine_info)
853 # check if requesting available IDs:
857 # check if requesting available IDs:
854 if queue in self.by_ident:
858 if queue in self.by_ident:
855 try:
859 try:
856 raise KeyError("queue_id %r in use"%queue)
860 raise KeyError("queue_id %r in use"%queue)
857 except:
861 except:
858 content = error.wrap_exception()
862 content = error.wrap_exception()
859 self.log.error("queue_id %r in use"%queue, exc_info=True)
863 self.log.error("queue_id %r in use"%queue, exc_info=True)
860 elif heart in self.hearts: # need to check unique hearts?
864 elif heart in self.hearts: # need to check unique hearts?
861 try:
865 try:
862 raise KeyError("heart_id %r in use"%heart)
866 raise KeyError("heart_id %r in use"%heart)
863 except:
867 except:
864 self.log.error("heart_id %r in use"%heart, exc_info=True)
868 self.log.error("heart_id %r in use"%heart, exc_info=True)
865 content = error.wrap_exception()
869 content = error.wrap_exception()
866 else:
870 else:
867 for h, pack in self.incoming_registrations.iteritems():
871 for h, pack in self.incoming_registrations.iteritems():
868 if heart == h:
872 if heart == h:
869 try:
873 try:
870 raise KeyError("heart_id %r in use"%heart)
874 raise KeyError("heart_id %r in use"%heart)
871 except:
875 except:
872 self.log.error("heart_id %r in use"%heart, exc_info=True)
876 self.log.error("heart_id %r in use"%heart, exc_info=True)
873 content = error.wrap_exception()
877 content = error.wrap_exception()
874 break
878 break
875 elif queue == pack[1]:
879 elif queue == pack[1]:
876 try:
880 try:
877 raise KeyError("queue_id %r in use"%queue)
881 raise KeyError("queue_id %r in use"%queue)
878 except:
882 except:
879 self.log.error("queue_id %r in use"%queue, exc_info=True)
883 self.log.error("queue_id %r in use"%queue, exc_info=True)
880 content = error.wrap_exception()
884 content = error.wrap_exception()
881 break
885 break
882
886
883 msg = self.session.send(self.query, "registration_reply",
887 msg = self.session.send(self.query, "registration_reply",
884 content=content,
888 content=content,
885 ident=reg)
889 ident=reg)
886
890
887 if content['status'] == 'ok':
891 if content['status'] == 'ok':
888 if heart in self.heartmonitor.hearts:
892 if heart in self.heartmonitor.hearts:
889 # already beating
893 # already beating
890 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
894 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
891 self.finish_registration(heart)
895 self.finish_registration(heart)
892 else:
896 else:
893 purge = lambda : self._purge_stalled_registration(heart)
897 purge = lambda : self._purge_stalled_registration(heart)
894 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
898 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
895 dc.start()
899 dc.start()
896 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
900 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
897 else:
901 else:
898 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
902 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
899 return eid
903 return eid
900
904
901 def unregister_engine(self, ident, msg):
905 def unregister_engine(self, ident, msg):
902 """Unregister an engine that explicitly requested to leave."""
906 """Unregister an engine that explicitly requested to leave."""
903 try:
907 try:
904 eid = msg['content']['id']
908 eid = msg['content']['id']
905 except:
909 except:
906 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
910 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
907 return
911 return
908 self.log.info("registration::unregister_engine(%r)"%eid)
912 self.log.info("registration::unregister_engine(%r)"%eid)
909 # print (eid)
913 # print (eid)
910 uuid = self.keytable[eid]
914 uuid = self.keytable[eid]
911 content=dict(id=eid, queue=uuid)
915 content=dict(id=eid, queue=uuid)
912 self.dead_engines.add(uuid)
916 self.dead_engines.add(uuid)
913 # self.ids.remove(eid)
917 # self.ids.remove(eid)
914 # uuid = self.keytable.pop(eid)
918 # uuid = self.keytable.pop(eid)
915 #
919 #
916 # ec = self.engines.pop(eid)
920 # ec = self.engines.pop(eid)
917 # self.hearts.pop(ec.heartbeat)
921 # self.hearts.pop(ec.heartbeat)
918 # self.by_ident.pop(ec.queue)
922 # self.by_ident.pop(ec.queue)
919 # self.completed.pop(eid)
923 # self.completed.pop(eid)
920 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
924 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
921 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
925 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
922 dc.start()
926 dc.start()
923 ############## TODO: HANDLE IT ################
927 ############## TODO: HANDLE IT ################
924
928
925 if self.notifier:
929 if self.notifier:
926 self.session.send(self.notifier, "unregistration_notification", content=content)
930 self.session.send(self.notifier, "unregistration_notification", content=content)
927
931
928 def _handle_stranded_msgs(self, eid, uuid):
932 def _handle_stranded_msgs(self, eid, uuid):
929 """Handle messages known to be on an engine when the engine unregisters.
933 """Handle messages known to be on an engine when the engine unregisters.
930
934
931 It is possible that this will fire prematurely - that is, an engine will
935 It is possible that this will fire prematurely - that is, an engine will
932 go down after completing a result, and the client will be notified
936 go down after completing a result, and the client will be notified
933 that the result failed and later receive the actual result.
937 that the result failed and later receive the actual result.
934 """
938 """
935
939
936 outstanding = self.queues[eid]
940 outstanding = self.queues[eid]
937
941
938 for msg_id in outstanding:
942 for msg_id in outstanding:
939 self.pending.remove(msg_id)
943 self.pending.remove(msg_id)
940 self.all_completed.add(msg_id)
944 self.all_completed.add(msg_id)
941 try:
945 try:
942 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
946 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
943 except:
947 except:
944 content = error.wrap_exception()
948 content = error.wrap_exception()
945 # build a fake header:
949 # build a fake header:
946 header = {}
950 header = {}
947 header['engine'] = uuid
951 header['engine'] = uuid
948 header['date'] = datetime.now()
952 header['date'] = datetime.now()
949 rec = dict(result_content=content, result_header=header, result_buffers=[])
953 rec = dict(result_content=content, result_header=header, result_buffers=[])
950 rec['completed'] = header['date']
954 rec['completed'] = header['date']
951 rec['engine_uuid'] = uuid
955 rec['engine_uuid'] = uuid
952 try:
956 try:
953 self.db.update_record(msg_id, rec)
957 self.db.update_record(msg_id, rec)
954 except Exception:
958 except Exception:
955 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
959 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
956
960
957
961
958 def finish_registration(self, heart):
962 def finish_registration(self, heart):
959 """Second half of engine registration, called after our HeartMonitor
963 """Second half of engine registration, called after our HeartMonitor
960 has received a beat from the Engine's Heart."""
964 has received a beat from the Engine's Heart."""
961 try:
965 try:
962 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
966 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
963 except KeyError:
967 except KeyError:
964 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
968 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
965 return
969 return
966 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
970 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
967 if purge is not None:
971 if purge is not None:
968 purge.stop()
972 purge.stop()
969 control = queue
973 control = queue
970 self.ids.add(eid)
974 self.ids.add(eid)
971 self.keytable[eid] = queue
975 self.keytable[eid] = queue
972 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
976 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
973 control=control, heartbeat=heart)
977 control=control, heartbeat=heart)
974 self.by_ident[queue] = eid
978 self.by_ident[queue] = eid
975 self.queues[eid] = list()
979 self.queues[eid] = list()
976 self.tasks[eid] = list()
980 self.tasks[eid] = list()
977 self.completed[eid] = list()
981 self.completed[eid] = list()
978 self.hearts[heart] = eid
982 self.hearts[heart] = eid
979 content = dict(id=eid, queue=self.engines[eid].queue)
983 content = dict(id=eid, queue=self.engines[eid].queue)
980 if self.notifier:
984 if self.notifier:
981 self.session.send(self.notifier, "registration_notification", content=content)
985 self.session.send(self.notifier, "registration_notification", content=content)
982 self.log.info("engine::Engine Connected: %i"%eid)
986 self.log.info("engine::Engine Connected: %i"%eid)
983
987
984 def _purge_stalled_registration(self, heart):
988 def _purge_stalled_registration(self, heart):
985 if heart in self.incoming_registrations:
989 if heart in self.incoming_registrations:
986 eid = self.incoming_registrations.pop(heart)[0]
990 eid = self.incoming_registrations.pop(heart)[0]
987 self.log.info("registration::purging stalled registration: %i"%eid)
991 self.log.info("registration::purging stalled registration: %i"%eid)
988 else:
992 else:
989 pass
993 pass
990
994
991 #-------------------------------------------------------------------------
995 #-------------------------------------------------------------------------
992 # Client Requests
996 # Client Requests
993 #-------------------------------------------------------------------------
997 #-------------------------------------------------------------------------
994
998
995 def shutdown_request(self, client_id, msg):
999 def shutdown_request(self, client_id, msg):
996 """handle shutdown request."""
1000 """handle shutdown request."""
997 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1001 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
998 # also notify other clients of shutdown
1002 # also notify other clients of shutdown
999 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1003 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1000 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1004 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1001 dc.start()
1005 dc.start()
1002
1006
1003 def _shutdown(self):
1007 def _shutdown(self):
1004 self.log.info("hub::hub shutting down.")
1008 self.log.info("hub::hub shutting down.")
1005 time.sleep(0.1)
1009 time.sleep(0.1)
1006 sys.exit(0)
1010 sys.exit(0)
1007
1011
1008
1012
1009 def check_load(self, client_id, msg):
1013 def check_load(self, client_id, msg):
1010 content = msg['content']
1014 content = msg['content']
1011 try:
1015 try:
1012 targets = content['targets']
1016 targets = content['targets']
1013 targets = self._validate_targets(targets)
1017 targets = self._validate_targets(targets)
1014 except:
1018 except:
1015 content = error.wrap_exception()
1019 content = error.wrap_exception()
1016 self.session.send(self.query, "hub_error",
1020 self.session.send(self.query, "hub_error",
1017 content=content, ident=client_id)
1021 content=content, ident=client_id)
1018 return
1022 return
1019
1023
1020 content = dict(status='ok')
1024 content = dict(status='ok')
1021 # loads = {}
1025 # loads = {}
1022 for t in targets:
1026 for t in targets:
1023 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1027 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1024 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1028 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1025
1029
1026
1030
1027 def queue_status(self, client_id, msg):
1031 def queue_status(self, client_id, msg):
1028 """Return the Queue status of one or more targets.
1032 """Return the Queue status of one or more targets.
1029 if verbose: return the msg_ids
1033 if verbose: return the msg_ids
1030 else: return len of each type.
1034 else: return len of each type.
1031 keys: queue (pending MUX jobs)
1035 keys: queue (pending MUX jobs)
1032 tasks (pending Task jobs)
1036 tasks (pending Task jobs)
1033 completed (finished jobs from both queues)"""
1037 completed (finished jobs from both queues)"""
1034 content = msg['content']
1038 content = msg['content']
1035 targets = content['targets']
1039 targets = content['targets']
1036 try:
1040 try:
1037 targets = self._validate_targets(targets)
1041 targets = self._validate_targets(targets)
1038 except:
1042 except:
1039 content = error.wrap_exception()
1043 content = error.wrap_exception()
1040 self.session.send(self.query, "hub_error",
1044 self.session.send(self.query, "hub_error",
1041 content=content, ident=client_id)
1045 content=content, ident=client_id)
1042 return
1046 return
1043 verbose = content.get('verbose', False)
1047 verbose = content.get('verbose', False)
1044 content = dict(status='ok')
1048 content = dict(status='ok')
1045 for t in targets:
1049 for t in targets:
1046 queue = self.queues[t]
1050 queue = self.queues[t]
1047 completed = self.completed[t]
1051 completed = self.completed[t]
1048 tasks = self.tasks[t]
1052 tasks = self.tasks[t]
1049 if not verbose:
1053 if not verbose:
1050 queue = len(queue)
1054 queue = len(queue)
1051 completed = len(completed)
1055 completed = len(completed)
1052 tasks = len(tasks)
1056 tasks = len(tasks)
1053 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1057 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1054 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1058 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1055
1059
1056 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1060 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1057
1061
1058 def purge_results(self, client_id, msg):
1062 def purge_results(self, client_id, msg):
1059 """Purge results from memory. This method is more valuable before we move
1063 """Purge results from memory. This method is more valuable before we move
1060 to a DB based message storage mechanism."""
1064 to a DB based message storage mechanism."""
1061 content = msg['content']
1065 content = msg['content']
1062 msg_ids = content.get('msg_ids', [])
1066 msg_ids = content.get('msg_ids', [])
1063 reply = dict(status='ok')
1067 reply = dict(status='ok')
1064 if msg_ids == 'all':
1068 if msg_ids == 'all':
1065 try:
1069 try:
1066 self.db.drop_matching_records(dict(completed={'$ne':None}))
1070 self.db.drop_matching_records(dict(completed={'$ne':None}))
1067 except Exception:
1071 except Exception:
1068 reply = error.wrap_exception()
1072 reply = error.wrap_exception()
1069 else:
1073 else:
1070 pending = filter(lambda m: m in self.pending, msg_ids)
1074 pending = filter(lambda m: m in self.pending, msg_ids)
1071 if pending:
1075 if pending:
1072 try:
1076 try:
1073 raise IndexError("msg pending: %r"%pending[0])
1077 raise IndexError("msg pending: %r"%pending[0])
1074 except:
1078 except:
1075 reply = error.wrap_exception()
1079 reply = error.wrap_exception()
1076 else:
1080 else:
1077 try:
1081 try:
1078 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1082 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1079 except Exception:
1083 except Exception:
1080 reply = error.wrap_exception()
1084 reply = error.wrap_exception()
1081
1085
1082 if reply['status'] == 'ok':
1086 if reply['status'] == 'ok':
1083 eids = content.get('engine_ids', [])
1087 eids = content.get('engine_ids', [])
1084 for eid in eids:
1088 for eid in eids:
1085 if eid not in self.engines:
1089 if eid not in self.engines:
1086 try:
1090 try:
1087 raise IndexError("No such engine: %i"%eid)
1091 raise IndexError("No such engine: %i"%eid)
1088 except:
1092 except:
1089 reply = error.wrap_exception()
1093 reply = error.wrap_exception()
1090 break
1094 break
1091 msg_ids = self.completed.pop(eid)
1095 msg_ids = self.completed.pop(eid)
1092 uid = self.engines[eid].queue
1096 uid = self.engines[eid].queue
1093 try:
1097 try:
1094 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1098 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1095 except Exception:
1099 except Exception:
1096 reply = error.wrap_exception()
1100 reply = error.wrap_exception()
1097 break
1101 break
1098
1102
1099 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1103 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1100
1104
1101 def resubmit_task(self, client_id, msg):
1105 def resubmit_task(self, client_id, msg):
1102 """Resubmit one or more tasks."""
1106 """Resubmit one or more tasks."""
1103 def finish(reply):
1107 def finish(reply):
1104 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1108 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1105
1109
1106 content = msg['content']
1110 content = msg['content']
1107 msg_ids = content['msg_ids']
1111 msg_ids = content['msg_ids']
1108 reply = dict(status='ok')
1112 reply = dict(status='ok')
1109 try:
1113 try:
1110 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1114 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1111 'header', 'content', 'buffers'])
1115 'header', 'content', 'buffers'])
1112 except Exception:
1116 except Exception:
1113 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1117 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1114 return finish(error.wrap_exception())
1118 return finish(error.wrap_exception())
1115
1119
1116 # validate msg_ids
1120 # validate msg_ids
1117 found_ids = [ rec['msg_id'] for rec in records ]
1121 found_ids = [ rec['msg_id'] for rec in records ]
1118 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1122 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1119 if len(records) > len(msg_ids):
1123 if len(records) > len(msg_ids):
1120 try:
1124 try:
1121 raise RuntimeError("DB appears to be in an inconsistent state."
1125 raise RuntimeError("DB appears to be in an inconsistent state."
1122 "More matching records were found than should exist")
1126 "More matching records were found than should exist")
1123 except Exception:
1127 except Exception:
1124 return finish(error.wrap_exception())
1128 return finish(error.wrap_exception())
1125 elif len(records) < len(msg_ids):
1129 elif len(records) < len(msg_ids):
1126 missing = [ m for m in msg_ids if m not in found_ids ]
1130 missing = [ m for m in msg_ids if m not in found_ids ]
1127 try:
1131 try:
1128 raise KeyError("No such msg(s): %r"%missing)
1132 raise KeyError("No such msg(s): %r"%missing)
1129 except KeyError:
1133 except KeyError:
1130 return finish(error.wrap_exception())
1134 return finish(error.wrap_exception())
1131 elif invalid_ids:
1135 elif invalid_ids:
1132 msg_id = invalid_ids[0]
1136 msg_id = invalid_ids[0]
1133 try:
1137 try:
1134 raise ValueError("Task %r appears to be inflight"%(msg_id))
1138 raise ValueError("Task %r appears to be inflight"%(msg_id))
1135 except Exception:
1139 except Exception:
1136 return finish(error.wrap_exception())
1140 return finish(error.wrap_exception())
1137
1141
1138 # clear the existing records
1142 # clear the existing records
1139 now = datetime.now()
1143 now = datetime.now()
1140 rec = empty_record()
1144 rec = empty_record()
1141 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1145 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1142 rec['resubmitted'] = now
1146 rec['resubmitted'] = now
1143 rec['queue'] = 'task'
1147 rec['queue'] = 'task'
1144 rec['client_uuid'] = client_id[0]
1148 rec['client_uuid'] = client_id[0]
1145 try:
1149 try:
1146 for msg_id in msg_ids:
1150 for msg_id in msg_ids:
1147 self.all_completed.discard(msg_id)
1151 self.all_completed.discard(msg_id)
1148 self.db.update_record(msg_id, rec)
1152 self.db.update_record(msg_id, rec)
1149 except Exception:
1153 except Exception:
1150 self.log.error('db::db error upating record', exc_info=True)
1154 self.log.error('db::db error upating record', exc_info=True)
1151 reply = error.wrap_exception()
1155 reply = error.wrap_exception()
1152 else:
1156 else:
1153 # send the messages
1157 # send the messages
1154 for rec in records:
1158 for rec in records:
1155 header = rec['header']
1159 header = rec['header']
1156 # include resubmitted in header to prevent digest collision
1160 # include resubmitted in header to prevent digest collision
1157 header['resubmitted'] = now
1161 header['resubmitted'] = now
1158 msg = self.session.msg(header['msg_type'])
1162 msg = self.session.msg(header['msg_type'])
1159 msg['content'] = rec['content']
1163 msg['content'] = rec['content']
1160 msg['header'] = header
1164 msg['header'] = header
1161 msg['msg_id'] = rec['msg_id']
1165 msg['msg_id'] = rec['msg_id']
1162 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1166 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1163
1167
1164 finish(dict(status='ok'))
1168 finish(dict(status='ok'))
1165
1169
1166
1170
1167 def _extract_record(self, rec):
1171 def _extract_record(self, rec):
1168 """decompose a TaskRecord dict into subsection of reply for get_result"""
1172 """decompose a TaskRecord dict into subsection of reply for get_result"""
1169 io_dict = {}
1173 io_dict = {}
1170 for key in 'pyin pyout pyerr stdout stderr'.split():
1174 for key in 'pyin pyout pyerr stdout stderr'.split():
1171 io_dict[key] = rec[key]
1175 io_dict[key] = rec[key]
1172 content = { 'result_content': rec['result_content'],
1176 content = { 'result_content': rec['result_content'],
1173 'header': rec['header'],
1177 'header': rec['header'],
1174 'result_header' : rec['result_header'],
1178 'result_header' : rec['result_header'],
1175 'io' : io_dict,
1179 'io' : io_dict,
1176 }
1180 }
1177 if rec['result_buffers']:
1181 if rec['result_buffers']:
1178 buffers = map(str, rec['result_buffers'])
1182 buffers = map(str, rec['result_buffers'])
1179 else:
1183 else:
1180 buffers = []
1184 buffers = []
1181
1185
1182 return content, buffers
1186 return content, buffers
1183
1187
1184 def get_results(self, client_id, msg):
1188 def get_results(self, client_id, msg):
1185 """Get the result of 1 or more messages."""
1189 """Get the result of 1 or more messages."""
1186 content = msg['content']
1190 content = msg['content']
1187 msg_ids = sorted(set(content['msg_ids']))
1191 msg_ids = sorted(set(content['msg_ids']))
1188 statusonly = content.get('status_only', False)
1192 statusonly = content.get('status_only', False)
1189 pending = []
1193 pending = []
1190 completed = []
1194 completed = []
1191 content = dict(status='ok')
1195 content = dict(status='ok')
1192 content['pending'] = pending
1196 content['pending'] = pending
1193 content['completed'] = completed
1197 content['completed'] = completed
1194 buffers = []
1198 buffers = []
1195 if not statusonly:
1199 if not statusonly:
1196 try:
1200 try:
1197 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1201 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1198 # turn match list into dict, for faster lookup
1202 # turn match list into dict, for faster lookup
1199 records = {}
1203 records = {}
1200 for rec in matches:
1204 for rec in matches:
1201 records[rec['msg_id']] = rec
1205 records[rec['msg_id']] = rec
1202 except Exception:
1206 except Exception:
1203 content = error.wrap_exception()
1207 content = error.wrap_exception()
1204 self.session.send(self.query, "result_reply", content=content,
1208 self.session.send(self.query, "result_reply", content=content,
1205 parent=msg, ident=client_id)
1209 parent=msg, ident=client_id)
1206 return
1210 return
1207 else:
1211 else:
1208 records = {}
1212 records = {}
1209 for msg_id in msg_ids:
1213 for msg_id in msg_ids:
1210 if msg_id in self.pending:
1214 if msg_id in self.pending:
1211 pending.append(msg_id)
1215 pending.append(msg_id)
1212 elif msg_id in self.all_completed:
1216 elif msg_id in self.all_completed:
1213 completed.append(msg_id)
1217 completed.append(msg_id)
1214 if not statusonly:
1218 if not statusonly:
1215 c,bufs = self._extract_record(records[msg_id])
1219 c,bufs = self._extract_record(records[msg_id])
1216 content[msg_id] = c
1220 content[msg_id] = c
1217 buffers.extend(bufs)
1221 buffers.extend(bufs)
1218 elif msg_id in records:
1222 elif msg_id in records:
1219 if rec['completed']:
1223 if rec['completed']:
1220 completed.append(msg_id)
1224 completed.append(msg_id)
1221 c,bufs = self._extract_record(records[msg_id])
1225 c,bufs = self._extract_record(records[msg_id])
1222 content[msg_id] = c
1226 content[msg_id] = c
1223 buffers.extend(bufs)
1227 buffers.extend(bufs)
1224 else:
1228 else:
1225 pending.append(msg_id)
1229 pending.append(msg_id)
1226 else:
1230 else:
1227 try:
1231 try:
1228 raise KeyError('No such message: '+msg_id)
1232 raise KeyError('No such message: '+msg_id)
1229 except:
1233 except:
1230 content = error.wrap_exception()
1234 content = error.wrap_exception()
1231 break
1235 break
1232 self.session.send(self.query, "result_reply", content=content,
1236 self.session.send(self.query, "result_reply", content=content,
1233 parent=msg, ident=client_id,
1237 parent=msg, ident=client_id,
1234 buffers=buffers)
1238 buffers=buffers)
1235
1239
1236 def get_history(self, client_id, msg):
1240 def get_history(self, client_id, msg):
1237 """Get a list of all msg_ids in our DB records"""
1241 """Get a list of all msg_ids in our DB records"""
1238 try:
1242 try:
1239 msg_ids = self.db.get_history()
1243 msg_ids = self.db.get_history()
1240 except Exception as e:
1244 except Exception as e:
1241 content = error.wrap_exception()
1245 content = error.wrap_exception()
1242 else:
1246 else:
1243 content = dict(status='ok', history=msg_ids)
1247 content = dict(status='ok', history=msg_ids)
1244
1248
1245 self.session.send(self.query, "history_reply", content=content,
1249 self.session.send(self.query, "history_reply", content=content,
1246 parent=msg, ident=client_id)
1250 parent=msg, ident=client_id)
1247
1251
1248 def db_query(self, client_id, msg):
1252 def db_query(self, client_id, msg):
1249 """Perform a raw query on the task record database."""
1253 """Perform a raw query on the task record database."""
1250 content = msg['content']
1254 content = msg['content']
1251 query = content.get('query', {})
1255 query = content.get('query', {})
1252 keys = content.get('keys', None)
1256 keys = content.get('keys', None)
1253 buffers = []
1257 buffers = []
1254 empty = list()
1258 empty = list()
1255 try:
1259 try:
1256 records = self.db.find_records(query, keys)
1260 records = self.db.find_records(query, keys)
1257 except Exception as e:
1261 except Exception as e:
1258 content = error.wrap_exception()
1262 content = error.wrap_exception()
1259 else:
1263 else:
1260 # extract buffers from reply content:
1264 # extract buffers from reply content:
1261 if keys is not None:
1265 if keys is not None:
1262 buffer_lens = [] if 'buffers' in keys else None
1266 buffer_lens = [] if 'buffers' in keys else None
1263 result_buffer_lens = [] if 'result_buffers' in keys else None
1267 result_buffer_lens = [] if 'result_buffers' in keys else None
1264 else:
1268 else:
1265 buffer_lens = []
1269 buffer_lens = []
1266 result_buffer_lens = []
1270 result_buffer_lens = []
1267
1271
1268 for rec in records:
1272 for rec in records:
1269 # buffers may be None, so double check
1273 # buffers may be None, so double check
1270 if buffer_lens is not None:
1274 if buffer_lens is not None:
1271 b = rec.pop('buffers', empty) or empty
1275 b = rec.pop('buffers', empty) or empty
1272 buffer_lens.append(len(b))
1276 buffer_lens.append(len(b))
1273 buffers.extend(b)
1277 buffers.extend(b)
1274 if result_buffer_lens is not None:
1278 if result_buffer_lens is not None:
1275 rb = rec.pop('result_buffers', empty) or empty
1279 rb = rec.pop('result_buffers', empty) or empty
1276 result_buffer_lens.append(len(rb))
1280 result_buffer_lens.append(len(rb))
1277 buffers.extend(rb)
1281 buffers.extend(rb)
1278 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1282 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1279 result_buffer_lens=result_buffer_lens)
1283 result_buffer_lens=result_buffer_lens)
1280
1284
1281 self.session.send(self.query, "db_reply", content=content,
1285 self.session.send(self.query, "db_reply", content=content,
1282 parent=msg, ident=client_id,
1286 parent=msg, ident=client_id,
1283 buffers=buffers)
1287 buffers=buffers)
1284
1288
@@ -1,112 +1,117 b''
1 """A TaskRecord backend using mongodb"""
1 """A TaskRecord backend using mongodb
2
3 Authors:
4
5 * Min RK
6 """
2 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
4 #
9 #
5 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
8
13
9 from pymongo import Connection
14 from pymongo import Connection
10 from pymongo.binary import Binary
15 from pymongo.binary import Binary
11
16
12 from IPython.utils.traitlets import Dict, List, Unicode, Instance
17 from IPython.utils.traitlets import Dict, List, Unicode, Instance
13
18
14 from .dictdb import BaseDB
19 from .dictdb import BaseDB
15
20
16 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
17 # MongoDB class
22 # MongoDB class
18 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
19
24
20 class MongoDB(BaseDB):
25 class MongoDB(BaseDB):
21 """MongoDB TaskRecord backend."""
26 """MongoDB TaskRecord backend."""
22
27
23 connection_args = List(config=True,
28 connection_args = List(config=True,
24 help="""Positional arguments to be passed to pymongo.Connection. Only
29 help="""Positional arguments to be passed to pymongo.Connection. Only
25 necessary if the default mongodb configuration does not point to your
30 necessary if the default mongodb configuration does not point to your
26 mongod instance.""")
31 mongod instance.""")
27 connection_kwargs = Dict(config=True,
32 connection_kwargs = Dict(config=True,
28 help="""Keyword arguments to be passed to pymongo.Connection. Only
33 help="""Keyword arguments to be passed to pymongo.Connection. Only
29 necessary if the default mongodb configuration does not point to your
34 necessary if the default mongodb configuration does not point to your
30 mongod instance."""
35 mongod instance."""
31 )
36 )
32 database = Unicode(config=True,
37 database = Unicode(config=True,
33 help="""The MongoDB database name to use for storing tasks for this session. If unspecified,
38 help="""The MongoDB database name to use for storing tasks for this session. If unspecified,
34 a new database will be created with the Hub's IDENT. Specifying the database will result
39 a new database will be created with the Hub's IDENT. Specifying the database will result
35 in tasks from previous sessions being available via Clients' db_query and
40 in tasks from previous sessions being available via Clients' db_query and
36 get_result methods.""")
41 get_result methods.""")
37
42
38 _connection = Instance(Connection) # pymongo connection
43 _connection = Instance(Connection) # pymongo connection
39
44
40 def __init__(self, **kwargs):
45 def __init__(self, **kwargs):
41 super(MongoDB, self).__init__(**kwargs)
46 super(MongoDB, self).__init__(**kwargs)
42 if self._connection is None:
47 if self._connection is None:
43 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
48 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
44 if not self.database:
49 if not self.database:
45 self.database = self.session
50 self.database = self.session
46 self._db = self._connection[self.database]
51 self._db = self._connection[self.database]
47 self._records = self._db['task_records']
52 self._records = self._db['task_records']
48 self._records.ensure_index('msg_id', unique=True)
53 self._records.ensure_index('msg_id', unique=True)
49 self._records.ensure_index('submitted') # for sorting history
54 self._records.ensure_index('submitted') # for sorting history
50 # for rec in self._records.find
55 # for rec in self._records.find
51
56
52 def _binary_buffers(self, rec):
57 def _binary_buffers(self, rec):
53 for key in ('buffers', 'result_buffers'):
58 for key in ('buffers', 'result_buffers'):
54 if rec.get(key, None):
59 if rec.get(key, None):
55 rec[key] = map(Binary, rec[key])
60 rec[key] = map(Binary, rec[key])
56 return rec
61 return rec
57
62
58 def add_record(self, msg_id, rec):
63 def add_record(self, msg_id, rec):
59 """Add a new Task Record, by msg_id."""
64 """Add a new Task Record, by msg_id."""
60 # print rec
65 # print rec
61 rec = self._binary_buffers(rec)
66 rec = self._binary_buffers(rec)
62 self._records.insert(rec)
67 self._records.insert(rec)
63
68
64 def get_record(self, msg_id):
69 def get_record(self, msg_id):
65 """Get a specific Task Record, by msg_id."""
70 """Get a specific Task Record, by msg_id."""
66 r = self._records.find_one({'msg_id': msg_id})
71 r = self._records.find_one({'msg_id': msg_id})
67 if not r:
72 if not r:
68 # r will be '' if nothing is found
73 # r will be '' if nothing is found
69 raise KeyError(msg_id)
74 raise KeyError(msg_id)
70 return r
75 return r
71
76
72 def update_record(self, msg_id, rec):
77 def update_record(self, msg_id, rec):
73 """Update the data in an existing record."""
78 """Update the data in an existing record."""
74 rec = self._binary_buffers(rec)
79 rec = self._binary_buffers(rec)
75
80
76 self._records.update({'msg_id':msg_id}, {'$set': rec})
81 self._records.update({'msg_id':msg_id}, {'$set': rec})
77
82
78 def drop_matching_records(self, check):
83 def drop_matching_records(self, check):
79 """Remove a record from the DB."""
84 """Remove a record from the DB."""
80 self._records.remove(check)
85 self._records.remove(check)
81
86
82 def drop_record(self, msg_id):
87 def drop_record(self, msg_id):
83 """Remove a record from the DB."""
88 """Remove a record from the DB."""
84 self._records.remove({'msg_id':msg_id})
89 self._records.remove({'msg_id':msg_id})
85
90
86 def find_records(self, check, keys=None):
91 def find_records(self, check, keys=None):
87 """Find records matching a query dict, optionally extracting subset of keys.
92 """Find records matching a query dict, optionally extracting subset of keys.
88
93
89 Returns list of matching records.
94 Returns list of matching records.
90
95
91 Parameters
96 Parameters
92 ----------
97 ----------
93
98
94 check: dict
99 check: dict
95 mongodb-style query argument
100 mongodb-style query argument
96 keys: list of strs [optional]
101 keys: list of strs [optional]
97 if specified, the subset of keys to extract. msg_id will *always* be
102 if specified, the subset of keys to extract. msg_id will *always* be
98 included.
103 included.
99 """
104 """
100 if keys and 'msg_id' not in keys:
105 if keys and 'msg_id' not in keys:
101 keys.append('msg_id')
106 keys.append('msg_id')
102 matches = list(self._records.find(check,keys))
107 matches = list(self._records.find(check,keys))
103 for rec in matches:
108 for rec in matches:
104 rec.pop('_id')
109 rec.pop('_id')
105 return matches
110 return matches
106
111
107 def get_history(self):
112 def get_history(self):
108 """get all msg_ids, ordered by time submitted."""
113 """get all msg_ids, ordered by time submitted."""
109 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
114 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
110 return [ rec['msg_id'] for rec in cursor ]
115 return [ rec['msg_id'] for rec in cursor ]
111
116
112
117
@@ -1,688 +1,692 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6
7 Authors:
8
9 * Min RK
6 """
10 """
7 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
12 # Copyright (C) 2010-2011 The IPython Development Team
9 #
13 #
10 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
13
17
14 #----------------------------------------------------------------------
18 #----------------------------------------------------------------------
15 # Imports
19 # Imports
16 #----------------------------------------------------------------------
20 #----------------------------------------------------------------------
17
21
18 from __future__ import print_function
22 from __future__ import print_function
19
23
20 import logging
24 import logging
21 import sys
25 import sys
22
26
23 from datetime import datetime, timedelta
27 from datetime import datetime, timedelta
24 from random import randint, random
28 from random import randint, random
25 from types import FunctionType
29 from types import FunctionType
26
30
27 try:
31 try:
28 import numpy
32 import numpy
29 except ImportError:
33 except ImportError:
30 numpy = None
34 numpy = None
31
35
32 import zmq
36 import zmq
33 from zmq.eventloop import ioloop, zmqstream
37 from zmq.eventloop import ioloop, zmqstream
34
38
35 # local imports
39 # local imports
36 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
37 from IPython.config.loader import Config
41 from IPython.config.loader import Config
38 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Str, Enum
42 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Str, Enum
39
43
40 from IPython.parallel import error
44 from IPython.parallel import error
41 from IPython.parallel.factory import SessionFactory
45 from IPython.parallel.factory import SessionFactory
42 from IPython.parallel.util import connect_logger, local_logger
46 from IPython.parallel.util import connect_logger, local_logger
43
47
44 from .dependency import Dependency
48 from .dependency import Dependency
45
49
46 @decorator
50 @decorator
47 def logged(f,self,*args,**kwargs):
51 def logged(f,self,*args,**kwargs):
48 # print ("#--------------------")
52 # print ("#--------------------")
49 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
53 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
50 # print ("#--")
54 # print ("#--")
51 return f(self,*args, **kwargs)
55 return f(self,*args, **kwargs)
52
56
53 #----------------------------------------------------------------------
57 #----------------------------------------------------------------------
54 # Chooser functions
58 # Chooser functions
55 #----------------------------------------------------------------------
59 #----------------------------------------------------------------------
56
60
57 def plainrandom(loads):
61 def plainrandom(loads):
58 """Plain random pick."""
62 """Plain random pick."""
59 n = len(loads)
63 n = len(loads)
60 return randint(0,n-1)
64 return randint(0,n-1)
61
65
62 def lru(loads):
66 def lru(loads):
63 """Always pick the front of the line.
67 """Always pick the front of the line.
64
68
65 The content of `loads` is ignored.
69 The content of `loads` is ignored.
66
70
67 Assumes LRU ordering of loads, with oldest first.
71 Assumes LRU ordering of loads, with oldest first.
68 """
72 """
69 return 0
73 return 0
70
74
71 def twobin(loads):
75 def twobin(loads):
72 """Pick two at random, use the LRU of the two.
76 """Pick two at random, use the LRU of the two.
73
77
74 The content of loads is ignored.
78 The content of loads is ignored.
75
79
76 Assumes LRU ordering of loads, with oldest first.
80 Assumes LRU ordering of loads, with oldest first.
77 """
81 """
78 n = len(loads)
82 n = len(loads)
79 a = randint(0,n-1)
83 a = randint(0,n-1)
80 b = randint(0,n-1)
84 b = randint(0,n-1)
81 return min(a,b)
85 return min(a,b)
82
86
83 def weighted(loads):
87 def weighted(loads):
84 """Pick two at random using inverse load as weight.
88 """Pick two at random using inverse load as weight.
85
89
86 Return the less loaded of the two.
90 Return the less loaded of the two.
87 """
91 """
88 # weight 0 a million times more than 1:
92 # weight 0 a million times more than 1:
89 weights = 1./(1e-6+numpy.array(loads))
93 weights = 1./(1e-6+numpy.array(loads))
90 sums = weights.cumsum()
94 sums = weights.cumsum()
91 t = sums[-1]
95 t = sums[-1]
92 x = random()*t
96 x = random()*t
93 y = random()*t
97 y = random()*t
94 idx = 0
98 idx = 0
95 idy = 0
99 idy = 0
96 while sums[idx] < x:
100 while sums[idx] < x:
97 idx += 1
101 idx += 1
98 while sums[idy] < y:
102 while sums[idy] < y:
99 idy += 1
103 idy += 1
100 if weights[idy] > weights[idx]:
104 if weights[idy] > weights[idx]:
101 return idy
105 return idy
102 else:
106 else:
103 return idx
107 return idx
104
108
105 def leastload(loads):
109 def leastload(loads):
106 """Always choose the lowest load.
110 """Always choose the lowest load.
107
111
108 If the lowest load occurs more than once, the first
112 If the lowest load occurs more than once, the first
109 occurance will be used. If loads has LRU ordering, this means
113 occurance will be used. If loads has LRU ordering, this means
110 the LRU of those with the lowest load is chosen.
114 the LRU of those with the lowest load is chosen.
111 """
115 """
112 return loads.index(min(loads))
116 return loads.index(min(loads))
113
117
114 #---------------------------------------------------------------------
118 #---------------------------------------------------------------------
115 # Classes
119 # Classes
116 #---------------------------------------------------------------------
120 #---------------------------------------------------------------------
117 # store empty default dependency:
121 # store empty default dependency:
118 MET = Dependency([])
122 MET = Dependency([])
119
123
120 class TaskScheduler(SessionFactory):
124 class TaskScheduler(SessionFactory):
121 """Python TaskScheduler object.
125 """Python TaskScheduler object.
122
126
123 This is the simplest object that supports msg_id based
127 This is the simplest object that supports msg_id based
124 DAG dependencies. *Only* task msg_ids are checked, not
128 DAG dependencies. *Only* task msg_ids are checked, not
125 msg_ids of jobs submitted via the MUX queue.
129 msg_ids of jobs submitted via the MUX queue.
126
130
127 """
131 """
128
132
129 hwm = Int(0, config=True, shortname='hwm',
133 hwm = Int(0, config=True, shortname='hwm',
130 help="""specify the High Water Mark (HWM) for the downstream
134 help="""specify the High Water Mark (HWM) for the downstream
131 socket in the Task scheduler. This is the maximum number
135 socket in the Task scheduler. This is the maximum number
132 of allowed outstanding tasks on each engine."""
136 of allowed outstanding tasks on each engine."""
133 )
137 )
134 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
138 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
135 'leastload', config=True, shortname='scheme', allow_none=False,
139 'leastload', config=True, shortname='scheme', allow_none=False,
136 help="""select the task scheduler scheme [default: Python LRU]
140 help="""select the task scheduler scheme [default: Python LRU]
137 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
141 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
138 )
142 )
139 def _scheme_name_changed(self, old, new):
143 def _scheme_name_changed(self, old, new):
140 self.log.debug("Using scheme %r"%new)
144 self.log.debug("Using scheme %r"%new)
141 self.scheme = globals()[new]
145 self.scheme = globals()[new]
142
146
143 # input arguments:
147 # input arguments:
144 scheme = Instance(FunctionType) # function for determining the destination
148 scheme = Instance(FunctionType) # function for determining the destination
145 def _scheme_default(self):
149 def _scheme_default(self):
146 return leastload
150 return leastload
147 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
151 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
148 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
152 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
149 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
153 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
150 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
154 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
151
155
152 # internals:
156 # internals:
153 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
157 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
154 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
158 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
155 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
159 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
156 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
160 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
157 pending = Dict() # dict by engine_uuid of submitted tasks
161 pending = Dict() # dict by engine_uuid of submitted tasks
158 completed = Dict() # dict by engine_uuid of completed tasks
162 completed = Dict() # dict by engine_uuid of completed tasks
159 failed = Dict() # dict by engine_uuid of failed tasks
163 failed = Dict() # dict by engine_uuid of failed tasks
160 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
164 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
161 clients = Dict() # dict by msg_id for who submitted the task
165 clients = Dict() # dict by msg_id for who submitted the task
162 targets = List() # list of target IDENTs
166 targets = List() # list of target IDENTs
163 loads = List() # list of engine loads
167 loads = List() # list of engine loads
164 # full = Set() # set of IDENTs that have HWM outstanding tasks
168 # full = Set() # set of IDENTs that have HWM outstanding tasks
165 all_completed = Set() # set of all completed tasks
169 all_completed = Set() # set of all completed tasks
166 all_failed = Set() # set of all failed tasks
170 all_failed = Set() # set of all failed tasks
167 all_done = Set() # set of all finished tasks=union(completed,failed)
171 all_done = Set() # set of all finished tasks=union(completed,failed)
168 all_ids = Set() # set of all submitted task IDs
172 all_ids = Set() # set of all submitted task IDs
169 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
173 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
170 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
174 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
171
175
172
176
173 def start(self):
177 def start(self):
174 self.engine_stream.on_recv(self.dispatch_result, copy=False)
178 self.engine_stream.on_recv(self.dispatch_result, copy=False)
175 self._notification_handlers = dict(
179 self._notification_handlers = dict(
176 registration_notification = self._register_engine,
180 registration_notification = self._register_engine,
177 unregistration_notification = self._unregister_engine
181 unregistration_notification = self._unregister_engine
178 )
182 )
179 self.notifier_stream.on_recv(self.dispatch_notification)
183 self.notifier_stream.on_recv(self.dispatch_notification)
180 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
184 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
181 self.auditor.start()
185 self.auditor.start()
182 self.log.info("Scheduler started [%s]"%self.scheme_name)
186 self.log.info("Scheduler started [%s]"%self.scheme_name)
183
187
184 def resume_receiving(self):
188 def resume_receiving(self):
185 """Resume accepting jobs."""
189 """Resume accepting jobs."""
186 self.client_stream.on_recv(self.dispatch_submission, copy=False)
190 self.client_stream.on_recv(self.dispatch_submission, copy=False)
187
191
188 def stop_receiving(self):
192 def stop_receiving(self):
189 """Stop accepting jobs while there are no engines.
193 """Stop accepting jobs while there are no engines.
190 Leave them in the ZMQ queue."""
194 Leave them in the ZMQ queue."""
191 self.client_stream.on_recv(None)
195 self.client_stream.on_recv(None)
192
196
193 #-----------------------------------------------------------------------
197 #-----------------------------------------------------------------------
194 # [Un]Registration Handling
198 # [Un]Registration Handling
195 #-----------------------------------------------------------------------
199 #-----------------------------------------------------------------------
196
200
197 def dispatch_notification(self, msg):
201 def dispatch_notification(self, msg):
198 """dispatch register/unregister events."""
202 """dispatch register/unregister events."""
199 try:
203 try:
200 idents,msg = self.session.feed_identities(msg)
204 idents,msg = self.session.feed_identities(msg)
201 except ValueError:
205 except ValueError:
202 self.log.warn("task::Invalid Message: %r"%msg)
206 self.log.warn("task::Invalid Message: %r"%msg)
203 return
207 return
204 try:
208 try:
205 msg = self.session.unpack_message(msg)
209 msg = self.session.unpack_message(msg)
206 except ValueError:
210 except ValueError:
207 self.log.warn("task::Unauthorized message from: %r"%idents)
211 self.log.warn("task::Unauthorized message from: %r"%idents)
208 return
212 return
209
213
210 msg_type = msg['msg_type']
214 msg_type = msg['msg_type']
211
215
212 handler = self._notification_handlers.get(msg_type, None)
216 handler = self._notification_handlers.get(msg_type, None)
213 if handler is None:
217 if handler is None:
214 self.log.error("Unhandled message type: %r"%msg_type)
218 self.log.error("Unhandled message type: %r"%msg_type)
215 else:
219 else:
216 try:
220 try:
217 handler(str(msg['content']['queue']))
221 handler(str(msg['content']['queue']))
218 except KeyError:
222 except KeyError:
219 self.log.error("task::Invalid notification msg: %r"%msg)
223 self.log.error("task::Invalid notification msg: %r"%msg)
220
224
221 @logged
225 @logged
222 def _register_engine(self, uid):
226 def _register_engine(self, uid):
223 """New engine with ident `uid` became available."""
227 """New engine with ident `uid` became available."""
224 # head of the line:
228 # head of the line:
225 self.targets.insert(0,uid)
229 self.targets.insert(0,uid)
226 self.loads.insert(0,0)
230 self.loads.insert(0,0)
227 # initialize sets
231 # initialize sets
228 self.completed[uid] = set()
232 self.completed[uid] = set()
229 self.failed[uid] = set()
233 self.failed[uid] = set()
230 self.pending[uid] = {}
234 self.pending[uid] = {}
231 if len(self.targets) == 1:
235 if len(self.targets) == 1:
232 self.resume_receiving()
236 self.resume_receiving()
233 # rescan the graph:
237 # rescan the graph:
234 self.update_graph(None)
238 self.update_graph(None)
235
239
236 def _unregister_engine(self, uid):
240 def _unregister_engine(self, uid):
237 """Existing engine with ident `uid` became unavailable."""
241 """Existing engine with ident `uid` became unavailable."""
238 if len(self.targets) == 1:
242 if len(self.targets) == 1:
239 # this was our only engine
243 # this was our only engine
240 self.stop_receiving()
244 self.stop_receiving()
241
245
242 # handle any potentially finished tasks:
246 # handle any potentially finished tasks:
243 self.engine_stream.flush()
247 self.engine_stream.flush()
244
248
245 # don't pop destinations, because they might be used later
249 # don't pop destinations, because they might be used later
246 # map(self.destinations.pop, self.completed.pop(uid))
250 # map(self.destinations.pop, self.completed.pop(uid))
247 # map(self.destinations.pop, self.failed.pop(uid))
251 # map(self.destinations.pop, self.failed.pop(uid))
248
252
249 # prevent this engine from receiving work
253 # prevent this engine from receiving work
250 idx = self.targets.index(uid)
254 idx = self.targets.index(uid)
251 self.targets.pop(idx)
255 self.targets.pop(idx)
252 self.loads.pop(idx)
256 self.loads.pop(idx)
253
257
254 # wait 5 seconds before cleaning up pending jobs, since the results might
258 # wait 5 seconds before cleaning up pending jobs, since the results might
255 # still be incoming
259 # still be incoming
256 if self.pending[uid]:
260 if self.pending[uid]:
257 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
261 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
258 dc.start()
262 dc.start()
259 else:
263 else:
260 self.completed.pop(uid)
264 self.completed.pop(uid)
261 self.failed.pop(uid)
265 self.failed.pop(uid)
262
266
263
267
264 @logged
268 @logged
265 def handle_stranded_tasks(self, engine):
269 def handle_stranded_tasks(self, engine):
266 """Deal with jobs resident in an engine that died."""
270 """Deal with jobs resident in an engine that died."""
267 lost = self.pending[engine]
271 lost = self.pending[engine]
268 for msg_id in lost.keys():
272 for msg_id in lost.keys():
269 if msg_id not in self.pending[engine]:
273 if msg_id not in self.pending[engine]:
270 # prevent double-handling of messages
274 # prevent double-handling of messages
271 continue
275 continue
272
276
273 raw_msg = lost[msg_id][0]
277 raw_msg = lost[msg_id][0]
274 idents,msg = self.session.feed_identities(raw_msg, copy=False)
278 idents,msg = self.session.feed_identities(raw_msg, copy=False)
275 parent = self.session.unpack(msg[1].bytes)
279 parent = self.session.unpack(msg[1].bytes)
276 idents = [engine, idents[0]]
280 idents = [engine, idents[0]]
277
281
278 # build fake error reply
282 # build fake error reply
279 try:
283 try:
280 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
284 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
281 except:
285 except:
282 content = error.wrap_exception()
286 content = error.wrap_exception()
283 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
287 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
284 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
288 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
285 # and dispatch it
289 # and dispatch it
286 self.dispatch_result(raw_reply)
290 self.dispatch_result(raw_reply)
287
291
288 # finally scrub completed/failed lists
292 # finally scrub completed/failed lists
289 self.completed.pop(engine)
293 self.completed.pop(engine)
290 self.failed.pop(engine)
294 self.failed.pop(engine)
291
295
292
296
293 #-----------------------------------------------------------------------
297 #-----------------------------------------------------------------------
294 # Job Submission
298 # Job Submission
295 #-----------------------------------------------------------------------
299 #-----------------------------------------------------------------------
296 @logged
300 @logged
297 def dispatch_submission(self, raw_msg):
301 def dispatch_submission(self, raw_msg):
298 """Dispatch job submission to appropriate handlers."""
302 """Dispatch job submission to appropriate handlers."""
299 # ensure targets up to date:
303 # ensure targets up to date:
300 self.notifier_stream.flush()
304 self.notifier_stream.flush()
301 try:
305 try:
302 idents, msg = self.session.feed_identities(raw_msg, copy=False)
306 idents, msg = self.session.feed_identities(raw_msg, copy=False)
303 msg = self.session.unpack_message(msg, content=False, copy=False)
307 msg = self.session.unpack_message(msg, content=False, copy=False)
304 except Exception:
308 except Exception:
305 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
309 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
306 return
310 return
307
311
308
312
309 # send to monitor
313 # send to monitor
310 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
314 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
311
315
312 header = msg['header']
316 header = msg['header']
313 msg_id = header['msg_id']
317 msg_id = header['msg_id']
314 self.all_ids.add(msg_id)
318 self.all_ids.add(msg_id)
315
319
316 # targets
320 # targets
317 targets = set(header.get('targets', []))
321 targets = set(header.get('targets', []))
318 retries = header.get('retries', 0)
322 retries = header.get('retries', 0)
319 self.retries[msg_id] = retries
323 self.retries[msg_id] = retries
320
324
321 # time dependencies
325 # time dependencies
322 after = Dependency(header.get('after', []))
326 after = Dependency(header.get('after', []))
323 if after.all:
327 if after.all:
324 if after.success:
328 if after.success:
325 after.difference_update(self.all_completed)
329 after.difference_update(self.all_completed)
326 if after.failure:
330 if after.failure:
327 after.difference_update(self.all_failed)
331 after.difference_update(self.all_failed)
328 if after.check(self.all_completed, self.all_failed):
332 if after.check(self.all_completed, self.all_failed):
329 # recast as empty set, if `after` already met,
333 # recast as empty set, if `after` already met,
330 # to prevent unnecessary set comparisons
334 # to prevent unnecessary set comparisons
331 after = MET
335 after = MET
332
336
333 # location dependencies
337 # location dependencies
334 follow = Dependency(header.get('follow', []))
338 follow = Dependency(header.get('follow', []))
335
339
336 # turn timeouts into datetime objects:
340 # turn timeouts into datetime objects:
337 timeout = header.get('timeout', None)
341 timeout = header.get('timeout', None)
338 if timeout:
342 if timeout:
339 timeout = datetime.now() + timedelta(0,timeout,0)
343 timeout = datetime.now() + timedelta(0,timeout,0)
340
344
341 args = [raw_msg, targets, after, follow, timeout]
345 args = [raw_msg, targets, after, follow, timeout]
342
346
343 # validate and reduce dependencies:
347 # validate and reduce dependencies:
344 for dep in after,follow:
348 for dep in after,follow:
345 # check valid:
349 # check valid:
346 if msg_id in dep or dep.difference(self.all_ids):
350 if msg_id in dep or dep.difference(self.all_ids):
347 self.depending[msg_id] = args
351 self.depending[msg_id] = args
348 return self.fail_unreachable(msg_id, error.InvalidDependency)
352 return self.fail_unreachable(msg_id, error.InvalidDependency)
349 # check if unreachable:
353 # check if unreachable:
350 if dep.unreachable(self.all_completed, self.all_failed):
354 if dep.unreachable(self.all_completed, self.all_failed):
351 self.depending[msg_id] = args
355 self.depending[msg_id] = args
352 return self.fail_unreachable(msg_id)
356 return self.fail_unreachable(msg_id)
353
357
354 if after.check(self.all_completed, self.all_failed):
358 if after.check(self.all_completed, self.all_failed):
355 # time deps already met, try to run
359 # time deps already met, try to run
356 if not self.maybe_run(msg_id, *args):
360 if not self.maybe_run(msg_id, *args):
357 # can't run yet
361 # can't run yet
358 if msg_id not in self.all_failed:
362 if msg_id not in self.all_failed:
359 # could have failed as unreachable
363 # could have failed as unreachable
360 self.save_unmet(msg_id, *args)
364 self.save_unmet(msg_id, *args)
361 else:
365 else:
362 self.save_unmet(msg_id, *args)
366 self.save_unmet(msg_id, *args)
363
367
364 # @logged
368 # @logged
365 def audit_timeouts(self):
369 def audit_timeouts(self):
366 """Audit all waiting tasks for expired timeouts."""
370 """Audit all waiting tasks for expired timeouts."""
367 now = datetime.now()
371 now = datetime.now()
368 for msg_id in self.depending.keys():
372 for msg_id in self.depending.keys():
369 # must recheck, in case one failure cascaded to another:
373 # must recheck, in case one failure cascaded to another:
370 if msg_id in self.depending:
374 if msg_id in self.depending:
371 raw,after,targets,follow,timeout = self.depending[msg_id]
375 raw,after,targets,follow,timeout = self.depending[msg_id]
372 if timeout and timeout < now:
376 if timeout and timeout < now:
373 self.fail_unreachable(msg_id, error.TaskTimeout)
377 self.fail_unreachable(msg_id, error.TaskTimeout)
374
378
375 @logged
379 @logged
376 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
380 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
377 """a task has become unreachable, send a reply with an ImpossibleDependency
381 """a task has become unreachable, send a reply with an ImpossibleDependency
378 error."""
382 error."""
379 if msg_id not in self.depending:
383 if msg_id not in self.depending:
380 self.log.error("msg %r already failed!"%msg_id)
384 self.log.error("msg %r already failed!"%msg_id)
381 return
385 return
382 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
386 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
383 for mid in follow.union(after):
387 for mid in follow.union(after):
384 if mid in self.graph:
388 if mid in self.graph:
385 self.graph[mid].remove(msg_id)
389 self.graph[mid].remove(msg_id)
386
390
387 # FIXME: unpacking a message I've already unpacked, but didn't save:
391 # FIXME: unpacking a message I've already unpacked, but didn't save:
388 idents,msg = self.session.feed_identities(raw_msg, copy=False)
392 idents,msg = self.session.feed_identities(raw_msg, copy=False)
389 header = self.session.unpack(msg[1].bytes)
393 header = self.session.unpack(msg[1].bytes)
390
394
391 try:
395 try:
392 raise why()
396 raise why()
393 except:
397 except:
394 content = error.wrap_exception()
398 content = error.wrap_exception()
395
399
396 self.all_done.add(msg_id)
400 self.all_done.add(msg_id)
397 self.all_failed.add(msg_id)
401 self.all_failed.add(msg_id)
398
402
399 msg = self.session.send(self.client_stream, 'apply_reply', content,
403 msg = self.session.send(self.client_stream, 'apply_reply', content,
400 parent=header, ident=idents)
404 parent=header, ident=idents)
401 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
405 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
402
406
403 self.update_graph(msg_id, success=False)
407 self.update_graph(msg_id, success=False)
404
408
405 @logged
409 @logged
406 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
410 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
407 """check location dependencies, and run if they are met."""
411 """check location dependencies, and run if they are met."""
408 blacklist = self.blacklist.setdefault(msg_id, set())
412 blacklist = self.blacklist.setdefault(msg_id, set())
409 if follow or targets or blacklist or self.hwm:
413 if follow or targets or blacklist or self.hwm:
410 # we need a can_run filter
414 # we need a can_run filter
411 def can_run(idx):
415 def can_run(idx):
412 # check hwm
416 # check hwm
413 if self.hwm and self.loads[idx] == self.hwm:
417 if self.hwm and self.loads[idx] == self.hwm:
414 return False
418 return False
415 target = self.targets[idx]
419 target = self.targets[idx]
416 # check blacklist
420 # check blacklist
417 if target in blacklist:
421 if target in blacklist:
418 return False
422 return False
419 # check targets
423 # check targets
420 if targets and target not in targets:
424 if targets and target not in targets:
421 return False
425 return False
422 # check follow
426 # check follow
423 return follow.check(self.completed[target], self.failed[target])
427 return follow.check(self.completed[target], self.failed[target])
424
428
425 indices = filter(can_run, range(len(self.targets)))
429 indices = filter(can_run, range(len(self.targets)))
426
430
427 if not indices:
431 if not indices:
428 # couldn't run
432 # couldn't run
429 if follow.all:
433 if follow.all:
430 # check follow for impossibility
434 # check follow for impossibility
431 dests = set()
435 dests = set()
432 relevant = set()
436 relevant = set()
433 if follow.success:
437 if follow.success:
434 relevant = self.all_completed
438 relevant = self.all_completed
435 if follow.failure:
439 if follow.failure:
436 relevant = relevant.union(self.all_failed)
440 relevant = relevant.union(self.all_failed)
437 for m in follow.intersection(relevant):
441 for m in follow.intersection(relevant):
438 dests.add(self.destinations[m])
442 dests.add(self.destinations[m])
439 if len(dests) > 1:
443 if len(dests) > 1:
440 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
444 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
441 self.fail_unreachable(msg_id)
445 self.fail_unreachable(msg_id)
442 return False
446 return False
443 if targets:
447 if targets:
444 # check blacklist+targets for impossibility
448 # check blacklist+targets for impossibility
445 targets.difference_update(blacklist)
449 targets.difference_update(blacklist)
446 if not targets or not targets.intersection(self.targets):
450 if not targets or not targets.intersection(self.targets):
447 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
451 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
448 self.fail_unreachable(msg_id)
452 self.fail_unreachable(msg_id)
449 return False
453 return False
450 return False
454 return False
451 else:
455 else:
452 indices = None
456 indices = None
453
457
454 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
458 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
455 return True
459 return True
456
460
457 @logged
461 @logged
458 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
462 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
459 """Save a message for later submission when its dependencies are met."""
463 """Save a message for later submission when its dependencies are met."""
460 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
464 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
461 # track the ids in follow or after, but not those already finished
465 # track the ids in follow or after, but not those already finished
462 for dep_id in after.union(follow).difference(self.all_done):
466 for dep_id in after.union(follow).difference(self.all_done):
463 if dep_id not in self.graph:
467 if dep_id not in self.graph:
464 self.graph[dep_id] = set()
468 self.graph[dep_id] = set()
465 self.graph[dep_id].add(msg_id)
469 self.graph[dep_id].add(msg_id)
466
470
467 @logged
471 @logged
468 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
472 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
469 """Submit a task to any of a subset of our targets."""
473 """Submit a task to any of a subset of our targets."""
470 if indices:
474 if indices:
471 loads = [self.loads[i] for i in indices]
475 loads = [self.loads[i] for i in indices]
472 else:
476 else:
473 loads = self.loads
477 loads = self.loads
474 idx = self.scheme(loads)
478 idx = self.scheme(loads)
475 if indices:
479 if indices:
476 idx = indices[idx]
480 idx = indices[idx]
477 target = self.targets[idx]
481 target = self.targets[idx]
478 # print (target, map(str, msg[:3]))
482 # print (target, map(str, msg[:3]))
479 # send job to the engine
483 # send job to the engine
480 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
484 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
481 self.engine_stream.send_multipart(raw_msg, copy=False)
485 self.engine_stream.send_multipart(raw_msg, copy=False)
482 # update load
486 # update load
483 self.add_job(idx)
487 self.add_job(idx)
484 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
488 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
485 # notify Hub
489 # notify Hub
486 content = dict(msg_id=msg_id, engine_id=target)
490 content = dict(msg_id=msg_id, engine_id=target)
487 self.session.send(self.mon_stream, 'task_destination', content=content,
491 self.session.send(self.mon_stream, 'task_destination', content=content,
488 ident=['tracktask',self.session.session])
492 ident=['tracktask',self.session.session])
489
493
490
494
491 #-----------------------------------------------------------------------
495 #-----------------------------------------------------------------------
492 # Result Handling
496 # Result Handling
493 #-----------------------------------------------------------------------
497 #-----------------------------------------------------------------------
494 @logged
498 @logged
495 def dispatch_result(self, raw_msg):
499 def dispatch_result(self, raw_msg):
496 """dispatch method for result replies"""
500 """dispatch method for result replies"""
497 try:
501 try:
498 idents,msg = self.session.feed_identities(raw_msg, copy=False)
502 idents,msg = self.session.feed_identities(raw_msg, copy=False)
499 msg = self.session.unpack_message(msg, content=False, copy=False)
503 msg = self.session.unpack_message(msg, content=False, copy=False)
500 engine = idents[0]
504 engine = idents[0]
501 try:
505 try:
502 idx = self.targets.index(engine)
506 idx = self.targets.index(engine)
503 except ValueError:
507 except ValueError:
504 pass # skip load-update for dead engines
508 pass # skip load-update for dead engines
505 else:
509 else:
506 self.finish_job(idx)
510 self.finish_job(idx)
507 except Exception:
511 except Exception:
508 self.log.error("task::Invaid result: %r"%raw_msg, exc_info=True)
512 self.log.error("task::Invaid result: %r"%raw_msg, exc_info=True)
509 return
513 return
510
514
511 header = msg['header']
515 header = msg['header']
512 parent = msg['parent_header']
516 parent = msg['parent_header']
513 if header.get('dependencies_met', True):
517 if header.get('dependencies_met', True):
514 success = (header['status'] == 'ok')
518 success = (header['status'] == 'ok')
515 msg_id = parent['msg_id']
519 msg_id = parent['msg_id']
516 retries = self.retries[msg_id]
520 retries = self.retries[msg_id]
517 if not success and retries > 0:
521 if not success and retries > 0:
518 # failed
522 # failed
519 self.retries[msg_id] = retries - 1
523 self.retries[msg_id] = retries - 1
520 self.handle_unmet_dependency(idents, parent)
524 self.handle_unmet_dependency(idents, parent)
521 else:
525 else:
522 del self.retries[msg_id]
526 del self.retries[msg_id]
523 # relay to client and update graph
527 # relay to client and update graph
524 self.handle_result(idents, parent, raw_msg, success)
528 self.handle_result(idents, parent, raw_msg, success)
525 # send to Hub monitor
529 # send to Hub monitor
526 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
530 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
527 else:
531 else:
528 self.handle_unmet_dependency(idents, parent)
532 self.handle_unmet_dependency(idents, parent)
529
533
530 @logged
534 @logged
531 def handle_result(self, idents, parent, raw_msg, success=True):
535 def handle_result(self, idents, parent, raw_msg, success=True):
532 """handle a real task result, either success or failure"""
536 """handle a real task result, either success or failure"""
533 # first, relay result to client
537 # first, relay result to client
534 engine = idents[0]
538 engine = idents[0]
535 client = idents[1]
539 client = idents[1]
536 # swap_ids for XREP-XREP mirror
540 # swap_ids for XREP-XREP mirror
537 raw_msg[:2] = [client,engine]
541 raw_msg[:2] = [client,engine]
538 # print (map(str, raw_msg[:4]))
542 # print (map(str, raw_msg[:4]))
539 self.client_stream.send_multipart(raw_msg, copy=False)
543 self.client_stream.send_multipart(raw_msg, copy=False)
540 # now, update our data structures
544 # now, update our data structures
541 msg_id = parent['msg_id']
545 msg_id = parent['msg_id']
542 self.blacklist.pop(msg_id, None)
546 self.blacklist.pop(msg_id, None)
543 self.pending[engine].pop(msg_id)
547 self.pending[engine].pop(msg_id)
544 if success:
548 if success:
545 self.completed[engine].add(msg_id)
549 self.completed[engine].add(msg_id)
546 self.all_completed.add(msg_id)
550 self.all_completed.add(msg_id)
547 else:
551 else:
548 self.failed[engine].add(msg_id)
552 self.failed[engine].add(msg_id)
549 self.all_failed.add(msg_id)
553 self.all_failed.add(msg_id)
550 self.all_done.add(msg_id)
554 self.all_done.add(msg_id)
551 self.destinations[msg_id] = engine
555 self.destinations[msg_id] = engine
552
556
553 self.update_graph(msg_id, success)
557 self.update_graph(msg_id, success)
554
558
555 @logged
559 @logged
556 def handle_unmet_dependency(self, idents, parent):
560 def handle_unmet_dependency(self, idents, parent):
557 """handle an unmet dependency"""
561 """handle an unmet dependency"""
558 engine = idents[0]
562 engine = idents[0]
559 msg_id = parent['msg_id']
563 msg_id = parent['msg_id']
560
564
561 if msg_id not in self.blacklist:
565 if msg_id not in self.blacklist:
562 self.blacklist[msg_id] = set()
566 self.blacklist[msg_id] = set()
563 self.blacklist[msg_id].add(engine)
567 self.blacklist[msg_id].add(engine)
564
568
565 args = self.pending[engine].pop(msg_id)
569 args = self.pending[engine].pop(msg_id)
566 raw,targets,after,follow,timeout = args
570 raw,targets,after,follow,timeout = args
567
571
568 if self.blacklist[msg_id] == targets:
572 if self.blacklist[msg_id] == targets:
569 self.depending[msg_id] = args
573 self.depending[msg_id] = args
570 self.fail_unreachable(msg_id)
574 self.fail_unreachable(msg_id)
571 elif not self.maybe_run(msg_id, *args):
575 elif not self.maybe_run(msg_id, *args):
572 # resubmit failed
576 # resubmit failed
573 if msg_id not in self.all_failed:
577 if msg_id not in self.all_failed:
574 # put it back in our dependency tree
578 # put it back in our dependency tree
575 self.save_unmet(msg_id, *args)
579 self.save_unmet(msg_id, *args)
576
580
577 if self.hwm:
581 if self.hwm:
578 try:
582 try:
579 idx = self.targets.index(engine)
583 idx = self.targets.index(engine)
580 except ValueError:
584 except ValueError:
581 pass # skip load-update for dead engines
585 pass # skip load-update for dead engines
582 else:
586 else:
583 if self.loads[idx] == self.hwm-1:
587 if self.loads[idx] == self.hwm-1:
584 self.update_graph(None)
588 self.update_graph(None)
585
589
586
590
587
591
588 @logged
592 @logged
589 def update_graph(self, dep_id=None, success=True):
593 def update_graph(self, dep_id=None, success=True):
590 """dep_id just finished. Update our dependency
594 """dep_id just finished. Update our dependency
591 graph and submit any jobs that just became runable.
595 graph and submit any jobs that just became runable.
592
596
593 Called with dep_id=None to update entire graph for hwm, but without finishing
597 Called with dep_id=None to update entire graph for hwm, but without finishing
594 a task.
598 a task.
595 """
599 """
596 # print ("\n\n***********")
600 # print ("\n\n***********")
597 # pprint (dep_id)
601 # pprint (dep_id)
598 # pprint (self.graph)
602 # pprint (self.graph)
599 # pprint (self.depending)
603 # pprint (self.depending)
600 # pprint (self.all_completed)
604 # pprint (self.all_completed)
601 # pprint (self.all_failed)
605 # pprint (self.all_failed)
602 # print ("\n\n***********\n\n")
606 # print ("\n\n***********\n\n")
603 # update any jobs that depended on the dependency
607 # update any jobs that depended on the dependency
604 jobs = self.graph.pop(dep_id, [])
608 jobs = self.graph.pop(dep_id, [])
605
609
606 # recheck *all* jobs if
610 # recheck *all* jobs if
607 # a) we have HWM and an engine just become no longer full
611 # a) we have HWM and an engine just become no longer full
608 # or b) dep_id was given as None
612 # or b) dep_id was given as None
609 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
613 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
610 jobs = self.depending.keys()
614 jobs = self.depending.keys()
611
615
612 for msg_id in jobs:
616 for msg_id in jobs:
613 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
617 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
614
618
615 if after.unreachable(self.all_completed, self.all_failed)\
619 if after.unreachable(self.all_completed, self.all_failed)\
616 or follow.unreachable(self.all_completed, self.all_failed):
620 or follow.unreachable(self.all_completed, self.all_failed):
617 self.fail_unreachable(msg_id)
621 self.fail_unreachable(msg_id)
618
622
619 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
623 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
620 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
624 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
621
625
622 self.depending.pop(msg_id)
626 self.depending.pop(msg_id)
623 for mid in follow.union(after):
627 for mid in follow.union(after):
624 if mid in self.graph:
628 if mid in self.graph:
625 self.graph[mid].remove(msg_id)
629 self.graph[mid].remove(msg_id)
626
630
627 #----------------------------------------------------------------------
631 #----------------------------------------------------------------------
628 # methods to be overridden by subclasses
632 # methods to be overridden by subclasses
629 #----------------------------------------------------------------------
633 #----------------------------------------------------------------------
630
634
631 def add_job(self, idx):
635 def add_job(self, idx):
632 """Called after self.targets[idx] just got the job with header.
636 """Called after self.targets[idx] just got the job with header.
633 Override with subclasses. The default ordering is simple LRU.
637 Override with subclasses. The default ordering is simple LRU.
634 The default loads are the number of outstanding jobs."""
638 The default loads are the number of outstanding jobs."""
635 self.loads[idx] += 1
639 self.loads[idx] += 1
636 for lis in (self.targets, self.loads):
640 for lis in (self.targets, self.loads):
637 lis.append(lis.pop(idx))
641 lis.append(lis.pop(idx))
638
642
639
643
640 def finish_job(self, idx):
644 def finish_job(self, idx):
641 """Called after self.targets[idx] just finished a job.
645 """Called after self.targets[idx] just finished a job.
642 Override with subclasses."""
646 Override with subclasses."""
643 self.loads[idx] -= 1
647 self.loads[idx] -= 1
644
648
645
649
646
650
647 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
651 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
648 logname='root', log_url=None, loglevel=logging.DEBUG,
652 logname='root', log_url=None, loglevel=logging.DEBUG,
649 identity=b'task'):
653 identity=b'task'):
650 from zmq.eventloop import ioloop
654 from zmq.eventloop import ioloop
651 from zmq.eventloop.zmqstream import ZMQStream
655 from zmq.eventloop.zmqstream import ZMQStream
652
656
653 if config:
657 if config:
654 # unwrap dict back into Config
658 # unwrap dict back into Config
655 config = Config(config)
659 config = Config(config)
656
660
657 ctx = zmq.Context()
661 ctx = zmq.Context()
658 loop = ioloop.IOLoop()
662 loop = ioloop.IOLoop()
659 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
663 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
660 ins.setsockopt(zmq.IDENTITY, identity)
664 ins.setsockopt(zmq.IDENTITY, identity)
661 ins.bind(in_addr)
665 ins.bind(in_addr)
662
666
663 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
667 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
664 outs.setsockopt(zmq.IDENTITY, identity)
668 outs.setsockopt(zmq.IDENTITY, identity)
665 outs.bind(out_addr)
669 outs.bind(out_addr)
666 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
670 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
667 mons.connect(mon_addr)
671 mons.connect(mon_addr)
668 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
672 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
669 nots.setsockopt(zmq.SUBSCRIBE, '')
673 nots.setsockopt(zmq.SUBSCRIBE, '')
670 nots.connect(not_addr)
674 nots.connect(not_addr)
671
675
672 # setup logging. Note that these will not work in-process, because they clobber
676 # setup logging. Note that these will not work in-process, because they clobber
673 # existing loggers.
677 # existing loggers.
674 if log_url:
678 if log_url:
675 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
679 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
676 else:
680 else:
677 log = local_logger(logname, loglevel)
681 log = local_logger(logname, loglevel)
678
682
679 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
683 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
680 mon_stream=mons, notifier_stream=nots,
684 mon_stream=mons, notifier_stream=nots,
681 loop=loop, log=log,
685 loop=loop, log=log,
682 config=config)
686 config=config)
683 scheduler.start()
687 scheduler.start()
684 try:
688 try:
685 loop.start()
689 loop.start()
686 except KeyboardInterrupt:
690 except KeyboardInterrupt:
687 print ("interrupted, exiting...", file=sys.__stderr__)
691 print ("interrupted, exiting...", file=sys.__stderr__)
688
692
@@ -1,386 +1,391 b''
1 """A TaskRecord backend using sqlite3"""
1 """A TaskRecord backend using sqlite3
2
3 Authors:
4
5 * Min RK
6 """
2 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
3 # Copyright (C) 2011 The IPython Development Team
8 # Copyright (C) 2011 The IPython Development Team
4 #
9 #
5 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
8
13
9 import json
14 import json
10 import os
15 import os
11 import cPickle as pickle
16 import cPickle as pickle
12 from datetime import datetime
17 from datetime import datetime
13
18
14 import sqlite3
19 import sqlite3
15
20
16 from zmq.eventloop import ioloop
21 from zmq.eventloop import ioloop
17
22
18 from IPython.utils.traitlets import Unicode, Instance, List, Dict
23 from IPython.utils.traitlets import Unicode, Instance, List, Dict
19 from .dictdb import BaseDB
24 from .dictdb import BaseDB
20 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
25 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
21
26
22 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
23 # SQLite operators, adapters, and converters
28 # SQLite operators, adapters, and converters
24 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
25
30
26 operators = {
31 operators = {
27 '$lt' : "<",
32 '$lt' : "<",
28 '$gt' : ">",
33 '$gt' : ">",
29 # null is handled weird with ==,!=
34 # null is handled weird with ==,!=
30 '$eq' : "=",
35 '$eq' : "=",
31 '$ne' : "!=",
36 '$ne' : "!=",
32 '$lte': "<=",
37 '$lte': "<=",
33 '$gte': ">=",
38 '$gte': ">=",
34 '$in' : ('=', ' OR '),
39 '$in' : ('=', ' OR '),
35 '$nin': ('!=', ' AND '),
40 '$nin': ('!=', ' AND '),
36 # '$all': None,
41 # '$all': None,
37 # '$mod': None,
42 # '$mod': None,
38 # '$exists' : None
43 # '$exists' : None
39 }
44 }
40 null_operators = {
45 null_operators = {
41 '=' : "IS NULL",
46 '=' : "IS NULL",
42 '!=' : "IS NOT NULL",
47 '!=' : "IS NOT NULL",
43 }
48 }
44
49
45 def _adapt_dict(d):
50 def _adapt_dict(d):
46 return json.dumps(d, default=date_default)
51 return json.dumps(d, default=date_default)
47
52
48 def _convert_dict(ds):
53 def _convert_dict(ds):
49 if ds is None:
54 if ds is None:
50 return ds
55 return ds
51 else:
56 else:
52 return extract_dates(json.loads(ds))
57 return extract_dates(json.loads(ds))
53
58
54 def _adapt_bufs(bufs):
59 def _adapt_bufs(bufs):
55 # this is *horrible*
60 # this is *horrible*
56 # copy buffers into single list and pickle it:
61 # copy buffers into single list and pickle it:
57 if bufs and isinstance(bufs[0], (bytes, buffer)):
62 if bufs and isinstance(bufs[0], (bytes, buffer)):
58 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
63 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
59 elif bufs:
64 elif bufs:
60 return bufs
65 return bufs
61 else:
66 else:
62 return None
67 return None
63
68
64 def _convert_bufs(bs):
69 def _convert_bufs(bs):
65 if bs is None:
70 if bs is None:
66 return []
71 return []
67 else:
72 else:
68 return pickle.loads(bytes(bs))
73 return pickle.loads(bytes(bs))
69
74
70 #-----------------------------------------------------------------------------
75 #-----------------------------------------------------------------------------
71 # SQLiteDB class
76 # SQLiteDB class
72 #-----------------------------------------------------------------------------
77 #-----------------------------------------------------------------------------
73
78
74 class SQLiteDB(BaseDB):
79 class SQLiteDB(BaseDB):
75 """SQLite3 TaskRecord backend."""
80 """SQLite3 TaskRecord backend."""
76
81
77 filename = Unicode('tasks.db', config=True,
82 filename = Unicode('tasks.db', config=True,
78 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
83 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
79 location = Unicode('', config=True,
84 location = Unicode('', config=True,
80 help="""The directory containing the sqlite task database. The default
85 help="""The directory containing the sqlite task database. The default
81 is to use the cluster_dir location.""")
86 is to use the cluster_dir location.""")
82 table = Unicode("", config=True,
87 table = Unicode("", config=True,
83 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
88 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
84 a new table will be created with the Hub's IDENT. Specifying the table will result
89 a new table will be created with the Hub's IDENT. Specifying the table will result
85 in tasks from previous sessions being available via Clients' db_query and
90 in tasks from previous sessions being available via Clients' db_query and
86 get_result methods.""")
91 get_result methods.""")
87
92
88 _db = Instance('sqlite3.Connection')
93 _db = Instance('sqlite3.Connection')
89 # the ordered list of column names
94 # the ordered list of column names
90 _keys = List(['msg_id' ,
95 _keys = List(['msg_id' ,
91 'header' ,
96 'header' ,
92 'content',
97 'content',
93 'buffers',
98 'buffers',
94 'submitted',
99 'submitted',
95 'client_uuid' ,
100 'client_uuid' ,
96 'engine_uuid' ,
101 'engine_uuid' ,
97 'started',
102 'started',
98 'completed',
103 'completed',
99 'resubmitted',
104 'resubmitted',
100 'result_header' ,
105 'result_header' ,
101 'result_content' ,
106 'result_content' ,
102 'result_buffers' ,
107 'result_buffers' ,
103 'queue' ,
108 'queue' ,
104 'pyin' ,
109 'pyin' ,
105 'pyout',
110 'pyout',
106 'pyerr',
111 'pyerr',
107 'stdout',
112 'stdout',
108 'stderr',
113 'stderr',
109 ])
114 ])
110 # sqlite datatypes for checking that db is current format
115 # sqlite datatypes for checking that db is current format
111 _types = Dict({'msg_id' : 'text' ,
116 _types = Dict({'msg_id' : 'text' ,
112 'header' : 'dict text',
117 'header' : 'dict text',
113 'content' : 'dict text',
118 'content' : 'dict text',
114 'buffers' : 'bufs blob',
119 'buffers' : 'bufs blob',
115 'submitted' : 'timestamp',
120 'submitted' : 'timestamp',
116 'client_uuid' : 'text',
121 'client_uuid' : 'text',
117 'engine_uuid' : 'text',
122 'engine_uuid' : 'text',
118 'started' : 'timestamp',
123 'started' : 'timestamp',
119 'completed' : 'timestamp',
124 'completed' : 'timestamp',
120 'resubmitted' : 'timestamp',
125 'resubmitted' : 'timestamp',
121 'result_header' : 'dict text',
126 'result_header' : 'dict text',
122 'result_content' : 'dict text',
127 'result_content' : 'dict text',
123 'result_buffers' : 'bufs blob',
128 'result_buffers' : 'bufs blob',
124 'queue' : 'text',
129 'queue' : 'text',
125 'pyin' : 'text',
130 'pyin' : 'text',
126 'pyout' : 'text',
131 'pyout' : 'text',
127 'pyerr' : 'text',
132 'pyerr' : 'text',
128 'stdout' : 'text',
133 'stdout' : 'text',
129 'stderr' : 'text',
134 'stderr' : 'text',
130 })
135 })
131
136
132 def __init__(self, **kwargs):
137 def __init__(self, **kwargs):
133 super(SQLiteDB, self).__init__(**kwargs)
138 super(SQLiteDB, self).__init__(**kwargs)
134 if not self.table:
139 if not self.table:
135 # use session, and prefix _, since starting with # is illegal
140 # use session, and prefix _, since starting with # is illegal
136 self.table = '_'+self.session.replace('-','_')
141 self.table = '_'+self.session.replace('-','_')
137 if not self.location:
142 if not self.location:
138 # get current profile
143 # get current profile
139 from IPython.core.newapplication import BaseIPythonApplication
144 from IPython.core.newapplication import BaseIPythonApplication
140 if BaseIPythonApplication.initialized():
145 if BaseIPythonApplication.initialized():
141 app = BaseIPythonApplication.instance()
146 app = BaseIPythonApplication.instance()
142 if app.profile_dir is not None:
147 if app.profile_dir is not None:
143 self.location = app.profile_dir.location
148 self.location = app.profile_dir.location
144 else:
149 else:
145 self.location = u'.'
150 self.location = u'.'
146 else:
151 else:
147 self.location = u'.'
152 self.location = u'.'
148 self._init_db()
153 self._init_db()
149
154
150 # register db commit as 2s periodic callback
155 # register db commit as 2s periodic callback
151 # to prevent clogging pipes
156 # to prevent clogging pipes
152 # assumes we are being run in a zmq ioloop app
157 # assumes we are being run in a zmq ioloop app
153 loop = ioloop.IOLoop.instance()
158 loop = ioloop.IOLoop.instance()
154 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
159 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
155 pc.start()
160 pc.start()
156
161
157 def _defaults(self, keys=None):
162 def _defaults(self, keys=None):
158 """create an empty record"""
163 """create an empty record"""
159 d = {}
164 d = {}
160 keys = self._keys if keys is None else keys
165 keys = self._keys if keys is None else keys
161 for key in keys:
166 for key in keys:
162 d[key] = None
167 d[key] = None
163 return d
168 return d
164
169
165 def _check_table(self):
170 def _check_table(self):
166 """Ensure that an incorrect table doesn't exist
171 """Ensure that an incorrect table doesn't exist
167
172
168 If a bad (old) table does exist, return False
173 If a bad (old) table does exist, return False
169 """
174 """
170 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
175 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
171 lines = cursor.fetchall()
176 lines = cursor.fetchall()
172 if not lines:
177 if not lines:
173 # table does not exist
178 # table does not exist
174 return True
179 return True
175 types = {}
180 types = {}
176 keys = []
181 keys = []
177 for line in lines:
182 for line in lines:
178 keys.append(line[1])
183 keys.append(line[1])
179 types[line[1]] = line[2]
184 types[line[1]] = line[2]
180 if self._keys != keys:
185 if self._keys != keys:
181 # key mismatch
186 # key mismatch
182 self.log.warn('keys mismatch')
187 self.log.warn('keys mismatch')
183 return False
188 return False
184 for key in self._keys:
189 for key in self._keys:
185 if types[key] != self._types[key]:
190 if types[key] != self._types[key]:
186 self.log.warn(
191 self.log.warn(
187 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
192 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
188 )
193 )
189 return False
194 return False
190 return True
195 return True
191
196
192 def _init_db(self):
197 def _init_db(self):
193 """Connect to the database and get new session number."""
198 """Connect to the database and get new session number."""
194 # register adapters
199 # register adapters
195 sqlite3.register_adapter(dict, _adapt_dict)
200 sqlite3.register_adapter(dict, _adapt_dict)
196 sqlite3.register_converter('dict', _convert_dict)
201 sqlite3.register_converter('dict', _convert_dict)
197 sqlite3.register_adapter(list, _adapt_bufs)
202 sqlite3.register_adapter(list, _adapt_bufs)
198 sqlite3.register_converter('bufs', _convert_bufs)
203 sqlite3.register_converter('bufs', _convert_bufs)
199 # connect to the db
204 # connect to the db
200 dbfile = os.path.join(self.location, self.filename)
205 dbfile = os.path.join(self.location, self.filename)
201 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
206 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
202 # isolation_level = None)#,
207 # isolation_level = None)#,
203 cached_statements=64)
208 cached_statements=64)
204 # print dir(self._db)
209 # print dir(self._db)
205 first_table = self.table
210 first_table = self.table
206 i=0
211 i=0
207 while not self._check_table():
212 while not self._check_table():
208 i+=1
213 i+=1
209 self.table = first_table+'_%i'%i
214 self.table = first_table+'_%i'%i
210 self.log.warn(
215 self.log.warn(
211 "Table %s exists and doesn't match db format, trying %s"%
216 "Table %s exists and doesn't match db format, trying %s"%
212 (first_table,self.table)
217 (first_table,self.table)
213 )
218 )
214
219
215 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
220 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
216 (msg_id text PRIMARY KEY,
221 (msg_id text PRIMARY KEY,
217 header dict text,
222 header dict text,
218 content dict text,
223 content dict text,
219 buffers bufs blob,
224 buffers bufs blob,
220 submitted timestamp,
225 submitted timestamp,
221 client_uuid text,
226 client_uuid text,
222 engine_uuid text,
227 engine_uuid text,
223 started timestamp,
228 started timestamp,
224 completed timestamp,
229 completed timestamp,
225 resubmitted timestamp,
230 resubmitted timestamp,
226 result_header dict text,
231 result_header dict text,
227 result_content dict text,
232 result_content dict text,
228 result_buffers bufs blob,
233 result_buffers bufs blob,
229 queue text,
234 queue text,
230 pyin text,
235 pyin text,
231 pyout text,
236 pyout text,
232 pyerr text,
237 pyerr text,
233 stdout text,
238 stdout text,
234 stderr text)
239 stderr text)
235 """%self.table)
240 """%self.table)
236 self._db.commit()
241 self._db.commit()
237
242
238 def _dict_to_list(self, d):
243 def _dict_to_list(self, d):
239 """turn a mongodb-style record dict into a list."""
244 """turn a mongodb-style record dict into a list."""
240
245
241 return [ d[key] for key in self._keys ]
246 return [ d[key] for key in self._keys ]
242
247
243 def _list_to_dict(self, line, keys=None):
248 def _list_to_dict(self, line, keys=None):
244 """Inverse of dict_to_list"""
249 """Inverse of dict_to_list"""
245 keys = self._keys if keys is None else keys
250 keys = self._keys if keys is None else keys
246 d = self._defaults(keys)
251 d = self._defaults(keys)
247 for key,value in zip(keys, line):
252 for key,value in zip(keys, line):
248 d[key] = value
253 d[key] = value
249
254
250 return d
255 return d
251
256
252 def _render_expression(self, check):
257 def _render_expression(self, check):
253 """Turn a mongodb-style search dict into an SQL query."""
258 """Turn a mongodb-style search dict into an SQL query."""
254 expressions = []
259 expressions = []
255 args = []
260 args = []
256
261
257 skeys = set(check.keys())
262 skeys = set(check.keys())
258 skeys.difference_update(set(self._keys))
263 skeys.difference_update(set(self._keys))
259 skeys.difference_update(set(['buffers', 'result_buffers']))
264 skeys.difference_update(set(['buffers', 'result_buffers']))
260 if skeys:
265 if skeys:
261 raise KeyError("Illegal testing key(s): %s"%skeys)
266 raise KeyError("Illegal testing key(s): %s"%skeys)
262
267
263 for name,sub_check in check.iteritems():
268 for name,sub_check in check.iteritems():
264 if isinstance(sub_check, dict):
269 if isinstance(sub_check, dict):
265 for test,value in sub_check.iteritems():
270 for test,value in sub_check.iteritems():
266 try:
271 try:
267 op = operators[test]
272 op = operators[test]
268 except KeyError:
273 except KeyError:
269 raise KeyError("Unsupported operator: %r"%test)
274 raise KeyError("Unsupported operator: %r"%test)
270 if isinstance(op, tuple):
275 if isinstance(op, tuple):
271 op, join = op
276 op, join = op
272
277
273 if value is None and op in null_operators:
278 if value is None and op in null_operators:
274 expr = "%s %s"%null_operators[op]
279 expr = "%s %s"%null_operators[op]
275 else:
280 else:
276 expr = "%s %s ?"%(name, op)
281 expr = "%s %s ?"%(name, op)
277 if isinstance(value, (tuple,list)):
282 if isinstance(value, (tuple,list)):
278 if op in null_operators and any([v is None for v in value]):
283 if op in null_operators and any([v is None for v in value]):
279 # equality tests don't work with NULL
284 # equality tests don't work with NULL
280 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
285 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
281 expr = '( %s )'%( join.join([expr]*len(value)) )
286 expr = '( %s )'%( join.join([expr]*len(value)) )
282 args.extend(value)
287 args.extend(value)
283 else:
288 else:
284 args.append(value)
289 args.append(value)
285 expressions.append(expr)
290 expressions.append(expr)
286 else:
291 else:
287 # it's an equality check
292 # it's an equality check
288 if sub_check is None:
293 if sub_check is None:
289 expressions.append("%s IS NULL")
294 expressions.append("%s IS NULL")
290 else:
295 else:
291 expressions.append("%s = ?"%name)
296 expressions.append("%s = ?"%name)
292 args.append(sub_check)
297 args.append(sub_check)
293
298
294 expr = " AND ".join(expressions)
299 expr = " AND ".join(expressions)
295 return expr, args
300 return expr, args
296
301
297 def add_record(self, msg_id, rec):
302 def add_record(self, msg_id, rec):
298 """Add a new Task Record, by msg_id."""
303 """Add a new Task Record, by msg_id."""
299 d = self._defaults()
304 d = self._defaults()
300 d.update(rec)
305 d.update(rec)
301 d['msg_id'] = msg_id
306 d['msg_id'] = msg_id
302 line = self._dict_to_list(d)
307 line = self._dict_to_list(d)
303 tups = '(%s)'%(','.join(['?']*len(line)))
308 tups = '(%s)'%(','.join(['?']*len(line)))
304 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
309 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
305 # self._db.commit()
310 # self._db.commit()
306
311
307 def get_record(self, msg_id):
312 def get_record(self, msg_id):
308 """Get a specific Task Record, by msg_id."""
313 """Get a specific Task Record, by msg_id."""
309 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
314 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
310 line = cursor.fetchone()
315 line = cursor.fetchone()
311 if line is None:
316 if line is None:
312 raise KeyError("No such msg: %r"%msg_id)
317 raise KeyError("No such msg: %r"%msg_id)
313 return self._list_to_dict(line)
318 return self._list_to_dict(line)
314
319
315 def update_record(self, msg_id, rec):
320 def update_record(self, msg_id, rec):
316 """Update the data in an existing record."""
321 """Update the data in an existing record."""
317 query = "UPDATE %s SET "%self.table
322 query = "UPDATE %s SET "%self.table
318 sets = []
323 sets = []
319 keys = sorted(rec.keys())
324 keys = sorted(rec.keys())
320 values = []
325 values = []
321 for key in keys:
326 for key in keys:
322 sets.append('%s = ?'%key)
327 sets.append('%s = ?'%key)
323 values.append(rec[key])
328 values.append(rec[key])
324 query += ', '.join(sets)
329 query += ', '.join(sets)
325 query += ' WHERE msg_id == ?'
330 query += ' WHERE msg_id == ?'
326 values.append(msg_id)
331 values.append(msg_id)
327 self._db.execute(query, values)
332 self._db.execute(query, values)
328 # self._db.commit()
333 # self._db.commit()
329
334
330 def drop_record(self, msg_id):
335 def drop_record(self, msg_id):
331 """Remove a record from the DB."""
336 """Remove a record from the DB."""
332 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
337 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
333 # self._db.commit()
338 # self._db.commit()
334
339
335 def drop_matching_records(self, check):
340 def drop_matching_records(self, check):
336 """Remove a record from the DB."""
341 """Remove a record from the DB."""
337 expr,args = self._render_expression(check)
342 expr,args = self._render_expression(check)
338 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
343 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
339 self._db.execute(query,args)
344 self._db.execute(query,args)
340 # self._db.commit()
345 # self._db.commit()
341
346
342 def find_records(self, check, keys=None):
347 def find_records(self, check, keys=None):
343 """Find records matching a query dict, optionally extracting subset of keys.
348 """Find records matching a query dict, optionally extracting subset of keys.
344
349
345 Returns list of matching records.
350 Returns list of matching records.
346
351
347 Parameters
352 Parameters
348 ----------
353 ----------
349
354
350 check: dict
355 check: dict
351 mongodb-style query argument
356 mongodb-style query argument
352 keys: list of strs [optional]
357 keys: list of strs [optional]
353 if specified, the subset of keys to extract. msg_id will *always* be
358 if specified, the subset of keys to extract. msg_id will *always* be
354 included.
359 included.
355 """
360 """
356 if keys:
361 if keys:
357 bad_keys = [ key for key in keys if key not in self._keys ]
362 bad_keys = [ key for key in keys if key not in self._keys ]
358 if bad_keys:
363 if bad_keys:
359 raise KeyError("Bad record key(s): %s"%bad_keys)
364 raise KeyError("Bad record key(s): %s"%bad_keys)
360
365
361 if keys:
366 if keys:
362 # ensure msg_id is present and first:
367 # ensure msg_id is present and first:
363 if 'msg_id' in keys:
368 if 'msg_id' in keys:
364 keys.remove('msg_id')
369 keys.remove('msg_id')
365 keys.insert(0, 'msg_id')
370 keys.insert(0, 'msg_id')
366 req = ', '.join(keys)
371 req = ', '.join(keys)
367 else:
372 else:
368 req = '*'
373 req = '*'
369 expr,args = self._render_expression(check)
374 expr,args = self._render_expression(check)
370 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
375 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
371 cursor = self._db.execute(query, args)
376 cursor = self._db.execute(query, args)
372 matches = cursor.fetchall()
377 matches = cursor.fetchall()
373 records = []
378 records = []
374 for line in matches:
379 for line in matches:
375 rec = self._list_to_dict(line, keys)
380 rec = self._list_to_dict(line, keys)
376 records.append(rec)
381 records.append(rec)
377 return records
382 return records
378
383
379 def get_history(self):
384 def get_history(self):
380 """get all msg_ids, ordered by time submitted."""
385 """get all msg_ids, ordered by time submitted."""
381 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
386 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
382 cursor = self._db.execute(query)
387 cursor = self._db.execute(query)
383 # will be a list of length 1 tuples
388 # will be a list of length 1 tuples
384 return [ tup[0] for tup in cursor.fetchall()]
389 return [ tup[0] for tup in cursor.fetchall()]
385
390
386 __all__ = ['SQLiteDB'] No newline at end of file
391 __all__ = ['SQLiteDB']
@@ -1,166 +1,170 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """A simple engine that talks to a controller over 0MQ.
2 """A simple engine that talks to a controller over 0MQ.
3 it handles registration, etc. and launches a kernel
3 it handles registration, etc. and launches a kernel
4 connected to the Controller's Schedulers.
4 connected to the Controller's Schedulers.
5
6 Authors:
7
8 * Min RK
5 """
9 """
6 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010-2011 The IPython Development Team
11 # Copyright (C) 2010-2011 The IPython Development Team
8 #
12 #
9 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
12
16
13 from __future__ import print_function
17 from __future__ import print_function
14
18
15 import sys
19 import sys
16 import time
20 import time
17
21
18 import zmq
22 import zmq
19 from zmq.eventloop import ioloop, zmqstream
23 from zmq.eventloop import ioloop, zmqstream
20
24
21 # internal
25 # internal
22 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode
26 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode
23 # from IPython.utils.localinterfaces import LOCALHOST
27 # from IPython.utils.localinterfaces import LOCALHOST
24
28
25 from IPython.parallel.controller.heartmonitor import Heart
29 from IPython.parallel.controller.heartmonitor import Heart
26 from IPython.parallel.factory import RegistrationFactory
30 from IPython.parallel.factory import RegistrationFactory
27 from IPython.parallel.util import disambiguate_url
31 from IPython.parallel.util import disambiguate_url
28
32
29 from IPython.zmq.session import Message
33 from IPython.zmq.session import Message
30
34
31 from .streamkernel import Kernel
35 from .streamkernel import Kernel
32
36
33 class EngineFactory(RegistrationFactory):
37 class EngineFactory(RegistrationFactory):
34 """IPython engine"""
38 """IPython engine"""
35
39
36 # configurables:
40 # configurables:
37 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
41 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
38 help="""The OutStream for handling stdout/err.
42 help="""The OutStream for handling stdout/err.
39 Typically 'IPython.zmq.iostream.OutStream'""")
43 Typically 'IPython.zmq.iostream.OutStream'""")
40 display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True,
44 display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True,
41 help="""The class for handling displayhook.
45 help="""The class for handling displayhook.
42 Typically 'IPython.zmq.displayhook.DisplayHook'""")
46 Typically 'IPython.zmq.displayhook.DisplayHook'""")
43 location=Unicode(config=True,
47 location=Unicode(config=True,
44 help="""The location (an IP address) of the controller. This is
48 help="""The location (an IP address) of the controller. This is
45 used for disambiguating URLs, to determine whether
49 used for disambiguating URLs, to determine whether
46 loopback should be used to connect or the public address.""")
50 loopback should be used to connect or the public address.""")
47 timeout=CFloat(2,config=True,
51 timeout=CFloat(2,config=True,
48 help="""The time (in seconds) to wait for the Controller to respond
52 help="""The time (in seconds) to wait for the Controller to respond
49 to registration requests before giving up.""")
53 to registration requests before giving up.""")
50
54
51 # not configurable:
55 # not configurable:
52 user_ns=Dict()
56 user_ns=Dict()
53 id=Int(allow_none=True)
57 id=Int(allow_none=True)
54 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
58 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
55 kernel=Instance(Kernel)
59 kernel=Instance(Kernel)
56
60
57
61
58 def __init__(self, **kwargs):
62 def __init__(self, **kwargs):
59 super(EngineFactory, self).__init__(**kwargs)
63 super(EngineFactory, self).__init__(**kwargs)
60 self.ident = self.session.session
64 self.ident = self.session.session
61 ctx = self.context
65 ctx = self.context
62
66
63 reg = ctx.socket(zmq.XREQ)
67 reg = ctx.socket(zmq.XREQ)
64 reg.setsockopt(zmq.IDENTITY, self.ident)
68 reg.setsockopt(zmq.IDENTITY, self.ident)
65 reg.connect(self.url)
69 reg.connect(self.url)
66 self.registrar = zmqstream.ZMQStream(reg, self.loop)
70 self.registrar = zmqstream.ZMQStream(reg, self.loop)
67
71
68 def register(self):
72 def register(self):
69 """send the registration_request"""
73 """send the registration_request"""
70
74
71 self.log.info("registering")
75 self.log.info("registering")
72 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
76 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
73 self.registrar.on_recv(self.complete_registration)
77 self.registrar.on_recv(self.complete_registration)
74 # print (self.session.key)
78 # print (self.session.key)
75 self.session.send(self.registrar, "registration_request",content=content)
79 self.session.send(self.registrar, "registration_request",content=content)
76
80
77 def complete_registration(self, msg):
81 def complete_registration(self, msg):
78 # print msg
82 # print msg
79 self._abort_dc.stop()
83 self._abort_dc.stop()
80 ctx = self.context
84 ctx = self.context
81 loop = self.loop
85 loop = self.loop
82 identity = self.ident
86 identity = self.ident
83
87
84 idents,msg = self.session.feed_identities(msg)
88 idents,msg = self.session.feed_identities(msg)
85 msg = Message(self.session.unpack_message(msg))
89 msg = Message(self.session.unpack_message(msg))
86
90
87 if msg.content.status == 'ok':
91 if msg.content.status == 'ok':
88 self.id = int(msg.content.id)
92 self.id = int(msg.content.id)
89
93
90 # create Shell Streams (MUX, Task, etc.):
94 # create Shell Streams (MUX, Task, etc.):
91 queue_addr = msg.content.mux
95 queue_addr = msg.content.mux
92 shell_addrs = [ str(queue_addr) ]
96 shell_addrs = [ str(queue_addr) ]
93 task_addr = msg.content.task
97 task_addr = msg.content.task
94 if task_addr:
98 if task_addr:
95 shell_addrs.append(str(task_addr))
99 shell_addrs.append(str(task_addr))
96
100
97 # Uncomment this to go back to two-socket model
101 # Uncomment this to go back to two-socket model
98 # shell_streams = []
102 # shell_streams = []
99 # for addr in shell_addrs:
103 # for addr in shell_addrs:
100 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
104 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
101 # stream.setsockopt(zmq.IDENTITY, identity)
105 # stream.setsockopt(zmq.IDENTITY, identity)
102 # stream.connect(disambiguate_url(addr, self.location))
106 # stream.connect(disambiguate_url(addr, self.location))
103 # shell_streams.append(stream)
107 # shell_streams.append(stream)
104
108
105 # Now use only one shell stream for mux and tasks
109 # Now use only one shell stream for mux and tasks
106 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
110 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
107 stream.setsockopt(zmq.IDENTITY, identity)
111 stream.setsockopt(zmq.IDENTITY, identity)
108 shell_streams = [stream]
112 shell_streams = [stream]
109 for addr in shell_addrs:
113 for addr in shell_addrs:
110 stream.connect(disambiguate_url(addr, self.location))
114 stream.connect(disambiguate_url(addr, self.location))
111 # end single stream-socket
115 # end single stream-socket
112
116
113 # control stream:
117 # control stream:
114 control_addr = str(msg.content.control)
118 control_addr = str(msg.content.control)
115 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
119 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
116 control_stream.setsockopt(zmq.IDENTITY, identity)
120 control_stream.setsockopt(zmq.IDENTITY, identity)
117 control_stream.connect(disambiguate_url(control_addr, self.location))
121 control_stream.connect(disambiguate_url(control_addr, self.location))
118
122
119 # create iopub stream:
123 # create iopub stream:
120 iopub_addr = msg.content.iopub
124 iopub_addr = msg.content.iopub
121 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
125 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
122 iopub_stream.setsockopt(zmq.IDENTITY, identity)
126 iopub_stream.setsockopt(zmq.IDENTITY, identity)
123 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
127 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
124
128
125 # launch heartbeat
129 # launch heartbeat
126 hb_addrs = msg.content.heartbeat
130 hb_addrs = msg.content.heartbeat
127 # print (hb_addrs)
131 # print (hb_addrs)
128
132
129 # # Redirect input streams and set a display hook.
133 # # Redirect input streams and set a display hook.
130 if self.out_stream_factory:
134 if self.out_stream_factory:
131 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
135 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
132 sys.stdout.topic = 'engine.%i.stdout'%self.id
136 sys.stdout.topic = 'engine.%i.stdout'%self.id
133 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
137 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
134 sys.stderr.topic = 'engine.%i.stderr'%self.id
138 sys.stderr.topic = 'engine.%i.stderr'%self.id
135 if self.display_hook_factory:
139 if self.display_hook_factory:
136 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
140 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
137 sys.displayhook.topic = 'engine.%i.pyout'%self.id
141 sys.displayhook.topic = 'engine.%i.pyout'%self.id
138
142
139 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
143 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
140 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
144 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
141 loop=loop, user_ns = self.user_ns, log=self.log)
145 loop=loop, user_ns = self.user_ns, log=self.log)
142 self.kernel.start()
146 self.kernel.start()
143 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
147 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
144 heart = Heart(*map(str, hb_addrs), heart_id=identity)
148 heart = Heart(*map(str, hb_addrs), heart_id=identity)
145 heart.start()
149 heart.start()
146
150
147
151
148 else:
152 else:
149 self.log.fatal("Registration Failed: %s"%msg)
153 self.log.fatal("Registration Failed: %s"%msg)
150 raise Exception("Registration Failed: %s"%msg)
154 raise Exception("Registration Failed: %s"%msg)
151
155
152 self.log.info("Completed registration with id %i"%self.id)
156 self.log.info("Completed registration with id %i"%self.id)
153
157
154
158
155 def abort(self):
159 def abort(self):
156 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
160 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
157 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
161 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
158 time.sleep(1)
162 time.sleep(1)
159 sys.exit(255)
163 sys.exit(255)
160
164
161 def start(self):
165 def start(self):
162 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
166 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
163 dc.start()
167 dc.start()
164 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
168 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
165 self._abort_dc.start()
169 self._abort_dc.start()
166
170
@@ -1,225 +1,230 b''
1 """KernelStarter class that intercepts Control Queue messages, and handles process management."""
1 """KernelStarter class that intercepts Control Queue messages, and handles process management.
2
3 Authors:
4
5 * Min RK
6 """
2 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
4 #
9 #
5 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
8
13
9 from zmq.eventloop import ioloop
14 from zmq.eventloop import ioloop
10
15
11 from IPython.zmq.session import Session
16 from IPython.zmq.session import Session
12
17
13 class KernelStarter(object):
18 class KernelStarter(object):
14 """Object for resetting/killing the Kernel."""
19 """Object for resetting/killing the Kernel."""
15
20
16
21
17 def __init__(self, session, upstream, downstream, *kernel_args, **kernel_kwargs):
22 def __init__(self, session, upstream, downstream, *kernel_args, **kernel_kwargs):
18 self.session = session
23 self.session = session
19 self.upstream = upstream
24 self.upstream = upstream
20 self.downstream = downstream
25 self.downstream = downstream
21 self.kernel_args = kernel_args
26 self.kernel_args = kernel_args
22 self.kernel_kwargs = kernel_kwargs
27 self.kernel_kwargs = kernel_kwargs
23 self.handlers = {}
28 self.handlers = {}
24 for method in 'shutdown_request shutdown_reply'.split():
29 for method in 'shutdown_request shutdown_reply'.split():
25 self.handlers[method] = getattr(self, method)
30 self.handlers[method] = getattr(self, method)
26
31
27 def start(self):
32 def start(self):
28 self.upstream.on_recv(self.dispatch_request)
33 self.upstream.on_recv(self.dispatch_request)
29 self.downstream.on_recv(self.dispatch_reply)
34 self.downstream.on_recv(self.dispatch_reply)
30
35
31 #--------------------------------------------------------------------------
36 #--------------------------------------------------------------------------
32 # Dispatch methods
37 # Dispatch methods
33 #--------------------------------------------------------------------------
38 #--------------------------------------------------------------------------
34
39
35 def dispatch_request(self, raw_msg):
40 def dispatch_request(self, raw_msg):
36 idents, msg = self.session.feed_identities()
41 idents, msg = self.session.feed_identities()
37 try:
42 try:
38 msg = self.session.unpack_message(msg, content=False)
43 msg = self.session.unpack_message(msg, content=False)
39 except:
44 except:
40 print ("bad msg: %s"%msg)
45 print ("bad msg: %s"%msg)
41
46
42 msgtype = msg['msg_type']
47 msgtype = msg['msg_type']
43 handler = self.handlers.get(msgtype, None)
48 handler = self.handlers.get(msgtype, None)
44 if handler is None:
49 if handler is None:
45 self.downstream.send_multipart(raw_msg, copy=False)
50 self.downstream.send_multipart(raw_msg, copy=False)
46 else:
51 else:
47 handler(msg)
52 handler(msg)
48
53
49 def dispatch_reply(self, raw_msg):
54 def dispatch_reply(self, raw_msg):
50 idents, msg = self.session.feed_identities()
55 idents, msg = self.session.feed_identities()
51 try:
56 try:
52 msg = self.session.unpack_message(msg, content=False)
57 msg = self.session.unpack_message(msg, content=False)
53 except:
58 except:
54 print ("bad msg: %s"%msg)
59 print ("bad msg: %s"%msg)
55
60
56 msgtype = msg['msg_type']
61 msgtype = msg['msg_type']
57 handler = self.handlers.get(msgtype, None)
62 handler = self.handlers.get(msgtype, None)
58 if handler is None:
63 if handler is None:
59 self.upstream.send_multipart(raw_msg, copy=False)
64 self.upstream.send_multipart(raw_msg, copy=False)
60 else:
65 else:
61 handler(msg)
66 handler(msg)
62
67
63 #--------------------------------------------------------------------------
68 #--------------------------------------------------------------------------
64 # Handlers
69 # Handlers
65 #--------------------------------------------------------------------------
70 #--------------------------------------------------------------------------
66
71
67 def shutdown_request(self, msg):
72 def shutdown_request(self, msg):
68 """"""
73 """"""
69 self.downstream.send_multipart(msg)
74 self.downstream.send_multipart(msg)
70
75
71 #--------------------------------------------------------------------------
76 #--------------------------------------------------------------------------
72 # Kernel process management methods, from KernelManager:
77 # Kernel process management methods, from KernelManager:
73 #--------------------------------------------------------------------------
78 #--------------------------------------------------------------------------
74
79
75 def _check_local(addr):
80 def _check_local(addr):
76 if isinstance(addr, tuple):
81 if isinstance(addr, tuple):
77 addr = addr[0]
82 addr = addr[0]
78 return addr in LOCAL_IPS
83 return addr in LOCAL_IPS
79
84
80 def start_kernel(self, **kw):
85 def start_kernel(self, **kw):
81 """Starts a kernel process and configures the manager to use it.
86 """Starts a kernel process and configures the manager to use it.
82
87
83 If random ports (port=0) are being used, this method must be called
88 If random ports (port=0) are being used, this method must be called
84 before the channels are created.
89 before the channels are created.
85
90
86 Parameters:
91 Parameters:
87 -----------
92 -----------
88 ipython : bool, optional (default True)
93 ipython : bool, optional (default True)
89 Whether to use an IPython kernel instead of a plain Python kernel.
94 Whether to use an IPython kernel instead of a plain Python kernel.
90 """
95 """
91 self.kernel = Process(target=make_kernel, args=self.kernel_args,
96 self.kernel = Process(target=make_kernel, args=self.kernel_args,
92 kwargs=self.kernel_kwargs)
97 kwargs=self.kernel_kwargs)
93
98
94 def shutdown_kernel(self, restart=False):
99 def shutdown_kernel(self, restart=False):
95 """ Attempts to the stop the kernel process cleanly. If the kernel
100 """ Attempts to the stop the kernel process cleanly. If the kernel
96 cannot be stopped, it is killed, if possible.
101 cannot be stopped, it is killed, if possible.
97 """
102 """
98 # FIXME: Shutdown does not work on Windows due to ZMQ errors!
103 # FIXME: Shutdown does not work on Windows due to ZMQ errors!
99 if sys.platform == 'win32':
104 if sys.platform == 'win32':
100 self.kill_kernel()
105 self.kill_kernel()
101 return
106 return
102
107
103 # Don't send any additional kernel kill messages immediately, to give
108 # Don't send any additional kernel kill messages immediately, to give
104 # the kernel a chance to properly execute shutdown actions. Wait for at
109 # the kernel a chance to properly execute shutdown actions. Wait for at
105 # most 1s, checking every 0.1s.
110 # most 1s, checking every 0.1s.
106 self.xreq_channel.shutdown(restart=restart)
111 self.xreq_channel.shutdown(restart=restart)
107 for i in range(10):
112 for i in range(10):
108 if self.is_alive:
113 if self.is_alive:
109 time.sleep(0.1)
114 time.sleep(0.1)
110 else:
115 else:
111 break
116 break
112 else:
117 else:
113 # OK, we've waited long enough.
118 # OK, we've waited long enough.
114 if self.has_kernel:
119 if self.has_kernel:
115 self.kill_kernel()
120 self.kill_kernel()
116
121
117 def restart_kernel(self, now=False):
122 def restart_kernel(self, now=False):
118 """Restarts a kernel with the same arguments that were used to launch
123 """Restarts a kernel with the same arguments that were used to launch
119 it. If the old kernel was launched with random ports, the same ports
124 it. If the old kernel was launched with random ports, the same ports
120 will be used for the new kernel.
125 will be used for the new kernel.
121
126
122 Parameters
127 Parameters
123 ----------
128 ----------
124 now : bool, optional
129 now : bool, optional
125 If True, the kernel is forcefully restarted *immediately*, without
130 If True, the kernel is forcefully restarted *immediately*, without
126 having a chance to do any cleanup action. Otherwise the kernel is
131 having a chance to do any cleanup action. Otherwise the kernel is
127 given 1s to clean up before a forceful restart is issued.
132 given 1s to clean up before a forceful restart is issued.
128
133
129 In all cases the kernel is restarted, the only difference is whether
134 In all cases the kernel is restarted, the only difference is whether
130 it is given a chance to perform a clean shutdown or not.
135 it is given a chance to perform a clean shutdown or not.
131 """
136 """
132 if self._launch_args is None:
137 if self._launch_args is None:
133 raise RuntimeError("Cannot restart the kernel. "
138 raise RuntimeError("Cannot restart the kernel. "
134 "No previous call to 'start_kernel'.")
139 "No previous call to 'start_kernel'.")
135 else:
140 else:
136 if self.has_kernel:
141 if self.has_kernel:
137 if now:
142 if now:
138 self.kill_kernel()
143 self.kill_kernel()
139 else:
144 else:
140 self.shutdown_kernel(restart=True)
145 self.shutdown_kernel(restart=True)
141 self.start_kernel(**self._launch_args)
146 self.start_kernel(**self._launch_args)
142
147
143 # FIXME: Messages get dropped in Windows due to probable ZMQ bug
148 # FIXME: Messages get dropped in Windows due to probable ZMQ bug
144 # unless there is some delay here.
149 # unless there is some delay here.
145 if sys.platform == 'win32':
150 if sys.platform == 'win32':
146 time.sleep(0.2)
151 time.sleep(0.2)
147
152
148 @property
153 @property
149 def has_kernel(self):
154 def has_kernel(self):
150 """Returns whether a kernel process has been specified for the kernel
155 """Returns whether a kernel process has been specified for the kernel
151 manager.
156 manager.
152 """
157 """
153 return self.kernel is not None
158 return self.kernel is not None
154
159
155 def kill_kernel(self):
160 def kill_kernel(self):
156 """ Kill the running kernel. """
161 """ Kill the running kernel. """
157 if self.has_kernel:
162 if self.has_kernel:
158 # Pause the heart beat channel if it exists.
163 # Pause the heart beat channel if it exists.
159 if self._hb_channel is not None:
164 if self._hb_channel is not None:
160 self._hb_channel.pause()
165 self._hb_channel.pause()
161
166
162 # Attempt to kill the kernel.
167 # Attempt to kill the kernel.
163 try:
168 try:
164 self.kernel.kill()
169 self.kernel.kill()
165 except OSError, e:
170 except OSError, e:
166 # In Windows, we will get an Access Denied error if the process
171 # In Windows, we will get an Access Denied error if the process
167 # has already terminated. Ignore it.
172 # has already terminated. Ignore it.
168 if not (sys.platform == 'win32' and e.winerror == 5):
173 if not (sys.platform == 'win32' and e.winerror == 5):
169 raise
174 raise
170 self.kernel = None
175 self.kernel = None
171 else:
176 else:
172 raise RuntimeError("Cannot kill kernel. No kernel is running!")
177 raise RuntimeError("Cannot kill kernel. No kernel is running!")
173
178
174 def interrupt_kernel(self):
179 def interrupt_kernel(self):
175 """ Interrupts the kernel. Unlike ``signal_kernel``, this operation is
180 """ Interrupts the kernel. Unlike ``signal_kernel``, this operation is
176 well supported on all platforms.
181 well supported on all platforms.
177 """
182 """
178 if self.has_kernel:
183 if self.has_kernel:
179 if sys.platform == 'win32':
184 if sys.platform == 'win32':
180 from parentpoller import ParentPollerWindows as Poller
185 from parentpoller import ParentPollerWindows as Poller
181 Poller.send_interrupt(self.kernel.win32_interrupt_event)
186 Poller.send_interrupt(self.kernel.win32_interrupt_event)
182 else:
187 else:
183 self.kernel.send_signal(signal.SIGINT)
188 self.kernel.send_signal(signal.SIGINT)
184 else:
189 else:
185 raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
190 raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
186
191
187 def signal_kernel(self, signum):
192 def signal_kernel(self, signum):
188 """ Sends a signal to the kernel. Note that since only SIGTERM is
193 """ Sends a signal to the kernel. Note that since only SIGTERM is
189 supported on Windows, this function is only useful on Unix systems.
194 supported on Windows, this function is only useful on Unix systems.
190 """
195 """
191 if self.has_kernel:
196 if self.has_kernel:
192 self.kernel.send_signal(signum)
197 self.kernel.send_signal(signum)
193 else:
198 else:
194 raise RuntimeError("Cannot signal kernel. No kernel is running!")
199 raise RuntimeError("Cannot signal kernel. No kernel is running!")
195
200
196 @property
201 @property
197 def is_alive(self):
202 def is_alive(self):
198 """Is the kernel process still running?"""
203 """Is the kernel process still running?"""
199 # FIXME: not using a heartbeat means this method is broken for any
204 # FIXME: not using a heartbeat means this method is broken for any
200 # remote kernel, it's only capable of handling local kernels.
205 # remote kernel, it's only capable of handling local kernels.
201 if self.has_kernel:
206 if self.has_kernel:
202 if self.kernel.poll() is None:
207 if self.kernel.poll() is None:
203 return True
208 return True
204 else:
209 else:
205 return False
210 return False
206 else:
211 else:
207 # We didn't start the kernel with this KernelManager so we don't
212 # We didn't start the kernel with this KernelManager so we don't
208 # know if it is running. We should use a heartbeat for this case.
213 # know if it is running. We should use a heartbeat for this case.
209 return True
214 return True
210
215
211
216
212 def make_starter(up_addr, down_addr, *args, **kwargs):
217 def make_starter(up_addr, down_addr, *args, **kwargs):
213 """entry point function for launching a kernelstarter in a subprocess"""
218 """entry point function for launching a kernelstarter in a subprocess"""
214 loop = ioloop.IOLoop.instance()
219 loop = ioloop.IOLoop.instance()
215 ctx = zmq.Context()
220 ctx = zmq.Context()
216 session = Session()
221 session = Session()
217 upstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
222 upstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
218 upstream.connect(up_addr)
223 upstream.connect(up_addr)
219 downstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
224 downstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
220 downstream.connect(down_addr)
225 downstream.connect(down_addr)
221
226
222 starter = KernelStarter(session, upstream, downstream, *args, **kwargs)
227 starter = KernelStarter(session, upstream, downstream, *args, **kwargs)
223 starter.start()
228 starter.start()
224 loop.start()
229 loop.start()
225 No newline at end of file
230
@@ -1,422 +1,429 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """
2 """
3 Kernel adapted from kernel.py to use ZMQ Streams
3 Kernel adapted from kernel.py to use ZMQ Streams
4
5 Authors:
6
7 * Min RK
8 * Brian Granger
9 * Fernando Perez
10 * Evan Patterson
4 """
11 """
5 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
6 # Copyright (C) 2010-2011 The IPython Development Team
13 # Copyright (C) 2010-2011 The IPython Development Team
7 #
14 #
8 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
9 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
10 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
11
18
12 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
13 # Imports
20 # Imports
14 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
15
22
16 # Standard library imports.
23 # Standard library imports.
17 from __future__ import print_function
24 from __future__ import print_function
18
25
19 import sys
26 import sys
20 import time
27 import time
21
28
22 from code import CommandCompiler
29 from code import CommandCompiler
23 from datetime import datetime
30 from datetime import datetime
24 from pprint import pprint
31 from pprint import pprint
25
32
26 # System library imports.
33 # System library imports.
27 import zmq
34 import zmq
28 from zmq.eventloop import ioloop, zmqstream
35 from zmq.eventloop import ioloop, zmqstream
29
36
30 # Local imports.
37 # Local imports.
31 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode
38 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode
32 from IPython.zmq.completer import KernelCompleter
39 from IPython.zmq.completer import KernelCompleter
33
40
34 from IPython.parallel.error import wrap_exception
41 from IPython.parallel.error import wrap_exception
35 from IPython.parallel.factory import SessionFactory
42 from IPython.parallel.factory import SessionFactory
36 from IPython.parallel.util import serialize_object, unpack_apply_message
43 from IPython.parallel.util import serialize_object, unpack_apply_message
37
44
38 def printer(*args):
45 def printer(*args):
39 pprint(args, stream=sys.__stdout__)
46 pprint(args, stream=sys.__stdout__)
40
47
41
48
42 class _Passer(zmqstream.ZMQStream):
49 class _Passer(zmqstream.ZMQStream):
43 """Empty class that implements `send()` that does nothing.
50 """Empty class that implements `send()` that does nothing.
44
51
45 Subclass ZMQStream for Session typechecking
52 Subclass ZMQStream for Session typechecking
46
53
47 """
54 """
48 def __init__(self, *args, **kwargs):
55 def __init__(self, *args, **kwargs):
49 pass
56 pass
50
57
51 def send(self, *args, **kwargs):
58 def send(self, *args, **kwargs):
52 pass
59 pass
53 send_multipart = send
60 send_multipart = send
54
61
55
62
56 #-----------------------------------------------------------------------------
63 #-----------------------------------------------------------------------------
57 # Main kernel class
64 # Main kernel class
58 #-----------------------------------------------------------------------------
65 #-----------------------------------------------------------------------------
59
66
60 class Kernel(SessionFactory):
67 class Kernel(SessionFactory):
61
68
62 #---------------------------------------------------------------------------
69 #---------------------------------------------------------------------------
63 # Kernel interface
70 # Kernel interface
64 #---------------------------------------------------------------------------
71 #---------------------------------------------------------------------------
65
72
66 # kwargs:
73 # kwargs:
67 exec_lines = List(Unicode, config=True,
74 exec_lines = List(Unicode, config=True,
68 help="List of lines to execute")
75 help="List of lines to execute")
69
76
70 int_id = Int(-1)
77 int_id = Int(-1)
71 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
78 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
72
79
73 control_stream = Instance(zmqstream.ZMQStream)
80 control_stream = Instance(zmqstream.ZMQStream)
74 task_stream = Instance(zmqstream.ZMQStream)
81 task_stream = Instance(zmqstream.ZMQStream)
75 iopub_stream = Instance(zmqstream.ZMQStream)
82 iopub_stream = Instance(zmqstream.ZMQStream)
76 client = Instance('IPython.parallel.Client')
83 client = Instance('IPython.parallel.Client')
77
84
78 # internals
85 # internals
79 shell_streams = List()
86 shell_streams = List()
80 compiler = Instance(CommandCompiler, (), {})
87 compiler = Instance(CommandCompiler, (), {})
81 completer = Instance(KernelCompleter)
88 completer = Instance(KernelCompleter)
82
89
83 aborted = Set()
90 aborted = Set()
84 shell_handlers = Dict()
91 shell_handlers = Dict()
85 control_handlers = Dict()
92 control_handlers = Dict()
86
93
87 def _set_prefix(self):
94 def _set_prefix(self):
88 self.prefix = "engine.%s"%self.int_id
95 self.prefix = "engine.%s"%self.int_id
89
96
90 def _connect_completer(self):
97 def _connect_completer(self):
91 self.completer = KernelCompleter(self.user_ns)
98 self.completer = KernelCompleter(self.user_ns)
92
99
93 def __init__(self, **kwargs):
100 def __init__(self, **kwargs):
94 super(Kernel, self).__init__(**kwargs)
101 super(Kernel, self).__init__(**kwargs)
95 self._set_prefix()
102 self._set_prefix()
96 self._connect_completer()
103 self._connect_completer()
97
104
98 self.on_trait_change(self._set_prefix, 'id')
105 self.on_trait_change(self._set_prefix, 'id')
99 self.on_trait_change(self._connect_completer, 'user_ns')
106 self.on_trait_change(self._connect_completer, 'user_ns')
100
107
101 # Build dict of handlers for message types
108 # Build dict of handlers for message types
102 for msg_type in ['execute_request', 'complete_request', 'apply_request',
109 for msg_type in ['execute_request', 'complete_request', 'apply_request',
103 'clear_request']:
110 'clear_request']:
104 self.shell_handlers[msg_type] = getattr(self, msg_type)
111 self.shell_handlers[msg_type] = getattr(self, msg_type)
105
112
106 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
113 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
107 self.control_handlers[msg_type] = getattr(self, msg_type)
114 self.control_handlers[msg_type] = getattr(self, msg_type)
108
115
109 self._initial_exec_lines()
116 self._initial_exec_lines()
110
117
111 def _wrap_exception(self, method=None):
118 def _wrap_exception(self, method=None):
112 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
119 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
113 content=wrap_exception(e_info)
120 content=wrap_exception(e_info)
114 return content
121 return content
115
122
116 def _initial_exec_lines(self):
123 def _initial_exec_lines(self):
117 s = _Passer()
124 s = _Passer()
118 content = dict(silent=True, user_variable=[],user_expressions=[])
125 content = dict(silent=True, user_variable=[],user_expressions=[])
119 for line in self.exec_lines:
126 for line in self.exec_lines:
120 self.log.debug("executing initialization: %s"%line)
127 self.log.debug("executing initialization: %s"%line)
121 content.update({'code':line})
128 content.update({'code':line})
122 msg = self.session.msg('execute_request', content)
129 msg = self.session.msg('execute_request', content)
123 self.execute_request(s, [], msg)
130 self.execute_request(s, [], msg)
124
131
125
132
126 #-------------------- control handlers -----------------------------
133 #-------------------- control handlers -----------------------------
127 def abort_queues(self):
134 def abort_queues(self):
128 for stream in self.shell_streams:
135 for stream in self.shell_streams:
129 if stream:
136 if stream:
130 self.abort_queue(stream)
137 self.abort_queue(stream)
131
138
132 def abort_queue(self, stream):
139 def abort_queue(self, stream):
133 while True:
140 while True:
134 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
141 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
135 if msg is None:
142 if msg is None:
136 return
143 return
137
144
138 self.log.info("Aborting:")
145 self.log.info("Aborting:")
139 self.log.info(str(msg))
146 self.log.info(str(msg))
140 msg_type = msg['msg_type']
147 msg_type = msg['msg_type']
141 reply_type = msg_type.split('_')[0] + '_reply'
148 reply_type = msg_type.split('_')[0] + '_reply'
142 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
149 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
143 # self.reply_socket.send(ident,zmq.SNDMORE)
150 # self.reply_socket.send(ident,zmq.SNDMORE)
144 # self.reply_socket.send_json(reply_msg)
151 # self.reply_socket.send_json(reply_msg)
145 reply_msg = self.session.send(stream, reply_type,
152 reply_msg = self.session.send(stream, reply_type,
146 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
153 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
147 self.log.debug(str(reply_msg))
154 self.log.debug(str(reply_msg))
148 # We need to wait a bit for requests to come in. This can probably
155 # We need to wait a bit for requests to come in. This can probably
149 # be set shorter for true asynchronous clients.
156 # be set shorter for true asynchronous clients.
150 time.sleep(0.05)
157 time.sleep(0.05)
151
158
152 def abort_request(self, stream, ident, parent):
159 def abort_request(self, stream, ident, parent):
153 """abort a specifig msg by id"""
160 """abort a specifig msg by id"""
154 msg_ids = parent['content'].get('msg_ids', None)
161 msg_ids = parent['content'].get('msg_ids', None)
155 if isinstance(msg_ids, basestring):
162 if isinstance(msg_ids, basestring):
156 msg_ids = [msg_ids]
163 msg_ids = [msg_ids]
157 if not msg_ids:
164 if not msg_ids:
158 self.abort_queues()
165 self.abort_queues()
159 for mid in msg_ids:
166 for mid in msg_ids:
160 self.aborted.add(str(mid))
167 self.aborted.add(str(mid))
161
168
162 content = dict(status='ok')
169 content = dict(status='ok')
163 reply_msg = self.session.send(stream, 'abort_reply', content=content,
170 reply_msg = self.session.send(stream, 'abort_reply', content=content,
164 parent=parent, ident=ident)
171 parent=parent, ident=ident)
165 self.log.debug(str(reply_msg))
172 self.log.debug(str(reply_msg))
166
173
167 def shutdown_request(self, stream, ident, parent):
174 def shutdown_request(self, stream, ident, parent):
168 """kill ourself. This should really be handled in an external process"""
175 """kill ourself. This should really be handled in an external process"""
169 try:
176 try:
170 self.abort_queues()
177 self.abort_queues()
171 except:
178 except:
172 content = self._wrap_exception('shutdown')
179 content = self._wrap_exception('shutdown')
173 else:
180 else:
174 content = dict(parent['content'])
181 content = dict(parent['content'])
175 content['status'] = 'ok'
182 content['status'] = 'ok'
176 msg = self.session.send(stream, 'shutdown_reply',
183 msg = self.session.send(stream, 'shutdown_reply',
177 content=content, parent=parent, ident=ident)
184 content=content, parent=parent, ident=ident)
178 self.log.debug(str(msg))
185 self.log.debug(str(msg))
179 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
186 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
180 dc.start()
187 dc.start()
181
188
182 def dispatch_control(self, msg):
189 def dispatch_control(self, msg):
183 idents,msg = self.session.feed_identities(msg, copy=False)
190 idents,msg = self.session.feed_identities(msg, copy=False)
184 try:
191 try:
185 msg = self.session.unpack_message(msg, content=True, copy=False)
192 msg = self.session.unpack_message(msg, content=True, copy=False)
186 except:
193 except:
187 self.log.error("Invalid Message", exc_info=True)
194 self.log.error("Invalid Message", exc_info=True)
188 return
195 return
189
196
190 header = msg['header']
197 header = msg['header']
191 msg_id = header['msg_id']
198 msg_id = header['msg_id']
192
199
193 handler = self.control_handlers.get(msg['msg_type'], None)
200 handler = self.control_handlers.get(msg['msg_type'], None)
194 if handler is None:
201 if handler is None:
195 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
202 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
196 else:
203 else:
197 handler(self.control_stream, idents, msg)
204 handler(self.control_stream, idents, msg)
198
205
199
206
200 #-------------------- queue helpers ------------------------------
207 #-------------------- queue helpers ------------------------------
201
208
202 def check_dependencies(self, dependencies):
209 def check_dependencies(self, dependencies):
203 if not dependencies:
210 if not dependencies:
204 return True
211 return True
205 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
212 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
206 anyorall = dependencies[0]
213 anyorall = dependencies[0]
207 dependencies = dependencies[1]
214 dependencies = dependencies[1]
208 else:
215 else:
209 anyorall = 'all'
216 anyorall = 'all'
210 results = self.client.get_results(dependencies,status_only=True)
217 results = self.client.get_results(dependencies,status_only=True)
211 if results['status'] != 'ok':
218 if results['status'] != 'ok':
212 return False
219 return False
213
220
214 if anyorall == 'any':
221 if anyorall == 'any':
215 if not results['completed']:
222 if not results['completed']:
216 return False
223 return False
217 else:
224 else:
218 if results['pending']:
225 if results['pending']:
219 return False
226 return False
220
227
221 return True
228 return True
222
229
223 def check_aborted(self, msg_id):
230 def check_aborted(self, msg_id):
224 return msg_id in self.aborted
231 return msg_id in self.aborted
225
232
226 #-------------------- queue handlers -----------------------------
233 #-------------------- queue handlers -----------------------------
227
234
228 def clear_request(self, stream, idents, parent):
235 def clear_request(self, stream, idents, parent):
229 """Clear our namespace."""
236 """Clear our namespace."""
230 self.user_ns = {}
237 self.user_ns = {}
231 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
238 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
232 content = dict(status='ok'))
239 content = dict(status='ok'))
233 self._initial_exec_lines()
240 self._initial_exec_lines()
234
241
235 def execute_request(self, stream, ident, parent):
242 def execute_request(self, stream, ident, parent):
236 self.log.debug('execute request %s'%parent)
243 self.log.debug('execute request %s'%parent)
237 try:
244 try:
238 code = parent[u'content'][u'code']
245 code = parent[u'content'][u'code']
239 except:
246 except:
240 self.log.error("Got bad msg: %s"%parent, exc_info=True)
247 self.log.error("Got bad msg: %s"%parent, exc_info=True)
241 return
248 return
242 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
249 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
243 ident='%s.pyin'%self.prefix)
250 ident='%s.pyin'%self.prefix)
244 started = datetime.now()
251 started = datetime.now()
245 try:
252 try:
246 comp_code = self.compiler(code, '<zmq-kernel>')
253 comp_code = self.compiler(code, '<zmq-kernel>')
247 # allow for not overriding displayhook
254 # allow for not overriding displayhook
248 if hasattr(sys.displayhook, 'set_parent'):
255 if hasattr(sys.displayhook, 'set_parent'):
249 sys.displayhook.set_parent(parent)
256 sys.displayhook.set_parent(parent)
250 sys.stdout.set_parent(parent)
257 sys.stdout.set_parent(parent)
251 sys.stderr.set_parent(parent)
258 sys.stderr.set_parent(parent)
252 exec comp_code in self.user_ns, self.user_ns
259 exec comp_code in self.user_ns, self.user_ns
253 except:
260 except:
254 exc_content = self._wrap_exception('execute')
261 exc_content = self._wrap_exception('execute')
255 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
262 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
256 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
263 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
257 ident='%s.pyerr'%self.prefix)
264 ident='%s.pyerr'%self.prefix)
258 reply_content = exc_content
265 reply_content = exc_content
259 else:
266 else:
260 reply_content = {'status' : 'ok'}
267 reply_content = {'status' : 'ok'}
261
268
262 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
269 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
263 ident=ident, subheader = dict(started=started))
270 ident=ident, subheader = dict(started=started))
264 self.log.debug(str(reply_msg))
271 self.log.debug(str(reply_msg))
265 if reply_msg['content']['status'] == u'error':
272 if reply_msg['content']['status'] == u'error':
266 self.abort_queues()
273 self.abort_queues()
267
274
268 def complete_request(self, stream, ident, parent):
275 def complete_request(self, stream, ident, parent):
269 matches = {'matches' : self.complete(parent),
276 matches = {'matches' : self.complete(parent),
270 'status' : 'ok'}
277 'status' : 'ok'}
271 completion_msg = self.session.send(stream, 'complete_reply',
278 completion_msg = self.session.send(stream, 'complete_reply',
272 matches, parent, ident)
279 matches, parent, ident)
273 # print >> sys.__stdout__, completion_msg
280 # print >> sys.__stdout__, completion_msg
274
281
275 def complete(self, msg):
282 def complete(self, msg):
276 return self.completer.complete(msg.content.line, msg.content.text)
283 return self.completer.complete(msg.content.line, msg.content.text)
277
284
278 def apply_request(self, stream, ident, parent):
285 def apply_request(self, stream, ident, parent):
279 # flush previous reply, so this request won't block it
286 # flush previous reply, so this request won't block it
280 stream.flush(zmq.POLLOUT)
287 stream.flush(zmq.POLLOUT)
281
288
282 try:
289 try:
283 content = parent[u'content']
290 content = parent[u'content']
284 bufs = parent[u'buffers']
291 bufs = parent[u'buffers']
285 msg_id = parent['header']['msg_id']
292 msg_id = parent['header']['msg_id']
286 # bound = parent['header'].get('bound', False)
293 # bound = parent['header'].get('bound', False)
287 except:
294 except:
288 self.log.error("Got bad msg: %s"%parent, exc_info=True)
295 self.log.error("Got bad msg: %s"%parent, exc_info=True)
289 return
296 return
290 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
297 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
291 # self.iopub_stream.send(pyin_msg)
298 # self.iopub_stream.send(pyin_msg)
292 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
299 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
293 sub = {'dependencies_met' : True, 'engine' : self.ident,
300 sub = {'dependencies_met' : True, 'engine' : self.ident,
294 'started': datetime.now()}
301 'started': datetime.now()}
295 try:
302 try:
296 # allow for not overriding displayhook
303 # allow for not overriding displayhook
297 if hasattr(sys.displayhook, 'set_parent'):
304 if hasattr(sys.displayhook, 'set_parent'):
298 sys.displayhook.set_parent(parent)
305 sys.displayhook.set_parent(parent)
299 sys.stdout.set_parent(parent)
306 sys.stdout.set_parent(parent)
300 sys.stderr.set_parent(parent)
307 sys.stderr.set_parent(parent)
301 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
308 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
302 working = self.user_ns
309 working = self.user_ns
303 # suffix =
310 # suffix =
304 prefix = "_"+str(msg_id).replace("-","")+"_"
311 prefix = "_"+str(msg_id).replace("-","")+"_"
305
312
306 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
313 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
307 # if bound:
314 # if bound:
308 # bound_ns = Namespace(working)
315 # bound_ns = Namespace(working)
309 # args = [bound_ns]+list(args)
316 # args = [bound_ns]+list(args)
310
317
311 fname = getattr(f, '__name__', 'f')
318 fname = getattr(f, '__name__', 'f')
312
319
313 fname = prefix+"f"
320 fname = prefix+"f"
314 argname = prefix+"args"
321 argname = prefix+"args"
315 kwargname = prefix+"kwargs"
322 kwargname = prefix+"kwargs"
316 resultname = prefix+"result"
323 resultname = prefix+"result"
317
324
318 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
325 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
319 # print ns
326 # print ns
320 working.update(ns)
327 working.update(ns)
321 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
328 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
322 try:
329 try:
323 exec code in working,working
330 exec code in working,working
324 result = working.get(resultname)
331 result = working.get(resultname)
325 finally:
332 finally:
326 for key in ns.iterkeys():
333 for key in ns.iterkeys():
327 working.pop(key)
334 working.pop(key)
328 # if bound:
335 # if bound:
329 # working.update(bound_ns)
336 # working.update(bound_ns)
330
337
331 packed_result,buf = serialize_object(result)
338 packed_result,buf = serialize_object(result)
332 result_buf = [packed_result]+buf
339 result_buf = [packed_result]+buf
333 except:
340 except:
334 exc_content = self._wrap_exception('apply')
341 exc_content = self._wrap_exception('apply')
335 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
342 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
336 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
343 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
337 ident='%s.pyerr'%self.prefix)
344 ident='%s.pyerr'%self.prefix)
338 reply_content = exc_content
345 reply_content = exc_content
339 result_buf = []
346 result_buf = []
340
347
341 if exc_content['ename'] == 'UnmetDependency':
348 if exc_content['ename'] == 'UnmetDependency':
342 sub['dependencies_met'] = False
349 sub['dependencies_met'] = False
343 else:
350 else:
344 reply_content = {'status' : 'ok'}
351 reply_content = {'status' : 'ok'}
345
352
346 # put 'ok'/'error' status in header, for scheduler introspection:
353 # put 'ok'/'error' status in header, for scheduler introspection:
347 sub['status'] = reply_content['status']
354 sub['status'] = reply_content['status']
348
355
349 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
356 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
350 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
357 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
351
358
352 # flush i/o
359 # flush i/o
353 # should this be before reply_msg is sent, like in the single-kernel code,
360 # should this be before reply_msg is sent, like in the single-kernel code,
354 # or should nothing get in the way of real results?
361 # or should nothing get in the way of real results?
355 sys.stdout.flush()
362 sys.stdout.flush()
356 sys.stderr.flush()
363 sys.stderr.flush()
357
364
358 def dispatch_queue(self, stream, msg):
365 def dispatch_queue(self, stream, msg):
359 self.control_stream.flush()
366 self.control_stream.flush()
360 idents,msg = self.session.feed_identities(msg, copy=False)
367 idents,msg = self.session.feed_identities(msg, copy=False)
361 try:
368 try:
362 msg = self.session.unpack_message(msg, content=True, copy=False)
369 msg = self.session.unpack_message(msg, content=True, copy=False)
363 except:
370 except:
364 self.log.error("Invalid Message", exc_info=True)
371 self.log.error("Invalid Message", exc_info=True)
365 return
372 return
366
373
367
374
368 header = msg['header']
375 header = msg['header']
369 msg_id = header['msg_id']
376 msg_id = header['msg_id']
370 if self.check_aborted(msg_id):
377 if self.check_aborted(msg_id):
371 self.aborted.remove(msg_id)
378 self.aborted.remove(msg_id)
372 # is it safe to assume a msg_id will not be resubmitted?
379 # is it safe to assume a msg_id will not be resubmitted?
373 reply_type = msg['msg_type'].split('_')[0] + '_reply'
380 reply_type = msg['msg_type'].split('_')[0] + '_reply'
374 status = {'status' : 'aborted'}
381 status = {'status' : 'aborted'}
375 reply_msg = self.session.send(stream, reply_type, subheader=status,
382 reply_msg = self.session.send(stream, reply_type, subheader=status,
376 content=status, parent=msg, ident=idents)
383 content=status, parent=msg, ident=idents)
377 return
384 return
378 handler = self.shell_handlers.get(msg['msg_type'], None)
385 handler = self.shell_handlers.get(msg['msg_type'], None)
379 if handler is None:
386 if handler is None:
380 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
387 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
381 else:
388 else:
382 handler(stream, idents, msg)
389 handler(stream, idents, msg)
383
390
384 def start(self):
391 def start(self):
385 #### stream mode:
392 #### stream mode:
386 if self.control_stream:
393 if self.control_stream:
387 self.control_stream.on_recv(self.dispatch_control, copy=False)
394 self.control_stream.on_recv(self.dispatch_control, copy=False)
388 self.control_stream.on_err(printer)
395 self.control_stream.on_err(printer)
389
396
390 def make_dispatcher(stream):
397 def make_dispatcher(stream):
391 def dispatcher(msg):
398 def dispatcher(msg):
392 return self.dispatch_queue(stream, msg)
399 return self.dispatch_queue(stream, msg)
393 return dispatcher
400 return dispatcher
394
401
395 for s in self.shell_streams:
402 for s in self.shell_streams:
396 s.on_recv(make_dispatcher(s), copy=False)
403 s.on_recv(make_dispatcher(s), copy=False)
397 s.on_err(printer)
404 s.on_err(printer)
398
405
399 if self.iopub_stream:
406 if self.iopub_stream:
400 self.iopub_stream.on_err(printer)
407 self.iopub_stream.on_err(printer)
401
408
402 #### while True mode:
409 #### while True mode:
403 # while True:
410 # while True:
404 # idle = True
411 # idle = True
405 # try:
412 # try:
406 # msg = self.shell_stream.socket.recv_multipart(
413 # msg = self.shell_stream.socket.recv_multipart(
407 # zmq.NOBLOCK, copy=False)
414 # zmq.NOBLOCK, copy=False)
408 # except zmq.ZMQError, e:
415 # except zmq.ZMQError, e:
409 # if e.errno != zmq.EAGAIN:
416 # if e.errno != zmq.EAGAIN:
410 # raise e
417 # raise e
411 # else:
418 # else:
412 # idle=False
419 # idle=False
413 # self.dispatch_queue(self.shell_stream, msg)
420 # self.dispatch_queue(self.shell_stream, msg)
414 #
421 #
415 # if not self.task_stream.empty():
422 # if not self.task_stream.empty():
416 # idle=False
423 # idle=False
417 # msg = self.task_stream.recv_multipart()
424 # msg = self.task_stream.recv_multipart()
418 # self.dispatch_queue(self.task_stream, msg)
425 # self.dispatch_queue(self.task_stream, msg)
419 # if idle:
426 # if idle:
420 # # don't busywait
427 # # don't busywait
421 # time.sleep(1e-3)
428 # time.sleep(1e-3)
422
429
@@ -1,313 +1,319 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Classes and functions for kernel related errors and exceptions."""
3 """Classes and functions for kernel related errors and exceptions.
4
5 Authors:
6
7 * Brian Granger
8 * Min RK
9 """
4 from __future__ import print_function
10 from __future__ import print_function
5
11
6 import sys
12 import sys
7 import traceback
13 import traceback
8
14
9 __docformat__ = "restructuredtext en"
15 __docformat__ = "restructuredtext en"
10
16
11 # Tell nose to skip this module
17 # Tell nose to skip this module
12 __test__ = {}
18 __test__ = {}
13
19
14 #-------------------------------------------------------------------------------
20 #-------------------------------------------------------------------------------
15 # Copyright (C) 2008 The IPython Development Team
21 # Copyright (C) 2008-2011 The IPython Development Team
16 #
22 #
17 # Distributed under the terms of the BSD License. The full license is in
23 # Distributed under the terms of the BSD License. The full license is in
18 # the file COPYING, distributed as part of this software.
24 # the file COPYING, distributed as part of this software.
19 #-------------------------------------------------------------------------------
25 #-------------------------------------------------------------------------------
20
26
21 #-------------------------------------------------------------------------------
27 #-------------------------------------------------------------------------------
22 # Error classes
28 # Error classes
23 #-------------------------------------------------------------------------------
29 #-------------------------------------------------------------------------------
24 class IPythonError(Exception):
30 class IPythonError(Exception):
25 """Base exception that all of our exceptions inherit from.
31 """Base exception that all of our exceptions inherit from.
26
32
27 This can be raised by code that doesn't have any more specific
33 This can be raised by code that doesn't have any more specific
28 information."""
34 information."""
29
35
30 pass
36 pass
31
37
32 # Exceptions associated with the controller objects
38 # Exceptions associated with the controller objects
33 class ControllerError(IPythonError): pass
39 class ControllerError(IPythonError): pass
34
40
35 class ControllerCreationError(ControllerError): pass
41 class ControllerCreationError(ControllerError): pass
36
42
37
43
38 # Exceptions associated with the Engines
44 # Exceptions associated with the Engines
39 class EngineError(IPythonError): pass
45 class EngineError(IPythonError): pass
40
46
41 class EngineCreationError(EngineError): pass
47 class EngineCreationError(EngineError): pass
42
48
43 class KernelError(IPythonError):
49 class KernelError(IPythonError):
44 pass
50 pass
45
51
46 class NotDefined(KernelError):
52 class NotDefined(KernelError):
47 def __init__(self, name):
53 def __init__(self, name):
48 self.name = name
54 self.name = name
49 self.args = (name,)
55 self.args = (name,)
50
56
51 def __repr__(self):
57 def __repr__(self):
52 return '<NotDefined: %s>' % self.name
58 return '<NotDefined: %s>' % self.name
53
59
54 __str__ = __repr__
60 __str__ = __repr__
55
61
56
62
57 class QueueCleared(KernelError):
63 class QueueCleared(KernelError):
58 pass
64 pass
59
65
60
66
61 class IdInUse(KernelError):
67 class IdInUse(KernelError):
62 pass
68 pass
63
69
64
70
65 class ProtocolError(KernelError):
71 class ProtocolError(KernelError):
66 pass
72 pass
67
73
68
74
69 class ConnectionError(KernelError):
75 class ConnectionError(KernelError):
70 pass
76 pass
71
77
72
78
73 class InvalidEngineID(KernelError):
79 class InvalidEngineID(KernelError):
74 pass
80 pass
75
81
76
82
77 class NoEnginesRegistered(KernelError):
83 class NoEnginesRegistered(KernelError):
78 pass
84 pass
79
85
80
86
81 class InvalidClientID(KernelError):
87 class InvalidClientID(KernelError):
82 pass
88 pass
83
89
84
90
85 class InvalidDeferredID(KernelError):
91 class InvalidDeferredID(KernelError):
86 pass
92 pass
87
93
88
94
89 class SerializationError(KernelError):
95 class SerializationError(KernelError):
90 pass
96 pass
91
97
92
98
93 class MessageSizeError(KernelError):
99 class MessageSizeError(KernelError):
94 pass
100 pass
95
101
96
102
97 class PBMessageSizeError(MessageSizeError):
103 class PBMessageSizeError(MessageSizeError):
98 pass
104 pass
99
105
100
106
101 class ResultNotCompleted(KernelError):
107 class ResultNotCompleted(KernelError):
102 pass
108 pass
103
109
104
110
105 class ResultAlreadyRetrieved(KernelError):
111 class ResultAlreadyRetrieved(KernelError):
106 pass
112 pass
107
113
108 class ClientError(KernelError):
114 class ClientError(KernelError):
109 pass
115 pass
110
116
111
117
112 class TaskAborted(KernelError):
118 class TaskAborted(KernelError):
113 pass
119 pass
114
120
115
121
116 class TaskTimeout(KernelError):
122 class TaskTimeout(KernelError):
117 pass
123 pass
118
124
119
125
120 class NotAPendingResult(KernelError):
126 class NotAPendingResult(KernelError):
121 pass
127 pass
122
128
123
129
124 class UnpickleableException(KernelError):
130 class UnpickleableException(KernelError):
125 pass
131 pass
126
132
127
133
128 class AbortedPendingDeferredError(KernelError):
134 class AbortedPendingDeferredError(KernelError):
129 pass
135 pass
130
136
131
137
132 class InvalidProperty(KernelError):
138 class InvalidProperty(KernelError):
133 pass
139 pass
134
140
135
141
136 class MissingBlockArgument(KernelError):
142 class MissingBlockArgument(KernelError):
137 pass
143 pass
138
144
139
145
140 class StopLocalExecution(KernelError):
146 class StopLocalExecution(KernelError):
141 pass
147 pass
142
148
143
149
144 class SecurityError(KernelError):
150 class SecurityError(KernelError):
145 pass
151 pass
146
152
147
153
148 class FileTimeoutError(KernelError):
154 class FileTimeoutError(KernelError):
149 pass
155 pass
150
156
151 class TimeoutError(KernelError):
157 class TimeoutError(KernelError):
152 pass
158 pass
153
159
154 class UnmetDependency(KernelError):
160 class UnmetDependency(KernelError):
155 pass
161 pass
156
162
157 class ImpossibleDependency(UnmetDependency):
163 class ImpossibleDependency(UnmetDependency):
158 pass
164 pass
159
165
160 class DependencyTimeout(ImpossibleDependency):
166 class DependencyTimeout(ImpossibleDependency):
161 pass
167 pass
162
168
163 class InvalidDependency(ImpossibleDependency):
169 class InvalidDependency(ImpossibleDependency):
164 pass
170 pass
165
171
166 class RemoteError(KernelError):
172 class RemoteError(KernelError):
167 """Error raised elsewhere"""
173 """Error raised elsewhere"""
168 ename=None
174 ename=None
169 evalue=None
175 evalue=None
170 traceback=None
176 traceback=None
171 engine_info=None
177 engine_info=None
172
178
173 def __init__(self, ename, evalue, traceback, engine_info=None):
179 def __init__(self, ename, evalue, traceback, engine_info=None):
174 self.ename=ename
180 self.ename=ename
175 self.evalue=evalue
181 self.evalue=evalue
176 self.traceback=traceback
182 self.traceback=traceback
177 self.engine_info=engine_info or {}
183 self.engine_info=engine_info or {}
178 self.args=(ename, evalue)
184 self.args=(ename, evalue)
179
185
180 def __repr__(self):
186 def __repr__(self):
181 engineid = self.engine_info.get('engine_id', ' ')
187 engineid = self.engine_info.get('engine_id', ' ')
182 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
188 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
183
189
184 def __str__(self):
190 def __str__(self):
185 sig = "%s(%s)"%(self.ename, self.evalue)
191 sig = "%s(%s)"%(self.ename, self.evalue)
186 if self.traceback:
192 if self.traceback:
187 return sig + '\n' + self.traceback
193 return sig + '\n' + self.traceback
188 else:
194 else:
189 return sig
195 return sig
190
196
191
197
192 class TaskRejectError(KernelError):
198 class TaskRejectError(KernelError):
193 """Exception to raise when a task should be rejected by an engine.
199 """Exception to raise when a task should be rejected by an engine.
194
200
195 This exception can be used to allow a task running on an engine to test
201 This exception can be used to allow a task running on an engine to test
196 if the engine (or the user's namespace on the engine) has the needed
202 if the engine (or the user's namespace on the engine) has the needed
197 task dependencies. If not, the task should raise this exception. For
203 task dependencies. If not, the task should raise this exception. For
198 the task to be retried on another engine, the task should be created
204 the task to be retried on another engine, the task should be created
199 with the `retries` argument > 1.
205 with the `retries` argument > 1.
200
206
201 The advantage of this approach over our older properties system is that
207 The advantage of this approach over our older properties system is that
202 tasks have full access to the user's namespace on the engines and the
208 tasks have full access to the user's namespace on the engines and the
203 properties don't have to be managed or tested by the controller.
209 properties don't have to be managed or tested by the controller.
204 """
210 """
205
211
206
212
207 class CompositeError(RemoteError):
213 class CompositeError(RemoteError):
208 """Error for representing possibly multiple errors on engines"""
214 """Error for representing possibly multiple errors on engines"""
209 def __init__(self, message, elist):
215 def __init__(self, message, elist):
210 Exception.__init__(self, *(message, elist))
216 Exception.__init__(self, *(message, elist))
211 # Don't use pack_exception because it will conflict with the .message
217 # Don't use pack_exception because it will conflict with the .message
212 # attribute that is being deprecated in 2.6 and beyond.
218 # attribute that is being deprecated in 2.6 and beyond.
213 self.msg = message
219 self.msg = message
214 self.elist = elist
220 self.elist = elist
215 self.args = [ e[0] for e in elist ]
221 self.args = [ e[0] for e in elist ]
216
222
217 def _get_engine_str(self, ei):
223 def _get_engine_str(self, ei):
218 if not ei:
224 if not ei:
219 return '[Engine Exception]'
225 return '[Engine Exception]'
220 else:
226 else:
221 return '[%s:%s]: ' % (ei['engine_id'], ei['method'])
227 return '[%s:%s]: ' % (ei['engine_id'], ei['method'])
222
228
223 def _get_traceback(self, ev):
229 def _get_traceback(self, ev):
224 try:
230 try:
225 tb = ev._ipython_traceback_text
231 tb = ev._ipython_traceback_text
226 except AttributeError:
232 except AttributeError:
227 return 'No traceback available'
233 return 'No traceback available'
228 else:
234 else:
229 return tb
235 return tb
230
236
231 def __str__(self):
237 def __str__(self):
232 s = str(self.msg)
238 s = str(self.msg)
233 for en, ev, etb, ei in self.elist:
239 for en, ev, etb, ei in self.elist:
234 engine_str = self._get_engine_str(ei)
240 engine_str = self._get_engine_str(ei)
235 s = s + '\n' + engine_str + en + ': ' + str(ev)
241 s = s + '\n' + engine_str + en + ': ' + str(ev)
236 return s
242 return s
237
243
238 def __repr__(self):
244 def __repr__(self):
239 return "CompositeError(%i)"%len(self.elist)
245 return "CompositeError(%i)"%len(self.elist)
240
246
241 def print_tracebacks(self, excid=None):
247 def print_tracebacks(self, excid=None):
242 if excid is None:
248 if excid is None:
243 for (en,ev,etb,ei) in self.elist:
249 for (en,ev,etb,ei) in self.elist:
244 print (self._get_engine_str(ei))
250 print (self._get_engine_str(ei))
245 print (etb or 'No traceback available')
251 print (etb or 'No traceback available')
246 print ()
252 print ()
247 else:
253 else:
248 try:
254 try:
249 en,ev,etb,ei = self.elist[excid]
255 en,ev,etb,ei = self.elist[excid]
250 except:
256 except:
251 raise IndexError("an exception with index %i does not exist"%excid)
257 raise IndexError("an exception with index %i does not exist"%excid)
252 else:
258 else:
253 print (self._get_engine_str(ei))
259 print (self._get_engine_str(ei))
254 print (etb or 'No traceback available')
260 print (etb or 'No traceback available')
255
261
256 def raise_exception(self, excid=0):
262 def raise_exception(self, excid=0):
257 try:
263 try:
258 en,ev,etb,ei = self.elist[excid]
264 en,ev,etb,ei = self.elist[excid]
259 except:
265 except:
260 raise IndexError("an exception with index %i does not exist"%excid)
266 raise IndexError("an exception with index %i does not exist"%excid)
261 else:
267 else:
262 raise RemoteError(en, ev, etb, ei)
268 raise RemoteError(en, ev, etb, ei)
263
269
264
270
265 def collect_exceptions(rdict_or_list, method='unspecified'):
271 def collect_exceptions(rdict_or_list, method='unspecified'):
266 """check a result dict for errors, and raise CompositeError if any exist.
272 """check a result dict for errors, and raise CompositeError if any exist.
267 Passthrough otherwise."""
273 Passthrough otherwise."""
268 elist = []
274 elist = []
269 if isinstance(rdict_or_list, dict):
275 if isinstance(rdict_or_list, dict):
270 rlist = rdict_or_list.values()
276 rlist = rdict_or_list.values()
271 else:
277 else:
272 rlist = rdict_or_list
278 rlist = rdict_or_list
273 for r in rlist:
279 for r in rlist:
274 if isinstance(r, RemoteError):
280 if isinstance(r, RemoteError):
275 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
281 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
276 # Sometimes we could have CompositeError in our list. Just take
282 # Sometimes we could have CompositeError in our list. Just take
277 # the errors out of them and put them in our new list. This
283 # the errors out of them and put them in our new list. This
278 # has the effect of flattening lists of CompositeErrors into one
284 # has the effect of flattening lists of CompositeErrors into one
279 # CompositeError
285 # CompositeError
280 if en=='CompositeError':
286 if en=='CompositeError':
281 for e in ev.elist:
287 for e in ev.elist:
282 elist.append(e)
288 elist.append(e)
283 else:
289 else:
284 elist.append((en, ev, etb, ei))
290 elist.append((en, ev, etb, ei))
285 if len(elist)==0:
291 if len(elist)==0:
286 return rdict_or_list
292 return rdict_or_list
287 else:
293 else:
288 msg = "one or more exceptions from call to method: %s" % (method)
294 msg = "one or more exceptions from call to method: %s" % (method)
289 # This silliness is needed so the debugger has access to the exception
295 # This silliness is needed so the debugger has access to the exception
290 # instance (e in this case)
296 # instance (e in this case)
291 try:
297 try:
292 raise CompositeError(msg, elist)
298 raise CompositeError(msg, elist)
293 except CompositeError as e:
299 except CompositeError as e:
294 raise e
300 raise e
295
301
296 def wrap_exception(engine_info={}):
302 def wrap_exception(engine_info={}):
297 etype, evalue, tb = sys.exc_info()
303 etype, evalue, tb = sys.exc_info()
298 stb = traceback.format_exception(etype, evalue, tb)
304 stb = traceback.format_exception(etype, evalue, tb)
299 exc_content = {
305 exc_content = {
300 'status' : 'error',
306 'status' : 'error',
301 'traceback' : stb,
307 'traceback' : stb,
302 'ename' : unicode(etype.__name__),
308 'ename' : unicode(etype.__name__),
303 'evalue' : unicode(evalue),
309 'evalue' : unicode(evalue),
304 'engine_info' : engine_info
310 'engine_info' : engine_info
305 }
311 }
306 return exc_content
312 return exc_content
307
313
308 def unwrap_exception(content):
314 def unwrap_exception(content):
309 err = RemoteError(content['ename'], content['evalue'],
315 err = RemoteError(content['ename'], content['evalue'],
310 ''.join(content['traceback']),
316 ''.join(content['traceback']),
311 content.get('engine_info', {}))
317 content.get('engine_info', {}))
312 return err
318 return err
313
319
@@ -1,72 +1,77 b''
1 """Base config factories."""
1 """Base config factories.
2
3 Authors:
4
5 * Min RK
6 """
2
7
3 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
4 # Copyright (C) 2010-2011 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
5 #
10 #
6 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
8 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
9
14
10 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
11 # Imports
16 # Imports
12 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
13
18
14
19
15 import logging
20 import logging
16 import os
21 import os
17
22
18 import zmq
23 import zmq
19 from zmq.eventloop.ioloop import IOLoop
24 from zmq.eventloop.ioloop import IOLoop
20
25
21 from IPython.config.configurable import Configurable
26 from IPython.config.configurable import Configurable
22 from IPython.utils.traitlets import Int, Instance, Unicode
27 from IPython.utils.traitlets import Int, Instance, Unicode
23
28
24 from IPython.parallel.util import select_random_ports
29 from IPython.parallel.util import select_random_ports
25 from IPython.zmq.session import Session, SessionFactory
30 from IPython.zmq.session import Session, SessionFactory
26
31
27 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
28 # Classes
33 # Classes
29 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
30
35
31
36
32 class RegistrationFactory(SessionFactory):
37 class RegistrationFactory(SessionFactory):
33 """The Base Configurable for objects that involve registration."""
38 """The Base Configurable for objects that involve registration."""
34
39
35 url = Unicode('', config=True,
40 url = Unicode('', config=True,
36 help="""The 0MQ url used for registration. This sets transport, ip, and port
41 help="""The 0MQ url used for registration. This sets transport, ip, and port
37 in one variable. For example: url='tcp://127.0.0.1:12345' or
42 in one variable. For example: url='tcp://127.0.0.1:12345' or
38 url='epgm://*:90210'""") # url takes precedence over ip,regport,transport
43 url='epgm://*:90210'""") # url takes precedence over ip,regport,transport
39 transport = Unicode('tcp', config=True,
44 transport = Unicode('tcp', config=True,
40 help="""The 0MQ transport for communications. This will likely be
45 help="""The 0MQ transport for communications. This will likely be
41 the default of 'tcp', but other values include 'ipc', 'epgm', 'inproc'.""")
46 the default of 'tcp', but other values include 'ipc', 'epgm', 'inproc'.""")
42 ip = Unicode('127.0.0.1', config=True,
47 ip = Unicode('127.0.0.1', config=True,
43 help="""The IP address for registration. This is generally either
48 help="""The IP address for registration. This is generally either
44 '127.0.0.1' for loopback only or '*' for all interfaces.
49 '127.0.0.1' for loopback only or '*' for all interfaces.
45 [default: '127.0.0.1']""")
50 [default: '127.0.0.1']""")
46 regport = Int(config=True,
51 regport = Int(config=True,
47 help="""The port on which the Hub listens for registration.""")
52 help="""The port on which the Hub listens for registration.""")
48 def _regport_default(self):
53 def _regport_default(self):
49 return select_random_ports(1)[0]
54 return select_random_ports(1)[0]
50
55
51 def __init__(self, **kwargs):
56 def __init__(self, **kwargs):
52 super(RegistrationFactory, self).__init__(**kwargs)
57 super(RegistrationFactory, self).__init__(**kwargs)
53 self._propagate_url()
58 self._propagate_url()
54 self._rebuild_url()
59 self._rebuild_url()
55 self.on_trait_change(self._propagate_url, 'url')
60 self.on_trait_change(self._propagate_url, 'url')
56 self.on_trait_change(self._rebuild_url, 'ip')
61 self.on_trait_change(self._rebuild_url, 'ip')
57 self.on_trait_change(self._rebuild_url, 'transport')
62 self.on_trait_change(self._rebuild_url, 'transport')
58 self.on_trait_change(self._rebuild_url, 'regport')
63 self.on_trait_change(self._rebuild_url, 'regport')
59
64
60 def _rebuild_url(self):
65 def _rebuild_url(self):
61 self.url = "%s://%s:%i"%(self.transport, self.ip, self.regport)
66 self.url = "%s://%s:%i"%(self.transport, self.ip, self.regport)
62
67
63 def _propagate_url(self):
68 def _propagate_url(self):
64 """Ensure self.url contains full transport://interface:port"""
69 """Ensure self.url contains full transport://interface:port"""
65 if self.url:
70 if self.url:
66 iface = self.url.split('://',1)
71 iface = self.url.split('://',1)
67 if len(iface) == 2:
72 if len(iface) == 2:
68 self.transport,iface = iface
73 self.transport,iface = iface
69 iface = iface.split(':')
74 iface = iface.split(':')
70 self.ip = iface[0]
75 self.ip = iface[0]
71 if iface[1]:
76 if iface[1]:
72 self.regport = int(iface[1])
77 self.regport = int(iface[1])
@@ -1,129 +1,134 b''
1 """base class for parallel client tests"""
1 """base class for parallel client tests
2
3 Authors:
4
5 * Min RK
6 """
2
7
3 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
5 #
10 #
6 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
9
14
10 import sys
15 import sys
11 import tempfile
16 import tempfile
12 import time
17 import time
13
18
14 from nose import SkipTest
19 from nose import SkipTest
15
20
16 import zmq
21 import zmq
17 from zmq.tests import BaseZMQTestCase
22 from zmq.tests import BaseZMQTestCase
18
23
19 from IPython.external.decorator import decorator
24 from IPython.external.decorator import decorator
20
25
21 from IPython.parallel import error
26 from IPython.parallel import error
22 from IPython.parallel import Client
27 from IPython.parallel import Client
23
28
24 from IPython.parallel.tests import launchers, add_engines
29 from IPython.parallel.tests import launchers, add_engines
25
30
26 # simple tasks for use in apply tests
31 # simple tasks for use in apply tests
27
32
28 def segfault():
33 def segfault():
29 """this will segfault"""
34 """this will segfault"""
30 import ctypes
35 import ctypes
31 ctypes.memset(-1,0,1)
36 ctypes.memset(-1,0,1)
32
37
33 def crash():
38 def crash():
34 """from stdlib crashers in the test suite"""
39 """from stdlib crashers in the test suite"""
35 import types
40 import types
36 if sys.platform.startswith('win'):
41 if sys.platform.startswith('win'):
37 import ctypes
42 import ctypes
38 ctypes.windll.kernel32.SetErrorMode(0x0002);
43 ctypes.windll.kernel32.SetErrorMode(0x0002);
39
44
40 co = types.CodeType(0, 0, 0, 0, b'\x04\x71\x00\x00',
45 co = types.CodeType(0, 0, 0, 0, b'\x04\x71\x00\x00',
41 (), (), (), '', '', 1, b'')
46 (), (), (), '', '', 1, b'')
42 exec(co)
47 exec(co)
43
48
44 def wait(n):
49 def wait(n):
45 """sleep for a time"""
50 """sleep for a time"""
46 import time
51 import time
47 time.sleep(n)
52 time.sleep(n)
48 return n
53 return n
49
54
50 def raiser(eclass):
55 def raiser(eclass):
51 """raise an exception"""
56 """raise an exception"""
52 raise eclass()
57 raise eclass()
53
58
54 # test decorator for skipping tests when libraries are unavailable
59 # test decorator for skipping tests when libraries are unavailable
55 def skip_without(*names):
60 def skip_without(*names):
56 """skip a test if some names are not importable"""
61 """skip a test if some names are not importable"""
57 @decorator
62 @decorator
58 def skip_without_names(f, *args, **kwargs):
63 def skip_without_names(f, *args, **kwargs):
59 """decorator to skip tests in the absence of numpy."""
64 """decorator to skip tests in the absence of numpy."""
60 for name in names:
65 for name in names:
61 try:
66 try:
62 __import__(name)
67 __import__(name)
63 except ImportError:
68 except ImportError:
64 raise SkipTest
69 raise SkipTest
65 return f(*args, **kwargs)
70 return f(*args, **kwargs)
66 return skip_without_names
71 return skip_without_names
67
72
68 class ClusterTestCase(BaseZMQTestCase):
73 class ClusterTestCase(BaseZMQTestCase):
69
74
70 def add_engines(self, n=1, block=True):
75 def add_engines(self, n=1, block=True):
71 """add multiple engines to our cluster"""
76 """add multiple engines to our cluster"""
72 self.engines.extend(add_engines(n))
77 self.engines.extend(add_engines(n))
73 if block:
78 if block:
74 self.wait_on_engines()
79 self.wait_on_engines()
75
80
76 def wait_on_engines(self, timeout=5):
81 def wait_on_engines(self, timeout=5):
77 """wait for our engines to connect."""
82 """wait for our engines to connect."""
78 n = len(self.engines)+self.base_engine_count
83 n = len(self.engines)+self.base_engine_count
79 tic = time.time()
84 tic = time.time()
80 while time.time()-tic < timeout and len(self.client.ids) < n:
85 while time.time()-tic < timeout and len(self.client.ids) < n:
81 time.sleep(0.1)
86 time.sleep(0.1)
82
87
83 assert not len(self.client.ids) < n, "waiting for engines timed out"
88 assert not len(self.client.ids) < n, "waiting for engines timed out"
84
89
85 def connect_client(self):
90 def connect_client(self):
86 """connect a client with my Context, and track its sockets for cleanup"""
91 """connect a client with my Context, and track its sockets for cleanup"""
87 c = Client(profile='iptest', context=self.context)
92 c = Client(profile='iptest', context=self.context)
88 for name in filter(lambda n:n.endswith('socket'), dir(c)):
93 for name in filter(lambda n:n.endswith('socket'), dir(c)):
89 s = getattr(c, name)
94 s = getattr(c, name)
90 s.setsockopt(zmq.LINGER, 0)
95 s.setsockopt(zmq.LINGER, 0)
91 self.sockets.append(s)
96 self.sockets.append(s)
92 return c
97 return c
93
98
94 def assertRaisesRemote(self, etype, f, *args, **kwargs):
99 def assertRaisesRemote(self, etype, f, *args, **kwargs):
95 try:
100 try:
96 try:
101 try:
97 f(*args, **kwargs)
102 f(*args, **kwargs)
98 except error.CompositeError as e:
103 except error.CompositeError as e:
99 e.raise_exception()
104 e.raise_exception()
100 except error.RemoteError as e:
105 except error.RemoteError as e:
101 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
106 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
102 else:
107 else:
103 self.fail("should have raised a RemoteError")
108 self.fail("should have raised a RemoteError")
104
109
105 def setUp(self):
110 def setUp(self):
106 BaseZMQTestCase.setUp(self)
111 BaseZMQTestCase.setUp(self)
107 self.client = self.connect_client()
112 self.client = self.connect_client()
108 # start every test with clean engine namespaces:
113 # start every test with clean engine namespaces:
109 self.client.clear(block=True)
114 self.client.clear(block=True)
110 self.base_engine_count=len(self.client.ids)
115 self.base_engine_count=len(self.client.ids)
111 self.engines=[]
116 self.engines=[]
112
117
113 def tearDown(self):
118 def tearDown(self):
114 # self.client.clear(block=True)
119 # self.client.clear(block=True)
115 # close fds:
120 # close fds:
116 for e in filter(lambda e: e.poll() is not None, launchers):
121 for e in filter(lambda e: e.poll() is not None, launchers):
117 launchers.remove(e)
122 launchers.remove(e)
118
123
119 # allow flushing of incoming messages to prevent crash on socket close
124 # allow flushing of incoming messages to prevent crash on socket close
120 self.client.wait(timeout=2)
125 self.client.wait(timeout=2)
121 # time.sleep(2)
126 # time.sleep(2)
122 self.client.spin()
127 self.client.spin()
123 self.client.close()
128 self.client.close()
124 BaseZMQTestCase.tearDown(self)
129 BaseZMQTestCase.tearDown(self)
125 # this will be redundant when pyzmq merges PR #88
130 # this will be redundant when pyzmq merges PR #88
126 # self.context.term()
131 # self.context.term()
127 # print tempfile.TemporaryFile().fileno(),
132 # print tempfile.TemporaryFile().fileno(),
128 # sys.stdout.flush()
133 # sys.stdout.flush()
129 No newline at end of file
134
@@ -1,68 +1,73 b''
1 """Tests for asyncresult.py"""
1 """Tests for asyncresult.py
2
3 Authors:
4
5 * Min RK
6 """
2
7
3 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
5 #
10 #
6 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
9
14
10 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
11 # Imports
16 # Imports
12 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
13
18
14
19
15 from IPython.parallel.error import TimeoutError
20 from IPython.parallel.error import TimeoutError
16
21
17 from IPython.parallel.tests import add_engines
22 from IPython.parallel.tests import add_engines
18 from .clienttest import ClusterTestCase
23 from .clienttest import ClusterTestCase
19
24
20 def setup():
25 def setup():
21 add_engines(2)
26 add_engines(2)
22
27
23 def wait(n):
28 def wait(n):
24 import time
29 import time
25 time.sleep(n)
30 time.sleep(n)
26 return n
31 return n
27
32
28 class AsyncResultTest(ClusterTestCase):
33 class AsyncResultTest(ClusterTestCase):
29
34
30 def test_single_result(self):
35 def test_single_result(self):
31 eid = self.client.ids[-1]
36 eid = self.client.ids[-1]
32 ar = self.client[eid].apply_async(lambda : 42)
37 ar = self.client[eid].apply_async(lambda : 42)
33 self.assertEquals(ar.get(), 42)
38 self.assertEquals(ar.get(), 42)
34 ar = self.client[[eid]].apply_async(lambda : 42)
39 ar = self.client[[eid]].apply_async(lambda : 42)
35 self.assertEquals(ar.get(), [42])
40 self.assertEquals(ar.get(), [42])
36 ar = self.client[-1:].apply_async(lambda : 42)
41 ar = self.client[-1:].apply_async(lambda : 42)
37 self.assertEquals(ar.get(), [42])
42 self.assertEquals(ar.get(), [42])
38
43
39 def test_get_after_done(self):
44 def test_get_after_done(self):
40 ar = self.client[-1].apply_async(lambda : 42)
45 ar = self.client[-1].apply_async(lambda : 42)
41 ar.wait()
46 ar.wait()
42 self.assertTrue(ar.ready())
47 self.assertTrue(ar.ready())
43 self.assertEquals(ar.get(), 42)
48 self.assertEquals(ar.get(), 42)
44 self.assertEquals(ar.get(), 42)
49 self.assertEquals(ar.get(), 42)
45
50
46 def test_get_before_done(self):
51 def test_get_before_done(self):
47 ar = self.client[-1].apply_async(wait, 0.1)
52 ar = self.client[-1].apply_async(wait, 0.1)
48 self.assertRaises(TimeoutError, ar.get, 0)
53 self.assertRaises(TimeoutError, ar.get, 0)
49 ar.wait(0)
54 ar.wait(0)
50 self.assertFalse(ar.ready())
55 self.assertFalse(ar.ready())
51 self.assertEquals(ar.get(), 0.1)
56 self.assertEquals(ar.get(), 0.1)
52
57
53 def test_get_after_error(self):
58 def test_get_after_error(self):
54 ar = self.client[-1].apply_async(lambda : 1/0)
59 ar = self.client[-1].apply_async(lambda : 1/0)
55 ar.wait()
60 ar.wait()
56 self.assertRaisesRemote(ZeroDivisionError, ar.get)
61 self.assertRaisesRemote(ZeroDivisionError, ar.get)
57 self.assertRaisesRemote(ZeroDivisionError, ar.get)
62 self.assertRaisesRemote(ZeroDivisionError, ar.get)
58 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
63 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
59
64
60 def test_get_dict(self):
65 def test_get_dict(self):
61 n = len(self.client)
66 n = len(self.client)
62 ar = self.client[:].apply_async(lambda : 5)
67 ar = self.client[:].apply_async(lambda : 5)
63 self.assertEquals(ar.get(), [5]*n)
68 self.assertEquals(ar.get(), [5]*n)
64 d = ar.get_dict()
69 d = ar.get_dict()
65 self.assertEquals(sorted(d.keys()), sorted(self.client.ids))
70 self.assertEquals(sorted(d.keys()), sorted(self.client.ids))
66 for eid,r in d.iteritems():
71 for eid,r in d.iteritems():
67 self.assertEquals(r, 5)
72 self.assertEquals(r, 5)
68
73
@@ -1,244 +1,249 b''
1 """Tests for parallel client.py"""
1 """Tests for parallel client.py
2
3 Authors:
4
5 * Min RK
6 """
2
7
3 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
5 #
10 #
6 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
9
14
10 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
11 # Imports
16 # Imports
12 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
13
18
14 import time
19 import time
15 from datetime import datetime
20 from datetime import datetime
16 from tempfile import mktemp
21 from tempfile import mktemp
17
22
18 import zmq
23 import zmq
19
24
20 from IPython.parallel.client import client as clientmod
25 from IPython.parallel.client import client as clientmod
21 from IPython.parallel import error
26 from IPython.parallel import error
22 from IPython.parallel import AsyncResult, AsyncHubResult
27 from IPython.parallel import AsyncResult, AsyncHubResult
23 from IPython.parallel import LoadBalancedView, DirectView
28 from IPython.parallel import LoadBalancedView, DirectView
24
29
25 from clienttest import ClusterTestCase, segfault, wait, add_engines
30 from clienttest import ClusterTestCase, segfault, wait, add_engines
26
31
27 def setup():
32 def setup():
28 add_engines(4)
33 add_engines(4)
29
34
30 class TestClient(ClusterTestCase):
35 class TestClient(ClusterTestCase):
31
36
32 def test_ids(self):
37 def test_ids(self):
33 n = len(self.client.ids)
38 n = len(self.client.ids)
34 self.add_engines(3)
39 self.add_engines(3)
35 self.assertEquals(len(self.client.ids), n+3)
40 self.assertEquals(len(self.client.ids), n+3)
36
41
37 def test_view_indexing(self):
42 def test_view_indexing(self):
38 """test index access for views"""
43 """test index access for views"""
39 self.add_engines(2)
44 self.add_engines(2)
40 targets = self.client._build_targets('all')[-1]
45 targets = self.client._build_targets('all')[-1]
41 v = self.client[:]
46 v = self.client[:]
42 self.assertEquals(v.targets, targets)
47 self.assertEquals(v.targets, targets)
43 t = self.client.ids[2]
48 t = self.client.ids[2]
44 v = self.client[t]
49 v = self.client[t]
45 self.assert_(isinstance(v, DirectView))
50 self.assert_(isinstance(v, DirectView))
46 self.assertEquals(v.targets, t)
51 self.assertEquals(v.targets, t)
47 t = self.client.ids[2:4]
52 t = self.client.ids[2:4]
48 v = self.client[t]
53 v = self.client[t]
49 self.assert_(isinstance(v, DirectView))
54 self.assert_(isinstance(v, DirectView))
50 self.assertEquals(v.targets, t)
55 self.assertEquals(v.targets, t)
51 v = self.client[::2]
56 v = self.client[::2]
52 self.assert_(isinstance(v, DirectView))
57 self.assert_(isinstance(v, DirectView))
53 self.assertEquals(v.targets, targets[::2])
58 self.assertEquals(v.targets, targets[::2])
54 v = self.client[1::3]
59 v = self.client[1::3]
55 self.assert_(isinstance(v, DirectView))
60 self.assert_(isinstance(v, DirectView))
56 self.assertEquals(v.targets, targets[1::3])
61 self.assertEquals(v.targets, targets[1::3])
57 v = self.client[:-3]
62 v = self.client[:-3]
58 self.assert_(isinstance(v, DirectView))
63 self.assert_(isinstance(v, DirectView))
59 self.assertEquals(v.targets, targets[:-3])
64 self.assertEquals(v.targets, targets[:-3])
60 v = self.client[-1]
65 v = self.client[-1]
61 self.assert_(isinstance(v, DirectView))
66 self.assert_(isinstance(v, DirectView))
62 self.assertEquals(v.targets, targets[-1])
67 self.assertEquals(v.targets, targets[-1])
63 self.assertRaises(TypeError, lambda : self.client[None])
68 self.assertRaises(TypeError, lambda : self.client[None])
64
69
65 def test_lbview_targets(self):
70 def test_lbview_targets(self):
66 """test load_balanced_view targets"""
71 """test load_balanced_view targets"""
67 v = self.client.load_balanced_view()
72 v = self.client.load_balanced_view()
68 self.assertEquals(v.targets, None)
73 self.assertEquals(v.targets, None)
69 v = self.client.load_balanced_view(-1)
74 v = self.client.load_balanced_view(-1)
70 self.assertEquals(v.targets, [self.client.ids[-1]])
75 self.assertEquals(v.targets, [self.client.ids[-1]])
71 v = self.client.load_balanced_view('all')
76 v = self.client.load_balanced_view('all')
72 self.assertEquals(v.targets, self.client.ids)
77 self.assertEquals(v.targets, self.client.ids)
73
78
74 def test_targets(self):
79 def test_targets(self):
75 """test various valid targets arguments"""
80 """test various valid targets arguments"""
76 build = self.client._build_targets
81 build = self.client._build_targets
77 ids = self.client.ids
82 ids = self.client.ids
78 idents,targets = build(None)
83 idents,targets = build(None)
79 self.assertEquals(ids, targets)
84 self.assertEquals(ids, targets)
80
85
81 def test_clear(self):
86 def test_clear(self):
82 """test clear behavior"""
87 """test clear behavior"""
83 # self.add_engines(2)
88 # self.add_engines(2)
84 v = self.client[:]
89 v = self.client[:]
85 v.block=True
90 v.block=True
86 v.push(dict(a=5))
91 v.push(dict(a=5))
87 v.pull('a')
92 v.pull('a')
88 id0 = self.client.ids[-1]
93 id0 = self.client.ids[-1]
89 self.client.clear(targets=id0, block=True)
94 self.client.clear(targets=id0, block=True)
90 a = self.client[:-1].get('a')
95 a = self.client[:-1].get('a')
91 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
96 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
92 self.client.clear(block=True)
97 self.client.clear(block=True)
93 for i in self.client.ids:
98 for i in self.client.ids:
94 # print i
99 # print i
95 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
100 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
96
101
97 def test_get_result(self):
102 def test_get_result(self):
98 """test getting results from the Hub."""
103 """test getting results from the Hub."""
99 c = clientmod.Client(profile='iptest')
104 c = clientmod.Client(profile='iptest')
100 # self.add_engines(1)
105 # self.add_engines(1)
101 t = c.ids[-1]
106 t = c.ids[-1]
102 ar = c[t].apply_async(wait, 1)
107 ar = c[t].apply_async(wait, 1)
103 # give the monitor time to notice the message
108 # give the monitor time to notice the message
104 time.sleep(.25)
109 time.sleep(.25)
105 ahr = self.client.get_result(ar.msg_ids)
110 ahr = self.client.get_result(ar.msg_ids)
106 self.assertTrue(isinstance(ahr, AsyncHubResult))
111 self.assertTrue(isinstance(ahr, AsyncHubResult))
107 self.assertEquals(ahr.get(), ar.get())
112 self.assertEquals(ahr.get(), ar.get())
108 ar2 = self.client.get_result(ar.msg_ids)
113 ar2 = self.client.get_result(ar.msg_ids)
109 self.assertFalse(isinstance(ar2, AsyncHubResult))
114 self.assertFalse(isinstance(ar2, AsyncHubResult))
110 c.close()
115 c.close()
111
116
112 def test_ids_list(self):
117 def test_ids_list(self):
113 """test client.ids"""
118 """test client.ids"""
114 # self.add_engines(2)
119 # self.add_engines(2)
115 ids = self.client.ids
120 ids = self.client.ids
116 self.assertEquals(ids, self.client._ids)
121 self.assertEquals(ids, self.client._ids)
117 self.assertFalse(ids is self.client._ids)
122 self.assertFalse(ids is self.client._ids)
118 ids.remove(ids[-1])
123 ids.remove(ids[-1])
119 self.assertNotEquals(ids, self.client._ids)
124 self.assertNotEquals(ids, self.client._ids)
120
125
121 def test_queue_status(self):
126 def test_queue_status(self):
122 # self.addEngine(4)
127 # self.addEngine(4)
123 ids = self.client.ids
128 ids = self.client.ids
124 id0 = ids[0]
129 id0 = ids[0]
125 qs = self.client.queue_status(targets=id0)
130 qs = self.client.queue_status(targets=id0)
126 self.assertTrue(isinstance(qs, dict))
131 self.assertTrue(isinstance(qs, dict))
127 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
132 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
128 allqs = self.client.queue_status()
133 allqs = self.client.queue_status()
129 self.assertTrue(isinstance(allqs, dict))
134 self.assertTrue(isinstance(allqs, dict))
130 self.assertEquals(sorted(allqs.keys()), sorted(self.client.ids + ['unassigned']))
135 self.assertEquals(sorted(allqs.keys()), sorted(self.client.ids + ['unassigned']))
131 unassigned = allqs.pop('unassigned')
136 unassigned = allqs.pop('unassigned')
132 for eid,qs in allqs.items():
137 for eid,qs in allqs.items():
133 self.assertTrue(isinstance(qs, dict))
138 self.assertTrue(isinstance(qs, dict))
134 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
139 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
135
140
136 def test_shutdown(self):
141 def test_shutdown(self):
137 # self.addEngine(4)
142 # self.addEngine(4)
138 ids = self.client.ids
143 ids = self.client.ids
139 id0 = ids[0]
144 id0 = ids[0]
140 self.client.shutdown(id0, block=True)
145 self.client.shutdown(id0, block=True)
141 while id0 in self.client.ids:
146 while id0 in self.client.ids:
142 time.sleep(0.1)
147 time.sleep(0.1)
143 self.client.spin()
148 self.client.spin()
144
149
145 self.assertRaises(IndexError, lambda : self.client[id0])
150 self.assertRaises(IndexError, lambda : self.client[id0])
146
151
147 def test_result_status(self):
152 def test_result_status(self):
148 pass
153 pass
149 # to be written
154 # to be written
150
155
151 def test_db_query_dt(self):
156 def test_db_query_dt(self):
152 """test db query by date"""
157 """test db query by date"""
153 hist = self.client.hub_history()
158 hist = self.client.hub_history()
154 middle = self.client.db_query({'msg_id' : hist[len(hist)/2]})[0]
159 middle = self.client.db_query({'msg_id' : hist[len(hist)/2]})[0]
155 tic = middle['submitted']
160 tic = middle['submitted']
156 before = self.client.db_query({'submitted' : {'$lt' : tic}})
161 before = self.client.db_query({'submitted' : {'$lt' : tic}})
157 after = self.client.db_query({'submitted' : {'$gte' : tic}})
162 after = self.client.db_query({'submitted' : {'$gte' : tic}})
158 self.assertEquals(len(before)+len(after),len(hist))
163 self.assertEquals(len(before)+len(after),len(hist))
159 for b in before:
164 for b in before:
160 self.assertTrue(b['submitted'] < tic)
165 self.assertTrue(b['submitted'] < tic)
161 for a in after:
166 for a in after:
162 self.assertTrue(a['submitted'] >= tic)
167 self.assertTrue(a['submitted'] >= tic)
163 same = self.client.db_query({'submitted' : tic})
168 same = self.client.db_query({'submitted' : tic})
164 for s in same:
169 for s in same:
165 self.assertTrue(s['submitted'] == tic)
170 self.assertTrue(s['submitted'] == tic)
166
171
167 def test_db_query_keys(self):
172 def test_db_query_keys(self):
168 """test extracting subset of record keys"""
173 """test extracting subset of record keys"""
169 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
174 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
170 for rec in found:
175 for rec in found:
171 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
176 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
172
177
173 def test_db_query_msg_id(self):
178 def test_db_query_msg_id(self):
174 """ensure msg_id is always in db queries"""
179 """ensure msg_id is always in db queries"""
175 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
180 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
176 for rec in found:
181 for rec in found:
177 self.assertTrue('msg_id' in rec.keys())
182 self.assertTrue('msg_id' in rec.keys())
178 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
183 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
179 for rec in found:
184 for rec in found:
180 self.assertTrue('msg_id' in rec.keys())
185 self.assertTrue('msg_id' in rec.keys())
181 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
186 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
182 for rec in found:
187 for rec in found:
183 self.assertTrue('msg_id' in rec.keys())
188 self.assertTrue('msg_id' in rec.keys())
184
189
185 def test_db_query_in(self):
190 def test_db_query_in(self):
186 """test db query with '$in','$nin' operators"""
191 """test db query with '$in','$nin' operators"""
187 hist = self.client.hub_history()
192 hist = self.client.hub_history()
188 even = hist[::2]
193 even = hist[::2]
189 odd = hist[1::2]
194 odd = hist[1::2]
190 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
195 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
191 found = [ r['msg_id'] for r in recs ]
196 found = [ r['msg_id'] for r in recs ]
192 self.assertEquals(set(even), set(found))
197 self.assertEquals(set(even), set(found))
193 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
198 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
194 found = [ r['msg_id'] for r in recs ]
199 found = [ r['msg_id'] for r in recs ]
195 self.assertEquals(set(odd), set(found))
200 self.assertEquals(set(odd), set(found))
196
201
197 def test_hub_history(self):
202 def test_hub_history(self):
198 hist = self.client.hub_history()
203 hist = self.client.hub_history()
199 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
204 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
200 recdict = {}
205 recdict = {}
201 for rec in recs:
206 for rec in recs:
202 recdict[rec['msg_id']] = rec
207 recdict[rec['msg_id']] = rec
203
208
204 latest = datetime(1984,1,1)
209 latest = datetime(1984,1,1)
205 for msg_id in hist:
210 for msg_id in hist:
206 rec = recdict[msg_id]
211 rec = recdict[msg_id]
207 newt = rec['submitted']
212 newt = rec['submitted']
208 self.assertTrue(newt >= latest)
213 self.assertTrue(newt >= latest)
209 latest = newt
214 latest = newt
210 ar = self.client[-1].apply_async(lambda : 1)
215 ar = self.client[-1].apply_async(lambda : 1)
211 ar.get()
216 ar.get()
212 time.sleep(0.25)
217 time.sleep(0.25)
213 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
218 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
214
219
215 def test_resubmit(self):
220 def test_resubmit(self):
216 def f():
221 def f():
217 import random
222 import random
218 return random.random()
223 return random.random()
219 v = self.client.load_balanced_view()
224 v = self.client.load_balanced_view()
220 ar = v.apply_async(f)
225 ar = v.apply_async(f)
221 r1 = ar.get(1)
226 r1 = ar.get(1)
222 ahr = self.client.resubmit(ar.msg_ids)
227 ahr = self.client.resubmit(ar.msg_ids)
223 r2 = ahr.get(1)
228 r2 = ahr.get(1)
224 self.assertFalse(r1 == r2)
229 self.assertFalse(r1 == r2)
225
230
226 def test_resubmit_inflight(self):
231 def test_resubmit_inflight(self):
227 """ensure ValueError on resubmit of inflight task"""
232 """ensure ValueError on resubmit of inflight task"""
228 v = self.client.load_balanced_view()
233 v = self.client.load_balanced_view()
229 ar = v.apply_async(time.sleep,1)
234 ar = v.apply_async(time.sleep,1)
230 # give the message a chance to arrive
235 # give the message a chance to arrive
231 time.sleep(0.2)
236 time.sleep(0.2)
232 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
237 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
233 ar.get(2)
238 ar.get(2)
234
239
235 def test_resubmit_badkey(self):
240 def test_resubmit_badkey(self):
236 """ensure KeyError on resubmit of nonexistant task"""
241 """ensure KeyError on resubmit of nonexistant task"""
237 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
242 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
238
243
239 def test_purge_results(self):
244 def test_purge_results(self):
240 hist = self.client.hub_history()
245 hist = self.client.hub_history()
241 self.client.purge_results(hist)
246 self.client.purge_results(hist)
242 newhist = self.client.hub_history()
247 newhist = self.client.hub_history()
243 self.assertTrue(len(newhist) == 0)
248 self.assertTrue(len(newhist) == 0)
244
249
@@ -1,173 +1,178 b''
1 """Tests for db backends"""
1 """Tests for db backends
2
3 Authors:
4
5 * Min RK
6 """
2
7
3 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
5 #
10 #
6 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
9
14
10 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
11 # Imports
16 # Imports
12 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
13
18
14
19
15 import tempfile
20 import tempfile
16 import time
21 import time
17
22
18 from datetime import datetime, timedelta
23 from datetime import datetime, timedelta
19 from unittest import TestCase
24 from unittest import TestCase
20
25
21 from nose import SkipTest
26 from nose import SkipTest
22
27
23 from IPython.parallel import error
28 from IPython.parallel import error
24 from IPython.parallel.controller.dictdb import DictDB
29 from IPython.parallel.controller.dictdb import DictDB
25 from IPython.parallel.controller.sqlitedb import SQLiteDB
30 from IPython.parallel.controller.sqlitedb import SQLiteDB
26 from IPython.parallel.controller.hub import init_record, empty_record
31 from IPython.parallel.controller.hub import init_record, empty_record
27
32
28 from IPython.zmq.session import Session
33 from IPython.zmq.session import Session
29
34
30
35
31 #-------------------------------------------------------------------------------
36 #-------------------------------------------------------------------------------
32 # TestCases
37 # TestCases
33 #-------------------------------------------------------------------------------
38 #-------------------------------------------------------------------------------
34
39
35 class TestDictBackend(TestCase):
40 class TestDictBackend(TestCase):
36 def setUp(self):
41 def setUp(self):
37 self.session = Session()
42 self.session = Session()
38 self.db = self.create_db()
43 self.db = self.create_db()
39 self.load_records(16)
44 self.load_records(16)
40
45
41 def create_db(self):
46 def create_db(self):
42 return DictDB()
47 return DictDB()
43
48
44 def load_records(self, n=1):
49 def load_records(self, n=1):
45 """load n records for testing"""
50 """load n records for testing"""
46 #sleep 1/10 s, to ensure timestamp is different to previous calls
51 #sleep 1/10 s, to ensure timestamp is different to previous calls
47 time.sleep(0.1)
52 time.sleep(0.1)
48 msg_ids = []
53 msg_ids = []
49 for i in range(n):
54 for i in range(n):
50 msg = self.session.msg('apply_request', content=dict(a=5))
55 msg = self.session.msg('apply_request', content=dict(a=5))
51 msg['buffers'] = []
56 msg['buffers'] = []
52 rec = init_record(msg)
57 rec = init_record(msg)
53 msg_ids.append(msg['msg_id'])
58 msg_ids.append(msg['msg_id'])
54 self.db.add_record(msg['msg_id'], rec)
59 self.db.add_record(msg['msg_id'], rec)
55 return msg_ids
60 return msg_ids
56
61
57 def test_add_record(self):
62 def test_add_record(self):
58 before = self.db.get_history()
63 before = self.db.get_history()
59 self.load_records(5)
64 self.load_records(5)
60 after = self.db.get_history()
65 after = self.db.get_history()
61 self.assertEquals(len(after), len(before)+5)
66 self.assertEquals(len(after), len(before)+5)
62 self.assertEquals(after[:-5],before)
67 self.assertEquals(after[:-5],before)
63
68
64 def test_drop_record(self):
69 def test_drop_record(self):
65 msg_id = self.load_records()[-1]
70 msg_id = self.load_records()[-1]
66 rec = self.db.get_record(msg_id)
71 rec = self.db.get_record(msg_id)
67 self.db.drop_record(msg_id)
72 self.db.drop_record(msg_id)
68 self.assertRaises(KeyError,self.db.get_record, msg_id)
73 self.assertRaises(KeyError,self.db.get_record, msg_id)
69
74
70 def _round_to_millisecond(self, dt):
75 def _round_to_millisecond(self, dt):
71 """necessary because mongodb rounds microseconds"""
76 """necessary because mongodb rounds microseconds"""
72 micro = dt.microsecond
77 micro = dt.microsecond
73 extra = int(str(micro)[-3:])
78 extra = int(str(micro)[-3:])
74 return dt - timedelta(microseconds=extra)
79 return dt - timedelta(microseconds=extra)
75
80
76 def test_update_record(self):
81 def test_update_record(self):
77 now = self._round_to_millisecond(datetime.now())
82 now = self._round_to_millisecond(datetime.now())
78 #
83 #
79 msg_id = self.db.get_history()[-1]
84 msg_id = self.db.get_history()[-1]
80 rec1 = self.db.get_record(msg_id)
85 rec1 = self.db.get_record(msg_id)
81 data = {'stdout': 'hello there', 'completed' : now}
86 data = {'stdout': 'hello there', 'completed' : now}
82 self.db.update_record(msg_id, data)
87 self.db.update_record(msg_id, data)
83 rec2 = self.db.get_record(msg_id)
88 rec2 = self.db.get_record(msg_id)
84 self.assertEquals(rec2['stdout'], 'hello there')
89 self.assertEquals(rec2['stdout'], 'hello there')
85 self.assertEquals(rec2['completed'], now)
90 self.assertEquals(rec2['completed'], now)
86 rec1.update(data)
91 rec1.update(data)
87 self.assertEquals(rec1, rec2)
92 self.assertEquals(rec1, rec2)
88
93
89 # def test_update_record_bad(self):
94 # def test_update_record_bad(self):
90 # """test updating nonexistant records"""
95 # """test updating nonexistant records"""
91 # msg_id = str(uuid.uuid4())
96 # msg_id = str(uuid.uuid4())
92 # data = {'stdout': 'hello there'}
97 # data = {'stdout': 'hello there'}
93 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
98 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
94
99
95 def test_find_records_dt(self):
100 def test_find_records_dt(self):
96 """test finding records by date"""
101 """test finding records by date"""
97 hist = self.db.get_history()
102 hist = self.db.get_history()
98 middle = self.db.get_record(hist[len(hist)/2])
103 middle = self.db.get_record(hist[len(hist)/2])
99 tic = middle['submitted']
104 tic = middle['submitted']
100 before = self.db.find_records({'submitted' : {'$lt' : tic}})
105 before = self.db.find_records({'submitted' : {'$lt' : tic}})
101 after = self.db.find_records({'submitted' : {'$gte' : tic}})
106 after = self.db.find_records({'submitted' : {'$gte' : tic}})
102 self.assertEquals(len(before)+len(after),len(hist))
107 self.assertEquals(len(before)+len(after),len(hist))
103 for b in before:
108 for b in before:
104 self.assertTrue(b['submitted'] < tic)
109 self.assertTrue(b['submitted'] < tic)
105 for a in after:
110 for a in after:
106 self.assertTrue(a['submitted'] >= tic)
111 self.assertTrue(a['submitted'] >= tic)
107 same = self.db.find_records({'submitted' : tic})
112 same = self.db.find_records({'submitted' : tic})
108 for s in same:
113 for s in same:
109 self.assertTrue(s['submitted'] == tic)
114 self.assertTrue(s['submitted'] == tic)
110
115
111 def test_find_records_keys(self):
116 def test_find_records_keys(self):
112 """test extracting subset of record keys"""
117 """test extracting subset of record keys"""
113 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
118 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
114 for rec in found:
119 for rec in found:
115 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
120 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
116
121
117 def test_find_records_msg_id(self):
122 def test_find_records_msg_id(self):
118 """ensure msg_id is always in found records"""
123 """ensure msg_id is always in found records"""
119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
124 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
120 for rec in found:
125 for rec in found:
121 self.assertTrue('msg_id' in rec.keys())
126 self.assertTrue('msg_id' in rec.keys())
122 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
123 for rec in found:
128 for rec in found:
124 self.assertTrue('msg_id' in rec.keys())
129 self.assertTrue('msg_id' in rec.keys())
125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
130 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
126 for rec in found:
131 for rec in found:
127 self.assertTrue('msg_id' in rec.keys())
132 self.assertTrue('msg_id' in rec.keys())
128
133
129 def test_find_records_in(self):
134 def test_find_records_in(self):
130 """test finding records with '$in','$nin' operators"""
135 """test finding records with '$in','$nin' operators"""
131 hist = self.db.get_history()
136 hist = self.db.get_history()
132 even = hist[::2]
137 even = hist[::2]
133 odd = hist[1::2]
138 odd = hist[1::2]
134 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
139 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
135 found = [ r['msg_id'] for r in recs ]
140 found = [ r['msg_id'] for r in recs ]
136 self.assertEquals(set(even), set(found))
141 self.assertEquals(set(even), set(found))
137 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
142 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
138 found = [ r['msg_id'] for r in recs ]
143 found = [ r['msg_id'] for r in recs ]
139 self.assertEquals(set(odd), set(found))
144 self.assertEquals(set(odd), set(found))
140
145
141 def test_get_history(self):
146 def test_get_history(self):
142 msg_ids = self.db.get_history()
147 msg_ids = self.db.get_history()
143 latest = datetime(1984,1,1)
148 latest = datetime(1984,1,1)
144 for msg_id in msg_ids:
149 for msg_id in msg_ids:
145 rec = self.db.get_record(msg_id)
150 rec = self.db.get_record(msg_id)
146 newt = rec['submitted']
151 newt = rec['submitted']
147 self.assertTrue(newt >= latest)
152 self.assertTrue(newt >= latest)
148 latest = newt
153 latest = newt
149 msg_id = self.load_records(1)[-1]
154 msg_id = self.load_records(1)[-1]
150 self.assertEquals(self.db.get_history()[-1],msg_id)
155 self.assertEquals(self.db.get_history()[-1],msg_id)
151
156
152 def test_datetime(self):
157 def test_datetime(self):
153 """get/set timestamps with datetime objects"""
158 """get/set timestamps with datetime objects"""
154 msg_id = self.db.get_history()[-1]
159 msg_id = self.db.get_history()[-1]
155 rec = self.db.get_record(msg_id)
160 rec = self.db.get_record(msg_id)
156 self.assertTrue(isinstance(rec['submitted'], datetime))
161 self.assertTrue(isinstance(rec['submitted'], datetime))
157 self.db.update_record(msg_id, dict(completed=datetime.now()))
162 self.db.update_record(msg_id, dict(completed=datetime.now()))
158 rec = self.db.get_record(msg_id)
163 rec = self.db.get_record(msg_id)
159 self.assertTrue(isinstance(rec['completed'], datetime))
164 self.assertTrue(isinstance(rec['completed'], datetime))
160
165
161 def test_drop_matching(self):
166 def test_drop_matching(self):
162 msg_ids = self.load_records(10)
167 msg_ids = self.load_records(10)
163 query = {'msg_id' : {'$in':msg_ids}}
168 query = {'msg_id' : {'$in':msg_ids}}
164 self.db.drop_matching_records(query)
169 self.db.drop_matching_records(query)
165 recs = self.db.find_records(query)
170 recs = self.db.find_records(query)
166 self.assertTrue(len(recs)==0)
171 self.assertTrue(len(recs)==0)
167
172
168 class TestSQLiteBackend(TestDictBackend):
173 class TestSQLiteBackend(TestDictBackend):
169 def create_db(self):
174 def create_db(self):
170 return SQLiteDB(location=tempfile.gettempdir())
175 return SQLiteDB(location=tempfile.gettempdir())
171
176
172 def tearDown(self):
177 def tearDown(self):
173 self.db._db.close()
178 self.db._db.close()
@@ -1,101 +1,106 b''
1 """Tests for dependency.py"""
1 """Tests for dependency.py
2
3 Authors:
4
5 * Min RK
6 """
2
7
3 __docformat__ = "restructuredtext en"
8 __docformat__ = "restructuredtext en"
4
9
5 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
6 # Copyright (C) 2011 The IPython Development Team
11 # Copyright (C) 2011 The IPython Development Team
7 #
12 #
8 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
9 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
10 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
11
16
12 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
13 # Imports
18 # Imports
14 #-------------------------------------------------------------------------------
19 #-------------------------------------------------------------------------------
15
20
16 # import
21 # import
17 import os
22 import os
18
23
19 from IPython.utils.pickleutil import can, uncan
24 from IPython.utils.pickleutil import can, uncan
20
25
21 import IPython.parallel as pmod
26 import IPython.parallel as pmod
22 from IPython.parallel.util import interactive
27 from IPython.parallel.util import interactive
23
28
24 from IPython.parallel.tests import add_engines
29 from IPython.parallel.tests import add_engines
25 from .clienttest import ClusterTestCase
30 from .clienttest import ClusterTestCase
26
31
27 def setup():
32 def setup():
28 add_engines(1)
33 add_engines(1)
29
34
30 @pmod.require('time')
35 @pmod.require('time')
31 def wait(n):
36 def wait(n):
32 time.sleep(n)
37 time.sleep(n)
33 return n
38 return n
34
39
35 mixed = map(str, range(10))
40 mixed = map(str, range(10))
36 completed = map(str, range(0,10,2))
41 completed = map(str, range(0,10,2))
37 failed = map(str, range(1,10,2))
42 failed = map(str, range(1,10,2))
38
43
39 class DependencyTest(ClusterTestCase):
44 class DependencyTest(ClusterTestCase):
40
45
41 def setUp(self):
46 def setUp(self):
42 ClusterTestCase.setUp(self)
47 ClusterTestCase.setUp(self)
43 self.user_ns = {'__builtins__' : __builtins__}
48 self.user_ns = {'__builtins__' : __builtins__}
44 self.view = self.client.load_balanced_view()
49 self.view = self.client.load_balanced_view()
45 self.dview = self.client[-1]
50 self.dview = self.client[-1]
46 self.succeeded = set(map(str, range(0,25,2)))
51 self.succeeded = set(map(str, range(0,25,2)))
47 self.failed = set(map(str, range(1,25,2)))
52 self.failed = set(map(str, range(1,25,2)))
48
53
49 def assertMet(self, dep):
54 def assertMet(self, dep):
50 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
55 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
51
56
52 def assertUnmet(self, dep):
57 def assertUnmet(self, dep):
53 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
58 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
54
59
55 def assertUnreachable(self, dep):
60 def assertUnreachable(self, dep):
56 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
61 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
57
62
58 def assertReachable(self, dep):
63 def assertReachable(self, dep):
59 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
64 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
60
65
61 def cancan(self, f):
66 def cancan(self, f):
62 """decorator to pass through canning into self.user_ns"""
67 """decorator to pass through canning into self.user_ns"""
63 return uncan(can(f), self.user_ns)
68 return uncan(can(f), self.user_ns)
64
69
65 def test_require_imports(self):
70 def test_require_imports(self):
66 """test that @require imports names"""
71 """test that @require imports names"""
67 @self.cancan
72 @self.cancan
68 @pmod.require('urllib')
73 @pmod.require('urllib')
69 @interactive
74 @interactive
70 def encode(dikt):
75 def encode(dikt):
71 return urllib.urlencode(dikt)
76 return urllib.urlencode(dikt)
72 # must pass through canning to properly connect namespaces
77 # must pass through canning to properly connect namespaces
73 self.assertEquals(encode(dict(a=5)), 'a=5')
78 self.assertEquals(encode(dict(a=5)), 'a=5')
74
79
75 def test_success_only(self):
80 def test_success_only(self):
76 dep = pmod.Dependency(mixed, success=True, failure=False)
81 dep = pmod.Dependency(mixed, success=True, failure=False)
77 self.assertUnmet(dep)
82 self.assertUnmet(dep)
78 self.assertUnreachable(dep)
83 self.assertUnreachable(dep)
79 dep.all=False
84 dep.all=False
80 self.assertMet(dep)
85 self.assertMet(dep)
81 self.assertReachable(dep)
86 self.assertReachable(dep)
82 dep = pmod.Dependency(completed, success=True, failure=False)
87 dep = pmod.Dependency(completed, success=True, failure=False)
83 self.assertMet(dep)
88 self.assertMet(dep)
84 self.assertReachable(dep)
89 self.assertReachable(dep)
85 dep.all=False
90 dep.all=False
86 self.assertMet(dep)
91 self.assertMet(dep)
87 self.assertReachable(dep)
92 self.assertReachable(dep)
88
93
89 def test_failure_only(self):
94 def test_failure_only(self):
90 dep = pmod.Dependency(mixed, success=False, failure=True)
95 dep = pmod.Dependency(mixed, success=False, failure=True)
91 self.assertUnmet(dep)
96 self.assertUnmet(dep)
92 self.assertUnreachable(dep)
97 self.assertUnreachable(dep)
93 dep.all=False
98 dep.all=False
94 self.assertMet(dep)
99 self.assertMet(dep)
95 self.assertReachable(dep)
100 self.assertReachable(dep)
96 dep = pmod.Dependency(completed, success=False, failure=True)
101 dep = pmod.Dependency(completed, success=False, failure=True)
97 self.assertUnmet(dep)
102 self.assertUnmet(dep)
98 self.assertUnreachable(dep)
103 self.assertUnreachable(dep)
99 dep.all=False
104 dep.all=False
100 self.assertUnmet(dep)
105 self.assertUnmet(dep)
101 self.assertUnreachable(dep)
106 self.assertUnreachable(dep)
@@ -1,120 +1,125 b''
1 """test LoadBalancedView objects"""
1 """test LoadBalancedView objects
2
3 Authors:
4
5 * Min RK
6 """
2 # -*- coding: utf-8 -*-
7 # -*- coding: utf-8 -*-
3 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
5 #
10 #
6 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
9
14
10 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
11 # Imports
16 # Imports
12 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
13
18
14 import sys
19 import sys
15 import time
20 import time
16
21
17 import zmq
22 import zmq
18
23
19 from IPython import parallel as pmod
24 from IPython import parallel as pmod
20 from IPython.parallel import error
25 from IPython.parallel import error
21
26
22 from IPython.parallel.tests import add_engines
27 from IPython.parallel.tests import add_engines
23
28
24 from .clienttest import ClusterTestCase, crash, wait, skip_without
29 from .clienttest import ClusterTestCase, crash, wait, skip_without
25
30
26 def setup():
31 def setup():
27 add_engines(3)
32 add_engines(3)
28
33
29 class TestLoadBalancedView(ClusterTestCase):
34 class TestLoadBalancedView(ClusterTestCase):
30
35
31 def setUp(self):
36 def setUp(self):
32 ClusterTestCase.setUp(self)
37 ClusterTestCase.setUp(self)
33 self.view = self.client.load_balanced_view()
38 self.view = self.client.load_balanced_view()
34
39
35 def test_z_crash_task(self):
40 def test_z_crash_task(self):
36 """test graceful handling of engine death (balanced)"""
41 """test graceful handling of engine death (balanced)"""
37 # self.add_engines(1)
42 # self.add_engines(1)
38 ar = self.view.apply_async(crash)
43 ar = self.view.apply_async(crash)
39 self.assertRaisesRemote(error.EngineError, ar.get, 10)
44 self.assertRaisesRemote(error.EngineError, ar.get, 10)
40 eid = ar.engine_id
45 eid = ar.engine_id
41 tic = time.time()
46 tic = time.time()
42 while eid in self.client.ids and time.time()-tic < 5:
47 while eid in self.client.ids and time.time()-tic < 5:
43 time.sleep(.01)
48 time.sleep(.01)
44 self.client.spin()
49 self.client.spin()
45 self.assertFalse(eid in self.client.ids, "Engine should have died")
50 self.assertFalse(eid in self.client.ids, "Engine should have died")
46
51
47 def test_map(self):
52 def test_map(self):
48 def f(x):
53 def f(x):
49 return x**2
54 return x**2
50 data = range(16)
55 data = range(16)
51 r = self.view.map_sync(f, data)
56 r = self.view.map_sync(f, data)
52 self.assertEquals(r, map(f, data))
57 self.assertEquals(r, map(f, data))
53
58
54 def test_abort(self):
59 def test_abort(self):
55 view = self.view
60 view = self.view
56 ar = self.client[:].apply_async(time.sleep, .5)
61 ar = self.client[:].apply_async(time.sleep, .5)
57 ar2 = view.apply_async(lambda : 2)
62 ar2 = view.apply_async(lambda : 2)
58 ar3 = view.apply_async(lambda : 3)
63 ar3 = view.apply_async(lambda : 3)
59 view.abort(ar2)
64 view.abort(ar2)
60 view.abort(ar3.msg_ids)
65 view.abort(ar3.msg_ids)
61 self.assertRaises(error.TaskAborted, ar2.get)
66 self.assertRaises(error.TaskAborted, ar2.get)
62 self.assertRaises(error.TaskAborted, ar3.get)
67 self.assertRaises(error.TaskAborted, ar3.get)
63
68
64 def test_retries(self):
69 def test_retries(self):
65 add_engines(3)
70 add_engines(3)
66 view = self.view
71 view = self.view
67 view.timeout = 1 # prevent hang if this doesn't behave
72 view.timeout = 1 # prevent hang if this doesn't behave
68 def fail():
73 def fail():
69 assert False
74 assert False
70 for r in range(len(self.client)-1):
75 for r in range(len(self.client)-1):
71 with view.temp_flags(retries=r):
76 with view.temp_flags(retries=r):
72 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
77 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
73
78
74 with view.temp_flags(retries=len(self.client), timeout=0.25):
79 with view.temp_flags(retries=len(self.client), timeout=0.25):
75 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
80 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
76
81
77 def test_invalid_dependency(self):
82 def test_invalid_dependency(self):
78 view = self.view
83 view = self.view
79 with view.temp_flags(after='12345'):
84 with view.temp_flags(after='12345'):
80 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
85 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
81
86
82 def test_impossible_dependency(self):
87 def test_impossible_dependency(self):
83 if len(self.client) < 2:
88 if len(self.client) < 2:
84 add_engines(2)
89 add_engines(2)
85 view = self.client.load_balanced_view()
90 view = self.client.load_balanced_view()
86 ar1 = view.apply_async(lambda : 1)
91 ar1 = view.apply_async(lambda : 1)
87 ar1.get()
92 ar1.get()
88 e1 = ar1.engine_id
93 e1 = ar1.engine_id
89 e2 = e1
94 e2 = e1
90 while e2 == e1:
95 while e2 == e1:
91 ar2 = view.apply_async(lambda : 1)
96 ar2 = view.apply_async(lambda : 1)
92 ar2.get()
97 ar2.get()
93 e2 = ar2.engine_id
98 e2 = ar2.engine_id
94
99
95 with view.temp_flags(follow=[ar1, ar2]):
100 with view.temp_flags(follow=[ar1, ar2]):
96 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
101 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
97
102
98
103
99 def test_follow(self):
104 def test_follow(self):
100 ar = self.view.apply_async(lambda : 1)
105 ar = self.view.apply_async(lambda : 1)
101 ar.get()
106 ar.get()
102 ars = []
107 ars = []
103 first_id = ar.engine_id
108 first_id = ar.engine_id
104
109
105 self.view.follow = ar
110 self.view.follow = ar
106 for i in range(5):
111 for i in range(5):
107 ars.append(self.view.apply_async(lambda : 1))
112 ars.append(self.view.apply_async(lambda : 1))
108 self.view.wait(ars)
113 self.view.wait(ars)
109 for ar in ars:
114 for ar in ars:
110 self.assertEquals(ar.engine_id, first_id)
115 self.assertEquals(ar.engine_id, first_id)
111
116
112 def test_after(self):
117 def test_after(self):
113 view = self.view
118 view = self.view
114 ar = view.apply_async(time.sleep, 0.5)
119 ar = view.apply_async(time.sleep, 0.5)
115 with view.temp_flags(after=ar):
120 with view.temp_flags(after=ar):
116 ar2 = view.apply_async(lambda : 1)
121 ar2 = view.apply_async(lambda : 1)
117
122
118 ar.wait()
123 ar.wait()
119 ar2.wait()
124 ar2.wait()
120 self.assertTrue(ar2.started > ar.completed)
125 self.assertTrue(ar2.started > ar.completed)
@@ -1,37 +1,42 b''
1 """Tests for mongodb backend"""
1 """Tests for mongodb backend
2
3 Authors:
4
5 * Min RK
6 """
2
7
3 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
5 #
10 #
6 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
9
14
10 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
11 # Imports
16 # Imports
12 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
13
18
14 from nose import SkipTest
19 from nose import SkipTest
15
20
16 from pymongo import Connection
21 from pymongo import Connection
17 from IPython.parallel.controller.mongodb import MongoDB
22 from IPython.parallel.controller.mongodb import MongoDB
18
23
19 from . import test_db
24 from . import test_db
20
25
21 try:
26 try:
22 c = Connection()
27 c = Connection()
23 except Exception:
28 except Exception:
24 c=None
29 c=None
25
30
26 class TestMongoBackend(test_db.TestDictBackend):
31 class TestMongoBackend(test_db.TestDictBackend):
27 """MongoDB backend tests"""
32 """MongoDB backend tests"""
28
33
29 def create_db(self):
34 def create_db(self):
30 try:
35 try:
31 return MongoDB(database='iptestdb', _connection=c)
36 return MongoDB(database='iptestdb', _connection=c)
32 except Exception:
37 except Exception:
33 raise SkipTest("Couldn't connect to mongodb")
38 raise SkipTest("Couldn't connect to mongodb")
34
39
35 def teardown(self):
40 def teardown(self):
36 if c is not None:
41 if c is not None:
37 c.drop_database('iptestdb')
42 c.drop_database('iptestdb')
@@ -1,108 +1,113 b''
1 """test serialization with newserialized"""
1 """test serialization with newserialized
2
3 Authors:
4
5 * Min RK
6 """
2
7
3 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
5 #
10 #
6 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
9
14
10 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
11 # Imports
16 # Imports
12 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
13
18
14 from unittest import TestCase
19 from unittest import TestCase
15
20
16 from IPython.testing.decorators import parametric
21 from IPython.testing.decorators import parametric
17 from IPython.utils import newserialized as ns
22 from IPython.utils import newserialized as ns
18 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
23 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
19 from IPython.parallel.tests.clienttest import skip_without
24 from IPython.parallel.tests.clienttest import skip_without
20
25
21
26
22 class CanningTestCase(TestCase):
27 class CanningTestCase(TestCase):
23 def test_canning(self):
28 def test_canning(self):
24 d = dict(a=5,b=6)
29 d = dict(a=5,b=6)
25 cd = can(d)
30 cd = can(d)
26 self.assertTrue(isinstance(cd, dict))
31 self.assertTrue(isinstance(cd, dict))
27
32
28 def test_canned_function(self):
33 def test_canned_function(self):
29 f = lambda : 7
34 f = lambda : 7
30 cf = can(f)
35 cf = can(f)
31 self.assertTrue(isinstance(cf, CannedFunction))
36 self.assertTrue(isinstance(cf, CannedFunction))
32
37
33 @parametric
38 @parametric
34 def test_can_roundtrip(cls):
39 def test_can_roundtrip(cls):
35 objs = [
40 objs = [
36 dict(),
41 dict(),
37 set(),
42 set(),
38 list(),
43 list(),
39 ['a',1,['a',1],u'e'],
44 ['a',1,['a',1],u'e'],
40 ]
45 ]
41 return map(cls.run_roundtrip, objs)
46 return map(cls.run_roundtrip, objs)
42
47
43 @classmethod
48 @classmethod
44 def run_roundtrip(self, obj):
49 def run_roundtrip(self, obj):
45 o = uncan(can(obj))
50 o = uncan(can(obj))
46 assert o == obj, "failed assertion: %r == %r"%(o,obj)
51 assert o == obj, "failed assertion: %r == %r"%(o,obj)
47
52
48 def test_serialized_interfaces(self):
53 def test_serialized_interfaces(self):
49
54
50 us = {'a':10, 'b':range(10)}
55 us = {'a':10, 'b':range(10)}
51 s = ns.serialize(us)
56 s = ns.serialize(us)
52 uus = ns.unserialize(s)
57 uus = ns.unserialize(s)
53 self.assertTrue(isinstance(s, ns.SerializeIt))
58 self.assertTrue(isinstance(s, ns.SerializeIt))
54 self.assertEquals(uus, us)
59 self.assertEquals(uus, us)
55
60
56 def test_pickle_serialized(self):
61 def test_pickle_serialized(self):
57 obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
62 obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
58 original = ns.UnSerialized(obj)
63 original = ns.UnSerialized(obj)
59 originalSer = ns.SerializeIt(original)
64 originalSer = ns.SerializeIt(original)
60 firstData = originalSer.getData()
65 firstData = originalSer.getData()
61 firstTD = originalSer.getTypeDescriptor()
66 firstTD = originalSer.getTypeDescriptor()
62 firstMD = originalSer.getMetadata()
67 firstMD = originalSer.getMetadata()
63 self.assertEquals(firstTD, 'pickle')
68 self.assertEquals(firstTD, 'pickle')
64 self.assertEquals(firstMD, {})
69 self.assertEquals(firstMD, {})
65 unSerialized = ns.UnSerializeIt(originalSer)
70 unSerialized = ns.UnSerializeIt(originalSer)
66 secondObj = unSerialized.getObject()
71 secondObj = unSerialized.getObject()
67 for k, v in secondObj.iteritems():
72 for k, v in secondObj.iteritems():
68 self.assertEquals(obj[k], v)
73 self.assertEquals(obj[k], v)
69 secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
74 secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
70 self.assertEquals(firstData, secondSer.getData())
75 self.assertEquals(firstData, secondSer.getData())
71 self.assertEquals(firstTD, secondSer.getTypeDescriptor() )
76 self.assertEquals(firstTD, secondSer.getTypeDescriptor() )
72 self.assertEquals(firstMD, secondSer.getMetadata())
77 self.assertEquals(firstMD, secondSer.getMetadata())
73
78
74 @skip_without('numpy')
79 @skip_without('numpy')
75 def test_ndarray_serialized(self):
80 def test_ndarray_serialized(self):
76 import numpy
81 import numpy
77 a = numpy.linspace(0.0, 1.0, 1000)
82 a = numpy.linspace(0.0, 1.0, 1000)
78 unSer1 = ns.UnSerialized(a)
83 unSer1 = ns.UnSerialized(a)
79 ser1 = ns.SerializeIt(unSer1)
84 ser1 = ns.SerializeIt(unSer1)
80 td = ser1.getTypeDescriptor()
85 td = ser1.getTypeDescriptor()
81 self.assertEquals(td, 'ndarray')
86 self.assertEquals(td, 'ndarray')
82 md = ser1.getMetadata()
87 md = ser1.getMetadata()
83 self.assertEquals(md['shape'], a.shape)
88 self.assertEquals(md['shape'], a.shape)
84 self.assertEquals(md['dtype'], a.dtype.str)
89 self.assertEquals(md['dtype'], a.dtype.str)
85 buff = ser1.getData()
90 buff = ser1.getData()
86 self.assertEquals(buff, numpy.getbuffer(a))
91 self.assertEquals(buff, numpy.getbuffer(a))
87 s = ns.Serialized(buff, td, md)
92 s = ns.Serialized(buff, td, md)
88 final = ns.unserialize(s)
93 final = ns.unserialize(s)
89 self.assertEquals(numpy.getbuffer(a), numpy.getbuffer(final))
94 self.assertEquals(numpy.getbuffer(a), numpy.getbuffer(final))
90 self.assertTrue((a==final).all())
95 self.assertTrue((a==final).all())
91 self.assertEquals(a.dtype.str, final.dtype.str)
96 self.assertEquals(a.dtype.str, final.dtype.str)
92 self.assertEquals(a.shape, final.shape)
97 self.assertEquals(a.shape, final.shape)
93 # test non-copying:
98 # test non-copying:
94 a[2] = 1e9
99 a[2] = 1e9
95 self.assertTrue((a==final).all())
100 self.assertTrue((a==final).all())
96
101
97 def test_uncan_function_globals(self):
102 def test_uncan_function_globals(self):
98 """test that uncanning a module function restores it into its module"""
103 """test that uncanning a module function restores it into its module"""
99 from re import search
104 from re import search
100 cf = can(search)
105 cf = can(search)
101 csearch = uncan(cf)
106 csearch = uncan(cf)
102 self.assertEqual(csearch.__module__, search.__module__)
107 self.assertEqual(csearch.__module__, search.__module__)
103 self.assertNotEqual(csearch('asd', 'asdf'), None)
108 self.assertNotEqual(csearch('asd', 'asdf'), None)
104 csearch = uncan(cf, dict(a=5))
109 csearch = uncan(cf, dict(a=5))
105 self.assertEqual(csearch.__module__, search.__module__)
110 self.assertEqual(csearch.__module__, search.__module__)
106 self.assertNotEqual(csearch('asd', 'asdf'), None)
111 self.assertNotEqual(csearch('asd', 'asdf'), None)
107
112
108 No newline at end of file
113
@@ -1,441 +1,446 b''
1 """test View objects"""
1 """test View objects
2
3 Authors:
4
5 * Min RK
6 """
2 # -*- coding: utf-8 -*-
7 # -*- coding: utf-8 -*-
3 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
5 #
10 #
6 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
9
14
10 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
11 # Imports
16 # Imports
12 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
13
18
14 import sys
19 import sys
15 import time
20 import time
16 from tempfile import mktemp
21 from tempfile import mktemp
17 from StringIO import StringIO
22 from StringIO import StringIO
18
23
19 import zmq
24 import zmq
20
25
21 from IPython import parallel as pmod
26 from IPython import parallel as pmod
22 from IPython.parallel import error
27 from IPython.parallel import error
23 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
28 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
24 from IPython.parallel import DirectView
29 from IPython.parallel import DirectView
25 from IPython.parallel.util import interactive
30 from IPython.parallel.util import interactive
26
31
27 from IPython.parallel.tests import add_engines
32 from IPython.parallel.tests import add_engines
28
33
29 from .clienttest import ClusterTestCase, crash, wait, skip_without
34 from .clienttest import ClusterTestCase, crash, wait, skip_without
30
35
31 def setup():
36 def setup():
32 add_engines(3)
37 add_engines(3)
33
38
34 class TestView(ClusterTestCase):
39 class TestView(ClusterTestCase):
35
40
36 def test_z_crash_mux(self):
41 def test_z_crash_mux(self):
37 """test graceful handling of engine death (direct)"""
42 """test graceful handling of engine death (direct)"""
38 # self.add_engines(1)
43 # self.add_engines(1)
39 eid = self.client.ids[-1]
44 eid = self.client.ids[-1]
40 ar = self.client[eid].apply_async(crash)
45 ar = self.client[eid].apply_async(crash)
41 self.assertRaisesRemote(error.EngineError, ar.get)
46 self.assertRaisesRemote(error.EngineError, ar.get)
42 eid = ar.engine_id
47 eid = ar.engine_id
43 tic = time.time()
48 tic = time.time()
44 while eid in self.client.ids and time.time()-tic < 5:
49 while eid in self.client.ids and time.time()-tic < 5:
45 time.sleep(.01)
50 time.sleep(.01)
46 self.client.spin()
51 self.client.spin()
47 self.assertFalse(eid in self.client.ids, "Engine should have died")
52 self.assertFalse(eid in self.client.ids, "Engine should have died")
48
53
49 def test_push_pull(self):
54 def test_push_pull(self):
50 """test pushing and pulling"""
55 """test pushing and pulling"""
51 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
56 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
52 t = self.client.ids[-1]
57 t = self.client.ids[-1]
53 v = self.client[t]
58 v = self.client[t]
54 push = v.push
59 push = v.push
55 pull = v.pull
60 pull = v.pull
56 v.block=True
61 v.block=True
57 nengines = len(self.client)
62 nengines = len(self.client)
58 push({'data':data})
63 push({'data':data})
59 d = pull('data')
64 d = pull('data')
60 self.assertEquals(d, data)
65 self.assertEquals(d, data)
61 self.client[:].push({'data':data})
66 self.client[:].push({'data':data})
62 d = self.client[:].pull('data', block=True)
67 d = self.client[:].pull('data', block=True)
63 self.assertEquals(d, nengines*[data])
68 self.assertEquals(d, nengines*[data])
64 ar = push({'data':data}, block=False)
69 ar = push({'data':data}, block=False)
65 self.assertTrue(isinstance(ar, AsyncResult))
70 self.assertTrue(isinstance(ar, AsyncResult))
66 r = ar.get()
71 r = ar.get()
67 ar = self.client[:].pull('data', block=False)
72 ar = self.client[:].pull('data', block=False)
68 self.assertTrue(isinstance(ar, AsyncResult))
73 self.assertTrue(isinstance(ar, AsyncResult))
69 r = ar.get()
74 r = ar.get()
70 self.assertEquals(r, nengines*[data])
75 self.assertEquals(r, nengines*[data])
71 self.client[:].push(dict(a=10,b=20))
76 self.client[:].push(dict(a=10,b=20))
72 r = self.client[:].pull(('a','b'), block=True)
77 r = self.client[:].pull(('a','b'), block=True)
73 self.assertEquals(r, nengines*[[10,20]])
78 self.assertEquals(r, nengines*[[10,20]])
74
79
75 def test_push_pull_function(self):
80 def test_push_pull_function(self):
76 "test pushing and pulling functions"
81 "test pushing and pulling functions"
77 def testf(x):
82 def testf(x):
78 return 2.0*x
83 return 2.0*x
79
84
80 t = self.client.ids[-1]
85 t = self.client.ids[-1]
81 v = self.client[t]
86 v = self.client[t]
82 v.block=True
87 v.block=True
83 push = v.push
88 push = v.push
84 pull = v.pull
89 pull = v.pull
85 execute = v.execute
90 execute = v.execute
86 push({'testf':testf})
91 push({'testf':testf})
87 r = pull('testf')
92 r = pull('testf')
88 self.assertEqual(r(1.0), testf(1.0))
93 self.assertEqual(r(1.0), testf(1.0))
89 execute('r = testf(10)')
94 execute('r = testf(10)')
90 r = pull('r')
95 r = pull('r')
91 self.assertEquals(r, testf(10))
96 self.assertEquals(r, testf(10))
92 ar = self.client[:].push({'testf':testf}, block=False)
97 ar = self.client[:].push({'testf':testf}, block=False)
93 ar.get()
98 ar.get()
94 ar = self.client[:].pull('testf', block=False)
99 ar = self.client[:].pull('testf', block=False)
95 rlist = ar.get()
100 rlist = ar.get()
96 for r in rlist:
101 for r in rlist:
97 self.assertEqual(r(1.0), testf(1.0))
102 self.assertEqual(r(1.0), testf(1.0))
98 execute("def g(x): return x*x")
103 execute("def g(x): return x*x")
99 r = pull(('testf','g'))
104 r = pull(('testf','g'))
100 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
105 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
101
106
102 def test_push_function_globals(self):
107 def test_push_function_globals(self):
103 """test that pushed functions have access to globals"""
108 """test that pushed functions have access to globals"""
104 @interactive
109 @interactive
105 def geta():
110 def geta():
106 return a
111 return a
107 # self.add_engines(1)
112 # self.add_engines(1)
108 v = self.client[-1]
113 v = self.client[-1]
109 v.block=True
114 v.block=True
110 v['f'] = geta
115 v['f'] = geta
111 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
116 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
112 v.execute('a=5')
117 v.execute('a=5')
113 v.execute('b=f()')
118 v.execute('b=f()')
114 self.assertEquals(v['b'], 5)
119 self.assertEquals(v['b'], 5)
115
120
116 def test_push_function_defaults(self):
121 def test_push_function_defaults(self):
117 """test that pushed functions preserve default args"""
122 """test that pushed functions preserve default args"""
118 def echo(a=10):
123 def echo(a=10):
119 return a
124 return a
120 v = self.client[-1]
125 v = self.client[-1]
121 v.block=True
126 v.block=True
122 v['f'] = echo
127 v['f'] = echo
123 v.execute('b=f()')
128 v.execute('b=f()')
124 self.assertEquals(v['b'], 10)
129 self.assertEquals(v['b'], 10)
125
130
126 def test_get_result(self):
131 def test_get_result(self):
127 """test getting results from the Hub."""
132 """test getting results from the Hub."""
128 c = pmod.Client(profile='iptest')
133 c = pmod.Client(profile='iptest')
129 # self.add_engines(1)
134 # self.add_engines(1)
130 t = c.ids[-1]
135 t = c.ids[-1]
131 v = c[t]
136 v = c[t]
132 v2 = self.client[t]
137 v2 = self.client[t]
133 ar = v.apply_async(wait, 1)
138 ar = v.apply_async(wait, 1)
134 # give the monitor time to notice the message
139 # give the monitor time to notice the message
135 time.sleep(.25)
140 time.sleep(.25)
136 ahr = v2.get_result(ar.msg_ids)
141 ahr = v2.get_result(ar.msg_ids)
137 self.assertTrue(isinstance(ahr, AsyncHubResult))
142 self.assertTrue(isinstance(ahr, AsyncHubResult))
138 self.assertEquals(ahr.get(), ar.get())
143 self.assertEquals(ahr.get(), ar.get())
139 ar2 = v2.get_result(ar.msg_ids)
144 ar2 = v2.get_result(ar.msg_ids)
140 self.assertFalse(isinstance(ar2, AsyncHubResult))
145 self.assertFalse(isinstance(ar2, AsyncHubResult))
141 c.spin()
146 c.spin()
142 c.close()
147 c.close()
143
148
144 def test_run_newline(self):
149 def test_run_newline(self):
145 """test that run appends newline to files"""
150 """test that run appends newline to files"""
146 tmpfile = mktemp()
151 tmpfile = mktemp()
147 with open(tmpfile, 'w') as f:
152 with open(tmpfile, 'w') as f:
148 f.write("""def g():
153 f.write("""def g():
149 return 5
154 return 5
150 """)
155 """)
151 v = self.client[-1]
156 v = self.client[-1]
152 v.run(tmpfile, block=True)
157 v.run(tmpfile, block=True)
153 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
158 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
154
159
155 def test_apply_tracked(self):
160 def test_apply_tracked(self):
156 """test tracking for apply"""
161 """test tracking for apply"""
157 # self.add_engines(1)
162 # self.add_engines(1)
158 t = self.client.ids[-1]
163 t = self.client.ids[-1]
159 v = self.client[t]
164 v = self.client[t]
160 v.block=False
165 v.block=False
161 def echo(n=1024*1024, **kwargs):
166 def echo(n=1024*1024, **kwargs):
162 with v.temp_flags(**kwargs):
167 with v.temp_flags(**kwargs):
163 return v.apply(lambda x: x, 'x'*n)
168 return v.apply(lambda x: x, 'x'*n)
164 ar = echo(1, track=False)
169 ar = echo(1, track=False)
165 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
170 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
166 self.assertTrue(ar.sent)
171 self.assertTrue(ar.sent)
167 ar = echo(track=True)
172 ar = echo(track=True)
168 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
173 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
169 self.assertEquals(ar.sent, ar._tracker.done)
174 self.assertEquals(ar.sent, ar._tracker.done)
170 ar._tracker.wait()
175 ar._tracker.wait()
171 self.assertTrue(ar.sent)
176 self.assertTrue(ar.sent)
172
177
173 def test_push_tracked(self):
178 def test_push_tracked(self):
174 t = self.client.ids[-1]
179 t = self.client.ids[-1]
175 ns = dict(x='x'*1024*1024)
180 ns = dict(x='x'*1024*1024)
176 v = self.client[t]
181 v = self.client[t]
177 ar = v.push(ns, block=False, track=False)
182 ar = v.push(ns, block=False, track=False)
178 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
179 self.assertTrue(ar.sent)
184 self.assertTrue(ar.sent)
180
185
181 ar = v.push(ns, block=False, track=True)
186 ar = v.push(ns, block=False, track=True)
182 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
183 self.assertEquals(ar.sent, ar._tracker.done)
188 self.assertEquals(ar.sent, ar._tracker.done)
184 ar._tracker.wait()
189 ar._tracker.wait()
185 self.assertTrue(ar.sent)
190 self.assertTrue(ar.sent)
186 ar.get()
191 ar.get()
187
192
188 def test_scatter_tracked(self):
193 def test_scatter_tracked(self):
189 t = self.client.ids
194 t = self.client.ids
190 x='x'*1024*1024
195 x='x'*1024*1024
191 ar = self.client[t].scatter('x', x, block=False, track=False)
196 ar = self.client[t].scatter('x', x, block=False, track=False)
192 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
193 self.assertTrue(ar.sent)
198 self.assertTrue(ar.sent)
194
199
195 ar = self.client[t].scatter('x', x, block=False, track=True)
200 ar = self.client[t].scatter('x', x, block=False, track=True)
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 self.assertEquals(ar.sent, ar._tracker.done)
202 self.assertEquals(ar.sent, ar._tracker.done)
198 ar._tracker.wait()
203 ar._tracker.wait()
199 self.assertTrue(ar.sent)
204 self.assertTrue(ar.sent)
200 ar.get()
205 ar.get()
201
206
202 def test_remote_reference(self):
207 def test_remote_reference(self):
203 v = self.client[-1]
208 v = self.client[-1]
204 v['a'] = 123
209 v['a'] = 123
205 ra = pmod.Reference('a')
210 ra = pmod.Reference('a')
206 b = v.apply_sync(lambda x: x, ra)
211 b = v.apply_sync(lambda x: x, ra)
207 self.assertEquals(b, 123)
212 self.assertEquals(b, 123)
208
213
209
214
210 def test_scatter_gather(self):
215 def test_scatter_gather(self):
211 view = self.client[:]
216 view = self.client[:]
212 seq1 = range(16)
217 seq1 = range(16)
213 view.scatter('a', seq1)
218 view.scatter('a', seq1)
214 seq2 = view.gather('a', block=True)
219 seq2 = view.gather('a', block=True)
215 self.assertEquals(seq2, seq1)
220 self.assertEquals(seq2, seq1)
216 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
221 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
217
222
218 @skip_without('numpy')
223 @skip_without('numpy')
219 def test_scatter_gather_numpy(self):
224 def test_scatter_gather_numpy(self):
220 import numpy
225 import numpy
221 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
226 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
222 view = self.client[:]
227 view = self.client[:]
223 a = numpy.arange(64)
228 a = numpy.arange(64)
224 view.scatter('a', a)
229 view.scatter('a', a)
225 b = view.gather('a', block=True)
230 b = view.gather('a', block=True)
226 assert_array_equal(b, a)
231 assert_array_equal(b, a)
227
232
228 def test_map(self):
233 def test_map(self):
229 view = self.client[:]
234 view = self.client[:]
230 def f(x):
235 def f(x):
231 return x**2
236 return x**2
232 data = range(16)
237 data = range(16)
233 r = view.map_sync(f, data)
238 r = view.map_sync(f, data)
234 self.assertEquals(r, map(f, data))
239 self.assertEquals(r, map(f, data))
235
240
236 def test_scatterGatherNonblocking(self):
241 def test_scatterGatherNonblocking(self):
237 data = range(16)
242 data = range(16)
238 view = self.client[:]
243 view = self.client[:]
239 view.scatter('a', data, block=False)
244 view.scatter('a', data, block=False)
240 ar = view.gather('a', block=False)
245 ar = view.gather('a', block=False)
241 self.assertEquals(ar.get(), data)
246 self.assertEquals(ar.get(), data)
242
247
243 @skip_without('numpy')
248 @skip_without('numpy')
244 def test_scatter_gather_numpy_nonblocking(self):
249 def test_scatter_gather_numpy_nonblocking(self):
245 import numpy
250 import numpy
246 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
251 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
247 a = numpy.arange(64)
252 a = numpy.arange(64)
248 view = self.client[:]
253 view = self.client[:]
249 ar = view.scatter('a', a, block=False)
254 ar = view.scatter('a', a, block=False)
250 self.assertTrue(isinstance(ar, AsyncResult))
255 self.assertTrue(isinstance(ar, AsyncResult))
251 amr = view.gather('a', block=False)
256 amr = view.gather('a', block=False)
252 self.assertTrue(isinstance(amr, AsyncMapResult))
257 self.assertTrue(isinstance(amr, AsyncMapResult))
253 assert_array_equal(amr.get(), a)
258 assert_array_equal(amr.get(), a)
254
259
255 def test_execute(self):
260 def test_execute(self):
256 view = self.client[:]
261 view = self.client[:]
257 # self.client.debug=True
262 # self.client.debug=True
258 execute = view.execute
263 execute = view.execute
259 ar = execute('c=30', block=False)
264 ar = execute('c=30', block=False)
260 self.assertTrue(isinstance(ar, AsyncResult))
265 self.assertTrue(isinstance(ar, AsyncResult))
261 ar = execute('d=[0,1,2]', block=False)
266 ar = execute('d=[0,1,2]', block=False)
262 self.client.wait(ar, 1)
267 self.client.wait(ar, 1)
263 self.assertEquals(len(ar.get()), len(self.client))
268 self.assertEquals(len(ar.get()), len(self.client))
264 for c in view['c']:
269 for c in view['c']:
265 self.assertEquals(c, 30)
270 self.assertEquals(c, 30)
266
271
267 def test_abort(self):
272 def test_abort(self):
268 view = self.client[-1]
273 view = self.client[-1]
269 ar = view.execute('import time; time.sleep(0.25)', block=False)
274 ar = view.execute('import time; time.sleep(0.25)', block=False)
270 ar2 = view.apply_async(lambda : 2)
275 ar2 = view.apply_async(lambda : 2)
271 ar3 = view.apply_async(lambda : 3)
276 ar3 = view.apply_async(lambda : 3)
272 view.abort(ar2)
277 view.abort(ar2)
273 view.abort(ar3.msg_ids)
278 view.abort(ar3.msg_ids)
274 self.assertRaises(error.TaskAborted, ar2.get)
279 self.assertRaises(error.TaskAborted, ar2.get)
275 self.assertRaises(error.TaskAborted, ar3.get)
280 self.assertRaises(error.TaskAborted, ar3.get)
276
281
277 def test_temp_flags(self):
282 def test_temp_flags(self):
278 view = self.client[-1]
283 view = self.client[-1]
279 view.block=True
284 view.block=True
280 with view.temp_flags(block=False):
285 with view.temp_flags(block=False):
281 self.assertFalse(view.block)
286 self.assertFalse(view.block)
282 self.assertTrue(view.block)
287 self.assertTrue(view.block)
283
288
284 def test_importer(self):
289 def test_importer(self):
285 view = self.client[-1]
290 view = self.client[-1]
286 view.clear(block=True)
291 view.clear(block=True)
287 with view.importer:
292 with view.importer:
288 import re
293 import re
289
294
290 @interactive
295 @interactive
291 def findall(pat, s):
296 def findall(pat, s):
292 # this globals() step isn't necessary in real code
297 # this globals() step isn't necessary in real code
293 # only to prevent a closure in the test
298 # only to prevent a closure in the test
294 re = globals()['re']
299 re = globals()['re']
295 return re.findall(pat, s)
300 return re.findall(pat, s)
296
301
297 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
302 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
298
303
299 # parallel magic tests
304 # parallel magic tests
300
305
301 def test_magic_px_blocking(self):
306 def test_magic_px_blocking(self):
302 ip = get_ipython()
307 ip = get_ipython()
303 v = self.client[-1]
308 v = self.client[-1]
304 v.activate()
309 v.activate()
305 v.block=True
310 v.block=True
306
311
307 ip.magic_px('a=5')
312 ip.magic_px('a=5')
308 self.assertEquals(v['a'], 5)
313 self.assertEquals(v['a'], 5)
309 ip.magic_px('a=10')
314 ip.magic_px('a=10')
310 self.assertEquals(v['a'], 10)
315 self.assertEquals(v['a'], 10)
311 sio = StringIO()
316 sio = StringIO()
312 savestdout = sys.stdout
317 savestdout = sys.stdout
313 sys.stdout = sio
318 sys.stdout = sio
314 ip.magic_px('print a')
319 ip.magic_px('print a')
315 sys.stdout = savestdout
320 sys.stdout = savestdout
316 sio.read()
321 sio.read()
317 self.assertTrue('[stdout:%i]'%v.targets in sio.buf)
322 self.assertTrue('[stdout:%i]'%v.targets in sio.buf)
318 self.assertTrue(sio.buf.rstrip().endswith('10'))
323 self.assertTrue(sio.buf.rstrip().endswith('10'))
319 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
324 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
320
325
321 def test_magic_px_nonblocking(self):
326 def test_magic_px_nonblocking(self):
322 ip = get_ipython()
327 ip = get_ipython()
323 v = self.client[-1]
328 v = self.client[-1]
324 v.activate()
329 v.activate()
325 v.block=False
330 v.block=False
326
331
327 ip.magic_px('a=5')
332 ip.magic_px('a=5')
328 self.assertEquals(v['a'], 5)
333 self.assertEquals(v['a'], 5)
329 ip.magic_px('a=10')
334 ip.magic_px('a=10')
330 self.assertEquals(v['a'], 10)
335 self.assertEquals(v['a'], 10)
331 sio = StringIO()
336 sio = StringIO()
332 savestdout = sys.stdout
337 savestdout = sys.stdout
333 sys.stdout = sio
338 sys.stdout = sio
334 ip.magic_px('print a')
339 ip.magic_px('print a')
335 sys.stdout = savestdout
340 sys.stdout = savestdout
336 sio.read()
341 sio.read()
337 self.assertFalse('[stdout:%i]'%v.targets in sio.buf)
342 self.assertFalse('[stdout:%i]'%v.targets in sio.buf)
338 ip.magic_px('1/0')
343 ip.magic_px('1/0')
339 ar = v.get_result(-1)
344 ar = v.get_result(-1)
340 self.assertRaisesRemote(ZeroDivisionError, ar.get)
345 self.assertRaisesRemote(ZeroDivisionError, ar.get)
341
346
342 def test_magic_autopx_blocking(self):
347 def test_magic_autopx_blocking(self):
343 ip = get_ipython()
348 ip = get_ipython()
344 v = self.client[-1]
349 v = self.client[-1]
345 v.activate()
350 v.activate()
346 v.block=True
351 v.block=True
347
352
348 sio = StringIO()
353 sio = StringIO()
349 savestdout = sys.stdout
354 savestdout = sys.stdout
350 sys.stdout = sio
355 sys.stdout = sio
351 ip.magic_autopx()
356 ip.magic_autopx()
352 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
357 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
353 ip.run_cell('print b')
358 ip.run_cell('print b')
354 ip.run_cell("b/c")
359 ip.run_cell("b/c")
355 ip.run_code(compile('b*=2', '', 'single'))
360 ip.run_code(compile('b*=2', '', 'single'))
356 ip.magic_autopx()
361 ip.magic_autopx()
357 sys.stdout = savestdout
362 sys.stdout = savestdout
358 sio.read()
363 sio.read()
359 output = sio.buf.strip()
364 output = sio.buf.strip()
360 self.assertTrue(output.startswith('%autopx enabled'))
365 self.assertTrue(output.startswith('%autopx enabled'))
361 self.assertTrue(output.endswith('%autopx disabled'))
366 self.assertTrue(output.endswith('%autopx disabled'))
362 self.assertTrue('RemoteError: ZeroDivisionError' in output)
367 self.assertTrue('RemoteError: ZeroDivisionError' in output)
363 ar = v.get_result(-2)
368 ar = v.get_result(-2)
364 self.assertEquals(v['a'], 5)
369 self.assertEquals(v['a'], 5)
365 self.assertEquals(v['b'], 20)
370 self.assertEquals(v['b'], 20)
366 self.assertRaisesRemote(ZeroDivisionError, ar.get)
371 self.assertRaisesRemote(ZeroDivisionError, ar.get)
367
372
368 def test_magic_autopx_nonblocking(self):
373 def test_magic_autopx_nonblocking(self):
369 ip = get_ipython()
374 ip = get_ipython()
370 v = self.client[-1]
375 v = self.client[-1]
371 v.activate()
376 v.activate()
372 v.block=False
377 v.block=False
373
378
374 sio = StringIO()
379 sio = StringIO()
375 savestdout = sys.stdout
380 savestdout = sys.stdout
376 sys.stdout = sio
381 sys.stdout = sio
377 ip.magic_autopx()
382 ip.magic_autopx()
378 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
383 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
379 ip.run_cell('print b')
384 ip.run_cell('print b')
380 ip.run_cell("b/c")
385 ip.run_cell("b/c")
381 ip.run_code(compile('b*=2', '', 'single'))
386 ip.run_code(compile('b*=2', '', 'single'))
382 ip.magic_autopx()
387 ip.magic_autopx()
383 sys.stdout = savestdout
388 sys.stdout = savestdout
384 sio.read()
389 sio.read()
385 output = sio.buf.strip()
390 output = sio.buf.strip()
386 self.assertTrue(output.startswith('%autopx enabled'))
391 self.assertTrue(output.startswith('%autopx enabled'))
387 self.assertTrue(output.endswith('%autopx disabled'))
392 self.assertTrue(output.endswith('%autopx disabled'))
388 self.assertFalse('ZeroDivisionError' in output)
393 self.assertFalse('ZeroDivisionError' in output)
389 ar = v.get_result(-2)
394 ar = v.get_result(-2)
390 self.assertEquals(v['a'], 5)
395 self.assertEquals(v['a'], 5)
391 self.assertEquals(v['b'], 20)
396 self.assertEquals(v['b'], 20)
392 self.assertRaisesRemote(ZeroDivisionError, ar.get)
397 self.assertRaisesRemote(ZeroDivisionError, ar.get)
393
398
394 def test_magic_result(self):
399 def test_magic_result(self):
395 ip = get_ipython()
400 ip = get_ipython()
396 v = self.client[-1]
401 v = self.client[-1]
397 v.activate()
402 v.activate()
398 v['a'] = 111
403 v['a'] = 111
399 ra = v['a']
404 ra = v['a']
400
405
401 ar = ip.magic_result()
406 ar = ip.magic_result()
402 self.assertEquals(ar.msg_ids, [v.history[-1]])
407 self.assertEquals(ar.msg_ids, [v.history[-1]])
403 self.assertEquals(ar.get(), 111)
408 self.assertEquals(ar.get(), 111)
404 ar = ip.magic_result('-2')
409 ar = ip.magic_result('-2')
405 self.assertEquals(ar.msg_ids, [v.history[-2]])
410 self.assertEquals(ar.msg_ids, [v.history[-2]])
406
411
407 def test_unicode_execute(self):
412 def test_unicode_execute(self):
408 """test executing unicode strings"""
413 """test executing unicode strings"""
409 v = self.client[-1]
414 v = self.client[-1]
410 v.block=True
415 v.block=True
411 code=u"a=u'é'"
416 code=u"a=u'é'"
412 v.execute(code)
417 v.execute(code)
413 self.assertEquals(v['a'], u'é')
418 self.assertEquals(v['a'], u'é')
414
419
415 def test_unicode_apply_result(self):
420 def test_unicode_apply_result(self):
416 """test unicode apply results"""
421 """test unicode apply results"""
417 v = self.client[-1]
422 v = self.client[-1]
418 r = v.apply_sync(lambda : u'é')
423 r = v.apply_sync(lambda : u'é')
419 self.assertEquals(r, u'é')
424 self.assertEquals(r, u'é')
420
425
421 def test_unicode_apply_arg(self):
426 def test_unicode_apply_arg(self):
422 """test passing unicode arguments to apply"""
427 """test passing unicode arguments to apply"""
423 v = self.client[-1]
428 v = self.client[-1]
424
429
425 @interactive
430 @interactive
426 def check_unicode(a, check):
431 def check_unicode(a, check):
427 assert isinstance(a, unicode), "%r is not unicode"%a
432 assert isinstance(a, unicode), "%r is not unicode"%a
428 assert isinstance(check, bytes), "%r is not bytes"%check
433 assert isinstance(check, bytes), "%r is not bytes"%check
429 assert a.encode('utf8') == check, "%s != %s"%(a,check)
434 assert a.encode('utf8') == check, "%s != %s"%(a,check)
430
435
431 for s in [ u'é', u'ßø®∫','asdf'.decode() ]:
436 for s in [ u'é', u'ßø®∫','asdf'.decode() ]:
432 try:
437 try:
433 v.apply_sync(check_unicode, s, s.encode('utf8'))
438 v.apply_sync(check_unicode, s, s.encode('utf8'))
434 except error.RemoteError as e:
439 except error.RemoteError as e:
435 if e.ename == 'AssertionError':
440 if e.ename == 'AssertionError':
436 self.fail(e.evalue)
441 self.fail(e.evalue)
437 else:
442 else:
438 raise e
443 raise e
439
444
440
445
441
446
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
General Comments 0
You need to be logged in to leave comments. Login now