##// END OF EJS Templates
all ipcluster scripts in some degree of working order with new config
MinRK -
Show More
@@ -1,566 +1,492
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython cluster directory
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2009 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 from __future__ import with_statement
19 19
20 20 import os
21 21 import logging
22 22 import re
23 23 import shutil
24 24 import sys
25 25
26 26 from subprocess import Popen, PIPE
27 27
28 from IPython.config.loader import PyFileConfigLoader
28 from IPython.config.loader import PyFileConfigLoader, Config
29 29 from IPython.config.configurable import Configurable
30 from IPython.core.application import Application, BaseAppConfigLoader
30 from IPython.config.application import Application
31 31 from IPython.core.crashhandler import CrashHandler
32 from IPython.core.newapplication import BaseIPythonApplication
32 33 from IPython.core import release
33 34 from IPython.utils.path import (
34 35 get_ipython_package_dir,
36 get_ipython_dir,
35 37 expand_path
36 38 )
37 from IPython.utils.traitlets import Unicode
39 from IPython.utils.traitlets import Unicode, Bool, CStr, Instance, Dict
38 40
39 41 #-----------------------------------------------------------------------------
40 42 # Module errors
41 43 #-----------------------------------------------------------------------------
42 44
43 45 class ClusterDirError(Exception):
44 46 pass
45 47
46 48
47 49 class PIDFileError(Exception):
48 50 pass
49 51
50 52
51 53 #-----------------------------------------------------------------------------
52 54 # Class for managing cluster directories
53 55 #-----------------------------------------------------------------------------
54 56
55 57 class ClusterDir(Configurable):
56 58 """An object to manage the cluster directory and its resources.
57 59
58 60 The cluster directory is used by :command:`ipengine`,
59 61 :command:`ipcontroller` and :command:`ipclsuter` to manage the
60 62 configuration, logging and security of these applications.
61 63
62 64 This object knows how to find, create and manage these directories. This
63 65 should be used by any code that want's to handle cluster directories.
64 66 """
65 67
66 68 security_dir_name = Unicode('security')
67 69 log_dir_name = Unicode('log')
68 70 pid_dir_name = Unicode('pid')
69 71 security_dir = Unicode(u'')
70 72 log_dir = Unicode(u'')
71 73 pid_dir = Unicode(u'')
72 location = Unicode(u'')
73 74
74 def __init__(self, location=u''):
75 super(ClusterDir, self).__init__(location=location)
75 location = Unicode(u'', config=True,
76 help="""Set the cluster dir. This overrides the logic used by the
77 `profile` option.""",
78 )
79 profile = Unicode(u'default',
80 help="""The string name of the profile to be used. This determines the name
81 of the cluster dir as: cluster_<profile>. The default profile is named
82 'default'. The cluster directory is resolve this way if the
83 `cluster_dir` option is not used.""", config=True
84 )
85
86 _location_isset = Bool(False) # flag for detecting multiply set location
87 _new_dir = Bool(False) # flag for whether a new dir was created
88
89 def __init__(self, **kwargs):
90 super(ClusterDir, self).__init__(**kwargs)
91 if not self.location:
92 self._profile_changed('profile', 'default', self.profile)
76 93
77 94 def _location_changed(self, name, old, new):
95 if self._location_isset:
96 raise RuntimeError("Cannot set ClusterDir more than once.")
97 self._location_isset = True
78 98 if not os.path.isdir(new):
79 99 os.makedirs(new)
100 self._new_dir = True
101 # ensure config files exist:
102 self.copy_all_config_files(overwrite=False)
80 103 self.security_dir = os.path.join(new, self.security_dir_name)
81 104 self.log_dir = os.path.join(new, self.log_dir_name)
82 105 self.pid_dir = os.path.join(new, self.pid_dir_name)
83 106 self.check_dirs()
84 107
108 def _profile_changed(self, name, old, new):
109 if self._location_isset:
110 raise RuntimeError("ClusterDir already set. Cannot set by profile.")
111 self.location = os.path.join(get_ipython_dir(), 'cluster_'+new)
112
85 113 def _log_dir_changed(self, name, old, new):
86 114 self.check_log_dir()
87 115
88 116 def check_log_dir(self):
89 117 if not os.path.isdir(self.log_dir):
90 118 os.mkdir(self.log_dir)
91 119
92 120 def _security_dir_changed(self, name, old, new):
93 121 self.check_security_dir()
94 122
95 123 def check_security_dir(self):
96 124 if not os.path.isdir(self.security_dir):
97 125 os.mkdir(self.security_dir, 0700)
98 126 os.chmod(self.security_dir, 0700)
99 127
100 128 def _pid_dir_changed(self, name, old, new):
101 129 self.check_pid_dir()
102 130
103 131 def check_pid_dir(self):
104 132 if not os.path.isdir(self.pid_dir):
105 133 os.mkdir(self.pid_dir, 0700)
106 134 os.chmod(self.pid_dir, 0700)
107 135
108 136 def check_dirs(self):
109 137 self.check_security_dir()
110 138 self.check_log_dir()
111 139 self.check_pid_dir()
112 140
113 def load_config_file(self, filename):
114 """Load a config file from the top level of the cluster dir.
115
116 Parameters
117 ----------
118 filename : unicode or str
119 The filename only of the config file that must be located in
120 the top-level of the cluster directory.
121 """
122 loader = PyFileConfigLoader(filename, self.location)
123 return loader.load_config()
124
125 141 def copy_config_file(self, config_file, path=None, overwrite=False):
126 142 """Copy a default config file into the active cluster directory.
127 143
128 144 Default configuration files are kept in :mod:`IPython.config.default`.
129 145 This function moves these from that location to the working cluster
130 146 directory.
131 147 """
132 148 if path is None:
133 149 import IPython.config.default
134 150 path = IPython.config.default.__file__.split(os.path.sep)[:-1]
135 151 path = os.path.sep.join(path)
136 152 src = os.path.join(path, config_file)
137 153 dst = os.path.join(self.location, config_file)
138 154 if not os.path.isfile(dst) or overwrite:
139 155 shutil.copy(src, dst)
140 156
141 157 def copy_all_config_files(self, path=None, overwrite=False):
142 158 """Copy all config files into the active cluster directory."""
143 159 for f in [u'ipcontroller_config.py', u'ipengine_config.py',
144 160 u'ipcluster_config.py']:
145 161 self.copy_config_file(f, path=path, overwrite=overwrite)
146 162
147 163 @classmethod
148 164 def create_cluster_dir(csl, cluster_dir):
149 165 """Create a new cluster directory given a full path.
150 166
151 167 Parameters
152 168 ----------
153 169 cluster_dir : str
154 170 The full path to the cluster directory. If it does exist, it will
155 171 be used. If not, it will be created.
156 172 """
157 173 return ClusterDir(location=cluster_dir)
158 174
159 175 @classmethod
160 176 def create_cluster_dir_by_profile(cls, path, profile=u'default'):
161 177 """Create a cluster dir by profile name and path.
162 178
163 179 Parameters
164 180 ----------
165 181 path : str
166 182 The path (directory) to put the cluster directory in.
167 183 profile : str
168 184 The name of the profile. The name of the cluster directory will
169 185 be "cluster_<profile>".
170 186 """
171 187 if not os.path.isdir(path):
172 188 raise ClusterDirError('Directory not found: %s' % path)
173 189 cluster_dir = os.path.join(path, u'cluster_' + profile)
174 190 return ClusterDir(location=cluster_dir)
175 191
176 192 @classmethod
177 193 def find_cluster_dir_by_profile(cls, ipython_dir, profile=u'default'):
178 194 """Find an existing cluster dir by profile name, return its ClusterDir.
179 195
180 196 This searches through a sequence of paths for a cluster dir. If it
181 197 is not found, a :class:`ClusterDirError` exception will be raised.
182 198
183 199 The search path algorithm is:
184 200 1. ``os.getcwd()``
185 201 2. ``ipython_dir``
186 202 3. The directories found in the ":" separated
187 203 :env:`IPCLUSTER_DIR_PATH` environment variable.
188 204
189 205 Parameters
190 206 ----------
191 207 ipython_dir : unicode or str
192 208 The IPython directory to use.
193 209 profile : unicode or str
194 210 The name of the profile. The name of the cluster directory
195 211 will be "cluster_<profile>".
196 212 """
197 213 dirname = u'cluster_' + profile
198 214 cluster_dir_paths = os.environ.get('IPCLUSTER_DIR_PATH','')
199 215 if cluster_dir_paths:
200 216 cluster_dir_paths = cluster_dir_paths.split(':')
201 217 else:
202 218 cluster_dir_paths = []
203 219 paths = [os.getcwd(), ipython_dir] + cluster_dir_paths
204 220 for p in paths:
205 221 cluster_dir = os.path.join(p, dirname)
206 222 if os.path.isdir(cluster_dir):
207 223 return ClusterDir(location=cluster_dir)
208 224 else:
209 225 raise ClusterDirError('Cluster directory not found in paths: %s' % dirname)
210 226
211 227 @classmethod
212 228 def find_cluster_dir(cls, cluster_dir):
213 229 """Find/create a cluster dir and return its ClusterDir.
214 230
215 231 This will create the cluster directory if it doesn't exist.
216 232
217 233 Parameters
218 234 ----------
219 235 cluster_dir : unicode or str
220 236 The path of the cluster directory. This is expanded using
221 237 :func:`IPython.utils.genutils.expand_path`.
222 238 """
223 239 cluster_dir = expand_path(cluster_dir)
224 240 if not os.path.isdir(cluster_dir):
225 241 raise ClusterDirError('Cluster directory not found: %s' % cluster_dir)
226 242 return ClusterDir(location=cluster_dir)
227 243
228 244
229 245 #-----------------------------------------------------------------------------
230 # Command line options
231 #-----------------------------------------------------------------------------
232
233 class ClusterDirConfigLoader(BaseAppConfigLoader):
234
235 def _add_cluster_profile(self, parser):
236 paa = parser.add_argument
237 paa('-p', '--profile',
238 dest='Global.profile',type=unicode,
239 help=
240 """The string name of the profile to be used. This determines the name
241 of the cluster dir as: cluster_<profile>. The default profile is named
242 'default'. The cluster directory is resolve this way if the
243 --cluster-dir option is not used.""",
244 metavar='Global.profile')
245
246 def _add_cluster_dir(self, parser):
247 paa = parser.add_argument
248 paa('--cluster-dir',
249 dest='Global.cluster_dir',type=unicode,
250 help="""Set the cluster dir. This overrides the logic used by the
251 --profile option.""",
252 metavar='Global.cluster_dir')
253
254 def _add_work_dir(self, parser):
255 paa = parser.add_argument
256 paa('--work-dir',
257 dest='Global.work_dir',type=unicode,
258 help='Set the working dir for the process.',
259 metavar='Global.work_dir')
260
261 def _add_clean_logs(self, parser):
262 paa = parser.add_argument
263 paa('--clean-logs',
264 dest='Global.clean_logs', action='store_true',
265 help='Delete old log flies before starting.')
266
267 def _add_no_clean_logs(self, parser):
268 paa = parser.add_argument
269 paa('--no-clean-logs',
270 dest='Global.clean_logs', action='store_false',
271 help="Don't Delete old log flies before starting.")
272
273 def _add_arguments(self):
274 super(ClusterDirConfigLoader, self)._add_arguments()
275 self._add_cluster_profile(self.parser)
276 self._add_cluster_dir(self.parser)
277 self._add_work_dir(self.parser)
278 self._add_clean_logs(self.parser)
279 self._add_no_clean_logs(self.parser)
280
281
282 #-----------------------------------------------------------------------------
283 246 # Crash handler for this application
284 247 #-----------------------------------------------------------------------------
285 248
286 249
287 250 _message_template = """\
288 251 Oops, $self.app_name crashed. We do our best to make it stable, but...
289 252
290 253 A crash report was automatically generated with the following information:
291 254 - A verbatim copy of the crash traceback.
292 255 - Data on your current $self.app_name configuration.
293 256
294 257 It was left in the file named:
295 258 \t'$self.crash_report_fname'
296 259 If you can email this file to the developers, the information in it will help
297 260 them in understanding and correcting the problem.
298 261
299 262 You can mail it to: $self.contact_name at $self.contact_email
300 263 with the subject '$self.app_name Crash Report'.
301 264
302 265 If you want to do it now, the following command will work (under Unix):
303 266 mail -s '$self.app_name Crash Report' $self.contact_email < $self.crash_report_fname
304 267
305 268 To ensure accurate tracking of this issue, please file a report about it at:
306 269 $self.bug_tracker
307 270 """
308 271
309 272 class ClusterDirCrashHandler(CrashHandler):
310 273 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
311 274
312 275 message_template = _message_template
313 276
314 277 def __init__(self, app):
315 contact_name = release.authors['Brian'][0]
316 contact_email = release.authors['Brian'][1]
278 contact_name = release.authors['Min'][0]
279 contact_email = release.authors['Min'][1]
317 280 bug_tracker = 'http://github.com/ipython/ipython/issues'
318 281 super(ClusterDirCrashHandler,self).__init__(
319 282 app, contact_name, contact_email, bug_tracker
320 283 )
321 284
322 285
323 286 #-----------------------------------------------------------------------------
324 287 # Main application
325 288 #-----------------------------------------------------------------------------
326
327 class ApplicationWithClusterDir(Application):
289 base_aliases = {
290 'profile' : "ClusterDir.profile",
291 'cluster_dir' : 'ClusterDir.location',
292 'log_level' : 'Application.log_level',
293 'work_dir' : 'ClusterDirApplicaiton.work_dir',
294 'log_to_file' : 'ClusterDirApplicaiton.log_to_file',
295 'clean_logs' : 'ClusterDirApplicaiton.clean_logs',
296 'log_url' : 'ClusterDirApplicaiton.log_url',
297 }
298
299 base_flags = {
300 'debug' : ( {"Application" : {"log_level" : logging.DEBUG}}, "set loglevel to DEBUG"),
301 'clean-logs' : ( {"ClusterDirApplication" : {"clean_logs" : True}}, "cleanup old logfiles"),
302 'log-to-file' : ( {"ClusterDirApplication" : {"log_to_file" : True}}, "log to a file")
303 }
304 for k,v in base_flags.iteritems():
305 base_flags[k] = (Config(v[0]),v[1])
306
307 class ClusterDirApplication(BaseIPythonApplication):
328 308 """An application that puts everything into a cluster directory.
329 309
330 310 Instead of looking for things in the ipython_dir, this type of application
331 311 will use its own private directory called the "cluster directory"
332 312 for things like config files, log files, etc.
333 313
334 314 The cluster directory is resolved as follows:
335 315
336 316 * If the ``--cluster-dir`` option is given, it is used.
337 317 * If ``--cluster-dir`` is not given, the application directory is
338 318 resolve using the profile name as ``cluster_<profile>``. The search
339 319 path for this directory is then i) cwd if it is found there
340 320 and ii) in ipython_dir otherwise.
341 321
342 322 The config file for the application is to be put in the cluster
343 323 dir and named the value of the ``config_file_name`` class attribute.
344 324 """
345 325
346 command_line_loader = ClusterDirConfigLoader
347 326 crash_handler_class = ClusterDirCrashHandler
348 auto_create_cluster_dir = True
327 auto_create_cluster_dir = Bool(True, config=True,
328 help="whether to create the cluster_dir if it doesn't exist")
349 329 # temporarily override default_log_level to INFO
350 330 default_log_level = logging.INFO
331 cluster_dir = Instance(ClusterDir)
332
333 work_dir = Unicode(os.getcwdu(), config=True,
334 help='Set the working dir for the process.'
335 )
336 def _work_dir_changed(self, name, old, new):
337 self.work_dir = unicode(expand_path(new))
338
339 log_to_file = Bool(config=True,
340 help="whether to log to a file")
341
342 clean_logs = Bool(True, shortname='--clean-logs', config=True,
343 help="whether to cleanup old logfiles before starting")
351 344
352 def create_default_config(self):
353 super(ApplicationWithClusterDir, self).create_default_config()
354 self.default_config.Global.profile = u'default'
355 self.default_config.Global.cluster_dir = u''
356 self.default_config.Global.work_dir = os.getcwd()
357 self.default_config.Global.log_to_file = False
358 self.default_config.Global.log_url = None
359 self.default_config.Global.clean_logs = False
345 log_url = CStr('', shortname='--log-url', config=True,
346 help="The ZMQ URL of the iplooger to aggregate logging.")
360 347
361 def find_resources(self):
348 config_file = Unicode(u'', config=True,
349 help="""Path to ipcontroller configuration file. The default is to use
350 <appname>_config.py, as found by cluster-dir."""
351 )
352
353 aliases = Dict(base_aliases)
354 flags = Dict(base_flags)
355
356 def init_clusterdir(self):
362 357 """This resolves the cluster directory.
363 358
364 359 This tries to find the cluster directory and if successful, it will
365 360 have done:
366 361 * Sets ``self.cluster_dir_obj`` to the :class:`ClusterDir` object for
367 362 the application.
368 363 * Sets ``self.cluster_dir`` attribute of the application and config
369 364 objects.
370 365
371 366 The algorithm used for this is as follows:
372 367 1. Try ``Global.cluster_dir``.
373 368 2. Try using ``Global.profile``.
374 369 3. If both of these fail and ``self.auto_create_cluster_dir`` is
375 370 ``True``, then create the new cluster dir in the IPython directory.
376 371 4. If all fails, then raise :class:`ClusterDirError`.
377 372 """
378
379 try:
380 cluster_dir = self.command_line_config.Global.cluster_dir
381 except AttributeError:
382 cluster_dir = self.default_config.Global.cluster_dir
383 cluster_dir = expand_path(cluster_dir)
384 try:
385 self.cluster_dir_obj = ClusterDir.find_cluster_dir(cluster_dir)
386 except ClusterDirError:
387 pass
388 else:
389 self.log.info('Using existing cluster dir: %s' % \
390 self.cluster_dir_obj.location
391 )
392 self.finish_cluster_dir()
393 return
394
395 try:
396 self.profile = self.command_line_config.Global.profile
397 except AttributeError:
398 self.profile = self.default_config.Global.profile
399 try:
400 self.cluster_dir_obj = ClusterDir.find_cluster_dir_by_profile(
401 self.ipython_dir, self.profile)
402 except ClusterDirError:
403 pass
404 else:
405 self.log.info('Using existing cluster dir: %s' % \
406 self.cluster_dir_obj.location
407 )
408 self.finish_cluster_dir()
409 return
410
411 if self.auto_create_cluster_dir:
412 self.cluster_dir_obj = ClusterDir.create_cluster_dir_by_profile(
413 self.ipython_dir, self.profile
414 )
373 self.cluster_dir = ClusterDir(config=self.config)
374 if self.cluster_dir._new_dir:
415 375 self.log.info('Creating new cluster dir: %s' % \
416 self.cluster_dir_obj.location
417 )
418 self.finish_cluster_dir()
376 self.cluster_dir.location)
419 377 else:
420 raise ClusterDirError('Could not find a valid cluster directory.')
421
422 def finish_cluster_dir(self):
423 # Set the cluster directory
424 self.cluster_dir = self.cluster_dir_obj.location
425
426 # These have to be set because they could be different from the one
427 # that we just computed. Because command line has the highest
428 # priority, this will always end up in the master_config.
429 self.default_config.Global.cluster_dir = self.cluster_dir
430 self.command_line_config.Global.cluster_dir = self.cluster_dir
431
432 def find_config_file_name(self):
433 """Find the config file name for this application."""
434 # For this type of Application it should be set as a class attribute.
435 if not hasattr(self, 'default_config_file_name'):
436 self.log.critical("No config filename found")
437 else:
438 self.config_file_name = self.default_config_file_name
439
440 def find_config_file_paths(self):
441 # Set the search path to to the cluster directory. We should NOT
442 # include IPython.config.default here as the default config files
443 # are ALWAYS automatically moved to the cluster directory.
444 conf_dir = os.path.join(get_ipython_package_dir(), 'config', 'default')
445 self.config_file_paths = (self.cluster_dir,)
446
447 def pre_construct(self):
448 # The log and security dirs were set earlier, but here we put them
449 # into the config and log them.
450 config = self.master_config
451 sdir = self.cluster_dir_obj.security_dir
452 self.security_dir = config.Global.security_dir = sdir
453 ldir = self.cluster_dir_obj.log_dir
454 self.log_dir = config.Global.log_dir = ldir
455 pdir = self.cluster_dir_obj.pid_dir
456 self.pid_dir = config.Global.pid_dir = pdir
457 self.log.info("Cluster directory set to: %s" % self.cluster_dir)
458 config.Global.work_dir = unicode(expand_path(config.Global.work_dir))
459 # Change to the working directory. We do this just before construct
460 # is called so all the components there have the right working dir.
461 self.to_work_dir()
378 self.log.info('Using existing cluster dir: %s' % \
379 self.cluster_dir.location)
462 380
463 381 def to_work_dir(self):
464 wd = self.master_config.Global.work_dir
465 if unicode(wd) != unicode(os.getcwd()):
382 wd = self.work_dir
383 if unicode(wd) != os.getcwdu():
466 384 os.chdir(wd)
467 385 self.log.info("Changing to working dir: %s" % wd)
468 386
469 def start_logging(self):
470 # Remove old log files
471 if self.master_config.Global.clean_logs:
472 log_dir = self.master_config.Global.log_dir
473 for f in os.listdir(log_dir):
474 if re.match(r'%s-\d+\.(log|err|out)'%self.name,f):
475 # if f.startswith(self.name + u'-') and f.endswith('.log'):
476 os.remove(os.path.join(log_dir, f))
477 # Start logging to the new log file
478 if self.master_config.Global.log_to_file:
479 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
480 logfile = os.path.join(self.log_dir, log_filename)
481 open_log_file = open(logfile, 'w')
482 elif self.master_config.Global.log_url:
483 open_log_file = None
484 else:
485 open_log_file = sys.stdout
486 if open_log_file is not None:
487 self.log.removeHandler(self._log_handler)
488 self._log_handler = logging.StreamHandler(open_log_file)
489 self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
490 self._log_handler.setFormatter(self._log_formatter)
491 self.log.addHandler(self._log_handler)
492 # log.startLogging(open_log_file)
387 def load_config_file(self, filename, path=None):
388 """Load a .py based config file by filename and path."""
389 return Application.load_config_file(self, filename, path=path)
390 #
391 # def load_default_config_file(self):
392 # """Load a .py based config file by filename and path."""
393 # return BaseIPythonApplication.load_config_file(self)
394
395 # disable URL-logging
396 # def init_logging(self):
397 # # Remove old log files
398 # if self.master_config.Global.clean_logs:
399 # log_dir = self.master_config.Global.log_dir
400 # for f in os.listdir(log_dir):
401 # if re.match(r'%s-\d+\.(log|err|out)'%self.name,f):
402 # # if f.startswith(self.name + u'-') and f.endswith('.log'):
403 # os.remove(os.path.join(log_dir, f))
404 # # Start logging to the new log file
405 # if self.master_config.Global.log_to_file:
406 # log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
407 # logfile = os.path.join(self.log_dir, log_filename)
408 # open_log_file = open(logfile, 'w')
409 # elif self.master_config.Global.log_url:
410 # open_log_file = None
411 # else:
412 # open_log_file = sys.stdout
413 # if open_log_file is not None:
414 # self.log.removeHandler(self._log_handler)
415 # self._log_handler = logging.StreamHandler(open_log_file)
416 # self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
417 # self._log_handler.setFormatter(self._log_formatter)
418 # self.log.addHandler(self._log_handler)
419 # # log.startLogging(open_log_file)
493 420
494 421 def write_pid_file(self, overwrite=False):
495 422 """Create a .pid file in the pid_dir with my pid.
496 423
497 424 This must be called after pre_construct, which sets `self.pid_dir`.
498 425 This raises :exc:`PIDFileError` if the pid file exists already.
499 426 """
500 pid_file = os.path.join(self.pid_dir, self.name + u'.pid')
427 pid_file = os.path.join(self.cluster_dir.pid_dir, self.name + u'.pid')
501 428 if os.path.isfile(pid_file):
502 429 pid = self.get_pid_from_file()
503 430 if not overwrite:
504 431 raise PIDFileError(
505 432 'The pid file [%s] already exists. \nThis could mean that this '
506 433 'server is already running with [pid=%s].' % (pid_file, pid)
507 434 )
508 435 with open(pid_file, 'w') as f:
509 436 self.log.info("Creating pid file: %s" % pid_file)
510 437 f.write(repr(os.getpid())+'\n')
511 438
512 439 def remove_pid_file(self):
513 440 """Remove the pid file.
514 441
515 442 This should be called at shutdown by registering a callback with
516 443 :func:`reactor.addSystemEventTrigger`. This needs to return
517 444 ``None``.
518 445 """
519 pid_file = os.path.join(self.pid_dir, self.name + u'.pid')
446 pid_file = os.path.join(self.cluster_dir.pid_dir, self.name + u'.pid')
520 447 if os.path.isfile(pid_file):
521 448 try:
522 449 self.log.info("Removing pid file: %s" % pid_file)
523 450 os.remove(pid_file)
524 451 except:
525 452 self.log.warn("Error removing the pid file: %s" % pid_file)
526 453
527 454 def get_pid_from_file(self):
528 455 """Get the pid from the pid file.
529 456
530 457 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
531 458 """
532 pid_file = os.path.join(self.pid_dir, self.name + u'.pid')
459 pid_file = os.path.join(self.cluster_dir.pid_dir, self.name + u'.pid')
533 460 if os.path.isfile(pid_file):
534 461 with open(pid_file, 'r') as f:
535 462 pid = int(f.read().strip())
536 463 return pid
537 464 else:
538 465 raise PIDFileError('pid file not found: %s' % pid_file)
539 466
540 467 def check_pid(self, pid):
541 468 if os.name == 'nt':
542 469 try:
543 470 import ctypes
544 471 # returns 0 if no such process (of ours) exists
545 472 # positive int otherwise
546 473 p = ctypes.windll.kernel32.OpenProcess(1,0,pid)
547 474 except Exception:
548 475 self.log.warn(
549 476 "Could not determine whether pid %i is running via `OpenProcess`. "
550 477 " Making the likely assumption that it is."%pid
551 478 )
552 479 return True
553 480 return bool(p)
554 481 else:
555 482 try:
556 483 p = Popen(['ps','x'], stdout=PIPE, stderr=PIPE)
557 484 output,_ = p.communicate()
558 485 except OSError:
559 486 self.log.warn(
560 487 "Could not determine whether pid %i is running via `ps x`. "
561 488 " Making the likely assumption that it is."%pid
562 489 )
563 490 return True
564 491 pids = map(int, re.findall(r'^\W*\d+', output, re.MULTILINE))
565 492 return pid in pids
566 No newline at end of file
@@ -1,617 +1,550
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The ipcluster application.
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2009 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import errno
19 19 import logging
20 20 import os
21 21 import re
22 22 import signal
23 23
24 24 from subprocess import check_call, CalledProcessError, PIPE
25 25 import zmq
26 26 from zmq.eventloop import ioloop
27 27
28 from IPython.external.argparse import ArgumentParser, SUPPRESS
28 from IPython.config.loader import Config
29 29 from IPython.utils.importstring import import_item
30 from IPython.utils.traitlets import Int, CStr, CUnicode, Str, Bool, CFloat, Dict, List
30 31
31 32 from IPython.parallel.apps.clusterdir import (
32 ApplicationWithClusterDir, ClusterDirConfigLoader,
33 ClusterDirError, PIDFileError
33 ClusterDirApplication, ClusterDirError,
34 PIDFileError,
35 base_flags,
34 36 )
35 37
36 38
37 39 #-----------------------------------------------------------------------------
38 40 # Module level variables
39 41 #-----------------------------------------------------------------------------
40 42
41 43
42 44 default_config_file_name = u'ipcluster_config.py'
43 45
44 46
45 47 _description = """\
46 48 Start an IPython cluster for parallel computing.\n\n
47 49
48 50 An IPython cluster consists of 1 controller and 1 or more engines.
49 51 This command automates the startup of these processes using a wide
50 52 range of startup methods (SSH, local processes, PBS, mpiexec,
51 53 Windows HPC Server 2008). To start a cluster with 4 engines on your
52 local host simply do 'ipcluster start -n 4'. For more complex usage
53 you will typically do 'ipcluster create -p mycluster', then edit
54 configuration files, followed by 'ipcluster start -p mycluster -n 4'.
54 local host simply do 'ipcluster start n=4'. For more complex usage
55 you will typically do 'ipcluster --create profile=mycluster', then edit
56 configuration files, followed by 'ipcluster --start -p mycluster -n 4'.
55 57 """
56 58
57 59
58 60 # Exit codes for ipcluster
59 61
60 62 # This will be the exit code if the ipcluster appears to be running because
61 63 # a .pid file exists
62 64 ALREADY_STARTED = 10
63 65
64 66
65 67 # This will be the exit code if ipcluster stop is run, but there is not .pid
66 68 # file to be found.
67 69 ALREADY_STOPPED = 11
68 70
69 71 # This will be the exit code if ipcluster engines is run, but there is not .pid
70 72 # file to be found.
71 73 NO_CLUSTER = 12
72 74
73 75
74 76 #-----------------------------------------------------------------------------
75 # Command line options
77 # Main application
76 78 #-----------------------------------------------------------------------------
77
78
79 class IPClusterAppConfigLoader(ClusterDirConfigLoader):
80
81 def _add_arguments(self):
82 # Don't call ClusterDirConfigLoader._add_arguments as we don't want
83 # its defaults on self.parser. Instead, we will put those on
84 # default options on our subparsers.
85
86 # This has all the common options that all subcommands use
87 parent_parser1 = ArgumentParser(
88 add_help=False,
89 argument_default=SUPPRESS
90 )
91 self._add_ipython_dir(parent_parser1)
92 self._add_log_level(parent_parser1)
93
94 # This has all the common options that other subcommands use
95 parent_parser2 = ArgumentParser(
96 add_help=False,
97 argument_default=SUPPRESS
98 )
99 self._add_cluster_profile(parent_parser2)
100 self._add_cluster_dir(parent_parser2)
101 self._add_work_dir(parent_parser2)
102 paa = parent_parser2.add_argument
103 paa('--log-to-file',
104 action='store_true', dest='Global.log_to_file',
105 help='Log to a file in the log directory (default is stdout)')
106
107 # Create the object used to create the subparsers.
108 subparsers = self.parser.add_subparsers(
109 dest='Global.subcommand',
110 title='ipcluster subcommands',
111 description=
112 """ipcluster has a variety of subcommands. The general way of
113 running ipcluster is 'ipcluster <cmd> [options]'. To get help
114 on a particular subcommand do 'ipcluster <cmd> -h'."""
115 # help="For more help, type 'ipcluster <cmd> -h'",
116 )
117
118 # The "list" subcommand parser
119 parser_list = subparsers.add_parser(
120 'list',
121 parents=[parent_parser1],
122 argument_default=SUPPRESS,
123 help="List all clusters in cwd and ipython_dir.",
124 description=
125 """List all available clusters, by cluster directory, that can
126 be found in the current working directly or in the ipython
127 directory. Cluster directories are named using the convention
128 'cluster_<profile>'."""
129 )
130
131 # The "create" subcommand parser
132 parser_create = subparsers.add_parser(
133 'create',
134 parents=[parent_parser1, parent_parser2],
135 argument_default=SUPPRESS,
136 help="Create a new cluster directory.",
137 description=
138 """Create an ipython cluster directory by its profile name or
139 cluster directory path. Cluster directories contain
140 configuration, log and security related files and are named
141 using the convention 'cluster_<profile>'. By default they are
142 located in your ipython directory. Once created, you will
143 probably need to edit the configuration files in the cluster
144 directory to configure your cluster. Most users will create a
145 cluster directory by profile name,
146 'ipcluster create -p mycluster', which will put the directory
147 in '<ipython_dir>/cluster_mycluster'.
148 """
149 )
150 paa = parser_create.add_argument
151 paa('--reset-config',
152 dest='Global.reset_config', action='store_true',
153 help=
154 """Recopy the default config files to the cluster directory.
155 You will loose any modifications you have made to these files.""")
156
157 # The "start" subcommand parser
158 parser_start = subparsers.add_parser(
159 'start',
160 parents=[parent_parser1, parent_parser2],
161 argument_default=SUPPRESS,
162 help="Start a cluster.",
163 description=
164 """Start an ipython cluster by its profile name or cluster
79 start_help = """Start an ipython cluster by its profile name or cluster
165 80 directory. Cluster directories contain configuration, log and
166 81 security related files and are named using the convention
167 82 'cluster_<profile>' and should be creating using the 'start'
168 83 subcommand of 'ipcluster'. If your cluster directory is in
169 84 the cwd or the ipython directory, you can simply refer to it
170 85 using its profile name, 'ipcluster start -n 4 -p <profile>`,
171 86 otherwise use the '--cluster-dir' option.
172 87 """
173 )
174
175 paa = parser_start.add_argument
176 paa('-n', '--number',
177 type=int, dest='Global.n',
178 help='The number of engines to start.',
179 metavar='Global.n')
180 paa('--clean-logs',
181 dest='Global.clean_logs', action='store_true',
182 help='Delete old log flies before starting.')
183 paa('--no-clean-logs',
184 dest='Global.clean_logs', action='store_false',
185 help="Don't delete old log flies before starting.")
186 paa('--daemon',
187 dest='Global.daemonize', action='store_true',
188 help='Daemonize the ipcluster program. This implies --log-to-file')
189 paa('--no-daemon',
190 dest='Global.daemonize', action='store_false',
191 help="Dont't daemonize the ipcluster program.")
192 paa('--delay',
193 type=float, dest='Global.delay',
194 help="Specify the delay (in seconds) between starting the controller and starting the engine(s).")
195
196 # The "stop" subcommand parser
197 parser_stop = subparsers.add_parser(
198 'stop',
199 parents=[parent_parser1, parent_parser2],
200 argument_default=SUPPRESS,
201 help="Stop a running cluster.",
202 description=
203 """Stop a running ipython cluster by its profile name or cluster
88 stop_help = """Stop a running ipython cluster by its profile name or cluster
204 89 directory. Cluster directories are named using the convention
205 90 'cluster_<profile>'. If your cluster directory is in
206 91 the cwd or the ipython directory, you can simply refer to it
207 92 using its profile name, 'ipcluster stop -p <profile>`, otherwise
208 93 use the '--cluster-dir' option.
209 94 """
210 )
211 paa = parser_stop.add_argument
212 paa('--signal',
213 dest='Global.signal', type=int,
214 help="The signal number to use in stopping the cluster (default=2).",
215 metavar="Global.signal")
216
217 # the "engines" subcommand parser
218 parser_engines = subparsers.add_parser(
219 'engines',
220 parents=[parent_parser1, parent_parser2],
221 argument_default=SUPPRESS,
222 help="Attach some engines to an existing controller or cluster.",
223 description=
224 """Start one or more engines to connect to an existing Cluster
95 engines_help = """Start one or more engines to connect to an existing Cluster
225 96 by profile name or cluster directory.
226 97 Cluster directories contain configuration, log and
227 98 security related files and are named using the convention
228 99 'cluster_<profile>' and should be creating using the 'start'
229 100 subcommand of 'ipcluster'. If your cluster directory is in
230 101 the cwd or the ipython directory, you can simply refer to it
231 using its profile name, 'ipcluster engines -n 4 -p <profile>`,
232 otherwise use the '--cluster-dir' option.
102 using its profile name, 'ipcluster --engines -n 4 -p <profile>`,
103 otherwise use the 'cluster_dir' option.
233 104 """
234 )
235 paa = parser_engines.add_argument
236 paa('-n', '--number',
237 type=int, dest='Global.n',
238 help='The number of engines to start.',
239 metavar='Global.n')
240 paa('--daemon',
241 dest='Global.daemonize', action='store_true',
242 help='Daemonize the ipcluster program. This implies --log-to-file')
243 paa('--no-daemon',
244 dest='Global.daemonize', action='store_false',
245 help="Dont't daemonize the ipcluster program.")
105 create_help = """Create an ipython cluster directory by its profile name or
106 cluster directory path. Cluster directories contain
107 configuration, log and security related files and are named
108 using the convention 'cluster_<profile>'. By default they are
109 located in your ipython directory. Once created, you will
110 probably need to edit the configuration files in the cluster
111 directory to configure your cluster. Most users will create a
112 cluster directory by profile name,
113 'ipcluster create -p mycluster', which will put the directory
114 in '<ipython_dir>/cluster_mycluster'.
115 """
116 list_help = """List all available clusters, by cluster directory, that can
117 be found in the current working directly or in the ipython
118 directory. Cluster directories are named using the convention
119 'cluster_<profile>'."""
246 120
247 #-----------------------------------------------------------------------------
248 # Main application
249 #-----------------------------------------------------------------------------
250 121
122 flags = {}
123 flags.update(base_flags)
124 flags.update({
125 'start' : ({ 'IPClusterApp': Config({'subcommand' : 'start'})} , start_help),
126 'stop' : ({ 'IPClusterApp': Config({'subcommand' : 'stop'})} , stop_help),
127 'create' : ({ 'IPClusterApp': Config({'subcommand' : 'create'})} , create_help),
128 'engines' : ({ 'IPClusterApp': Config({'subcommand' : 'engines'})} , engines_help),
129 'list' : ({ 'IPClusterApp': Config({'subcommand' : 'list'})} , list_help),
251 130
252 class IPClusterApp(ApplicationWithClusterDir):
131 })
132
133 class IPClusterApp(ClusterDirApplication):
253 134
254 135 name = u'ipcluster'
255 136 description = _description
256 137 usage = None
257 command_line_loader = IPClusterAppConfigLoader
258 138 default_config_file_name = default_config_file_name
259 139 default_log_level = logging.INFO
260 140 auto_create_cluster_dir = False
141 classes = List()
142 def _classes_default(self,):
143 from IPython.parallel.apps import launcher
144 return launcher.all_launchers
145
146 n = Int(0, config=True,
147 help="The number of engines to start.")
148 signal = Int(signal.SIGINT, config=True,
149 help="signal to use for stopping. [default: SIGINT]")
150 delay = CFloat(1., config=True,
151 help="delay (in s) between starting the controller and the engines")
152
153 subcommand = Str('', config=True,
154 help="""ipcluster has a variety of subcommands. The general way of
155 running ipcluster is 'ipcluster --<cmd> [options]'."""
156 )
261 157
262 def create_default_config(self):
263 super(IPClusterApp, self).create_default_config()
264 self.default_config.Global.controller_launcher = \
265 'IPython.parallel.apps.launcher.LocalControllerLauncher'
266 self.default_config.Global.engine_launcher = \
267 'IPython.parallel.apps.launcher.LocalEngineSetLauncher'
268 self.default_config.Global.n = 2
269 self.default_config.Global.delay = 2
270 self.default_config.Global.reset_config = False
271 self.default_config.Global.clean_logs = True
272 self.default_config.Global.signal = signal.SIGINT
273 self.default_config.Global.daemonize = False
274
275 def find_resources(self):
276 subcommand = self.command_line_config.Global.subcommand
158 controller_launcher_class = Str('IPython.parallel.apps.launcher.LocalControllerLauncher',
159 config=True,
160 help="The class for launching a Controller."
161 )
162 engine_launcher_class = Str('IPython.parallel.apps.launcher.LocalEngineSetLauncher',
163 config=True,
164 help="The class for launching Engines."
165 )
166 reset = Bool(False, config=True,
167 help="Whether to reset config files as part of '--create'."
168 )
169 daemonize = Bool(False, config=True,
170 help='Daemonize the ipcluster program. This implies --log-to-file')
171
172 def _daemonize_changed(self, name, old, new):
173 if new:
174 self.log_to_file = True
175
176 def _n_changed(self, name, old, new):
177 # propagate n all over the place...
178 # TODO make this clean
179 # ensure all classes are covered.
180 self.config.LocalEngineSetLauncher.n=new
181 self.config.MPIExecEngineSetLauncher.n=new
182 self.config.SSHEngineSetLauncher.n=new
183 self.config.PBSEngineSetLauncher.n=new
184 self.config.SGEEngineSetLauncher.n=new
185 self.config.WinHPEngineSetLauncher.n=new
186
187 aliases = Dict(dict(
188 n='IPClusterApp.n',
189 signal = 'IPClusterApp.signal',
190 delay = 'IPClusterApp.delay',
191 clauncher = 'IPClusterApp.controller_launcher_class',
192 elauncher = 'IPClusterApp.engine_launcher_class',
193 ))
194 flags = Dict(flags)
195
196 def init_clusterdir(self):
197 subcommand = self.subcommand
277 198 if subcommand=='list':
278 199 self.list_cluster_dirs()
279 # Exit immediately because there is nothing left to do.
280 self.exit()
281 elif subcommand=='create':
200 self.exit(0)
201 if subcommand=='create':
202 reset = self.reset_config
282 203 self.auto_create_cluster_dir = True
283 super(IPClusterApp, self).find_resources()
204 super(IPClusterApp, self).init_clusterdir()
205 self.log.info('Copying default config files to cluster directory '
206 '[overwrite=%r]' % (reset,))
207 self.cluster_dir.copy_all_config_files(overwrite=reset)
284 208 elif subcommand=='start' or subcommand=='stop':
285 209 self.auto_create_cluster_dir = True
286 210 try:
287 super(IPClusterApp, self).find_resources()
211 super(IPClusterApp, self).init_clusterdir()
288 212 except ClusterDirError:
289 213 raise ClusterDirError(
290 214 "Could not find a cluster directory. A cluster dir must "
291 215 "be created before running 'ipcluster start'. Do "
292 216 "'ipcluster create -h' or 'ipcluster list -h' for more "
293 217 "information about creating and listing cluster dirs."
294 218 )
295 219 elif subcommand=='engines':
296 220 self.auto_create_cluster_dir = False
297 221 try:
298 super(IPClusterApp, self).find_resources()
222 super(IPClusterApp, self).init_clusterdir()
299 223 except ClusterDirError:
300 224 raise ClusterDirError(
301 225 "Could not find a cluster directory. A cluster dir must "
302 226 "be created before running 'ipcluster start'. Do "
303 227 "'ipcluster create -h' or 'ipcluster list -h' for more "
304 228 "information about creating and listing cluster dirs."
305 229 )
306 230
307 231 def list_cluster_dirs(self):
308 232 # Find the search paths
309 233 cluster_dir_paths = os.environ.get('IPCLUSTER_DIR_PATH','')
310 234 if cluster_dir_paths:
311 235 cluster_dir_paths = cluster_dir_paths.split(':')
312 236 else:
313 237 cluster_dir_paths = []
314 238 try:
315 ipython_dir = self.command_line_config.Global.ipython_dir
239 ipython_dir = self.ipython_dir
316 240 except AttributeError:
317 ipython_dir = self.default_config.Global.ipython_dir
241 ipython_dir = self.ipython_dir
318 242 paths = [os.getcwd(), ipython_dir] + \
319 243 cluster_dir_paths
320 244 paths = list(set(paths))
321 245
322 246 self.log.info('Searching for cluster dirs in paths: %r' % paths)
323 247 for path in paths:
324 248 files = os.listdir(path)
325 249 for f in files:
326 250 full_path = os.path.join(path, f)
327 251 if os.path.isdir(full_path) and f.startswith('cluster_'):
328 252 profile = full_path.split('_')[-1]
329 start_cmd = 'ipcluster start -p %s -n 4' % profile
253 start_cmd = 'ipcluster --start profile=%s n=4' % profile
330 254 print start_cmd + " ==> " + full_path
331 255
332 def pre_construct(self):
333 # IPClusterApp.pre_construct() is where we cd to the working directory.
334 super(IPClusterApp, self).pre_construct()
335 config = self.master_config
336 try:
337 daemon = config.Global.daemonize
338 if daemon:
339 config.Global.log_to_file = True
340 except AttributeError:
341 pass
342
343 def construct(self):
344 config = self.master_config
345 subcmd = config.Global.subcommand
346 reset = config.Global.reset_config
347 if subcmd == 'list':
348 return
349 if subcmd == 'create':
350 self.log.info('Copying default config files to cluster directory '
351 '[overwrite=%r]' % (reset,))
352 self.cluster_dir_obj.copy_all_config_files(overwrite=reset)
256 def init_launchers(self):
257 config = self.config
258 subcmd = self.subcommand
353 259 if subcmd =='start':
354 self.cluster_dir_obj.copy_all_config_files(overwrite=False)
355 260 self.start_logging()
356 261 self.loop = ioloop.IOLoop.instance()
357 262 # reactor.callWhenRunning(self.start_launchers)
358 263 dc = ioloop.DelayedCallback(self.start_launchers, 0, self.loop)
359 264 dc.start()
360 265 if subcmd == 'engines':
361 266 self.start_logging()
362 267 self.loop = ioloop.IOLoop.instance()
363 268 # reactor.callWhenRunning(self.start_launchers)
364 269 engine_only = lambda : self.start_launchers(controller=False)
365 270 dc = ioloop.DelayedCallback(engine_only, 0, self.loop)
366 271 dc.start()
367 272
368 273 def start_launchers(self, controller=True):
369 config = self.master_config
274 config = self.config
370 275
371 276 # Create the launchers. In both bases, we set the work_dir of
372 277 # the launcher to the cluster_dir. This is where the launcher's
373 278 # subprocesses will be launched. It is not where the controller
374 279 # and engine will be launched.
375 280 if controller:
376 cl_class = import_item(config.Global.controller_launcher)
281 clsname = self.controller_launcher_class
282 if '.' not in clsname:
283 clsname = 'IPython.parallel.apps.launcher.'+clsname
284 cl_class = import_item(clsname)
377 285 self.controller_launcher = cl_class(
378 work_dir=self.cluster_dir, config=config,
286 work_dir=self.cluster_dir.location, config=config,
379 287 logname=self.log.name
380 288 )
381 289 # Setup the observing of stopping. If the controller dies, shut
382 290 # everything down as that will be completely fatal for the engines.
383 291 self.controller_launcher.on_stop(self.stop_launchers)
384 292 # But, we don't monitor the stopping of engines. An engine dying
385 293 # is just fine and in principle a user could start a new engine.
386 294 # Also, if we did monitor engine stopping, it is difficult to
387 295 # know what to do when only some engines die. Currently, the
388 296 # observing of engine stopping is inconsistent. Some launchers
389 297 # might trigger on a single engine stopping, other wait until
390 298 # all stop. TODO: think more about how to handle this.
391 299 else:
392 300 self.controller_launcher = None
393 301
394 el_class = import_item(config.Global.engine_launcher)
302 clsname = self.engine_launcher_class
303 if '.' not in clsname:
304 # not a module, presume it's the raw name in apps.launcher
305 clsname = 'IPython.parallel.apps.launcher.'+clsname
306 print repr(clsname)
307 el_class = import_item(clsname)
308
395 309 self.engine_launcher = el_class(
396 work_dir=self.cluster_dir, config=config, logname=self.log.name
310 work_dir=self.cluster_dir.location, config=config, logname=self.log.name
397 311 )
398 312
399 313 # Setup signals
400 314 signal.signal(signal.SIGINT, self.sigint_handler)
401 315
402 316 # Start the controller and engines
403 317 self._stopping = False # Make sure stop_launchers is not called 2x.
404 318 if controller:
405 319 self.start_controller()
406 dc = ioloop.DelayedCallback(self.start_engines, 1000*config.Global.delay*controller, self.loop)
320 dc = ioloop.DelayedCallback(self.start_engines, 1000*self.delay*controller, self.loop)
407 321 dc.start()
408 322 self.startup_message()
409 323
410 324 def startup_message(self, r=None):
411 325 self.log.info("IPython cluster: started")
412 326 return r
413 327
414 328 def start_controller(self, r=None):
415 329 # self.log.info("In start_controller")
416 config = self.master_config
330 config = self.config
417 331 d = self.controller_launcher.start(
418 cluster_dir=config.Global.cluster_dir
332 cluster_dir=self.cluster_dir.location
419 333 )
420 334 return d
421 335
422 336 def start_engines(self, r=None):
423 337 # self.log.info("In start_engines")
424 config = self.master_config
338 config = self.config
425 339
426 340 d = self.engine_launcher.start(
427 config.Global.n,
428 cluster_dir=config.Global.cluster_dir
341 self.n,
342 cluster_dir=self.cluster_dir.location
429 343 )
430 344 return d
431 345
432 346 def stop_controller(self, r=None):
433 347 # self.log.info("In stop_controller")
434 348 if self.controller_launcher and self.controller_launcher.running:
435 349 return self.controller_launcher.stop()
436 350
437 351 def stop_engines(self, r=None):
438 352 # self.log.info("In stop_engines")
439 353 if self.engine_launcher.running:
440 354 d = self.engine_launcher.stop()
441 355 # d.addErrback(self.log_err)
442 356 return d
443 357 else:
444 358 return None
445 359
446 360 def log_err(self, f):
447 361 self.log.error(f.getTraceback())
448 362 return None
449 363
450 364 def stop_launchers(self, r=None):
451 365 if not self._stopping:
452 366 self._stopping = True
453 367 # if isinstance(r, failure.Failure):
454 368 # self.log.error('Unexpected error in ipcluster:')
455 369 # self.log.info(r.getTraceback())
456 370 self.log.error("IPython cluster: stopping")
457 371 # These return deferreds. We are not doing anything with them
458 372 # but we are holding refs to them as a reminder that they
459 373 # do return deferreds.
460 374 d1 = self.stop_engines()
461 375 d2 = self.stop_controller()
462 376 # Wait a few seconds to let things shut down.
463 377 dc = ioloop.DelayedCallback(self.loop.stop, 4000, self.loop)
464 378 dc.start()
465 379 # reactor.callLater(4.0, reactor.stop)
466 380
467 381 def sigint_handler(self, signum, frame):
468 382 self.stop_launchers()
469 383
470 384 def start_logging(self):
471 385 # Remove old log files of the controller and engine
472 if self.master_config.Global.clean_logs:
473 log_dir = self.master_config.Global.log_dir
386 if self.clean_logs:
387 log_dir = self.cluster_dir.log_dir
474 388 for f in os.listdir(log_dir):
475 389 if re.match(r'ip(engine|controller)z-\d+\.(log|err|out)',f):
476 390 os.remove(os.path.join(log_dir, f))
477 391 # This will remove old log files for ipcluster itself
478 super(IPClusterApp, self).start_logging()
392 # super(IPClusterApp, self).start_logging()
479 393
480 def start_app(self):
394 def start(self):
481 395 """Start the application, depending on what subcommand is used."""
482 subcmd = self.master_config.Global.subcommand
483 if subcmd=='create' or subcmd=='list':
396 subcmd = self.subcommand
397 if subcmd=='create':
398 # init_clusterdir step completed create action
484 399 return
485 400 elif subcmd=='start':
486 401 self.start_app_start()
487 402 elif subcmd=='stop':
488 403 self.start_app_stop()
489 404 elif subcmd=='engines':
490 405 self.start_app_engines()
406 else:
407 self.log.fatal("one command of '--start', '--stop', '--list', '--create', '--engines'"
408 " must be specified")
409 self.exit(-1)
491 410
492 411 def start_app_start(self):
493 412 """Start the app for the start subcommand."""
494 config = self.master_config
413 config = self.config
495 414 # First see if the cluster is already running
496 415 try:
497 416 pid = self.get_pid_from_file()
498 417 except PIDFileError:
499 418 pass
500 419 else:
501 420 if self.check_pid(pid):
502 421 self.log.critical(
503 422 'Cluster is already running with [pid=%s]. '
504 423 'use "ipcluster stop" to stop the cluster.' % pid
505 424 )
506 425 # Here I exit with a unusual exit status that other processes
507 426 # can watch for to learn how I existed.
508 427 self.exit(ALREADY_STARTED)
509 428 else:
510 429 self.remove_pid_file()
511 430
512 431
513 432 # Now log and daemonize
514 433 self.log.info(
515 'Starting ipcluster with [daemon=%r]' % config.Global.daemonize
434 'Starting ipcluster with [daemon=%r]' % self.daemonize
516 435 )
517 436 # TODO: Get daemonize working on Windows or as a Windows Server.
518 if config.Global.daemonize:
437 if self.daemonize:
519 438 if os.name=='posix':
520 439 from twisted.scripts._twistd_unix import daemonize
521 440 daemonize()
522 441
523 442 # Now write the new pid file AFTER our new forked pid is active.
524 443 self.write_pid_file()
525 444 try:
526 445 self.loop.start()
527 446 except KeyboardInterrupt:
528 447 pass
529 448 except zmq.ZMQError as e:
530 449 if e.errno == errno.EINTR:
531 450 pass
532 451 else:
533 452 raise
534 453 finally:
535 454 self.remove_pid_file()
536 455
537 456 def start_app_engines(self):
538 457 """Start the app for the start subcommand."""
539 config = self.master_config
458 config = self.config
540 459 # First see if the cluster is already running
541 460
542 461 # Now log and daemonize
543 462 self.log.info(
544 'Starting engines with [daemon=%r]' % config.Global.daemonize
463 'Starting engines with [daemon=%r]' % self.daemonize
545 464 )
546 465 # TODO: Get daemonize working on Windows or as a Windows Server.
547 if config.Global.daemonize:
466 if self.daemonize:
548 467 if os.name=='posix':
549 468 from twisted.scripts._twistd_unix import daemonize
550 469 daemonize()
551 470
552 471 # Now write the new pid file AFTER our new forked pid is active.
553 472 # self.write_pid_file()
554 473 try:
555 474 self.loop.start()
556 475 except KeyboardInterrupt:
557 476 pass
558 477 except zmq.ZMQError as e:
559 478 if e.errno == errno.EINTR:
560 479 pass
561 480 else:
562 481 raise
563 482 # self.remove_pid_file()
564 483
565 484 def start_app_stop(self):
566 485 """Start the app for the stop subcommand."""
567 config = self.master_config
486 config = self.config
568 487 try:
569 488 pid = self.get_pid_from_file()
570 489 except PIDFileError:
571 490 self.log.critical(
572 491 'Could not read pid file, cluster is probably not running.'
573 492 )
574 493 # Here I exit with a unusual exit status that other processes
575 494 # can watch for to learn how I existed.
576 495 self.remove_pid_file()
577 496 self.exit(ALREADY_STOPPED)
578 497
579 498 if not self.check_pid(pid):
580 499 self.log.critical(
581 500 'Cluster [pid=%r] is not running.' % pid
582 501 )
583 502 self.remove_pid_file()
584 503 # Here I exit with a unusual exit status that other processes
585 504 # can watch for to learn how I existed.
586 505 self.exit(ALREADY_STOPPED)
587 506
588 507 elif os.name=='posix':
589 sig = config.Global.signal
508 sig = self.signal
590 509 self.log.info(
591 510 "Stopping cluster [pid=%r] with [signal=%r]" % (pid, sig)
592 511 )
593 512 try:
594 513 os.kill(pid, sig)
595 514 except OSError:
596 515 self.log.error("Stopping cluster failed, assuming already dead.",
597 516 exc_info=True)
598 517 self.remove_pid_file()
599 518 elif os.name=='nt':
600 519 try:
601 520 # kill the whole tree
602 521 p = check_call(['taskkill', '-pid', str(pid), '-t', '-f'], stdout=PIPE,stderr=PIPE)
603 522 except (CalledProcessError, OSError):
604 523 self.log.error("Stopping cluster failed, assuming already dead.",
605 524 exc_info=True)
606 525 self.remove_pid_file()
607 526
608 527
609 528 def launch_new_instance():
610 529 """Create and run the IPython cluster."""
611 530 app = IPClusterApp()
531 app.parse_command_line()
532 cl_config = app.config
533 app.init_clusterdir()
534 if app.config_file:
535 app.load_config_file(app.config_file)
536 else:
537 app.load_config_file(app.default_config_file_name, path=app.cluster_dir.location)
538 # command-line should *override* config file, but command-line is necessary
539 # to determine clusterdir, etc.
540 app.update_config(cl_config)
541
542 app.to_work_dir()
543 app.init_launchers()
544
612 545 app.start()
613 546
614 547
615 548 if __name__ == '__main__':
616 549 launch_new_instance()
617 550
@@ -1,433 +1,408
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython controller application.
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2009 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 from __future__ import with_statement
19 19
20 20 import copy
21 21 import os
22 22 import logging
23 23 import socket
24 24 import stat
25 25 import sys
26 26 import uuid
27 27
28 from multiprocessing import Process
29
28 30 import zmq
31 from zmq.devices import ProcessMonitoredQueue
29 32 from zmq.log.handlers import PUBHandler
30 33 from zmq.utils import jsonapi as json
31 34
32 35 from IPython.config.loader import Config
33 36
34 37 from IPython.parallel import factory
35 38
36 39 from IPython.parallel.apps.clusterdir import (
37 ApplicationWithClusterDir,
38 ClusterDirConfigLoader
40 ClusterDir,
41 ClusterDirApplication,
42 base_flags
43 # ClusterDirConfigLoader
39 44 )
40 from IPython.parallel.util import disambiguate_ip_address, split_url
41 # from IPython.kernel.fcutil import FCServiceFactory, FURLError
42 from IPython.utils.traitlets import Instance, Unicode
45 from IPython.utils.importstring import import_item
46 from IPython.utils.traitlets import Instance, Unicode, Bool, List, CStr, Dict
47
48 # from IPython.parallel.controller.controller import ControllerFactory
49 from IPython.parallel.streamsession import StreamSession
50 from IPython.parallel.controller.heartmonitor import HeartMonitor
51 from IPython.parallel.controller.hub import Hub, HubFactory
52 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
53 from IPython.parallel.controller.sqlitedb import SQLiteDB
54
55 from IPython.parallel.util import signal_children,disambiguate_ip_address, split_url
43 56
44 from IPython.parallel.controller.controller import ControllerFactory
57 # conditional import of MongoDB backend class
58
59 try:
60 from IPython.parallel.controller.mongodb import MongoDB
61 except ImportError:
62 maybe_mongo = []
63 else:
64 maybe_mongo = [MongoDB]
45 65
46 66
47 67 #-----------------------------------------------------------------------------
48 68 # Module level variables
49 69 #-----------------------------------------------------------------------------
50 70
51 71
52 72 #: The default config file name for this application
53 73 default_config_file_name = u'ipcontroller_config.py'
54 74
55 75
56 76 _description = """Start the IPython controller for parallel computing.
57 77
58 78 The IPython controller provides a gateway between the IPython engines and
59 79 clients. The controller needs to be started before the engines and can be
60 80 configured using command line options or using a cluster directory. Cluster
61 81 directories contain config, log and security files and are usually located in
62 82 your ipython directory and named as "cluster_<profile>". See the --profile
63 83 and --cluster-dir options for details.
64 84 """
65 85
66 #-----------------------------------------------------------------------------
67 # Default interfaces
68 #-----------------------------------------------------------------------------
69
70 # The default client interfaces for FCClientServiceFactory.interfaces
71 default_client_interfaces = Config()
72 default_client_interfaces.Default.url_file = 'ipcontroller-client.url'
73
74 # Make this a dict we can pass to Config.__init__ for the default
75 default_client_interfaces = dict(copy.deepcopy(default_client_interfaces.items()))
76
77
78 86
79 # The default engine interfaces for FCEngineServiceFactory.interfaces
80 default_engine_interfaces = Config()
81 default_engine_interfaces.Default.url_file = u'ipcontroller-engine.url'
82
83 # Make this a dict we can pass to Config.__init__ for the default
84 default_engine_interfaces = dict(copy.deepcopy(default_engine_interfaces.items()))
85
86
87 #-----------------------------------------------------------------------------
88 # Service factories
89 #-----------------------------------------------------------------------------
90
91 #
92 # class FCClientServiceFactory(FCServiceFactory):
93 # """A Foolscap implementation of the client services."""
94 #
95 # cert_file = Unicode(u'ipcontroller-client.pem', config=True)
96 # interfaces = Instance(klass=Config, kw=default_client_interfaces,
97 # allow_none=False, config=True)
98 #
99 #
100 # class FCEngineServiceFactory(FCServiceFactory):
101 # """A Foolscap implementation of the engine services."""
102 #
103 # cert_file = Unicode(u'ipcontroller-engine.pem', config=True)
104 # interfaces = Instance(klass=dict, kw=default_engine_interfaces,
105 # allow_none=False, config=True)
106 #
107
108 #-----------------------------------------------------------------------------
109 # Command line options
110 #-----------------------------------------------------------------------------
111
112
113 class IPControllerAppConfigLoader(ClusterDirConfigLoader):
114
115 def _add_arguments(self):
116 super(IPControllerAppConfigLoader, self)._add_arguments()
117 paa = self.parser.add_argument
118
119 ## Hub Config:
120 paa('--mongodb',
121 dest='HubFactory.db_class', action='store_const',
122 const='IPython.parallel.controller.mongodb.MongoDB',
123 help='Use MongoDB for task storage [default: in-memory]')
124 paa('--sqlite',
125 dest='HubFactory.db_class', action='store_const',
126 const='IPython.parallel.controller.sqlitedb.SQLiteDB',
127 help='Use SQLite3 for DB task storage [default: in-memory]')
128 paa('--hb',
129 type=int, dest='HubFactory.hb', nargs=2,
130 help='The (2) ports the Hub\'s Heartmonitor will use for the heartbeat '
131 'connections [default: random]',
132 metavar='Hub.hb_ports')
133 paa('--ping',
134 type=int, dest='HubFactory.ping',
135 help='The frequency at which the Hub pings the engines for heartbeats '
136 ' (in ms) [default: 100]',
137 metavar='Hub.ping')
138
139 # Client config
140 paa('--client-ip',
141 type=str, dest='HubFactory.client_ip',
142 help='The IP address or hostname the Hub will listen on for '
143 'client connections. Both engine-ip and client-ip can be set simultaneously '
144 'via --ip [default: loopback]',
145 metavar='Hub.client_ip')
146 paa('--client-transport',
147 type=str, dest='HubFactory.client_transport',
148 help='The ZeroMQ transport the Hub will use for '
149 'client connections. Both engine-transport and client-transport can be set simultaneously '
150 'via --transport [default: tcp]',
151 metavar='Hub.client_transport')
152 paa('--query',
153 type=int, dest='HubFactory.query_port',
154 help='The port on which the Hub XREP socket will listen for result queries from clients [default: random]',
155 metavar='Hub.query_port')
156 paa('--notifier',
157 type=int, dest='HubFactory.notifier_port',
158 help='The port on which the Hub PUB socket will listen for notification connections [default: random]',
159 metavar='Hub.notifier_port')
160
161 # Engine config
162 paa('--engine-ip',
163 type=str, dest='HubFactory.engine_ip',
164 help='The IP address or hostname the Hub will listen on for '
165 'engine connections. This applies to the Hub and its schedulers'
166 'engine-ip and client-ip can be set simultaneously '
167 'via --ip [default: loopback]',
168 metavar='Hub.engine_ip')
169 paa('--engine-transport',
170 type=str, dest='HubFactory.engine_transport',
171 help='The ZeroMQ transport the Hub will use for '
172 'client connections. Both engine-transport and client-transport can be set simultaneously '
173 'via --transport [default: tcp]',
174 metavar='Hub.engine_transport')
175
176 # Scheduler config
177 paa('--mux',
178 type=int, dest='ControllerFactory.mux', nargs=2,
179 help='The (2) ports the MUX scheduler will listen on for client,engine '
180 'connections, respectively [default: random]',
181 metavar='Scheduler.mux_ports')
182 paa('--task',
183 type=int, dest='ControllerFactory.task', nargs=2,
184 help='The (2) ports the Task scheduler will listen on for client,engine '
185 'connections, respectively [default: random]',
186 metavar='Scheduler.task_ports')
187 paa('--control',
188 type=int, dest='ControllerFactory.control', nargs=2,
189 help='The (2) ports the Control scheduler will listen on for client,engine '
190 'connections, respectively [default: random]',
191 metavar='Scheduler.control_ports')
192 paa('--iopub',
193 type=int, dest='ControllerFactory.iopub', nargs=2,
194 help='The (2) ports the IOPub scheduler will listen on for client,engine '
195 'connections, respectively [default: random]',
196 metavar='Scheduler.iopub_ports')
197
198 paa('--scheme',
199 type=str, dest='HubFactory.scheme',
200 choices = ['pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'],
201 help='select the task scheduler scheme [default: Python LRU]',
202 metavar='Scheduler.scheme')
203 paa('--usethreads',
204 dest='ControllerFactory.usethreads', action="store_true",
205 help='Use threads instead of processes for the schedulers',
206 )
207 paa('--hwm',
208 dest='TaskScheduler.hwm', type=int,
209 help='specify the High Water Mark (HWM) '
210 'in the Python scheduler. This is the maximum number '
211 'of allowed outstanding tasks on each engine.',
212 )
213
214 ## Global config
215 paa('--log-to-file',
216 action='store_true', dest='Global.log_to_file',
217 help='Log to a file in the log directory (default is stdout)')
218 paa('--log-url',
219 type=str, dest='Global.log_url',
220 help='Broadcast logs to an iploggerz process [default: disabled]')
221 paa('-r','--reuse-files',
222 action='store_true', dest='Global.reuse_files',
223 help='Try to reuse existing json connection files.')
224 paa('--no-secure',
225 action='store_false', dest='Global.secure',
226 help='Turn off execution keys (default).')
227 paa('--secure',
228 action='store_true', dest='Global.secure',
229 help='Turn on execution keys.')
230 paa('--execkey',
231 type=str, dest='Global.exec_key',
232 help='path to a file containing an execution key.',
233 metavar='keyfile')
234 paa('--ssh',
235 type=str, dest='Global.sshserver',
236 help='ssh url for clients to use when connecting to the Controller '
237 'processes. It should be of the form: [user@]server[:port]. The '
238 'Controller\'s listening addresses must be accessible from the ssh server',
239 metavar='Global.sshserver')
240 paa('--location',
241 type=str, dest='Global.location',
242 help="The external IP or domain name of this machine, used for disambiguating "
243 "engine and client connections.",
244 metavar='Global.location')
245 factory.add_session_arguments(self.parser)
246 factory.add_registration_arguments(self.parser)
247 87
248 88
249 89 #-----------------------------------------------------------------------------
250 90 # The main application
251 91 #-----------------------------------------------------------------------------
252
253
254 class IPControllerApp(ApplicationWithClusterDir):
92 flags = {}
93 flags.update(base_flags)
94 flags.update({
95 'usethreads' : ( {'IPControllerApp' : {'usethreads' : True}},
96 'Use threads instead of processes for the schedulers'),
97 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
98 'use the SQLiteDB backend'),
99 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
100 'use the MongoDB backend'),
101 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
102 'use the in-memory DictDB backend'),
103 })
104
105 flags.update()
106
107 class IPControllerApp(ClusterDirApplication):
255 108
256 109 name = u'ipcontroller'
257 110 description = _description
258 command_line_loader = IPControllerAppConfigLoader
111 # command_line_loader = IPControllerAppConfigLoader
259 112 default_config_file_name = default_config_file_name
260 113 auto_create_cluster_dir = True
114 classes = [ClusterDir, StreamSession, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
261 115
262
263 def create_default_config(self):
264 super(IPControllerApp, self).create_default_config()
265 # Don't set defaults for Global.secure or Global.reuse_furls
266 # as those are set in a component.
267 self.default_config.Global.import_statements = []
268 self.default_config.Global.clean_logs = True
269 self.default_config.Global.secure = True
270 self.default_config.Global.reuse_files = False
271 self.default_config.Global.exec_key = "exec_key.key"
272 self.default_config.Global.sshserver = None
273 self.default_config.Global.location = None
274
275 def pre_construct(self):
276 super(IPControllerApp, self).pre_construct()
277 c = self.master_config
278 # The defaults for these are set in FCClientServiceFactory and
279 # FCEngineServiceFactory, so we only set them here if the global
280 # options have be set to override the class level defaults.
116 reuse_files = Bool(False, config=True,
117 help='Whether to reuse existing json connection files [default: False]'
118 )
119 secure = Bool(True, config=True,
120 help='Whether to use exec_keys for extra authentication [default: True]'
121 )
122 ssh_server = Unicode(u'', config=True,
123 help="""ssh url for clients to use when connecting to the Controller
124 processes. It should be of the form: [user@]server[:port]. The
125 Controller\'s listening addresses must be accessible from the ssh server""",
126 )
127 location = Unicode(u'', config=True,
128 help="""The external IP or domain name of the Controller, used for disambiguating
129 engine and client connections.""",
130 )
131 import_statements = List([], config=True,
132 help="import statements to be run at startup. Necessary in some environments"
133 )
134
135 usethreads = Bool(False, config=True,
136 help='Use threads instead of processes for the schedulers',
137 )
138
139 # internal
140 children = List()
141 mq_class = CStr('zmq.devices.ProcessMonitoredQueue')
142
143 def _usethreads_changed(self, name, old, new):
144 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
145
146 aliases = Dict(dict(
147 config = 'IPControllerApp.config_file',
148 # file = 'IPControllerApp.url_file',
149 log_level = 'IPControllerApp.log_level',
150 reuse_files = 'IPControllerApp.reuse_files',
151 secure = 'IPControllerApp.secure',
152 ssh = 'IPControllerApp.ssh_server',
153 usethreads = 'IPControllerApp.usethreads',
154 import_statements = 'IPControllerApp.import_statements',
155 location = 'IPControllerApp.location',
156
157 ident = 'StreamSession.session',
158 user = 'StreamSession.username',
159 exec_key = 'StreamSession.keyfile',
160
161 url = 'HubFactory.url',
162 ip = 'HubFactory.ip',
163 transport = 'HubFactory.transport',
164 port = 'HubFactory.regport',
165
166 ping = 'HeartMonitor.period',
167
168 scheme = 'TaskScheduler.scheme_name',
169 hwm = 'TaskScheduler.hwm',
170
171
172 profile = "ClusterDir.profile",
173 cluster_dir = 'ClusterDir.location',
281 174
282 # if hasattr(c.Global, 'reuse_furls'):
283 # c.FCClientServiceFactory.reuse_furls = c.Global.reuse_furls
284 # c.FCEngineServiceFactory.reuse_furls = c.Global.reuse_furls
285 # del c.Global.reuse_furls
286 # if hasattr(c.Global, 'secure'):
287 # c.FCClientServiceFactory.secure = c.Global.secure
288 # c.FCEngineServiceFactory.secure = c.Global.secure
289 # del c.Global.secure
175 ))
176 flags = Dict(flags)
290 177
178
291 179 def save_connection_dict(self, fname, cdict):
292 180 """save a connection dict to json file."""
293 c = self.master_config
181 c = self.config
294 182 url = cdict['url']
295 183 location = cdict['location']
296 184 if not location:
297 185 try:
298 186 proto,ip,port = split_url(url)
299 187 except AssertionError:
300 188 pass
301 189 else:
302 190 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
303 191 cdict['location'] = location
304 fname = os.path.join(c.Global.security_dir, fname)
192 fname = os.path.join(self.cluster_dir.security_dir, fname)
305 193 with open(fname, 'w') as f:
306 194 f.write(json.dumps(cdict, indent=2))
307 195 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
308 196
309 197 def load_config_from_json(self):
310 198 """load config from existing json connector files."""
311 c = self.master_config
199 c = self.config
312 200 # load from engine config
313 with open(os.path.join(c.Global.security_dir, 'ipcontroller-engine.json')) as f:
201 with open(os.path.join(self.cluster_dir.security_dir, 'ipcontroller-engine.json')) as f:
314 202 cfg = json.loads(f.read())
315 key = c.SessionFactory.exec_key = cfg['exec_key']
203 key = c.StreamSession.key = cfg['exec_key']
316 204 xport,addr = cfg['url'].split('://')
317 205 c.HubFactory.engine_transport = xport
318 206 ip,ports = addr.split(':')
319 207 c.HubFactory.engine_ip = ip
320 208 c.HubFactory.regport = int(ports)
321 c.Global.location = cfg['location']
209 self.location = cfg['location']
322 210
323 211 # load client config
324 with open(os.path.join(c.Global.security_dir, 'ipcontroller-client.json')) as f:
212 with open(os.path.join(self.cluster_dir.security_dir, 'ipcontroller-client.json')) as f:
325 213 cfg = json.loads(f.read())
326 214 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
327 215 xport,addr = cfg['url'].split('://')
328 216 c.HubFactory.client_transport = xport
329 217 ip,ports = addr.split(':')
330 218 c.HubFactory.client_ip = ip
331 c.Global.sshserver = cfg['ssh']
219 self.ssh_server = cfg['ssh']
332 220 assert int(ports) == c.HubFactory.regport, "regport mismatch"
333 221
334 def construct(self):
222 def init_hub(self):
335 223 # This is the working dir by now.
336 224 sys.path.insert(0, '')
337 c = self.master_config
225 c = self.config
338 226
339 self.import_statements()
340 reusing = c.Global.reuse_files
227 self.do_import_statements()
228 reusing = self.reuse_files
341 229 if reusing:
342 230 try:
343 231 self.load_config_from_json()
344 232 except (AssertionError,IOError):
345 233 reusing=False
346 234 # check again, because reusing may have failed:
347 235 if reusing:
348 236 pass
349 elif c.Global.secure:
350 keyfile = os.path.join(c.Global.security_dir, c.Global.exec_key)
237 elif self.secure:
351 238 key = str(uuid.uuid4())
352 with open(keyfile, 'w') as f:
353 f.write(key)
354 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
355 c.SessionFactory.exec_key = key
239 # keyfile = os.path.join(self.cluster_dir.security_dir, self.exec_key)
240 # with open(keyfile, 'w') as f:
241 # f.write(key)
242 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
243 c.StreamSession.key = key
356 244 else:
357 c.SessionFactory.exec_key = ''
358 key = None
245 key = c.StreamSession.key = ''
359 246
360 247 try:
361 self.factory = ControllerFactory(config=c, logname=self.log.name)
362 self.start_logging()
363 self.factory.construct()
248 self.factory = HubFactory(config=c, log=self.log)
249 # self.start_logging()
250 self.factory.init_hub()
364 251 except:
365 252 self.log.error("Couldn't construct the Controller", exc_info=True)
366 253 self.exit(1)
367 254
368 255 if not reusing:
369 256 # save to new json config files
370 257 f = self.factory
371 258 cdict = {'exec_key' : key,
372 'ssh' : c.Global.sshserver,
259 'ssh' : self.ssh_server,
373 260 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
374 'location' : c.Global.location
261 'location' : self.location
375 262 }
376 263 self.save_connection_dict('ipcontroller-client.json', cdict)
377 264 edict = cdict
378 265 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
379 266 self.save_connection_dict('ipcontroller-engine.json', edict)
267
268 #
269 def init_schedulers(self):
270 children = self.children
271 mq = import_item(self.mq_class)
380 272
273 hub = self.factory
274 # maybe_inproc = 'inproc://monitor' if self.usethreads else self.monitor_url
275 # IOPub relay (in a Process)
276 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
277 q.bind_in(hub.client_info['iopub'])
278 q.bind_out(hub.engine_info['iopub'])
279 q.setsockopt_out(zmq.SUBSCRIBE, '')
280 q.connect_mon(hub.monitor_url)
281 q.daemon=True
282 children.append(q)
283
284 # Multiplexer Queue (in a Process)
285 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
286 q.bind_in(hub.client_info['mux'])
287 q.setsockopt_in(zmq.IDENTITY, 'mux')
288 q.bind_out(hub.engine_info['mux'])
289 q.connect_mon(hub.monitor_url)
290 q.daemon=True
291 children.append(q)
292
293 # Control Queue (in a Process)
294 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
295 q.bind_in(hub.client_info['control'])
296 q.setsockopt_in(zmq.IDENTITY, 'control')
297 q.bind_out(hub.engine_info['control'])
298 q.connect_mon(hub.monitor_url)
299 q.daemon=True
300 children.append(q)
301 try:
302 scheme = self.config.TaskScheduler.scheme_name
303 except AttributeError:
304 scheme = TaskScheduler.scheme_name.get_default_value()
305 # Task Queue (in a Process)
306 if scheme == 'pure':
307 self.log.warn("task::using pure XREQ Task scheduler")
308 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
309 # q.setsockopt_out(zmq.HWM, hub.hwm)
310 q.bind_in(hub.client_info['task'][1])
311 q.setsockopt_in(zmq.IDENTITY, 'task')
312 q.bind_out(hub.engine_info['task'])
313 q.connect_mon(hub.monitor_url)
314 q.daemon=True
315 children.append(q)
316 elif scheme == 'none':
317 self.log.warn("task::using no Task scheduler")
318
319 else:
320 self.log.info("task::using Python %s Task scheduler"%scheme)
321 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
322 hub.monitor_url, hub.client_info['notification'])
323 kwargs = dict(logname=self.log.name, loglevel=self.log_level,
324 config=dict(self.config))
325 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
326 q.daemon=True
327 children.append(q)
328
381 329
382 330 def save_urls(self):
383 331 """save the registration urls to files."""
384 c = self.master_config
332 c = self.config
385 333
386 sec_dir = c.Global.security_dir
334 sec_dir = self.cluster_dir.security_dir
387 335 cf = self.factory
388 336
389 337 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
390 338 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
391 339
392 340 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
393 341 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
394 342
395 343
396 def import_statements(self):
397 statements = self.master_config.Global.import_statements
344 def do_import_statements(self):
345 statements = self.import_statements
398 346 for s in statements:
399 347 try:
400 348 self.log.msg("Executing statement: '%s'" % s)
401 349 exec s in globals(), locals()
402 350 except:
403 351 self.log.msg("Error running statement: %s" % s)
404 352
405 def start_logging(self):
406 super(IPControllerApp, self).start_logging()
407 if self.master_config.Global.log_url:
408 context = self.factory.context
409 lsock = context.socket(zmq.PUB)
410 lsock.connect(self.master_config.Global.log_url)
411 handler = PUBHandler(lsock)
412 handler.root_topic = 'controller'
413 handler.setLevel(self.log_level)
414 self.log.addHandler(handler)
415 #
416 def start_app(self):
353 # def start_logging(self):
354 # super(IPControllerApp, self).start_logging()
355 # if self.config.Global.log_url:
356 # context = self.factory.context
357 # lsock = context.socket(zmq.PUB)
358 # lsock.connect(self.config.Global.log_url)
359 # handler = PUBHandler(lsock)
360 # handler.root_topic = 'controller'
361 # handler.setLevel(self.log_level)
362 # self.log.addHandler(handler)
363 # #
364 def start(self):
417 365 # Start the subprocesses:
418 366 self.factory.start()
367 child_procs = []
368 for child in self.children:
369 child.start()
370 if isinstance(child, ProcessMonitoredQueue):
371 child_procs.append(child.launcher)
372 elif isinstance(child, Process):
373 child_procs.append(child)
374 if child_procs:
375 signal_children(child_procs)
376
419 377 self.write_pid_file(overwrite=True)
378
420 379 try:
421 380 self.factory.loop.start()
422 381 except KeyboardInterrupt:
423 382 self.log.critical("Interrupted, Exiting...\n")
424 383
425 384
426 385 def launch_new_instance():
427 386 """Create and run the IPython controller"""
428 387 app = IPControllerApp()
388 app.parse_command_line()
389 cl_config = app.config
390 # app.load_config_file()
391 app.init_clusterdir()
392 if app.config_file:
393 app.load_config_file(app.config_file)
394 else:
395 app.load_config_file(app.default_config_file_name, path=app.cluster_dir.location)
396 # command-line should *override* config file, but command-line is necessary
397 # to determine clusterdir, etc.
398 app.update_config(cl_config)
399
400 app.to_work_dir()
401 app.init_hub()
402 app.init_schedulers()
403
429 404 app.start()
430 405
431 406
432 407 if __name__ == '__main__':
433 408 launch_new_instance()
@@ -1,303 +1,301
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython engine application
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2009 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import json
19 19 import os
20 20 import sys
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop
24 24
25 25 from IPython.parallel.apps.clusterdir import (
26 ApplicationWithClusterDir,
27 ClusterDirConfigLoader
26 ClusterDirApplication,
27 ClusterDir,
28 base_aliases,
29 # ClusterDirConfigLoader
28 30 )
29 31 from IPython.zmq.log import EnginePUBHandler
30 32
31 from IPython.parallel import factory
33 from IPython.config.configurable import Configurable
34 from IPython.parallel.streamsession import StreamSession
32 35 from IPython.parallel.engine.engine import EngineFactory
33 36 from IPython.parallel.engine.streamkernel import Kernel
34 37 from IPython.parallel.util import disambiguate_url
35 38
36 39 from IPython.utils.importstring import import_item
40 from IPython.utils.traitlets import Str, Bool, Unicode, Dict, List, CStr
37 41
38 42
39 43 #-----------------------------------------------------------------------------
40 44 # Module level variables
41 45 #-----------------------------------------------------------------------------
42 46
43 47 #: The default config file name for this application
44 48 default_config_file_name = u'ipengine_config.py'
45 49
50 _description = """Start an IPython engine for parallel computing.\n\n
51
52 IPython engines run in parallel and perform computations on behalf of a client
53 and controller. A controller needs to be started before the engines. The
54 engine can be configured using command line options or using a cluster
55 directory. Cluster directories contain config, log and security files and are
56 usually located in your ipython directory and named as "cluster_<profile>".
57 See the `profile` and `cluster_dir` options for details.
58 """
59
60
61 #-----------------------------------------------------------------------------
62 # MPI configuration
63 #-----------------------------------------------------------------------------
46 64
47 65 mpi4py_init = """from mpi4py import MPI as mpi
48 66 mpi.size = mpi.COMM_WORLD.Get_size()
49 67 mpi.rank = mpi.COMM_WORLD.Get_rank()
50 68 """
51 69
52 70
53 71 pytrilinos_init = """from PyTrilinos import Epetra
54 72 class SimpleStruct:
55 73 pass
56 74 mpi = SimpleStruct()
57 75 mpi.rank = 0
58 76 mpi.size = 0
59 77 """
60 78
79 class MPI(Configurable):
80 """Configurable for MPI initialization"""
81 use = Str('', config=True,
82 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
83 )
61 84
62 _description = """Start an IPython engine for parallel computing.\n\n
85 def _on_use_changed(self, old, new):
86 # load default init script if it's not set
87 if not self.init_script:
88 self.init_script = self.default_inits.get(new, '')
89
90 init_script = Str('', config=True,
91 help="Initialization code for MPI")
92
93 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
94 config=True)
63 95
64 IPython engines run in parallel and perform computations on behalf of a client
65 and controller. A controller needs to be started before the engines. The
66 engine can be configured using command line options or using a cluster
67 directory. Cluster directories contain config, log and security files and are
68 usually located in your ipython directory and named as "cluster_<profile>".
69 See the --profile and --cluster-dir options for details.
70 """
71 96
72 97 #-----------------------------------------------------------------------------
73 # Command line options
98 # Main application
74 99 #-----------------------------------------------------------------------------
75 100
76 101
77 class IPEngineAppConfigLoader(ClusterDirConfigLoader):
78
79 def _add_arguments(self):
80 super(IPEngineAppConfigLoader, self)._add_arguments()
81 paa = self.parser.add_argument
82 # Controller config
83 paa('--file', '-f',
84 type=unicode, dest='Global.url_file',
85 help='The full location of the file containing the connection information fo '
86 'controller. If this is not given, the file must be in the '
87 'security directory of the cluster directory. This location is '
88 'resolved using the --profile and --app-dir options.',
89 metavar='Global.url_file')
90 # MPI
91 paa('--mpi',
92 type=str, dest='MPI.use',
93 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).',
94 metavar='MPI.use')
95 # Global config
96 paa('--log-to-file',
97 action='store_true', dest='Global.log_to_file',
98 help='Log to a file in the log directory (default is stdout)')
99 paa('--log-url',
100 dest='Global.log_url',
101 help="url of ZMQ logger, as started with iploggerz")
102 # paa('--execkey',
103 # type=str, dest='Global.exec_key',
104 # help='path to a file containing an execution key.',
105 # metavar='keyfile')
106 # paa('--no-secure',
107 # action='store_false', dest='Global.secure',
108 # help='Turn off execution keys.')
109 # paa('--secure',
110 # action='store_true', dest='Global.secure',
111 # help='Turn on execution keys (default).')
112 # init command
113 paa('-c',
114 type=str, dest='Global.extra_exec_lines',
102 class IPEngineApp(ClusterDirApplication):
103
104 app_name = Unicode(u'ipengine')
105 description = Unicode(_description)
106 default_config_file_name = default_config_file_name
107 classes = List([ClusterDir, StreamSession, EngineFactory, Kernel, MPI])
108
109 startup_script = Unicode(u'', config=True,
110 help='specify a script to be run at startup')
111 startup_command = Str('', config=True,
115 112 help='specify a command to be run at startup')
116 paa('-s',
117 type=unicode, dest='Global.extra_exec_file',
118 help='specify a script to be run at startup')
119
120 factory.add_session_arguments(self.parser)
121 factory.add_registration_arguments(self.parser)
122 113
114 url_file = Unicode(u'', config=True,
115 help="""The full location of the file containing the connection information for
116 the controller. If this is not given, the file must be in the
117 security directory of the cluster directory. This location is
118 resolved using the `profile` or `cluster_dir` options.""",
119 )
123 120
124 #-----------------------------------------------------------------------------
125 # Main application
126 #-----------------------------------------------------------------------------
121 url_file_name = Unicode(u'ipcontroller-engine.json')
127 122
123 aliases = Dict(dict(
124 config = 'IPEngineApp.config_file',
125 file = 'IPEngineApp.url_file',
126 c = 'IPEngineApp.startup_command',
127 s = 'IPEngineApp.startup_script',
128 128
129 class IPEngineApp(ApplicationWithClusterDir):
129 ident = 'StreamSession.session',
130 user = 'StreamSession.username',
131 exec_key = 'StreamSession.keyfile',
130 132
131 name = u'ipengine'
132 description = _description
133 command_line_loader = IPEngineAppConfigLoader
134 default_config_file_name = default_config_file_name
135 auto_create_cluster_dir = True
136
137 def create_default_config(self):
138 super(IPEngineApp, self).create_default_config()
139
140 # The engine should not clean logs as we don't want to remove the
141 # active log files of other running engines.
142 self.default_config.Global.clean_logs = False
143 self.default_config.Global.secure = True
144
145 # Global config attributes
146 self.default_config.Global.exec_lines = []
147 self.default_config.Global.extra_exec_lines = ''
148 self.default_config.Global.extra_exec_file = u''
149
150 # Configuration related to the controller
151 # This must match the filename (path not included) that the controller
152 # used for the FURL file.
153 self.default_config.Global.url_file = u''
154 self.default_config.Global.url_file_name = u'ipcontroller-engine.json'
155 # If given, this is the actual location of the controller's FURL file.
156 # If not, this is computed using the profile, app_dir and furl_file_name
157 # self.default_config.Global.key_file_name = u'exec_key.key'
158 # self.default_config.Global.key_file = u''
159
160 # MPI related config attributes
161 self.default_config.MPI.use = ''
162 self.default_config.MPI.mpi4py = mpi4py_init
163 self.default_config.MPI.pytrilinos = pytrilinos_init
164
165 def post_load_command_line_config(self):
166 pass
167
168 def pre_construct(self):
169 super(IPEngineApp, self).pre_construct()
170 # self.find_cont_url_file()
171 self.find_url_file()
172 if self.master_config.Global.extra_exec_lines:
173 self.master_config.Global.exec_lines.append(self.master_config.Global.extra_exec_lines)
174 if self.master_config.Global.extra_exec_file:
175 enc = sys.getfilesystemencoding() or 'utf8'
176 cmd="execfile(%r)"%self.master_config.Global.extra_exec_file.encode(enc)
177 self.master_config.Global.exec_lines.append(cmd)
133 url = 'EngineFactory.url',
134 ip = 'EngineFactory.ip',
135 transport = 'EngineFactory.transport',
136 port = 'EngineFactory.regport',
137 location = 'EngineFactory.location',
138
139 timeout = 'EngineFactory.timeout',
140
141 profile = "ClusterDir.profile",
142 cluster_dir = 'ClusterDir.location',
143
144 mpi = 'MPI.use',
145
146 log_level = 'IPEngineApp.log_level',
147 ))
178 148
179 149 # def find_key_file(self):
180 150 # """Set the key file.
181 151 #
182 152 # Here we don't try to actually see if it exists for is valid as that
183 153 # is hadled by the connection logic.
184 154 # """
185 155 # config = self.master_config
186 156 # # Find the actual controller key file
187 157 # if not config.Global.key_file:
188 158 # try_this = os.path.join(
189 159 # config.Global.cluster_dir,
190 160 # config.Global.security_dir,
191 161 # config.Global.key_file_name
192 162 # )
193 163 # config.Global.key_file = try_this
194 164
195 165 def find_url_file(self):
196 166 """Set the key file.
197 167
198 168 Here we don't try to actually see if it exists for is valid as that
199 169 is hadled by the connection logic.
200 170 """
201 config = self.master_config
171 config = self.config
202 172 # Find the actual controller key file
203 if not config.Global.url_file:
204 try_this = os.path.join(
205 config.Global.cluster_dir,
206 config.Global.security_dir,
207 config.Global.url_file_name
173 if not self.url_file:
174 self.url_file = os.path.join(
175 self.cluster_dir.security_dir,
176 self.url_file_name
208 177 )
209 config.Global.url_file = try_this
210 178
211 def construct(self):
179 def init_engine(self):
212 180 # This is the working dir by now.
213 181 sys.path.insert(0, '')
214 config = self.master_config
182 config = self.config
183 # print config
184 self.find_url_file()
185
215 186 # if os.path.exists(config.Global.key_file) and config.Global.secure:
216 187 # config.SessionFactory.exec_key = config.Global.key_file
217 if os.path.exists(config.Global.url_file):
218 with open(config.Global.url_file) as f:
188 if os.path.exists(self.url_file):
189 with open(self.url_file) as f:
219 190 d = json.loads(f.read())
220 191 for k,v in d.iteritems():
221 192 if isinstance(v, unicode):
222 193 d[k] = v.encode()
223 194 if d['exec_key']:
224 config.SessionFactory.exec_key = d['exec_key']
195 config.StreamSession.key = d['exec_key']
225 196 d['url'] = disambiguate_url(d['url'], d['location'])
226 config.RegistrationFactory.url=d['url']
197 config.EngineFactory.url = d['url']
227 198 config.EngineFactory.location = d['location']
228 199
200 try:
201 exec_lines = config.Kernel.exec_lines
202 except AttributeError:
203 config.Kernel.exec_lines = []
204 exec_lines = config.Kernel.exec_lines
229 205
230
231 config.Kernel.exec_lines = config.Global.exec_lines
232
233 self.start_mpi()
206 if self.startup_script:
207 enc = sys.getfilesystemencoding() or 'utf8'
208 cmd="execfile(%r)"%self.startup_script.encode(enc)
209 exec_lines.append(cmd)
210 if self.startup_command:
211 exec_lines.append(self.startup_command)
234 212
235 # Create the underlying shell class and EngineService
213 # Create the underlying shell class and Engine
236 214 # shell_class = import_item(self.master_config.Global.shell_class)
215 # print self.config
237 216 try:
238 self.engine = EngineFactory(config=config, logname=self.log.name)
217 self.engine = EngineFactory(config=config, log=self.log)
239 218 except:
240 219 self.log.error("Couldn't start the Engine", exc_info=True)
241 220 self.exit(1)
242 221
243 self.start_logging()
222 # self.start_logging()
244 223
245 224 # Create the service hierarchy
246 225 # self.main_service = service.MultiService()
247 226 # self.engine_service.setServiceParent(self.main_service)
248 227 # self.tub_service = Tub()
249 228 # self.tub_service.setServiceParent(self.main_service)
250 229 # # This needs to be called before the connection is initiated
251 230 # self.main_service.startService()
252 231
253 232 # This initiates the connection to the controller and calls
254 233 # register_engine to tell the controller we are ready to do work
255 234 # self.engine_connector = EngineConnector(self.tub_service)
256 235
257 236 # self.log.info("Using furl file: %s" % self.master_config.Global.furl_file)
258 237
259 238 # reactor.callWhenRunning(self.call_connect)
260 239
261
262 def start_logging(self):
263 super(IPEngineApp, self).start_logging()
264 if self.master_config.Global.log_url:
265 context = self.engine.context
266 lsock = context.socket(zmq.PUB)
267 lsock.connect(self.master_config.Global.log_url)
268 handler = EnginePUBHandler(self.engine, lsock)
269 handler.setLevel(self.log_level)
270 self.log.addHandler(handler)
271
272 def start_mpi(self):
240 # def start_logging(self):
241 # super(IPEngineApp, self).start_logging()
242 # if self.master_config.Global.log_url:
243 # context = self.engine.context
244 # lsock = context.socket(zmq.PUB)
245 # lsock.connect(self.master_config.Global.log_url)
246 # handler = EnginePUBHandler(self.engine, lsock)
247 # handler.setLevel(self.log_level)
248 # self.log.addHandler(handler)
249 #
250 def init_mpi(self):
273 251 global mpi
274 mpikey = self.master_config.MPI.use
275 mpi_import_statement = self.master_config.MPI.get(mpikey, None)
276 if mpi_import_statement is not None:
252 self.mpi = MPI(config=self.config)
253
254 mpi_import_statement = self.mpi.init_script
255 if mpi_import_statement:
277 256 try:
278 257 self.log.info("Initializing MPI:")
279 258 self.log.info(mpi_import_statement)
280 259 exec mpi_import_statement in globals()
281 260 except:
282 261 mpi = None
283 262 else:
284 263 mpi = None
285 264
286 265
287 def start_app(self):
266 def start(self):
288 267 self.engine.start()
289 268 try:
290 269 self.engine.loop.start()
291 270 except KeyboardInterrupt:
292 271 self.log.critical("Engine Interrupted, shutting down...\n")
293 272
294 273
295 274 def launch_new_instance():
296 """Create and run the IPython controller"""
275 """Create and run the IPython engine"""
297 276 app = IPEngineApp()
277 app.parse_command_line()
278 cl_config = app.config
279 app.init_clusterdir()
280 # app.load_config_file()
281 # print app.config
282 if app.config_file:
283 app.load_config_file(app.config_file)
284 else:
285 app.load_config_file(app.default_config_file_name, path=app.cluster_dir.location)
286
287 # command-line should *override* config file, but command-line is necessary
288 # to determine clusterdir, etc.
289 app.update_config(cl_config)
290
291 # print app.config
292 app.to_work_dir()
293 app.init_mpi()
294 app.init_engine()
295 print app.config
298 296 app.start()
299 297
300 298
301 299 if __name__ == '__main__':
302 300 launch_new_instance()
303 301
@@ -1,132 +1,132
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 A simple IPython logger application
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import sys
20 20
21 21 import zmq
22 22
23 23 from IPython.parallel.apps.clusterdir import (
24 ApplicationWithClusterDir,
24 ClusterDirApplication,
25 25 ClusterDirConfigLoader
26 26 )
27 27 from IPython.parallel.apps.logwatcher import LogWatcher
28 28
29 29 #-----------------------------------------------------------------------------
30 30 # Module level variables
31 31 #-----------------------------------------------------------------------------
32 32
33 33 #: The default config file name for this application
34 34 default_config_file_name = u'iplogger_config.py'
35 35
36 36 _description = """Start an IPython logger for parallel computing.\n\n
37 37
38 38 IPython controllers and engines (and your own processes) can broadcast log messages
39 39 by registering a `zmq.log.handlers.PUBHandler` with the `logging` module. The
40 40 logger can be configured using command line options or using a cluster
41 41 directory. Cluster directories contain config, log and security files and are
42 42 usually located in your ipython directory and named as "cluster_<profile>".
43 43 See the --profile and --cluster-dir options for details.
44 44 """
45 45
46 46 #-----------------------------------------------------------------------------
47 47 # Command line options
48 48 #-----------------------------------------------------------------------------
49 49
50 50
51 51 class IPLoggerAppConfigLoader(ClusterDirConfigLoader):
52 52
53 53 def _add_arguments(self):
54 54 super(IPLoggerAppConfigLoader, self)._add_arguments()
55 55 paa = self.parser.add_argument
56 56 # Controller config
57 57 paa('--url',
58 58 type=str, dest='LogWatcher.url',
59 59 help='The url the LogWatcher will listen on',
60 60 )
61 61 # MPI
62 62 paa('--topics',
63 63 type=str, dest='LogWatcher.topics', nargs='+',
64 64 help='What topics to subscribe to',
65 65 metavar='topics')
66 66 # Global config
67 67 paa('--log-to-file',
68 68 action='store_true', dest='Global.log_to_file',
69 69 help='Log to a file in the log directory (default is stdout)')
70 70
71 71
72 72 #-----------------------------------------------------------------------------
73 73 # Main application
74 74 #-----------------------------------------------------------------------------
75 75
76 76
77 class IPLoggerApp(ApplicationWithClusterDir):
77 class IPLoggerApp(ClusterDirApplication):
78 78
79 79 name = u'iploggerz'
80 80 description = _description
81 81 command_line_loader = IPLoggerAppConfigLoader
82 82 default_config_file_name = default_config_file_name
83 83 auto_create_cluster_dir = True
84 84
85 85 def create_default_config(self):
86 86 super(IPLoggerApp, self).create_default_config()
87 87
88 88 # The engine should not clean logs as we don't want to remove the
89 89 # active log files of other running engines.
90 90 self.default_config.Global.clean_logs = False
91 91
92 92 # If given, this is the actual location of the logger's URL file.
93 93 # If not, this is computed using the profile, app_dir and furl_file_name
94 94 self.default_config.Global.url_file_name = u'iplogger.url'
95 95 self.default_config.Global.url_file = u''
96 96
97 97 def post_load_command_line_config(self):
98 98 pass
99 99
100 100 def pre_construct(self):
101 101 super(IPLoggerApp, self).pre_construct()
102 102
103 103 def construct(self):
104 104 # This is the working dir by now.
105 105 sys.path.insert(0, '')
106 106
107 107 self.start_logging()
108 108
109 109 try:
110 110 self.watcher = LogWatcher(config=self.master_config, logname=self.log.name)
111 111 except:
112 112 self.log.error("Couldn't start the LogWatcher", exc_info=True)
113 113 self.exit(1)
114 114
115 115
116 116 def start_app(self):
117 117 try:
118 118 self.watcher.start()
119 119 self.watcher.loop.start()
120 120 except KeyboardInterrupt:
121 121 self.log.critical("Logging Interrupted, shutting down...\n")
122 122
123 123
124 124 def launch_new_instance():
125 125 """Create and run the IPython LogWatcher"""
126 126 app = IPLoggerApp()
127 127 app.start()
128 128
129 129
130 130 if __name__ == '__main__':
131 131 launch_new_instance()
132 132
@@ -1,996 +1,1070
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 Facilities for launching IPython processes asynchronously.
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2009 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import copy
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23
24 24 # signal imports, handling various platforms, versions
25 25
26 26 from signal import SIGINT, SIGTERM
27 27 try:
28 28 from signal import SIGKILL
29 29 except ImportError:
30 30 # Windows
31 31 SIGKILL=SIGTERM
32 32
33 33 try:
34 34 # Windows >= 2.7, 3.2
35 35 from signal import CTRL_C_EVENT as SIGINT
36 36 except ImportError:
37 37 pass
38 38
39 39 from subprocess import Popen, PIPE, STDOUT
40 40 try:
41 41 from subprocess import check_output
42 42 except ImportError:
43 43 # pre-2.7, define check_output with Popen
44 44 def check_output(*args, **kwargs):
45 45 kwargs.update(dict(stdout=PIPE))
46 46 p = Popen(*args, **kwargs)
47 47 out,err = p.communicate()
48 48 return out
49 49
50 50 from zmq.eventloop import ioloop
51 51
52 52 from IPython.external import Itpl
53 53 # from IPython.config.configurable import Configurable
54 54 from IPython.utils.traitlets import Any, Str, Int, List, Unicode, Dict, Instance, CUnicode
55 55 from IPython.utils.path import get_ipython_module_path
56 56 from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError
57 57
58 58 from IPython.parallel.factory import LoggingFactory
59 59
60 60 from .win32support import forward_read_events
61 61
62 62 from .winhpcjob import IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob
63 63
64 64 WINDOWS = os.name == 'nt'
65 65
66 66 #-----------------------------------------------------------------------------
67 67 # Paths to the kernel apps
68 68 #-----------------------------------------------------------------------------
69 69
70 70
71 71 ipcluster_cmd_argv = pycmd2argv(get_ipython_module_path(
72 72 'IPython.parallel.apps.ipclusterapp'
73 73 ))
74 74
75 75 ipengine_cmd_argv = pycmd2argv(get_ipython_module_path(
76 76 'IPython.parallel.apps.ipengineapp'
77 77 ))
78 78
79 79 ipcontroller_cmd_argv = pycmd2argv(get_ipython_module_path(
80 80 'IPython.parallel.apps.ipcontrollerapp'
81 81 ))
82 82
83 83 #-----------------------------------------------------------------------------
84 84 # Base launchers and errors
85 85 #-----------------------------------------------------------------------------
86 86
87 87
88 88 class LauncherError(Exception):
89 89 pass
90 90
91 91
92 92 class ProcessStateError(LauncherError):
93 93 pass
94 94
95 95
96 96 class UnknownStatus(LauncherError):
97 97 pass
98 98
99 99
100 100 class BaseLauncher(LoggingFactory):
101 101 """An asbtraction for starting, stopping and signaling a process."""
102 102
103 103 # In all of the launchers, the work_dir is where child processes will be
104 104 # run. This will usually be the cluster_dir, but may not be. any work_dir
105 105 # passed into the __init__ method will override the config value.
106 106 # This should not be used to set the work_dir for the actual engine
107 107 # and controller. Instead, use their own config files or the
108 108 # controller_args, engine_args attributes of the launchers to add
109 # the --work-dir option.
109 # the work_dir option.
110 110 work_dir = Unicode(u'.')
111 111 loop = Instance('zmq.eventloop.ioloop.IOLoop')
112 112
113 113 start_data = Any()
114 114 stop_data = Any()
115 115
116 116 def _loop_default(self):
117 117 return ioloop.IOLoop.instance()
118 118
119 119 def __init__(self, work_dir=u'.', config=None, **kwargs):
120 120 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
121 121 self.state = 'before' # can be before, running, after
122 122 self.stop_callbacks = []
123 123 self.start_data = None
124 124 self.stop_data = None
125 125
126 126 @property
127 127 def args(self):
128 128 """A list of cmd and args that will be used to start the process.
129 129
130 130 This is what is passed to :func:`spawnProcess` and the first element
131 131 will be the process name.
132 132 """
133 133 return self.find_args()
134 134
135 135 def find_args(self):
136 136 """The ``.args`` property calls this to find the args list.
137 137
138 138 Subcommand should implement this to construct the cmd and args.
139 139 """
140 140 raise NotImplementedError('find_args must be implemented in a subclass')
141 141
142 142 @property
143 143 def arg_str(self):
144 144 """The string form of the program arguments."""
145 145 return ' '.join(self.args)
146 146
147 147 @property
148 148 def running(self):
149 149 """Am I running."""
150 150 if self.state == 'running':
151 151 return True
152 152 else:
153 153 return False
154 154
155 155 def start(self):
156 156 """Start the process.
157 157
158 158 This must return a deferred that fires with information about the
159 159 process starting (like a pid, job id, etc.).
160 160 """
161 161 raise NotImplementedError('start must be implemented in a subclass')
162 162
163 163 def stop(self):
164 164 """Stop the process and notify observers of stopping.
165 165
166 166 This must return a deferred that fires with information about the
167 167 processing stopping, like errors that occur while the process is
168 168 attempting to be shut down. This deferred won't fire when the process
169 169 actually stops. To observe the actual process stopping, see
170 170 :func:`observe_stop`.
171 171 """
172 172 raise NotImplementedError('stop must be implemented in a subclass')
173 173
174 174 def on_stop(self, f):
175 175 """Get a deferred that will fire when the process stops.
176 176
177 177 The deferred will fire with data that contains information about
178 178 the exit status of the process.
179 179 """
180 180 if self.state=='after':
181 181 return f(self.stop_data)
182 182 else:
183 183 self.stop_callbacks.append(f)
184 184
185 185 def notify_start(self, data):
186 186 """Call this to trigger startup actions.
187 187
188 188 This logs the process startup and sets the state to 'running'. It is
189 189 a pass-through so it can be used as a callback.
190 190 """
191 191
192 192 self.log.info('Process %r started: %r' % (self.args[0], data))
193 193 self.start_data = data
194 194 self.state = 'running'
195 195 return data
196 196
197 197 def notify_stop(self, data):
198 198 """Call this to trigger process stop actions.
199 199
200 200 This logs the process stopping and sets the state to 'after'. Call
201 201 this to trigger all the deferreds from :func:`observe_stop`."""
202 202
203 203 self.log.info('Process %r stopped: %r' % (self.args[0], data))
204 204 self.stop_data = data
205 205 self.state = 'after'
206 206 for i in range(len(self.stop_callbacks)):
207 207 d = self.stop_callbacks.pop()
208 208 d(data)
209 209 return data
210 210
211 211 def signal(self, sig):
212 212 """Signal the process.
213 213
214 214 Return a semi-meaningless deferred after signaling the process.
215 215
216 216 Parameters
217 217 ----------
218 218 sig : str or int
219 219 'KILL', 'INT', etc., or any signal number
220 220 """
221 221 raise NotImplementedError('signal must be implemented in a subclass')
222 222
223 223
224 224 #-----------------------------------------------------------------------------
225 225 # Local process launchers
226 226 #-----------------------------------------------------------------------------
227 227
228 228
229 229 class LocalProcessLauncher(BaseLauncher):
230 230 """Start and stop an external process in an asynchronous manner.
231 231
232 232 This will launch the external process with a working directory of
233 233 ``self.work_dir``.
234 234 """
235 235
236 236 # This is used to to construct self.args, which is passed to
237 237 # spawnProcess.
238 238 cmd_and_args = List([])
239 239 poll_frequency = Int(100) # in ms
240 240
241 241 def __init__(self, work_dir=u'.', config=None, **kwargs):
242 242 super(LocalProcessLauncher, self).__init__(
243 243 work_dir=work_dir, config=config, **kwargs
244 244 )
245 245 self.process = None
246 246 self.start_deferred = None
247 247 self.poller = None
248 248
249 249 def find_args(self):
250 250 return self.cmd_and_args
251 251
252 252 def start(self):
253 253 if self.state == 'before':
254 254 self.process = Popen(self.args,
255 255 stdout=PIPE,stderr=PIPE,stdin=PIPE,
256 256 env=os.environ,
257 257 cwd=self.work_dir
258 258 )
259 259 if WINDOWS:
260 260 self.stdout = forward_read_events(self.process.stdout)
261 261 self.stderr = forward_read_events(self.process.stderr)
262 262 else:
263 263 self.stdout = self.process.stdout.fileno()
264 264 self.stderr = self.process.stderr.fileno()
265 265 self.loop.add_handler(self.stdout, self.handle_stdout, self.loop.READ)
266 266 self.loop.add_handler(self.stderr, self.handle_stderr, self.loop.READ)
267 267 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
268 268 self.poller.start()
269 269 self.notify_start(self.process.pid)
270 270 else:
271 271 s = 'The process was already started and has state: %r' % self.state
272 272 raise ProcessStateError(s)
273 273
274 274 def stop(self):
275 275 return self.interrupt_then_kill()
276 276
277 277 def signal(self, sig):
278 278 if self.state == 'running':
279 279 if WINDOWS and sig != SIGINT:
280 280 # use Windows tree-kill for better child cleanup
281 281 check_output(['taskkill', '-pid', str(self.process.pid), '-t', '-f'])
282 282 else:
283 283 self.process.send_signal(sig)
284 284
285 285 def interrupt_then_kill(self, delay=2.0):
286 286 """Send INT, wait a delay and then send KILL."""
287 287 try:
288 288 self.signal(SIGINT)
289 289 except Exception:
290 290 self.log.debug("interrupt failed")
291 291 pass
292 292 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
293 293 self.killer.start()
294 294
295 295 # callbacks, etc:
296 296
297 297 def handle_stdout(self, fd, events):
298 298 if WINDOWS:
299 299 line = self.stdout.recv()
300 300 else:
301 301 line = self.process.stdout.readline()
302 302 # a stopped process will be readable but return empty strings
303 303 if line:
304 304 self.log.info(line[:-1])
305 305 else:
306 306 self.poll()
307 307
308 308 def handle_stderr(self, fd, events):
309 309 if WINDOWS:
310 310 line = self.stderr.recv()
311 311 else:
312 312 line = self.process.stderr.readline()
313 313 # a stopped process will be readable but return empty strings
314 314 if line:
315 315 self.log.error(line[:-1])
316 316 else:
317 317 self.poll()
318 318
319 319 def poll(self):
320 320 status = self.process.poll()
321 321 if status is not None:
322 322 self.poller.stop()
323 323 self.loop.remove_handler(self.stdout)
324 324 self.loop.remove_handler(self.stderr)
325 325 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
326 326 return status
327 327
328 328 class LocalControllerLauncher(LocalProcessLauncher):
329 329 """Launch a controller as a regular external process."""
330 330
331 controller_cmd = List(ipcontroller_cmd_argv, config=True)
331 controller_cmd = List(ipcontroller_cmd_argv, config=True,
332 help="""Popen command to launch ipcontroller.""")
332 333 # Command line arguments to ipcontroller.
333 controller_args = List(['--log-to-file','--log-level', str(logging.INFO)], config=True)
334 controller_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
335 help="""command-line args to pass to ipcontroller""")
334 336
335 337 def find_args(self):
336 338 return self.controller_cmd + self.controller_args
337 339
338 340 def start(self, cluster_dir):
339 341 """Start the controller by cluster_dir."""
340 self.controller_args.extend(['--cluster-dir', cluster_dir])
342 self.controller_args.extend(['cluster_dir=%s'%cluster_dir])
341 343 self.cluster_dir = unicode(cluster_dir)
342 344 self.log.info("Starting LocalControllerLauncher: %r" % self.args)
343 345 return super(LocalControllerLauncher, self).start()
344 346
345 347
346 348 class LocalEngineLauncher(LocalProcessLauncher):
347 349 """Launch a single engine as a regular externall process."""
348 350
349 engine_cmd = List(ipengine_cmd_argv, config=True)
351 engine_cmd = List(ipengine_cmd_argv, config=True,
352 help="""command to launch the Engine.""")
350 353 # Command line arguments for ipengine.
351 engine_args = List(
352 ['--log-to-file','--log-level', str(logging.INFO)], config=True
354 engine_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
355 help="command-line arguments to pass to ipengine"
353 356 )
354 357
355 358 def find_args(self):
356 359 return self.engine_cmd + self.engine_args
357 360
358 361 def start(self, cluster_dir):
359 362 """Start the engine by cluster_dir."""
360 self.engine_args.extend(['--cluster-dir', cluster_dir])
363 self.engine_args.extend(['cluster_dir=%s'%cluster_dir])
361 364 self.cluster_dir = unicode(cluster_dir)
362 365 return super(LocalEngineLauncher, self).start()
363 366
364 367
365 368 class LocalEngineSetLauncher(BaseLauncher):
366 369 """Launch a set of engines as regular external processes."""
367 370
368 371 # Command line arguments for ipengine.
369 372 engine_args = List(
370 ['--log-to-file','--log-level', str(logging.INFO)], config=True
373 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
374 help="command-line arguments to pass to ipengine"
371 375 )
372 376 # launcher class
373 377 launcher_class = LocalEngineLauncher
374 378
375 379 launchers = Dict()
376 380 stop_data = Dict()
377 381
378 382 def __init__(self, work_dir=u'.', config=None, **kwargs):
379 383 super(LocalEngineSetLauncher, self).__init__(
380 384 work_dir=work_dir, config=config, **kwargs
381 385 )
382 386 self.stop_data = {}
383 387
384 388 def start(self, n, cluster_dir):
385 389 """Start n engines by profile or cluster_dir."""
386 390 self.cluster_dir = unicode(cluster_dir)
387 391 dlist = []
388 392 for i in range(n):
389 393 el = self.launcher_class(work_dir=self.work_dir, config=self.config, logname=self.log.name)
390 394 # Copy the engine args over to each engine launcher.
391 395 el.engine_args = copy.deepcopy(self.engine_args)
392 396 el.on_stop(self._notice_engine_stopped)
393 397 d = el.start(cluster_dir)
394 398 if i==0:
395 399 self.log.info("Starting LocalEngineSetLauncher: %r" % el.args)
396 400 self.launchers[i] = el
397 401 dlist.append(d)
398 402 self.notify_start(dlist)
399 403 # The consumeErrors here could be dangerous
400 404 # dfinal = gatherBoth(dlist, consumeErrors=True)
401 405 # dfinal.addCallback(self.notify_start)
402 406 return dlist
403 407
404 408 def find_args(self):
405 409 return ['engine set']
406 410
407 411 def signal(self, sig):
408 412 dlist = []
409 413 for el in self.launchers.itervalues():
410 414 d = el.signal(sig)
411 415 dlist.append(d)
412 416 # dfinal = gatherBoth(dlist, consumeErrors=True)
413 417 return dlist
414 418
415 419 def interrupt_then_kill(self, delay=1.0):
416 420 dlist = []
417 421 for el in self.launchers.itervalues():
418 422 d = el.interrupt_then_kill(delay)
419 423 dlist.append(d)
420 424 # dfinal = gatherBoth(dlist, consumeErrors=True)
421 425 return dlist
422 426
423 427 def stop(self):
424 428 return self.interrupt_then_kill()
425 429
426 430 def _notice_engine_stopped(self, data):
427 431 pid = data['pid']
428 432 for idx,el in self.launchers.iteritems():
429 433 if el.process.pid == pid:
430 434 break
431 435 self.launchers.pop(idx)
432 436 self.stop_data[idx] = data
433 437 if not self.launchers:
434 438 self.notify_stop(self.stop_data)
435 439
436 440
437 441 #-----------------------------------------------------------------------------
438 442 # MPIExec launchers
439 443 #-----------------------------------------------------------------------------
440 444
441 445
442 446 class MPIExecLauncher(LocalProcessLauncher):
443 447 """Launch an external process using mpiexec."""
444 448
445 # The mpiexec command to use in starting the process.
446 mpi_cmd = List(['mpiexec'], config=True)
447 # The command line arguments to pass to mpiexec.
448 mpi_args = List([], config=True)
449 # The program to start using mpiexec.
450 program = List(['date'], config=True)
451 # The command line argument to the program.
452 program_args = List([], config=True)
453 # The number of instances of the program to start.
454 n = Int(1, config=True)
449 mpi_cmd = List(['mpiexec'], config=True,
450 help="The mpiexec command to use in starting the process."
451 )
452 mpi_args = List([], config=True,
453 help="The command line arguments to pass to mpiexec."
454 )
455 program = List(['date'], config=True,
456 help="The program to start via mpiexec.")
457 program_args = List([], config=True,
458 help="The command line argument to the program."
459 )
460 n = Int(1)
455 461
456 462 def find_args(self):
457 463 """Build self.args using all the fields."""
458 464 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
459 465 self.program + self.program_args
460 466
461 467 def start(self, n):
462 468 """Start n instances of the program using mpiexec."""
463 469 self.n = n
464 470 return super(MPIExecLauncher, self).start()
465 471
466 472
467 473 class MPIExecControllerLauncher(MPIExecLauncher):
468 474 """Launch a controller using mpiexec."""
469 475
470 controller_cmd = List(ipcontroller_cmd_argv, config=True)
471 # Command line arguments to ipcontroller.
472 controller_args = List(['--log-to-file','--log-level', str(logging.INFO)], config=True)
473 n = Int(1, config=False)
476 controller_cmd = List(ipcontroller_cmd_argv, config=True,
477 help="Popen command to launch the Contropper"
478 )
479 controller_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
480 help="Command line arguments to pass to ipcontroller."
481 )
482 n = Int(1)
474 483
475 484 def start(self, cluster_dir):
476 485 """Start the controller by cluster_dir."""
477 self.controller_args.extend(['--cluster-dir', cluster_dir])
486 self.controller_args.extend(['cluster_dir=%s'%cluster_dir])
478 487 self.cluster_dir = unicode(cluster_dir)
479 488 self.log.info("Starting MPIExecControllerLauncher: %r" % self.args)
480 489 return super(MPIExecControllerLauncher, self).start(1)
481 490
482 491 def find_args(self):
483 492 return self.mpi_cmd + ['-n', self.n] + self.mpi_args + \
484 493 self.controller_cmd + self.controller_args
485 494
486 495
487 496 class MPIExecEngineSetLauncher(MPIExecLauncher):
488 497
489 program = List(ipengine_cmd_argv, config=True)
490 # Command line arguments for ipengine.
498 program = List(ipengine_cmd_argv, config=True,
499 help="Popen command for ipengine"
500 )
491 501 program_args = List(
492 ['--log-to-file','--log-level', str(logging.INFO)], config=True
502 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
503 help="Command line arguments for ipengine."
493 504 )
494 n = Int(1, config=True)
505 n = Int(1)
495 506
496 507 def start(self, n, cluster_dir):
497 508 """Start n engines by profile or cluster_dir."""
498 self.program_args.extend(['--cluster-dir', cluster_dir])
509 self.program_args.extend(['cluster_dir=%s'%cluster_dir])
499 510 self.cluster_dir = unicode(cluster_dir)
500 511 self.n = n
501 512 self.log.info('Starting MPIExecEngineSetLauncher: %r' % self.args)
502 513 return super(MPIExecEngineSetLauncher, self).start(n)
503 514
504 515 #-----------------------------------------------------------------------------
505 516 # SSH launchers
506 517 #-----------------------------------------------------------------------------
507 518
508 519 # TODO: Get SSH Launcher working again.
509 520
510 521 class SSHLauncher(LocalProcessLauncher):
511 522 """A minimal launcher for ssh.
512 523
513 524 To be useful this will probably have to be extended to use the ``sshx``
514 525 idea for environment variables. There could be other things this needs
515 526 as well.
516 527 """
517 528
518 ssh_cmd = List(['ssh'], config=True)
519 ssh_args = List(['-tt'], config=True)
520 program = List(['date'], config=True)
521 program_args = List([], config=True)
522 hostname = CUnicode('', config=True)
523 user = CUnicode('', config=True)
524 location = CUnicode('')
529 ssh_cmd = List(['ssh'], config=True,
530 help="command for starting ssh")
531 ssh_args = List(['-tt'], config=True,
532 help="args to pass to ssh")
533 program = List(['date'], config=True,
534 help="Program to launch via ssh")
535 program_args = List([], config=True,
536 help="args to pass to remote program")
537 hostname = CUnicode('', config=True,
538 help="hostname on which to launch the program")
539 user = CUnicode('', config=True,
540 help="username for ssh")
541 location = CUnicode('', config=True,
542 help="user@hostname location for ssh in one setting")
525 543
526 544 def _hostname_changed(self, name, old, new):
527 545 if self.user:
528 546 self.location = u'%s@%s' % (self.user, new)
529 547 else:
530 548 self.location = new
531 549
532 550 def _user_changed(self, name, old, new):
533 551 self.location = u'%s@%s' % (new, self.hostname)
534 552
535 553 def find_args(self):
536 554 return self.ssh_cmd + self.ssh_args + [self.location] + \
537 555 self.program + self.program_args
538 556
539 557 def start(self, cluster_dir, hostname=None, user=None):
540 558 self.cluster_dir = unicode(cluster_dir)
541 559 if hostname is not None:
542 560 self.hostname = hostname
543 561 if user is not None:
544 562 self.user = user
545 563
546 564 return super(SSHLauncher, self).start()
547 565
548 566 def signal(self, sig):
549 567 if self.state == 'running':
550 568 # send escaped ssh connection-closer
551 569 self.process.stdin.write('~.')
552 570 self.process.stdin.flush()
553 571
554 572
555 573
556 574 class SSHControllerLauncher(SSHLauncher):
557 575
558 program = List(ipcontroller_cmd_argv, config=True)
559 # Command line arguments to ipcontroller.
560 program_args = List(['-r', '--log-to-file','--log-level', str(logging.INFO)], config=True)
576 program = List(ipcontroller_cmd_argv, config=True,
577 help="remote ipcontroller command.")
578 program_args = List(['--reuse-files', '--log-to-file','log_level=%i'%logging.INFO], config=True,
579 help="Command line arguments to ipcontroller.")
561 580
562 581
563 582 class SSHEngineLauncher(SSHLauncher):
564 program = List(ipengine_cmd_argv, config=True)
583 program = List(ipengine_cmd_argv, config=True,
584 help="remote ipengine command.")
565 585 # Command line arguments for ipengine.
566 586 program_args = List(
567 ['--log-to-file','--log-level', str(logging.INFO)], config=True
587 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
588 help="Command line arguments to ipengine."
568 589 )
569 590
570 591 class SSHEngineSetLauncher(LocalEngineSetLauncher):
571 592 launcher_class = SSHEngineLauncher
572 engines = Dict(config=True)
593 engines = Dict(config=True,
594 help="""dict of engines to launch. This is a dict by hostname of ints,
595 corresponding to the number of engines to start on that host.""")
573 596
574 597 def start(self, n, cluster_dir):
575 598 """Start engines by profile or cluster_dir.
576 599 `n` is ignored, and the `engines` config property is used instead.
577 600 """
578 601
579 602 self.cluster_dir = unicode(cluster_dir)
580 603 dlist = []
581 604 for host, n in self.engines.iteritems():
582 605 if isinstance(n, (tuple, list)):
583 606 n, args = n
584 607 else:
585 608 args = copy.deepcopy(self.engine_args)
586 609
587 610 if '@' in host:
588 611 user,host = host.split('@',1)
589 612 else:
590 613 user=None
591 614 for i in range(n):
592 615 el = self.launcher_class(work_dir=self.work_dir, config=self.config, logname=self.log.name)
593 616
594 617 # Copy the engine args over to each engine launcher.
595 618 i
596 619 el.program_args = args
597 620 el.on_stop(self._notice_engine_stopped)
598 621 d = el.start(cluster_dir, user=user, hostname=host)
599 622 if i==0:
600 623 self.log.info("Starting SSHEngineSetLauncher: %r" % el.args)
601 624 self.launchers[host+str(i)] = el
602 625 dlist.append(d)
603 626 self.notify_start(dlist)
604 627 return dlist
605 628
606 629
607 630
608 631 #-----------------------------------------------------------------------------
609 632 # Windows HPC Server 2008 scheduler launchers
610 633 #-----------------------------------------------------------------------------
611 634
612 635
613 636 # This is only used on Windows.
614 637 def find_job_cmd():
615 638 if WINDOWS:
616 639 try:
617 640 return find_cmd('job')
618 641 except (FindCmdError, ImportError):
619 642 # ImportError will be raised if win32api is not installed
620 643 return 'job'
621 644 else:
622 645 return 'job'
623 646
624 647
625 648 class WindowsHPCLauncher(BaseLauncher):
626 649
627 # A regular expression used to get the job id from the output of the
628 # submit_command.
629 job_id_regexp = Str(r'\d+', config=True)
630 # The filename of the instantiated job script.
631 job_file_name = CUnicode(u'ipython_job.xml', config=True)
650 job_id_regexp = Str(r'\d+', config=True,
651 help="""A regular expression used to get the job id from the output of the
652 submit_command. """
653 )
654 job_file_name = CUnicode(u'ipython_job.xml', config=True,
655 help="The filename of the instantiated job script.")
632 656 # The full path to the instantiated job script. This gets made dynamically
633 657 # by combining the work_dir with the job_file_name.
634 658 job_file = CUnicode(u'')
635 # The hostname of the scheduler to submit the job to
636 scheduler = CUnicode('', config=True)
637 job_cmd = CUnicode(find_job_cmd(), config=True)
659 scheduler = CUnicode('', config=True,
660 help="The hostname of the scheduler to submit the job to.")
661 job_cmd = CUnicode(find_job_cmd(), config=True,
662 help="The command for submitting jobs.")
638 663
639 664 def __init__(self, work_dir=u'.', config=None, **kwargs):
640 665 super(WindowsHPCLauncher, self).__init__(
641 666 work_dir=work_dir, config=config, **kwargs
642 667 )
643 668
644 669 @property
645 670 def job_file(self):
646 671 return os.path.join(self.work_dir, self.job_file_name)
647 672
648 673 def write_job_file(self, n):
649 674 raise NotImplementedError("Implement write_job_file in a subclass.")
650 675
651 676 def find_args(self):
652 677 return [u'job.exe']
653 678
654 679 def parse_job_id(self, output):
655 680 """Take the output of the submit command and return the job id."""
656 681 m = re.search(self.job_id_regexp, output)
657 682 if m is not None:
658 683 job_id = m.group()
659 684 else:
660 685 raise LauncherError("Job id couldn't be determined: %s" % output)
661 686 self.job_id = job_id
662 687 self.log.info('Job started with job id: %r' % job_id)
663 688 return job_id
664 689
665 690 def start(self, n):
666 691 """Start n copies of the process using the Win HPC job scheduler."""
667 692 self.write_job_file(n)
668 693 args = [
669 694 'submit',
670 695 '/jobfile:%s' % self.job_file,
671 696 '/scheduler:%s' % self.scheduler
672 697 ]
673 698 self.log.info("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
674 699 # Twisted will raise DeprecationWarnings if we try to pass unicode to this
675 700 output = check_output([self.job_cmd]+args,
676 701 env=os.environ,
677 702 cwd=self.work_dir,
678 703 stderr=STDOUT
679 704 )
680 705 job_id = self.parse_job_id(output)
681 706 self.notify_start(job_id)
682 707 return job_id
683 708
684 709 def stop(self):
685 710 args = [
686 711 'cancel',
687 712 self.job_id,
688 713 '/scheduler:%s' % self.scheduler
689 714 ]
690 715 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
691 716 try:
692 717 output = check_output([self.job_cmd]+args,
693 718 env=os.environ,
694 719 cwd=self.work_dir,
695 720 stderr=STDOUT
696 721 )
697 722 except:
698 723 output = 'The job already appears to be stoppped: %r' % self.job_id
699 724 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
700 725 return output
701 726
702 727
703 728 class WindowsHPCControllerLauncher(WindowsHPCLauncher):
704 729
705 job_file_name = CUnicode(u'ipcontroller_job.xml', config=True)
706 extra_args = List([], config=False)
730 job_file_name = CUnicode(u'ipcontroller_job.xml', config=True,
731 help="WinHPC xml job file.")
732 extra_args = List([], config=False,
733 help="extra args to pass to ipcontroller")
707 734
708 735 def write_job_file(self, n):
709 736 job = IPControllerJob(config=self.config)
710 737
711 738 t = IPControllerTask(config=self.config)
712 739 # The tasks work directory is *not* the actual work directory of
713 740 # the controller. It is used as the base path for the stdout/stderr
714 741 # files that the scheduler redirects to.
715 742 t.work_directory = self.cluster_dir
716 # Add the --cluster-dir and from self.start().
743 # Add the cluster_dir and from self.start().
717 744 t.controller_args.extend(self.extra_args)
718 745 job.add_task(t)
719 746
720 747 self.log.info("Writing job description file: %s" % self.job_file)
721 748 job.write(self.job_file)
722 749
723 750 @property
724 751 def job_file(self):
725 752 return os.path.join(self.cluster_dir, self.job_file_name)
726 753
727 754 def start(self, cluster_dir):
728 755 """Start the controller by cluster_dir."""
729 self.extra_args = ['--cluster-dir', cluster_dir]
756 self.extra_args = ['cluster_dir=%s'%cluster_dir]
730 757 self.cluster_dir = unicode(cluster_dir)
731 758 return super(WindowsHPCControllerLauncher, self).start(1)
732 759
733 760
734 761 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher):
735 762
736 job_file_name = CUnicode(u'ipengineset_job.xml', config=True)
737 extra_args = List([], config=False)
763 job_file_name = CUnicode(u'ipengineset_job.xml', config=True,
764 help="jobfile for ipengines job")
765 extra_args = List([], config=False,
766 help="extra args to pas to ipengine")
738 767
739 768 def write_job_file(self, n):
740 769 job = IPEngineSetJob(config=self.config)
741 770
742 771 for i in range(n):
743 772 t = IPEngineTask(config=self.config)
744 773 # The tasks work directory is *not* the actual work directory of
745 774 # the engine. It is used as the base path for the stdout/stderr
746 775 # files that the scheduler redirects to.
747 776 t.work_directory = self.cluster_dir
748 # Add the --cluster-dir and from self.start().
777 # Add the cluster_dir and from self.start().
749 778 t.engine_args.extend(self.extra_args)
750 779 job.add_task(t)
751 780
752 781 self.log.info("Writing job description file: %s" % self.job_file)
753 782 job.write(self.job_file)
754 783
755 784 @property
756 785 def job_file(self):
757 786 return os.path.join(self.cluster_dir, self.job_file_name)
758 787
759 788 def start(self, n, cluster_dir):
760 789 """Start the controller by cluster_dir."""
761 self.extra_args = ['--cluster-dir', cluster_dir]
790 self.extra_args = ['cluster_dir=%s'%cluster_dir]
762 791 self.cluster_dir = unicode(cluster_dir)
763 792 return super(WindowsHPCEngineSetLauncher, self).start(n)
764 793
765 794
766 795 #-----------------------------------------------------------------------------
767 796 # Batch (PBS) system launchers
768 797 #-----------------------------------------------------------------------------
769 798
770 799 class BatchSystemLauncher(BaseLauncher):
771 800 """Launch an external process using a batch system.
772 801
773 802 This class is designed to work with UNIX batch systems like PBS, LSF,
774 803 GridEngine, etc. The overall model is that there are different commands
775 804 like qsub, qdel, etc. that handle the starting and stopping of the process.
776 805
777 806 This class also has the notion of a batch script. The ``batch_template``
778 807 attribute can be set to a string that is a template for the batch script.
779 808 This template is instantiated using Itpl. Thus the template can use
780 809 ${n} fot the number of instances. Subclasses can add additional variables
781 810 to the template dict.
782 811 """
783 812
784 813 # Subclasses must fill these in. See PBSEngineSet
785 # The name of the command line program used to submit jobs.
786 submit_command = List([''], config=True)
787 # The name of the command line program used to delete jobs.
788 delete_command = List([''], config=True)
789 # A regular expression used to get the job id from the output of the
790 # submit_command.
791 job_id_regexp = CUnicode('', config=True)
792 # The string that is the batch script template itself.
793 batch_template = CUnicode('', config=True)
794 # The file that contains the batch template
795 batch_template_file = CUnicode(u'', config=True)
796 # The filename of the instantiated batch script.
797 batch_file_name = CUnicode(u'batch_script', config=True)
798 # The PBS Queue
799 queue = CUnicode(u'', config=True)
814 submit_command = List([''], config=True,
815 help="The name of the command line program used to submit jobs.")
816 delete_command = List([''], config=True,
817 help="The name of the command line program used to delete jobs.")
818 job_id_regexp = CUnicode('', config=True,
819 help="""A regular expression used to get the job id from the output of the
820 submit_command.""")
821 batch_template = CUnicode('', config=True,
822 help="The string that is the batch script template itself.")
823 batch_template_file = CUnicode(u'', config=True,
824 help="The file that contains the batch template.")
825 batch_file_name = CUnicode(u'batch_script', config=True,
826 help="The filename of the instantiated batch script.")
827 queue = CUnicode(u'', config=True,
828 help="The PBS Queue.")
800 829
801 830 # not configurable, override in subclasses
802 831 # PBS Job Array regex
803 832 job_array_regexp = CUnicode('')
804 833 job_array_template = CUnicode('')
805 834 # PBS Queue regex
806 835 queue_regexp = CUnicode('')
807 836 queue_template = CUnicode('')
808 837 # The default batch template, override in subclasses
809 838 default_template = CUnicode('')
810 839 # The full path to the instantiated batch script.
811 840 batch_file = CUnicode(u'')
812 841 # the format dict used with batch_template:
813 842 context = Dict()
814 843
815 844
816 845 def find_args(self):
817 846 return self.submit_command + [self.batch_file]
818 847
819 848 def __init__(self, work_dir=u'.', config=None, **kwargs):
820 849 super(BatchSystemLauncher, self).__init__(
821 850 work_dir=work_dir, config=config, **kwargs
822 851 )
823 852 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
824 853
825 854 def parse_job_id(self, output):
826 855 """Take the output of the submit command and return the job id."""
827 856 m = re.search(self.job_id_regexp, output)
828 857 if m is not None:
829 858 job_id = m.group()
830 859 else:
831 860 raise LauncherError("Job id couldn't be determined: %s" % output)
832 861 self.job_id = job_id
833 862 self.log.info('Job submitted with job id: %r' % job_id)
834 863 return job_id
835 864
836 865 def write_batch_script(self, n):
837 866 """Instantiate and write the batch script to the work_dir."""
838 867 self.context['n'] = n
839 868 self.context['queue'] = self.queue
840 869 print self.context
841 870 # first priority is batch_template if set
842 871 if self.batch_template_file and not self.batch_template:
843 872 # second priority is batch_template_file
844 873 with open(self.batch_template_file) as f:
845 874 self.batch_template = f.read()
846 875 if not self.batch_template:
847 876 # third (last) priority is default_template
848 877 self.batch_template = self.default_template
849 878
850 879 regex = re.compile(self.job_array_regexp)
851 880 # print regex.search(self.batch_template)
852 881 if not regex.search(self.batch_template):
853 882 self.log.info("adding job array settings to batch script")
854 883 firstline, rest = self.batch_template.split('\n',1)
855 884 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
856 885
857 886 regex = re.compile(self.queue_regexp)
858 887 # print regex.search(self.batch_template)
859 888 if self.queue and not regex.search(self.batch_template):
860 889 self.log.info("adding PBS queue settings to batch script")
861 890 firstline, rest = self.batch_template.split('\n',1)
862 891 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
863 892
864 893 script_as_string = Itpl.itplns(self.batch_template, self.context)
865 894 self.log.info('Writing instantiated batch script: %s' % self.batch_file)
866 895
867 896 with open(self.batch_file, 'w') as f:
868 897 f.write(script_as_string)
869 898 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
870 899
871 900 def start(self, n, cluster_dir):
872 901 """Start n copies of the process using a batch system."""
873 902 # Here we save profile and cluster_dir in the context so they
874 903 # can be used in the batch script template as ${profile} and
875 904 # ${cluster_dir}
876 905 self.context['cluster_dir'] = cluster_dir
877 906 self.cluster_dir = unicode(cluster_dir)
878 907 self.write_batch_script(n)
879 908 output = check_output(self.args, env=os.environ)
880 909
881 910 job_id = self.parse_job_id(output)
882 911 self.notify_start(job_id)
883 912 return job_id
884 913
885 914 def stop(self):
886 915 output = check_output(self.delete_command+[self.job_id], env=os.environ)
887 916 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
888 917 return output
889 918
890 919
891 920 class PBSLauncher(BatchSystemLauncher):
892 921 """A BatchSystemLauncher subclass for PBS."""
893 922
894 submit_command = List(['qsub'], config=True)
895 delete_command = List(['qdel'], config=True)
896 job_id_regexp = CUnicode(r'\d+', config=True)
923 submit_command = List(['qsub'], config=True,
924 help="The PBS submit command ['qsub']")
925 delete_command = List(['qdel'], config=True,
926 help="The PBS delete command ['qsub']")
927 job_id_regexp = CUnicode(r'\d+', config=True,
928 help="Regular expresion for identifying the job ID [r'\d+']")
897 929
898 930 batch_file = CUnicode(u'')
899 931 job_array_regexp = CUnicode('#PBS\W+-t\W+[\w\d\-\$]+')
900 932 job_array_template = CUnicode('#PBS -t 1-$n')
901 933 queue_regexp = CUnicode('#PBS\W+-q\W+\$?\w+')
902 934 queue_template = CUnicode('#PBS -q $queue')
903 935
904 936
905 937 class PBSControllerLauncher(PBSLauncher):
906 938 """Launch a controller using PBS."""
907 939
908 batch_file_name = CUnicode(u'pbs_controller', config=True)
940 batch_file_name = CUnicode(u'pbs_controller', config=True,
941 help="batch file name for the controller job.")
909 942 default_template= CUnicode("""#!/bin/sh
910 943 #PBS -V
911 944 #PBS -N ipcontroller
912 %s --log-to-file --cluster-dir $cluster_dir
945 %s --log-to-file cluster_dir $cluster_dir
913 946 """%(' '.join(ipcontroller_cmd_argv)))
914 947
915 948 def start(self, cluster_dir):
916 949 """Start the controller by profile or cluster_dir."""
917 950 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
918 951 return super(PBSControllerLauncher, self).start(1, cluster_dir)
919 952
920 953
921 954 class PBSEngineSetLauncher(PBSLauncher):
922 955 """Launch Engines using PBS"""
923 batch_file_name = CUnicode(u'pbs_engines', config=True)
956 batch_file_name = CUnicode(u'pbs_engines', config=True,
957 help="batch file name for the engine(s) job.")
924 958 default_template= CUnicode(u"""#!/bin/sh
925 959 #PBS -V
926 960 #PBS -N ipengine
927 %s --cluster-dir $cluster_dir
961 %s cluster_dir $cluster_dir
928 962 """%(' '.join(ipengine_cmd_argv)))
929 963
930 964 def start(self, n, cluster_dir):
931 965 """Start n engines by profile or cluster_dir."""
932 966 self.log.info('Starting %i engines with PBSEngineSetLauncher: %r' % (n, self.args))
933 967 return super(PBSEngineSetLauncher, self).start(n, cluster_dir)
934 968
935 969 #SGE is very similar to PBS
936 970
937 971 class SGELauncher(PBSLauncher):
938 972 """Sun GridEngine is a PBS clone with slightly different syntax"""
939 973 job_array_regexp = CUnicode('#$$\W+-t\W+[\w\d\-\$]+')
940 974 job_array_template = CUnicode('#$$ -t 1-$n')
941 975 queue_regexp = CUnicode('#$$\W+-q\W+\$?\w+')
942 976 queue_template = CUnicode('#$$ -q $queue')
943 977
944 978 class SGEControllerLauncher(SGELauncher):
945 979 """Launch a controller using SGE."""
946 980
947 batch_file_name = CUnicode(u'sge_controller', config=True)
981 batch_file_name = CUnicode(u'sge_controller', config=True,
982 help="batch file name for the ipontroller job.")
948 983 default_template= CUnicode(u"""#$$ -V
949 984 #$$ -S /bin/sh
950 985 #$$ -N ipcontroller
951 %s --log-to-file --cluster-dir $cluster_dir
986 %s --log-to-file cluster_dir=$cluster_dir
952 987 """%(' '.join(ipcontroller_cmd_argv)))
953 988
954 989 def start(self, cluster_dir):
955 990 """Start the controller by profile or cluster_dir."""
956 991 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
957 992 return super(PBSControllerLauncher, self).start(1, cluster_dir)
958 993
959 994 class SGEEngineSetLauncher(SGELauncher):
960 995 """Launch Engines with SGE"""
961 batch_file_name = CUnicode(u'sge_engines', config=True)
996 batch_file_name = CUnicode(u'sge_engines', config=True,
997 help="batch file name for the engine(s) job.")
962 998 default_template = CUnicode("""#$$ -V
963 999 #$$ -S /bin/sh
964 1000 #$$ -N ipengine
965 %s --cluster-dir $cluster_dir
1001 %s cluster_dir=$cluster_dir
966 1002 """%(' '.join(ipengine_cmd_argv)))
967 1003
968 1004 def start(self, n, cluster_dir):
969 1005 """Start n engines by profile or cluster_dir."""
970 1006 self.log.info('Starting %i engines with SGEEngineSetLauncher: %r' % (n, self.args))
971 1007 return super(SGEEngineSetLauncher, self).start(n, cluster_dir)
972 1008
973 1009
974 1010 #-----------------------------------------------------------------------------
975 1011 # A launcher for ipcluster itself!
976 1012 #-----------------------------------------------------------------------------
977 1013
978 1014
979 1015 class IPClusterLauncher(LocalProcessLauncher):
980 1016 """Launch the ipcluster program in an external process."""
981 1017
982 ipcluster_cmd = List(ipcluster_cmd_argv, config=True)
983 # Command line arguments to pass to ipcluster.
1018 ipcluster_cmd = List(ipcluster_cmd_argv, config=True,
1019 help="Popen command for ipcluster")
984 1020 ipcluster_args = List(
985 ['--clean-logs', '--log-to-file', '--log-level', str(logging.INFO)], config=True)
1021 ['--clean-logs', '--log-to-file', 'log_level=%i'%logging.INFO], config=True,
1022 help="Command line arguments to pass to ipcluster.")
986 1023 ipcluster_subcommand = Str('start')
987 1024 ipcluster_n = Int(2)
988 1025
989 1026 def find_args(self):
990 return self.ipcluster_cmd + [self.ipcluster_subcommand] + \
991 ['-n', repr(self.ipcluster_n)] + self.ipcluster_args
1027 return self.ipcluster_cmd + ['--'+self.ipcluster_subcommand] + \
1028 ['n=%i'%self.ipcluster_n] + self.ipcluster_args
992 1029
993 1030 def start(self):
994 1031 self.log.info("Starting ipcluster: %r" % self.args)
995 1032 return super(IPClusterLauncher, self).start()
996 1033
1034 #-----------------------------------------------------------------------------
1035 # Collections of launchers
1036 #-----------------------------------------------------------------------------
1037
1038 local_launchers = [
1039 LocalControllerLauncher,
1040 LocalEngineLauncher,
1041 LocalEngineSetLauncher,
1042 ]
1043 mpi_launchers = [
1044 MPIExecLauncher,
1045 MPIExecControllerLauncher,
1046 MPIExecEngineSetLauncher,
1047 ]
1048 ssh_launchers = [
1049 SSHLauncher,
1050 SSHControllerLauncher,
1051 SSHEngineLauncher,
1052 SSHEngineSetLauncher,
1053 ]
1054 winhpc_launchers = [
1055 WindowsHPCLauncher,
1056 WindowsHPCControllerLauncher,
1057 WindowsHPCEngineSetLauncher,
1058 ]
1059 pbs_launchers = [
1060 PBSLauncher,
1061 PBSControllerLauncher,
1062 PBSEngineSetLauncher,
1063 ]
1064 sge_launchers = [
1065 SGELauncher,
1066 SGEControllerLauncher,
1067 SGEEngineSetLauncher,
1068 ]
1069 all_launchers = local_launchers + mpi_launchers + ssh_launchers + winhpc_launchers\
1070 + pbs_launchers + sge_launchers No newline at end of file
@@ -1,1356 +1,1356
1 1 """A semi-synchronous Client for the ZMQ cluster"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import os
14 14 import json
15 15 import time
16 16 import warnings
17 17 from datetime import datetime
18 18 from getpass import getpass
19 19 from pprint import pprint
20 20
21 21 pjoin = os.path.join
22 22
23 23 import zmq
24 24 # from zmq.eventloop import ioloop, zmqstream
25 25
26 26 from IPython.utils.path import get_ipython_dir
27 27 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
28 28 Dict, List, Bool, Str, Set)
29 29 from IPython.external.decorator import decorator
30 30 from IPython.external.ssh import tunnel
31 31
32 32 from IPython.parallel import error
33 33 from IPython.parallel import streamsession as ss
34 34 from IPython.parallel import util
35 35
36 36 from .asyncresult import AsyncResult, AsyncHubResult
37 37 from IPython.parallel.apps.clusterdir import ClusterDir, ClusterDirError
38 38 from .view import DirectView, LoadBalancedView
39 39
40 40 #--------------------------------------------------------------------------
41 41 # Decorators for Client methods
42 42 #--------------------------------------------------------------------------
43 43
44 44 @decorator
45 45 def spin_first(f, self, *args, **kwargs):
46 46 """Call spin() to sync state prior to calling the method."""
47 47 self.spin()
48 48 return f(self, *args, **kwargs)
49 49
50 50
51 51 #--------------------------------------------------------------------------
52 52 # Classes
53 53 #--------------------------------------------------------------------------
54 54
55 55 class Metadata(dict):
56 56 """Subclass of dict for initializing metadata values.
57 57
58 58 Attribute access works on keys.
59 59
60 60 These objects have a strict set of keys - errors will raise if you try
61 61 to add new keys.
62 62 """
63 63 def __init__(self, *args, **kwargs):
64 64 dict.__init__(self)
65 65 md = {'msg_id' : None,
66 66 'submitted' : None,
67 67 'started' : None,
68 68 'completed' : None,
69 69 'received' : None,
70 70 'engine_uuid' : None,
71 71 'engine_id' : None,
72 72 'follow' : None,
73 73 'after' : None,
74 74 'status' : None,
75 75
76 76 'pyin' : None,
77 77 'pyout' : None,
78 78 'pyerr' : None,
79 79 'stdout' : '',
80 80 'stderr' : '',
81 81 }
82 82 self.update(md)
83 83 self.update(dict(*args, **kwargs))
84 84
85 85 def __getattr__(self, key):
86 86 """getattr aliased to getitem"""
87 87 if key in self.iterkeys():
88 88 return self[key]
89 89 else:
90 90 raise AttributeError(key)
91 91
92 92 def __setattr__(self, key, value):
93 93 """setattr aliased to setitem, with strict"""
94 94 if key in self.iterkeys():
95 95 self[key] = value
96 96 else:
97 97 raise AttributeError(key)
98 98
99 99 def __setitem__(self, key, value):
100 100 """strict static key enforcement"""
101 101 if key in self.iterkeys():
102 102 dict.__setitem__(self, key, value)
103 103 else:
104 104 raise KeyError(key)
105 105
106 106
107 107 class Client(HasTraits):
108 108 """A semi-synchronous client to the IPython ZMQ cluster
109 109
110 110 Parameters
111 111 ----------
112 112
113 113 url_or_file : bytes; zmq url or path to ipcontroller-client.json
114 114 Connection information for the Hub's registration. If a json connector
115 115 file is given, then likely no further configuration is necessary.
116 116 [Default: use profile]
117 117 profile : bytes
118 118 The name of the Cluster profile to be used to find connector information.
119 119 [Default: 'default']
120 120 context : zmq.Context
121 121 Pass an existing zmq.Context instance, otherwise the client will create its own.
122 122 username : bytes
123 123 set username to be passed to the Session object
124 124 debug : bool
125 125 flag for lots of message printing for debug purposes
126 126
127 127 #-------------- ssh related args ----------------
128 128 # These are args for configuring the ssh tunnel to be used
129 129 # credentials are used to forward connections over ssh to the Controller
130 130 # Note that the ip given in `addr` needs to be relative to sshserver
131 131 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
132 132 # and set sshserver as the same machine the Controller is on. However,
133 133 # the only requirement is that sshserver is able to see the Controller
134 134 # (i.e. is within the same trusted network).
135 135
136 136 sshserver : str
137 137 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
138 138 If keyfile or password is specified, and this is not, it will default to
139 139 the ip given in addr.
140 140 sshkey : str; path to public ssh key file
141 141 This specifies a key to be used in ssh login, default None.
142 142 Regular default ssh keys will be used without specifying this argument.
143 143 password : str
144 144 Your ssh password to sshserver. Note that if this is left None,
145 145 you will be prompted for it if passwordless key based login is unavailable.
146 146 paramiko : bool
147 147 flag for whether to use paramiko instead of shell ssh for tunneling.
148 148 [default: True on win32, False else]
149 149
150 150 ------- exec authentication args -------
151 151 If even localhost is untrusted, you can have some protection against
152 152 unauthorized execution by using a key. Messages are still sent
153 153 as cleartext, so if someone can snoop your loopback traffic this will
154 154 not help against malicious attacks.
155 155
156 156 exec_key : str
157 157 an authentication key or file containing a key
158 158 default: None
159 159
160 160
161 161 Attributes
162 162 ----------
163 163
164 164 ids : list of int engine IDs
165 165 requesting the ids attribute always synchronizes
166 166 the registration state. To request ids without synchronization,
167 167 use semi-private _ids attributes.
168 168
169 169 history : list of msg_ids
170 170 a list of msg_ids, keeping track of all the execution
171 171 messages you have submitted in order.
172 172
173 173 outstanding : set of msg_ids
174 174 a set of msg_ids that have been submitted, but whose
175 175 results have not yet been received.
176 176
177 177 results : dict
178 178 a dict of all our results, keyed by msg_id
179 179
180 180 block : bool
181 181 determines default behavior when block not specified
182 182 in execution methods
183 183
184 184 Methods
185 185 -------
186 186
187 187 spin
188 188 flushes incoming results and registration state changes
189 189 control methods spin, and requesting `ids` also ensures up to date
190 190
191 191 wait
192 192 wait on one or more msg_ids
193 193
194 194 execution methods
195 195 apply
196 196 legacy: execute, run
197 197
198 198 data movement
199 199 push, pull, scatter, gather
200 200
201 201 query methods
202 202 queue_status, get_result, purge, result_status
203 203
204 204 control methods
205 205 abort, shutdown
206 206
207 207 """
208 208
209 209
210 210 block = Bool(False)
211 211 outstanding = Set()
212 212 results = Instance('collections.defaultdict', (dict,))
213 213 metadata = Instance('collections.defaultdict', (Metadata,))
214 214 history = List()
215 215 debug = Bool(False)
216 216 profile=CUnicode('default')
217 217
218 218 _outstanding_dict = Instance('collections.defaultdict', (set,))
219 219 _ids = List()
220 220 _connected=Bool(False)
221 221 _ssh=Bool(False)
222 222 _context = Instance('zmq.Context')
223 223 _config = Dict()
224 224 _engines=Instance(util.ReverseDict, (), {})
225 225 # _hub_socket=Instance('zmq.Socket')
226 226 _query_socket=Instance('zmq.Socket')
227 227 _control_socket=Instance('zmq.Socket')
228 228 _iopub_socket=Instance('zmq.Socket')
229 229 _notification_socket=Instance('zmq.Socket')
230 230 _mux_socket=Instance('zmq.Socket')
231 231 _task_socket=Instance('zmq.Socket')
232 232 _task_scheme=Str()
233 233 _closed = False
234 234 _ignored_control_replies=Int(0)
235 235 _ignored_hub_replies=Int(0)
236 236
237 237 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
238 238 context=None, username=None, debug=False, exec_key=None,
239 239 sshserver=None, sshkey=None, password=None, paramiko=None,
240 240 timeout=10
241 241 ):
242 242 super(Client, self).__init__(debug=debug, profile=profile)
243 243 if context is None:
244 244 context = zmq.Context.instance()
245 245 self._context = context
246 246
247 247
248 248 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
249 249 if self._cd is not None:
250 250 if url_or_file is None:
251 251 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
252 252 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
253 253 " Please specify at least one of url_or_file or profile."
254 254
255 255 try:
256 256 util.validate_url(url_or_file)
257 257 except AssertionError:
258 258 if not os.path.exists(url_or_file):
259 259 if self._cd:
260 260 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
261 261 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
262 262 with open(url_or_file) as f:
263 263 cfg = json.loads(f.read())
264 264 else:
265 265 cfg = {'url':url_or_file}
266 266
267 267 # sync defaults from args, json:
268 268 if sshserver:
269 269 cfg['ssh'] = sshserver
270 270 if exec_key:
271 271 cfg['exec_key'] = exec_key
272 272 exec_key = cfg['exec_key']
273 273 sshserver=cfg['ssh']
274 274 url = cfg['url']
275 275 location = cfg.setdefault('location', None)
276 276 cfg['url'] = util.disambiguate_url(cfg['url'], location)
277 277 url = cfg['url']
278 278
279 279 self._config = cfg
280 280
281 281 self._ssh = bool(sshserver or sshkey or password)
282 282 if self._ssh and sshserver is None:
283 283 # default to ssh via localhost
284 284 sshserver = url.split('://')[1].split(':')[0]
285 285 if self._ssh and password is None:
286 286 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
287 287 password=False
288 288 else:
289 289 password = getpass("SSH Password for %s: "%sshserver)
290 290 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
291 291 if exec_key is not None and os.path.isfile(exec_key):
292 292 arg = 'keyfile'
293 293 else:
294 294 arg = 'key'
295 295 key_arg = {arg:exec_key}
296 296 if username is None:
297 297 self.session = ss.StreamSession(**key_arg)
298 298 else:
299 self.session = ss.StreamSession(username, **key_arg)
299 self.session = ss.StreamSession(username=username, **key_arg)
300 300 self._query_socket = self._context.socket(zmq.XREQ)
301 301 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
302 302 if self._ssh:
303 303 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
304 304 else:
305 305 self._query_socket.connect(url)
306 306
307 307 self.session.debug = self.debug
308 308
309 309 self._notification_handlers = {'registration_notification' : self._register_engine,
310 310 'unregistration_notification' : self._unregister_engine,
311 311 'shutdown_notification' : lambda msg: self.close(),
312 312 }
313 313 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
314 314 'apply_reply' : self._handle_apply_reply}
315 315 self._connect(sshserver, ssh_kwargs, timeout)
316 316
317 317 def __del__(self):
318 318 """cleanup sockets, but _not_ context."""
319 319 self.close()
320 320
321 321 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
322 322 if ipython_dir is None:
323 323 ipython_dir = get_ipython_dir()
324 324 if cluster_dir is not None:
325 325 try:
326 326 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
327 327 return
328 328 except ClusterDirError:
329 329 pass
330 330 elif profile is not None:
331 331 try:
332 332 self._cd = ClusterDir.find_cluster_dir_by_profile(
333 333 ipython_dir, profile)
334 334 return
335 335 except ClusterDirError:
336 336 pass
337 337 self._cd = None
338 338
339 339 def _update_engines(self, engines):
340 340 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
341 341 for k,v in engines.iteritems():
342 342 eid = int(k)
343 343 self._engines[eid] = bytes(v) # force not unicode
344 344 self._ids.append(eid)
345 345 self._ids = sorted(self._ids)
346 346 if sorted(self._engines.keys()) != range(len(self._engines)) and \
347 347 self._task_scheme == 'pure' and self._task_socket:
348 348 self._stop_scheduling_tasks()
349 349
350 350 def _stop_scheduling_tasks(self):
351 351 """Stop scheduling tasks because an engine has been unregistered
352 352 from a pure ZMQ scheduler.
353 353 """
354 354 self._task_socket.close()
355 355 self._task_socket = None
356 356 msg = "An engine has been unregistered, and we are using pure " +\
357 357 "ZMQ task scheduling. Task farming will be disabled."
358 358 if self.outstanding:
359 359 msg += " If you were running tasks when this happened, " +\
360 360 "some `outstanding` msg_ids may never resolve."
361 361 warnings.warn(msg, RuntimeWarning)
362 362
363 363 def _build_targets(self, targets):
364 364 """Turn valid target IDs or 'all' into two lists:
365 365 (int_ids, uuids).
366 366 """
367 367 if not self._ids:
368 368 # flush notification socket if no engines yet, just in case
369 369 if not self.ids:
370 370 raise error.NoEnginesRegistered("Can't build targets without any engines")
371 371
372 372 if targets is None:
373 373 targets = self._ids
374 374 elif isinstance(targets, str):
375 375 if targets.lower() == 'all':
376 376 targets = self._ids
377 377 else:
378 378 raise TypeError("%r not valid str target, must be 'all'"%(targets))
379 379 elif isinstance(targets, int):
380 380 if targets < 0:
381 381 targets = self.ids[targets]
382 382 if targets not in self._ids:
383 383 raise IndexError("No such engine: %i"%targets)
384 384 targets = [targets]
385 385
386 386 if isinstance(targets, slice):
387 387 indices = range(len(self._ids))[targets]
388 388 ids = self.ids
389 389 targets = [ ids[i] for i in indices ]
390 390
391 391 if not isinstance(targets, (tuple, list, xrange)):
392 392 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
393 393
394 394 return [self._engines[t] for t in targets], list(targets)
395 395
396 396 def _connect(self, sshserver, ssh_kwargs, timeout):
397 397 """setup all our socket connections to the cluster. This is called from
398 398 __init__."""
399 399
400 400 # Maybe allow reconnecting?
401 401 if self._connected:
402 402 return
403 403 self._connected=True
404 404
405 405 def connect_socket(s, url):
406 406 url = util.disambiguate_url(url, self._config['location'])
407 407 if self._ssh:
408 408 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
409 409 else:
410 410 return s.connect(url)
411 411
412 412 self.session.send(self._query_socket, 'connection_request')
413 413 r,w,x = zmq.select([self._query_socket],[],[], timeout)
414 414 if not r:
415 415 raise error.TimeoutError("Hub connection request timed out")
416 416 idents,msg = self.session.recv(self._query_socket,mode=0)
417 417 if self.debug:
418 418 pprint(msg)
419 419 msg = ss.Message(msg)
420 420 content = msg.content
421 421 self._config['registration'] = dict(content)
422 422 if content.status == 'ok':
423 423 if content.mux:
424 424 self._mux_socket = self._context.socket(zmq.XREQ)
425 425 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
426 426 connect_socket(self._mux_socket, content.mux)
427 427 if content.task:
428 428 self._task_scheme, task_addr = content.task
429 429 self._task_socket = self._context.socket(zmq.XREQ)
430 430 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
431 431 connect_socket(self._task_socket, task_addr)
432 432 if content.notification:
433 433 self._notification_socket = self._context.socket(zmq.SUB)
434 434 connect_socket(self._notification_socket, content.notification)
435 435 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
436 436 # if content.query:
437 437 # self._query_socket = self._context.socket(zmq.XREQ)
438 438 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
439 439 # connect_socket(self._query_socket, content.query)
440 440 if content.control:
441 441 self._control_socket = self._context.socket(zmq.XREQ)
442 442 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
443 443 connect_socket(self._control_socket, content.control)
444 444 if content.iopub:
445 445 self._iopub_socket = self._context.socket(zmq.SUB)
446 446 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
447 447 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
448 448 connect_socket(self._iopub_socket, content.iopub)
449 449 self._update_engines(dict(content.engines))
450 450 else:
451 451 self._connected = False
452 452 raise Exception("Failed to connect!")
453 453
454 454 #--------------------------------------------------------------------------
455 455 # handlers and callbacks for incoming messages
456 456 #--------------------------------------------------------------------------
457 457
458 458 def _unwrap_exception(self, content):
459 459 """unwrap exception, and remap engine_id to int."""
460 460 e = error.unwrap_exception(content)
461 461 # print e.traceback
462 462 if e.engine_info:
463 463 e_uuid = e.engine_info['engine_uuid']
464 464 eid = self._engines[e_uuid]
465 465 e.engine_info['engine_id'] = eid
466 466 return e
467 467
468 468 def _extract_metadata(self, header, parent, content):
469 469 md = {'msg_id' : parent['msg_id'],
470 470 'received' : datetime.now(),
471 471 'engine_uuid' : header.get('engine', None),
472 472 'follow' : parent.get('follow', []),
473 473 'after' : parent.get('after', []),
474 474 'status' : content['status'],
475 475 }
476 476
477 477 if md['engine_uuid'] is not None:
478 478 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
479 479
480 480 if 'date' in parent:
481 481 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
482 482 if 'started' in header:
483 483 md['started'] = datetime.strptime(header['started'], util.ISO8601)
484 484 if 'date' in header:
485 485 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
486 486 return md
487 487
488 488 def _register_engine(self, msg):
489 489 """Register a new engine, and update our connection info."""
490 490 content = msg['content']
491 491 eid = content['id']
492 492 d = {eid : content['queue']}
493 493 self._update_engines(d)
494 494
495 495 def _unregister_engine(self, msg):
496 496 """Unregister an engine that has died."""
497 497 content = msg['content']
498 498 eid = int(content['id'])
499 499 if eid in self._ids:
500 500 self._ids.remove(eid)
501 501 uuid = self._engines.pop(eid)
502 502
503 503 self._handle_stranded_msgs(eid, uuid)
504 504
505 505 if self._task_socket and self._task_scheme == 'pure':
506 506 self._stop_scheduling_tasks()
507 507
508 508 def _handle_stranded_msgs(self, eid, uuid):
509 509 """Handle messages known to be on an engine when the engine unregisters.
510 510
511 511 It is possible that this will fire prematurely - that is, an engine will
512 512 go down after completing a result, and the client will be notified
513 513 of the unregistration and later receive the successful result.
514 514 """
515 515
516 516 outstanding = self._outstanding_dict[uuid]
517 517
518 518 for msg_id in list(outstanding):
519 519 if msg_id in self.results:
520 520 # we already
521 521 continue
522 522 try:
523 523 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
524 524 except:
525 525 content = error.wrap_exception()
526 526 # build a fake message:
527 527 parent = {}
528 528 header = {}
529 529 parent['msg_id'] = msg_id
530 530 header['engine'] = uuid
531 531 header['date'] = datetime.now().strftime(util.ISO8601)
532 532 msg = dict(parent_header=parent, header=header, content=content)
533 533 self._handle_apply_reply(msg)
534 534
535 535 def _handle_execute_reply(self, msg):
536 536 """Save the reply to an execute_request into our results.
537 537
538 538 execute messages are never actually used. apply is used instead.
539 539 """
540 540
541 541 parent = msg['parent_header']
542 542 msg_id = parent['msg_id']
543 543 if msg_id not in self.outstanding:
544 544 if msg_id in self.history:
545 545 print ("got stale result: %s"%msg_id)
546 546 else:
547 547 print ("got unknown result: %s"%msg_id)
548 548 else:
549 549 self.outstanding.remove(msg_id)
550 550 self.results[msg_id] = self._unwrap_exception(msg['content'])
551 551
552 552 def _handle_apply_reply(self, msg):
553 553 """Save the reply to an apply_request into our results."""
554 554 parent = msg['parent_header']
555 555 msg_id = parent['msg_id']
556 556 if msg_id not in self.outstanding:
557 557 if msg_id in self.history:
558 558 print ("got stale result: %s"%msg_id)
559 559 print self.results[msg_id]
560 560 print msg
561 561 else:
562 562 print ("got unknown result: %s"%msg_id)
563 563 else:
564 564 self.outstanding.remove(msg_id)
565 565 content = msg['content']
566 566 header = msg['header']
567 567
568 568 # construct metadata:
569 569 md = self.metadata[msg_id]
570 570 md.update(self._extract_metadata(header, parent, content))
571 571 # is this redundant?
572 572 self.metadata[msg_id] = md
573 573
574 574 e_outstanding = self._outstanding_dict[md['engine_uuid']]
575 575 if msg_id in e_outstanding:
576 576 e_outstanding.remove(msg_id)
577 577
578 578 # construct result:
579 579 if content['status'] == 'ok':
580 580 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
581 581 elif content['status'] == 'aborted':
582 582 self.results[msg_id] = error.TaskAborted(msg_id)
583 583 elif content['status'] == 'resubmitted':
584 584 # TODO: handle resubmission
585 585 pass
586 586 else:
587 587 self.results[msg_id] = self._unwrap_exception(content)
588 588
589 589 def _flush_notifications(self):
590 590 """Flush notifications of engine registrations waiting
591 591 in ZMQ queue."""
592 592 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
593 593 while msg is not None:
594 594 if self.debug:
595 595 pprint(msg)
596 596 msg = msg[-1]
597 597 msg_type = msg['msg_type']
598 598 handler = self._notification_handlers.get(msg_type, None)
599 599 if handler is None:
600 600 raise Exception("Unhandled message type: %s"%msg.msg_type)
601 601 else:
602 602 handler(msg)
603 603 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
604 604
605 605 def _flush_results(self, sock):
606 606 """Flush task or queue results waiting in ZMQ queue."""
607 607 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
608 608 while msg is not None:
609 609 if self.debug:
610 610 pprint(msg)
611 611 msg = msg[-1]
612 612 msg_type = msg['msg_type']
613 613 handler = self._queue_handlers.get(msg_type, None)
614 614 if handler is None:
615 615 raise Exception("Unhandled message type: %s"%msg.msg_type)
616 616 else:
617 617 handler(msg)
618 618 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
619 619
620 620 def _flush_control(self, sock):
621 621 """Flush replies from the control channel waiting
622 622 in the ZMQ queue.
623 623
624 624 Currently: ignore them."""
625 625 if self._ignored_control_replies <= 0:
626 626 return
627 627 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
628 628 while msg is not None:
629 629 self._ignored_control_replies -= 1
630 630 if self.debug:
631 631 pprint(msg)
632 632 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
633 633
634 634 def _flush_ignored_control(self):
635 635 """flush ignored control replies"""
636 636 while self._ignored_control_replies > 0:
637 637 self.session.recv(self._control_socket)
638 638 self._ignored_control_replies -= 1
639 639
640 640 def _flush_ignored_hub_replies(self):
641 641 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
642 642 while msg is not None:
643 643 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
644 644
645 645 def _flush_iopub(self, sock):
646 646 """Flush replies from the iopub channel waiting
647 647 in the ZMQ queue.
648 648 """
649 649 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
650 650 while msg is not None:
651 651 if self.debug:
652 652 pprint(msg)
653 653 msg = msg[-1]
654 654 parent = msg['parent_header']
655 655 msg_id = parent['msg_id']
656 656 content = msg['content']
657 657 header = msg['header']
658 658 msg_type = msg['msg_type']
659 659
660 660 # init metadata:
661 661 md = self.metadata[msg_id]
662 662
663 663 if msg_type == 'stream':
664 664 name = content['name']
665 665 s = md[name] or ''
666 666 md[name] = s + content['data']
667 667 elif msg_type == 'pyerr':
668 668 md.update({'pyerr' : self._unwrap_exception(content)})
669 669 elif msg_type == 'pyin':
670 670 md.update({'pyin' : content['code']})
671 671 else:
672 672 md.update({msg_type : content.get('data', '')})
673 673
674 674 # reduntant?
675 675 self.metadata[msg_id] = md
676 676
677 677 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
678 678
679 679 #--------------------------------------------------------------------------
680 680 # len, getitem
681 681 #--------------------------------------------------------------------------
682 682
683 683 def __len__(self):
684 684 """len(client) returns # of engines."""
685 685 return len(self.ids)
686 686
687 687 def __getitem__(self, key):
688 688 """index access returns DirectView multiplexer objects
689 689
690 690 Must be int, slice, or list/tuple/xrange of ints"""
691 691 if not isinstance(key, (int, slice, tuple, list, xrange)):
692 692 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
693 693 else:
694 694 return self.direct_view(key)
695 695
696 696 #--------------------------------------------------------------------------
697 697 # Begin public methods
698 698 #--------------------------------------------------------------------------
699 699
700 700 @property
701 701 def ids(self):
702 702 """Always up-to-date ids property."""
703 703 self._flush_notifications()
704 704 # always copy:
705 705 return list(self._ids)
706 706
707 707 def close(self):
708 708 if self._closed:
709 709 return
710 710 snames = filter(lambda n: n.endswith('socket'), dir(self))
711 711 for socket in map(lambda name: getattr(self, name), snames):
712 712 if isinstance(socket, zmq.Socket) and not socket.closed:
713 713 socket.close()
714 714 self._closed = True
715 715
716 716 def spin(self):
717 717 """Flush any registration notifications and execution results
718 718 waiting in the ZMQ queue.
719 719 """
720 720 if self._notification_socket:
721 721 self._flush_notifications()
722 722 if self._mux_socket:
723 723 self._flush_results(self._mux_socket)
724 724 if self._task_socket:
725 725 self._flush_results(self._task_socket)
726 726 if self._control_socket:
727 727 self._flush_control(self._control_socket)
728 728 if self._iopub_socket:
729 729 self._flush_iopub(self._iopub_socket)
730 730 if self._query_socket:
731 731 self._flush_ignored_hub_replies()
732 732
733 733 def wait(self, jobs=None, timeout=-1):
734 734 """waits on one or more `jobs`, for up to `timeout` seconds.
735 735
736 736 Parameters
737 737 ----------
738 738
739 739 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
740 740 ints are indices to self.history
741 741 strs are msg_ids
742 742 default: wait on all outstanding messages
743 743 timeout : float
744 744 a time in seconds, after which to give up.
745 745 default is -1, which means no timeout
746 746
747 747 Returns
748 748 -------
749 749
750 750 True : when all msg_ids are done
751 751 False : timeout reached, some msg_ids still outstanding
752 752 """
753 753 tic = time.time()
754 754 if jobs is None:
755 755 theids = self.outstanding
756 756 else:
757 757 if isinstance(jobs, (int, str, AsyncResult)):
758 758 jobs = [jobs]
759 759 theids = set()
760 760 for job in jobs:
761 761 if isinstance(job, int):
762 762 # index access
763 763 job = self.history[job]
764 764 elif isinstance(job, AsyncResult):
765 765 map(theids.add, job.msg_ids)
766 766 continue
767 767 theids.add(job)
768 768 if not theids.intersection(self.outstanding):
769 769 return True
770 770 self.spin()
771 771 while theids.intersection(self.outstanding):
772 772 if timeout >= 0 and ( time.time()-tic ) > timeout:
773 773 break
774 774 time.sleep(1e-3)
775 775 self.spin()
776 776 return len(theids.intersection(self.outstanding)) == 0
777 777
778 778 #--------------------------------------------------------------------------
779 779 # Control methods
780 780 #--------------------------------------------------------------------------
781 781
782 782 @spin_first
783 783 def clear(self, targets=None, block=None):
784 784 """Clear the namespace in target(s)."""
785 785 block = self.block if block is None else block
786 786 targets = self._build_targets(targets)[0]
787 787 for t in targets:
788 788 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
789 789 error = False
790 790 if block:
791 791 self._flush_ignored_control()
792 792 for i in range(len(targets)):
793 793 idents,msg = self.session.recv(self._control_socket,0)
794 794 if self.debug:
795 795 pprint(msg)
796 796 if msg['content']['status'] != 'ok':
797 797 error = self._unwrap_exception(msg['content'])
798 798 else:
799 799 self._ignored_control_replies += len(targets)
800 800 if error:
801 801 raise error
802 802
803 803
804 804 @spin_first
805 805 def abort(self, jobs=None, targets=None, block=None):
806 806 """Abort specific jobs from the execution queues of target(s).
807 807
808 808 This is a mechanism to prevent jobs that have already been submitted
809 809 from executing.
810 810
811 811 Parameters
812 812 ----------
813 813
814 814 jobs : msg_id, list of msg_ids, or AsyncResult
815 815 The jobs to be aborted
816 816
817 817
818 818 """
819 819 block = self.block if block is None else block
820 820 targets = self._build_targets(targets)[0]
821 821 msg_ids = []
822 822 if isinstance(jobs, (basestring,AsyncResult)):
823 823 jobs = [jobs]
824 824 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
825 825 if bad_ids:
826 826 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
827 827 for j in jobs:
828 828 if isinstance(j, AsyncResult):
829 829 msg_ids.extend(j.msg_ids)
830 830 else:
831 831 msg_ids.append(j)
832 832 content = dict(msg_ids=msg_ids)
833 833 for t in targets:
834 834 self.session.send(self._control_socket, 'abort_request',
835 835 content=content, ident=t)
836 836 error = False
837 837 if block:
838 838 self._flush_ignored_control()
839 839 for i in range(len(targets)):
840 840 idents,msg = self.session.recv(self._control_socket,0)
841 841 if self.debug:
842 842 pprint(msg)
843 843 if msg['content']['status'] != 'ok':
844 844 error = self._unwrap_exception(msg['content'])
845 845 else:
846 846 self._ignored_control_replies += len(targets)
847 847 if error:
848 848 raise error
849 849
850 850 @spin_first
851 851 def shutdown(self, targets=None, restart=False, hub=False, block=None):
852 852 """Terminates one or more engine processes, optionally including the hub."""
853 853 block = self.block if block is None else block
854 854 if hub:
855 855 targets = 'all'
856 856 targets = self._build_targets(targets)[0]
857 857 for t in targets:
858 858 self.session.send(self._control_socket, 'shutdown_request',
859 859 content={'restart':restart},ident=t)
860 860 error = False
861 861 if block or hub:
862 862 self._flush_ignored_control()
863 863 for i in range(len(targets)):
864 864 idents,msg = self.session.recv(self._control_socket, 0)
865 865 if self.debug:
866 866 pprint(msg)
867 867 if msg['content']['status'] != 'ok':
868 868 error = self._unwrap_exception(msg['content'])
869 869 else:
870 870 self._ignored_control_replies += len(targets)
871 871
872 872 if hub:
873 873 time.sleep(0.25)
874 874 self.session.send(self._query_socket, 'shutdown_request')
875 875 idents,msg = self.session.recv(self._query_socket, 0)
876 876 if self.debug:
877 877 pprint(msg)
878 878 if msg['content']['status'] != 'ok':
879 879 error = self._unwrap_exception(msg['content'])
880 880
881 881 if error:
882 882 raise error
883 883
884 884 #--------------------------------------------------------------------------
885 885 # Execution related methods
886 886 #--------------------------------------------------------------------------
887 887
888 888 def _maybe_raise(self, result):
889 889 """wrapper for maybe raising an exception if apply failed."""
890 890 if isinstance(result, error.RemoteError):
891 891 raise result
892 892
893 893 return result
894 894
895 895 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
896 896 ident=None):
897 897 """construct and send an apply message via a socket.
898 898
899 899 This is the principal method with which all engine execution is performed by views.
900 900 """
901 901
902 902 assert not self._closed, "cannot use me anymore, I'm closed!"
903 903 # defaults:
904 904 args = args if args is not None else []
905 905 kwargs = kwargs if kwargs is not None else {}
906 906 subheader = subheader if subheader is not None else {}
907 907
908 908 # validate arguments
909 909 if not callable(f):
910 910 raise TypeError("f must be callable, not %s"%type(f))
911 911 if not isinstance(args, (tuple, list)):
912 912 raise TypeError("args must be tuple or list, not %s"%type(args))
913 913 if not isinstance(kwargs, dict):
914 914 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
915 915 if not isinstance(subheader, dict):
916 916 raise TypeError("subheader must be dict, not %s"%type(subheader))
917 917
918 918 bufs = util.pack_apply_message(f,args,kwargs)
919 919
920 920 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
921 921 subheader=subheader, track=track)
922 922
923 923 msg_id = msg['msg_id']
924 924 self.outstanding.add(msg_id)
925 925 if ident:
926 926 # possibly routed to a specific engine
927 927 if isinstance(ident, list):
928 928 ident = ident[-1]
929 929 if ident in self._engines.values():
930 930 # save for later, in case of engine death
931 931 self._outstanding_dict[ident].add(msg_id)
932 932 self.history.append(msg_id)
933 933 self.metadata[msg_id]['submitted'] = datetime.now()
934 934
935 935 return msg
936 936
937 937 #--------------------------------------------------------------------------
938 938 # construct a View object
939 939 #--------------------------------------------------------------------------
940 940
941 941 def load_balanced_view(self, targets=None):
942 942 """construct a DirectView object.
943 943
944 944 If no arguments are specified, create a LoadBalancedView
945 945 using all engines.
946 946
947 947 Parameters
948 948 ----------
949 949
950 950 targets: list,slice,int,etc. [default: use all engines]
951 951 The subset of engines across which to load-balance
952 952 """
953 953 if targets is not None:
954 954 targets = self._build_targets(targets)[1]
955 955 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
956 956
957 957 def direct_view(self, targets='all'):
958 958 """construct a DirectView object.
959 959
960 960 If no targets are specified, create a DirectView
961 961 using all engines.
962 962
963 963 Parameters
964 964 ----------
965 965
966 966 targets: list,slice,int,etc. [default: use all engines]
967 967 The engines to use for the View
968 968 """
969 969 single = isinstance(targets, int)
970 970 targets = self._build_targets(targets)[1]
971 971 if single:
972 972 targets = targets[0]
973 973 return DirectView(client=self, socket=self._mux_socket, targets=targets)
974 974
975 975 #--------------------------------------------------------------------------
976 976 # Query methods
977 977 #--------------------------------------------------------------------------
978 978
979 979 @spin_first
980 980 def get_result(self, indices_or_msg_ids=None, block=None):
981 981 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
982 982
983 983 If the client already has the results, no request to the Hub will be made.
984 984
985 985 This is a convenient way to construct AsyncResult objects, which are wrappers
986 986 that include metadata about execution, and allow for awaiting results that
987 987 were not submitted by this Client.
988 988
989 989 It can also be a convenient way to retrieve the metadata associated with
990 990 blocking execution, since it always retrieves
991 991
992 992 Examples
993 993 --------
994 994 ::
995 995
996 996 In [10]: r = client.apply()
997 997
998 998 Parameters
999 999 ----------
1000 1000
1001 1001 indices_or_msg_ids : integer history index, str msg_id, or list of either
1002 1002 The indices or msg_ids of indices to be retrieved
1003 1003
1004 1004 block : bool
1005 1005 Whether to wait for the result to be done
1006 1006
1007 1007 Returns
1008 1008 -------
1009 1009
1010 1010 AsyncResult
1011 1011 A single AsyncResult object will always be returned.
1012 1012
1013 1013 AsyncHubResult
1014 1014 A subclass of AsyncResult that retrieves results from the Hub
1015 1015
1016 1016 """
1017 1017 block = self.block if block is None else block
1018 1018 if indices_or_msg_ids is None:
1019 1019 indices_or_msg_ids = -1
1020 1020
1021 1021 if not isinstance(indices_or_msg_ids, (list,tuple)):
1022 1022 indices_or_msg_ids = [indices_or_msg_ids]
1023 1023
1024 1024 theids = []
1025 1025 for id in indices_or_msg_ids:
1026 1026 if isinstance(id, int):
1027 1027 id = self.history[id]
1028 1028 if not isinstance(id, str):
1029 1029 raise TypeError("indices must be str or int, not %r"%id)
1030 1030 theids.append(id)
1031 1031
1032 1032 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1033 1033 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1034 1034
1035 1035 if remote_ids:
1036 1036 ar = AsyncHubResult(self, msg_ids=theids)
1037 1037 else:
1038 1038 ar = AsyncResult(self, msg_ids=theids)
1039 1039
1040 1040 if block:
1041 1041 ar.wait()
1042 1042
1043 1043 return ar
1044 1044
1045 1045 @spin_first
1046 1046 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1047 1047 """Resubmit one or more tasks.
1048 1048
1049 1049 in-flight tasks may not be resubmitted.
1050 1050
1051 1051 Parameters
1052 1052 ----------
1053 1053
1054 1054 indices_or_msg_ids : integer history index, str msg_id, or list of either
1055 1055 The indices or msg_ids of indices to be retrieved
1056 1056
1057 1057 block : bool
1058 1058 Whether to wait for the result to be done
1059 1059
1060 1060 Returns
1061 1061 -------
1062 1062
1063 1063 AsyncHubResult
1064 1064 A subclass of AsyncResult that retrieves results from the Hub
1065 1065
1066 1066 """
1067 1067 block = self.block if block is None else block
1068 1068 if indices_or_msg_ids is None:
1069 1069 indices_or_msg_ids = -1
1070 1070
1071 1071 if not isinstance(indices_or_msg_ids, (list,tuple)):
1072 1072 indices_or_msg_ids = [indices_or_msg_ids]
1073 1073
1074 1074 theids = []
1075 1075 for id in indices_or_msg_ids:
1076 1076 if isinstance(id, int):
1077 1077 id = self.history[id]
1078 1078 if not isinstance(id, str):
1079 1079 raise TypeError("indices must be str or int, not %r"%id)
1080 1080 theids.append(id)
1081 1081
1082 1082 for msg_id in theids:
1083 1083 self.outstanding.discard(msg_id)
1084 1084 if msg_id in self.history:
1085 1085 self.history.remove(msg_id)
1086 1086 self.results.pop(msg_id, None)
1087 1087 self.metadata.pop(msg_id, None)
1088 1088 content = dict(msg_ids = theids)
1089 1089
1090 1090 self.session.send(self._query_socket, 'resubmit_request', content)
1091 1091
1092 1092 zmq.select([self._query_socket], [], [])
1093 1093 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1094 1094 if self.debug:
1095 1095 pprint(msg)
1096 1096 content = msg['content']
1097 1097 if content['status'] != 'ok':
1098 1098 raise self._unwrap_exception(content)
1099 1099
1100 1100 ar = AsyncHubResult(self, msg_ids=theids)
1101 1101
1102 1102 if block:
1103 1103 ar.wait()
1104 1104
1105 1105 return ar
1106 1106
1107 1107 @spin_first
1108 1108 def result_status(self, msg_ids, status_only=True):
1109 1109 """Check on the status of the result(s) of the apply request with `msg_ids`.
1110 1110
1111 1111 If status_only is False, then the actual results will be retrieved, else
1112 1112 only the status of the results will be checked.
1113 1113
1114 1114 Parameters
1115 1115 ----------
1116 1116
1117 1117 msg_ids : list of msg_ids
1118 1118 if int:
1119 1119 Passed as index to self.history for convenience.
1120 1120 status_only : bool (default: True)
1121 1121 if False:
1122 1122 Retrieve the actual results of completed tasks.
1123 1123
1124 1124 Returns
1125 1125 -------
1126 1126
1127 1127 results : dict
1128 1128 There will always be the keys 'pending' and 'completed', which will
1129 1129 be lists of msg_ids that are incomplete or complete. If `status_only`
1130 1130 is False, then completed results will be keyed by their `msg_id`.
1131 1131 """
1132 1132 if not isinstance(msg_ids, (list,tuple)):
1133 1133 msg_ids = [msg_ids]
1134 1134
1135 1135 theids = []
1136 1136 for msg_id in msg_ids:
1137 1137 if isinstance(msg_id, int):
1138 1138 msg_id = self.history[msg_id]
1139 1139 if not isinstance(msg_id, basestring):
1140 1140 raise TypeError("msg_ids must be str, not %r"%msg_id)
1141 1141 theids.append(msg_id)
1142 1142
1143 1143 completed = []
1144 1144 local_results = {}
1145 1145
1146 1146 # comment this block out to temporarily disable local shortcut:
1147 1147 for msg_id in theids:
1148 1148 if msg_id in self.results:
1149 1149 completed.append(msg_id)
1150 1150 local_results[msg_id] = self.results[msg_id]
1151 1151 theids.remove(msg_id)
1152 1152
1153 1153 if theids: # some not locally cached
1154 1154 content = dict(msg_ids=theids, status_only=status_only)
1155 1155 msg = self.session.send(self._query_socket, "result_request", content=content)
1156 1156 zmq.select([self._query_socket], [], [])
1157 1157 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1158 1158 if self.debug:
1159 1159 pprint(msg)
1160 1160 content = msg['content']
1161 1161 if content['status'] != 'ok':
1162 1162 raise self._unwrap_exception(content)
1163 1163 buffers = msg['buffers']
1164 1164 else:
1165 1165 content = dict(completed=[],pending=[])
1166 1166
1167 1167 content['completed'].extend(completed)
1168 1168
1169 1169 if status_only:
1170 1170 return content
1171 1171
1172 1172 failures = []
1173 1173 # load cached results into result:
1174 1174 content.update(local_results)
1175 1175 # update cache with results:
1176 1176 for msg_id in sorted(theids):
1177 1177 if msg_id in content['completed']:
1178 1178 rec = content[msg_id]
1179 1179 parent = rec['header']
1180 1180 header = rec['result_header']
1181 1181 rcontent = rec['result_content']
1182 1182 iodict = rec['io']
1183 1183 if isinstance(rcontent, str):
1184 1184 rcontent = self.session.unpack(rcontent)
1185 1185
1186 1186 md = self.metadata[msg_id]
1187 1187 md.update(self._extract_metadata(header, parent, rcontent))
1188 1188 md.update(iodict)
1189 1189
1190 1190 if rcontent['status'] == 'ok':
1191 1191 res,buffers = util.unserialize_object(buffers)
1192 1192 else:
1193 1193 print rcontent
1194 1194 res = self._unwrap_exception(rcontent)
1195 1195 failures.append(res)
1196 1196
1197 1197 self.results[msg_id] = res
1198 1198 content[msg_id] = res
1199 1199
1200 1200 if len(theids) == 1 and failures:
1201 1201 raise failures[0]
1202 1202
1203 1203 error.collect_exceptions(failures, "result_status")
1204 1204 return content
1205 1205
1206 1206 @spin_first
1207 1207 def queue_status(self, targets='all', verbose=False):
1208 1208 """Fetch the status of engine queues.
1209 1209
1210 1210 Parameters
1211 1211 ----------
1212 1212
1213 1213 targets : int/str/list of ints/strs
1214 1214 the engines whose states are to be queried.
1215 1215 default : all
1216 1216 verbose : bool
1217 1217 Whether to return lengths only, or lists of ids for each element
1218 1218 """
1219 1219 engine_ids = self._build_targets(targets)[1]
1220 1220 content = dict(targets=engine_ids, verbose=verbose)
1221 1221 self.session.send(self._query_socket, "queue_request", content=content)
1222 1222 idents,msg = self.session.recv(self._query_socket, 0)
1223 1223 if self.debug:
1224 1224 pprint(msg)
1225 1225 content = msg['content']
1226 1226 status = content.pop('status')
1227 1227 if status != 'ok':
1228 1228 raise self._unwrap_exception(content)
1229 1229 content = util.rekey(content)
1230 1230 if isinstance(targets, int):
1231 1231 return content[targets]
1232 1232 else:
1233 1233 return content
1234 1234
1235 1235 @spin_first
1236 1236 def purge_results(self, jobs=[], targets=[]):
1237 1237 """Tell the Hub to forget results.
1238 1238
1239 1239 Individual results can be purged by msg_id, or the entire
1240 1240 history of specific targets can be purged.
1241 1241
1242 1242 Parameters
1243 1243 ----------
1244 1244
1245 1245 jobs : str or list of str or AsyncResult objects
1246 1246 the msg_ids whose results should be forgotten.
1247 1247 targets : int/str/list of ints/strs
1248 1248 The targets, by uuid or int_id, whose entire history is to be purged.
1249 1249 Use `targets='all'` to scrub everything from the Hub's memory.
1250 1250
1251 1251 default : None
1252 1252 """
1253 1253 if not targets and not jobs:
1254 1254 raise ValueError("Must specify at least one of `targets` and `jobs`")
1255 1255 if targets:
1256 1256 targets = self._build_targets(targets)[1]
1257 1257
1258 1258 # construct msg_ids from jobs
1259 1259 msg_ids = []
1260 1260 if isinstance(jobs, (basestring,AsyncResult)):
1261 1261 jobs = [jobs]
1262 1262 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1263 1263 if bad_ids:
1264 1264 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1265 1265 for j in jobs:
1266 1266 if isinstance(j, AsyncResult):
1267 1267 msg_ids.extend(j.msg_ids)
1268 1268 else:
1269 1269 msg_ids.append(j)
1270 1270
1271 1271 content = dict(targets=targets, msg_ids=msg_ids)
1272 1272 self.session.send(self._query_socket, "purge_request", content=content)
1273 1273 idents, msg = self.session.recv(self._query_socket, 0)
1274 1274 if self.debug:
1275 1275 pprint(msg)
1276 1276 content = msg['content']
1277 1277 if content['status'] != 'ok':
1278 1278 raise self._unwrap_exception(content)
1279 1279
1280 1280 @spin_first
1281 1281 def hub_history(self):
1282 1282 """Get the Hub's history
1283 1283
1284 1284 Just like the Client, the Hub has a history, which is a list of msg_ids.
1285 1285 This will contain the history of all clients, and, depending on configuration,
1286 1286 may contain history across multiple cluster sessions.
1287 1287
1288 1288 Any msg_id returned here is a valid argument to `get_result`.
1289 1289
1290 1290 Returns
1291 1291 -------
1292 1292
1293 1293 msg_ids : list of strs
1294 1294 list of all msg_ids, ordered by task submission time.
1295 1295 """
1296 1296
1297 1297 self.session.send(self._query_socket, "history_request", content={})
1298 1298 idents, msg = self.session.recv(self._query_socket, 0)
1299 1299
1300 1300 if self.debug:
1301 1301 pprint(msg)
1302 1302 content = msg['content']
1303 1303 if content['status'] != 'ok':
1304 1304 raise self._unwrap_exception(content)
1305 1305 else:
1306 1306 return content['history']
1307 1307
1308 1308 @spin_first
1309 1309 def db_query(self, query, keys=None):
1310 1310 """Query the Hub's TaskRecord database
1311 1311
1312 1312 This will return a list of task record dicts that match `query`
1313 1313
1314 1314 Parameters
1315 1315 ----------
1316 1316
1317 1317 query : mongodb query dict
1318 1318 The search dict. See mongodb query docs for details.
1319 1319 keys : list of strs [optional]
1320 1320 The subset of keys to be returned. The default is to fetch everything but buffers.
1321 1321 'msg_id' will *always* be included.
1322 1322 """
1323 1323 if isinstance(keys, basestring):
1324 1324 keys = [keys]
1325 1325 content = dict(query=query, keys=keys)
1326 1326 self.session.send(self._query_socket, "db_request", content=content)
1327 1327 idents, msg = self.session.recv(self._query_socket, 0)
1328 1328 if self.debug:
1329 1329 pprint(msg)
1330 1330 content = msg['content']
1331 1331 if content['status'] != 'ok':
1332 1332 raise self._unwrap_exception(content)
1333 1333
1334 1334 records = content['records']
1335 1335 buffer_lens = content['buffer_lens']
1336 1336 result_buffer_lens = content['result_buffer_lens']
1337 1337 buffers = msg['buffers']
1338 1338 has_bufs = buffer_lens is not None
1339 1339 has_rbufs = result_buffer_lens is not None
1340 1340 for i,rec in enumerate(records):
1341 1341 # relink buffers
1342 1342 if has_bufs:
1343 1343 blen = buffer_lens[i]
1344 1344 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1345 1345 if has_rbufs:
1346 1346 blen = result_buffer_lens[i]
1347 1347 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1348 1348 # turn timestamps back into times
1349 1349 for key in 'submitted started completed resubmitted'.split():
1350 1350 maybedate = rec.get(key, None)
1351 1351 if maybedate and util.ISO8601_RE.match(maybedate):
1352 1352 rec[key] = datetime.strptime(maybedate, util.ISO8601)
1353 1353
1354 1354 return records
1355 1355
1356 1356 __all__ = [ 'Client' ]
@@ -1,163 +1,165
1 1 #!/usr/bin/env python
2 2 """
3 3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
4 4 and hearts are tracked based on their XREQ identities.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010-2011 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import print_function
14 14 import time
15 15 import uuid
16 16
17 17 import zmq
18 18 from zmq.devices import ProcessDevice, ThreadDevice
19 19 from zmq.eventloop import ioloop, zmqstream
20 20
21 from IPython.utils.traitlets import Set, Instance, CFloat, Bool
21 from IPython.utils.traitlets import Set, Instance, CFloat, Bool, CStr
22 22 from IPython.parallel.factory import LoggingFactory
23 23
24 24 class Heart(object):
25 25 """A basic heart object for responding to a HeartMonitor.
26 26 This is a simple wrapper with defaults for the most common
27 27 Device model for responding to heartbeats.
28 28
29 29 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
30 30 SUB/XREQ for in/out.
31 31
32 32 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
33 33 device=None
34 34 id=None
35 35 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.XREQ, heart_id=None):
36 36 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
37 37 self.device.daemon=True
38 38 self.device.connect_in(in_addr)
39 39 self.device.connect_out(out_addr)
40 40 if in_type == zmq.SUB:
41 41 self.device.setsockopt_in(zmq.SUBSCRIBE, "")
42 42 if heart_id is None:
43 43 heart_id = str(uuid.uuid4())
44 44 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
45 45 self.id = heart_id
46 46
47 47 def start(self):
48 48 return self.device.start()
49 49
50 50 class HeartMonitor(LoggingFactory):
51 51 """A basic HeartMonitor class
52 52 pingstream: a PUB stream
53 53 pongstream: an XREP stream
54 54 period: the period of the heartbeat in milliseconds"""
55 55
56 period=CFloat(1000, config=True) # in milliseconds
56 period=CFloat(1000, config=True,
57 help='The frequency at which the Hub pings the engines for heartbeats '
58 ' (in ms) [default: 100]',
59 )
57 60
58 61 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
59 62 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
60 63 loop = Instance('zmq.eventloop.ioloop.IOLoop')
61 64 def _loop_default(self):
62 65 return ioloop.IOLoop.instance()
63 debug=Bool(False)
64 66
65 67 # not settable:
66 68 hearts=Set()
67 69 responses=Set()
68 70 on_probation=Set()
69 71 last_ping=CFloat(0)
70 72 _new_handlers = Set()
71 73 _failure_handlers = Set()
72 74 lifetime = CFloat(0)
73 75 tic = CFloat(0)
74 76
75 77 def __init__(self, **kwargs):
76 78 super(HeartMonitor, self).__init__(**kwargs)
77 79
78 80 self.pongstream.on_recv(self.handle_pong)
79 81
80 82 def start(self):
81 83 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
82 84 self.caller.start()
83 85
84 86 def add_new_heart_handler(self, handler):
85 87 """add a new handler for new hearts"""
86 88 self.log.debug("heartbeat::new_heart_handler: %s"%handler)
87 89 self._new_handlers.add(handler)
88 90
89 91 def add_heart_failure_handler(self, handler):
90 92 """add a new handler for heart failure"""
91 93 self.log.debug("heartbeat::new heart failure handler: %s"%handler)
92 94 self._failure_handlers.add(handler)
93 95
94 96 def beat(self):
95 97 self.pongstream.flush()
96 98 self.last_ping = self.lifetime
97 99
98 100 toc = time.time()
99 101 self.lifetime += toc-self.tic
100 102 self.tic = toc
101 103 # self.log.debug("heartbeat::%s"%self.lifetime)
102 104 goodhearts = self.hearts.intersection(self.responses)
103 105 missed_beats = self.hearts.difference(goodhearts)
104 106 heartfailures = self.on_probation.intersection(missed_beats)
105 107 newhearts = self.responses.difference(goodhearts)
106 108 map(self.handle_new_heart, newhearts)
107 109 map(self.handle_heart_failure, heartfailures)
108 110 self.on_probation = missed_beats.intersection(self.hearts)
109 111 self.responses = set()
110 112 # print self.on_probation, self.hearts
111 113 # self.log.debug("heartbeat::beat %.3f, %i beating hearts"%(self.lifetime, len(self.hearts)))
112 114 self.pingstream.send(str(self.lifetime))
113 115
114 116 def handle_new_heart(self, heart):
115 117 if self._new_handlers:
116 118 for handler in self._new_handlers:
117 119 handler(heart)
118 120 else:
119 121 self.log.info("heartbeat::yay, got new heart %s!"%heart)
120 122 self.hearts.add(heart)
121 123
122 124 def handle_heart_failure(self, heart):
123 125 if self._failure_handlers:
124 126 for handler in self._failure_handlers:
125 127 try:
126 128 handler(heart)
127 129 except Exception as e:
128 130 self.log.error("heartbeat::Bad Handler! %s"%handler, exc_info=True)
129 131 pass
130 132 else:
131 133 self.log.info("heartbeat::Heart %s failed :("%heart)
132 134 self.hearts.remove(heart)
133 135
134 136
135 137 def handle_pong(self, msg):
136 138 "a heart just beat"
137 139 if msg[1] == str(self.lifetime):
138 140 delta = time.time()-self.tic
139 141 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
140 142 self.responses.add(msg[0])
141 143 elif msg[1] == str(self.last_ping):
142 144 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
143 145 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond"%(msg[0], 1000*delta))
144 146 self.responses.add(msg[0])
145 147 else:
146 148 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)"%
147 149 (msg[1],self.lifetime))
148 150
149 151
150 152 if __name__ == '__main__':
151 153 loop = ioloop.IOLoop.instance()
152 154 context = zmq.Context()
153 155 pub = context.socket(zmq.PUB)
154 156 pub.bind('tcp://127.0.0.1:5555')
155 157 xrep = context.socket(zmq.XREP)
156 158 xrep.bind('tcp://127.0.0.1:5556')
157 159
158 160 outstream = zmqstream.ZMQStream(pub, loop)
159 161 instream = zmqstream.ZMQStream(xrep, loop)
160 162
161 163 hb = HeartMonitor(loop, outstream, instream)
162 164
163 165 loop.start()
@@ -1,1282 +1,1293
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 import sys
19 19 import time
20 20 from datetime import datetime
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop
24 24 from zmq.eventloop.zmqstream import ZMQStream
25 25
26 26 # internal:
27 27 from IPython.utils.importstring import import_item
28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
28 from IPython.utils.traitlets import (
29 HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool, Tuple
30 )
29 31
30 32 from IPython.parallel import error, util
31 33 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
32 34
33 35 from .heartmonitor import HeartMonitor
34 36
35 37 #-----------------------------------------------------------------------------
36 38 # Code
37 39 #-----------------------------------------------------------------------------
38 40
39 41 def _passer(*args, **kwargs):
40 42 return
41 43
42 44 def _printer(*args, **kwargs):
43 45 print (args)
44 46 print (kwargs)
45 47
46 48 def empty_record():
47 49 """Return an empty dict with all record keys."""
48 50 return {
49 51 'msg_id' : None,
50 52 'header' : None,
51 53 'content': None,
52 54 'buffers': None,
53 55 'submitted': None,
54 56 'client_uuid' : None,
55 57 'engine_uuid' : None,
56 58 'started': None,
57 59 'completed': None,
58 60 'resubmitted': None,
59 61 'result_header' : None,
60 62 'result_content' : None,
61 63 'result_buffers' : None,
62 64 'queue' : None,
63 65 'pyin' : None,
64 66 'pyout': None,
65 67 'pyerr': None,
66 68 'stdout': '',
67 69 'stderr': '',
68 70 }
69 71
70 72 def init_record(msg):
71 73 """Initialize a TaskRecord based on a request."""
72 74 header = msg['header']
73 75 return {
74 76 'msg_id' : header['msg_id'],
75 77 'header' : header,
76 78 'content': msg['content'],
77 79 'buffers': msg['buffers'],
78 80 'submitted': datetime.strptime(header['date'], util.ISO8601),
79 81 'client_uuid' : None,
80 82 'engine_uuid' : None,
81 83 'started': None,
82 84 'completed': None,
83 85 'resubmitted': None,
84 86 'result_header' : None,
85 87 'result_content' : None,
86 88 'result_buffers' : None,
87 89 'queue' : None,
88 90 'pyin' : None,
89 91 'pyout': None,
90 92 'pyerr': None,
91 93 'stdout': '',
92 94 'stderr': '',
93 95 }
94 96
95 97
96 98 class EngineConnector(HasTraits):
97 99 """A simple object for accessing the various zmq connections of an object.
98 100 Attributes are:
99 101 id (int): engine ID
100 102 uuid (str): uuid (unused?)
101 103 queue (str): identity of queue's XREQ socket
102 104 registration (str): identity of registration XREQ socket
103 105 heartbeat (str): identity of heartbeat XREQ socket
104 106 """
105 107 id=Int(0)
106 108 queue=Str()
107 109 control=Str()
108 110 registration=Str()
109 111 heartbeat=Str()
110 112 pending=Set()
111 113
112 114 class HubFactory(RegistrationFactory):
113 115 """The Configurable for setting up a Hub."""
114 116
115 # name of a scheduler scheme
116 scheme = Str('leastload', config=True)
117
118 117 # port-pairs for monitoredqueues:
119 hb = Instance(list, config=True)
118 hb = Tuple(Int,Int,config=True,
119 help="""XREQ/SUB Port pair for Engine heartbeats""")
120 120 def _hb_default(self):
121 return util.select_random_ports(2)
121 return tuple(util.select_random_ports(2))
122
123 mux = Tuple(Int,Int,config=True,
124 help="""Engine/Client Port pair for MUX queue""")
122 125
123 mux = Instance(list, config=True)
124 126 def _mux_default(self):
125 return util.select_random_ports(2)
127 return tuple(util.select_random_ports(2))
126 128
127 task = Instance(list, config=True)
129 task = Tuple(Int,Int,config=True,
130 help="""Engine/Client Port pair for Task queue""")
128 131 def _task_default(self):
129 return util.select_random_ports(2)
132 return tuple(util.select_random_ports(2))
133
134 control = Tuple(Int,Int,config=True,
135 help="""Engine/Client Port pair for Control queue""")
130 136
131 control = Instance(list, config=True)
132 137 def _control_default(self):
133 return util.select_random_ports(2)
138 return tuple(util.select_random_ports(2))
139
140 iopub = Tuple(Int,Int,config=True,
141 help="""Engine/Client Port pair for IOPub relay""")
134 142
135 iopub = Instance(list, config=True)
136 143 def _iopub_default(self):
137 return util.select_random_ports(2)
144 return tuple(util.select_random_ports(2))
138 145
139 146 # single ports:
140 mon_port = Instance(int, config=True)
147 mon_port = Int(config=True,
148 help="""Monitor (SUB) port for queue traffic""")
149
141 150 def _mon_port_default(self):
142 151 return util.select_random_ports(1)[0]
143 152
144 notifier_port = Instance(int, config=True)
153 notifier_port = Int(config=True,
154 help="""PUB port for sending engine status notifications""")
155
145 156 def _notifier_port_default(self):
146 157 return util.select_random_ports(1)[0]
147 158
148 ping = Int(1000, config=True) # ping frequency
159 engine_ip = CStr('127.0.0.1', config=True,
160 help="IP on which to listen for engine connections. [default: loopback]")
161 engine_transport = CStr('tcp', config=True,
162 help="0MQ transport for engine connections. [default: tcp]")
149 163
150 engine_ip = CStr('127.0.0.1', config=True)
151 engine_transport = CStr('tcp', config=True)
164 client_ip = CStr('127.0.0.1', config=True,
165 help="IP on which to listen for client connections. [default: loopback]")
166 client_transport = CStr('tcp', config=True,
167 help="0MQ transport for client connections. [default : tcp]")
152 168
153 client_ip = CStr('127.0.0.1', config=True)
154 client_transport = CStr('tcp', config=True)
155
156 monitor_ip = CStr('127.0.0.1', config=True)
157 monitor_transport = CStr('tcp', config=True)
169 monitor_ip = CStr('127.0.0.1', config=True,
170 help="IP on which to listen for monitor messages. [default: loopback]")
171 monitor_transport = CStr('tcp', config=True,
172 help="0MQ transport for monitor messages. [default : tcp]")
158 173
159 174 monitor_url = CStr('')
160 175
161 db_class = CStr('IPython.parallel.controller.dictdb.DictDB', config=True)
176 db_class = CStr('IPython.parallel.controller.dictdb.DictDB', config=True,
177 help="""The class to use for the DB backend""")
162 178
163 179 # not configurable
164 180 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
165 181 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
166 subconstructors = List()
167 _constructed = Bool(False)
168 182
169 183 def _ip_changed(self, name, old, new):
170 184 self.engine_ip = new
171 185 self.client_ip = new
172 186 self.monitor_ip = new
173 187 self._update_monitor_url()
174 188
175 189 def _update_monitor_url(self):
176 190 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
177 191
178 192 def _transport_changed(self, name, old, new):
179 193 self.engine_transport = new
180 194 self.client_transport = new
181 195 self.monitor_transport = new
182 196 self._update_monitor_url()
183 197
184 198 def __init__(self, **kwargs):
185 199 super(HubFactory, self).__init__(**kwargs)
186 200 self._update_monitor_url()
187 201 # self.on_trait_change(self._sync_ips, 'ip')
188 202 # self.on_trait_change(self._sync_transports, 'transport')
189 self.subconstructors.append(self.construct_hub)
203 # self.subconstructors.append(self.construct_hub)
190 204
191 205
192 206 def construct(self):
193 assert not self._constructed, "already constructed!"
194
195 for subc in self.subconstructors:
196 subc()
197
198 self._constructed = True
199
207 self.init_hub()
200 208
201 209 def start(self):
202 assert self._constructed, "must be constructed by self.construct() first!"
203 210 self.heartmonitor.start()
204 211 self.log.info("Heartmonitor started")
205 212
206 def construct_hub(self):
213 def init_hub(self):
207 214 """construct"""
208 215 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
209 216 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
210 217
211 218 ctx = self.context
212 219 loop = self.loop
213 220
214 221 # Registrar socket
215 222 q = ZMQStream(ctx.socket(zmq.XREP), loop)
216 223 q.bind(client_iface % self.regport)
217 224 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
218 225 if self.client_ip != self.engine_ip:
219 226 q.bind(engine_iface % self.regport)
220 227 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
221 228
222 229 ### Engine connections ###
223 230
224 231 # heartbeat
225 232 hpub = ctx.socket(zmq.PUB)
226 233 hpub.bind(engine_iface % self.hb[0])
227 234 hrep = ctx.socket(zmq.XREP)
228 235 hrep.bind(engine_iface % self.hb[1])
229 236 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
230 period=self.ping, logname=self.log.name)
237 config=self.config)
231 238
232 239 ### Client connections ###
233 240 # Notifier socket
234 241 n = ZMQStream(ctx.socket(zmq.PUB), loop)
235 242 n.bind(client_iface%self.notifier_port)
236 243
237 244 ### build and launch the queues ###
238 245
239 246 # monitor socket
240 247 sub = ctx.socket(zmq.SUB)
241 248 sub.setsockopt(zmq.SUBSCRIBE, "")
242 249 sub.bind(self.monitor_url)
243 250 sub.bind('inproc://monitor')
244 251 sub = ZMQStream(sub, loop)
245 252
246 253 # connect the db
247 254 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
248 255 # cdir = self.config.Global.cluster_dir
249 256 self.db = import_item(self.db_class)(session=self.session.session, config=self.config)
250 257 time.sleep(.25)
251
258 try:
259 scheme = self.config.TaskScheduler.scheme_name
260 except AttributeError:
261 from .scheduler import TaskScheduler
262 scheme = TaskScheduler.scheme_name.get_default_value()
252 263 # build connection dicts
253 264 self.engine_info = {
254 265 'control' : engine_iface%self.control[1],
255 266 'mux': engine_iface%self.mux[1],
256 267 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
257 268 'task' : engine_iface%self.task[1],
258 269 'iopub' : engine_iface%self.iopub[1],
259 270 # 'monitor' : engine_iface%self.mon_port,
260 271 }
261 272
262 273 self.client_info = {
263 274 'control' : client_iface%self.control[0],
264 275 'mux': client_iface%self.mux[0],
265 'task' : (self.scheme, client_iface%self.task[0]),
276 'task' : (scheme, client_iface%self.task[0]),
266 277 'iopub' : client_iface%self.iopub[0],
267 278 'notification': client_iface%self.notifier_port
268 279 }
269 280 self.log.debug("Hub engine addrs: %s"%self.engine_info)
270 281 self.log.debug("Hub client addrs: %s"%self.client_info)
271 282
272 283 # resubmit stream
273 284 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
274 285 url = util.disambiguate_url(self.client_info['task'][-1])
275 286 r.setsockopt(zmq.IDENTITY, self.session.session)
276 287 r.connect(url)
277 288
278 289 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
279 290 query=q, notifier=n, resubmit=r, db=self.db,
280 291 engine_info=self.engine_info, client_info=self.client_info,
281 292 logname=self.log.name)
282 293
283 294
284 295 class Hub(LoggingFactory):
285 296 """The IPython Controller Hub with 0MQ connections
286 297
287 298 Parameters
288 299 ==========
289 300 loop: zmq IOLoop instance
290 301 session: StreamSession object
291 302 <removed> context: zmq context for creating new connections (?)
292 303 queue: ZMQStream for monitoring the command queue (SUB)
293 304 query: ZMQStream for engine registration and client queries requests (XREP)
294 305 heartbeat: HeartMonitor object checking the pulse of the engines
295 306 notifier: ZMQStream for broadcasting engine registration changes (PUB)
296 307 db: connection to db for out of memory logging of commands
297 308 NotImplemented
298 309 engine_info: dict of zmq connection information for engines to connect
299 310 to the queues.
300 311 client_info: dict of zmq connection information for engines to connect
301 312 to the queues.
302 313 """
303 314 # internal data structures:
304 315 ids=Set() # engine IDs
305 316 keytable=Dict()
306 317 by_ident=Dict()
307 318 engines=Dict()
308 319 clients=Dict()
309 320 hearts=Dict()
310 321 pending=Set()
311 322 queues=Dict() # pending msg_ids keyed by engine_id
312 323 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
313 324 completed=Dict() # completed msg_ids keyed by engine_id
314 325 all_completed=Set() # completed msg_ids keyed by engine_id
315 326 dead_engines=Set() # completed msg_ids keyed by engine_id
316 327 unassigned=Set() # set of task msg_ds not yet assigned a destination
317 328 incoming_registrations=Dict()
318 329 registration_timeout=Int()
319 330 _idcounter=Int(0)
320 331
321 332 # objects from constructor:
322 333 loop=Instance(ioloop.IOLoop)
323 334 query=Instance(ZMQStream)
324 335 monitor=Instance(ZMQStream)
325 336 notifier=Instance(ZMQStream)
326 337 resubmit=Instance(ZMQStream)
327 338 heartmonitor=Instance(HeartMonitor)
328 339 db=Instance(object)
329 340 client_info=Dict()
330 341 engine_info=Dict()
331 342
332 343
333 344 def __init__(self, **kwargs):
334 345 """
335 346 # universal:
336 347 loop: IOLoop for creating future connections
337 348 session: streamsession for sending serialized data
338 349 # engine:
339 350 queue: ZMQStream for monitoring queue messages
340 351 query: ZMQStream for engine+client registration and client requests
341 352 heartbeat: HeartMonitor object for tracking engines
342 353 # extra:
343 354 db: ZMQStream for db connection (NotImplemented)
344 355 engine_info: zmq address/protocol dict for engine connections
345 356 client_info: zmq address/protocol dict for client connections
346 357 """
347 358
348 359 super(Hub, self).__init__(**kwargs)
349 360 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
350 361
351 362 # validate connection dicts:
352 363 for k,v in self.client_info.iteritems():
353 364 if k == 'task':
354 365 util.validate_url_container(v[1])
355 366 else:
356 367 util.validate_url_container(v)
357 368 # util.validate_url_container(self.client_info)
358 369 util.validate_url_container(self.engine_info)
359 370
360 371 # register our callbacks
361 372 self.query.on_recv(self.dispatch_query)
362 373 self.monitor.on_recv(self.dispatch_monitor_traffic)
363 374
364 375 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
365 376 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
366 377
367 378 self.monitor_handlers = { 'in' : self.save_queue_request,
368 379 'out': self.save_queue_result,
369 380 'intask': self.save_task_request,
370 381 'outtask': self.save_task_result,
371 382 'tracktask': self.save_task_destination,
372 383 'incontrol': _passer,
373 384 'outcontrol': _passer,
374 385 'iopub': self.save_iopub_message,
375 386 }
376 387
377 388 self.query_handlers = {'queue_request': self.queue_status,
378 389 'result_request': self.get_results,
379 390 'history_request': self.get_history,
380 391 'db_request': self.db_query,
381 392 'purge_request': self.purge_results,
382 393 'load_request': self.check_load,
383 394 'resubmit_request': self.resubmit_task,
384 395 'shutdown_request': self.shutdown_request,
385 396 'registration_request' : self.register_engine,
386 397 'unregistration_request' : self.unregister_engine,
387 398 'connection_request': self.connection_request,
388 399 }
389 400
390 401 # ignore resubmit replies
391 402 self.resubmit.on_recv(lambda msg: None, copy=False)
392 403
393 404 self.log.info("hub::created hub")
394 405
395 406 @property
396 407 def _next_id(self):
397 408 """gemerate a new ID.
398 409
399 410 No longer reuse old ids, just count from 0."""
400 411 newid = self._idcounter
401 412 self._idcounter += 1
402 413 return newid
403 414 # newid = 0
404 415 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
405 416 # # print newid, self.ids, self.incoming_registrations
406 417 # while newid in self.ids or newid in incoming:
407 418 # newid += 1
408 419 # return newid
409 420
410 421 #-----------------------------------------------------------------------------
411 422 # message validation
412 423 #-----------------------------------------------------------------------------
413 424
414 425 def _validate_targets(self, targets):
415 426 """turn any valid targets argument into a list of integer ids"""
416 427 if targets is None:
417 428 # default to all
418 429 targets = self.ids
419 430
420 431 if isinstance(targets, (int,str,unicode)):
421 432 # only one target specified
422 433 targets = [targets]
423 434 _targets = []
424 435 for t in targets:
425 436 # map raw identities to ids
426 437 if isinstance(t, (str,unicode)):
427 438 t = self.by_ident.get(t, t)
428 439 _targets.append(t)
429 440 targets = _targets
430 441 bad_targets = [ t for t in targets if t not in self.ids ]
431 442 if bad_targets:
432 443 raise IndexError("No Such Engine: %r"%bad_targets)
433 444 if not targets:
434 445 raise IndexError("No Engines Registered")
435 446 return targets
436 447
437 448 #-----------------------------------------------------------------------------
438 449 # dispatch methods (1 per stream)
439 450 #-----------------------------------------------------------------------------
440 451
441 452 # def dispatch_registration_request(self, msg):
442 453 # """"""
443 454 # self.log.debug("registration::dispatch_register_request(%s)"%msg)
444 455 # idents,msg = self.session.feed_identities(msg)
445 456 # if not idents:
446 457 # self.log.error("Bad Query Message: %s"%msg, exc_info=True)
447 458 # return
448 459 # try:
449 460 # msg = self.session.unpack_message(msg,content=True)
450 461 # except:
451 462 # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
452 463 # return
453 464 #
454 465 # msg_type = msg['msg_type']
455 466 # content = msg['content']
456 467 #
457 468 # handler = self.query_handlers.get(msg_type, None)
458 469 # if handler is None:
459 470 # self.log.error("registration::got bad registration message: %s"%msg)
460 471 # else:
461 472 # handler(idents, msg)
462 473
463 474 def dispatch_monitor_traffic(self, msg):
464 475 """all ME and Task queue messages come through here, as well as
465 476 IOPub traffic."""
466 477 self.log.debug("monitor traffic: %r"%msg[:2])
467 478 switch = msg[0]
468 479 idents, msg = self.session.feed_identities(msg[1:])
469 480 if not idents:
470 481 self.log.error("Bad Monitor Message: %r"%msg)
471 482 return
472 483 handler = self.monitor_handlers.get(switch, None)
473 484 if handler is not None:
474 485 handler(idents, msg)
475 486 else:
476 487 self.log.error("Invalid monitor topic: %r"%switch)
477 488
478 489
479 490 def dispatch_query(self, msg):
480 491 """Route registration requests and queries from clients."""
481 492 idents, msg = self.session.feed_identities(msg)
482 493 if not idents:
483 494 self.log.error("Bad Query Message: %r"%msg)
484 495 return
485 496 client_id = idents[0]
486 497 try:
487 498 msg = self.session.unpack_message(msg, content=True)
488 499 except:
489 500 content = error.wrap_exception()
490 501 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
491 502 self.session.send(self.query, "hub_error", ident=client_id,
492 503 content=content)
493 504 return
494 505
495 506 # print client_id, header, parent, content
496 507 #switch on message type:
497 508 msg_type = msg['msg_type']
498 509 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 510 handler = self.query_handlers.get(msg_type, None)
500 511 try:
501 512 assert handler is not None, "Bad Message Type: %r"%msg_type
502 513 except:
503 514 content = error.wrap_exception()
504 515 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
505 516 self.session.send(self.query, "hub_error", ident=client_id,
506 517 content=content)
507 518 return
508 519
509 520 else:
510 521 handler(idents, msg)
511 522
512 523 def dispatch_db(self, msg):
513 524 """"""
514 525 raise NotImplementedError
515 526
516 527 #---------------------------------------------------------------------------
517 528 # handler methods (1 per event)
518 529 #---------------------------------------------------------------------------
519 530
520 531 #----------------------- Heartbeat --------------------------------------
521 532
522 533 def handle_new_heart(self, heart):
523 534 """handler to attach to heartbeater.
524 535 Called when a new heart starts to beat.
525 536 Triggers completion of registration."""
526 537 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
527 538 if heart not in self.incoming_registrations:
528 539 self.log.info("heartbeat::ignoring new heart: %r"%heart)
529 540 else:
530 541 self.finish_registration(heart)
531 542
532 543
533 544 def handle_heart_failure(self, heart):
534 545 """handler to attach to heartbeater.
535 546 called when a previously registered heart fails to respond to beat request.
536 547 triggers unregistration"""
537 548 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
538 549 eid = self.hearts.get(heart, None)
539 550 queue = self.engines[eid].queue
540 551 if eid is None:
541 552 self.log.info("heartbeat::ignoring heart failure %r"%heart)
542 553 else:
543 554 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
544 555
545 556 #----------------------- MUX Queue Traffic ------------------------------
546 557
547 558 def save_queue_request(self, idents, msg):
548 559 if len(idents) < 2:
549 560 self.log.error("invalid identity prefix: %s"%idents)
550 561 return
551 562 queue_id, client_id = idents[:2]
552 563 try:
553 564 msg = self.session.unpack_message(msg, content=False)
554 565 except:
555 566 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
556 567 return
557 568
558 569 eid = self.by_ident.get(queue_id, None)
559 570 if eid is None:
560 571 self.log.error("queue::target %r not registered"%queue_id)
561 572 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
562 573 return
563 574
564 575 header = msg['header']
565 576 msg_id = header['msg_id']
566 577 record = init_record(msg)
567 578 record['engine_uuid'] = queue_id
568 579 record['client_uuid'] = client_id
569 580 record['queue'] = 'mux'
570 581
571 582 try:
572 583 # it's posible iopub arrived first:
573 584 existing = self.db.get_record(msg_id)
574 585 for key,evalue in existing.iteritems():
575 586 rvalue = record.get(key, None)
576 587 if evalue and rvalue and evalue != rvalue:
577 588 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
578 589 elif evalue and not rvalue:
579 590 record[key] = evalue
580 591 self.db.update_record(msg_id, record)
581 592 except KeyError:
582 593 self.db.add_record(msg_id, record)
583 594
584 595 self.pending.add(msg_id)
585 596 self.queues[eid].append(msg_id)
586 597
587 598 def save_queue_result(self, idents, msg):
588 599 if len(idents) < 2:
589 600 self.log.error("invalid identity prefix: %s"%idents)
590 601 return
591 602
592 603 client_id, queue_id = idents[:2]
593 604 try:
594 605 msg = self.session.unpack_message(msg, content=False)
595 606 except:
596 607 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
597 608 queue_id,client_id, msg), exc_info=True)
598 609 return
599 610
600 611 eid = self.by_ident.get(queue_id, None)
601 612 if eid is None:
602 613 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
603 614 # self.log.debug("queue:: %s"%msg[2:])
604 615 return
605 616
606 617 parent = msg['parent_header']
607 618 if not parent:
608 619 return
609 620 msg_id = parent['msg_id']
610 621 if msg_id in self.pending:
611 622 self.pending.remove(msg_id)
612 623 self.all_completed.add(msg_id)
613 624 self.queues[eid].remove(msg_id)
614 625 self.completed[eid].append(msg_id)
615 626 elif msg_id not in self.all_completed:
616 627 # it could be a result from a dead engine that died before delivering the
617 628 # result
618 629 self.log.warn("queue:: unknown msg finished %s"%msg_id)
619 630 return
620 631 # update record anyway, because the unregistration could have been premature
621 632 rheader = msg['header']
622 633 completed = datetime.strptime(rheader['date'], util.ISO8601)
623 634 started = rheader.get('started', None)
624 635 if started is not None:
625 636 started = datetime.strptime(started, util.ISO8601)
626 637 result = {
627 638 'result_header' : rheader,
628 639 'result_content': msg['content'],
629 640 'started' : started,
630 641 'completed' : completed
631 642 }
632 643
633 644 result['result_buffers'] = msg['buffers']
634 645 try:
635 646 self.db.update_record(msg_id, result)
636 647 except Exception:
637 648 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
638 649
639 650
640 651 #--------------------- Task Queue Traffic ------------------------------
641 652
642 653 def save_task_request(self, idents, msg):
643 654 """Save the submission of a task."""
644 655 client_id = idents[0]
645 656
646 657 try:
647 658 msg = self.session.unpack_message(msg, content=False)
648 659 except:
649 660 self.log.error("task::client %r sent invalid task message: %s"%(
650 661 client_id, msg), exc_info=True)
651 662 return
652 663 record = init_record(msg)
653 664
654 665 record['client_uuid'] = client_id
655 666 record['queue'] = 'task'
656 667 header = msg['header']
657 668 msg_id = header['msg_id']
658 669 self.pending.add(msg_id)
659 670 self.unassigned.add(msg_id)
660 671 try:
661 672 # it's posible iopub arrived first:
662 673 existing = self.db.get_record(msg_id)
663 674 if existing['resubmitted']:
664 675 for key in ('submitted', 'client_uuid', 'buffers'):
665 676 # don't clobber these keys on resubmit
666 677 # submitted and client_uuid should be different
667 678 # and buffers might be big, and shouldn't have changed
668 679 record.pop(key)
669 680 # still check content,header which should not change
670 681 # but are not expensive to compare as buffers
671 682
672 683 for key,evalue in existing.iteritems():
673 684 if key.endswith('buffers'):
674 685 # don't compare buffers
675 686 continue
676 687 rvalue = record.get(key, None)
677 688 if evalue and rvalue and evalue != rvalue:
678 689 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
679 690 elif evalue and not rvalue:
680 691 record[key] = evalue
681 692 self.db.update_record(msg_id, record)
682 693 except KeyError:
683 694 self.db.add_record(msg_id, record)
684 695 except Exception:
685 696 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
686 697
687 698 def save_task_result(self, idents, msg):
688 699 """save the result of a completed task."""
689 700 client_id = idents[0]
690 701 try:
691 702 msg = self.session.unpack_message(msg, content=False)
692 703 except:
693 704 self.log.error("task::invalid task result message send to %r: %s"%(
694 705 client_id, msg), exc_info=True)
695 706 raise
696 707 return
697 708
698 709 parent = msg['parent_header']
699 710 if not parent:
700 711 # print msg
701 712 self.log.warn("Task %r had no parent!"%msg)
702 713 return
703 714 msg_id = parent['msg_id']
704 715 if msg_id in self.unassigned:
705 716 self.unassigned.remove(msg_id)
706 717
707 718 header = msg['header']
708 719 engine_uuid = header.get('engine', None)
709 720 eid = self.by_ident.get(engine_uuid, None)
710 721
711 722 if msg_id in self.pending:
712 723 self.pending.remove(msg_id)
713 724 self.all_completed.add(msg_id)
714 725 if eid is not None:
715 726 self.completed[eid].append(msg_id)
716 727 if msg_id in self.tasks[eid]:
717 728 self.tasks[eid].remove(msg_id)
718 729 completed = datetime.strptime(header['date'], util.ISO8601)
719 730 started = header.get('started', None)
720 731 if started is not None:
721 732 started = datetime.strptime(started, util.ISO8601)
722 733 result = {
723 734 'result_header' : header,
724 735 'result_content': msg['content'],
725 736 'started' : started,
726 737 'completed' : completed,
727 738 'engine_uuid': engine_uuid
728 739 }
729 740
730 741 result['result_buffers'] = msg['buffers']
731 742 try:
732 743 self.db.update_record(msg_id, result)
733 744 except Exception:
734 745 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
735 746
736 747 else:
737 748 self.log.debug("task::unknown task %s finished"%msg_id)
738 749
739 750 def save_task_destination(self, idents, msg):
740 751 try:
741 752 msg = self.session.unpack_message(msg, content=True)
742 753 except:
743 754 self.log.error("task::invalid task tracking message", exc_info=True)
744 755 return
745 756 content = msg['content']
746 757 # print (content)
747 758 msg_id = content['msg_id']
748 759 engine_uuid = content['engine_id']
749 760 eid = self.by_ident[engine_uuid]
750 761
751 762 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
752 763 if msg_id in self.unassigned:
753 764 self.unassigned.remove(msg_id)
754 765 # else:
755 766 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
756 767
757 768 self.tasks[eid].append(msg_id)
758 769 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
759 770 try:
760 771 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
761 772 except Exception:
762 773 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
763 774
764 775
765 776 def mia_task_request(self, idents, msg):
766 777 raise NotImplementedError
767 778 client_id = idents[0]
768 779 # content = dict(mia=self.mia,status='ok')
769 780 # self.session.send('mia_reply', content=content, idents=client_id)
770 781
771 782
772 783 #--------------------- IOPub Traffic ------------------------------
773 784
774 785 def save_iopub_message(self, topics, msg):
775 786 """save an iopub message into the db"""
776 787 # print (topics)
777 788 try:
778 789 msg = self.session.unpack_message(msg, content=True)
779 790 except:
780 791 self.log.error("iopub::invalid IOPub message", exc_info=True)
781 792 return
782 793
783 794 parent = msg['parent_header']
784 795 if not parent:
785 796 self.log.error("iopub::invalid IOPub message: %s"%msg)
786 797 return
787 798 msg_id = parent['msg_id']
788 799 msg_type = msg['msg_type']
789 800 content = msg['content']
790 801
791 802 # ensure msg_id is in db
792 803 try:
793 804 rec = self.db.get_record(msg_id)
794 805 except KeyError:
795 806 rec = empty_record()
796 807 rec['msg_id'] = msg_id
797 808 self.db.add_record(msg_id, rec)
798 809 # stream
799 810 d = {}
800 811 if msg_type == 'stream':
801 812 name = content['name']
802 813 s = rec[name] or ''
803 814 d[name] = s + content['data']
804 815
805 816 elif msg_type == 'pyerr':
806 817 d['pyerr'] = content
807 818 elif msg_type == 'pyin':
808 819 d['pyin'] = content['code']
809 820 else:
810 821 d[msg_type] = content.get('data', '')
811 822
812 823 try:
813 824 self.db.update_record(msg_id, d)
814 825 except Exception:
815 826 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
816 827
817 828
818 829
819 830 #-------------------------------------------------------------------------
820 831 # Registration requests
821 832 #-------------------------------------------------------------------------
822 833
823 834 def connection_request(self, client_id, msg):
824 835 """Reply with connection addresses for clients."""
825 836 self.log.info("client::client %s connected"%client_id)
826 837 content = dict(status='ok')
827 838 content.update(self.client_info)
828 839 jsonable = {}
829 840 for k,v in self.keytable.iteritems():
830 841 if v not in self.dead_engines:
831 842 jsonable[str(k)] = v
832 843 content['engines'] = jsonable
833 844 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
834 845
835 846 def register_engine(self, reg, msg):
836 847 """Register a new engine."""
837 848 content = msg['content']
838 849 try:
839 850 queue = content['queue']
840 851 except KeyError:
841 852 self.log.error("registration::queue not specified", exc_info=True)
842 853 return
843 854 heart = content.get('heartbeat', None)
844 855 """register a new engine, and create the socket(s) necessary"""
845 856 eid = self._next_id
846 857 # print (eid, queue, reg, heart)
847 858
848 859 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
849 860
850 861 content = dict(id=eid,status='ok')
851 862 content.update(self.engine_info)
852 863 # check if requesting available IDs:
853 864 if queue in self.by_ident:
854 865 try:
855 866 raise KeyError("queue_id %r in use"%queue)
856 867 except:
857 868 content = error.wrap_exception()
858 869 self.log.error("queue_id %r in use"%queue, exc_info=True)
859 870 elif heart in self.hearts: # need to check unique hearts?
860 871 try:
861 872 raise KeyError("heart_id %r in use"%heart)
862 873 except:
863 874 self.log.error("heart_id %r in use"%heart, exc_info=True)
864 875 content = error.wrap_exception()
865 876 else:
866 877 for h, pack in self.incoming_registrations.iteritems():
867 878 if heart == h:
868 879 try:
869 880 raise KeyError("heart_id %r in use"%heart)
870 881 except:
871 882 self.log.error("heart_id %r in use"%heart, exc_info=True)
872 883 content = error.wrap_exception()
873 884 break
874 885 elif queue == pack[1]:
875 886 try:
876 887 raise KeyError("queue_id %r in use"%queue)
877 888 except:
878 889 self.log.error("queue_id %r in use"%queue, exc_info=True)
879 890 content = error.wrap_exception()
880 891 break
881 892
882 893 msg = self.session.send(self.query, "registration_reply",
883 894 content=content,
884 895 ident=reg)
885 896
886 897 if content['status'] == 'ok':
887 898 if heart in self.heartmonitor.hearts:
888 899 # already beating
889 900 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
890 901 self.finish_registration(heart)
891 902 else:
892 903 purge = lambda : self._purge_stalled_registration(heart)
893 904 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
894 905 dc.start()
895 906 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
896 907 else:
897 908 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
898 909 return eid
899 910
900 911 def unregister_engine(self, ident, msg):
901 912 """Unregister an engine that explicitly requested to leave."""
902 913 try:
903 914 eid = msg['content']['id']
904 915 except:
905 916 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
906 917 return
907 918 self.log.info("registration::unregister_engine(%s)"%eid)
908 919 # print (eid)
909 920 uuid = self.keytable[eid]
910 921 content=dict(id=eid, queue=uuid)
911 922 self.dead_engines.add(uuid)
912 923 # self.ids.remove(eid)
913 924 # uuid = self.keytable.pop(eid)
914 925 #
915 926 # ec = self.engines.pop(eid)
916 927 # self.hearts.pop(ec.heartbeat)
917 928 # self.by_ident.pop(ec.queue)
918 929 # self.completed.pop(eid)
919 930 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
920 931 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
921 932 dc.start()
922 933 ############## TODO: HANDLE IT ################
923 934
924 935 if self.notifier:
925 936 self.session.send(self.notifier, "unregistration_notification", content=content)
926 937
927 938 def _handle_stranded_msgs(self, eid, uuid):
928 939 """Handle messages known to be on an engine when the engine unregisters.
929 940
930 941 It is possible that this will fire prematurely - that is, an engine will
931 942 go down after completing a result, and the client will be notified
932 943 that the result failed and later receive the actual result.
933 944 """
934 945
935 946 outstanding = self.queues[eid]
936 947
937 948 for msg_id in outstanding:
938 949 self.pending.remove(msg_id)
939 950 self.all_completed.add(msg_id)
940 951 try:
941 952 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
942 953 except:
943 954 content = error.wrap_exception()
944 955 # build a fake header:
945 956 header = {}
946 957 header['engine'] = uuid
947 958 header['date'] = datetime.now()
948 959 rec = dict(result_content=content, result_header=header, result_buffers=[])
949 960 rec['completed'] = header['date']
950 961 rec['engine_uuid'] = uuid
951 962 try:
952 963 self.db.update_record(msg_id, rec)
953 964 except Exception:
954 965 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
955 966
956 967
957 968 def finish_registration(self, heart):
958 969 """Second half of engine registration, called after our HeartMonitor
959 970 has received a beat from the Engine's Heart."""
960 971 try:
961 972 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
962 973 except KeyError:
963 974 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
964 975 return
965 976 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
966 977 if purge is not None:
967 978 purge.stop()
968 979 control = queue
969 980 self.ids.add(eid)
970 981 self.keytable[eid] = queue
971 982 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
972 983 control=control, heartbeat=heart)
973 984 self.by_ident[queue] = eid
974 985 self.queues[eid] = list()
975 986 self.tasks[eid] = list()
976 987 self.completed[eid] = list()
977 988 self.hearts[heart] = eid
978 989 content = dict(id=eid, queue=self.engines[eid].queue)
979 990 if self.notifier:
980 991 self.session.send(self.notifier, "registration_notification", content=content)
981 992 self.log.info("engine::Engine Connected: %i"%eid)
982 993
983 994 def _purge_stalled_registration(self, heart):
984 995 if heart in self.incoming_registrations:
985 996 eid = self.incoming_registrations.pop(heart)[0]
986 997 self.log.info("registration::purging stalled registration: %i"%eid)
987 998 else:
988 999 pass
989 1000
990 1001 #-------------------------------------------------------------------------
991 1002 # Client Requests
992 1003 #-------------------------------------------------------------------------
993 1004
994 1005 def shutdown_request(self, client_id, msg):
995 1006 """handle shutdown request."""
996 1007 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
997 1008 # also notify other clients of shutdown
998 1009 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
999 1010 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1000 1011 dc.start()
1001 1012
1002 1013 def _shutdown(self):
1003 1014 self.log.info("hub::hub shutting down.")
1004 1015 time.sleep(0.1)
1005 1016 sys.exit(0)
1006 1017
1007 1018
1008 1019 def check_load(self, client_id, msg):
1009 1020 content = msg['content']
1010 1021 try:
1011 1022 targets = content['targets']
1012 1023 targets = self._validate_targets(targets)
1013 1024 except:
1014 1025 content = error.wrap_exception()
1015 1026 self.session.send(self.query, "hub_error",
1016 1027 content=content, ident=client_id)
1017 1028 return
1018 1029
1019 1030 content = dict(status='ok')
1020 1031 # loads = {}
1021 1032 for t in targets:
1022 1033 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1023 1034 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1024 1035
1025 1036
1026 1037 def queue_status(self, client_id, msg):
1027 1038 """Return the Queue status of one or more targets.
1028 1039 if verbose: return the msg_ids
1029 1040 else: return len of each type.
1030 1041 keys: queue (pending MUX jobs)
1031 1042 tasks (pending Task jobs)
1032 1043 completed (finished jobs from both queues)"""
1033 1044 content = msg['content']
1034 1045 targets = content['targets']
1035 1046 try:
1036 1047 targets = self._validate_targets(targets)
1037 1048 except:
1038 1049 content = error.wrap_exception()
1039 1050 self.session.send(self.query, "hub_error",
1040 1051 content=content, ident=client_id)
1041 1052 return
1042 1053 verbose = content.get('verbose', False)
1043 1054 content = dict(status='ok')
1044 1055 for t in targets:
1045 1056 queue = self.queues[t]
1046 1057 completed = self.completed[t]
1047 1058 tasks = self.tasks[t]
1048 1059 if not verbose:
1049 1060 queue = len(queue)
1050 1061 completed = len(completed)
1051 1062 tasks = len(tasks)
1052 1063 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1053 1064 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1054 1065
1055 1066 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1056 1067
1057 1068 def purge_results(self, client_id, msg):
1058 1069 """Purge results from memory. This method is more valuable before we move
1059 1070 to a DB based message storage mechanism."""
1060 1071 content = msg['content']
1061 1072 msg_ids = content.get('msg_ids', [])
1062 1073 reply = dict(status='ok')
1063 1074 if msg_ids == 'all':
1064 1075 try:
1065 1076 self.db.drop_matching_records(dict(completed={'$ne':None}))
1066 1077 except Exception:
1067 1078 reply = error.wrap_exception()
1068 1079 else:
1069 1080 pending = filter(lambda m: m in self.pending, msg_ids)
1070 1081 if pending:
1071 1082 try:
1072 1083 raise IndexError("msg pending: %r"%pending[0])
1073 1084 except:
1074 1085 reply = error.wrap_exception()
1075 1086 else:
1076 1087 try:
1077 1088 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1078 1089 except Exception:
1079 1090 reply = error.wrap_exception()
1080 1091
1081 1092 if reply['status'] == 'ok':
1082 1093 eids = content.get('engine_ids', [])
1083 1094 for eid in eids:
1084 1095 if eid not in self.engines:
1085 1096 try:
1086 1097 raise IndexError("No such engine: %i"%eid)
1087 1098 except:
1088 1099 reply = error.wrap_exception()
1089 1100 break
1090 1101 msg_ids = self.completed.pop(eid)
1091 1102 uid = self.engines[eid].queue
1092 1103 try:
1093 1104 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1094 1105 except Exception:
1095 1106 reply = error.wrap_exception()
1096 1107 break
1097 1108
1098 1109 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1099 1110
1100 1111 def resubmit_task(self, client_id, msg):
1101 1112 """Resubmit one or more tasks."""
1102 1113 def finish(reply):
1103 1114 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1104 1115
1105 1116 content = msg['content']
1106 1117 msg_ids = content['msg_ids']
1107 1118 reply = dict(status='ok')
1108 1119 try:
1109 1120 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1110 1121 'header', 'content', 'buffers'])
1111 1122 except Exception:
1112 1123 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1113 1124 return finish(error.wrap_exception())
1114 1125
1115 1126 # validate msg_ids
1116 1127 found_ids = [ rec['msg_id'] for rec in records ]
1117 1128 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1118 1129 if len(records) > len(msg_ids):
1119 1130 try:
1120 1131 raise RuntimeError("DB appears to be in an inconsistent state."
1121 1132 "More matching records were found than should exist")
1122 1133 except Exception:
1123 1134 return finish(error.wrap_exception())
1124 1135 elif len(records) < len(msg_ids):
1125 1136 missing = [ m for m in msg_ids if m not in found_ids ]
1126 1137 try:
1127 1138 raise KeyError("No such msg(s): %s"%missing)
1128 1139 except KeyError:
1129 1140 return finish(error.wrap_exception())
1130 1141 elif invalid_ids:
1131 1142 msg_id = invalid_ids[0]
1132 1143 try:
1133 1144 raise ValueError("Task %r appears to be inflight"%(msg_id))
1134 1145 except Exception:
1135 1146 return finish(error.wrap_exception())
1136 1147
1137 1148 # clear the existing records
1138 1149 rec = empty_record()
1139 1150 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1140 1151 rec['resubmitted'] = datetime.now()
1141 1152 rec['queue'] = 'task'
1142 1153 rec['client_uuid'] = client_id[0]
1143 1154 try:
1144 1155 for msg_id in msg_ids:
1145 1156 self.all_completed.discard(msg_id)
1146 1157 self.db.update_record(msg_id, rec)
1147 1158 except Exception:
1148 1159 self.log.error('db::db error upating record', exc_info=True)
1149 1160 reply = error.wrap_exception()
1150 1161 else:
1151 1162 # send the messages
1152 1163 for rec in records:
1153 1164 header = rec['header']
1154 1165 msg = self.session.msg(header['msg_type'])
1155 1166 msg['content'] = rec['content']
1156 1167 msg['header'] = header
1157 1168 msg['msg_id'] = rec['msg_id']
1158 1169 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1159 1170
1160 1171 finish(dict(status='ok'))
1161 1172
1162 1173
1163 1174 def _extract_record(self, rec):
1164 1175 """decompose a TaskRecord dict into subsection of reply for get_result"""
1165 1176 io_dict = {}
1166 1177 for key in 'pyin pyout pyerr stdout stderr'.split():
1167 1178 io_dict[key] = rec[key]
1168 1179 content = { 'result_content': rec['result_content'],
1169 1180 'header': rec['header'],
1170 1181 'result_header' : rec['result_header'],
1171 1182 'io' : io_dict,
1172 1183 }
1173 1184 if rec['result_buffers']:
1174 1185 buffers = map(str, rec['result_buffers'])
1175 1186 else:
1176 1187 buffers = []
1177 1188
1178 1189 return content, buffers
1179 1190
1180 1191 def get_results(self, client_id, msg):
1181 1192 """Get the result of 1 or more messages."""
1182 1193 content = msg['content']
1183 1194 msg_ids = sorted(set(content['msg_ids']))
1184 1195 statusonly = content.get('status_only', False)
1185 1196 pending = []
1186 1197 completed = []
1187 1198 content = dict(status='ok')
1188 1199 content['pending'] = pending
1189 1200 content['completed'] = completed
1190 1201 buffers = []
1191 1202 if not statusonly:
1192 1203 try:
1193 1204 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1194 1205 # turn match list into dict, for faster lookup
1195 1206 records = {}
1196 1207 for rec in matches:
1197 1208 records[rec['msg_id']] = rec
1198 1209 except Exception:
1199 1210 content = error.wrap_exception()
1200 1211 self.session.send(self.query, "result_reply", content=content,
1201 1212 parent=msg, ident=client_id)
1202 1213 return
1203 1214 else:
1204 1215 records = {}
1205 1216 for msg_id in msg_ids:
1206 1217 if msg_id in self.pending:
1207 1218 pending.append(msg_id)
1208 1219 elif msg_id in self.all_completed:
1209 1220 completed.append(msg_id)
1210 1221 if not statusonly:
1211 1222 c,bufs = self._extract_record(records[msg_id])
1212 1223 content[msg_id] = c
1213 1224 buffers.extend(bufs)
1214 1225 elif msg_id in records:
1215 1226 if rec['completed']:
1216 1227 completed.append(msg_id)
1217 1228 c,bufs = self._extract_record(records[msg_id])
1218 1229 content[msg_id] = c
1219 1230 buffers.extend(bufs)
1220 1231 else:
1221 1232 pending.append(msg_id)
1222 1233 else:
1223 1234 try:
1224 1235 raise KeyError('No such message: '+msg_id)
1225 1236 except:
1226 1237 content = error.wrap_exception()
1227 1238 break
1228 1239 self.session.send(self.query, "result_reply", content=content,
1229 1240 parent=msg, ident=client_id,
1230 1241 buffers=buffers)
1231 1242
1232 1243 def get_history(self, client_id, msg):
1233 1244 """Get a list of all msg_ids in our DB records"""
1234 1245 try:
1235 1246 msg_ids = self.db.get_history()
1236 1247 except Exception as e:
1237 1248 content = error.wrap_exception()
1238 1249 else:
1239 1250 content = dict(status='ok', history=msg_ids)
1240 1251
1241 1252 self.session.send(self.query, "history_reply", content=content,
1242 1253 parent=msg, ident=client_id)
1243 1254
1244 1255 def db_query(self, client_id, msg):
1245 1256 """Perform a raw query on the task record database."""
1246 1257 content = msg['content']
1247 1258 query = content.get('query', {})
1248 1259 keys = content.get('keys', None)
1249 1260 query = util.extract_dates(query)
1250 1261 buffers = []
1251 1262 empty = list()
1252 1263
1253 1264 try:
1254 1265 records = self.db.find_records(query, keys)
1255 1266 except Exception as e:
1256 1267 content = error.wrap_exception()
1257 1268 else:
1258 1269 # extract buffers from reply content:
1259 1270 if keys is not None:
1260 1271 buffer_lens = [] if 'buffers' in keys else None
1261 1272 result_buffer_lens = [] if 'result_buffers' in keys else None
1262 1273 else:
1263 1274 buffer_lens = []
1264 1275 result_buffer_lens = []
1265 1276
1266 1277 for rec in records:
1267 1278 # buffers may be None, so double check
1268 1279 if buffer_lens is not None:
1269 1280 b = rec.pop('buffers', empty) or empty
1270 1281 buffer_lens.append(len(b))
1271 1282 buffers.extend(b)
1272 1283 if result_buffer_lens is not None:
1273 1284 rb = rec.pop('result_buffers', empty) or empty
1274 1285 result_buffer_lens.append(len(rb))
1275 1286 buffers.extend(rb)
1276 1287 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1277 1288 result_buffer_lens=result_buffer_lens)
1278 1289
1279 1290 self.session.send(self.query, "db_reply", content=content,
1280 1291 parent=msg, ident=client_id,
1281 1292 buffers=buffers)
1282 1293
@@ -1,101 +1,112
1 1 """A TaskRecord backend using mongodb"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 from pymongo import Connection
10 10 from pymongo.binary import Binary
11 11
12 12 from IPython.utils.traitlets import Dict, List, CUnicode, CStr, Instance
13 13
14 14 from .dictdb import BaseDB
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # MongoDB class
18 18 #-----------------------------------------------------------------------------
19 19
20 20 class MongoDB(BaseDB):
21 21 """MongoDB TaskRecord backend."""
22 22
23 connection_args = List(config=True) # args passed to pymongo.Connection
24 connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection
25 database = CUnicode(config=True) # name of the mongodb database
23 connection_args = List(config=True,
24 help="""Positional arguments to be passed to pymongo.Connection. Only
25 necessary if the default mongodb configuration does not point to your
26 mongod instance.""")
27 connection_kwargs = Dict(config=True,
28 help="""Keyword arguments to be passed to pymongo.Connection. Only
29 necessary if the default mongodb configuration does not point to your
30 mongod instance."""
31 )
32 database = CUnicode(config=True,
33 help="""The MongoDB database name to use for storing tasks for this session. If unspecified,
34 a new database will be created with the Hub's IDENT. Specifying the database will result
35 in tasks from previous sessions being available via Clients' db_query and
36 get_result methods.""")
26 37
27 38 _connection = Instance(Connection) # pymongo connection
28 39
29 40 def __init__(self, **kwargs):
30 41 super(MongoDB, self).__init__(**kwargs)
31 42 if self._connection is None:
32 43 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
33 44 if not self.database:
34 45 self.database = self.session
35 46 self._db = self._connection[self.database]
36 47 self._records = self._db['task_records']
37 48 self._records.ensure_index('msg_id', unique=True)
38 49 self._records.ensure_index('submitted') # for sorting history
39 50 # for rec in self._records.find
40 51
41 52 def _binary_buffers(self, rec):
42 53 for key in ('buffers', 'result_buffers'):
43 54 if rec.get(key, None):
44 55 rec[key] = map(Binary, rec[key])
45 56 return rec
46 57
47 58 def add_record(self, msg_id, rec):
48 59 """Add a new Task Record, by msg_id."""
49 60 # print rec
50 61 rec = self._binary_buffers(rec)
51 62 self._records.insert(rec)
52 63
53 64 def get_record(self, msg_id):
54 65 """Get a specific Task Record, by msg_id."""
55 66 r = self._records.find_one({'msg_id': msg_id})
56 67 if not r:
57 68 # r will be '' if nothing is found
58 69 raise KeyError(msg_id)
59 70 return r
60 71
61 72 def update_record(self, msg_id, rec):
62 73 """Update the data in an existing record."""
63 74 rec = self._binary_buffers(rec)
64 75
65 76 self._records.update({'msg_id':msg_id}, {'$set': rec})
66 77
67 78 def drop_matching_records(self, check):
68 79 """Remove a record from the DB."""
69 80 self._records.remove(check)
70 81
71 82 def drop_record(self, msg_id):
72 83 """Remove a record from the DB."""
73 84 self._records.remove({'msg_id':msg_id})
74 85
75 86 def find_records(self, check, keys=None):
76 87 """Find records matching a query dict, optionally extracting subset of keys.
77 88
78 89 Returns list of matching records.
79 90
80 91 Parameters
81 92 ----------
82 93
83 94 check: dict
84 95 mongodb-style query argument
85 96 keys: list of strs [optional]
86 97 if specified, the subset of keys to extract. msg_id will *always* be
87 98 included.
88 99 """
89 100 if keys and 'msg_id' not in keys:
90 101 keys.append('msg_id')
91 102 matches = list(self._records.find(check,keys))
92 103 for rec in matches:
93 104 rec.pop('_id')
94 105 return matches
95 106
96 107 def get_history(self):
97 108 """get all msg_ids, ordered by time submitted."""
98 109 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
99 110 return [ rec['msg_id'] for rec in cursor ]
100 111
101 112
@@ -1,665 +1,677
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #----------------------------------------------------------------------
15 15 # Imports
16 16 #----------------------------------------------------------------------
17 17
18 18 from __future__ import print_function
19 19
20 20 import logging
21 21 import sys
22 22
23 23 from datetime import datetime, timedelta
24 24 from random import randint, random
25 25 from types import FunctionType
26 26
27 27 try:
28 28 import numpy
29 29 except ImportError:
30 30 numpy = None
31 31
32 32 import zmq
33 33 from zmq.eventloop import ioloop, zmqstream
34 34
35 35 # local imports
36 36 from IPython.external.decorator import decorator
37 37 from IPython.config.loader import Config
38 from IPython.utils.traitlets import Instance, Dict, List, Set, Int
38 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Str, Enum
39 39
40 40 from IPython.parallel import error
41 41 from IPython.parallel.factory import SessionFactory
42 42 from IPython.parallel.util import connect_logger, local_logger
43 43
44 44 from .dependency import Dependency
45 45
46 46 @decorator
47 47 def logged(f,self,*args,**kwargs):
48 48 # print ("#--------------------")
49 49 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
50 50 # print ("#--")
51 51 return f(self,*args, **kwargs)
52 52
53 53 #----------------------------------------------------------------------
54 54 # Chooser functions
55 55 #----------------------------------------------------------------------
56 56
57 57 def plainrandom(loads):
58 58 """Plain random pick."""
59 59 n = len(loads)
60 60 return randint(0,n-1)
61 61
62 62 def lru(loads):
63 63 """Always pick the front of the line.
64 64
65 65 The content of `loads` is ignored.
66 66
67 67 Assumes LRU ordering of loads, with oldest first.
68 68 """
69 69 return 0
70 70
71 71 def twobin(loads):
72 72 """Pick two at random, use the LRU of the two.
73 73
74 74 The content of loads is ignored.
75 75
76 76 Assumes LRU ordering of loads, with oldest first.
77 77 """
78 78 n = len(loads)
79 79 a = randint(0,n-1)
80 80 b = randint(0,n-1)
81 81 return min(a,b)
82 82
83 83 def weighted(loads):
84 84 """Pick two at random using inverse load as weight.
85 85
86 86 Return the less loaded of the two.
87 87 """
88 88 # weight 0 a million times more than 1:
89 89 weights = 1./(1e-6+numpy.array(loads))
90 90 sums = weights.cumsum()
91 91 t = sums[-1]
92 92 x = random()*t
93 93 y = random()*t
94 94 idx = 0
95 95 idy = 0
96 96 while sums[idx] < x:
97 97 idx += 1
98 98 while sums[idy] < y:
99 99 idy += 1
100 100 if weights[idy] > weights[idx]:
101 101 return idy
102 102 else:
103 103 return idx
104 104
105 105 def leastload(loads):
106 106 """Always choose the lowest load.
107 107
108 108 If the lowest load occurs more than once, the first
109 109 occurance will be used. If loads has LRU ordering, this means
110 110 the LRU of those with the lowest load is chosen.
111 111 """
112 112 return loads.index(min(loads))
113 113
114 114 #---------------------------------------------------------------------
115 115 # Classes
116 116 #---------------------------------------------------------------------
117 117 # store empty default dependency:
118 118 MET = Dependency([])
119 119
120 120 class TaskScheduler(SessionFactory):
121 121 """Python TaskScheduler object.
122 122
123 123 This is the simplest object that supports msg_id based
124 124 DAG dependencies. *Only* task msg_ids are checked, not
125 125 msg_ids of jobs submitted via the MUX queue.
126 126
127 127 """
128 128
129 hwm = Int(0, config=True) # limit number of outstanding tasks
129 hwm = Int(0, config=True, shortname='hwm',
130 help="""specify the High Water Mark (HWM) for the downstream
131 socket in the Task scheduler. This is the maximum number
132 of allowed outstanding tasks on each engine."""
133 )
134 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
135 'leastload', config=True, shortname='scheme', allow_none=False,
136 help="""select the task scheduler scheme [default: Python LRU]
137 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
138 )
139 def _scheme_name_changed(self, old, new):
140 self.log.debug("Using scheme %r"%new)
141 self.scheme = globals()[new]
130 142
131 143 # input arguments:
132 144 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
133 145 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
134 146 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
135 147 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
136 148 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
137 149
138 150 # internals:
139 151 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
140 152 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
141 153 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
142 154 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
143 155 pending = Dict() # dict by engine_uuid of submitted tasks
144 156 completed = Dict() # dict by engine_uuid of completed tasks
145 157 failed = Dict() # dict by engine_uuid of failed tasks
146 158 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
147 159 clients = Dict() # dict by msg_id for who submitted the task
148 160 targets = List() # list of target IDENTs
149 161 loads = List() # list of engine loads
150 162 # full = Set() # set of IDENTs that have HWM outstanding tasks
151 163 all_completed = Set() # set of all completed tasks
152 164 all_failed = Set() # set of all failed tasks
153 165 all_done = Set() # set of all finished tasks=union(completed,failed)
154 166 all_ids = Set() # set of all submitted task IDs
155 167 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
156 168 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
157 169
158 170
159 171 def start(self):
160 172 self.engine_stream.on_recv(self.dispatch_result, copy=False)
161 173 self._notification_handlers = dict(
162 174 registration_notification = self._register_engine,
163 175 unregistration_notification = self._unregister_engine
164 176 )
165 177 self.notifier_stream.on_recv(self.dispatch_notification)
166 178 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
167 179 self.auditor.start()
168 180 self.log.info("Scheduler started...%r"%self)
169 181
170 182 def resume_receiving(self):
171 183 """Resume accepting jobs."""
172 184 self.client_stream.on_recv(self.dispatch_submission, copy=False)
173 185
174 186 def stop_receiving(self):
175 187 """Stop accepting jobs while there are no engines.
176 188 Leave them in the ZMQ queue."""
177 189 self.client_stream.on_recv(None)
178 190
179 191 #-----------------------------------------------------------------------
180 192 # [Un]Registration Handling
181 193 #-----------------------------------------------------------------------
182 194
183 195 def dispatch_notification(self, msg):
184 196 """dispatch register/unregister events."""
185 197 idents,msg = self.session.feed_identities(msg)
186 198 msg = self.session.unpack_message(msg)
187 199 msg_type = msg['msg_type']
188 200 handler = self._notification_handlers.get(msg_type, None)
189 201 if handler is None:
190 202 raise Exception("Unhandled message type: %s"%msg_type)
191 203 else:
192 204 try:
193 205 handler(str(msg['content']['queue']))
194 206 except KeyError:
195 207 self.log.error("task::Invalid notification msg: %s"%msg)
196 208
197 209 @logged
198 210 def _register_engine(self, uid):
199 211 """New engine with ident `uid` became available."""
200 212 # head of the line:
201 213 self.targets.insert(0,uid)
202 214 self.loads.insert(0,0)
203 215 # initialize sets
204 216 self.completed[uid] = set()
205 217 self.failed[uid] = set()
206 218 self.pending[uid] = {}
207 219 if len(self.targets) == 1:
208 220 self.resume_receiving()
209 221 # rescan the graph:
210 222 self.update_graph(None)
211 223
212 224 def _unregister_engine(self, uid):
213 225 """Existing engine with ident `uid` became unavailable."""
214 226 if len(self.targets) == 1:
215 227 # this was our only engine
216 228 self.stop_receiving()
217 229
218 230 # handle any potentially finished tasks:
219 231 self.engine_stream.flush()
220 232
221 233 # don't pop destinations, because they might be used later
222 234 # map(self.destinations.pop, self.completed.pop(uid))
223 235 # map(self.destinations.pop, self.failed.pop(uid))
224 236
225 237 # prevent this engine from receiving work
226 238 idx = self.targets.index(uid)
227 239 self.targets.pop(idx)
228 240 self.loads.pop(idx)
229 241
230 242 # wait 5 seconds before cleaning up pending jobs, since the results might
231 243 # still be incoming
232 244 if self.pending[uid]:
233 245 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
234 246 dc.start()
235 247 else:
236 248 self.completed.pop(uid)
237 249 self.failed.pop(uid)
238 250
239 251
240 252 @logged
241 253 def handle_stranded_tasks(self, engine):
242 254 """Deal with jobs resident in an engine that died."""
243 255 lost = self.pending[engine]
244 256 for msg_id in lost.keys():
245 257 if msg_id not in self.pending[engine]:
246 258 # prevent double-handling of messages
247 259 continue
248 260
249 261 raw_msg = lost[msg_id][0]
250 262
251 263 idents,msg = self.session.feed_identities(raw_msg, copy=False)
252 264 msg = self.session.unpack_message(msg, copy=False, content=False)
253 265 parent = msg['header']
254 266 idents = [engine, idents[0]]
255 267
256 268 # build fake error reply
257 269 try:
258 270 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
259 271 except:
260 272 content = error.wrap_exception()
261 273 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
262 274 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
263 275 # and dispatch it
264 276 self.dispatch_result(raw_reply)
265 277
266 278 # finally scrub completed/failed lists
267 279 self.completed.pop(engine)
268 280 self.failed.pop(engine)
269 281
270 282
271 283 #-----------------------------------------------------------------------
272 284 # Job Submission
273 285 #-----------------------------------------------------------------------
274 286 @logged
275 287 def dispatch_submission(self, raw_msg):
276 288 """Dispatch job submission to appropriate handlers."""
277 289 # ensure targets up to date:
278 290 self.notifier_stream.flush()
279 291 try:
280 292 idents, msg = self.session.feed_identities(raw_msg, copy=False)
281 293 msg = self.session.unpack_message(msg, content=False, copy=False)
282 294 except Exception:
283 295 self.log.error("task::Invaid task: %s"%raw_msg, exc_info=True)
284 296 return
285 297
286 298 # send to monitor
287 299 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
288 300
289 301 header = msg['header']
290 302 msg_id = header['msg_id']
291 303 self.all_ids.add(msg_id)
292 304
293 305 # targets
294 306 targets = set(header.get('targets', []))
295 307 retries = header.get('retries', 0)
296 308 self.retries[msg_id] = retries
297 309
298 310 # time dependencies
299 311 after = Dependency(header.get('after', []))
300 312 if after.all:
301 313 if after.success:
302 314 after.difference_update(self.all_completed)
303 315 if after.failure:
304 316 after.difference_update(self.all_failed)
305 317 if after.check(self.all_completed, self.all_failed):
306 318 # recast as empty set, if `after` already met,
307 319 # to prevent unnecessary set comparisons
308 320 after = MET
309 321
310 322 # location dependencies
311 323 follow = Dependency(header.get('follow', []))
312 324
313 325 # turn timeouts into datetime objects:
314 326 timeout = header.get('timeout', None)
315 327 if timeout:
316 328 timeout = datetime.now() + timedelta(0,timeout,0)
317 329
318 330 args = [raw_msg, targets, after, follow, timeout]
319 331
320 332 # validate and reduce dependencies:
321 333 for dep in after,follow:
322 334 # check valid:
323 335 if msg_id in dep or dep.difference(self.all_ids):
324 336 self.depending[msg_id] = args
325 337 return self.fail_unreachable(msg_id, error.InvalidDependency)
326 338 # check if unreachable:
327 339 if dep.unreachable(self.all_completed, self.all_failed):
328 340 self.depending[msg_id] = args
329 341 return self.fail_unreachable(msg_id)
330 342
331 343 if after.check(self.all_completed, self.all_failed):
332 344 # time deps already met, try to run
333 345 if not self.maybe_run(msg_id, *args):
334 346 # can't run yet
335 347 if msg_id not in self.all_failed:
336 348 # could have failed as unreachable
337 349 self.save_unmet(msg_id, *args)
338 350 else:
339 351 self.save_unmet(msg_id, *args)
340 352
341 353 # @logged
342 354 def audit_timeouts(self):
343 355 """Audit all waiting tasks for expired timeouts."""
344 356 now = datetime.now()
345 357 for msg_id in self.depending.keys():
346 358 # must recheck, in case one failure cascaded to another:
347 359 if msg_id in self.depending:
348 360 raw,after,targets,follow,timeout = self.depending[msg_id]
349 361 if timeout and timeout < now:
350 362 self.fail_unreachable(msg_id, error.TaskTimeout)
351 363
352 364 @logged
353 365 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
354 366 """a task has become unreachable, send a reply with an ImpossibleDependency
355 367 error."""
356 368 if msg_id not in self.depending:
357 369 self.log.error("msg %r already failed!"%msg_id)
358 370 return
359 371 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
360 372 for mid in follow.union(after):
361 373 if mid in self.graph:
362 374 self.graph[mid].remove(msg_id)
363 375
364 376 # FIXME: unpacking a message I've already unpacked, but didn't save:
365 377 idents,msg = self.session.feed_identities(raw_msg, copy=False)
366 378 msg = self.session.unpack_message(msg, copy=False, content=False)
367 379 header = msg['header']
368 380
369 381 try:
370 382 raise why()
371 383 except:
372 384 content = error.wrap_exception()
373 385
374 386 self.all_done.add(msg_id)
375 387 self.all_failed.add(msg_id)
376 388
377 389 msg = self.session.send(self.client_stream, 'apply_reply', content,
378 390 parent=header, ident=idents)
379 391 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
380 392
381 393 self.update_graph(msg_id, success=False)
382 394
383 395 @logged
384 396 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
385 397 """check location dependencies, and run if they are met."""
386 398 blacklist = self.blacklist.setdefault(msg_id, set())
387 399 if follow or targets or blacklist or self.hwm:
388 400 # we need a can_run filter
389 401 def can_run(idx):
390 402 # check hwm
391 403 if self.hwm and self.loads[idx] == self.hwm:
392 404 return False
393 405 target = self.targets[idx]
394 406 # check blacklist
395 407 if target in blacklist:
396 408 return False
397 409 # check targets
398 410 if targets and target not in targets:
399 411 return False
400 412 # check follow
401 413 return follow.check(self.completed[target], self.failed[target])
402 414
403 415 indices = filter(can_run, range(len(self.targets)))
404 416
405 417 if not indices:
406 418 # couldn't run
407 419 if follow.all:
408 420 # check follow for impossibility
409 421 dests = set()
410 422 relevant = set()
411 423 if follow.success:
412 424 relevant = self.all_completed
413 425 if follow.failure:
414 426 relevant = relevant.union(self.all_failed)
415 427 for m in follow.intersection(relevant):
416 428 dests.add(self.destinations[m])
417 429 if len(dests) > 1:
418 430 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
419 431 self.fail_unreachable(msg_id)
420 432 return False
421 433 if targets:
422 434 # check blacklist+targets for impossibility
423 435 targets.difference_update(blacklist)
424 436 if not targets or not targets.intersection(self.targets):
425 437 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
426 438 self.fail_unreachable(msg_id)
427 439 return False
428 440 return False
429 441 else:
430 442 indices = None
431 443
432 444 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
433 445 return True
434 446
435 447 @logged
436 448 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
437 449 """Save a message for later submission when its dependencies are met."""
438 450 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
439 451 # track the ids in follow or after, but not those already finished
440 452 for dep_id in after.union(follow).difference(self.all_done):
441 453 if dep_id not in self.graph:
442 454 self.graph[dep_id] = set()
443 455 self.graph[dep_id].add(msg_id)
444 456
445 457 @logged
446 458 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
447 459 """Submit a task to any of a subset of our targets."""
448 460 if indices:
449 461 loads = [self.loads[i] for i in indices]
450 462 else:
451 463 loads = self.loads
452 464 idx = self.scheme(loads)
453 465 if indices:
454 466 idx = indices[idx]
455 467 target = self.targets[idx]
456 468 # print (target, map(str, msg[:3]))
457 469 # send job to the engine
458 470 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
459 471 self.engine_stream.send_multipart(raw_msg, copy=False)
460 472 # update load
461 473 self.add_job(idx)
462 474 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
463 475 # notify Hub
464 476 content = dict(msg_id=msg_id, engine_id=target)
465 477 self.session.send(self.mon_stream, 'task_destination', content=content,
466 478 ident=['tracktask',self.session.session])
467 479
468 480
469 481 #-----------------------------------------------------------------------
470 482 # Result Handling
471 483 #-----------------------------------------------------------------------
472 484 @logged
473 485 def dispatch_result(self, raw_msg):
474 486 """dispatch method for result replies"""
475 487 try:
476 488 idents,msg = self.session.feed_identities(raw_msg, copy=False)
477 489 msg = self.session.unpack_message(msg, content=False, copy=False)
478 490 engine = idents[0]
479 491 try:
480 492 idx = self.targets.index(engine)
481 493 except ValueError:
482 494 pass # skip load-update for dead engines
483 495 else:
484 496 self.finish_job(idx)
485 497 except Exception:
486 498 self.log.error("task::Invaid result: %s"%raw_msg, exc_info=True)
487 499 return
488 500
489 501 header = msg['header']
490 502 parent = msg['parent_header']
491 503 if header.get('dependencies_met', True):
492 504 success = (header['status'] == 'ok')
493 505 msg_id = parent['msg_id']
494 506 retries = self.retries[msg_id]
495 507 if not success and retries > 0:
496 508 # failed
497 509 self.retries[msg_id] = retries - 1
498 510 self.handle_unmet_dependency(idents, parent)
499 511 else:
500 512 del self.retries[msg_id]
501 513 # relay to client and update graph
502 514 self.handle_result(idents, parent, raw_msg, success)
503 515 # send to Hub monitor
504 516 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
505 517 else:
506 518 self.handle_unmet_dependency(idents, parent)
507 519
508 520 @logged
509 521 def handle_result(self, idents, parent, raw_msg, success=True):
510 522 """handle a real task result, either success or failure"""
511 523 # first, relay result to client
512 524 engine = idents[0]
513 525 client = idents[1]
514 526 # swap_ids for XREP-XREP mirror
515 527 raw_msg[:2] = [client,engine]
516 528 # print (map(str, raw_msg[:4]))
517 529 self.client_stream.send_multipart(raw_msg, copy=False)
518 530 # now, update our data structures
519 531 msg_id = parent['msg_id']
520 532 self.blacklist.pop(msg_id, None)
521 533 self.pending[engine].pop(msg_id)
522 534 if success:
523 535 self.completed[engine].add(msg_id)
524 536 self.all_completed.add(msg_id)
525 537 else:
526 538 self.failed[engine].add(msg_id)
527 539 self.all_failed.add(msg_id)
528 540 self.all_done.add(msg_id)
529 541 self.destinations[msg_id] = engine
530 542
531 543 self.update_graph(msg_id, success)
532 544
533 545 @logged
534 546 def handle_unmet_dependency(self, idents, parent):
535 547 """handle an unmet dependency"""
536 548 engine = idents[0]
537 549 msg_id = parent['msg_id']
538 550
539 551 if msg_id not in self.blacklist:
540 552 self.blacklist[msg_id] = set()
541 553 self.blacklist[msg_id].add(engine)
542 554
543 555 args = self.pending[engine].pop(msg_id)
544 556 raw,targets,after,follow,timeout = args
545 557
546 558 if self.blacklist[msg_id] == targets:
547 559 self.depending[msg_id] = args
548 560 self.fail_unreachable(msg_id)
549 561 elif not self.maybe_run(msg_id, *args):
550 562 # resubmit failed
551 563 if msg_id not in self.all_failed:
552 564 # put it back in our dependency tree
553 565 self.save_unmet(msg_id, *args)
554 566
555 567 if self.hwm:
556 568 try:
557 569 idx = self.targets.index(engine)
558 570 except ValueError:
559 571 pass # skip load-update for dead engines
560 572 else:
561 573 if self.loads[idx] == self.hwm-1:
562 574 self.update_graph(None)
563 575
564 576
565 577
566 578 @logged
567 579 def update_graph(self, dep_id=None, success=True):
568 580 """dep_id just finished. Update our dependency
569 581 graph and submit any jobs that just became runable.
570 582
571 583 Called with dep_id=None to update entire graph for hwm, but without finishing
572 584 a task.
573 585 """
574 586 # print ("\n\n***********")
575 587 # pprint (dep_id)
576 588 # pprint (self.graph)
577 589 # pprint (self.depending)
578 590 # pprint (self.all_completed)
579 591 # pprint (self.all_failed)
580 592 # print ("\n\n***********\n\n")
581 593 # update any jobs that depended on the dependency
582 594 jobs = self.graph.pop(dep_id, [])
583 595
584 596 # recheck *all* jobs if
585 597 # a) we have HWM and an engine just become no longer full
586 598 # or b) dep_id was given as None
587 599 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
588 600 jobs = self.depending.keys()
589 601
590 602 for msg_id in jobs:
591 603 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
592 604
593 605 if after.unreachable(self.all_completed, self.all_failed) or follow.unreachable(self.all_completed, self.all_failed):
594 606 self.fail_unreachable(msg_id)
595 607
596 608 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
597 609 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
598 610
599 611 self.depending.pop(msg_id)
600 612 for mid in follow.union(after):
601 613 if mid in self.graph:
602 614 self.graph[mid].remove(msg_id)
603 615
604 616 #----------------------------------------------------------------------
605 617 # methods to be overridden by subclasses
606 618 #----------------------------------------------------------------------
607 619
608 620 def add_job(self, idx):
609 621 """Called after self.targets[idx] just got the job with header.
610 622 Override with subclasses. The default ordering is simple LRU.
611 623 The default loads are the number of outstanding jobs."""
612 624 self.loads[idx] += 1
613 625 for lis in (self.targets, self.loads):
614 626 lis.append(lis.pop(idx))
615 627
616 628
617 629 def finish_job(self, idx):
618 630 """Called after self.targets[idx] just finished a job.
619 631 Override with subclasses."""
620 632 self.loads[idx] -= 1
621 633
622 634
623 635
624 636 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,logname='ZMQ',
625 log_addr=None, loglevel=logging.DEBUG, scheme='lru',
637 log_addr=None, loglevel=logging.DEBUG,
626 638 identity=b'task'):
627 639 from zmq.eventloop import ioloop
628 640 from zmq.eventloop.zmqstream import ZMQStream
629 641
630 642 if config:
631 643 # unwrap dict back into Config
632 644 config = Config(config)
633 645
634 646 ctx = zmq.Context()
635 647 loop = ioloop.IOLoop()
636 648 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
637 649 ins.setsockopt(zmq.IDENTITY, identity)
638 650 ins.bind(in_addr)
639 651
640 652 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
641 653 outs.setsockopt(zmq.IDENTITY, identity)
642 654 outs.bind(out_addr)
643 655 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
644 656 mons.connect(mon_addr)
645 657 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
646 658 nots.setsockopt(zmq.SUBSCRIBE, '')
647 659 nots.connect(not_addr)
648 660
649 scheme = globals().get(scheme, None)
661 # scheme = globals().get(scheme, None)
650 662 # setup logging
651 663 if log_addr:
652 664 connect_logger(logname, ctx, log_addr, root="scheduler", loglevel=loglevel)
653 665 else:
654 666 local_logger(logname, loglevel)
655 667
656 668 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
657 669 mon_stream=mons, notifier_stream=nots,
658 scheme=scheme, loop=loop, logname=logname,
670 loop=loop, logname=logname,
659 671 config=config)
660 672 scheduler.start()
661 673 try:
662 674 loop.start()
663 675 except KeyboardInterrupt:
664 676 print ("interrupted, exiting...", file=sys.__stderr__)
665 677
@@ -1,326 +1,333
1 1 """A TaskRecord backend using sqlite3"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2011 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 import json
10 10 import os
11 11 import cPickle as pickle
12 12 from datetime import datetime
13 13
14 14 import sqlite3
15 15
16 16 from zmq.eventloop import ioloop
17 17
18 18 from IPython.utils.traitlets import CUnicode, CStr, Instance, List
19 19 from .dictdb import BaseDB
20 20 from IPython.parallel.util import ISO8601
21 21
22 22 #-----------------------------------------------------------------------------
23 23 # SQLite operators, adapters, and converters
24 24 #-----------------------------------------------------------------------------
25 25
26 26 operators = {
27 27 '$lt' : "<",
28 28 '$gt' : ">",
29 29 # null is handled weird with ==,!=
30 30 '$eq' : "=",
31 31 '$ne' : "!=",
32 32 '$lte': "<=",
33 33 '$gte': ">=",
34 34 '$in' : ('=', ' OR '),
35 35 '$nin': ('!=', ' AND '),
36 36 # '$all': None,
37 37 # '$mod': None,
38 38 # '$exists' : None
39 39 }
40 40 null_operators = {
41 41 '=' : "IS NULL",
42 42 '!=' : "IS NOT NULL",
43 43 }
44 44
45 45 def _adapt_datetime(dt):
46 46 return dt.strftime(ISO8601)
47 47
48 48 def _convert_datetime(ds):
49 49 if ds is None:
50 50 return ds
51 51 else:
52 52 return datetime.strptime(ds, ISO8601)
53 53
54 54 def _adapt_dict(d):
55 55 return json.dumps(d)
56 56
57 57 def _convert_dict(ds):
58 58 if ds is None:
59 59 return ds
60 60 else:
61 61 return json.loads(ds)
62 62
63 63 def _adapt_bufs(bufs):
64 64 # this is *horrible*
65 65 # copy buffers into single list and pickle it:
66 66 if bufs and isinstance(bufs[0], (bytes, buffer)):
67 67 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
68 68 elif bufs:
69 69 return bufs
70 70 else:
71 71 return None
72 72
73 73 def _convert_bufs(bs):
74 74 if bs is None:
75 75 return []
76 76 else:
77 77 return pickle.loads(bytes(bs))
78 78
79 79 #-----------------------------------------------------------------------------
80 80 # SQLiteDB class
81 81 #-----------------------------------------------------------------------------
82 82
83 83 class SQLiteDB(BaseDB):
84 84 """SQLite3 TaskRecord backend."""
85 85
86 filename = CUnicode('tasks.db', config=True)
87 location = CUnicode('', config=True)
88 table = CUnicode("", config=True)
86 filename = CUnicode('tasks.db', config=True,
87 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
88 location = CUnicode('', config=True,
89 help="""The directory containing the sqlite task database. The default
90 is to use the cluster_dir location.""")
91 table = CUnicode("", config=True,
92 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
93 a new table will be created with the Hub's IDENT. Specifying the table will result
94 in tasks from previous sessions being available via Clients' db_query and
95 get_result methods.""")
89 96
90 97 _db = Instance('sqlite3.Connection')
91 98 _keys = List(['msg_id' ,
92 99 'header' ,
93 100 'content',
94 101 'buffers',
95 102 'submitted',
96 103 'client_uuid' ,
97 104 'engine_uuid' ,
98 105 'started',
99 106 'completed',
100 107 'resubmitted',
101 108 'result_header' ,
102 109 'result_content' ,
103 110 'result_buffers' ,
104 111 'queue' ,
105 112 'pyin' ,
106 113 'pyout',
107 114 'pyerr',
108 115 'stdout',
109 116 'stderr',
110 117 ])
111 118
112 119 def __init__(self, **kwargs):
113 120 super(SQLiteDB, self).__init__(**kwargs)
114 121 if not self.table:
115 122 # use session, and prefix _, since starting with # is illegal
116 123 self.table = '_'+self.session.replace('-','_')
117 124 if not self.location:
118 125 if hasattr(self.config.Global, 'cluster_dir'):
119 126 self.location = self.config.Global.cluster_dir
120 127 else:
121 128 self.location = '.'
122 129 self._init_db()
123 130
124 131 # register db commit as 2s periodic callback
125 132 # to prevent clogging pipes
126 133 # assumes we are being run in a zmq ioloop app
127 134 loop = ioloop.IOLoop.instance()
128 135 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
129 136 pc.start()
130 137
131 138 def _defaults(self, keys=None):
132 139 """create an empty record"""
133 140 d = {}
134 141 keys = self._keys if keys is None else keys
135 142 for key in keys:
136 143 d[key] = None
137 144 return d
138 145
139 146 def _init_db(self):
140 147 """Connect to the database and get new session number."""
141 148 # register adapters
142 149 sqlite3.register_adapter(datetime, _adapt_datetime)
143 150 sqlite3.register_converter('datetime', _convert_datetime)
144 151 sqlite3.register_adapter(dict, _adapt_dict)
145 152 sqlite3.register_converter('dict', _convert_dict)
146 153 sqlite3.register_adapter(list, _adapt_bufs)
147 154 sqlite3.register_converter('bufs', _convert_bufs)
148 155 # connect to the db
149 156 dbfile = os.path.join(self.location, self.filename)
150 157 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
151 158 # isolation_level = None)#,
152 159 cached_statements=64)
153 160 # print dir(self._db)
154 161
155 162 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
156 163 (msg_id text PRIMARY KEY,
157 164 header dict text,
158 165 content dict text,
159 166 buffers bufs blob,
160 167 submitted datetime text,
161 168 client_uuid text,
162 169 engine_uuid text,
163 170 started datetime text,
164 171 completed datetime text,
165 172 resubmitted datetime text,
166 173 result_header dict text,
167 174 result_content dict text,
168 175 result_buffers bufs blob,
169 176 queue text,
170 177 pyin text,
171 178 pyout text,
172 179 pyerr text,
173 180 stdout text,
174 181 stderr text)
175 182 """%self.table)
176 183 self._db.commit()
177 184
178 185 def _dict_to_list(self, d):
179 186 """turn a mongodb-style record dict into a list."""
180 187
181 188 return [ d[key] for key in self._keys ]
182 189
183 190 def _list_to_dict(self, line, keys=None):
184 191 """Inverse of dict_to_list"""
185 192 keys = self._keys if keys is None else keys
186 193 d = self._defaults(keys)
187 194 for key,value in zip(keys, line):
188 195 d[key] = value
189 196
190 197 return d
191 198
192 199 def _render_expression(self, check):
193 200 """Turn a mongodb-style search dict into an SQL query."""
194 201 expressions = []
195 202 args = []
196 203
197 204 skeys = set(check.keys())
198 205 skeys.difference_update(set(self._keys))
199 206 skeys.difference_update(set(['buffers', 'result_buffers']))
200 207 if skeys:
201 208 raise KeyError("Illegal testing key(s): %s"%skeys)
202 209
203 210 for name,sub_check in check.iteritems():
204 211 if isinstance(sub_check, dict):
205 212 for test,value in sub_check.iteritems():
206 213 try:
207 214 op = operators[test]
208 215 except KeyError:
209 216 raise KeyError("Unsupported operator: %r"%test)
210 217 if isinstance(op, tuple):
211 218 op, join = op
212 219
213 220 if value is None and op in null_operators:
214 221 expr = "%s %s"%null_operators[op]
215 222 else:
216 223 expr = "%s %s ?"%(name, op)
217 224 if isinstance(value, (tuple,list)):
218 225 if op in null_operators and any([v is None for v in value]):
219 226 # equality tests don't work with NULL
220 227 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
221 228 expr = '( %s )'%( join.join([expr]*len(value)) )
222 229 args.extend(value)
223 230 else:
224 231 args.append(value)
225 232 expressions.append(expr)
226 233 else:
227 234 # it's an equality check
228 235 if sub_check is None:
229 236 expressions.append("%s IS NULL")
230 237 else:
231 238 expressions.append("%s = ?"%name)
232 239 args.append(sub_check)
233 240
234 241 expr = " AND ".join(expressions)
235 242 return expr, args
236 243
237 244 def add_record(self, msg_id, rec):
238 245 """Add a new Task Record, by msg_id."""
239 246 d = self._defaults()
240 247 d.update(rec)
241 248 d['msg_id'] = msg_id
242 249 line = self._dict_to_list(d)
243 250 tups = '(%s)'%(','.join(['?']*len(line)))
244 251 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
245 252 # self._db.commit()
246 253
247 254 def get_record(self, msg_id):
248 255 """Get a specific Task Record, by msg_id."""
249 256 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
250 257 line = cursor.fetchone()
251 258 if line is None:
252 259 raise KeyError("No such msg: %r"%msg_id)
253 260 return self._list_to_dict(line)
254 261
255 262 def update_record(self, msg_id, rec):
256 263 """Update the data in an existing record."""
257 264 query = "UPDATE %s SET "%self.table
258 265 sets = []
259 266 keys = sorted(rec.keys())
260 267 values = []
261 268 for key in keys:
262 269 sets.append('%s = ?'%key)
263 270 values.append(rec[key])
264 271 query += ', '.join(sets)
265 272 query += ' WHERE msg_id == ?'
266 273 values.append(msg_id)
267 274 self._db.execute(query, values)
268 275 # self._db.commit()
269 276
270 277 def drop_record(self, msg_id):
271 278 """Remove a record from the DB."""
272 279 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
273 280 # self._db.commit()
274 281
275 282 def drop_matching_records(self, check):
276 283 """Remove a record from the DB."""
277 284 expr,args = self._render_expression(check)
278 285 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
279 286 self._db.execute(query,args)
280 287 # self._db.commit()
281 288
282 289 def find_records(self, check, keys=None):
283 290 """Find records matching a query dict, optionally extracting subset of keys.
284 291
285 292 Returns list of matching records.
286 293
287 294 Parameters
288 295 ----------
289 296
290 297 check: dict
291 298 mongodb-style query argument
292 299 keys: list of strs [optional]
293 300 if specified, the subset of keys to extract. msg_id will *always* be
294 301 included.
295 302 """
296 303 if keys:
297 304 bad_keys = [ key for key in keys if key not in self._keys ]
298 305 if bad_keys:
299 306 raise KeyError("Bad record key(s): %s"%bad_keys)
300 307
301 308 if keys:
302 309 # ensure msg_id is present and first:
303 310 if 'msg_id' in keys:
304 311 keys.remove('msg_id')
305 312 keys.insert(0, 'msg_id')
306 313 req = ', '.join(keys)
307 314 else:
308 315 req = '*'
309 316 expr,args = self._render_expression(check)
310 317 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
311 318 cursor = self._db.execute(query, args)
312 319 matches = cursor.fetchall()
313 320 records = []
314 321 for line in matches:
315 322 rec = self._list_to_dict(line, keys)
316 323 records.append(rec)
317 324 return records
318 325
319 326 def get_history(self):
320 327 """get all msg_ids, ordered by time submitted."""
321 328 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
322 329 cursor = self._db.execute(query)
323 330 # will be a list of length 1 tuples
324 331 return [ tup[0] for tup in cursor.fetchall()]
325 332
326 333 __all__ = ['SQLiteDB'] No newline at end of file
@@ -1,156 +1,166
1 1 #!/usr/bin/env python
2 2 """A simple engine that talks to a controller over 0MQ.
3 3 it handles registration, etc. and launches a kernel
4 4 connected to the Controller's Schedulers.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010-2011 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import print_function
14 14
15 15 import sys
16 16 import time
17 17
18 18 import zmq
19 19 from zmq.eventloop import ioloop, zmqstream
20 20
21 21 # internal
22 22 from IPython.utils.traitlets import Instance, Str, Dict, Int, Type, CFloat
23 23 # from IPython.utils.localinterfaces import LOCALHOST
24 24
25 25 from IPython.parallel.controller.heartmonitor import Heart
26 26 from IPython.parallel.factory import RegistrationFactory
27 27 from IPython.parallel.streamsession import Message
28 28 from IPython.parallel.util import disambiguate_url
29 29
30 30 from .streamkernel import Kernel
31 31
32 32 class EngineFactory(RegistrationFactory):
33 33 """IPython engine"""
34 34
35 35 # configurables:
36 user_ns=Dict(config=True)
37 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True)
38 display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True)
39 location=Str(config=True)
40 timeout=CFloat(2,config=True)
36 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
37 help="""The OutStream for handling stdout/err.
38 Typically 'IPython.zmq.iostream.OutStream'""")
39 display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True,
40 help="""The class for handling displayhook.
41 Typically 'IPython.zmq.displayhook.DisplayHook'""")
42 location=Str(config=True,
43 help="""The location (an IP address) of the controller. This is
44 used for disambiguating URLs, to determine whether
45 loopback should be used to connect or the public address.""")
46 timeout=CFloat(2,config=True,
47 help="""The time (in seconds) to wait for the Controller to respond
48 to registration requests before giving up.""")
41 49
42 50 # not configurable:
51 user_ns=Dict()
43 52 id=Int(allow_none=True)
44 53 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
45 54 kernel=Instance(Kernel)
46 55
47 56
48 57 def __init__(self, **kwargs):
49 58 super(EngineFactory, self).__init__(**kwargs)
59 self.ident = self.session.session
50 60 ctx = self.context
51 61
52 62 reg = ctx.socket(zmq.XREQ)
53 63 reg.setsockopt(zmq.IDENTITY, self.ident)
54 64 reg.connect(self.url)
55 65 self.registrar = zmqstream.ZMQStream(reg, self.loop)
56 66
57 67 def register(self):
58 68 """send the registration_request"""
59 69
60 70 self.log.info("registering")
61 71 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
62 72 self.registrar.on_recv(self.complete_registration)
63 73 # print (self.session.key)
64 74 self.session.send(self.registrar, "registration_request",content=content)
65 75
66 76 def complete_registration(self, msg):
67 77 # print msg
68 78 self._abort_dc.stop()
69 79 ctx = self.context
70 80 loop = self.loop
71 81 identity = self.ident
72 82
73 83 idents,msg = self.session.feed_identities(msg)
74 84 msg = Message(self.session.unpack_message(msg))
75 85
76 86 if msg.content.status == 'ok':
77 87 self.id = int(msg.content.id)
78 88
79 89 # create Shell Streams (MUX, Task, etc.):
80 90 queue_addr = msg.content.mux
81 91 shell_addrs = [ str(queue_addr) ]
82 92 task_addr = msg.content.task
83 93 if task_addr:
84 94 shell_addrs.append(str(task_addr))
85 95
86 96 # Uncomment this to go back to two-socket model
87 97 # shell_streams = []
88 98 # for addr in shell_addrs:
89 99 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
90 100 # stream.setsockopt(zmq.IDENTITY, identity)
91 101 # stream.connect(disambiguate_url(addr, self.location))
92 102 # shell_streams.append(stream)
93 103
94 104 # Now use only one shell stream for mux and tasks
95 105 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
96 106 stream.setsockopt(zmq.IDENTITY, identity)
97 107 shell_streams = [stream]
98 108 for addr in shell_addrs:
99 109 stream.connect(disambiguate_url(addr, self.location))
100 110 # end single stream-socket
101 111
102 112 # control stream:
103 113 control_addr = str(msg.content.control)
104 114 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
105 115 control_stream.setsockopt(zmq.IDENTITY, identity)
106 116 control_stream.connect(disambiguate_url(control_addr, self.location))
107 117
108 118 # create iopub stream:
109 119 iopub_addr = msg.content.iopub
110 120 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
111 121 iopub_stream.setsockopt(zmq.IDENTITY, identity)
112 122 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
113 123
114 124 # launch heartbeat
115 125 hb_addrs = msg.content.heartbeat
116 126 # print (hb_addrs)
117 127
118 128 # # Redirect input streams and set a display hook.
119 129 if self.out_stream_factory:
120 130 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
121 131 sys.stdout.topic = 'engine.%i.stdout'%self.id
122 132 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
123 133 sys.stderr.topic = 'engine.%i.stderr'%self.id
124 134 if self.display_hook_factory:
125 135 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
126 136 sys.displayhook.topic = 'engine.%i.pyout'%self.id
127 137
128 138 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
129 139 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
130 loop=loop, user_ns = self.user_ns, logname=self.log.name)
140 loop=loop, user_ns = self.user_ns, log=self.log)
131 141 self.kernel.start()
132 142 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
133 143 heart = Heart(*map(str, hb_addrs), heart_id=identity)
134 144 # ioloop.DelayedCallback(heart.start, 1000, self.loop).start()
135 145 heart.start()
136 146
137 147
138 148 else:
139 149 self.log.fatal("Registration Failed: %s"%msg)
140 150 raise Exception("Registration Failed: %s"%msg)
141 151
142 152 self.log.info("Completed registration with id %i"%self.id)
143 153
144 154
145 155 def abort(self):
146 self.log.fatal("Registration timed out")
156 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
147 157 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
148 158 time.sleep(1)
149 159 sys.exit(255)
150 160
151 161 def start(self):
152 162 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
153 163 dc.start()
154 164 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
155 165 self._abort_dc.start()
156 166
@@ -1,431 +1,433
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4 """
5 5 #-----------------------------------------------------------------------------
6 6 # Copyright (C) 2010-2011 The IPython Development Team
7 7 #
8 8 # Distributed under the terms of the BSD License. The full license is in
9 9 # the file COPYING, distributed as part of this software.
10 10 #-----------------------------------------------------------------------------
11 11
12 12 #-----------------------------------------------------------------------------
13 13 # Imports
14 14 #-----------------------------------------------------------------------------
15 15
16 16 # Standard library imports.
17 17 from __future__ import print_function
18 18
19 19 import sys
20 20 import time
21 21
22 22 from code import CommandCompiler
23 23 from datetime import datetime
24 24 from pprint import pprint
25 25
26 26 # System library imports.
27 27 import zmq
28 28 from zmq.eventloop import ioloop, zmqstream
29 29
30 30 # Local imports.
31 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Str
31 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Str, CStr
32 32 from IPython.zmq.completer import KernelCompleter
33 33
34 34 from IPython.parallel.error import wrap_exception
35 35 from IPython.parallel.factory import SessionFactory
36 36 from IPython.parallel.util import serialize_object, unpack_apply_message, ISO8601
37 37
38 38 def printer(*args):
39 39 pprint(args, stream=sys.__stdout__)
40 40
41 41
42 42 class _Passer(zmqstream.ZMQStream):
43 43 """Empty class that implements `send()` that does nothing.
44 44
45 45 Subclass ZMQStream for StreamSession typechecking
46 46
47 47 """
48 48 def __init__(self, *args, **kwargs):
49 49 pass
50 50
51 51 def send(self, *args, **kwargs):
52 52 pass
53 53 send_multipart = send
54 54
55 55
56 56 #-----------------------------------------------------------------------------
57 57 # Main kernel class
58 58 #-----------------------------------------------------------------------------
59 59
60 60 class Kernel(SessionFactory):
61 61
62 62 #---------------------------------------------------------------------------
63 63 # Kernel interface
64 64 #---------------------------------------------------------------------------
65 65
66 66 # kwargs:
67 int_id = Int(-1, config=True)
68 user_ns = Dict(config=True)
69 exec_lines = List(config=True)
67 exec_lines = List(CStr, config=True,
68 help="List of lines to execute")
69
70 int_id = Int(-1)
71 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
70 72
71 73 control_stream = Instance(zmqstream.ZMQStream)
72 74 task_stream = Instance(zmqstream.ZMQStream)
73 75 iopub_stream = Instance(zmqstream.ZMQStream)
74 76 client = Instance('IPython.parallel.Client')
75 77
76 78 # internals
77 79 shell_streams = List()
78 80 compiler = Instance(CommandCompiler, (), {})
79 81 completer = Instance(KernelCompleter)
80 82
81 83 aborted = Set()
82 84 shell_handlers = Dict()
83 85 control_handlers = Dict()
84 86
85 87 def _set_prefix(self):
86 88 self.prefix = "engine.%s"%self.int_id
87 89
88 90 def _connect_completer(self):
89 91 self.completer = KernelCompleter(self.user_ns)
90 92
91 93 def __init__(self, **kwargs):
92 94 super(Kernel, self).__init__(**kwargs)
93 95 self._set_prefix()
94 96 self._connect_completer()
95 97
96 98 self.on_trait_change(self._set_prefix, 'id')
97 99 self.on_trait_change(self._connect_completer, 'user_ns')
98 100
99 101 # Build dict of handlers for message types
100 102 for msg_type in ['execute_request', 'complete_request', 'apply_request',
101 103 'clear_request']:
102 104 self.shell_handlers[msg_type] = getattr(self, msg_type)
103 105
104 106 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
105 107 self.control_handlers[msg_type] = getattr(self, msg_type)
106 108
107 109 self._initial_exec_lines()
108 110
109 111 def _wrap_exception(self, method=None):
110 112 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
111 113 content=wrap_exception(e_info)
112 114 return content
113 115
114 116 def _initial_exec_lines(self):
115 117 s = _Passer()
116 118 content = dict(silent=True, user_variable=[],user_expressions=[])
117 119 for line in self.exec_lines:
118 120 self.log.debug("executing initialization: %s"%line)
119 121 content.update({'code':line})
120 122 msg = self.session.msg('execute_request', content)
121 123 self.execute_request(s, [], msg)
122 124
123 125
124 126 #-------------------- control handlers -----------------------------
125 127 def abort_queues(self):
126 128 for stream in self.shell_streams:
127 129 if stream:
128 130 self.abort_queue(stream)
129 131
130 132 def abort_queue(self, stream):
131 133 while True:
132 134 try:
133 135 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
134 136 except zmq.ZMQError as e:
135 137 if e.errno == zmq.EAGAIN:
136 138 break
137 139 else:
138 140 return
139 141 else:
140 142 if msg is None:
141 143 return
142 144 else:
143 145 idents,msg = msg
144 146
145 147 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
146 148 # msg = self.reply_socket.recv_json()
147 149 self.log.info("Aborting:")
148 150 self.log.info(str(msg))
149 151 msg_type = msg['msg_type']
150 152 reply_type = msg_type.split('_')[0] + '_reply'
151 153 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
152 154 # self.reply_socket.send(ident,zmq.SNDMORE)
153 155 # self.reply_socket.send_json(reply_msg)
154 156 reply_msg = self.session.send(stream, reply_type,
155 157 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
156 158 self.log.debug(str(reply_msg))
157 159 # We need to wait a bit for requests to come in. This can probably
158 160 # be set shorter for true asynchronous clients.
159 161 time.sleep(0.05)
160 162
161 163 def abort_request(self, stream, ident, parent):
162 164 """abort a specifig msg by id"""
163 165 msg_ids = parent['content'].get('msg_ids', None)
164 166 if isinstance(msg_ids, basestring):
165 167 msg_ids = [msg_ids]
166 168 if not msg_ids:
167 169 self.abort_queues()
168 170 for mid in msg_ids:
169 171 self.aborted.add(str(mid))
170 172
171 173 content = dict(status='ok')
172 174 reply_msg = self.session.send(stream, 'abort_reply', content=content,
173 175 parent=parent, ident=ident)
174 176 self.log.debug(str(reply_msg))
175 177
176 178 def shutdown_request(self, stream, ident, parent):
177 179 """kill ourself. This should really be handled in an external process"""
178 180 try:
179 181 self.abort_queues()
180 182 except:
181 183 content = self._wrap_exception('shutdown')
182 184 else:
183 185 content = dict(parent['content'])
184 186 content['status'] = 'ok'
185 187 msg = self.session.send(stream, 'shutdown_reply',
186 188 content=content, parent=parent, ident=ident)
187 189 self.log.debug(str(msg))
188 190 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
189 191 dc.start()
190 192
191 193 def dispatch_control(self, msg):
192 194 idents,msg = self.session.feed_identities(msg, copy=False)
193 195 try:
194 196 msg = self.session.unpack_message(msg, content=True, copy=False)
195 197 except:
196 198 self.log.error("Invalid Message", exc_info=True)
197 199 return
198 200
199 201 header = msg['header']
200 202 msg_id = header['msg_id']
201 203
202 204 handler = self.control_handlers.get(msg['msg_type'], None)
203 205 if handler is None:
204 206 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
205 207 else:
206 208 handler(self.control_stream, idents, msg)
207 209
208 210
209 211 #-------------------- queue helpers ------------------------------
210 212
211 213 def check_dependencies(self, dependencies):
212 214 if not dependencies:
213 215 return True
214 216 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
215 217 anyorall = dependencies[0]
216 218 dependencies = dependencies[1]
217 219 else:
218 220 anyorall = 'all'
219 221 results = self.client.get_results(dependencies,status_only=True)
220 222 if results['status'] != 'ok':
221 223 return False
222 224
223 225 if anyorall == 'any':
224 226 if not results['completed']:
225 227 return False
226 228 else:
227 229 if results['pending']:
228 230 return False
229 231
230 232 return True
231 233
232 234 def check_aborted(self, msg_id):
233 235 return msg_id in self.aborted
234 236
235 237 #-------------------- queue handlers -----------------------------
236 238
237 239 def clear_request(self, stream, idents, parent):
238 240 """Clear our namespace."""
239 241 self.user_ns = {}
240 242 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
241 243 content = dict(status='ok'))
242 244 self._initial_exec_lines()
243 245
244 246 def execute_request(self, stream, ident, parent):
245 247 self.log.debug('execute request %s'%parent)
246 248 try:
247 249 code = parent[u'content'][u'code']
248 250 except:
249 251 self.log.error("Got bad msg: %s"%parent, exc_info=True)
250 252 return
251 253 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
252 254 ident='%s.pyin'%self.prefix)
253 255 started = datetime.now().strftime(ISO8601)
254 256 try:
255 257 comp_code = self.compiler(code, '<zmq-kernel>')
256 258 # allow for not overriding displayhook
257 259 if hasattr(sys.displayhook, 'set_parent'):
258 260 sys.displayhook.set_parent(parent)
259 261 sys.stdout.set_parent(parent)
260 262 sys.stderr.set_parent(parent)
261 263 exec comp_code in self.user_ns, self.user_ns
262 264 except:
263 265 exc_content = self._wrap_exception('execute')
264 266 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
265 267 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
266 268 ident='%s.pyerr'%self.prefix)
267 269 reply_content = exc_content
268 270 else:
269 271 reply_content = {'status' : 'ok'}
270 272
271 273 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
272 274 ident=ident, subheader = dict(started=started))
273 275 self.log.debug(str(reply_msg))
274 276 if reply_msg['content']['status'] == u'error':
275 277 self.abort_queues()
276 278
277 279 def complete_request(self, stream, ident, parent):
278 280 matches = {'matches' : self.complete(parent),
279 281 'status' : 'ok'}
280 282 completion_msg = self.session.send(stream, 'complete_reply',
281 283 matches, parent, ident)
282 284 # print >> sys.__stdout__, completion_msg
283 285
284 286 def complete(self, msg):
285 287 return self.completer.complete(msg.content.line, msg.content.text)
286 288
287 289 def apply_request(self, stream, ident, parent):
288 290 # flush previous reply, so this request won't block it
289 291 stream.flush(zmq.POLLOUT)
290 292
291 293 try:
292 294 content = parent[u'content']
293 295 bufs = parent[u'buffers']
294 296 msg_id = parent['header']['msg_id']
295 297 # bound = parent['header'].get('bound', False)
296 298 except:
297 299 self.log.error("Got bad msg: %s"%parent, exc_info=True)
298 300 return
299 301 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
300 302 # self.iopub_stream.send(pyin_msg)
301 303 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
302 304 sub = {'dependencies_met' : True, 'engine' : self.ident,
303 305 'started': datetime.now().strftime(ISO8601)}
304 306 try:
305 307 # allow for not overriding displayhook
306 308 if hasattr(sys.displayhook, 'set_parent'):
307 309 sys.displayhook.set_parent(parent)
308 310 sys.stdout.set_parent(parent)
309 311 sys.stderr.set_parent(parent)
310 312 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
311 313 working = self.user_ns
312 314 # suffix =
313 315 prefix = "_"+str(msg_id).replace("-","")+"_"
314 316
315 317 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
316 318 # if bound:
317 319 # bound_ns = Namespace(working)
318 320 # args = [bound_ns]+list(args)
319 321
320 322 fname = getattr(f, '__name__', 'f')
321 323
322 324 fname = prefix+"f"
323 325 argname = prefix+"args"
324 326 kwargname = prefix+"kwargs"
325 327 resultname = prefix+"result"
326 328
327 329 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
328 330 # print ns
329 331 working.update(ns)
330 332 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
331 333 try:
332 334 exec code in working,working
333 335 result = working.get(resultname)
334 336 finally:
335 337 for key in ns.iterkeys():
336 338 working.pop(key)
337 339 # if bound:
338 340 # working.update(bound_ns)
339 341
340 342 packed_result,buf = serialize_object(result)
341 343 result_buf = [packed_result]+buf
342 344 except:
343 345 exc_content = self._wrap_exception('apply')
344 346 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
345 347 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
346 348 ident='%s.pyerr'%self.prefix)
347 349 reply_content = exc_content
348 350 result_buf = []
349 351
350 352 if exc_content['ename'] == 'UnmetDependency':
351 353 sub['dependencies_met'] = False
352 354 else:
353 355 reply_content = {'status' : 'ok'}
354 356
355 357 # put 'ok'/'error' status in header, for scheduler introspection:
356 358 sub['status'] = reply_content['status']
357 359
358 360 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
359 361 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
360 362
361 363 # flush i/o
362 364 # should this be before reply_msg is sent, like in the single-kernel code,
363 365 # or should nothing get in the way of real results?
364 366 sys.stdout.flush()
365 367 sys.stderr.flush()
366 368
367 369 def dispatch_queue(self, stream, msg):
368 370 self.control_stream.flush()
369 371 idents,msg = self.session.feed_identities(msg, copy=False)
370 372 try:
371 373 msg = self.session.unpack_message(msg, content=True, copy=False)
372 374 except:
373 375 self.log.error("Invalid Message", exc_info=True)
374 376 return
375 377
376 378
377 379 header = msg['header']
378 380 msg_id = header['msg_id']
379 381 if self.check_aborted(msg_id):
380 382 self.aborted.remove(msg_id)
381 383 # is it safe to assume a msg_id will not be resubmitted?
382 384 reply_type = msg['msg_type'].split('_')[0] + '_reply'
383 385 status = {'status' : 'aborted'}
384 386 reply_msg = self.session.send(stream, reply_type, subheader=status,
385 387 content=status, parent=msg, ident=idents)
386 388 return
387 389 handler = self.shell_handlers.get(msg['msg_type'], None)
388 390 if handler is None:
389 391 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
390 392 else:
391 393 handler(stream, idents, msg)
392 394
393 395 def start(self):
394 396 #### stream mode:
395 397 if self.control_stream:
396 398 self.control_stream.on_recv(self.dispatch_control, copy=False)
397 399 self.control_stream.on_err(printer)
398 400
399 401 def make_dispatcher(stream):
400 402 def dispatcher(msg):
401 403 return self.dispatch_queue(stream, msg)
402 404 return dispatcher
403 405
404 406 for s in self.shell_streams:
405 407 s.on_recv(make_dispatcher(s), copy=False)
406 408 s.on_err(printer)
407 409
408 410 if self.iopub_stream:
409 411 self.iopub_stream.on_err(printer)
410 412
411 413 #### while True mode:
412 414 # while True:
413 415 # idle = True
414 416 # try:
415 417 # msg = self.shell_stream.socket.recv_multipart(
416 418 # zmq.NOBLOCK, copy=False)
417 419 # except zmq.ZMQError, e:
418 420 # if e.errno != zmq.EAGAIN:
419 421 # raise e
420 422 # else:
421 423 # idle=False
422 424 # self.dispatch_queue(self.shell_stream, msg)
423 425 #
424 426 # if not self.task_stream.empty():
425 427 # idle=False
426 428 # msg = self.task_stream.recv_multipart()
427 429 # self.dispatch_queue(self.task_stream, msg)
428 430 # if idle:
429 431 # # don't busywait
430 432 # time.sleep(1e-3)
431 433
@@ -1,152 +1,95
1 1 """Base config factories."""
2 2
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (C) 2008-2009 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13
14 14
15 15 import logging
16 16 import os
17 import uuid
18 17
19 18 from zmq.eventloop.ioloop import IOLoop
20 19
21 20 from IPython.config.configurable import Configurable
22 from IPython.utils.importstring import import_item
23 21 from IPython.utils.traitlets import Str,Int,Instance, CUnicode, CStr
24 22
25 23 import IPython.parallel.streamsession as ss
26 24 from IPython.parallel.util import select_random_ports
27 25
28 26 #-----------------------------------------------------------------------------
29 27 # Classes
30 28 #-----------------------------------------------------------------------------
31 29 class LoggingFactory(Configurable):
32 30 """A most basic class, that has a `log` (type:`Logger`) attribute, set via a `logname` Trait."""
33 31 log = Instance('logging.Logger', ('ZMQ', logging.WARN))
34 32 logname = CUnicode('ZMQ')
35 33 def _logname_changed(self, name, old, new):
36 34 self.log = logging.getLogger(new)
37 35
38 36
39 37 class SessionFactory(LoggingFactory):
40 38 """The Base factory from which every factory in IPython.parallel inherits"""
41 39
42 packer = Str('',config=True)
43 unpacker = Str('',config=True)
44 ident = CStr('',config=True)
45 def _ident_default(self):
46 return str(uuid.uuid4())
47 username = CUnicode(os.environ.get('USER','username'),config=True)
48 exec_key = CUnicode('',config=True)
49 40 # not configurable:
50 41 context = Instance('zmq.Context', (), {})
51 42 session = Instance('IPython.parallel.streamsession.StreamSession')
52 43 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
53 44 def _loop_default(self):
54 45 return IOLoop.instance()
55 46
56 47
57 48 def __init__(self, **kwargs):
58 49 super(SessionFactory, self).__init__(**kwargs)
59 exec_key = self.exec_key or None
60 # set the packers:
61 if not self.packer:
62 packer_f = unpacker_f = None
63 elif self.packer.lower() == 'json':
64 packer_f = ss.json_packer
65 unpacker_f = ss.json_unpacker
66 elif self.packer.lower() == 'pickle':
67 packer_f = ss.pickle_packer
68 unpacker_f = ss.pickle_unpacker
69 else:
70 packer_f = import_item(self.packer)
71 unpacker_f = import_item(self.unpacker)
72 50
73 51 # construct the session
74 self.session = ss.StreamSession(self.username, self.ident, packer=packer_f, unpacker=unpacker_f, key=exec_key)
52 self.session = ss.StreamSession(**kwargs)
75 53
76 54
77 55 class RegistrationFactory(SessionFactory):
78 56 """The Base Configurable for objects that involve registration."""
79 57
80 url = Str('', config=True) # url takes precedence over ip,regport,transport
81 transport = Str('tcp', config=True)
82 ip = Str('127.0.0.1', config=True)
83 regport = Instance(int, config=True)
58 url = Str('', config=True,
59 help="""The 0MQ url used for registration. This sets transport, ip, and port
60 in one variable. For example: url='tcp://127.0.0.1:12345' or
61 url='epgm://*:90210'""") # url takes precedence over ip,regport,transport
62 transport = Str('tcp', config=True,
63 help="""The 0MQ transport for communications. This will likely be
64 the default of 'tcp', but other values include 'ipc', 'epgm', 'inproc'.""")
65 ip = Str('127.0.0.1', config=True,
66 help="""The IP address for registration. This is generally either
67 '127.0.0.1' for loopback only or '*' for all interfaces.
68 [default: '127.0.0.1']""")
69 regport = Int(config=True,
70 help="""The port on which the Hub listens for registration.""")
84 71 def _regport_default(self):
85 # return 10101
86 72 return select_random_ports(1)[0]
87 73
88 74 def __init__(self, **kwargs):
89 75 super(RegistrationFactory, self).__init__(**kwargs)
90 76 self._propagate_url()
91 77 self._rebuild_url()
92 78 self.on_trait_change(self._propagate_url, 'url')
93 79 self.on_trait_change(self._rebuild_url, 'ip')
94 80 self.on_trait_change(self._rebuild_url, 'transport')
95 81 self.on_trait_change(self._rebuild_url, 'regport')
96 82
97 83 def _rebuild_url(self):
98 84 self.url = "%s://%s:%i"%(self.transport, self.ip, self.regport)
99 85
100 86 def _propagate_url(self):
101 87 """Ensure self.url contains full transport://interface:port"""
102 88 if self.url:
103 89 iface = self.url.split('://',1)
104 90 if len(iface) == 2:
105 91 self.transport,iface = iface
106 92 iface = iface.split(':')
107 93 self.ip = iface[0]
108 94 if iface[1]:
109 95 self.regport = int(iface[1])
110
111 #-----------------------------------------------------------------------------
112 # argparse argument extenders
113 #-----------------------------------------------------------------------------
114
115
116 def add_session_arguments(parser):
117 paa = parser.add_argument
118 paa('--ident',
119 type=str, dest='SessionFactory.ident',
120 help='set the ZMQ and session identity [default: random uuid]',
121 metavar='identity')
122 # paa('--execkey',
123 # type=str, dest='SessionFactory.exec_key',
124 # help='path to a file containing an execution key.',
125 # metavar='execkey')
126 paa('--packer',
127 type=str, dest='SessionFactory.packer',
128 help='method to serialize messages: {json,pickle} [default: json]',
129 metavar='packer')
130 paa('--unpacker',
131 type=str, dest='SessionFactory.unpacker',
132 help='inverse function of `packer`. Only necessary when using something other than json|pickle',
133 metavar='packer')
134
135 def add_registration_arguments(parser):
136 paa = parser.add_argument
137 paa('--ip',
138 type=str, dest='RegistrationFactory.ip',
139 help="The IP used for registration [default: localhost]",
140 metavar='ip')
141 paa('--transport',
142 type=str, dest='RegistrationFactory.transport',
143 help="The ZeroMQ transport used for registration [default: tcp]",
144 metavar='transport')
145 paa('--url',
146 type=str, dest='RegistrationFactory.url',
147 help='set transport,ip,regport in one go, e.g. tcp://127.0.0.1:10101',
148 metavar='url')
149 paa('--regport',
150 type=int, dest='RegistrationFactory.regport',
151 help="The port used for registration [default: 10101]",
152 metavar='ip')
@@ -1,419 +1,446
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2010-2011 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11
12 12 import os
13 13 import pprint
14 14 import uuid
15 15 from datetime import datetime
16 16
17 17 try:
18 18 import cPickle
19 19 pickle = cPickle
20 20 except:
21 21 cPickle = None
22 22 import pickle
23 23
24 24 import zmq
25 25 from zmq.utils import jsonapi
26 26 from zmq.eventloop.zmqstream import ZMQStream
27 27
28 from IPython.config.configurable import Configurable
29 from IPython.utils.importstring import import_item
30 from IPython.utils.traitlets import Str, CStr, CUnicode, Bool, Any
31
28 32 from .util import ISO8601
29 33
34
30 35 def squash_unicode(obj):
31 36 """coerce unicode back to bytestrings."""
32 37 if isinstance(obj,dict):
33 38 for key in obj.keys():
34 39 obj[key] = squash_unicode(obj[key])
35 40 if isinstance(key, unicode):
36 41 obj[squash_unicode(key)] = obj.pop(key)
37 42 elif isinstance(obj, list):
38 43 for i,v in enumerate(obj):
39 44 obj[i] = squash_unicode(v)
40 45 elif isinstance(obj, unicode):
41 46 obj = obj.encode('utf8')
42 47 return obj
43 48
44 49 def _date_default(obj):
45 50 if isinstance(obj, datetime):
46 51 return obj.strftime(ISO8601)
47 52 else:
48 53 raise TypeError("%r is not JSON serializable"%obj)
49 54
50 55 _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
51 56 json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default})
52 57 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
53 58
54 59 pickle_packer = lambda o: pickle.dumps(o,-1)
55 60 pickle_unpacker = pickle.loads
56 61
57 62 default_packer = json_packer
58 63 default_unpacker = json_unpacker
59 64
60 65
61 66 DELIM="<IDS|MSG>"
62 67
63 68 class Message(object):
64 69 """A simple message object that maps dict keys to attributes.
65 70
66 71 A Message can be created from a dict and a dict from a Message instance
67 72 simply by calling dict(msg_obj)."""
68 73
69 74 def __init__(self, msg_dict):
70 75 dct = self.__dict__
71 76 for k, v in dict(msg_dict).iteritems():
72 77 if isinstance(v, dict):
73 78 v = Message(v)
74 79 dct[k] = v
75 80
76 81 # Having this iterator lets dict(msg_obj) work out of the box.
77 82 def __iter__(self):
78 83 return iter(self.__dict__.iteritems())
79 84
80 85 def __repr__(self):
81 86 return repr(self.__dict__)
82 87
83 88 def __str__(self):
84 89 return pprint.pformat(self.__dict__)
85 90
86 91 def __contains__(self, k):
87 92 return k in self.__dict__
88 93
89 94 def __getitem__(self, k):
90 95 return self.__dict__[k]
91 96
92 97
93 98 def msg_header(msg_id, msg_type, username, session):
94 99 date=datetime.now().strftime(ISO8601)
95 100 return locals()
96 101
97 102 def extract_header(msg_or_header):
98 103 """Given a message or header, return the header."""
99 104 if not msg_or_header:
100 105 return {}
101 106 try:
102 107 # See if msg_or_header is the entire message.
103 108 h = msg_or_header['header']
104 109 except KeyError:
105 110 try:
106 111 # See if msg_or_header is just the header
107 112 h = msg_or_header['msg_id']
108 113 except KeyError:
109 114 raise
110 115 else:
111 116 h = msg_or_header
112 117 if not isinstance(h, dict):
113 118 h = dict(h)
114 119 return h
115 120
116 class StreamSession(object):
121 class StreamSession(Configurable):
117 122 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
118 debug=False
119 key=None
120
121 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
122 if username is None:
123 username = os.environ.get('USER','username')
124 self.username = username
125 if session is None:
126 self.session = str(uuid.uuid4())
123 debug=Bool(False, config=True, help="""Debug output in the StreamSession""")
124 packer = Str('json',config=True,
125 help="""The name of the packer for serializing messages.
126 Should be one of 'json', 'pickle', or an import name
127 for a custom serializer.""")
128 def _packer_changed(self, name, old, new):
129 if new.lower() == 'json':
130 self.pack = json_packer
131 self.unpack = json_unpacker
132 elif new.lower() == 'pickle':
133 self.pack = pickle_packer
134 self.unpack = pickle_unpacker
127 135 else:
128 self.session = session
129 self.msg_id = str(uuid.uuid4())
130 if packer is None:
131 self.pack = default_packer
136 self.pack = import_item(new)
137
138 unpacker = Str('json',config=True,
139 help="""The name of the unpacker for unserializing messages.
140 Only used with custom functions for `packer`.""")
141 def _unpacker_changed(self, name, old, new):
142 if new.lower() == 'json':
143 self.pack = json_packer
144 self.unpack = json_unpacker
145 elif new.lower() == 'pickle':
146 self.pack = pickle_packer
147 self.unpack = pickle_unpacker
132 148 else:
133 if not callable(packer):
134 raise TypeError("packer must be callable, not %s"%type(packer))
135 self.pack = packer
149 self.unpack = import_item(new)
136 150
137 if unpacker is None:
138 self.unpack = default_unpacker
139 else:
140 if not callable(unpacker):
141 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
142 self.unpack = unpacker
151 session = CStr('',config=True,
152 help="""The UUID identifying this session.""")
153 def _session_default(self):
154 return str(uuid.uuid4())
155 username = CUnicode(os.environ.get('USER','username'),config=True,
156 help="""Username for the Session. Default is your system username.""")
157 key = CStr('', config=True,
158 help="""execution key, for extra authentication.""")
159
160 keyfile = CUnicode('', config=True,
161 help="""path to file containing execution key.""")
162 def _keyfile_changed(self, name, old, new):
163 with open(new, 'rb') as f:
164 self.key = f.read().strip()
165
166 pack = Any(default_packer) # the actual packer function
167 def _pack_changed(self, name, old, new):
168 if not callable(new):
169 raise TypeError("packer must be callable, not %s"%type(new))
143 170
144 if key is not None and keyfile is not None:
145 raise TypeError("Must specify key OR keyfile, not both")
146 if keyfile is not None:
147 with open(keyfile) as f:
148 self.key = f.read().strip()
149 else:
150 self.key = key
151 if isinstance(self.key, unicode):
152 self.key = self.key.encode('utf8')
153 # print key, keyfile, self.key
171 unpack = Any(default_unpacker) # the actual packer function
172 def _unpack_changed(self, name, old, new):
173 if not callable(new):
174 raise TypeError("packer must be callable, not %s"%type(new))
175
176 def __init__(self, **kwargs):
177 super(StreamSession, self).__init__(**kwargs)
154 178 self.none = self.pack({})
155
179
180 @property
181 def msg_id(self):
182 """always return new uuid"""
183 return str(uuid.uuid4())
184
156 185 def msg_header(self, msg_type):
157 h = msg_header(self.msg_id, msg_type, self.username, self.session)
158 self.msg_id = str(uuid.uuid4())
159 return h
186 return msg_header(self.msg_id, msg_type, self.username, self.session)
160 187
161 188 def msg(self, msg_type, content=None, parent=None, subheader=None):
162 189 msg = {}
163 190 msg['header'] = self.msg_header(msg_type)
164 191 msg['msg_id'] = msg['header']['msg_id']
165 192 msg['parent_header'] = {} if parent is None else extract_header(parent)
166 193 msg['msg_type'] = msg_type
167 194 msg['content'] = {} if content is None else content
168 195 sub = {} if subheader is None else subheader
169 196 msg['header'].update(sub)
170 197 return msg
171 198
172 199 def check_key(self, msg_or_header):
173 200 """Check that a message's header has the right key"""
174 if self.key is None:
201 if not self.key:
175 202 return True
176 203 header = extract_header(msg_or_header)
177 return header.get('key', None) == self.key
204 return header.get('key', '') == self.key
178 205
179 206
180 207 def serialize(self, msg, ident=None):
181 208 content = msg.get('content', {})
182 209 if content is None:
183 210 content = self.none
184 211 elif isinstance(content, dict):
185 212 content = self.pack(content)
186 213 elif isinstance(content, bytes):
187 214 # content is already packed, as in a relayed message
188 215 pass
189 216 elif isinstance(content, unicode):
190 217 # should be bytes, but JSON often spits out unicode
191 218 content = content.encode('utf8')
192 219 else:
193 220 raise TypeError("Content incorrect type: %s"%type(content))
194 221
195 222 to_send = []
196 223
197 224 if isinstance(ident, list):
198 225 # accept list of idents
199 226 to_send.extend(ident)
200 227 elif ident is not None:
201 228 to_send.append(ident)
202 229 to_send.append(DELIM)
203 if self.key is not None:
230 if self.key:
204 231 to_send.append(self.key)
205 232 to_send.append(self.pack(msg['header']))
206 233 to_send.append(self.pack(msg['parent_header']))
207 234 to_send.append(content)
208 235
209 236 return to_send
210 237
211 238 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
212 239 """Build and send a message via stream or socket.
213 240
214 241 Parameters
215 242 ----------
216 243
217 244 stream : zmq.Socket or ZMQStream
218 245 the socket-like object used to send the data
219 246 msg_or_type : str or Message/dict
220 247 Normally, msg_or_type will be a msg_type unless a message is being sent more
221 248 than once.
222 249
223 250 content : dict or None
224 251 the content of the message (ignored if msg_or_type is a message)
225 252 buffers : list or None
226 253 the already-serialized buffers to be appended to the message
227 254 parent : Message or dict or None
228 255 the parent or parent header describing the parent of this message
229 256 subheader : dict or None
230 257 extra header keys for this message's header
231 258 ident : bytes or list of bytes
232 259 the zmq.IDENTITY routing path
233 260 track : bool
234 261 whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
235 262
236 263 Returns
237 264 -------
238 265 msg : message dict
239 266 the constructed message
240 267 (msg,tracker) : (message dict, MessageTracker)
241 268 if track=True, then a 2-tuple will be returned, the first element being the constructed
242 269 message, and the second being the MessageTracker
243 270
244 271 """
245 272
246 273 if not isinstance(stream, (zmq.Socket, ZMQStream)):
247 274 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
248 275 elif track and isinstance(stream, ZMQStream):
249 276 raise TypeError("ZMQStream cannot track messages")
250 277
251 278 if isinstance(msg_or_type, (Message, dict)):
252 279 # we got a Message, not a msg_type
253 280 # don't build a new Message
254 281 msg = msg_or_type
255 282 else:
256 283 msg = self.msg(msg_or_type, content, parent, subheader)
257 284
258 285 buffers = [] if buffers is None else buffers
259 286 to_send = self.serialize(msg, ident)
260 287 flag = 0
261 288 if buffers:
262 289 flag = zmq.SNDMORE
263 290 _track = False
264 291 else:
265 292 _track=track
266 293 if track:
267 294 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
268 295 else:
269 296 tracker = stream.send_multipart(to_send, flag, copy=False)
270 297 for b in buffers[:-1]:
271 298 stream.send(b, flag, copy=False)
272 299 if buffers:
273 300 if track:
274 301 tracker = stream.send(buffers[-1], copy=False, track=track)
275 302 else:
276 303 tracker = stream.send(buffers[-1], copy=False)
277 304
278 305 # omsg = Message(msg)
279 306 if self.debug:
280 307 pprint.pprint(msg)
281 308 pprint.pprint(to_send)
282 309 pprint.pprint(buffers)
283 310
284 311 msg['tracker'] = tracker
285 312
286 313 return msg
287 314
288 315 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
289 316 """Send a raw message via ident path.
290 317
291 318 Parameters
292 319 ----------
293 320 msg : list of sendable buffers"""
294 321 to_send = []
295 322 if isinstance(ident, bytes):
296 323 ident = [ident]
297 324 if ident is not None:
298 325 to_send.extend(ident)
299 326 to_send.append(DELIM)
300 if self.key is not None:
327 if self.key:
301 328 to_send.append(self.key)
302 329 to_send.extend(msg)
303 330 stream.send_multipart(msg, flags, copy=copy)
304 331
305 332 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
306 333 """receives and unpacks a message
307 334 returns [idents], msg"""
308 335 if isinstance(socket, ZMQStream):
309 336 socket = socket.socket
310 337 try:
311 338 msg = socket.recv_multipart(mode, copy=copy)
312 339 except zmq.ZMQError as e:
313 340 if e.errno == zmq.EAGAIN:
314 341 # We can convert EAGAIN to None as we know in this case
315 342 # recv_multipart won't return None.
316 343 return None
317 344 else:
318 345 raise
319 346 # return an actual Message object
320 347 # determine the number of idents by trying to unpack them.
321 348 # this is terrible:
322 349 idents, msg = self.feed_identities(msg, copy)
323 350 try:
324 351 return idents, self.unpack_message(msg, content=content, copy=copy)
325 352 except Exception as e:
326 353 print (idents, msg)
327 354 # TODO: handle it
328 355 raise e
329 356
330 357 def feed_identities(self, msg, copy=True):
331 358 """feed until DELIM is reached, then return the prefix as idents and remainder as
332 359 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
333 360
334 361 Parameters
335 362 ----------
336 363 msg : a list of Message or bytes objects
337 364 the message to be split
338 365 copy : bool
339 366 flag determining whether the arguments are bytes or Messages
340 367
341 368 Returns
342 369 -------
343 370 (idents,msg) : two lists
344 371 idents will always be a list of bytes - the indentity prefix
345 372 msg will be a list of bytes or Messages, unchanged from input
346 373 msg should be unpackable via self.unpack_message at this point.
347 374 """
348 ikey = int(self.key is not None)
375 ikey = int(self.key != '')
349 376 minlen = 3 + ikey
350 377 msg = list(msg)
351 378 idents = []
352 379 while len(msg) > minlen:
353 380 if copy:
354 381 s = msg[0]
355 382 else:
356 383 s = msg[0].bytes
357 384 if s == DELIM:
358 385 msg.pop(0)
359 386 break
360 387 else:
361 388 idents.append(s)
362 389 msg.pop(0)
363 390
364 391 return idents, msg
365 392
366 393 def unpack_message(self, msg, content=True, copy=True):
367 394 """Return a message object from the format
368 395 sent by self.send.
369 396
370 397 Parameters:
371 398 -----------
372 399
373 400 content : bool (True)
374 401 whether to unpack the content dict (True),
375 402 or leave it serialized (False)
376 403
377 404 copy : bool (True)
378 405 whether to return the bytes (True),
379 406 or the non-copying Message object in each place (False)
380 407
381 408 """
382 ikey = int(self.key is not None)
409 ikey = int(self.key != '')
383 410 minlen = 3 + ikey
384 411 message = {}
385 412 if not copy:
386 413 for i in range(minlen):
387 414 msg[i] = msg[i].bytes
388 415 if ikey:
389 416 if not self.key == msg[0]:
390 417 raise KeyError("Invalid Session Key: %s"%msg[0])
391 418 if not len(msg) >= minlen:
392 419 raise TypeError("malformed message, must have at least %i elements"%minlen)
393 420 message['header'] = self.unpack(msg[ikey+0])
394 421 message['msg_type'] = message['header']['msg_type']
395 422 message['parent_header'] = self.unpack(msg[ikey+1])
396 423 if content:
397 424 message['content'] = self.unpack(msg[ikey+2])
398 425 else:
399 426 message['content'] = msg[ikey+2]
400 427
401 428 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
402 429 return message
403 430
404 431
405 432 def test_msg2obj():
406 433 am = dict(x=1)
407 434 ao = Message(am)
408 435 assert ao.x == am['x']
409 436
410 437 am['y'] = dict(z=1)
411 438 ao = Message(am)
412 439 assert ao.y.z == am['y']['z']
413 440
414 441 k1, k2 = 'y', 'z'
415 442 assert ao[k1][k2] == am[k1][k2]
416 443
417 444 am2 = dict(ao)
418 445 assert am['x'] == am2['x']
419 446 assert am['y']['z'] == am2['y']['z']
1 NO CONTENT: file was removed
General Comments 0
You need to be logged in to leave comments. Login now