##// END OF EJS Templates
update recently changed modules with Authors in docstring
MinRK -
Show More

The requested changes are too big and content was truncated. Show full diff

@@ -1,361 +1,362 b''
1 1 # encoding: utf-8
2 2 """
3 3 A base class for a configurable application.
4 4
5 5 Authors:
6 6
7 7 * Brian Granger
8 * Min RK
8 9 """
9 10
10 11 #-----------------------------------------------------------------------------
11 12 # Copyright (C) 2008-2011 The IPython Development Team
12 13 #
13 14 # Distributed under the terms of the BSD License. The full license is in
14 15 # the file COPYING, distributed as part of this software.
15 16 #-----------------------------------------------------------------------------
16 17
17 18 #-----------------------------------------------------------------------------
18 19 # Imports
19 20 #-----------------------------------------------------------------------------
20 21
21 22 from copy import deepcopy
22 23 import logging
23 24 import re
24 25 import sys
25 26
26 27 from IPython.config.configurable import SingletonConfigurable
27 28 from IPython.config.loader import (
28 29 KeyValueConfigLoader, PyFileConfigLoader, Config, ArgumentError
29 30 )
30 31
31 32 from IPython.utils.traitlets import (
32 33 Unicode, List, Int, Enum, Dict, Instance
33 34 )
34 35 from IPython.utils.importstring import import_item
35 36 from IPython.utils.text import indent
36 37
37 38 #-----------------------------------------------------------------------------
38 39 # Descriptions for the various sections
39 40 #-----------------------------------------------------------------------------
40 41
41 42 flag_description = """
42 43 Flags are command-line arguments passed as '--<flag>'.
43 44 These take no parameters, unlike regular key-value arguments.
44 45 They are typically used for setting boolean flags, or enabling
45 46 modes that involve setting multiple options together.
46 47 """.strip() # trim newlines of front and back
47 48
48 49 alias_description = """
49 50 These are commonly set parameters, given abbreviated aliases for convenience.
50 51 They are set in the same `name=value` way as class parameters, where
51 52 <name> is replaced by the real parameter for which it is an alias.
52 53 """.strip() # trim newlines of front and back
53 54
54 55 keyvalue_description = """
55 56 Parameters are set from command-line arguments of the form:
56 57 `Class.trait=value`. Parameters will *never* be prefixed with '-'.
57 58 This line is evaluated in Python, so simple expressions are allowed, e.g.
58 59 `C.a='range(3)'` For setting C.a=[0,1,2]
59 60 """.strip() # trim newlines of front and back
60 61
61 62 #-----------------------------------------------------------------------------
62 63 # Application class
63 64 #-----------------------------------------------------------------------------
64 65
65 66
66 67 class ApplicationError(Exception):
67 68 pass
68 69
69 70
70 71 class Application(SingletonConfigurable):
71 72 """A singleton application with full configuration support."""
72 73
73 74 # The name of the application, will usually match the name of the command
74 75 # line application
75 76 name = Unicode(u'application')
76 77
77 78 # The description of the application that is printed at the beginning
78 79 # of the help.
79 80 description = Unicode(u'This is an application.')
80 81 # default section descriptions
81 82 flag_description = Unicode(flag_description)
82 83 alias_description = Unicode(alias_description)
83 84 keyvalue_description = Unicode(keyvalue_description)
84 85
85 86
86 87 # A sequence of Configurable subclasses whose config=True attributes will
87 88 # be exposed at the command line.
88 89 classes = List([])
89 90
90 91 # The version string of this application.
91 92 version = Unicode(u'0.0')
92 93
93 94 # The log level for the application
94 95 log_level = Enum((0,10,20,30,40,50), default_value=logging.WARN,
95 96 config=True,
96 97 help="Set the log level.")
97 98
98 99 # the alias map for configurables
99 100 aliases = Dict(dict(log_level='Application.log_level'))
100 101
101 102 # flags for loading Configurables or store_const style flags
102 103 # flags are loaded from this dict by '--key' flags
103 104 # this must be a dict of two-tuples, the first element being the Config/dict
104 105 # and the second being the help string for the flag
105 106 flags = Dict()
106 107
107 108 # subcommands for launching other applications
108 109 # if this is not empty, this will be a parent Application
109 110 # this must be a dict of two-tuples, the first element being the application class/import string
110 111 # and the second being the help string for the subcommand
111 112 subcommands = Dict()
112 113 # parse_command_line will initialize a subapp, if requested
113 114 subapp = Instance('IPython.config.application.Application', allow_none=True)
114 115
115 116 # extra command-line arguments that don't set config values
116 117 extra_args = List(Unicode)
117 118
118 119
119 120 def __init__(self, **kwargs):
120 121 SingletonConfigurable.__init__(self, **kwargs)
121 122 # Add my class to self.classes so my attributes appear in command line
122 123 # options.
123 124 self.classes.insert(0, self.__class__)
124 125
125 126 # ensure self.flags dict is valid
126 127 for key,value in self.flags.iteritems():
127 128 assert len(value) == 2, "Bad flag: %r:%s"%(key,value)
128 129 assert isinstance(value[0], (dict, Config)), "Bad flag: %r:%s"%(key,value)
129 130 assert isinstance(value[1], basestring), "Bad flag: %r:%s"%(key,value)
130 131 self.init_logging()
131 132
132 133 def _config_changed(self, name, old, new):
133 134 SingletonConfigurable._config_changed(self, name, old, new)
134 135 self.log.debug('Config changed:')
135 136 self.log.debug(repr(new))
136 137
137 138 def init_logging(self):
138 139 """Start logging for this application.
139 140
140 141 The default is to log to stdout using a StreaHandler. The log level
141 142 starts at loggin.WARN, but this can be adjusted by setting the
142 143 ``log_level`` attribute.
143 144 """
144 145 self.log = logging.getLogger(self.__class__.__name__)
145 146 self.log.setLevel(self.log_level)
146 147 self._log_handler = logging.StreamHandler()
147 148 self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
148 149 self._log_handler.setFormatter(self._log_formatter)
149 150 self.log.addHandler(self._log_handler)
150 151
151 152 def initialize(self, argv=None):
152 153 """Do the basic steps to configure me.
153 154
154 155 Override in subclasses.
155 156 """
156 157 self.parse_command_line(argv)
157 158
158 159
159 160 def start(self):
160 161 """Start the app mainloop.
161 162
162 163 Override in subclasses.
163 164 """
164 165 if self.subapp is not None:
165 166 return self.subapp.start()
166 167
167 168 def _log_level_changed(self, name, old, new):
168 169 """Adjust the log level when log_level is set."""
169 170 self.log.setLevel(new)
170 171
171 172 def print_alias_help(self):
172 173 """print the alias part of the help"""
173 174 if not self.aliases:
174 175 return
175 176
176 177 lines = ['Aliases']
177 178 lines.append('-'*len(lines[0]))
178 179 lines.append(self.alias_description)
179 180 lines.append('')
180 181
181 182 classdict = {}
182 183 for cls in self.classes:
183 184 # include all parents (up to, but excluding Configurable) in available names
184 185 for c in cls.mro()[:-3]:
185 186 classdict[c.__name__] = c
186 187
187 188 for alias, longname in self.aliases.iteritems():
188 189 classname, traitname = longname.split('.',1)
189 190 cls = classdict[classname]
190 191
191 192 trait = cls.class_traits(config=True)[traitname]
192 193 help = cls.class_get_trait_help(trait)
193 194 help = help.replace(longname, "%s (%s)"%(alias, longname), 1)
194 195 lines.append(help)
195 196 lines.append('')
196 197 print '\n'.join(lines)
197 198
198 199 def print_flag_help(self):
199 200 """print the flag part of the help"""
200 201 if not self.flags:
201 202 return
202 203
203 204 lines = ['Flags']
204 205 lines.append('-'*len(lines[0]))
205 206 lines.append(self.flag_description)
206 207 lines.append('')
207 208
208 209 for m, (cfg,help) in self.flags.iteritems():
209 210 lines.append('--'+m)
210 211 lines.append(indent(help.strip(), flatten=True))
211 212 lines.append('')
212 213 print '\n'.join(lines)
213 214
214 215 def print_subcommands(self):
215 216 """print the subcommand part of the help"""
216 217 if not self.subcommands:
217 218 return
218 219
219 220 lines = ["Subcommands"]
220 221 lines.append('-'*len(lines[0]))
221 222 for subc, (cls,help) in self.subcommands.iteritems():
222 223 lines.append("%s : %s"%(subc, cls))
223 224 if help:
224 225 lines.append(indent(help.strip(), flatten=True))
225 226 lines.append('')
226 227 print '\n'.join(lines)
227 228
228 229 def print_help(self, classes=False):
229 230 """Print the help for each Configurable class in self.classes.
230 231
231 232 If classes=False (the default), only flags and aliases are printed
232 233 """
233 234 self.print_subcommands()
234 235 self.print_flag_help()
235 236 self.print_alias_help()
236 237
237 238 if classes:
238 239 if self.classes:
239 240 print "Class parameters"
240 241 print "----------------"
241 242 print self.keyvalue_description
242 243 print
243 244
244 245 for cls in self.classes:
245 246 cls.class_print_help()
246 247 print
247 248 else:
248 249 print "To see all available configurables, use `--help-all`"
249 250 print
250 251
251 252 def print_description(self):
252 253 """Print the application description."""
253 254 print self.description
254 255 print
255 256
256 257 def print_version(self):
257 258 """Print the version string."""
258 259 print self.version
259 260
260 261 def update_config(self, config):
261 262 """Fire the traits events when the config is updated."""
262 263 # Save a copy of the current config.
263 264 newconfig = deepcopy(self.config)
264 265 # Merge the new config into the current one.
265 266 newconfig._merge(config)
266 267 # Save the combined config as self.config, which triggers the traits
267 268 # events.
268 269 self.config = newconfig
269 270
270 271 def initialize_subcommand(self, subc, argv=None):
271 272 """Initialize a subcommand with argv"""
272 273 subapp,help = self.subcommands.get(subc)
273 274
274 275 if isinstance(subapp, basestring):
275 276 subapp = import_item(subapp)
276 277
277 278 # clear existing instances
278 279 self.__class__.clear_instance()
279 280 # instantiate
280 281 self.subapp = subapp.instance()
281 282 # and initialize subapp
282 283 self.subapp.initialize(argv)
283 284
284 285 def parse_command_line(self, argv=None):
285 286 """Parse the command line arguments."""
286 287 argv = sys.argv[1:] if argv is None else argv
287 288
288 289 if self.subcommands and len(argv) > 0:
289 290 # we have subcommands, and one may have been specified
290 291 subc, subargv = argv[0], argv[1:]
291 292 if re.match(r'^\w(\-?\w)*$', subc) and subc in self.subcommands:
292 293 # it's a subcommand, and *not* a flag or class parameter
293 294 return self.initialize_subcommand(subc, subargv)
294 295
295 296 if '-h' in argv or '--help' in argv or '--help-all' in argv:
296 297 self.print_description()
297 298 self.print_help('--help-all' in argv)
298 299 self.exit(0)
299 300
300 301 if '--version' in argv:
301 302 self.print_version()
302 303 self.exit(0)
303 304
304 305 loader = KeyValueConfigLoader(argv=argv, aliases=self.aliases,
305 306 flags=self.flags)
306 307 try:
307 308 config = loader.load_config()
308 309 except ArgumentError as e:
309 310 self.log.fatal(str(e))
310 311 self.print_description()
311 312 self.print_help()
312 313 self.exit(1)
313 314 self.update_config(config)
314 315 # store unparsed args in extra_args
315 316 self.extra_args = loader.extra_args
316 317
317 318 def load_config_file(self, filename, path=None):
318 319 """Load a .py based config file by filename and path."""
319 320 loader = PyFileConfigLoader(filename, path=path)
320 321 config = loader.load_config()
321 322 self.update_config(config)
322 323
323 324 def exit(self, exit_status=0):
324 325 self.log.debug("Exiting application: %s" % self.name)
325 326 sys.exit(exit_status)
326 327
327 328 #-----------------------------------------------------------------------------
328 329 # utility functions, for convenience
329 330 #-----------------------------------------------------------------------------
330 331
331 332 def boolean_flag(name, configurable, set_help='', unset_help=''):
332 333 """helper for building basic --trait, --no-trait flags
333 334
334 335 Parameters
335 336 ----------
336 337
337 338 name : str
338 339 The name of the flag.
339 340 configurable : str
340 341 The 'Class.trait' string of the trait to be set/unset with the flag
341 342 set_help : unicode
342 343 help string for --name flag
343 344 unset_help : unicode
344 345 help string for --no-name flag
345 346
346 347 Returns
347 348 -------
348 349
349 350 cfg : dict
350 351 A dict with two keys: 'name', and 'no-name', for setting and unsetting
351 352 the trait, respectively.
352 353 """
353 354 # default helpstrings
354 355 set_help = set_help or "set %s=True"%configurable
355 356 unset_help = unset_help or "set %s=False"%configurable
356 357
357 358 cls,trait = configurable.split('.')
358 359
359 360 setter = {cls : {trait : True}}
360 361 unsetter = {cls : {trait : False}}
361 362 return {name : (setter, set_help), 'no-'+name : (unsetter, unset_help)}
@@ -1,278 +1,279 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 A base class for objects that are configurable.
5 5
6 6 Authors:
7 7
8 8 * Brian Granger
9 9 * Fernando Perez
10 * Min RK
10 11 """
11 12
12 13 #-----------------------------------------------------------------------------
13 # Copyright (C) 2008-2010 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
14 15 #
15 16 # Distributed under the terms of the BSD License. The full license is in
16 17 # the file COPYING, distributed as part of this software.
17 18 #-----------------------------------------------------------------------------
18 19
19 20 #-----------------------------------------------------------------------------
20 21 # Imports
21 22 #-----------------------------------------------------------------------------
22 23
23 24 from copy import deepcopy
24 25 import datetime
25 26
26 27 from loader import Config
27 28 from IPython.utils.traitlets import HasTraits, Instance
28 29 from IPython.utils.text import indent
29 30
30 31
31 32 #-----------------------------------------------------------------------------
32 33 # Helper classes for Configurables
33 34 #-----------------------------------------------------------------------------
34 35
35 36
36 37 class ConfigurableError(Exception):
37 38 pass
38 39
39 40
40 41 class MultipleInstanceError(ConfigurableError):
41 42 pass
42 43
43 44 #-----------------------------------------------------------------------------
44 45 # Configurable implementation
45 46 #-----------------------------------------------------------------------------
46 47
47 48 class Configurable(HasTraits):
48 49
49 50 config = Instance(Config,(),{})
50 51 created = None
51 52
52 53 def __init__(self, **kwargs):
53 54 """Create a conigurable given a config config.
54 55
55 56 Parameters
56 57 ----------
57 58 config : Config
58 59 If this is empty, default values are used. If config is a
59 60 :class:`Config` instance, it will be used to configure the
60 61 instance.
61 62
62 63 Notes
63 64 -----
64 65 Subclasses of Configurable must call the :meth:`__init__` method of
65 66 :class:`Configurable` *before* doing anything else and using
66 67 :func:`super`::
67 68
68 69 class MyConfigurable(Configurable):
69 70 def __init__(self, config=None):
70 71 super(MyConfigurable, self).__init__(config)
71 72 # Then any other code you need to finish initialization.
72 73
73 74 This ensures that instances will be configured properly.
74 75 """
75 76 config = kwargs.pop('config', None)
76 77 if config is not None:
77 78 # We used to deepcopy, but for now we are trying to just save
78 79 # by reference. This *could* have side effects as all components
79 80 # will share config. In fact, I did find such a side effect in
80 81 # _config_changed below. If a config attribute value was a mutable type
81 82 # all instances of a component were getting the same copy, effectively
82 83 # making that a class attribute.
83 84 # self.config = deepcopy(config)
84 85 self.config = config
85 86 # This should go second so individual keyword arguments override
86 87 # the values in config.
87 88 super(Configurable, self).__init__(**kwargs)
88 89 self.created = datetime.datetime.now()
89 90
90 91 #-------------------------------------------------------------------------
91 92 # Static trait notifiations
92 93 #-------------------------------------------------------------------------
93 94
94 95 def _config_changed(self, name, old, new):
95 96 """Update all the class traits having ``config=True`` as metadata.
96 97
97 98 For any class trait with a ``config`` metadata attribute that is
98 99 ``True``, we update the trait with the value of the corresponding
99 100 config entry.
100 101 """
101 102 # Get all traits with a config metadata entry that is True
102 103 traits = self.traits(config=True)
103 104
104 105 # We auto-load config section for this class as well as any parent
105 106 # classes that are Configurable subclasses. This starts with Configurable
106 107 # and works down the mro loading the config for each section.
107 108 section_names = [cls.__name__ for cls in \
108 109 reversed(self.__class__.__mro__) if
109 110 issubclass(cls, Configurable) and issubclass(self.__class__, cls)]
110 111
111 112 for sname in section_names:
112 113 # Don't do a blind getattr as that would cause the config to
113 114 # dynamically create the section with name self.__class__.__name__.
114 115 if new._has_section(sname):
115 116 my_config = new[sname]
116 117 for k, v in traits.iteritems():
117 118 # Don't allow traitlets with config=True to start with
118 119 # uppercase. Otherwise, they are confused with Config
119 120 # subsections. But, developers shouldn't have uppercase
120 121 # attributes anyways! (PEP 6)
121 122 if k[0].upper()==k[0] and not k.startswith('_'):
122 123 raise ConfigurableError('Configurable traitlets with '
123 124 'config=True must start with a lowercase so they are '
124 125 'not confused with Config subsections: %s.%s' % \
125 126 (self.__class__.__name__, k))
126 127 try:
127 128 # Here we grab the value from the config
128 129 # If k has the naming convention of a config
129 130 # section, it will be auto created.
130 131 config_value = my_config[k]
131 132 except KeyError:
132 133 pass
133 134 else:
134 135 # print "Setting %s.%s from %s.%s=%r" % \
135 136 # (self.__class__.__name__,k,sname,k,config_value)
136 137 # We have to do a deepcopy here if we don't deepcopy the entire
137 138 # config object. If we don't, a mutable config_value will be
138 139 # shared by all instances, effectively making it a class attribute.
139 140 setattr(self, k, deepcopy(config_value))
140 141
141 142 @classmethod
142 143 def class_get_help(cls):
143 144 """Get the help string for this class in ReST format."""
144 145 cls_traits = cls.class_traits(config=True)
145 146 final_help = []
146 147 final_help.append(u'%s options' % cls.__name__)
147 148 final_help.append(len(final_help[0])*u'-')
148 149 for k,v in cls.class_traits(config=True).iteritems():
149 150 help = cls.class_get_trait_help(v)
150 151 final_help.append(help)
151 152 return '\n'.join(final_help)
152 153
153 154 @classmethod
154 155 def class_get_trait_help(cls, trait):
155 156 """Get the help string for a single """
156 157 lines = []
157 158 header = "%s.%s : %s" % (cls.__name__, trait.name, trait.__class__.__name__)
158 159 try:
159 160 dvr = repr(trait.get_default_value())
160 161 except Exception:
161 162 dvr = None # ignore defaults we can't construct
162 163 if dvr is not None:
163 164 header += ' [default: %s]'%dvr
164 165 lines.append(header)
165 166
166 167 help = trait.get_metadata('help')
167 168 if help is not None:
168 169 lines.append(indent(help.strip(), flatten=True))
169 170 if 'Enum' in trait.__class__.__name__:
170 171 # include Enum choices
171 172 lines.append(indent('Choices: %r'%(trait.values,), flatten=True))
172 173 return '\n'.join(lines)
173 174
174 175 @classmethod
175 176 def class_print_help(cls):
176 177 print cls.class_get_help()
177 178
178 179
179 180 class SingletonConfigurable(Configurable):
180 181 """A configurable that only allows one instance.
181 182
182 183 This class is for classes that should only have one instance of itself
183 184 or *any* subclass. To create and retrieve such a class use the
184 185 :meth:`SingletonConfigurable.instance` method.
185 186 """
186 187
187 188 _instance = None
188 189
189 190 @classmethod
190 191 def _walk_mro(cls):
191 192 """Walk the cls.mro() for parent classes that are also singletons
192 193
193 194 For use in instance()
194 195 """
195 196
196 197 for subclass in cls.mro():
197 198 if issubclass(cls, subclass) and \
198 199 issubclass(subclass, SingletonConfigurable) and \
199 200 subclass != SingletonConfigurable:
200 201 yield subclass
201 202
202 203 @classmethod
203 204 def clear_instance(cls):
204 205 """unset _instance for this class and singleton parents.
205 206 """
206 207 if not cls.initialized():
207 208 return
208 209 for subclass in cls._walk_mro():
209 210 if isinstance(subclass._instance, cls):
210 211 # only clear instances that are instances
211 212 # of the calling class
212 213 subclass._instance = None
213 214
214 215 @classmethod
215 216 def instance(cls, *args, **kwargs):
216 217 """Returns a global instance of this class.
217 218
218 219 This method create a new instance if none have previously been created
219 220 and returns a previously created instance is one already exists.
220 221
221 222 The arguments and keyword arguments passed to this method are passed
222 223 on to the :meth:`__init__` method of the class upon instantiation.
223 224
224 225 Examples
225 226 --------
226 227
227 228 Create a singleton class using instance, and retrieve it::
228 229
229 230 >>> from IPython.config.configurable import SingletonConfigurable
230 231 >>> class Foo(SingletonConfigurable): pass
231 232 >>> foo = Foo.instance()
232 233 >>> foo == Foo.instance()
233 234 True
234 235
235 236 Create a subclass that is retrived using the base class instance::
236 237
237 238 >>> class Bar(SingletonConfigurable): pass
238 239 >>> class Bam(Bar): pass
239 240 >>> bam = Bam.instance()
240 241 >>> bam == Bar.instance()
241 242 True
242 243 """
243 244 # Create and save the instance
244 245 if cls._instance is None:
245 246 inst = cls(*args, **kwargs)
246 247 # Now make sure that the instance will also be returned by
247 248 # parent classes' _instance attribute.
248 249 for subclass in cls._walk_mro():
249 250 subclass._instance = inst
250 251
251 252 if isinstance(cls._instance, cls):
252 253 return cls._instance
253 254 else:
254 255 raise MultipleInstanceError(
255 256 'Multiple incompatible subclass instances of '
256 257 '%s are being created.' % cls.__name__
257 258 )
258 259
259 260 @classmethod
260 261 def initialized(cls):
261 262 """Has an instance been created?"""
262 263 return hasattr(cls, "_instance") and cls._instance is not None
263 264
264 265
265 266 class LoggingConfigurable(Configurable):
266 267 """A parent class for Configurables that log.
267 268
268 269 Subclasses have a log trait, and the default behavior
269 270 is to get the logger from the currently running Application
270 271 via Application.instance().log.
271 272 """
272 273
273 274 log = Instance('logging.Logger')
274 275 def _log_default(self):
275 276 from IPython.config.application import Application
276 277 return Application.instance().log
277 278
278 279 No newline at end of file
@@ -1,513 +1,514 b''
1 1 """A simple configuration system.
2 2
3 3 Authors
4 4 -------
5 5 * Brian Granger
6 6 * Fernando Perez
7 * Min RK
7 8 """
8 9
9 10 #-----------------------------------------------------------------------------
10 # Copyright (C) 2008-2009 The IPython Development Team
11 # Copyright (C) 2008-2011 The IPython Development Team
11 12 #
12 13 # Distributed under the terms of the BSD License. The full license is in
13 14 # the file COPYING, distributed as part of this software.
14 15 #-----------------------------------------------------------------------------
15 16
16 17 #-----------------------------------------------------------------------------
17 18 # Imports
18 19 #-----------------------------------------------------------------------------
19 20
20 21 import __builtin__
21 22 import re
22 23 import sys
23 24
24 25 from IPython.external import argparse
25 26 from IPython.utils.path import filefind
26 27
27 28 #-----------------------------------------------------------------------------
28 29 # Exceptions
29 30 #-----------------------------------------------------------------------------
30 31
31 32
32 33 class ConfigError(Exception):
33 34 pass
34 35
35 36
36 37 class ConfigLoaderError(ConfigError):
37 38 pass
38 39
39 40 class ArgumentError(ConfigLoaderError):
40 41 pass
41 42
42 43 #-----------------------------------------------------------------------------
43 44 # Argparse fix
44 45 #-----------------------------------------------------------------------------
45 46
46 47 # Unfortunately argparse by default prints help messages to stderr instead of
47 48 # stdout. This makes it annoying to capture long help screens at the command
48 49 # line, since one must know how to pipe stderr, which many users don't know how
49 50 # to do. So we override the print_help method with one that defaults to
50 51 # stdout and use our class instead.
51 52
52 53 class ArgumentParser(argparse.ArgumentParser):
53 54 """Simple argparse subclass that prints help to stdout by default."""
54 55
55 56 def print_help(self, file=None):
56 57 if file is None:
57 58 file = sys.stdout
58 59 return super(ArgumentParser, self).print_help(file)
59 60
60 61 print_help.__doc__ = argparse.ArgumentParser.print_help.__doc__
61 62
62 63 #-----------------------------------------------------------------------------
63 64 # Config class for holding config information
64 65 #-----------------------------------------------------------------------------
65 66
66 67
67 68 class Config(dict):
68 69 """An attribute based dict that can do smart merges."""
69 70
70 71 def __init__(self, *args, **kwds):
71 72 dict.__init__(self, *args, **kwds)
72 73 # This sets self.__dict__ = self, but it has to be done this way
73 74 # because we are also overriding __setattr__.
74 75 dict.__setattr__(self, '__dict__', self)
75 76
76 77 def _merge(self, other):
77 78 to_update = {}
78 79 for k, v in other.iteritems():
79 80 if not self.has_key(k):
80 81 to_update[k] = v
81 82 else: # I have this key
82 83 if isinstance(v, Config):
83 84 # Recursively merge common sub Configs
84 85 self[k]._merge(v)
85 86 else:
86 87 # Plain updates for non-Configs
87 88 to_update[k] = v
88 89
89 90 self.update(to_update)
90 91
91 92 def _is_section_key(self, key):
92 93 if key[0].upper()==key[0] and not key.startswith('_'):
93 94 return True
94 95 else:
95 96 return False
96 97
97 98 def __contains__(self, key):
98 99 if self._is_section_key(key):
99 100 return True
100 101 else:
101 102 return super(Config, self).__contains__(key)
102 103 # .has_key is deprecated for dictionaries.
103 104 has_key = __contains__
104 105
105 106 def _has_section(self, key):
106 107 if self._is_section_key(key):
107 108 if super(Config, self).__contains__(key):
108 109 return True
109 110 return False
110 111
111 112 def copy(self):
112 113 return type(self)(dict.copy(self))
113 114
114 115 def __copy__(self):
115 116 return self.copy()
116 117
117 118 def __deepcopy__(self, memo):
118 119 import copy
119 120 return type(self)(copy.deepcopy(self.items()))
120 121
121 122 def __getitem__(self, key):
122 123 # We cannot use directly self._is_section_key, because it triggers
123 124 # infinite recursion on top of PyPy. Instead, we manually fish the
124 125 # bound method.
125 126 is_section_key = self.__class__._is_section_key.__get__(self)
126 127
127 128 # Because we use this for an exec namespace, we need to delegate
128 129 # the lookup of names in __builtin__ to itself. This means
129 130 # that you can't have section or attribute names that are
130 131 # builtins.
131 132 try:
132 133 return getattr(__builtin__, key)
133 134 except AttributeError:
134 135 pass
135 136 if is_section_key(key):
136 137 try:
137 138 return dict.__getitem__(self, key)
138 139 except KeyError:
139 140 c = Config()
140 141 dict.__setitem__(self, key, c)
141 142 return c
142 143 else:
143 144 return dict.__getitem__(self, key)
144 145
145 146 def __setitem__(self, key, value):
146 147 # Don't allow names in __builtin__ to be modified.
147 148 if hasattr(__builtin__, key):
148 149 raise ConfigError('Config variable names cannot have the same name '
149 150 'as a Python builtin: %s' % key)
150 151 if self._is_section_key(key):
151 152 if not isinstance(value, Config):
152 153 raise ValueError('values whose keys begin with an uppercase '
153 154 'char must be Config instances: %r, %r' % (key, value))
154 155 else:
155 156 dict.__setitem__(self, key, value)
156 157
157 158 def __getattr__(self, key):
158 159 try:
159 160 return self.__getitem__(key)
160 161 except KeyError, e:
161 162 raise AttributeError(e)
162 163
163 164 def __setattr__(self, key, value):
164 165 try:
165 166 self.__setitem__(key, value)
166 167 except KeyError, e:
167 168 raise AttributeError(e)
168 169
169 170 def __delattr__(self, key):
170 171 try:
171 172 dict.__delitem__(self, key)
172 173 except KeyError, e:
173 174 raise AttributeError(e)
174 175
175 176
176 177 #-----------------------------------------------------------------------------
177 178 # Config loading classes
178 179 #-----------------------------------------------------------------------------
179 180
180 181
181 182 class ConfigLoader(object):
182 183 """A object for loading configurations from just about anywhere.
183 184
184 185 The resulting configuration is packaged as a :class:`Struct`.
185 186
186 187 Notes
187 188 -----
188 189 A :class:`ConfigLoader` does one thing: load a config from a source
189 190 (file, command line arguments) and returns the data as a :class:`Struct`.
190 191 There are lots of things that :class:`ConfigLoader` does not do. It does
191 192 not implement complex logic for finding config files. It does not handle
192 193 default values or merge multiple configs. These things need to be
193 194 handled elsewhere.
194 195 """
195 196
196 197 def __init__(self):
197 198 """A base class for config loaders.
198 199
199 200 Examples
200 201 --------
201 202
202 203 >>> cl = ConfigLoader()
203 204 >>> config = cl.load_config()
204 205 >>> config
205 206 {}
206 207 """
207 208 self.clear()
208 209
209 210 def clear(self):
210 211 self.config = Config()
211 212
212 213 def load_config(self):
213 214 """Load a config from somewhere, return a :class:`Config` instance.
214 215
215 216 Usually, this will cause self.config to be set and then returned.
216 217 However, in most cases, :meth:`ConfigLoader.clear` should be called
217 218 to erase any previous state.
218 219 """
219 220 self.clear()
220 221 return self.config
221 222
222 223
223 224 class FileConfigLoader(ConfigLoader):
224 225 """A base class for file based configurations.
225 226
226 227 As we add more file based config loaders, the common logic should go
227 228 here.
228 229 """
229 230 pass
230 231
231 232
232 233 class PyFileConfigLoader(FileConfigLoader):
233 234 """A config loader for pure python files.
234 235
235 236 This calls execfile on a plain python file and looks for attributes
236 237 that are all caps. These attribute are added to the config Struct.
237 238 """
238 239
239 240 def __init__(self, filename, path=None):
240 241 """Build a config loader for a filename and path.
241 242
242 243 Parameters
243 244 ----------
244 245 filename : str
245 246 The file name of the config file.
246 247 path : str, list, tuple
247 248 The path to search for the config file on, or a sequence of
248 249 paths to try in order.
249 250 """
250 251 super(PyFileConfigLoader, self).__init__()
251 252 self.filename = filename
252 253 self.path = path
253 254 self.full_filename = ''
254 255 self.data = None
255 256
256 257 def load_config(self):
257 258 """Load the config from a file and return it as a Struct."""
258 259 self.clear()
259 260 self._find_file()
260 261 self._read_file_as_dict()
261 262 self._convert_to_config()
262 263 return self.config
263 264
264 265 def _find_file(self):
265 266 """Try to find the file by searching the paths."""
266 267 self.full_filename = filefind(self.filename, self.path)
267 268
268 269 def _read_file_as_dict(self):
269 270 """Load the config file into self.config, with recursive loading."""
270 271 # This closure is made available in the namespace that is used
271 272 # to exec the config file. This allows users to call
272 273 # load_subconfig('myconfig.py') to load config files recursively.
273 274 # It needs to be a closure because it has references to self.path
274 275 # and self.config. The sub-config is loaded with the same path
275 276 # as the parent, but it uses an empty config which is then merged
276 277 # with the parents.
277 278 def load_subconfig(fname):
278 279 loader = PyFileConfigLoader(fname, self.path)
279 280 try:
280 281 sub_config = loader.load_config()
281 282 except IOError:
282 283 # Pass silently if the sub config is not there. This happens
283 284 # when a user us using a profile, but not the default config.
284 285 pass
285 286 else:
286 287 self.config._merge(sub_config)
287 288
288 289 # Again, this needs to be a closure and should be used in config
289 290 # files to get the config being loaded.
290 291 def get_config():
291 292 return self.config
292 293
293 294 namespace = dict(load_subconfig=load_subconfig, get_config=get_config)
294 295 fs_encoding = sys.getfilesystemencoding() or 'ascii'
295 296 conf_filename = self.full_filename.encode(fs_encoding)
296 297 execfile(conf_filename, namespace)
297 298
298 299 def _convert_to_config(self):
299 300 if self.data is None:
300 301 ConfigLoaderError('self.data does not exist')
301 302
302 303
303 304 class CommandLineConfigLoader(ConfigLoader):
304 305 """A config loader for command line arguments.
305 306
306 307 As we add more command line based loaders, the common logic should go
307 308 here.
308 309 """
309 310
310 311 kv_pattern = re.compile(r'[A-Za-z]\w*(\.\w+)*\=.*')
311 312 flag_pattern = re.compile(r'\-\-\w+(\-\w)*')
312 313
313 314 class KeyValueConfigLoader(CommandLineConfigLoader):
314 315 """A config loader that loads key value pairs from the command line.
315 316
316 317 This allows command line options to be gives in the following form::
317 318
318 319 ipython Global.profile="foo" InteractiveShell.autocall=False
319 320 """
320 321
321 322 def __init__(self, argv=None, aliases=None, flags=None):
322 323 """Create a key value pair config loader.
323 324
324 325 Parameters
325 326 ----------
326 327 argv : list
327 328 A list that has the form of sys.argv[1:] which has unicode
328 329 elements of the form u"key=value". If this is None (default),
329 330 then sys.argv[1:] will be used.
330 331 aliases : dict
331 332 A dict of aliases for configurable traits.
332 333 Keys are the short aliases, Values are the resolved trait.
333 334 Of the form: `{'alias' : 'Configurable.trait'}`
334 335 flags : dict
335 336 A dict of flags, keyed by str name. Vaues can be Config objects,
336 337 dicts, or "key=value" strings. If Config or dict, when the flag
337 338 is triggered, The flag is loaded as `self.config.update(m)`.
338 339
339 340 Returns
340 341 -------
341 342 config : Config
342 343 The resulting Config object.
343 344
344 345 Examples
345 346 --------
346 347
347 348 >>> from IPython.config.loader import KeyValueConfigLoader
348 349 >>> cl = KeyValueConfigLoader()
349 350 >>> cl.load_config(["foo='bar'","A.name='brian'","B.number=0"])
350 351 {'A': {'name': 'brian'}, 'B': {'number': 0}, 'foo': 'bar'}
351 352 """
352 353 if argv is None:
353 354 argv = sys.argv[1:]
354 355 self.argv = argv
355 356 self.aliases = aliases or {}
356 357 self.flags = flags or {}
357 358
358 359 def load_config(self, argv=None, aliases=None, flags=None):
359 360 """Parse the configuration and generate the Config object.
360 361
361 362 Parameters
362 363 ----------
363 364 argv : list, optional
364 365 A list that has the form of sys.argv[1:] which has unicode
365 366 elements of the form u"key=value". If this is None (default),
366 367 then self.argv will be used.
367 368 aliases : dict
368 369 A dict of aliases for configurable traits.
369 370 Keys are the short aliases, Values are the resolved trait.
370 371 Of the form: `{'alias' : 'Configurable.trait'}`
371 372 flags : dict
372 373 A dict of flags, keyed by str name. Values can be Config objects
373 374 or dicts. When the flag is triggered, The config is loaded as
374 375 `self.config.update(cfg)`.
375 376 """
376 377 from IPython.config.configurable import Configurable
377 378
378 379 self.clear()
379 380 if argv is None:
380 381 argv = self.argv
381 382 if aliases is None:
382 383 aliases = self.aliases
383 384 if flags is None:
384 385 flags = self.flags
385 386
386 387 self.extra_args = []
387 388
388 389 for item in argv:
389 390 if kv_pattern.match(item):
390 391 lhs,rhs = item.split('=',1)
391 392 # Substitute longnames for aliases.
392 393 if lhs in aliases:
393 394 lhs = aliases[lhs]
394 395 exec_str = 'self.config.' + lhs + '=' + rhs
395 396 try:
396 397 # Try to see if regular Python syntax will work. This
397 398 # won't handle strings as the quote marks are removed
398 399 # by the system shell.
399 400 exec exec_str in locals(), globals()
400 401 except (NameError, SyntaxError):
401 402 # This case happens if the rhs is a string but without
402 403 # the quote marks. We add the quote marks and see if
403 404 # it succeeds. If it still fails, we let it raise.
404 405 exec_str = 'self.config.' + lhs + '="' + rhs + '"'
405 406 exec exec_str in locals(), globals()
406 407 elif flag_pattern.match(item):
407 408 # trim leading '--'
408 409 m = item[2:]
409 410 cfg,_ = flags.get(m, (None,None))
410 411 if cfg is None:
411 412 raise ArgumentError("Unrecognized flag: %r"%item)
412 413 elif isinstance(cfg, (dict, Config)):
413 414 # don't clobber whole config sections, update
414 415 # each section from config:
415 416 for sec,c in cfg.iteritems():
416 417 self.config[sec].update(c)
417 418 else:
418 419 raise ValueError("Invalid flag: %r"%flag)
419 420 elif item.startswith('-'):
420 421 # this shouldn't ever be valid
421 422 raise ArgumentError("Invalid argument: %r"%item)
422 423 else:
423 424 # keep all args that aren't valid in a list,
424 425 # in case our parent knows what to do with them.
425 426 self.extra_args.append(item)
426 427 return self.config
427 428
428 429 class ArgParseConfigLoader(CommandLineConfigLoader):
429 430 """A loader that uses the argparse module to load from the command line."""
430 431
431 432 def __init__(self, argv=None, *parser_args, **parser_kw):
432 433 """Create a config loader for use with argparse.
433 434
434 435 Parameters
435 436 ----------
436 437
437 438 argv : optional, list
438 439 If given, used to read command-line arguments from, otherwise
439 440 sys.argv[1:] is used.
440 441
441 442 parser_args : tuple
442 443 A tuple of positional arguments that will be passed to the
443 444 constructor of :class:`argparse.ArgumentParser`.
444 445
445 446 parser_kw : dict
446 447 A tuple of keyword arguments that will be passed to the
447 448 constructor of :class:`argparse.ArgumentParser`.
448 449
449 450 Returns
450 451 -------
451 452 config : Config
452 453 The resulting Config object.
453 454 """
454 455 super(CommandLineConfigLoader, self).__init__()
455 456 if argv == None:
456 457 argv = sys.argv[1:]
457 458 self.argv = argv
458 459 self.parser_args = parser_args
459 460 self.version = parser_kw.pop("version", None)
460 461 kwargs = dict(argument_default=argparse.SUPPRESS)
461 462 kwargs.update(parser_kw)
462 463 self.parser_kw = kwargs
463 464
464 465 def load_config(self, argv=None):
465 466 """Parse command line arguments and return as a Config object.
466 467
467 468 Parameters
468 469 ----------
469 470
470 471 args : optional, list
471 472 If given, a list with the structure of sys.argv[1:] to parse
472 473 arguments from. If not given, the instance's self.argv attribute
473 474 (given at construction time) is used."""
474 475 self.clear()
475 476 if argv is None:
476 477 argv = self.argv
477 478 self._create_parser()
478 479 self._parse_args(argv)
479 480 self._convert_to_config()
480 481 return self.config
481 482
482 483 def get_extra_args(self):
483 484 if hasattr(self, 'extra_args'):
484 485 return self.extra_args
485 486 else:
486 487 return []
487 488
488 489 def _create_parser(self):
489 490 self.parser = ArgumentParser(*self.parser_args, **self.parser_kw)
490 491 self._add_arguments()
491 492
492 493 def _add_arguments(self):
493 494 raise NotImplementedError("subclasses must implement _add_arguments")
494 495
495 496 def _parse_args(self, args):
496 497 """self.parser->self.parsed_data"""
497 498 # decode sys.argv to support unicode command-line options
498 499 uargs = []
499 500 for a in args:
500 501 if isinstance(a, str):
501 502 # don't decode if we already got unicode
502 503 a = a.decode(sys.stdin.encoding or
503 504 sys.getdefaultencoding())
504 505 uargs.append(a)
505 506 self.parsed_data, self.extra_args = self.parser.parse_known_args(uargs)
506 507
507 508 def _convert_to_config(self):
508 509 """self.parsed_data->self.config"""
509 510 for k, v in vars(self.parsed_data).iteritems():
510 511 exec_str = 'self.config.' + k + '= v'
511 512 exec exec_str in locals(), globals()
512 513
513 514
@@ -1,435 +1,436 b''
1 1 # encoding: utf-8
2 2 """
3 3 An application for IPython.
4 4
5 5 All top-level applications should use the classes in this module for
6 6 handling configuration and creating componenets.
7 7
8 8 The job of an :class:`Application` is to create the master configuration
9 9 object and then create the configurable objects, passing the config to them.
10 10
11 11 Authors:
12 12
13 13 * Brian Granger
14 14 * Fernando Perez
15 * Min RK
15 16
16 17 Notes
17 18 -----
18 19 """
19 20
20 21 #-----------------------------------------------------------------------------
21 22 # Copyright (C) 2008-2009 The IPython Development Team
22 23 #
23 24 # Distributed under the terms of the BSD License. The full license is in
24 25 # the file COPYING, distributed as part of this software.
25 26 #-----------------------------------------------------------------------------
26 27
27 28 #-----------------------------------------------------------------------------
28 29 # Imports
29 30 #-----------------------------------------------------------------------------
30 31
31 32 import logging
32 33 import os
33 34 import shutil
34 35 import sys
35 36
36 37 from IPython.config.application import Application
37 38 from IPython.config.configurable import Configurable
38 39 from IPython.config.loader import Config
39 40 from IPython.core import release, crashhandler
40 41 from IPython.utils.path import get_ipython_dir, get_ipython_package_dir, expand_path
41 42 from IPython.utils.traitlets import List, Unicode, Type, Bool, Dict
42 43
43 44 #-----------------------------------------------------------------------------
44 45 # Classes and functions
45 46 #-----------------------------------------------------------------------------
46 47
47 48
48 49 #-----------------------------------------------------------------------------
49 50 # Module errors
50 51 #-----------------------------------------------------------------------------
51 52
52 53 class ProfileDirError(Exception):
53 54 pass
54 55
55 56
56 57 #-----------------------------------------------------------------------------
57 58 # Class for managing profile directories
58 59 #-----------------------------------------------------------------------------
59 60
60 61 class ProfileDir(Configurable):
61 62 """An object to manage the profile directory and its resources.
62 63
63 64 The profile directory is used by all IPython applications, to manage
64 65 configuration, logging and security.
65 66
66 67 This object knows how to find, create and manage these directories. This
67 68 should be used by any code that wants to handle profiles.
68 69 """
69 70
70 71 security_dir_name = Unicode('security')
71 72 log_dir_name = Unicode('log')
72 73 pid_dir_name = Unicode('pid')
73 74 security_dir = Unicode(u'')
74 75 log_dir = Unicode(u'')
75 76 pid_dir = Unicode(u'')
76 77
77 78 location = Unicode(u'', config=True,
78 79 help="""Set the profile location directly. This overrides the logic used by the
79 80 `profile` option.""",
80 81 )
81 82
82 83 _location_isset = Bool(False) # flag for detecting multiply set location
83 84
84 85 def _location_changed(self, name, old, new):
85 86 if self._location_isset:
86 87 raise RuntimeError("Cannot set profile location more than once.")
87 88 self._location_isset = True
88 89 if not os.path.isdir(new):
89 90 os.makedirs(new)
90 91
91 92 # ensure config files exist:
92 93 self.security_dir = os.path.join(new, self.security_dir_name)
93 94 self.log_dir = os.path.join(new, self.log_dir_name)
94 95 self.pid_dir = os.path.join(new, self.pid_dir_name)
95 96 self.check_dirs()
96 97
97 98 def _log_dir_changed(self, name, old, new):
98 99 self.check_log_dir()
99 100
100 101 def check_log_dir(self):
101 102 if not os.path.isdir(self.log_dir):
102 103 os.mkdir(self.log_dir)
103 104
104 105 def _security_dir_changed(self, name, old, new):
105 106 self.check_security_dir()
106 107
107 108 def check_security_dir(self):
108 109 if not os.path.isdir(self.security_dir):
109 110 os.mkdir(self.security_dir, 0700)
110 111 else:
111 112 os.chmod(self.security_dir, 0700)
112 113
113 114 def _pid_dir_changed(self, name, old, new):
114 115 self.check_pid_dir()
115 116
116 117 def check_pid_dir(self):
117 118 if not os.path.isdir(self.pid_dir):
118 119 os.mkdir(self.pid_dir, 0700)
119 120 else:
120 121 os.chmod(self.pid_dir, 0700)
121 122
122 123 def check_dirs(self):
123 124 self.check_security_dir()
124 125 self.check_log_dir()
125 126 self.check_pid_dir()
126 127
127 128 def copy_config_file(self, config_file, path=None, overwrite=False):
128 129 """Copy a default config file into the active profile directory.
129 130
130 131 Default configuration files are kept in :mod:`IPython.config.default`.
131 132 This function moves these from that location to the working profile
132 133 directory.
133 134 """
134 135 dst = os.path.join(self.location, config_file)
135 136 if os.path.isfile(dst) and not overwrite:
136 137 return
137 138 if path is None:
138 139 path = os.path.join(get_ipython_package_dir(), u'config', u'profile', u'default')
139 140 src = os.path.join(path, config_file)
140 141 shutil.copy(src, dst)
141 142
142 143 @classmethod
143 144 def create_profile_dir(cls, profile_dir, config=None):
144 145 """Create a new profile directory given a full path.
145 146
146 147 Parameters
147 148 ----------
148 149 profile_dir : str
149 150 The full path to the profile directory. If it does exist, it will
150 151 be used. If not, it will be created.
151 152 """
152 153 return cls(location=profile_dir, config=config)
153 154
154 155 @classmethod
155 156 def create_profile_dir_by_name(cls, path, name=u'default', config=None):
156 157 """Create a profile dir by profile name and path.
157 158
158 159 Parameters
159 160 ----------
160 161 path : unicode
161 162 The path (directory) to put the profile directory in.
162 163 name : unicode
163 164 The name of the profile. The name of the profile directory will
164 165 be "profile_<profile>".
165 166 """
166 167 if not os.path.isdir(path):
167 168 raise ProfileDirError('Directory not found: %s' % path)
168 169 profile_dir = os.path.join(path, u'profile_' + name)
169 170 return cls(location=profile_dir, config=config)
170 171
171 172 @classmethod
172 173 def find_profile_dir_by_name(cls, ipython_dir, name=u'default', config=None):
173 174 """Find an existing profile dir by profile name, return its ProfileDir.
174 175
175 176 This searches through a sequence of paths for a profile dir. If it
176 177 is not found, a :class:`ProfileDirError` exception will be raised.
177 178
178 179 The search path algorithm is:
179 180 1. ``os.getcwd()``
180 181 2. ``ipython_dir``
181 182 3. The directories found in the ":" separated
182 183 :env:`IPCLUSTER_DIR_PATH` environment variable.
183 184
184 185 Parameters
185 186 ----------
186 187 ipython_dir : unicode or str
187 188 The IPython directory to use.
188 189 name : unicode or str
189 190 The name of the profile. The name of the profile directory
190 191 will be "profile_<profile>".
191 192 """
192 193 dirname = u'profile_' + name
193 194 profile_dir_paths = os.environ.get('IPYTHON_PROFILE_PATH','')
194 195 if profile_dir_paths:
195 196 profile_dir_paths = profile_dir_paths.split(os.pathsep)
196 197 else:
197 198 profile_dir_paths = []
198 199 paths = [os.getcwd(), ipython_dir] + profile_dir_paths
199 200 for p in paths:
200 201 profile_dir = os.path.join(p, dirname)
201 202 if os.path.isdir(profile_dir):
202 203 return cls(location=profile_dir, config=config)
203 204 else:
204 205 raise ProfileDirError('Profile directory not found in paths: %s' % dirname)
205 206
206 207 @classmethod
207 208 def find_profile_dir(cls, profile_dir, config=None):
208 209 """Find/create a profile dir and return its ProfileDir.
209 210
210 211 This will create the profile directory if it doesn't exist.
211 212
212 213 Parameters
213 214 ----------
214 215 profile_dir : unicode or str
215 216 The path of the profile directory. This is expanded using
216 217 :func:`IPython.utils.genutils.expand_path`.
217 218 """
218 219 profile_dir = expand_path(profile_dir)
219 220 if not os.path.isdir(profile_dir):
220 221 raise ProfileDirError('Profile directory not found: %s' % profile_dir)
221 222 return cls(location=profile_dir, config=config)
222 223
223 224
224 225 #-----------------------------------------------------------------------------
225 226 # Base Application Class
226 227 #-----------------------------------------------------------------------------
227 228
228 229 # aliases and flags
229 230
230 231 base_aliases = dict(
231 232 profile='BaseIPythonApplication.profile',
232 233 ipython_dir='BaseIPythonApplication.ipython_dir',
233 234 )
234 235
235 236 base_flags = dict(
236 237 debug = ({'Application' : {'log_level' : logging.DEBUG}},
237 238 "set log level to logging.DEBUG (maximize logging output)"),
238 239 quiet = ({'Application' : {'log_level' : logging.CRITICAL}},
239 240 "set log level to logging.CRITICAL (minimize logging output)"),
240 241 init = ({'BaseIPythonApplication' : {
241 242 'copy_config_files' : True,
242 243 'auto_create' : True}
243 244 }, "Initialize profile with default config files")
244 245 )
245 246
246 247
247 248 class BaseIPythonApplication(Application):
248 249
249 250 name = Unicode(u'ipython')
250 251 description = Unicode(u'IPython: an enhanced interactive Python shell.')
251 252 version = Unicode(release.version)
252 253
253 254 aliases = Dict(base_aliases)
254 255 flags = Dict(base_flags)
255 256
256 257 # Track whether the config_file has changed,
257 258 # because some logic happens only if we aren't using the default.
258 259 config_file_specified = Bool(False)
259 260
260 261 config_file_name = Unicode(u'ipython_config.py')
261 262 def _config_file_name_changed(self, name, old, new):
262 263 if new != old:
263 264 self.config_file_specified = True
264 265
265 266 # The directory that contains IPython's builtin profiles.
266 267 builtin_profile_dir = Unicode(
267 268 os.path.join(get_ipython_package_dir(), u'config', u'profile', u'default')
268 269 )
269 270
270 271 config_file_paths = List(Unicode)
271 272 def _config_file_paths_default(self):
272 273 return [os.getcwdu()]
273 274
274 275 profile = Unicode(u'default', config=True,
275 276 help="""The IPython profile to use."""
276 277 )
277 278 def _profile_changed(self, name, old, new):
278 279 self.builtin_profile_dir = os.path.join(
279 280 get_ipython_package_dir(), u'config', u'profile', new
280 281 )
281 282
282 283
283 284 ipython_dir = Unicode(get_ipython_dir(), config=True,
284 285 help="""
285 286 The name of the IPython directory. This directory is used for logging
286 287 configuration (through profiles), history storage, etc. The default
287 288 is usually $HOME/.ipython. This options can also be specified through
288 289 the environment variable IPYTHON_DIR.
289 290 """
290 291 )
291 292
292 293 overwrite = Bool(False, config=True,
293 294 help="""Whether to overwrite existing config files when copying""")
294 295 auto_create = Bool(False, config=True,
295 296 help="""Whether to create profile dir if it doesn't exist""")
296 297
297 298 config_files = List(Unicode)
298 299 def _config_files_default(self):
299 300 return [u'ipython_config.py']
300 301
301 302 copy_config_files = Bool(False, config=True,
302 303 help="""Whether to copy the default config files into the profile dir.""")
303 304
304 305 # The class to use as the crash handler.
305 306 crash_handler_class = Type(crashhandler.CrashHandler)
306 307
307 308 def __init__(self, **kwargs):
308 309 super(BaseIPythonApplication, self).__init__(**kwargs)
309 310 # ensure even default IPYTHON_DIR exists
310 311 if not os.path.exists(self.ipython_dir):
311 312 self._ipython_dir_changed('ipython_dir', self.ipython_dir, self.ipython_dir)
312 313
313 314 #-------------------------------------------------------------------------
314 315 # Various stages of Application creation
315 316 #-------------------------------------------------------------------------
316 317
317 318 def init_crash_handler(self):
318 319 """Create a crash handler, typically setting sys.excepthook to it."""
319 320 self.crash_handler = self.crash_handler_class(self)
320 321 sys.excepthook = self.crash_handler
321 322
322 323 def _ipython_dir_changed(self, name, old, new):
323 324 if old in sys.path:
324 325 sys.path.remove(old)
325 326 sys.path.append(os.path.abspath(new))
326 327 if not os.path.isdir(new):
327 328 os.makedirs(new, mode=0777)
328 329 readme = os.path.join(new, 'README')
329 330 if not os.path.exists(readme):
330 331 path = os.path.join(get_ipython_package_dir(), u'config', u'profile')
331 332 shutil.copy(os.path.join(path, 'README'), readme)
332 333 self.log.debug("IPYTHON_DIR set to: %s" % new)
333 334
334 335 def load_config_file(self, suppress_errors=True):
335 336 """Load the config file.
336 337
337 338 By default, errors in loading config are handled, and a warning
338 339 printed on screen. For testing, the suppress_errors option is set
339 340 to False, so errors will make tests fail.
340 341 """
341 342 self.log.debug("Attempting to load config file: %s" %
342 343 self.config_file_name)
343 344 try:
344 345 Application.load_config_file(
345 346 self,
346 347 self.config_file_name,
347 348 path=self.config_file_paths
348 349 )
349 350 except IOError:
350 351 # Only warn if the default config file was NOT being used.
351 352 if self.config_file_specified:
352 353 self.log.warn("Config file not found, skipping: %s" %
353 354 self.config_file_name)
354 355 except:
355 356 # For testing purposes.
356 357 if not suppress_errors:
357 358 raise
358 359 self.log.warn("Error loading config file: %s" %
359 360 self.config_file_name, exc_info=True)
360 361
361 362 def init_profile_dir(self):
362 363 """initialize the profile dir"""
363 364 try:
364 365 # location explicitly specified:
365 366 location = self.config.ProfileDir.location
366 367 except AttributeError:
367 368 # location not specified, find by profile name
368 369 try:
369 370 p = ProfileDir.find_profile_dir_by_name(self.ipython_dir, self.profile, self.config)
370 371 except ProfileDirError:
371 372 # not found, maybe create it (always create default profile)
372 373 if self.auto_create or self.profile=='default':
373 374 try:
374 375 p = ProfileDir.create_profile_dir_by_name(self.ipython_dir, self.profile, self.config)
375 376 except ProfileDirError:
376 377 self.log.fatal("Could not create profile: %r"%self.profile)
377 378 self.exit(1)
378 379 else:
379 380 self.log.info("Created profile dir: %r"%p.location)
380 381 else:
381 382 self.log.fatal("Profile %r not found."%self.profile)
382 383 self.exit(1)
383 384 else:
384 385 self.log.info("Using existing profile dir: %r"%p.location)
385 386 else:
386 387 # location is fully specified
387 388 try:
388 389 p = ProfileDir.find_profile_dir(location, self.config)
389 390 except ProfileDirError:
390 391 # not found, maybe create it
391 392 if self.auto_create:
392 393 try:
393 394 p = ProfileDir.create_profile_dir(location, self.config)
394 395 except ProfileDirError:
395 396 self.log.fatal("Could not create profile directory: %r"%location)
396 397 self.exit(1)
397 398 else:
398 399 self.log.info("Creating new profile dir: %r"%location)
399 400 else:
400 401 self.log.fatal("Profile directory %r not found."%location)
401 402 self.exit(1)
402 403 else:
403 404 self.log.info("Using existing profile dir: %r"%location)
404 405
405 406 self.profile_dir = p
406 407 self.config_file_paths.append(p.location)
407 408
408 409 def init_config_files(self):
409 410 """[optionally] copy default config files into profile dir."""
410 411 # copy config files
411 412 if self.copy_config_files:
412 413 path = self.builtin_profile_dir
413 414 src = self.profile
414 415 if not os.path.exists(path):
415 416 # use default if new profile doesn't have a preset
416 417 path = None
417 418 src = 'default'
418 419
419 420 self.log.debug("Staging %s config files into %r [overwrite=%s]"%(
420 421 src, self.profile_dir.location, self.overwrite)
421 422 )
422 423
423 424 for cfg in self.config_files:
424 425 self.profile_dir.copy_config_file(cfg, path=path, overwrite=self.overwrite)
425 426
426 427 def initialize(self, argv=None):
427 428 self.init_crash_handler()
428 429 self.parse_command_line(argv)
429 430 cl_config = self.config
430 431 self.init_profile_dir()
431 432 self.init_config_files()
432 433 self.load_config_file()
433 434 # enforce cl-opts override configfile opts:
434 435 self.update_config(cl_config)
435 436
@@ -1,32 +1,37 b''
1 """The IPython ZMQ-based parallel computing interface."""
1 """The IPython ZMQ-based parallel computing interface.
2
3 Authors:
4
5 * MinRK
6 """
2 7 #-----------------------------------------------------------------------------
3 8 # Copyright (C) 2011 The IPython Development Team
4 9 #
5 10 # Distributed under the terms of the BSD License. The full license is in
6 11 # the file COPYING, distributed as part of this software.
7 12 #-----------------------------------------------------------------------------
8 13
9 14 #-----------------------------------------------------------------------------
10 15 # Imports
11 16 #-----------------------------------------------------------------------------
12 17
13 18 import os
14 19 import zmq
15 20
16 21
17 22 if os.name == 'nt':
18 23 if zmq.__version__ < '2.1.7':
19 24 raise ImportError("IPython.parallel requires pyzmq/0MQ >= 2.1.7 on Windows, "
20 25 "and you appear to have %s"%zmq.__version__)
21 26 elif zmq.__version__ < '2.1.4':
22 27 raise ImportError("IPython.parallel requires pyzmq/0MQ >= 2.1.4, you appear to have %s"%zmq.__version__)
23 28
24 29 from IPython.utils.pickleutil import Reference
25 30
26 31 from .client.asyncresult import *
27 32 from .client.client import Client
28 33 from .client.remotefunction import *
29 34 from .client.view import *
30 35 from .controller.dependency import *
31 36
32 37
@@ -1,257 +1,263 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 The IPython cluster directory
4 The Base Application class for IPython.parallel apps
5
6 Authors:
7
8 * Brian Granger
9 * Min RK
10
5 11 """
6 12
7 13 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
9 15 #
10 16 # Distributed under the terms of the BSD License. The full license is in
11 17 # the file COPYING, distributed as part of this software.
12 18 #-----------------------------------------------------------------------------
13 19
14 20 #-----------------------------------------------------------------------------
15 21 # Imports
16 22 #-----------------------------------------------------------------------------
17 23
18 24 from __future__ import with_statement
19 25
20 26 import os
21 27 import logging
22 28 import re
23 29 import sys
24 30
25 31 from subprocess import Popen, PIPE
26 32
27 33 from IPython.core import release
28 34 from IPython.core.crashhandler import CrashHandler
29 35 from IPython.core.newapplication import (
30 36 BaseIPythonApplication,
31 37 base_aliases as base_ip_aliases,
32 38 base_flags as base_ip_flags
33 39 )
34 40 from IPython.utils.path import expand_path
35 41
36 42 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict, List
37 43
38 44 #-----------------------------------------------------------------------------
39 45 # Module errors
40 46 #-----------------------------------------------------------------------------
41 47
42 48 class PIDFileError(Exception):
43 49 pass
44 50
45 51
46 52 #-----------------------------------------------------------------------------
47 53 # Crash handler for this application
48 54 #-----------------------------------------------------------------------------
49 55
50 56
51 57 _message_template = """\
52 58 Oops, $self.app_name crashed. We do our best to make it stable, but...
53 59
54 60 A crash report was automatically generated with the following information:
55 61 - A verbatim copy of the crash traceback.
56 62 - Data on your current $self.app_name configuration.
57 63
58 64 It was left in the file named:
59 65 \t'$self.crash_report_fname'
60 66 If you can email this file to the developers, the information in it will help
61 67 them in understanding and correcting the problem.
62 68
63 69 You can mail it to: $self.contact_name at $self.contact_email
64 70 with the subject '$self.app_name Crash Report'.
65 71
66 72 If you want to do it now, the following command will work (under Unix):
67 73 mail -s '$self.app_name Crash Report' $self.contact_email < $self.crash_report_fname
68 74
69 75 To ensure accurate tracking of this issue, please file a report about it at:
70 76 $self.bug_tracker
71 77 """
72 78
73 79 class ParallelCrashHandler(CrashHandler):
74 80 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
75 81
76 82 message_template = _message_template
77 83
78 84 def __init__(self, app):
79 85 contact_name = release.authors['Min'][0]
80 86 contact_email = release.authors['Min'][1]
81 87 bug_tracker = 'http://github.com/ipython/ipython/issues'
82 88 super(ParallelCrashHandler,self).__init__(
83 89 app, contact_name, contact_email, bug_tracker
84 90 )
85 91
86 92
87 93 #-----------------------------------------------------------------------------
88 94 # Main application
89 95 #-----------------------------------------------------------------------------
90 96 base_aliases = {}
91 97 base_aliases.update(base_ip_aliases)
92 98 base_aliases.update({
93 99 'profile_dir' : 'ProfileDir.location',
94 100 'log_level' : 'BaseParallelApplication.log_level',
95 101 'work_dir' : 'BaseParallelApplication.work_dir',
96 102 'log_to_file' : 'BaseParallelApplication.log_to_file',
97 103 'clean_logs' : 'BaseParallelApplication.clean_logs',
98 104 'log_url' : 'BaseParallelApplication.log_url',
99 105 })
100 106
101 107 base_flags = {
102 108 'log-to-file' : (
103 109 {'BaseParallelApplication' : {'log_to_file' : True}},
104 110 "send log output to a file"
105 111 )
106 112 }
107 113 base_flags.update(base_ip_flags)
108 114
109 115 class BaseParallelApplication(BaseIPythonApplication):
110 116 """The base Application for IPython.parallel apps
111 117
112 118 Principle extensions to BaseIPyythonApplication:
113 119
114 120 * work_dir
115 121 * remote logging via pyzmq
116 122 * IOLoop instance
117 123 """
118 124
119 125 crash_handler_class = ParallelCrashHandler
120 126
121 127 def _log_level_default(self):
122 128 # temporarily override default_log_level to INFO
123 129 return logging.INFO
124 130
125 131 work_dir = Unicode(os.getcwdu(), config=True,
126 132 help='Set the working dir for the process.'
127 133 )
128 134 def _work_dir_changed(self, name, old, new):
129 135 self.work_dir = unicode(expand_path(new))
130 136
131 137 log_to_file = Bool(config=True,
132 138 help="whether to log to a file")
133 139
134 140 clean_logs = Bool(False, config=True,
135 141 help="whether to cleanup old logfiles before starting")
136 142
137 143 log_url = Unicode('', config=True,
138 144 help="The ZMQ URL of the iplogger to aggregate logging.")
139 145
140 146 def _config_files_default(self):
141 147 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
142 148
143 149 loop = Instance('zmq.eventloop.ioloop.IOLoop')
144 150 def _loop_default(self):
145 151 from zmq.eventloop.ioloop import IOLoop
146 152 return IOLoop.instance()
147 153
148 154 aliases = Dict(base_aliases)
149 155 flags = Dict(base_flags)
150 156
151 157 def initialize(self, argv=None):
152 158 """initialize the app"""
153 159 super(BaseParallelApplication, self).initialize(argv)
154 160 self.to_work_dir()
155 161 self.reinit_logging()
156 162
157 163 def to_work_dir(self):
158 164 wd = self.work_dir
159 165 if unicode(wd) != os.getcwdu():
160 166 os.chdir(wd)
161 167 self.log.info("Changing to working dir: %s" % wd)
162 168 # This is the working dir by now.
163 169 sys.path.insert(0, '')
164 170
165 171 def reinit_logging(self):
166 172 # Remove old log files
167 173 log_dir = self.profile_dir.log_dir
168 174 if self.clean_logs:
169 175 for f in os.listdir(log_dir):
170 176 if re.match(r'%s-\d+\.(log|err|out)'%self.name,f):
171 177 os.remove(os.path.join(log_dir, f))
172 178 if self.log_to_file:
173 179 # Start logging to the new log file
174 180 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
175 181 logfile = os.path.join(log_dir, log_filename)
176 182 open_log_file = open(logfile, 'w')
177 183 else:
178 184 open_log_file = None
179 185 if open_log_file is not None:
180 186 self.log.removeHandler(self._log_handler)
181 187 self._log_handler = logging.StreamHandler(open_log_file)
182 188 self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
183 189 self._log_handler.setFormatter(self._log_formatter)
184 190 self.log.addHandler(self._log_handler)
185 191
186 192 def write_pid_file(self, overwrite=False):
187 193 """Create a .pid file in the pid_dir with my pid.
188 194
189 195 This must be called after pre_construct, which sets `self.pid_dir`.
190 196 This raises :exc:`PIDFileError` if the pid file exists already.
191 197 """
192 198 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
193 199 if os.path.isfile(pid_file):
194 200 pid = self.get_pid_from_file()
195 201 if not overwrite:
196 202 raise PIDFileError(
197 203 'The pid file [%s] already exists. \nThis could mean that this '
198 204 'server is already running with [pid=%s].' % (pid_file, pid)
199 205 )
200 206 with open(pid_file, 'w') as f:
201 207 self.log.info("Creating pid file: %s" % pid_file)
202 208 f.write(repr(os.getpid())+'\n')
203 209
204 210 def remove_pid_file(self):
205 211 """Remove the pid file.
206 212
207 213 This should be called at shutdown by registering a callback with
208 214 :func:`reactor.addSystemEventTrigger`. This needs to return
209 215 ``None``.
210 216 """
211 217 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
212 218 if os.path.isfile(pid_file):
213 219 try:
214 220 self.log.info("Removing pid file: %s" % pid_file)
215 221 os.remove(pid_file)
216 222 except:
217 223 self.log.warn("Error removing the pid file: %s" % pid_file)
218 224
219 225 def get_pid_from_file(self):
220 226 """Get the pid from the pid file.
221 227
222 228 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
223 229 """
224 230 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
225 231 if os.path.isfile(pid_file):
226 232 with open(pid_file, 'r') as f:
227 233 pid = int(f.read().strip())
228 234 return pid
229 235 else:
230 236 raise PIDFileError('pid file not found: %s' % pid_file)
231 237
232 238 def check_pid(self, pid):
233 239 if os.name == 'nt':
234 240 try:
235 241 import ctypes
236 242 # returns 0 if no such process (of ours) exists
237 243 # positive int otherwise
238 244 p = ctypes.windll.kernel32.OpenProcess(1,0,pid)
239 245 except Exception:
240 246 self.log.warn(
241 247 "Could not determine whether pid %i is running via `OpenProcess`. "
242 248 " Making the likely assumption that it is."%pid
243 249 )
244 250 return True
245 251 return bool(p)
246 252 else:
247 253 try:
248 254 p = Popen(['ps','x'], stdout=PIPE, stderr=PIPE)
249 255 output,_ = p.communicate()
250 256 except OSError:
251 257 self.log.warn(
252 258 "Could not determine whether pid %i is running via `ps x`. "
253 259 " Making the likely assumption that it is."%pid
254 260 )
255 261 return True
256 262 pids = map(int, re.findall(r'^\W*\d+', output, re.MULTILINE))
257 263 return pid in pids
@@ -1,521 +1,527 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The ipcluster application.
5
6 Authors:
7
8 * Brian Granger
9 * MinRK
10
5 11 """
6 12
7 13 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
9 15 #
10 16 # Distributed under the terms of the BSD License. The full license is in
11 17 # the file COPYING, distributed as part of this software.
12 18 #-----------------------------------------------------------------------------
13 19
14 20 #-----------------------------------------------------------------------------
15 21 # Imports
16 22 #-----------------------------------------------------------------------------
17 23
18 24 import errno
19 25 import logging
20 26 import os
21 27 import re
22 28 import signal
23 29
24 30 from subprocess import check_call, CalledProcessError, PIPE
25 31 import zmq
26 32 from zmq.eventloop import ioloop
27 33
28 34 from IPython.config.application import Application, boolean_flag
29 35 from IPython.config.loader import Config
30 36 from IPython.core.newapplication import BaseIPythonApplication, ProfileDir
31 37 from IPython.utils.importstring import import_item
32 38 from IPython.utils.traitlets import Int, Unicode, Bool, CFloat, Dict, List
33 39
34 40 from IPython.parallel.apps.baseapp import (
35 41 BaseParallelApplication,
36 42 PIDFileError,
37 43 base_flags, base_aliases
38 44 )
39 45
40 46
41 47 #-----------------------------------------------------------------------------
42 48 # Module level variables
43 49 #-----------------------------------------------------------------------------
44 50
45 51
46 52 default_config_file_name = u'ipcluster_config.py'
47 53
48 54
49 55 _description = """Start an IPython cluster for parallel computing.
50 56
51 57 An IPython cluster consists of 1 controller and 1 or more engines.
52 58 This command automates the startup of these processes using a wide
53 59 range of startup methods (SSH, local processes, PBS, mpiexec,
54 60 Windows HPC Server 2008). To start a cluster with 4 engines on your
55 61 local host simply do 'ipcluster start n=4'. For more complex usage
56 62 you will typically do 'ipcluster create profile=mycluster', then edit
57 63 configuration files, followed by 'ipcluster start profile=mycluster n=4'.
58 64 """
59 65
60 66
61 67 # Exit codes for ipcluster
62 68
63 69 # This will be the exit code if the ipcluster appears to be running because
64 70 # a .pid file exists
65 71 ALREADY_STARTED = 10
66 72
67 73
68 74 # This will be the exit code if ipcluster stop is run, but there is not .pid
69 75 # file to be found.
70 76 ALREADY_STOPPED = 11
71 77
72 78 # This will be the exit code if ipcluster engines is run, but there is not .pid
73 79 # file to be found.
74 80 NO_CLUSTER = 12
75 81
76 82
77 83 #-----------------------------------------------------------------------------
78 84 # Main application
79 85 #-----------------------------------------------------------------------------
80 86 start_help = """Start an IPython cluster for parallel computing
81 87
82 88 Start an ipython cluster by its profile name or cluster
83 89 directory. Cluster directories contain configuration, log and
84 90 security related files and are named using the convention
85 91 'cluster_<profile>' and should be creating using the 'start'
86 92 subcommand of 'ipcluster'. If your cluster directory is in
87 93 the cwd or the ipython directory, you can simply refer to it
88 94 using its profile name, 'ipcluster start n=4 profile=<profile>`,
89 95 otherwise use the 'profile_dir' option.
90 96 """
91 97 stop_help = """Stop a running IPython cluster
92 98
93 99 Stop a running ipython cluster by its profile name or cluster
94 100 directory. Cluster directories are named using the convention
95 101 'cluster_<profile>'. If your cluster directory is in
96 102 the cwd or the ipython directory, you can simply refer to it
97 103 using its profile name, 'ipcluster stop profile=<profile>`, otherwise
98 104 use the 'profile_dir' option.
99 105 """
100 106 engines_help = """Start engines connected to an existing IPython cluster
101 107
102 108 Start one or more engines to connect to an existing Cluster
103 109 by profile name or cluster directory.
104 110 Cluster directories contain configuration, log and
105 111 security related files and are named using the convention
106 112 'cluster_<profile>' and should be creating using the 'start'
107 113 subcommand of 'ipcluster'. If your cluster directory is in
108 114 the cwd or the ipython directory, you can simply refer to it
109 115 using its profile name, 'ipcluster engines n=4 profile=<profile>`,
110 116 otherwise use the 'profile_dir' option.
111 117 """
112 118 create_help = """Create an ipcluster profile by name
113 119
114 120 Create an ipython cluster directory by its profile name or
115 121 cluster directory path. Cluster directories contain
116 122 configuration, log and security related files and are named
117 123 using the convention 'cluster_<profile>'. By default they are
118 124 located in your ipython directory. Once created, you will
119 125 probably need to edit the configuration files in the cluster
120 126 directory to configure your cluster. Most users will create a
121 127 cluster directory by profile name,
122 128 `ipcluster create profile=mycluster`, which will put the directory
123 129 in `<ipython_dir>/cluster_mycluster`.
124 130 """
125 131 list_help = """List available cluster profiles
126 132
127 133 List all available clusters, by cluster directory, that can
128 134 be found in the current working directly or in the ipython
129 135 directory. Cluster directories are named using the convention
130 136 'cluster_<profile>'.
131 137 """
132 138
133 139
134 140 class IPClusterList(BaseIPythonApplication):
135 141 name = u'ipcluster-list'
136 142 description = list_help
137 143
138 144 # empty aliases
139 145 aliases=Dict()
140 146 flags = Dict(base_flags)
141 147
142 148 def _log_level_default(self):
143 149 return 20
144 150
145 151 def list_profile_dirs(self):
146 152 # Find the search paths
147 153 profile_dir_paths = os.environ.get('IPYTHON_PROFILE_PATH','')
148 154 if profile_dir_paths:
149 155 profile_dir_paths = profile_dir_paths.split(':')
150 156 else:
151 157 profile_dir_paths = []
152 158
153 159 ipython_dir = self.ipython_dir
154 160
155 161 paths = [os.getcwd(), ipython_dir] + profile_dir_paths
156 162 paths = list(set(paths))
157 163
158 164 self.log.info('Searching for cluster profiles in paths: %r' % paths)
159 165 for path in paths:
160 166 files = os.listdir(path)
161 167 for f in files:
162 168 full_path = os.path.join(path, f)
163 169 if os.path.isdir(full_path) and f.startswith('profile_') and \
164 170 os.path.isfile(os.path.join(full_path, 'ipcontroller_config.py')):
165 171 profile = f.split('_')[-1]
166 172 start_cmd = 'ipcluster start profile=%s n=4' % profile
167 173 print start_cmd + " ==> " + full_path
168 174
169 175 def start(self):
170 176 self.list_profile_dirs()
171 177
172 178
173 179 # `ipcluster create` will be deprecated when `ipython profile create` or equivalent exists
174 180
175 181 create_flags = {}
176 182 create_flags.update(base_flags)
177 183 create_flags.update(boolean_flag('reset', 'IPClusterCreate.overwrite',
178 184 "reset config files to defaults", "leave existing config files"))
179 185
180 186 class IPClusterCreate(BaseParallelApplication):
181 187 name = u'ipcluster-create'
182 188 description = create_help
183 189 auto_create = Bool(True)
184 190 config_file_name = Unicode(default_config_file_name)
185 191
186 192 flags = Dict(create_flags)
187 193
188 194 aliases = Dict(dict(profile='BaseIPythonApplication.profile'))
189 195
190 196 classes = [ProfileDir]
191 197
192 198
193 199 stop_aliases = dict(
194 200 signal='IPClusterStop.signal',
195 201 profile='BaseIPythonApplication.profile',
196 202 profile_dir='ProfileDir.location',
197 203 )
198 204
199 205 class IPClusterStop(BaseParallelApplication):
200 206 name = u'ipcluster'
201 207 description = stop_help
202 208 config_file_name = Unicode(default_config_file_name)
203 209
204 210 signal = Int(signal.SIGINT, config=True,
205 211 help="signal to use for stopping processes.")
206 212
207 213 aliases = Dict(stop_aliases)
208 214
209 215 def start(self):
210 216 """Start the app for the stop subcommand."""
211 217 try:
212 218 pid = self.get_pid_from_file()
213 219 except PIDFileError:
214 220 self.log.critical(
215 221 'Could not read pid file, cluster is probably not running.'
216 222 )
217 223 # Here I exit with a unusual exit status that other processes
218 224 # can watch for to learn how I existed.
219 225 self.remove_pid_file()
220 226 self.exit(ALREADY_STOPPED)
221 227
222 228 if not self.check_pid(pid):
223 229 self.log.critical(
224 230 'Cluster [pid=%r] is not running.' % pid
225 231 )
226 232 self.remove_pid_file()
227 233 # Here I exit with a unusual exit status that other processes
228 234 # can watch for to learn how I existed.
229 235 self.exit(ALREADY_STOPPED)
230 236
231 237 elif os.name=='posix':
232 238 sig = self.signal
233 239 self.log.info(
234 240 "Stopping cluster [pid=%r] with [signal=%r]" % (pid, sig)
235 241 )
236 242 try:
237 243 os.kill(pid, sig)
238 244 except OSError:
239 245 self.log.error("Stopping cluster failed, assuming already dead.",
240 246 exc_info=True)
241 247 self.remove_pid_file()
242 248 elif os.name=='nt':
243 249 try:
244 250 # kill the whole tree
245 251 p = check_call(['taskkill', '-pid', str(pid), '-t', '-f'], stdout=PIPE,stderr=PIPE)
246 252 except (CalledProcessError, OSError):
247 253 self.log.error("Stopping cluster failed, assuming already dead.",
248 254 exc_info=True)
249 255 self.remove_pid_file()
250 256
251 257 engine_aliases = {}
252 258 engine_aliases.update(base_aliases)
253 259 engine_aliases.update(dict(
254 260 n='IPClusterEngines.n',
255 261 elauncher = 'IPClusterEngines.engine_launcher_class',
256 262 ))
257 263 class IPClusterEngines(BaseParallelApplication):
258 264
259 265 name = u'ipcluster'
260 266 description = engines_help
261 267 usage = None
262 268 config_file_name = Unicode(default_config_file_name)
263 269 default_log_level = logging.INFO
264 270 classes = List()
265 271 def _classes_default(self):
266 272 from IPython.parallel.apps import launcher
267 273 launchers = launcher.all_launchers
268 274 eslaunchers = [ l for l in launchers if 'EngineSet' in l.__name__]
269 275 return [ProfileDir]+eslaunchers
270 276
271 277 n = Int(2, config=True,
272 278 help="The number of engines to start.")
273 279
274 280 engine_launcher_class = Unicode('LocalEngineSetLauncher',
275 281 config=True,
276 282 help="The class for launching a set of Engines."
277 283 )
278 284 daemonize = Bool(False, config=True,
279 285 help='Daemonize the ipcluster program. This implies --log-to-file')
280 286
281 287 def _daemonize_changed(self, name, old, new):
282 288 if new:
283 289 self.log_to_file = True
284 290
285 291 aliases = Dict(engine_aliases)
286 292 # flags = Dict(flags)
287 293 _stopping = False
288 294
289 295 def initialize(self, argv=None):
290 296 super(IPClusterEngines, self).initialize(argv)
291 297 self.init_signal()
292 298 self.init_launchers()
293 299
294 300 def init_launchers(self):
295 301 self.engine_launcher = self.build_launcher(self.engine_launcher_class)
296 302 self.engine_launcher.on_stop(lambda r: self.loop.stop())
297 303
298 304 def init_signal(self):
299 305 # Setup signals
300 306 signal.signal(signal.SIGINT, self.sigint_handler)
301 307
302 308 def build_launcher(self, clsname):
303 309 """import and instantiate a Launcher based on importstring"""
304 310 if '.' not in clsname:
305 311 # not a module, presume it's the raw name in apps.launcher
306 312 clsname = 'IPython.parallel.apps.launcher.'+clsname
307 313 # print repr(clsname)
308 314 klass = import_item(clsname)
309 315
310 316 launcher = klass(
311 317 work_dir=self.profile_dir.location, config=self.config, log=self.log
312 318 )
313 319 return launcher
314 320
315 321 def start_engines(self):
316 322 self.log.info("Starting %i engines"%self.n)
317 323 self.engine_launcher.start(
318 324 self.n,
319 325 self.profile_dir.location
320 326 )
321 327
322 328 def stop_engines(self):
323 329 self.log.info("Stopping Engines...")
324 330 if self.engine_launcher.running:
325 331 d = self.engine_launcher.stop()
326 332 return d
327 333 else:
328 334 return None
329 335
330 336 def stop_launchers(self, r=None):
331 337 if not self._stopping:
332 338 self._stopping = True
333 339 self.log.error("IPython cluster: stopping")
334 340 self.stop_engines()
335 341 # Wait a few seconds to let things shut down.
336 342 dc = ioloop.DelayedCallback(self.loop.stop, 4000, self.loop)
337 343 dc.start()
338 344
339 345 def sigint_handler(self, signum, frame):
340 346 self.log.debug("SIGINT received, stopping launchers...")
341 347 self.stop_launchers()
342 348
343 349 def start_logging(self):
344 350 # Remove old log files of the controller and engine
345 351 if self.clean_logs:
346 352 log_dir = self.profile_dir.log_dir
347 353 for f in os.listdir(log_dir):
348 354 if re.match(r'ip(engine|controller)z-\d+\.(log|err|out)',f):
349 355 os.remove(os.path.join(log_dir, f))
350 356 # This will remove old log files for ipcluster itself
351 357 # super(IPBaseParallelApplication, self).start_logging()
352 358
353 359 def start(self):
354 360 """Start the app for the engines subcommand."""
355 361 self.log.info("IPython cluster: started")
356 362 # First see if the cluster is already running
357 363
358 364 # Now log and daemonize
359 365 self.log.info(
360 366 'Starting engines with [daemon=%r]' % self.daemonize
361 367 )
362 368 # TODO: Get daemonize working on Windows or as a Windows Server.
363 369 if self.daemonize:
364 370 if os.name=='posix':
365 371 from twisted.scripts._twistd_unix import daemonize
366 372 daemonize()
367 373
368 374 dc = ioloop.DelayedCallback(self.start_engines, 0, self.loop)
369 375 dc.start()
370 376 # Now write the new pid file AFTER our new forked pid is active.
371 377 # self.write_pid_file()
372 378 try:
373 379 self.loop.start()
374 380 except KeyboardInterrupt:
375 381 pass
376 382 except zmq.ZMQError as e:
377 383 if e.errno == errno.EINTR:
378 384 pass
379 385 else:
380 386 raise
381 387
382 388 start_aliases = {}
383 389 start_aliases.update(engine_aliases)
384 390 start_aliases.update(dict(
385 391 delay='IPClusterStart.delay',
386 392 clean_logs='IPClusterStart.clean_logs',
387 393 ))
388 394
389 395 class IPClusterStart(IPClusterEngines):
390 396
391 397 name = u'ipcluster'
392 398 description = start_help
393 399 default_log_level = logging.INFO
394 400 auto_create = Bool(True, config=True,
395 401 help="whether to create the profile_dir if it doesn't exist")
396 402 classes = List()
397 403 def _classes_default(self,):
398 404 from IPython.parallel.apps import launcher
399 405 return [ProfileDir]+launcher.all_launchers
400 406
401 407 clean_logs = Bool(True, config=True,
402 408 help="whether to cleanup old logs before starting")
403 409
404 410 delay = CFloat(1., config=True,
405 411 help="delay (in s) between starting the controller and the engines")
406 412
407 413 controller_launcher_class = Unicode('LocalControllerLauncher',
408 414 config=True,
409 415 help="The class for launching a Controller."
410 416 )
411 417 reset = Bool(False, config=True,
412 418 help="Whether to reset config files as part of '--create'."
413 419 )
414 420
415 421 # flags = Dict(flags)
416 422 aliases = Dict(start_aliases)
417 423
418 424 def init_launchers(self):
419 425 self.controller_launcher = self.build_launcher(self.controller_launcher_class)
420 426 self.engine_launcher = self.build_launcher(self.engine_launcher_class)
421 427 self.controller_launcher.on_stop(self.stop_launchers)
422 428
423 429 def start_controller(self):
424 430 self.controller_launcher.start(
425 431 self.profile_dir.location
426 432 )
427 433
428 434 def stop_controller(self):
429 435 # self.log.info("In stop_controller")
430 436 if self.controller_launcher and self.controller_launcher.running:
431 437 return self.controller_launcher.stop()
432 438
433 439 def stop_launchers(self, r=None):
434 440 if not self._stopping:
435 441 self.stop_controller()
436 442 super(IPClusterStart, self).stop_launchers()
437 443
438 444 def start(self):
439 445 """Start the app for the start subcommand."""
440 446 # First see if the cluster is already running
441 447 try:
442 448 pid = self.get_pid_from_file()
443 449 except PIDFileError:
444 450 pass
445 451 else:
446 452 if self.check_pid(pid):
447 453 self.log.critical(
448 454 'Cluster is already running with [pid=%s]. '
449 455 'use "ipcluster stop" to stop the cluster.' % pid
450 456 )
451 457 # Here I exit with a unusual exit status that other processes
452 458 # can watch for to learn how I existed.
453 459 self.exit(ALREADY_STARTED)
454 460 else:
455 461 self.remove_pid_file()
456 462
457 463
458 464 # Now log and daemonize
459 465 self.log.info(
460 466 'Starting ipcluster with [daemon=%r]' % self.daemonize
461 467 )
462 468 # TODO: Get daemonize working on Windows or as a Windows Server.
463 469 if self.daemonize:
464 470 if os.name=='posix':
465 471 from twisted.scripts._twistd_unix import daemonize
466 472 daemonize()
467 473
468 474 dc = ioloop.DelayedCallback(self.start_controller, 0, self.loop)
469 475 dc.start()
470 476 dc = ioloop.DelayedCallback(self.start_engines, 1000*self.delay, self.loop)
471 477 dc.start()
472 478 # Now write the new pid file AFTER our new forked pid is active.
473 479 self.write_pid_file()
474 480 try:
475 481 self.loop.start()
476 482 except KeyboardInterrupt:
477 483 pass
478 484 except zmq.ZMQError as e:
479 485 if e.errno == errno.EINTR:
480 486 pass
481 487 else:
482 488 raise
483 489 finally:
484 490 self.remove_pid_file()
485 491
486 492 base='IPython.parallel.apps.ipclusterapp.IPCluster'
487 493
488 494 class IPBaseParallelApplication(Application):
489 495 name = u'ipcluster'
490 496 description = _description
491 497
492 498 subcommands = {'create' : (base+'Create', create_help),
493 499 'list' : (base+'List', list_help),
494 500 'start' : (base+'Start', start_help),
495 501 'stop' : (base+'Stop', stop_help),
496 502 'engines' : (base+'Engines', engines_help),
497 503 }
498 504
499 505 # no aliases or flags for parent App
500 506 aliases = Dict()
501 507 flags = Dict()
502 508
503 509 def start(self):
504 510 if self.subapp is None:
505 511 print "No subcommand specified! Must specify one of: %s"%(self.subcommands.keys())
506 512 print
507 513 self.print_subcommands()
508 514 self.exit(1)
509 515 else:
510 516 return self.subapp.start()
511 517
512 518 def launch_new_instance():
513 519 """Create and run the IPython cluster."""
514 520 app = IPBaseParallelApplication.instance()
515 521 app.initialize()
516 522 app.start()
517 523
518 524
519 525 if __name__ == '__main__':
520 526 launch_new_instance()
521 527
@@ -1,402 +1,408 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython controller application.
5
6 Authors:
7
8 * Brian Granger
9 * MinRK
10
5 11 """
6 12
7 13 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
9 15 #
10 16 # Distributed under the terms of the BSD License. The full license is in
11 17 # the file COPYING, distributed as part of this software.
12 18 #-----------------------------------------------------------------------------
13 19
14 20 #-----------------------------------------------------------------------------
15 21 # Imports
16 22 #-----------------------------------------------------------------------------
17 23
18 24 from __future__ import with_statement
19 25
20 26 import os
21 27 import socket
22 28 import stat
23 29 import sys
24 30 import uuid
25 31
26 32 from multiprocessing import Process
27 33
28 34 import zmq
29 35 from zmq.devices import ProcessMonitoredQueue
30 36 from zmq.log.handlers import PUBHandler
31 37 from zmq.utils import jsonapi as json
32 38
33 39 from IPython.config.application import boolean_flag
34 40 from IPython.core.newapplication import ProfileDir
35 41
36 42 from IPython.parallel.apps.baseapp import (
37 43 BaseParallelApplication,
38 44 base_flags
39 45 )
40 46 from IPython.utils.importstring import import_item
41 47 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
42 48
43 49 # from IPython.parallel.controller.controller import ControllerFactory
44 50 from IPython.zmq.session import Session
45 51 from IPython.parallel.controller.heartmonitor import HeartMonitor
46 52 from IPython.parallel.controller.hub import HubFactory
47 53 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
48 54 from IPython.parallel.controller.sqlitedb import SQLiteDB
49 55
50 56 from IPython.parallel.util import signal_children, split_url
51 57
52 58 # conditional import of MongoDB backend class
53 59
54 60 try:
55 61 from IPython.parallel.controller.mongodb import MongoDB
56 62 except ImportError:
57 63 maybe_mongo = []
58 64 else:
59 65 maybe_mongo = [MongoDB]
60 66
61 67
62 68 #-----------------------------------------------------------------------------
63 69 # Module level variables
64 70 #-----------------------------------------------------------------------------
65 71
66 72
67 73 #: The default config file name for this application
68 74 default_config_file_name = u'ipcontroller_config.py'
69 75
70 76
71 77 _description = """Start the IPython controller for parallel computing.
72 78
73 79 The IPython controller provides a gateway between the IPython engines and
74 80 clients. The controller needs to be started before the engines and can be
75 81 configured using command line options or using a cluster directory. Cluster
76 82 directories contain config, log and security files and are usually located in
77 83 your ipython directory and named as "cluster_<profile>". See the `profile`
78 84 and `profile_dir` options for details.
79 85 """
80 86
81 87
82 88
83 89
84 90 #-----------------------------------------------------------------------------
85 91 # The main application
86 92 #-----------------------------------------------------------------------------
87 93 flags = {}
88 94 flags.update(base_flags)
89 95 flags.update({
90 96 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
91 97 'Use threads instead of processes for the schedulers'),
92 98 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
93 99 'use the SQLiteDB backend'),
94 100 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
95 101 'use the MongoDB backend'),
96 102 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
97 103 'use the in-memory DictDB backend'),
98 104 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
99 105 'reuse existing json connection files')
100 106 })
101 107
102 108 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
103 109 "Use HMAC digests for authentication of messages.",
104 110 "Don't authenticate messages."
105 111 ))
106 112
107 113 class IPControllerApp(BaseParallelApplication):
108 114
109 115 name = u'ipcontroller'
110 116 description = _description
111 117 config_file_name = Unicode(default_config_file_name)
112 118 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
113 119
114 120 # change default to True
115 121 auto_create = Bool(True, config=True,
116 122 help="""Whether to create profile dir if it doesn't exist.""")
117 123
118 124 reuse_files = Bool(False, config=True,
119 125 help='Whether to reuse existing json connection files.'
120 126 )
121 127 secure = Bool(True, config=True,
122 128 help='Whether to use HMAC digests for extra message authentication.'
123 129 )
124 130 ssh_server = Unicode(u'', config=True,
125 131 help="""ssh url for clients to use when connecting to the Controller
126 132 processes. It should be of the form: [user@]server[:port]. The
127 133 Controller's listening addresses must be accessible from the ssh server""",
128 134 )
129 135 location = Unicode(u'', config=True,
130 136 help="""The external IP or domain name of the Controller, used for disambiguating
131 137 engine and client connections.""",
132 138 )
133 139 import_statements = List([], config=True,
134 140 help="import statements to be run at startup. Necessary in some environments"
135 141 )
136 142
137 143 use_threads = Bool(False, config=True,
138 144 help='Use threads instead of processes for the schedulers',
139 145 )
140 146
141 147 # internal
142 148 children = List()
143 149 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
144 150
145 151 def _use_threads_changed(self, name, old, new):
146 152 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
147 153
148 154 aliases = Dict(dict(
149 155 log_level = 'IPControllerApp.log_level',
150 156 log_url = 'IPControllerApp.log_url',
151 157 reuse_files = 'IPControllerApp.reuse_files',
152 158 secure = 'IPControllerApp.secure',
153 159 ssh = 'IPControllerApp.ssh_server',
154 160 use_threads = 'IPControllerApp.use_threads',
155 161 import_statements = 'IPControllerApp.import_statements',
156 162 location = 'IPControllerApp.location',
157 163
158 164 ident = 'Session.session',
159 165 user = 'Session.username',
160 166 exec_key = 'Session.keyfile',
161 167
162 168 url = 'HubFactory.url',
163 169 ip = 'HubFactory.ip',
164 170 transport = 'HubFactory.transport',
165 171 port = 'HubFactory.regport',
166 172
167 173 ping = 'HeartMonitor.period',
168 174
169 175 scheme = 'TaskScheduler.scheme_name',
170 176 hwm = 'TaskScheduler.hwm',
171 177
172 178
173 179 profile = "BaseIPythonApplication.profile",
174 180 profile_dir = 'ProfileDir.location',
175 181
176 182 ))
177 183 flags = Dict(flags)
178 184
179 185
180 186 def save_connection_dict(self, fname, cdict):
181 187 """save a connection dict to json file."""
182 188 c = self.config
183 189 url = cdict['url']
184 190 location = cdict['location']
185 191 if not location:
186 192 try:
187 193 proto,ip,port = split_url(url)
188 194 except AssertionError:
189 195 pass
190 196 else:
191 197 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
192 198 cdict['location'] = location
193 199 fname = os.path.join(self.profile_dir.security_dir, fname)
194 200 with open(fname, 'w') as f:
195 201 f.write(json.dumps(cdict, indent=2))
196 202 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
197 203
198 204 def load_config_from_json(self):
199 205 """load config from existing json connector files."""
200 206 c = self.config
201 207 # load from engine config
202 208 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
203 209 cfg = json.loads(f.read())
204 210 key = c.Session.key = cfg['exec_key']
205 211 xport,addr = cfg['url'].split('://')
206 212 c.HubFactory.engine_transport = xport
207 213 ip,ports = addr.split(':')
208 214 c.HubFactory.engine_ip = ip
209 215 c.HubFactory.regport = int(ports)
210 216 self.location = cfg['location']
211 217
212 218 # load client config
213 219 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
214 220 cfg = json.loads(f.read())
215 221 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
216 222 xport,addr = cfg['url'].split('://')
217 223 c.HubFactory.client_transport = xport
218 224 ip,ports = addr.split(':')
219 225 c.HubFactory.client_ip = ip
220 226 self.ssh_server = cfg['ssh']
221 227 assert int(ports) == c.HubFactory.regport, "regport mismatch"
222 228
223 229 def init_hub(self):
224 230 c = self.config
225 231
226 232 self.do_import_statements()
227 233 reusing = self.reuse_files
228 234 if reusing:
229 235 try:
230 236 self.load_config_from_json()
231 237 except (AssertionError,IOError):
232 238 reusing=False
233 239 # check again, because reusing may have failed:
234 240 if reusing:
235 241 pass
236 242 elif self.secure:
237 243 key = str(uuid.uuid4())
238 244 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
239 245 # with open(keyfile, 'w') as f:
240 246 # f.write(key)
241 247 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
242 248 c.Session.key = key
243 249 else:
244 250 key = c.Session.key = ''
245 251
246 252 try:
247 253 self.factory = HubFactory(config=c, log=self.log)
248 254 # self.start_logging()
249 255 self.factory.init_hub()
250 256 except:
251 257 self.log.error("Couldn't construct the Controller", exc_info=True)
252 258 self.exit(1)
253 259
254 260 if not reusing:
255 261 # save to new json config files
256 262 f = self.factory
257 263 cdict = {'exec_key' : key,
258 264 'ssh' : self.ssh_server,
259 265 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
260 266 'location' : self.location
261 267 }
262 268 self.save_connection_dict('ipcontroller-client.json', cdict)
263 269 edict = cdict
264 270 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
265 271 self.save_connection_dict('ipcontroller-engine.json', edict)
266 272
267 273 #
268 274 def init_schedulers(self):
269 275 children = self.children
270 276 mq = import_item(str(self.mq_class))
271 277
272 278 hub = self.factory
273 279 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
274 280 # IOPub relay (in a Process)
275 281 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
276 282 q.bind_in(hub.client_info['iopub'])
277 283 q.bind_out(hub.engine_info['iopub'])
278 284 q.setsockopt_out(zmq.SUBSCRIBE, '')
279 285 q.connect_mon(hub.monitor_url)
280 286 q.daemon=True
281 287 children.append(q)
282 288
283 289 # Multiplexer Queue (in a Process)
284 290 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
285 291 q.bind_in(hub.client_info['mux'])
286 292 q.setsockopt_in(zmq.IDENTITY, 'mux')
287 293 q.bind_out(hub.engine_info['mux'])
288 294 q.connect_mon(hub.monitor_url)
289 295 q.daemon=True
290 296 children.append(q)
291 297
292 298 # Control Queue (in a Process)
293 299 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
294 300 q.bind_in(hub.client_info['control'])
295 301 q.setsockopt_in(zmq.IDENTITY, 'control')
296 302 q.bind_out(hub.engine_info['control'])
297 303 q.connect_mon(hub.monitor_url)
298 304 q.daemon=True
299 305 children.append(q)
300 306 try:
301 307 scheme = self.config.TaskScheduler.scheme_name
302 308 except AttributeError:
303 309 scheme = TaskScheduler.scheme_name.get_default_value()
304 310 # Task Queue (in a Process)
305 311 if scheme == 'pure':
306 312 self.log.warn("task::using pure XREQ Task scheduler")
307 313 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
308 314 # q.setsockopt_out(zmq.HWM, hub.hwm)
309 315 q.bind_in(hub.client_info['task'][1])
310 316 q.setsockopt_in(zmq.IDENTITY, 'task')
311 317 q.bind_out(hub.engine_info['task'])
312 318 q.connect_mon(hub.monitor_url)
313 319 q.daemon=True
314 320 children.append(q)
315 321 elif scheme == 'none':
316 322 self.log.warn("task::using no Task scheduler")
317 323
318 324 else:
319 325 self.log.info("task::using Python %s Task scheduler"%scheme)
320 326 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
321 327 hub.monitor_url, hub.client_info['notification'])
322 328 kwargs = dict(logname='scheduler', loglevel=self.log_level,
323 329 log_url = self.log_url, config=dict(self.config))
324 330 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
325 331 q.daemon=True
326 332 children.append(q)
327 333
328 334
329 335 def save_urls(self):
330 336 """save the registration urls to files."""
331 337 c = self.config
332 338
333 339 sec_dir = self.profile_dir.security_dir
334 340 cf = self.factory
335 341
336 342 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
337 343 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
338 344
339 345 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
340 346 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
341 347
342 348
343 349 def do_import_statements(self):
344 350 statements = self.import_statements
345 351 for s in statements:
346 352 try:
347 353 self.log.msg("Executing statement: '%s'" % s)
348 354 exec s in globals(), locals()
349 355 except:
350 356 self.log.msg("Error running statement: %s" % s)
351 357
352 358 def forward_logging(self):
353 359 if self.log_url:
354 360 self.log.info("Forwarding logging to %s"%self.log_url)
355 361 context = zmq.Context.instance()
356 362 lsock = context.socket(zmq.PUB)
357 363 lsock.connect(self.log_url)
358 364 handler = PUBHandler(lsock)
359 365 self.log.removeHandler(self._log_handler)
360 366 handler.root_topic = 'controller'
361 367 handler.setLevel(self.log_level)
362 368 self.log.addHandler(handler)
363 369 self._log_handler = handler
364 370 # #
365 371
366 372 def initialize(self, argv=None):
367 373 super(IPControllerApp, self).initialize(argv)
368 374 self.forward_logging()
369 375 self.init_hub()
370 376 self.init_schedulers()
371 377
372 378 def start(self):
373 379 # Start the subprocesses:
374 380 self.factory.start()
375 381 child_procs = []
376 382 for child in self.children:
377 383 child.start()
378 384 if isinstance(child, ProcessMonitoredQueue):
379 385 child_procs.append(child.launcher)
380 386 elif isinstance(child, Process):
381 387 child_procs.append(child)
382 388 if child_procs:
383 389 signal_children(child_procs)
384 390
385 391 self.write_pid_file(overwrite=True)
386 392
387 393 try:
388 394 self.factory.loop.start()
389 395 except KeyboardInterrupt:
390 396 self.log.critical("Interrupted, Exiting...\n")
391 397
392 398
393 399
394 400 def launch_new_instance():
395 401 """Create and run the IPython controller"""
396 402 app = IPControllerApp.instance()
397 403 app.initialize()
398 404 app.start()
399 405
400 406
401 407 if __name__ == '__main__':
402 408 launch_new_instance()
@@ -1,270 +1,276 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython engine application
5
6 Authors:
7
8 * Brian Granger
9 * MinRK
10
5 11 """
6 12
7 13 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
9 15 #
10 16 # Distributed under the terms of the BSD License. The full license is in
11 17 # the file COPYING, distributed as part of this software.
12 18 #-----------------------------------------------------------------------------
13 19
14 20 #-----------------------------------------------------------------------------
15 21 # Imports
16 22 #-----------------------------------------------------------------------------
17 23
18 24 import json
19 25 import os
20 26 import sys
21 27
22 28 import zmq
23 29 from zmq.eventloop import ioloop
24 30
25 31 from IPython.core.newapplication import ProfileDir
26 32 from IPython.parallel.apps.baseapp import BaseParallelApplication
27 33 from IPython.zmq.log import EnginePUBHandler
28 34
29 35 from IPython.config.configurable import Configurable
30 36 from IPython.zmq.session import Session
31 37 from IPython.parallel.engine.engine import EngineFactory
32 38 from IPython.parallel.engine.streamkernel import Kernel
33 39 from IPython.parallel.util import disambiguate_url
34 40
35 41 from IPython.utils.importstring import import_item
36 42 from IPython.utils.traitlets import Bool, Unicode, Dict, List
37 43
38 44
39 45 #-----------------------------------------------------------------------------
40 46 # Module level variables
41 47 #-----------------------------------------------------------------------------
42 48
43 49 #: The default config file name for this application
44 50 default_config_file_name = u'ipengine_config.py'
45 51
46 52 _description = """Start an IPython engine for parallel computing.
47 53
48 54 IPython engines run in parallel and perform computations on behalf of a client
49 55 and controller. A controller needs to be started before the engines. The
50 56 engine can be configured using command line options or using a cluster
51 57 directory. Cluster directories contain config, log and security files and are
52 58 usually located in your ipython directory and named as "cluster_<profile>".
53 59 See the `profile` and `profile_dir` options for details.
54 60 """
55 61
56 62
57 63 #-----------------------------------------------------------------------------
58 64 # MPI configuration
59 65 #-----------------------------------------------------------------------------
60 66
61 67 mpi4py_init = """from mpi4py import MPI as mpi
62 68 mpi.size = mpi.COMM_WORLD.Get_size()
63 69 mpi.rank = mpi.COMM_WORLD.Get_rank()
64 70 """
65 71
66 72
67 73 pytrilinos_init = """from PyTrilinos import Epetra
68 74 class SimpleStruct:
69 75 pass
70 76 mpi = SimpleStruct()
71 77 mpi.rank = 0
72 78 mpi.size = 0
73 79 """
74 80
75 81 class MPI(Configurable):
76 82 """Configurable for MPI initialization"""
77 83 use = Unicode('', config=True,
78 84 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
79 85 )
80 86
81 87 def _on_use_changed(self, old, new):
82 88 # load default init script if it's not set
83 89 if not self.init_script:
84 90 self.init_script = self.default_inits.get(new, '')
85 91
86 92 init_script = Unicode('', config=True,
87 93 help="Initialization code for MPI")
88 94
89 95 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
90 96 config=True)
91 97
92 98
93 99 #-----------------------------------------------------------------------------
94 100 # Main application
95 101 #-----------------------------------------------------------------------------
96 102
97 103
98 104 class IPEngineApp(BaseParallelApplication):
99 105
100 106 app_name = Unicode(u'ipengine')
101 107 description = Unicode(_description)
102 108 config_file_name = Unicode(default_config_file_name)
103 109 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
104 110
105 111 startup_script = Unicode(u'', config=True,
106 112 help='specify a script to be run at startup')
107 113 startup_command = Unicode('', config=True,
108 114 help='specify a command to be run at startup')
109 115
110 116 url_file = Unicode(u'', config=True,
111 117 help="""The full location of the file containing the connection information for
112 118 the controller. If this is not given, the file must be in the
113 119 security directory of the cluster directory. This location is
114 120 resolved using the `profile` or `profile_dir` options.""",
115 121 )
116 122
117 123 url_file_name = Unicode(u'ipcontroller-engine.json')
118 124 log_url = Unicode('', config=True,
119 125 help="""The URL for the iploggerapp instance, for forwarding
120 126 logging to a central location.""")
121 127
122 128 aliases = Dict(dict(
123 129 file = 'IPEngineApp.url_file',
124 130 c = 'IPEngineApp.startup_command',
125 131 s = 'IPEngineApp.startup_script',
126 132
127 133 ident = 'Session.session',
128 134 user = 'Session.username',
129 135 exec_key = 'Session.keyfile',
130 136
131 137 url = 'EngineFactory.url',
132 138 ip = 'EngineFactory.ip',
133 139 transport = 'EngineFactory.transport',
134 140 port = 'EngineFactory.regport',
135 141 location = 'EngineFactory.location',
136 142
137 143 timeout = 'EngineFactory.timeout',
138 144
139 145 profile = "IPEngineApp.profile",
140 146 profile_dir = 'ProfileDir.location',
141 147
142 148 mpi = 'MPI.use',
143 149
144 150 log_level = 'IPEngineApp.log_level',
145 151 log_url = 'IPEngineApp.log_url'
146 152 ))
147 153
148 154 # def find_key_file(self):
149 155 # """Set the key file.
150 156 #
151 157 # Here we don't try to actually see if it exists for is valid as that
152 158 # is hadled by the connection logic.
153 159 # """
154 160 # config = self.master_config
155 161 # # Find the actual controller key file
156 162 # if not config.Global.key_file:
157 163 # try_this = os.path.join(
158 164 # config.Global.profile_dir,
159 165 # config.Global.security_dir,
160 166 # config.Global.key_file_name
161 167 # )
162 168 # config.Global.key_file = try_this
163 169
164 170 def find_url_file(self):
165 171 """Set the key file.
166 172
167 173 Here we don't try to actually see if it exists for is valid as that
168 174 is hadled by the connection logic.
169 175 """
170 176 config = self.config
171 177 # Find the actual controller key file
172 178 if not self.url_file:
173 179 self.url_file = os.path.join(
174 180 self.profile_dir.security_dir,
175 181 self.url_file_name
176 182 )
177 183 def init_engine(self):
178 184 # This is the working dir by now.
179 185 sys.path.insert(0, '')
180 186 config = self.config
181 187 # print config
182 188 self.find_url_file()
183 189
184 190 # if os.path.exists(config.Global.key_file) and config.Global.secure:
185 191 # config.SessionFactory.exec_key = config.Global.key_file
186 192 if os.path.exists(self.url_file):
187 193 with open(self.url_file) as f:
188 194 d = json.loads(f.read())
189 195 for k,v in d.iteritems():
190 196 if isinstance(v, unicode):
191 197 d[k] = v.encode()
192 198 if d['exec_key']:
193 199 config.Session.key = d['exec_key']
194 200 d['url'] = disambiguate_url(d['url'], d['location'])
195 201 config.EngineFactory.url = d['url']
196 202 config.EngineFactory.location = d['location']
197 203
198 204 try:
199 205 exec_lines = config.Kernel.exec_lines
200 206 except AttributeError:
201 207 config.Kernel.exec_lines = []
202 208 exec_lines = config.Kernel.exec_lines
203 209
204 210 if self.startup_script:
205 211 enc = sys.getfilesystemencoding() or 'utf8'
206 212 cmd="execfile(%r)"%self.startup_script.encode(enc)
207 213 exec_lines.append(cmd)
208 214 if self.startup_command:
209 215 exec_lines.append(self.startup_command)
210 216
211 217 # Create the underlying shell class and Engine
212 218 # shell_class = import_item(self.master_config.Global.shell_class)
213 219 # print self.config
214 220 try:
215 221 self.engine = EngineFactory(config=config, log=self.log)
216 222 except:
217 223 self.log.error("Couldn't start the Engine", exc_info=True)
218 224 self.exit(1)
219 225
220 226 def forward_logging(self):
221 227 if self.log_url:
222 228 self.log.info("Forwarding logging to %s"%self.log_url)
223 229 context = self.engine.context
224 230 lsock = context.socket(zmq.PUB)
225 231 lsock.connect(self.log_url)
226 232 self.log.removeHandler(self._log_handler)
227 233 handler = EnginePUBHandler(self.engine, lsock)
228 234 handler.setLevel(self.log_level)
229 235 self.log.addHandler(handler)
230 236 self._log_handler = handler
231 237 #
232 238 def init_mpi(self):
233 239 global mpi
234 240 self.mpi = MPI(config=self.config)
235 241
236 242 mpi_import_statement = self.mpi.init_script
237 243 if mpi_import_statement:
238 244 try:
239 245 self.log.info("Initializing MPI:")
240 246 self.log.info(mpi_import_statement)
241 247 exec mpi_import_statement in globals()
242 248 except:
243 249 mpi = None
244 250 else:
245 251 mpi = None
246 252
247 253 def initialize(self, argv=None):
248 254 super(IPEngineApp, self).initialize(argv)
249 255 self.init_mpi()
250 256 self.init_engine()
251 257 self.forward_logging()
252 258
253 259 def start(self):
254 260 self.engine.start()
255 261 try:
256 262 self.engine.loop.start()
257 263 except KeyboardInterrupt:
258 264 self.log.critical("Engine Interrupted, shutting down...\n")
259 265
260 266
261 267 def launch_new_instance():
262 268 """Create and run the IPython engine"""
263 269 app = IPEngineApp.instance()
264 270 app.initialize()
265 271 app.start()
266 272
267 273
268 274 if __name__ == '__main__':
269 275 launch_new_instance()
270 276
@@ -1,96 +1,101 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 A simple IPython logger application
5
6 Authors:
7
8 * MinRK
9
5 10 """
6 11
7 12 #-----------------------------------------------------------------------------
8 13 # Copyright (C) 2011 The IPython Development Team
9 14 #
10 15 # Distributed under the terms of the BSD License. The full license is in
11 16 # the file COPYING, distributed as part of this software.
12 17 #-----------------------------------------------------------------------------
13 18
14 19 #-----------------------------------------------------------------------------
15 20 # Imports
16 21 #-----------------------------------------------------------------------------
17 22
18 23 import os
19 24 import sys
20 25
21 26 import zmq
22 27
23 28 from IPython.core.newapplication import ProfileDir
24 29 from IPython.utils.traitlets import Bool, Dict, Unicode
25 30
26 31 from IPython.parallel.apps.baseapp import (
27 32 BaseParallelApplication,
28 33 base_aliases
29 34 )
30 35 from IPython.parallel.apps.logwatcher import LogWatcher
31 36
32 37 #-----------------------------------------------------------------------------
33 38 # Module level variables
34 39 #-----------------------------------------------------------------------------
35 40
36 41 #: The default config file name for this application
37 42 default_config_file_name = u'iplogger_config.py'
38 43
39 44 _description = """Start an IPython logger for parallel computing.
40 45
41 46 IPython controllers and engines (and your own processes) can broadcast log messages
42 47 by registering a `zmq.log.handlers.PUBHandler` with the `logging` module. The
43 48 logger can be configured using command line options or using a cluster
44 49 directory. Cluster directories contain config, log and security files and are
45 50 usually located in your ipython directory and named as "cluster_<profile>".
46 51 See the `profile` and `profile_dir` options for details.
47 52 """
48 53
49 54
50 55 #-----------------------------------------------------------------------------
51 56 # Main application
52 57 #-----------------------------------------------------------------------------
53 58 aliases = {}
54 59 aliases.update(base_aliases)
55 60 aliases.update(dict(url='LogWatcher.url', topics='LogWatcher.topics'))
56 61
57 62 class IPLoggerApp(BaseParallelApplication):
58 63
59 64 name = u'iploggerz'
60 65 description = _description
61 66 config_file_name = Unicode(default_config_file_name)
62 67
63 68 classes = [LogWatcher, ProfileDir]
64 69 aliases = Dict(aliases)
65 70
66 71 def initialize(self, argv=None):
67 72 super(IPLoggerApp, self).initialize(argv)
68 73 self.init_watcher()
69 74
70 75 def init_watcher(self):
71 76 try:
72 77 self.watcher = LogWatcher(config=self.config, log=self.log)
73 78 except:
74 79 self.log.error("Couldn't start the LogWatcher", exc_info=True)
75 80 self.exit(1)
76 81 self.log.info("Listening for log messages on %r"%self.watcher.url)
77 82
78 83
79 84 def start(self):
80 85 self.watcher.start()
81 86 try:
82 87 self.watcher.loop.start()
83 88 except KeyboardInterrupt:
84 89 self.log.critical("Logging Interrupted, shutting down...\n")
85 90
86 91
87 92 def launch_new_instance():
88 93 """Create and run the IPython LogWatcher"""
89 94 app = IPLoggerApp.instance()
90 95 app.initialize()
91 96 app.start()
92 97
93 98
94 99 if __name__ == '__main__':
95 100 launch_new_instance()
96 101
@@ -1,1069 +1,1074 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 Facilities for launching IPython processes asynchronously.
5
6 Authors:
7
8 * Brian Granger
9 * MinRK
5 10 """
6 11
7 12 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
13 # Copyright (C) 2008-2011 The IPython Development Team
9 14 #
10 15 # Distributed under the terms of the BSD License. The full license is in
11 16 # the file COPYING, distributed as part of this software.
12 17 #-----------------------------------------------------------------------------
13 18
14 19 #-----------------------------------------------------------------------------
15 20 # Imports
16 21 #-----------------------------------------------------------------------------
17 22
18 23 import copy
19 24 import logging
20 25 import os
21 26 import re
22 27 import stat
23 28
24 29 # signal imports, handling various platforms, versions
25 30
26 31 from signal import SIGINT, SIGTERM
27 32 try:
28 33 from signal import SIGKILL
29 34 except ImportError:
30 35 # Windows
31 36 SIGKILL=SIGTERM
32 37
33 38 try:
34 39 # Windows >= 2.7, 3.2
35 40 from signal import CTRL_C_EVENT as SIGINT
36 41 except ImportError:
37 42 pass
38 43
39 44 from subprocess import Popen, PIPE, STDOUT
40 45 try:
41 46 from subprocess import check_output
42 47 except ImportError:
43 48 # pre-2.7, define check_output with Popen
44 49 def check_output(*args, **kwargs):
45 50 kwargs.update(dict(stdout=PIPE))
46 51 p = Popen(*args, **kwargs)
47 52 out,err = p.communicate()
48 53 return out
49 54
50 55 from zmq.eventloop import ioloop
51 56
52 57 from IPython.config.application import Application
53 58 from IPython.config.configurable import LoggingConfigurable
54 59 from IPython.utils.text import EvalFormatter
55 60 from IPython.utils.traitlets import Any, Int, List, Unicode, Dict, Instance
56 61 from IPython.utils.path import get_ipython_module_path
57 62 from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError
58 63
59 64 from .win32support import forward_read_events
60 65
61 66 from .winhpcjob import IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob
62 67
63 68 WINDOWS = os.name == 'nt'
64 69
65 70 #-----------------------------------------------------------------------------
66 71 # Paths to the kernel apps
67 72 #-----------------------------------------------------------------------------
68 73
69 74
70 75 ipcluster_cmd_argv = pycmd2argv(get_ipython_module_path(
71 76 'IPython.parallel.apps.ipclusterapp'
72 77 ))
73 78
74 79 ipengine_cmd_argv = pycmd2argv(get_ipython_module_path(
75 80 'IPython.parallel.apps.ipengineapp'
76 81 ))
77 82
78 83 ipcontroller_cmd_argv = pycmd2argv(get_ipython_module_path(
79 84 'IPython.parallel.apps.ipcontrollerapp'
80 85 ))
81 86
82 87 #-----------------------------------------------------------------------------
83 88 # Base launchers and errors
84 89 #-----------------------------------------------------------------------------
85 90
86 91
87 92 class LauncherError(Exception):
88 93 pass
89 94
90 95
91 96 class ProcessStateError(LauncherError):
92 97 pass
93 98
94 99
95 100 class UnknownStatus(LauncherError):
96 101 pass
97 102
98 103
99 104 class BaseLauncher(LoggingConfigurable):
100 105 """An asbtraction for starting, stopping and signaling a process."""
101 106
102 107 # In all of the launchers, the work_dir is where child processes will be
103 108 # run. This will usually be the profile_dir, but may not be. any work_dir
104 109 # passed into the __init__ method will override the config value.
105 110 # This should not be used to set the work_dir for the actual engine
106 111 # and controller. Instead, use their own config files or the
107 112 # controller_args, engine_args attributes of the launchers to add
108 113 # the work_dir option.
109 114 work_dir = Unicode(u'.')
110 115 loop = Instance('zmq.eventloop.ioloop.IOLoop')
111 116
112 117 start_data = Any()
113 118 stop_data = Any()
114 119
115 120 def _loop_default(self):
116 121 return ioloop.IOLoop.instance()
117 122
118 123 def __init__(self, work_dir=u'.', config=None, **kwargs):
119 124 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
120 125 self.state = 'before' # can be before, running, after
121 126 self.stop_callbacks = []
122 127 self.start_data = None
123 128 self.stop_data = None
124 129
125 130 @property
126 131 def args(self):
127 132 """A list of cmd and args that will be used to start the process.
128 133
129 134 This is what is passed to :func:`spawnProcess` and the first element
130 135 will be the process name.
131 136 """
132 137 return self.find_args()
133 138
134 139 def find_args(self):
135 140 """The ``.args`` property calls this to find the args list.
136 141
137 142 Subcommand should implement this to construct the cmd and args.
138 143 """
139 144 raise NotImplementedError('find_args must be implemented in a subclass')
140 145
141 146 @property
142 147 def arg_str(self):
143 148 """The string form of the program arguments."""
144 149 return ' '.join(self.args)
145 150
146 151 @property
147 152 def running(self):
148 153 """Am I running."""
149 154 if self.state == 'running':
150 155 return True
151 156 else:
152 157 return False
153 158
154 159 def start(self):
155 160 """Start the process.
156 161
157 162 This must return a deferred that fires with information about the
158 163 process starting (like a pid, job id, etc.).
159 164 """
160 165 raise NotImplementedError('start must be implemented in a subclass')
161 166
162 167 def stop(self):
163 168 """Stop the process and notify observers of stopping.
164 169
165 170 This must return a deferred that fires with information about the
166 171 processing stopping, like errors that occur while the process is
167 172 attempting to be shut down. This deferred won't fire when the process
168 173 actually stops. To observe the actual process stopping, see
169 174 :func:`observe_stop`.
170 175 """
171 176 raise NotImplementedError('stop must be implemented in a subclass')
172 177
173 178 def on_stop(self, f):
174 179 """Get a deferred that will fire when the process stops.
175 180
176 181 The deferred will fire with data that contains information about
177 182 the exit status of the process.
178 183 """
179 184 if self.state=='after':
180 185 return f(self.stop_data)
181 186 else:
182 187 self.stop_callbacks.append(f)
183 188
184 189 def notify_start(self, data):
185 190 """Call this to trigger startup actions.
186 191
187 192 This logs the process startup and sets the state to 'running'. It is
188 193 a pass-through so it can be used as a callback.
189 194 """
190 195
191 196 self.log.info('Process %r started: %r' % (self.args[0], data))
192 197 self.start_data = data
193 198 self.state = 'running'
194 199 return data
195 200
196 201 def notify_stop(self, data):
197 202 """Call this to trigger process stop actions.
198 203
199 204 This logs the process stopping and sets the state to 'after'. Call
200 205 this to trigger all the deferreds from :func:`observe_stop`."""
201 206
202 207 self.log.info('Process %r stopped: %r' % (self.args[0], data))
203 208 self.stop_data = data
204 209 self.state = 'after'
205 210 for i in range(len(self.stop_callbacks)):
206 211 d = self.stop_callbacks.pop()
207 212 d(data)
208 213 return data
209 214
210 215 def signal(self, sig):
211 216 """Signal the process.
212 217
213 218 Return a semi-meaningless deferred after signaling the process.
214 219
215 220 Parameters
216 221 ----------
217 222 sig : str or int
218 223 'KILL', 'INT', etc., or any signal number
219 224 """
220 225 raise NotImplementedError('signal must be implemented in a subclass')
221 226
222 227
223 228 #-----------------------------------------------------------------------------
224 229 # Local process launchers
225 230 #-----------------------------------------------------------------------------
226 231
227 232
228 233 class LocalProcessLauncher(BaseLauncher):
229 234 """Start and stop an external process in an asynchronous manner.
230 235
231 236 This will launch the external process with a working directory of
232 237 ``self.work_dir``.
233 238 """
234 239
235 240 # This is used to to construct self.args, which is passed to
236 241 # spawnProcess.
237 242 cmd_and_args = List([])
238 243 poll_frequency = Int(100) # in ms
239 244
240 245 def __init__(self, work_dir=u'.', config=None, **kwargs):
241 246 super(LocalProcessLauncher, self).__init__(
242 247 work_dir=work_dir, config=config, **kwargs
243 248 )
244 249 self.process = None
245 250 self.start_deferred = None
246 251 self.poller = None
247 252
248 253 def find_args(self):
249 254 return self.cmd_and_args
250 255
251 256 def start(self):
252 257 if self.state == 'before':
253 258 self.process = Popen(self.args,
254 259 stdout=PIPE,stderr=PIPE,stdin=PIPE,
255 260 env=os.environ,
256 261 cwd=self.work_dir
257 262 )
258 263 if WINDOWS:
259 264 self.stdout = forward_read_events(self.process.stdout)
260 265 self.stderr = forward_read_events(self.process.stderr)
261 266 else:
262 267 self.stdout = self.process.stdout.fileno()
263 268 self.stderr = self.process.stderr.fileno()
264 269 self.loop.add_handler(self.stdout, self.handle_stdout, self.loop.READ)
265 270 self.loop.add_handler(self.stderr, self.handle_stderr, self.loop.READ)
266 271 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
267 272 self.poller.start()
268 273 self.notify_start(self.process.pid)
269 274 else:
270 275 s = 'The process was already started and has state: %r' % self.state
271 276 raise ProcessStateError(s)
272 277
273 278 def stop(self):
274 279 return self.interrupt_then_kill()
275 280
276 281 def signal(self, sig):
277 282 if self.state == 'running':
278 283 if WINDOWS and sig != SIGINT:
279 284 # use Windows tree-kill for better child cleanup
280 285 check_output(['taskkill', '-pid', str(self.process.pid), '-t', '-f'])
281 286 else:
282 287 self.process.send_signal(sig)
283 288
284 289 def interrupt_then_kill(self, delay=2.0):
285 290 """Send INT, wait a delay and then send KILL."""
286 291 try:
287 292 self.signal(SIGINT)
288 293 except Exception:
289 294 self.log.debug("interrupt failed")
290 295 pass
291 296 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
292 297 self.killer.start()
293 298
294 299 # callbacks, etc:
295 300
296 301 def handle_stdout(self, fd, events):
297 302 if WINDOWS:
298 303 line = self.stdout.recv()
299 304 else:
300 305 line = self.process.stdout.readline()
301 306 # a stopped process will be readable but return empty strings
302 307 if line:
303 308 self.log.info(line[:-1])
304 309 else:
305 310 self.poll()
306 311
307 312 def handle_stderr(self, fd, events):
308 313 if WINDOWS:
309 314 line = self.stderr.recv()
310 315 else:
311 316 line = self.process.stderr.readline()
312 317 # a stopped process will be readable but return empty strings
313 318 if line:
314 319 self.log.error(line[:-1])
315 320 else:
316 321 self.poll()
317 322
318 323 def poll(self):
319 324 status = self.process.poll()
320 325 if status is not None:
321 326 self.poller.stop()
322 327 self.loop.remove_handler(self.stdout)
323 328 self.loop.remove_handler(self.stderr)
324 329 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
325 330 return status
326 331
327 332 class LocalControllerLauncher(LocalProcessLauncher):
328 333 """Launch a controller as a regular external process."""
329 334
330 335 controller_cmd = List(ipcontroller_cmd_argv, config=True,
331 336 help="""Popen command to launch ipcontroller.""")
332 337 # Command line arguments to ipcontroller.
333 338 controller_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
334 339 help="""command-line args to pass to ipcontroller""")
335 340
336 341 def find_args(self):
337 342 return self.controller_cmd + self.controller_args
338 343
339 344 def start(self, profile_dir):
340 345 """Start the controller by profile_dir."""
341 346 self.controller_args.extend(['profile_dir=%s'%profile_dir])
342 347 self.profile_dir = unicode(profile_dir)
343 348 self.log.info("Starting LocalControllerLauncher: %r" % self.args)
344 349 return super(LocalControllerLauncher, self).start()
345 350
346 351
347 352 class LocalEngineLauncher(LocalProcessLauncher):
348 353 """Launch a single engine as a regular externall process."""
349 354
350 355 engine_cmd = List(ipengine_cmd_argv, config=True,
351 356 help="""command to launch the Engine.""")
352 357 # Command line arguments for ipengine.
353 358 engine_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
354 359 help="command-line arguments to pass to ipengine"
355 360 )
356 361
357 362 def find_args(self):
358 363 return self.engine_cmd + self.engine_args
359 364
360 365 def start(self, profile_dir):
361 366 """Start the engine by profile_dir."""
362 367 self.engine_args.extend(['profile_dir=%s'%profile_dir])
363 368 self.profile_dir = unicode(profile_dir)
364 369 return super(LocalEngineLauncher, self).start()
365 370
366 371
367 372 class LocalEngineSetLauncher(BaseLauncher):
368 373 """Launch a set of engines as regular external processes."""
369 374
370 375 # Command line arguments for ipengine.
371 376 engine_args = List(
372 377 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
373 378 help="command-line arguments to pass to ipengine"
374 379 )
375 380 # launcher class
376 381 launcher_class = LocalEngineLauncher
377 382
378 383 launchers = Dict()
379 384 stop_data = Dict()
380 385
381 386 def __init__(self, work_dir=u'.', config=None, **kwargs):
382 387 super(LocalEngineSetLauncher, self).__init__(
383 388 work_dir=work_dir, config=config, **kwargs
384 389 )
385 390 self.stop_data = {}
386 391
387 392 def start(self, n, profile_dir):
388 393 """Start n engines by profile or profile_dir."""
389 394 self.profile_dir = unicode(profile_dir)
390 395 dlist = []
391 396 for i in range(n):
392 397 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
393 398 # Copy the engine args over to each engine launcher.
394 399 el.engine_args = copy.deepcopy(self.engine_args)
395 400 el.on_stop(self._notice_engine_stopped)
396 401 d = el.start(profile_dir)
397 402 if i==0:
398 403 self.log.info("Starting LocalEngineSetLauncher: %r" % el.args)
399 404 self.launchers[i] = el
400 405 dlist.append(d)
401 406 self.notify_start(dlist)
402 407 # The consumeErrors here could be dangerous
403 408 # dfinal = gatherBoth(dlist, consumeErrors=True)
404 409 # dfinal.addCallback(self.notify_start)
405 410 return dlist
406 411
407 412 def find_args(self):
408 413 return ['engine set']
409 414
410 415 def signal(self, sig):
411 416 dlist = []
412 417 for el in self.launchers.itervalues():
413 418 d = el.signal(sig)
414 419 dlist.append(d)
415 420 # dfinal = gatherBoth(dlist, consumeErrors=True)
416 421 return dlist
417 422
418 423 def interrupt_then_kill(self, delay=1.0):
419 424 dlist = []
420 425 for el in self.launchers.itervalues():
421 426 d = el.interrupt_then_kill(delay)
422 427 dlist.append(d)
423 428 # dfinal = gatherBoth(dlist, consumeErrors=True)
424 429 return dlist
425 430
426 431 def stop(self):
427 432 return self.interrupt_then_kill()
428 433
429 434 def _notice_engine_stopped(self, data):
430 435 pid = data['pid']
431 436 for idx,el in self.launchers.iteritems():
432 437 if el.process.pid == pid:
433 438 break
434 439 self.launchers.pop(idx)
435 440 self.stop_data[idx] = data
436 441 if not self.launchers:
437 442 self.notify_stop(self.stop_data)
438 443
439 444
440 445 #-----------------------------------------------------------------------------
441 446 # MPIExec launchers
442 447 #-----------------------------------------------------------------------------
443 448
444 449
445 450 class MPIExecLauncher(LocalProcessLauncher):
446 451 """Launch an external process using mpiexec."""
447 452
448 453 mpi_cmd = List(['mpiexec'], config=True,
449 454 help="The mpiexec command to use in starting the process."
450 455 )
451 456 mpi_args = List([], config=True,
452 457 help="The command line arguments to pass to mpiexec."
453 458 )
454 459 program = List(['date'], config=True,
455 460 help="The program to start via mpiexec.")
456 461 program_args = List([], config=True,
457 462 help="The command line argument to the program."
458 463 )
459 464 n = Int(1)
460 465
461 466 def find_args(self):
462 467 """Build self.args using all the fields."""
463 468 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
464 469 self.program + self.program_args
465 470
466 471 def start(self, n):
467 472 """Start n instances of the program using mpiexec."""
468 473 self.n = n
469 474 return super(MPIExecLauncher, self).start()
470 475
471 476
472 477 class MPIExecControllerLauncher(MPIExecLauncher):
473 478 """Launch a controller using mpiexec."""
474 479
475 480 controller_cmd = List(ipcontroller_cmd_argv, config=True,
476 481 help="Popen command to launch the Contropper"
477 482 )
478 483 controller_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
479 484 help="Command line arguments to pass to ipcontroller."
480 485 )
481 486 n = Int(1)
482 487
483 488 def start(self, profile_dir):
484 489 """Start the controller by profile_dir."""
485 490 self.controller_args.extend(['profile_dir=%s'%profile_dir])
486 491 self.profile_dir = unicode(profile_dir)
487 492 self.log.info("Starting MPIExecControllerLauncher: %r" % self.args)
488 493 return super(MPIExecControllerLauncher, self).start(1)
489 494
490 495 def find_args(self):
491 496 return self.mpi_cmd + ['-n', self.n] + self.mpi_args + \
492 497 self.controller_cmd + self.controller_args
493 498
494 499
495 500 class MPIExecEngineSetLauncher(MPIExecLauncher):
496 501
497 502 program = List(ipengine_cmd_argv, config=True,
498 503 help="Popen command for ipengine"
499 504 )
500 505 program_args = List(
501 506 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
502 507 help="Command line arguments for ipengine."
503 508 )
504 509 n = Int(1)
505 510
506 511 def start(self, n, profile_dir):
507 512 """Start n engines by profile or profile_dir."""
508 513 self.program_args.extend(['profile_dir=%s'%profile_dir])
509 514 self.profile_dir = unicode(profile_dir)
510 515 self.n = n
511 516 self.log.info('Starting MPIExecEngineSetLauncher: %r' % self.args)
512 517 return super(MPIExecEngineSetLauncher, self).start(n)
513 518
514 519 #-----------------------------------------------------------------------------
515 520 # SSH launchers
516 521 #-----------------------------------------------------------------------------
517 522
518 523 # TODO: Get SSH Launcher working again.
519 524
520 525 class SSHLauncher(LocalProcessLauncher):
521 526 """A minimal launcher for ssh.
522 527
523 528 To be useful this will probably have to be extended to use the ``sshx``
524 529 idea for environment variables. There could be other things this needs
525 530 as well.
526 531 """
527 532
528 533 ssh_cmd = List(['ssh'], config=True,
529 534 help="command for starting ssh")
530 535 ssh_args = List(['-tt'], config=True,
531 536 help="args to pass to ssh")
532 537 program = List(['date'], config=True,
533 538 help="Program to launch via ssh")
534 539 program_args = List([], config=True,
535 540 help="args to pass to remote program")
536 541 hostname = Unicode('', config=True,
537 542 help="hostname on which to launch the program")
538 543 user = Unicode('', config=True,
539 544 help="username for ssh")
540 545 location = Unicode('', config=True,
541 546 help="user@hostname location for ssh in one setting")
542 547
543 548 def _hostname_changed(self, name, old, new):
544 549 if self.user:
545 550 self.location = u'%s@%s' % (self.user, new)
546 551 else:
547 552 self.location = new
548 553
549 554 def _user_changed(self, name, old, new):
550 555 self.location = u'%s@%s' % (new, self.hostname)
551 556
552 557 def find_args(self):
553 558 return self.ssh_cmd + self.ssh_args + [self.location] + \
554 559 self.program + self.program_args
555 560
556 561 def start(self, profile_dir, hostname=None, user=None):
557 562 self.profile_dir = unicode(profile_dir)
558 563 if hostname is not None:
559 564 self.hostname = hostname
560 565 if user is not None:
561 566 self.user = user
562 567
563 568 return super(SSHLauncher, self).start()
564 569
565 570 def signal(self, sig):
566 571 if self.state == 'running':
567 572 # send escaped ssh connection-closer
568 573 self.process.stdin.write('~.')
569 574 self.process.stdin.flush()
570 575
571 576
572 577
573 578 class SSHControllerLauncher(SSHLauncher):
574 579
575 580 program = List(ipcontroller_cmd_argv, config=True,
576 581 help="remote ipcontroller command.")
577 582 program_args = List(['--reuse-files', '--log-to-file','log_level=%i'%logging.INFO], config=True,
578 583 help="Command line arguments to ipcontroller.")
579 584
580 585
581 586 class SSHEngineLauncher(SSHLauncher):
582 587 program = List(ipengine_cmd_argv, config=True,
583 588 help="remote ipengine command.")
584 589 # Command line arguments for ipengine.
585 590 program_args = List(
586 591 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
587 592 help="Command line arguments to ipengine."
588 593 )
589 594
590 595 class SSHEngineSetLauncher(LocalEngineSetLauncher):
591 596 launcher_class = SSHEngineLauncher
592 597 engines = Dict(config=True,
593 598 help="""dict of engines to launch. This is a dict by hostname of ints,
594 599 corresponding to the number of engines to start on that host.""")
595 600
596 601 def start(self, n, profile_dir):
597 602 """Start engines by profile or profile_dir.
598 603 `n` is ignored, and the `engines` config property is used instead.
599 604 """
600 605
601 606 self.profile_dir = unicode(profile_dir)
602 607 dlist = []
603 608 for host, n in self.engines.iteritems():
604 609 if isinstance(n, (tuple, list)):
605 610 n, args = n
606 611 else:
607 612 args = copy.deepcopy(self.engine_args)
608 613
609 614 if '@' in host:
610 615 user,host = host.split('@',1)
611 616 else:
612 617 user=None
613 618 for i in range(n):
614 619 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
615 620
616 621 # Copy the engine args over to each engine launcher.
617 622 i
618 623 el.program_args = args
619 624 el.on_stop(self._notice_engine_stopped)
620 625 d = el.start(profile_dir, user=user, hostname=host)
621 626 if i==0:
622 627 self.log.info("Starting SSHEngineSetLauncher: %r" % el.args)
623 628 self.launchers[host+str(i)] = el
624 629 dlist.append(d)
625 630 self.notify_start(dlist)
626 631 return dlist
627 632
628 633
629 634
630 635 #-----------------------------------------------------------------------------
631 636 # Windows HPC Server 2008 scheduler launchers
632 637 #-----------------------------------------------------------------------------
633 638
634 639
635 640 # This is only used on Windows.
636 641 def find_job_cmd():
637 642 if WINDOWS:
638 643 try:
639 644 return find_cmd('job')
640 645 except (FindCmdError, ImportError):
641 646 # ImportError will be raised if win32api is not installed
642 647 return 'job'
643 648 else:
644 649 return 'job'
645 650
646 651
647 652 class WindowsHPCLauncher(BaseLauncher):
648 653
649 654 job_id_regexp = Unicode(r'\d+', config=True,
650 655 help="""A regular expression used to get the job id from the output of the
651 656 submit_command. """
652 657 )
653 658 job_file_name = Unicode(u'ipython_job.xml', config=True,
654 659 help="The filename of the instantiated job script.")
655 660 # The full path to the instantiated job script. This gets made dynamically
656 661 # by combining the work_dir with the job_file_name.
657 662 job_file = Unicode(u'')
658 663 scheduler = Unicode('', config=True,
659 664 help="The hostname of the scheduler to submit the job to.")
660 665 job_cmd = Unicode(find_job_cmd(), config=True,
661 666 help="The command for submitting jobs.")
662 667
663 668 def __init__(self, work_dir=u'.', config=None, **kwargs):
664 669 super(WindowsHPCLauncher, self).__init__(
665 670 work_dir=work_dir, config=config, **kwargs
666 671 )
667 672
668 673 @property
669 674 def job_file(self):
670 675 return os.path.join(self.work_dir, self.job_file_name)
671 676
672 677 def write_job_file(self, n):
673 678 raise NotImplementedError("Implement write_job_file in a subclass.")
674 679
675 680 def find_args(self):
676 681 return [u'job.exe']
677 682
678 683 def parse_job_id(self, output):
679 684 """Take the output of the submit command and return the job id."""
680 685 m = re.search(self.job_id_regexp, output)
681 686 if m is not None:
682 687 job_id = m.group()
683 688 else:
684 689 raise LauncherError("Job id couldn't be determined: %s" % output)
685 690 self.job_id = job_id
686 691 self.log.info('Job started with job id: %r' % job_id)
687 692 return job_id
688 693
689 694 def start(self, n):
690 695 """Start n copies of the process using the Win HPC job scheduler."""
691 696 self.write_job_file(n)
692 697 args = [
693 698 'submit',
694 699 '/jobfile:%s' % self.job_file,
695 700 '/scheduler:%s' % self.scheduler
696 701 ]
697 702 self.log.info("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
698 703 # Twisted will raise DeprecationWarnings if we try to pass unicode to this
699 704 output = check_output([self.job_cmd]+args,
700 705 env=os.environ,
701 706 cwd=self.work_dir,
702 707 stderr=STDOUT
703 708 )
704 709 job_id = self.parse_job_id(output)
705 710 self.notify_start(job_id)
706 711 return job_id
707 712
708 713 def stop(self):
709 714 args = [
710 715 'cancel',
711 716 self.job_id,
712 717 '/scheduler:%s' % self.scheduler
713 718 ]
714 719 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
715 720 try:
716 721 output = check_output([self.job_cmd]+args,
717 722 env=os.environ,
718 723 cwd=self.work_dir,
719 724 stderr=STDOUT
720 725 )
721 726 except:
722 727 output = 'The job already appears to be stoppped: %r' % self.job_id
723 728 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
724 729 return output
725 730
726 731
727 732 class WindowsHPCControllerLauncher(WindowsHPCLauncher):
728 733
729 734 job_file_name = Unicode(u'ipcontroller_job.xml', config=True,
730 735 help="WinHPC xml job file.")
731 736 extra_args = List([], config=False,
732 737 help="extra args to pass to ipcontroller")
733 738
734 739 def write_job_file(self, n):
735 740 job = IPControllerJob(config=self.config)
736 741
737 742 t = IPControllerTask(config=self.config)
738 743 # The tasks work directory is *not* the actual work directory of
739 744 # the controller. It is used as the base path for the stdout/stderr
740 745 # files that the scheduler redirects to.
741 746 t.work_directory = self.profile_dir
742 747 # Add the profile_dir and from self.start().
743 748 t.controller_args.extend(self.extra_args)
744 749 job.add_task(t)
745 750
746 751 self.log.info("Writing job description file: %s" % self.job_file)
747 752 job.write(self.job_file)
748 753
749 754 @property
750 755 def job_file(self):
751 756 return os.path.join(self.profile_dir, self.job_file_name)
752 757
753 758 def start(self, profile_dir):
754 759 """Start the controller by profile_dir."""
755 760 self.extra_args = ['profile_dir=%s'%profile_dir]
756 761 self.profile_dir = unicode(profile_dir)
757 762 return super(WindowsHPCControllerLauncher, self).start(1)
758 763
759 764
760 765 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher):
761 766
762 767 job_file_name = Unicode(u'ipengineset_job.xml', config=True,
763 768 help="jobfile for ipengines job")
764 769 extra_args = List([], config=False,
765 770 help="extra args to pas to ipengine")
766 771
767 772 def write_job_file(self, n):
768 773 job = IPEngineSetJob(config=self.config)
769 774
770 775 for i in range(n):
771 776 t = IPEngineTask(config=self.config)
772 777 # The tasks work directory is *not* the actual work directory of
773 778 # the engine. It is used as the base path for the stdout/stderr
774 779 # files that the scheduler redirects to.
775 780 t.work_directory = self.profile_dir
776 781 # Add the profile_dir and from self.start().
777 782 t.engine_args.extend(self.extra_args)
778 783 job.add_task(t)
779 784
780 785 self.log.info("Writing job description file: %s" % self.job_file)
781 786 job.write(self.job_file)
782 787
783 788 @property
784 789 def job_file(self):
785 790 return os.path.join(self.profile_dir, self.job_file_name)
786 791
787 792 def start(self, n, profile_dir):
788 793 """Start the controller by profile_dir."""
789 794 self.extra_args = ['profile_dir=%s'%profile_dir]
790 795 self.profile_dir = unicode(profile_dir)
791 796 return super(WindowsHPCEngineSetLauncher, self).start(n)
792 797
793 798
794 799 #-----------------------------------------------------------------------------
795 800 # Batch (PBS) system launchers
796 801 #-----------------------------------------------------------------------------
797 802
798 803 class BatchSystemLauncher(BaseLauncher):
799 804 """Launch an external process using a batch system.
800 805
801 806 This class is designed to work with UNIX batch systems like PBS, LSF,
802 807 GridEngine, etc. The overall model is that there are different commands
803 808 like qsub, qdel, etc. that handle the starting and stopping of the process.
804 809
805 810 This class also has the notion of a batch script. The ``batch_template``
806 811 attribute can be set to a string that is a template for the batch script.
807 812 This template is instantiated using string formatting. Thus the template can
808 813 use {n} fot the number of instances. Subclasses can add additional variables
809 814 to the template dict.
810 815 """
811 816
812 817 # Subclasses must fill these in. See PBSEngineSet
813 818 submit_command = List([''], config=True,
814 819 help="The name of the command line program used to submit jobs.")
815 820 delete_command = List([''], config=True,
816 821 help="The name of the command line program used to delete jobs.")
817 822 job_id_regexp = Unicode('', config=True,
818 823 help="""A regular expression used to get the job id from the output of the
819 824 submit_command.""")
820 825 batch_template = Unicode('', config=True,
821 826 help="The string that is the batch script template itself.")
822 827 batch_template_file = Unicode(u'', config=True,
823 828 help="The file that contains the batch template.")
824 829 batch_file_name = Unicode(u'batch_script', config=True,
825 830 help="The filename of the instantiated batch script.")
826 831 queue = Unicode(u'', config=True,
827 832 help="The PBS Queue.")
828 833
829 834 # not configurable, override in subclasses
830 835 # PBS Job Array regex
831 836 job_array_regexp = Unicode('')
832 837 job_array_template = Unicode('')
833 838 # PBS Queue regex
834 839 queue_regexp = Unicode('')
835 840 queue_template = Unicode('')
836 841 # The default batch template, override in subclasses
837 842 default_template = Unicode('')
838 843 # The full path to the instantiated batch script.
839 844 batch_file = Unicode(u'')
840 845 # the format dict used with batch_template:
841 846 context = Dict()
842 847 # the Formatter instance for rendering the templates:
843 848 formatter = Instance(EvalFormatter, (), {})
844 849
845 850
846 851 def find_args(self):
847 852 return self.submit_command + [self.batch_file]
848 853
849 854 def __init__(self, work_dir=u'.', config=None, **kwargs):
850 855 super(BatchSystemLauncher, self).__init__(
851 856 work_dir=work_dir, config=config, **kwargs
852 857 )
853 858 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
854 859
855 860 def parse_job_id(self, output):
856 861 """Take the output of the submit command and return the job id."""
857 862 m = re.search(self.job_id_regexp, output)
858 863 if m is not None:
859 864 job_id = m.group()
860 865 else:
861 866 raise LauncherError("Job id couldn't be determined: %s" % output)
862 867 self.job_id = job_id
863 868 self.log.info('Job submitted with job id: %r' % job_id)
864 869 return job_id
865 870
866 871 def write_batch_script(self, n):
867 872 """Instantiate and write the batch script to the work_dir."""
868 873 self.context['n'] = n
869 874 self.context['queue'] = self.queue
870 875 # first priority is batch_template if set
871 876 if self.batch_template_file and not self.batch_template:
872 877 # second priority is batch_template_file
873 878 with open(self.batch_template_file) as f:
874 879 self.batch_template = f.read()
875 880 if not self.batch_template:
876 881 # third (last) priority is default_template
877 882 self.batch_template = self.default_template
878 883
879 884 regex = re.compile(self.job_array_regexp)
880 885 # print regex.search(self.batch_template)
881 886 if not regex.search(self.batch_template):
882 887 self.log.info("adding job array settings to batch script")
883 888 firstline, rest = self.batch_template.split('\n',1)
884 889 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
885 890
886 891 regex = re.compile(self.queue_regexp)
887 892 # print regex.search(self.batch_template)
888 893 if self.queue and not regex.search(self.batch_template):
889 894 self.log.info("adding PBS queue settings to batch script")
890 895 firstline, rest = self.batch_template.split('\n',1)
891 896 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
892 897
893 898 script_as_string = self.formatter.format(self.batch_template, **self.context)
894 899 self.log.info('Writing instantiated batch script: %s' % self.batch_file)
895 900
896 901 with open(self.batch_file, 'w') as f:
897 902 f.write(script_as_string)
898 903 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
899 904
900 905 def start(self, n, profile_dir):
901 906 """Start n copies of the process using a batch system."""
902 907 # Here we save profile_dir in the context so they
903 908 # can be used in the batch script template as {profile_dir}
904 909 self.context['profile_dir'] = profile_dir
905 910 self.profile_dir = unicode(profile_dir)
906 911 self.write_batch_script(n)
907 912 output = check_output(self.args, env=os.environ)
908 913
909 914 job_id = self.parse_job_id(output)
910 915 self.notify_start(job_id)
911 916 return job_id
912 917
913 918 def stop(self):
914 919 output = check_output(self.delete_command+[self.job_id], env=os.environ)
915 920 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
916 921 return output
917 922
918 923
919 924 class PBSLauncher(BatchSystemLauncher):
920 925 """A BatchSystemLauncher subclass for PBS."""
921 926
922 927 submit_command = List(['qsub'], config=True,
923 928 help="The PBS submit command ['qsub']")
924 929 delete_command = List(['qdel'], config=True,
925 930 help="The PBS delete command ['qsub']")
926 931 job_id_regexp = Unicode(r'\d+', config=True,
927 932 help="Regular expresion for identifying the job ID [r'\d+']")
928 933
929 934 batch_file = Unicode(u'')
930 935 job_array_regexp = Unicode('#PBS\W+-t\W+[\w\d\-\$]+')
931 936 job_array_template = Unicode('#PBS -t 1-{n}')
932 937 queue_regexp = Unicode('#PBS\W+-q\W+\$?\w+')
933 938 queue_template = Unicode('#PBS -q {queue}')
934 939
935 940
936 941 class PBSControllerLauncher(PBSLauncher):
937 942 """Launch a controller using PBS."""
938 943
939 944 batch_file_name = Unicode(u'pbs_controller', config=True,
940 945 help="batch file name for the controller job.")
941 946 default_template= Unicode("""#!/bin/sh
942 947 #PBS -V
943 948 #PBS -N ipcontroller
944 949 %s --log-to-file profile_dir={profile_dir}
945 950 """%(' '.join(ipcontroller_cmd_argv)))
946 951
947 952 def start(self, profile_dir):
948 953 """Start the controller by profile or profile_dir."""
949 954 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
950 955 return super(PBSControllerLauncher, self).start(1, profile_dir)
951 956
952 957
953 958 class PBSEngineSetLauncher(PBSLauncher):
954 959 """Launch Engines using PBS"""
955 960 batch_file_name = Unicode(u'pbs_engines', config=True,
956 961 help="batch file name for the engine(s) job.")
957 962 default_template= Unicode(u"""#!/bin/sh
958 963 #PBS -V
959 964 #PBS -N ipengine
960 965 %s profile_dir={profile_dir}
961 966 """%(' '.join(ipengine_cmd_argv)))
962 967
963 968 def start(self, n, profile_dir):
964 969 """Start n engines by profile or profile_dir."""
965 970 self.log.info('Starting %i engines with PBSEngineSetLauncher: %r' % (n, self.args))
966 971 return super(PBSEngineSetLauncher, self).start(n, profile_dir)
967 972
968 973 #SGE is very similar to PBS
969 974
970 975 class SGELauncher(PBSLauncher):
971 976 """Sun GridEngine is a PBS clone with slightly different syntax"""
972 977 job_array_regexp = Unicode('#\$\W+\-t')
973 978 job_array_template = Unicode('#$ -t 1-{n}')
974 979 queue_regexp = Unicode('#\$\W+-q\W+\$?\w+')
975 980 queue_template = Unicode('#$ -q $queue')
976 981
977 982 class SGEControllerLauncher(SGELauncher):
978 983 """Launch a controller using SGE."""
979 984
980 985 batch_file_name = Unicode(u'sge_controller', config=True,
981 986 help="batch file name for the ipontroller job.")
982 987 default_template= Unicode(u"""#$ -V
983 988 #$ -S /bin/sh
984 989 #$ -N ipcontroller
985 990 %s --log-to-file profile_dir={profile_dir}
986 991 """%(' '.join(ipcontroller_cmd_argv)))
987 992
988 993 def start(self, profile_dir):
989 994 """Start the controller by profile or profile_dir."""
990 995 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
991 996 return super(SGEControllerLauncher, self).start(1, profile_dir)
992 997
993 998 class SGEEngineSetLauncher(SGELauncher):
994 999 """Launch Engines with SGE"""
995 1000 batch_file_name = Unicode(u'sge_engines', config=True,
996 1001 help="batch file name for the engine(s) job.")
997 1002 default_template = Unicode("""#$ -V
998 1003 #$ -S /bin/sh
999 1004 #$ -N ipengine
1000 1005 %s profile_dir={profile_dir}
1001 1006 """%(' '.join(ipengine_cmd_argv)))
1002 1007
1003 1008 def start(self, n, profile_dir):
1004 1009 """Start n engines by profile or profile_dir."""
1005 1010 self.log.info('Starting %i engines with SGEEngineSetLauncher: %r' % (n, self.args))
1006 1011 return super(SGEEngineSetLauncher, self).start(n, profile_dir)
1007 1012
1008 1013
1009 1014 #-----------------------------------------------------------------------------
1010 1015 # A launcher for ipcluster itself!
1011 1016 #-----------------------------------------------------------------------------
1012 1017
1013 1018
1014 1019 class IPClusterLauncher(LocalProcessLauncher):
1015 1020 """Launch the ipcluster program in an external process."""
1016 1021
1017 1022 ipcluster_cmd = List(ipcluster_cmd_argv, config=True,
1018 1023 help="Popen command for ipcluster")
1019 1024 ipcluster_args = List(
1020 1025 ['--clean-logs', '--log-to-file', 'log_level=%i'%logging.INFO], config=True,
1021 1026 help="Command line arguments to pass to ipcluster.")
1022 1027 ipcluster_subcommand = Unicode('start')
1023 1028 ipcluster_n = Int(2)
1024 1029
1025 1030 def find_args(self):
1026 1031 return self.ipcluster_cmd + ['--'+self.ipcluster_subcommand] + \
1027 1032 ['n=%i'%self.ipcluster_n] + self.ipcluster_args
1028 1033
1029 1034 def start(self):
1030 1035 self.log.info("Starting ipcluster: %r" % self.args)
1031 1036 return super(IPClusterLauncher, self).start()
1032 1037
1033 1038 #-----------------------------------------------------------------------------
1034 1039 # Collections of launchers
1035 1040 #-----------------------------------------------------------------------------
1036 1041
1037 1042 local_launchers = [
1038 1043 LocalControllerLauncher,
1039 1044 LocalEngineLauncher,
1040 1045 LocalEngineSetLauncher,
1041 1046 ]
1042 1047 mpi_launchers = [
1043 1048 MPIExecLauncher,
1044 1049 MPIExecControllerLauncher,
1045 1050 MPIExecEngineSetLauncher,
1046 1051 ]
1047 1052 ssh_launchers = [
1048 1053 SSHLauncher,
1049 1054 SSHControllerLauncher,
1050 1055 SSHEngineLauncher,
1051 1056 SSHEngineSetLauncher,
1052 1057 ]
1053 1058 winhpc_launchers = [
1054 1059 WindowsHPCLauncher,
1055 1060 WindowsHPCControllerLauncher,
1056 1061 WindowsHPCEngineSetLauncher,
1057 1062 ]
1058 1063 pbs_launchers = [
1059 1064 PBSLauncher,
1060 1065 PBSControllerLauncher,
1061 1066 PBSEngineSetLauncher,
1062 1067 ]
1063 1068 sge_launchers = [
1064 1069 SGELauncher,
1065 1070 SGEControllerLauncher,
1066 1071 SGEEngineSetLauncher,
1067 1072 ]
1068 1073 all_launchers = local_launchers + mpi_launchers + ssh_launchers + winhpc_launchers\
1069 1074 + pbs_launchers + sge_launchers
@@ -1,108 +1,115 b''
1 1 #!/usr/bin/env python
2 """A simple logger object that consolidates messages incoming from ipcluster processes."""
2 """
3 A simple logger object that consolidates messages incoming from ipcluster processes.
4
5 Authors:
6
7 * MinRK
8
9 """
3 10
4 11 #-----------------------------------------------------------------------------
5 12 # Copyright (C) 2011 The IPython Development Team
6 13 #
7 14 # Distributed under the terms of the BSD License. The full license is in
8 15 # the file COPYING, distributed as part of this software.
9 16 #-----------------------------------------------------------------------------
10 17
11 18 #-----------------------------------------------------------------------------
12 19 # Imports
13 20 #-----------------------------------------------------------------------------
14 21
15 22
16 23 import logging
17 24 import sys
18 25
19 26 import zmq
20 27 from zmq.eventloop import ioloop, zmqstream
21 28
22 29 from IPython.config.configurable import LoggingConfigurable
23 30 from IPython.utils.traitlets import Int, Unicode, Instance, List
24 31
25 32 #-----------------------------------------------------------------------------
26 33 # Classes
27 34 #-----------------------------------------------------------------------------
28 35
29 36
30 37 class LogWatcher(LoggingConfigurable):
31 38 """A simple class that receives messages on a SUB socket, as published
32 39 by subclasses of `zmq.log.handlers.PUBHandler`, and logs them itself.
33 40
34 41 This can subscribe to multiple topics, but defaults to all topics.
35 42 """
36 43
37 44 # configurables
38 45 topics = List([''], config=True,
39 46 help="The ZMQ topics to subscribe to. Default is to subscribe to all messages")
40 47 url = Unicode('tcp://127.0.0.1:20202', config=True,
41 48 help="ZMQ url on which to listen for log messages")
42 49
43 50 # internals
44 51 stream = Instance('zmq.eventloop.zmqstream.ZMQStream')
45 52
46 53 context = Instance(zmq.Context)
47 54 def _context_default(self):
48 55 return zmq.Context.instance()
49 56
50 57 loop = Instance(zmq.eventloop.ioloop.IOLoop)
51 58 def _loop_default(self):
52 59 return ioloop.IOLoop.instance()
53 60
54 61 def __init__(self, **kwargs):
55 62 super(LogWatcher, self).__init__(**kwargs)
56 63 s = self.context.socket(zmq.SUB)
57 64 s.bind(self.url)
58 65 self.stream = zmqstream.ZMQStream(s, self.loop)
59 66 self.subscribe()
60 67 self.on_trait_change(self.subscribe, 'topics')
61 68
62 69 def start(self):
63 70 self.stream.on_recv(self.log_message)
64 71
65 72 def stop(self):
66 73 self.stream.stop_on_recv()
67 74
68 75 def subscribe(self):
69 76 """Update our SUB socket's subscriptions."""
70 77 self.stream.setsockopt(zmq.UNSUBSCRIBE, '')
71 78 if '' in self.topics:
72 79 self.log.debug("Subscribing to: everything")
73 80 self.stream.setsockopt(zmq.SUBSCRIBE, '')
74 81 else:
75 82 for topic in self.topics:
76 83 self.log.debug("Subscribing to: %r"%(topic))
77 84 self.stream.setsockopt(zmq.SUBSCRIBE, topic)
78 85
79 86 def _extract_level(self, topic_str):
80 87 """Turn 'engine.0.INFO.extra' into (logging.INFO, 'engine.0.extra')"""
81 88 topics = topic_str.split('.')
82 89 for idx,t in enumerate(topics):
83 90 level = getattr(logging, t, None)
84 91 if level is not None:
85 92 break
86 93
87 94 if level is None:
88 95 level = logging.INFO
89 96 else:
90 97 topics.pop(idx)
91 98
92 99 return level, '.'.join(topics)
93 100
94 101
95 102 def log_message(self, raw):
96 103 """receive and parse a message, then log it."""
97 104 if len(raw) != 2 or '.' not in raw[0]:
98 105 self.log.error("Invalid log message: %s"%raw)
99 106 return
100 107 else:
101 108 topic, msg = raw
102 109 # don't newline, since log messages always newline:
103 110 topic,level_name = topic.rsplit('.',1)
104 111 level,topic = self._extract_level(topic)
105 112 if msg[-1] == '\n':
106 113 msg = msg[:-1]
107 114 self.log.log(level, "[%s] %s" % (topic, msg))
108 115
@@ -1,67 +1,73 b''
1 1 #!/usr/bin/env python
2 2 """Utility for forwarding file read events over a zmq socket.
3 3
4 This is necessary because select on Windows only supports"""
4 This is necessary because select on Windows only supports
5
6 Authors:
7
8 * MinRK
9
10 """
5 11
6 12 #-----------------------------------------------------------------------------
7 13 # Copyright (C) 2011 The IPython Development Team
8 14 #
9 15 # Distributed under the terms of the BSD License. The full license is in
10 16 # the file COPYING, distributed as part of this software.
11 17 #-----------------------------------------------------------------------------
12 18
13 19 #-----------------------------------------------------------------------------
14 20 # Imports
15 21 #-----------------------------------------------------------------------------
16 22
17 23 import uuid
18 24 import zmq
19 25
20 26 from threading import Thread
21 27
22 28 #-----------------------------------------------------------------------------
23 29 # Code
24 30 #-----------------------------------------------------------------------------
25 31
26 32 class ForwarderThread(Thread):
27 33 def __init__(self, sock, fd):
28 34 Thread.__init__(self)
29 35 self.daemon=True
30 36 self.sock = sock
31 37 self.fd = fd
32 38
33 39 def run(self):
34 40 """loop through lines in self.fd, and send them over self.sock"""
35 41 line = self.fd.readline()
36 42 # allow for files opened in unicode mode
37 43 if isinstance(line, unicode):
38 44 send = self.sock.send_unicode
39 45 else:
40 46 send = self.sock.send
41 47 while line:
42 48 send(line)
43 49 line = self.fd.readline()
44 50 # line == '' means EOF
45 51 self.fd.close()
46 52 self.sock.close()
47 53
48 54 def forward_read_events(fd, context=None):
49 55 """forward read events from an FD over a socket.
50 56
51 57 This method wraps a file in a socket pair, so it can
52 58 be polled for read events by select (specifically zmq.eventloop.ioloop)
53 59 """
54 60 if context is None:
55 61 context = zmq.Context.instance()
56 62 push = context.socket(zmq.PUSH)
57 63 push.setsockopt(zmq.LINGER, -1)
58 64 pull = context.socket(zmq.PULL)
59 65 addr='inproc://%s'%uuid.uuid4()
60 66 push.bind(addr)
61 67 pull.connect(addr)
62 68 forwarder = ForwarderThread(push, fd)
63 69 forwarder.start()
64 70 return pull
65 71
66 72
67 73 __all__ = ['forward_read_events'] No newline at end of file
@@ -1,314 +1,320 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 Job and task components for writing .xml files that the Windows HPC Server
5 5 2008 can use to start jobs.
6
7 Authors:
8
9 * Brian Granger
10 * MinRK
11
6 12 """
7 13
8 14 #-----------------------------------------------------------------------------
9 # Copyright (C) 2008-2009 The IPython Development Team
15 # Copyright (C) 2008-2011 The IPython Development Team
10 16 #
11 17 # Distributed under the terms of the BSD License. The full license is in
12 18 # the file COPYING, distributed as part of this software.
13 19 #-----------------------------------------------------------------------------
14 20
15 21 #-----------------------------------------------------------------------------
16 22 # Imports
17 23 #-----------------------------------------------------------------------------
18 24
19 25 import os
20 26 import re
21 27 import uuid
22 28
23 29 from xml.etree import ElementTree as ET
24 30
25 31 from IPython.config.configurable import Configurable
26 32 from IPython.utils.traitlets import (
27 33 Unicode, Int, List, Instance,
28 34 Enum, Bool
29 35 )
30 36
31 37 #-----------------------------------------------------------------------------
32 38 # Job and Task classes
33 39 #-----------------------------------------------------------------------------
34 40
35 41
36 42 def as_str(value):
37 43 if isinstance(value, str):
38 44 return value
39 45 elif isinstance(value, bool):
40 46 if value:
41 47 return 'true'
42 48 else:
43 49 return 'false'
44 50 elif isinstance(value, (int, float)):
45 51 return repr(value)
46 52 else:
47 53 return value
48 54
49 55
50 56 def indent(elem, level=0):
51 57 i = "\n" + level*" "
52 58 if len(elem):
53 59 if not elem.text or not elem.text.strip():
54 60 elem.text = i + " "
55 61 if not elem.tail or not elem.tail.strip():
56 62 elem.tail = i
57 63 for elem in elem:
58 64 indent(elem, level+1)
59 65 if not elem.tail or not elem.tail.strip():
60 66 elem.tail = i
61 67 else:
62 68 if level and (not elem.tail or not elem.tail.strip()):
63 69 elem.tail = i
64 70
65 71
66 72 def find_username():
67 73 domain = os.environ.get('USERDOMAIN')
68 74 username = os.environ.get('USERNAME','')
69 75 if domain is None:
70 76 return username
71 77 else:
72 78 return '%s\\%s' % (domain, username)
73 79
74 80
75 81 class WinHPCJob(Configurable):
76 82
77 83 job_id = Unicode('')
78 84 job_name = Unicode('MyJob', config=True)
79 85 min_cores = Int(1, config=True)
80 86 max_cores = Int(1, config=True)
81 87 min_sockets = Int(1, config=True)
82 88 max_sockets = Int(1, config=True)
83 89 min_nodes = Int(1, config=True)
84 90 max_nodes = Int(1, config=True)
85 91 unit_type = Unicode("Core", config=True)
86 92 auto_calculate_min = Bool(True, config=True)
87 93 auto_calculate_max = Bool(True, config=True)
88 94 run_until_canceled = Bool(False, config=True)
89 95 is_exclusive = Bool(False, config=True)
90 96 username = Unicode(find_username(), config=True)
91 97 job_type = Unicode('Batch', config=True)
92 98 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
93 99 default_value='Highest', config=True)
94 100 requested_nodes = Unicode('', config=True)
95 101 project = Unicode('IPython', config=True)
96 102 xmlns = Unicode('http://schemas.microsoft.com/HPCS2008/scheduler/')
97 103 version = Unicode("2.000")
98 104 tasks = List([])
99 105
100 106 @property
101 107 def owner(self):
102 108 return self.username
103 109
104 110 def _write_attr(self, root, attr, key):
105 111 s = as_str(getattr(self, attr, ''))
106 112 if s:
107 113 root.set(key, s)
108 114
109 115 def as_element(self):
110 116 # We have to add _A_ type things to get the right order than
111 117 # the MSFT XML parser expects.
112 118 root = ET.Element('Job')
113 119 self._write_attr(root, 'version', '_A_Version')
114 120 self._write_attr(root, 'job_name', '_B_Name')
115 121 self._write_attr(root, 'unit_type', '_C_UnitType')
116 122 self._write_attr(root, 'min_cores', '_D_MinCores')
117 123 self._write_attr(root, 'max_cores', '_E_MaxCores')
118 124 self._write_attr(root, 'min_sockets', '_F_MinSockets')
119 125 self._write_attr(root, 'max_sockets', '_G_MaxSockets')
120 126 self._write_attr(root, 'min_nodes', '_H_MinNodes')
121 127 self._write_attr(root, 'max_nodes', '_I_MaxNodes')
122 128 self._write_attr(root, 'run_until_canceled', '_J_RunUntilCanceled')
123 129 self._write_attr(root, 'is_exclusive', '_K_IsExclusive')
124 130 self._write_attr(root, 'username', '_L_UserName')
125 131 self._write_attr(root, 'job_type', '_M_JobType')
126 132 self._write_attr(root, 'priority', '_N_Priority')
127 133 self._write_attr(root, 'requested_nodes', '_O_RequestedNodes')
128 134 self._write_attr(root, 'auto_calculate_max', '_P_AutoCalculateMax')
129 135 self._write_attr(root, 'auto_calculate_min', '_Q_AutoCalculateMin')
130 136 self._write_attr(root, 'project', '_R_Project')
131 137 self._write_attr(root, 'owner', '_S_Owner')
132 138 self._write_attr(root, 'xmlns', '_T_xmlns')
133 139 dependencies = ET.SubElement(root, "Dependencies")
134 140 etasks = ET.SubElement(root, "Tasks")
135 141 for t in self.tasks:
136 142 etasks.append(t.as_element())
137 143 return root
138 144
139 145 def tostring(self):
140 146 """Return the string representation of the job description XML."""
141 147 root = self.as_element()
142 148 indent(root)
143 149 txt = ET.tostring(root, encoding="utf-8")
144 150 # Now remove the tokens used to order the attributes.
145 151 txt = re.sub(r'_[A-Z]_','',txt)
146 152 txt = '<?xml version="1.0" encoding="utf-8"?>\n' + txt
147 153 return txt
148 154
149 155 def write(self, filename):
150 156 """Write the XML job description to a file."""
151 157 txt = self.tostring()
152 158 with open(filename, 'w') as f:
153 159 f.write(txt)
154 160
155 161 def add_task(self, task):
156 162 """Add a task to the job.
157 163
158 164 Parameters
159 165 ----------
160 166 task : :class:`WinHPCTask`
161 167 The task object to add.
162 168 """
163 169 self.tasks.append(task)
164 170
165 171
166 172 class WinHPCTask(Configurable):
167 173
168 174 task_id = Unicode('')
169 175 task_name = Unicode('')
170 176 version = Unicode("2.000")
171 177 min_cores = Int(1, config=True)
172 178 max_cores = Int(1, config=True)
173 179 min_sockets = Int(1, config=True)
174 180 max_sockets = Int(1, config=True)
175 181 min_nodes = Int(1, config=True)
176 182 max_nodes = Int(1, config=True)
177 183 unit_type = Unicode("Core", config=True)
178 184 command_line = Unicode('', config=True)
179 185 work_directory = Unicode('', config=True)
180 186 is_rerunnaable = Bool(True, config=True)
181 187 std_out_file_path = Unicode('', config=True)
182 188 std_err_file_path = Unicode('', config=True)
183 189 is_parametric = Bool(False, config=True)
184 190 environment_variables = Instance(dict, args=(), config=True)
185 191
186 192 def _write_attr(self, root, attr, key):
187 193 s = as_str(getattr(self, attr, ''))
188 194 if s:
189 195 root.set(key, s)
190 196
191 197 def as_element(self):
192 198 root = ET.Element('Task')
193 199 self._write_attr(root, 'version', '_A_Version')
194 200 self._write_attr(root, 'task_name', '_B_Name')
195 201 self._write_attr(root, 'min_cores', '_C_MinCores')
196 202 self._write_attr(root, 'max_cores', '_D_MaxCores')
197 203 self._write_attr(root, 'min_sockets', '_E_MinSockets')
198 204 self._write_attr(root, 'max_sockets', '_F_MaxSockets')
199 205 self._write_attr(root, 'min_nodes', '_G_MinNodes')
200 206 self._write_attr(root, 'max_nodes', '_H_MaxNodes')
201 207 self._write_attr(root, 'command_line', '_I_CommandLine')
202 208 self._write_attr(root, 'work_directory', '_J_WorkDirectory')
203 209 self._write_attr(root, 'is_rerunnaable', '_K_IsRerunnable')
204 210 self._write_attr(root, 'std_out_file_path', '_L_StdOutFilePath')
205 211 self._write_attr(root, 'std_err_file_path', '_M_StdErrFilePath')
206 212 self._write_attr(root, 'is_parametric', '_N_IsParametric')
207 213 self._write_attr(root, 'unit_type', '_O_UnitType')
208 214 root.append(self.get_env_vars())
209 215 return root
210 216
211 217 def get_env_vars(self):
212 218 env_vars = ET.Element('EnvironmentVariables')
213 219 for k, v in self.environment_variables.iteritems():
214 220 variable = ET.SubElement(env_vars, "Variable")
215 221 name = ET.SubElement(variable, "Name")
216 222 name.text = k
217 223 value = ET.SubElement(variable, "Value")
218 224 value.text = v
219 225 return env_vars
220 226
221 227
222 228
223 229 # By declaring these, we can configure the controller and engine separately!
224 230
225 231 class IPControllerJob(WinHPCJob):
226 232 job_name = Unicode('IPController', config=False)
227 233 is_exclusive = Bool(False, config=True)
228 234 username = Unicode(find_username(), config=True)
229 235 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
230 236 default_value='Highest', config=True)
231 237 requested_nodes = Unicode('', config=True)
232 238 project = Unicode('IPython', config=True)
233 239
234 240
235 241 class IPEngineSetJob(WinHPCJob):
236 242 job_name = Unicode('IPEngineSet', config=False)
237 243 is_exclusive = Bool(False, config=True)
238 244 username = Unicode(find_username(), config=True)
239 245 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
240 246 default_value='Highest', config=True)
241 247 requested_nodes = Unicode('', config=True)
242 248 project = Unicode('IPython', config=True)
243 249
244 250
245 251 class IPControllerTask(WinHPCTask):
246 252
247 253 task_name = Unicode('IPController', config=True)
248 254 controller_cmd = List(['ipcontroller.exe'], config=True)
249 255 controller_args = List(['--log-to-file', '--log-level', '40'], config=True)
250 256 # I don't want these to be configurable
251 257 std_out_file_path = Unicode('', config=False)
252 258 std_err_file_path = Unicode('', config=False)
253 259 min_cores = Int(1, config=False)
254 260 max_cores = Int(1, config=False)
255 261 min_sockets = Int(1, config=False)
256 262 max_sockets = Int(1, config=False)
257 263 min_nodes = Int(1, config=False)
258 264 max_nodes = Int(1, config=False)
259 265 unit_type = Unicode("Core", config=False)
260 266 work_directory = Unicode('', config=False)
261 267
262 268 def __init__(self, config=None):
263 269 super(IPControllerTask, self).__init__(config=config)
264 270 the_uuid = uuid.uuid1()
265 271 self.std_out_file_path = os.path.join('log','ipcontroller-%s.out' % the_uuid)
266 272 self.std_err_file_path = os.path.join('log','ipcontroller-%s.err' % the_uuid)
267 273
268 274 @property
269 275 def command_line(self):
270 276 return ' '.join(self.controller_cmd + self.controller_args)
271 277
272 278
273 279 class IPEngineTask(WinHPCTask):
274 280
275 281 task_name = Unicode('IPEngine', config=True)
276 282 engine_cmd = List(['ipengine.exe'], config=True)
277 283 engine_args = List(['--log-to-file', '--log-level', '40'], config=True)
278 284 # I don't want these to be configurable
279 285 std_out_file_path = Unicode('', config=False)
280 286 std_err_file_path = Unicode('', config=False)
281 287 min_cores = Int(1, config=False)
282 288 max_cores = Int(1, config=False)
283 289 min_sockets = Int(1, config=False)
284 290 max_sockets = Int(1, config=False)
285 291 min_nodes = Int(1, config=False)
286 292 max_nodes = Int(1, config=False)
287 293 unit_type = Unicode("Core", config=False)
288 294 work_directory = Unicode('', config=False)
289 295
290 296 def __init__(self, config=None):
291 297 super(IPEngineTask,self).__init__(config=config)
292 298 the_uuid = uuid.uuid1()
293 299 self.std_out_file_path = os.path.join('log','ipengine-%s.out' % the_uuid)
294 300 self.std_err_file_path = os.path.join('log','ipengine-%s.err' % the_uuid)
295 301
296 302 @property
297 303 def command_line(self):
298 304 return ' '.join(self.engine_cmd + self.engine_args)
299 305
300 306
301 307 # j = WinHPCJob(None)
302 308 # j.job_name = 'IPCluster'
303 309 # j.username = 'GNET\\bgranger'
304 310 # j.requested_nodes = 'GREEN'
305 311 #
306 312 # t = WinHPCTask(None)
307 313 # t.task_name = 'Controller'
308 314 # t.command_line = r"\\blue\domainusers$\bgranger\Python\Python25\Scripts\ipcontroller.exe --log-to-file -p default --log-level 10"
309 315 # t.work_directory = r"\\blue\domainusers$\bgranger\.ipython\cluster_default"
310 316 # t.std_out_file_path = 'controller-out.txt'
311 317 # t.std_err_file_path = 'controller-err.txt'
312 318 # t.environment_variables['PYTHONPATH'] = r"\\blue\domainusers$\bgranger\Python\Python25\Lib\site-packages"
313 319 # j.add_task(t)
314 320
@@ -1,340 +1,345 b''
1 """AsyncResult objects for the client"""
1 """AsyncResult objects for the client
2
3 Authors:
4
5 * MinRK
6 """
2 7 #-----------------------------------------------------------------------------
3 8 # Copyright (C) 2010-2011 The IPython Development Team
4 9 #
5 10 # Distributed under the terms of the BSD License. The full license is in
6 11 # the file COPYING, distributed as part of this software.
7 12 #-----------------------------------------------------------------------------
8 13
9 14 #-----------------------------------------------------------------------------
10 15 # Imports
11 16 #-----------------------------------------------------------------------------
12 17
13 18 import time
14 19
15 20 from zmq import MessageTracker
16 21
17 22 from IPython.external.decorator import decorator
18 23 from IPython.parallel import error
19 24
20 25 #-----------------------------------------------------------------------------
21 26 # Classes
22 27 #-----------------------------------------------------------------------------
23 28
24 29 # global empty tracker that's always done:
25 30 finished_tracker = MessageTracker()
26 31
27 32 @decorator
28 33 def check_ready(f, self, *args, **kwargs):
29 34 """Call spin() to sync state prior to calling the method."""
30 35 self.wait(0)
31 36 if not self._ready:
32 37 raise error.TimeoutError("result not ready")
33 38 return f(self, *args, **kwargs)
34 39
35 40 class AsyncResult(object):
36 41 """Class for representing results of non-blocking calls.
37 42
38 43 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
39 44 """
40 45
41 46 msg_ids = None
42 47 _targets = None
43 48 _tracker = None
44 49 _single_result = False
45 50
46 51 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
47 52 if isinstance(msg_ids, basestring):
48 53 # always a list
49 54 msg_ids = [msg_ids]
50 55 if tracker is None:
51 56 # default to always done
52 57 tracker = finished_tracker
53 58 self._client = client
54 59 self.msg_ids = msg_ids
55 60 self._fname=fname
56 61 self._targets = targets
57 62 self._tracker = tracker
58 63 self._ready = False
59 64 self._success = None
60 65 if len(msg_ids) == 1:
61 66 self._single_result = not isinstance(targets, (list, tuple))
62 67 else:
63 68 self._single_result = False
64 69
65 70 def __repr__(self):
66 71 if self._ready:
67 72 return "<%s: finished>"%(self.__class__.__name__)
68 73 else:
69 74 return "<%s: %s>"%(self.__class__.__name__,self._fname)
70 75
71 76
72 77 def _reconstruct_result(self, res):
73 78 """Reconstruct our result from actual result list (always a list)
74 79
75 80 Override me in subclasses for turning a list of results
76 81 into the expected form.
77 82 """
78 83 if self._single_result:
79 84 return res[0]
80 85 else:
81 86 return res
82 87
83 88 def get(self, timeout=-1):
84 89 """Return the result when it arrives.
85 90
86 91 If `timeout` is not ``None`` and the result does not arrive within
87 92 `timeout` seconds then ``TimeoutError`` is raised. If the
88 93 remote call raised an exception then that exception will be reraised
89 94 by get() inside a `RemoteError`.
90 95 """
91 96 if not self.ready():
92 97 self.wait(timeout)
93 98
94 99 if self._ready:
95 100 if self._success:
96 101 return self._result
97 102 else:
98 103 raise self._exception
99 104 else:
100 105 raise error.TimeoutError("Result not ready.")
101 106
102 107 def ready(self):
103 108 """Return whether the call has completed."""
104 109 if not self._ready:
105 110 self.wait(0)
106 111 return self._ready
107 112
108 113 def wait(self, timeout=-1):
109 114 """Wait until the result is available or until `timeout` seconds pass.
110 115
111 116 This method always returns None.
112 117 """
113 118 if self._ready:
114 119 return
115 120 self._ready = self._client.wait(self.msg_ids, timeout)
116 121 if self._ready:
117 122 try:
118 123 results = map(self._client.results.get, self.msg_ids)
119 124 self._result = results
120 125 if self._single_result:
121 126 r = results[0]
122 127 if isinstance(r, Exception):
123 128 raise r
124 129 else:
125 130 results = error.collect_exceptions(results, self._fname)
126 131 self._result = self._reconstruct_result(results)
127 132 except Exception, e:
128 133 self._exception = e
129 134 self._success = False
130 135 else:
131 136 self._success = True
132 137 finally:
133 138 self._metadata = map(self._client.metadata.get, self.msg_ids)
134 139
135 140
136 141 def successful(self):
137 142 """Return whether the call completed without raising an exception.
138 143
139 144 Will raise ``AssertionError`` if the result is not ready.
140 145 """
141 146 assert self.ready()
142 147 return self._success
143 148
144 149 #----------------------------------------------------------------
145 150 # Extra methods not in mp.pool.AsyncResult
146 151 #----------------------------------------------------------------
147 152
148 153 def get_dict(self, timeout=-1):
149 154 """Get the results as a dict, keyed by engine_id.
150 155
151 156 timeout behavior is described in `get()`.
152 157 """
153 158
154 159 results = self.get(timeout)
155 160 engine_ids = [ md['engine_id'] for md in self._metadata ]
156 161 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
157 162 maxcount = bycount.count(bycount[-1])
158 163 if maxcount > 1:
159 164 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
160 165 maxcount, bycount[-1]))
161 166
162 167 return dict(zip(engine_ids,results))
163 168
164 169 @property
165 170 def result(self):
166 171 """result property wrapper for `get(timeout=0)`."""
167 172 return self.get()
168 173
169 174 # abbreviated alias:
170 175 r = result
171 176
172 177 @property
173 178 @check_ready
174 179 def metadata(self):
175 180 """property for accessing execution metadata."""
176 181 if self._single_result:
177 182 return self._metadata[0]
178 183 else:
179 184 return self._metadata
180 185
181 186 @property
182 187 def result_dict(self):
183 188 """result property as a dict."""
184 189 return self.get_dict()
185 190
186 191 def __dict__(self):
187 192 return self.get_dict(0)
188 193
189 194 def abort(self):
190 195 """abort my tasks."""
191 196 assert not self.ready(), "Can't abort, I am already done!"
192 197 return self.client.abort(self.msg_ids, targets=self._targets, block=True)
193 198
194 199 @property
195 200 def sent(self):
196 201 """check whether my messages have been sent."""
197 202 return self._tracker.done
198 203
199 204 def wait_for_send(self, timeout=-1):
200 205 """wait for pyzmq send to complete.
201 206
202 207 This is necessary when sending arrays that you intend to edit in-place.
203 208 `timeout` is in seconds, and will raise TimeoutError if it is reached
204 209 before the send completes.
205 210 """
206 211 return self._tracker.wait(timeout)
207 212
208 213 #-------------------------------------
209 214 # dict-access
210 215 #-------------------------------------
211 216
212 217 @check_ready
213 218 def __getitem__(self, key):
214 219 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
215 220 """
216 221 if isinstance(key, int):
217 222 return error.collect_exceptions([self._result[key]], self._fname)[0]
218 223 elif isinstance(key, slice):
219 224 return error.collect_exceptions(self._result[key], self._fname)
220 225 elif isinstance(key, basestring):
221 226 values = [ md[key] for md in self._metadata ]
222 227 if self._single_result:
223 228 return values[0]
224 229 else:
225 230 return values
226 231 else:
227 232 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
228 233
229 234 @check_ready
230 235 def __getattr__(self, key):
231 236 """getattr maps to getitem for convenient attr access to metadata."""
232 237 if key not in self._metadata[0].keys():
233 238 raise AttributeError("%r object has no attribute %r"%(
234 239 self.__class__.__name__, key))
235 240 return self.__getitem__(key)
236 241
237 242 # asynchronous iterator:
238 243 def __iter__(self):
239 244 if self._single_result:
240 245 raise TypeError("AsyncResults with a single result are not iterable.")
241 246 try:
242 247 rlist = self.get(0)
243 248 except error.TimeoutError:
244 249 # wait for each result individually
245 250 for msg_id in self.msg_ids:
246 251 ar = AsyncResult(self._client, msg_id, self._fname)
247 252 yield ar.get()
248 253 else:
249 254 # already done
250 255 for r in rlist:
251 256 yield r
252 257
253 258
254 259
255 260 class AsyncMapResult(AsyncResult):
256 261 """Class for representing results of non-blocking gathers.
257 262
258 263 This will properly reconstruct the gather.
259 264 """
260 265
261 266 def __init__(self, client, msg_ids, mapObject, fname=''):
262 267 AsyncResult.__init__(self, client, msg_ids, fname=fname)
263 268 self._mapObject = mapObject
264 269 self._single_result = False
265 270
266 271 def _reconstruct_result(self, res):
267 272 """Perform the gather on the actual results."""
268 273 return self._mapObject.joinPartitions(res)
269 274
270 275 # asynchronous iterator:
271 276 def __iter__(self):
272 277 try:
273 278 rlist = self.get(0)
274 279 except error.TimeoutError:
275 280 # wait for each result individually
276 281 for msg_id in self.msg_ids:
277 282 ar = AsyncResult(self._client, msg_id, self._fname)
278 283 rlist = ar.get()
279 284 try:
280 285 for r in rlist:
281 286 yield r
282 287 except TypeError:
283 288 # flattened, not a list
284 289 # this could get broken by flattened data that returns iterables
285 290 # but most calls to map do not expose the `flatten` argument
286 291 yield rlist
287 292 else:
288 293 # already done
289 294 for r in rlist:
290 295 yield r
291 296
292 297
293 298 class AsyncHubResult(AsyncResult):
294 299 """Class to wrap pending results that must be requested from the Hub.
295 300
296 301 Note that waiting/polling on these objects requires polling the Hubover the network,
297 302 so use `AsyncHubResult.wait()` sparingly.
298 303 """
299 304
300 305 def wait(self, timeout=-1):
301 306 """wait for result to complete."""
302 307 start = time.time()
303 308 if self._ready:
304 309 return
305 310 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
306 311 local_ready = self._client.wait(local_ids, timeout)
307 312 if local_ready:
308 313 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
309 314 if not remote_ids:
310 315 self._ready = True
311 316 else:
312 317 rdict = self._client.result_status(remote_ids, status_only=False)
313 318 pending = rdict['pending']
314 319 while pending and (timeout < 0 or time.time() < start+timeout):
315 320 rdict = self._client.result_status(remote_ids, status_only=False)
316 321 pending = rdict['pending']
317 322 if pending:
318 323 time.sleep(0.1)
319 324 if not pending:
320 325 self._ready = True
321 326 if self._ready:
322 327 try:
323 328 results = map(self._client.results.get, self.msg_ids)
324 329 self._result = results
325 330 if self._single_result:
326 331 r = results[0]
327 332 if isinstance(r, Exception):
328 333 raise r
329 334 else:
330 335 results = error.collect_exceptions(results, self._fname)
331 336 self._result = self._reconstruct_result(results)
332 337 except Exception, e:
333 338 self._exception = e
334 339 self._success = False
335 340 else:
336 341 self._success = True
337 342 finally:
338 343 self._metadata = map(self._client.metadata.get, self.msg_ids)
339 344
340 345 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult'] No newline at end of file
@@ -1,1368 +1,1373 b''
1 """A semi-synchronous Client for the ZMQ cluster"""
1 """A semi-synchronous Client for the ZMQ cluster
2
3 Authors:
4
5 * MinRK
6 """
2 7 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
4 9 #
5 10 # Distributed under the terms of the BSD License. The full license is in
6 11 # the file COPYING, distributed as part of this software.
7 12 #-----------------------------------------------------------------------------
8 13
9 14 #-----------------------------------------------------------------------------
10 15 # Imports
11 16 #-----------------------------------------------------------------------------
12 17
13 18 import os
14 19 import json
15 20 import time
16 21 import warnings
17 22 from datetime import datetime
18 23 from getpass import getpass
19 24 from pprint import pprint
20 25
21 26 pjoin = os.path.join
22 27
23 28 import zmq
24 29 # from zmq.eventloop import ioloop, zmqstream
25 30
26 31 from IPython.utils.path import get_ipython_dir
27 32 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
28 33 Dict, List, Bool, Set)
29 34 from IPython.external.decorator import decorator
30 35 from IPython.external.ssh import tunnel
31 36
32 37 from IPython.parallel import error
33 38 from IPython.parallel import util
34 39
35 40 from IPython.zmq.session import Session, Message
36 41
37 42 from .asyncresult import AsyncResult, AsyncHubResult
38 43 from IPython.core.newapplication import ProfileDir, ProfileDirError
39 44 from .view import DirectView, LoadBalancedView
40 45
41 46 #--------------------------------------------------------------------------
42 47 # Decorators for Client methods
43 48 #--------------------------------------------------------------------------
44 49
45 50 @decorator
46 51 def spin_first(f, self, *args, **kwargs):
47 52 """Call spin() to sync state prior to calling the method."""
48 53 self.spin()
49 54 return f(self, *args, **kwargs)
50 55
51 56
52 57 #--------------------------------------------------------------------------
53 58 # Classes
54 59 #--------------------------------------------------------------------------
55 60
56 61 class Metadata(dict):
57 62 """Subclass of dict for initializing metadata values.
58 63
59 64 Attribute access works on keys.
60 65
61 66 These objects have a strict set of keys - errors will raise if you try
62 67 to add new keys.
63 68 """
64 69 def __init__(self, *args, **kwargs):
65 70 dict.__init__(self)
66 71 md = {'msg_id' : None,
67 72 'submitted' : None,
68 73 'started' : None,
69 74 'completed' : None,
70 75 'received' : None,
71 76 'engine_uuid' : None,
72 77 'engine_id' : None,
73 78 'follow' : None,
74 79 'after' : None,
75 80 'status' : None,
76 81
77 82 'pyin' : None,
78 83 'pyout' : None,
79 84 'pyerr' : None,
80 85 'stdout' : '',
81 86 'stderr' : '',
82 87 }
83 88 self.update(md)
84 89 self.update(dict(*args, **kwargs))
85 90
86 91 def __getattr__(self, key):
87 92 """getattr aliased to getitem"""
88 93 if key in self.iterkeys():
89 94 return self[key]
90 95 else:
91 96 raise AttributeError(key)
92 97
93 98 def __setattr__(self, key, value):
94 99 """setattr aliased to setitem, with strict"""
95 100 if key in self.iterkeys():
96 101 self[key] = value
97 102 else:
98 103 raise AttributeError(key)
99 104
100 105 def __setitem__(self, key, value):
101 106 """strict static key enforcement"""
102 107 if key in self.iterkeys():
103 108 dict.__setitem__(self, key, value)
104 109 else:
105 110 raise KeyError(key)
106 111
107 112
108 113 class Client(HasTraits):
109 114 """A semi-synchronous client to the IPython ZMQ cluster
110 115
111 116 Parameters
112 117 ----------
113 118
114 119 url_or_file : bytes; zmq url or path to ipcontroller-client.json
115 120 Connection information for the Hub's registration. If a json connector
116 121 file is given, then likely no further configuration is necessary.
117 122 [Default: use profile]
118 123 profile : bytes
119 124 The name of the Cluster profile to be used to find connector information.
120 125 [Default: 'default']
121 126 context : zmq.Context
122 127 Pass an existing zmq.Context instance, otherwise the client will create its own.
123 128 debug : bool
124 129 flag for lots of message printing for debug purposes
125 130 timeout : int/float
126 131 time (in seconds) to wait for connection replies from the Hub
127 132 [Default: 10]
128 133
129 134 #-------------- session related args ----------------
130 135
131 136 config : Config object
132 137 If specified, this will be relayed to the Session for configuration
133 138 username : str
134 139 set username for the session object
135 140 packer : str (import_string) or callable
136 141 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
137 142 function to serialize messages. Must support same input as
138 143 JSON, and output must be bytes.
139 144 You can pass a callable directly as `pack`
140 145 unpacker : str (import_string) or callable
141 146 The inverse of packer. Only necessary if packer is specified as *not* one
142 147 of 'json' or 'pickle'.
143 148
144 149 #-------------- ssh related args ----------------
145 150 # These are args for configuring the ssh tunnel to be used
146 151 # credentials are used to forward connections over ssh to the Controller
147 152 # Note that the ip given in `addr` needs to be relative to sshserver
148 153 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
149 154 # and set sshserver as the same machine the Controller is on. However,
150 155 # the only requirement is that sshserver is able to see the Controller
151 156 # (i.e. is within the same trusted network).
152 157
153 158 sshserver : str
154 159 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
155 160 If keyfile or password is specified, and this is not, it will default to
156 161 the ip given in addr.
157 162 sshkey : str; path to public ssh key file
158 163 This specifies a key to be used in ssh login, default None.
159 164 Regular default ssh keys will be used without specifying this argument.
160 165 password : str
161 166 Your ssh password to sshserver. Note that if this is left None,
162 167 you will be prompted for it if passwordless key based login is unavailable.
163 168 paramiko : bool
164 169 flag for whether to use paramiko instead of shell ssh for tunneling.
165 170 [default: True on win32, False else]
166 171
167 172 ------- exec authentication args -------
168 173 If even localhost is untrusted, you can have some protection against
169 174 unauthorized execution by signing messages with HMAC digests.
170 175 Messages are still sent as cleartext, so if someone can snoop your
171 176 loopback traffic this will not protect your privacy, but will prevent
172 177 unauthorized execution.
173 178
174 179 exec_key : str
175 180 an authentication key or file containing a key
176 181 default: None
177 182
178 183
179 184 Attributes
180 185 ----------
181 186
182 187 ids : list of int engine IDs
183 188 requesting the ids attribute always synchronizes
184 189 the registration state. To request ids without synchronization,
185 190 use semi-private _ids attributes.
186 191
187 192 history : list of msg_ids
188 193 a list of msg_ids, keeping track of all the execution
189 194 messages you have submitted in order.
190 195
191 196 outstanding : set of msg_ids
192 197 a set of msg_ids that have been submitted, but whose
193 198 results have not yet been received.
194 199
195 200 results : dict
196 201 a dict of all our results, keyed by msg_id
197 202
198 203 block : bool
199 204 determines default behavior when block not specified
200 205 in execution methods
201 206
202 207 Methods
203 208 -------
204 209
205 210 spin
206 211 flushes incoming results and registration state changes
207 212 control methods spin, and requesting `ids` also ensures up to date
208 213
209 214 wait
210 215 wait on one or more msg_ids
211 216
212 217 execution methods
213 218 apply
214 219 legacy: execute, run
215 220
216 221 data movement
217 222 push, pull, scatter, gather
218 223
219 224 query methods
220 225 queue_status, get_result, purge, result_status
221 226
222 227 control methods
223 228 abort, shutdown
224 229
225 230 """
226 231
227 232
228 233 block = Bool(False)
229 234 outstanding = Set()
230 235 results = Instance('collections.defaultdict', (dict,))
231 236 metadata = Instance('collections.defaultdict', (Metadata,))
232 237 history = List()
233 238 debug = Bool(False)
234 239 profile=Unicode('default')
235 240
236 241 _outstanding_dict = Instance('collections.defaultdict', (set,))
237 242 _ids = List()
238 243 _connected=Bool(False)
239 244 _ssh=Bool(False)
240 245 _context = Instance('zmq.Context')
241 246 _config = Dict()
242 247 _engines=Instance(util.ReverseDict, (), {})
243 248 # _hub_socket=Instance('zmq.Socket')
244 249 _query_socket=Instance('zmq.Socket')
245 250 _control_socket=Instance('zmq.Socket')
246 251 _iopub_socket=Instance('zmq.Socket')
247 252 _notification_socket=Instance('zmq.Socket')
248 253 _mux_socket=Instance('zmq.Socket')
249 254 _task_socket=Instance('zmq.Socket')
250 255 _task_scheme=Unicode()
251 256 _closed = False
252 257 _ignored_control_replies=Int(0)
253 258 _ignored_hub_replies=Int(0)
254 259
255 260 def __init__(self, url_or_file=None, profile='default', profile_dir=None, ipython_dir=None,
256 261 context=None, debug=False, exec_key=None,
257 262 sshserver=None, sshkey=None, password=None, paramiko=None,
258 263 timeout=10, **extra_args
259 264 ):
260 265 super(Client, self).__init__(debug=debug, profile=profile)
261 266 if context is None:
262 267 context = zmq.Context.instance()
263 268 self._context = context
264 269
265 270
266 271 self._setup_profile_dir(profile, profile_dir, ipython_dir)
267 272 if self._cd is not None:
268 273 if url_or_file is None:
269 274 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
270 275 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
271 276 " Please specify at least one of url_or_file or profile."
272 277
273 278 try:
274 279 util.validate_url(url_or_file)
275 280 except AssertionError:
276 281 if not os.path.exists(url_or_file):
277 282 if self._cd:
278 283 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
279 284 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
280 285 with open(url_or_file) as f:
281 286 cfg = json.loads(f.read())
282 287 else:
283 288 cfg = {'url':url_or_file}
284 289
285 290 # sync defaults from args, json:
286 291 if sshserver:
287 292 cfg['ssh'] = sshserver
288 293 if exec_key:
289 294 cfg['exec_key'] = exec_key
290 295 exec_key = cfg['exec_key']
291 296 sshserver=cfg['ssh']
292 297 url = cfg['url']
293 298 location = cfg.setdefault('location', None)
294 299 cfg['url'] = util.disambiguate_url(cfg['url'], location)
295 300 url = cfg['url']
296 301
297 302 self._config = cfg
298 303
299 304 self._ssh = bool(sshserver or sshkey or password)
300 305 if self._ssh and sshserver is None:
301 306 # default to ssh via localhost
302 307 sshserver = url.split('://')[1].split(':')[0]
303 308 if self._ssh and password is None:
304 309 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
305 310 password=False
306 311 else:
307 312 password = getpass("SSH Password for %s: "%sshserver)
308 313 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
309 314
310 315 # configure and construct the session
311 316 if exec_key is not None:
312 317 if os.path.isfile(exec_key):
313 318 extra_args['keyfile'] = exec_key
314 319 else:
315 320 extra_args['key'] = exec_key
316 321 self.session = Session(**extra_args)
317 322
318 323 self._query_socket = self._context.socket(zmq.XREQ)
319 324 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
320 325 if self._ssh:
321 326 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
322 327 else:
323 328 self._query_socket.connect(url)
324 329
325 330 self.session.debug = self.debug
326 331
327 332 self._notification_handlers = {'registration_notification' : self._register_engine,
328 333 'unregistration_notification' : self._unregister_engine,
329 334 'shutdown_notification' : lambda msg: self.close(),
330 335 }
331 336 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
332 337 'apply_reply' : self._handle_apply_reply}
333 338 self._connect(sshserver, ssh_kwargs, timeout)
334 339
335 340 def __del__(self):
336 341 """cleanup sockets, but _not_ context."""
337 342 self.close()
338 343
339 344 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
340 345 if ipython_dir is None:
341 346 ipython_dir = get_ipython_dir()
342 347 if profile_dir is not None:
343 348 try:
344 349 self._cd = ProfileDir.find_profile_dir(profile_dir)
345 350 return
346 351 except ProfileDirError:
347 352 pass
348 353 elif profile is not None:
349 354 try:
350 355 self._cd = ProfileDir.find_profile_dir_by_name(
351 356 ipython_dir, profile)
352 357 return
353 358 except ProfileDirError:
354 359 pass
355 360 self._cd = None
356 361
357 362 def _update_engines(self, engines):
358 363 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
359 364 for k,v in engines.iteritems():
360 365 eid = int(k)
361 366 self._engines[eid] = bytes(v) # force not unicode
362 367 self._ids.append(eid)
363 368 self._ids = sorted(self._ids)
364 369 if sorted(self._engines.keys()) != range(len(self._engines)) and \
365 370 self._task_scheme == 'pure' and self._task_socket:
366 371 self._stop_scheduling_tasks()
367 372
368 373 def _stop_scheduling_tasks(self):
369 374 """Stop scheduling tasks because an engine has been unregistered
370 375 from a pure ZMQ scheduler.
371 376 """
372 377 self._task_socket.close()
373 378 self._task_socket = None
374 379 msg = "An engine has been unregistered, and we are using pure " +\
375 380 "ZMQ task scheduling. Task farming will be disabled."
376 381 if self.outstanding:
377 382 msg += " If you were running tasks when this happened, " +\
378 383 "some `outstanding` msg_ids may never resolve."
379 384 warnings.warn(msg, RuntimeWarning)
380 385
381 386 def _build_targets(self, targets):
382 387 """Turn valid target IDs or 'all' into two lists:
383 388 (int_ids, uuids).
384 389 """
385 390 if not self._ids:
386 391 # flush notification socket if no engines yet, just in case
387 392 if not self.ids:
388 393 raise error.NoEnginesRegistered("Can't build targets without any engines")
389 394
390 395 if targets is None:
391 396 targets = self._ids
392 397 elif isinstance(targets, str):
393 398 if targets.lower() == 'all':
394 399 targets = self._ids
395 400 else:
396 401 raise TypeError("%r not valid str target, must be 'all'"%(targets))
397 402 elif isinstance(targets, int):
398 403 if targets < 0:
399 404 targets = self.ids[targets]
400 405 if targets not in self._ids:
401 406 raise IndexError("No such engine: %i"%targets)
402 407 targets = [targets]
403 408
404 409 if isinstance(targets, slice):
405 410 indices = range(len(self._ids))[targets]
406 411 ids = self.ids
407 412 targets = [ ids[i] for i in indices ]
408 413
409 414 if not isinstance(targets, (tuple, list, xrange)):
410 415 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
411 416
412 417 return [self._engines[t] for t in targets], list(targets)
413 418
414 419 def _connect(self, sshserver, ssh_kwargs, timeout):
415 420 """setup all our socket connections to the cluster. This is called from
416 421 __init__."""
417 422
418 423 # Maybe allow reconnecting?
419 424 if self._connected:
420 425 return
421 426 self._connected=True
422 427
423 428 def connect_socket(s, url):
424 429 url = util.disambiguate_url(url, self._config['location'])
425 430 if self._ssh:
426 431 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
427 432 else:
428 433 return s.connect(url)
429 434
430 435 self.session.send(self._query_socket, 'connection_request')
431 436 r,w,x = zmq.select([self._query_socket],[],[], timeout)
432 437 if not r:
433 438 raise error.TimeoutError("Hub connection request timed out")
434 439 idents,msg = self.session.recv(self._query_socket,mode=0)
435 440 if self.debug:
436 441 pprint(msg)
437 442 msg = Message(msg)
438 443 content = msg.content
439 444 self._config['registration'] = dict(content)
440 445 if content.status == 'ok':
441 446 if content.mux:
442 447 self._mux_socket = self._context.socket(zmq.XREQ)
443 448 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
444 449 connect_socket(self._mux_socket, content.mux)
445 450 if content.task:
446 451 self._task_scheme, task_addr = content.task
447 452 self._task_socket = self._context.socket(zmq.XREQ)
448 453 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
449 454 connect_socket(self._task_socket, task_addr)
450 455 if content.notification:
451 456 self._notification_socket = self._context.socket(zmq.SUB)
452 457 connect_socket(self._notification_socket, content.notification)
453 458 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
454 459 # if content.query:
455 460 # self._query_socket = self._context.socket(zmq.XREQ)
456 461 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
457 462 # connect_socket(self._query_socket, content.query)
458 463 if content.control:
459 464 self._control_socket = self._context.socket(zmq.XREQ)
460 465 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
461 466 connect_socket(self._control_socket, content.control)
462 467 if content.iopub:
463 468 self._iopub_socket = self._context.socket(zmq.SUB)
464 469 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
465 470 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
466 471 connect_socket(self._iopub_socket, content.iopub)
467 472 self._update_engines(dict(content.engines))
468 473 else:
469 474 self._connected = False
470 475 raise Exception("Failed to connect!")
471 476
472 477 #--------------------------------------------------------------------------
473 478 # handlers and callbacks for incoming messages
474 479 #--------------------------------------------------------------------------
475 480
476 481 def _unwrap_exception(self, content):
477 482 """unwrap exception, and remap engine_id to int."""
478 483 e = error.unwrap_exception(content)
479 484 # print e.traceback
480 485 if e.engine_info:
481 486 e_uuid = e.engine_info['engine_uuid']
482 487 eid = self._engines[e_uuid]
483 488 e.engine_info['engine_id'] = eid
484 489 return e
485 490
486 491 def _extract_metadata(self, header, parent, content):
487 492 md = {'msg_id' : parent['msg_id'],
488 493 'received' : datetime.now(),
489 494 'engine_uuid' : header.get('engine', None),
490 495 'follow' : parent.get('follow', []),
491 496 'after' : parent.get('after', []),
492 497 'status' : content['status'],
493 498 }
494 499
495 500 if md['engine_uuid'] is not None:
496 501 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
497 502
498 503 if 'date' in parent:
499 504 md['submitted'] = parent['date']
500 505 if 'started' in header:
501 506 md['started'] = header['started']
502 507 if 'date' in header:
503 508 md['completed'] = header['date']
504 509 return md
505 510
506 511 def _register_engine(self, msg):
507 512 """Register a new engine, and update our connection info."""
508 513 content = msg['content']
509 514 eid = content['id']
510 515 d = {eid : content['queue']}
511 516 self._update_engines(d)
512 517
513 518 def _unregister_engine(self, msg):
514 519 """Unregister an engine that has died."""
515 520 content = msg['content']
516 521 eid = int(content['id'])
517 522 if eid in self._ids:
518 523 self._ids.remove(eid)
519 524 uuid = self._engines.pop(eid)
520 525
521 526 self._handle_stranded_msgs(eid, uuid)
522 527
523 528 if self._task_socket and self._task_scheme == 'pure':
524 529 self._stop_scheduling_tasks()
525 530
526 531 def _handle_stranded_msgs(self, eid, uuid):
527 532 """Handle messages known to be on an engine when the engine unregisters.
528 533
529 534 It is possible that this will fire prematurely - that is, an engine will
530 535 go down after completing a result, and the client will be notified
531 536 of the unregistration and later receive the successful result.
532 537 """
533 538
534 539 outstanding = self._outstanding_dict[uuid]
535 540
536 541 for msg_id in list(outstanding):
537 542 if msg_id in self.results:
538 543 # we already
539 544 continue
540 545 try:
541 546 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
542 547 except:
543 548 content = error.wrap_exception()
544 549 # build a fake message:
545 550 parent = {}
546 551 header = {}
547 552 parent['msg_id'] = msg_id
548 553 header['engine'] = uuid
549 554 header['date'] = datetime.now()
550 555 msg = dict(parent_header=parent, header=header, content=content)
551 556 self._handle_apply_reply(msg)
552 557
553 558 def _handle_execute_reply(self, msg):
554 559 """Save the reply to an execute_request into our results.
555 560
556 561 execute messages are never actually used. apply is used instead.
557 562 """
558 563
559 564 parent = msg['parent_header']
560 565 msg_id = parent['msg_id']
561 566 if msg_id not in self.outstanding:
562 567 if msg_id in self.history:
563 568 print ("got stale result: %s"%msg_id)
564 569 else:
565 570 print ("got unknown result: %s"%msg_id)
566 571 else:
567 572 self.outstanding.remove(msg_id)
568 573 self.results[msg_id] = self._unwrap_exception(msg['content'])
569 574
570 575 def _handle_apply_reply(self, msg):
571 576 """Save the reply to an apply_request into our results."""
572 577 parent = msg['parent_header']
573 578 msg_id = parent['msg_id']
574 579 if msg_id not in self.outstanding:
575 580 if msg_id in self.history:
576 581 print ("got stale result: %s"%msg_id)
577 582 print self.results[msg_id]
578 583 print msg
579 584 else:
580 585 print ("got unknown result: %s"%msg_id)
581 586 else:
582 587 self.outstanding.remove(msg_id)
583 588 content = msg['content']
584 589 header = msg['header']
585 590
586 591 # construct metadata:
587 592 md = self.metadata[msg_id]
588 593 md.update(self._extract_metadata(header, parent, content))
589 594 # is this redundant?
590 595 self.metadata[msg_id] = md
591 596
592 597 e_outstanding = self._outstanding_dict[md['engine_uuid']]
593 598 if msg_id in e_outstanding:
594 599 e_outstanding.remove(msg_id)
595 600
596 601 # construct result:
597 602 if content['status'] == 'ok':
598 603 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
599 604 elif content['status'] == 'aborted':
600 605 self.results[msg_id] = error.TaskAborted(msg_id)
601 606 elif content['status'] == 'resubmitted':
602 607 # TODO: handle resubmission
603 608 pass
604 609 else:
605 610 self.results[msg_id] = self._unwrap_exception(content)
606 611
607 612 def _flush_notifications(self):
608 613 """Flush notifications of engine registrations waiting
609 614 in ZMQ queue."""
610 615 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
611 616 while msg is not None:
612 617 if self.debug:
613 618 pprint(msg)
614 619 msg_type = msg['msg_type']
615 620 handler = self._notification_handlers.get(msg_type, None)
616 621 if handler is None:
617 622 raise Exception("Unhandled message type: %s"%msg.msg_type)
618 623 else:
619 624 handler(msg)
620 625 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
621 626
622 627 def _flush_results(self, sock):
623 628 """Flush task or queue results waiting in ZMQ queue."""
624 629 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
625 630 while msg is not None:
626 631 if self.debug:
627 632 pprint(msg)
628 633 msg_type = msg['msg_type']
629 634 handler = self._queue_handlers.get(msg_type, None)
630 635 if handler is None:
631 636 raise Exception("Unhandled message type: %s"%msg.msg_type)
632 637 else:
633 638 handler(msg)
634 639 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
635 640
636 641 def _flush_control(self, sock):
637 642 """Flush replies from the control channel waiting
638 643 in the ZMQ queue.
639 644
640 645 Currently: ignore them."""
641 646 if self._ignored_control_replies <= 0:
642 647 return
643 648 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
644 649 while msg is not None:
645 650 self._ignored_control_replies -= 1
646 651 if self.debug:
647 652 pprint(msg)
648 653 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
649 654
650 655 def _flush_ignored_control(self):
651 656 """flush ignored control replies"""
652 657 while self._ignored_control_replies > 0:
653 658 self.session.recv(self._control_socket)
654 659 self._ignored_control_replies -= 1
655 660
656 661 def _flush_ignored_hub_replies(self):
657 662 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
658 663 while msg is not None:
659 664 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
660 665
661 666 def _flush_iopub(self, sock):
662 667 """Flush replies from the iopub channel waiting
663 668 in the ZMQ queue.
664 669 """
665 670 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
666 671 while msg is not None:
667 672 if self.debug:
668 673 pprint(msg)
669 674 parent = msg['parent_header']
670 675 msg_id = parent['msg_id']
671 676 content = msg['content']
672 677 header = msg['header']
673 678 msg_type = msg['msg_type']
674 679
675 680 # init metadata:
676 681 md = self.metadata[msg_id]
677 682
678 683 if msg_type == 'stream':
679 684 name = content['name']
680 685 s = md[name] or ''
681 686 md[name] = s + content['data']
682 687 elif msg_type == 'pyerr':
683 688 md.update({'pyerr' : self._unwrap_exception(content)})
684 689 elif msg_type == 'pyin':
685 690 md.update({'pyin' : content['code']})
686 691 else:
687 692 md.update({msg_type : content.get('data', '')})
688 693
689 694 # reduntant?
690 695 self.metadata[msg_id] = md
691 696
692 697 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
693 698
694 699 #--------------------------------------------------------------------------
695 700 # len, getitem
696 701 #--------------------------------------------------------------------------
697 702
698 703 def __len__(self):
699 704 """len(client) returns # of engines."""
700 705 return len(self.ids)
701 706
702 707 def __getitem__(self, key):
703 708 """index access returns DirectView multiplexer objects
704 709
705 710 Must be int, slice, or list/tuple/xrange of ints"""
706 711 if not isinstance(key, (int, slice, tuple, list, xrange)):
707 712 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
708 713 else:
709 714 return self.direct_view(key)
710 715
711 716 #--------------------------------------------------------------------------
712 717 # Begin public methods
713 718 #--------------------------------------------------------------------------
714 719
715 720 @property
716 721 def ids(self):
717 722 """Always up-to-date ids property."""
718 723 self._flush_notifications()
719 724 # always copy:
720 725 return list(self._ids)
721 726
722 727 def close(self):
723 728 if self._closed:
724 729 return
725 730 snames = filter(lambda n: n.endswith('socket'), dir(self))
726 731 for socket in map(lambda name: getattr(self, name), snames):
727 732 if isinstance(socket, zmq.Socket) and not socket.closed:
728 733 socket.close()
729 734 self._closed = True
730 735
731 736 def spin(self):
732 737 """Flush any registration notifications and execution results
733 738 waiting in the ZMQ queue.
734 739 """
735 740 if self._notification_socket:
736 741 self._flush_notifications()
737 742 if self._mux_socket:
738 743 self._flush_results(self._mux_socket)
739 744 if self._task_socket:
740 745 self._flush_results(self._task_socket)
741 746 if self._control_socket:
742 747 self._flush_control(self._control_socket)
743 748 if self._iopub_socket:
744 749 self._flush_iopub(self._iopub_socket)
745 750 if self._query_socket:
746 751 self._flush_ignored_hub_replies()
747 752
748 753 def wait(self, jobs=None, timeout=-1):
749 754 """waits on one or more `jobs`, for up to `timeout` seconds.
750 755
751 756 Parameters
752 757 ----------
753 758
754 759 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
755 760 ints are indices to self.history
756 761 strs are msg_ids
757 762 default: wait on all outstanding messages
758 763 timeout : float
759 764 a time in seconds, after which to give up.
760 765 default is -1, which means no timeout
761 766
762 767 Returns
763 768 -------
764 769
765 770 True : when all msg_ids are done
766 771 False : timeout reached, some msg_ids still outstanding
767 772 """
768 773 tic = time.time()
769 774 if jobs is None:
770 775 theids = self.outstanding
771 776 else:
772 777 if isinstance(jobs, (int, str, AsyncResult)):
773 778 jobs = [jobs]
774 779 theids = set()
775 780 for job in jobs:
776 781 if isinstance(job, int):
777 782 # index access
778 783 job = self.history[job]
779 784 elif isinstance(job, AsyncResult):
780 785 map(theids.add, job.msg_ids)
781 786 continue
782 787 theids.add(job)
783 788 if not theids.intersection(self.outstanding):
784 789 return True
785 790 self.spin()
786 791 while theids.intersection(self.outstanding):
787 792 if timeout >= 0 and ( time.time()-tic ) > timeout:
788 793 break
789 794 time.sleep(1e-3)
790 795 self.spin()
791 796 return len(theids.intersection(self.outstanding)) == 0
792 797
793 798 #--------------------------------------------------------------------------
794 799 # Control methods
795 800 #--------------------------------------------------------------------------
796 801
797 802 @spin_first
798 803 def clear(self, targets=None, block=None):
799 804 """Clear the namespace in target(s)."""
800 805 block = self.block if block is None else block
801 806 targets = self._build_targets(targets)[0]
802 807 for t in targets:
803 808 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
804 809 error = False
805 810 if block:
806 811 self._flush_ignored_control()
807 812 for i in range(len(targets)):
808 813 idents,msg = self.session.recv(self._control_socket,0)
809 814 if self.debug:
810 815 pprint(msg)
811 816 if msg['content']['status'] != 'ok':
812 817 error = self._unwrap_exception(msg['content'])
813 818 else:
814 819 self._ignored_control_replies += len(targets)
815 820 if error:
816 821 raise error
817 822
818 823
819 824 @spin_first
820 825 def abort(self, jobs=None, targets=None, block=None):
821 826 """Abort specific jobs from the execution queues of target(s).
822 827
823 828 This is a mechanism to prevent jobs that have already been submitted
824 829 from executing.
825 830
826 831 Parameters
827 832 ----------
828 833
829 834 jobs : msg_id, list of msg_ids, or AsyncResult
830 835 The jobs to be aborted
831 836
832 837
833 838 """
834 839 block = self.block if block is None else block
835 840 targets = self._build_targets(targets)[0]
836 841 msg_ids = []
837 842 if isinstance(jobs, (basestring,AsyncResult)):
838 843 jobs = [jobs]
839 844 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
840 845 if bad_ids:
841 846 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
842 847 for j in jobs:
843 848 if isinstance(j, AsyncResult):
844 849 msg_ids.extend(j.msg_ids)
845 850 else:
846 851 msg_ids.append(j)
847 852 content = dict(msg_ids=msg_ids)
848 853 for t in targets:
849 854 self.session.send(self._control_socket, 'abort_request',
850 855 content=content, ident=t)
851 856 error = False
852 857 if block:
853 858 self._flush_ignored_control()
854 859 for i in range(len(targets)):
855 860 idents,msg = self.session.recv(self._control_socket,0)
856 861 if self.debug:
857 862 pprint(msg)
858 863 if msg['content']['status'] != 'ok':
859 864 error = self._unwrap_exception(msg['content'])
860 865 else:
861 866 self._ignored_control_replies += len(targets)
862 867 if error:
863 868 raise error
864 869
865 870 @spin_first
866 871 def shutdown(self, targets=None, restart=False, hub=False, block=None):
867 872 """Terminates one or more engine processes, optionally including the hub."""
868 873 block = self.block if block is None else block
869 874 if hub:
870 875 targets = 'all'
871 876 targets = self._build_targets(targets)[0]
872 877 for t in targets:
873 878 self.session.send(self._control_socket, 'shutdown_request',
874 879 content={'restart':restart},ident=t)
875 880 error = False
876 881 if block or hub:
877 882 self._flush_ignored_control()
878 883 for i in range(len(targets)):
879 884 idents,msg = self.session.recv(self._control_socket, 0)
880 885 if self.debug:
881 886 pprint(msg)
882 887 if msg['content']['status'] != 'ok':
883 888 error = self._unwrap_exception(msg['content'])
884 889 else:
885 890 self._ignored_control_replies += len(targets)
886 891
887 892 if hub:
888 893 time.sleep(0.25)
889 894 self.session.send(self._query_socket, 'shutdown_request')
890 895 idents,msg = self.session.recv(self._query_socket, 0)
891 896 if self.debug:
892 897 pprint(msg)
893 898 if msg['content']['status'] != 'ok':
894 899 error = self._unwrap_exception(msg['content'])
895 900
896 901 if error:
897 902 raise error
898 903
899 904 #--------------------------------------------------------------------------
900 905 # Execution related methods
901 906 #--------------------------------------------------------------------------
902 907
903 908 def _maybe_raise(self, result):
904 909 """wrapper for maybe raising an exception if apply failed."""
905 910 if isinstance(result, error.RemoteError):
906 911 raise result
907 912
908 913 return result
909 914
910 915 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
911 916 ident=None):
912 917 """construct and send an apply message via a socket.
913 918
914 919 This is the principal method with which all engine execution is performed by views.
915 920 """
916 921
917 922 assert not self._closed, "cannot use me anymore, I'm closed!"
918 923 # defaults:
919 924 args = args if args is not None else []
920 925 kwargs = kwargs if kwargs is not None else {}
921 926 subheader = subheader if subheader is not None else {}
922 927
923 928 # validate arguments
924 929 if not callable(f):
925 930 raise TypeError("f must be callable, not %s"%type(f))
926 931 if not isinstance(args, (tuple, list)):
927 932 raise TypeError("args must be tuple or list, not %s"%type(args))
928 933 if not isinstance(kwargs, dict):
929 934 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
930 935 if not isinstance(subheader, dict):
931 936 raise TypeError("subheader must be dict, not %s"%type(subheader))
932 937
933 938 bufs = util.pack_apply_message(f,args,kwargs)
934 939
935 940 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
936 941 subheader=subheader, track=track)
937 942
938 943 msg_id = msg['msg_id']
939 944 self.outstanding.add(msg_id)
940 945 if ident:
941 946 # possibly routed to a specific engine
942 947 if isinstance(ident, list):
943 948 ident = ident[-1]
944 949 if ident in self._engines.values():
945 950 # save for later, in case of engine death
946 951 self._outstanding_dict[ident].add(msg_id)
947 952 self.history.append(msg_id)
948 953 self.metadata[msg_id]['submitted'] = datetime.now()
949 954
950 955 return msg
951 956
952 957 #--------------------------------------------------------------------------
953 958 # construct a View object
954 959 #--------------------------------------------------------------------------
955 960
956 961 def load_balanced_view(self, targets=None):
957 962 """construct a DirectView object.
958 963
959 964 If no arguments are specified, create a LoadBalancedView
960 965 using all engines.
961 966
962 967 Parameters
963 968 ----------
964 969
965 970 targets: list,slice,int,etc. [default: use all engines]
966 971 The subset of engines across which to load-balance
967 972 """
968 973 if targets is not None:
969 974 targets = self._build_targets(targets)[1]
970 975 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
971 976
972 977 def direct_view(self, targets='all'):
973 978 """construct a DirectView object.
974 979
975 980 If no targets are specified, create a DirectView
976 981 using all engines.
977 982
978 983 Parameters
979 984 ----------
980 985
981 986 targets: list,slice,int,etc. [default: use all engines]
982 987 The engines to use for the View
983 988 """
984 989 single = isinstance(targets, int)
985 990 targets = self._build_targets(targets)[1]
986 991 if single:
987 992 targets = targets[0]
988 993 return DirectView(client=self, socket=self._mux_socket, targets=targets)
989 994
990 995 #--------------------------------------------------------------------------
991 996 # Query methods
992 997 #--------------------------------------------------------------------------
993 998
994 999 @spin_first
995 1000 def get_result(self, indices_or_msg_ids=None, block=None):
996 1001 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
997 1002
998 1003 If the client already has the results, no request to the Hub will be made.
999 1004
1000 1005 This is a convenient way to construct AsyncResult objects, which are wrappers
1001 1006 that include metadata about execution, and allow for awaiting results that
1002 1007 were not submitted by this Client.
1003 1008
1004 1009 It can also be a convenient way to retrieve the metadata associated with
1005 1010 blocking execution, since it always retrieves
1006 1011
1007 1012 Examples
1008 1013 --------
1009 1014 ::
1010 1015
1011 1016 In [10]: r = client.apply()
1012 1017
1013 1018 Parameters
1014 1019 ----------
1015 1020
1016 1021 indices_or_msg_ids : integer history index, str msg_id, or list of either
1017 1022 The indices or msg_ids of indices to be retrieved
1018 1023
1019 1024 block : bool
1020 1025 Whether to wait for the result to be done
1021 1026
1022 1027 Returns
1023 1028 -------
1024 1029
1025 1030 AsyncResult
1026 1031 A single AsyncResult object will always be returned.
1027 1032
1028 1033 AsyncHubResult
1029 1034 A subclass of AsyncResult that retrieves results from the Hub
1030 1035
1031 1036 """
1032 1037 block = self.block if block is None else block
1033 1038 if indices_or_msg_ids is None:
1034 1039 indices_or_msg_ids = -1
1035 1040
1036 1041 if not isinstance(indices_or_msg_ids, (list,tuple)):
1037 1042 indices_or_msg_ids = [indices_or_msg_ids]
1038 1043
1039 1044 theids = []
1040 1045 for id in indices_or_msg_ids:
1041 1046 if isinstance(id, int):
1042 1047 id = self.history[id]
1043 1048 if not isinstance(id, str):
1044 1049 raise TypeError("indices must be str or int, not %r"%id)
1045 1050 theids.append(id)
1046 1051
1047 1052 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1048 1053 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1049 1054
1050 1055 if remote_ids:
1051 1056 ar = AsyncHubResult(self, msg_ids=theids)
1052 1057 else:
1053 1058 ar = AsyncResult(self, msg_ids=theids)
1054 1059
1055 1060 if block:
1056 1061 ar.wait()
1057 1062
1058 1063 return ar
1059 1064
1060 1065 @spin_first
1061 1066 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1062 1067 """Resubmit one or more tasks.
1063 1068
1064 1069 in-flight tasks may not be resubmitted.
1065 1070
1066 1071 Parameters
1067 1072 ----------
1068 1073
1069 1074 indices_or_msg_ids : integer history index, str msg_id, or list of either
1070 1075 The indices or msg_ids of indices to be retrieved
1071 1076
1072 1077 block : bool
1073 1078 Whether to wait for the result to be done
1074 1079
1075 1080 Returns
1076 1081 -------
1077 1082
1078 1083 AsyncHubResult
1079 1084 A subclass of AsyncResult that retrieves results from the Hub
1080 1085
1081 1086 """
1082 1087 block = self.block if block is None else block
1083 1088 if indices_or_msg_ids is None:
1084 1089 indices_or_msg_ids = -1
1085 1090
1086 1091 if not isinstance(indices_or_msg_ids, (list,tuple)):
1087 1092 indices_or_msg_ids = [indices_or_msg_ids]
1088 1093
1089 1094 theids = []
1090 1095 for id in indices_or_msg_ids:
1091 1096 if isinstance(id, int):
1092 1097 id = self.history[id]
1093 1098 if not isinstance(id, str):
1094 1099 raise TypeError("indices must be str or int, not %r"%id)
1095 1100 theids.append(id)
1096 1101
1097 1102 for msg_id in theids:
1098 1103 self.outstanding.discard(msg_id)
1099 1104 if msg_id in self.history:
1100 1105 self.history.remove(msg_id)
1101 1106 self.results.pop(msg_id, None)
1102 1107 self.metadata.pop(msg_id, None)
1103 1108 content = dict(msg_ids = theids)
1104 1109
1105 1110 self.session.send(self._query_socket, 'resubmit_request', content)
1106 1111
1107 1112 zmq.select([self._query_socket], [], [])
1108 1113 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1109 1114 if self.debug:
1110 1115 pprint(msg)
1111 1116 content = msg['content']
1112 1117 if content['status'] != 'ok':
1113 1118 raise self._unwrap_exception(content)
1114 1119
1115 1120 ar = AsyncHubResult(self, msg_ids=theids)
1116 1121
1117 1122 if block:
1118 1123 ar.wait()
1119 1124
1120 1125 return ar
1121 1126
1122 1127 @spin_first
1123 1128 def result_status(self, msg_ids, status_only=True):
1124 1129 """Check on the status of the result(s) of the apply request with `msg_ids`.
1125 1130
1126 1131 If status_only is False, then the actual results will be retrieved, else
1127 1132 only the status of the results will be checked.
1128 1133
1129 1134 Parameters
1130 1135 ----------
1131 1136
1132 1137 msg_ids : list of msg_ids
1133 1138 if int:
1134 1139 Passed as index to self.history for convenience.
1135 1140 status_only : bool (default: True)
1136 1141 if False:
1137 1142 Retrieve the actual results of completed tasks.
1138 1143
1139 1144 Returns
1140 1145 -------
1141 1146
1142 1147 results : dict
1143 1148 There will always be the keys 'pending' and 'completed', which will
1144 1149 be lists of msg_ids that are incomplete or complete. If `status_only`
1145 1150 is False, then completed results will be keyed by their `msg_id`.
1146 1151 """
1147 1152 if not isinstance(msg_ids, (list,tuple)):
1148 1153 msg_ids = [msg_ids]
1149 1154
1150 1155 theids = []
1151 1156 for msg_id in msg_ids:
1152 1157 if isinstance(msg_id, int):
1153 1158 msg_id = self.history[msg_id]
1154 1159 if not isinstance(msg_id, basestring):
1155 1160 raise TypeError("msg_ids must be str, not %r"%msg_id)
1156 1161 theids.append(msg_id)
1157 1162
1158 1163 completed = []
1159 1164 local_results = {}
1160 1165
1161 1166 # comment this block out to temporarily disable local shortcut:
1162 1167 for msg_id in theids:
1163 1168 if msg_id in self.results:
1164 1169 completed.append(msg_id)
1165 1170 local_results[msg_id] = self.results[msg_id]
1166 1171 theids.remove(msg_id)
1167 1172
1168 1173 if theids: # some not locally cached
1169 1174 content = dict(msg_ids=theids, status_only=status_only)
1170 1175 msg = self.session.send(self._query_socket, "result_request", content=content)
1171 1176 zmq.select([self._query_socket], [], [])
1172 1177 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1173 1178 if self.debug:
1174 1179 pprint(msg)
1175 1180 content = msg['content']
1176 1181 if content['status'] != 'ok':
1177 1182 raise self._unwrap_exception(content)
1178 1183 buffers = msg['buffers']
1179 1184 else:
1180 1185 content = dict(completed=[],pending=[])
1181 1186
1182 1187 content['completed'].extend(completed)
1183 1188
1184 1189 if status_only:
1185 1190 return content
1186 1191
1187 1192 failures = []
1188 1193 # load cached results into result:
1189 1194 content.update(local_results)
1190 1195
1191 1196 # update cache with results:
1192 1197 for msg_id in sorted(theids):
1193 1198 if msg_id in content['completed']:
1194 1199 rec = content[msg_id]
1195 1200 parent = rec['header']
1196 1201 header = rec['result_header']
1197 1202 rcontent = rec['result_content']
1198 1203 iodict = rec['io']
1199 1204 if isinstance(rcontent, str):
1200 1205 rcontent = self.session.unpack(rcontent)
1201 1206
1202 1207 md = self.metadata[msg_id]
1203 1208 md.update(self._extract_metadata(header, parent, rcontent))
1204 1209 md.update(iodict)
1205 1210
1206 1211 if rcontent['status'] == 'ok':
1207 1212 res,buffers = util.unserialize_object(buffers)
1208 1213 else:
1209 1214 print rcontent
1210 1215 res = self._unwrap_exception(rcontent)
1211 1216 failures.append(res)
1212 1217
1213 1218 self.results[msg_id] = res
1214 1219 content[msg_id] = res
1215 1220
1216 1221 if len(theids) == 1 and failures:
1217 1222 raise failures[0]
1218 1223
1219 1224 error.collect_exceptions(failures, "result_status")
1220 1225 return content
1221 1226
1222 1227 @spin_first
1223 1228 def queue_status(self, targets='all', verbose=False):
1224 1229 """Fetch the status of engine queues.
1225 1230
1226 1231 Parameters
1227 1232 ----------
1228 1233
1229 1234 targets : int/str/list of ints/strs
1230 1235 the engines whose states are to be queried.
1231 1236 default : all
1232 1237 verbose : bool
1233 1238 Whether to return lengths only, or lists of ids for each element
1234 1239 """
1235 1240 engine_ids = self._build_targets(targets)[1]
1236 1241 content = dict(targets=engine_ids, verbose=verbose)
1237 1242 self.session.send(self._query_socket, "queue_request", content=content)
1238 1243 idents,msg = self.session.recv(self._query_socket, 0)
1239 1244 if self.debug:
1240 1245 pprint(msg)
1241 1246 content = msg['content']
1242 1247 status = content.pop('status')
1243 1248 if status != 'ok':
1244 1249 raise self._unwrap_exception(content)
1245 1250 content = util.rekey(content)
1246 1251 if isinstance(targets, int):
1247 1252 return content[targets]
1248 1253 else:
1249 1254 return content
1250 1255
1251 1256 @spin_first
1252 1257 def purge_results(self, jobs=[], targets=[]):
1253 1258 """Tell the Hub to forget results.
1254 1259
1255 1260 Individual results can be purged by msg_id, or the entire
1256 1261 history of specific targets can be purged.
1257 1262
1258 1263 Parameters
1259 1264 ----------
1260 1265
1261 1266 jobs : str or list of str or AsyncResult objects
1262 1267 the msg_ids whose results should be forgotten.
1263 1268 targets : int/str/list of ints/strs
1264 1269 The targets, by uuid or int_id, whose entire history is to be purged.
1265 1270 Use `targets='all'` to scrub everything from the Hub's memory.
1266 1271
1267 1272 default : None
1268 1273 """
1269 1274 if not targets and not jobs:
1270 1275 raise ValueError("Must specify at least one of `targets` and `jobs`")
1271 1276 if targets:
1272 1277 targets = self._build_targets(targets)[1]
1273 1278
1274 1279 # construct msg_ids from jobs
1275 1280 msg_ids = []
1276 1281 if isinstance(jobs, (basestring,AsyncResult)):
1277 1282 jobs = [jobs]
1278 1283 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1279 1284 if bad_ids:
1280 1285 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1281 1286 for j in jobs:
1282 1287 if isinstance(j, AsyncResult):
1283 1288 msg_ids.extend(j.msg_ids)
1284 1289 else:
1285 1290 msg_ids.append(j)
1286 1291
1287 1292 content = dict(targets=targets, msg_ids=msg_ids)
1288 1293 self.session.send(self._query_socket, "purge_request", content=content)
1289 1294 idents, msg = self.session.recv(self._query_socket, 0)
1290 1295 if self.debug:
1291 1296 pprint(msg)
1292 1297 content = msg['content']
1293 1298 if content['status'] != 'ok':
1294 1299 raise self._unwrap_exception(content)
1295 1300
1296 1301 @spin_first
1297 1302 def hub_history(self):
1298 1303 """Get the Hub's history
1299 1304
1300 1305 Just like the Client, the Hub has a history, which is a list of msg_ids.
1301 1306 This will contain the history of all clients, and, depending on configuration,
1302 1307 may contain history across multiple cluster sessions.
1303 1308
1304 1309 Any msg_id returned here is a valid argument to `get_result`.
1305 1310
1306 1311 Returns
1307 1312 -------
1308 1313
1309 1314 msg_ids : list of strs
1310 1315 list of all msg_ids, ordered by task submission time.
1311 1316 """
1312 1317
1313 1318 self.session.send(self._query_socket, "history_request", content={})
1314 1319 idents, msg = self.session.recv(self._query_socket, 0)
1315 1320
1316 1321 if self.debug:
1317 1322 pprint(msg)
1318 1323 content = msg['content']
1319 1324 if content['status'] != 'ok':
1320 1325 raise self._unwrap_exception(content)
1321 1326 else:
1322 1327 return content['history']
1323 1328
1324 1329 @spin_first
1325 1330 def db_query(self, query, keys=None):
1326 1331 """Query the Hub's TaskRecord database
1327 1332
1328 1333 This will return a list of task record dicts that match `query`
1329 1334
1330 1335 Parameters
1331 1336 ----------
1332 1337
1333 1338 query : mongodb query dict
1334 1339 The search dict. See mongodb query docs for details.
1335 1340 keys : list of strs [optional]
1336 1341 The subset of keys to be returned. The default is to fetch everything but buffers.
1337 1342 'msg_id' will *always* be included.
1338 1343 """
1339 1344 if isinstance(keys, basestring):
1340 1345 keys = [keys]
1341 1346 content = dict(query=query, keys=keys)
1342 1347 self.session.send(self._query_socket, "db_request", content=content)
1343 1348 idents, msg = self.session.recv(self._query_socket, 0)
1344 1349 if self.debug:
1345 1350 pprint(msg)
1346 1351 content = msg['content']
1347 1352 if content['status'] != 'ok':
1348 1353 raise self._unwrap_exception(content)
1349 1354
1350 1355 records = content['records']
1351 1356
1352 1357 buffer_lens = content['buffer_lens']
1353 1358 result_buffer_lens = content['result_buffer_lens']
1354 1359 buffers = msg['buffers']
1355 1360 has_bufs = buffer_lens is not None
1356 1361 has_rbufs = result_buffer_lens is not None
1357 1362 for i,rec in enumerate(records):
1358 1363 # relink buffers
1359 1364 if has_bufs:
1360 1365 blen = buffer_lens[i]
1361 1366 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1362 1367 if has_rbufs:
1363 1368 blen = result_buffer_lens[i]
1364 1369 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1365 1370
1366 1371 return records
1367 1372
1368 1373 __all__ = [ 'Client' ]
@@ -1,158 +1,165 b''
1 1 # encoding: utf-8
2 2
3 3 """Classes used in scattering and gathering sequences.
4 4
5 5 Scattering consists of partitioning a sequence and sending the various
6 6 pieces to individual nodes in a cluster.
7
8
9 Authors:
10
11 * Brian Granger
12 * MinRK
13
7 14 """
8 15
9 16 __docformat__ = "restructuredtext en"
10 17
11 18 #-------------------------------------------------------------------------------
12 # Copyright (C) 2008 The IPython Development Team
19 # Copyright (C) 2008-2011 The IPython Development Team
13 20 #
14 21 # Distributed under the terms of the BSD License. The full license is in
15 22 # the file COPYING, distributed as part of this software.
16 23 #-------------------------------------------------------------------------------
17 24
18 25 #-------------------------------------------------------------------------------
19 26 # Imports
20 27 #-------------------------------------------------------------------------------
21 28
22 29 import types
23 30
24 31 from IPython.utils.data import flatten as utils_flatten
25 32
26 33 #-------------------------------------------------------------------------------
27 34 # Figure out which array packages are present and their array types
28 35 #-------------------------------------------------------------------------------
29 36
30 37 arrayModules = []
31 38 try:
32 39 import Numeric
33 40 except ImportError:
34 41 pass
35 42 else:
36 43 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
37 44 try:
38 45 import numpy
39 46 except ImportError:
40 47 pass
41 48 else:
42 49 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
43 50 try:
44 51 import numarray
45 52 except ImportError:
46 53 pass
47 54 else:
48 55 arrayModules.append({'module':numarray,
49 56 'type':numarray.numarraycore.NumArray})
50 57
51 58 class Map:
52 59 """A class for partitioning a sequence using a map."""
53 60
54 61 def getPartition(self, seq, p, q):
55 62 """Returns the pth partition of q partitions of seq."""
56 63
57 64 # Test for error conditions here
58 65 if p<0 or p>=q:
59 66 print "No partition exists."
60 67 return
61 68
62 69 remainder = len(seq)%q
63 70 basesize = len(seq)/q
64 71 hi = []
65 72 lo = []
66 73 for n in range(q):
67 74 if n < remainder:
68 75 lo.append(n * (basesize + 1))
69 76 hi.append(lo[-1] + basesize + 1)
70 77 else:
71 78 lo.append(n*basesize + remainder)
72 79 hi.append(lo[-1] + basesize)
73 80
74 81
75 82 result = seq[lo[p]:hi[p]]
76 83 return result
77 84
78 85 def joinPartitions(self, listOfPartitions):
79 86 return self.concatenate(listOfPartitions)
80 87
81 88 def concatenate(self, listOfPartitions):
82 89 testObject = listOfPartitions[0]
83 90 # First see if we have a known array type
84 91 for m in arrayModules:
85 92 #print m
86 93 if isinstance(testObject, m['type']):
87 94 return m['module'].concatenate(listOfPartitions)
88 95 # Next try for Python sequence types
89 96 if isinstance(testObject, (types.ListType, types.TupleType)):
90 97 return utils_flatten(listOfPartitions)
91 98 # If we have scalars, just return listOfPartitions
92 99 return listOfPartitions
93 100
94 101 class RoundRobinMap(Map):
95 102 """Partitions a sequence in a roun robin fashion.
96 103
97 104 This currently does not work!
98 105 """
99 106
100 107 def getPartition(self, seq, p, q):
101 108 # if not isinstance(seq,(list,tuple)):
102 109 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
103 110 return seq[p:len(seq):q]
104 111 #result = []
105 112 #for i in range(p,len(seq),q):
106 113 # result.append(seq[i])
107 114 #return result
108 115
109 116 def joinPartitions(self, listOfPartitions):
110 117 testObject = listOfPartitions[0]
111 118 # First see if we have a known array type
112 119 for m in arrayModules:
113 120 #print m
114 121 if isinstance(testObject, m['type']):
115 122 return self.flatten_array(m['type'], listOfPartitions)
116 123 if isinstance(testObject, (types.ListType, types.TupleType)):
117 124 return self.flatten_list(listOfPartitions)
118 125 return listOfPartitions
119 126
120 127 def flatten_array(self, klass, listOfPartitions):
121 128 test = listOfPartitions[0]
122 129 shape = list(test.shape)
123 130 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
124 131 A = klass(shape)
125 132 N = shape[0]
126 133 q = len(listOfPartitions)
127 134 for p,part in enumerate(listOfPartitions):
128 135 A[p:N:q] = part
129 136 return A
130 137
131 138 def flatten_list(self, listOfPartitions):
132 139 flat = []
133 140 for i in range(len(listOfPartitions[0])):
134 141 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
135 142 return flat
136 143 #lengths = [len(x) for x in listOfPartitions]
137 144 #maxPartitionLength = len(listOfPartitions[0])
138 145 #numberOfPartitions = len(listOfPartitions)
139 146 #concat = self.concatenate(listOfPartitions)
140 147 #totalLength = len(concat)
141 148 #result = []
142 149 #for i in range(maxPartitionLength):
143 150 # result.append(concat[i:totalLength:maxPartitionLength])
144 151 # return self.concatenate(listOfPartitions)
145 152
146 153 def mappable(obj):
147 154 """return whether an object is mappable or not."""
148 155 if isinstance(obj, (tuple,list)):
149 156 return True
150 157 for m in arrayModules:
151 158 if isinstance(obj,m['type']):
152 159 return True
153 160 return False
154 161
155 162 dists = {'b':Map,'r':RoundRobinMap}
156 163
157 164
158 165
@@ -1,200 +1,206 b''
1 """Remote Functions and decorators for Views."""
1 """Remote Functions and decorators for Views.
2
3 Authors:
4
5 * Brian Granger
6 * Min RK
7 """
2 8 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
4 10 #
5 11 # Distributed under the terms of the BSD License. The full license is in
6 12 # the file COPYING, distributed as part of this software.
7 13 #-----------------------------------------------------------------------------
8 14
9 15 #-----------------------------------------------------------------------------
10 16 # Imports
11 17 #-----------------------------------------------------------------------------
12 18
13 19 import warnings
14 20
15 21 from IPython.testing.skipdoctest import skip_doctest
16 22
17 23 from . import map as Map
18 24 from .asyncresult import AsyncMapResult
19 25
20 26 #-----------------------------------------------------------------------------
21 27 # Decorators
22 28 #-----------------------------------------------------------------------------
23 29
24 30 @skip_doctest
25 31 def remote(view, block=None, **flags):
26 32 """Turn a function into a remote function.
27 33
28 34 This method can be used for map:
29 35
30 36 In [1]: @remote(view,block=True)
31 37 ...: def func(a):
32 38 ...: pass
33 39 """
34 40
35 41 def remote_function(f):
36 42 return RemoteFunction(view, f, block=block, **flags)
37 43 return remote_function
38 44
39 45 @skip_doctest
40 46 def parallel(view, dist='b', block=None, **flags):
41 47 """Turn a function into a parallel remote function.
42 48
43 49 This method can be used for map:
44 50
45 51 In [1]: @parallel(view, block=True)
46 52 ...: def func(a):
47 53 ...: pass
48 54 """
49 55
50 56 def parallel_function(f):
51 57 return ParallelFunction(view, f, dist=dist, block=block, **flags)
52 58 return parallel_function
53 59
54 60 #--------------------------------------------------------------------------
55 61 # Classes
56 62 #--------------------------------------------------------------------------
57 63
58 64 class RemoteFunction(object):
59 65 """Turn an existing function into a remote function.
60 66
61 67 Parameters
62 68 ----------
63 69
64 70 view : View instance
65 71 The view to be used for execution
66 72 f : callable
67 73 The function to be wrapped into a remote function
68 74 block : bool [default: None]
69 75 Whether to wait for results or not. The default behavior is
70 76 to use the current `block` attribute of `view`
71 77
72 78 **flags : remaining kwargs are passed to View.temp_flags
73 79 """
74 80
75 81 view = None # the remote connection
76 82 func = None # the wrapped function
77 83 block = None # whether to block
78 84 flags = None # dict of extra kwargs for temp_flags
79 85
80 86 def __init__(self, view, f, block=None, **flags):
81 87 self.view = view
82 88 self.func = f
83 89 self.block=block
84 90 self.flags=flags
85 91
86 92 def __call__(self, *args, **kwargs):
87 93 block = self.view.block if self.block is None else self.block
88 94 with self.view.temp_flags(block=block, **self.flags):
89 95 return self.view.apply(self.func, *args, **kwargs)
90 96
91 97
92 98 class ParallelFunction(RemoteFunction):
93 99 """Class for mapping a function to sequences.
94 100
95 101 This will distribute the sequences according the a mapper, and call
96 102 the function on each sub-sequence. If called via map, then the function
97 103 will be called once on each element, rather that each sub-sequence.
98 104
99 105 Parameters
100 106 ----------
101 107
102 108 view : View instance
103 109 The view to be used for execution
104 110 f : callable
105 111 The function to be wrapped into a remote function
106 112 dist : str [default: 'b']
107 113 The key for which mapObject to use to distribute sequences
108 114 options are:
109 115 * 'b' : use contiguous chunks in order
110 116 * 'r' : use round-robin striping
111 117 block : bool [default: None]
112 118 Whether to wait for results or not. The default behavior is
113 119 to use the current `block` attribute of `view`
114 120 chunksize : int or None
115 121 The size of chunk to use when breaking up sequences in a load-balanced manner
116 122 **flags : remaining kwargs are passed to View.temp_flags
117 123 """
118 124
119 125 chunksize=None
120 126 mapObject=None
121 127
122 128 def __init__(self, view, f, dist='b', block=None, chunksize=None, **flags):
123 129 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
124 130 self.chunksize = chunksize
125 131
126 132 mapClass = Map.dists[dist]
127 133 self.mapObject = mapClass()
128 134
129 135 def __call__(self, *sequences):
130 136 # check that the length of sequences match
131 137 len_0 = len(sequences[0])
132 138 for s in sequences:
133 139 if len(s)!=len_0:
134 140 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
135 141 raise ValueError(msg)
136 142 balanced = 'Balanced' in self.view.__class__.__name__
137 143 if balanced:
138 144 if self.chunksize:
139 145 nparts = len_0/self.chunksize + int(len_0%self.chunksize > 0)
140 146 else:
141 147 nparts = len_0
142 148 targets = [None]*nparts
143 149 else:
144 150 if self.chunksize:
145 151 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
146 152 # multiplexed:
147 153 targets = self.view.targets
148 154 nparts = len(targets)
149 155
150 156 msg_ids = []
151 157 # my_f = lambda *a: map(self.func, *a)
152 158 client = self.view.client
153 159 for index, t in enumerate(targets):
154 160 args = []
155 161 for seq in sequences:
156 162 part = self.mapObject.getPartition(seq, index, nparts)
157 163 if len(part) == 0:
158 164 continue
159 165 else:
160 166 args.append(part)
161 167 if not args:
162 168 continue
163 169
164 170 # print (args)
165 171 if hasattr(self, '_map'):
166 172 f = map
167 173 args = [self.func]+args
168 174 else:
169 175 f=self.func
170 176
171 177 view = self.view if balanced else client[t]
172 178 with view.temp_flags(block=False, **self.flags):
173 179 ar = view.apply(f, *args)
174 180
175 181 msg_ids.append(ar.msg_ids[0])
176 182
177 183 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject, fname=self.func.__name__)
178 184
179 185 if self.block:
180 186 try:
181 187 return r.get()
182 188 except KeyboardInterrupt:
183 189 return r
184 190 else:
185 191 return r
186 192
187 193 def map(self, *sequences):
188 194 """call a function on each element of a sequence remotely.
189 195 This should behave very much like the builtin map, but return an AsyncMapResult
190 196 if self.block is False.
191 197 """
192 198 # set _map as a flag for use inside self.__call__
193 199 self._map = True
194 200 try:
195 201 ret = self.__call__(*sequences)
196 202 finally:
197 203 del self._map
198 204 return ret
199 205
200 206 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,1041 +1,1046 b''
1 """Views of remote engines."""
1 """Views of remote engines.
2
3 Authors:
4
5 * Min RK
6 """
2 7 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
4 9 #
5 10 # Distributed under the terms of the BSD License. The full license is in
6 11 # the file COPYING, distributed as part of this software.
7 12 #-----------------------------------------------------------------------------
8 13
9 14 #-----------------------------------------------------------------------------
10 15 # Imports
11 16 #-----------------------------------------------------------------------------
12 17
13 18 import imp
14 19 import sys
15 20 import warnings
16 21 from contextlib import contextmanager
17 22 from types import ModuleType
18 23
19 24 import zmq
20 25
21 26 from IPython.testing.skipdoctest import skip_doctest
22 27 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat, CInt
23 28 from IPython.external.decorator import decorator
24 29
25 30 from IPython.parallel import util
26 31 from IPython.parallel.controller.dependency import Dependency, dependent
27 32
28 33 from . import map as Map
29 34 from .asyncresult import AsyncResult, AsyncMapResult
30 35 from .remotefunction import ParallelFunction, parallel, remote
31 36
32 37 #-----------------------------------------------------------------------------
33 38 # Decorators
34 39 #-----------------------------------------------------------------------------
35 40
36 41 @decorator
37 42 def save_ids(f, self, *args, **kwargs):
38 43 """Keep our history and outstanding attributes up to date after a method call."""
39 44 n_previous = len(self.client.history)
40 45 try:
41 46 ret = f(self, *args, **kwargs)
42 47 finally:
43 48 nmsgs = len(self.client.history) - n_previous
44 49 msg_ids = self.client.history[-nmsgs:]
45 50 self.history.extend(msg_ids)
46 51 map(self.outstanding.add, msg_ids)
47 52 return ret
48 53
49 54 @decorator
50 55 def sync_results(f, self, *args, **kwargs):
51 56 """sync relevant results from self.client to our results attribute."""
52 57 ret = f(self, *args, **kwargs)
53 58 delta = self.outstanding.difference(self.client.outstanding)
54 59 completed = self.outstanding.intersection(delta)
55 60 self.outstanding = self.outstanding.difference(completed)
56 61 for msg_id in completed:
57 62 self.results[msg_id] = self.client.results[msg_id]
58 63 return ret
59 64
60 65 @decorator
61 66 def spin_after(f, self, *args, **kwargs):
62 67 """call spin after the method."""
63 68 ret = f(self, *args, **kwargs)
64 69 self.spin()
65 70 return ret
66 71
67 72 #-----------------------------------------------------------------------------
68 73 # Classes
69 74 #-----------------------------------------------------------------------------
70 75
71 76 @skip_doctest
72 77 class View(HasTraits):
73 78 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
74 79
75 80 Don't use this class, use subclasses.
76 81
77 82 Methods
78 83 -------
79 84
80 85 spin
81 86 flushes incoming results and registration state changes
82 87 control methods spin, and requesting `ids` also ensures up to date
83 88
84 89 wait
85 90 wait on one or more msg_ids
86 91
87 92 execution methods
88 93 apply
89 94 legacy: execute, run
90 95
91 96 data movement
92 97 push, pull, scatter, gather
93 98
94 99 query methods
95 100 get_result, queue_status, purge_results, result_status
96 101
97 102 control methods
98 103 abort, shutdown
99 104
100 105 """
101 106 # flags
102 107 block=Bool(False)
103 108 track=Bool(True)
104 109 targets = Any()
105 110
106 111 history=List()
107 112 outstanding = Set()
108 113 results = Dict()
109 114 client = Instance('IPython.parallel.Client')
110 115
111 116 _socket = Instance('zmq.Socket')
112 117 _flag_names = List(['targets', 'block', 'track'])
113 118 _targets = Any()
114 119 _idents = Any()
115 120
116 121 def __init__(self, client=None, socket=None, **flags):
117 122 super(View, self).__init__(client=client, _socket=socket)
118 123 self.block = client.block
119 124
120 125 self.set_flags(**flags)
121 126
122 127 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
123 128
124 129
125 130 def __repr__(self):
126 131 strtargets = str(self.targets)
127 132 if len(strtargets) > 16:
128 133 strtargets = strtargets[:12]+'...]'
129 134 return "<%s %s>"%(self.__class__.__name__, strtargets)
130 135
131 136 def set_flags(self, **kwargs):
132 137 """set my attribute flags by keyword.
133 138
134 139 Views determine behavior with a few attributes (`block`, `track`, etc.).
135 140 These attributes can be set all at once by name with this method.
136 141
137 142 Parameters
138 143 ----------
139 144
140 145 block : bool
141 146 whether to wait for results
142 147 track : bool
143 148 whether to create a MessageTracker to allow the user to
144 149 safely edit after arrays and buffers during non-copying
145 150 sends.
146 151 """
147 152 for name, value in kwargs.iteritems():
148 153 if name not in self._flag_names:
149 154 raise KeyError("Invalid name: %r"%name)
150 155 else:
151 156 setattr(self, name, value)
152 157
153 158 @contextmanager
154 159 def temp_flags(self, **kwargs):
155 160 """temporarily set flags, for use in `with` statements.
156 161
157 162 See set_flags for permanent setting of flags
158 163
159 164 Examples
160 165 --------
161 166
162 167 >>> view.track=False
163 168 ...
164 169 >>> with view.temp_flags(track=True):
165 170 ... ar = view.apply(dostuff, my_big_array)
166 171 ... ar.tracker.wait() # wait for send to finish
167 172 >>> view.track
168 173 False
169 174
170 175 """
171 176 # preflight: save flags, and set temporaries
172 177 saved_flags = {}
173 178 for f in self._flag_names:
174 179 saved_flags[f] = getattr(self, f)
175 180 self.set_flags(**kwargs)
176 181 # yield to the with-statement block
177 182 try:
178 183 yield
179 184 finally:
180 185 # postflight: restore saved flags
181 186 self.set_flags(**saved_flags)
182 187
183 188
184 189 #----------------------------------------------------------------
185 190 # apply
186 191 #----------------------------------------------------------------
187 192
188 193 @sync_results
189 194 @save_ids
190 195 def _really_apply(self, f, args, kwargs, block=None, **options):
191 196 """wrapper for client.send_apply_message"""
192 197 raise NotImplementedError("Implement in subclasses")
193 198
194 199 def apply(self, f, *args, **kwargs):
195 200 """calls f(*args, **kwargs) on remote engines, returning the result.
196 201
197 202 This method sets all apply flags via this View's attributes.
198 203
199 204 if self.block is False:
200 205 returns AsyncResult
201 206 else:
202 207 returns actual result of f(*args, **kwargs)
203 208 """
204 209 return self._really_apply(f, args, kwargs)
205 210
206 211 def apply_async(self, f, *args, **kwargs):
207 212 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
208 213
209 214 returns AsyncResult
210 215 """
211 216 return self._really_apply(f, args, kwargs, block=False)
212 217
213 218 @spin_after
214 219 def apply_sync(self, f, *args, **kwargs):
215 220 """calls f(*args, **kwargs) on remote engines in a blocking manner,
216 221 returning the result.
217 222
218 223 returns: actual result of f(*args, **kwargs)
219 224 """
220 225 return self._really_apply(f, args, kwargs, block=True)
221 226
222 227 #----------------------------------------------------------------
223 228 # wrappers for client and control methods
224 229 #----------------------------------------------------------------
225 230 @sync_results
226 231 def spin(self):
227 232 """spin the client, and sync"""
228 233 self.client.spin()
229 234
230 235 @sync_results
231 236 def wait(self, jobs=None, timeout=-1):
232 237 """waits on one or more `jobs`, for up to `timeout` seconds.
233 238
234 239 Parameters
235 240 ----------
236 241
237 242 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
238 243 ints are indices to self.history
239 244 strs are msg_ids
240 245 default: wait on all outstanding messages
241 246 timeout : float
242 247 a time in seconds, after which to give up.
243 248 default is -1, which means no timeout
244 249
245 250 Returns
246 251 -------
247 252
248 253 True : when all msg_ids are done
249 254 False : timeout reached, some msg_ids still outstanding
250 255 """
251 256 if jobs is None:
252 257 jobs = self.history
253 258 return self.client.wait(jobs, timeout)
254 259
255 260 def abort(self, jobs=None, targets=None, block=None):
256 261 """Abort jobs on my engines.
257 262
258 263 Parameters
259 264 ----------
260 265
261 266 jobs : None, str, list of strs, optional
262 267 if None: abort all jobs.
263 268 else: abort specific msg_id(s).
264 269 """
265 270 block = block if block is not None else self.block
266 271 targets = targets if targets is not None else self.targets
267 272 return self.client.abort(jobs=jobs, targets=targets, block=block)
268 273
269 274 def queue_status(self, targets=None, verbose=False):
270 275 """Fetch the Queue status of my engines"""
271 276 targets = targets if targets is not None else self.targets
272 277 return self.client.queue_status(targets=targets, verbose=verbose)
273 278
274 279 def purge_results(self, jobs=[], targets=[]):
275 280 """Instruct the controller to forget specific results."""
276 281 if targets is None or targets == 'all':
277 282 targets = self.targets
278 283 return self.client.purge_results(jobs=jobs, targets=targets)
279 284
280 285 def shutdown(self, targets=None, restart=False, hub=False, block=None):
281 286 """Terminates one or more engine processes, optionally including the hub.
282 287 """
283 288 block = self.block if block is None else block
284 289 if targets is None or targets == 'all':
285 290 targets = self.targets
286 291 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
287 292
288 293 @spin_after
289 294 def get_result(self, indices_or_msg_ids=None):
290 295 """return one or more results, specified by history index or msg_id.
291 296
292 297 See client.get_result for details.
293 298
294 299 """
295 300
296 301 if indices_or_msg_ids is None:
297 302 indices_or_msg_ids = -1
298 303 if isinstance(indices_or_msg_ids, int):
299 304 indices_or_msg_ids = self.history[indices_or_msg_ids]
300 305 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
301 306 indices_or_msg_ids = list(indices_or_msg_ids)
302 307 for i,index in enumerate(indices_or_msg_ids):
303 308 if isinstance(index, int):
304 309 indices_or_msg_ids[i] = self.history[index]
305 310 return self.client.get_result(indices_or_msg_ids)
306 311
307 312 #-------------------------------------------------------------------
308 313 # Map
309 314 #-------------------------------------------------------------------
310 315
311 316 def map(self, f, *sequences, **kwargs):
312 317 """override in subclasses"""
313 318 raise NotImplementedError
314 319
315 320 def map_async(self, f, *sequences, **kwargs):
316 321 """Parallel version of builtin `map`, using this view's engines.
317 322
318 323 This is equivalent to map(...block=False)
319 324
320 325 See `self.map` for details.
321 326 """
322 327 if 'block' in kwargs:
323 328 raise TypeError("map_async doesn't take a `block` keyword argument.")
324 329 kwargs['block'] = False
325 330 return self.map(f,*sequences,**kwargs)
326 331
327 332 def map_sync(self, f, *sequences, **kwargs):
328 333 """Parallel version of builtin `map`, using this view's engines.
329 334
330 335 This is equivalent to map(...block=True)
331 336
332 337 See `self.map` for details.
333 338 """
334 339 if 'block' in kwargs:
335 340 raise TypeError("map_sync doesn't take a `block` keyword argument.")
336 341 kwargs['block'] = True
337 342 return self.map(f,*sequences,**kwargs)
338 343
339 344 def imap(self, f, *sequences, **kwargs):
340 345 """Parallel version of `itertools.imap`.
341 346
342 347 See `self.map` for details.
343 348
344 349 """
345 350
346 351 return iter(self.map_async(f,*sequences, **kwargs))
347 352
348 353 #-------------------------------------------------------------------
349 354 # Decorators
350 355 #-------------------------------------------------------------------
351 356
352 357 def remote(self, block=True, **flags):
353 358 """Decorator for making a RemoteFunction"""
354 359 block = self.block if block is None else block
355 360 return remote(self, block=block, **flags)
356 361
357 362 def parallel(self, dist='b', block=None, **flags):
358 363 """Decorator for making a ParallelFunction"""
359 364 block = self.block if block is None else block
360 365 return parallel(self, dist=dist, block=block, **flags)
361 366
362 367 @skip_doctest
363 368 class DirectView(View):
364 369 """Direct Multiplexer View of one or more engines.
365 370
366 371 These are created via indexed access to a client:
367 372
368 373 >>> dv_1 = client[1]
369 374 >>> dv_all = client[:]
370 375 >>> dv_even = client[::2]
371 376 >>> dv_some = client[1:3]
372 377
373 378 This object provides dictionary access to engine namespaces:
374 379
375 380 # push a=5:
376 381 >>> dv['a'] = 5
377 382 # pull 'foo':
378 383 >>> db['foo']
379 384
380 385 """
381 386
382 387 def __init__(self, client=None, socket=None, targets=None):
383 388 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
384 389
385 390 @property
386 391 def importer(self):
387 392 """sync_imports(local=True) as a property.
388 393
389 394 See sync_imports for details.
390 395
391 396 """
392 397 return self.sync_imports(True)
393 398
394 399 @contextmanager
395 400 def sync_imports(self, local=True):
396 401 """Context Manager for performing simultaneous local and remote imports.
397 402
398 403 'import x as y' will *not* work. The 'as y' part will simply be ignored.
399 404
400 405 >>> with view.sync_imports():
401 406 ... from numpy import recarray
402 407 importing recarray from numpy on engine(s)
403 408
404 409 """
405 410 import __builtin__
406 411 local_import = __builtin__.__import__
407 412 modules = set()
408 413 results = []
409 414 @util.interactive
410 415 def remote_import(name, fromlist, level):
411 416 """the function to be passed to apply, that actually performs the import
412 417 on the engine, and loads up the user namespace.
413 418 """
414 419 import sys
415 420 user_ns = globals()
416 421 mod = __import__(name, fromlist=fromlist, level=level)
417 422 if fromlist:
418 423 for key in fromlist:
419 424 user_ns[key] = getattr(mod, key)
420 425 else:
421 426 user_ns[name] = sys.modules[name]
422 427
423 428 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
424 429 """the drop-in replacement for __import__, that optionally imports
425 430 locally as well.
426 431 """
427 432 # don't override nested imports
428 433 save_import = __builtin__.__import__
429 434 __builtin__.__import__ = local_import
430 435
431 436 if imp.lock_held():
432 437 # this is a side-effect import, don't do it remotely, or even
433 438 # ignore the local effects
434 439 return local_import(name, globals, locals, fromlist, level)
435 440
436 441 imp.acquire_lock()
437 442 if local:
438 443 mod = local_import(name, globals, locals, fromlist, level)
439 444 else:
440 445 raise NotImplementedError("remote-only imports not yet implemented")
441 446 imp.release_lock()
442 447
443 448 key = name+':'+','.join(fromlist or [])
444 449 if level == -1 and key not in modules:
445 450 modules.add(key)
446 451 if fromlist:
447 452 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
448 453 else:
449 454 print "importing %s on engine(s)"%name
450 455 results.append(self.apply_async(remote_import, name, fromlist, level))
451 456 # restore override
452 457 __builtin__.__import__ = save_import
453 458
454 459 return mod
455 460
456 461 # override __import__
457 462 __builtin__.__import__ = view_import
458 463 try:
459 464 # enter the block
460 465 yield
461 466 except ImportError:
462 467 if not local:
463 468 # ignore import errors if not doing local imports
464 469 pass
465 470 finally:
466 471 # always restore __import__
467 472 __builtin__.__import__ = local_import
468 473
469 474 for r in results:
470 475 # raise possible remote ImportErrors here
471 476 r.get()
472 477
473 478
474 479 @sync_results
475 480 @save_ids
476 481 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
477 482 """calls f(*args, **kwargs) on remote engines, returning the result.
478 483
479 484 This method sets all of `apply`'s flags via this View's attributes.
480 485
481 486 Parameters
482 487 ----------
483 488
484 489 f : callable
485 490
486 491 args : list [default: empty]
487 492
488 493 kwargs : dict [default: empty]
489 494
490 495 targets : target list [default: self.targets]
491 496 where to run
492 497 block : bool [default: self.block]
493 498 whether to block
494 499 track : bool [default: self.track]
495 500 whether to ask zmq to track the message, for safe non-copying sends
496 501
497 502 Returns
498 503 -------
499 504
500 505 if self.block is False:
501 506 returns AsyncResult
502 507 else:
503 508 returns actual result of f(*args, **kwargs) on the engine(s)
504 509 This will be a list of self.targets is also a list (even length 1), or
505 510 the single result if self.targets is an integer engine id
506 511 """
507 512 args = [] if args is None else args
508 513 kwargs = {} if kwargs is None else kwargs
509 514 block = self.block if block is None else block
510 515 track = self.track if track is None else track
511 516 targets = self.targets if targets is None else targets
512 517
513 518 _idents = self.client._build_targets(targets)[0]
514 519 msg_ids = []
515 520 trackers = []
516 521 for ident in _idents:
517 522 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
518 523 ident=ident)
519 524 if track:
520 525 trackers.append(msg['tracker'])
521 526 msg_ids.append(msg['msg_id'])
522 527 tracker = None if track is False else zmq.MessageTracker(*trackers)
523 528 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
524 529 if block:
525 530 try:
526 531 return ar.get()
527 532 except KeyboardInterrupt:
528 533 pass
529 534 return ar
530 535
531 536 @spin_after
532 537 def map(self, f, *sequences, **kwargs):
533 538 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
534 539
535 540 Parallel version of builtin `map`, using this View's `targets`.
536 541
537 542 There will be one task per target, so work will be chunked
538 543 if the sequences are longer than `targets`.
539 544
540 545 Results can be iterated as they are ready, but will become available in chunks.
541 546
542 547 Parameters
543 548 ----------
544 549
545 550 f : callable
546 551 function to be mapped
547 552 *sequences: one or more sequences of matching length
548 553 the sequences to be distributed and passed to `f`
549 554 block : bool
550 555 whether to wait for the result or not [default self.block]
551 556
552 557 Returns
553 558 -------
554 559
555 560 if block=False:
556 561 AsyncMapResult
557 562 An object like AsyncResult, but which reassembles the sequence of results
558 563 into a single list. AsyncMapResults can be iterated through before all
559 564 results are complete.
560 565 else:
561 566 list
562 567 the result of map(f,*sequences)
563 568 """
564 569
565 570 block = kwargs.pop('block', self.block)
566 571 for k in kwargs.keys():
567 572 if k not in ['block', 'track']:
568 573 raise TypeError("invalid keyword arg, %r"%k)
569 574
570 575 assert len(sequences) > 0, "must have some sequences to map onto!"
571 576 pf = ParallelFunction(self, f, block=block, **kwargs)
572 577 return pf.map(*sequences)
573 578
574 579 def execute(self, code, targets=None, block=None):
575 580 """Executes `code` on `targets` in blocking or nonblocking manner.
576 581
577 582 ``execute`` is always `bound` (affects engine namespace)
578 583
579 584 Parameters
580 585 ----------
581 586
582 587 code : str
583 588 the code string to be executed
584 589 block : bool
585 590 whether or not to wait until done to return
586 591 default: self.block
587 592 """
588 593 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
589 594
590 595 def run(self, filename, targets=None, block=None):
591 596 """Execute contents of `filename` on my engine(s).
592 597
593 598 This simply reads the contents of the file and calls `execute`.
594 599
595 600 Parameters
596 601 ----------
597 602
598 603 filename : str
599 604 The path to the file
600 605 targets : int/str/list of ints/strs
601 606 the engines on which to execute
602 607 default : all
603 608 block : bool
604 609 whether or not to wait until done
605 610 default: self.block
606 611
607 612 """
608 613 with open(filename, 'r') as f:
609 614 # add newline in case of trailing indented whitespace
610 615 # which will cause SyntaxError
611 616 code = f.read()+'\n'
612 617 return self.execute(code, block=block, targets=targets)
613 618
614 619 def update(self, ns):
615 620 """update remote namespace with dict `ns`
616 621
617 622 See `push` for details.
618 623 """
619 624 return self.push(ns, block=self.block, track=self.track)
620 625
621 626 def push(self, ns, targets=None, block=None, track=None):
622 627 """update remote namespace with dict `ns`
623 628
624 629 Parameters
625 630 ----------
626 631
627 632 ns : dict
628 633 dict of keys with which to update engine namespace(s)
629 634 block : bool [default : self.block]
630 635 whether to wait to be notified of engine receipt
631 636
632 637 """
633 638
634 639 block = block if block is not None else self.block
635 640 track = track if track is not None else self.track
636 641 targets = targets if targets is not None else self.targets
637 642 # applier = self.apply_sync if block else self.apply_async
638 643 if not isinstance(ns, dict):
639 644 raise TypeError("Must be a dict, not %s"%type(ns))
640 645 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
641 646
642 647 def get(self, key_s):
643 648 """get object(s) by `key_s` from remote namespace
644 649
645 650 see `pull` for details.
646 651 """
647 652 # block = block if block is not None else self.block
648 653 return self.pull(key_s, block=True)
649 654
650 655 def pull(self, names, targets=None, block=None):
651 656 """get object(s) by `name` from remote namespace
652 657
653 658 will return one object if it is a key.
654 659 can also take a list of keys, in which case it will return a list of objects.
655 660 """
656 661 block = block if block is not None else self.block
657 662 targets = targets if targets is not None else self.targets
658 663 applier = self.apply_sync if block else self.apply_async
659 664 if isinstance(names, basestring):
660 665 pass
661 666 elif isinstance(names, (list,tuple,set)):
662 667 for key in names:
663 668 if not isinstance(key, basestring):
664 669 raise TypeError("keys must be str, not type %r"%type(key))
665 670 else:
666 671 raise TypeError("names must be strs, not %r"%names)
667 672 return self._really_apply(util._pull, (names,), block=block, targets=targets)
668 673
669 674 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
670 675 """
671 676 Partition a Python sequence and send the partitions to a set of engines.
672 677 """
673 678 block = block if block is not None else self.block
674 679 track = track if track is not None else self.track
675 680 targets = targets if targets is not None else self.targets
676 681
677 682 mapObject = Map.dists[dist]()
678 683 nparts = len(targets)
679 684 msg_ids = []
680 685 trackers = []
681 686 for index, engineid in enumerate(targets):
682 687 partition = mapObject.getPartition(seq, index, nparts)
683 688 if flatten and len(partition) == 1:
684 689 ns = {key: partition[0]}
685 690 else:
686 691 ns = {key: partition}
687 692 r = self.push(ns, block=False, track=track, targets=engineid)
688 693 msg_ids.extend(r.msg_ids)
689 694 if track:
690 695 trackers.append(r._tracker)
691 696
692 697 if track:
693 698 tracker = zmq.MessageTracker(*trackers)
694 699 else:
695 700 tracker = None
696 701
697 702 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
698 703 if block:
699 704 r.wait()
700 705 else:
701 706 return r
702 707
703 708 @sync_results
704 709 @save_ids
705 710 def gather(self, key, dist='b', targets=None, block=None):
706 711 """
707 712 Gather a partitioned sequence on a set of engines as a single local seq.
708 713 """
709 714 block = block if block is not None else self.block
710 715 targets = targets if targets is not None else self.targets
711 716 mapObject = Map.dists[dist]()
712 717 msg_ids = []
713 718
714 719 for index, engineid in enumerate(targets):
715 720 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
716 721
717 722 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
718 723
719 724 if block:
720 725 try:
721 726 return r.get()
722 727 except KeyboardInterrupt:
723 728 pass
724 729 return r
725 730
726 731 def __getitem__(self, key):
727 732 return self.get(key)
728 733
729 734 def __setitem__(self,key, value):
730 735 self.update({key:value})
731 736
732 737 def clear(self, targets=None, block=False):
733 738 """Clear the remote namespaces on my engines."""
734 739 block = block if block is not None else self.block
735 740 targets = targets if targets is not None else self.targets
736 741 return self.client.clear(targets=targets, block=block)
737 742
738 743 def kill(self, targets=None, block=True):
739 744 """Kill my engines."""
740 745 block = block if block is not None else self.block
741 746 targets = targets if targets is not None else self.targets
742 747 return self.client.kill(targets=targets, block=block)
743 748
744 749 #----------------------------------------
745 750 # activate for %px,%autopx magics
746 751 #----------------------------------------
747 752 def activate(self):
748 753 """Make this `View` active for parallel magic commands.
749 754
750 755 IPython has a magic command syntax to work with `MultiEngineClient` objects.
751 756 In a given IPython session there is a single active one. While
752 757 there can be many `Views` created and used by the user,
753 758 there is only one active one. The active `View` is used whenever
754 759 the magic commands %px and %autopx are used.
755 760
756 761 The activate() method is called on a given `View` to make it
757 762 active. Once this has been done, the magic commands can be used.
758 763 """
759 764
760 765 try:
761 766 # This is injected into __builtins__.
762 767 ip = get_ipython()
763 768 except NameError:
764 769 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
765 770 else:
766 771 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
767 772 if pmagic is None:
768 773 ip.magic_load_ext('parallelmagic')
769 774 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
770 775
771 776 pmagic.active_view = self
772 777
773 778
774 779 @skip_doctest
775 780 class LoadBalancedView(View):
776 781 """An load-balancing View that only executes via the Task scheduler.
777 782
778 783 Load-balanced views can be created with the client's `view` method:
779 784
780 785 >>> v = client.load_balanced_view()
781 786
782 787 or targets can be specified, to restrict the potential destinations:
783 788
784 789 >>> v = client.client.load_balanced_view(([1,3])
785 790
786 791 which would restrict loadbalancing to between engines 1 and 3.
787 792
788 793 """
789 794
790 795 follow=Any()
791 796 after=Any()
792 797 timeout=CFloat()
793 798 retries = CInt(0)
794 799
795 800 _task_scheme = Any()
796 801 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
797 802
798 803 def __init__(self, client=None, socket=None, **flags):
799 804 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
800 805 self._task_scheme=client._task_scheme
801 806
802 807 def _validate_dependency(self, dep):
803 808 """validate a dependency.
804 809
805 810 For use in `set_flags`.
806 811 """
807 812 if dep is None or isinstance(dep, (str, AsyncResult, Dependency)):
808 813 return True
809 814 elif isinstance(dep, (list,set, tuple)):
810 815 for d in dep:
811 816 if not isinstance(d, (str, AsyncResult)):
812 817 return False
813 818 elif isinstance(dep, dict):
814 819 if set(dep.keys()) != set(Dependency().as_dict().keys()):
815 820 return False
816 821 if not isinstance(dep['msg_ids'], list):
817 822 return False
818 823 for d in dep['msg_ids']:
819 824 if not isinstance(d, str):
820 825 return False
821 826 else:
822 827 return False
823 828
824 829 return True
825 830
826 831 def _render_dependency(self, dep):
827 832 """helper for building jsonable dependencies from various input forms."""
828 833 if isinstance(dep, Dependency):
829 834 return dep.as_dict()
830 835 elif isinstance(dep, AsyncResult):
831 836 return dep.msg_ids
832 837 elif dep is None:
833 838 return []
834 839 else:
835 840 # pass to Dependency constructor
836 841 return list(Dependency(dep))
837 842
838 843 def set_flags(self, **kwargs):
839 844 """set my attribute flags by keyword.
840 845
841 846 A View is a wrapper for the Client's apply method, but with attributes
842 847 that specify keyword arguments, those attributes can be set by keyword
843 848 argument with this method.
844 849
845 850 Parameters
846 851 ----------
847 852
848 853 block : bool
849 854 whether to wait for results
850 855 track : bool
851 856 whether to create a MessageTracker to allow the user to
852 857 safely edit after arrays and buffers during non-copying
853 858 sends.
854 859
855 860 after : Dependency or collection of msg_ids
856 861 Only for load-balanced execution (targets=None)
857 862 Specify a list of msg_ids as a time-based dependency.
858 863 This job will only be run *after* the dependencies
859 864 have been met.
860 865
861 866 follow : Dependency or collection of msg_ids
862 867 Only for load-balanced execution (targets=None)
863 868 Specify a list of msg_ids as a location-based dependency.
864 869 This job will only be run on an engine where this dependency
865 870 is met.
866 871
867 872 timeout : float/int or None
868 873 Only for load-balanced execution (targets=None)
869 874 Specify an amount of time (in seconds) for the scheduler to
870 875 wait for dependencies to be met before failing with a
871 876 DependencyTimeout.
872 877
873 878 retries : int
874 879 Number of times a task will be retried on failure.
875 880 """
876 881
877 882 super(LoadBalancedView, self).set_flags(**kwargs)
878 883 for name in ('follow', 'after'):
879 884 if name in kwargs:
880 885 value = kwargs[name]
881 886 if self._validate_dependency(value):
882 887 setattr(self, name, value)
883 888 else:
884 889 raise ValueError("Invalid dependency: %r"%value)
885 890 if 'timeout' in kwargs:
886 891 t = kwargs['timeout']
887 892 if not isinstance(t, (int, long, float, type(None))):
888 893 raise TypeError("Invalid type for timeout: %r"%type(t))
889 894 if t is not None:
890 895 if t < 0:
891 896 raise ValueError("Invalid timeout: %s"%t)
892 897 self.timeout = t
893 898
894 899 @sync_results
895 900 @save_ids
896 901 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
897 902 after=None, follow=None, timeout=None,
898 903 targets=None, retries=None):
899 904 """calls f(*args, **kwargs) on a remote engine, returning the result.
900 905
901 906 This method temporarily sets all of `apply`'s flags for a single call.
902 907
903 908 Parameters
904 909 ----------
905 910
906 911 f : callable
907 912
908 913 args : list [default: empty]
909 914
910 915 kwargs : dict [default: empty]
911 916
912 917 block : bool [default: self.block]
913 918 whether to block
914 919 track : bool [default: self.track]
915 920 whether to ask zmq to track the message, for safe non-copying sends
916 921
917 922 !!!!!! TODO: THE REST HERE !!!!
918 923
919 924 Returns
920 925 -------
921 926
922 927 if self.block is False:
923 928 returns AsyncResult
924 929 else:
925 930 returns actual result of f(*args, **kwargs) on the engine(s)
926 931 This will be a list of self.targets is also a list (even length 1), or
927 932 the single result if self.targets is an integer engine id
928 933 """
929 934
930 935 # validate whether we can run
931 936 if self._socket.closed:
932 937 msg = "Task farming is disabled"
933 938 if self._task_scheme == 'pure':
934 939 msg += " because the pure ZMQ scheduler cannot handle"
935 940 msg += " disappearing engines."
936 941 raise RuntimeError(msg)
937 942
938 943 if self._task_scheme == 'pure':
939 944 # pure zmq scheme doesn't support extra features
940 945 msg = "Pure ZMQ scheduler doesn't support the following flags:"
941 946 "follow, after, retries, targets, timeout"
942 947 if (follow or after or retries or targets or timeout):
943 948 # hard fail on Scheduler flags
944 949 raise RuntimeError(msg)
945 950 if isinstance(f, dependent):
946 951 # soft warn on functional dependencies
947 952 warnings.warn(msg, RuntimeWarning)
948 953
949 954 # build args
950 955 args = [] if args is None else args
951 956 kwargs = {} if kwargs is None else kwargs
952 957 block = self.block if block is None else block
953 958 track = self.track if track is None else track
954 959 after = self.after if after is None else after
955 960 retries = self.retries if retries is None else retries
956 961 follow = self.follow if follow is None else follow
957 962 timeout = self.timeout if timeout is None else timeout
958 963 targets = self.targets if targets is None else targets
959 964
960 965 if not isinstance(retries, int):
961 966 raise TypeError('retries must be int, not %r'%type(retries))
962 967
963 968 if targets is None:
964 969 idents = []
965 970 else:
966 971 idents = self.client._build_targets(targets)[0]
967 972
968 973 after = self._render_dependency(after)
969 974 follow = self._render_dependency(follow)
970 975 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
971 976
972 977 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
973 978 subheader=subheader)
974 979 tracker = None if track is False else msg['tracker']
975 980
976 981 ar = AsyncResult(self.client, msg['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
977 982
978 983 if block:
979 984 try:
980 985 return ar.get()
981 986 except KeyboardInterrupt:
982 987 pass
983 988 return ar
984 989
985 990 @spin_after
986 991 @save_ids
987 992 def map(self, f, *sequences, **kwargs):
988 993 """view.map(f, *sequences, block=self.block, chunksize=1) => list|AsyncMapResult
989 994
990 995 Parallel version of builtin `map`, load-balanced by this View.
991 996
992 997 `block`, and `chunksize` can be specified by keyword only.
993 998
994 999 Each `chunksize` elements will be a separate task, and will be
995 1000 load-balanced. This lets individual elements be available for iteration
996 1001 as soon as they arrive.
997 1002
998 1003 Parameters
999 1004 ----------
1000 1005
1001 1006 f : callable
1002 1007 function to be mapped
1003 1008 *sequences: one or more sequences of matching length
1004 1009 the sequences to be distributed and passed to `f`
1005 1010 block : bool
1006 1011 whether to wait for the result or not [default self.block]
1007 1012 track : bool
1008 1013 whether to create a MessageTracker to allow the user to
1009 1014 safely edit after arrays and buffers during non-copying
1010 1015 sends.
1011 1016 chunksize : int
1012 1017 how many elements should be in each task [default 1]
1013 1018
1014 1019 Returns
1015 1020 -------
1016 1021
1017 1022 if block=False:
1018 1023 AsyncMapResult
1019 1024 An object like AsyncResult, but which reassembles the sequence of results
1020 1025 into a single list. AsyncMapResults can be iterated through before all
1021 1026 results are complete.
1022 1027 else:
1023 1028 the result of map(f,*sequences)
1024 1029
1025 1030 """
1026 1031
1027 1032 # default
1028 1033 block = kwargs.get('block', self.block)
1029 1034 chunksize = kwargs.get('chunksize', 1)
1030 1035
1031 1036 keyset = set(kwargs.keys())
1032 1037 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1033 1038 if extra_keys:
1034 1039 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1035 1040
1036 1041 assert len(sequences) > 0, "must have some sequences to map onto!"
1037 1042
1038 1043 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
1039 1044 return pf.map(*sequences)
1040 1045
1041 1046 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,196 +1,201 b''
1 """Dependency utilities"""
1 """Dependency utilities
2
3 Authors:
4
5 * Min RK
6 """
2 7 #-----------------------------------------------------------------------------
3 8 # Copyright (C) 2010-2011 The IPython Development Team
4 9 #
5 10 # Distributed under the terms of the BSD License. The full license is in
6 11 # the file COPYING, distributed as part of this software.
7 12 #-----------------------------------------------------------------------------
8 13
9 14 from types import ModuleType
10 15
11 16 from IPython.parallel.client.asyncresult import AsyncResult
12 17 from IPython.parallel.error import UnmetDependency
13 18 from IPython.parallel.util import interactive
14 19
15 20 class depend(object):
16 21 """Dependency decorator, for use with tasks.
17 22
18 23 `@depend` lets you define a function for engine dependencies
19 24 just like you use `apply` for tasks.
20 25
21 26
22 27 Examples
23 28 --------
24 29 ::
25 30
26 31 @depend(df, a,b, c=5)
27 32 def f(m,n,p)
28 33
29 34 view.apply(f, 1,2,3)
30 35
31 36 will call df(a,b,c=5) on the engine, and if it returns False or
32 37 raises an UnmetDependency error, then the task will not be run
33 38 and another engine will be tried.
34 39 """
35 40 def __init__(self, f, *args, **kwargs):
36 41 self.f = f
37 42 self.args = args
38 43 self.kwargs = kwargs
39 44
40 45 def __call__(self, f):
41 46 return dependent(f, self.f, *self.args, **self.kwargs)
42 47
43 48 class dependent(object):
44 49 """A function that depends on another function.
45 50 This is an object to prevent the closure used
46 51 in traditional decorators, which are not picklable.
47 52 """
48 53
49 54 def __init__(self, f, df, *dargs, **dkwargs):
50 55 self.f = f
51 56 self.func_name = getattr(f, '__name__', 'f')
52 57 self.df = df
53 58 self.dargs = dargs
54 59 self.dkwargs = dkwargs
55 60
56 61 def __call__(self, *args, **kwargs):
57 62 # if hasattr(self.f, 'func_globals') and hasattr(self.df, 'func_globals'):
58 63 # self.df.func_globals = self.f.func_globals
59 64 if self.df(*self.dargs, **self.dkwargs) is False:
60 65 raise UnmetDependency()
61 66 return self.f(*args, **kwargs)
62 67
63 68 @property
64 69 def __name__(self):
65 70 return self.func_name
66 71
67 72 @interactive
68 73 def _require(*names):
69 74 """Helper for @require decorator."""
70 75 from IPython.parallel.error import UnmetDependency
71 76 user_ns = globals()
72 77 for name in names:
73 78 if name in user_ns:
74 79 continue
75 80 try:
76 81 exec 'import %s'%name in user_ns
77 82 except ImportError:
78 83 raise UnmetDependency(name)
79 84 return True
80 85
81 86 def require(*mods):
82 87 """Simple decorator for requiring names to be importable.
83 88
84 89 Examples
85 90 --------
86 91
87 92 In [1]: @require('numpy')
88 93 ...: def norm(a):
89 94 ...: import numpy
90 95 ...: return numpy.linalg.norm(a,2)
91 96 """
92 97 names = []
93 98 for mod in mods:
94 99 if isinstance(mod, ModuleType):
95 100 mod = mod.__name__
96 101
97 102 if isinstance(mod, basestring):
98 103 names.append(mod)
99 104 else:
100 105 raise TypeError("names must be modules or module names, not %s"%type(mod))
101 106
102 107 return depend(_require, *names)
103 108
104 109 class Dependency(set):
105 110 """An object for representing a set of msg_id dependencies.
106 111
107 112 Subclassed from set().
108 113
109 114 Parameters
110 115 ----------
111 116 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
112 117 The msg_ids to depend on
113 118 all : bool [default True]
114 119 Whether the dependency should be considered met when *all* depending tasks have completed
115 120 or only when *any* have been completed.
116 121 success : bool [default True]
117 122 Whether to consider successes as fulfilling dependencies.
118 123 failure : bool [default False]
119 124 Whether to consider failures as fulfilling dependencies.
120 125
121 126 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
122 127 as soon as the first depended-upon task fails.
123 128 """
124 129
125 130 all=True
126 131 success=True
127 132 failure=True
128 133
129 134 def __init__(self, dependencies=[], all=True, success=True, failure=False):
130 135 if isinstance(dependencies, dict):
131 136 # load from dict
132 137 all = dependencies.get('all', True)
133 138 success = dependencies.get('success', success)
134 139 failure = dependencies.get('failure', failure)
135 140 dependencies = dependencies.get('dependencies', [])
136 141 ids = []
137 142
138 143 # extract ids from various sources:
139 144 if isinstance(dependencies, (basestring, AsyncResult)):
140 145 dependencies = [dependencies]
141 146 for d in dependencies:
142 147 if isinstance(d, basestring):
143 148 ids.append(d)
144 149 elif isinstance(d, AsyncResult):
145 150 ids.extend(d.msg_ids)
146 151 else:
147 152 raise TypeError("invalid dependency type: %r"%type(d))
148 153
149 154 set.__init__(self, ids)
150 155 self.all = all
151 156 if not (success or failure):
152 157 raise ValueError("Must depend on at least one of successes or failures!")
153 158 self.success=success
154 159 self.failure = failure
155 160
156 161 def check(self, completed, failed=None):
157 162 """check whether our dependencies have been met."""
158 163 if len(self) == 0:
159 164 return True
160 165 against = set()
161 166 if self.success:
162 167 against = completed
163 168 if failed is not None and self.failure:
164 169 against = against.union(failed)
165 170 if self.all:
166 171 return self.issubset(against)
167 172 else:
168 173 return not self.isdisjoint(against)
169 174
170 175 def unreachable(self, completed, failed=None):
171 176 """return whether this dependency has become impossible."""
172 177 if len(self) == 0:
173 178 return False
174 179 against = set()
175 180 if not self.success:
176 181 against = completed
177 182 if failed is not None and not self.failure:
178 183 against = against.union(failed)
179 184 if self.all:
180 185 return not self.isdisjoint(against)
181 186 else:
182 187 return self.issubset(against)
183 188
184 189
185 190 def as_dict(self):
186 191 """Represent this dependency as a dict. For json compatibility."""
187 192 return dict(
188 193 dependencies=list(self),
189 194 all=self.all,
190 195 success=self.success,
191 196 failure=self.failure
192 197 )
193 198
194 199
195 200 __all__ = ['depend', 'require', 'dependent', 'Dependency']
196 201
@@ -1,180 +1,185 b''
1 1 """A Task logger that presents our DB interface,
2 2 but exists entirely in memory and implemented with dicts.
3 3
4 Authors:
5
6 * Min RK
7
8
4 9 TaskRecords are dicts of the form:
5 10 {
6 11 'msg_id' : str(uuid),
7 12 'client_uuid' : str(uuid),
8 13 'engine_uuid' : str(uuid) or None,
9 14 'header' : dict(header),
10 15 'content': dict(content),
11 16 'buffers': list(buffers),
12 17 'submitted': datetime,
13 18 'started': datetime or None,
14 19 'completed': datetime or None,
15 20 'resubmitted': datetime or None,
16 21 'result_header' : dict(header) or None,
17 22 'result_content' : dict(content) or None,
18 23 'result_buffers' : list(buffers) or None,
19 24 }
20 25 With this info, many of the special categories of tasks can be defined by query:
21 26
22 27 pending: completed is None
23 28 client's outstanding: client_uuid = uuid && completed is None
24 29 MIA: arrived is None (and completed is None)
25 30 etc.
26 31
27 32 EngineRecords are dicts of the form:
28 33 {
29 34 'eid' : int(id),
30 35 'uuid': str(uuid)
31 36 }
32 37 This may be extended, but is currently.
33 38
34 39 We support a subset of mongodb operators:
35 40 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
36 41 """
37 42 #-----------------------------------------------------------------------------
38 # Copyright (C) 2010 The IPython Development Team
43 # Copyright (C) 2010-2011 The IPython Development Team
39 44 #
40 45 # Distributed under the terms of the BSD License. The full license is in
41 46 # the file COPYING, distributed as part of this software.
42 47 #-----------------------------------------------------------------------------
43 48
44 49
45 50 from datetime import datetime
46 51
47 52 from IPython.config.configurable import LoggingConfigurable
48 53
49 54 from IPython.utils.traitlets import Dict, Unicode, Instance
50 55
51 56 filters = {
52 57 '$lt' : lambda a,b: a < b,
53 58 '$gt' : lambda a,b: b > a,
54 59 '$eq' : lambda a,b: a == b,
55 60 '$ne' : lambda a,b: a != b,
56 61 '$lte': lambda a,b: a <= b,
57 62 '$gte': lambda a,b: a >= b,
58 63 '$in' : lambda a,b: a in b,
59 64 '$nin': lambda a,b: a not in b,
60 65 '$all': lambda a,b: all([ a in bb for bb in b ]),
61 66 '$mod': lambda a,b: a%b[0] == b[1],
62 67 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
63 68 }
64 69
65 70
66 71 class CompositeFilter(object):
67 72 """Composite filter for matching multiple properties."""
68 73
69 74 def __init__(self, dikt):
70 75 self.tests = []
71 76 self.values = []
72 77 for key, value in dikt.iteritems():
73 78 self.tests.append(filters[key])
74 79 self.values.append(value)
75 80
76 81 def __call__(self, value):
77 82 for test,check in zip(self.tests, self.values):
78 83 if not test(value, check):
79 84 return False
80 85 return True
81 86
82 87 class BaseDB(LoggingConfigurable):
83 88 """Empty Parent class so traitlets work on DB."""
84 89 # base configurable traits:
85 90 session = Unicode("")
86 91
87 92 class DictDB(BaseDB):
88 93 """Basic in-memory dict-based object for saving Task Records.
89 94
90 95 This is the first object to present the DB interface
91 96 for logging tasks out of memory.
92 97
93 98 The interface is based on MongoDB, so adding a MongoDB
94 99 backend should be straightforward.
95 100 """
96 101
97 102 _records = Dict()
98 103
99 104 def _match_one(self, rec, tests):
100 105 """Check if a specific record matches tests."""
101 106 for key,test in tests.iteritems():
102 107 if not test(rec.get(key, None)):
103 108 return False
104 109 return True
105 110
106 111 def _match(self, check):
107 112 """Find all the matches for a check dict."""
108 113 matches = []
109 114 tests = {}
110 115 for k,v in check.iteritems():
111 116 if isinstance(v, dict):
112 117 tests[k] = CompositeFilter(v)
113 118 else:
114 119 tests[k] = lambda o: o==v
115 120
116 121 for rec in self._records.itervalues():
117 122 if self._match_one(rec, tests):
118 123 matches.append(rec)
119 124 return matches
120 125
121 126 def _extract_subdict(self, rec, keys):
122 127 """extract subdict of keys"""
123 128 d = {}
124 129 d['msg_id'] = rec['msg_id']
125 130 for key in keys:
126 131 d[key] = rec[key]
127 132 return d
128 133
129 134 def add_record(self, msg_id, rec):
130 135 """Add a new Task Record, by msg_id."""
131 136 if self._records.has_key(msg_id):
132 137 raise KeyError("Already have msg_id %r"%(msg_id))
133 138 self._records[msg_id] = rec
134 139
135 140 def get_record(self, msg_id):
136 141 """Get a specific Task Record, by msg_id."""
137 142 if not self._records.has_key(msg_id):
138 143 raise KeyError("No such msg_id %r"%(msg_id))
139 144 return self._records[msg_id]
140 145
141 146 def update_record(self, msg_id, rec):
142 147 """Update the data in an existing record."""
143 148 self._records[msg_id].update(rec)
144 149
145 150 def drop_matching_records(self, check):
146 151 """Remove a record from the DB."""
147 152 matches = self._match(check)
148 153 for m in matches:
149 154 del self._records[m['msg_id']]
150 155
151 156 def drop_record(self, msg_id):
152 157 """Remove a record from the DB."""
153 158 del self._records[msg_id]
154 159
155 160
156 161 def find_records(self, check, keys=None):
157 162 """Find records matching a query dict, optionally extracting subset of keys.
158 163
159 164 Returns dict keyed by msg_id of matching records.
160 165
161 166 Parameters
162 167 ----------
163 168
164 169 check: dict
165 170 mongodb-style query argument
166 171 keys: list of strs [optional]
167 172 if specified, the subset of keys to extract. msg_id will *always* be
168 173 included.
169 174 """
170 175 matches = self._match(check)
171 176 if keys:
172 177 return [ self._extract_subdict(rec, keys) for rec in matches ]
173 178 else:
174 179 return matches
175 180
176 181
177 182 def get_history(self):
178 183 """get all msg_ids, ordered by time submitted."""
179 184 msg_ids = self._records.keys()
180 185 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
@@ -1,165 +1,169 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
4 4 and hearts are tracked based on their XREQ identities.
5
6 Authors:
7
8 * Min RK
5 9 """
6 10 #-----------------------------------------------------------------------------
7 11 # Copyright (C) 2010-2011 The IPython Development Team
8 12 #
9 13 # Distributed under the terms of the BSD License. The full license is in
10 14 # the file COPYING, distributed as part of this software.
11 15 #-----------------------------------------------------------------------------
12 16
13 17 from __future__ import print_function
14 18 import time
15 19 import uuid
16 20
17 21 import zmq
18 22 from zmq.devices import ThreadDevice
19 23 from zmq.eventloop import ioloop, zmqstream
20 24
21 25 from IPython.config.configurable import LoggingConfigurable
22 26 from IPython.utils.traitlets import Set, Instance, CFloat
23 27
24 28 class Heart(object):
25 29 """A basic heart object for responding to a HeartMonitor.
26 30 This is a simple wrapper with defaults for the most common
27 31 Device model for responding to heartbeats.
28 32
29 33 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
30 34 SUB/XREQ for in/out.
31 35
32 36 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
33 37 device=None
34 38 id=None
35 39 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.XREQ, heart_id=None):
36 40 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
37 41 self.device.daemon=True
38 42 self.device.connect_in(in_addr)
39 43 self.device.connect_out(out_addr)
40 44 if in_type == zmq.SUB:
41 45 self.device.setsockopt_in(zmq.SUBSCRIBE, "")
42 46 if heart_id is None:
43 47 heart_id = str(uuid.uuid4())
44 48 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
45 49 self.id = heart_id
46 50
47 51 def start(self):
48 52 return self.device.start()
49 53
50 54 class HeartMonitor(LoggingConfigurable):
51 55 """A basic HeartMonitor class
52 56 pingstream: a PUB stream
53 57 pongstream: an XREP stream
54 58 period: the period of the heartbeat in milliseconds"""
55 59
56 60 period=CFloat(1000, config=True,
57 61 help='The frequency at which the Hub pings the engines for heartbeats '
58 62 ' (in ms) [default: 100]',
59 63 )
60 64
61 65 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
62 66 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
63 67 loop = Instance('zmq.eventloop.ioloop.IOLoop')
64 68 def _loop_default(self):
65 69 return ioloop.IOLoop.instance()
66 70
67 71 # not settable:
68 72 hearts=Set()
69 73 responses=Set()
70 74 on_probation=Set()
71 75 last_ping=CFloat(0)
72 76 _new_handlers = Set()
73 77 _failure_handlers = Set()
74 78 lifetime = CFloat(0)
75 79 tic = CFloat(0)
76 80
77 81 def __init__(self, **kwargs):
78 82 super(HeartMonitor, self).__init__(**kwargs)
79 83
80 84 self.pongstream.on_recv(self.handle_pong)
81 85
82 86 def start(self):
83 87 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
84 88 self.caller.start()
85 89
86 90 def add_new_heart_handler(self, handler):
87 91 """add a new handler for new hearts"""
88 92 self.log.debug("heartbeat::new_heart_handler: %s"%handler)
89 93 self._new_handlers.add(handler)
90 94
91 95 def add_heart_failure_handler(self, handler):
92 96 """add a new handler for heart failure"""
93 97 self.log.debug("heartbeat::new heart failure handler: %s"%handler)
94 98 self._failure_handlers.add(handler)
95 99
96 100 def beat(self):
97 101 self.pongstream.flush()
98 102 self.last_ping = self.lifetime
99 103
100 104 toc = time.time()
101 105 self.lifetime += toc-self.tic
102 106 self.tic = toc
103 107 # self.log.debug("heartbeat::%s"%self.lifetime)
104 108 goodhearts = self.hearts.intersection(self.responses)
105 109 missed_beats = self.hearts.difference(goodhearts)
106 110 heartfailures = self.on_probation.intersection(missed_beats)
107 111 newhearts = self.responses.difference(goodhearts)
108 112 map(self.handle_new_heart, newhearts)
109 113 map(self.handle_heart_failure, heartfailures)
110 114 self.on_probation = missed_beats.intersection(self.hearts)
111 115 self.responses = set()
112 116 # print self.on_probation, self.hearts
113 117 # self.log.debug("heartbeat::beat %.3f, %i beating hearts"%(self.lifetime, len(self.hearts)))
114 118 self.pingstream.send(str(self.lifetime))
115 119
116 120 def handle_new_heart(self, heart):
117 121 if self._new_handlers:
118 122 for handler in self._new_handlers:
119 123 handler(heart)
120 124 else:
121 125 self.log.info("heartbeat::yay, got new heart %s!"%heart)
122 126 self.hearts.add(heart)
123 127
124 128 def handle_heart_failure(self, heart):
125 129 if self._failure_handlers:
126 130 for handler in self._failure_handlers:
127 131 try:
128 132 handler(heart)
129 133 except Exception as e:
130 134 self.log.error("heartbeat::Bad Handler! %s"%handler, exc_info=True)
131 135 pass
132 136 else:
133 137 self.log.info("heartbeat::Heart %s failed :("%heart)
134 138 self.hearts.remove(heart)
135 139
136 140
137 141 def handle_pong(self, msg):
138 142 "a heart just beat"
139 143 if msg[1] == str(self.lifetime):
140 144 delta = time.time()-self.tic
141 145 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
142 146 self.responses.add(msg[0])
143 147 elif msg[1] == str(self.last_ping):
144 148 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
145 149 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond"%(msg[0], 1000*delta))
146 150 self.responses.add(msg[0])
147 151 else:
148 152 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)"%
149 153 (msg[1],self.lifetime))
150 154
151 155
152 156 if __name__ == '__main__':
153 157 loop = ioloop.IOLoop.instance()
154 158 context = zmq.Context()
155 159 pub = context.socket(zmq.PUB)
156 160 pub.bind('tcp://127.0.0.1:5555')
157 161 xrep = context.socket(zmq.XREP)
158 162 xrep.bind('tcp://127.0.0.1:5556')
159 163
160 164 outstream = zmqstream.ZMQStream(pub, loop)
161 165 instream = zmqstream.ZMQStream(xrep, loop)
162 166
163 167 hb = HeartMonitor(loop, outstream, instream)
164 168
165 169 loop.start()
@@ -1,1284 +1,1288 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5
6 Authors:
7
8 * Min RK
5 9 """
6 10 #-----------------------------------------------------------------------------
7 11 # Copyright (C) 2010 The IPython Development Team
8 12 #
9 13 # Distributed under the terms of the BSD License. The full license is in
10 14 # the file COPYING, distributed as part of this software.
11 15 #-----------------------------------------------------------------------------
12 16
13 17 #-----------------------------------------------------------------------------
14 18 # Imports
15 19 #-----------------------------------------------------------------------------
16 20 from __future__ import print_function
17 21
18 22 import sys
19 23 import time
20 24 from datetime import datetime
21 25
22 26 import zmq
23 27 from zmq.eventloop import ioloop
24 28 from zmq.eventloop.zmqstream import ZMQStream
25 29
26 30 # internal:
27 31 from IPython.utils.importstring import import_item
28 32 from IPython.utils.traitlets import (
29 33 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CStr
30 34 )
31 35
32 36 from IPython.parallel import error, util
33 37 from IPython.parallel.factory import RegistrationFactory
34 38
35 39 from IPython.zmq.session import SessionFactory
36 40
37 41 from .heartmonitor import HeartMonitor
38 42
39 43 #-----------------------------------------------------------------------------
40 44 # Code
41 45 #-----------------------------------------------------------------------------
42 46
43 47 def _passer(*args, **kwargs):
44 48 return
45 49
46 50 def _printer(*args, **kwargs):
47 51 print (args)
48 52 print (kwargs)
49 53
50 54 def empty_record():
51 55 """Return an empty dict with all record keys."""
52 56 return {
53 57 'msg_id' : None,
54 58 'header' : None,
55 59 'content': None,
56 60 'buffers': None,
57 61 'submitted': None,
58 62 'client_uuid' : None,
59 63 'engine_uuid' : None,
60 64 'started': None,
61 65 'completed': None,
62 66 'resubmitted': None,
63 67 'result_header' : None,
64 68 'result_content' : None,
65 69 'result_buffers' : None,
66 70 'queue' : None,
67 71 'pyin' : None,
68 72 'pyout': None,
69 73 'pyerr': None,
70 74 'stdout': '',
71 75 'stderr': '',
72 76 }
73 77
74 78 def init_record(msg):
75 79 """Initialize a TaskRecord based on a request."""
76 80 header = msg['header']
77 81 return {
78 82 'msg_id' : header['msg_id'],
79 83 'header' : header,
80 84 'content': msg['content'],
81 85 'buffers': msg['buffers'],
82 86 'submitted': header['date'],
83 87 'client_uuid' : None,
84 88 'engine_uuid' : None,
85 89 'started': None,
86 90 'completed': None,
87 91 'resubmitted': None,
88 92 'result_header' : None,
89 93 'result_content' : None,
90 94 'result_buffers' : None,
91 95 'queue' : None,
92 96 'pyin' : None,
93 97 'pyout': None,
94 98 'pyerr': None,
95 99 'stdout': '',
96 100 'stderr': '',
97 101 }
98 102
99 103
100 104 class EngineConnector(HasTraits):
101 105 """A simple object for accessing the various zmq connections of an object.
102 106 Attributes are:
103 107 id (int): engine ID
104 108 uuid (str): uuid (unused?)
105 109 queue (str): identity of queue's XREQ socket
106 110 registration (str): identity of registration XREQ socket
107 111 heartbeat (str): identity of heartbeat XREQ socket
108 112 """
109 113 id=Int(0)
110 114 queue=CStr()
111 115 control=CStr()
112 116 registration=CStr()
113 117 heartbeat=CStr()
114 118 pending=Set()
115 119
116 120 class HubFactory(RegistrationFactory):
117 121 """The Configurable for setting up a Hub."""
118 122
119 123 # port-pairs for monitoredqueues:
120 124 hb = Tuple(Int,Int,config=True,
121 125 help="""XREQ/SUB Port pair for Engine heartbeats""")
122 126 def _hb_default(self):
123 127 return tuple(util.select_random_ports(2))
124 128
125 129 mux = Tuple(Int,Int,config=True,
126 130 help="""Engine/Client Port pair for MUX queue""")
127 131
128 132 def _mux_default(self):
129 133 return tuple(util.select_random_ports(2))
130 134
131 135 task = Tuple(Int,Int,config=True,
132 136 help="""Engine/Client Port pair for Task queue""")
133 137 def _task_default(self):
134 138 return tuple(util.select_random_ports(2))
135 139
136 140 control = Tuple(Int,Int,config=True,
137 141 help="""Engine/Client Port pair for Control queue""")
138 142
139 143 def _control_default(self):
140 144 return tuple(util.select_random_ports(2))
141 145
142 146 iopub = Tuple(Int,Int,config=True,
143 147 help="""Engine/Client Port pair for IOPub relay""")
144 148
145 149 def _iopub_default(self):
146 150 return tuple(util.select_random_ports(2))
147 151
148 152 # single ports:
149 153 mon_port = Int(config=True,
150 154 help="""Monitor (SUB) port for queue traffic""")
151 155
152 156 def _mon_port_default(self):
153 157 return util.select_random_ports(1)[0]
154 158
155 159 notifier_port = Int(config=True,
156 160 help="""PUB port for sending engine status notifications""")
157 161
158 162 def _notifier_port_default(self):
159 163 return util.select_random_ports(1)[0]
160 164
161 165 engine_ip = Unicode('127.0.0.1', config=True,
162 166 help="IP on which to listen for engine connections. [default: loopback]")
163 167 engine_transport = Unicode('tcp', config=True,
164 168 help="0MQ transport for engine connections. [default: tcp]")
165 169
166 170 client_ip = Unicode('127.0.0.1', config=True,
167 171 help="IP on which to listen for client connections. [default: loopback]")
168 172 client_transport = Unicode('tcp', config=True,
169 173 help="0MQ transport for client connections. [default : tcp]")
170 174
171 175 monitor_ip = Unicode('127.0.0.1', config=True,
172 176 help="IP on which to listen for monitor messages. [default: loopback]")
173 177 monitor_transport = Unicode('tcp', config=True,
174 178 help="0MQ transport for monitor messages. [default : tcp]")
175 179
176 180 monitor_url = Unicode('')
177 181
178 182 db_class = Unicode('IPython.parallel.controller.dictdb.DictDB', config=True,
179 183 help="""The class to use for the DB backend""")
180 184
181 185 # not configurable
182 186 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
183 187 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
184 188
185 189 def _ip_changed(self, name, old, new):
186 190 self.engine_ip = new
187 191 self.client_ip = new
188 192 self.monitor_ip = new
189 193 self._update_monitor_url()
190 194
191 195 def _update_monitor_url(self):
192 196 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
193 197
194 198 def _transport_changed(self, name, old, new):
195 199 self.engine_transport = new
196 200 self.client_transport = new
197 201 self.monitor_transport = new
198 202 self._update_monitor_url()
199 203
200 204 def __init__(self, **kwargs):
201 205 super(HubFactory, self).__init__(**kwargs)
202 206 self._update_monitor_url()
203 207
204 208
205 209 def construct(self):
206 210 self.init_hub()
207 211
208 212 def start(self):
209 213 self.heartmonitor.start()
210 214 self.log.info("Heartmonitor started")
211 215
212 216 def init_hub(self):
213 217 """construct"""
214 218 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
215 219 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
216 220
217 221 ctx = self.context
218 222 loop = self.loop
219 223
220 224 # Registrar socket
221 225 q = ZMQStream(ctx.socket(zmq.XREP), loop)
222 226 q.bind(client_iface % self.regport)
223 227 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
224 228 if self.client_ip != self.engine_ip:
225 229 q.bind(engine_iface % self.regport)
226 230 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
227 231
228 232 ### Engine connections ###
229 233
230 234 # heartbeat
231 235 hpub = ctx.socket(zmq.PUB)
232 236 hpub.bind(engine_iface % self.hb[0])
233 237 hrep = ctx.socket(zmq.XREP)
234 238 hrep.bind(engine_iface % self.hb[1])
235 239 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
236 240 pingstream=ZMQStream(hpub,loop),
237 241 pongstream=ZMQStream(hrep,loop)
238 242 )
239 243
240 244 ### Client connections ###
241 245 # Notifier socket
242 246 n = ZMQStream(ctx.socket(zmq.PUB), loop)
243 247 n.bind(client_iface%self.notifier_port)
244 248
245 249 ### build and launch the queues ###
246 250
247 251 # monitor socket
248 252 sub = ctx.socket(zmq.SUB)
249 253 sub.setsockopt(zmq.SUBSCRIBE, "")
250 254 sub.bind(self.monitor_url)
251 255 sub.bind('inproc://monitor')
252 256 sub = ZMQStream(sub, loop)
253 257
254 258 # connect the db
255 259 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
256 260 # cdir = self.config.Global.cluster_dir
257 261 self.db = import_item(str(self.db_class))(session=self.session.session,
258 262 config=self.config, log=self.log)
259 263 time.sleep(.25)
260 264 try:
261 265 scheme = self.config.TaskScheduler.scheme_name
262 266 except AttributeError:
263 267 from .scheduler import TaskScheduler
264 268 scheme = TaskScheduler.scheme_name.get_default_value()
265 269 # build connection dicts
266 270 self.engine_info = {
267 271 'control' : engine_iface%self.control[1],
268 272 'mux': engine_iface%self.mux[1],
269 273 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
270 274 'task' : engine_iface%self.task[1],
271 275 'iopub' : engine_iface%self.iopub[1],
272 276 # 'monitor' : engine_iface%self.mon_port,
273 277 }
274 278
275 279 self.client_info = {
276 280 'control' : client_iface%self.control[0],
277 281 'mux': client_iface%self.mux[0],
278 282 'task' : (scheme, client_iface%self.task[0]),
279 283 'iopub' : client_iface%self.iopub[0],
280 284 'notification': client_iface%self.notifier_port
281 285 }
282 286 self.log.debug("Hub engine addrs: %s"%self.engine_info)
283 287 self.log.debug("Hub client addrs: %s"%self.client_info)
284 288
285 289 # resubmit stream
286 290 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
287 291 url = util.disambiguate_url(self.client_info['task'][-1])
288 292 r.setsockopt(zmq.IDENTITY, self.session.session)
289 293 r.connect(url)
290 294
291 295 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
292 296 query=q, notifier=n, resubmit=r, db=self.db,
293 297 engine_info=self.engine_info, client_info=self.client_info,
294 298 log=self.log)
295 299
296 300
297 301 class Hub(SessionFactory):
298 302 """The IPython Controller Hub with 0MQ connections
299 303
300 304 Parameters
301 305 ==========
302 306 loop: zmq IOLoop instance
303 307 session: Session object
304 308 <removed> context: zmq context for creating new connections (?)
305 309 queue: ZMQStream for monitoring the command queue (SUB)
306 310 query: ZMQStream for engine registration and client queries requests (XREP)
307 311 heartbeat: HeartMonitor object checking the pulse of the engines
308 312 notifier: ZMQStream for broadcasting engine registration changes (PUB)
309 313 db: connection to db for out of memory logging of commands
310 314 NotImplemented
311 315 engine_info: dict of zmq connection information for engines to connect
312 316 to the queues.
313 317 client_info: dict of zmq connection information for engines to connect
314 318 to the queues.
315 319 """
316 320 # internal data structures:
317 321 ids=Set() # engine IDs
318 322 keytable=Dict()
319 323 by_ident=Dict()
320 324 engines=Dict()
321 325 clients=Dict()
322 326 hearts=Dict()
323 327 pending=Set()
324 328 queues=Dict() # pending msg_ids keyed by engine_id
325 329 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
326 330 completed=Dict() # completed msg_ids keyed by engine_id
327 331 all_completed=Set() # completed msg_ids keyed by engine_id
328 332 dead_engines=Set() # completed msg_ids keyed by engine_id
329 333 unassigned=Set() # set of task msg_ds not yet assigned a destination
330 334 incoming_registrations=Dict()
331 335 registration_timeout=Int()
332 336 _idcounter=Int(0)
333 337
334 338 # objects from constructor:
335 339 query=Instance(ZMQStream)
336 340 monitor=Instance(ZMQStream)
337 341 notifier=Instance(ZMQStream)
338 342 resubmit=Instance(ZMQStream)
339 343 heartmonitor=Instance(HeartMonitor)
340 344 db=Instance(object)
341 345 client_info=Dict()
342 346 engine_info=Dict()
343 347
344 348
345 349 def __init__(self, **kwargs):
346 350 """
347 351 # universal:
348 352 loop: IOLoop for creating future connections
349 353 session: streamsession for sending serialized data
350 354 # engine:
351 355 queue: ZMQStream for monitoring queue messages
352 356 query: ZMQStream for engine+client registration and client requests
353 357 heartbeat: HeartMonitor object for tracking engines
354 358 # extra:
355 359 db: ZMQStream for db connection (NotImplemented)
356 360 engine_info: zmq address/protocol dict for engine connections
357 361 client_info: zmq address/protocol dict for client connections
358 362 """
359 363
360 364 super(Hub, self).__init__(**kwargs)
361 365 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
362 366
363 367 # validate connection dicts:
364 368 for k,v in self.client_info.iteritems():
365 369 if k == 'task':
366 370 util.validate_url_container(v[1])
367 371 else:
368 372 util.validate_url_container(v)
369 373 # util.validate_url_container(self.client_info)
370 374 util.validate_url_container(self.engine_info)
371 375
372 376 # register our callbacks
373 377 self.query.on_recv(self.dispatch_query)
374 378 self.monitor.on_recv(self.dispatch_monitor_traffic)
375 379
376 380 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
377 381 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
378 382
379 383 self.monitor_handlers = { 'in' : self.save_queue_request,
380 384 'out': self.save_queue_result,
381 385 'intask': self.save_task_request,
382 386 'outtask': self.save_task_result,
383 387 'tracktask': self.save_task_destination,
384 388 'incontrol': _passer,
385 389 'outcontrol': _passer,
386 390 'iopub': self.save_iopub_message,
387 391 }
388 392
389 393 self.query_handlers = {'queue_request': self.queue_status,
390 394 'result_request': self.get_results,
391 395 'history_request': self.get_history,
392 396 'db_request': self.db_query,
393 397 'purge_request': self.purge_results,
394 398 'load_request': self.check_load,
395 399 'resubmit_request': self.resubmit_task,
396 400 'shutdown_request': self.shutdown_request,
397 401 'registration_request' : self.register_engine,
398 402 'unregistration_request' : self.unregister_engine,
399 403 'connection_request': self.connection_request,
400 404 }
401 405
402 406 # ignore resubmit replies
403 407 self.resubmit.on_recv(lambda msg: None, copy=False)
404 408
405 409 self.log.info("hub::created hub")
406 410
407 411 @property
408 412 def _next_id(self):
409 413 """gemerate a new ID.
410 414
411 415 No longer reuse old ids, just count from 0."""
412 416 newid = self._idcounter
413 417 self._idcounter += 1
414 418 return newid
415 419 # newid = 0
416 420 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
417 421 # # print newid, self.ids, self.incoming_registrations
418 422 # while newid in self.ids or newid in incoming:
419 423 # newid += 1
420 424 # return newid
421 425
422 426 #-----------------------------------------------------------------------------
423 427 # message validation
424 428 #-----------------------------------------------------------------------------
425 429
426 430 def _validate_targets(self, targets):
427 431 """turn any valid targets argument into a list of integer ids"""
428 432 if targets is None:
429 433 # default to all
430 434 targets = self.ids
431 435
432 436 if isinstance(targets, (int,str,unicode)):
433 437 # only one target specified
434 438 targets = [targets]
435 439 _targets = []
436 440 for t in targets:
437 441 # map raw identities to ids
438 442 if isinstance(t, (str,unicode)):
439 443 t = self.by_ident.get(t, t)
440 444 _targets.append(t)
441 445 targets = _targets
442 446 bad_targets = [ t for t in targets if t not in self.ids ]
443 447 if bad_targets:
444 448 raise IndexError("No Such Engine: %r"%bad_targets)
445 449 if not targets:
446 450 raise IndexError("No Engines Registered")
447 451 return targets
448 452
449 453 #-----------------------------------------------------------------------------
450 454 # dispatch methods (1 per stream)
451 455 #-----------------------------------------------------------------------------
452 456
453 457
454 458 def dispatch_monitor_traffic(self, msg):
455 459 """all ME and Task queue messages come through here, as well as
456 460 IOPub traffic."""
457 461 self.log.debug("monitor traffic: %r"%msg[:2])
458 462 switch = msg[0]
459 463 try:
460 464 idents, msg = self.session.feed_identities(msg[1:])
461 465 except ValueError:
462 466 idents=[]
463 467 if not idents:
464 468 self.log.error("Bad Monitor Message: %r"%msg)
465 469 return
466 470 handler = self.monitor_handlers.get(switch, None)
467 471 if handler is not None:
468 472 handler(idents, msg)
469 473 else:
470 474 self.log.error("Invalid monitor topic: %r"%switch)
471 475
472 476
473 477 def dispatch_query(self, msg):
474 478 """Route registration requests and queries from clients."""
475 479 try:
476 480 idents, msg = self.session.feed_identities(msg)
477 481 except ValueError:
478 482 idents = []
479 483 if not idents:
480 484 self.log.error("Bad Query Message: %r"%msg)
481 485 return
482 486 client_id = idents[0]
483 487 try:
484 488 msg = self.session.unpack_message(msg, content=True)
485 489 except Exception:
486 490 content = error.wrap_exception()
487 491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
488 492 self.session.send(self.query, "hub_error", ident=client_id,
489 493 content=content)
490 494 return
491 495 # print client_id, header, parent, content
492 496 #switch on message type:
493 497 msg_type = msg['msg_type']
494 498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
495 499 handler = self.query_handlers.get(msg_type, None)
496 500 try:
497 501 assert handler is not None, "Bad Message Type: %r"%msg_type
498 502 except:
499 503 content = error.wrap_exception()
500 504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
501 505 self.session.send(self.query, "hub_error", ident=client_id,
502 506 content=content)
503 507 return
504 508
505 509 else:
506 510 handler(idents, msg)
507 511
508 512 def dispatch_db(self, msg):
509 513 """"""
510 514 raise NotImplementedError
511 515
512 516 #---------------------------------------------------------------------------
513 517 # handler methods (1 per event)
514 518 #---------------------------------------------------------------------------
515 519
516 520 #----------------------- Heartbeat --------------------------------------
517 521
518 522 def handle_new_heart(self, heart):
519 523 """handler to attach to heartbeater.
520 524 Called when a new heart starts to beat.
521 525 Triggers completion of registration."""
522 526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
523 527 if heart not in self.incoming_registrations:
524 528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
525 529 else:
526 530 self.finish_registration(heart)
527 531
528 532
529 533 def handle_heart_failure(self, heart):
530 534 """handler to attach to heartbeater.
531 535 called when a previously registered heart fails to respond to beat request.
532 536 triggers unregistration"""
533 537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
534 538 eid = self.hearts.get(heart, None)
535 539 queue = self.engines[eid].queue
536 540 if eid is None:
537 541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
538 542 else:
539 543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
540 544
541 545 #----------------------- MUX Queue Traffic ------------------------------
542 546
543 547 def save_queue_request(self, idents, msg):
544 548 if len(idents) < 2:
545 549 self.log.error("invalid identity prefix: %r"%idents)
546 550 return
547 551 queue_id, client_id = idents[:2]
548 552 try:
549 553 msg = self.session.unpack_message(msg)
550 554 except Exception:
551 555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
552 556 return
553 557
554 558 eid = self.by_ident.get(queue_id, None)
555 559 if eid is None:
556 560 self.log.error("queue::target %r not registered"%queue_id)
557 561 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
558 562 return
559 563 record = init_record(msg)
560 564 msg_id = record['msg_id']
561 565 record['engine_uuid'] = queue_id
562 566 record['client_uuid'] = client_id
563 567 record['queue'] = 'mux'
564 568
565 569 try:
566 570 # it's posible iopub arrived first:
567 571 existing = self.db.get_record(msg_id)
568 572 for key,evalue in existing.iteritems():
569 573 rvalue = record.get(key, None)
570 574 if evalue and rvalue and evalue != rvalue:
571 575 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
572 576 elif evalue and not rvalue:
573 577 record[key] = evalue
574 578 try:
575 579 self.db.update_record(msg_id, record)
576 580 except Exception:
577 581 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
578 582 except KeyError:
579 583 try:
580 584 self.db.add_record(msg_id, record)
581 585 except Exception:
582 586 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
583 587
584 588
585 589 self.pending.add(msg_id)
586 590 self.queues[eid].append(msg_id)
587 591
588 592 def save_queue_result(self, idents, msg):
589 593 if len(idents) < 2:
590 594 self.log.error("invalid identity prefix: %r"%idents)
591 595 return
592 596
593 597 client_id, queue_id = idents[:2]
594 598 try:
595 599 msg = self.session.unpack_message(msg)
596 600 except Exception:
597 601 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
598 602 queue_id,client_id, msg), exc_info=True)
599 603 return
600 604
601 605 eid = self.by_ident.get(queue_id, None)
602 606 if eid is None:
603 607 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
604 608 return
605 609
606 610 parent = msg['parent_header']
607 611 if not parent:
608 612 return
609 613 msg_id = parent['msg_id']
610 614 if msg_id in self.pending:
611 615 self.pending.remove(msg_id)
612 616 self.all_completed.add(msg_id)
613 617 self.queues[eid].remove(msg_id)
614 618 self.completed[eid].append(msg_id)
615 619 elif msg_id not in self.all_completed:
616 620 # it could be a result from a dead engine that died before delivering the
617 621 # result
618 622 self.log.warn("queue:: unknown msg finished %r"%msg_id)
619 623 return
620 624 # update record anyway, because the unregistration could have been premature
621 625 rheader = msg['header']
622 626 completed = rheader['date']
623 627 started = rheader.get('started', None)
624 628 result = {
625 629 'result_header' : rheader,
626 630 'result_content': msg['content'],
627 631 'started' : started,
628 632 'completed' : completed
629 633 }
630 634
631 635 result['result_buffers'] = msg['buffers']
632 636 try:
633 637 self.db.update_record(msg_id, result)
634 638 except Exception:
635 639 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
636 640
637 641
638 642 #--------------------- Task Queue Traffic ------------------------------
639 643
640 644 def save_task_request(self, idents, msg):
641 645 """Save the submission of a task."""
642 646 client_id = idents[0]
643 647
644 648 try:
645 649 msg = self.session.unpack_message(msg)
646 650 except Exception:
647 651 self.log.error("task::client %r sent invalid task message: %r"%(
648 652 client_id, msg), exc_info=True)
649 653 return
650 654 record = init_record(msg)
651 655
652 656 record['client_uuid'] = client_id
653 657 record['queue'] = 'task'
654 658 header = msg['header']
655 659 msg_id = header['msg_id']
656 660 self.pending.add(msg_id)
657 661 self.unassigned.add(msg_id)
658 662 try:
659 663 # it's posible iopub arrived first:
660 664 existing = self.db.get_record(msg_id)
661 665 if existing['resubmitted']:
662 666 for key in ('submitted', 'client_uuid', 'buffers'):
663 667 # don't clobber these keys on resubmit
664 668 # submitted and client_uuid should be different
665 669 # and buffers might be big, and shouldn't have changed
666 670 record.pop(key)
667 671 # still check content,header which should not change
668 672 # but are not expensive to compare as buffers
669 673
670 674 for key,evalue in existing.iteritems():
671 675 if key.endswith('buffers'):
672 676 # don't compare buffers
673 677 continue
674 678 rvalue = record.get(key, None)
675 679 if evalue and rvalue and evalue != rvalue:
676 680 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
677 681 elif evalue and not rvalue:
678 682 record[key] = evalue
679 683 try:
680 684 self.db.update_record(msg_id, record)
681 685 except Exception:
682 686 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
683 687 except KeyError:
684 688 try:
685 689 self.db.add_record(msg_id, record)
686 690 except Exception:
687 691 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
688 692 except Exception:
689 693 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
690 694
691 695 def save_task_result(self, idents, msg):
692 696 """save the result of a completed task."""
693 697 client_id = idents[0]
694 698 try:
695 699 msg = self.session.unpack_message(msg)
696 700 except Exception:
697 701 self.log.error("task::invalid task result message send to %r: %r"%(
698 702 client_id, msg), exc_info=True)
699 703 return
700 704
701 705 parent = msg['parent_header']
702 706 if not parent:
703 707 # print msg
704 708 self.log.warn("Task %r had no parent!"%msg)
705 709 return
706 710 msg_id = parent['msg_id']
707 711 if msg_id in self.unassigned:
708 712 self.unassigned.remove(msg_id)
709 713
710 714 header = msg['header']
711 715 engine_uuid = header.get('engine', None)
712 716 eid = self.by_ident.get(engine_uuid, None)
713 717
714 718 if msg_id in self.pending:
715 719 self.pending.remove(msg_id)
716 720 self.all_completed.add(msg_id)
717 721 if eid is not None:
718 722 self.completed[eid].append(msg_id)
719 723 if msg_id in self.tasks[eid]:
720 724 self.tasks[eid].remove(msg_id)
721 725 completed = header['date']
722 726 started = header.get('started', None)
723 727 result = {
724 728 'result_header' : header,
725 729 'result_content': msg['content'],
726 730 'started' : started,
727 731 'completed' : completed,
728 732 'engine_uuid': engine_uuid
729 733 }
730 734
731 735 result['result_buffers'] = msg['buffers']
732 736 try:
733 737 self.db.update_record(msg_id, result)
734 738 except Exception:
735 739 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
736 740
737 741 else:
738 742 self.log.debug("task::unknown task %r finished"%msg_id)
739 743
740 744 def save_task_destination(self, idents, msg):
741 745 try:
742 746 msg = self.session.unpack_message(msg, content=True)
743 747 except Exception:
744 748 self.log.error("task::invalid task tracking message", exc_info=True)
745 749 return
746 750 content = msg['content']
747 751 # print (content)
748 752 msg_id = content['msg_id']
749 753 engine_uuid = content['engine_id']
750 754 eid = self.by_ident[engine_uuid]
751 755
752 756 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
753 757 if msg_id in self.unassigned:
754 758 self.unassigned.remove(msg_id)
755 759 # else:
756 760 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
757 761
758 762 self.tasks[eid].append(msg_id)
759 763 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
760 764 try:
761 765 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
762 766 except Exception:
763 767 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
764 768
765 769
766 770 def mia_task_request(self, idents, msg):
767 771 raise NotImplementedError
768 772 client_id = idents[0]
769 773 # content = dict(mia=self.mia,status='ok')
770 774 # self.session.send('mia_reply', content=content, idents=client_id)
771 775
772 776
773 777 #--------------------- IOPub Traffic ------------------------------
774 778
775 779 def save_iopub_message(self, topics, msg):
776 780 """save an iopub message into the db"""
777 781 # print (topics)
778 782 try:
779 783 msg = self.session.unpack_message(msg, content=True)
780 784 except Exception:
781 785 self.log.error("iopub::invalid IOPub message", exc_info=True)
782 786 return
783 787
784 788 parent = msg['parent_header']
785 789 if not parent:
786 790 self.log.error("iopub::invalid IOPub message: %r"%msg)
787 791 return
788 792 msg_id = parent['msg_id']
789 793 msg_type = msg['msg_type']
790 794 content = msg['content']
791 795
792 796 # ensure msg_id is in db
793 797 try:
794 798 rec = self.db.get_record(msg_id)
795 799 except KeyError:
796 800 rec = empty_record()
797 801 rec['msg_id'] = msg_id
798 802 self.db.add_record(msg_id, rec)
799 803 # stream
800 804 d = {}
801 805 if msg_type == 'stream':
802 806 name = content['name']
803 807 s = rec[name] or ''
804 808 d[name] = s + content['data']
805 809
806 810 elif msg_type == 'pyerr':
807 811 d['pyerr'] = content
808 812 elif msg_type == 'pyin':
809 813 d['pyin'] = content['code']
810 814 else:
811 815 d[msg_type] = content.get('data', '')
812 816
813 817 try:
814 818 self.db.update_record(msg_id, d)
815 819 except Exception:
816 820 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
817 821
818 822
819 823
820 824 #-------------------------------------------------------------------------
821 825 # Registration requests
822 826 #-------------------------------------------------------------------------
823 827
824 828 def connection_request(self, client_id, msg):
825 829 """Reply with connection addresses for clients."""
826 830 self.log.info("client::client %r connected"%client_id)
827 831 content = dict(status='ok')
828 832 content.update(self.client_info)
829 833 jsonable = {}
830 834 for k,v in self.keytable.iteritems():
831 835 if v not in self.dead_engines:
832 836 jsonable[str(k)] = v
833 837 content['engines'] = jsonable
834 838 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
835 839
836 840 def register_engine(self, reg, msg):
837 841 """Register a new engine."""
838 842 content = msg['content']
839 843 try:
840 844 queue = content['queue']
841 845 except KeyError:
842 846 self.log.error("registration::queue not specified", exc_info=True)
843 847 return
844 848 heart = content.get('heartbeat', None)
845 849 """register a new engine, and create the socket(s) necessary"""
846 850 eid = self._next_id
847 851 # print (eid, queue, reg, heart)
848 852
849 853 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
850 854
851 855 content = dict(id=eid,status='ok')
852 856 content.update(self.engine_info)
853 857 # check if requesting available IDs:
854 858 if queue in self.by_ident:
855 859 try:
856 860 raise KeyError("queue_id %r in use"%queue)
857 861 except:
858 862 content = error.wrap_exception()
859 863 self.log.error("queue_id %r in use"%queue, exc_info=True)
860 864 elif heart in self.hearts: # need to check unique hearts?
861 865 try:
862 866 raise KeyError("heart_id %r in use"%heart)
863 867 except:
864 868 self.log.error("heart_id %r in use"%heart, exc_info=True)
865 869 content = error.wrap_exception()
866 870 else:
867 871 for h, pack in self.incoming_registrations.iteritems():
868 872 if heart == h:
869 873 try:
870 874 raise KeyError("heart_id %r in use"%heart)
871 875 except:
872 876 self.log.error("heart_id %r in use"%heart, exc_info=True)
873 877 content = error.wrap_exception()
874 878 break
875 879 elif queue == pack[1]:
876 880 try:
877 881 raise KeyError("queue_id %r in use"%queue)
878 882 except:
879 883 self.log.error("queue_id %r in use"%queue, exc_info=True)
880 884 content = error.wrap_exception()
881 885 break
882 886
883 887 msg = self.session.send(self.query, "registration_reply",
884 888 content=content,
885 889 ident=reg)
886 890
887 891 if content['status'] == 'ok':
888 892 if heart in self.heartmonitor.hearts:
889 893 # already beating
890 894 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
891 895 self.finish_registration(heart)
892 896 else:
893 897 purge = lambda : self._purge_stalled_registration(heart)
894 898 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
895 899 dc.start()
896 900 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
897 901 else:
898 902 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
899 903 return eid
900 904
901 905 def unregister_engine(self, ident, msg):
902 906 """Unregister an engine that explicitly requested to leave."""
903 907 try:
904 908 eid = msg['content']['id']
905 909 except:
906 910 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
907 911 return
908 912 self.log.info("registration::unregister_engine(%r)"%eid)
909 913 # print (eid)
910 914 uuid = self.keytable[eid]
911 915 content=dict(id=eid, queue=uuid)
912 916 self.dead_engines.add(uuid)
913 917 # self.ids.remove(eid)
914 918 # uuid = self.keytable.pop(eid)
915 919 #
916 920 # ec = self.engines.pop(eid)
917 921 # self.hearts.pop(ec.heartbeat)
918 922 # self.by_ident.pop(ec.queue)
919 923 # self.completed.pop(eid)
920 924 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
921 925 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
922 926 dc.start()
923 927 ############## TODO: HANDLE IT ################
924 928
925 929 if self.notifier:
926 930 self.session.send(self.notifier, "unregistration_notification", content=content)
927 931
928 932 def _handle_stranded_msgs(self, eid, uuid):
929 933 """Handle messages known to be on an engine when the engine unregisters.
930 934
931 935 It is possible that this will fire prematurely - that is, an engine will
932 936 go down after completing a result, and the client will be notified
933 937 that the result failed and later receive the actual result.
934 938 """
935 939
936 940 outstanding = self.queues[eid]
937 941
938 942 for msg_id in outstanding:
939 943 self.pending.remove(msg_id)
940 944 self.all_completed.add(msg_id)
941 945 try:
942 946 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
943 947 except:
944 948 content = error.wrap_exception()
945 949 # build a fake header:
946 950 header = {}
947 951 header['engine'] = uuid
948 952 header['date'] = datetime.now()
949 953 rec = dict(result_content=content, result_header=header, result_buffers=[])
950 954 rec['completed'] = header['date']
951 955 rec['engine_uuid'] = uuid
952 956 try:
953 957 self.db.update_record(msg_id, rec)
954 958 except Exception:
955 959 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
956 960
957 961
958 962 def finish_registration(self, heart):
959 963 """Second half of engine registration, called after our HeartMonitor
960 964 has received a beat from the Engine's Heart."""
961 965 try:
962 966 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
963 967 except KeyError:
964 968 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
965 969 return
966 970 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
967 971 if purge is not None:
968 972 purge.stop()
969 973 control = queue
970 974 self.ids.add(eid)
971 975 self.keytable[eid] = queue
972 976 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
973 977 control=control, heartbeat=heart)
974 978 self.by_ident[queue] = eid
975 979 self.queues[eid] = list()
976 980 self.tasks[eid] = list()
977 981 self.completed[eid] = list()
978 982 self.hearts[heart] = eid
979 983 content = dict(id=eid, queue=self.engines[eid].queue)
980 984 if self.notifier:
981 985 self.session.send(self.notifier, "registration_notification", content=content)
982 986 self.log.info("engine::Engine Connected: %i"%eid)
983 987
984 988 def _purge_stalled_registration(self, heart):
985 989 if heart in self.incoming_registrations:
986 990 eid = self.incoming_registrations.pop(heart)[0]
987 991 self.log.info("registration::purging stalled registration: %i"%eid)
988 992 else:
989 993 pass
990 994
991 995 #-------------------------------------------------------------------------
992 996 # Client Requests
993 997 #-------------------------------------------------------------------------
994 998
995 999 def shutdown_request(self, client_id, msg):
996 1000 """handle shutdown request."""
997 1001 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
998 1002 # also notify other clients of shutdown
999 1003 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1000 1004 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1001 1005 dc.start()
1002 1006
1003 1007 def _shutdown(self):
1004 1008 self.log.info("hub::hub shutting down.")
1005 1009 time.sleep(0.1)
1006 1010 sys.exit(0)
1007 1011
1008 1012
1009 1013 def check_load(self, client_id, msg):
1010 1014 content = msg['content']
1011 1015 try:
1012 1016 targets = content['targets']
1013 1017 targets = self._validate_targets(targets)
1014 1018 except:
1015 1019 content = error.wrap_exception()
1016 1020 self.session.send(self.query, "hub_error",
1017 1021 content=content, ident=client_id)
1018 1022 return
1019 1023
1020 1024 content = dict(status='ok')
1021 1025 # loads = {}
1022 1026 for t in targets:
1023 1027 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1024 1028 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1025 1029
1026 1030
1027 1031 def queue_status(self, client_id, msg):
1028 1032 """Return the Queue status of one or more targets.
1029 1033 if verbose: return the msg_ids
1030 1034 else: return len of each type.
1031 1035 keys: queue (pending MUX jobs)
1032 1036 tasks (pending Task jobs)
1033 1037 completed (finished jobs from both queues)"""
1034 1038 content = msg['content']
1035 1039 targets = content['targets']
1036 1040 try:
1037 1041 targets = self._validate_targets(targets)
1038 1042 except:
1039 1043 content = error.wrap_exception()
1040 1044 self.session.send(self.query, "hub_error",
1041 1045 content=content, ident=client_id)
1042 1046 return
1043 1047 verbose = content.get('verbose', False)
1044 1048 content = dict(status='ok')
1045 1049 for t in targets:
1046 1050 queue = self.queues[t]
1047 1051 completed = self.completed[t]
1048 1052 tasks = self.tasks[t]
1049 1053 if not verbose:
1050 1054 queue = len(queue)
1051 1055 completed = len(completed)
1052 1056 tasks = len(tasks)
1053 1057 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1054 1058 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1055 1059
1056 1060 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1057 1061
1058 1062 def purge_results(self, client_id, msg):
1059 1063 """Purge results from memory. This method is more valuable before we move
1060 1064 to a DB based message storage mechanism."""
1061 1065 content = msg['content']
1062 1066 msg_ids = content.get('msg_ids', [])
1063 1067 reply = dict(status='ok')
1064 1068 if msg_ids == 'all':
1065 1069 try:
1066 1070 self.db.drop_matching_records(dict(completed={'$ne':None}))
1067 1071 except Exception:
1068 1072 reply = error.wrap_exception()
1069 1073 else:
1070 1074 pending = filter(lambda m: m in self.pending, msg_ids)
1071 1075 if pending:
1072 1076 try:
1073 1077 raise IndexError("msg pending: %r"%pending[0])
1074 1078 except:
1075 1079 reply = error.wrap_exception()
1076 1080 else:
1077 1081 try:
1078 1082 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1079 1083 except Exception:
1080 1084 reply = error.wrap_exception()
1081 1085
1082 1086 if reply['status'] == 'ok':
1083 1087 eids = content.get('engine_ids', [])
1084 1088 for eid in eids:
1085 1089 if eid not in self.engines:
1086 1090 try:
1087 1091 raise IndexError("No such engine: %i"%eid)
1088 1092 except:
1089 1093 reply = error.wrap_exception()
1090 1094 break
1091 1095 msg_ids = self.completed.pop(eid)
1092 1096 uid = self.engines[eid].queue
1093 1097 try:
1094 1098 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1095 1099 except Exception:
1096 1100 reply = error.wrap_exception()
1097 1101 break
1098 1102
1099 1103 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1100 1104
1101 1105 def resubmit_task(self, client_id, msg):
1102 1106 """Resubmit one or more tasks."""
1103 1107 def finish(reply):
1104 1108 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1105 1109
1106 1110 content = msg['content']
1107 1111 msg_ids = content['msg_ids']
1108 1112 reply = dict(status='ok')
1109 1113 try:
1110 1114 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1111 1115 'header', 'content', 'buffers'])
1112 1116 except Exception:
1113 1117 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1114 1118 return finish(error.wrap_exception())
1115 1119
1116 1120 # validate msg_ids
1117 1121 found_ids = [ rec['msg_id'] for rec in records ]
1118 1122 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1119 1123 if len(records) > len(msg_ids):
1120 1124 try:
1121 1125 raise RuntimeError("DB appears to be in an inconsistent state."
1122 1126 "More matching records were found than should exist")
1123 1127 except Exception:
1124 1128 return finish(error.wrap_exception())
1125 1129 elif len(records) < len(msg_ids):
1126 1130 missing = [ m for m in msg_ids if m not in found_ids ]
1127 1131 try:
1128 1132 raise KeyError("No such msg(s): %r"%missing)
1129 1133 except KeyError:
1130 1134 return finish(error.wrap_exception())
1131 1135 elif invalid_ids:
1132 1136 msg_id = invalid_ids[0]
1133 1137 try:
1134 1138 raise ValueError("Task %r appears to be inflight"%(msg_id))
1135 1139 except Exception:
1136 1140 return finish(error.wrap_exception())
1137 1141
1138 1142 # clear the existing records
1139 1143 now = datetime.now()
1140 1144 rec = empty_record()
1141 1145 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1142 1146 rec['resubmitted'] = now
1143 1147 rec['queue'] = 'task'
1144 1148 rec['client_uuid'] = client_id[0]
1145 1149 try:
1146 1150 for msg_id in msg_ids:
1147 1151 self.all_completed.discard(msg_id)
1148 1152 self.db.update_record(msg_id, rec)
1149 1153 except Exception:
1150 1154 self.log.error('db::db error upating record', exc_info=True)
1151 1155 reply = error.wrap_exception()
1152 1156 else:
1153 1157 # send the messages
1154 1158 for rec in records:
1155 1159 header = rec['header']
1156 1160 # include resubmitted in header to prevent digest collision
1157 1161 header['resubmitted'] = now
1158 1162 msg = self.session.msg(header['msg_type'])
1159 1163 msg['content'] = rec['content']
1160 1164 msg['header'] = header
1161 1165 msg['msg_id'] = rec['msg_id']
1162 1166 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1163 1167
1164 1168 finish(dict(status='ok'))
1165 1169
1166 1170
1167 1171 def _extract_record(self, rec):
1168 1172 """decompose a TaskRecord dict into subsection of reply for get_result"""
1169 1173 io_dict = {}
1170 1174 for key in 'pyin pyout pyerr stdout stderr'.split():
1171 1175 io_dict[key] = rec[key]
1172 1176 content = { 'result_content': rec['result_content'],
1173 1177 'header': rec['header'],
1174 1178 'result_header' : rec['result_header'],
1175 1179 'io' : io_dict,
1176 1180 }
1177 1181 if rec['result_buffers']:
1178 1182 buffers = map(str, rec['result_buffers'])
1179 1183 else:
1180 1184 buffers = []
1181 1185
1182 1186 return content, buffers
1183 1187
1184 1188 def get_results(self, client_id, msg):
1185 1189 """Get the result of 1 or more messages."""
1186 1190 content = msg['content']
1187 1191 msg_ids = sorted(set(content['msg_ids']))
1188 1192 statusonly = content.get('status_only', False)
1189 1193 pending = []
1190 1194 completed = []
1191 1195 content = dict(status='ok')
1192 1196 content['pending'] = pending
1193 1197 content['completed'] = completed
1194 1198 buffers = []
1195 1199 if not statusonly:
1196 1200 try:
1197 1201 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1198 1202 # turn match list into dict, for faster lookup
1199 1203 records = {}
1200 1204 for rec in matches:
1201 1205 records[rec['msg_id']] = rec
1202 1206 except Exception:
1203 1207 content = error.wrap_exception()
1204 1208 self.session.send(self.query, "result_reply", content=content,
1205 1209 parent=msg, ident=client_id)
1206 1210 return
1207 1211 else:
1208 1212 records = {}
1209 1213 for msg_id in msg_ids:
1210 1214 if msg_id in self.pending:
1211 1215 pending.append(msg_id)
1212 1216 elif msg_id in self.all_completed:
1213 1217 completed.append(msg_id)
1214 1218 if not statusonly:
1215 1219 c,bufs = self._extract_record(records[msg_id])
1216 1220 content[msg_id] = c
1217 1221 buffers.extend(bufs)
1218 1222 elif msg_id in records:
1219 1223 if rec['completed']:
1220 1224 completed.append(msg_id)
1221 1225 c,bufs = self._extract_record(records[msg_id])
1222 1226 content[msg_id] = c
1223 1227 buffers.extend(bufs)
1224 1228 else:
1225 1229 pending.append(msg_id)
1226 1230 else:
1227 1231 try:
1228 1232 raise KeyError('No such message: '+msg_id)
1229 1233 except:
1230 1234 content = error.wrap_exception()
1231 1235 break
1232 1236 self.session.send(self.query, "result_reply", content=content,
1233 1237 parent=msg, ident=client_id,
1234 1238 buffers=buffers)
1235 1239
1236 1240 def get_history(self, client_id, msg):
1237 1241 """Get a list of all msg_ids in our DB records"""
1238 1242 try:
1239 1243 msg_ids = self.db.get_history()
1240 1244 except Exception as e:
1241 1245 content = error.wrap_exception()
1242 1246 else:
1243 1247 content = dict(status='ok', history=msg_ids)
1244 1248
1245 1249 self.session.send(self.query, "history_reply", content=content,
1246 1250 parent=msg, ident=client_id)
1247 1251
1248 1252 def db_query(self, client_id, msg):
1249 1253 """Perform a raw query on the task record database."""
1250 1254 content = msg['content']
1251 1255 query = content.get('query', {})
1252 1256 keys = content.get('keys', None)
1253 1257 buffers = []
1254 1258 empty = list()
1255 1259 try:
1256 1260 records = self.db.find_records(query, keys)
1257 1261 except Exception as e:
1258 1262 content = error.wrap_exception()
1259 1263 else:
1260 1264 # extract buffers from reply content:
1261 1265 if keys is not None:
1262 1266 buffer_lens = [] if 'buffers' in keys else None
1263 1267 result_buffer_lens = [] if 'result_buffers' in keys else None
1264 1268 else:
1265 1269 buffer_lens = []
1266 1270 result_buffer_lens = []
1267 1271
1268 1272 for rec in records:
1269 1273 # buffers may be None, so double check
1270 1274 if buffer_lens is not None:
1271 1275 b = rec.pop('buffers', empty) or empty
1272 1276 buffer_lens.append(len(b))
1273 1277 buffers.extend(b)
1274 1278 if result_buffer_lens is not None:
1275 1279 rb = rec.pop('result_buffers', empty) or empty
1276 1280 result_buffer_lens.append(len(rb))
1277 1281 buffers.extend(rb)
1278 1282 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1279 1283 result_buffer_lens=result_buffer_lens)
1280 1284
1281 1285 self.session.send(self.query, "db_reply", content=content,
1282 1286 parent=msg, ident=client_id,
1283 1287 buffers=buffers)
1284 1288
@@ -1,112 +1,117 b''
1 """A TaskRecord backend using mongodb"""
1 """A TaskRecord backend using mongodb
2
3 Authors:
4
5 * Min RK
6 """
2 7 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
4 9 #
5 10 # Distributed under the terms of the BSD License. The full license is in
6 11 # the file COPYING, distributed as part of this software.
7 12 #-----------------------------------------------------------------------------
8 13
9 14 from pymongo import Connection
10 15 from pymongo.binary import Binary
11 16
12 17 from IPython.utils.traitlets import Dict, List, Unicode, Instance
13 18
14 19 from .dictdb import BaseDB
15 20
16 21 #-----------------------------------------------------------------------------
17 22 # MongoDB class
18 23 #-----------------------------------------------------------------------------
19 24
20 25 class MongoDB(BaseDB):
21 26 """MongoDB TaskRecord backend."""
22 27
23 28 connection_args = List(config=True,
24 29 help="""Positional arguments to be passed to pymongo.Connection. Only
25 30 necessary if the default mongodb configuration does not point to your
26 31 mongod instance.""")
27 32 connection_kwargs = Dict(config=True,
28 33 help="""Keyword arguments to be passed to pymongo.Connection. Only
29 34 necessary if the default mongodb configuration does not point to your
30 35 mongod instance."""
31 36 )
32 37 database = Unicode(config=True,
33 38 help="""The MongoDB database name to use for storing tasks for this session. If unspecified,
34 39 a new database will be created with the Hub's IDENT. Specifying the database will result
35 40 in tasks from previous sessions being available via Clients' db_query and
36 41 get_result methods.""")
37 42
38 43 _connection = Instance(Connection) # pymongo connection
39 44
40 45 def __init__(self, **kwargs):
41 46 super(MongoDB, self).__init__(**kwargs)
42 47 if self._connection is None:
43 48 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
44 49 if not self.database:
45 50 self.database = self.session
46 51 self._db = self._connection[self.database]
47 52 self._records = self._db['task_records']
48 53 self._records.ensure_index('msg_id', unique=True)
49 54 self._records.ensure_index('submitted') # for sorting history
50 55 # for rec in self._records.find
51 56
52 57 def _binary_buffers(self, rec):
53 58 for key in ('buffers', 'result_buffers'):
54 59 if rec.get(key, None):
55 60 rec[key] = map(Binary, rec[key])
56 61 return rec
57 62
58 63 def add_record(self, msg_id, rec):
59 64 """Add a new Task Record, by msg_id."""
60 65 # print rec
61 66 rec = self._binary_buffers(rec)
62 67 self._records.insert(rec)
63 68
64 69 def get_record(self, msg_id):
65 70 """Get a specific Task Record, by msg_id."""
66 71 r = self._records.find_one({'msg_id': msg_id})
67 72 if not r:
68 73 # r will be '' if nothing is found
69 74 raise KeyError(msg_id)
70 75 return r
71 76
72 77 def update_record(self, msg_id, rec):
73 78 """Update the data in an existing record."""
74 79 rec = self._binary_buffers(rec)
75 80
76 81 self._records.update({'msg_id':msg_id}, {'$set': rec})
77 82
78 83 def drop_matching_records(self, check):
79 84 """Remove a record from the DB."""
80 85 self._records.remove(check)
81 86
82 87 def drop_record(self, msg_id):
83 88 """Remove a record from the DB."""
84 89 self._records.remove({'msg_id':msg_id})
85 90
86 91 def find_records(self, check, keys=None):
87 92 """Find records matching a query dict, optionally extracting subset of keys.
88 93
89 94 Returns list of matching records.
90 95
91 96 Parameters
92 97 ----------
93 98
94 99 check: dict
95 100 mongodb-style query argument
96 101 keys: list of strs [optional]
97 102 if specified, the subset of keys to extract. msg_id will *always* be
98 103 included.
99 104 """
100 105 if keys and 'msg_id' not in keys:
101 106 keys.append('msg_id')
102 107 matches = list(self._records.find(check,keys))
103 108 for rec in matches:
104 109 rec.pop('_id')
105 110 return matches
106 111
107 112 def get_history(self):
108 113 """get all msg_ids, ordered by time submitted."""
109 114 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
110 115 return [ rec['msg_id'] for rec in cursor ]
111 116
112 117
@@ -1,688 +1,692 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6
7 Authors:
8
9 * Min RK
6 10 """
7 11 #-----------------------------------------------------------------------------
8 12 # Copyright (C) 2010-2011 The IPython Development Team
9 13 #
10 14 # Distributed under the terms of the BSD License. The full license is in
11 15 # the file COPYING, distributed as part of this software.
12 16 #-----------------------------------------------------------------------------
13 17
14 18 #----------------------------------------------------------------------
15 19 # Imports
16 20 #----------------------------------------------------------------------
17 21
18 22 from __future__ import print_function
19 23
20 24 import logging
21 25 import sys
22 26
23 27 from datetime import datetime, timedelta
24 28 from random import randint, random
25 29 from types import FunctionType
26 30
27 31 try:
28 32 import numpy
29 33 except ImportError:
30 34 numpy = None
31 35
32 36 import zmq
33 37 from zmq.eventloop import ioloop, zmqstream
34 38
35 39 # local imports
36 40 from IPython.external.decorator import decorator
37 41 from IPython.config.loader import Config
38 42 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Str, Enum
39 43
40 44 from IPython.parallel import error
41 45 from IPython.parallel.factory import SessionFactory
42 46 from IPython.parallel.util import connect_logger, local_logger
43 47
44 48 from .dependency import Dependency
45 49
46 50 @decorator
47 51 def logged(f,self,*args,**kwargs):
48 52 # print ("#--------------------")
49 53 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
50 54 # print ("#--")
51 55 return f(self,*args, **kwargs)
52 56
53 57 #----------------------------------------------------------------------
54 58 # Chooser functions
55 59 #----------------------------------------------------------------------
56 60
57 61 def plainrandom(loads):
58 62 """Plain random pick."""
59 63 n = len(loads)
60 64 return randint(0,n-1)
61 65
62 66 def lru(loads):
63 67 """Always pick the front of the line.
64 68
65 69 The content of `loads` is ignored.
66 70
67 71 Assumes LRU ordering of loads, with oldest first.
68 72 """
69 73 return 0
70 74
71 75 def twobin(loads):
72 76 """Pick two at random, use the LRU of the two.
73 77
74 78 The content of loads is ignored.
75 79
76 80 Assumes LRU ordering of loads, with oldest first.
77 81 """
78 82 n = len(loads)
79 83 a = randint(0,n-1)
80 84 b = randint(0,n-1)
81 85 return min(a,b)
82 86
83 87 def weighted(loads):
84 88 """Pick two at random using inverse load as weight.
85 89
86 90 Return the less loaded of the two.
87 91 """
88 92 # weight 0 a million times more than 1:
89 93 weights = 1./(1e-6+numpy.array(loads))
90 94 sums = weights.cumsum()
91 95 t = sums[-1]
92 96 x = random()*t
93 97 y = random()*t
94 98 idx = 0
95 99 idy = 0
96 100 while sums[idx] < x:
97 101 idx += 1
98 102 while sums[idy] < y:
99 103 idy += 1
100 104 if weights[idy] > weights[idx]:
101 105 return idy
102 106 else:
103 107 return idx
104 108
105 109 def leastload(loads):
106 110 """Always choose the lowest load.
107 111
108 112 If the lowest load occurs more than once, the first
109 113 occurance will be used. If loads has LRU ordering, this means
110 114 the LRU of those with the lowest load is chosen.
111 115 """
112 116 return loads.index(min(loads))
113 117
114 118 #---------------------------------------------------------------------
115 119 # Classes
116 120 #---------------------------------------------------------------------
117 121 # store empty default dependency:
118 122 MET = Dependency([])
119 123
120 124 class TaskScheduler(SessionFactory):
121 125 """Python TaskScheduler object.
122 126
123 127 This is the simplest object that supports msg_id based
124 128 DAG dependencies. *Only* task msg_ids are checked, not
125 129 msg_ids of jobs submitted via the MUX queue.
126 130
127 131 """
128 132
129 133 hwm = Int(0, config=True, shortname='hwm',
130 134 help="""specify the High Water Mark (HWM) for the downstream
131 135 socket in the Task scheduler. This is the maximum number
132 136 of allowed outstanding tasks on each engine."""
133 137 )
134 138 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
135 139 'leastload', config=True, shortname='scheme', allow_none=False,
136 140 help="""select the task scheduler scheme [default: Python LRU]
137 141 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
138 142 )
139 143 def _scheme_name_changed(self, old, new):
140 144 self.log.debug("Using scheme %r"%new)
141 145 self.scheme = globals()[new]
142 146
143 147 # input arguments:
144 148 scheme = Instance(FunctionType) # function for determining the destination
145 149 def _scheme_default(self):
146 150 return leastload
147 151 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
148 152 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
149 153 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
150 154 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
151 155
152 156 # internals:
153 157 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
154 158 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
155 159 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
156 160 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
157 161 pending = Dict() # dict by engine_uuid of submitted tasks
158 162 completed = Dict() # dict by engine_uuid of completed tasks
159 163 failed = Dict() # dict by engine_uuid of failed tasks
160 164 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
161 165 clients = Dict() # dict by msg_id for who submitted the task
162 166 targets = List() # list of target IDENTs
163 167 loads = List() # list of engine loads
164 168 # full = Set() # set of IDENTs that have HWM outstanding tasks
165 169 all_completed = Set() # set of all completed tasks
166 170 all_failed = Set() # set of all failed tasks
167 171 all_done = Set() # set of all finished tasks=union(completed,failed)
168 172 all_ids = Set() # set of all submitted task IDs
169 173 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
170 174 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
171 175
172 176
173 177 def start(self):
174 178 self.engine_stream.on_recv(self.dispatch_result, copy=False)
175 179 self._notification_handlers = dict(
176 180 registration_notification = self._register_engine,
177 181 unregistration_notification = self._unregister_engine
178 182 )
179 183 self.notifier_stream.on_recv(self.dispatch_notification)
180 184 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
181 185 self.auditor.start()
182 186 self.log.info("Scheduler started [%s]"%self.scheme_name)
183 187
184 188 def resume_receiving(self):
185 189 """Resume accepting jobs."""
186 190 self.client_stream.on_recv(self.dispatch_submission, copy=False)
187 191
188 192 def stop_receiving(self):
189 193 """Stop accepting jobs while there are no engines.
190 194 Leave them in the ZMQ queue."""
191 195 self.client_stream.on_recv(None)
192 196
193 197 #-----------------------------------------------------------------------
194 198 # [Un]Registration Handling
195 199 #-----------------------------------------------------------------------
196 200
197 201 def dispatch_notification(self, msg):
198 202 """dispatch register/unregister events."""
199 203 try:
200 204 idents,msg = self.session.feed_identities(msg)
201 205 except ValueError:
202 206 self.log.warn("task::Invalid Message: %r"%msg)
203 207 return
204 208 try:
205 209 msg = self.session.unpack_message(msg)
206 210 except ValueError:
207 211 self.log.warn("task::Unauthorized message from: %r"%idents)
208 212 return
209 213
210 214 msg_type = msg['msg_type']
211 215
212 216 handler = self._notification_handlers.get(msg_type, None)
213 217 if handler is None:
214 218 self.log.error("Unhandled message type: %r"%msg_type)
215 219 else:
216 220 try:
217 221 handler(str(msg['content']['queue']))
218 222 except KeyError:
219 223 self.log.error("task::Invalid notification msg: %r"%msg)
220 224
221 225 @logged
222 226 def _register_engine(self, uid):
223 227 """New engine with ident `uid` became available."""
224 228 # head of the line:
225 229 self.targets.insert(0,uid)
226 230 self.loads.insert(0,0)
227 231 # initialize sets
228 232 self.completed[uid] = set()
229 233 self.failed[uid] = set()
230 234 self.pending[uid] = {}
231 235 if len(self.targets) == 1:
232 236 self.resume_receiving()
233 237 # rescan the graph:
234 238 self.update_graph(None)
235 239
236 240 def _unregister_engine(self, uid):
237 241 """Existing engine with ident `uid` became unavailable."""
238 242 if len(self.targets) == 1:
239 243 # this was our only engine
240 244 self.stop_receiving()
241 245
242 246 # handle any potentially finished tasks:
243 247 self.engine_stream.flush()
244 248
245 249 # don't pop destinations, because they might be used later
246 250 # map(self.destinations.pop, self.completed.pop(uid))
247 251 # map(self.destinations.pop, self.failed.pop(uid))
248 252
249 253 # prevent this engine from receiving work
250 254 idx = self.targets.index(uid)
251 255 self.targets.pop(idx)
252 256 self.loads.pop(idx)
253 257
254 258 # wait 5 seconds before cleaning up pending jobs, since the results might
255 259 # still be incoming
256 260 if self.pending[uid]:
257 261 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
258 262 dc.start()
259 263 else:
260 264 self.completed.pop(uid)
261 265 self.failed.pop(uid)
262 266
263 267
264 268 @logged
265 269 def handle_stranded_tasks(self, engine):
266 270 """Deal with jobs resident in an engine that died."""
267 271 lost = self.pending[engine]
268 272 for msg_id in lost.keys():
269 273 if msg_id not in self.pending[engine]:
270 274 # prevent double-handling of messages
271 275 continue
272 276
273 277 raw_msg = lost[msg_id][0]
274 278 idents,msg = self.session.feed_identities(raw_msg, copy=False)
275 279 parent = self.session.unpack(msg[1].bytes)
276 280 idents = [engine, idents[0]]
277 281
278 282 # build fake error reply
279 283 try:
280 284 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
281 285 except:
282 286 content = error.wrap_exception()
283 287 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
284 288 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
285 289 # and dispatch it
286 290 self.dispatch_result(raw_reply)
287 291
288 292 # finally scrub completed/failed lists
289 293 self.completed.pop(engine)
290 294 self.failed.pop(engine)
291 295
292 296
293 297 #-----------------------------------------------------------------------
294 298 # Job Submission
295 299 #-----------------------------------------------------------------------
296 300 @logged
297 301 def dispatch_submission(self, raw_msg):
298 302 """Dispatch job submission to appropriate handlers."""
299 303 # ensure targets up to date:
300 304 self.notifier_stream.flush()
301 305 try:
302 306 idents, msg = self.session.feed_identities(raw_msg, copy=False)
303 307 msg = self.session.unpack_message(msg, content=False, copy=False)
304 308 except Exception:
305 309 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
306 310 return
307 311
308 312
309 313 # send to monitor
310 314 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
311 315
312 316 header = msg['header']
313 317 msg_id = header['msg_id']
314 318 self.all_ids.add(msg_id)
315 319
316 320 # targets
317 321 targets = set(header.get('targets', []))
318 322 retries = header.get('retries', 0)
319 323 self.retries[msg_id] = retries
320 324
321 325 # time dependencies
322 326 after = Dependency(header.get('after', []))
323 327 if after.all:
324 328 if after.success:
325 329 after.difference_update(self.all_completed)
326 330 if after.failure:
327 331 after.difference_update(self.all_failed)
328 332 if after.check(self.all_completed, self.all_failed):
329 333 # recast as empty set, if `after` already met,
330 334 # to prevent unnecessary set comparisons
331 335 after = MET
332 336
333 337 # location dependencies
334 338 follow = Dependency(header.get('follow', []))
335 339
336 340 # turn timeouts into datetime objects:
337 341 timeout = header.get('timeout', None)
338 342 if timeout:
339 343 timeout = datetime.now() + timedelta(0,timeout,0)
340 344
341 345 args = [raw_msg, targets, after, follow, timeout]
342 346
343 347 # validate and reduce dependencies:
344 348 for dep in after,follow:
345 349 # check valid:
346 350 if msg_id in dep or dep.difference(self.all_ids):
347 351 self.depending[msg_id] = args
348 352 return self.fail_unreachable(msg_id, error.InvalidDependency)
349 353 # check if unreachable:
350 354 if dep.unreachable(self.all_completed, self.all_failed):
351 355 self.depending[msg_id] = args
352 356 return self.fail_unreachable(msg_id)
353 357
354 358 if after.check(self.all_completed, self.all_failed):
355 359 # time deps already met, try to run
356 360 if not self.maybe_run(msg_id, *args):
357 361 # can't run yet
358 362 if msg_id not in self.all_failed:
359 363 # could have failed as unreachable
360 364 self.save_unmet(msg_id, *args)
361 365 else:
362 366 self.save_unmet(msg_id, *args)
363 367
364 368 # @logged
365 369 def audit_timeouts(self):
366 370 """Audit all waiting tasks for expired timeouts."""
367 371 now = datetime.now()
368 372 for msg_id in self.depending.keys():
369 373 # must recheck, in case one failure cascaded to another:
370 374 if msg_id in self.depending:
371 375 raw,after,targets,follow,timeout = self.depending[msg_id]
372 376 if timeout and timeout < now:
373 377 self.fail_unreachable(msg_id, error.TaskTimeout)
374 378
375 379 @logged
376 380 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
377 381 """a task has become unreachable, send a reply with an ImpossibleDependency
378 382 error."""
379 383 if msg_id not in self.depending:
380 384 self.log.error("msg %r already failed!"%msg_id)
381 385 return
382 386 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
383 387 for mid in follow.union(after):
384 388 if mid in self.graph:
385 389 self.graph[mid].remove(msg_id)
386 390
387 391 # FIXME: unpacking a message I've already unpacked, but didn't save:
388 392 idents,msg = self.session.feed_identities(raw_msg, copy=False)
389 393 header = self.session.unpack(msg[1].bytes)
390 394
391 395 try:
392 396 raise why()
393 397 except:
394 398 content = error.wrap_exception()
395 399
396 400 self.all_done.add(msg_id)
397 401 self.all_failed.add(msg_id)
398 402
399 403 msg = self.session.send(self.client_stream, 'apply_reply', content,
400 404 parent=header, ident=idents)
401 405 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
402 406
403 407 self.update_graph(msg_id, success=False)
404 408
405 409 @logged
406 410 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
407 411 """check location dependencies, and run if they are met."""
408 412 blacklist = self.blacklist.setdefault(msg_id, set())
409 413 if follow or targets or blacklist or self.hwm:
410 414 # we need a can_run filter
411 415 def can_run(idx):
412 416 # check hwm
413 417 if self.hwm and self.loads[idx] == self.hwm:
414 418 return False
415 419 target = self.targets[idx]
416 420 # check blacklist
417 421 if target in blacklist:
418 422 return False
419 423 # check targets
420 424 if targets and target not in targets:
421 425 return False
422 426 # check follow
423 427 return follow.check(self.completed[target], self.failed[target])
424 428
425 429 indices = filter(can_run, range(len(self.targets)))
426 430
427 431 if not indices:
428 432 # couldn't run
429 433 if follow.all:
430 434 # check follow for impossibility
431 435 dests = set()
432 436 relevant = set()
433 437 if follow.success:
434 438 relevant = self.all_completed
435 439 if follow.failure:
436 440 relevant = relevant.union(self.all_failed)
437 441 for m in follow.intersection(relevant):
438 442 dests.add(self.destinations[m])
439 443 if len(dests) > 1:
440 444 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
441 445 self.fail_unreachable(msg_id)
442 446 return False
443 447 if targets:
444 448 # check blacklist+targets for impossibility
445 449 targets.difference_update(blacklist)
446 450 if not targets or not targets.intersection(self.targets):
447 451 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
448 452 self.fail_unreachable(msg_id)
449 453 return False
450 454 return False
451 455 else:
452 456 indices = None
453 457
454 458 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
455 459 return True
456 460
457 461 @logged
458 462 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
459 463 """Save a message for later submission when its dependencies are met."""
460 464 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
461 465 # track the ids in follow or after, but not those already finished
462 466 for dep_id in after.union(follow).difference(self.all_done):
463 467 if dep_id not in self.graph:
464 468 self.graph[dep_id] = set()
465 469 self.graph[dep_id].add(msg_id)
466 470
467 471 @logged
468 472 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
469 473 """Submit a task to any of a subset of our targets."""
470 474 if indices:
471 475 loads = [self.loads[i] for i in indices]
472 476 else:
473 477 loads = self.loads
474 478 idx = self.scheme(loads)
475 479 if indices:
476 480 idx = indices[idx]
477 481 target = self.targets[idx]
478 482 # print (target, map(str, msg[:3]))
479 483 # send job to the engine
480 484 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
481 485 self.engine_stream.send_multipart(raw_msg, copy=False)
482 486 # update load
483 487 self.add_job(idx)
484 488 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
485 489 # notify Hub
486 490 content = dict(msg_id=msg_id, engine_id=target)
487 491 self.session.send(self.mon_stream, 'task_destination', content=content,
488 492 ident=['tracktask',self.session.session])
489 493
490 494
491 495 #-----------------------------------------------------------------------
492 496 # Result Handling
493 497 #-----------------------------------------------------------------------
494 498 @logged
495 499 def dispatch_result(self, raw_msg):
496 500 """dispatch method for result replies"""
497 501 try:
498 502 idents,msg = self.session.feed_identities(raw_msg, copy=False)
499 503 msg = self.session.unpack_message(msg, content=False, copy=False)
500 504 engine = idents[0]
501 505 try:
502 506 idx = self.targets.index(engine)
503 507 except ValueError:
504 508 pass # skip load-update for dead engines
505 509 else:
506 510 self.finish_job(idx)
507 511 except Exception:
508 512 self.log.error("task::Invaid result: %r"%raw_msg, exc_info=True)
509 513 return
510 514
511 515 header = msg['header']
512 516 parent = msg['parent_header']
513 517 if header.get('dependencies_met', True):
514 518 success = (header['status'] == 'ok')
515 519 msg_id = parent['msg_id']
516 520 retries = self.retries[msg_id]
517 521 if not success and retries > 0:
518 522 # failed
519 523 self.retries[msg_id] = retries - 1
520 524 self.handle_unmet_dependency(idents, parent)
521 525 else:
522 526 del self.retries[msg_id]
523 527 # relay to client and update graph
524 528 self.handle_result(idents, parent, raw_msg, success)
525 529 # send to Hub monitor
526 530 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
527 531 else:
528 532 self.handle_unmet_dependency(idents, parent)
529 533
530 534 @logged
531 535 def handle_result(self, idents, parent, raw_msg, success=True):
532 536 """handle a real task result, either success or failure"""
533 537 # first, relay result to client
534 538 engine = idents[0]
535 539 client = idents[1]
536 540 # swap_ids for XREP-XREP mirror
537 541 raw_msg[:2] = [client,engine]
538 542 # print (map(str, raw_msg[:4]))
539 543 self.client_stream.send_multipart(raw_msg, copy=False)
540 544 # now, update our data structures
541 545 msg_id = parent['msg_id']
542 546 self.blacklist.pop(msg_id, None)
543 547 self.pending[engine].pop(msg_id)
544 548 if success:
545 549 self.completed[engine].add(msg_id)
546 550 self.all_completed.add(msg_id)
547 551 else:
548 552 self.failed[engine].add(msg_id)
549 553 self.all_failed.add(msg_id)
550 554 self.all_done.add(msg_id)
551 555 self.destinations[msg_id] = engine
552 556
553 557 self.update_graph(msg_id, success)
554 558
555 559 @logged
556 560 def handle_unmet_dependency(self, idents, parent):
557 561 """handle an unmet dependency"""
558 562 engine = idents[0]
559 563 msg_id = parent['msg_id']
560 564
561 565 if msg_id not in self.blacklist:
562 566 self.blacklist[msg_id] = set()
563 567 self.blacklist[msg_id].add(engine)
564 568
565 569 args = self.pending[engine].pop(msg_id)
566 570 raw,targets,after,follow,timeout = args
567 571
568 572 if self.blacklist[msg_id] == targets:
569 573 self.depending[msg_id] = args
570 574 self.fail_unreachable(msg_id)
571 575 elif not self.maybe_run(msg_id, *args):
572 576 # resubmit failed
573 577 if msg_id not in self.all_failed:
574 578 # put it back in our dependency tree
575 579 self.save_unmet(msg_id, *args)
576 580
577 581 if self.hwm:
578 582 try:
579 583 idx = self.targets.index(engine)
580 584 except ValueError:
581 585 pass # skip load-update for dead engines
582 586 else:
583 587 if self.loads[idx] == self.hwm-1:
584 588 self.update_graph(None)
585 589
586 590
587 591
588 592 @logged
589 593 def update_graph(self, dep_id=None, success=True):
590 594 """dep_id just finished. Update our dependency
591 595 graph and submit any jobs that just became runable.
592 596
593 597 Called with dep_id=None to update entire graph for hwm, but without finishing
594 598 a task.
595 599 """
596 600 # print ("\n\n***********")
597 601 # pprint (dep_id)
598 602 # pprint (self.graph)
599 603 # pprint (self.depending)
600 604 # pprint (self.all_completed)
601 605 # pprint (self.all_failed)
602 606 # print ("\n\n***********\n\n")
603 607 # update any jobs that depended on the dependency
604 608 jobs = self.graph.pop(dep_id, [])
605 609
606 610 # recheck *all* jobs if
607 611 # a) we have HWM and an engine just become no longer full
608 612 # or b) dep_id was given as None
609 613 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
610 614 jobs = self.depending.keys()
611 615
612 616 for msg_id in jobs:
613 617 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
614 618
615 619 if after.unreachable(self.all_completed, self.all_failed)\
616 620 or follow.unreachable(self.all_completed, self.all_failed):
617 621 self.fail_unreachable(msg_id)
618 622
619 623 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
620 624 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
621 625
622 626 self.depending.pop(msg_id)
623 627 for mid in follow.union(after):
624 628 if mid in self.graph:
625 629 self.graph[mid].remove(msg_id)
626 630
627 631 #----------------------------------------------------------------------
628 632 # methods to be overridden by subclasses
629 633 #----------------------------------------------------------------------
630 634
631 635 def add_job(self, idx):
632 636 """Called after self.targets[idx] just got the job with header.
633 637 Override with subclasses. The default ordering is simple LRU.
634 638 The default loads are the number of outstanding jobs."""
635 639 self.loads[idx] += 1
636 640 for lis in (self.targets, self.loads):
637 641 lis.append(lis.pop(idx))
638 642
639 643
640 644 def finish_job(self, idx):
641 645 """Called after self.targets[idx] just finished a job.
642 646 Override with subclasses."""
643 647 self.loads[idx] -= 1
644 648
645 649
646 650
647 651 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
648 652 logname='root', log_url=None, loglevel=logging.DEBUG,
649 653 identity=b'task'):
650 654 from zmq.eventloop import ioloop
651 655 from zmq.eventloop.zmqstream import ZMQStream
652 656
653 657 if config:
654 658 # unwrap dict back into Config
655 659 config = Config(config)
656 660
657 661 ctx = zmq.Context()
658 662 loop = ioloop.IOLoop()
659 663 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
660 664 ins.setsockopt(zmq.IDENTITY, identity)
661 665 ins.bind(in_addr)
662 666
663 667 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
664 668 outs.setsockopt(zmq.IDENTITY, identity)
665 669 outs.bind(out_addr)
666 670 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
667 671 mons.connect(mon_addr)
668 672 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
669 673 nots.setsockopt(zmq.SUBSCRIBE, '')
670 674 nots.connect(not_addr)
671 675
672 676 # setup logging. Note that these will not work in-process, because they clobber
673 677 # existing loggers.
674 678 if log_url:
675 679 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
676 680 else:
677 681 log = local_logger(logname, loglevel)
678 682
679 683 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
680 684 mon_stream=mons, notifier_stream=nots,
681 685 loop=loop, log=log,
682 686 config=config)
683 687 scheduler.start()
684 688 try:
685 689 loop.start()
686 690 except KeyboardInterrupt:
687 691 print ("interrupted, exiting...", file=sys.__stderr__)
688 692
@@ -1,386 +1,391 b''
1 """A TaskRecord backend using sqlite3"""
1 """A TaskRecord backend using sqlite3
2
3 Authors:
4
5 * Min RK
6 """
2 7 #-----------------------------------------------------------------------------
3 8 # Copyright (C) 2011 The IPython Development Team
4 9 #
5 10 # Distributed under the terms of the BSD License. The full license is in
6 11 # the file COPYING, distributed as part of this software.
7 12 #-----------------------------------------------------------------------------
8 13
9 14 import json
10 15 import os
11 16 import cPickle as pickle
12 17 from datetime import datetime
13 18
14 19 import sqlite3
15 20
16 21 from zmq.eventloop import ioloop
17 22
18 23 from IPython.utils.traitlets import Unicode, Instance, List, Dict
19 24 from .dictdb import BaseDB
20 25 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
21 26
22 27 #-----------------------------------------------------------------------------
23 28 # SQLite operators, adapters, and converters
24 29 #-----------------------------------------------------------------------------
25 30
26 31 operators = {
27 32 '$lt' : "<",
28 33 '$gt' : ">",
29 34 # null is handled weird with ==,!=
30 35 '$eq' : "=",
31 36 '$ne' : "!=",
32 37 '$lte': "<=",
33 38 '$gte': ">=",
34 39 '$in' : ('=', ' OR '),
35 40 '$nin': ('!=', ' AND '),
36 41 # '$all': None,
37 42 # '$mod': None,
38 43 # '$exists' : None
39 44 }
40 45 null_operators = {
41 46 '=' : "IS NULL",
42 47 '!=' : "IS NOT NULL",
43 48 }
44 49
45 50 def _adapt_dict(d):
46 51 return json.dumps(d, default=date_default)
47 52
48 53 def _convert_dict(ds):
49 54 if ds is None:
50 55 return ds
51 56 else:
52 57 return extract_dates(json.loads(ds))
53 58
54 59 def _adapt_bufs(bufs):
55 60 # this is *horrible*
56 61 # copy buffers into single list and pickle it:
57 62 if bufs and isinstance(bufs[0], (bytes, buffer)):
58 63 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
59 64 elif bufs:
60 65 return bufs
61 66 else:
62 67 return None
63 68
64 69 def _convert_bufs(bs):
65 70 if bs is None:
66 71 return []
67 72 else:
68 73 return pickle.loads(bytes(bs))
69 74
70 75 #-----------------------------------------------------------------------------
71 76 # SQLiteDB class
72 77 #-----------------------------------------------------------------------------
73 78
74 79 class SQLiteDB(BaseDB):
75 80 """SQLite3 TaskRecord backend."""
76 81
77 82 filename = Unicode('tasks.db', config=True,
78 83 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
79 84 location = Unicode('', config=True,
80 85 help="""The directory containing the sqlite task database. The default
81 86 is to use the cluster_dir location.""")
82 87 table = Unicode("", config=True,
83 88 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
84 89 a new table will be created with the Hub's IDENT. Specifying the table will result
85 90 in tasks from previous sessions being available via Clients' db_query and
86 91 get_result methods.""")
87 92
88 93 _db = Instance('sqlite3.Connection')
89 94 # the ordered list of column names
90 95 _keys = List(['msg_id' ,
91 96 'header' ,
92 97 'content',
93 98 'buffers',
94 99 'submitted',
95 100 'client_uuid' ,
96 101 'engine_uuid' ,
97 102 'started',
98 103 'completed',
99 104 'resubmitted',
100 105 'result_header' ,
101 106 'result_content' ,
102 107 'result_buffers' ,
103 108 'queue' ,
104 109 'pyin' ,
105 110 'pyout',
106 111 'pyerr',
107 112 'stdout',
108 113 'stderr',
109 114 ])
110 115 # sqlite datatypes for checking that db is current format
111 116 _types = Dict({'msg_id' : 'text' ,
112 117 'header' : 'dict text',
113 118 'content' : 'dict text',
114 119 'buffers' : 'bufs blob',
115 120 'submitted' : 'timestamp',
116 121 'client_uuid' : 'text',
117 122 'engine_uuid' : 'text',
118 123 'started' : 'timestamp',
119 124 'completed' : 'timestamp',
120 125 'resubmitted' : 'timestamp',
121 126 'result_header' : 'dict text',
122 127 'result_content' : 'dict text',
123 128 'result_buffers' : 'bufs blob',
124 129 'queue' : 'text',
125 130 'pyin' : 'text',
126 131 'pyout' : 'text',
127 132 'pyerr' : 'text',
128 133 'stdout' : 'text',
129 134 'stderr' : 'text',
130 135 })
131 136
132 137 def __init__(self, **kwargs):
133 138 super(SQLiteDB, self).__init__(**kwargs)
134 139 if not self.table:
135 140 # use session, and prefix _, since starting with # is illegal
136 141 self.table = '_'+self.session.replace('-','_')
137 142 if not self.location:
138 143 # get current profile
139 144 from IPython.core.newapplication import BaseIPythonApplication
140 145 if BaseIPythonApplication.initialized():
141 146 app = BaseIPythonApplication.instance()
142 147 if app.profile_dir is not None:
143 148 self.location = app.profile_dir.location
144 149 else:
145 150 self.location = u'.'
146 151 else:
147 152 self.location = u'.'
148 153 self._init_db()
149 154
150 155 # register db commit as 2s periodic callback
151 156 # to prevent clogging pipes
152 157 # assumes we are being run in a zmq ioloop app
153 158 loop = ioloop.IOLoop.instance()
154 159 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
155 160 pc.start()
156 161
157 162 def _defaults(self, keys=None):
158 163 """create an empty record"""
159 164 d = {}
160 165 keys = self._keys if keys is None else keys
161 166 for key in keys:
162 167 d[key] = None
163 168 return d
164 169
165 170 def _check_table(self):
166 171 """Ensure that an incorrect table doesn't exist
167 172
168 173 If a bad (old) table does exist, return False
169 174 """
170 175 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
171 176 lines = cursor.fetchall()
172 177 if not lines:
173 178 # table does not exist
174 179 return True
175 180 types = {}
176 181 keys = []
177 182 for line in lines:
178 183 keys.append(line[1])
179 184 types[line[1]] = line[2]
180 185 if self._keys != keys:
181 186 # key mismatch
182 187 self.log.warn('keys mismatch')
183 188 return False
184 189 for key in self._keys:
185 190 if types[key] != self._types[key]:
186 191 self.log.warn(
187 192 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
188 193 )
189 194 return False
190 195 return True
191 196
192 197 def _init_db(self):
193 198 """Connect to the database and get new session number."""
194 199 # register adapters
195 200 sqlite3.register_adapter(dict, _adapt_dict)
196 201 sqlite3.register_converter('dict', _convert_dict)
197 202 sqlite3.register_adapter(list, _adapt_bufs)
198 203 sqlite3.register_converter('bufs', _convert_bufs)
199 204 # connect to the db
200 205 dbfile = os.path.join(self.location, self.filename)
201 206 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
202 207 # isolation_level = None)#,
203 208 cached_statements=64)
204 209 # print dir(self._db)
205 210 first_table = self.table
206 211 i=0
207 212 while not self._check_table():
208 213 i+=1
209 214 self.table = first_table+'_%i'%i
210 215 self.log.warn(
211 216 "Table %s exists and doesn't match db format, trying %s"%
212 217 (first_table,self.table)
213 218 )
214 219
215 220 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
216 221 (msg_id text PRIMARY KEY,
217 222 header dict text,
218 223 content dict text,
219 224 buffers bufs blob,
220 225 submitted timestamp,
221 226 client_uuid text,
222 227 engine_uuid text,
223 228 started timestamp,
224 229 completed timestamp,
225 230 resubmitted timestamp,
226 231 result_header dict text,
227 232 result_content dict text,
228 233 result_buffers bufs blob,
229 234 queue text,
230 235 pyin text,
231 236 pyout text,
232 237 pyerr text,
233 238 stdout text,
234 239 stderr text)
235 240 """%self.table)
236 241 self._db.commit()
237 242
238 243 def _dict_to_list(self, d):
239 244 """turn a mongodb-style record dict into a list."""
240 245
241 246 return [ d[key] for key in self._keys ]
242 247
243 248 def _list_to_dict(self, line, keys=None):
244 249 """Inverse of dict_to_list"""
245 250 keys = self._keys if keys is None else keys
246 251 d = self._defaults(keys)
247 252 for key,value in zip(keys, line):
248 253 d[key] = value
249 254
250 255 return d
251 256
252 257 def _render_expression(self, check):
253 258 """Turn a mongodb-style search dict into an SQL query."""
254 259 expressions = []
255 260 args = []
256 261
257 262 skeys = set(check.keys())
258 263 skeys.difference_update(set(self._keys))
259 264 skeys.difference_update(set(['buffers', 'result_buffers']))
260 265 if skeys:
261 266 raise KeyError("Illegal testing key(s): %s"%skeys)
262 267
263 268 for name,sub_check in check.iteritems():
264 269 if isinstance(sub_check, dict):
265 270 for test,value in sub_check.iteritems():
266 271 try:
267 272 op = operators[test]
268 273 except KeyError:
269 274 raise KeyError("Unsupported operator: %r"%test)
270 275 if isinstance(op, tuple):
271 276 op, join = op
272 277
273 278 if value is None and op in null_operators:
274 279 expr = "%s %s"%null_operators[op]
275 280 else:
276 281 expr = "%s %s ?"%(name, op)
277 282 if isinstance(value, (tuple,list)):
278 283 if op in null_operators and any([v is None for v in value]):
279 284 # equality tests don't work with NULL
280 285 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
281 286 expr = '( %s )'%( join.join([expr]*len(value)) )
282 287 args.extend(value)
283 288 else:
284 289 args.append(value)
285 290 expressions.append(expr)
286 291 else:
287 292 # it's an equality check
288 293 if sub_check is None:
289 294 expressions.append("%s IS NULL")
290 295 else:
291 296 expressions.append("%s = ?"%name)
292 297 args.append(sub_check)
293 298
294 299 expr = " AND ".join(expressions)
295 300 return expr, args
296 301
297 302 def add_record(self, msg_id, rec):
298 303 """Add a new Task Record, by msg_id."""
299 304 d = self._defaults()
300 305 d.update(rec)
301 306 d['msg_id'] = msg_id
302 307 line = self._dict_to_list(d)
303 308 tups = '(%s)'%(','.join(['?']*len(line)))
304 309 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
305 310 # self._db.commit()
306 311
307 312 def get_record(self, msg_id):
308 313 """Get a specific Task Record, by msg_id."""
309 314 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
310 315 line = cursor.fetchone()
311 316 if line is None:
312 317 raise KeyError("No such msg: %r"%msg_id)
313 318 return self._list_to_dict(line)
314 319
315 320 def update_record(self, msg_id, rec):
316 321 """Update the data in an existing record."""
317 322 query = "UPDATE %s SET "%self.table
318 323 sets = []
319 324 keys = sorted(rec.keys())
320 325 values = []
321 326 for key in keys:
322 327 sets.append('%s = ?'%key)
323 328 values.append(rec[key])
324 329 query += ', '.join(sets)
325 330 query += ' WHERE msg_id == ?'
326 331 values.append(msg_id)
327 332 self._db.execute(query, values)
328 333 # self._db.commit()
329 334
330 335 def drop_record(self, msg_id):
331 336 """Remove a record from the DB."""
332 337 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
333 338 # self._db.commit()
334 339
335 340 def drop_matching_records(self, check):
336 341 """Remove a record from the DB."""
337 342 expr,args = self._render_expression(check)
338 343 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
339 344 self._db.execute(query,args)
340 345 # self._db.commit()
341 346
342 347 def find_records(self, check, keys=None):
343 348 """Find records matching a query dict, optionally extracting subset of keys.
344 349
345 350 Returns list of matching records.
346 351
347 352 Parameters
348 353 ----------
349 354
350 355 check: dict
351 356 mongodb-style query argument
352 357 keys: list of strs [optional]
353 358 if specified, the subset of keys to extract. msg_id will *always* be
354 359 included.
355 360 """
356 361 if keys:
357 362 bad_keys = [ key for key in keys if key not in self._keys ]
358 363 if bad_keys:
359 364 raise KeyError("Bad record key(s): %s"%bad_keys)
360 365
361 366 if keys:
362 367 # ensure msg_id is present and first:
363 368 if 'msg_id' in keys:
364 369 keys.remove('msg_id')
365 370 keys.insert(0, 'msg_id')
366 371 req = ', '.join(keys)
367 372 else:
368 373 req = '*'
369 374 expr,args = self._render_expression(check)
370 375 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
371 376 cursor = self._db.execute(query, args)
372 377 matches = cursor.fetchall()
373 378 records = []
374 379 for line in matches:
375 380 rec = self._list_to_dict(line, keys)
376 381 records.append(rec)
377 382 return records
378 383
379 384 def get_history(self):
380 385 """get all msg_ids, ordered by time submitted."""
381 386 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
382 387 cursor = self._db.execute(query)
383 388 # will be a list of length 1 tuples
384 389 return [ tup[0] for tup in cursor.fetchall()]
385 390
386 391 __all__ = ['SQLiteDB'] No newline at end of file
@@ -1,166 +1,170 b''
1 1 #!/usr/bin/env python
2 2 """A simple engine that talks to a controller over 0MQ.
3 3 it handles registration, etc. and launches a kernel
4 4 connected to the Controller's Schedulers.
5
6 Authors:
7
8 * Min RK
5 9 """
6 10 #-----------------------------------------------------------------------------
7 11 # Copyright (C) 2010-2011 The IPython Development Team
8 12 #
9 13 # Distributed under the terms of the BSD License. The full license is in
10 14 # the file COPYING, distributed as part of this software.
11 15 #-----------------------------------------------------------------------------
12 16
13 17 from __future__ import print_function
14 18
15 19 import sys
16 20 import time
17 21
18 22 import zmq
19 23 from zmq.eventloop import ioloop, zmqstream
20 24
21 25 # internal
22 26 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode
23 27 # from IPython.utils.localinterfaces import LOCALHOST
24 28
25 29 from IPython.parallel.controller.heartmonitor import Heart
26 30 from IPython.parallel.factory import RegistrationFactory
27 31 from IPython.parallel.util import disambiguate_url
28 32
29 33 from IPython.zmq.session import Message
30 34
31 35 from .streamkernel import Kernel
32 36
33 37 class EngineFactory(RegistrationFactory):
34 38 """IPython engine"""
35 39
36 40 # configurables:
37 41 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
38 42 help="""The OutStream for handling stdout/err.
39 43 Typically 'IPython.zmq.iostream.OutStream'""")
40 44 display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True,
41 45 help="""The class for handling displayhook.
42 46 Typically 'IPython.zmq.displayhook.DisplayHook'""")
43 47 location=Unicode(config=True,
44 48 help="""The location (an IP address) of the controller. This is
45 49 used for disambiguating URLs, to determine whether
46 50 loopback should be used to connect or the public address.""")
47 51 timeout=CFloat(2,config=True,
48 52 help="""The time (in seconds) to wait for the Controller to respond
49 53 to registration requests before giving up.""")
50 54
51 55 # not configurable:
52 56 user_ns=Dict()
53 57 id=Int(allow_none=True)
54 58 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
55 59 kernel=Instance(Kernel)
56 60
57 61
58 62 def __init__(self, **kwargs):
59 63 super(EngineFactory, self).__init__(**kwargs)
60 64 self.ident = self.session.session
61 65 ctx = self.context
62 66
63 67 reg = ctx.socket(zmq.XREQ)
64 68 reg.setsockopt(zmq.IDENTITY, self.ident)
65 69 reg.connect(self.url)
66 70 self.registrar = zmqstream.ZMQStream(reg, self.loop)
67 71
68 72 def register(self):
69 73 """send the registration_request"""
70 74
71 75 self.log.info("registering")
72 76 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
73 77 self.registrar.on_recv(self.complete_registration)
74 78 # print (self.session.key)
75 79 self.session.send(self.registrar, "registration_request",content=content)
76 80
77 81 def complete_registration(self, msg):
78 82 # print msg
79 83 self._abort_dc.stop()
80 84 ctx = self.context
81 85 loop = self.loop
82 86 identity = self.ident
83 87
84 88 idents,msg = self.session.feed_identities(msg)
85 89 msg = Message(self.session.unpack_message(msg))
86 90
87 91 if msg.content.status == 'ok':
88 92 self.id = int(msg.content.id)
89 93
90 94 # create Shell Streams (MUX, Task, etc.):
91 95 queue_addr = msg.content.mux
92 96 shell_addrs = [ str(queue_addr) ]
93 97 task_addr = msg.content.task
94 98 if task_addr:
95 99 shell_addrs.append(str(task_addr))
96 100
97 101 # Uncomment this to go back to two-socket model
98 102 # shell_streams = []
99 103 # for addr in shell_addrs:
100 104 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
101 105 # stream.setsockopt(zmq.IDENTITY, identity)
102 106 # stream.connect(disambiguate_url(addr, self.location))
103 107 # shell_streams.append(stream)
104 108
105 109 # Now use only one shell stream for mux and tasks
106 110 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
107 111 stream.setsockopt(zmq.IDENTITY, identity)
108 112 shell_streams = [stream]
109 113 for addr in shell_addrs:
110 114 stream.connect(disambiguate_url(addr, self.location))
111 115 # end single stream-socket
112 116
113 117 # control stream:
114 118 control_addr = str(msg.content.control)
115 119 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
116 120 control_stream.setsockopt(zmq.IDENTITY, identity)
117 121 control_stream.connect(disambiguate_url(control_addr, self.location))
118 122
119 123 # create iopub stream:
120 124 iopub_addr = msg.content.iopub
121 125 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
122 126 iopub_stream.setsockopt(zmq.IDENTITY, identity)
123 127 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
124 128
125 129 # launch heartbeat
126 130 hb_addrs = msg.content.heartbeat
127 131 # print (hb_addrs)
128 132
129 133 # # Redirect input streams and set a display hook.
130 134 if self.out_stream_factory:
131 135 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
132 136 sys.stdout.topic = 'engine.%i.stdout'%self.id
133 137 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
134 138 sys.stderr.topic = 'engine.%i.stderr'%self.id
135 139 if self.display_hook_factory:
136 140 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
137 141 sys.displayhook.topic = 'engine.%i.pyout'%self.id
138 142
139 143 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
140 144 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
141 145 loop=loop, user_ns = self.user_ns, log=self.log)
142 146 self.kernel.start()
143 147 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
144 148 heart = Heart(*map(str, hb_addrs), heart_id=identity)
145 149 heart.start()
146 150
147 151
148 152 else:
149 153 self.log.fatal("Registration Failed: %s"%msg)
150 154 raise Exception("Registration Failed: %s"%msg)
151 155
152 156 self.log.info("Completed registration with id %i"%self.id)
153 157
154 158
155 159 def abort(self):
156 160 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
157 161 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
158 162 time.sleep(1)
159 163 sys.exit(255)
160 164
161 165 def start(self):
162 166 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
163 167 dc.start()
164 168 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
165 169 self._abort_dc.start()
166 170
@@ -1,225 +1,230 b''
1 """KernelStarter class that intercepts Control Queue messages, and handles process management."""
1 """KernelStarter class that intercepts Control Queue messages, and handles process management.
2
3 Authors:
4
5 * Min RK
6 """
2 7 #-----------------------------------------------------------------------------
3 8 # Copyright (C) 2010-2011 The IPython Development Team
4 9 #
5 10 # Distributed under the terms of the BSD License. The full license is in
6 11 # the file COPYING, distributed as part of this software.
7 12 #-----------------------------------------------------------------------------
8 13
9 14 from zmq.eventloop import ioloop
10 15
11 16 from IPython.zmq.session import Session
12 17
13 18 class KernelStarter(object):
14 19 """Object for resetting/killing the Kernel."""
15 20
16 21
17 22 def __init__(self, session, upstream, downstream, *kernel_args, **kernel_kwargs):
18 23 self.session = session
19 24 self.upstream = upstream
20 25 self.downstream = downstream
21 26 self.kernel_args = kernel_args
22 27 self.kernel_kwargs = kernel_kwargs
23 28 self.handlers = {}
24 29 for method in 'shutdown_request shutdown_reply'.split():
25 30 self.handlers[method] = getattr(self, method)
26 31
27 32 def start(self):
28 33 self.upstream.on_recv(self.dispatch_request)
29 34 self.downstream.on_recv(self.dispatch_reply)
30 35
31 36 #--------------------------------------------------------------------------
32 37 # Dispatch methods
33 38 #--------------------------------------------------------------------------
34 39
35 40 def dispatch_request(self, raw_msg):
36 41 idents, msg = self.session.feed_identities()
37 42 try:
38 43 msg = self.session.unpack_message(msg, content=False)
39 44 except:
40 45 print ("bad msg: %s"%msg)
41 46
42 47 msgtype = msg['msg_type']
43 48 handler = self.handlers.get(msgtype, None)
44 49 if handler is None:
45 50 self.downstream.send_multipart(raw_msg, copy=False)
46 51 else:
47 52 handler(msg)
48 53
49 54 def dispatch_reply(self, raw_msg):
50 55 idents, msg = self.session.feed_identities()
51 56 try:
52 57 msg = self.session.unpack_message(msg, content=False)
53 58 except:
54 59 print ("bad msg: %s"%msg)
55 60
56 61 msgtype = msg['msg_type']
57 62 handler = self.handlers.get(msgtype, None)
58 63 if handler is None:
59 64 self.upstream.send_multipart(raw_msg, copy=False)
60 65 else:
61 66 handler(msg)
62 67
63 68 #--------------------------------------------------------------------------
64 69 # Handlers
65 70 #--------------------------------------------------------------------------
66 71
67 72 def shutdown_request(self, msg):
68 73 """"""
69 74 self.downstream.send_multipart(msg)
70 75
71 76 #--------------------------------------------------------------------------
72 77 # Kernel process management methods, from KernelManager:
73 78 #--------------------------------------------------------------------------
74 79
75 80 def _check_local(addr):
76 81 if isinstance(addr, tuple):
77 82 addr = addr[0]
78 83 return addr in LOCAL_IPS
79 84
80 85 def start_kernel(self, **kw):
81 86 """Starts a kernel process and configures the manager to use it.
82 87
83 88 If random ports (port=0) are being used, this method must be called
84 89 before the channels are created.
85 90
86 91 Parameters:
87 92 -----------
88 93 ipython : bool, optional (default True)
89 94 Whether to use an IPython kernel instead of a plain Python kernel.
90 95 """
91 96 self.kernel = Process(target=make_kernel, args=self.kernel_args,
92 97 kwargs=self.kernel_kwargs)
93 98
94 99 def shutdown_kernel(self, restart=False):
95 100 """ Attempts to the stop the kernel process cleanly. If the kernel
96 101 cannot be stopped, it is killed, if possible.
97 102 """
98 103 # FIXME: Shutdown does not work on Windows due to ZMQ errors!
99 104 if sys.platform == 'win32':
100 105 self.kill_kernel()
101 106 return
102 107
103 108 # Don't send any additional kernel kill messages immediately, to give
104 109 # the kernel a chance to properly execute shutdown actions. Wait for at
105 110 # most 1s, checking every 0.1s.
106 111 self.xreq_channel.shutdown(restart=restart)
107 112 for i in range(10):
108 113 if self.is_alive:
109 114 time.sleep(0.1)
110 115 else:
111 116 break
112 117 else:
113 118 # OK, we've waited long enough.
114 119 if self.has_kernel:
115 120 self.kill_kernel()
116 121
117 122 def restart_kernel(self, now=False):
118 123 """Restarts a kernel with the same arguments that were used to launch
119 124 it. If the old kernel was launched with random ports, the same ports
120 125 will be used for the new kernel.
121 126
122 127 Parameters
123 128 ----------
124 129 now : bool, optional
125 130 If True, the kernel is forcefully restarted *immediately*, without
126 131 having a chance to do any cleanup action. Otherwise the kernel is
127 132 given 1s to clean up before a forceful restart is issued.
128 133
129 134 In all cases the kernel is restarted, the only difference is whether
130 135 it is given a chance to perform a clean shutdown or not.
131 136 """
132 137 if self._launch_args is None:
133 138 raise RuntimeError("Cannot restart the kernel. "
134 139 "No previous call to 'start_kernel'.")
135 140 else:
136 141 if self.has_kernel:
137 142 if now:
138 143 self.kill_kernel()
139 144 else:
140 145 self.shutdown_kernel(restart=True)
141 146 self.start_kernel(**self._launch_args)
142 147
143 148 # FIXME: Messages get dropped in Windows due to probable ZMQ bug
144 149 # unless there is some delay here.
145 150 if sys.platform == 'win32':
146 151 time.sleep(0.2)
147 152
148 153 @property
149 154 def has_kernel(self):
150 155 """Returns whether a kernel process has been specified for the kernel
151 156 manager.
152 157 """
153 158 return self.kernel is not None
154 159
155 160 def kill_kernel(self):
156 161 """ Kill the running kernel. """
157 162 if self.has_kernel:
158 163 # Pause the heart beat channel if it exists.
159 164 if self._hb_channel is not None:
160 165 self._hb_channel.pause()
161 166
162 167 # Attempt to kill the kernel.
163 168 try:
164 169 self.kernel.kill()
165 170 except OSError, e:
166 171 # In Windows, we will get an Access Denied error if the process
167 172 # has already terminated. Ignore it.
168 173 if not (sys.platform == 'win32' and e.winerror == 5):
169 174 raise
170 175 self.kernel = None
171 176 else:
172 177 raise RuntimeError("Cannot kill kernel. No kernel is running!")
173 178
174 179 def interrupt_kernel(self):
175 180 """ Interrupts the kernel. Unlike ``signal_kernel``, this operation is
176 181 well supported on all platforms.
177 182 """
178 183 if self.has_kernel:
179 184 if sys.platform == 'win32':
180 185 from parentpoller import ParentPollerWindows as Poller
181 186 Poller.send_interrupt(self.kernel.win32_interrupt_event)
182 187 else:
183 188 self.kernel.send_signal(signal.SIGINT)
184 189 else:
185 190 raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
186 191
187 192 def signal_kernel(self, signum):
188 193 """ Sends a signal to the kernel. Note that since only SIGTERM is
189 194 supported on Windows, this function is only useful on Unix systems.
190 195 """
191 196 if self.has_kernel:
192 197 self.kernel.send_signal(signum)
193 198 else:
194 199 raise RuntimeError("Cannot signal kernel. No kernel is running!")
195 200
196 201 @property
197 202 def is_alive(self):
198 203 """Is the kernel process still running?"""
199 204 # FIXME: not using a heartbeat means this method is broken for any
200 205 # remote kernel, it's only capable of handling local kernels.
201 206 if self.has_kernel:
202 207 if self.kernel.poll() is None:
203 208 return True
204 209 else:
205 210 return False
206 211 else:
207 212 # We didn't start the kernel with this KernelManager so we don't
208 213 # know if it is running. We should use a heartbeat for this case.
209 214 return True
210 215
211 216
212 217 def make_starter(up_addr, down_addr, *args, **kwargs):
213 218 """entry point function for launching a kernelstarter in a subprocess"""
214 219 loop = ioloop.IOLoop.instance()
215 220 ctx = zmq.Context()
216 221 session = Session()
217 222 upstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
218 223 upstream.connect(up_addr)
219 224 downstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
220 225 downstream.connect(down_addr)
221 226
222 227 starter = KernelStarter(session, upstream, downstream, *args, **kwargs)
223 228 starter.start()
224 229 loop.start()
225 230 No newline at end of file
@@ -1,422 +1,429 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4
5 Authors:
6
7 * Min RK
8 * Brian Granger
9 * Fernando Perez
10 * Evan Patterson
4 11 """
5 12 #-----------------------------------------------------------------------------
6 13 # Copyright (C) 2010-2011 The IPython Development Team
7 14 #
8 15 # Distributed under the terms of the BSD License. The full license is in
9 16 # the file COPYING, distributed as part of this software.
10 17 #-----------------------------------------------------------------------------
11 18
12 19 #-----------------------------------------------------------------------------
13 20 # Imports
14 21 #-----------------------------------------------------------------------------
15 22
16 23 # Standard library imports.
17 24 from __future__ import print_function
18 25
19 26 import sys
20 27 import time
21 28
22 29 from code import CommandCompiler
23 30 from datetime import datetime
24 31 from pprint import pprint
25 32
26 33 # System library imports.
27 34 import zmq
28 35 from zmq.eventloop import ioloop, zmqstream
29 36
30 37 # Local imports.
31 38 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode
32 39 from IPython.zmq.completer import KernelCompleter
33 40
34 41 from IPython.parallel.error import wrap_exception
35 42 from IPython.parallel.factory import SessionFactory
36 43 from IPython.parallel.util import serialize_object, unpack_apply_message
37 44
38 45 def printer(*args):
39 46 pprint(args, stream=sys.__stdout__)
40 47
41 48
42 49 class _Passer(zmqstream.ZMQStream):
43 50 """Empty class that implements `send()` that does nothing.
44 51
45 52 Subclass ZMQStream for Session typechecking
46 53
47 54 """
48 55 def __init__(self, *args, **kwargs):
49 56 pass
50 57
51 58 def send(self, *args, **kwargs):
52 59 pass
53 60 send_multipart = send
54 61
55 62
56 63 #-----------------------------------------------------------------------------
57 64 # Main kernel class
58 65 #-----------------------------------------------------------------------------
59 66
60 67 class Kernel(SessionFactory):
61 68
62 69 #---------------------------------------------------------------------------
63 70 # Kernel interface
64 71 #---------------------------------------------------------------------------
65 72
66 73 # kwargs:
67 74 exec_lines = List(Unicode, config=True,
68 75 help="List of lines to execute")
69 76
70 77 int_id = Int(-1)
71 78 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
72 79
73 80 control_stream = Instance(zmqstream.ZMQStream)
74 81 task_stream = Instance(zmqstream.ZMQStream)
75 82 iopub_stream = Instance(zmqstream.ZMQStream)
76 83 client = Instance('IPython.parallel.Client')
77 84
78 85 # internals
79 86 shell_streams = List()
80 87 compiler = Instance(CommandCompiler, (), {})
81 88 completer = Instance(KernelCompleter)
82 89
83 90 aborted = Set()
84 91 shell_handlers = Dict()
85 92 control_handlers = Dict()
86 93
87 94 def _set_prefix(self):
88 95 self.prefix = "engine.%s"%self.int_id
89 96
90 97 def _connect_completer(self):
91 98 self.completer = KernelCompleter(self.user_ns)
92 99
93 100 def __init__(self, **kwargs):
94 101 super(Kernel, self).__init__(**kwargs)
95 102 self._set_prefix()
96 103 self._connect_completer()
97 104
98 105 self.on_trait_change(self._set_prefix, 'id')
99 106 self.on_trait_change(self._connect_completer, 'user_ns')
100 107
101 108 # Build dict of handlers for message types
102 109 for msg_type in ['execute_request', 'complete_request', 'apply_request',
103 110 'clear_request']:
104 111 self.shell_handlers[msg_type] = getattr(self, msg_type)
105 112
106 113 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
107 114 self.control_handlers[msg_type] = getattr(self, msg_type)
108 115
109 116 self._initial_exec_lines()
110 117
111 118 def _wrap_exception(self, method=None):
112 119 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
113 120 content=wrap_exception(e_info)
114 121 return content
115 122
116 123 def _initial_exec_lines(self):
117 124 s = _Passer()
118 125 content = dict(silent=True, user_variable=[],user_expressions=[])
119 126 for line in self.exec_lines:
120 127 self.log.debug("executing initialization: %s"%line)
121 128 content.update({'code':line})
122 129 msg = self.session.msg('execute_request', content)
123 130 self.execute_request(s, [], msg)
124 131
125 132
126 133 #-------------------- control handlers -----------------------------
127 134 def abort_queues(self):
128 135 for stream in self.shell_streams:
129 136 if stream:
130 137 self.abort_queue(stream)
131 138
132 139 def abort_queue(self, stream):
133 140 while True:
134 141 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
135 142 if msg is None:
136 143 return
137 144
138 145 self.log.info("Aborting:")
139 146 self.log.info(str(msg))
140 147 msg_type = msg['msg_type']
141 148 reply_type = msg_type.split('_')[0] + '_reply'
142 149 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
143 150 # self.reply_socket.send(ident,zmq.SNDMORE)
144 151 # self.reply_socket.send_json(reply_msg)
145 152 reply_msg = self.session.send(stream, reply_type,
146 153 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
147 154 self.log.debug(str(reply_msg))
148 155 # We need to wait a bit for requests to come in. This can probably
149 156 # be set shorter for true asynchronous clients.
150 157 time.sleep(0.05)
151 158
152 159 def abort_request(self, stream, ident, parent):
153 160 """abort a specifig msg by id"""
154 161 msg_ids = parent['content'].get('msg_ids', None)
155 162 if isinstance(msg_ids, basestring):
156 163 msg_ids = [msg_ids]
157 164 if not msg_ids:
158 165 self.abort_queues()
159 166 for mid in msg_ids:
160 167 self.aborted.add(str(mid))
161 168
162 169 content = dict(status='ok')
163 170 reply_msg = self.session.send(stream, 'abort_reply', content=content,
164 171 parent=parent, ident=ident)
165 172 self.log.debug(str(reply_msg))
166 173
167 174 def shutdown_request(self, stream, ident, parent):
168 175 """kill ourself. This should really be handled in an external process"""
169 176 try:
170 177 self.abort_queues()
171 178 except:
172 179 content = self._wrap_exception('shutdown')
173 180 else:
174 181 content = dict(parent['content'])
175 182 content['status'] = 'ok'
176 183 msg = self.session.send(stream, 'shutdown_reply',
177 184 content=content, parent=parent, ident=ident)
178 185 self.log.debug(str(msg))
179 186 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
180 187 dc.start()
181 188
182 189 def dispatch_control(self, msg):
183 190 idents,msg = self.session.feed_identities(msg, copy=False)
184 191 try:
185 192 msg = self.session.unpack_message(msg, content=True, copy=False)
186 193 except:
187 194 self.log.error("Invalid Message", exc_info=True)
188 195 return
189 196
190 197 header = msg['header']
191 198 msg_id = header['msg_id']
192 199
193 200 handler = self.control_handlers.get(msg['msg_type'], None)
194 201 if handler is None:
195 202 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
196 203 else:
197 204 handler(self.control_stream, idents, msg)
198 205
199 206
200 207 #-------------------- queue helpers ------------------------------
201 208
202 209 def check_dependencies(self, dependencies):
203 210 if not dependencies:
204 211 return True
205 212 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
206 213 anyorall = dependencies[0]
207 214 dependencies = dependencies[1]
208 215 else:
209 216 anyorall = 'all'
210 217 results = self.client.get_results(dependencies,status_only=True)
211 218 if results['status'] != 'ok':
212 219 return False
213 220
214 221 if anyorall == 'any':
215 222 if not results['completed']:
216 223 return False
217 224 else:
218 225 if results['pending']:
219 226 return False
220 227
221 228 return True
222 229
223 230 def check_aborted(self, msg_id):
224 231 return msg_id in self.aborted
225 232
226 233 #-------------------- queue handlers -----------------------------
227 234
228 235 def clear_request(self, stream, idents, parent):
229 236 """Clear our namespace."""
230 237 self.user_ns = {}
231 238 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
232 239 content = dict(status='ok'))
233 240 self._initial_exec_lines()
234 241
235 242 def execute_request(self, stream, ident, parent):
236 243 self.log.debug('execute request %s'%parent)
237 244 try:
238 245 code = parent[u'content'][u'code']
239 246 except:
240 247 self.log.error("Got bad msg: %s"%parent, exc_info=True)
241 248 return
242 249 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
243 250 ident='%s.pyin'%self.prefix)
244 251 started = datetime.now()
245 252 try:
246 253 comp_code = self.compiler(code, '<zmq-kernel>')
247 254 # allow for not overriding displayhook
248 255 if hasattr(sys.displayhook, 'set_parent'):
249 256 sys.displayhook.set_parent(parent)
250 257 sys.stdout.set_parent(parent)
251 258 sys.stderr.set_parent(parent)
252 259 exec comp_code in self.user_ns, self.user_ns
253 260 except:
254 261 exc_content = self._wrap_exception('execute')
255 262 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
256 263 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
257 264 ident='%s.pyerr'%self.prefix)
258 265 reply_content = exc_content
259 266 else:
260 267 reply_content = {'status' : 'ok'}
261 268
262 269 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
263 270 ident=ident, subheader = dict(started=started))
264 271 self.log.debug(str(reply_msg))
265 272 if reply_msg['content']['status'] == u'error':
266 273 self.abort_queues()
267 274
268 275 def complete_request(self, stream, ident, parent):
269 276 matches = {'matches' : self.complete(parent),
270 277 'status' : 'ok'}
271 278 completion_msg = self.session.send(stream, 'complete_reply',
272 279 matches, parent, ident)
273 280 # print >> sys.__stdout__, completion_msg
274 281
275 282 def complete(self, msg):
276 283 return self.completer.complete(msg.content.line, msg.content.text)
277 284
278 285 def apply_request(self, stream, ident, parent):
279 286 # flush previous reply, so this request won't block it
280 287 stream.flush(zmq.POLLOUT)
281 288
282 289 try:
283 290 content = parent[u'content']
284 291 bufs = parent[u'buffers']
285 292 msg_id = parent['header']['msg_id']
286 293 # bound = parent['header'].get('bound', False)
287 294 except:
288 295 self.log.error("Got bad msg: %s"%parent, exc_info=True)
289 296 return
290 297 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
291 298 # self.iopub_stream.send(pyin_msg)
292 299 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
293 300 sub = {'dependencies_met' : True, 'engine' : self.ident,
294 301 'started': datetime.now()}
295 302 try:
296 303 # allow for not overriding displayhook
297 304 if hasattr(sys.displayhook, 'set_parent'):
298 305 sys.displayhook.set_parent(parent)
299 306 sys.stdout.set_parent(parent)
300 307 sys.stderr.set_parent(parent)
301 308 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
302 309 working = self.user_ns
303 310 # suffix =
304 311 prefix = "_"+str(msg_id).replace("-","")+"_"
305 312
306 313 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
307 314 # if bound:
308 315 # bound_ns = Namespace(working)
309 316 # args = [bound_ns]+list(args)
310 317
311 318 fname = getattr(f, '__name__', 'f')
312 319
313 320 fname = prefix+"f"
314 321 argname = prefix+"args"
315 322 kwargname = prefix+"kwargs"
316 323 resultname = prefix+"result"
317 324
318 325 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
319 326 # print ns
320 327 working.update(ns)
321 328 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
322 329 try:
323 330 exec code in working,working
324 331 result = working.get(resultname)
325 332 finally:
326 333 for key in ns.iterkeys():
327 334 working.pop(key)
328 335 # if bound:
329 336 # working.update(bound_ns)
330 337
331 338 packed_result,buf = serialize_object(result)
332 339 result_buf = [packed_result]+buf
333 340 except:
334 341 exc_content = self._wrap_exception('apply')
335 342 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
336 343 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
337 344 ident='%s.pyerr'%self.prefix)
338 345 reply_content = exc_content
339 346 result_buf = []
340 347
341 348 if exc_content['ename'] == 'UnmetDependency':
342 349 sub['dependencies_met'] = False
343 350 else:
344 351 reply_content = {'status' : 'ok'}
345 352
346 353 # put 'ok'/'error' status in header, for scheduler introspection:
347 354 sub['status'] = reply_content['status']
348 355
349 356 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
350 357 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
351 358
352 359 # flush i/o
353 360 # should this be before reply_msg is sent, like in the single-kernel code,
354 361 # or should nothing get in the way of real results?
355 362 sys.stdout.flush()
356 363 sys.stderr.flush()
357 364
358 365 def dispatch_queue(self, stream, msg):
359 366 self.control_stream.flush()
360 367 idents,msg = self.session.feed_identities(msg, copy=False)
361 368 try:
362 369 msg = self.session.unpack_message(msg, content=True, copy=False)
363 370 except:
364 371 self.log.error("Invalid Message", exc_info=True)
365 372 return
366 373
367 374
368 375 header = msg['header']
369 376 msg_id = header['msg_id']
370 377 if self.check_aborted(msg_id):
371 378 self.aborted.remove(msg_id)
372 379 # is it safe to assume a msg_id will not be resubmitted?
373 380 reply_type = msg['msg_type'].split('_')[0] + '_reply'
374 381 status = {'status' : 'aborted'}
375 382 reply_msg = self.session.send(stream, reply_type, subheader=status,
376 383 content=status, parent=msg, ident=idents)
377 384 return
378 385 handler = self.shell_handlers.get(msg['msg_type'], None)
379 386 if handler is None:
380 387 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
381 388 else:
382 389 handler(stream, idents, msg)
383 390
384 391 def start(self):
385 392 #### stream mode:
386 393 if self.control_stream:
387 394 self.control_stream.on_recv(self.dispatch_control, copy=False)
388 395 self.control_stream.on_err(printer)
389 396
390 397 def make_dispatcher(stream):
391 398 def dispatcher(msg):
392 399 return self.dispatch_queue(stream, msg)
393 400 return dispatcher
394 401
395 402 for s in self.shell_streams:
396 403 s.on_recv(make_dispatcher(s), copy=False)
397 404 s.on_err(printer)
398 405
399 406 if self.iopub_stream:
400 407 self.iopub_stream.on_err(printer)
401 408
402 409 #### while True mode:
403 410 # while True:
404 411 # idle = True
405 412 # try:
406 413 # msg = self.shell_stream.socket.recv_multipart(
407 414 # zmq.NOBLOCK, copy=False)
408 415 # except zmq.ZMQError, e:
409 416 # if e.errno != zmq.EAGAIN:
410 417 # raise e
411 418 # else:
412 419 # idle=False
413 420 # self.dispatch_queue(self.shell_stream, msg)
414 421 #
415 422 # if not self.task_stream.empty():
416 423 # idle=False
417 424 # msg = self.task_stream.recv_multipart()
418 425 # self.dispatch_queue(self.task_stream, msg)
419 426 # if idle:
420 427 # # don't busywait
421 428 # time.sleep(1e-3)
422 429
@@ -1,313 +1,319 b''
1 1 # encoding: utf-8
2 2
3 """Classes and functions for kernel related errors and exceptions."""
3 """Classes and functions for kernel related errors and exceptions.
4
5 Authors:
6
7 * Brian Granger
8 * Min RK
9 """
4 10 from __future__ import print_function
5 11
6 12 import sys
7 13 import traceback
8 14
9 15 __docformat__ = "restructuredtext en"
10 16
11 17 # Tell nose to skip this module
12 18 __test__ = {}
13 19
14 20 #-------------------------------------------------------------------------------
15 # Copyright (C) 2008 The IPython Development Team
21 # Copyright (C) 2008-2011 The IPython Development Team
16 22 #
17 23 # Distributed under the terms of the BSD License. The full license is in
18 24 # the file COPYING, distributed as part of this software.
19 25 #-------------------------------------------------------------------------------
20 26
21 27 #-------------------------------------------------------------------------------
22 28 # Error classes
23 29 #-------------------------------------------------------------------------------
24 30 class IPythonError(Exception):
25 31 """Base exception that all of our exceptions inherit from.
26 32
27 33 This can be raised by code that doesn't have any more specific
28 34 information."""
29 35
30 36 pass
31 37
32 38 # Exceptions associated with the controller objects
33 39 class ControllerError(IPythonError): pass
34 40
35 41 class ControllerCreationError(ControllerError): pass
36 42
37 43
38 44 # Exceptions associated with the Engines
39 45 class EngineError(IPythonError): pass
40 46
41 47 class EngineCreationError(EngineError): pass
42 48
43 49 class KernelError(IPythonError):
44 50 pass
45 51
46 52 class NotDefined(KernelError):
47 53 def __init__(self, name):
48 54 self.name = name
49 55 self.args = (name,)
50 56
51 57 def __repr__(self):
52 58 return '<NotDefined: %s>' % self.name
53 59
54 60 __str__ = __repr__
55 61
56 62
57 63 class QueueCleared(KernelError):
58 64 pass
59 65
60 66
61 67 class IdInUse(KernelError):
62 68 pass
63 69
64 70
65 71 class ProtocolError(KernelError):
66 72 pass
67 73
68 74
69 75 class ConnectionError(KernelError):
70 76 pass
71 77
72 78
73 79 class InvalidEngineID(KernelError):
74 80 pass
75 81
76 82
77 83 class NoEnginesRegistered(KernelError):
78 84 pass
79 85
80 86
81 87 class InvalidClientID(KernelError):
82 88 pass
83 89
84 90
85 91 class InvalidDeferredID(KernelError):
86 92 pass
87 93
88 94
89 95 class SerializationError(KernelError):
90 96 pass
91 97
92 98
93 99 class MessageSizeError(KernelError):
94 100 pass
95 101
96 102
97 103 class PBMessageSizeError(MessageSizeError):
98 104 pass
99 105
100 106
101 107 class ResultNotCompleted(KernelError):
102 108 pass
103 109
104 110
105 111 class ResultAlreadyRetrieved(KernelError):
106 112 pass
107 113
108 114 class ClientError(KernelError):
109 115 pass
110 116
111 117
112 118 class TaskAborted(KernelError):
113 119 pass
114 120
115 121
116 122 class TaskTimeout(KernelError):
117 123 pass
118 124
119 125
120 126 class NotAPendingResult(KernelError):
121 127 pass
122 128
123 129
124 130 class UnpickleableException(KernelError):
125 131 pass
126 132
127 133
128 134 class AbortedPendingDeferredError(KernelError):
129 135 pass
130 136
131 137
132 138 class InvalidProperty(KernelError):
133 139 pass
134 140
135 141
136 142 class MissingBlockArgument(KernelError):
137 143 pass
138 144
139 145
140 146 class StopLocalExecution(KernelError):
141 147 pass
142 148
143 149
144 150 class SecurityError(KernelError):
145 151 pass
146 152
147 153
148 154 class FileTimeoutError(KernelError):
149 155 pass
150 156
151 157 class TimeoutError(KernelError):
152 158 pass
153 159
154 160 class UnmetDependency(KernelError):
155 161 pass
156 162
157 163 class ImpossibleDependency(UnmetDependency):
158 164 pass
159 165
160 166 class DependencyTimeout(ImpossibleDependency):
161 167 pass
162 168
163 169 class InvalidDependency(ImpossibleDependency):
164 170 pass
165 171
166 172 class RemoteError(KernelError):
167 173 """Error raised elsewhere"""
168 174 ename=None
169 175 evalue=None
170 176 traceback=None
171 177 engine_info=None
172 178
173 179 def __init__(self, ename, evalue, traceback, engine_info=None):
174 180 self.ename=ename
175 181 self.evalue=evalue
176 182 self.traceback=traceback
177 183 self.engine_info=engine_info or {}
178 184 self.args=(ename, evalue)
179 185
180 186 def __repr__(self):
181 187 engineid = self.engine_info.get('engine_id', ' ')
182 188 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
183 189
184 190 def __str__(self):
185 191 sig = "%s(%s)"%(self.ename, self.evalue)
186 192 if self.traceback:
187 193 return sig + '\n' + self.traceback
188 194 else:
189 195 return sig
190 196
191 197
192 198 class TaskRejectError(KernelError):
193 199 """Exception to raise when a task should be rejected by an engine.
194 200
195 201 This exception can be used to allow a task running on an engine to test
196 202 if the engine (or the user's namespace on the engine) has the needed
197 203 task dependencies. If not, the task should raise this exception. For
198 204 the task to be retried on another engine, the task should be created
199 205 with the `retries` argument > 1.
200 206
201 207 The advantage of this approach over our older properties system is that
202 208 tasks have full access to the user's namespace on the engines and the
203 209 properties don't have to be managed or tested by the controller.
204 210 """
205 211
206 212
207 213 class CompositeError(RemoteError):
208 214 """Error for representing possibly multiple errors on engines"""
209 215 def __init__(self, message, elist):
210 216 Exception.__init__(self, *(message, elist))
211 217 # Don't use pack_exception because it will conflict with the .message
212 218 # attribute that is being deprecated in 2.6 and beyond.
213 219 self.msg = message
214 220 self.elist = elist
215 221 self.args = [ e[0] for e in elist ]
216 222
217 223 def _get_engine_str(self, ei):
218 224 if not ei:
219 225 return '[Engine Exception]'
220 226 else:
221 227 return '[%s:%s]: ' % (ei['engine_id'], ei['method'])
222 228
223 229 def _get_traceback(self, ev):
224 230 try:
225 231 tb = ev._ipython_traceback_text
226 232 except AttributeError:
227 233 return 'No traceback available'
228 234 else:
229 235 return tb
230 236
231 237 def __str__(self):
232 238 s = str(self.msg)
233 239 for en, ev, etb, ei in self.elist:
234 240 engine_str = self._get_engine_str(ei)
235 241 s = s + '\n' + engine_str + en + ': ' + str(ev)
236 242 return s
237 243
238 244 def __repr__(self):
239 245 return "CompositeError(%i)"%len(self.elist)
240 246
241 247 def print_tracebacks(self, excid=None):
242 248 if excid is None:
243 249 for (en,ev,etb,ei) in self.elist:
244 250 print (self._get_engine_str(ei))
245 251 print (etb or 'No traceback available')
246 252 print ()
247 253 else:
248 254 try:
249 255 en,ev,etb,ei = self.elist[excid]
250 256 except:
251 257 raise IndexError("an exception with index %i does not exist"%excid)
252 258 else:
253 259 print (self._get_engine_str(ei))
254 260 print (etb or 'No traceback available')
255 261
256 262 def raise_exception(self, excid=0):
257 263 try:
258 264 en,ev,etb,ei = self.elist[excid]
259 265 except:
260 266 raise IndexError("an exception with index %i does not exist"%excid)
261 267 else:
262 268 raise RemoteError(en, ev, etb, ei)
263 269
264 270
265 271 def collect_exceptions(rdict_or_list, method='unspecified'):
266 272 """check a result dict for errors, and raise CompositeError if any exist.
267 273 Passthrough otherwise."""
268 274 elist = []
269 275 if isinstance(rdict_or_list, dict):
270 276 rlist = rdict_or_list.values()
271 277 else:
272 278 rlist = rdict_or_list
273 279 for r in rlist:
274 280 if isinstance(r, RemoteError):
275 281 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
276 282 # Sometimes we could have CompositeError in our list. Just take
277 283 # the errors out of them and put them in our new list. This
278 284 # has the effect of flattening lists of CompositeErrors into one
279 285 # CompositeError
280 286 if en=='CompositeError':
281 287 for e in ev.elist:
282 288 elist.append(e)
283 289 else:
284 290 elist.append((en, ev, etb, ei))
285 291 if len(elist)==0:
286 292 return rdict_or_list
287 293 else:
288 294 msg = "one or more exceptions from call to method: %s" % (method)
289 295 # This silliness is needed so the debugger has access to the exception
290 296 # instance (e in this case)
291 297 try:
292 298 raise CompositeError(msg, elist)
293 299 except CompositeError as e:
294 300 raise e
295 301
296 302 def wrap_exception(engine_info={}):
297 303 etype, evalue, tb = sys.exc_info()
298 304 stb = traceback.format_exception(etype, evalue, tb)
299 305 exc_content = {
300 306 'status' : 'error',
301 307 'traceback' : stb,
302 308 'ename' : unicode(etype.__name__),
303 309 'evalue' : unicode(evalue),
304 310 'engine_info' : engine_info
305 311 }
306 312 return exc_content
307 313
308 314 def unwrap_exception(content):
309 315 err = RemoteError(content['ename'], content['evalue'],
310 316 ''.join(content['traceback']),
311 317 content.get('engine_info', {}))
312 318 return err
313 319
@@ -1,72 +1,77 b''
1 """Base config factories."""
1 """Base config factories.
2
3 Authors:
4
5 * Min RK
6 """
2 7
3 8 #-----------------------------------------------------------------------------
4 9 # Copyright (C) 2010-2011 The IPython Development Team
5 10 #
6 11 # Distributed under the terms of the BSD License. The full license is in
7 12 # the file COPYING, distributed as part of this software.
8 13 #-----------------------------------------------------------------------------
9 14
10 15 #-----------------------------------------------------------------------------
11 16 # Imports
12 17 #-----------------------------------------------------------------------------
13 18
14 19
15 20 import logging
16 21 import os
17 22
18 23 import zmq
19 24 from zmq.eventloop.ioloop import IOLoop
20 25
21 26 from IPython.config.configurable import Configurable
22 27 from IPython.utils.traitlets import Int, Instance, Unicode
23 28
24 29 from IPython.parallel.util import select_random_ports
25 30 from IPython.zmq.session import Session, SessionFactory
26 31
27 32 #-----------------------------------------------------------------------------
28 33 # Classes
29 34 #-----------------------------------------------------------------------------
30 35
31 36
32 37 class RegistrationFactory(SessionFactory):
33 38 """The Base Configurable for objects that involve registration."""
34 39
35 40 url = Unicode('', config=True,
36 41 help="""The 0MQ url used for registration. This sets transport, ip, and port
37 42 in one variable. For example: url='tcp://127.0.0.1:12345' or
38 43 url='epgm://*:90210'""") # url takes precedence over ip,regport,transport
39 44 transport = Unicode('tcp', config=True,
40 45 help="""The 0MQ transport for communications. This will likely be
41 46 the default of 'tcp', but other values include 'ipc', 'epgm', 'inproc'.""")
42 47 ip = Unicode('127.0.0.1', config=True,
43 48 help="""The IP address for registration. This is generally either
44 49 '127.0.0.1' for loopback only or '*' for all interfaces.
45 50 [default: '127.0.0.1']""")
46 51 regport = Int(config=True,
47 52 help="""The port on which the Hub listens for registration.""")
48 53 def _regport_default(self):
49 54 return select_random_ports(1)[0]
50 55
51 56 def __init__(self, **kwargs):
52 57 super(RegistrationFactory, self).__init__(**kwargs)
53 58 self._propagate_url()
54 59 self._rebuild_url()
55 60 self.on_trait_change(self._propagate_url, 'url')
56 61 self.on_trait_change(self._rebuild_url, 'ip')
57 62 self.on_trait_change(self._rebuild_url, 'transport')
58 63 self.on_trait_change(self._rebuild_url, 'regport')
59 64
60 65 def _rebuild_url(self):
61 66 self.url = "%s://%s:%i"%(self.transport, self.ip, self.regport)
62 67
63 68 def _propagate_url(self):
64 69 """Ensure self.url contains full transport://interface:port"""
65 70 if self.url:
66 71 iface = self.url.split('://',1)
67 72 if len(iface) == 2:
68 73 self.transport,iface = iface
69 74 iface = iface.split(':')
70 75 self.ip = iface[0]
71 76 if iface[1]:
72 77 self.regport = int(iface[1])
@@ -1,129 +1,134 b''
1 """base class for parallel client tests"""
1 """base class for parallel client tests
2
3 Authors:
4
5 * Min RK
6 """
2 7
3 8 #-------------------------------------------------------------------------------
4 9 # Copyright (C) 2011 The IPython Development Team
5 10 #
6 11 # Distributed under the terms of the BSD License. The full license is in
7 12 # the file COPYING, distributed as part of this software.
8 13 #-------------------------------------------------------------------------------
9 14
10 15 import sys
11 16 import tempfile
12 17 import time
13 18
14 19 from nose import SkipTest
15 20
16 21 import zmq
17 22 from zmq.tests import BaseZMQTestCase
18 23
19 24 from IPython.external.decorator import decorator
20 25
21 26 from IPython.parallel import error
22 27 from IPython.parallel import Client
23 28
24 29 from IPython.parallel.tests import launchers, add_engines
25 30
26 31 # simple tasks for use in apply tests
27 32
28 33 def segfault():
29 34 """this will segfault"""
30 35 import ctypes
31 36 ctypes.memset(-1,0,1)
32 37
33 38 def crash():
34 39 """from stdlib crashers in the test suite"""
35 40 import types
36 41 if sys.platform.startswith('win'):
37 42 import ctypes
38 43 ctypes.windll.kernel32.SetErrorMode(0x0002);
39 44
40 45 co = types.CodeType(0, 0, 0, 0, b'\x04\x71\x00\x00',
41 46 (), (), (), '', '', 1, b'')
42 47 exec(co)
43 48
44 49 def wait(n):
45 50 """sleep for a time"""
46 51 import time
47 52 time.sleep(n)
48 53 return n
49 54
50 55 def raiser(eclass):
51 56 """raise an exception"""
52 57 raise eclass()
53 58
54 59 # test decorator for skipping tests when libraries are unavailable
55 60 def skip_without(*names):
56 61 """skip a test if some names are not importable"""
57 62 @decorator
58 63 def skip_without_names(f, *args, **kwargs):
59 64 """decorator to skip tests in the absence of numpy."""
60 65 for name in names:
61 66 try:
62 67 __import__(name)
63 68 except ImportError:
64 69 raise SkipTest
65 70 return f(*args, **kwargs)
66 71 return skip_without_names
67 72
68 73 class ClusterTestCase(BaseZMQTestCase):
69 74
70 75 def add_engines(self, n=1, block=True):
71 76 """add multiple engines to our cluster"""
72 77 self.engines.extend(add_engines(n))
73 78 if block:
74 79 self.wait_on_engines()
75 80
76 81 def wait_on_engines(self, timeout=5):
77 82 """wait for our engines to connect."""
78 83 n = len(self.engines)+self.base_engine_count
79 84 tic = time.time()
80 85 while time.time()-tic < timeout and len(self.client.ids) < n:
81 86 time.sleep(0.1)
82 87
83 88 assert not len(self.client.ids) < n, "waiting for engines timed out"
84 89
85 90 def connect_client(self):
86 91 """connect a client with my Context, and track its sockets for cleanup"""
87 92 c = Client(profile='iptest', context=self.context)
88 93 for name in filter(lambda n:n.endswith('socket'), dir(c)):
89 94 s = getattr(c, name)
90 95 s.setsockopt(zmq.LINGER, 0)
91 96 self.sockets.append(s)
92 97 return c
93 98
94 99 def assertRaisesRemote(self, etype, f, *args, **kwargs):
95 100 try:
96 101 try:
97 102 f(*args, **kwargs)
98 103 except error.CompositeError as e:
99 104 e.raise_exception()
100 105 except error.RemoteError as e:
101 106 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
102 107 else:
103 108 self.fail("should have raised a RemoteError")
104 109
105 110 def setUp(self):
106 111 BaseZMQTestCase.setUp(self)
107 112 self.client = self.connect_client()
108 113 # start every test with clean engine namespaces:
109 114 self.client.clear(block=True)
110 115 self.base_engine_count=len(self.client.ids)
111 116 self.engines=[]
112 117
113 118 def tearDown(self):
114 119 # self.client.clear(block=True)
115 120 # close fds:
116 121 for e in filter(lambda e: e.poll() is not None, launchers):
117 122 launchers.remove(e)
118 123
119 124 # allow flushing of incoming messages to prevent crash on socket close
120 125 self.client.wait(timeout=2)
121 126 # time.sleep(2)
122 127 self.client.spin()
123 128 self.client.close()
124 129 BaseZMQTestCase.tearDown(self)
125 130 # this will be redundant when pyzmq merges PR #88
126 131 # self.context.term()
127 132 # print tempfile.TemporaryFile().fileno(),
128 133 # sys.stdout.flush()
129 134 No newline at end of file
@@ -1,68 +1,73 b''
1 """Tests for asyncresult.py"""
1 """Tests for asyncresult.py
2
3 Authors:
4
5 * Min RK
6 """
2 7
3 8 #-------------------------------------------------------------------------------
4 9 # Copyright (C) 2011 The IPython Development Team
5 10 #
6 11 # Distributed under the terms of the BSD License. The full license is in
7 12 # the file COPYING, distributed as part of this software.
8 13 #-------------------------------------------------------------------------------
9 14
10 15 #-------------------------------------------------------------------------------
11 16 # Imports
12 17 #-------------------------------------------------------------------------------
13 18
14 19
15 20 from IPython.parallel.error import TimeoutError
16 21
17 22 from IPython.parallel.tests import add_engines
18 23 from .clienttest import ClusterTestCase
19 24
20 25 def setup():
21 26 add_engines(2)
22 27
23 28 def wait(n):
24 29 import time
25 30 time.sleep(n)
26 31 return n
27 32
28 33 class AsyncResultTest(ClusterTestCase):
29 34
30 35 def test_single_result(self):
31 36 eid = self.client.ids[-1]
32 37 ar = self.client[eid].apply_async(lambda : 42)
33 38 self.assertEquals(ar.get(), 42)
34 39 ar = self.client[[eid]].apply_async(lambda : 42)
35 40 self.assertEquals(ar.get(), [42])
36 41 ar = self.client[-1:].apply_async(lambda : 42)
37 42 self.assertEquals(ar.get(), [42])
38 43
39 44 def test_get_after_done(self):
40 45 ar = self.client[-1].apply_async(lambda : 42)
41 46 ar.wait()
42 47 self.assertTrue(ar.ready())
43 48 self.assertEquals(ar.get(), 42)
44 49 self.assertEquals(ar.get(), 42)
45 50
46 51 def test_get_before_done(self):
47 52 ar = self.client[-1].apply_async(wait, 0.1)
48 53 self.assertRaises(TimeoutError, ar.get, 0)
49 54 ar.wait(0)
50 55 self.assertFalse(ar.ready())
51 56 self.assertEquals(ar.get(), 0.1)
52 57
53 58 def test_get_after_error(self):
54 59 ar = self.client[-1].apply_async(lambda : 1/0)
55 60 ar.wait()
56 61 self.assertRaisesRemote(ZeroDivisionError, ar.get)
57 62 self.assertRaisesRemote(ZeroDivisionError, ar.get)
58 63 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
59 64
60 65 def test_get_dict(self):
61 66 n = len(self.client)
62 67 ar = self.client[:].apply_async(lambda : 5)
63 68 self.assertEquals(ar.get(), [5]*n)
64 69 d = ar.get_dict()
65 70 self.assertEquals(sorted(d.keys()), sorted(self.client.ids))
66 71 for eid,r in d.iteritems():
67 72 self.assertEquals(r, 5)
68 73
@@ -1,244 +1,249 b''
1 """Tests for parallel client.py"""
1 """Tests for parallel client.py
2
3 Authors:
4
5 * Min RK
6 """
2 7
3 8 #-------------------------------------------------------------------------------
4 9 # Copyright (C) 2011 The IPython Development Team
5 10 #
6 11 # Distributed under the terms of the BSD License. The full license is in
7 12 # the file COPYING, distributed as part of this software.
8 13 #-------------------------------------------------------------------------------
9 14
10 15 #-------------------------------------------------------------------------------
11 16 # Imports
12 17 #-------------------------------------------------------------------------------
13 18
14 19 import time
15 20 from datetime import datetime
16 21 from tempfile import mktemp
17 22
18 23 import zmq
19 24
20 25 from IPython.parallel.client import client as clientmod
21 26 from IPython.parallel import error
22 27 from IPython.parallel import AsyncResult, AsyncHubResult
23 28 from IPython.parallel import LoadBalancedView, DirectView
24 29
25 30 from clienttest import ClusterTestCase, segfault, wait, add_engines
26 31
27 32 def setup():
28 33 add_engines(4)
29 34
30 35 class TestClient(ClusterTestCase):
31 36
32 37 def test_ids(self):
33 38 n = len(self.client.ids)
34 39 self.add_engines(3)
35 40 self.assertEquals(len(self.client.ids), n+3)
36 41
37 42 def test_view_indexing(self):
38 43 """test index access for views"""
39 44 self.add_engines(2)
40 45 targets = self.client._build_targets('all')[-1]
41 46 v = self.client[:]
42 47 self.assertEquals(v.targets, targets)
43 48 t = self.client.ids[2]
44 49 v = self.client[t]
45 50 self.assert_(isinstance(v, DirectView))
46 51 self.assertEquals(v.targets, t)
47 52 t = self.client.ids[2:4]
48 53 v = self.client[t]
49 54 self.assert_(isinstance(v, DirectView))
50 55 self.assertEquals(v.targets, t)
51 56 v = self.client[::2]
52 57 self.assert_(isinstance(v, DirectView))
53 58 self.assertEquals(v.targets, targets[::2])
54 59 v = self.client[1::3]
55 60 self.assert_(isinstance(v, DirectView))
56 61 self.assertEquals(v.targets, targets[1::3])
57 62 v = self.client[:-3]
58 63 self.assert_(isinstance(v, DirectView))
59 64 self.assertEquals(v.targets, targets[:-3])
60 65 v = self.client[-1]
61 66 self.assert_(isinstance(v, DirectView))
62 67 self.assertEquals(v.targets, targets[-1])
63 68 self.assertRaises(TypeError, lambda : self.client[None])
64 69
65 70 def test_lbview_targets(self):
66 71 """test load_balanced_view targets"""
67 72 v = self.client.load_balanced_view()
68 73 self.assertEquals(v.targets, None)
69 74 v = self.client.load_balanced_view(-1)
70 75 self.assertEquals(v.targets, [self.client.ids[-1]])
71 76 v = self.client.load_balanced_view('all')
72 77 self.assertEquals(v.targets, self.client.ids)
73 78
74 79 def test_targets(self):
75 80 """test various valid targets arguments"""
76 81 build = self.client._build_targets
77 82 ids = self.client.ids
78 83 idents,targets = build(None)
79 84 self.assertEquals(ids, targets)
80 85
81 86 def test_clear(self):
82 87 """test clear behavior"""
83 88 # self.add_engines(2)
84 89 v = self.client[:]
85 90 v.block=True
86 91 v.push(dict(a=5))
87 92 v.pull('a')
88 93 id0 = self.client.ids[-1]
89 94 self.client.clear(targets=id0, block=True)
90 95 a = self.client[:-1].get('a')
91 96 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
92 97 self.client.clear(block=True)
93 98 for i in self.client.ids:
94 99 # print i
95 100 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
96 101
97 102 def test_get_result(self):
98 103 """test getting results from the Hub."""
99 104 c = clientmod.Client(profile='iptest')
100 105 # self.add_engines(1)
101 106 t = c.ids[-1]
102 107 ar = c[t].apply_async(wait, 1)
103 108 # give the monitor time to notice the message
104 109 time.sleep(.25)
105 110 ahr = self.client.get_result(ar.msg_ids)
106 111 self.assertTrue(isinstance(ahr, AsyncHubResult))
107 112 self.assertEquals(ahr.get(), ar.get())
108 113 ar2 = self.client.get_result(ar.msg_ids)
109 114 self.assertFalse(isinstance(ar2, AsyncHubResult))
110 115 c.close()
111 116
112 117 def test_ids_list(self):
113 118 """test client.ids"""
114 119 # self.add_engines(2)
115 120 ids = self.client.ids
116 121 self.assertEquals(ids, self.client._ids)
117 122 self.assertFalse(ids is self.client._ids)
118 123 ids.remove(ids[-1])
119 124 self.assertNotEquals(ids, self.client._ids)
120 125
121 126 def test_queue_status(self):
122 127 # self.addEngine(4)
123 128 ids = self.client.ids
124 129 id0 = ids[0]
125 130 qs = self.client.queue_status(targets=id0)
126 131 self.assertTrue(isinstance(qs, dict))
127 132 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
128 133 allqs = self.client.queue_status()
129 134 self.assertTrue(isinstance(allqs, dict))
130 135 self.assertEquals(sorted(allqs.keys()), sorted(self.client.ids + ['unassigned']))
131 136 unassigned = allqs.pop('unassigned')
132 137 for eid,qs in allqs.items():
133 138 self.assertTrue(isinstance(qs, dict))
134 139 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
135 140
136 141 def test_shutdown(self):
137 142 # self.addEngine(4)
138 143 ids = self.client.ids
139 144 id0 = ids[0]
140 145 self.client.shutdown(id0, block=True)
141 146 while id0 in self.client.ids:
142 147 time.sleep(0.1)
143 148 self.client.spin()
144 149
145 150 self.assertRaises(IndexError, lambda : self.client[id0])
146 151
147 152 def test_result_status(self):
148 153 pass
149 154 # to be written
150 155
151 156 def test_db_query_dt(self):
152 157 """test db query by date"""
153 158 hist = self.client.hub_history()
154 159 middle = self.client.db_query({'msg_id' : hist[len(hist)/2]})[0]
155 160 tic = middle['submitted']
156 161 before = self.client.db_query({'submitted' : {'$lt' : tic}})
157 162 after = self.client.db_query({'submitted' : {'$gte' : tic}})
158 163 self.assertEquals(len(before)+len(after),len(hist))
159 164 for b in before:
160 165 self.assertTrue(b['submitted'] < tic)
161 166 for a in after:
162 167 self.assertTrue(a['submitted'] >= tic)
163 168 same = self.client.db_query({'submitted' : tic})
164 169 for s in same:
165 170 self.assertTrue(s['submitted'] == tic)
166 171
167 172 def test_db_query_keys(self):
168 173 """test extracting subset of record keys"""
169 174 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
170 175 for rec in found:
171 176 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
172 177
173 178 def test_db_query_msg_id(self):
174 179 """ensure msg_id is always in db queries"""
175 180 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
176 181 for rec in found:
177 182 self.assertTrue('msg_id' in rec.keys())
178 183 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
179 184 for rec in found:
180 185 self.assertTrue('msg_id' in rec.keys())
181 186 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
182 187 for rec in found:
183 188 self.assertTrue('msg_id' in rec.keys())
184 189
185 190 def test_db_query_in(self):
186 191 """test db query with '$in','$nin' operators"""
187 192 hist = self.client.hub_history()
188 193 even = hist[::2]
189 194 odd = hist[1::2]
190 195 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
191 196 found = [ r['msg_id'] for r in recs ]
192 197 self.assertEquals(set(even), set(found))
193 198 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
194 199 found = [ r['msg_id'] for r in recs ]
195 200 self.assertEquals(set(odd), set(found))
196 201
197 202 def test_hub_history(self):
198 203 hist = self.client.hub_history()
199 204 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
200 205 recdict = {}
201 206 for rec in recs:
202 207 recdict[rec['msg_id']] = rec
203 208
204 209 latest = datetime(1984,1,1)
205 210 for msg_id in hist:
206 211 rec = recdict[msg_id]
207 212 newt = rec['submitted']
208 213 self.assertTrue(newt >= latest)
209 214 latest = newt
210 215 ar = self.client[-1].apply_async(lambda : 1)
211 216 ar.get()
212 217 time.sleep(0.25)
213 218 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
214 219
215 220 def test_resubmit(self):
216 221 def f():
217 222 import random
218 223 return random.random()
219 224 v = self.client.load_balanced_view()
220 225 ar = v.apply_async(f)
221 226 r1 = ar.get(1)
222 227 ahr = self.client.resubmit(ar.msg_ids)
223 228 r2 = ahr.get(1)
224 229 self.assertFalse(r1 == r2)
225 230
226 231 def test_resubmit_inflight(self):
227 232 """ensure ValueError on resubmit of inflight task"""
228 233 v = self.client.load_balanced_view()
229 234 ar = v.apply_async(time.sleep,1)
230 235 # give the message a chance to arrive
231 236 time.sleep(0.2)
232 237 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
233 238 ar.get(2)
234 239
235 240 def test_resubmit_badkey(self):
236 241 """ensure KeyError on resubmit of nonexistant task"""
237 242 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
238 243
239 244 def test_purge_results(self):
240 245 hist = self.client.hub_history()
241 246 self.client.purge_results(hist)
242 247 newhist = self.client.hub_history()
243 248 self.assertTrue(len(newhist) == 0)
244 249
@@ -1,173 +1,178 b''
1 """Tests for db backends"""
1 """Tests for db backends
2
3 Authors:
4
5 * Min RK
6 """
2 7
3 8 #-------------------------------------------------------------------------------
4 9 # Copyright (C) 2011 The IPython Development Team
5 10 #
6 11 # Distributed under the terms of the BSD License. The full license is in
7 12 # the file COPYING, distributed as part of this software.
8 13 #-------------------------------------------------------------------------------
9 14
10 15 #-------------------------------------------------------------------------------
11 16 # Imports
12 17 #-------------------------------------------------------------------------------
13 18
14 19
15 20 import tempfile
16 21 import time
17 22
18 23 from datetime import datetime, timedelta
19 24 from unittest import TestCase
20 25
21 26 from nose import SkipTest
22 27
23 28 from IPython.parallel import error
24 29 from IPython.parallel.controller.dictdb import DictDB
25 30 from IPython.parallel.controller.sqlitedb import SQLiteDB
26 31 from IPython.parallel.controller.hub import init_record, empty_record
27 32
28 33 from IPython.zmq.session import Session
29 34
30 35
31 36 #-------------------------------------------------------------------------------
32 37 # TestCases
33 38 #-------------------------------------------------------------------------------
34 39
35 40 class TestDictBackend(TestCase):
36 41 def setUp(self):
37 42 self.session = Session()
38 43 self.db = self.create_db()
39 44 self.load_records(16)
40 45
41 46 def create_db(self):
42 47 return DictDB()
43 48
44 49 def load_records(self, n=1):
45 50 """load n records for testing"""
46 51 #sleep 1/10 s, to ensure timestamp is different to previous calls
47 52 time.sleep(0.1)
48 53 msg_ids = []
49 54 for i in range(n):
50 55 msg = self.session.msg('apply_request', content=dict(a=5))
51 56 msg['buffers'] = []
52 57 rec = init_record(msg)
53 58 msg_ids.append(msg['msg_id'])
54 59 self.db.add_record(msg['msg_id'], rec)
55 60 return msg_ids
56 61
57 62 def test_add_record(self):
58 63 before = self.db.get_history()
59 64 self.load_records(5)
60 65 after = self.db.get_history()
61 66 self.assertEquals(len(after), len(before)+5)
62 67 self.assertEquals(after[:-5],before)
63 68
64 69 def test_drop_record(self):
65 70 msg_id = self.load_records()[-1]
66 71 rec = self.db.get_record(msg_id)
67 72 self.db.drop_record(msg_id)
68 73 self.assertRaises(KeyError,self.db.get_record, msg_id)
69 74
70 75 def _round_to_millisecond(self, dt):
71 76 """necessary because mongodb rounds microseconds"""
72 77 micro = dt.microsecond
73 78 extra = int(str(micro)[-3:])
74 79 return dt - timedelta(microseconds=extra)
75 80
76 81 def test_update_record(self):
77 82 now = self._round_to_millisecond(datetime.now())
78 83 #
79 84 msg_id = self.db.get_history()[-1]
80 85 rec1 = self.db.get_record(msg_id)
81 86 data = {'stdout': 'hello there', 'completed' : now}
82 87 self.db.update_record(msg_id, data)
83 88 rec2 = self.db.get_record(msg_id)
84 89 self.assertEquals(rec2['stdout'], 'hello there')
85 90 self.assertEquals(rec2['completed'], now)
86 91 rec1.update(data)
87 92 self.assertEquals(rec1, rec2)
88 93
89 94 # def test_update_record_bad(self):
90 95 # """test updating nonexistant records"""
91 96 # msg_id = str(uuid.uuid4())
92 97 # data = {'stdout': 'hello there'}
93 98 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
94 99
95 100 def test_find_records_dt(self):
96 101 """test finding records by date"""
97 102 hist = self.db.get_history()
98 103 middle = self.db.get_record(hist[len(hist)/2])
99 104 tic = middle['submitted']
100 105 before = self.db.find_records({'submitted' : {'$lt' : tic}})
101 106 after = self.db.find_records({'submitted' : {'$gte' : tic}})
102 107 self.assertEquals(len(before)+len(after),len(hist))
103 108 for b in before:
104 109 self.assertTrue(b['submitted'] < tic)
105 110 for a in after:
106 111 self.assertTrue(a['submitted'] >= tic)
107 112 same = self.db.find_records({'submitted' : tic})
108 113 for s in same:
109 114 self.assertTrue(s['submitted'] == tic)
110 115
111 116 def test_find_records_keys(self):
112 117 """test extracting subset of record keys"""
113 118 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
114 119 for rec in found:
115 120 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
116 121
117 122 def test_find_records_msg_id(self):
118 123 """ensure msg_id is always in found records"""
119 124 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
120 125 for rec in found:
121 126 self.assertTrue('msg_id' in rec.keys())
122 127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
123 128 for rec in found:
124 129 self.assertTrue('msg_id' in rec.keys())
125 130 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
126 131 for rec in found:
127 132 self.assertTrue('msg_id' in rec.keys())
128 133
129 134 def test_find_records_in(self):
130 135 """test finding records with '$in','$nin' operators"""
131 136 hist = self.db.get_history()
132 137 even = hist[::2]
133 138 odd = hist[1::2]
134 139 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
135 140 found = [ r['msg_id'] for r in recs ]
136 141 self.assertEquals(set(even), set(found))
137 142 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
138 143 found = [ r['msg_id'] for r in recs ]
139 144 self.assertEquals(set(odd), set(found))
140 145
141 146 def test_get_history(self):
142 147 msg_ids = self.db.get_history()
143 148 latest = datetime(1984,1,1)
144 149 for msg_id in msg_ids:
145 150 rec = self.db.get_record(msg_id)
146 151 newt = rec['submitted']
147 152 self.assertTrue(newt >= latest)
148 153 latest = newt
149 154 msg_id = self.load_records(1)[-1]
150 155 self.assertEquals(self.db.get_history()[-1],msg_id)
151 156
152 157 def test_datetime(self):
153 158 """get/set timestamps with datetime objects"""
154 159 msg_id = self.db.get_history()[-1]
155 160 rec = self.db.get_record(msg_id)
156 161 self.assertTrue(isinstance(rec['submitted'], datetime))
157 162 self.db.update_record(msg_id, dict(completed=datetime.now()))
158 163 rec = self.db.get_record(msg_id)
159 164 self.assertTrue(isinstance(rec['completed'], datetime))
160 165
161 166 def test_drop_matching(self):
162 167 msg_ids = self.load_records(10)
163 168 query = {'msg_id' : {'$in':msg_ids}}
164 169 self.db.drop_matching_records(query)
165 170 recs = self.db.find_records(query)
166 171 self.assertTrue(len(recs)==0)
167 172
168 173 class TestSQLiteBackend(TestDictBackend):
169 174 def create_db(self):
170 175 return SQLiteDB(location=tempfile.gettempdir())
171 176
172 177 def tearDown(self):
173 178 self.db._db.close()
@@ -1,101 +1,106 b''
1 """Tests for dependency.py"""
1 """Tests for dependency.py
2
3 Authors:
4
5 * Min RK
6 """
2 7
3 8 __docformat__ = "restructuredtext en"
4 9
5 10 #-------------------------------------------------------------------------------
6 11 # Copyright (C) 2011 The IPython Development Team
7 12 #
8 13 # Distributed under the terms of the BSD License. The full license is in
9 14 # the file COPYING, distributed as part of this software.
10 15 #-------------------------------------------------------------------------------
11 16
12 17 #-------------------------------------------------------------------------------
13 18 # Imports
14 19 #-------------------------------------------------------------------------------
15 20
16 21 # import
17 22 import os
18 23
19 24 from IPython.utils.pickleutil import can, uncan
20 25
21 26 import IPython.parallel as pmod
22 27 from IPython.parallel.util import interactive
23 28
24 29 from IPython.parallel.tests import add_engines
25 30 from .clienttest import ClusterTestCase
26 31
27 32 def setup():
28 33 add_engines(1)
29 34
30 35 @pmod.require('time')
31 36 def wait(n):
32 37 time.sleep(n)
33 38 return n
34 39
35 40 mixed = map(str, range(10))
36 41 completed = map(str, range(0,10,2))
37 42 failed = map(str, range(1,10,2))
38 43
39 44 class DependencyTest(ClusterTestCase):
40 45
41 46 def setUp(self):
42 47 ClusterTestCase.setUp(self)
43 48 self.user_ns = {'__builtins__' : __builtins__}
44 49 self.view = self.client.load_balanced_view()
45 50 self.dview = self.client[-1]
46 51 self.succeeded = set(map(str, range(0,25,2)))
47 52 self.failed = set(map(str, range(1,25,2)))
48 53
49 54 def assertMet(self, dep):
50 55 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
51 56
52 57 def assertUnmet(self, dep):
53 58 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
54 59
55 60 def assertUnreachable(self, dep):
56 61 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
57 62
58 63 def assertReachable(self, dep):
59 64 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
60 65
61 66 def cancan(self, f):
62 67 """decorator to pass through canning into self.user_ns"""
63 68 return uncan(can(f), self.user_ns)
64 69
65 70 def test_require_imports(self):
66 71 """test that @require imports names"""
67 72 @self.cancan
68 73 @pmod.require('urllib')
69 74 @interactive
70 75 def encode(dikt):
71 76 return urllib.urlencode(dikt)
72 77 # must pass through canning to properly connect namespaces
73 78 self.assertEquals(encode(dict(a=5)), 'a=5')
74 79
75 80 def test_success_only(self):
76 81 dep = pmod.Dependency(mixed, success=True, failure=False)
77 82 self.assertUnmet(dep)
78 83 self.assertUnreachable(dep)
79 84 dep.all=False
80 85 self.assertMet(dep)
81 86 self.assertReachable(dep)
82 87 dep = pmod.Dependency(completed, success=True, failure=False)
83 88 self.assertMet(dep)
84 89 self.assertReachable(dep)
85 90 dep.all=False
86 91 self.assertMet(dep)
87 92 self.assertReachable(dep)
88 93
89 94 def test_failure_only(self):
90 95 dep = pmod.Dependency(mixed, success=False, failure=True)
91 96 self.assertUnmet(dep)
92 97 self.assertUnreachable(dep)
93 98 dep.all=False
94 99 self.assertMet(dep)
95 100 self.assertReachable(dep)
96 101 dep = pmod.Dependency(completed, success=False, failure=True)
97 102 self.assertUnmet(dep)
98 103 self.assertUnreachable(dep)
99 104 dep.all=False
100 105 self.assertUnmet(dep)
101 106 self.assertUnreachable(dep)
@@ -1,120 +1,125 b''
1 """test LoadBalancedView objects"""
1 """test LoadBalancedView objects
2
3 Authors:
4
5 * Min RK
6 """
2 7 # -*- coding: utf-8 -*-
3 8 #-------------------------------------------------------------------------------
4 9 # Copyright (C) 2011 The IPython Development Team
5 10 #
6 11 # Distributed under the terms of the BSD License. The full license is in
7 12 # the file COPYING, distributed as part of this software.
8 13 #-------------------------------------------------------------------------------
9 14
10 15 #-------------------------------------------------------------------------------
11 16 # Imports
12 17 #-------------------------------------------------------------------------------
13 18
14 19 import sys
15 20 import time
16 21
17 22 import zmq
18 23
19 24 from IPython import parallel as pmod
20 25 from IPython.parallel import error
21 26
22 27 from IPython.parallel.tests import add_engines
23 28
24 29 from .clienttest import ClusterTestCase, crash, wait, skip_without
25 30
26 31 def setup():
27 32 add_engines(3)
28 33
29 34 class TestLoadBalancedView(ClusterTestCase):
30 35
31 36 def setUp(self):
32 37 ClusterTestCase.setUp(self)
33 38 self.view = self.client.load_balanced_view()
34 39
35 40 def test_z_crash_task(self):
36 41 """test graceful handling of engine death (balanced)"""
37 42 # self.add_engines(1)
38 43 ar = self.view.apply_async(crash)
39 44 self.assertRaisesRemote(error.EngineError, ar.get, 10)
40 45 eid = ar.engine_id
41 46 tic = time.time()
42 47 while eid in self.client.ids and time.time()-tic < 5:
43 48 time.sleep(.01)
44 49 self.client.spin()
45 50 self.assertFalse(eid in self.client.ids, "Engine should have died")
46 51
47 52 def test_map(self):
48 53 def f(x):
49 54 return x**2
50 55 data = range(16)
51 56 r = self.view.map_sync(f, data)
52 57 self.assertEquals(r, map(f, data))
53 58
54 59 def test_abort(self):
55 60 view = self.view
56 61 ar = self.client[:].apply_async(time.sleep, .5)
57 62 ar2 = view.apply_async(lambda : 2)
58 63 ar3 = view.apply_async(lambda : 3)
59 64 view.abort(ar2)
60 65 view.abort(ar3.msg_ids)
61 66 self.assertRaises(error.TaskAborted, ar2.get)
62 67 self.assertRaises(error.TaskAborted, ar3.get)
63 68
64 69 def test_retries(self):
65 70 add_engines(3)
66 71 view = self.view
67 72 view.timeout = 1 # prevent hang if this doesn't behave
68 73 def fail():
69 74 assert False
70 75 for r in range(len(self.client)-1):
71 76 with view.temp_flags(retries=r):
72 77 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
73 78
74 79 with view.temp_flags(retries=len(self.client), timeout=0.25):
75 80 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
76 81
77 82 def test_invalid_dependency(self):
78 83 view = self.view
79 84 with view.temp_flags(after='12345'):
80 85 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
81 86
82 87 def test_impossible_dependency(self):
83 88 if len(self.client) < 2:
84 89 add_engines(2)
85 90 view = self.client.load_balanced_view()
86 91 ar1 = view.apply_async(lambda : 1)
87 92 ar1.get()
88 93 e1 = ar1.engine_id
89 94 e2 = e1
90 95 while e2 == e1:
91 96 ar2 = view.apply_async(lambda : 1)
92 97 ar2.get()
93 98 e2 = ar2.engine_id
94 99
95 100 with view.temp_flags(follow=[ar1, ar2]):
96 101 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
97 102
98 103
99 104 def test_follow(self):
100 105 ar = self.view.apply_async(lambda : 1)
101 106 ar.get()
102 107 ars = []
103 108 first_id = ar.engine_id
104 109
105 110 self.view.follow = ar
106 111 for i in range(5):
107 112 ars.append(self.view.apply_async(lambda : 1))
108 113 self.view.wait(ars)
109 114 for ar in ars:
110 115 self.assertEquals(ar.engine_id, first_id)
111 116
112 117 def test_after(self):
113 118 view = self.view
114 119 ar = view.apply_async(time.sleep, 0.5)
115 120 with view.temp_flags(after=ar):
116 121 ar2 = view.apply_async(lambda : 1)
117 122
118 123 ar.wait()
119 124 ar2.wait()
120 125 self.assertTrue(ar2.started > ar.completed)
@@ -1,37 +1,42 b''
1 """Tests for mongodb backend"""
1 """Tests for mongodb backend
2
3 Authors:
4
5 * Min RK
6 """
2 7
3 8 #-------------------------------------------------------------------------------
4 9 # Copyright (C) 2011 The IPython Development Team
5 10 #
6 11 # Distributed under the terms of the BSD License. The full license is in
7 12 # the file COPYING, distributed as part of this software.
8 13 #-------------------------------------------------------------------------------
9 14
10 15 #-------------------------------------------------------------------------------
11 16 # Imports
12 17 #-------------------------------------------------------------------------------
13 18
14 19 from nose import SkipTest
15 20
16 21 from pymongo import Connection
17 22 from IPython.parallel.controller.mongodb import MongoDB
18 23
19 24 from . import test_db
20 25
21 26 try:
22 27 c = Connection()
23 28 except Exception:
24 29 c=None
25 30
26 31 class TestMongoBackend(test_db.TestDictBackend):
27 32 """MongoDB backend tests"""
28 33
29 34 def create_db(self):
30 35 try:
31 36 return MongoDB(database='iptestdb', _connection=c)
32 37 except Exception:
33 38 raise SkipTest("Couldn't connect to mongodb")
34 39
35 40 def teardown(self):
36 41 if c is not None:
37 42 c.drop_database('iptestdb')
@@ -1,108 +1,113 b''
1 """test serialization with newserialized"""
1 """test serialization with newserialized
2
3 Authors:
4
5 * Min RK
6 """
2 7
3 8 #-------------------------------------------------------------------------------
4 9 # Copyright (C) 2011 The IPython Development Team
5 10 #
6 11 # Distributed under the terms of the BSD License. The full license is in
7 12 # the file COPYING, distributed as part of this software.
8 13 #-------------------------------------------------------------------------------
9 14
10 15 #-------------------------------------------------------------------------------
11 16 # Imports
12 17 #-------------------------------------------------------------------------------
13 18
14 19 from unittest import TestCase
15 20
16 21 from IPython.testing.decorators import parametric
17 22 from IPython.utils import newserialized as ns
18 23 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
19 24 from IPython.parallel.tests.clienttest import skip_without
20 25
21 26
22 27 class CanningTestCase(TestCase):
23 28 def test_canning(self):
24 29 d = dict(a=5,b=6)
25 30 cd = can(d)
26 31 self.assertTrue(isinstance(cd, dict))
27 32
28 33 def test_canned_function(self):
29 34 f = lambda : 7
30 35 cf = can(f)
31 36 self.assertTrue(isinstance(cf, CannedFunction))
32 37
33 38 @parametric
34 39 def test_can_roundtrip(cls):
35 40 objs = [
36 41 dict(),
37 42 set(),
38 43 list(),
39 44 ['a',1,['a',1],u'e'],
40 45 ]
41 46 return map(cls.run_roundtrip, objs)
42 47
43 48 @classmethod
44 49 def run_roundtrip(self, obj):
45 50 o = uncan(can(obj))
46 51 assert o == obj, "failed assertion: %r == %r"%(o,obj)
47 52
48 53 def test_serialized_interfaces(self):
49 54
50 55 us = {'a':10, 'b':range(10)}
51 56 s = ns.serialize(us)
52 57 uus = ns.unserialize(s)
53 58 self.assertTrue(isinstance(s, ns.SerializeIt))
54 59 self.assertEquals(uus, us)
55 60
56 61 def test_pickle_serialized(self):
57 62 obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
58 63 original = ns.UnSerialized(obj)
59 64 originalSer = ns.SerializeIt(original)
60 65 firstData = originalSer.getData()
61 66 firstTD = originalSer.getTypeDescriptor()
62 67 firstMD = originalSer.getMetadata()
63 68 self.assertEquals(firstTD, 'pickle')
64 69 self.assertEquals(firstMD, {})
65 70 unSerialized = ns.UnSerializeIt(originalSer)
66 71 secondObj = unSerialized.getObject()
67 72 for k, v in secondObj.iteritems():
68 73 self.assertEquals(obj[k], v)
69 74 secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
70 75 self.assertEquals(firstData, secondSer.getData())
71 76 self.assertEquals(firstTD, secondSer.getTypeDescriptor() )
72 77 self.assertEquals(firstMD, secondSer.getMetadata())
73 78
74 79 @skip_without('numpy')
75 80 def test_ndarray_serialized(self):
76 81 import numpy
77 82 a = numpy.linspace(0.0, 1.0, 1000)
78 83 unSer1 = ns.UnSerialized(a)
79 84 ser1 = ns.SerializeIt(unSer1)
80 85 td = ser1.getTypeDescriptor()
81 86 self.assertEquals(td, 'ndarray')
82 87 md = ser1.getMetadata()
83 88 self.assertEquals(md['shape'], a.shape)
84 89 self.assertEquals(md['dtype'], a.dtype.str)
85 90 buff = ser1.getData()
86 91 self.assertEquals(buff, numpy.getbuffer(a))
87 92 s = ns.Serialized(buff, td, md)
88 93 final = ns.unserialize(s)
89 94 self.assertEquals(numpy.getbuffer(a), numpy.getbuffer(final))
90 95 self.assertTrue((a==final).all())
91 96 self.assertEquals(a.dtype.str, final.dtype.str)
92 97 self.assertEquals(a.shape, final.shape)
93 98 # test non-copying:
94 99 a[2] = 1e9
95 100 self.assertTrue((a==final).all())
96 101
97 102 def test_uncan_function_globals(self):
98 103 """test that uncanning a module function restores it into its module"""
99 104 from re import search
100 105 cf = can(search)
101 106 csearch = uncan(cf)
102 107 self.assertEqual(csearch.__module__, search.__module__)
103 108 self.assertNotEqual(csearch('asd', 'asdf'), None)
104 109 csearch = uncan(cf, dict(a=5))
105 110 self.assertEqual(csearch.__module__, search.__module__)
106 111 self.assertNotEqual(csearch('asd', 'asdf'), None)
107 112
108 113 No newline at end of file
@@ -1,441 +1,446 b''
1 """test View objects"""
1 """test View objects
2
3 Authors:
4
5 * Min RK
6 """
2 7 # -*- coding: utf-8 -*-
3 8 #-------------------------------------------------------------------------------
4 9 # Copyright (C) 2011 The IPython Development Team
5 10 #
6 11 # Distributed under the terms of the BSD License. The full license is in
7 12 # the file COPYING, distributed as part of this software.
8 13 #-------------------------------------------------------------------------------
9 14
10 15 #-------------------------------------------------------------------------------
11 16 # Imports
12 17 #-------------------------------------------------------------------------------
13 18
14 19 import sys
15 20 import time
16 21 from tempfile import mktemp
17 22 from StringIO import StringIO
18 23
19 24 import zmq
20 25
21 26 from IPython import parallel as pmod
22 27 from IPython.parallel import error
23 28 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
24 29 from IPython.parallel import DirectView
25 30 from IPython.parallel.util import interactive
26 31
27 32 from IPython.parallel.tests import add_engines
28 33
29 34 from .clienttest import ClusterTestCase, crash, wait, skip_without
30 35
31 36 def setup():
32 37 add_engines(3)
33 38
34 39 class TestView(ClusterTestCase):
35 40
36 41 def test_z_crash_mux(self):
37 42 """test graceful handling of engine death (direct)"""
38 43 # self.add_engines(1)
39 44 eid = self.client.ids[-1]
40 45 ar = self.client[eid].apply_async(crash)
41 46 self.assertRaisesRemote(error.EngineError, ar.get)
42 47 eid = ar.engine_id
43 48 tic = time.time()
44 49 while eid in self.client.ids and time.time()-tic < 5:
45 50 time.sleep(.01)
46 51 self.client.spin()
47 52 self.assertFalse(eid in self.client.ids, "Engine should have died")
48 53
49 54 def test_push_pull(self):
50 55 """test pushing and pulling"""
51 56 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
52 57 t = self.client.ids[-1]
53 58 v = self.client[t]
54 59 push = v.push
55 60 pull = v.pull
56 61 v.block=True
57 62 nengines = len(self.client)
58 63 push({'data':data})
59 64 d = pull('data')
60 65 self.assertEquals(d, data)
61 66 self.client[:].push({'data':data})
62 67 d = self.client[:].pull('data', block=True)
63 68 self.assertEquals(d, nengines*[data])
64 69 ar = push({'data':data}, block=False)
65 70 self.assertTrue(isinstance(ar, AsyncResult))
66 71 r = ar.get()
67 72 ar = self.client[:].pull('data', block=False)
68 73 self.assertTrue(isinstance(ar, AsyncResult))
69 74 r = ar.get()
70 75 self.assertEquals(r, nengines*[data])
71 76 self.client[:].push(dict(a=10,b=20))
72 77 r = self.client[:].pull(('a','b'), block=True)
73 78 self.assertEquals(r, nengines*[[10,20]])
74 79
75 80 def test_push_pull_function(self):
76 81 "test pushing and pulling functions"
77 82 def testf(x):
78 83 return 2.0*x
79 84
80 85 t = self.client.ids[-1]
81 86 v = self.client[t]
82 87 v.block=True
83 88 push = v.push
84 89 pull = v.pull
85 90 execute = v.execute
86 91 push({'testf':testf})
87 92 r = pull('testf')
88 93 self.assertEqual(r(1.0), testf(1.0))
89 94 execute('r = testf(10)')
90 95 r = pull('r')
91 96 self.assertEquals(r, testf(10))
92 97 ar = self.client[:].push({'testf':testf}, block=False)
93 98 ar.get()
94 99 ar = self.client[:].pull('testf', block=False)
95 100 rlist = ar.get()
96 101 for r in rlist:
97 102 self.assertEqual(r(1.0), testf(1.0))
98 103 execute("def g(x): return x*x")
99 104 r = pull(('testf','g'))
100 105 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
101 106
102 107 def test_push_function_globals(self):
103 108 """test that pushed functions have access to globals"""
104 109 @interactive
105 110 def geta():
106 111 return a
107 112 # self.add_engines(1)
108 113 v = self.client[-1]
109 114 v.block=True
110 115 v['f'] = geta
111 116 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
112 117 v.execute('a=5')
113 118 v.execute('b=f()')
114 119 self.assertEquals(v['b'], 5)
115 120
116 121 def test_push_function_defaults(self):
117 122 """test that pushed functions preserve default args"""
118 123 def echo(a=10):
119 124 return a
120 125 v = self.client[-1]
121 126 v.block=True
122 127 v['f'] = echo
123 128 v.execute('b=f()')
124 129 self.assertEquals(v['b'], 10)
125 130
126 131 def test_get_result(self):
127 132 """test getting results from the Hub."""
128 133 c = pmod.Client(profile='iptest')
129 134 # self.add_engines(1)
130 135 t = c.ids[-1]
131 136 v = c[t]
132 137 v2 = self.client[t]
133 138 ar = v.apply_async(wait, 1)
134 139 # give the monitor time to notice the message
135 140 time.sleep(.25)
136 141 ahr = v2.get_result(ar.msg_ids)
137 142 self.assertTrue(isinstance(ahr, AsyncHubResult))
138 143 self.assertEquals(ahr.get(), ar.get())
139 144 ar2 = v2.get_result(ar.msg_ids)
140 145 self.assertFalse(isinstance(ar2, AsyncHubResult))
141 146 c.spin()
142 147 c.close()
143 148
144 149 def test_run_newline(self):
145 150 """test that run appends newline to files"""
146 151 tmpfile = mktemp()
147 152 with open(tmpfile, 'w') as f:
148 153 f.write("""def g():
149 154 return 5
150 155 """)
151 156 v = self.client[-1]
152 157 v.run(tmpfile, block=True)
153 158 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
154 159
155 160 def test_apply_tracked(self):
156 161 """test tracking for apply"""
157 162 # self.add_engines(1)
158 163 t = self.client.ids[-1]
159 164 v = self.client[t]
160 165 v.block=False
161 166 def echo(n=1024*1024, **kwargs):
162 167 with v.temp_flags(**kwargs):
163 168 return v.apply(lambda x: x, 'x'*n)
164 169 ar = echo(1, track=False)
165 170 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
166 171 self.assertTrue(ar.sent)
167 172 ar = echo(track=True)
168 173 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
169 174 self.assertEquals(ar.sent, ar._tracker.done)
170 175 ar._tracker.wait()
171 176 self.assertTrue(ar.sent)
172 177
173 178 def test_push_tracked(self):
174 179 t = self.client.ids[-1]
175 180 ns = dict(x='x'*1024*1024)
176 181 v = self.client[t]
177 182 ar = v.push(ns, block=False, track=False)
178 183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
179 184 self.assertTrue(ar.sent)
180 185
181 186 ar = v.push(ns, block=False, track=True)
182 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
183 188 self.assertEquals(ar.sent, ar._tracker.done)
184 189 ar._tracker.wait()
185 190 self.assertTrue(ar.sent)
186 191 ar.get()
187 192
188 193 def test_scatter_tracked(self):
189 194 t = self.client.ids
190 195 x='x'*1024*1024
191 196 ar = self.client[t].scatter('x', x, block=False, track=False)
192 197 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
193 198 self.assertTrue(ar.sent)
194 199
195 200 ar = self.client[t].scatter('x', x, block=False, track=True)
196 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 202 self.assertEquals(ar.sent, ar._tracker.done)
198 203 ar._tracker.wait()
199 204 self.assertTrue(ar.sent)
200 205 ar.get()
201 206
202 207 def test_remote_reference(self):
203 208 v = self.client[-1]
204 209 v['a'] = 123
205 210 ra = pmod.Reference('a')
206 211 b = v.apply_sync(lambda x: x, ra)
207 212 self.assertEquals(b, 123)
208 213
209 214
210 215 def test_scatter_gather(self):
211 216 view = self.client[:]
212 217 seq1 = range(16)
213 218 view.scatter('a', seq1)
214 219 seq2 = view.gather('a', block=True)
215 220 self.assertEquals(seq2, seq1)
216 221 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
217 222
218 223 @skip_without('numpy')
219 224 def test_scatter_gather_numpy(self):
220 225 import numpy
221 226 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
222 227 view = self.client[:]
223 228 a = numpy.arange(64)
224 229 view.scatter('a', a)
225 230 b = view.gather('a', block=True)
226 231 assert_array_equal(b, a)
227 232
228 233 def test_map(self):
229 234 view = self.client[:]
230 235 def f(x):
231 236 return x**2
232 237 data = range(16)
233 238 r = view.map_sync(f, data)
234 239 self.assertEquals(r, map(f, data))
235 240
236 241 def test_scatterGatherNonblocking(self):
237 242 data = range(16)
238 243 view = self.client[:]
239 244 view.scatter('a', data, block=False)
240 245 ar = view.gather('a', block=False)
241 246 self.assertEquals(ar.get(), data)
242 247
243 248 @skip_without('numpy')
244 249 def test_scatter_gather_numpy_nonblocking(self):
245 250 import numpy
246 251 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
247 252 a = numpy.arange(64)
248 253 view = self.client[:]
249 254 ar = view.scatter('a', a, block=False)
250 255 self.assertTrue(isinstance(ar, AsyncResult))
251 256 amr = view.gather('a', block=False)
252 257 self.assertTrue(isinstance(amr, AsyncMapResult))
253 258 assert_array_equal(amr.get(), a)
254 259
255 260 def test_execute(self):
256 261 view = self.client[:]
257 262 # self.client.debug=True
258 263 execute = view.execute
259 264 ar = execute('c=30', block=False)
260 265 self.assertTrue(isinstance(ar, AsyncResult))
261 266 ar = execute('d=[0,1,2]', block=False)
262 267 self.client.wait(ar, 1)
263 268 self.assertEquals(len(ar.get()), len(self.client))
264 269 for c in view['c']:
265 270 self.assertEquals(c, 30)
266 271
267 272 def test_abort(self):
268 273 view = self.client[-1]
269 274 ar = view.execute('import time; time.sleep(0.25)', block=False)
270 275 ar2 = view.apply_async(lambda : 2)
271 276 ar3 = view.apply_async(lambda : 3)
272 277 view.abort(ar2)
273 278 view.abort(ar3.msg_ids)
274 279 self.assertRaises(error.TaskAborted, ar2.get)
275 280 self.assertRaises(error.TaskAborted, ar3.get)
276 281
277 282 def test_temp_flags(self):
278 283 view = self.client[-1]
279 284 view.block=True
280 285 with view.temp_flags(block=False):
281 286 self.assertFalse(view.block)
282 287 self.assertTrue(view.block)
283 288
284 289 def test_importer(self):
285 290 view = self.client[-1]
286 291 view.clear(block=True)
287 292 with view.importer:
288 293 import re
289 294
290 295 @interactive
291 296 def findall(pat, s):
292 297 # this globals() step isn't necessary in real code
293 298 # only to prevent a closure in the test
294 299 re = globals()['re']
295 300 return re.findall(pat, s)
296 301
297 302 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
298 303
299 304 # parallel magic tests
300 305
301 306 def test_magic_px_blocking(self):
302 307 ip = get_ipython()
303 308 v = self.client[-1]
304 309 v.activate()
305 310 v.block=True
306 311
307 312 ip.magic_px('a=5')
308 313 self.assertEquals(v['a'], 5)
309 314 ip.magic_px('a=10')
310 315 self.assertEquals(v['a'], 10)
311 316 sio = StringIO()
312 317 savestdout = sys.stdout
313 318 sys.stdout = sio
314 319 ip.magic_px('print a')
315 320 sys.stdout = savestdout
316 321 sio.read()
317 322 self.assertTrue('[stdout:%i]'%v.targets in sio.buf)
318 323 self.assertTrue(sio.buf.rstrip().endswith('10'))
319 324 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
320 325
321 326 def test_magic_px_nonblocking(self):
322 327 ip = get_ipython()
323 328 v = self.client[-1]
324 329 v.activate()
325 330 v.block=False
326 331
327 332 ip.magic_px('a=5')
328 333 self.assertEquals(v['a'], 5)
329 334 ip.magic_px('a=10')
330 335 self.assertEquals(v['a'], 10)
331 336 sio = StringIO()
332 337 savestdout = sys.stdout
333 338 sys.stdout = sio
334 339 ip.magic_px('print a')
335 340 sys.stdout = savestdout
336 341 sio.read()
337 342 self.assertFalse('[stdout:%i]'%v.targets in sio.buf)
338 343 ip.magic_px('1/0')
339 344 ar = v.get_result(-1)
340 345 self.assertRaisesRemote(ZeroDivisionError, ar.get)
341 346
342 347 def test_magic_autopx_blocking(self):
343 348 ip = get_ipython()
344 349 v = self.client[-1]
345 350 v.activate()
346 351 v.block=True
347 352
348 353 sio = StringIO()
349 354 savestdout = sys.stdout
350 355 sys.stdout = sio
351 356 ip.magic_autopx()
352 357 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
353 358 ip.run_cell('print b')
354 359 ip.run_cell("b/c")
355 360 ip.run_code(compile('b*=2', '', 'single'))
356 361 ip.magic_autopx()
357 362 sys.stdout = savestdout
358 363 sio.read()
359 364 output = sio.buf.strip()
360 365 self.assertTrue(output.startswith('%autopx enabled'))
361 366 self.assertTrue(output.endswith('%autopx disabled'))
362 367 self.assertTrue('RemoteError: ZeroDivisionError' in output)
363 368 ar = v.get_result(-2)
364 369 self.assertEquals(v['a'], 5)
365 370 self.assertEquals(v['b'], 20)
366 371 self.assertRaisesRemote(ZeroDivisionError, ar.get)
367 372
368 373 def test_magic_autopx_nonblocking(self):
369 374 ip = get_ipython()
370 375 v = self.client[-1]
371 376 v.activate()
372 377 v.block=False
373 378
374 379 sio = StringIO()
375 380 savestdout = sys.stdout
376 381 sys.stdout = sio
377 382 ip.magic_autopx()
378 383 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
379 384 ip.run_cell('print b')
380 385 ip.run_cell("b/c")
381 386 ip.run_code(compile('b*=2', '', 'single'))
382 387 ip.magic_autopx()
383 388 sys.stdout = savestdout
384 389 sio.read()
385 390 output = sio.buf.strip()
386 391 self.assertTrue(output.startswith('%autopx enabled'))
387 392 self.assertTrue(output.endswith('%autopx disabled'))
388 393 self.assertFalse('ZeroDivisionError' in output)
389 394 ar = v.get_result(-2)
390 395 self.assertEquals(v['a'], 5)
391 396 self.assertEquals(v['b'], 20)
392 397 self.assertRaisesRemote(ZeroDivisionError, ar.get)
393 398
394 399 def test_magic_result(self):
395 400 ip = get_ipython()
396 401 v = self.client[-1]
397 402 v.activate()
398 403 v['a'] = 111
399 404 ra = v['a']
400 405
401 406 ar = ip.magic_result()
402 407 self.assertEquals(ar.msg_ids, [v.history[-1]])
403 408 self.assertEquals(ar.get(), 111)
404 409 ar = ip.magic_result('-2')
405 410 self.assertEquals(ar.msg_ids, [v.history[-2]])
406 411
407 412 def test_unicode_execute(self):
408 413 """test executing unicode strings"""
409 414 v = self.client[-1]
410 415 v.block=True
411 416 code=u"a=u'é'"
412 417 v.execute(code)
413 418 self.assertEquals(v['a'], u'é')
414 419
415 420 def test_unicode_apply_result(self):
416 421 """test unicode apply results"""
417 422 v = self.client[-1]
418 423 r = v.apply_sync(lambda : u'é')
419 424 self.assertEquals(r, u'é')
420 425
421 426 def test_unicode_apply_arg(self):
422 427 """test passing unicode arguments to apply"""
423 428 v = self.client[-1]
424 429
425 430 @interactive
426 431 def check_unicode(a, check):
427 432 assert isinstance(a, unicode), "%r is not unicode"%a
428 433 assert isinstance(check, bytes), "%r is not bytes"%check
429 434 assert a.encode('utf8') == check, "%s != %s"%(a,check)
430 435
431 436 for s in [ u'é', u'ßø®∫','asdf'.decode() ]:
432 437 try:
433 438 v.apply_sync(check_unicode, s, s.encode('utf8'))
434 439 except error.RemoteError as e:
435 440 if e.ename == 'AssertionError':
436 441 self.fail(e.evalue)
437 442 else:
438 443 raise e
439 444
440 445
441 446
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
General Comments 0
You need to be logged in to leave comments. Login now