##// END OF EJS Templates
reorganize Factory classes to follow relocation of Session object
MinRK -
Show More
@@ -1,521 +1,521 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The ipcluster application.
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2009 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import errno
19 19 import logging
20 20 import os
21 21 import re
22 22 import signal
23 23
24 24 from subprocess import check_call, CalledProcessError, PIPE
25 25 import zmq
26 26 from zmq.eventloop import ioloop
27 27
28 28 from IPython.config.application import Application, boolean_flag
29 29 from IPython.config.loader import Config
30 30 from IPython.core.newapplication import BaseIPythonApplication, ProfileDir
31 31 from IPython.utils.importstring import import_item
32 32 from IPython.utils.traitlets import Int, Unicode, Bool, CFloat, Dict, List
33 33
34 34 from IPython.parallel.apps.baseapp import (
35 35 BaseParallelApplication,
36 36 PIDFileError,
37 37 base_flags, base_aliases
38 38 )
39 39
40 40
41 41 #-----------------------------------------------------------------------------
42 42 # Module level variables
43 43 #-----------------------------------------------------------------------------
44 44
45 45
46 46 default_config_file_name = u'ipcluster_config.py'
47 47
48 48
49 49 _description = """Start an IPython cluster for parallel computing.
50 50
51 51 An IPython cluster consists of 1 controller and 1 or more engines.
52 52 This command automates the startup of these processes using a wide
53 53 range of startup methods (SSH, local processes, PBS, mpiexec,
54 54 Windows HPC Server 2008). To start a cluster with 4 engines on your
55 55 local host simply do 'ipcluster start n=4'. For more complex usage
56 56 you will typically do 'ipcluster create profile=mycluster', then edit
57 57 configuration files, followed by 'ipcluster start profile=mycluster n=4'.
58 58 """
59 59
60 60
61 61 # Exit codes for ipcluster
62 62
63 63 # This will be the exit code if the ipcluster appears to be running because
64 64 # a .pid file exists
65 65 ALREADY_STARTED = 10
66 66
67 67
68 68 # This will be the exit code if ipcluster stop is run, but there is not .pid
69 69 # file to be found.
70 70 ALREADY_STOPPED = 11
71 71
72 72 # This will be the exit code if ipcluster engines is run, but there is not .pid
73 73 # file to be found.
74 74 NO_CLUSTER = 12
75 75
76 76
77 77 #-----------------------------------------------------------------------------
78 78 # Main application
79 79 #-----------------------------------------------------------------------------
80 80 start_help = """Start an IPython cluster for parallel computing
81 81
82 82 Start an ipython cluster by its profile name or cluster
83 83 directory. Cluster directories contain configuration, log and
84 84 security related files and are named using the convention
85 85 'cluster_<profile>' and should be creating using the 'start'
86 86 subcommand of 'ipcluster'. If your cluster directory is in
87 87 the cwd or the ipython directory, you can simply refer to it
88 88 using its profile name, 'ipcluster start n=4 profile=<profile>`,
89 89 otherwise use the 'profile_dir' option.
90 90 """
91 91 stop_help = """Stop a running IPython cluster
92 92
93 93 Stop a running ipython cluster by its profile name or cluster
94 94 directory. Cluster directories are named using the convention
95 95 'cluster_<profile>'. If your cluster directory is in
96 96 the cwd or the ipython directory, you can simply refer to it
97 97 using its profile name, 'ipcluster stop profile=<profile>`, otherwise
98 98 use the 'profile_dir' option.
99 99 """
100 100 engines_help = """Start engines connected to an existing IPython cluster
101 101
102 102 Start one or more engines to connect to an existing Cluster
103 103 by profile name or cluster directory.
104 104 Cluster directories contain configuration, log and
105 105 security related files and are named using the convention
106 106 'cluster_<profile>' and should be creating using the 'start'
107 107 subcommand of 'ipcluster'. If your cluster directory is in
108 108 the cwd or the ipython directory, you can simply refer to it
109 109 using its profile name, 'ipcluster engines n=4 profile=<profile>`,
110 110 otherwise use the 'profile_dir' option.
111 111 """
112 112 create_help = """Create an ipcluster profile by name
113 113
114 114 Create an ipython cluster directory by its profile name or
115 115 cluster directory path. Cluster directories contain
116 116 configuration, log and security related files and are named
117 117 using the convention 'cluster_<profile>'. By default they are
118 118 located in your ipython directory. Once created, you will
119 119 probably need to edit the configuration files in the cluster
120 120 directory to configure your cluster. Most users will create a
121 121 cluster directory by profile name,
122 122 `ipcluster create profile=mycluster`, which will put the directory
123 123 in `<ipython_dir>/cluster_mycluster`.
124 124 """
125 125 list_help = """List available cluster profiles
126 126
127 127 List all available clusters, by cluster directory, that can
128 128 be found in the current working directly or in the ipython
129 129 directory. Cluster directories are named using the convention
130 130 'cluster_<profile>'.
131 131 """
132 132
133 133
134 134 class IPClusterList(BaseIPythonApplication):
135 135 name = u'ipcluster-list'
136 136 description = list_help
137 137
138 138 # empty aliases
139 139 aliases=Dict()
140 140 flags = Dict(base_flags)
141 141
142 142 def _log_level_default(self):
143 143 return 20
144 144
145 145 def list_profile_dirs(self):
146 146 # Find the search paths
147 147 profile_dir_paths = os.environ.get('IPYTHON_PROFILE_PATH','')
148 148 if profile_dir_paths:
149 149 profile_dir_paths = profile_dir_paths.split(':')
150 150 else:
151 151 profile_dir_paths = []
152 152
153 153 ipython_dir = self.ipython_dir
154 154
155 155 paths = [os.getcwd(), ipython_dir] + profile_dir_paths
156 156 paths = list(set(paths))
157 157
158 158 self.log.info('Searching for cluster profiles in paths: %r' % paths)
159 159 for path in paths:
160 160 files = os.listdir(path)
161 161 for f in files:
162 162 full_path = os.path.join(path, f)
163 163 if os.path.isdir(full_path) and f.startswith('profile_') and \
164 164 os.path.isfile(os.path.join(full_path, 'ipcontroller_config.py')):
165 165 profile = f.split('_')[-1]
166 166 start_cmd = 'ipcluster start profile=%s n=4' % profile
167 167 print start_cmd + " ==> " + full_path
168 168
169 169 def start(self):
170 170 self.list_profile_dirs()
171 171
172 172
173 173 # `ipcluster create` will be deprecated when `ipython profile create` or equivalent exists
174 174
175 175 create_flags = {}
176 176 create_flags.update(base_flags)
177 177 create_flags.update(boolean_flag('reset', 'IPClusterCreate.overwrite',
178 178 "reset config files to defaults", "leave existing config files"))
179 179
180 180 class IPClusterCreate(BaseParallelApplication):
181 181 name = u'ipcluster-create'
182 182 description = create_help
183 183 auto_create = Bool(True)
184 184 config_file_name = Unicode(default_config_file_name)
185 185
186 186 flags = Dict(create_flags)
187 187
188 188 aliases = Dict(dict(profile='BaseIPythonApplication.profile'))
189 189
190 190 classes = [ProfileDir]
191 191
192 192
193 193 stop_aliases = dict(
194 194 signal='IPClusterStop.signal',
195 195 profile='BaseIPythonApplication.profile',
196 196 profile_dir='ProfileDir.location',
197 197 )
198 198
199 199 class IPClusterStop(BaseParallelApplication):
200 200 name = u'ipcluster'
201 201 description = stop_help
202 202 config_file_name = Unicode(default_config_file_name)
203 203
204 204 signal = Int(signal.SIGINT, config=True,
205 205 help="signal to use for stopping processes.")
206 206
207 207 aliases = Dict(stop_aliases)
208 208
209 209 def start(self):
210 210 """Start the app for the stop subcommand."""
211 211 try:
212 212 pid = self.get_pid_from_file()
213 213 except PIDFileError:
214 214 self.log.critical(
215 215 'Could not read pid file, cluster is probably not running.'
216 216 )
217 217 # Here I exit with a unusual exit status that other processes
218 218 # can watch for to learn how I existed.
219 219 self.remove_pid_file()
220 220 self.exit(ALREADY_STOPPED)
221 221
222 222 if not self.check_pid(pid):
223 223 self.log.critical(
224 224 'Cluster [pid=%r] is not running.' % pid
225 225 )
226 226 self.remove_pid_file()
227 227 # Here I exit with a unusual exit status that other processes
228 228 # can watch for to learn how I existed.
229 229 self.exit(ALREADY_STOPPED)
230 230
231 231 elif os.name=='posix':
232 232 sig = self.signal
233 233 self.log.info(
234 234 "Stopping cluster [pid=%r] with [signal=%r]" % (pid, sig)
235 235 )
236 236 try:
237 237 os.kill(pid, sig)
238 238 except OSError:
239 239 self.log.error("Stopping cluster failed, assuming already dead.",
240 240 exc_info=True)
241 241 self.remove_pid_file()
242 242 elif os.name=='nt':
243 243 try:
244 244 # kill the whole tree
245 245 p = check_call(['taskkill', '-pid', str(pid), '-t', '-f'], stdout=PIPE,stderr=PIPE)
246 246 except (CalledProcessError, OSError):
247 247 self.log.error("Stopping cluster failed, assuming already dead.",
248 248 exc_info=True)
249 249 self.remove_pid_file()
250 250
251 251 engine_aliases = {}
252 252 engine_aliases.update(base_aliases)
253 253 engine_aliases.update(dict(
254 254 n='IPClusterEngines.n',
255 255 elauncher = 'IPClusterEngines.engine_launcher_class',
256 256 ))
257 257 class IPClusterEngines(BaseParallelApplication):
258 258
259 259 name = u'ipcluster'
260 260 description = engines_help
261 261 usage = None
262 262 config_file_name = Unicode(default_config_file_name)
263 263 default_log_level = logging.INFO
264 264 classes = List()
265 265 def _classes_default(self):
266 266 from IPython.parallel.apps import launcher
267 267 launchers = launcher.all_launchers
268 268 eslaunchers = [ l for l in launchers if 'EngineSet' in l.__name__]
269 269 return [ProfileDir]+eslaunchers
270 270
271 271 n = Int(2, config=True,
272 272 help="The number of engines to start.")
273 273
274 274 engine_launcher_class = Unicode('LocalEngineSetLauncher',
275 275 config=True,
276 276 help="The class for launching a set of Engines."
277 277 )
278 278 daemonize = Bool(False, config=True,
279 279 help='Daemonize the ipcluster program. This implies --log-to-file')
280 280
281 281 def _daemonize_changed(self, name, old, new):
282 282 if new:
283 283 self.log_to_file = True
284 284
285 285 aliases = Dict(engine_aliases)
286 286 # flags = Dict(flags)
287 287 _stopping = False
288 288
289 289 def initialize(self, argv=None):
290 290 super(IPClusterEngines, self).initialize(argv)
291 291 self.init_signal()
292 292 self.init_launchers()
293 293
294 294 def init_launchers(self):
295 295 self.engine_launcher = self.build_launcher(self.engine_launcher_class)
296 296 self.engine_launcher.on_stop(lambda r: self.loop.stop())
297 297
298 298 def init_signal(self):
299 299 # Setup signals
300 300 signal.signal(signal.SIGINT, self.sigint_handler)
301 301
302 302 def build_launcher(self, clsname):
303 303 """import and instantiate a Launcher based on importstring"""
304 304 if '.' not in clsname:
305 305 # not a module, presume it's the raw name in apps.launcher
306 306 clsname = 'IPython.parallel.apps.launcher.'+clsname
307 307 # print repr(clsname)
308 308 klass = import_item(clsname)
309 309
310 310 launcher = klass(
311 work_dir=self.profile_dir.location, config=self.config, logname=self.log.name
311 work_dir=self.profile_dir.location, config=self.config, log=self.log
312 312 )
313 313 return launcher
314 314
315 315 def start_engines(self):
316 316 self.log.info("Starting %i engines"%self.n)
317 317 self.engine_launcher.start(
318 318 self.n,
319 319 self.profile_dir.location
320 320 )
321 321
322 322 def stop_engines(self):
323 323 self.log.info("Stopping Engines...")
324 324 if self.engine_launcher.running:
325 325 d = self.engine_launcher.stop()
326 326 return d
327 327 else:
328 328 return None
329 329
330 330 def stop_launchers(self, r=None):
331 331 if not self._stopping:
332 332 self._stopping = True
333 333 self.log.error("IPython cluster: stopping")
334 334 self.stop_engines()
335 335 # Wait a few seconds to let things shut down.
336 336 dc = ioloop.DelayedCallback(self.loop.stop, 4000, self.loop)
337 337 dc.start()
338 338
339 339 def sigint_handler(self, signum, frame):
340 340 self.log.debug("SIGINT received, stopping launchers...")
341 341 self.stop_launchers()
342 342
343 343 def start_logging(self):
344 344 # Remove old log files of the controller and engine
345 345 if self.clean_logs:
346 346 log_dir = self.profile_dir.log_dir
347 347 for f in os.listdir(log_dir):
348 348 if re.match(r'ip(engine|controller)z-\d+\.(log|err|out)',f):
349 349 os.remove(os.path.join(log_dir, f))
350 350 # This will remove old log files for ipcluster itself
351 351 # super(IPBaseParallelApplication, self).start_logging()
352 352
353 353 def start(self):
354 354 """Start the app for the engines subcommand."""
355 355 self.log.info("IPython cluster: started")
356 356 # First see if the cluster is already running
357 357
358 358 # Now log and daemonize
359 359 self.log.info(
360 360 'Starting engines with [daemon=%r]' % self.daemonize
361 361 )
362 362 # TODO: Get daemonize working on Windows or as a Windows Server.
363 363 if self.daemonize:
364 364 if os.name=='posix':
365 365 from twisted.scripts._twistd_unix import daemonize
366 366 daemonize()
367 367
368 368 dc = ioloop.DelayedCallback(self.start_engines, 0, self.loop)
369 369 dc.start()
370 370 # Now write the new pid file AFTER our new forked pid is active.
371 371 # self.write_pid_file()
372 372 try:
373 373 self.loop.start()
374 374 except KeyboardInterrupt:
375 375 pass
376 376 except zmq.ZMQError as e:
377 377 if e.errno == errno.EINTR:
378 378 pass
379 379 else:
380 380 raise
381 381
382 382 start_aliases = {}
383 383 start_aliases.update(engine_aliases)
384 384 start_aliases.update(dict(
385 385 delay='IPClusterStart.delay',
386 386 clean_logs='IPClusterStart.clean_logs',
387 387 ))
388 388
389 389 class IPClusterStart(IPClusterEngines):
390 390
391 391 name = u'ipcluster'
392 392 description = start_help
393 393 default_log_level = logging.INFO
394 394 auto_create = Bool(True, config=True,
395 395 help="whether to create the profile_dir if it doesn't exist")
396 396 classes = List()
397 397 def _classes_default(self,):
398 398 from IPython.parallel.apps import launcher
399 399 return [ProfileDir]+launcher.all_launchers
400 400
401 401 clean_logs = Bool(True, config=True,
402 402 help="whether to cleanup old logs before starting")
403 403
404 404 delay = CFloat(1., config=True,
405 405 help="delay (in s) between starting the controller and the engines")
406 406
407 407 controller_launcher_class = Unicode('LocalControllerLauncher',
408 408 config=True,
409 409 help="The class for launching a Controller."
410 410 )
411 411 reset = Bool(False, config=True,
412 412 help="Whether to reset config files as part of '--create'."
413 413 )
414 414
415 415 # flags = Dict(flags)
416 416 aliases = Dict(start_aliases)
417 417
418 418 def init_launchers(self):
419 419 self.controller_launcher = self.build_launcher(self.controller_launcher_class)
420 420 self.engine_launcher = self.build_launcher(self.engine_launcher_class)
421 421 self.controller_launcher.on_stop(self.stop_launchers)
422 422
423 423 def start_controller(self):
424 424 self.controller_launcher.start(
425 425 self.profile_dir.location
426 426 )
427 427
428 428 def stop_controller(self):
429 429 # self.log.info("In stop_controller")
430 430 if self.controller_launcher and self.controller_launcher.running:
431 431 return self.controller_launcher.stop()
432 432
433 433 def stop_launchers(self, r=None):
434 434 if not self._stopping:
435 435 self.stop_controller()
436 436 super(IPClusterStart, self).stop_launchers()
437 437
438 438 def start(self):
439 439 """Start the app for the start subcommand."""
440 440 # First see if the cluster is already running
441 441 try:
442 442 pid = self.get_pid_from_file()
443 443 except PIDFileError:
444 444 pass
445 445 else:
446 446 if self.check_pid(pid):
447 447 self.log.critical(
448 448 'Cluster is already running with [pid=%s]. '
449 449 'use "ipcluster stop" to stop the cluster.' % pid
450 450 )
451 451 # Here I exit with a unusual exit status that other processes
452 452 # can watch for to learn how I existed.
453 453 self.exit(ALREADY_STARTED)
454 454 else:
455 455 self.remove_pid_file()
456 456
457 457
458 458 # Now log and daemonize
459 459 self.log.info(
460 460 'Starting ipcluster with [daemon=%r]' % self.daemonize
461 461 )
462 462 # TODO: Get daemonize working on Windows or as a Windows Server.
463 463 if self.daemonize:
464 464 if os.name=='posix':
465 465 from twisted.scripts._twistd_unix import daemonize
466 466 daemonize()
467 467
468 468 dc = ioloop.DelayedCallback(self.start_controller, 0, self.loop)
469 469 dc.start()
470 470 dc = ioloop.DelayedCallback(self.start_engines, 1000*self.delay, self.loop)
471 471 dc.start()
472 472 # Now write the new pid file AFTER our new forked pid is active.
473 473 self.write_pid_file()
474 474 try:
475 475 self.loop.start()
476 476 except KeyboardInterrupt:
477 477 pass
478 478 except zmq.ZMQError as e:
479 479 if e.errno == errno.EINTR:
480 480 pass
481 481 else:
482 482 raise
483 483 finally:
484 484 self.remove_pid_file()
485 485
486 486 base='IPython.parallel.apps.ipclusterapp.IPCluster'
487 487
488 488 class IPBaseParallelApplication(Application):
489 489 name = u'ipcluster'
490 490 description = _description
491 491
492 492 subcommands = {'create' : (base+'Create', create_help),
493 493 'list' : (base+'List', list_help),
494 494 'start' : (base+'Start', start_help),
495 495 'stop' : (base+'Stop', stop_help),
496 496 'engines' : (base+'Engines', engines_help),
497 497 }
498 498
499 499 # no aliases or flags for parent App
500 500 aliases = Dict()
501 501 flags = Dict()
502 502
503 503 def start(self):
504 504 if self.subapp is None:
505 505 print "No subcommand specified! Must specify one of: %s"%(self.subcommands.keys())
506 506 print
507 507 self.print_subcommands()
508 508 self.exit(1)
509 509 else:
510 510 return self.subapp.start()
511 511
512 512 def launch_new_instance():
513 513 """Create and run the IPython cluster."""
514 514 app = IPBaseParallelApplication.instance()
515 515 app.initialize()
516 516 app.start()
517 517
518 518
519 519 if __name__ == '__main__':
520 520 launch_new_instance()
521 521
@@ -1,96 +1,96 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 A simple IPython logger application
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import sys
20 20
21 21 import zmq
22 22
23 23 from IPython.core.newapplication import ProfileDir
24 24 from IPython.utils.traitlets import Bool, Dict, Unicode
25 25
26 26 from IPython.parallel.apps.baseapp import (
27 27 BaseParallelApplication,
28 28 base_aliases
29 29 )
30 30 from IPython.parallel.apps.logwatcher import LogWatcher
31 31
32 32 #-----------------------------------------------------------------------------
33 33 # Module level variables
34 34 #-----------------------------------------------------------------------------
35 35
36 36 #: The default config file name for this application
37 37 default_config_file_name = u'iplogger_config.py'
38 38
39 39 _description = """Start an IPython logger for parallel computing.
40 40
41 41 IPython controllers and engines (and your own processes) can broadcast log messages
42 42 by registering a `zmq.log.handlers.PUBHandler` with the `logging` module. The
43 43 logger can be configured using command line options or using a cluster
44 44 directory. Cluster directories contain config, log and security files and are
45 45 usually located in your ipython directory and named as "cluster_<profile>".
46 46 See the `profile` and `profile_dir` options for details.
47 47 """
48 48
49 49
50 50 #-----------------------------------------------------------------------------
51 51 # Main application
52 52 #-----------------------------------------------------------------------------
53 53 aliases = {}
54 54 aliases.update(base_aliases)
55 55 aliases.update(dict(url='LogWatcher.url', topics='LogWatcher.topics'))
56 56
57 57 class IPLoggerApp(BaseParallelApplication):
58 58
59 59 name = u'iploggerz'
60 60 description = _description
61 61 config_file_name = Unicode(default_config_file_name)
62 62
63 63 classes = [LogWatcher, ProfileDir]
64 64 aliases = Dict(aliases)
65 65
66 66 def initialize(self, argv=None):
67 67 super(IPLoggerApp, self).initialize(argv)
68 68 self.init_watcher()
69 69
70 70 def init_watcher(self):
71 71 try:
72 self.watcher = LogWatcher(config=self.config, logname=self.log.name)
72 self.watcher = LogWatcher(config=self.config, log=self.log)
73 73 except:
74 74 self.log.error("Couldn't start the LogWatcher", exc_info=True)
75 75 self.exit(1)
76 76 self.log.info("Listening for log messages on %r"%self.watcher.url)
77 77
78 78
79 79 def start(self):
80 80 self.watcher.start()
81 81 try:
82 82 self.watcher.loop.start()
83 83 except KeyboardInterrupt:
84 84 self.log.critical("Logging Interrupted, shutting down...\n")
85 85
86 86
87 87 def launch_new_instance():
88 88 """Create and run the IPython LogWatcher"""
89 89 app = IPLoggerApp.instance()
90 90 app.initialize()
91 91 app.start()
92 92
93 93
94 94 if __name__ == '__main__':
95 95 launch_new_instance()
96 96
@@ -1,1070 +1,1069 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 Facilities for launching IPython processes asynchronously.
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2009 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import copy
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23
24 24 # signal imports, handling various platforms, versions
25 25
26 26 from signal import SIGINT, SIGTERM
27 27 try:
28 28 from signal import SIGKILL
29 29 except ImportError:
30 30 # Windows
31 31 SIGKILL=SIGTERM
32 32
33 33 try:
34 34 # Windows >= 2.7, 3.2
35 35 from signal import CTRL_C_EVENT as SIGINT
36 36 except ImportError:
37 37 pass
38 38
39 39 from subprocess import Popen, PIPE, STDOUT
40 40 try:
41 41 from subprocess import check_output
42 42 except ImportError:
43 43 # pre-2.7, define check_output with Popen
44 44 def check_output(*args, **kwargs):
45 45 kwargs.update(dict(stdout=PIPE))
46 46 p = Popen(*args, **kwargs)
47 47 out,err = p.communicate()
48 48 return out
49 49
50 50 from zmq.eventloop import ioloop
51 51
52 # from IPython.config.configurable import Configurable
52 from IPython.config.configurable import Configurable
53 53 from IPython.utils.text import EvalFormatter
54 54 from IPython.utils.traitlets import Any, Int, List, Unicode, Dict, Instance
55 55 from IPython.utils.path import get_ipython_module_path
56 56 from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError
57 57
58 from IPython.parallel.factory import LoggingFactory
59
60 58 from .win32support import forward_read_events
61 59
62 60 from .winhpcjob import IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob
63 61
64 62 WINDOWS = os.name == 'nt'
65 63
66 64 #-----------------------------------------------------------------------------
67 65 # Paths to the kernel apps
68 66 #-----------------------------------------------------------------------------
69 67
70 68
71 69 ipcluster_cmd_argv = pycmd2argv(get_ipython_module_path(
72 70 'IPython.parallel.apps.ipclusterapp'
73 71 ))
74 72
75 73 ipengine_cmd_argv = pycmd2argv(get_ipython_module_path(
76 74 'IPython.parallel.apps.ipengineapp'
77 75 ))
78 76
79 77 ipcontroller_cmd_argv = pycmd2argv(get_ipython_module_path(
80 78 'IPython.parallel.apps.ipcontrollerapp'
81 79 ))
82 80
83 81 #-----------------------------------------------------------------------------
84 82 # Base launchers and errors
85 83 #-----------------------------------------------------------------------------
86 84
87 85
88 86 class LauncherError(Exception):
89 87 pass
90 88
91 89
92 90 class ProcessStateError(LauncherError):
93 91 pass
94 92
95 93
96 94 class UnknownStatus(LauncherError):
97 95 pass
98 96
99 97
100 class BaseLauncher(LoggingFactory):
98 class BaseLauncher(Configurable):
101 99 """An asbtraction for starting, stopping and signaling a process."""
102 100
103 101 # In all of the launchers, the work_dir is where child processes will be
104 102 # run. This will usually be the profile_dir, but may not be. any work_dir
105 103 # passed into the __init__ method will override the config value.
106 104 # This should not be used to set the work_dir for the actual engine
107 105 # and controller. Instead, use their own config files or the
108 106 # controller_args, engine_args attributes of the launchers to add
109 107 # the work_dir option.
110 108 work_dir = Unicode(u'.')
111 109 loop = Instance('zmq.eventloop.ioloop.IOLoop')
110 log = Instance('logging.Logger', ('root',))
112 111
113 112 start_data = Any()
114 113 stop_data = Any()
115 114
116 115 def _loop_default(self):
117 116 return ioloop.IOLoop.instance()
118 117
119 118 def __init__(self, work_dir=u'.', config=None, **kwargs):
120 119 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
121 120 self.state = 'before' # can be before, running, after
122 121 self.stop_callbacks = []
123 122 self.start_data = None
124 123 self.stop_data = None
125 124
126 125 @property
127 126 def args(self):
128 127 """A list of cmd and args that will be used to start the process.
129 128
130 129 This is what is passed to :func:`spawnProcess` and the first element
131 130 will be the process name.
132 131 """
133 132 return self.find_args()
134 133
135 134 def find_args(self):
136 135 """The ``.args`` property calls this to find the args list.
137 136
138 137 Subcommand should implement this to construct the cmd and args.
139 138 """
140 139 raise NotImplementedError('find_args must be implemented in a subclass')
141 140
142 141 @property
143 142 def arg_str(self):
144 143 """The string form of the program arguments."""
145 144 return ' '.join(self.args)
146 145
147 146 @property
148 147 def running(self):
149 148 """Am I running."""
150 149 if self.state == 'running':
151 150 return True
152 151 else:
153 152 return False
154 153
155 154 def start(self):
156 155 """Start the process.
157 156
158 157 This must return a deferred that fires with information about the
159 158 process starting (like a pid, job id, etc.).
160 159 """
161 160 raise NotImplementedError('start must be implemented in a subclass')
162 161
163 162 def stop(self):
164 163 """Stop the process and notify observers of stopping.
165 164
166 165 This must return a deferred that fires with information about the
167 166 processing stopping, like errors that occur while the process is
168 167 attempting to be shut down. This deferred won't fire when the process
169 168 actually stops. To observe the actual process stopping, see
170 169 :func:`observe_stop`.
171 170 """
172 171 raise NotImplementedError('stop must be implemented in a subclass')
173 172
174 173 def on_stop(self, f):
175 174 """Get a deferred that will fire when the process stops.
176 175
177 176 The deferred will fire with data that contains information about
178 177 the exit status of the process.
179 178 """
180 179 if self.state=='after':
181 180 return f(self.stop_data)
182 181 else:
183 182 self.stop_callbacks.append(f)
184 183
185 184 def notify_start(self, data):
186 185 """Call this to trigger startup actions.
187 186
188 187 This logs the process startup and sets the state to 'running'. It is
189 188 a pass-through so it can be used as a callback.
190 189 """
191 190
192 191 self.log.info('Process %r started: %r' % (self.args[0], data))
193 192 self.start_data = data
194 193 self.state = 'running'
195 194 return data
196 195
197 196 def notify_stop(self, data):
198 197 """Call this to trigger process stop actions.
199 198
200 199 This logs the process stopping and sets the state to 'after'. Call
201 200 this to trigger all the deferreds from :func:`observe_stop`."""
202 201
203 202 self.log.info('Process %r stopped: %r' % (self.args[0], data))
204 203 self.stop_data = data
205 204 self.state = 'after'
206 205 for i in range(len(self.stop_callbacks)):
207 206 d = self.stop_callbacks.pop()
208 207 d(data)
209 208 return data
210 209
211 210 def signal(self, sig):
212 211 """Signal the process.
213 212
214 213 Return a semi-meaningless deferred after signaling the process.
215 214
216 215 Parameters
217 216 ----------
218 217 sig : str or int
219 218 'KILL', 'INT', etc., or any signal number
220 219 """
221 220 raise NotImplementedError('signal must be implemented in a subclass')
222 221
223 222
224 223 #-----------------------------------------------------------------------------
225 224 # Local process launchers
226 225 #-----------------------------------------------------------------------------
227 226
228 227
229 228 class LocalProcessLauncher(BaseLauncher):
230 229 """Start and stop an external process in an asynchronous manner.
231 230
232 231 This will launch the external process with a working directory of
233 232 ``self.work_dir``.
234 233 """
235 234
236 235 # This is used to to construct self.args, which is passed to
237 236 # spawnProcess.
238 237 cmd_and_args = List([])
239 238 poll_frequency = Int(100) # in ms
240 239
241 240 def __init__(self, work_dir=u'.', config=None, **kwargs):
242 241 super(LocalProcessLauncher, self).__init__(
243 242 work_dir=work_dir, config=config, **kwargs
244 243 )
245 244 self.process = None
246 245 self.start_deferred = None
247 246 self.poller = None
248 247
249 248 def find_args(self):
250 249 return self.cmd_and_args
251 250
252 251 def start(self):
253 252 if self.state == 'before':
254 253 self.process = Popen(self.args,
255 254 stdout=PIPE,stderr=PIPE,stdin=PIPE,
256 255 env=os.environ,
257 256 cwd=self.work_dir
258 257 )
259 258 if WINDOWS:
260 259 self.stdout = forward_read_events(self.process.stdout)
261 260 self.stderr = forward_read_events(self.process.stderr)
262 261 else:
263 262 self.stdout = self.process.stdout.fileno()
264 263 self.stderr = self.process.stderr.fileno()
265 264 self.loop.add_handler(self.stdout, self.handle_stdout, self.loop.READ)
266 265 self.loop.add_handler(self.stderr, self.handle_stderr, self.loop.READ)
267 266 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
268 267 self.poller.start()
269 268 self.notify_start(self.process.pid)
270 269 else:
271 270 s = 'The process was already started and has state: %r' % self.state
272 271 raise ProcessStateError(s)
273 272
274 273 def stop(self):
275 274 return self.interrupt_then_kill()
276 275
277 276 def signal(self, sig):
278 277 if self.state == 'running':
279 278 if WINDOWS and sig != SIGINT:
280 279 # use Windows tree-kill for better child cleanup
281 280 check_output(['taskkill', '-pid', str(self.process.pid), '-t', '-f'])
282 281 else:
283 282 self.process.send_signal(sig)
284 283
285 284 def interrupt_then_kill(self, delay=2.0):
286 285 """Send INT, wait a delay and then send KILL."""
287 286 try:
288 287 self.signal(SIGINT)
289 288 except Exception:
290 289 self.log.debug("interrupt failed")
291 290 pass
292 291 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
293 292 self.killer.start()
294 293
295 294 # callbacks, etc:
296 295
297 296 def handle_stdout(self, fd, events):
298 297 if WINDOWS:
299 298 line = self.stdout.recv()
300 299 else:
301 300 line = self.process.stdout.readline()
302 301 # a stopped process will be readable but return empty strings
303 302 if line:
304 303 self.log.info(line[:-1])
305 304 else:
306 305 self.poll()
307 306
308 307 def handle_stderr(self, fd, events):
309 308 if WINDOWS:
310 309 line = self.stderr.recv()
311 310 else:
312 311 line = self.process.stderr.readline()
313 312 # a stopped process will be readable but return empty strings
314 313 if line:
315 314 self.log.error(line[:-1])
316 315 else:
317 316 self.poll()
318 317
319 318 def poll(self):
320 319 status = self.process.poll()
321 320 if status is not None:
322 321 self.poller.stop()
323 322 self.loop.remove_handler(self.stdout)
324 323 self.loop.remove_handler(self.stderr)
325 324 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
326 325 return status
327 326
328 327 class LocalControllerLauncher(LocalProcessLauncher):
329 328 """Launch a controller as a regular external process."""
330 329
331 330 controller_cmd = List(ipcontroller_cmd_argv, config=True,
332 331 help="""Popen command to launch ipcontroller.""")
333 332 # Command line arguments to ipcontroller.
334 333 controller_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
335 334 help="""command-line args to pass to ipcontroller""")
336 335
337 336 def find_args(self):
338 337 return self.controller_cmd + self.controller_args
339 338
340 339 def start(self, profile_dir):
341 340 """Start the controller by profile_dir."""
342 341 self.controller_args.extend(['profile_dir=%s'%profile_dir])
343 342 self.profile_dir = unicode(profile_dir)
344 343 self.log.info("Starting LocalControllerLauncher: %r" % self.args)
345 344 return super(LocalControllerLauncher, self).start()
346 345
347 346
348 347 class LocalEngineLauncher(LocalProcessLauncher):
349 348 """Launch a single engine as a regular externall process."""
350 349
351 350 engine_cmd = List(ipengine_cmd_argv, config=True,
352 351 help="""command to launch the Engine.""")
353 352 # Command line arguments for ipengine.
354 353 engine_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
355 354 help="command-line arguments to pass to ipengine"
356 355 )
357 356
358 357 def find_args(self):
359 358 return self.engine_cmd + self.engine_args
360 359
361 360 def start(self, profile_dir):
362 361 """Start the engine by profile_dir."""
363 362 self.engine_args.extend(['profile_dir=%s'%profile_dir])
364 363 self.profile_dir = unicode(profile_dir)
365 364 return super(LocalEngineLauncher, self).start()
366 365
367 366
368 367 class LocalEngineSetLauncher(BaseLauncher):
369 368 """Launch a set of engines as regular external processes."""
370 369
371 370 # Command line arguments for ipengine.
372 371 engine_args = List(
373 372 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
374 373 help="command-line arguments to pass to ipengine"
375 374 )
376 375 # launcher class
377 376 launcher_class = LocalEngineLauncher
378 377
379 378 launchers = Dict()
380 379 stop_data = Dict()
381 380
382 381 def __init__(self, work_dir=u'.', config=None, **kwargs):
383 382 super(LocalEngineSetLauncher, self).__init__(
384 383 work_dir=work_dir, config=config, **kwargs
385 384 )
386 385 self.stop_data = {}
387 386
388 387 def start(self, n, profile_dir):
389 388 """Start n engines by profile or profile_dir."""
390 389 self.profile_dir = unicode(profile_dir)
391 390 dlist = []
392 391 for i in range(n):
393 el = self.launcher_class(work_dir=self.work_dir, config=self.config, logname=self.log.name)
392 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
394 393 # Copy the engine args over to each engine launcher.
395 394 el.engine_args = copy.deepcopy(self.engine_args)
396 395 el.on_stop(self._notice_engine_stopped)
397 396 d = el.start(profile_dir)
398 397 if i==0:
399 398 self.log.info("Starting LocalEngineSetLauncher: %r" % el.args)
400 399 self.launchers[i] = el
401 400 dlist.append(d)
402 401 self.notify_start(dlist)
403 402 # The consumeErrors here could be dangerous
404 403 # dfinal = gatherBoth(dlist, consumeErrors=True)
405 404 # dfinal.addCallback(self.notify_start)
406 405 return dlist
407 406
408 407 def find_args(self):
409 408 return ['engine set']
410 409
411 410 def signal(self, sig):
412 411 dlist = []
413 412 for el in self.launchers.itervalues():
414 413 d = el.signal(sig)
415 414 dlist.append(d)
416 415 # dfinal = gatherBoth(dlist, consumeErrors=True)
417 416 return dlist
418 417
419 418 def interrupt_then_kill(self, delay=1.0):
420 419 dlist = []
421 420 for el in self.launchers.itervalues():
422 421 d = el.interrupt_then_kill(delay)
423 422 dlist.append(d)
424 423 # dfinal = gatherBoth(dlist, consumeErrors=True)
425 424 return dlist
426 425
427 426 def stop(self):
428 427 return self.interrupt_then_kill()
429 428
430 429 def _notice_engine_stopped(self, data):
431 430 pid = data['pid']
432 431 for idx,el in self.launchers.iteritems():
433 432 if el.process.pid == pid:
434 433 break
435 434 self.launchers.pop(idx)
436 435 self.stop_data[idx] = data
437 436 if not self.launchers:
438 437 self.notify_stop(self.stop_data)
439 438
440 439
441 440 #-----------------------------------------------------------------------------
442 441 # MPIExec launchers
443 442 #-----------------------------------------------------------------------------
444 443
445 444
446 445 class MPIExecLauncher(LocalProcessLauncher):
447 446 """Launch an external process using mpiexec."""
448 447
449 448 mpi_cmd = List(['mpiexec'], config=True,
450 449 help="The mpiexec command to use in starting the process."
451 450 )
452 451 mpi_args = List([], config=True,
453 452 help="The command line arguments to pass to mpiexec."
454 453 )
455 454 program = List(['date'], config=True,
456 455 help="The program to start via mpiexec.")
457 456 program_args = List([], config=True,
458 457 help="The command line argument to the program."
459 458 )
460 459 n = Int(1)
461 460
462 461 def find_args(self):
463 462 """Build self.args using all the fields."""
464 463 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
465 464 self.program + self.program_args
466 465
467 466 def start(self, n):
468 467 """Start n instances of the program using mpiexec."""
469 468 self.n = n
470 469 return super(MPIExecLauncher, self).start()
471 470
472 471
473 472 class MPIExecControllerLauncher(MPIExecLauncher):
474 473 """Launch a controller using mpiexec."""
475 474
476 475 controller_cmd = List(ipcontroller_cmd_argv, config=True,
477 476 help="Popen command to launch the Contropper"
478 477 )
479 478 controller_args = List(['--log-to-file','log_level=%i'%logging.INFO], config=True,
480 479 help="Command line arguments to pass to ipcontroller."
481 480 )
482 481 n = Int(1)
483 482
484 483 def start(self, profile_dir):
485 484 """Start the controller by profile_dir."""
486 485 self.controller_args.extend(['profile_dir=%s'%profile_dir])
487 486 self.profile_dir = unicode(profile_dir)
488 487 self.log.info("Starting MPIExecControllerLauncher: %r" % self.args)
489 488 return super(MPIExecControllerLauncher, self).start(1)
490 489
491 490 def find_args(self):
492 491 return self.mpi_cmd + ['-n', self.n] + self.mpi_args + \
493 492 self.controller_cmd + self.controller_args
494 493
495 494
496 495 class MPIExecEngineSetLauncher(MPIExecLauncher):
497 496
498 497 program = List(ipengine_cmd_argv, config=True,
499 498 help="Popen command for ipengine"
500 499 )
501 500 program_args = List(
502 501 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
503 502 help="Command line arguments for ipengine."
504 503 )
505 504 n = Int(1)
506 505
507 506 def start(self, n, profile_dir):
508 507 """Start n engines by profile or profile_dir."""
509 508 self.program_args.extend(['profile_dir=%s'%profile_dir])
510 509 self.profile_dir = unicode(profile_dir)
511 510 self.n = n
512 511 self.log.info('Starting MPIExecEngineSetLauncher: %r' % self.args)
513 512 return super(MPIExecEngineSetLauncher, self).start(n)
514 513
515 514 #-----------------------------------------------------------------------------
516 515 # SSH launchers
517 516 #-----------------------------------------------------------------------------
518 517
519 518 # TODO: Get SSH Launcher working again.
520 519
521 520 class SSHLauncher(LocalProcessLauncher):
522 521 """A minimal launcher for ssh.
523 522
524 523 To be useful this will probably have to be extended to use the ``sshx``
525 524 idea for environment variables. There could be other things this needs
526 525 as well.
527 526 """
528 527
529 528 ssh_cmd = List(['ssh'], config=True,
530 529 help="command for starting ssh")
531 530 ssh_args = List(['-tt'], config=True,
532 531 help="args to pass to ssh")
533 532 program = List(['date'], config=True,
534 533 help="Program to launch via ssh")
535 534 program_args = List([], config=True,
536 535 help="args to pass to remote program")
537 536 hostname = Unicode('', config=True,
538 537 help="hostname on which to launch the program")
539 538 user = Unicode('', config=True,
540 539 help="username for ssh")
541 540 location = Unicode('', config=True,
542 541 help="user@hostname location for ssh in one setting")
543 542
544 543 def _hostname_changed(self, name, old, new):
545 544 if self.user:
546 545 self.location = u'%s@%s' % (self.user, new)
547 546 else:
548 547 self.location = new
549 548
550 549 def _user_changed(self, name, old, new):
551 550 self.location = u'%s@%s' % (new, self.hostname)
552 551
553 552 def find_args(self):
554 553 return self.ssh_cmd + self.ssh_args + [self.location] + \
555 554 self.program + self.program_args
556 555
557 556 def start(self, profile_dir, hostname=None, user=None):
558 557 self.profile_dir = unicode(profile_dir)
559 558 if hostname is not None:
560 559 self.hostname = hostname
561 560 if user is not None:
562 561 self.user = user
563 562
564 563 return super(SSHLauncher, self).start()
565 564
566 565 def signal(self, sig):
567 566 if self.state == 'running':
568 567 # send escaped ssh connection-closer
569 568 self.process.stdin.write('~.')
570 569 self.process.stdin.flush()
571 570
572 571
573 572
574 573 class SSHControllerLauncher(SSHLauncher):
575 574
576 575 program = List(ipcontroller_cmd_argv, config=True,
577 576 help="remote ipcontroller command.")
578 577 program_args = List(['--reuse-files', '--log-to-file','log_level=%i'%logging.INFO], config=True,
579 578 help="Command line arguments to ipcontroller.")
580 579
581 580
582 581 class SSHEngineLauncher(SSHLauncher):
583 582 program = List(ipengine_cmd_argv, config=True,
584 583 help="remote ipengine command.")
585 584 # Command line arguments for ipengine.
586 585 program_args = List(
587 586 ['--log-to-file','log_level=%i'%logging.INFO], config=True,
588 587 help="Command line arguments to ipengine."
589 588 )
590 589
591 590 class SSHEngineSetLauncher(LocalEngineSetLauncher):
592 591 launcher_class = SSHEngineLauncher
593 592 engines = Dict(config=True,
594 593 help="""dict of engines to launch. This is a dict by hostname of ints,
595 594 corresponding to the number of engines to start on that host.""")
596 595
597 596 def start(self, n, profile_dir):
598 597 """Start engines by profile or profile_dir.
599 598 `n` is ignored, and the `engines` config property is used instead.
600 599 """
601 600
602 601 self.profile_dir = unicode(profile_dir)
603 602 dlist = []
604 603 for host, n in self.engines.iteritems():
605 604 if isinstance(n, (tuple, list)):
606 605 n, args = n
607 606 else:
608 607 args = copy.deepcopy(self.engine_args)
609 608
610 609 if '@' in host:
611 610 user,host = host.split('@',1)
612 611 else:
613 612 user=None
614 613 for i in range(n):
615 el = self.launcher_class(work_dir=self.work_dir, config=self.config, logname=self.log.name)
614 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
616 615
617 616 # Copy the engine args over to each engine launcher.
618 617 i
619 618 el.program_args = args
620 619 el.on_stop(self._notice_engine_stopped)
621 620 d = el.start(profile_dir, user=user, hostname=host)
622 621 if i==0:
623 622 self.log.info("Starting SSHEngineSetLauncher: %r" % el.args)
624 623 self.launchers[host+str(i)] = el
625 624 dlist.append(d)
626 625 self.notify_start(dlist)
627 626 return dlist
628 627
629 628
630 629
631 630 #-----------------------------------------------------------------------------
632 631 # Windows HPC Server 2008 scheduler launchers
633 632 #-----------------------------------------------------------------------------
634 633
635 634
636 635 # This is only used on Windows.
637 636 def find_job_cmd():
638 637 if WINDOWS:
639 638 try:
640 639 return find_cmd('job')
641 640 except (FindCmdError, ImportError):
642 641 # ImportError will be raised if win32api is not installed
643 642 return 'job'
644 643 else:
645 644 return 'job'
646 645
647 646
648 647 class WindowsHPCLauncher(BaseLauncher):
649 648
650 649 job_id_regexp = Unicode(r'\d+', config=True,
651 650 help="""A regular expression used to get the job id from the output of the
652 651 submit_command. """
653 652 )
654 653 job_file_name = Unicode(u'ipython_job.xml', config=True,
655 654 help="The filename of the instantiated job script.")
656 655 # The full path to the instantiated job script. This gets made dynamically
657 656 # by combining the work_dir with the job_file_name.
658 657 job_file = Unicode(u'')
659 658 scheduler = Unicode('', config=True,
660 659 help="The hostname of the scheduler to submit the job to.")
661 660 job_cmd = Unicode(find_job_cmd(), config=True,
662 661 help="The command for submitting jobs.")
663 662
664 663 def __init__(self, work_dir=u'.', config=None, **kwargs):
665 664 super(WindowsHPCLauncher, self).__init__(
666 665 work_dir=work_dir, config=config, **kwargs
667 666 )
668 667
669 668 @property
670 669 def job_file(self):
671 670 return os.path.join(self.work_dir, self.job_file_name)
672 671
673 672 def write_job_file(self, n):
674 673 raise NotImplementedError("Implement write_job_file in a subclass.")
675 674
676 675 def find_args(self):
677 676 return [u'job.exe']
678 677
679 678 def parse_job_id(self, output):
680 679 """Take the output of the submit command and return the job id."""
681 680 m = re.search(self.job_id_regexp, output)
682 681 if m is not None:
683 682 job_id = m.group()
684 683 else:
685 684 raise LauncherError("Job id couldn't be determined: %s" % output)
686 685 self.job_id = job_id
687 686 self.log.info('Job started with job id: %r' % job_id)
688 687 return job_id
689 688
690 689 def start(self, n):
691 690 """Start n copies of the process using the Win HPC job scheduler."""
692 691 self.write_job_file(n)
693 692 args = [
694 693 'submit',
695 694 '/jobfile:%s' % self.job_file,
696 695 '/scheduler:%s' % self.scheduler
697 696 ]
698 697 self.log.info("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
699 698 # Twisted will raise DeprecationWarnings if we try to pass unicode to this
700 699 output = check_output([self.job_cmd]+args,
701 700 env=os.environ,
702 701 cwd=self.work_dir,
703 702 stderr=STDOUT
704 703 )
705 704 job_id = self.parse_job_id(output)
706 705 self.notify_start(job_id)
707 706 return job_id
708 707
709 708 def stop(self):
710 709 args = [
711 710 'cancel',
712 711 self.job_id,
713 712 '/scheduler:%s' % self.scheduler
714 713 ]
715 714 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
716 715 try:
717 716 output = check_output([self.job_cmd]+args,
718 717 env=os.environ,
719 718 cwd=self.work_dir,
720 719 stderr=STDOUT
721 720 )
722 721 except:
723 722 output = 'The job already appears to be stoppped: %r' % self.job_id
724 723 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
725 724 return output
726 725
727 726
728 727 class WindowsHPCControllerLauncher(WindowsHPCLauncher):
729 728
730 729 job_file_name = Unicode(u'ipcontroller_job.xml', config=True,
731 730 help="WinHPC xml job file.")
732 731 extra_args = List([], config=False,
733 732 help="extra args to pass to ipcontroller")
734 733
735 734 def write_job_file(self, n):
736 735 job = IPControllerJob(config=self.config)
737 736
738 737 t = IPControllerTask(config=self.config)
739 738 # The tasks work directory is *not* the actual work directory of
740 739 # the controller. It is used as the base path for the stdout/stderr
741 740 # files that the scheduler redirects to.
742 741 t.work_directory = self.profile_dir
743 742 # Add the profile_dir and from self.start().
744 743 t.controller_args.extend(self.extra_args)
745 744 job.add_task(t)
746 745
747 746 self.log.info("Writing job description file: %s" % self.job_file)
748 747 job.write(self.job_file)
749 748
750 749 @property
751 750 def job_file(self):
752 751 return os.path.join(self.profile_dir, self.job_file_name)
753 752
754 753 def start(self, profile_dir):
755 754 """Start the controller by profile_dir."""
756 755 self.extra_args = ['profile_dir=%s'%profile_dir]
757 756 self.profile_dir = unicode(profile_dir)
758 757 return super(WindowsHPCControllerLauncher, self).start(1)
759 758
760 759
761 760 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher):
762 761
763 762 job_file_name = Unicode(u'ipengineset_job.xml', config=True,
764 763 help="jobfile for ipengines job")
765 764 extra_args = List([], config=False,
766 765 help="extra args to pas to ipengine")
767 766
768 767 def write_job_file(self, n):
769 768 job = IPEngineSetJob(config=self.config)
770 769
771 770 for i in range(n):
772 771 t = IPEngineTask(config=self.config)
773 772 # The tasks work directory is *not* the actual work directory of
774 773 # the engine. It is used as the base path for the stdout/stderr
775 774 # files that the scheduler redirects to.
776 775 t.work_directory = self.profile_dir
777 776 # Add the profile_dir and from self.start().
778 777 t.engine_args.extend(self.extra_args)
779 778 job.add_task(t)
780 779
781 780 self.log.info("Writing job description file: %s" % self.job_file)
782 781 job.write(self.job_file)
783 782
784 783 @property
785 784 def job_file(self):
786 785 return os.path.join(self.profile_dir, self.job_file_name)
787 786
788 787 def start(self, n, profile_dir):
789 788 """Start the controller by profile_dir."""
790 789 self.extra_args = ['profile_dir=%s'%profile_dir]
791 790 self.profile_dir = unicode(profile_dir)
792 791 return super(WindowsHPCEngineSetLauncher, self).start(n)
793 792
794 793
795 794 #-----------------------------------------------------------------------------
796 795 # Batch (PBS) system launchers
797 796 #-----------------------------------------------------------------------------
798 797
799 798 class BatchSystemLauncher(BaseLauncher):
800 799 """Launch an external process using a batch system.
801 800
802 801 This class is designed to work with UNIX batch systems like PBS, LSF,
803 802 GridEngine, etc. The overall model is that there are different commands
804 803 like qsub, qdel, etc. that handle the starting and stopping of the process.
805 804
806 805 This class also has the notion of a batch script. The ``batch_template``
807 806 attribute can be set to a string that is a template for the batch script.
808 807 This template is instantiated using string formatting. Thus the template can
809 808 use {n} fot the number of instances. Subclasses can add additional variables
810 809 to the template dict.
811 810 """
812 811
813 812 # Subclasses must fill these in. See PBSEngineSet
814 813 submit_command = List([''], config=True,
815 814 help="The name of the command line program used to submit jobs.")
816 815 delete_command = List([''], config=True,
817 816 help="The name of the command line program used to delete jobs.")
818 817 job_id_regexp = Unicode('', config=True,
819 818 help="""A regular expression used to get the job id from the output of the
820 819 submit_command.""")
821 820 batch_template = Unicode('', config=True,
822 821 help="The string that is the batch script template itself.")
823 822 batch_template_file = Unicode(u'', config=True,
824 823 help="The file that contains the batch template.")
825 824 batch_file_name = Unicode(u'batch_script', config=True,
826 825 help="The filename of the instantiated batch script.")
827 826 queue = Unicode(u'', config=True,
828 827 help="The PBS Queue.")
829 828
830 829 # not configurable, override in subclasses
831 830 # PBS Job Array regex
832 831 job_array_regexp = Unicode('')
833 832 job_array_template = Unicode('')
834 833 # PBS Queue regex
835 834 queue_regexp = Unicode('')
836 835 queue_template = Unicode('')
837 836 # The default batch template, override in subclasses
838 837 default_template = Unicode('')
839 838 # The full path to the instantiated batch script.
840 839 batch_file = Unicode(u'')
841 840 # the format dict used with batch_template:
842 841 context = Dict()
843 842 # the Formatter instance for rendering the templates:
844 843 formatter = Instance(EvalFormatter, (), {})
845 844
846 845
847 846 def find_args(self):
848 847 return self.submit_command + [self.batch_file]
849 848
850 849 def __init__(self, work_dir=u'.', config=None, **kwargs):
851 850 super(BatchSystemLauncher, self).__init__(
852 851 work_dir=work_dir, config=config, **kwargs
853 852 )
854 853 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
855 854
856 855 def parse_job_id(self, output):
857 856 """Take the output of the submit command and return the job id."""
858 857 m = re.search(self.job_id_regexp, output)
859 858 if m is not None:
860 859 job_id = m.group()
861 860 else:
862 861 raise LauncherError("Job id couldn't be determined: %s" % output)
863 862 self.job_id = job_id
864 863 self.log.info('Job submitted with job id: %r' % job_id)
865 864 return job_id
866 865
867 866 def write_batch_script(self, n):
868 867 """Instantiate and write the batch script to the work_dir."""
869 868 self.context['n'] = n
870 869 self.context['queue'] = self.queue
871 870 # first priority is batch_template if set
872 871 if self.batch_template_file and not self.batch_template:
873 872 # second priority is batch_template_file
874 873 with open(self.batch_template_file) as f:
875 874 self.batch_template = f.read()
876 875 if not self.batch_template:
877 876 # third (last) priority is default_template
878 877 self.batch_template = self.default_template
879 878
880 879 regex = re.compile(self.job_array_regexp)
881 880 # print regex.search(self.batch_template)
882 881 if not regex.search(self.batch_template):
883 882 self.log.info("adding job array settings to batch script")
884 883 firstline, rest = self.batch_template.split('\n',1)
885 884 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
886 885
887 886 regex = re.compile(self.queue_regexp)
888 887 # print regex.search(self.batch_template)
889 888 if self.queue and not regex.search(self.batch_template):
890 889 self.log.info("adding PBS queue settings to batch script")
891 890 firstline, rest = self.batch_template.split('\n',1)
892 891 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
893 892
894 893 script_as_string = self.formatter.format(self.batch_template, **self.context)
895 894 self.log.info('Writing instantiated batch script: %s' % self.batch_file)
896 895
897 896 with open(self.batch_file, 'w') as f:
898 897 f.write(script_as_string)
899 898 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
900 899
901 900 def start(self, n, profile_dir):
902 901 """Start n copies of the process using a batch system."""
903 902 # Here we save profile_dir in the context so they
904 903 # can be used in the batch script template as {profile_dir}
905 904 self.context['profile_dir'] = profile_dir
906 905 self.profile_dir = unicode(profile_dir)
907 906 self.write_batch_script(n)
908 907 output = check_output(self.args, env=os.environ)
909 908
910 909 job_id = self.parse_job_id(output)
911 910 self.notify_start(job_id)
912 911 return job_id
913 912
914 913 def stop(self):
915 914 output = check_output(self.delete_command+[self.job_id], env=os.environ)
916 915 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
917 916 return output
918 917
919 918
920 919 class PBSLauncher(BatchSystemLauncher):
921 920 """A BatchSystemLauncher subclass for PBS."""
922 921
923 922 submit_command = List(['qsub'], config=True,
924 923 help="The PBS submit command ['qsub']")
925 924 delete_command = List(['qdel'], config=True,
926 925 help="The PBS delete command ['qsub']")
927 926 job_id_regexp = Unicode(r'\d+', config=True,
928 927 help="Regular expresion for identifying the job ID [r'\d+']")
929 928
930 929 batch_file = Unicode(u'')
931 930 job_array_regexp = Unicode('#PBS\W+-t\W+[\w\d\-\$]+')
932 931 job_array_template = Unicode('#PBS -t 1-{n}')
933 932 queue_regexp = Unicode('#PBS\W+-q\W+\$?\w+')
934 933 queue_template = Unicode('#PBS -q {queue}')
935 934
936 935
937 936 class PBSControllerLauncher(PBSLauncher):
938 937 """Launch a controller using PBS."""
939 938
940 939 batch_file_name = Unicode(u'pbs_controller', config=True,
941 940 help="batch file name for the controller job.")
942 941 default_template= Unicode("""#!/bin/sh
943 942 #PBS -V
944 943 #PBS -N ipcontroller
945 944 %s --log-to-file profile_dir={profile_dir}
946 945 """%(' '.join(ipcontroller_cmd_argv)))
947 946
948 947 def start(self, profile_dir):
949 948 """Start the controller by profile or profile_dir."""
950 949 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
951 950 return super(PBSControllerLauncher, self).start(1, profile_dir)
952 951
953 952
954 953 class PBSEngineSetLauncher(PBSLauncher):
955 954 """Launch Engines using PBS"""
956 955 batch_file_name = Unicode(u'pbs_engines', config=True,
957 956 help="batch file name for the engine(s) job.")
958 957 default_template= Unicode(u"""#!/bin/sh
959 958 #PBS -V
960 959 #PBS -N ipengine
961 960 %s profile_dir={profile_dir}
962 961 """%(' '.join(ipengine_cmd_argv)))
963 962
964 963 def start(self, n, profile_dir):
965 964 """Start n engines by profile or profile_dir."""
966 965 self.log.info('Starting %i engines with PBSEngineSetLauncher: %r' % (n, self.args))
967 966 return super(PBSEngineSetLauncher, self).start(n, profile_dir)
968 967
969 968 #SGE is very similar to PBS
970 969
971 970 class SGELauncher(PBSLauncher):
972 971 """Sun GridEngine is a PBS clone with slightly different syntax"""
973 972 job_array_regexp = Unicode('#\$\W+\-t')
974 973 job_array_template = Unicode('#$ -t 1-{n}')
975 974 queue_regexp = Unicode('#\$\W+-q\W+\$?\w+')
976 975 queue_template = Unicode('#$ -q $queue')
977 976
978 977 class SGEControllerLauncher(SGELauncher):
979 978 """Launch a controller using SGE."""
980 979
981 980 batch_file_name = Unicode(u'sge_controller', config=True,
982 981 help="batch file name for the ipontroller job.")
983 982 default_template= Unicode(u"""#$ -V
984 983 #$ -S /bin/sh
985 984 #$ -N ipcontroller
986 985 %s --log-to-file profile_dir={profile_dir}
987 986 """%(' '.join(ipcontroller_cmd_argv)))
988 987
989 988 def start(self, profile_dir):
990 989 """Start the controller by profile or profile_dir."""
991 990 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
992 991 return super(SGEControllerLauncher, self).start(1, profile_dir)
993 992
994 993 class SGEEngineSetLauncher(SGELauncher):
995 994 """Launch Engines with SGE"""
996 995 batch_file_name = Unicode(u'sge_engines', config=True,
997 996 help="batch file name for the engine(s) job.")
998 997 default_template = Unicode("""#$ -V
999 998 #$ -S /bin/sh
1000 999 #$ -N ipengine
1001 1000 %s profile_dir={profile_dir}
1002 1001 """%(' '.join(ipengine_cmd_argv)))
1003 1002
1004 1003 def start(self, n, profile_dir):
1005 1004 """Start n engines by profile or profile_dir."""
1006 1005 self.log.info('Starting %i engines with SGEEngineSetLauncher: %r' % (n, self.args))
1007 1006 return super(SGEEngineSetLauncher, self).start(n, profile_dir)
1008 1007
1009 1008
1010 1009 #-----------------------------------------------------------------------------
1011 1010 # A launcher for ipcluster itself!
1012 1011 #-----------------------------------------------------------------------------
1013 1012
1014 1013
1015 1014 class IPClusterLauncher(LocalProcessLauncher):
1016 1015 """Launch the ipcluster program in an external process."""
1017 1016
1018 1017 ipcluster_cmd = List(ipcluster_cmd_argv, config=True,
1019 1018 help="Popen command for ipcluster")
1020 1019 ipcluster_args = List(
1021 1020 ['--clean-logs', '--log-to-file', 'log_level=%i'%logging.INFO], config=True,
1022 1021 help="Command line arguments to pass to ipcluster.")
1023 1022 ipcluster_subcommand = Unicode('start')
1024 1023 ipcluster_n = Int(2)
1025 1024
1026 1025 def find_args(self):
1027 1026 return self.ipcluster_cmd + ['--'+self.ipcluster_subcommand] + \
1028 1027 ['n=%i'%self.ipcluster_n] + self.ipcluster_args
1029 1028
1030 1029 def start(self):
1031 1030 self.log.info("Starting ipcluster: %r" % self.args)
1032 1031 return super(IPClusterLauncher, self).start()
1033 1032
1034 1033 #-----------------------------------------------------------------------------
1035 1034 # Collections of launchers
1036 1035 #-----------------------------------------------------------------------------
1037 1036
1038 1037 local_launchers = [
1039 1038 LocalControllerLauncher,
1040 1039 LocalEngineLauncher,
1041 1040 LocalEngineSetLauncher,
1042 1041 ]
1043 1042 mpi_launchers = [
1044 1043 MPIExecLauncher,
1045 1044 MPIExecControllerLauncher,
1046 1045 MPIExecEngineSetLauncher,
1047 1046 ]
1048 1047 ssh_launchers = [
1049 1048 SSHLauncher,
1050 1049 SSHControllerLauncher,
1051 1050 SSHEngineLauncher,
1052 1051 SSHEngineSetLauncher,
1053 1052 ]
1054 1053 winhpc_launchers = [
1055 1054 WindowsHPCLauncher,
1056 1055 WindowsHPCControllerLauncher,
1057 1056 WindowsHPCEngineSetLauncher,
1058 1057 ]
1059 1058 pbs_launchers = [
1060 1059 PBSLauncher,
1061 1060 PBSControllerLauncher,
1062 1061 PBSEngineSetLauncher,
1063 1062 ]
1064 1063 sge_launchers = [
1065 1064 SGELauncher,
1066 1065 SGEControllerLauncher,
1067 1066 SGEEngineSetLauncher,
1068 1067 ]
1069 1068 all_launchers = local_launchers + mpi_launchers + ssh_launchers + winhpc_launchers\
1070 1069 + pbs_launchers + sge_launchers
@@ -1,108 +1,110 b''
1 1 #!/usr/bin/env python
2 2 """A simple logger object that consolidates messages incoming from ipcluster processes."""
3 3
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2011 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14
15 15
16 16 import logging
17 17 import sys
18 18
19 19 import zmq
20 20 from zmq.eventloop import ioloop, zmqstream
21 21
22 from IPython.config.configurable import Configurable
22 23 from IPython.utils.traitlets import Int, Unicode, Instance, List
23 24
24 from IPython.parallel.factory import LoggingFactory
25
26 25 #-----------------------------------------------------------------------------
27 26 # Classes
28 27 #-----------------------------------------------------------------------------
29 28
30 29
31 class LogWatcher(LoggingFactory):
30 class LogWatcher(Configurable):
32 31 """A simple class that receives messages on a SUB socket, as published
33 32 by subclasses of `zmq.log.handlers.PUBHandler`, and logs them itself.
34 33
35 34 This can subscribe to multiple topics, but defaults to all topics.
36 35 """
36
37 log = Instance('logging.Logger', ('root',))
38
37 39 # configurables
38 40 topics = List([''], config=True,
39 41 help="The ZMQ topics to subscribe to. Default is to subscribe to all messages")
40 42 url = Unicode('tcp://127.0.0.1:20202', config=True,
41 43 help="ZMQ url on which to listen for log messages")
42 44
43 45 # internals
44 46 stream = Instance('zmq.eventloop.zmqstream.ZMQStream')
45 47
46 48 context = Instance(zmq.Context)
47 49 def _context_default(self):
48 50 return zmq.Context.instance()
49 51
50 52 loop = Instance(zmq.eventloop.ioloop.IOLoop)
51 53 def _loop_default(self):
52 54 return ioloop.IOLoop.instance()
53 55
54 56 def __init__(self, **kwargs):
55 57 super(LogWatcher, self).__init__(**kwargs)
56 58 s = self.context.socket(zmq.SUB)
57 59 s.bind(self.url)
58 60 self.stream = zmqstream.ZMQStream(s, self.loop)
59 61 self.subscribe()
60 62 self.on_trait_change(self.subscribe, 'topics')
61 63
62 64 def start(self):
63 65 self.stream.on_recv(self.log_message)
64 66
65 67 def stop(self):
66 68 self.stream.stop_on_recv()
67 69
68 70 def subscribe(self):
69 71 """Update our SUB socket's subscriptions."""
70 72 self.stream.setsockopt(zmq.UNSUBSCRIBE, '')
71 73 if '' in self.topics:
72 74 self.log.debug("Subscribing to: everything")
73 75 self.stream.setsockopt(zmq.SUBSCRIBE, '')
74 76 else:
75 77 for topic in self.topics:
76 78 self.log.debug("Subscribing to: %r"%(topic))
77 79 self.stream.setsockopt(zmq.SUBSCRIBE, topic)
78 80
79 81 def _extract_level(self, topic_str):
80 82 """Turn 'engine.0.INFO.extra' into (logging.INFO, 'engine.0.extra')"""
81 83 topics = topic_str.split('.')
82 84 for idx,t in enumerate(topics):
83 85 level = getattr(logging, t, None)
84 86 if level is not None:
85 87 break
86 88
87 89 if level is None:
88 90 level = logging.INFO
89 91 else:
90 92 topics.pop(idx)
91 93
92 94 return level, '.'.join(topics)
93 95
94 96
95 97 def log_message(self, raw):
96 98 """receive and parse a message, then log it."""
97 99 if len(raw) != 2 or '.' not in raw[0]:
98 100 self.log.error("Invalid log message: %s"%raw)
99 101 return
100 102 else:
101 103 topic, msg = raw
102 104 # don't newline, since log messages always newline:
103 105 topic,level_name = topic.rsplit('.',1)
104 106 level,topic = self._extract_level(topic)
105 107 if msg[-1] == '\n':
106 108 msg = msg[:-1]
107 109 self.log.log(level, "[%s] %s" % (topic, msg))
108 110
@@ -1,165 +1,166 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
4 4 and hearts are tracked based on their XREQ identities.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010-2011 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import print_function
14 14 import time
15 15 import uuid
16 16
17 17 import zmq
18 18 from zmq.devices import ThreadDevice
19 19 from zmq.eventloop import ioloop, zmqstream
20 20
21 from IPython.config.configurable import Configurable
21 22 from IPython.utils.traitlets import Set, Instance, CFloat
22 from IPython.parallel.factory import LoggingFactory
23 23
24 24 class Heart(object):
25 25 """A basic heart object for responding to a HeartMonitor.
26 26 This is a simple wrapper with defaults for the most common
27 27 Device model for responding to heartbeats.
28 28
29 29 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
30 30 SUB/XREQ for in/out.
31 31
32 32 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
33 33 device=None
34 34 id=None
35 35 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.XREQ, heart_id=None):
36 36 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
37 37 self.device.daemon=True
38 38 self.device.connect_in(in_addr)
39 39 self.device.connect_out(out_addr)
40 40 if in_type == zmq.SUB:
41 41 self.device.setsockopt_in(zmq.SUBSCRIBE, "")
42 42 if heart_id is None:
43 43 heart_id = str(uuid.uuid4())
44 44 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
45 45 self.id = heart_id
46 46
47 47 def start(self):
48 48 return self.device.start()
49 49
50 class HeartMonitor(LoggingFactory):
50 class HeartMonitor(Configurable):
51 51 """A basic HeartMonitor class
52 52 pingstream: a PUB stream
53 53 pongstream: an XREP stream
54 54 period: the period of the heartbeat in milliseconds"""
55 55
56 56 period=CFloat(1000, config=True,
57 57 help='The frequency at which the Hub pings the engines for heartbeats '
58 58 ' (in ms) [default: 100]',
59 59 )
60 60
61 log = Instance('logging.Logger', ('root',))
61 62 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
62 63 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
63 64 loop = Instance('zmq.eventloop.ioloop.IOLoop')
64 65 def _loop_default(self):
65 66 return ioloop.IOLoop.instance()
66 67
67 68 # not settable:
68 69 hearts=Set()
69 70 responses=Set()
70 71 on_probation=Set()
71 72 last_ping=CFloat(0)
72 73 _new_handlers = Set()
73 74 _failure_handlers = Set()
74 75 lifetime = CFloat(0)
75 76 tic = CFloat(0)
76 77
77 78 def __init__(self, **kwargs):
78 79 super(HeartMonitor, self).__init__(**kwargs)
79 80
80 81 self.pongstream.on_recv(self.handle_pong)
81 82
82 83 def start(self):
83 84 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
84 85 self.caller.start()
85 86
86 87 def add_new_heart_handler(self, handler):
87 88 """add a new handler for new hearts"""
88 89 self.log.debug("heartbeat::new_heart_handler: %s"%handler)
89 90 self._new_handlers.add(handler)
90 91
91 92 def add_heart_failure_handler(self, handler):
92 93 """add a new handler for heart failure"""
93 94 self.log.debug("heartbeat::new heart failure handler: %s"%handler)
94 95 self._failure_handlers.add(handler)
95 96
96 97 def beat(self):
97 98 self.pongstream.flush()
98 99 self.last_ping = self.lifetime
99 100
100 101 toc = time.time()
101 102 self.lifetime += toc-self.tic
102 103 self.tic = toc
103 104 # self.log.debug("heartbeat::%s"%self.lifetime)
104 105 goodhearts = self.hearts.intersection(self.responses)
105 106 missed_beats = self.hearts.difference(goodhearts)
106 107 heartfailures = self.on_probation.intersection(missed_beats)
107 108 newhearts = self.responses.difference(goodhearts)
108 109 map(self.handle_new_heart, newhearts)
109 110 map(self.handle_heart_failure, heartfailures)
110 111 self.on_probation = missed_beats.intersection(self.hearts)
111 112 self.responses = set()
112 113 # print self.on_probation, self.hearts
113 114 # self.log.debug("heartbeat::beat %.3f, %i beating hearts"%(self.lifetime, len(self.hearts)))
114 115 self.pingstream.send(str(self.lifetime))
115 116
116 117 def handle_new_heart(self, heart):
117 118 if self._new_handlers:
118 119 for handler in self._new_handlers:
119 120 handler(heart)
120 121 else:
121 122 self.log.info("heartbeat::yay, got new heart %s!"%heart)
122 123 self.hearts.add(heart)
123 124
124 125 def handle_heart_failure(self, heart):
125 126 if self._failure_handlers:
126 127 for handler in self._failure_handlers:
127 128 try:
128 129 handler(heart)
129 130 except Exception as e:
130 131 self.log.error("heartbeat::Bad Handler! %s"%handler, exc_info=True)
131 132 pass
132 133 else:
133 134 self.log.info("heartbeat::Heart %s failed :("%heart)
134 135 self.hearts.remove(heart)
135 136
136 137
137 138 def handle_pong(self, msg):
138 139 "a heart just beat"
139 140 if msg[1] == str(self.lifetime):
140 141 delta = time.time()-self.tic
141 142 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
142 143 self.responses.add(msg[0])
143 144 elif msg[1] == str(self.last_ping):
144 145 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
145 146 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond"%(msg[0], 1000*delta))
146 147 self.responses.add(msg[0])
147 148 else:
148 149 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)"%
149 150 (msg[1],self.lifetime))
150 151
151 152
152 153 if __name__ == '__main__':
153 154 loop = ioloop.IOLoop.instance()
154 155 context = zmq.Context()
155 156 pub = context.socket(zmq.PUB)
156 157 pub.bind('tcp://127.0.0.1:5555')
157 158 xrep = context.socket(zmq.XREP)
158 159 xrep.bind('tcp://127.0.0.1:5556')
159 160
160 161 outstream = zmqstream.ZMQStream(pub, loop)
161 162 instream = zmqstream.ZMQStream(xrep, loop)
162 163
163 164 hb = HeartMonitor(loop, outstream, instream)
164 165
165 166 loop.start()
@@ -1,1274 +1,1277 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 import sys
19 19 import time
20 20 from datetime import datetime
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop
24 24 from zmq.eventloop.zmqstream import ZMQStream
25 25
26 26 # internal:
27 27 from IPython.utils.importstring import import_item
28 28 from IPython.utils.traitlets import (
29 29 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CStr
30 30 )
31 31 from IPython.utils.jsonutil import ISO8601, extract_dates
32 32
33 33 from IPython.parallel import error, util
34 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
34 from IPython.parallel.factory import RegistrationFactory
35
36 from IPython.zmq.session import SessionFactory
35 37
36 38 from .heartmonitor import HeartMonitor
37 39
38 40 #-----------------------------------------------------------------------------
39 41 # Code
40 42 #-----------------------------------------------------------------------------
41 43
42 44 def _passer(*args, **kwargs):
43 45 return
44 46
45 47 def _printer(*args, **kwargs):
46 48 print (args)
47 49 print (kwargs)
48 50
49 51 def empty_record():
50 52 """Return an empty dict with all record keys."""
51 53 return {
52 54 'msg_id' : None,
53 55 'header' : None,
54 56 'content': None,
55 57 'buffers': None,
56 58 'submitted': None,
57 59 'client_uuid' : None,
58 60 'engine_uuid' : None,
59 61 'started': None,
60 62 'completed': None,
61 63 'resubmitted': None,
62 64 'result_header' : None,
63 65 'result_content' : None,
64 66 'result_buffers' : None,
65 67 'queue' : None,
66 68 'pyin' : None,
67 69 'pyout': None,
68 70 'pyerr': None,
69 71 'stdout': '',
70 72 'stderr': '',
71 73 }
72 74
73 75 def init_record(msg):
74 76 """Initialize a TaskRecord based on a request."""
75 77 header = extract_dates(msg['header'])
76 78 return {
77 79 'msg_id' : header['msg_id'],
78 80 'header' : header,
79 81 'content': msg['content'],
80 82 'buffers': msg['buffers'],
81 83 'submitted': header['date'],
82 84 'client_uuid' : None,
83 85 'engine_uuid' : None,
84 86 'started': None,
85 87 'completed': None,
86 88 'resubmitted': None,
87 89 'result_header' : None,
88 90 'result_content' : None,
89 91 'result_buffers' : None,
90 92 'queue' : None,
91 93 'pyin' : None,
92 94 'pyout': None,
93 95 'pyerr': None,
94 96 'stdout': '',
95 97 'stderr': '',
96 98 }
97 99
98 100
99 101 class EngineConnector(HasTraits):
100 102 """A simple object for accessing the various zmq connections of an object.
101 103 Attributes are:
102 104 id (int): engine ID
103 105 uuid (str): uuid (unused?)
104 106 queue (str): identity of queue's XREQ socket
105 107 registration (str): identity of registration XREQ socket
106 108 heartbeat (str): identity of heartbeat XREQ socket
107 109 """
108 110 id=Int(0)
109 111 queue=CStr()
110 112 control=CStr()
111 113 registration=CStr()
112 114 heartbeat=CStr()
113 115 pending=Set()
114 116
115 117 class HubFactory(RegistrationFactory):
116 118 """The Configurable for setting up a Hub."""
117 119
118 120 # port-pairs for monitoredqueues:
119 121 hb = Tuple(Int,Int,config=True,
120 122 help="""XREQ/SUB Port pair for Engine heartbeats""")
121 123 def _hb_default(self):
122 124 return tuple(util.select_random_ports(2))
123 125
124 126 mux = Tuple(Int,Int,config=True,
125 127 help="""Engine/Client Port pair for MUX queue""")
126 128
127 129 def _mux_default(self):
128 130 return tuple(util.select_random_ports(2))
129 131
130 132 task = Tuple(Int,Int,config=True,
131 133 help="""Engine/Client Port pair for Task queue""")
132 134 def _task_default(self):
133 135 return tuple(util.select_random_ports(2))
134 136
135 137 control = Tuple(Int,Int,config=True,
136 138 help="""Engine/Client Port pair for Control queue""")
137 139
138 140 def _control_default(self):
139 141 return tuple(util.select_random_ports(2))
140 142
141 143 iopub = Tuple(Int,Int,config=True,
142 144 help="""Engine/Client Port pair for IOPub relay""")
143 145
144 146 def _iopub_default(self):
145 147 return tuple(util.select_random_ports(2))
146 148
147 149 # single ports:
148 150 mon_port = Int(config=True,
149 151 help="""Monitor (SUB) port for queue traffic""")
150 152
151 153 def _mon_port_default(self):
152 154 return util.select_random_ports(1)[0]
153 155
154 156 notifier_port = Int(config=True,
155 157 help="""PUB port for sending engine status notifications""")
156 158
157 159 def _notifier_port_default(self):
158 160 return util.select_random_ports(1)[0]
159 161
160 162 engine_ip = Unicode('127.0.0.1', config=True,
161 163 help="IP on which to listen for engine connections. [default: loopback]")
162 164 engine_transport = Unicode('tcp', config=True,
163 165 help="0MQ transport for engine connections. [default: tcp]")
164 166
165 167 client_ip = Unicode('127.0.0.1', config=True,
166 168 help="IP on which to listen for client connections. [default: loopback]")
167 169 client_transport = Unicode('tcp', config=True,
168 170 help="0MQ transport for client connections. [default : tcp]")
169 171
170 172 monitor_ip = Unicode('127.0.0.1', config=True,
171 173 help="IP on which to listen for monitor messages. [default: loopback]")
172 174 monitor_transport = Unicode('tcp', config=True,
173 175 help="0MQ transport for monitor messages. [default : tcp]")
174 176
175 177 monitor_url = Unicode('')
176 178
177 179 db_class = Unicode('IPython.parallel.controller.dictdb.DictDB', config=True,
178 180 help="""The class to use for the DB backend""")
179 181
180 182 # not configurable
181 183 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
182 184 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
183 185
184 186 def _ip_changed(self, name, old, new):
185 187 self.engine_ip = new
186 188 self.client_ip = new
187 189 self.monitor_ip = new
188 190 self._update_monitor_url()
189 191
190 192 def _update_monitor_url(self):
191 193 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
192 194
193 195 def _transport_changed(self, name, old, new):
194 196 self.engine_transport = new
195 197 self.client_transport = new
196 198 self.monitor_transport = new
197 199 self._update_monitor_url()
198 200
199 201 def __init__(self, **kwargs):
200 202 super(HubFactory, self).__init__(**kwargs)
201 203 self._update_monitor_url()
202 204
203 205
204 206 def construct(self):
205 207 self.init_hub()
206 208
207 209 def start(self):
208 210 self.heartmonitor.start()
209 211 self.log.info("Heartmonitor started")
210 212
211 213 def init_hub(self):
212 214 """construct"""
213 215 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
214 216 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
215 217
216 218 ctx = self.context
217 219 loop = self.loop
218 220
219 221 # Registrar socket
220 222 q = ZMQStream(ctx.socket(zmq.XREP), loop)
221 223 q.bind(client_iface % self.regport)
222 224 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
223 225 if self.client_ip != self.engine_ip:
224 226 q.bind(engine_iface % self.regport)
225 227 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
226 228
227 229 ### Engine connections ###
228 230
229 231 # heartbeat
230 232 hpub = ctx.socket(zmq.PUB)
231 233 hpub.bind(engine_iface % self.hb[0])
232 234 hrep = ctx.socket(zmq.XREP)
233 235 hrep.bind(engine_iface % self.hb[1])
234 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
235 config=self.config)
236 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
237 pingstream=ZMQStream(hpub,loop),
238 pongstream=ZMQStream(hrep,loop)
239 )
236 240
237 241 ### Client connections ###
238 242 # Notifier socket
239 243 n = ZMQStream(ctx.socket(zmq.PUB), loop)
240 244 n.bind(client_iface%self.notifier_port)
241 245
242 246 ### build and launch the queues ###
243 247
244 248 # monitor socket
245 249 sub = ctx.socket(zmq.SUB)
246 250 sub.setsockopt(zmq.SUBSCRIBE, "")
247 251 sub.bind(self.monitor_url)
248 252 sub.bind('inproc://monitor')
249 253 sub = ZMQStream(sub, loop)
250 254
251 255 # connect the db
252 256 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
253 257 # cdir = self.config.Global.cluster_dir
254 258 self.db = import_item(str(self.db_class))(session=self.session.session, config=self.config)
255 259 time.sleep(.25)
256 260 try:
257 261 scheme = self.config.TaskScheduler.scheme_name
258 262 except AttributeError:
259 263 from .scheduler import TaskScheduler
260 264 scheme = TaskScheduler.scheme_name.get_default_value()
261 265 # build connection dicts
262 266 self.engine_info = {
263 267 'control' : engine_iface%self.control[1],
264 268 'mux': engine_iface%self.mux[1],
265 269 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
266 270 'task' : engine_iface%self.task[1],
267 271 'iopub' : engine_iface%self.iopub[1],
268 272 # 'monitor' : engine_iface%self.mon_port,
269 273 }
270 274
271 275 self.client_info = {
272 276 'control' : client_iface%self.control[0],
273 277 'mux': client_iface%self.mux[0],
274 278 'task' : (scheme, client_iface%self.task[0]),
275 279 'iopub' : client_iface%self.iopub[0],
276 280 'notification': client_iface%self.notifier_port
277 281 }
278 282 self.log.debug("Hub engine addrs: %s"%self.engine_info)
279 283 self.log.debug("Hub client addrs: %s"%self.client_info)
280 284
281 285 # resubmit stream
282 286 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
283 287 url = util.disambiguate_url(self.client_info['task'][-1])
284 288 r.setsockopt(zmq.IDENTITY, self.session.session)
285 289 r.connect(url)
286 290
287 291 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
288 292 query=q, notifier=n, resubmit=r, db=self.db,
289 293 engine_info=self.engine_info, client_info=self.client_info,
290 logname=self.log.name)
294 log=self.log)
291 295
292 296
293 class Hub(LoggingFactory):
297 class Hub(SessionFactory):
294 298 """The IPython Controller Hub with 0MQ connections
295 299
296 300 Parameters
297 301 ==========
298 302 loop: zmq IOLoop instance
299 303 session: Session object
300 304 <removed> context: zmq context for creating new connections (?)
301 305 queue: ZMQStream for monitoring the command queue (SUB)
302 306 query: ZMQStream for engine registration and client queries requests (XREP)
303 307 heartbeat: HeartMonitor object checking the pulse of the engines
304 308 notifier: ZMQStream for broadcasting engine registration changes (PUB)
305 309 db: connection to db for out of memory logging of commands
306 310 NotImplemented
307 311 engine_info: dict of zmq connection information for engines to connect
308 312 to the queues.
309 313 client_info: dict of zmq connection information for engines to connect
310 314 to the queues.
311 315 """
312 316 # internal data structures:
313 317 ids=Set() # engine IDs
314 318 keytable=Dict()
315 319 by_ident=Dict()
316 320 engines=Dict()
317 321 clients=Dict()
318 322 hearts=Dict()
319 323 pending=Set()
320 324 queues=Dict() # pending msg_ids keyed by engine_id
321 325 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
322 326 completed=Dict() # completed msg_ids keyed by engine_id
323 327 all_completed=Set() # completed msg_ids keyed by engine_id
324 328 dead_engines=Set() # completed msg_ids keyed by engine_id
325 329 unassigned=Set() # set of task msg_ds not yet assigned a destination
326 330 incoming_registrations=Dict()
327 331 registration_timeout=Int()
328 332 _idcounter=Int(0)
329 333
330 334 # objects from constructor:
331 loop=Instance(ioloop.IOLoop)
332 335 query=Instance(ZMQStream)
333 336 monitor=Instance(ZMQStream)
334 337 notifier=Instance(ZMQStream)
335 338 resubmit=Instance(ZMQStream)
336 339 heartmonitor=Instance(HeartMonitor)
337 340 db=Instance(object)
338 341 client_info=Dict()
339 342 engine_info=Dict()
340 343
341 344
342 345 def __init__(self, **kwargs):
343 346 """
344 347 # universal:
345 348 loop: IOLoop for creating future connections
346 349 session: streamsession for sending serialized data
347 350 # engine:
348 351 queue: ZMQStream for monitoring queue messages
349 352 query: ZMQStream for engine+client registration and client requests
350 353 heartbeat: HeartMonitor object for tracking engines
351 354 # extra:
352 355 db: ZMQStream for db connection (NotImplemented)
353 356 engine_info: zmq address/protocol dict for engine connections
354 357 client_info: zmq address/protocol dict for client connections
355 358 """
356 359
357 360 super(Hub, self).__init__(**kwargs)
358 361 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
359 362
360 363 # validate connection dicts:
361 364 for k,v in self.client_info.iteritems():
362 365 if k == 'task':
363 366 util.validate_url_container(v[1])
364 367 else:
365 368 util.validate_url_container(v)
366 369 # util.validate_url_container(self.client_info)
367 370 util.validate_url_container(self.engine_info)
368 371
369 372 # register our callbacks
370 373 self.query.on_recv(self.dispatch_query)
371 374 self.monitor.on_recv(self.dispatch_monitor_traffic)
372 375
373 376 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
374 377 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
375 378
376 379 self.monitor_handlers = { 'in' : self.save_queue_request,
377 380 'out': self.save_queue_result,
378 381 'intask': self.save_task_request,
379 382 'outtask': self.save_task_result,
380 383 'tracktask': self.save_task_destination,
381 384 'incontrol': _passer,
382 385 'outcontrol': _passer,
383 386 'iopub': self.save_iopub_message,
384 387 }
385 388
386 389 self.query_handlers = {'queue_request': self.queue_status,
387 390 'result_request': self.get_results,
388 391 'history_request': self.get_history,
389 392 'db_request': self.db_query,
390 393 'purge_request': self.purge_results,
391 394 'load_request': self.check_load,
392 395 'resubmit_request': self.resubmit_task,
393 396 'shutdown_request': self.shutdown_request,
394 397 'registration_request' : self.register_engine,
395 398 'unregistration_request' : self.unregister_engine,
396 399 'connection_request': self.connection_request,
397 400 }
398 401
399 402 # ignore resubmit replies
400 403 self.resubmit.on_recv(lambda msg: None, copy=False)
401 404
402 405 self.log.info("hub::created hub")
403 406
404 407 @property
405 408 def _next_id(self):
406 409 """gemerate a new ID.
407 410
408 411 No longer reuse old ids, just count from 0."""
409 412 newid = self._idcounter
410 413 self._idcounter += 1
411 414 return newid
412 415 # newid = 0
413 416 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
414 417 # # print newid, self.ids, self.incoming_registrations
415 418 # while newid in self.ids or newid in incoming:
416 419 # newid += 1
417 420 # return newid
418 421
419 422 #-----------------------------------------------------------------------------
420 423 # message validation
421 424 #-----------------------------------------------------------------------------
422 425
423 426 def _validate_targets(self, targets):
424 427 """turn any valid targets argument into a list of integer ids"""
425 428 if targets is None:
426 429 # default to all
427 430 targets = self.ids
428 431
429 432 if isinstance(targets, (int,str,unicode)):
430 433 # only one target specified
431 434 targets = [targets]
432 435 _targets = []
433 436 for t in targets:
434 437 # map raw identities to ids
435 438 if isinstance(t, (str,unicode)):
436 439 t = self.by_ident.get(t, t)
437 440 _targets.append(t)
438 441 targets = _targets
439 442 bad_targets = [ t for t in targets if t not in self.ids ]
440 443 if bad_targets:
441 444 raise IndexError("No Such Engine: %r"%bad_targets)
442 445 if not targets:
443 446 raise IndexError("No Engines Registered")
444 447 return targets
445 448
446 449 #-----------------------------------------------------------------------------
447 450 # dispatch methods (1 per stream)
448 451 #-----------------------------------------------------------------------------
449 452
450 453
451 454 def dispatch_monitor_traffic(self, msg):
452 455 """all ME and Task queue messages come through here, as well as
453 456 IOPub traffic."""
454 457 self.log.debug("monitor traffic: %r"%msg[:2])
455 458 switch = msg[0]
456 459 try:
457 460 idents, msg = self.session.feed_identities(msg[1:])
458 461 except ValueError:
459 462 idents=[]
460 463 if not idents:
461 464 self.log.error("Bad Monitor Message: %r"%msg)
462 465 return
463 466 handler = self.monitor_handlers.get(switch, None)
464 467 if handler is not None:
465 468 handler(idents, msg)
466 469 else:
467 470 self.log.error("Invalid monitor topic: %r"%switch)
468 471
469 472
470 473 def dispatch_query(self, msg):
471 474 """Route registration requests and queries from clients."""
472 475 try:
473 476 idents, msg = self.session.feed_identities(msg)
474 477 except ValueError:
475 478 idents = []
476 479 if not idents:
477 480 self.log.error("Bad Query Message: %r"%msg)
478 481 return
479 482 client_id = idents[0]
480 483 try:
481 484 msg = self.session.unpack_message(msg, content=True)
482 485 except Exception:
483 486 content = error.wrap_exception()
484 487 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
485 488 self.session.send(self.query, "hub_error", ident=client_id,
486 489 content=content)
487 490 return
488 491 print( idents, msg)
489 492 # print client_id, header, parent, content
490 493 #switch on message type:
491 494 msg_type = msg['msg_type']
492 495 self.log.info("client::client %r requested %r"%(client_id, msg_type))
493 496 handler = self.query_handlers.get(msg_type, None)
494 497 try:
495 498 assert handler is not None, "Bad Message Type: %r"%msg_type
496 499 except:
497 500 content = error.wrap_exception()
498 501 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
499 502 self.session.send(self.query, "hub_error", ident=client_id,
500 503 content=content)
501 504 return
502 505
503 506 else:
504 507 handler(idents, msg)
505 508
506 509 def dispatch_db(self, msg):
507 510 """"""
508 511 raise NotImplementedError
509 512
510 513 #---------------------------------------------------------------------------
511 514 # handler methods (1 per event)
512 515 #---------------------------------------------------------------------------
513 516
514 517 #----------------------- Heartbeat --------------------------------------
515 518
516 519 def handle_new_heart(self, heart):
517 520 """handler to attach to heartbeater.
518 521 Called when a new heart starts to beat.
519 522 Triggers completion of registration."""
520 523 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
521 524 if heart not in self.incoming_registrations:
522 525 self.log.info("heartbeat::ignoring new heart: %r"%heart)
523 526 else:
524 527 self.finish_registration(heart)
525 528
526 529
527 530 def handle_heart_failure(self, heart):
528 531 """handler to attach to heartbeater.
529 532 called when a previously registered heart fails to respond to beat request.
530 533 triggers unregistration"""
531 534 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
532 535 eid = self.hearts.get(heart, None)
533 536 queue = self.engines[eid].queue
534 537 if eid is None:
535 538 self.log.info("heartbeat::ignoring heart failure %r"%heart)
536 539 else:
537 540 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
538 541
539 542 #----------------------- MUX Queue Traffic ------------------------------
540 543
541 544 def save_queue_request(self, idents, msg):
542 545 if len(idents) < 2:
543 546 self.log.error("invalid identity prefix: %r"%idents)
544 547 return
545 548 queue_id, client_id = idents[:2]
546 549 try:
547 550 msg = self.session.unpack_message(msg, content=False)
548 551 except Exception:
549 552 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
550 553 return
551 554
552 555 eid = self.by_ident.get(queue_id, None)
553 556 if eid is None:
554 557 self.log.error("queue::target %r not registered"%queue_id)
555 558 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
556 559 return
557 560
558 561 header = msg['header']
559 562 msg_id = header['msg_id']
560 563 record = init_record(msg)
561 564 record['engine_uuid'] = queue_id
562 565 record['client_uuid'] = client_id
563 566 record['queue'] = 'mux'
564 567
565 568 try:
566 569 # it's posible iopub arrived first:
567 570 existing = self.db.get_record(msg_id)
568 571 for key,evalue in existing.iteritems():
569 572 rvalue = record.get(key, None)
570 573 if evalue and rvalue and evalue != rvalue:
571 574 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
572 575 elif evalue and not rvalue:
573 576 record[key] = evalue
574 577 self.db.update_record(msg_id, record)
575 578 except KeyError:
576 579 self.db.add_record(msg_id, record)
577 580
578 581 self.pending.add(msg_id)
579 582 self.queues[eid].append(msg_id)
580 583
581 584 def save_queue_result(self, idents, msg):
582 585 if len(idents) < 2:
583 586 self.log.error("invalid identity prefix: %r"%idents)
584 587 return
585 588
586 589 client_id, queue_id = idents[:2]
587 590 try:
588 591 msg = self.session.unpack_message(msg, content=False)
589 592 except Exception:
590 593 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
591 594 queue_id,client_id, msg), exc_info=True)
592 595 return
593 596
594 597 eid = self.by_ident.get(queue_id, None)
595 598 if eid is None:
596 599 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
597 600 return
598 601
599 602 parent = msg['parent_header']
600 603 if not parent:
601 604 return
602 605 msg_id = parent['msg_id']
603 606 if msg_id in self.pending:
604 607 self.pending.remove(msg_id)
605 608 self.all_completed.add(msg_id)
606 609 self.queues[eid].remove(msg_id)
607 610 self.completed[eid].append(msg_id)
608 611 elif msg_id not in self.all_completed:
609 612 # it could be a result from a dead engine that died before delivering the
610 613 # result
611 614 self.log.warn("queue:: unknown msg finished %r"%msg_id)
612 615 return
613 616 # update record anyway, because the unregistration could have been premature
614 617 rheader = extract_dates(msg['header'])
615 618 completed = rheader['date']
616 619 started = rheader.get('started', None)
617 620 result = {
618 621 'result_header' : rheader,
619 622 'result_content': msg['content'],
620 623 'started' : started,
621 624 'completed' : completed
622 625 }
623 626
624 627 result['result_buffers'] = msg['buffers']
625 628 try:
626 629 self.db.update_record(msg_id, result)
627 630 except Exception:
628 631 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
629 632
630 633
631 634 #--------------------- Task Queue Traffic ------------------------------
632 635
633 636 def save_task_request(self, idents, msg):
634 637 """Save the submission of a task."""
635 638 client_id = idents[0]
636 639
637 640 try:
638 641 msg = self.session.unpack_message(msg, content=False)
639 642 except Exception:
640 643 self.log.error("task::client %r sent invalid task message: %r"%(
641 644 client_id, msg), exc_info=True)
642 645 return
643 646 record = init_record(msg)
644 647
645 648 record['client_uuid'] = client_id
646 649 record['queue'] = 'task'
647 650 header = msg['header']
648 651 msg_id = header['msg_id']
649 652 self.pending.add(msg_id)
650 653 self.unassigned.add(msg_id)
651 654 try:
652 655 # it's posible iopub arrived first:
653 656 existing = self.db.get_record(msg_id)
654 657 if existing['resubmitted']:
655 658 for key in ('submitted', 'client_uuid', 'buffers'):
656 659 # don't clobber these keys on resubmit
657 660 # submitted and client_uuid should be different
658 661 # and buffers might be big, and shouldn't have changed
659 662 record.pop(key)
660 663 # still check content,header which should not change
661 664 # but are not expensive to compare as buffers
662 665
663 666 for key,evalue in existing.iteritems():
664 667 if key.endswith('buffers'):
665 668 # don't compare buffers
666 669 continue
667 670 rvalue = record.get(key, None)
668 671 if evalue and rvalue and evalue != rvalue:
669 672 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
670 673 elif evalue and not rvalue:
671 674 record[key] = evalue
672 675 self.db.update_record(msg_id, record)
673 676 except KeyError:
674 677 self.db.add_record(msg_id, record)
675 678 except Exception:
676 679 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
677 680
678 681 def save_task_result(self, idents, msg):
679 682 """save the result of a completed task."""
680 683 client_id = idents[0]
681 684 try:
682 685 msg = self.session.unpack_message(msg, content=False)
683 686 except Exception:
684 687 self.log.error("task::invalid task result message send to %r: %r"%(
685 688 client_id, msg), exc_info=True)
686 689 return
687 690
688 691 parent = msg['parent_header']
689 692 if not parent:
690 693 # print msg
691 694 self.log.warn("Task %r had no parent!"%msg)
692 695 return
693 696 msg_id = parent['msg_id']
694 697 if msg_id in self.unassigned:
695 698 self.unassigned.remove(msg_id)
696 699
697 700 header = extract_dates(msg['header'])
698 701 engine_uuid = header.get('engine', None)
699 702 eid = self.by_ident.get(engine_uuid, None)
700 703
701 704 if msg_id in self.pending:
702 705 self.pending.remove(msg_id)
703 706 self.all_completed.add(msg_id)
704 707 if eid is not None:
705 708 self.completed[eid].append(msg_id)
706 709 if msg_id in self.tasks[eid]:
707 710 self.tasks[eid].remove(msg_id)
708 711 completed = header['date']
709 712 started = header.get('started', None)
710 713 result = {
711 714 'result_header' : header,
712 715 'result_content': msg['content'],
713 716 'started' : started,
714 717 'completed' : completed,
715 718 'engine_uuid': engine_uuid
716 719 }
717 720
718 721 result['result_buffers'] = msg['buffers']
719 722 try:
720 723 self.db.update_record(msg_id, result)
721 724 except Exception:
722 725 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
723 726
724 727 else:
725 728 self.log.debug("task::unknown task %r finished"%msg_id)
726 729
727 730 def save_task_destination(self, idents, msg):
728 731 try:
729 732 msg = self.session.unpack_message(msg, content=True)
730 733 except Exception:
731 734 self.log.error("task::invalid task tracking message", exc_info=True)
732 735 return
733 736 content = msg['content']
734 737 # print (content)
735 738 msg_id = content['msg_id']
736 739 engine_uuid = content['engine_id']
737 740 eid = self.by_ident[engine_uuid]
738 741
739 742 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
740 743 if msg_id in self.unassigned:
741 744 self.unassigned.remove(msg_id)
742 745 # else:
743 746 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
744 747
745 748 self.tasks[eid].append(msg_id)
746 749 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
747 750 try:
748 751 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
749 752 except Exception:
750 753 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
751 754
752 755
753 756 def mia_task_request(self, idents, msg):
754 757 raise NotImplementedError
755 758 client_id = idents[0]
756 759 # content = dict(mia=self.mia,status='ok')
757 760 # self.session.send('mia_reply', content=content, idents=client_id)
758 761
759 762
760 763 #--------------------- IOPub Traffic ------------------------------
761 764
762 765 def save_iopub_message(self, topics, msg):
763 766 """save an iopub message into the db"""
764 767 # print (topics)
765 768 try:
766 769 msg = self.session.unpack_message(msg, content=True)
767 770 except Exception:
768 771 self.log.error("iopub::invalid IOPub message", exc_info=True)
769 772 return
770 773
771 774 parent = msg['parent_header']
772 775 if not parent:
773 776 self.log.error("iopub::invalid IOPub message: %r"%msg)
774 777 return
775 778 msg_id = parent['msg_id']
776 779 msg_type = msg['msg_type']
777 780 content = msg['content']
778 781
779 782 # ensure msg_id is in db
780 783 try:
781 784 rec = self.db.get_record(msg_id)
782 785 except KeyError:
783 786 rec = empty_record()
784 787 rec['msg_id'] = msg_id
785 788 self.db.add_record(msg_id, rec)
786 789 # stream
787 790 d = {}
788 791 if msg_type == 'stream':
789 792 name = content['name']
790 793 s = rec[name] or ''
791 794 d[name] = s + content['data']
792 795
793 796 elif msg_type == 'pyerr':
794 797 d['pyerr'] = content
795 798 elif msg_type == 'pyin':
796 799 d['pyin'] = content['code']
797 800 else:
798 801 d[msg_type] = content.get('data', '')
799 802
800 803 try:
801 804 self.db.update_record(msg_id, d)
802 805 except Exception:
803 806 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
804 807
805 808
806 809
807 810 #-------------------------------------------------------------------------
808 811 # Registration requests
809 812 #-------------------------------------------------------------------------
810 813
811 814 def connection_request(self, client_id, msg):
812 815 """Reply with connection addresses for clients."""
813 816 self.log.info("client::client %r connected"%client_id)
814 817 content = dict(status='ok')
815 818 content.update(self.client_info)
816 819 jsonable = {}
817 820 for k,v in self.keytable.iteritems():
818 821 if v not in self.dead_engines:
819 822 jsonable[str(k)] = v
820 823 content['engines'] = jsonable
821 824 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
822 825
823 826 def register_engine(self, reg, msg):
824 827 """Register a new engine."""
825 828 content = msg['content']
826 829 try:
827 830 queue = content['queue']
828 831 except KeyError:
829 832 self.log.error("registration::queue not specified", exc_info=True)
830 833 return
831 834 heart = content.get('heartbeat', None)
832 835 """register a new engine, and create the socket(s) necessary"""
833 836 eid = self._next_id
834 837 # print (eid, queue, reg, heart)
835 838
836 839 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
837 840
838 841 content = dict(id=eid,status='ok')
839 842 content.update(self.engine_info)
840 843 # check if requesting available IDs:
841 844 if queue in self.by_ident:
842 845 try:
843 846 raise KeyError("queue_id %r in use"%queue)
844 847 except:
845 848 content = error.wrap_exception()
846 849 self.log.error("queue_id %r in use"%queue, exc_info=True)
847 850 elif heart in self.hearts: # need to check unique hearts?
848 851 try:
849 852 raise KeyError("heart_id %r in use"%heart)
850 853 except:
851 854 self.log.error("heart_id %r in use"%heart, exc_info=True)
852 855 content = error.wrap_exception()
853 856 else:
854 857 for h, pack in self.incoming_registrations.iteritems():
855 858 if heart == h:
856 859 try:
857 860 raise KeyError("heart_id %r in use"%heart)
858 861 except:
859 862 self.log.error("heart_id %r in use"%heart, exc_info=True)
860 863 content = error.wrap_exception()
861 864 break
862 865 elif queue == pack[1]:
863 866 try:
864 867 raise KeyError("queue_id %r in use"%queue)
865 868 except:
866 869 self.log.error("queue_id %r in use"%queue, exc_info=True)
867 870 content = error.wrap_exception()
868 871 break
869 872
870 873 msg = self.session.send(self.query, "registration_reply",
871 874 content=content,
872 875 ident=reg)
873 876
874 877 if content['status'] == 'ok':
875 878 if heart in self.heartmonitor.hearts:
876 879 # already beating
877 880 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
878 881 self.finish_registration(heart)
879 882 else:
880 883 purge = lambda : self._purge_stalled_registration(heart)
881 884 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
882 885 dc.start()
883 886 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
884 887 else:
885 888 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
886 889 return eid
887 890
888 891 def unregister_engine(self, ident, msg):
889 892 """Unregister an engine that explicitly requested to leave."""
890 893 try:
891 894 eid = msg['content']['id']
892 895 except:
893 896 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
894 897 return
895 898 self.log.info("registration::unregister_engine(%r)"%eid)
896 899 # print (eid)
897 900 uuid = self.keytable[eid]
898 901 content=dict(id=eid, queue=uuid)
899 902 self.dead_engines.add(uuid)
900 903 # self.ids.remove(eid)
901 904 # uuid = self.keytable.pop(eid)
902 905 #
903 906 # ec = self.engines.pop(eid)
904 907 # self.hearts.pop(ec.heartbeat)
905 908 # self.by_ident.pop(ec.queue)
906 909 # self.completed.pop(eid)
907 910 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
908 911 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
909 912 dc.start()
910 913 ############## TODO: HANDLE IT ################
911 914
912 915 if self.notifier:
913 916 self.session.send(self.notifier, "unregistration_notification", content=content)
914 917
915 918 def _handle_stranded_msgs(self, eid, uuid):
916 919 """Handle messages known to be on an engine when the engine unregisters.
917 920
918 921 It is possible that this will fire prematurely - that is, an engine will
919 922 go down after completing a result, and the client will be notified
920 923 that the result failed and later receive the actual result.
921 924 """
922 925
923 926 outstanding = self.queues[eid]
924 927
925 928 for msg_id in outstanding:
926 929 self.pending.remove(msg_id)
927 930 self.all_completed.add(msg_id)
928 931 try:
929 932 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
930 933 except:
931 934 content = error.wrap_exception()
932 935 # build a fake header:
933 936 header = {}
934 937 header['engine'] = uuid
935 938 header['date'] = datetime.now()
936 939 rec = dict(result_content=content, result_header=header, result_buffers=[])
937 940 rec['completed'] = header['date']
938 941 rec['engine_uuid'] = uuid
939 942 try:
940 943 self.db.update_record(msg_id, rec)
941 944 except Exception:
942 945 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
943 946
944 947
945 948 def finish_registration(self, heart):
946 949 """Second half of engine registration, called after our HeartMonitor
947 950 has received a beat from the Engine's Heart."""
948 951 try:
949 952 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
950 953 except KeyError:
951 954 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
952 955 return
953 956 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
954 957 if purge is not None:
955 958 purge.stop()
956 959 control = queue
957 960 self.ids.add(eid)
958 961 self.keytable[eid] = queue
959 962 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
960 963 control=control, heartbeat=heart)
961 964 self.by_ident[queue] = eid
962 965 self.queues[eid] = list()
963 966 self.tasks[eid] = list()
964 967 self.completed[eid] = list()
965 968 self.hearts[heart] = eid
966 969 content = dict(id=eid, queue=self.engines[eid].queue)
967 970 if self.notifier:
968 971 self.session.send(self.notifier, "registration_notification", content=content)
969 972 self.log.info("engine::Engine Connected: %i"%eid)
970 973
971 974 def _purge_stalled_registration(self, heart):
972 975 if heart in self.incoming_registrations:
973 976 eid = self.incoming_registrations.pop(heart)[0]
974 977 self.log.info("registration::purging stalled registration: %i"%eid)
975 978 else:
976 979 pass
977 980
978 981 #-------------------------------------------------------------------------
979 982 # Client Requests
980 983 #-------------------------------------------------------------------------
981 984
982 985 def shutdown_request(self, client_id, msg):
983 986 """handle shutdown request."""
984 987 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
985 988 # also notify other clients of shutdown
986 989 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
987 990 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
988 991 dc.start()
989 992
990 993 def _shutdown(self):
991 994 self.log.info("hub::hub shutting down.")
992 995 time.sleep(0.1)
993 996 sys.exit(0)
994 997
995 998
996 999 def check_load(self, client_id, msg):
997 1000 content = msg['content']
998 1001 try:
999 1002 targets = content['targets']
1000 1003 targets = self._validate_targets(targets)
1001 1004 except:
1002 1005 content = error.wrap_exception()
1003 1006 self.session.send(self.query, "hub_error",
1004 1007 content=content, ident=client_id)
1005 1008 return
1006 1009
1007 1010 content = dict(status='ok')
1008 1011 # loads = {}
1009 1012 for t in targets:
1010 1013 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1011 1014 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1012 1015
1013 1016
1014 1017 def queue_status(self, client_id, msg):
1015 1018 """Return the Queue status of one or more targets.
1016 1019 if verbose: return the msg_ids
1017 1020 else: return len of each type.
1018 1021 keys: queue (pending MUX jobs)
1019 1022 tasks (pending Task jobs)
1020 1023 completed (finished jobs from both queues)"""
1021 1024 content = msg['content']
1022 1025 targets = content['targets']
1023 1026 try:
1024 1027 targets = self._validate_targets(targets)
1025 1028 except:
1026 1029 content = error.wrap_exception()
1027 1030 self.session.send(self.query, "hub_error",
1028 1031 content=content, ident=client_id)
1029 1032 return
1030 1033 verbose = content.get('verbose', False)
1031 1034 content = dict(status='ok')
1032 1035 for t in targets:
1033 1036 queue = self.queues[t]
1034 1037 completed = self.completed[t]
1035 1038 tasks = self.tasks[t]
1036 1039 if not verbose:
1037 1040 queue = len(queue)
1038 1041 completed = len(completed)
1039 1042 tasks = len(tasks)
1040 1043 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1041 1044 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1042 1045
1043 1046 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1044 1047
1045 1048 def purge_results(self, client_id, msg):
1046 1049 """Purge results from memory. This method is more valuable before we move
1047 1050 to a DB based message storage mechanism."""
1048 1051 content = msg['content']
1049 1052 msg_ids = content.get('msg_ids', [])
1050 1053 reply = dict(status='ok')
1051 1054 if msg_ids == 'all':
1052 1055 try:
1053 1056 self.db.drop_matching_records(dict(completed={'$ne':None}))
1054 1057 except Exception:
1055 1058 reply = error.wrap_exception()
1056 1059 else:
1057 1060 pending = filter(lambda m: m in self.pending, msg_ids)
1058 1061 if pending:
1059 1062 try:
1060 1063 raise IndexError("msg pending: %r"%pending[0])
1061 1064 except:
1062 1065 reply = error.wrap_exception()
1063 1066 else:
1064 1067 try:
1065 1068 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1066 1069 except Exception:
1067 1070 reply = error.wrap_exception()
1068 1071
1069 1072 if reply['status'] == 'ok':
1070 1073 eids = content.get('engine_ids', [])
1071 1074 for eid in eids:
1072 1075 if eid not in self.engines:
1073 1076 try:
1074 1077 raise IndexError("No such engine: %i"%eid)
1075 1078 except:
1076 1079 reply = error.wrap_exception()
1077 1080 break
1078 1081 msg_ids = self.completed.pop(eid)
1079 1082 uid = self.engines[eid].queue
1080 1083 try:
1081 1084 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1082 1085 except Exception:
1083 1086 reply = error.wrap_exception()
1084 1087 break
1085 1088
1086 1089 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1087 1090
1088 1091 def resubmit_task(self, client_id, msg):
1089 1092 """Resubmit one or more tasks."""
1090 1093 def finish(reply):
1091 1094 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1092 1095
1093 1096 content = msg['content']
1094 1097 msg_ids = content['msg_ids']
1095 1098 reply = dict(status='ok')
1096 1099 try:
1097 1100 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1098 1101 'header', 'content', 'buffers'])
1099 1102 except Exception:
1100 1103 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1101 1104 return finish(error.wrap_exception())
1102 1105
1103 1106 # validate msg_ids
1104 1107 found_ids = [ rec['msg_id'] for rec in records ]
1105 1108 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1106 1109 if len(records) > len(msg_ids):
1107 1110 try:
1108 1111 raise RuntimeError("DB appears to be in an inconsistent state."
1109 1112 "More matching records were found than should exist")
1110 1113 except Exception:
1111 1114 return finish(error.wrap_exception())
1112 1115 elif len(records) < len(msg_ids):
1113 1116 missing = [ m for m in msg_ids if m not in found_ids ]
1114 1117 try:
1115 1118 raise KeyError("No such msg(s): %r"%missing)
1116 1119 except KeyError:
1117 1120 return finish(error.wrap_exception())
1118 1121 elif invalid_ids:
1119 1122 msg_id = invalid_ids[0]
1120 1123 try:
1121 1124 raise ValueError("Task %r appears to be inflight"%(msg_id))
1122 1125 except Exception:
1123 1126 return finish(error.wrap_exception())
1124 1127
1125 1128 # clear the existing records
1126 1129 now = datetime.now()
1127 1130 rec = empty_record()
1128 1131 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1129 1132 rec['resubmitted'] = now
1130 1133 rec['queue'] = 'task'
1131 1134 rec['client_uuid'] = client_id[0]
1132 1135 try:
1133 1136 for msg_id in msg_ids:
1134 1137 self.all_completed.discard(msg_id)
1135 1138 self.db.update_record(msg_id, rec)
1136 1139 except Exception:
1137 1140 self.log.error('db::db error upating record', exc_info=True)
1138 1141 reply = error.wrap_exception()
1139 1142 else:
1140 1143 # send the messages
1141 1144 now_s = now.strftime(ISO8601)
1142 1145 for rec in records:
1143 1146 header = rec['header']
1144 1147 # include resubmitted in header to prevent digest collision
1145 1148 header['resubmitted'] = now_s
1146 1149 msg = self.session.msg(header['msg_type'])
1147 1150 msg['content'] = rec['content']
1148 1151 msg['header'] = header
1149 1152 msg['msg_id'] = rec['msg_id']
1150 1153 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1151 1154
1152 1155 finish(dict(status='ok'))
1153 1156
1154 1157
1155 1158 def _extract_record(self, rec):
1156 1159 """decompose a TaskRecord dict into subsection of reply for get_result"""
1157 1160 io_dict = {}
1158 1161 for key in 'pyin pyout pyerr stdout stderr'.split():
1159 1162 io_dict[key] = rec[key]
1160 1163 content = { 'result_content': rec['result_content'],
1161 1164 'header': rec['header'],
1162 1165 'result_header' : rec['result_header'],
1163 1166 'io' : io_dict,
1164 1167 }
1165 1168 if rec['result_buffers']:
1166 1169 buffers = map(str, rec['result_buffers'])
1167 1170 else:
1168 1171 buffers = []
1169 1172
1170 1173 return content, buffers
1171 1174
1172 1175 def get_results(self, client_id, msg):
1173 1176 """Get the result of 1 or more messages."""
1174 1177 content = msg['content']
1175 1178 msg_ids = sorted(set(content['msg_ids']))
1176 1179 statusonly = content.get('status_only', False)
1177 1180 pending = []
1178 1181 completed = []
1179 1182 content = dict(status='ok')
1180 1183 content['pending'] = pending
1181 1184 content['completed'] = completed
1182 1185 buffers = []
1183 1186 if not statusonly:
1184 1187 try:
1185 1188 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1186 1189 # turn match list into dict, for faster lookup
1187 1190 records = {}
1188 1191 for rec in matches:
1189 1192 records[rec['msg_id']] = rec
1190 1193 except Exception:
1191 1194 content = error.wrap_exception()
1192 1195 self.session.send(self.query, "result_reply", content=content,
1193 1196 parent=msg, ident=client_id)
1194 1197 return
1195 1198 else:
1196 1199 records = {}
1197 1200 for msg_id in msg_ids:
1198 1201 if msg_id in self.pending:
1199 1202 pending.append(msg_id)
1200 1203 elif msg_id in self.all_completed:
1201 1204 completed.append(msg_id)
1202 1205 if not statusonly:
1203 1206 c,bufs = self._extract_record(records[msg_id])
1204 1207 content[msg_id] = c
1205 1208 buffers.extend(bufs)
1206 1209 elif msg_id in records:
1207 1210 if rec['completed']:
1208 1211 completed.append(msg_id)
1209 1212 c,bufs = self._extract_record(records[msg_id])
1210 1213 content[msg_id] = c
1211 1214 buffers.extend(bufs)
1212 1215 else:
1213 1216 pending.append(msg_id)
1214 1217 else:
1215 1218 try:
1216 1219 raise KeyError('No such message: '+msg_id)
1217 1220 except:
1218 1221 content = error.wrap_exception()
1219 1222 break
1220 1223 self.session.send(self.query, "result_reply", content=content,
1221 1224 parent=msg, ident=client_id,
1222 1225 buffers=buffers)
1223 1226
1224 1227 def get_history(self, client_id, msg):
1225 1228 """Get a list of all msg_ids in our DB records"""
1226 1229 try:
1227 1230 msg_ids = self.db.get_history()
1228 1231 except Exception as e:
1229 1232 content = error.wrap_exception()
1230 1233 else:
1231 1234 content = dict(status='ok', history=msg_ids)
1232 1235
1233 1236 self.session.send(self.query, "history_reply", content=content,
1234 1237 parent=msg, ident=client_id)
1235 1238
1236 1239 def db_query(self, client_id, msg):
1237 1240 """Perform a raw query on the task record database."""
1238 1241 content = msg['content']
1239 1242 query = content.get('query', {})
1240 1243 keys = content.get('keys', None)
1241 1244 query = util.extract_dates(query)
1242 1245 buffers = []
1243 1246 empty = list()
1244 1247
1245 1248 try:
1246 1249 records = self.db.find_records(query, keys)
1247 1250 except Exception as e:
1248 1251 content = error.wrap_exception()
1249 1252 else:
1250 1253 # extract buffers from reply content:
1251 1254 if keys is not None:
1252 1255 buffer_lens = [] if 'buffers' in keys else None
1253 1256 result_buffer_lens = [] if 'result_buffers' in keys else None
1254 1257 else:
1255 1258 buffer_lens = []
1256 1259 result_buffer_lens = []
1257 1260
1258 1261 for rec in records:
1259 1262 # buffers may be None, so double check
1260 1263 if buffer_lens is not None:
1261 1264 b = rec.pop('buffers', empty) or empty
1262 1265 buffer_lens.append(len(b))
1263 1266 buffers.extend(b)
1264 1267 if result_buffer_lens is not None:
1265 1268 rb = rec.pop('result_buffers', empty) or empty
1266 1269 result_buffer_lens.append(len(rb))
1267 1270 buffers.extend(rb)
1268 1271 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1269 1272 result_buffer_lens=result_buffer_lens)
1270 1273
1271 1274 self.session.send(self.query, "db_reply", content=content,
1272 1275 parent=msg, ident=client_id,
1273 1276 buffers=buffers)
1274 1277
@@ -1,687 +1,688 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #----------------------------------------------------------------------
15 15 # Imports
16 16 #----------------------------------------------------------------------
17 17
18 18 from __future__ import print_function
19 19
20 20 import logging
21 21 import sys
22 22
23 23 from datetime import datetime, timedelta
24 24 from random import randint, random
25 25 from types import FunctionType
26 26
27 27 try:
28 28 import numpy
29 29 except ImportError:
30 30 numpy = None
31 31
32 32 import zmq
33 33 from zmq.eventloop import ioloop, zmqstream
34 34
35 35 # local imports
36 36 from IPython.external.decorator import decorator
37 37 from IPython.config.loader import Config
38 38 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Str, Enum
39 39
40 40 from IPython.parallel import error
41 41 from IPython.parallel.factory import SessionFactory
42 42 from IPython.parallel.util import connect_logger, local_logger
43 43
44 44 from .dependency import Dependency
45 45
46 46 @decorator
47 47 def logged(f,self,*args,**kwargs):
48 48 # print ("#--------------------")
49 49 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
50 50 # print ("#--")
51 51 return f(self,*args, **kwargs)
52 52
53 53 #----------------------------------------------------------------------
54 54 # Chooser functions
55 55 #----------------------------------------------------------------------
56 56
57 57 def plainrandom(loads):
58 58 """Plain random pick."""
59 59 n = len(loads)
60 60 return randint(0,n-1)
61 61
62 62 def lru(loads):
63 63 """Always pick the front of the line.
64 64
65 65 The content of `loads` is ignored.
66 66
67 67 Assumes LRU ordering of loads, with oldest first.
68 68 """
69 69 return 0
70 70
71 71 def twobin(loads):
72 72 """Pick two at random, use the LRU of the two.
73 73
74 74 The content of loads is ignored.
75 75
76 76 Assumes LRU ordering of loads, with oldest first.
77 77 """
78 78 n = len(loads)
79 79 a = randint(0,n-1)
80 80 b = randint(0,n-1)
81 81 return min(a,b)
82 82
83 83 def weighted(loads):
84 84 """Pick two at random using inverse load as weight.
85 85
86 86 Return the less loaded of the two.
87 87 """
88 88 # weight 0 a million times more than 1:
89 89 weights = 1./(1e-6+numpy.array(loads))
90 90 sums = weights.cumsum()
91 91 t = sums[-1]
92 92 x = random()*t
93 93 y = random()*t
94 94 idx = 0
95 95 idy = 0
96 96 while sums[idx] < x:
97 97 idx += 1
98 98 while sums[idy] < y:
99 99 idy += 1
100 100 if weights[idy] > weights[idx]:
101 101 return idy
102 102 else:
103 103 return idx
104 104
105 105 def leastload(loads):
106 106 """Always choose the lowest load.
107 107
108 108 If the lowest load occurs more than once, the first
109 109 occurance will be used. If loads has LRU ordering, this means
110 110 the LRU of those with the lowest load is chosen.
111 111 """
112 112 return loads.index(min(loads))
113 113
114 114 #---------------------------------------------------------------------
115 115 # Classes
116 116 #---------------------------------------------------------------------
117 117 # store empty default dependency:
118 118 MET = Dependency([])
119 119
120 120 class TaskScheduler(SessionFactory):
121 121 """Python TaskScheduler object.
122 122
123 123 This is the simplest object that supports msg_id based
124 124 DAG dependencies. *Only* task msg_ids are checked, not
125 125 msg_ids of jobs submitted via the MUX queue.
126 126
127 127 """
128 128
129 129 hwm = Int(0, config=True, shortname='hwm',
130 130 help="""specify the High Water Mark (HWM) for the downstream
131 131 socket in the Task scheduler. This is the maximum number
132 132 of allowed outstanding tasks on each engine."""
133 133 )
134 134 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
135 135 'leastload', config=True, shortname='scheme', allow_none=False,
136 136 help="""select the task scheduler scheme [default: Python LRU]
137 137 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
138 138 )
139 139 def _scheme_name_changed(self, old, new):
140 140 self.log.debug("Using scheme %r"%new)
141 141 self.scheme = globals()[new]
142 142
143 143 # input arguments:
144 144 scheme = Instance(FunctionType) # function for determining the destination
145 145 def _scheme_default(self):
146 146 return leastload
147 147 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
148 148 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
149 149 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
150 150 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
151 151
152 152 # internals:
153 153 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
154 154 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
155 155 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
156 156 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
157 157 pending = Dict() # dict by engine_uuid of submitted tasks
158 158 completed = Dict() # dict by engine_uuid of completed tasks
159 159 failed = Dict() # dict by engine_uuid of failed tasks
160 160 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
161 161 clients = Dict() # dict by msg_id for who submitted the task
162 162 targets = List() # list of target IDENTs
163 163 loads = List() # list of engine loads
164 164 # full = Set() # set of IDENTs that have HWM outstanding tasks
165 165 all_completed = Set() # set of all completed tasks
166 166 all_failed = Set() # set of all failed tasks
167 167 all_done = Set() # set of all finished tasks=union(completed,failed)
168 168 all_ids = Set() # set of all submitted task IDs
169 169 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
170 170 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
171 171
172 172
173 173 def start(self):
174 174 self.engine_stream.on_recv(self.dispatch_result, copy=False)
175 175 self._notification_handlers = dict(
176 176 registration_notification = self._register_engine,
177 177 unregistration_notification = self._unregister_engine
178 178 )
179 179 self.notifier_stream.on_recv(self.dispatch_notification)
180 180 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
181 181 self.auditor.start()
182 182 self.log.info("Scheduler started...%r"%self)
183 183
184 184 def resume_receiving(self):
185 185 """Resume accepting jobs."""
186 186 self.client_stream.on_recv(self.dispatch_submission, copy=False)
187 187
188 188 def stop_receiving(self):
189 189 """Stop accepting jobs while there are no engines.
190 190 Leave them in the ZMQ queue."""
191 191 self.client_stream.on_recv(None)
192 192
193 193 #-----------------------------------------------------------------------
194 194 # [Un]Registration Handling
195 195 #-----------------------------------------------------------------------
196 196
197 197 def dispatch_notification(self, msg):
198 198 """dispatch register/unregister events."""
199 199 try:
200 200 idents,msg = self.session.feed_identities(msg)
201 201 except ValueError:
202 202 self.log.warn("task::Invalid Message: %r"%msg)
203 203 return
204 204 try:
205 205 msg = self.session.unpack_message(msg)
206 206 except ValueError:
207 207 self.log.warn("task::Unauthorized message from: %r"%idents)
208 208 return
209 209
210 210 msg_type = msg['msg_type']
211 211
212 212 handler = self._notification_handlers.get(msg_type, None)
213 213 if handler is None:
214 214 self.log.error("Unhandled message type: %r"%msg_type)
215 215 else:
216 216 try:
217 217 handler(str(msg['content']['queue']))
218 218 except KeyError:
219 219 self.log.error("task::Invalid notification msg: %r"%msg)
220 220
221 221 @logged
222 222 def _register_engine(self, uid):
223 223 """New engine with ident `uid` became available."""
224 224 # head of the line:
225 225 self.targets.insert(0,uid)
226 226 self.loads.insert(0,0)
227 227 # initialize sets
228 228 self.completed[uid] = set()
229 229 self.failed[uid] = set()
230 230 self.pending[uid] = {}
231 231 if len(self.targets) == 1:
232 232 self.resume_receiving()
233 233 # rescan the graph:
234 234 self.update_graph(None)
235 235
236 236 def _unregister_engine(self, uid):
237 237 """Existing engine with ident `uid` became unavailable."""
238 238 if len(self.targets) == 1:
239 239 # this was our only engine
240 240 self.stop_receiving()
241 241
242 242 # handle any potentially finished tasks:
243 243 self.engine_stream.flush()
244 244
245 245 # don't pop destinations, because they might be used later
246 246 # map(self.destinations.pop, self.completed.pop(uid))
247 247 # map(self.destinations.pop, self.failed.pop(uid))
248 248
249 249 # prevent this engine from receiving work
250 250 idx = self.targets.index(uid)
251 251 self.targets.pop(idx)
252 252 self.loads.pop(idx)
253 253
254 254 # wait 5 seconds before cleaning up pending jobs, since the results might
255 255 # still be incoming
256 256 if self.pending[uid]:
257 257 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
258 258 dc.start()
259 259 else:
260 260 self.completed.pop(uid)
261 261 self.failed.pop(uid)
262 262
263 263
264 264 @logged
265 265 def handle_stranded_tasks(self, engine):
266 266 """Deal with jobs resident in an engine that died."""
267 267 lost = self.pending[engine]
268 268 for msg_id in lost.keys():
269 269 if msg_id not in self.pending[engine]:
270 270 # prevent double-handling of messages
271 271 continue
272 272
273 273 raw_msg = lost[msg_id][0]
274 274 idents,msg = self.session.feed_identities(raw_msg, copy=False)
275 275 parent = self.session.unpack(msg[1].bytes)
276 276 idents = [engine, idents[0]]
277 277
278 278 # build fake error reply
279 279 try:
280 280 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
281 281 except:
282 282 content = error.wrap_exception()
283 283 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
284 284 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
285 285 # and dispatch it
286 286 self.dispatch_result(raw_reply)
287 287
288 288 # finally scrub completed/failed lists
289 289 self.completed.pop(engine)
290 290 self.failed.pop(engine)
291 291
292 292
293 293 #-----------------------------------------------------------------------
294 294 # Job Submission
295 295 #-----------------------------------------------------------------------
296 296 @logged
297 297 def dispatch_submission(self, raw_msg):
298 298 """Dispatch job submission to appropriate handlers."""
299 299 # ensure targets up to date:
300 300 self.notifier_stream.flush()
301 301 try:
302 302 idents, msg = self.session.feed_identities(raw_msg, copy=False)
303 303 msg = self.session.unpack_message(msg, content=False, copy=False)
304 304 except Exception:
305 305 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
306 306 return
307 307
308 308
309 309 # send to monitor
310 310 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
311 311
312 312 header = msg['header']
313 313 msg_id = header['msg_id']
314 314 self.all_ids.add(msg_id)
315 315
316 316 # targets
317 317 targets = set(header.get('targets', []))
318 318 retries = header.get('retries', 0)
319 319 self.retries[msg_id] = retries
320 320
321 321 # time dependencies
322 322 after = Dependency(header.get('after', []))
323 323 if after.all:
324 324 if after.success:
325 325 after.difference_update(self.all_completed)
326 326 if after.failure:
327 327 after.difference_update(self.all_failed)
328 328 if after.check(self.all_completed, self.all_failed):
329 329 # recast as empty set, if `after` already met,
330 330 # to prevent unnecessary set comparisons
331 331 after = MET
332 332
333 333 # location dependencies
334 334 follow = Dependency(header.get('follow', []))
335 335
336 336 # turn timeouts into datetime objects:
337 337 timeout = header.get('timeout', None)
338 338 if timeout:
339 339 timeout = datetime.now() + timedelta(0,timeout,0)
340 340
341 341 args = [raw_msg, targets, after, follow, timeout]
342 342
343 343 # validate and reduce dependencies:
344 344 for dep in after,follow:
345 345 # check valid:
346 346 if msg_id in dep or dep.difference(self.all_ids):
347 347 self.depending[msg_id] = args
348 348 return self.fail_unreachable(msg_id, error.InvalidDependency)
349 349 # check if unreachable:
350 350 if dep.unreachable(self.all_completed, self.all_failed):
351 351 self.depending[msg_id] = args
352 352 return self.fail_unreachable(msg_id)
353 353
354 354 if after.check(self.all_completed, self.all_failed):
355 355 # time deps already met, try to run
356 356 if not self.maybe_run(msg_id, *args):
357 357 # can't run yet
358 358 if msg_id not in self.all_failed:
359 359 # could have failed as unreachable
360 360 self.save_unmet(msg_id, *args)
361 361 else:
362 362 self.save_unmet(msg_id, *args)
363 363
364 364 # @logged
365 365 def audit_timeouts(self):
366 366 """Audit all waiting tasks for expired timeouts."""
367 367 now = datetime.now()
368 368 for msg_id in self.depending.keys():
369 369 # must recheck, in case one failure cascaded to another:
370 370 if msg_id in self.depending:
371 371 raw,after,targets,follow,timeout = self.depending[msg_id]
372 372 if timeout and timeout < now:
373 373 self.fail_unreachable(msg_id, error.TaskTimeout)
374 374
375 375 @logged
376 376 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
377 377 """a task has become unreachable, send a reply with an ImpossibleDependency
378 378 error."""
379 379 if msg_id not in self.depending:
380 380 self.log.error("msg %r already failed!"%msg_id)
381 381 return
382 382 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
383 383 for mid in follow.union(after):
384 384 if mid in self.graph:
385 385 self.graph[mid].remove(msg_id)
386 386
387 387 # FIXME: unpacking a message I've already unpacked, but didn't save:
388 388 idents,msg = self.session.feed_identities(raw_msg, copy=False)
389 389 header = self.session.unpack(msg[1].bytes)
390 390
391 391 try:
392 392 raise why()
393 393 except:
394 394 content = error.wrap_exception()
395 395
396 396 self.all_done.add(msg_id)
397 397 self.all_failed.add(msg_id)
398 398
399 399 msg = self.session.send(self.client_stream, 'apply_reply', content,
400 400 parent=header, ident=idents)
401 401 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
402 402
403 403 self.update_graph(msg_id, success=False)
404 404
405 405 @logged
406 406 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
407 407 """check location dependencies, and run if they are met."""
408 408 blacklist = self.blacklist.setdefault(msg_id, set())
409 409 if follow or targets or blacklist or self.hwm:
410 410 # we need a can_run filter
411 411 def can_run(idx):
412 412 # check hwm
413 413 if self.hwm and self.loads[idx] == self.hwm:
414 414 return False
415 415 target = self.targets[idx]
416 416 # check blacklist
417 417 if target in blacklist:
418 418 return False
419 419 # check targets
420 420 if targets and target not in targets:
421 421 return False
422 422 # check follow
423 423 return follow.check(self.completed[target], self.failed[target])
424 424
425 425 indices = filter(can_run, range(len(self.targets)))
426 426
427 427 if not indices:
428 428 # couldn't run
429 429 if follow.all:
430 430 # check follow for impossibility
431 431 dests = set()
432 432 relevant = set()
433 433 if follow.success:
434 434 relevant = self.all_completed
435 435 if follow.failure:
436 436 relevant = relevant.union(self.all_failed)
437 437 for m in follow.intersection(relevant):
438 438 dests.add(self.destinations[m])
439 439 if len(dests) > 1:
440 440 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
441 441 self.fail_unreachable(msg_id)
442 442 return False
443 443 if targets:
444 444 # check blacklist+targets for impossibility
445 445 targets.difference_update(blacklist)
446 446 if not targets or not targets.intersection(self.targets):
447 447 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
448 448 self.fail_unreachable(msg_id)
449 449 return False
450 450 return False
451 451 else:
452 452 indices = None
453 453
454 454 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
455 455 return True
456 456
457 457 @logged
458 458 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
459 459 """Save a message for later submission when its dependencies are met."""
460 460 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
461 461 # track the ids in follow or after, but not those already finished
462 462 for dep_id in after.union(follow).difference(self.all_done):
463 463 if dep_id not in self.graph:
464 464 self.graph[dep_id] = set()
465 465 self.graph[dep_id].add(msg_id)
466 466
467 467 @logged
468 468 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
469 469 """Submit a task to any of a subset of our targets."""
470 470 if indices:
471 471 loads = [self.loads[i] for i in indices]
472 472 else:
473 473 loads = self.loads
474 474 idx = self.scheme(loads)
475 475 if indices:
476 476 idx = indices[idx]
477 477 target = self.targets[idx]
478 478 # print (target, map(str, msg[:3]))
479 479 # send job to the engine
480 480 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
481 481 self.engine_stream.send_multipart(raw_msg, copy=False)
482 482 # update load
483 483 self.add_job(idx)
484 484 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
485 485 # notify Hub
486 486 content = dict(msg_id=msg_id, engine_id=target)
487 487 self.session.send(self.mon_stream, 'task_destination', content=content,
488 488 ident=['tracktask',self.session.session])
489 489
490 490
491 491 #-----------------------------------------------------------------------
492 492 # Result Handling
493 493 #-----------------------------------------------------------------------
494 494 @logged
495 495 def dispatch_result(self, raw_msg):
496 496 """dispatch method for result replies"""
497 497 try:
498 498 idents,msg = self.session.feed_identities(raw_msg, copy=False)
499 499 msg = self.session.unpack_message(msg, content=False, copy=False)
500 500 engine = idents[0]
501 501 try:
502 502 idx = self.targets.index(engine)
503 503 except ValueError:
504 504 pass # skip load-update for dead engines
505 505 else:
506 506 self.finish_job(idx)
507 507 except Exception:
508 508 self.log.error("task::Invaid result: %r"%raw_msg, exc_info=True)
509 509 return
510 510
511 511 header = msg['header']
512 512 parent = msg['parent_header']
513 513 if header.get('dependencies_met', True):
514 514 success = (header['status'] == 'ok')
515 515 msg_id = parent['msg_id']
516 516 retries = self.retries[msg_id]
517 517 if not success and retries > 0:
518 518 # failed
519 519 self.retries[msg_id] = retries - 1
520 520 self.handle_unmet_dependency(idents, parent)
521 521 else:
522 522 del self.retries[msg_id]
523 523 # relay to client and update graph
524 524 self.handle_result(idents, parent, raw_msg, success)
525 525 # send to Hub monitor
526 526 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
527 527 else:
528 528 self.handle_unmet_dependency(idents, parent)
529 529
530 530 @logged
531 531 def handle_result(self, idents, parent, raw_msg, success=True):
532 532 """handle a real task result, either success or failure"""
533 533 # first, relay result to client
534 534 engine = idents[0]
535 535 client = idents[1]
536 536 # swap_ids for XREP-XREP mirror
537 537 raw_msg[:2] = [client,engine]
538 538 # print (map(str, raw_msg[:4]))
539 539 self.client_stream.send_multipart(raw_msg, copy=False)
540 540 # now, update our data structures
541 541 msg_id = parent['msg_id']
542 542 self.blacklist.pop(msg_id, None)
543 543 self.pending[engine].pop(msg_id)
544 544 if success:
545 545 self.completed[engine].add(msg_id)
546 546 self.all_completed.add(msg_id)
547 547 else:
548 548 self.failed[engine].add(msg_id)
549 549 self.all_failed.add(msg_id)
550 550 self.all_done.add(msg_id)
551 551 self.destinations[msg_id] = engine
552 552
553 553 self.update_graph(msg_id, success)
554 554
555 555 @logged
556 556 def handle_unmet_dependency(self, idents, parent):
557 557 """handle an unmet dependency"""
558 558 engine = idents[0]
559 559 msg_id = parent['msg_id']
560 560
561 561 if msg_id not in self.blacklist:
562 562 self.blacklist[msg_id] = set()
563 563 self.blacklist[msg_id].add(engine)
564 564
565 565 args = self.pending[engine].pop(msg_id)
566 566 raw,targets,after,follow,timeout = args
567 567
568 568 if self.blacklist[msg_id] == targets:
569 569 self.depending[msg_id] = args
570 570 self.fail_unreachable(msg_id)
571 571 elif not self.maybe_run(msg_id, *args):
572 572 # resubmit failed
573 573 if msg_id not in self.all_failed:
574 574 # put it back in our dependency tree
575 575 self.save_unmet(msg_id, *args)
576 576
577 577 if self.hwm:
578 578 try:
579 579 idx = self.targets.index(engine)
580 580 except ValueError:
581 581 pass # skip load-update for dead engines
582 582 else:
583 583 if self.loads[idx] == self.hwm-1:
584 584 self.update_graph(None)
585 585
586 586
587 587
588 588 @logged
589 589 def update_graph(self, dep_id=None, success=True):
590 590 """dep_id just finished. Update our dependency
591 591 graph and submit any jobs that just became runable.
592 592
593 593 Called with dep_id=None to update entire graph for hwm, but without finishing
594 594 a task.
595 595 """
596 596 # print ("\n\n***********")
597 597 # pprint (dep_id)
598 598 # pprint (self.graph)
599 599 # pprint (self.depending)
600 600 # pprint (self.all_completed)
601 601 # pprint (self.all_failed)
602 602 # print ("\n\n***********\n\n")
603 603 # update any jobs that depended on the dependency
604 604 jobs = self.graph.pop(dep_id, [])
605 605
606 606 # recheck *all* jobs if
607 607 # a) we have HWM and an engine just become no longer full
608 608 # or b) dep_id was given as None
609 609 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
610 610 jobs = self.depending.keys()
611 611
612 612 for msg_id in jobs:
613 613 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
614 614
615 if after.unreachable(self.all_completed, self.all_failed) or follow.unreachable(self.all_completed, self.all_failed):
615 if after.unreachable(self.all_completed, self.all_failed)\
616 or follow.unreachable(self.all_completed, self.all_failed):
616 617 self.fail_unreachable(msg_id)
617 618
618 619 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
619 620 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
620 621
621 622 self.depending.pop(msg_id)
622 623 for mid in follow.union(after):
623 624 if mid in self.graph:
624 625 self.graph[mid].remove(msg_id)
625 626
626 627 #----------------------------------------------------------------------
627 628 # methods to be overridden by subclasses
628 629 #----------------------------------------------------------------------
629 630
630 631 def add_job(self, idx):
631 632 """Called after self.targets[idx] just got the job with header.
632 633 Override with subclasses. The default ordering is simple LRU.
633 634 The default loads are the number of outstanding jobs."""
634 635 self.loads[idx] += 1
635 636 for lis in (self.targets, self.loads):
636 637 lis.append(lis.pop(idx))
637 638
638 639
639 640 def finish_job(self, idx):
640 641 """Called after self.targets[idx] just finished a job.
641 642 Override with subclasses."""
642 643 self.loads[idx] -= 1
643 644
644 645
645 646
646 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,logname='ZMQ',
647 log_url=None, loglevel=logging.DEBUG,
648 identity=b'task'):
647 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
648 logname='root', log_url=None, loglevel=logging.DEBUG,
649 identity=b'task'):
649 650 from zmq.eventloop import ioloop
650 651 from zmq.eventloop.zmqstream import ZMQStream
651 652
652 653 if config:
653 654 # unwrap dict back into Config
654 655 config = Config(config)
655 656
656 657 ctx = zmq.Context()
657 658 loop = ioloop.IOLoop()
658 659 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
659 660 ins.setsockopt(zmq.IDENTITY, identity)
660 661 ins.bind(in_addr)
661 662
662 663 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
663 664 outs.setsockopt(zmq.IDENTITY, identity)
664 665 outs.bind(out_addr)
665 666 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
666 667 mons.connect(mon_addr)
667 668 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
668 669 nots.setsockopt(zmq.SUBSCRIBE, '')
669 670 nots.connect(not_addr)
670 671
671 672 # setup logging. Note that these will not work in-process, because they clobber
672 673 # existing loggers.
673 674 if log_url:
674 connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
675 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
675 676 else:
676 local_logger(logname, loglevel)
677 log = local_logger(logname, loglevel)
677 678
678 679 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
679 680 mon_stream=mons, notifier_stream=nots,
680 loop=loop, logname=logname,
681 loop=loop, log=log,
681 682 config=config)
682 683 scheduler.start()
683 684 try:
684 685 loop.start()
685 686 except KeyboardInterrupt:
686 687 print ("interrupted, exiting...", file=sys.__stderr__)
687 688
@@ -1,99 +1,72 b''
1 1 """Base config factories."""
2 2
3 3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2008-2009 The IPython Development Team
4 # Copyright (C) 2010-2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13
14 14
15 15 import logging
16 16 import os
17 17
18 18 import zmq
19 19 from zmq.eventloop.ioloop import IOLoop
20 20
21 21 from IPython.config.configurable import Configurable
22 22 from IPython.utils.traitlets import Int, Instance, Unicode
23 23
24 24 from IPython.parallel.util import select_random_ports
25 from IPython.zmq.session import Session
25 from IPython.zmq.session import Session, SessionFactory
26 26
27 27 #-----------------------------------------------------------------------------
28 28 # Classes
29 29 #-----------------------------------------------------------------------------
30 class LoggingFactory(Configurable):
31 """A most basic class, that has a `log` (type:`Logger`) attribute, set via a `logname` Trait."""
32 log = Instance('logging.Logger', ('ZMQ', logging.WARN))
33 logname = Unicode('ZMQ')
34 def _logname_changed(self, name, old, new):
35 self.log = logging.getLogger(new)
36
37 30
38 class SessionFactory(LoggingFactory):
39 """The Base factory from which every factory in IPython.parallel inherits"""
40
41 # not configurable:
42 context = Instance('zmq.Context')
43 def _context_default(self):
44 return zmq.Context.instance()
45
46 session = Instance('IPython.zmq.session.Session')
47 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
48 def _loop_default(self):
49 return IOLoop.instance()
50
51
52 def __init__(self, **kwargs):
53 super(SessionFactory, self).__init__(**kwargs)
54
55 # construct the session
56 self.session = Session(**kwargs)
57
58 31
59 32 class RegistrationFactory(SessionFactory):
60 33 """The Base Configurable for objects that involve registration."""
61 34
62 35 url = Unicode('', config=True,
63 36 help="""The 0MQ url used for registration. This sets transport, ip, and port
64 37 in one variable. For example: url='tcp://127.0.0.1:12345' or
65 38 url='epgm://*:90210'""") # url takes precedence over ip,regport,transport
66 39 transport = Unicode('tcp', config=True,
67 40 help="""The 0MQ transport for communications. This will likely be
68 41 the default of 'tcp', but other values include 'ipc', 'epgm', 'inproc'.""")
69 42 ip = Unicode('127.0.0.1', config=True,
70 43 help="""The IP address for registration. This is generally either
71 44 '127.0.0.1' for loopback only or '*' for all interfaces.
72 45 [default: '127.0.0.1']""")
73 46 regport = Int(config=True,
74 47 help="""The port on which the Hub listens for registration.""")
75 48 def _regport_default(self):
76 49 return select_random_ports(1)[0]
77 50
78 51 def __init__(self, **kwargs):
79 52 super(RegistrationFactory, self).__init__(**kwargs)
80 53 self._propagate_url()
81 54 self._rebuild_url()
82 55 self.on_trait_change(self._propagate_url, 'url')
83 56 self.on_trait_change(self._rebuild_url, 'ip')
84 57 self.on_trait_change(self._rebuild_url, 'transport')
85 58 self.on_trait_change(self._rebuild_url, 'regport')
86 59
87 60 def _rebuild_url(self):
88 61 self.url = "%s://%s:%i"%(self.transport, self.ip, self.regport)
89 62
90 63 def _propagate_url(self):
91 64 """Ensure self.url contains full transport://interface:port"""
92 65 if self.url:
93 66 iface = self.url.split('://',1)
94 67 if len(iface) == 2:
95 68 self.transport,iface = iface
96 69 iface = iface.split(':')
97 70 self.ip = iface[0]
98 71 if iface[1]:
99 72 self.regport = int(iface[1])
@@ -1,466 +1,468 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010-2011 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 # Standard library imports.
14 14 import logging
15 15 import os
16 16 import re
17 17 import stat
18 18 import socket
19 19 import sys
20 20 from signal import signal, SIGINT, SIGABRT, SIGTERM
21 21 try:
22 22 from signal import SIGKILL
23 23 except ImportError:
24 24 SIGKILL=None
25 25
26 26 try:
27 27 import cPickle
28 28 pickle = cPickle
29 29 except:
30 30 cPickle = None
31 31 import pickle
32 32
33 33 # System library imports
34 34 import zmq
35 35 from zmq.log import handlers
36 36
37 37 # IPython imports
38 38 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
39 39 from IPython.utils.newserialized import serialize, unserialize
40 40 from IPython.zmq.log import EnginePUBHandler
41 41
42 42 #-----------------------------------------------------------------------------
43 43 # Classes
44 44 #-----------------------------------------------------------------------------
45 45
46 46 class Namespace(dict):
47 47 """Subclass of dict for attribute access to keys."""
48 48
49 49 def __getattr__(self, key):
50 50 """getattr aliased to getitem"""
51 51 if key in self.iterkeys():
52 52 return self[key]
53 53 else:
54 54 raise NameError(key)
55 55
56 56 def __setattr__(self, key, value):
57 57 """setattr aliased to setitem, with strict"""
58 58 if hasattr(dict, key):
59 59 raise KeyError("Cannot override dict keys %r"%key)
60 60 self[key] = value
61 61
62 62
63 63 class ReverseDict(dict):
64 64 """simple double-keyed subset of dict methods."""
65 65
66 66 def __init__(self, *args, **kwargs):
67 67 dict.__init__(self, *args, **kwargs)
68 68 self._reverse = dict()
69 69 for key, value in self.iteritems():
70 70 self._reverse[value] = key
71 71
72 72 def __getitem__(self, key):
73 73 try:
74 74 return dict.__getitem__(self, key)
75 75 except KeyError:
76 76 return self._reverse[key]
77 77
78 78 def __setitem__(self, key, value):
79 79 if key in self._reverse:
80 80 raise KeyError("Can't have key %r on both sides!"%key)
81 81 dict.__setitem__(self, key, value)
82 82 self._reverse[value] = key
83 83
84 84 def pop(self, key):
85 85 value = dict.pop(self, key)
86 86 self._reverse.pop(value)
87 87 return value
88 88
89 89 def get(self, key, default=None):
90 90 try:
91 91 return self[key]
92 92 except KeyError:
93 93 return default
94 94
95 95 #-----------------------------------------------------------------------------
96 96 # Functions
97 97 #-----------------------------------------------------------------------------
98 98
99 99 def validate_url(url):
100 100 """validate a url for zeromq"""
101 101 if not isinstance(url, basestring):
102 102 raise TypeError("url must be a string, not %r"%type(url))
103 103 url = url.lower()
104 104
105 105 proto_addr = url.split('://')
106 106 assert len(proto_addr) == 2, 'Invalid url: %r'%url
107 107 proto, addr = proto_addr
108 108 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
109 109
110 110 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
111 111 # author: Remi Sabourin
112 112 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
113 113
114 114 if proto == 'tcp':
115 115 lis = addr.split(':')
116 116 assert len(lis) == 2, 'Invalid url: %r'%url
117 117 addr,s_port = lis
118 118 try:
119 119 port = int(s_port)
120 120 except ValueError:
121 121 raise AssertionError("Invalid port %r in url: %r"%(port, url))
122 122
123 123 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
124 124
125 125 else:
126 126 # only validate tcp urls currently
127 127 pass
128 128
129 129 return True
130 130
131 131
132 132 def validate_url_container(container):
133 133 """validate a potentially nested collection of urls."""
134 134 if isinstance(container, basestring):
135 135 url = container
136 136 return validate_url(url)
137 137 elif isinstance(container, dict):
138 138 container = container.itervalues()
139 139
140 140 for element in container:
141 141 validate_url_container(element)
142 142
143 143
144 144 def split_url(url):
145 145 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
146 146 proto_addr = url.split('://')
147 147 assert len(proto_addr) == 2, 'Invalid url: %r'%url
148 148 proto, addr = proto_addr
149 149 lis = addr.split(':')
150 150 assert len(lis) == 2, 'Invalid url: %r'%url
151 151 addr,s_port = lis
152 152 return proto,addr,s_port
153 153
154 154 def disambiguate_ip_address(ip, location=None):
155 155 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
156 156 ones, based on the location (default interpretation of location is localhost)."""
157 157 if ip in ('0.0.0.0', '*'):
158 158 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
159 159 if location is None or location in external_ips:
160 160 ip='127.0.0.1'
161 161 elif location:
162 162 return location
163 163 return ip
164 164
165 165 def disambiguate_url(url, location=None):
166 166 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
167 167 ones, based on the location (default interpretation is localhost).
168 168
169 169 This is for zeromq urls, such as tcp://*:10101."""
170 170 try:
171 171 proto,ip,port = split_url(url)
172 172 except AssertionError:
173 173 # probably not tcp url; could be ipc, etc.
174 174 return url
175 175
176 176 ip = disambiguate_ip_address(ip,location)
177 177
178 178 return "%s://%s:%s"%(proto,ip,port)
179 179
180 180
181 181 def rekey(dikt):
182 182 """Rekey a dict that has been forced to use str keys where there should be
183 183 ints by json. This belongs in the jsonutil added by fperez."""
184 184 for k in dikt.iterkeys():
185 185 if isinstance(k, str):
186 186 ik=fk=None
187 187 try:
188 188 ik = int(k)
189 189 except ValueError:
190 190 try:
191 191 fk = float(k)
192 192 except ValueError:
193 193 continue
194 194 if ik is not None:
195 195 nk = ik
196 196 else:
197 197 nk = fk
198 198 if nk in dikt:
199 199 raise KeyError("already have key %r"%nk)
200 200 dikt[nk] = dikt.pop(k)
201 201 return dikt
202 202
203 203 def serialize_object(obj, threshold=64e-6):
204 204 """Serialize an object into a list of sendable buffers.
205 205
206 206 Parameters
207 207 ----------
208 208
209 209 obj : object
210 210 The object to be serialized
211 211 threshold : float
212 212 The threshold for not double-pickling the content.
213 213
214 214
215 215 Returns
216 216 -------
217 217 ('pmd', [bufs]) :
218 218 where pmd is the pickled metadata wrapper,
219 219 bufs is a list of data buffers
220 220 """
221 221 databuffers = []
222 222 if isinstance(obj, (list, tuple)):
223 223 clist = canSequence(obj)
224 224 slist = map(serialize, clist)
225 225 for s in slist:
226 226 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
227 227 databuffers.append(s.getData())
228 228 s.data = None
229 229 return pickle.dumps(slist,-1), databuffers
230 230 elif isinstance(obj, dict):
231 231 sobj = {}
232 232 for k in sorted(obj.iterkeys()):
233 233 s = serialize(can(obj[k]))
234 234 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
235 235 databuffers.append(s.getData())
236 236 s.data = None
237 237 sobj[k] = s
238 238 return pickle.dumps(sobj,-1),databuffers
239 239 else:
240 240 s = serialize(can(obj))
241 241 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
242 242 databuffers.append(s.getData())
243 243 s.data = None
244 244 return pickle.dumps(s,-1),databuffers
245 245
246 246
247 247 def unserialize_object(bufs):
248 248 """reconstruct an object serialized by serialize_object from data buffers."""
249 249 bufs = list(bufs)
250 250 sobj = pickle.loads(bufs.pop(0))
251 251 if isinstance(sobj, (list, tuple)):
252 252 for s in sobj:
253 253 if s.data is None:
254 254 s.data = bufs.pop(0)
255 255 return uncanSequence(map(unserialize, sobj)), bufs
256 256 elif isinstance(sobj, dict):
257 257 newobj = {}
258 258 for k in sorted(sobj.iterkeys()):
259 259 s = sobj[k]
260 260 if s.data is None:
261 261 s.data = bufs.pop(0)
262 262 newobj[k] = uncan(unserialize(s))
263 263 return newobj, bufs
264 264 else:
265 265 if sobj.data is None:
266 266 sobj.data = bufs.pop(0)
267 267 return uncan(unserialize(sobj)), bufs
268 268
269 269 def pack_apply_message(f, args, kwargs, threshold=64e-6):
270 270 """pack up a function, args, and kwargs to be sent over the wire
271 271 as a series of buffers. Any object whose data is larger than `threshold`
272 272 will not have their data copied (currently only numpy arrays support zero-copy)"""
273 273 msg = [pickle.dumps(can(f),-1)]
274 274 databuffers = [] # for large objects
275 275 sargs, bufs = serialize_object(args,threshold)
276 276 msg.append(sargs)
277 277 databuffers.extend(bufs)
278 278 skwargs, bufs = serialize_object(kwargs,threshold)
279 279 msg.append(skwargs)
280 280 databuffers.extend(bufs)
281 281 msg.extend(databuffers)
282 282 return msg
283 283
284 284 def unpack_apply_message(bufs, g=None, copy=True):
285 285 """unpack f,args,kwargs from buffers packed by pack_apply_message()
286 286 Returns: original f,args,kwargs"""
287 287 bufs = list(bufs) # allow us to pop
288 288 assert len(bufs) >= 3, "not enough buffers!"
289 289 if not copy:
290 290 for i in range(3):
291 291 bufs[i] = bufs[i].bytes
292 292 cf = pickle.loads(bufs.pop(0))
293 293 sargs = list(pickle.loads(bufs.pop(0)))
294 294 skwargs = dict(pickle.loads(bufs.pop(0)))
295 295 # print sargs, skwargs
296 296 f = uncan(cf, g)
297 297 for sa in sargs:
298 298 if sa.data is None:
299 299 m = bufs.pop(0)
300 300 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
301 301 # always use a buffer, until memoryviews get sorted out
302 302 sa.data = buffer(m)
303 303 # disable memoryview support
304 304 # if copy:
305 305 # sa.data = buffer(m)
306 306 # else:
307 307 # sa.data = m.buffer
308 308 else:
309 309 if copy:
310 310 sa.data = m
311 311 else:
312 312 sa.data = m.bytes
313 313
314 314 args = uncanSequence(map(unserialize, sargs), g)
315 315 kwargs = {}
316 316 for k in sorted(skwargs.iterkeys()):
317 317 sa = skwargs[k]
318 318 if sa.data is None:
319 319 m = bufs.pop(0)
320 320 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
321 321 # always use a buffer, until memoryviews get sorted out
322 322 sa.data = buffer(m)
323 323 # disable memoryview support
324 324 # if copy:
325 325 # sa.data = buffer(m)
326 326 # else:
327 327 # sa.data = m.buffer
328 328 else:
329 329 if copy:
330 330 sa.data = m
331 331 else:
332 332 sa.data = m.bytes
333 333
334 334 kwargs[k] = uncan(unserialize(sa), g)
335 335
336 336 return f,args,kwargs
337 337
338 338 #--------------------------------------------------------------------------
339 339 # helpers for implementing old MEC API via view.apply
340 340 #--------------------------------------------------------------------------
341 341
342 342 def interactive(f):
343 343 """decorator for making functions appear as interactively defined.
344 344 This results in the function being linked to the user_ns as globals()
345 345 instead of the module globals().
346 346 """
347 347 f.__module__ = '__main__'
348 348 return f
349 349
350 350 @interactive
351 351 def _push(ns):
352 352 """helper method for implementing `client.push` via `client.apply`"""
353 353 globals().update(ns)
354 354
355 355 @interactive
356 356 def _pull(keys):
357 357 """helper method for implementing `client.pull` via `client.apply`"""
358 358 user_ns = globals()
359 359 if isinstance(keys, (list,tuple, set)):
360 360 for key in keys:
361 361 if not user_ns.has_key(key):
362 362 raise NameError("name '%s' is not defined"%key)
363 363 return map(user_ns.get, keys)
364 364 else:
365 365 if not user_ns.has_key(keys):
366 366 raise NameError("name '%s' is not defined"%keys)
367 367 return user_ns.get(keys)
368 368
369 369 @interactive
370 370 def _execute(code):
371 371 """helper method for implementing `client.execute` via `client.apply`"""
372 372 exec code in globals()
373 373
374 374 #--------------------------------------------------------------------------
375 375 # extra process management utilities
376 376 #--------------------------------------------------------------------------
377 377
378 378 _random_ports = set()
379 379
380 380 def select_random_ports(n):
381 381 """Selects and return n random ports that are available."""
382 382 ports = []
383 383 for i in xrange(n):
384 384 sock = socket.socket()
385 385 sock.bind(('', 0))
386 386 while sock.getsockname()[1] in _random_ports:
387 387 sock.close()
388 388 sock = socket.socket()
389 389 sock.bind(('', 0))
390 390 ports.append(sock)
391 391 for i, sock in enumerate(ports):
392 392 port = sock.getsockname()[1]
393 393 sock.close()
394 394 ports[i] = port
395 395 _random_ports.add(port)
396 396 return ports
397 397
398 398 def signal_children(children):
399 399 """Relay interupt/term signals to children, for more solid process cleanup."""
400 400 def terminate_children(sig, frame):
401 401 logging.critical("Got signal %i, terminating children..."%sig)
402 402 for child in children:
403 403 child.terminate()
404 404
405 405 sys.exit(sig != SIGINT)
406 406 # sys.exit(sig)
407 407 for sig in (SIGINT, SIGABRT, SIGTERM):
408 408 signal(sig, terminate_children)
409 409
410 410 def generate_exec_key(keyfile):
411 411 import uuid
412 412 newkey = str(uuid.uuid4())
413 413 with open(keyfile, 'w') as f:
414 414 # f.write('ipython-key ')
415 415 f.write(newkey+'\n')
416 416 # set user-only RW permissions (0600)
417 417 # this will have no effect on Windows
418 418 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
419 419
420 420
421 421 def integer_loglevel(loglevel):
422 422 try:
423 423 loglevel = int(loglevel)
424 424 except ValueError:
425 425 if isinstance(loglevel, str):
426 426 loglevel = getattr(logging, loglevel)
427 427 return loglevel
428 428
429 429 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
430 430 logger = logging.getLogger(logname)
431 431 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
432 432 # don't add a second PUBHandler
433 433 return
434 434 loglevel = integer_loglevel(loglevel)
435 435 lsock = context.socket(zmq.PUB)
436 436 lsock.connect(iface)
437 437 handler = handlers.PUBHandler(lsock)
438 438 handler.setLevel(loglevel)
439 439 handler.root_topic = root
440 440 logger.addHandler(handler)
441 441 logger.setLevel(loglevel)
442 442
443 443 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
444 444 logger = logging.getLogger()
445 445 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
446 446 # don't add a second PUBHandler
447 447 return
448 448 loglevel = integer_loglevel(loglevel)
449 449 lsock = context.socket(zmq.PUB)
450 450 lsock.connect(iface)
451 451 handler = EnginePUBHandler(engine, lsock)
452 452 handler.setLevel(loglevel)
453 453 logger.addHandler(handler)
454 454 logger.setLevel(loglevel)
455 return logger
455 456
456 457 def local_logger(logname, loglevel=logging.DEBUG):
457 458 loglevel = integer_loglevel(loglevel)
458 459 logger = logging.getLogger(logname)
459 460 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
460 461 # don't add a second StreamHandler
461 462 return
462 463 handler = logging.StreamHandler()
463 464 handler.setLevel(loglevel)
464 465 logger.addHandler(handler)
465 466 logger.setLevel(loglevel)
467 return logger
466 468
@@ -1,479 +1,511 b''
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2010-2011 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14
15 15 import hmac
16 import logging
16 17 import os
17 18 import pprint
18 19 import uuid
19 20 from datetime import datetime
20 21
21 22 try:
22 23 import cPickle
23 24 pickle = cPickle
24 25 except:
25 26 cPickle = None
26 27 import pickle
27 28
28 29 import zmq
29 30 from zmq.utils import jsonapi
31 from zmq.eventloop.ioloop import IOLoop
30 32 from zmq.eventloop.zmqstream import ZMQStream
31 33
32 34 from IPython.config.configurable import Configurable
33 35 from IPython.utils.importstring import import_item
34 36 from IPython.utils.jsonutil import date_default
35 37 from IPython.utils.traitlets import CStr, Unicode, Bool, Any, Instance, Set
36 38
37 39 #-----------------------------------------------------------------------------
38 40 # utility functions
39 41 #-----------------------------------------------------------------------------
40 42
41 43 def squash_unicode(obj):
42 44 """coerce unicode back to bytestrings."""
43 45 if isinstance(obj,dict):
44 46 for key in obj.keys():
45 47 obj[key] = squash_unicode(obj[key])
46 48 if isinstance(key, unicode):
47 49 obj[squash_unicode(key)] = obj.pop(key)
48 50 elif isinstance(obj, list):
49 51 for i,v in enumerate(obj):
50 52 obj[i] = squash_unicode(v)
51 53 elif isinstance(obj, unicode):
52 54 obj = obj.encode('utf8')
53 55 return obj
54 56
55 57 #-----------------------------------------------------------------------------
56 58 # globals and defaults
57 59 #-----------------------------------------------------------------------------
58 60
59 61 _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
60 62 json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:date_default})
61 63 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
62 64
63 65 pickle_packer = lambda o: pickle.dumps(o,-1)
64 66 pickle_unpacker = pickle.loads
65 67
66 68 default_packer = json_packer
67 69 default_unpacker = json_unpacker
68 70
69 71
70 72 DELIM="<IDS|MSG>"
71 73
72 74 #-----------------------------------------------------------------------------
73 75 # Classes
74 76 #-----------------------------------------------------------------------------
75 77
78 class SessionFactory(Configurable):
79 """The Base class for configurables that have a Session, Context, logger,
80 and IOLoop.
81 """
82
83 log = Instance('logging.Logger', ('', logging.WARN))
84
85 logname = Unicode('')
86 def _logname_changed(self, name, old, new):
87 self.log = logging.getLogger(new)
88
89 # not configurable:
90 context = Instance('zmq.Context')
91 def _context_default(self):
92 return zmq.Context.instance()
93
94 session = Instance('IPython.zmq.session.Session')
95
96 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
97 def _loop_default(self):
98 return IOLoop.instance()
99
100 def __init__(self, **kwargs):
101 super(SessionFactory, self).__init__(**kwargs)
102
103 if self.session is None:
104 # construct the session
105 self.session = Session(**kwargs)
106
107
76 108 class Message(object):
77 109 """A simple message object that maps dict keys to attributes.
78 110
79 111 A Message can be created from a dict and a dict from a Message instance
80 112 simply by calling dict(msg_obj)."""
81 113
82 114 def __init__(self, msg_dict):
83 115 dct = self.__dict__
84 116 for k, v in dict(msg_dict).iteritems():
85 117 if isinstance(v, dict):
86 118 v = Message(v)
87 119 dct[k] = v
88 120
89 121 # Having this iterator lets dict(msg_obj) work out of the box.
90 122 def __iter__(self):
91 123 return iter(self.__dict__.iteritems())
92 124
93 125 def __repr__(self):
94 126 return repr(self.__dict__)
95 127
96 128 def __str__(self):
97 129 return pprint.pformat(self.__dict__)
98 130
99 131 def __contains__(self, k):
100 132 return k in self.__dict__
101 133
102 134 def __getitem__(self, k):
103 135 return self.__dict__[k]
104 136
105 137
106 138 def msg_header(msg_id, msg_type, username, session):
107 139 date=datetime.now()
108 140 return locals()
109 141
110 142 def extract_header(msg_or_header):
111 143 """Given a message or header, return the header."""
112 144 if not msg_or_header:
113 145 return {}
114 146 try:
115 147 # See if msg_or_header is the entire message.
116 148 h = msg_or_header['header']
117 149 except KeyError:
118 150 try:
119 151 # See if msg_or_header is just the header
120 152 h = msg_or_header['msg_id']
121 153 except KeyError:
122 154 raise
123 155 else:
124 156 h = msg_or_header
125 157 if not isinstance(h, dict):
126 158 h = dict(h)
127 159 return h
128 160
129 161 class Session(Configurable):
130 162 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
131 163 debug=Bool(False, config=True, help="""Debug output in the Session""")
132 164 packer = Unicode('json',config=True,
133 165 help="""The name of the packer for serializing messages.
134 166 Should be one of 'json', 'pickle', or an import name
135 167 for a custom serializer.""")
136 168 def _packer_changed(self, name, old, new):
137 169 if new.lower() == 'json':
138 170 self.pack = json_packer
139 171 self.unpack = json_unpacker
140 172 elif new.lower() == 'pickle':
141 173 self.pack = pickle_packer
142 174 self.unpack = pickle_unpacker
143 175 else:
144 176 self.pack = import_item(new)
145 177
146 unpacker = Unicode('json',config=True,
178 unpacker = Unicode('json', config=True,
147 179 help="""The name of the unpacker for unserializing messages.
148 180 Only used with custom functions for `packer`.""")
149 181 def _unpacker_changed(self, name, old, new):
150 182 if new.lower() == 'json':
151 183 self.pack = json_packer
152 184 self.unpack = json_unpacker
153 185 elif new.lower() == 'pickle':
154 186 self.pack = pickle_packer
155 187 self.unpack = pickle_unpacker
156 188 else:
157 189 self.unpack = import_item(new)
158 190
159 session = CStr('',config=True,
191 session = CStr('', config=True,
160 192 help="""The UUID identifying this session.""")
161 193 def _session_default(self):
162 194 return bytes(uuid.uuid4())
163 195 username = Unicode(os.environ.get('USER','username'), config=True,
164 196 help="""Username for the Session. Default is your system username.""")
165 197
166 198 # message signature related traits:
167 199 key = CStr('', config=True,
168 200 help="""execution key, for extra authentication.""")
169 201 def _key_changed(self, name, old, new):
170 202 if new:
171 203 self.auth = hmac.HMAC(new)
172 204 else:
173 205 self.auth = None
174 206 auth = Instance(hmac.HMAC)
175 207 counters = Instance('collections.defaultdict', (int,))
176 208 digest_history = Set()
177 209
178 210 keyfile = Unicode('', config=True,
179 211 help="""path to file containing execution key.""")
180 212 def _keyfile_changed(self, name, old, new):
181 213 with open(new, 'rb') as f:
182 214 self.key = f.read().strip()
183 215
184 216 pack = Any(default_packer) # the actual packer function
185 217 def _pack_changed(self, name, old, new):
186 218 if not callable(new):
187 219 raise TypeError("packer must be callable, not %s"%type(new))
188 220
189 221 unpack = Any(default_unpacker) # the actual packer function
190 222 def _unpack_changed(self, name, old, new):
191 223 if not callable(new):
192 224 raise TypeError("packer must be callable, not %s"%type(new))
193 225
194 226 def __init__(self, **kwargs):
195 227 super(Session, self).__init__(**kwargs)
196 228 self.none = self.pack({})
197 229
198 230 @property
199 231 def msg_id(self):
200 232 """always return new uuid"""
201 233 return str(uuid.uuid4())
202 234
203 235 def msg_header(self, msg_type):
204 236 return msg_header(self.msg_id, msg_type, self.username, self.session)
205 237
206 238 def msg(self, msg_type, content=None, parent=None, subheader=None):
207 239 msg = {}
208 240 msg['header'] = self.msg_header(msg_type)
209 241 msg['msg_id'] = msg['header']['msg_id']
210 242 msg['parent_header'] = {} if parent is None else extract_header(parent)
211 243 msg['msg_type'] = msg_type
212 244 msg['content'] = {} if content is None else content
213 245 sub = {} if subheader is None else subheader
214 246 msg['header'].update(sub)
215 247 return msg
216 248
217 249 def check_key(self, msg_or_header):
218 250 """Check that a message's header has the right key"""
219 251 if not self.key:
220 252 return True
221 253 header = extract_header(msg_or_header)
222 254 return header.get('key', '') == self.key
223 255
224 256 def sign(self, msg):
225 257 """Sign a message with HMAC digest. If no auth, return b''."""
226 258 if self.auth is None:
227 259 return b''
228 260 h = self.auth.copy()
229 261 for m in msg:
230 262 h.update(m)
231 263 return h.hexdigest()
232 264
233 265 def serialize(self, msg, ident=None):
234 266 content = msg.get('content', {})
235 267 if content is None:
236 268 content = self.none
237 269 elif isinstance(content, dict):
238 270 content = self.pack(content)
239 271 elif isinstance(content, bytes):
240 272 # content is already packed, as in a relayed message
241 273 pass
242 274 elif isinstance(content, unicode):
243 275 # should be bytes, but JSON often spits out unicode
244 276 content = content.encode('utf8')
245 277 else:
246 278 raise TypeError("Content incorrect type: %s"%type(content))
247 279
248 280 real_message = [self.pack(msg['header']),
249 281 self.pack(msg['parent_header']),
250 282 content
251 283 ]
252 284
253 285 to_send = []
254 286
255 287 if isinstance(ident, list):
256 288 # accept list of idents
257 289 to_send.extend(ident)
258 290 elif ident is not None:
259 291 to_send.append(ident)
260 292 to_send.append(DELIM)
261 293
262 294 signature = self.sign(real_message)
263 295 to_send.append(signature)
264 296
265 297 to_send.extend(real_message)
266 298
267 299 return to_send
268 300
269 301 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
270 302 buffers=None, subheader=None, track=False):
271 303 """Build and send a message via stream or socket.
272 304
273 305 Parameters
274 306 ----------
275 307
276 308 stream : zmq.Socket or ZMQStream
277 309 the socket-like object used to send the data
278 310 msg_or_type : str or Message/dict
279 311 Normally, msg_or_type will be a msg_type unless a message is being sent more
280 312 than once.
281 313
282 314 content : dict or None
283 315 the content of the message (ignored if msg_or_type is a message)
284 316 parent : Message or dict or None
285 317 the parent or parent header describing the parent of this message
286 318 ident : bytes or list of bytes
287 319 the zmq.IDENTITY routing path
288 320 subheader : dict or None
289 321 extra header keys for this message's header
290 322 buffers : list or None
291 323 the already-serialized buffers to be appended to the message
292 324 track : bool
293 325 whether to track. Only for use with Sockets,
294 326 because ZMQStream objects cannot track messages.
295 327
296 328 Returns
297 329 -------
298 330 msg : message dict
299 331 the constructed message
300 332 (msg,tracker) : (message dict, MessageTracker)
301 333 if track=True, then a 2-tuple will be returned,
302 334 the first element being the constructed
303 335 message, and the second being the MessageTracker
304 336
305 337 """
306 338
307 339 if not isinstance(stream, (zmq.Socket, ZMQStream)):
308 340 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
309 341 elif track and isinstance(stream, ZMQStream):
310 342 raise TypeError("ZMQStream cannot track messages")
311 343
312 344 if isinstance(msg_or_type, (Message, dict)):
313 345 # we got a Message, not a msg_type
314 346 # don't build a new Message
315 347 msg = msg_or_type
316 348 else:
317 349 msg = self.msg(msg_or_type, content, parent, subheader)
318 350
319 351 buffers = [] if buffers is None else buffers
320 352 to_send = self.serialize(msg, ident)
321 353 flag = 0
322 354 if buffers:
323 355 flag = zmq.SNDMORE
324 356 _track = False
325 357 else:
326 358 _track=track
327 359 if track:
328 360 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
329 361 else:
330 362 tracker = stream.send_multipart(to_send, flag, copy=False)
331 363 for b in buffers[:-1]:
332 364 stream.send(b, flag, copy=False)
333 365 if buffers:
334 366 if track:
335 367 tracker = stream.send(buffers[-1], copy=False, track=track)
336 368 else:
337 369 tracker = stream.send(buffers[-1], copy=False)
338 370
339 371 # omsg = Message(msg)
340 372 if self.debug:
341 373 pprint.pprint(msg)
342 374 pprint.pprint(to_send)
343 375 pprint.pprint(buffers)
344 376
345 377 msg['tracker'] = tracker
346 378
347 379 return msg
348 380
349 381 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
350 382 """Send a raw message via ident path.
351 383
352 384 Parameters
353 385 ----------
354 386 msg : list of sendable buffers"""
355 387 to_send = []
356 388 if isinstance(ident, bytes):
357 389 ident = [ident]
358 390 if ident is not None:
359 391 to_send.extend(ident)
360 392
361 393 to_send.append(DELIM)
362 394 to_send.append(self.sign(msg))
363 395 to_send.extend(msg)
364 396 stream.send_multipart(msg, flags, copy=copy)
365 397
366 398 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
367 399 """receives and unpacks a message
368 400 returns [idents], msg"""
369 401 if isinstance(socket, ZMQStream):
370 402 socket = socket.socket
371 403 try:
372 404 msg = socket.recv_multipart(mode)
373 405 except zmq.ZMQError as e:
374 406 if e.errno == zmq.EAGAIN:
375 407 # We can convert EAGAIN to None as we know in this case
376 408 # recv_multipart won't return None.
377 409 return None,None
378 410 else:
379 411 raise
380 412 # return an actual Message object
381 413 # determine the number of idents by trying to unpack them.
382 414 # this is terrible:
383 415 idents, msg = self.feed_identities(msg, copy)
384 416 try:
385 417 return idents, self.unpack_message(msg, content=content, copy=copy)
386 418 except Exception as e:
387 419 print (idents, msg)
388 420 # TODO: handle it
389 421 raise e
390 422
391 423 def feed_identities(self, msg, copy=True):
392 424 """feed until DELIM is reached, then return the prefix as idents and remainder as
393 425 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
394 426
395 427 Parameters
396 428 ----------
397 429 msg : a list of Message or bytes objects
398 430 the message to be split
399 431 copy : bool
400 432 flag determining whether the arguments are bytes or Messages
401 433
402 434 Returns
403 435 -------
404 436 (idents,msg) : two lists
405 437 idents will always be a list of bytes - the indentity prefix
406 438 msg will be a list of bytes or Messages, unchanged from input
407 439 msg should be unpackable via self.unpack_message at this point.
408 440 """
409 441 if copy:
410 442 idx = msg.index(DELIM)
411 443 return msg[:idx], msg[idx+1:]
412 444 else:
413 445 failed = True
414 446 for idx,m in enumerate(msg):
415 447 if m.bytes == DELIM:
416 448 failed = False
417 449 break
418 450 if failed:
419 451 raise ValueError("DELIM not in msg")
420 452 idents, msg = msg[:idx], msg[idx+1:]
421 453 return [m.bytes for m in idents], msg
422 454
423 455 def unpack_message(self, msg, content=True, copy=True):
424 456 """Return a message object from the format
425 457 sent by self.send.
426 458
427 459 Parameters:
428 460 -----------
429 461
430 462 content : bool (True)
431 463 whether to unpack the content dict (True),
432 464 or leave it serialized (False)
433 465
434 466 copy : bool (True)
435 467 whether to return the bytes (True),
436 468 or the non-copying Message object in each place (False)
437 469
438 470 """
439 471 minlen = 4
440 472 message = {}
441 473 if not copy:
442 474 for i in range(minlen):
443 475 msg[i] = msg[i].bytes
444 476 if self.auth is not None:
445 477 signature = msg[0]
446 478 if signature in self.digest_history:
447 479 raise ValueError("Duplicate Signature: %r"%signature)
448 480 self.digest_history.add(signature)
449 481 check = self.sign(msg[1:4])
450 482 if not signature == check:
451 483 raise ValueError("Invalid Signature: %r"%signature)
452 484 if not len(msg) >= minlen:
453 485 raise TypeError("malformed message, must have at least %i elements"%minlen)
454 486 message['header'] = self.unpack(msg[1])
455 487 message['msg_type'] = message['header']['msg_type']
456 488 message['parent_header'] = self.unpack(msg[2])
457 489 if content:
458 490 message['content'] = self.unpack(msg[3])
459 491 else:
460 492 message['content'] = msg[3]
461 493
462 494 message['buffers'] = msg[4:]
463 495 return message
464 496
465 497 def test_msg2obj():
466 498 am = dict(x=1)
467 499 ao = Message(am)
468 500 assert ao.x == am['x']
469 501
470 502 am['y'] = dict(z=1)
471 503 ao = Message(am)
472 504 assert ao.y.z == am['y']['z']
473 505
474 506 k1, k2 = 'y', 'z'
475 507 assert ao[k1][k2] == am[k1][k2]
476 508
477 509 am2 = dict(ao)
478 510 assert am['x'] == am2['x']
479 511 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now